### Train best model according to report to try to reproduce results

- All static features: 
    * Z500, 
    * T850, 
    * latitude, 
    * orography, 
    * land-sea mask, 
    * soil type, and 
    * top-of-atmosphere radiation
- L=2
- $\Delta t$ = 6

#### Import libraries

In [None]:
import sys
sys.path.append('/'.join(sys.path[0].split('/')[:-1]))

import xarray as xr
import matplotlib.pyplot as plt
import numpy as np
import time
import os
import healpy as hp

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader

from modules.utils import train_model_2steps_temp, init_device
from modules.data import WeatherBenchDatasetXarrayHealpixTemp
from modules.healpix_models import UNetSphericalHealpix
from modules.test import create_iterative_predictions_healpix_temp
from modules.test import compute_rmse_healpix
from modules.plotting import plot_rmses

datadir = "../data/healpix/"
input_dir = datadir + "5.625deg/"
model_save_path = datadir + "models/"
pred_save_path = datadir + "predictions/"

if not os.path.isdir(model_save_path):
    os.mkdir(model_save_path)
    
if not os.path.isdir(pred_save_path):
    os.mkdir(pred_save_path)

Define constants and load data

In [None]:
train_years = ('2000', '2012')#('1979', '2012')
val_years = ('2013', '2016')
test_years = ('2017', '2018')

nodes = 12*16*16
max_lead_time = 5*24
nb_timesteps = 2

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0,2"
gpu = [0, 1]
num_workers = 10
pin_memory = True

nb_epochs = 20
learning_rate = 8e-3

obs = xr.open_mfdataset(pred_save_path + 'observations.nc', combine='by_coords')
#rmses_weyn = xr.open_dataset(datadir + 'metrics/rmses_weyn.nc')

Define functions:

**TODO**
Check if the code is the same as the functions with the same name in ```modules/*.py``` and subtitute by imports in such a case. 

#### Load data

In [None]:
z500 = xr.open_mfdataset(f'{input_dir}geopotential_500/*.nc', combine='by_coords').rename({'z':'z500'})
t850 = xr.open_mfdataset(f'{input_dir}temperature_850/*.nc', combine='by_coords').rename({'t':'t850'})
rad = xr.open_mfdataset(f'{input_dir}toa_incident_solar_radiation/*.nc', combine='by_coords')

z500 = z500.isel(time=slice(7, None))
t850 = t850.isel(time=slice(7, None))

constants = xr.open_dataset(f'{input_dir}constants/constants_5.625deg.nc').rename({'orography' :'orog'})
constants = constants.assign(cos_lon=lambda x: np.cos(np.deg2rad(x.lon)))
constants = constants.assign(sin_lon=lambda x: np.sin(np.deg2rad(x.lon)))

temp = xr.DataArray(np.zeros(z500.dims['time']), coords=[('time', z500.time.values)])
constants, _ = xr.broadcast(constants, temp)

orog = constants['orog']
lsm = constants['lsm']
lats = constants['lat2d']
slt = constants['slt']
cos_lon = constants['cos_lon']
sin_lon = constants['sin_lon']

In [None]:
z = xr.open_mfdataset(f'{input_dir}geopotential_500/*.nc', combine='by_coords')['z']\
.assign_coords(level=1)

t = xr.open_mfdataset(f'{input_dir}temperature_850/*.nc', combine='by_coords')['t']\
.assign_coords(level=1)

predictors = xr.concat([z, t], 'level')

In [None]:
#predictors_mean = predictors.mean(('time','node')).compute()
#predictors_std = predictors.std('time').mean('node').compute()

#const_mean = constants.mean(('time','node')).compute()
#const_std = constants.std('time').mean(('node')).compute()

In [None]:
# z500, t850, orog, lats, lsm, slt, rad
in_features = 7
out_features = 2
ds = xr.merge([z500, t850, orog, lats, lsm, slt, rad], compat='override')

ds_train = ds.sel(time=slice(*train_years))
ds_valid = ds.sel(time=slice(*val_years))
ds_test = ds.sel(time=slice(*test_years))

In [None]:
train_mean = ds_train.mean(('time','node')).compute()
train_std = ds_train.std('time').mean('node').compute()

#### Define model parameters

