### Check influence of different parameters in performance. 

Check influence of:
- specifying different chunk sizes and chunking along different dimensions
- use already standardized data --> does it save memory?
- use ``` .persist()``` to load data in a distributed way and speed up reading



In [1]:
chunk_size = 521#483*2 #483

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import warnings
warnings.filterwarnings("ignore")

In [4]:
import sys
sys.path.append('/'.join(sys.path[0].split('/')[:-1]))

import xarray as xr
import matplotlib.pyplot as plt
import numpy as np
import time
import os
import healpy as hp
import random

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader

from modules.utils import train_model_2steps, init_device
from modules.data import WeatherBenchDatasetXarrayHealpix
from modules.healpix_models import UNetSphericalHealpix
from modules.test import create_iterative_predictions_healpix
from modules.test import compute_rmse_healpix
from modules.plotting import plot_rmses

datadir = "../data/healpix/"
input_dir = datadir + "5.625deg_nearest/"
model_save_path = datadir + "models/"
pred_save_path = datadir + "predictions/"

train_years = ('1979', '2012')
val_years = ('2013', '2016')
test_years = ('2017', '2018')

nodes = 12*16*16
max_lead_time = 5*24
lead_time = 6
out_features = 2
nb_timesteps = 2
len_sqce = 2
# define time resolution
delta_t = 6

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="2,4"
gpu = [0,1]
num_workers = 10
pin_memory = True
batch_size = 95

nb_epochs = 10
learning_rate = 8e-3

#obs = xr.open_mfdataset(pred_save_path + 'observations_nearest.nc', combine='by_coords', chunks={'time':483})
#rmses_weyn = xr.open_dataset(datadir + 'metrics/rmses_weyn.nc')

In [5]:
from modules.data import WeatherBenchDatasetIterative
class WeatherBenchDatasetXarrayHealpixTemp(Dataset):
    
    """ Dataset used for graph models (1D), where data is loaded from stored numpy arrays.
    
    Parameters
    ----------
    ds : xarray Dataset
        Dataset containing the input data
    out_features : int
        Number of output features
    delta_t : int
        Temporal spacing between samples in temporal sequence (in hours)
    len_sqce : int
        Length of the input and output (predicted) sequences
    years : tuple(str)
        Years used to split the data
    nodes : float
        Number of nodes each sample has
    max_lead_time : int
        Maximum lead time (in case of iterative predictions) in hours
    load : bool
        If true, load dataset to RAM
    mean : np.ndarray of shape 2
        Mean to use for data normalization. If None, mean is computed from data
    std : np.ndarray of shape 2
        std to use for data normalization. If None, mean is computed from data
    """
        
    def __init__(self, ds, out_features, delta_t, len_sqce, years, nodes, nb_timesteps, 
                 max_lead_time=None, load=True, mean=None, std=None):
        
        
        self.delta_t = delta_t
        self.len_sqce = len_sqce
        self.years = years
        
        self.nodes = nodes
        self.out_features = out_features
        self.max_lead_time = max_lead_time
        self.nb_timesteps = nb_timesteps
        
        self.data = ds.to_array(dim='level', name='Dataset').transpose('time', 'node', 'level')
        self.in_features = self.data.shape[-1]
        
        self.mean = self.data.mean(('time', 'node')).compute() if mean is None else mean
        self.std = self.data.std(('time', 'node')).compute() if std is None else std
        
        eps = 0.001 #add to std to avoid division by 0
        
        # Count total number of samples
        total_samples = self.data.shape[0]        
        
        if max_lead_time is None:
            self.n_samples = total_samples - (len_sqce+1) * delta_t
        else:
            self.n_samples = total_samples - (len_sqce+1) * delta_t - max_lead_time
        
        # Normalize
        self.data = (self.data - self.mean.to_array(dim='level')) / (self.std.to_array(dim='level') + eps)
        self.data.persist()
        
        self.idxs = np.array(range(self.n_samples))
        
        
    def __len__(self):
        return self.n_samples
    
    def __getitem__(self, idx):
        """ Returns sample and label corresponding to an index as torch.Tensor objects
            The return tensor shapes are (for the sample and the label): [n_vertex, len_sqce, n_features]
            
        """
        idx_data = self.idxs[idx]
        #1,0,2
        
        #batch[0] --> (batch_size, num_nodes, n_features*len_sq)
        idx_full = np.concatenate([idx_data+delta_t,  idx_data + delta_t * len_sqce, idx_data + delta_t * (len_sqce+1)])
        dat = self.data.isel(time=idx_full).values
        
        
        X = (
            torch.tensor(dat[:len(idx),:,:] , \
                         dtype=torch.float).reshape(len(idx), self.nodes, -1),
        )
        
        y = (torch.tensor(dat[len(idx):len(idx)*2,:,:],\
                         dtype=torch.float).reshape(len(idx), self.nodes, -1),\
             torch.tensor(dat[len(idx)*2:,:,:out_features],\
                         dtype=torch.float).reshape(len(idx), self.nodes, -1)
        
        )
        return X, y 

