# Compare models on Nino indices

In [None]:
import os, torch, nc_time_axis, yaml
from importlib import reload
import numpy as np
import xarray as xr
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import cartopy as ctp
from tqdm import tqdm
from joblib import Parallel, delayed, parallel_backend

import hyblim.geoplot as gpl
from hyblim.utils import metric

plt.style.use("../../paper.mplstyle")

def get_model_specs_by_name(experiments, exp_name):
    for exp in experiments:
        if exp['name'] == exp_name:
            return exp 
    return None

# Load list of experiments
with open("experiments.yaml", "r") as f:
    experiments = yaml.safe_load(f)

In [None]:
load_experiments = ['LIM', 'LSTM', 'LIM+LSTM']
datasplit = 'test'


nino_scores, nino_scores_month = {}, {}
for exp_name in load_experiments:
    nino_scores_exp, nino_scores_month_exp, num_traindata = [], [], []
    exp = get_model_specs_by_name(experiments, f"{exp_name}")
    if exp is None:
        print(f"Experiment {exp_name} not found")
        continue
    nino_frcst = xr.open_dataset(exp['paths'][0] + '/metrics/nino_frcst_test.nc').transpose('time', 'member', 'lag')
    nino_target = xr.open_dataset(exp['paths'][0] + '/metrics/nino_target_test.nc').transpose('time','lag')

    nino_scores[exp_name], nino_scores_month[exp_name] = metric.time_series_score(nino_frcst, nino_target)

## Plot skill score averaged over all months

In [None]:
# Line plots 
reload(gpl)
model_name = ["LIM", "LSTM", "LIM+LSTM"] 
scores = ['rmsess', 'crpss']
idx_name = 'nino4'
plot_mean = True
ncols = len(scores)


fig, axs = plt.subplots(1, ncols, figsize=(7, 3.0),
                        sharex=True, sharey=True)

for i, score_name in enumerate(scores):
    for j, model in enumerate(model_name):
        ax = axs[i] if len(scores)>1 else axs

        score = nino_scores[model][score_name][idx_name]

        ax.plot(score['lag'], score, '-o', label=model)

        ax.set_xlabel(r'$\tau$ [months]')
        ax.set_ylabel(score_name)
        ax.axhline(0.0, color='k', linestyle='--')

    if i ==0:
        ax.legend(fontsize='small')
    

ax = axs[-1] if len(scores)>1 else axs
ax.set_ylim(-.1, 0.95)
_ = ax.set_xticks(score['lag'])

#gpl.enumerate_axes(axs, pos_x=0.01, pos_y=1.1, fontsize='medium')

## Plot skill over months

In [None]:
# Plot skill score
metrickey = 'rmsess'
pltspec = {
    'crpss' : dict(cmap='plasma_r', vmin=0, vmax=0.6),
    'rmsess' : dict(cmap='plasma', vmin=0, vmax=0.9),
}
ids_name = ['nino5', 'nino4', 'nino3', 'nino12']
nrows = len(model_name)
ncols = len(ids_name)

fig, axs = plt.subplots(nrows, ncols,
                        figsize=(ncols*3, nrows*2.5),
                        sharex=True, sharey=True)

for i, modelkey in enumerate(model_name):
    for j, nino_idx in enumerate(ids_name):
        ax = axs[i, j] if nrows > 1 else axs[j]
        score = nino_scores_month[modelkey][metrickey][nino_idx]
        im = ax.contourf(score['lag'], score['month'], score.data,
                         **pltspec[metrickey])
        ax.set_yticks(score['month'])
        ax.set_xticks(score['lag'])

        if i == nrows-1:
            ax.set_xlabel(rf"$\tau$")
        if i ==0:
            ax.set_title(rf"{nino_idx}")
        if j == 0:
            ax.set_yticklabels(['J', 'F', 'M', 'A', 'M', 'J', 'J', 'A', 'S', 'O', 'N', 'D'])
            ax.set_ylabel(f"{modelkey}")

cbar_ax = fig.add_axes([1, .2, 0.01, 0.55])
fig.colorbar(im, cax=cbar_ax, orientation='vertical', extend='max', label=rf"{metrickey}")

### Score differences per month

In [None]:
# Compare all to one skill score
reload(gpl)
model_ctrl = "LIM"
model_name = ["LSTM", "LIM+LSTM"]
metrickey = 'rmsess'
ids_name = ['nino5', 'nino4', 'nino3', 'nino12']
plparam = {
    'rmsess': dict(cmap='RdBu_r', vmin=-.2, vmax=.2, eps=0.02, centercolor="#FFFFFF"),
    'crpsss': dict(cmap='RdBu', vmin=-0.1, vmax = 0.1, eps=0.01, centercolor="#FFFFFF"),
}
alpha = 0.05 # Statistical significance level

nrows = len(model_name)
ncols = len(ids_name)
fig, axs = plt.subplots(nrows, len(ids_name), 
                        figsize=(ncols*3, nrows*2.5),
                        sharex=True, sharey=True)

for i, modelkey in enumerate(model_name):
    for j, nino_idx in enumerate(ids_name):
        score_model = nino_scores_month[modelkey][metrickey][nino_idx]
        score_ctrl = nino_scores_month[model_ctrl][metrickey][nino_idx]
        skill_diff = (score_model - score_ctrl)

        ax = axs[i, j] if nrows > 1 else axs[j]
        im = gpl.plot_matrix(skill_diff, 'lag', 'month', ax=ax,
                             bar='discrete', add_bar=False, **plparam[metrickey])

        ax.set_yticks(skill_diff['month'])
        ax.set_xticks(skill_diff['lag'])

        if i ==0:
            ax.set_title(rf"{nino_idx}")
        if j == 0:
            ax.set_yticklabels(['J', 'F', 'M', 'A', 'M', 'J', 'J', 'A', 'S', 'O', 'N', 'D'])
            #ax.set_ylabel("init. month" if key=='skill_init_month' else 'verif. month')
            ax.set_ylabel(rf"{modelkey}", fontsize=12)
        if i == nrows-1:
            ax.set_xlabel(r"$\tau$")

cbar_ax = fig.add_axes([1, .2, 0.01, 0.6])
fig.colorbar(im['im'], cax=cbar_ax, orientation='vertical', extend='both', label=rf"$\Delta$ {metrickey}")