In [None]:
from matplotlib import pyplot as plt

import xarray as xr
import netCDF4 as nc
import numpy as np

import os

import datetime as dt
import pickle

In [None]:
models = ['CanESM5', 'MIROC-ES2L', 'MPI-ESM1-2-LR', 'MIROC6', 'CESM2']

## Loading and averaging weights

In [None]:
# Loading wieghts
N=1500
coarse=6
var = 'psl'
directory = '../weights/'
if var == 'pr':
    file = 'Anchor_OPLS_{}_weights_stability_n{}_coarse{}.pkl'.format('precip', N, coarse)
elif var == 'tas' :
    file = 'Anchor_OPLS_{}_weights_stability_n{}_coarse{}.pkl'.format(var, 600, coarse)
else :
    file = 'Anchor_OPLS_{}_weights_stability_n{}_coarse{}.pkl'.format(var, N, coarse)
with open(directory + file, 'rb') as f:
    weights = pickle.load(f)

In [None]:
# Averaging training weights
gamma = 5 # Possible gamma : {2, 5, 10}
total_weights = np.zeros(weights['CanESM5'][gamma]['weights'].shape)
total_samples = 0
for model in models:
    total_weights += weights[model][gamma]['weights']*weights[model][gamma]['n_samples']
    total_samples += weights[model][gamma]['n_samples']
total_weights /= total_samples

## Predictions

In [None]:
path='../data/Evaluation-Tier1/Amon/{}'.format(var)
path2='../predictions/{}'.format(var)

for file in os.listdir(path) :
    # Opening test file 
    file_path = os.path.join(path, file)
    ds = xr.open_dataset(file_path)
    # Compute anomalies
    climatology = ds.groupby('time.month').mean(dim='time')
    anomalies = ds.groupby('time.month') - climatology
    shape = anomalies[var].shape
    # Coarsening anomalies
    anomalies_regional = anomalies[var].values.reshape(shape[0], shape[1]//coarse, coarse, shape[2]//coarse, coarse)
    X = anomalies_regional.mean(axis=(2, 4)).reshape(shape[0], (shape[1]//coarse) * (shape[2]//coarse))
    
    # Prediction using reduced rank weight matrix 
    Y_pred = X @ total_weights
    
    # Reshape to original shape (time, lat, lon)
    Y_pred_spatial = Y_pred.reshape(shape[0], shape[1], shape[2])
    
    
    # Saving data 
    ds[var][:] = Y_pred_spatial 
    #ds = ds.groupby('time.month') + climatology
    
    file_path2 = os.path.join(path2, file[:-3] + 'predictions.c')
    ds.to_netcdf(file_path2)
    
    ds.close()

## Checking predictions

In [None]:
path2='../predictions/{}'.format(var)
i=0
for file in os.listdir(path2) :
    file_path = os.path.join(path2, file)
    ds = xr.open_dataset(file_path)
    plt.plot(ds[var].values.mean(axis=(1, 2)), label='file {}'.format(i))
    ds.close()
    i+=1
plt.legend()
plt.show()