In [2]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from project_utils import utils
from project_utils import read_utils as read

In [3]:
## prediction parameters
input_length = [12, 12, 12, 60]
orig_dates = utils.load_dates()
lats, lons = utils.load_lat_lon()
na_lats = np.array([lats[0:5], lats[-5:]]).reshape(-1)

from project_utils.variant_dict import VARIANT_DICT, val_index

In [4]:
y_var = "tos"
model_list=["ACCESS-ESM1-5", "CanESM5", "CNRM-CM6-1", "GISS-E2-1-G", 
            "IPSL-CM6A-LR", "MIROC-ES2L", "MIROC6", "MPI-ESM1-2-LR", "NorCPM1"]
leads = [0, 0, 24]
prediction_lengths = [36, 60, 60]
seeds = [101, 121, 505]

In [5]:
for MODEL in model_list:
    print(MODEL)
    val_variants = VARIANT_DICT[MODEL][val_index]   
    for lead, length in zip(leads, prediction_lengths):
        _, _, prediction_dates = utils.get_prediction_dates(orig_dates, input_length,lead,
                                                    length)
        _, y_val = read.load_xy_data(orig_dates, input_length, lead, length, "tos", y_var,
                         val_variants, MODEL)
        y_val = xr.DataArray(data = y_val.reshape(len(val_variants), len(prediction_dates), len(lats), len(lons)), 
                      coords = dict(variant = val_variants, time = prediction_dates, lat = lats, lon = lons), 
                      name = y_var)
        y_quantiles = read.read_data(y_var, MODEL, 
                             length, VARIANTS = val_variants, stat = "quantiles", as_xarray=True)

        loss_list = []
        for SEED in seeds:
            loss_list.append(xr.open_dataset("../processed_data/training/"+y_var+"_"+str(length)+\
                                             "mo_"+str(lead)+"lead_"+MODEL+"-trained_"+MODEL+"-val_loss_"+\
                                             str(SEED)+".nc").mean(dim = "variant").rename({"val_loss": "SEED_"+str(SEED)}))

        loss_ds = xr.merge(loss_list)
        na_mask = xr.where(np.isnan(loss_ds["SEED_"+str(seeds[0])]), np.nan, 1)
        best_seed = loss_ds.fillna(0).to_array("variable").argmin("variable")
        best_seed = xr.where(np.isnan(na_mask), np.nan, best_seed).rename("best_seed")
        best_seed.to_netcdf("../processed_data/training/"+y_var+"_"+str(length)+\
                  "mo_"+str(lead)+"lead_"+MODEL+"_best_seed.nc")


ACCESS-ESM1-5
CanESM5
CNRM-CM6-1
GISS-E2-1-G
IPSL-CM6A-LR
MIROC-ES2L
MIROC6
MPI-ESM1-2-LR
NorCPM1
