# dev
- reimplementing the 'Quickstart' notebook from S. Rasp for training in pytorch

# loading data

In [None]:
import numpy as np
import xarray as xr
import torch
from src.train_nn_pytorch import Dataset

if torch.cuda.is_available():
    print('using CUDA !')
    device = torch.device("cuda")
    torch.set_default_tensor_type("torch.cuda.FloatTensor")
else:
    print("CUDA not available")
    device = torch.device("cpu")
    torch.set_default_tensor_type("torch.FloatTensor")

lead_time = 3*24
var_dict = {'z': None, 't': None}
batch_size = 32

datadir = '/gpfs/work/nonnenma/data/forecast_predictability/weatherbench/5_625deg/'
res_dir = '/gpfs/work/nonnenma/results/forecast_predictability/weatherbench/5_625deg/'

z500 = xr.open_mfdataset(f'{datadir}geopotential_500/*.nc', combine='by_coords')
t850 = xr.open_mfdataset(f'{datadir}temperature_850/*.nc', combine='by_coords')
dataset_list = [z500, t850]
x = xr.merge(dataset_list, compat='override')
n_channels = len(dataset_list) # = 1 if only loading one of geopotential Z500 and temperature T850

# tbd: separating train and test datasets / loaders should be avoidable with the start/end arguments of Dataset!

dg_train = Dataset(x.sel(time=slice('2015', '2015')), var_dict, lead_time, normalize=True)
train_loader = torch.utils.data.DataLoader(
    dg_train,
    batch_size=batch_size,
    drop_last=True)

dg_validation =  Dataset(x.sel(time=slice('2016', '2016')), var_dict, lead_time,
                        mean=dg_train.mean, std=dg_train.std, normalize=True)
test_validation = torch.utils.data.DataLoader(
    dg_validation,
    batch_size=batch_size,
    drop_last=False)

# define model

In [None]:
import torch.nn.functional as F

class PeriodicConv2D(torch.nn.Conv2d):
    """ Implementing 2D convolutional layer with mixed zero- and circular padding.
    Uses circular padding along last axis (W) and zero-padding on second-last axis (H)
    
    """
    def conv2d_forward(self, input, weight):
        if self.padding_mode == 'circular':
            expanded_padding_circ = ( (self.padding[0] + 1) // 2, self.padding[0] // 2, 0, 0)
            expanded_padding_zero = ( 0, 0, (self.padding[1] + 1) //2, self.padding[1] // 2 )
            return F.conv2d(F.pad(F.pad(input, expanded_padding_circ, mode='circular'), 
                                  expanded_padding_zero, mode='constant'),
                            weight, self.bias, self.stride,
                            (0,0), self.dilation, self.groups)
        return F.conv2d(input, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

class Net(torch.nn.Module):
    
    def __init__(self, filters, kernels, channels, activation):
        super(Net, self).__init__()
        self.layers, in_ = [], channels
        self.activation = activation
        assert not np.any(kernels == 2), 'kernel size 2 not allowed for circular padding'
        in_channels = [channels] + list(filters[:-1])
        self.layers = torch.nn.ModuleList([PeriodicConv2D(i, f, k, padding=(k-1, k-1),
                        padding_mode='circular') for i,f,k in zip(in_channels, filters, kernels)])
        
    def forward(self, x):
        for layer in self.layers[:-1]:
            x = self.activation(layer(x))
        x = self.layers[-1](x)
        return x

net = Net(filters=[64, 64, 64, 64, n_channels], kernels=[5, 5, 5, 5, 5], 
          channels=n_channels, activation=torch.nn.functional.elu)

# train model

In [None]:
import torch.optim as optim
optimizer = optim.Adam(net.parameters(), lr=0.01)

n_epochs = 200
epoch = 0
while True:

    epoch += 1
    print(f'epoch #{epoch}')
    # Train for a single epoch.
    for batch in train_loader:
        optimizer.zero_grad()
        inputs, targets = batch[0].to(device), batch[1].to(device)
        loss = F.mse_loss(net.forward(inputs), targets)
        loss.backward()
        optimizer.step()        

    # tbd: write early stopping from convergence on vakidation data
    if epoch > n_epochs:
        break
        
torch.save(net.state_dict(), res_dir + 'test_fccnn_3d_pytorch.pt')

#net_rec = Net(filters=[64, 64, 64, 64, n_channels], kernels=[5, 5, 5, 5, 5], 
#          channels=n_channels, activation=torch.nn.functional.elu)
#net_rec.load_state_dict(torch.load(res_dir + 'test_fccnn_3d_pytorch.pt'))

In [None]:
net.load_state_dict(torch.load(res_dir + 'test_fccnn_3d_pytorch.pt', map_location=torch.device('cpu')))

# debug

### Create a prediction and compute score

Now that we have a model (albeit a crappy one) we can create a prediction. For this we need to create a forecast for each forecast initialization time in the testing range (2017-2018) and unnormalize it. We then convert the forecasts to a Xarray dataset which allows us to easily compute the RMSE. All of this is taken care of in the `create_predictions()` function.

In [None]:
def create_predictions(model, dg):
    """Create non-iterative predictions"""
    preds = net.forward(torch.tensor(dg[np.arange(dg.__len__())][0]))
    # Unnormalize
    if dg.normalize:
        preds = preds.detach().numpy() * dg.std.values[None,:,None,None] + dg.mean.values[None,:,None,None]
    das = []
    lev_idx = 0
    for var, levels in dg.var_dict.items():
        if levels is None:
            das.append(xr.DataArray(
                preds[:, lev_idx, :, :],
                dims=['time', 'lat', 'lon'],
                coords={'time': dg.valid_time, 'lat': dg.ds.lat, 'lon': dg.ds.lon},
                name=var
            ))
            lev_idx += 1
        else:
            nlevs = len(levels)
            das.append(xr.DataArray(
                preds[:, lev_idx:lev_idx+nlevs, :, :],
                dims=['time', 'level' 'lat', 'lon'],
                coords={'time': dg.valid_time, 'level': dg.ds.level, 'lat': dg.ds.lat, 'lon': dg.ds.lon},
                name=var
            ))
            lev_idx += nlevs
    return xr.merge(das, compat='override')

dg_test =  Dataset(x.sel(time=slice('2017', '2018')).isel(time=slice(0, None, 12)), var_dict, lead_time,
                        mean=dg_train.mean, std=dg_train.std, normalize=True)
dg_test.valid_time = dg_test.data.isel(time=slice(lead_time, None)).time
preds = create_predictions(net, dg_test)

In [None]:
from src.score import compute_weighted_rmse, load_test_data

z500_test = load_test_data(f'{datadir}geopotential_500/', 'z')
rmse_z = compute_weighted_rmse(preds.z, z500_test.isel(time=slice(lead_time, None))).load()
rmse_z

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
time = '2017-03-02T00'
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15,5))
z500_test.sel(time=time).plot(ax=ax1)
preds.sel(time=time).z.plot(ax=ax2);

# The End

This is the end of the quickstart guide. Please refer to the Jupyter notebooks in the `notebooks` directory for more examples. If you have questions, feel free to ask them as a Github Issue.