# dev
- reimplementing the 'Quickstart' notebook from S. Rasp for training in pytorch

# loading data

In [8]:
import numpy as np
import xarray as xr
import torch
from src.pytorch.Dataset import Dataset
from src.pytorch.util import init_torch_device

device = init_torch_device()

datadir = '/gpfs/work/nonnenma/data/forecast_predictability/weatherbench/5_625deg/'
res_dir = '/gpfs/work/nonnenma/results/forecast_predictability/weatherbench/5_625deg/'

model_name = 'simpleResnet' # 'simpleResnet', 'tvfcnResnet50', 'cnnbn', 'Unetbn'

lead_time = 3*24
batch_size = 32

train_years = ('1979', '2015')

var_dict = {'geopotential': ('z', [100, 200, 500, 850, 1000]),
           'temperature': ('t', [100, 200, 500, 850, 1000]),
           'u_component_of_wind': ('u', [100, 200, 500, 850, 1000]), 
           'v_component_of_wind': ('v', [100, 200, 500, 850, 1000]),
           'constants': ['lsm','orography','lat2d']
           }
target_vars = ['geopotential', 'temperature']
target_levels = [500, 850]

x = xr.merge(
[xr.open_mfdataset(f'{datadir}/{var}/*.nc', combine='by_coords')
 for var in var_dict.keys()],
fill_value=0  # For the 'tisr' NaNs
)
x = x.chunk({'time' : np.sum(x.chunks['time']), 'lat' : x.chunks['lat'], 'lon': x.chunks['lon']})

dg_train = Dataset(x.sel(time=slice(train_years[0], train_years[1])), var_dict, lead_time, 
                   normalize=True, norm_subsample=1, res_dir=res_dir, train_years=train_years,
                   target_vars=target_vars, target_levels=target_levels)

train_loader = torch.utils.data.DataLoader(
    dg_train,
    batch_size=batch_size,
    drop_last=True)

dg_validation =  Dataset(x.sel(time=slice('2016', '2016')), var_dict, lead_time,
                        mean=dg_train.mean, std=dg_train.std, normalize=True, randomize_order=False,
                        target_vars=target_vars, target_levels=target_levels)
validation_loader = torch.utils.data.DataLoader(
    dg_validation,
    batch_size=batch_size,
    drop_last=False)

n_channels = len(dg_train.data.level.level)
print('n_channels', n_channels)

#model_fn = f'{n_channels}D_fc{model_name}_{lead_time//24}d_pytorch.pt' # file name for saving/loading prediction model
model_fn = f'{n_channels}D_fc{model_name}_{lead_time//24}d_pytorch_lrdecay_weightdecay_normed_test.pt' # file name for saving/loading prediction model
print('model filename', model_fn)


CUDA not available
Loading means and standard deviations from disk
n_channels 23
model filename 23D_fcsimpleResnet_3d_pytorch_lrdecay_weightdecay_normed_test.pt


In [None]:
# check I/O speed on single (empty) epoch
for batch in train_loader:
    inputs, targets = batch[0].to(device), batch[1].to(device)
    print(inputs.shape, targets.shape)

# define model

In [9]:
from src.pytorch.util import named_network
model, model_forward = named_network(model_name, n_channels, len(target_vars))

# train model

In [10]:
from src.pytorch.train import train_model

train_again = False
if train_again:
    training_outputs = train_model(model, train_loader, validation_loader, device, model_forward,
                    lr=5e-4, lr_min=1e-5, lr_decay=0.2, weight_decay=1e-5,
                    max_epochs=200, max_patience=20, max_lr_patience=5, eval_every=2000,
                    verbose=True, save_dir=res_dir + model_fn)

# if skip training, load model from disk
else:
    model.load_state_dict(torch.load(res_dir + model_fn, map_location=torch.device(device)))

# evaluate

In [None]:
from src.pytorch.train import calc_val_loss

validation_loss = calc_val_loss(validation_loader, model_forward, device)
validation_loss

In [None]:
from src.pytorch.train_nn import create_predictions
from src.score import compute_weighted_rmse, load_test_data

dg_test =  Dataset(x.sel(time=slice('2017', '2018')),
                   var_dict,
                   lead_time,
                   mean=dg_train.mean, # make sure that model was trained 
                   std=dg_train.std,   # with same data as in dg_train, 
                   normalize=True,     # or else normalization is off!
                   randomize_order=False,
                   target_vars=target_vars, 
                   target_levels=target_levels)

preds = create_predictions(model,
                           dg_test,
                           var_dict={'z' : None, 't' : None},
                           batch_size=100,
                           model_forward=model_forward,
                           verbose=True)

z500_test = load_test_data(f'{datadir}geopotential_500/', 'z')
t850_test = load_test_data(f'{datadir}temperature_850/', 't')
rmse_z = compute_weighted_rmse(preds.z, z500_test.isel(time=slice(lead_time, None))).load()
rmse_t = compute_weighted_rmse(preds.t, t850_test.isel(time=slice(lead_time, None))).load()
print('RMSE z', rmse_z.values); print('RMSE t', rmse_t.values)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

