# dev
- reimplementing the 'Quickstart' notebook from S. Rasp for training in pytorch

# define setup

In [None]:
import numpy as np
import torch
from src.pytorch.util import init_torch_device

device = init_torch_device()

datadir = '/gpfs/work/nonnenma/data/forecast_predictability/weatherbench/5_625deg/'
res_dir = '/gpfs/work/nonnenma/results/forecast_predictability/weatherbench/5_625deg/'

lead_time = 3*24
batch_size = 32

train_years = ('1979', '2015')
validation_years = ('2016', '2016')
test_years = ('2017', '2018')

var_dict = {'geopotential': ('z', [100, 200, 500, 850, 1000]),
           'temperature': ('t', [100, 200, 500, 850, 1000]),
           'u_component_of_wind': ('u', [100, 200, 500, 850, 1000]), 
           'v_component_of_wind': ('v', [100, 200, 500, 850, 1000]),
           'constants': ['lsm','orography','lat2d']
           }

target_var_dict = {'geopotential': 500, 'temperature': 850}

model_name = 'ConvLSTM' # 'simpleResnet', 'tvfcnResnet50', 'cnnbn', 'Unetbn'

#filters = [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128]
#kernel_sizes = [7, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
#past_times = [-6, -12]

filters = [128, 64, 64, 2]
kernel_sizes = [3, 3, 3, 3]
past_times = [-3, -6, -9, -12, -15, -18, -21, -24, -27, -30, -33, -36]

dropout_rate = 0.1
past_times_own_axis = True

verbose = True
loss_fun = 'mse'

mmap_mode = None #'r'

datadirdg = '/gpfs/work/greenber/'
filedir = datadir + '5_625deg_all_zscored.npy'
leveldir = datadir + '5_625deg_all_level_names.npy'

# load data

In [None]:
from src.pytorch.util import load_data
from src.pytorch.Dataset import collate_fn_memmap, load_mean_std

# load data
dg_train, dg_validation, dg_test, dg_meta = load_data(
    var_dict=var_dict, lead_time=lead_time,
    train_years=(train_years[0], train_years[1]), 
    validation_years=(validation_years[0], validation_years[1]), 
    test_years=(test_years[0], test_years[1]),
    target_var_dict=target_var_dict, datadir=datadir, 
    mmap_mode=mmap_mode, past_times=past_times, past_times_own_axis=past_times_own_axis
)

def collate_fn(batch):
    return collate_fn_memmap(batch, dg_train, past_times_own_axis=past_times_own_axis)

validation_loader = torch.utils.data.DataLoader(
    dg_validation, batch_size=batch_size, drop_last=False
)
train_loader = torch.utils.data.DataLoader(
    dg_train, batch_size=batch_size, collate_fn=collate_fn, drop_last=True,
    num_workers=0 #int(train_years[1]) - int(train_years[0]) + 1
)

n_channels = len(dg_train._var_idx) if past_times_own_axis else len(dg_train._var_idx) * len(dg_train.past_times)
print('n_channels', n_channels)

model_fn = f'{n_channels}D_fc{model_name}_{lead_time//24}d_pytorch.pt' # file name for saving/loading prediction model
#model_fn = f'{n_channels}D_fc{model_name}_{lead_time//24}d_pytorch_lrdecay_weightdecay_normed_test2.pt' # file name for saving/loading prediction model
print('model filename', model_fn)


In [None]:
max_steps = 200
print_every = 10

import time
def do_dummy_epoch(train_loader, t = None):
    # check I/O speed on single (empty) epoch
    num_steps = 1
    t = time.time() if t is None else t
    for batch in train_loader:
        inputs, targets = batch[0].to(device), batch[1].to(device)
        if np.mod(num_steps, print_every) == 0 or num_steps == 1:
            print(f"- batch #{num_steps}, time: {'{0:.2f}'.format(time.time() - t)}")
            print(inputs.shape, targets.shape)
        num_steps +=1
        if num_steps > max_steps:
            #out = inputs.numpy()
            break
    #return out

t = time.time()
do_dummy_epoch(train_loader, t)


# define model

In [None]:
model_name = 'ConvLSTM' # 'simpleResnet', 'tvfcnResnet50', 'cnnbn', 'Unetbn'

#filters = [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128]
#kernel_sizes = [7, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
#past_times = [-6, -12]

