# Tutorial 2: Implementing the model and training pipeline

In [None]:
#!pip install -q zarr torchdata zen3geo dask[distributed] intake xarray fsspec aiohttp regionmask --upgrade
#!pip install -q git+https://github.com/carbonplan/cmip6-downscaling.git@1.0
#!pip install -q git+https://github.com/xarray-contrib/xbatcher.git@463546e7739e68b10f1ae456fb910a1628de1e5c

In [None]:
import os
import dask
import time
import torch
import torchdata
import intake
import regionmask
import xbatcher
import zen3geo as zg
import xarray as xr
import numpy as np
from glob import glob
import matplotlib.pyplot as plt
import warnings

from torch import nn
from tqdm.autonotebook import tqdm
from functools import partial
from dask.distributed import Client, LocalCluster
from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
from torchdata.datapipes.utils import StreamWrapper
from torchdata.dataloader2 import DataLoader2
from torch.utils.data import DataLoader
from dask.diagnostics import ProgressBar

warnings.filterwarnings('ignore')
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float32

In [None]:
from src.datapipes import *

In [None]:
class LSTMOutput(nn.Module):
    def __init__(self, out_len=1):
        super().__init__()
        self.out_len = out_len
        
    def forward(self,x):
        # A stupid hack to get around the fact that nn.LSTM 
        # returns (output, (hn, cn))
        # Output shape (batch, sequence_length, hidden)
        tensor, _ = x
        # Now just grab the last index on the sequence lenght
        # Reshape shape (batch, hidden)
        return tensor[:, -self.out_len:, :]

In [None]:
ds = merge_data()
in_vars = ['pr',  'tasmax',  'tasmin',  'elevation',  'aspect_cosine']
out_vars = ['swe']
varlist = ['mask'] + in_vars + out_vars
input_sequence_length = 180  
output_sequence_length = 1
output_selector = {'time': slice(-output_sequence_length, None)}
input_dims={'time': input_sequence_length}
batch_dims={'lat': 30, 'lon': 30}
input_overlap={'time': 45} 

convert = partial(
    stack_split_convert, 
    in_vars=in_vars, 
    out_vars=out_vars, 
    out_selectors=output_selector,
    device=DEVICE
)

In [None]:
region = ['EEU']
hidden_size = 256
num_layers = 2
dropout = 0.25
train_period = slice('1985', '2015')
base_name = f'regional_{'-'.join(regions)}_lstm_h{hidden_size}_d{num_layers}'

model = nn.Sequential(
    nn.LSTM(
        input_size=len(in_vars), 
        hidden_size=hidden_size, 
        batch_first=True,
        num_layers=num_layers,
        dropout=dropout,
    ),
    LSTMOutput(output_sequence_length),
    nn.Linear(in_features=hidden_size, out_features=len(out_vars)),
    nn.SELU()
).float()
model = model.to(DEVICE)

opt = torch.optim.AdamW(model.parameters(), lr=5e-5)
loss_fun = nn.MSELoss()  

In [None]:
dp = RegionalSubsetterPipe(
    ds[varlist].sel(time=train_period).astype(np.float32),
    selected_regions=regions,
)
dp = dp.slice_with_xbatcher(
    input_dims=input_dims,
    batch_dims=batch_dims,
    input_overlap=input_overlap,
    preload_batch=False
)
dp = dp.map(filter_batch)
dp = dp.map(transform_batch)
dp = dp.map(convert)

In [None]:
def train_epoch(model, datapipe, loss_fun, optimizer):
    tot_loss = 0.0
    for i, (x, y) in tqdm(enumerate(dp)):
        if not len(x): continue
        opt.zero_grad()
        yhat = model(x)
        loss = loss_fun(yhat, y)
        if not np.isnan(loss.cpu().detach().numpy()):
            loss.backward()
            optimizer.step()
            tot_loss += loss.cpu().detach().numpy()
    return tot_loss

In [None]:
all_loss = []
max_epochs = 20
for e in tqdm(range(max_epochs)):
    loss = train_epoch(model, dp, loss_fun, opt)
    torch.save(
        model.state_dict(), 
        f'./logging/model_checkpoints/pt_files/{base_name}_e{e+starting_epoch:04}.pt'
    )
    with open(f'./logging/loss_{base_name}.txt', 'a') as f:
        f.writelines([f'{loss}\n'])
    starting_epoch += 1

In [None]:
with open(f'../logging/loss_{base_name}.txt', 'r') as f:
     txt = f.readlines()

In [None]:
stack_loss = np.array([float(t) for t in txt])
plt.plot(stack_loss[15:])
plt.semilogy()