In [None]:
# define length of sequence to take into account for loss
len_sqce = 2
# define time resolution
delta_t = 6

# predict 5days data
max_lead_time = 5*24

feature_idx = list(range(7))
in_features = 7
out_features = 2
#ds = xr.merge([z500, t850, orog, lats, lsm, slt, rad], compat='override')
#ds_test = ds.sel(time=slice(*test_years))

#train_mean_ = train_mean.to_array()[feature_idx]
#train_std_ = train_std.to_array()[feature_idx]

In [None]:
#del train_mean_
#del train_std_

In [None]:
len_sqce = 2
delta_t = 6

description = "all_const_len{}_delta{}".format(len_sqce, delta_t)

model_filename = model_save_path + "spherical_unet_" + description + ".h5"
pred_filename = pred_save_path + "spherical_unet_" + description + ".nc"
rmse_filename = datadir + 'metrics/rmse_' + description + '.nc'

**Attention:**

If ```load=True``` the kernel dies. Check problem origin and if it's necessary to load the data

In [None]:
# Train and validation data
training_ds = WeatherBenchDatasetXarrayHealpixTemp(ds=ds_train, out_features=out_features,
                                                   len_sqce=len_sqce, delta_t=delta_t, years=train_years,
                                                   nodes=nodes, nb_timesteps=nb_timesteps, 
                                                   mean=train_mean, std=train_std)

In [None]:
training_ds.data

In [None]:
validation_ds = WeatherBenchDatasetXarrayHealpixTemp(ds=ds_valid, out_features=out_features, 
                                                     len_sqce=len_sqce, delta_t=delta_t, years=val_years, 
                                                     nodes=nodes, nb_timesteps=nb_timesteps, 
                                                     mean=train_mean, std=train_std)

In [None]:
batch_size = 70

In [None]:
dl_train = DataLoader(training_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, 
                      pin_memory=pin_memory)

In [None]:
dl_val = DataLoader(validation_ds, batch_size=batch_size*2, shuffle=False, num_workers=num_workers, 
                    pin_memory=pin_memory)

Define model

**Attention:**

Problem with ```pygsp.graphs``` since the module we are trying to load doesn't seem to appear.

In [None]:
# Model
spherical_unet = UNetSphericalHealpix(N=nodes, in_channels=in_features*len_sqce, out_channels=out_features, 
                                      kernel_size=3)
spherical_unet, device = init_device(spherical_unet, gpu=gpu)


Train and test. Plot results

In [None]:
# Train model
train_loss, val_loss = train_model_2steps_temp(spherical_unet, device, dl_train, epochs=nb_epochs, 
                                               lr=learning_rate, validation_data=dl_val, 
                                               model_filename=model_filename)
torch.save(spherical_unet.state_dict(), model_filename)



In [None]:
# Show training losses
plt.plot(train_loss, label='Training loss')
plt.plot(val_loss, label='Validation loss')
plt.xlabel('Epochs')
plt.ylabel('MSE Loss')
plt.legend()
plt.show()

del dl_train, dl_val, training_ds, validation_ds
torch.cuda.empty_cache()

In [None]:
# Testing data
testing_ds = WeatherBenchDatasetXarrayHealpixTemp(ds=ds_test, out_features=out_features,
                                                  len_sqce=len_sqce, delta_t=delta_t, years=test_years, 
                                                  nodes=nodes, nb_timesteps=nb_timesteps, 
                                                  mean=train_mean, std=train_std, 
                                                  max_lead_time=max_lead_time)

dataloader_test = DataLoader(testing_ds, batch_size=int(0.7*batch_size), shuffle=False,
                             num_workers=num_workers)



In [None]:
# Compute predictions
preds = create_iterative_predictions_healpix_temp(spherical_unet, device, dataloader_test)
preds.to_netcdf(pred_filename)


In [None]:
# Compute and save RMSE
rmse = compute_rmse_healpix(preds, obs).load()
rmse.to_netcdf(rmse_filename)

# Show RMSE
print('Z500 - 0:', rmse.z.values[0])
print('T850 - 0:', rmse.t.values[0])
plot_rmses(rmse, rmses_weyn, lead_time=6)

del spherical_unet, preds, rmse
torch.cuda.empty_cache()