dg = dg_test

# variable names for display in figure
var_names = {'geopotential' : 'geopotential at 500hPa', 
             'temperature' : 'temperature at 850hPa'}

# pick time stamps to visualize
idx = [2000] # index relative to start time of dataset !

for i in idx:
    pre = dg[[i]][0]
    post = dg[[i]][1]
    # predict for single time stamp
    pred = model_forward(torch.tensor(pre,requires_grad=False).to(device)).detach().numpy()
    
    plt.figure(figsize=(16,6))
    for j in range(2):
        plt.subplot(1,2,j+1)

        # top: current state, middle: model-predicted future state, bottom: future state
        j_ = dg._target_idx[j] # index for dg object in case first two dimensions not Z500, T850
        plt.imshow(np.vstack((post[0,j,:,:], pred[0,j,:,:], pre[0,j_,:,:])))

        plt.plot([0.5, pred.shape[3]+.5], (1*pred.shape[2]-0.5)*np.ones(2), 'k', linewidth=1.5)
        plt.plot([0.5, pred.shape[3]+.5], (2*pred.shape[2]-0.5)*np.ones(2), 'k', linewidth=1.5)
        plt.yticks([pred.shape[2]//2, 3*pred.shape[2]//2, 5*pred.shape[2]//2],
                   [f'+{lead_time}h true', f'+{lead_time}h est.', 'state'])
        plt.axis([-0.5, pred.shape[3]-0.5, -0.5, 3*pred.shape[2]-0.5])
        plt.colorbar()
        plt.xlabel(var_names[list(dg.var_dict.keys())[j]])
        plt.title(dg.data.time.isel(time=i).values)
    plt.show()


In [None]:
import cartopy.crs as ccrs

# debug

### RMSE per pixel 

In [None]:
import matplotlib.pyplot as plt

RMSEs_z = np.sqrt(np.mean((preds[:,0,:,:] - z500_test.isel(time=slice(lead_time, None)))**2, axis=0))
RMSEs_t = np.sqrt(np.mean((preds[:,1,:,:] - t850_test.isel(time=slice(lead_time, None)))**2, axis=0))


weights_lat = np.cos(np.deg2rad(z500_test.lat))
weights_lat /= weights_lat.mean()

wRMSEs_z = np.sqrt(
    np.mean( ((preds[:,0,:,:] - z500_test.isel(time=slice(lead_time, None)))**2)*weights_lat, 
            axis=0))
wRMSEs_t = np.sqrt(
    np.mean( ((preds[:,1,:,:] - t850_test.isel(time=slice(lead_time, None)))**2)*weights_lat, 
            axis=0))

plt.figure(figsize=(16,8))

plt.subplot(2,2,1)
plt.imshow(RMSEs_z)
plt.title('RMSEs Z500')
plt.colorbar()
plt.subplot(2,2,2)
plt.imshow(RMSEs_t)
plt.title('RMSEs T850')
plt.colorbar()

plt.subplot(2,2,3)
plt.imshow(wRMSEs_z)
plt.title('weighted RMSEs Z500')
plt.colorbar()
plt.subplot(2,2,4)
plt.imshow(wRMSEs_t)
plt.title('weighted RMSEs T850')
plt.colorbar()

plt.show()


### RMSEs per time point

In [None]:
import matplotlib.pyplot as plt

RMSEs_z = np.sqrt(np.mean((preds[:,0,:,:] - z500_test.isel(time=slice(lead_time, None)))**2, axis=[1,2]))
RMSEs_t = np.sqrt(np.mean((preds[:,1,:,:] - t850_test.isel(time=slice(lead_time, None)))**2, axis=[1,2]))


weights_lat = np.cos(np.deg2rad(z500_test.lat))
weights_lat /= weights_lat.mean()

wRMSEs_z = np.sqrt(
    np.mean( ((preds[:,0,:,:] - z500_test.isel(time=slice(lead_time, None)))**2)*weights_lat, 
            axis=[1,2]))
wRMSEs_t = np.sqrt(
    np.mean( ((preds[:,1,:,:] - t850_test.isel(time=slice(lead_time, None)))**2)*weights_lat, 
            axis=[1,2]))

plt.figure(figsize=(16,8))

plt.subplot(2,2,1)
plt.plot(RMSEs_z)
plt.title('RMSEs Z500')
plt.subplot(2,2,2)
plt.plot(RMSEs_t)
plt.title('RMSEs T850')

plt.subplot(2,2,3)
plt.plot(wRMSEs_z)
plt.title('weighted RMSEs Z500')
plt.subplot(2,2,4)
plt.plot(wRMSEs_t)
plt.title('weighted RMSEs T850')

plt.show()
