# Import pakages

In [None]:
import os
from climsim_utils.data_utils import *
import torch.nn as nn
import torch
import torch.optim as optim

# Instantiate data class

In [None]:
grid_path = '/work/FAC/FGSE/IDYST/tbeucler/ai4pex/Physics_no_offline/ClimSim_offline/modules/climsim3/grid_info/ClimSim_low-res_grid-info.nc'
norm_path = '/work/FAC/FGSE/IDYST/tbeucler/ai4pex/Physics_no_offline/ClimSim_offline/modules/climsim3/preprocessing/normalizations/'

grid_info = xr.open_dataset(grid_path)
input_mean = xr.open_dataset(norm_path + 'inputs/input_mean_v2_rh_mc_pervar.nc')
input_max = xr.open_dataset(norm_path + 'inputs/input_max_v2_rh_mc_pervar.nc')
input_min = xr.open_dataset(norm_path + 'inputs/input_min_v2_rh_mc_pervar.nc')
output_scale = xr.open_dataset(norm_path + 'outputs/output_scale_std_lowerthred_v2_rh_mc.nc')


data = data_utils(grid_info = grid_info, 
                  input_mean = input_mean, 
                  input_max = input_max, 
                  input_min = input_min, 
                  output_scale = output_scale,
                  qinput_log = False , # added by me
                  input_abbrev = 'mlexpand',
                  output_abbrev = 'mlo',
                  normalize=True,
                  save_zarr=False,
                  save_h5=False,
                  save_npy=True,
                  cpuonly=False
                  )


# data = data_utils(grid_info = grid_info, 
#                   input_mean = input_mean, 
#                   input_max = input_max, 
#                   input_min = input_min, 
#                   output_scale = output_scale,
#                   qinput_log = True) # qinout_log WAS MISSING EVERYWHERE FOR SOME REASON


# set variables to V1 subset
data.set_to_v2_rh_vars()  # to be changed ?

# Load validation data

The .npy files shown below were created using the `create_npy_data_splits.ipynb` notebook in the preprocessing folder.

In [None]:
data_path = '/work/FAC/FGSE/IDYST/tbeucler/ai4pex/container/climsim-container/storage/climsim_highres_v2_rh_mc'

val_input_path = data_path + 'vals_set/val_input.npy'
val_target_path = data_path + 'val_set/val_target.npy'

data.input_val = data.load_npy_file(val_input_path)
data.target_val = data.load_npy_file(val_target_path)

# Load models

In [None]:
data.set_pressure_grid(data_split = 'val')
data.model_names = []
preds = []

### Load Multiple Linear Regression model

In [None]:
from MLR.mlr import MultiLinearRegression, train_mlr

train_inputs = torch.tensor(data.input_train, dtype=torch.float32)
train_targets = torch.tensor(data.target_train, dtype=torch.float32)

mlr = train_mlr(train_inputs,train_inputs, loss_name = 'mse')


# Multiple Linear Regression
mlr = MultiLinearRegression(124, 128)
mlr.load_state_dict(torch.load('fs_mlr.pth'))

X_val = torch.tensor(data.input_val, dtype=torch.float32)
mlr.eval()
with torch.no_grad():  
    mlr_pred_val = mlp(X_val).numpy()

print(mlr_pred_val.shape)
data.model_names.append('mlr')
preds.append(mlr_pred_val)

### Load Multi Layer Perceptron

### Set pressure grid

In [None]:
from MLP.mlp import MLP , train_mlp


hidden_layers = [384, 1024, 640]

mlp = MLP(124, hidden_layers , 128)
print(data.input_train.shape,data.target_train.shape)
criterion = nn.MSELoss()
optimizer = optim.Adam(mlp.parameters(), lr=0.001)

train_inputs = torch.tensor(data.input_train, dtype=torch.float32)
train_targets = torch.tensor(data.target_train, dtype=torch.float32)

train_mlp(mlp, train_inputs , train_targets, criterion, optimizer, epochs=20, batch_size=1024)
torch.save(mlp.state_dict(), 'fs_mlp.pth')


# Multiple layer perceptron

mlp = MLP(124, hidden_layers , 128)
mlp.load_state_dict(torch.load('fs_mlp.pth'))

X_val = torch.tensor(data.input_val, dtype=torch.float32)
mlp.eval()
with torch.no_grad():  
    mlp_pred_val = mlp(X_val).numpy()

print(mlp_pred_val.shape)
data.model_names.append('mlp')
preds.append(mlp_pred_val)

In [None]:
data.preds_val = dict(zip(data.model_names, preds))

# Evaluate on validation data

### Load predictions

### Weight predictions and target (just for V1 ?)

1. Undo output scaling

2.  Weight vertical levels by dp/g

3. Weight horizontal area of each grid cell by a[x]/mean(a[x])

4. Convert units to a common energy unit

In [None]:
data.reweight_target(data_split = 'val')
data.reweight_preds(data_split = 'val')

### Set and calculate metrics

In [None]:
data.metrics_names = ['MAE', 'RMSE', 'R2', 'bias']
data.create_metrics_df(data_split = 'val')

### Create plots

