In [None]:
%load_ext autoreload
%autoreload 2

# Basic transforms

Create basic transforms that are going to be useful when training models

In [None]:
import dask.distributed
import logging
import matplotlib.pyplot as plt
import pathlib
import torch
import torch.nn as nn
import xarray as xr

In [None]:
from crims2s.dataset import S2SDataset
from crims2s.transform import add_biweekly_dim_transform, example_to_pytorch
from crims2s.util import add_biweekly_dim


In [None]:
ML_DATASET = '***BASEDIR***/mlready/2021-08-07-test'

In [None]:
ml_dataset_path = pathlib.Path(ML_DATASET)
ml_files = [x for x in ml_dataset_path.iterdir() if x.name.endswith('0102.nc')]

In [None]:
ml_files[:5]

In [None]:
def rename_to_lead_time(dataset):
    return dataset.rename_dims({'biweekly_forecast': 'lead_time'}).assign_coords(lead_time=dataset.biweekly_forecast.data).drop('biweekly_forecast')

In [None]:
def aggregate_weekly(example):
    obs = example['obs']
    new_obs = aggregate_obs_weekly(obs)
    
    model = example['model']
    new_model = aggregate_model_weekly(model)
        
    new_features = example['features'].mean(dim='lead_time')
    
    example['features'] = rename_to_lead_time(new_features)
    example['model'] = rename_to_lead_time(new_model.drop('lead_time'))
    example['obs'] = rename_to_lead_time(new_obs)
    
    return example

In [None]:
def aggregate_model_weekly(model):
    aggregate_model_tp = model.tp.isel(lead_time=-1)
    aggregate_model_t2m = model.t2m.mean(dim='lead_time', skipna=True)
    
    return xr.merge([aggregate_model_tp, aggregate_model_t2m])

In [None]:
def aggregate_obs_weekly(obs):
    aggregate_obs_tp = obs.pr.sum(dim='lead_time', min_count=2).rename('tp')
    aggregate_obs_t2m = obs.t2m.mean(dim='lead_time')
    return xr.merge([aggregate_obs_tp, aggregate_obs_t2m])

In [None]:
dataset = S2SDataset(ML_DATASET, filter_str='0102.nc')

In [None]:
opened = dataset[0]

In [None]:
biweekly = add_biweekly_dim_transform(opened)

In [None]:
aggregated = aggregate_weekly(biweekly)

In [None]:
steps = [opened, biweekly, aggregated]

In [None]:
for s in steps:
    print(s.keys())

In [None]:
opened['obs']

In [None]:
biweekly['features']

In [None]:
aggregated['obs']

In [None]:
aggregated['obs'].tp.isel(lead_time=0).plot()

In [None]:
aggregated['model'].tp.isel(lead_time=0, realization=1).clip(min=0.0).plot()

In [None]:
aggregated['features']

## Check if t2m fits with date in raw obs

In [None]:
opened_obs = opened['obs']

In [None]:
raw_obs = xr.open_dataset('***BASEDIR***/obs-arlan-processed-2021-08-07/t2m.nc')

In [None]:
opened_obs

In [None]:
slice_opened_obs = opened_obs.swap_dims(lead_time='valid_time').sel(valid_time='2000-02-03')

In [None]:
slice_raw_obs = raw_obs.sel(time='2000-02-03')

In [None]:
(slice_opened_obs.t2m - slice_raw_obs.t2m).plot()