In [None]:
%load_ext autoreload
%autoreload 2

# Linear Model

Train a linear model for a single dayofyear using one of the ml datasets we made.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torch.distributions
import torch.nn as nn
import xarray as xr

from crims2s.dataset import S2SDataset, TransformedDataset
from crims2s.transform import CompositeTransform, add_biweekly_dim_transform

In [None]:
DATASET = '***BASEDIR***/mlready/2021-08-08-test/'

## make transform to interface dataset w/ linear model

In [None]:
def std_estimator(dataset, dim=None):
    dataset_mean = dataset.mean(dim=dim)
    
    if dim is None:
        dim_sizes = [dataset.sizes[x] for x in dataset_mean.dims]
    elif isinstance(dim, str):
        dim_sizes = dataset.sizes[dim]
    else:
        dim_sizes = [dataset.sizes[x] for x in dim]
    
    n = np.prod(dim_sizes)
    
    return xr.ufuncs.sqrt(xr.ufuncs.square(dataset - dataset_mean).sum(dim=dim) / (n - 1))

In [None]:
def model_to_distribution(model):   
    model_tp_mean = model.tp.isel(lead_time=-1).mean(dim='realization').rename('tp_mean')
    model_tp_std = std_estimator(model.tp.isel(lead_time=-1), dim='realization').rename('tp_std')
    
    model_t2m_mean = model.t2m.mean(dim=['lead_time', 'realization']).rename('t2m_mean')
    model_t2m_std = std_estimator(model.t2m, dim=['lead_time', 'realization']).rename('t2m_std')
    
    return xr.merge([
        model_tp_mean, model_tp_std, model_t2m_mean, model_t2m_std
    ]).drop('lead_time').rename(biweekly_forecast='lead_time')

In [None]:
def obs_to_biweekly(obs):
    aggregate_obs_tp = obs.pr.sum(dim='lead_time', min_count=2).rename('tp')
    aggregate_obs_t2m = obs.t2m.mean(dim='lead_time')
    return xr.merge([aggregate_obs_tp, aggregate_obs_t2m])

In [None]:
def linear_model_adapter(example):
    model = model_to_distribution(example['model'])
    obs = obs_to_biweekly(example['obs'])
    
    return {
        'model': model,
        'obs': obs
    }

In [None]:
def to_pytorch(example):
    model = example['model']
    obs = example['obs']
    
    return {
        'model_tp_mu': torch.from_numpy(model.tp_mean.data),
        'model_tp_sigma': torch.from_numpy(model.tp_std.data),
        'model_t2m_mu': torch.from_numpy(model.t2m_mean.data),
        'model_t2m_sigma': torch.from_numpy(model.t2m_std.data),
        'obs_t2m': torch.from_numpy(obs.t2m.data),
        'obs_tp': torch.from_numpy(obs.tp.data),
    }

In [None]:
transform = CompositeTransform([add_biweekly_dim_transform, linear_model_adapter, to_pytorch])

In [None]:
train_dataset = TransformedDataset(S2SDataset(DATASET, filter_str='0312.nc', include_features=False, years=list(range(2000,2017))), transform)
val_dataset = TransformedDataset(S2SDataset(DATASET, filter_str='0312.nc', include_features=False, years=list(range(2017,2020))), transform)

In [None]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, num_workers=4, batch_size=1)
val_dataloader = torch.utils.data.DataLoader(val_dataset, num_workers=1, batch_size=1)

In [None]:
class LinearModel(nn.Module):
    def __init__(self, *shape, fill_weights=0.0, fill_intercept=0.0):
        super().__init__()
        
        self.weights = nn.Parameter(torch.full(shape, fill_weights))
        self.intercept = nn.Parameter(torch.full(shape, fill_intercept))
        
    def forward(self, x):
        return self.intercept + self.weights * x + x

In [None]:
class TempPrecipEMOS(nn.Module):
    def __init__(self, biweekly=False):
        super().__init__()
        
        shape = (3, 121, 240) if biweekly else (121, 240)
        
        self.tp_mu_model = LinearModel(*shape)
        self.tp_sigma_model = LinearModel(*shape, fill_intercept=1.0)
        
        self.t2m_mu_model = LinearModel(*shape)
        self.t2m_sigma_model = LinearModel(*shape, fill_intercept=1.0)
        
    def forward(self, forecast_t2m_mu, forecast_t2m_sigma, forecast_tp_mu, forecast_tp_sigma):
        tp_mu = self.tp_mu_model(forecast_tp_mu)
        tp_sigma = self.tp_sigma_model(forecast_tp_sigma)
        tp_sigma = torch.clip(tp_sigma, min=1e-6)

        t2m_mu = self.t2m_mu_model(forecast_t2m_mu)
        t2m_sigma = self.t2m_sigma_model(forecast_t2m_sigma)
        t2m_sigma = torch.clip(t2m_sigma, min=1e-6)
        
        tp_dist = torch.distributions.Normal(loc=tp_mu, scale=tp_sigma)
        t2m_dist = torch.distributions.Normal(loc=t2m_mu, scale=t2m_sigma)
        
        return t2m_dist, tp_dist