In [None]:
# set plotting settings
%config InlineBackend.figure_format = 'retina'
letters = string.ascii_lowercase

# create custom dictionary for plotting
dict_var = data.metrics_var_val
plot_df_byvar = {}
for metric in data.metrics_names:
    plot_df_byvar[metric] = pd.DataFrame([dict_var[model][metric] for model in data.model_names],
                                               index=data.model_names)
    plot_df_byvar[metric] = plot_df_byvar[metric].rename(columns = data.var_short_names).transpose()

# plot figure
fig, axes = plt.subplots(nrows  = len(data.metrics_names), sharex = True)
for i in range(len(data.metrics_names)):
    plot_df_byvar[data.metrics_names[i]].plot.bar(
        legend = False,
        ax = axes[i])
    print(data.metrics_names[i],plot_df_byvar[data.metrics_names[i]])
    if data.metrics_names[i] != 'R2':
        axes[i].set_ylabel('$W/m^2$')
    else:
        axes[i].set_ylim(0,1)

    axes[i].set_title(f'({letters[i]}) {data.metrics_names[i]}')
axes[i].set_xlabel('Output variable')
axes[i].set_xticklabels(plot_df_byvar[data.metrics_names[i]].index, \
    rotation=0, ha='center')

axes[0].legend(columnspacing = .9, 
               labelspacing = .3,
               handleheight = .07,
               handlelength = 1.5,
               handletextpad = .2,
               borderpad = .2,
               ncol = 3,
               loc = 'upper right')
fig.set_size_inches(7,8)
fig.tight_layout()

If you trained models with different hyperparameters, use the ones that performed the best on validation data for evaluation on scoring data.

## Evaluate on scoring data

#### Do this at the VERY END (when you have finished tuned the hyperparameters for your  model and are seeking a final evaluation)

### Load scoring data

In [None]:
scoring_input_path = "/work/FAC/FGSE/IDYST/tbeucler/ai4pex/Physics_no_offline/First_step/fs_data/test_input.npy"
scoring_target_path = "/work/FAC/FGSE/IDYST/tbeucler/ai4pex/Physics_no_offline/First_step/fs_data/test_target.npy"
# path to target input
data.input_scoring = np.load(scoring_input_path)

# path to target output
data.target_scoring = np.load(scoring_target_path)

### Set pressure grid

In [None]:
data.set_pressure_grid(data_split = 'scoring')

### Load predictions

In [None]:
# constant prediction
const_pred_scoring = np.repeat(const_model[np.newaxis, :], data.target_scoring.shape[0], axis = 0)
print(const_pred_scoring.shape)

# multiple linear regression
X_scoring = data.input_scoring
bias_vector_scoring = np.ones((X_scoring.shape[0], 1))
X_scoring = np.concatenate((X_scoring, bias_vector_scoring), axis=1)
mlr_pred_scoring = X_scoring@mlr_weights
print(mlr_pred_scoring.shape)

# Your model prediction here

# Load predictions into object
data.model_names = ['const', 'mlr'] # model name here
preds = [const_pred_scoring, mlr_pred_scoring] # add prediction here
data.preds_scoring = dict(zip(data.model_names, preds))

### Weight predictions and target

1. Undo output scaling

2.  Weight vertical levels by dp/g

3. Weight horizontal area of each grid cell by a[x]/mean(a[x])

4. Convert units to a common energy unit

In [None]:
# weight predictions and target
data.reweight_target(data_split = 'scoring')
data.reweight_preds(data_split = 'scoring')

# set and calculate metrics
data.metrics_names = ['MAE', 'RMSE', 'R2', 'bias']
data.create_metrics_df(data_split = 'scoring')

### Create plots

In [None]:
# set plotting settings
%config InlineBackend.figure_format = 'retina'
letters = string.ascii_lowercase

# create custom dictionary for plotting
dict_var = data.metrics_var_scoring
plot_df_byvar = {}
for metric in data.metrics_names:
    plot_df_byvar[metric] = pd.DataFrame([dict_var[model][metric] for model in data.model_names],
                                               index=data.model_names)
    plot_df_byvar[metric] = plot_df_byvar[metric].rename(columns = data.var_short_names).transpose()

# plot figure
fig, axes = plt.subplots(nrows  = len(data.metrics_names), sharex = True)
for i in range(len(data.metrics_names)):
    plot_df_byvar[data.metrics_names[i]].plot.bar(
        legend = False,
        ax = axes[i])
    if data.metrics_names[i] != 'R2':
        axes[i].set_ylabel('$W/m^2$')
    else:
        axes[i].set_ylim(0,1)

    axes[i].set_title(f'({letters[i]}) {data.metrics_names[i]}')
axes[i].set_xlabel('Output variable')
axes[i].set_xticklabels(plot_df_byvar[data.metrics_names[i]].index, \
    rotation=0, ha='center')

axes[0].legend(columnspacing = .9, 
               labelspacing = .3,
               handleheight = .07,
               handlelength = 1.5,
               handletextpad = .2,
               borderpad = .2,
               ncol = 3,
               loc = 'upper right')
fig.set_size_inches(7,8)
fig.tight_layout()