In [3]:
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
dataset_param = {'CESM2-omip1':{'num_years': 62, 'start_year': 1948},'CESM2-omip2':{'num_years': 61, 'start_year': 1958},'GFDL_ESM4':{'num_years': 95, 'start_year': 1920}}

In [None]:

## select the dataset
# dataset = 'CESM2-omip1'
# dataset = 'CESM2-omip2'
# dataset = 'GFDL_ESM4'
dataset = 'CESM2-omip1'
num_years = dataset_param[dataset]['num_years']
start_year = dataset_param[dataset]['start_year']
year_list = [num for num in range(start_year, start_year + num_years)]

## find the sampled data place
pwd = f'data/{dataset}/graph/'
all_label = []
count = 0
for year in tqdm(year_list):
    file_path = pwd + str(year) + '.pt'
    data = torch.load(file_path)
    xyz = data.x_geo[:,:3]
    latitude = torch.arcsin(xyz[:,-1])
    latitude = torch.rad2deg(latitude)
    longitude = torch.arctan2(xyz[:,1],xyz[:,0])
    longitude = torch.rad2deg(longitude)
    label = data.y
    years = data.x_geo[:,4] * num_years + start_year 
    geo_information = torch.stack((latitude, longitude, years), dim=1)  # (42491, 4)
    if count == 0:
        all_geoinformation = geo_information.cpu().numpy()
        all_label = label.cpu().numpy()
    else:
        all_geoinformation = np.concatenate([all_geoinformation, geo_information.cpu().numpy()])
        all_label = np.concatenate([all_label, label.cpu().numpy()])
    count += 1
depths = np.arange(1, 34)
depths = np.tile(np.arange(1, 34), len(all_geoinformation))
all_geoinformation =  np.repeat(all_geoinformation, 33, axis = 0)
depths_column = pd.Series(depths, name='depth')
oxygen_column = pd.Series(all_label.flatten(), name='oxygen')
df = pd.DataFrame(all_geoinformation, columns=['latitude', 'longitude', 'year'])
df['depth'] = depths_column
df['oxygen'] = oxygen_column
df['latitude'] = df['latitude'].round(1)
df['longitude'] = df['longitude'].round(1)
df['year'] = df['year'].round(0)
df['depth'] = df['depth'].round(2)

all_df = pd.read_csv(f'infer_result/all_df_empty_{dataset}.csv')
sampled_data = pd.merge(all_df, df, on=['year', 'depth', 'latitude', 'longitude'], how='left')
sampled_data = sampled_data['oxygen'].values
sampled_data = sampled_data.reshape((num_years,33,180,360))
sampled_data = torch.tensor(sampled_data)

  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 62/62 [00:13<00:00,  4.73it/s]


In [None]:

predict_result_path = '/home/zhaoze/oxygen_gnn/do_profile/Oceanverse/infer_result/inference_by_CESM2-omip1_random_MLP_model_1.npy'  ## change the path
ground_truth_path = f'data/{dataset}/ground_truth/ground_truth.pt'
predict_result = np.load(predict_result_path)
ground_truth = np.load(ground_truth_path)
ground_truth = torch.from_numpy(ground_truth)
predict_result = torch.from_numpy(predict_result)
ground_truth = torch.concatenate((ground_truth[:, :, :, 180:], ground_truth[:, :, :, :180]), dim=3)


In [8]:
## select the valid data
select_indices = torch.isnan(sampled_data) & (~torch.isnan(predict_result)) & (~torch.isnan(ground_truth))
predict_result_valid = predict_result[select_indices]
ground_truth_valid = ground_truth[select_indices]
delta = predict_result_valid - ground_truth_valid
mae = torch.mean(torch.abs(delta))
mse = torch.mean(delta ** 2)
rmse = torch.sqrt(mse)
r2 = 1 - (torch.sum(delta ** 2) / torch.sum((ground_truth_valid - torch.mean(ground_truth_valid)) ** 2))
mape_threshold = 5/1000
mape_indices = ground_truth_valid > mape_threshold
mape = torch.mean(torch.abs(delta[mape_indices]) / ground_truth_valid[mape_indices])
print('MAE:', mae.item())
print('MSE:', mse.item())
print('RMSE:', rmse.item())
print('R2:', r2.item())
print('MAPE:', mape.item())

MAE: 0.06316649270201627
MSE: 0.006535294994396422
RMSE: 0.08084117140663179
R2: 0.22696399334128536
MAPE: 1.206215771634637