```standardization_contants``` is a boolean that will enable the creation of a file that contains the constants already standardized. Set to True only if it is the first time executing the notebook or the file was lost. 

In [6]:
standardization_contants = False

In [7]:
if standardization_contants:
    constants = xr.open_dataset(f'{input_dir}constants/constants_5.625deg.nc').rename({'orography' :'orog'})
    constants = constants.assign(cos_lon=lambda x: np.cos(np.deg2rad(x.lon)))
    constants = constants.assign(sin_lon=lambda x: np.sin(np.deg2rad(x.lon)))
    
    constants_mean = constants.mean().compute()
    constants_std = constants.std().compute()
    
    constants_mean.to_netcdf(f'{input_dir}constants/mean.nc')
    constants_std.to_netcdf(f'{input_dir}constants/std.nc')
    
    c_mean = xr.open_dataset(f'{input_dir}constants/mean.nc')
    c_std = xr.open_dataset(f'{input_dir}constants/std.nc')
    
    constants_ss = (constants - c_mean)/c_std
    
    constants_ss.to_netcdf(f'{input_dir}constants/constants_5.625deg_standardized.nc')

In [8]:
z500 = xr.open_mfdataset(f'{input_dir}geopotential_500/*.nc', combine='by_coords', chunks={'time':chunk_size}).rename({'z':'z500'})
t850 = xr.open_mfdataset(f'{input_dir}temperature_850/*.nc', combine='by_coords', chunks={'time':chunk_size}).rename({'t':'t850'})
rad = xr.open_mfdataset(f'{input_dir}toa_incident_solar_radiation/*.nc', combine='by_coords', chunks={'time':chunk_size})

z500 = z500.isel(time=slice(7, None))
t850 = t850.isel(time=slice(7, None))

constants = xr.open_dataset(f'{input_dir}constants/constants_5.625deg_standardized.nc')
#constants = constants.assign(cos_lon=lambda x: np.cos(np.deg2rad(x.lon)))
#constants = constants.assign(sin_lon=lambda x: np.sin(np.deg2rad(x.lon)))

#temp = xr.DataArray(np.zeros(z500.dims['time']), coords=[('time', z500.time.values)])
#constants, _ = xr.broadcast(constants, temp)

orog = constants['orog']
lsm = constants['lsm']
lats = constants['lat2d']
slt = constants['slt']
cos_lon = constants['cos_lon']
sin_lon = constants['sin_lon']

num_constants = len([orog, lats, lsm, slt])
constants_tensor = torch.tensor(xr.merge([orog, lats, lsm, slt], compat='override').to_array().values, \
                            dtype=torch.float)

In [9]:
#description = "no_const"
description = "all_const"

model_filename = model_save_path + "spherical_unet_" + description + ".h5"
pred_filename = pred_save_path + "spherical_unet_" + description + ".nc"
rmse_filename = datadir + 'metrics/rmse_' + description + '.nc'

# z500, t850, orog, lats, lsm, slt, rad
#feature_idx = [0, 1]
in_features = 7 #len(feature_idx)
ds = xr.merge([z500, t850, rad], compat='override')
#ds = xr.merge([z500, t850, orog, lats, lsm, slt, rad], compat='override')

