In [1]:
import sys
sys.path.append('/'.join(sys.path[0].split('/')[:-1]))

In [2]:
import xarray as xr
import matplotlib.pyplot as plt
import numpy as np
import time


import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader

from modules.data import Dataset_WeatherBench_1D, load_test_data
from modules.models import SphericalCNN
from modules.test import create_predictions, compute_weighted_rmse, assess_model

In [3]:
datadir = "/mnt/scratch/students/illorens/data/5.625deg/"

train_years=('1979', '2015')
valid_years=('2016', '2016')
test_years=('2017', '2018')
vars = ['z', 't']

In [4]:
def init_device(model, gpu=None):
    """Initialize device based on cpu/gpu and number of gpu
    Args:
        device (str): cpu or gpu
        ids (list of int or str): list of gpus that should be used
        unet (torch.Module): the model to place on the device(s)
    Raises:
        Exception: There is an error in configuring the cpu or gpu
    Returns:
        torch.Module, torch.device: the model placed on device, the device
    """
    
    if torch.cuda.is_available():
        if gpu is None:
            device = torch.device("cuda")
            model = model.to(device)
            model = nn.DataParallel(model)
        elif len(gpu) == 1:
            device = torch.device("cuda:{}".format(gpu[0]))
            model = model.to(device)
        else:
            device = torch.device("cuda:{}".format(gpu[0]))
            model = model.to(device)
            model = nn.DataParallel(model, device_ids=[int(i) for i in gpu])
    else:
        device = torch.device("cpu")
        model = model.to(device)

    return model, device


def train_model(model, device, train_generator, epochs, lr, validation_data, patience):
    
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, eps=1e-7, weight_decay=0, amsgrad=False)
    #optim.Adam(model.parameters(), lr=lr, eps=1e-7, weight_decay=0, amsgrad=False)
    
    min_val_loss = 1e15
    wait = 0
    stopped_epoch = 0
    stop_training = False
    
    train_losses = []
    val_losses = []
    
    for epoch in range(epochs):
        
        time1 = time.time()
        
        val_loss = 0
        train_loss = 0
        
        model.train()
        for batch_idx, (batch, labels) in enumerate(train_generator):
            # Transfer to GPU
            batch, labels = batch.to(device), labels.to(device)
            
            batch_size = batch.shape[0]
            
            # Model
            output = model(batch)

            loss = criterion(output, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss = train_loss + loss.item() * batch_size
            
        train_loss = train_loss / (len(train_generator.dataset))
        train_losses.append(train_loss)
        
        model.eval()
        with torch.set_grad_enabled(False):
            for batch, labels in validation_data:
                # Transfer to GPU
                batch, labels = batch.to(device), labels.to(device)
                
                batch_size = batch.shape[0]
                
                output = model(batch)

                val_loss = val_loss + criterion(output, labels).item() * batch_size
                
        val_loss = val_loss / (len(validation_data.dataset))
        val_losses.append(val_loss)
        
        time2 = time.time()
        
        # Print stuff
        print('Epoch: {e:3d}/{n_e:3d}  - loss: {l:.3f}  - val_loss: {v_l:.5f}  - time: {t:2f}'
              .format(e=epoch+1, n_e=epochs, l=train_loss, v_l=val_loss, t=time2-time1))
                
            
        if (val_loss - min_val_loss) < 0:
            min_val_loss = val_loss
            wait = 0
        else:
            if wait >= patience:
                stopped_epoch = epoch + 1
                stop_training = True
            wait += 1
        
        if stop_training:
            print('Epoch {e:3d}: early stopping'.format(e=stopped_epoch))
            return train_losses, val_losses
        
    return train_losses, val_losses

## 3 day prediction

### Data ###

In [5]:
batch_size = 128

In [6]:
lead_time = 72
model_save_fn = "/mnt/scratch/students/illorens/data/predictions/models/torch_fccnn_3d.h5"
pred_save_fn = "/mnt/scratch/students/illorens/data/predictions/torch_fccnn_3d.nc"

# 1. Open dataset and create data generators
z = xr.open_mfdataset(f'{datadir}geopotential_500/*.nc', combine='by_coords')
t = xr.open_mfdataset(f'{datadir}temperature_850/*.nc', combine='by_coords')
ds = xr.merge([z, t], compat='override')  # Override level. discarded later anyway.

lat = ds.dims['lat']
lon = ds.dims['lon']
nodes = lat * lon
ratio = lon/lat

ds_train = ds.sel(time=slice(*train_years))
ds_valid = ds.sel(time=slice(*valid_years))
ds_test = ds.sel(time=slice(*test_years))

dic = {var: None for var in vars}

dataset_train = Dataset_WeatherBench_1D(ds_train, dic, lead_time)
dataset_valid = Dataset_WeatherBench_1D(ds_valid, dic, lead_time, mean=dataset_train.mean, std=dataset_train.std)
dataset_test = Dataset_WeatherBench_1D(ds_test, dic, lead_time, mean=dataset_train.mean, std=dataset_train.std)

dg_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=1)
dg_valid = DataLoader(dataset_valid, batch_size=batch_size, shuffle=False, num_workers=1)
dg_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=1)

Loading data into RAM
Loading data into RAM
Loading data into RAM


### Models ###

In [7]:
lr = 1e-4
epochs = 100
patience = 3
gpu = 1

#### Spherical Graph CNN, WeatherBench architecture ####

In [None]:
spherical_cnn = SphericalCNN(N=nodes, ratio=ratio, in_channels=2, out_channels=2, kernel_size=5)
spherical_cnn, device = init_device(spherical_cnn, gpu=[1])

# Train model
train_loss, val_loss = train_model(spherical_cnn, device, dg_train, epochs=epochs, lr=lr, 
                                   validation_data=dg_valid, patience=patience)

print(f'Saving model weights: {model_save_fn}')
torch.save(model.state_dict(), model_save_fn)

# Create predictions
pred = create_predictions(model, dg_test, lat, lon, mean=dataset_train.mean, std=dataset_train.std)
print(f'Saving predictions: {pred_save_fn}')
pred.to_netcdf(pred_save_fn)

In [None]:
plt.plot(train_loss, label='Training loss')
plt.plot(val_loss, label='Validation loss')
plt.xlabel('Epochs')
plt.ylabel('MSE Loss')
plt.legend()
plt.show()

In [None]:
valid = load_test_data(datadir, lead_time)
compute_weighted_rmse(pred, valid).load()

In [None]:
assess_model(pred, valid, 'Spherical_CNN_weatherbench_architecture_3_days')