In [None]:
%load_ext autoreload
%autoreload 2

# Linear Model

Train a linear model for a single dayofyear using one of the ml datasets we made.

In [None]:
import dask.distributed
import logging
import matplotlib.pyplot as plt
import pathlib
import torch
import torch.nn as nn
import xarray as xr

In [None]:
from crims2s.util import add_biweekly_dim

In [None]:
ML_DATASET = '***BASEDIR***/mlready/2021-08-07-test'

In [None]:
ml_dataset_path = pathlib.Path(ML_DATASET)
ml_files = [x for x in ml_dataset_path.iterdir() if x.name.endswith('0102.nc')]

In [None]:
ml_files

In [None]:
def apply_to_all(transform, example):
    """Utility function to apply a transform on all the kews of an example."""
    new_example = {}
    for k in example:
        new_example[k] = transform(example[k])
    
    return new_example

In [None]:
def add_biweekly_dim_transform(example):
    return apply_to_all(add_biweekly_dim, example)

In [None]:
def rename_to_lead_time(dataset):
    return dataset.rename_dims({'biweekly_forecast': 'lead_time'}).assign_coords(lead_time=dataset.biweekly_forecast.data)

In [None]:
def aggregate_weekly(example):
    obs = example['obs']
    new_obs = aggregate_obs_weekly(obs)
    
    model = example['model']
    new_model = aggregate_model_weekly(model)
        
    new_features = example['features'].mean(dim='lead_time')
    
    example['features'] = new_features
    example['model'] = new_model
    example['obs'] = new_obs
    
    return example

In [None]:
def aggregate_model_weekly(model):
    aggregate_model_tp = model.tp.isel(lead_time=-1)
    aggregate_model_t2m = model.t2m.mean(dim='lead_time', skipna=True)
    
    return xr.merge([aggregate_model_tp, aggregate_model_t2m])

In [None]:
def aggregate_obs_weekly(obs):
    aggregate_obs_tp = obs.pr.sum(dim='lead_time', skipna=False).rename('tp')
    aggregate_obs_t2m = obs.t2m.mean(dim='lead_time')
    return xr.merge([aggregate_obs_tp, aggregate_obs_t2m])

In [None]:
def s2s_to_pytorch(example):
    obs = example['obs']
    model = example['model']
    features = example['features']
    
    return {
        'obs_t2m': torch.from_numpy(obs.t2m.data),
        'obs_tp': torch.from_numpy(obs.tp.data),
        'features': torch.from_numpy(features.x.data),
        'model_t2m': torch.from_numpy(model.t2m.data),
        'model_tp': torch.from_numpy(model.tp.data),
    }

In [None]:
class CompositeTransform:
    def __init__(self, transforms):
        self.transforms = transforms
        
    def __call__(self, example):
        transformed_example = example
        for t in self.transforms:
            transformed_example = t(transformed_example)
            
        return transformed_example

In [None]:
class TransformedDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        return self.transform(self.dataset[idx])

In [None]:
class S2SDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_dir, filter_str=None):
        dataset_path = pathlib.Path(dataset_dir)
        self.files = [x for x in dataset_path.iterdir() if filter_str is None or filter_str in x.name]
        
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        f = self.files[idx]
        features = xr.open_dataset(f, group='/x')
        obs = xr.open_dataset(f, group='/obs')
        model = xr.open_dataset(f, group='/model')
        target = xr.open_dataset(f, group='/y')
        
        return {'features': features, 'obs': obs, 'model': model, 'target': target}

In [None]:
dataset = S2SDataset(ML_DATASET, filter_str='0102.nc')

In [None]:
dataset_model = dataset[0]['model']

In [None]:
dataset_model

In [None]:
transformed = add_biweekly_dim_transform(dataset[0])

In [None]:
transformed['model']

In [None]:
biweekly_model = transformed['model']

In [None]:
biweekly_model.isel(realization=0, biweekly_forecast=2, lead_time=-1).tp.plot()

In [None]:
transformed = aggregate_weekly(add_biweekly_dim_transform(dataset[0]))

In [None]:
transformed['model'].isel(lead_time=2, realization=10).tp.plot()

In [None]:
transformed['obs'].tp.isel(lead_time=2).plot()

In [None]:
s2s_to_pytorch(transformed)

In [None]:
#transform = CompositeTransform([add_biweekly_dim_transform, aggregate_weekly, s2s_to_pytorch])
#dataset = TransformedDataset(S2SDataset(ML_DATASET, filter_str='0102.nc'), transform)

In [None]:
dataset[0]['obs_tp'].shape

In [None]:
example['features'].shape

In [None]:
example['obs_t2m'].shape

In [None]:
class LinearModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        t2m_mu_weights = nn.Parameter(torch.rand(2, 121, 240, 11, 13))
        t2m_mu_bias = nn.Parameter(torch.rand(2, 121, 240))
        
        t2m_sigma_weights = nn.Parameter(torch.rand(2, 121, 240, 11, 13))
        t2m_sigma_bias = nn.Parameter(torch.rand(2, 121, 240))

        
    def forward(self, example):
        

In [None]:
t2m_weights = torch.rand(2, 121, 240, 11, 13, requires_grad=True)

In [None]:
t2m_correction = torch.einsum('wijmv,wijmv->wij', example['features'], weights)

In [None]:
t2m_bias = torch.rand(2, 121, 240, requires_grad=True)

In [None]:
example['model_t2m'].shape

In [None]:
pred = example['model_t2m'] + t2m_correction + t2m_bias 

In [None]:
pred

In [None]:
plt.imshow(pred.detach().numpy()[1])

In [None]:
plt.imshow(example['obs_t2m'][0])

In [None]:
opened = dataset[0]

In [None]:
biweekly = add_biweekly_dim_transform(opened)

In [None]:
aggregated = aggregate_weekly(biweekly)

In [None]:
opened['obs'].isel(lead_time=10).pr.plot()

In [None]:
opened['obs'].isel(lead_time=10).t2m.plot()

In [None]:
biweekly['obs']

In [None]:
biweekly['obs'].isel(biweekly_forecast=0, lead_time=3).pr.plot()

In [None]:
biweekly['obs']

In [None]:
biweekly['obs'].isel(biweekly_forecast=0, lead_time=0).t2m.plot()

In [None]:
aggregated['obs'].isel(lead_time=2).tp.plot()

In [None]:
aggregated['obs'].isel(lead_time=2).t2m.plot()

In [None]:
aggregated['obs']

In [None]:
aggregated['features']

In [None]:
aggregated['model']

In [None]:
aggregated.keys()

## Check if t2m fits with date in raw obs

In [None]:
opened_obs = opened['obs']

In [None]:
raw_obs = xr.open_dataset('***BASEDIR***/obs-arlan-processed-2021-08-07/t2m.nc')

In [None]:
opened_obs

In [None]:
slice_opened_obs = opened_obs.swap_dims(lead_time='valid_time').sel(valid_time='2000-02-03')

In [None]:
slice_raw_obs = raw_obs.sel(time='2000-02-03')

In [None]:
(slice_opened_obs.t2m - slice_raw_obs.t2m).plot()