ds_train = ds.sel(time=slice(*train_years))
ds_valid = ds.sel(time=slice(*val_years))
ds_test = ds.sel(time=slice(*test_years))


In [10]:
if standardization_contants:
    mean_features = ds_train.mean(('time','node')).compute()
    std_features = ds_train.std('time').mean('node').compute()

    mean_features.to_netcdf(f'{input_dir}mean_train_features_dynamic.nc')
    std_features.to_netcdf(f'{input_dir}std_train_features_dynamic.nc')


In [11]:
train_mean_ = xr.open_mfdataset(f'{input_dir}mean_train_features_dynamic.nc')
train_std_ = xr.open_mfdataset(f'{input_dir}std_train_features_dynamic.nc')

In [12]:
# Train and validation data
training_ds = WeatherBenchDatasetXarrayHealpixTemp(ds=ds_train, out_features=out_features, delta_t=delta_t,
                                                   len_sqce=len_sqce, max_lead_time=max_lead_time,
                                                   years=train_years, nodes=nodes, nb_timesteps=nb_timesteps, 
                                                   mean=train_mean_, std=train_std_, load=False)
validation_ds = WeatherBenchDatasetXarrayHealpixTemp(ds=ds_valid, out_features=out_features, delta_t=delta_t,
                                                     len_sqce=len_sqce, max_lead_time=max_lead_time,
                                                     years=train_years, nodes=nodes, nb_timesteps=nb_timesteps, 
                                                     mean=train_mean_, std=train_std_, load=False)

dl_train = DataLoader(training_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers,\
                      pin_memory=pin_memory)

dl_val = DataLoader(validation_ds, batch_size=batch_size*2, shuffle=False, num_workers=num_workers,\
                    pin_memory=pin_memory)

In [13]:
# Model #old: in_channels=in_features*len_sqce
spherical_unet = UNetSphericalHealpix(N=nodes, in_channels=in_features, out_channels=out_features, 
                                      kernel_size=3)
spherical_unet, device = init_device(spherical_unet, gpu=gpu)

In [14]:
torch.cuda.empty_cache()

In [15]:
def train_model_2steps_custom(model, device, training_ds, constants, batch_size, epochs, lr, validation_data):    
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, eps=1e-7, weight_decay=0, amsgrad=False)
    
    train_losses = []
    val_losses = []
    n_samples = training_ds.n_samples
    num_nodes = training_ds.nodes
    num_constants = constants.shape[1]
    out_features = training_ds.out_features
    
    constants_expanded = constants.expand(batch_size, num_nodes, num_constants)
    constants1 = constants_expanded.to(device)
    
    for epoch in range(epochs):
        
        print('\rEpoch : {}'.format(epoch), end="")
        
        time1 = time.time()
        
        val_loss = 0
        train_loss = 0
        
        model.train()  
        
        random.shuffle(training_ds.idxs)
        idxs = training_ds.idxs
        
        batch_idx = 0
        for i in range(0, n_samples - batch_size, batch_size):
            i_next = min(i + batch_size, n_samples)
            
            if len(idxs[i:i_next]) < batch_size:
                constants_expanded = contants.expand(len(idxs[i:i_next]), num_nodes, num_constants)
                constants1 = constants_expanded.to(device)
        
            
            t1 = time.time()
            batch, labels = training_ds[idxs[i:i_next]]
            
            t2 = time.time()
            
            # Transfer to GPU
            
            
            batch1 = torch.cat((batch[0], constants_expanded), dim=2).to(device)
            label1 = labels[0].to(device)
            label2 = labels[1].to(device)
            
            
            t3 = time.time()
            batch_size = batch1.shape[0]
            
            # Model
            
            t4 = time.time()
            output1 = model(batch1)  
            t5 = time.time()
            batch2 = torch.cat((output1, label1[:,:,-1].view(-1, num_nodes, 1), constants1), dim=2)
            t6 = time.time()
            output2 = model(batch2)
            t7 = time.time()
            loss = criterion(output1, label1[:,:,:out_features]) + criterion(output2, label2)
            t8 = time.time()
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss = train_loss + loss.item() * batch_size
            
            
            print('\nTime to read batch: {}s'.format(t2-t1))
            print('Time to transfer data to GPU: {}s'.format(t3-t2))
            print('Time to process input 1: {}s'.format(t5-t4))
            print('Time to process input 2: {}s'.format(t7-t6))
            print('Time to compute loss: {}s'.format(t8-t7))
            print('\n')
            print('\rBatch idx: {}; Loss: {:.3f}'.format(batch_idx, train_loss/(batch_size*(batch_idx+1))), end="")
            batch_idx += 1
            
        if epoch == 2:
            return output1, output2, label1, label2
        
        train_loss = train_loss / (len(train_generator.dataset))
        train_losses.append(train_loss)
        
        model.eval()
        with torch.set_grad_enabled(False):
            index = 0
            
            for batch, labels in validation_data:
                # Transfer to GPU
                batch1 = torch.cat((batch[0], constants1), dim=2).to(device)
                label1 = labels[0].to(device)
                label2 = labels[1].to(device)

                batch_size = batch1.shape[0]
                
                output1 = model(batch1)
                batch2 = torch.cat((output1, constants1), dim=2)
                output2 = model(batch2)
                
                val_loss = val_loss + (criterion(output1, label1).item() 
                                       + criterion(output2, label2).item()) * batch_size
                index = index + batch_size
                
        val_loss = val_loss / (len(validation_data.dataset))
        val_losses.append(val_loss)
        
        time2 = time.time()
        
        # Print stuff
        print('Epoch: {e:3d}/{n_e:3d}  - loss: {l:.3f}  - val_loss: {v_l:.5f}  - time: {t:2f}'
              .format(e=epoch+1, n_e=epochs, l=train_loss, v_l=val_loss, t=time2-time1))
        
    return train_losses, val_losses

