# NCI WeatherBench-3d: Train a CNN (PyTroch)

In this notebook we will go through all the steps required to train a fully convolutional neural network.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:

import os
import warnings
warnings.filterwarnings('ignore')
from score import *
import numpy as np
import xarray as xr
from dask.diagnostics import ProgressBar
from datetime import datetime
from dask.distributed import Client

import torch
import torch.nn as nn
import torch.nn.functional as F
import train_nn_pytorch

In [None]:
client = Client(n_workers=12, threads_per_worker=1)  
client

# Dataset: '5.625'

In [None]:
%%time
print( f'[{datetime.now().replace(microsecond=0)}]' )

res     = '5.625'
datadir = f'/g/data/wb00/NCI-Weatherbench/{res}deg' 
print ("Data loading..." )

years = list(range(1999, 2022+1))
print (years)
z_files = [ file for year in years for file in glob.glob (fr'{datadir}/geopotential/*{year}*')  ] 
t_files = [ file for year in years for file in glob.glob (fr'{datadir}/temperature/*{year}*')    ] 

z = xr.open_mfdataset(z_files, combine='by_coords', parallel=True, chunks={'time': 10}).z.sel(level=[500]).load() 
t = xr.open_mfdataset(t_files, combine='by_coords', parallel=True, chunks={'time': 10}).t.sel(level=[850]).drop('level').load() 
datasets = [z, t]
print ("Merging ... ")
ds = xr.merge(datasets).compute()  

In [None]:
z['time'][:10]

In [None]:
z['time'][-10:]

In [None]:
def data_generate(ds, lead_time, batch_size, variables, train_years, valid_years, test_years):
    ds_train = ds.sel(time=slice(*train_years))
    ds_valid = ds.sel(time=slice(*valid_years))
    ds_test  = ds.sel(time=slice(*test_years))

    print ("Data generation ... ")
    dic = {var: 500 for var in variables} #vars}
    dg_train = train_nn_pytorch.DataGenerator(ds_train, dic, lead_time, batch_size=batch_size)
    dg_valid = train_nn_pytorch.DataGenerator(ds_valid, dic, lead_time, batch_size=batch_size, mean=dg_train.mean,
                         std=dg_train.std, shuffle=False)
    dg_test =  train_nn_pytorch.DataGenerator(ds_test, dic, lead_time, batch_size=batch_size, mean=dg_train.mean,
                         std=dg_train.std, shuffle=False)

    print(f'Mean = {dg_train.mean}; Std = {dg_train.std}')
    return dg_train, dg_valid, dg_test

def train(dg_train_generator, dg_valid_generator, dg_test_generator, model_save_fn):
    print ("Train model ... ") 
    
    model = train_nn_pytorch.Model_cnn(channels, kernels)
    
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    print('cuda' if torch.cuda.is_available() else 'cpu')
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)    
    
    patience = n_patience
    best_loss = float('inf')

    train_losses = []
    valid_losses = []
    avg_train_losses = []
    avg_valid_losses = []     

    for epoch in range(n_epochs):
        for X, y in dg_train_generator:
            model.train()
            X  = torch.squeeze( X )
            X = X.to(device)
            y_pred = model (X)  
            y  = torch.squeeze( y )
            y = y.to(device)
            loss = loss_fn(y_pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())

        model.eval()
        with torch.no_grad():
            for X_v, y_v in dg_valid_generator:
                X_v  = torch.squeeze( X_v )
                y_v  = torch.squeeze( y_v )
                X_v = X_v.to(device)
                y_v = y_v.to(device)

                y_v_pred = model(X_v)
                loss_v = loss_fn(y_v_pred, y_v)
                valid_losses.append(loss_v.item())

        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)

        epoch_len = len(str(n_epochs))

        print_msg = (f'[{epoch:>{epoch_len}}/{n_epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.5f} ' +
                     f'valid_loss: {valid_loss:.5f}')

        print(print_msg)

        # clear  
        train_losses = []
        valid_losses = []  

        if loss_v < best_loss:  
            best_loss = loss_v
            patience = n_patience  
        else:
            patience -= 1
            if patience == 0:
                print(f'Early stopping')
                break               

    return model  
        
