# Tutorial 3: Evaluating the trained model on the test data

### Outline

* Imports, including library code from previous steps
* Loading the trained model using hyperparameters and weights file
* Setting up the datapipe for the test data
* Some functions for "undoing/inverting" the ETL pipeline (aka recovering spatiotemporal relations)
* Running the trained model in eval mode
* Some basic metrics and analysis

In [None]:
import os
import dask
import time
import torch
import torchdata
import intake
import regionmask
import xbatcher
import zen3geo as zg
import xarray as xr
import numpy as np
from glob import glob
import matplotlib.pyplot as plt
import warnings

from torch import nn
from tqdm.autonotebook import tqdm
from functools import partial
from dask.distributed import Client, LocalCluster
from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
from torchdata.datapipes.utils import StreamWrapper
from torchdata.dataloader2 import DataLoader2
from torch.utils.data import DataLoader
from dask.diagnostics import ProgressBar

warnings.filterwarnings('ignore')
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float32

In [None]:
#dask.config.set(**{'array.slicing.split_large_chunks': False})
#cluster = LocalCluster(
#    n_workers=24,
#    threads_per_worker=1,
#    memory_limit='6GB',
#    dashboard_address=':2345'
#)
#client = Client(cluster)
#client

In [None]:
cat = intake.open_esm_datastore(
  'https://cpdataeuwest.blob.core.windows.net/cp-cmip/version1/catalogs/global-downscaled-cmip6.json'
)

cat_subset = cat.search(
    method="GARD-SV",
    source_id="CanESM5",
    experiment_id="ssp245",
    variable_id=['tasmin', 'tasmax', 'pr'],
    timescale='day',
)
dsets = cat_subset.to_dataset_dict()
met_ds = list(dsets.values())[0]#.chunk({'time': 1, 'lat': 48, 'lon': 48})


--> The keys in the returned dictionary of datasets are constructed as follows:
	'activity_id.institution_id.source_id.experiment_id.timescale.method'


In [None]:
import xarray as xr
import numpy as np
import intake

def merge_data():
    cat = intake.open_esm_datastore(
      'https://cpdataeuwest.blob.core.windows.net/cp-cmip/version1/catalogs/global-downscaled-cmip6.json'
    )
    
    cat_subset = cat.search(
        method="GARD-SV",
        source_id="CanESM5",
        experiment_id="ssp245",
        variable_id=['tasmin', 'tasmax', 'pr'],
        timescale='day',
    )
    dsets = cat_subset.to_dataset_dict()
    met_ds = list(dsets.values())[0]
    met_ds['lon'] = met_ds['lon'] % 360
    mask = xr.open_dataset('https://esiptutorial.blob.core.windows.net/eraswe/mask_10k_household.zarr', engine='zarr')
    terrain = xr.open_dataset('https://esiptutorial.blob.core.windows.net/eraswe/processed_slope_aspect_elevation.zarr', engine='zarr')
    #terrain['lon'] = terrain['lon'] + 180
    met_ds['mask'] = mask['sd'].rename({'latitude': 'lat', 'longitude': 'lon'})
    met_ds = xr.merge([met_ds, terrain])
    met_ds['mask'] = np.logical_and(~np.isnan(met_ds['elevation']), met_ds['mask']>0 ).astype(int)
    #met_ds['lon'] = met_ds['lon'] + 360
    return met_ds


In [None]:
class DatasetPipe(IterDataPipe):
    def __init__(self, ds):
        super().__init__()
        self.ds = ds
        
    def __iter__(self):
        yield self.ds

In [None]:
#@functional_datapipe("subset_regions")
class RegionalSubsetterPipe(IterDataPipe):
        
    def __init__(self, ds, selected_regions, repeat_region=10, preload=True):
        self.current_region = None
        self.ds = ds
        self.repeat_region = repeat_region
        self.selected_regions = [s for s in selected_regions 
                                 for _ in range(self.repeat_region)]
        self.preload = preload
        
    def select_region(self, region): 
        regions = regionmask.defined_regions.ar6.land
        region_id_mask = regions.mask(ds['lon'], ds['lat'])
        reg = np.unique(region_id_mask.values)
        reg = reg[~np.isnan(reg)]
        region_abbrevs = np.array(regions[reg].abbrevs)
        region_names = np.array(regions[reg].names)
        
        selection_mask = 0.0 * region_id_mask.copy()
        region_idx = np.argwhere(region_abbrevs == region)[0][0]
        region_mask = (region_id_mask == region_idx).astype(int)
        return self.ds.where(region_mask, drop=True)

    def __iter__(self):
        for region in self.selected_regions:
            if region != self.current_region:
                self.selected_ds = self.select_region(region)
                if self.preload:
                    self.selected_ds = self.selected_ds.load()
            self.current_region = region
            yield self.selected_ds


def filter_batch(batch):
    return batch.where(batch['mask']>0, drop=True)


