In [None]:
%load_ext autoreload
%autoreload 2

# Inference

Rework the model multiplexer notebook so that we include an inference part that produces prediction files.

In [None]:
import collections.abc
import matplotlib.pyplot as plt
import numpy as np
import re
import seaborn as sns
import torch
import torch.distributions
import torch.nn as nn
import torch.utils.data.dataloader
from typing import Union, Callable, Any, Hashable
import xarray as xr

from crims2s.dataset import S2SDataset, TransformedDataset
from crims2s.transform import CompositeTransform, add_biweekly_dim_transform, add_metadata, example_to_pytorch
from crims2s.util import ECMWF_FORECASTS

In [None]:
DATASET = '***BASEDIR***/mlready/2021-08-08-test/'

## make transform to interface dataset w/ linear model

In [None]:
def std_estimator(dataset, dim=None):
    dataset_mean = dataset.mean(dim=dim)
    
    if dim is None:
        dim_sizes = [dataset.sizes[x] for x in dataset_mean.dims]
    elif isinstance(dim, str):
        dim_sizes = dataset.sizes[dim]
    else:
        dim_sizes = [dataset.sizes[x] for x in dim]
    
    n = np.prod(dim_sizes)
    
    return xr.ufuncs.sqrt(xr.ufuncs.square(dataset - dataset_mean).sum(dim=dim) / (n - 1))

In [None]:
def model_to_distribution(model):   
    model_tp_mean = model.tp.isel(lead_time=-1).mean(dim='realization').rename('tp_mu')
    model_tp_std = std_estimator(model.tp.isel(lead_time=-1), dim='realization').rename('tp_sigma')
    
    model_t2m_mean = model.t2m.mean(dim=['lead_time', 'realization']).rename('t2m_mu')
    model_t2m_std = std_estimator(model.t2m, dim=['lead_time', 'realization']).rename('t2m_sigma')
    
    return xr.merge([
        model_tp_mean, model_tp_std, model_t2m_mean, model_t2m_std
    ]).drop('lead_time').rename(biweekly_forecast='lead_time')

In [None]:
def obs_to_biweekly(obs):
    aggregate_obs_tp = obs.pr.sum(dim='lead_time', min_count=2).rename('tp')
    aggregate_obs_t2m = obs.t2m.mean(dim='lead_time')
    return xr.merge([aggregate_obs_tp, aggregate_obs_t2m])

In [None]:
def linear_model_adapter(example):
    example['model'] = model_to_distribution(example['model'])
    example['obs'] = obs_to_biweekly(example['obs'])
    
    return example

In [None]:
filter_re = re.compile('01[0-9]{2}.nc$')
#filter_re = re.compile('0109.nc$')

In [None]:
raw_train_dataset = S2SDataset(DATASET, include_features=False, name_filter=lambda x: filter_re.search(x), years=list(range(2000,2017)))
raw_val_dataset = S2SDataset(DATASET, include_features=False, name_filter=lambda x: filter_re.search(x), years=list(range(2017,2020)))

In [None]:
transform = CompositeTransform([add_biweekly_dim_transform, linear_model_adapter, add_metadata, example_to_pytorch])

train_dataset = TransformedDataset(raw_train_dataset, transform)
val_dataset = TransformedDataset(raw_val_dataset, transform)

In [None]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, num_workers=4, batch_size=None, batch_sampler=None)
val_dataloader = torch.utils.data.DataLoader(val_dataset, num_workers=4, batch_size=None, batch_sampler=None)

In [None]:
class ModelMultiplexer(nn.Module):
    """Dispatch the training examples to multiple models depending on the example.
    For instance, we could use this to use a different model for every monthday forecast.
    
    Because it uses an arbitraty model for every sample, this module does not support batching.
    To use it, it is recommended to disable automatic batching on the dataloader."""
    def __init__(self, key, models):
        """Args:
            key: If a str, used as a key to fetch the model name from the example dict. 
                 If a callable, called on the example and should return to model name to use.
            models: A mapping from model names to model instances. They keys should correspond to what is returned when applying key on the example."""
        super().__init__()
        
        if isinstance(key, str):
            self.key_fn = lambda x: x[key]
        else:
            self.key_fn = key
            
        self.models = nn.ModuleDict(models)
        
    def forward(self, example):      
        model_name = self.key_fn(example)
        model = self.models[model_name]
        
        return model(example)
            

In [None]:
class LinearModel(nn.Module):
    def __init__(self, *shape, fill_weights=0.0, fill_intercept=0.0):
        super().__init__()
        
        self.weights = nn.Parameter(torch.full(shape, fill_weights))
        self.intercept = nn.Parameter(torch.full(shape, fill_intercept))
        
    def forward(self, x):
        return self.intercept + self.weights * x + x