def evaluate(pred, iterative, test_years):     
    print("Evaluating forecast ...")
    valid_years = list( range(int(test_years[0]), int(test_years[1])+1 ))
    print ('all test_years:', valid_years)
    z500_valid_files = [ file for year in valid_years for file in glob.glob (fr'{datadir}/geopotential/*{year}*') ] 
    t850_valid_files = [ file for year in valid_years for file in glob.glob (fr'{datadir}/temperature/*{year}*')  ]     

    z500_valid = load_test_data(z500_valid_files, 'z', slice(*test_years)) 
    t850_valid = load_test_data(t850_valid_files, 't', slice(*test_years))     
    
    valid      = xr.merge([z500_valid, t850_valid], compat='override').compute()
    
    print(train_nn_pytorch.evaluate_iterative_forecast(pred, valid, compute_weighted_rmse).load() if iterative \
                                                else train_nn_pytorch.compute_weighted_rmse(pred, valid).compute())


In [None]:
test_years =('2021', '2022')
test_years = list (range(int (test_years[0]), int(test_years[1])+1 ))
print ('test_years:', test_years)

# 72 hours (3 days)

In [None]:
%%time
print( f'[{datetime.now().replace(microsecond=0)}]' )
print (60*"-")
train_years=('1999', '2015')
valid_years=('2016', '2020')
test_years =('2021', '2022')
batch_size = 32
variables = ('z', 't')
lead_time = 72
channels = [2, 64, 64, 64, 64, 64, 2]
kernels = [5, 5, 5, 5, 5, 5]
dr = 0
save_prefix = 'PyTorch_NCI_tutorial' 
print ('save_prefix:', save_prefix)
model_save_fn = f'/scratch/vp91/mah900/NCI-Weatherbench/pred_dir/saved_models/{save_prefix}_cnn_3d.h5'
pred_save_fn  = f'/scratch/vp91/mah900/NCI-Weatherbench/pred_dir/{save_prefix}_cnn_3d.nc'
lr = 1e-4
iterative = False
iterative_lead_time = None
n_epochs = 100   
n_patience = 4

In [None]:
%%time
print( f'[{datetime.now().replace(microsecond=0)}]' )
dg_train, dg_valid, dg_test = data_generate(ds, lead_time, batch_size,
                                           variables, train_years, valid_years, test_years)

params = {'batch_size': 1,
          'shuffle': True,
          'num_workers': 6}
dg_train_generator = torch.utils.data.DataLoader(dg_train, **params)
params['shuffle'] = False
dg_valid_generator = torch.utils.data.DataLoader(dg_valid, **params)
dg_test_generator = torch.utils.data.DataLoader(dg_test, **params)

model = train(dg_train_generator, dg_valid_generator, dg_test_generator, model_save_fn)

print(f'Saving model weights: {model_save_fn}')
torch.save(model.state_dict(), model_save_fn)

In [None]:
%%time
print( f'[{datetime.now().replace(microsecond=0)}]' )
pred = train_nn_pytorch.create_predictions(model, dg_test_generator, 
                        mean=dg_train.mean.values, std=dg_train.std.values,
                        var_dict=dg_test.var_dict,
                        valid_time=dg_test.valid_time, 
                        lat=dg_test.ds.lat, 
                        lon=dg_test.ds.lon 
                       )
print(f'Saving predictions: {pred_save_fn}')
pred.to_netcdf(pred_save_fn)

evaluate(pred, iterative, test_years)

# 120 hours (5 days)

In [None]:
%%time
print( f'[{datetime.now().replace(microsecond=0)}]' )

