In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import hydra
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

In [None]:
from crims2s.dataset import S2SDataset, TransformedDataset
from crims2s.training.model.emos import NormalEMOSModel, LinearModel, NormalCubeNormalEMOS
from crims2s.training.model.util import PytorchMultiplexer
from crims2s.transform import AddLatLonFeature, AddBiweeklyDimTransform, CompositeTransform
from crims2s.util import ECMWF_FORECASTS

In [None]:
with hydra.initialize_config_module('crims2s.training.conf'):
    cfg = hydra.compose('config', overrides=['experiment/transform=latlon'])

In [None]:
t = hydra.utils.instantiate(cfg.experiment.transform)

In [None]:
d = TransformedDataset(S2SDataset(cfg.experiment.dataset.dataset_dir), t)

In [None]:
loader = torch.utils.data.DataLoader(d, batch_size=4)

In [None]:
one_batch = next(iter(loader))

In [None]:
one_batch['monthday']

In [None]:
one_batch['month']

In [None]:
t2m_mu = one_batch['model_parameters_t2m_mu']
t2m_sigma = one_batch['model_parameters_t2m_sigma']

In [None]:
class MultiplexedNormalEMOSModel(nn.Module):
    def __init__(self, loc_key, scale_key, linear_model_cls, key, biweekly=False, regularization=1e-9):
        super().__init__()

        self.loc_key = loc_key
        self.scale_key = scale_key
        
        shape = (2, 121, 240) if biweekly else (121, 240)

        self.loc_model = linear_model_cls(*shape)
        self.scale_model = linear_model_cls(*shape, fill_intercept=1.0)

        self.regularization = regularization
        
        self.key = key

    def forward(self, batch):
        forecast_loc, forecast_scale = batch[self.loc_key], batch[self.scale_key]
        key = batch[self.key]
        
        loc = self.loc_model(key, forecast_loc)

        scale = self.scale_model(key, forecast_scale)
        scale = torch.clip(forecast_scale, min=self.regularization)

        return torch.distributions.Normal(loc=loc, scale=scale)

In [None]:
class MonthlyNormalEMOSModel(MultiplexedNormalEMOSModel):
    def __init__(self, loc_key, scale_key, biweekly=False, regularization=1e-9):
        super().__init__(loc_key, scale_key, MonthlyLinearModel, 'month',  biweekly=biweekly, regularization=regularization)

In [None]:
class WeeklyNormalEMOSModel(MultiplexedNormalEMOSModel):
    def __init__(self, loc_key, scale_key, biweekly=False, regularization=1e-9):
        super().__init__(loc_key, scale_key, WeeklyLinearModel, 'monthday', biweekly=biweekly, regularization=regularization)

In [None]:
emos = NormalCubeNormalEMOS(biweekly=True)

In [None]:
list(emos.named_parameters())

In [None]:
emos(one_batch)

In [None]:
class MonthlyMultiplexer(PytorchMultiplexer):
    def __init__(self, cls, *args, **kwargs):
        monthly_models = {
            f"{month:02}": cls(*args, **kwargs) for month in range(1, 13)
        }
        
        super().__init__('month', monthly_models)

In [None]:
class WeeklyMultiplexer(PytorchMultiplexer):
    def __init__(self, cls, *args, **kwargs):
        monthdays = [f"{m:02}{d:02}" for m, d in ECMWF_FORECASTS]
        weekly_models = {monthday: cls(*args, **kwargs) for monthday in monthdays}
        
        super().__init__('monthday', weekly_models)

In [None]:
class MonthlyLinearModel(MonthlyMultiplexer):
    def __init__(self, *args, **kwargs):
        super().__init__(LinearModel, *args, **kwargs)

In [None]:
class WeeklyLinearModel(WeeklyMultiplexer):
    def __init__(self, *args, **kwargs):
        super().__init__(LinearModel, *args, **kwargs)

In [None]:
m = MonthlyLinearModel()

In [None]:
m(one_batch['month'], t2m_mu).shape