In [16]:
train_model_2steps_custom(spherical_unet, device, training_ds, constants_tensor.transpose(1,0), batch_size, epochs=7, \
                                           lr=learning_rate, validation_data=dl_val)

Epoch : 0
Time to read batch: 2.8489584922790527s
Time to transfer data to GPU: 0.005797386169433594s
Time to process input 1: 2.5217576026916504s
Time to process input 2: 0.10119962692260742s
Time to compute loss: 0.048406362533569336s


Batch idx: 0; Loss: 23.178
Time to read batch: 3.2484841346740723s
Time to transfer data to GPU: 0.005156040191650391s
Time to process input 1: 0.05000710487365723s
Time to process input 2: 0.04090094566345215s
Time to compute loss: 0.15488314628601074s


Batch idx: 1; Loss: 17.992
Time to read batch: 2.227978467941284s
Time to transfer data to GPU: 0.0050048828125s
Time to process input 1: 0.04160046577453613s
Time to process input 2: 0.04461359977722168s
Time to compute loss: 0.14378929138183594s


Batch idx: 2; Loss: 14.503
Time to read batch: 2.5850720405578613s
Time to transfer data to GPU: 0.004888057708740234s
Time to process input 1: 0.039263248443603516s
Time to process input 2: 0.03805685043334961s
Time to compute loss: 0.15384244918823242s


Batch idx: 32; Loss: 3.051
Time to read batch: 2.4416162967681885s
Time to transfer data to GPU: 0.00510859489440918s
Time to process input 1: 0.038666725158691406s
Time to process input 2: 0.039018869400024414s
Time to compute loss: 0.15257787704467773s


Batch idx: 33; Loss: 3.004
Time to read batch: 2.553452730178833s
Time to transfer data to GPU: 0.004793882369995117s
Time to process input 1: 0.046502113342285156s
Time to process input 2: 0.038489341735839844s
Time to compute loss: 0.15085411071777344s


Batch idx: 34; Loss: 2.957
Time to read batch: 2.407186985015869s
Time to transfer data to GPU: 0.005374431610107422s
Time to process input 1: 0.04132390022277832s
Time to process input 2: 0.038373470306396484s
Time to compute loss: 0.15488052368164062s


Batch idx: 35; Loss: 2.913
Time to read batch: 2.247410774230957s
Time to transfer data to GPU: 0.004984140396118164s
Time to process input 1: 0.04841923713684082s
Time to process input 2: 0.038690805435180664s
Time to compute los

KeyboardInterrupt: 