In [None]:
model = TempPrecipEMOS(biweekly=True)

In [None]:
optimizer = torch.optim.Adam(params=model.parameters(), lr=5e-3)

In [None]:
for epoch in range(6):
    train_losses = []
    train_temperature_losses = []
    train_rain_losses = []
    
    model.train()
    for example in train_dataloader:
        model_tp_mu, model_tp_sigma = example['model_tp_mu'], example['model_tp_sigma']
        #model_tp_mu, model_tp_sigma = model_tp_mu.cuda(), model_tp_sigma.cuda()
        
        model_t2m_mu, model_t2m_sigma = example['model_t2m_mu'], example['model_t2m_sigma']
        #model_t2m_mu, model_t2m_sigma = model_t2m_mu.cuda(), model_t2m_sigma.cuda()
        
        t2m_dist, tp_dist = model.forward(model_t2m_mu, model_t2m_sigma, model_tp_mu, model_tp_sigma)

        tp_obs = example['obs_tp']
        tp_nan_mask = tp_obs.isnan()
        tp_obs[tp_nan_mask] = 0.0
        tp_log_likelihood = tp_dist.log_prob(tp_obs)
        tp_log_likelihood[tp_nan_mask] = 0.0

        t2m_obs = example['obs_t2m']
        t2m_nan_mask = t2m_obs.isnan()
        t2m_obs[t2m_nan_mask] = 0.0
        t2m_log_likelihood = t2m_dist.log_prob(t2m_obs)
        t2m_log_likelihood[t2m_nan_mask] = 0.0

        rain_loss = -tp_log_likelihood.mean()
        temperature_loss =  -t2m_log_likelihood.mean()
        loss = rain_loss + temperature_loss
        
        loss.backward()
                
        optimizer.step()
        optimizer.zero_grad()
        
        train_losses.append(float(loss.detach()))
        train_temperature_losses.append(float(temperature_loss.detach()))
        train_rain_losses.append(float(rain_loss.detach()))

    train_mean_loss = np.array(train_losses).mean()
    train_mean_rain_loss = np.array(train_rain_losses).mean()
    train_mean_temperature_loss = np.array(train_temperature_losses).mean()
    print(f'Epoch {epoch} train loss: {train_mean_loss}. Temperature: {train_mean_temperature_loss}. Rain: {train_mean_rain_loss}.')
        
        
    model.eval()
    with torch.no_grad():
        val_losses = []
        val_rain_losses = []
        val_t2m_losses = []
        for example in val_dataloader:
            model_tp_mu, model_tp_sigma = example['model_tp_mu'], example['model_t2m_sigma']
            model_t2m_mu, model_t2m_sigma = example['model_t2m_mu'], example['model_t2m_sigma']
            
            t2m_dist, tp_dist = model(model_t2m_mu, model_t2m_sigma, model_tp_mu, model_tp_sigma)
            
            obs_t2m, obs_tp = example['obs_t2m'], example['obs_tp']
            
            tp_obs = example['obs_tp']
            tp_nan_mask = tp_obs.isnan()
            tp_obs[tp_nan_mask] = 0.0
            tp_log_likelihood = tp_dist.log_prob(tp_obs)
            tp_log_likelihood[tp_nan_mask] = 0.0

            t2m_obs = example['obs_t2m']
            t2m_nan_mask = t2m_obs.isnan()
            t2m_obs[t2m_nan_mask] = 0.0
            t2m_log_likelihood = t2m_dist.log_prob(t2m_obs)
            t2m_log_likelihood[t2m_nan_mask] = 0.0
            
            val_rain_loss = -tp_log_likelihood.mean()
            val_temperature_loss =  -t2m_log_likelihood.mean()
            val_loss = val_rain_loss + val_temperature_loss
            
            val_rain_losses.append(val_rain_loss.detach())
            val_t2m_losses.append(val_temperature_loss.detach())
            val_losses.append(val_loss.detach())
        
        
    val_mean_loss = np.array(val_losses).mean()
    val_mean_rain_loss = np.array(val_rain_losses).mean()
    val_mean_temperature_loss = np.array(val_t2m_losses).mean()
    print(f'Epoch {epoch} val loss: {val_mean_loss}. Temperature: {val_mean_temperature_loss}. Rain: {val_mean_rain_loss}.')
    print()

In [None]:
val_rain_losses

In [None]:
epoch_losses

In [None]:
sns.histplot(data=model.t2m_mu_model.weights[1][~t2m_nan_mask[0,0]].detach().numpy())

In [None]:
sns.histplot(data=model.t2m_mu_model.intercept[2][~t2m_nan_mask[0,0]].detach().numpy())

In [None]:
sns.histplot(data=model.tp_mu_model.weights[2][~tp_nan_mask[0,0]].detach().numpy())

In [None]:
sns.histplot(data=model.tp_mu_model.intercept[2][~tp_nan_mask[0,0]].detach().numpy())

In [None]:
plt.imshow(model.tp_mu_model.intercept[2].detach().numpy() + model.tp_mu_model.weights[2].detach().numpy())