print (60*"-")
batch_size = 32
lead_time = 120
print ('save_prefix:', save_prefix)
model_save_fn= f'/scratch/vp91/mah900/NCI-Weatherbench/pred_dir/saved_models/{save_prefix}_cnn_5d.h5'
pred_save_fn = f'/scratch/vp91/mah900/NCI-Weatherbench/pred_dir/{save_prefix}_cnn_5d.nc'
lr = 1e-4
iterative = False
iterative_lead_time = None

In [None]:
%%time
print( f'[{datetime.now().replace(microsecond=0)}]' )

dg_train, dg_valid, dg_test = data_generate(ds, lead_time, batch_size,
                                           variables, train_years, valid_years, test_years)

params = {'batch_size': 1,
          'shuffle': True,
          'num_workers': 6}
dg_train_generator = torch.utils.data.DataLoader(dg_train, **params)
params['shuffle'] = False
dg_valid_generator = torch.utils.data.DataLoader(dg_valid, **params)
dg_test_generator = torch.utils.data.DataLoader(dg_test, **params)

model = train(dg_train_generator, dg_valid_generator, dg_test_generator, model_save_fn)

print(f'Saving model weights: {model_save_fn}')
torch.save(model.state_dict(), model_save_fn)

In [None]:
%%time
print( f'[{datetime.now().replace(microsecond=0)}]' )
pred = train_nn_pytorch.create_predictions(model, dg_test_generator, 
                        mean=dg_train.mean.values, std=dg_train.std.values,
                        var_dict=dg_test.var_dict,
                        valid_time=dg_test.valid_time, 
                        lat=dg_test.ds.lat, 
                        lon=dg_test.ds.lon 
                       )
print(f'Saving predictions: {pred_save_fn}')
pred.to_netcdf(pred_save_fn)

evaluate(pred, iterative, test_years)

# fccnn_6h_iter

In [None]:
%%time
print( f'[{datetime.now().replace(microsecond=0)}]' )

print (60*"-")
batch_size = 32
lead_time = 6
iterative_lead_time = 5 * 24 
print ('save_prefix:', save_prefix)
model_save_fn=f'/scratch/vp91/mah900/NCI-Weatherbench/pred_dir/saved_models/{save_prefix}_fccnn_6h_iter.h5'
pred_save_fn =f'/scratch/vp91/mah900/NCI-Weatherbench/pred_dir/{save_prefix}_fccnn_6h_iter.nc'
lr = 1e-4
iterative = True

In [None]:
%%time
print( f'[{datetime.now().replace(microsecond=0)}]' )
dg_train, dg_valid, dg_test = data_generate(ds, lead_time, batch_size,
                                           variables, train_years, valid_years, test_years)

params = {'batch_size': 1,
          'shuffle': True,
          'num_workers': 6}
dg_train_generator = torch.utils.data.DataLoader(dg_train, **params)
params['shuffle'] = False
dg_valid_generator = torch.utils.data.DataLoader(dg_valid, **params)
dg_test_generator = torch.utils.data.DataLoader(dg_test, **params)

model = train(dg_train_generator, dg_valid_generator, dg_test_generator, model_save_fn)

print(f'Saving model weights: {model_save_fn}')
torch.save(model.state_dict(), model_save_fn)

In [None]:
%%time
print( f'[{datetime.now().replace(microsecond=0)}]' )
    
pred = train_nn_pytorch.create_iterative_predictions(model, dg_test_generator, 
                                        max_lead_time= 5*24,
                                        mean=dg_train.mean.values, std=dg_train.std.values,
                                        var_dict=dg_test.var_dict,
                                        valid_time=dg_test.valid_time, 
                                        lat=dg_test.ds.lat, 
                                        lon=dg_test.ds.lon,
                                        state=dg_test.data[:dg_test.n_samples], 
                                        lead_time=dg_test.lead_time, 
                                        init_time=dg_test.init_time
                                        )
print(f'Saving predictions: {pred_save_fn}')
pred.to_netcdf(pred_save_fn)

evaluate(pred, iterative, test_years)

In [None]:
pred