In [None]:
class TempPrecipEMOS(nn.Module):
    def __init__(self, biweekly=False):
        super().__init__()
        
        shape = (3, 121, 240) if biweekly else (121, 240)
        
        self.tp_mu_model = LinearModel(*shape)
        self.tp_sigma_model = LinearModel(*shape, fill_intercept=1.0)
        
        self.t2m_mu_model = LinearModel(*shape)
        self.t2m_sigma_model = LinearModel(*shape, fill_intercept=1.0)
        
    def forward(self, example):
        forecast_tp_mu, forecast_tp_sigma = example['model_tp_mu'], example['model_tp_sigma']
        forecast_t2m_mu, forecast_t2m_sigma = example['model_t2m_mu'], example['model_t2m_sigma']
        
        tp_mu = self.tp_mu_model(forecast_tp_mu)
        tp_sigma = self.tp_sigma_model(forecast_tp_sigma)
        tp_sigma = torch.clip(tp_sigma, min=1e-6)

        t2m_mu = self.t2m_mu_model(forecast_t2m_mu)
        t2m_sigma = self.t2m_sigma_model(forecast_t2m_sigma)
        t2m_sigma = torch.clip(t2m_sigma, min=1e-6)
        
        tp_dist = torch.distributions.Normal(loc=tp_mu, scale=tp_sigma)
        t2m_dist = torch.distributions.Normal(loc=t2m_mu, scale=t2m_sigma)
        
        return t2m_dist, tp_dist

In [None]:
monthdays = [f'{m:02}{d:02}' for m, d in ECMWF_FORECASTS]
weekly_models = {monthday: TempPrecipEMOS(biweekly=True) for monthday in monthdays}

In [None]:
model = ModelMultiplexer('monthday', weekly_models)

In [None]:
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)

In [None]:
for epoch in range(1):
    train_losses = []
    train_temperature_losses = []
    train_rain_losses = []
    
    model.train()
    for example in train_dataloader:
        t2m_dist, tp_dist = model.forward(example)

        tp_obs = example['obs_tp']
        tp_nan_mask = tp_obs.isnan()
        tp_obs[tp_nan_mask] = 0.0
        tp_log_likelihood = tp_dist.log_prob(tp_obs)
        tp_log_likelihood[tp_nan_mask] = 0.0

        t2m_obs = example['obs_t2m']
        t2m_nan_mask = t2m_obs.isnan()
        t2m_obs[t2m_nan_mask] = 0.0
        t2m_log_likelihood = t2m_dist.log_prob(t2m_obs)
        t2m_log_likelihood[t2m_nan_mask] = 0.0

        rain_loss = -tp_log_likelihood.mean()
        temperature_loss =  -t2m_log_likelihood.mean()
        loss = rain_loss + temperature_loss
        
        loss.backward()
                
        optimizer.step()
        optimizer.zero_grad()
        
        train_losses.append(float(loss.detach()))
        train_temperature_losses.append(float(temperature_loss.detach()))
        train_rain_losses.append(float(rain_loss.detach()))

    train_mean_loss = np.array(train_losses).mean()
    train_mean_rain_loss = np.array(train_rain_losses).mean()
    train_mean_temperature_loss = np.array(train_temperature_losses).mean()
    print(f'Epoch {epoch} train loss: {train_mean_loss}. Temperature: {train_mean_temperature_loss}. Rain: {train_mean_rain_loss}.')
        
        
    model.eval()
    with torch.no_grad():
        val_losses = []
        val_rain_losses = []
        val_t2m_losses = []
        for example in val_dataloader:
            t2m_dist, tp_dist = model(example)
            
            obs_t2m, obs_tp = example['obs_t2m'], example['obs_tp']
            
            tp_obs = example['obs_tp']
            tp_nan_mask = tp_obs.isnan()
            tp_obs[tp_nan_mask] = 0.0
            tp_log_likelihood = tp_dist.log_prob(tp_obs)
            tp_log_likelihood[tp_nan_mask] = 0.0

            t2m_obs = example['obs_t2m']
            t2m_nan_mask = t2m_obs.isnan()
            t2m_obs[t2m_nan_mask] = 0.0
            t2m_log_likelihood = t2m_dist.log_prob(t2m_obs)
            t2m_log_likelihood[t2m_nan_mask] = 0.0
            
            val_rain_loss = -tp_log_likelihood.mean()
            val_temperature_loss =  -t2m_log_likelihood.mean()
            val_loss = val_rain_loss + val_temperature_loss
            
            val_rain_losses.append(val_rain_loss.detach())
            val_t2m_losses.append(val_temperature_loss.detach())
            val_losses.append(val_loss.detach())
        
        
    val_mean_loss = np.array(val_losses).mean()
    val_mean_rain_loss = np.array(val_rain_losses).mean()
    val_mean_temperature_loss = np.array(val_t2m_losses).mean()
    print(f'Epoch {epoch} val loss: {val_mean_loss}. Temperature: {val_mean_temperature_loss}. Rain: {val_mean_rain_loss}.')
    print()