filters = [128, 64, 64, 2]
kernel_sizes = [3, 3, 3, 3]
past_times = [-3, -6, -9, -12, -15, -18, -21, -24, -27, -30, -33, -36]

past_times_own_axis = True

In [None]:
"""
from src.pytorch.util import named_network
model, model_forward = named_network(model_name, n_channels, len(target_var_dict), 
                                     filters=filters, kernel_sizes=kernel_sizes)
"""
from src.pytorch.util import named_network

model, model_forward = named_network(model_name, n_channels, len(target_var_dict), 
                                     kernel_sizes=kernel_sizes, filters=filters, dropout_rate=dropout_rate)

print('total #parameters: ', np.sum([np.prod(item.shape) for item in model.state_dict().values()]))

# train model

In [None]:
model_fn = 'models/convlstm_test/convlstm_test_23D_fcConvLSTM_72h.pt'
#model_fn = 'models/multi_delay_test/multi_delay_test_69D_fcsimpleResnet_72h.pt'

In [None]:
from src.pytorch.train import train_model, loss_function

train_again = False
if train_again:
    loss_fun = loss_function(loss_fun)
    training_outputs = train_model(model, train_loader, validation_loader, device, model_forward,
                    loss_fun=loss_fun, lr=5e-4, lr_min=1e-5, lr_decay=0.2, weight_decay=1e-5,
                    max_epochs=200, max_patience=20, max_lr_patience=5, eval_every=2000,
                    verbose=True, save_dir=res_dir + model_fn)

# if skip training, load model from disk
else:
    model.load_state_dict(torch.load(res_dir + model_fn, map_location=torch.device(device)))

# evaluate

In [None]:
from src.score import compute_weighted_rmse, load_test_data
z500_test = load_test_data(f'{datadir}geopotential_500/', 'z')
t850_test = load_test_data(f'{datadir}temperature_850/', 't')
z500_test.isel(time=slice(lead_time+dg_test.max_input_lag, None)).values.shape

In [None]:
#from src.pytorch.train import calc_val_loss
#print('validation loss:', calc_val_loss(validation_loader, model_forward, device))

In [None]:
from src.pytorch.train_nn import create_predictions
from src.score import compute_weighted_rmse, load_test_data

mean, std, _, _ = load_mean_std(res_dir, {'geopotential': ['z', [500]], 'temperature': ['t', [850]]}, train_years)   
valid_test_time = dg_meta['time'].sel(time=slice(test_years[0], test_years[1]))
dg_meta['valid_time'] = valid_test_time.isel(time=slice(dg_test.lead_time+dg_test.max_input_lag, None)).time
preds = create_predictions(model,
                           dg_test,
                           var_dict={'z' : None, 't' : None},
                           batch_size=100,
                           model_forward=model_forward,
                           verbose=True,
                           past_times_own_axis=past_times_own_axis,
                           mean=mean, 
                           std=std,
                           dg_meta=dg_meta)

z500_test = load_test_data(f'{datadir}geopotential_500/', 'z')
t850_test = load_test_data(f'{datadir}temperature_850/', 't')
rmse_z = compute_weighted_rmse(preds.z, z500_test.isel(time=slice(lead_time+dg_test.max_input_lag, None))).load()
rmse_t = compute_weighted_rmse(preds.t, t850_test.isel(time=slice(lead_time+dg_test.max_input_lag, None))).load()
print('RMSE z', rmse_z.values); print('RMSE t', rmse_t.values)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

dg = dg_test

# variable names for display in figure
var_names = {'geopotential' : 'geopotential at 500hPa', 
             'temperature' : 'temperature at 850hPa'}

# pick time stamps to visualize
idx = [dg_test.start+100] # index relative to start time of dataset !

