In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import hydra
import numpy as np
import pathlib
import torch
import xarray as xr

from crims2s.training.util import find_checkpoint_file
from crims2s.training.lightning import S2STercilesModule

In [None]:
with hydra.initialize_config_module('crims2s.training.conf'):
    cfg = hydra.compose('config', overrides=['experiment=emos', 'experiment/model=emos_ecmwf_rolling'])

In [None]:
dataset_path = pathlib.Path(cfg.experiment.dataset.dataset_dir)
sample_path = next(iter(dataset_path.glob('*.nc')))

In [None]:
sample_path

In [None]:
sample = xr.open_dataset(sample_path, group='/edges')

In [None]:
sample

In [6]:
find_checkpoint_file('***BASEDIR***/runs/train/outputs/2021-09-28/12-38-58')

PosixPath('***BASEDIR***/runs/train/outputs/2021-09-28/12-38-58/lightning/default_8/0_b0d8fe0320d4436b8ea0c3b63b0e285c/checkpoints/epoch=46-step=2396.ckpt')

In [None]:
checkpoint_path = find_checkpoint_file(
    hydra.utils.to_absolute_path('***BASEDIR***/runs/train/outputs/2021-09-28/11-36-48')
)

checkpoint_path

In [None]:
chkpt = torch.load(checkpoint_path, map_location=torch.device('cpu'))

In [None]:
state_dict = chkpt['state_dict']

In [None]:
state_dict

In [None]:
model_state = {k[6:]: state_dict[k] for k in state_dict if k.startswith('model.')}

In [None]:
model = hydra.utils.instantiate(cfg.experiment.model)
optimizer = hydra.utils.call(cfg.experiment.optimizer, model)

lightning_module = S2STercilesModule.load_from_checkpoint(
    checkpoint_path, model=model, optimizer=optimizer
)

In [None]:
lightning_module.model

In [None]:
def params_of_model(model, prefix='', nan_mask=None):
    dims = ['model', 'lead_time', 'latitude', 'longitude']
    coordinates = {'lead_time': sample.lead_time, 'latitude': sample.latitude, 'longitude': sample.longitude}
    
    model_weights = []
    for k in sorted(model.models):
        model_weights.append(model.models[k].weights)
    model_weights = torch.stack(model_weights)
    model_weights = xr.DataArray(data=model_weights.detach().numpy(), dims=dims, coords=coordinates)
    
    model_intercept = []
    for k in sorted(model.models):
        model_intercept.append(model.models[k].intercept)
    model_intercept = torch.stack(model_intercept)
    model_intercept = xr.DataArray(data=model_intercept.detach().numpy(), dims=dims, coords=coordinates)
    
    if nan_mask is not None:
        model_weights = xr.where(~nan_mask, model_weights, np.nan)
        model_intercept = xr.where(~nan_mask, model_intercept, np.nan)
        
    dataset = xr.Dataset(
        data_vars={
            f'{prefix}_weights': model_weights,
            f'{prefix}_intercept': model_intercept, 
        }
    )
    
    return dataset

In [None]:
nan_mask = sample.isnull().any(dim='category_edge')
t2m_nan_mask = nan_mask.t2m
tp_nan_mask = nan_mask.tp

In [None]:
params = xr.merge([
    params_of_model(model.model.tp_model.loc_model, prefix='tp_loc', nan_mask=tp_nan_mask),
    params_of_model(model.model.tp_model.scale_model, prefix='tp_scale', nan_mask=tp_nan_mask),
    params_of_model(model.model.t2m_model.loc_model, prefix='t2m_loc', nan_mask=tp_nan_mask),
    params_of_model(model.model.t2m_model.scale_model, prefix='t2m_scale', nan_mask=tp_nan_mask),
])

In [None]:
params

# Analysis

In [None]:
params.tp_scale_weights.isel(lead_time=1).plot()

In [None]:
params.t2m_scale_intercept.isel(lead_time=0).plot()