## Inference

In [None]:
filter_re = re.compile('01[0-9]{2}.nc$')
#filter_re = re.compile('0109.nc$')

In [None]:
raw_val_dataset = S2SDataset(DATASET, include_features=False, name_filter=lambda x: filter_re.search(x), years=list(range(2017,2020)))

In [None]:
len(raw_val_dataset)

In [None]:
t = CompositeTransform([add_biweekly_dim_transform, linear_model_adapter, add_metadata])
val_dataset = TransformedDataset(raw_val_dataset, t)

Here we have to specify collate_fn `lambda x: x` because if we set it at none, pytorch will use it's own collate_fn, which mangles the datasets and turns them into dictionaries.

In [None]:
val_dataloader = torch.utils.data.DataLoader(val_dataset, num_workers=4, batch_size=None, batch_sampler=None, collate_fn=lambda x: x)

In [None]:
def edges_cdf_to_terciles(edges_cdf):
    return torch.stack([
        edges_cdf[0],
        edges_cdf[1] - edges_cdf[0],
        1.0 - edges_cdf[1],
    ], dim=0)

In [None]:
def compute_edges_cdf_from_distribution(distribution, edges):
    edges_nan_mask = edges.isnan()
    edges[edges_nan_mask] = 0.0
    cdf = distribution.cdf(edges)
    edges[edges_nan_mask] = np.nan
    cdf[t2m_edges_nan_mask] = np.nan
    
    return cdf

In [None]:
def terciles_pytorch_to_xarray(t2m, tp, example_dataset, dims=['category', 'lead_time', 'latitude', 'longitude']):    
    t2m_array = xr.DataArray(
        data=t2m.detach().numpy(), 
        dims=dims,
        name='t2m'
    )
    tp_array = xr.DataArray(
        data=tp.detach().numpy(),
        dims=dims,
        name='tp'
    )
    dataset = xr.Dataset(data_vars={
        't2m': t2m_array,
        'tp': tp_array,
    })

    dataset = dataset.assign_coords({
        'forecast_year': example_forecast.forecast_year.data,
        'forecast_monthday': example_forecast.forecast_monthday.data,
        'lead_time': example_forecast.lead_time.data,
        'valid_time': example_forecast.valid_time,
        'forecast_time': example_forecast.forecast_time.data,
        'latitude': example_forecast.latitude.data,
        'longitude': example_forecast.longitude.data,
        'category': ['below normal', 'near normal', 'above normal'],
    }).expand_dims(['forecast_year', 'forecast_monthday'])
    
    return dataset

In [None]:
with torch.no_grad():
    datasets_of_examples = []
    
    for example in val_dataloader:
        pytorch_example = example_to_pytorch(example)
        t2m_dist, tp_dist = model(pytorch_example)

        t2m_edges = torch.cat([torch.full((2, 1, 121, 240), np.nan), pytorch_example['edges_t2m']], 1)
        t2m_cdf = compute_edges_cdf_from_distribution(t2m_dist, t2m_edges)

        tp_edges = torch.cat([torch.full((2, 1, 121, 240), np.nan), pytorch_example['edges_tp']], 1)
        tp_cdf = compute_edges_cdf_from_distribution(tp_dist, tp_edges)

        t2m_terciles = edges_cdf_to_terciles(t2m_cdf)
        tp_terciles = edges_cdf_to_terciles(tp_cdf)

        example_forecast = example['model']
        
        dataset = terciles_pytorch_to_xarray(t2m_terciles, tp_terciles, example_forecast)
        datasets_of_examples.append(dataset)

In [None]:
ml_prediction = xr.combine_by_coords(datasets_of_examples)

In [None]:
ml_prediction

In [None]:
dataset

In [None]:
dataset.t2m.isel(category=2, lead_time=1).plot()

In [None]:
t2m_array = xr.DataArray(data=t2m_terciles.detach().numpy(), dims=['tercile', 'lead_time', 'latitude', 'longitude'], )


In [None]:
t2m_array

In [None]:
t2m_terciles.shape

In [None]:
pytorch_example['edges_t2m'].shape

In [None]:
cax = plt.imshow(t2m_cdf[0,1].detach().numpy() < t2m_cdf[1,1].detach().numpy())
plt.colorbar(cax)

In [None]:
cax = plt.imshow(t2m_cdf[0,2].detach().numpy())
plt.colorbar(cax)

In [None]:
cax = plt.imshow(t2m_terciles[2,2].detach().numpy())
plt.colorbar(cax)