for i in idx:
    pre = dg[[i]][0]
    if past_times_own_axis:
        pre = pre.reshape((len(dg._past_idx), 1,  len(dg._var_idx), *pre.shape[2:])).transpose(1,0,2,3,4)
    post = dg[[i]][1]
    # predict for single time stamp
    pred = model_forward(torch.tensor(pre,requires_grad=False).to(device)).detach().numpy()
    
    plt.figure(figsize=(16,6))
    for j in range(2):
        plt.subplot(1,2,j+1)

        # top: current state, middle: model-predicted future state, bottom: future state
        j_ = dg._target_idx[j] # index for dg object in case first two dimensions not Z500, T850
        plt.imshow(np.vstack((post[0,j,:,:], pred[0,j,:,:], pre[0,-1,j_,:,:])))

        plt.plot([0.5, pred.shape[3]+.5], (1*pred.shape[2]-0.5)*np.ones(2), 'k', linewidth=1.5)
        plt.plot([0.5, pred.shape[3]+.5], (2*pred.shape[2]-0.5)*np.ones(2), 'k', linewidth=1.5)
        plt.yticks([pred.shape[2]//2, 3*pred.shape[2]//2, 5*pred.shape[2]//2],
                   [f'+{lead_time}h true', f'+{lead_time}h est.', 'state'])
        plt.axis([-0.5, pred.shape[3]-0.5, -0.5, 3*pred.shape[2]-0.5])
        plt.colorbar()
        plt.xlabel(var_names[list(dg.var_dict.keys())[j]])
        #plt.title(dg.data.time.isel(time=i).values)
    plt.show()


In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

dg = dg_test

# variable names for display in figure
var_names = {'geopotential' : 'geopotential at 500hPa', 
             'temperature' : 'temperature at 850hPa'}

# pick time stamps to visualize
idx = [2000] # index relative to start time of dataset !

for i in idx:
    pre = dg[[i]][0]
    post = dg[[i]][1]
    # predict for single time stamp
    pred = model_forward(torch.tensor(pre,requires_grad=False).to(device)).detach().numpy()
    
    plt.figure(figsize=(16,6))
    for j in range(2):
        plt.subplot(1,2,j+1)

        # top: current state, middle: model-predicted future state, bottom: future state
        j_ = dg._target_idx[j] # index for dg object in case first two dimensions not Z500, T850
        plt.imshow(np.vstack((post[0,j,:,:], pred[0,j,:,:], pre[0,j_,:,:])))

        plt.plot([0.5, pred.shape[3]+.5], (1*pred.shape[2]-0.5)*np.ones(2), 'k', linewidth=1.5)
        plt.plot([0.5, pred.shape[3]+.5], (2*pred.shape[2]-0.5)*np.ones(2), 'k', linewidth=1.5)
        plt.yticks([pred.shape[2]//2, 3*pred.shape[2]//2, 5*pred.shape[2]//2],
                   [f'+{lead_time}h true', f'+{lead_time}h est.', 'state'])
        plt.axis([-0.5, pred.shape[3]-0.5, -0.5, 3*pred.shape[2]-0.5])
        plt.colorbar()
        plt.xlabel(var_names[list(dg.var_dict.keys())[j]])
        plt.title(dg.data.time.isel(time=i).values)
    plt.show()


In [None]:
import cartopy.crs as ccrs

## model survery

In [None]:
import matplotlib.pyplot as plt
import os

exp_ids = ['resnet_baseline', 'resnet_baseline_no_L2',  'resnet_latmse', 'multi_delay_test']
def find_weights(fn):
    return fn[-4:] == 'h.pt'
dims = [23, 23, 23, 33]

fig = plt.figure(figsize=(12,5))
for exp_id, dim in zip(exp_ids, dims):
    save_dir = res_dir + 'models/' + exp_id + '/'

    model_fn = list(filter(find_weights, os.listdir(save_dir)))[0]
    lead_time = model_fn[-6:-4]
    training_outputs = np.load(save_dir + '_training_outputs' + '.npy', allow_pickle=True)[()]

    try:
        training_loss, validation_loss = training_outputs['training_loss'], training_outputs['validation_loss']
        RMSEs = np.load(save_dir + model_fn[:-3] + '_RMSE_zt.npy')

        plt.subplot(1,2,1)
        plt.semilogy(validation_loss, label=exp_id + f' ({dim}D)')
        plt.title('training')
        
        plt.subplot(1,4,3)
        plt.plot([0,1], RMSEs[0]*np.ones(2), label=exp_id)
        plt.title(f'RMSE {lead_time}h, z 500')
        plt.xticks([])
        plt.axis([-0.1, 1.1, 0, 600])
        
        plt.subplot(1,4,4)
        plt.plot([0,1], RMSEs[1]*np.ones(2), label=exp_id)
        plt.title(f'RMSE {lead_time}h, t 850')
        plt.xticks([])
        plt.axis([-0.1, 1.1, 0, 3.0])
    except:
        pass

plt.subplot(1,2,1)
plt.ylabel('validation error')
plt.legend()
#plt.subplot(1,4,3)
#plt.legend()
fig.patch.set_facecolor('xkcd:white')
plt.show()

# debug

### RMSE per pixel 

In [None]:
import matplotlib.pyplot as plt

RMSEs_z = np.sqrt(np.mean((preds[:,0,:,:] - z500_test.isel(time=slice(lead_time, None)))**2, axis=0))
RMSEs_t = np.sqrt(np.mean((preds[:,1,:,:] - t850_test.isel(time=slice(lead_time, None)))**2, axis=0))


weights_lat = np.cos(np.deg2rad(z500_test.lat))
weights_lat /= weights_lat.mean()

wRMSEs_z = np.sqrt(
    np.mean( ((preds[:,0,:,:] - z500_test.isel(time=slice(lead_time, None)))**2)*weights_lat, 
            axis=0))
wRMSEs_t = np.sqrt(
    np.mean( ((preds[:,1,:,:] - t850_test.isel(time=slice(lead_time, None)))**2)*weights_lat, 
            axis=0))

plt.figure(figsize=(16,8))

plt.subplot(2,2,1)
plt.imshow(RMSEs_z)
plt.title('RMSEs Z500')
plt.colorbar()
plt.subplot(2,2,2)
plt.imshow(RMSEs_t)
plt.title('RMSEs T850')
plt.colorbar()

plt.subplot(2,2,3)
plt.imshow(wRMSEs_z)
plt.title('weighted RMSEs Z500')
plt.colorbar()
plt.subplot(2,2,4)
plt.imshow(wRMSEs_t)
plt.title('weighted RMSEs T850')
plt.colorbar()

plt.show()


### RMSEs per time point

In [None]:
import matplotlib.pyplot as plt

RMSEs_z = np.sqrt(np.mean((preds[:,0,:,:] - z500_test.isel(time=slice(lead_time, None)))**2, axis=[1,2]))
RMSEs_t = np.sqrt(np.mean((preds[:,1,:,:] - t850_test.isel(time=slice(lead_time, None)))**2, axis=[1,2]))


weights_lat = np.cos(np.deg2rad(z500_test.lat))
weights_lat /= weights_lat.mean()

wRMSEs_z = np.sqrt(
    np.mean( ((preds[:,0,:,:] - z500_test.isel(time=slice(lead_time, None)))**2)*weights_lat, 
            axis=[1,2]))
wRMSEs_t = np.sqrt(
    np.mean( ((preds[:,1,:,:] - t850_test.isel(time=slice(lead_time, None)))**2)*weights_lat, 
            axis=[1,2]))

plt.figure(figsize=(16,8))

plt.subplot(2,2,1)
plt.plot(RMSEs_z)
plt.title('RMSEs Z500')
plt.subplot(2,2,2)
plt.plot(RMSEs_t)
plt.title('RMSEs T850')

plt.subplot(2,2,3)
plt.plot(wRMSEs_z)
plt.title('weighted RMSEs Z500')
plt.subplot(2,2,4)
plt.plot(wRMSEs_t)
plt.title('weighted RMSEs T850')

plt.show()


# quickplot of predictions

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

dg = dg_test

# variable names for display in figure
var_names = {'geopotential' : 'geopotential at 500hPa', 
             'temperature' : 'temperature at 850hPa'}

# pick time stamps to visualize
idx = [2000, 2024, 2048, 2072, 2096, 2120] # index relative to start time of dataset !
pre = dg[idx][0]
preds = model_forward(torch.tensor(pre,requires_grad=False).to(device)).detach().numpy() 
    
plt.figure(figsize=(16,3))
idx_plot = [0,1,2,3,4]
plt.imshow(np.hstack(preds[idx_plot,0,:,:]))
for i in range(1,len(idx_plot)):
    plt.plot(i*64*np.ones(2), [0, 31], 'k')
plt.xticks([])
plt.yticks([])
plt.show()

plt.figure(figsize=(16,3))
idx_plot = [0,1,2,3,4]
plt.imshow(np.hstack(preds[idx_plot,1,:,:]))
for i in range(1,len(idx_plot)):
    plt.plot(i*64*np.ones(2), [0, 31], 'k')
plt.xticks([])
plt.yticks([])
#plt.savefig('/gpfs/home/nonnenma/projects/seasonal_forecasting/results/weatherbench/figs/T850_example_dreds_N5_dt24h.pdf')
plt.show()