def transform_batch(batch):
    scale_means = xr.Dataset()
    scale_means['mask'] = 0.0
    scale_means['swe'] = 0.0
    scale_means['pr'] = 0.00
    scale_means['tasmax'] = 295.0
    scale_means['tasmin'] = 280.0
    scale_means['elevation'] = 630.0
    scale_means['aspect_cosine'] = 0.0
    
    scale_stds = xr.Dataset()
    scale_stds['mask'] = 1.0
    scale_stds['swe'] = 3.0
    scale_stds['pr'] = (3600*25)/100.0
    scale_stds['tasmax'] = 80.0
    scale_stds['tasmin'] = 80.0
    scale_stds['elevation'] = 830.0
    scale_stds['aspect_cosine'] = 1.0
    
    batch = (batch - scale_means) / scale_stds
    return batch


def stack_split_convert(
    batch, 
    in_vars, 
    out_vars, 
    in_selectors={},
    out_selectors={},
    device=None,
    min_samples=100
):
    if len(batch['sample']) > min_samples:
        x = (batch[in_vars]
                 .to_array()
                 .transpose('sample', 'time', 'variable')
                 .isel(**in_selectors))
        x = torch.tensor(x.values).float()
        if device:
            x = x.to(device)
            
        if len(out_vars):
            y = (batch[out_vars]
                      .to_array()
                      .transpose('sample', 'time', 'variable')
                      .isel(**out_selectors))
            y = torch.tensor(y.values).float()
            if device:
                y = y.to(device)
        else: 
            y = torch.tensor([])
    else:
        x, y = torch.tensor([]), torch.tensor([])
    return x, y


class LSTMOutput(nn.Module):
    def __init__(self, out_len=1):
        super().__init__()
        self.out_len = out_len
        
    def forward(self,x):
        # A stupid hack to get around the fact that nn.LSTM 
        # returns (output, (hn, cn))
        # Output shape (batch, sequence_length, hidden)
        tensor, _ = x
        # Now just grab the last index on the sequence lenght
        # Reshape shape (batch, hidden)
        return tensor[:, -self.out_len:, :]

In [None]:
full_ds = merge_data().squeeze().sel(lat=slice(17.5, None))

In [None]:
time_offset=14
ds = full_ds.sel(time=slice('2060', '2069')).isel(time=slice(time_offset, None))

with ProgressBar():
    gcm_pr = (ds['pr'].where(ds['mask'], other=np.nan)
              .max(dim='time')
              .compute())

In [None]:
in_vars = ['pr',  'tasmax',  'tasmin',  'elevation',  'aspect_cosine']
out_vars = []
varlist = ['mask'] + in_vars + out_vars
input_sequence_length = 180  
output_sequence_length = 1
output_selector = {'time': slice(-output_sequence_length, None)}
input_dims={'time': input_sequence_length}
batch_dims={'lat': 290, 'lon': 180}
input_overlap={'time': 14}
           
convert = partial(
    stack_split_convert, 
    in_vars=in_vars, 
    out_vars=out_vars, 
    out_selectors=output_selector,
    device=DEVICE,
    min_samples=90,
)

In [None]:
dp = DatasetPipe(ds)
dp = dp.slice_with_xbatcher(
    input_dims=input_dims,
    batch_dims=batch_dims,
    input_overlap=input_overlap,
    preload_batch=False
)

In [None]:
hidden_size = 256
num_layers = 2
dropout = 0.25
base_name = f'regional_xen_lstm_h{hidden_size}_d{num_layers}'
base_name = f'regional_xna_lstm_h{hidden_size}_d{num_layers}'

model_state = None

state_files = sorted(glob(f'../trained_models/{base_name}*.pt'))
model_state = torch.load(state_files[-1])

model = nn.Sequential(
    nn.LSTM(
        input_size=len(in_vars), 
        hidden_size=hidden_size, 
        batch_first=True,
        num_layers=num_layers,
        dropout=dropout,
    ),
    LSTMOutput(output_sequence_length),
    nn.Linear(in_features=hidden_size, out_features=1),
    nn.SELU()
).float()
if model_state:
    model.load_state_dict(model_state)
model = model.to(DEVICE)
model = model.eval()

In [None]:
all_pred = []

In [None]:
for i, sample in tqdm(enumerate(dp)):
    #if i <= 282: continue
    print(sample['time'].values[-1], 
          sample['lat'].min().values[()],
          sample['lon'].min().values[()], )
    t = sample['time'].isel(time=-1)
    if not t.dt.month.values[()] in [12, 1, 2]: continue
    print('filtering...')
    filtered = filter_batch(sample)
    print('transforming...')
    transformed = transform_batch(filtered)
    print('converting...')
    x, y = convert(transformed)
    if not len(x):
        print()
        continue
    print('running_model...')
    with torch.no_grad():
        yhat = model(x).cpu().numpy().squeeze()
    torch.cuda.empty_cache()
    print('reassembling_prediction...')
    pred = np.nan * xr.zeros_like(sample['mask'])
    pred.loc[sample['mask'] == 1] = yhat
    pred = pred.unstack()
    pred['time'] = sample['time'].isel(time=-1)
    pred.name = 'swe'
    all_pred.append(pred)
    print('writing chunk...')
    pred.to_netcdf(f'../full_chunks/ap_{time_offset}_{i}.nc')