In [None]:
%load_ext autoreload
%autoreload 2

# Linear Model

Train a linear model for a single dayofyear using one of the ml datasets we made.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torch.distributions
import xarray as xr

from crims2s.dataset import S2SDataset, TransformedDataset
from crims2s.transform import CompositeTransform, add_biweekly_dim_transform

In [None]:
DATASET = '***BASEDIR***/mlready/2021-08-07-test/'

## make transform to interface dataset w/ linear model

In [None]:
def std_estimator(dataset, dim=None):
    dataset_mean = dataset.mean(dim=dim)
    
    if dim is None:
        dim_sizes = [dataset.sizes[x] for x in dataset_mean.dims]
    elif isinstance(dim, str):
        dim_sizes = dataset.sizes[dim]
    else:
        dim_sizes = [dataset.sizes[x] for x in dim]
    
    n = np.prod(dim_sizes)
    
    return xr.ufuncs.sqrt(xr.ufuncs.square(dataset - dataset_mean).sum(dim=dim) / (n - 1))

In [None]:
def model_to_distribution(model):   
    model_tp_mean = model.tp.isel(lead_time=-1).mean(dim='realization').rename('tp_mean')
    model_tp_std = std_estimator(model.tp.isel(lead_time=-1), dim='realization').rename('tp_std')
    
    model_t2m_mean = model.t2m.mean(dim=['lead_time', 'realization']).rename('t2m_mean')
    model_t2m_std = std_estimator(model.t2m, dim=['lead_time', 'realization']).rename('t2m_std')
    
    return xr.merge([
        model_tp_mean, model_tp_std, model_t2m_mean, model_t2m_std
    ]).drop('lead_time').rename(biweekly_forecast='lead_time')

In [None]:
def obs_to_biweekly(obs):
    aggregate_obs_tp = obs.pr.sum(dim='lead_time', min_count=2).rename('tp')
    aggregate_obs_t2m = obs.t2m.mean(dim='lead_time')
    return xr.merge([aggregate_obs_tp, aggregate_obs_t2m])

In [None]:
def linear_model_adapter(example):
    model = model_to_distribution(example['model'])
    obs = obs_to_biweekly(example['obs'])
    
    return {
        'model': model,
        'obs': obs
    }

In [None]:
def to_pytorch(example):
    model = example['model']
    obs = example['obs']
    
    return {
        'model_tp_mu': torch.from_numpy(model.tp_mean.data),
        'model_tp_sigma': torch.from_numpy(model.tp_std.data),
        'model_t2m_mu': torch.from_numpy(model.t2m_mean.data),
        'model_t2m_sigma': torch.from_numpy(model.t2m_std.data),
        'obs_t2m': torch.from_numpy(obs.t2m.data),
        'obs_tp': torch.from_numpy(obs.tp.data),
    }

In [None]:
transform = CompositeTransform([add_biweekly_dim_transform, linear_model_adapter, to_pytorch])

In [None]:
train_dataset = TransformedDataset(S2SDataset(DATASET, filter_str='0102.nc', include_features=False), transform)
val_dataset = TransformedDataset(S2SDataset(DATASET, filter_str='0102.nc', include_features=False), transform)

In [None]:
dataloader = torch.utils.data.DataLoader(train_dataset, num_workers=2, batch_size=1)

In [None]:
tp_mu_intercept = torch.zeros(3, 121, 240, requires_grad=True)
tp_mu_weights = torch.zeros(3, 121, 240, requires_grad=True)
tp_sigma_intercept = torch.ones(3, 121, 240, requires_grad=True)
tp_sigma_weights = torch.zeros(3, 121, 240, requires_grad=True)

t2m_mu_intercept = torch.zeros(3, 121, 240, requires_grad=True)
t2m_mu_weights = torch.zeros(3, 121, 240, requires_grad=True)
t2m_sigma_intercept = torch.full((3, 121, 240), 2.0, requires_grad=True)
t2m_sigma_weights = torch.zeros(3, 121, 240, requires_grad=True)

In [None]:
optimizer = torch.optim.SGD(params=[tp_mu_intercept, tp_mu_weights, tp_sigma_intercept, tp_sigma_weights, t2m_mu_intercept, t2m_mu_weights, t2m_sigma_intercept, t2m_sigma_weights], lr=5e-2)

In [None]:
for epoch in range(5):
    for example in dataloader:
        model_tp_mu, model_tp_sigma = example['model_tp_mu'], example['model_tp_sigma']
        tp_mu = tp_mu_intercept + tp_mu_weights * model_tp_mu + model_tp_mu
        tp_sigma = tp_sigma_intercept + tp_sigma_weights * model_tp_sigma + model_tp_sigma
        tp_sigma = torch.clip(tp_sigma, min=1e-6)

        model_t2m_mu, model_t2m_sigma = example['model_t2m_mu'], example['model_t2m_sigma']
        t2m_mu = t2m_mu_intercept + t2m_mu_weights * model_t2m_mu + model_t2m_mu
        t2m_sigma = t2m_sigma_intercept + t2m_sigma_weights * model_t2m_sigma + model_t2m_sigma
        t2m_sigma = torch.clip(t2m_sigma, min=1e-6)
        

        tp_dist = torch.distributions.Normal(loc=tp_mu, scale=tp_sigma)
        t2m_dist = torch.distributions.Normal(loc=t2m_mu, scale=t2m_sigma)

        tp_obs = example['obs_t2m']
        tp_nan_mask = tp_obs.isnan()
        tp_obs[tp_nan_mask] = 0.0
        tp_log_likelihood = tp_dist.log_prob(tp_obs)
        tp_log_likelihood[tp_nan_mask] = 0.0

        t2m_obs = example['obs_t2m']
        t2m_nan_mask = tp_obs.isnan()
        t2m_obs[t2m_nan_mask] = 0.0
        t2m_log_likelihood = t2m_dist.log_prob(tp_obs)
        t2m_log_likelihood[t2m_nan_mask] = 0.0

        rain_loss = -tp_log_likelihood.mean()
        temperature_loss =  -t2m_log_likelihood.mean()
        loss = rain_loss + temperature_loss

        print(f'T2M: {temperature_loss}, TP: {rain_loss}, TOTAL: {loss}')

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

In [None]:
t2m_sigma_intercept.max()

In [None]:
tp_sigma_intercept[1]

In [None]:
sns.histplot(data=t2m_mu_weights[2].detach().numpy().flatten(), bins=30)

In [None]:
t2m_mu_weights[2].mean()

In [None]:
t2m_mu_weights.grad

In [None]:
obs_tp