# dev
- reimplementing the 'Quickstart' notebook from S. Rasp for training in pytorch

# define setup

In [None]:
import numpy as np
import torch
from src.pytorch.util import init_torch_device

device = init_torch_device()

datadir = '/gpfs/work/nonnenma/data/forecast_predictability/weatherbench/5_625deg/'
res_dir = '/gpfs/work/nonnenma/results/forecast_predictability/weatherbench/5_625deg/'

model_name = 'simpleResnet' # 'simpleResnet', 'tvfcnResnet50', 'cnnbn', 'Unetbn'

lead_time = 3*24
batch_size = 32

train_years = ('1979', '2015')
validation_years = ('2016', '2016')
test_years = ('2017', '2018')

var_dict = {'geopotential': ('z', [100, 200, 500, 850, 1000]),
           'temperature': ('t', [100, 200, 500, 850, 1000]),
           'u_component_of_wind': ('u', [100, 200, 500, 850, 1000]), 
           'v_component_of_wind': ('v', [100, 200, 500, 850, 1000]),
           'constants': ['lsm','orography','lat2d']
           }

"""
var_dict = {'geopotential': ('z', [500, 850]),
           'temperature': ('t', [500, 850]),
           'u_component_of_wind': ('u', [500, 850]), 
           'v_component_of_wind': ('v', [500, 850]),
           'constants': ['lsm','orography','lat2d']
           }
"""


target_var_dict = {'geopotential': 500, 'temperature': 850}

filters = [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128]
kernel_sizes = [7, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]

past_times = [-6, -12]
verbose = True
loss_fun = 'mse'

In [None]:
import xarray as xr
import dask
from src.pytorch.Dataset import Dataset_dask, Dataset_xr

x = xr.merge(
[xr.open_mfdataset(f'{datadir}/{var}/*.nc', combine='by_coords')
 for var in var_dict.keys()],
fill_value=0  # For the 'tisr' NaNs
)
dg_train = Dataset_dask(x.sel(time=slice(train_years[0], train_years[1])), var_dict, lead_time, 
                   normalize=False, res_dir=res_dir, train_years=train_years,
                   target_var_dict=target_var_dict, past_times=past_times, verbose=verbose)
x = xr.merge(
[xr.open_mfdataset(f'{datadir}/{var}/*.nc', combine='by_coords')
 for var in var_dict.keys()],
fill_value=0  # For the 'tisr' NaNs
)
dg_validation = Dataset_xr(x.sel(time=slice(validation_years[0], validation_years[1])), var_dict, lead_time,
                        normalize=True, res_dir=res_dir, train_years=train_years, randomize_order=False,
                        target_var_dict=target_var_dict, past_times=past_times, verbose=verbose)
dg_test =  Dataset_xr(x.sel(time=slice(test_years[0], test_years[1])), var_dict, lead_time,
                   normalize=True, mean=dg_validation.mean, std=dg_validation.std, randomize_order=False,
                   target_var_dict=target_var_dict, past_times=past_times, verbose=verbose)
print('chunks', dg_train.data.chunks)

def collate_fn(batch):
    X_stack = dask.array.stack([X for X,_ in batch], axis=0).compute()
    Y_stack = dask.array.stack([y for _,y in batch], axis=0).compute()
    X_stack = torch.tensor(X_stack, requires_grad=False)
    Y_stack = torch.tensor(Y_stack, requires_grad=False)
    return (X_stack, Y_stack)

num_workers = int(train_years[1]) - int(train_years[0]) + 1

train_loader = torch.utils.data.DataLoader(
    dg_train,
    batch_size=batch_size,
    num_workers=num_workers,
    collate_fn=collate_fn,
    drop_last=True)

In [None]:
import time
# check I/O speed on single (empty) epoch
num_steps = 1
t = time.time()
for batch in train_loader:
    if np.mod(num_steps, 1) == 0:
        print('- #, time: ', num_steps, time.time() - t)
    inputs, targets = batch[0].to(device), batch[1].to(device)
    print(inputs.shape, targets.shape)
    num_steps +=1

In [None]:
import time
t = time.time()
[item.compute() for item in dg_train[np.arange(32)+dg_train.max_input_lag]]
print(time.time()-t)

In [None]:
import numpy as np
import torch
import xarray as xr
import math

class Dataset(torch.utils.data.IterableDataset):
    r"""A class representing a :class:`Dataset`.

    Base on DataGenerator() object written by S. Rasp (for tensorflow v1.x): 
    https://github.com/pangeo-data/WeatherBench/blob/ced939e20da0432bc816d64c34344e72f9b4cd17/src/train_nn.py#L18

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """
    def __init__(self, ds, var_dict, lead_time, mean=None, std=None, load=False,
                 start=None, end=None, normalize=False, norm_subsample=1, randomize_order=True,
                 target_var_dict={'geopotential' : 500, 'temperature' : 850}, 
                 dtype=np.float32, res_dir=None, train_years=None, past_times=[], verbose=False):

        self.ds = ds
        self.var_dict = var_dict
        self.lead_time = lead_time
        self.past_times = past_times
        self.normalize = normalize
        self.randomize_order = randomize_order
        self.verbose = verbose

        # indexing for __getitem__ and __iter__ to find targets Z500, T850
        assert np.all(var in var_dict.keys() for var in target_var_dict.keys())
        assert np.all(level in var_dict[var][1] for var, level in target_var_dict.items())
        
        self.max_input_lag = -np.min(self.past_times) if len(self.past_times) > 0 else 0
        if start is None or end is None:
            start = np.max([0, self.max_input_lag])
            end = self.ds.time.isel(time=slice(0, -self.lead_time)).values.shape[0]
        assert end > start, "this example code only works with end >= start"
        assert start >= self.max_input_lag
        self.start, self.end = start, end

        self.data = []
        self.level_names = []
        generic_level = xr.DataArray([1], coords={'level': [1]}, dims=['level'])
        for long_var, params in var_dict.items():
            if long_var == 'constants':
                for var in params:
                    self.data.append(ds[var].expand_dims(
                        {'level': generic_level, 'time': ds.time}, (1, 0)
                    ).astype(dtype))
                    self.level_names.append(var)
            else:
                var, levels = params
                try:
                    self.data.append(ds[var].sel(level=levels))
                    self.level_names += [f'{var}_{level}' for level in levels]
                except ValueError:
                    self.data.append(ds[var].expand_dims({'level': generic_level}, 1))
                    self.level_names.append(var)        

        self.data = xr.concat(self.data, 'level')  # .transpose('time', 'lat', 'lon', 'level')
        self.data['level_names'] = xr.DataArray(
            self.level_names, dims=['level'], coords={'level': self.data.level})        
        self.output_idxs = range(len(self.data.level))

        if self.normalize:
            if mean is None or std is None:
                try:
                    print('Loading means and standard deviations from disk')
                    mean, std, level, level_names = load_mean_std(res_dir, var_dict, train_years)
                    assert np.all( level_names == self.level_names )
                    self.mean = xr.DataArray(mean, coords={'level': level}, dims=['level'])
                    self.std = xr.DataArray(std, coords={'level': level}, dims=['level'])
                except:
                    print('WARNING! Could not load means and stds. Computing. Can take a while !')
                    self.mean = self.data.isel(time=slice(0, None, norm_subsample)).mean(
                        ('time', 'lat', 'lon')).compute() if mean is None else mean
                    self.std = self.data.isel(time=slice(0, None, norm_subsample)).std(
                        ('time', 'lat', 'lon')).compute() if std is None else std
            else:
                self.mean, self.std = mean, std
            self.data = (self.data - self.mean) / self.std

        self.valid_time = self.data.isel(time=slice(lead_time+self.max_input_lag, None)).time

        self._target_idx = []
        for var, level in target_var_dict.items():
            target_name = var_dict[var][0] + '_' + str(level)
            self._target_idx += [np.where(np.array(self.level_names) == target_name)[0][0]]

        # According to S. Rasp, this has to go after computation of self.mean, self.std:
        if load: print('Loading data into RAM'); self.data.load()

    def __getitem__(self, index):
        """ Generate one batch of data """
        assert np.min(index) >= self.start
        idx = np.asarray(index)

        X = self.data.data[idx,:,:,:]
        y = self.data.data[idx + self.lead_time,:,:,:][:, self._target_idx, :, :]

        if self.max_input_lag > 0:
            Xl = [X]
            for l in self.past_times:
                Xl.append(self.data.data[idx+l,:,:,:])
            X = dask.array.concatenate(Xl, axis=1) if len (idx) > 1 else dask.array.concatenate(Xl, axis=0)

        return X, y

    def __iter__(self):
        """ Return iterable over data in random order """
        if torch.utils.data.get_worker_info() is None:
            iter_start = torch.tensor(self.start, requires_grad=False, dtype=torch.int)
            iter_end = torch.tensor(self.end, requires_grad=False, dtype=torch.int)  
        else: 
            worker_info = torch.utils.data.get_worker_info()
            worker_id, num_workers = worker_info.id, worker_info.num_workers
            worker_yrs = math.ceil(len(self.data.chunks[0])/num_workers)
            cumidx = np.concatenate(([0], np.cumsum(self.data.chunks[0])))
            iter_start = cumidx[worker_id*worker_yrs] + self.start 
            iter_start = torch.tensor(iter_start, requires_grad=False, dtype=torch.int)
            iter_end = min(cumidx[min((worker_id+1)*worker_yrs, len(self.data.chunks[0]))], self.end) 
            iter_end = torch.tensor(iter_end - self.lead_time, requires_grad=False, dtype=torch.int)

            if self.verbose:
                print(f'worker stats: worker #{worker_id} / {num_workers}')
                print('len(data.chunks)', len(self.data.chunks[0]))
                print('#assigned years:', worker_yrs)
                print('index start', iter_start)
                print('index end', iter_end)

        idx = np.arange(iter_start, iter_end)
        if self.randomize_order:
            idx = (torch.randperm(iter_end - iter_start) + iter_start).cpu().numpy()

        X = self.data.data[idx,:,:,:]
        y = self.data.data[idx + self.lead_time, :, :, :][:, self._target_idx, :, :]

        if self.max_input_lag > 0:
            Xl = [X]
            for l in self.past_times:
                Xl.append(self.data.data[idx+l,:,:,:])
            X = dask.array.concatenate(Xl, axis=1) if len (idx) > 1 else dask.array.concatenate(Xl, axis=0)

        return zip(X,y)

    def __len__(self):
        return self.data.isel(time=slice(0, -self.lead_time)).shape[0]

# load data

In [None]:
from src.pytorch.util import load_data
from src.pytorch.Dataset import collate_fn

dg_train, dg_validation, dg_test = load_data(
    var_dict=var_dict, lead_time=lead_time,
    train_years=(train_years[0], train_years[1]), 
    validation_years=(validation_years[0], validation_years[1]), 
    test_years=(test_years[0], test_years[1]),
    target_var_dict=target_var_dict, datadir=datadir, res_dir=res_dir,
    past_times=past_times, verbose=True
)
validation_loader = torch.utils.data.DataLoader(
    dg_validation,
    batch_size=batch_size,
    drop_last=False)    
train_loader = torch.utils.data.DataLoader(
    dg_train,
    batch_size=batch_size,
    num_workers=int(train_years[1]) - int(train_years[0]) + 1,
    collate_fn=collate_fn,
    drop_last=True)  

n_channels = len(dg_train.data.level.level) * (len(dg_train.past_times)+1)
print('n_channels', n_channels)

#model_fn = f'{n_channels}D_fc{model_name}_{lead_time//24}d_pytorch.pt' # file name for saving/loading prediction model
model_fn = f'{n_channels}D_fc{model_name}_{lead_time//24}d_pytorch_lrdecay_weightdecay_normed_test2.pt' # file name for saving/loading prediction model
print('model filename', model_fn)


In [None]:
import time
# check I/O speed on single (empty) epoch
num_steps = 0
t = time.time()
for batch in train_loader:
    if np.mod(num_steps, 1) == 0:
        print('- #, time: ', num_steps, time.time() - t)
    inputs, targets = batch[0].to(device), batch[1].to(device)
    print(inputs.shape, targets.shape)
    num_steps +=1

# define model

In [None]:
from src.pytorch.util import named_network
model, model_forward = named_network(model_name, n_channels, len(target_var_dict), 
                                     filters=filters, kernel_sizes=kernel_sizes)

# train model

In [None]:
model_fn = 'models/multi_delay_test/multi_delay_test_33D_fcsimpleResnet_72h.pt'

In [None]:
from src.pytorch.train import train_model, loss_function

train_again = False
if train_again:
    loss_fun = loss_function(loss_fun)
    training_outputs = train_model(model, train_loader, validation_loader, device, model_forward,
                    loss_fun=loss_fun, lr=5e-4, lr_min=1e-5, lr_decay=0.2, weight_decay=1e-5,
                    max_epochs=200, max_patience=20, max_lr_patience=5, eval_every=2000,
                    verbose=True, save_dir=res_dir + model_fn)

# if skip training, load model from disk
else:
    model.load_state_dict(torch.load(res_dir + model_fn, map_location=torch.device(device)))

# evaluate

In [None]:
from src.score import compute_weighted_rmse, load_test_data
z500_test = load_test_data(f'{datadir}geopotential_500/', 'z')
t850_test = load_test_data(f'{datadir}temperature_850/', 't')
z500_test.isel(time=slice(lead_time+dg_test.max_input_lag, None)).values.shape

In [None]:
from src.pytorch.train import calc_val_loss
print('validation loss:', calc_val_loss(validation_loader, model_forward, device))

In [None]:
from src.pytorch.train_nn import create_predictions
from src.score import compute_weighted_rmse, load_test_data

preds = create_predictions(model,
                           dg_test,
                           var_dict={'z' : None, 't' : None},
                           batch_size=100,
                           model_forward=model_forward,
                           verbose=True)

z500_test = load_test_data(f'{datadir}geopotential_500/', 'z')
t850_test = load_test_data(f'{datadir}temperature_850/', 't')
rmse_z = compute_weighted_rmse(preds.z, z500_test.isel(time=slice(lead_time+dg_test.max_input_lag, None))).load()
rmse_t = compute_weighted_rmse(preds.t, t850_test.isel(time=slice(lead_time+dg_test.max_input_lag, None))).load()
print('RMSE z', rmse_z.values); print('RMSE t', rmse_t.values)

In [None]:
rmse_z = compute_weighted_rmse(preds.z, z500_test.isel(time=slice(lead_time+dg_test.max_input_lag None))).load()
rmse_t = compute_weighted_rmse(preds.t, t850_test.isel(time=slice(lead_time+max_input_lag, None))).load()
print('RMSE z', rmse_z.values); print('RMSE t', rmse_t.values)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

dg = dg_test

# variable names for display in figure
var_names = {'geopotential' : 'geopotential at 500hPa', 
             'temperature' : 'temperature at 850hPa'}

# pick time stamps to visualize
idx = [2000] # index relative to start time of dataset !

for i in idx:
    pre = dg[[i]][0]
    post = dg[[i]][1]
    # predict for single time stamp
    pred = model_forward(torch.tensor(pre,requires_grad=False).to(device)).detach().numpy()
    
    plt.figure(figsize=(16,6))
    for j in range(2):
        plt.subplot(1,2,j+1)

        # top: current state, middle: model-predicted future state, bottom: future state
        j_ = dg._target_idx[j] # index for dg object in case first two dimensions not Z500, T850
        plt.imshow(np.vstack((post[0,j,:,:], pred[0,j,:,:], pre[0,j_,:,:])))

        plt.plot([0.5, pred.shape[3]+.5], (1*pred.shape[2]-0.5)*np.ones(2), 'k', linewidth=1.5)
        plt.plot([0.5, pred.shape[3]+.5], (2*pred.shape[2]-0.5)*np.ones(2), 'k', linewidth=1.5)
        plt.yticks([pred.shape[2]//2, 3*pred.shape[2]//2, 5*pred.shape[2]//2],
                   [f'+{lead_time}h true', f'+{lead_time}h est.', 'state'])
        plt.axis([-0.5, pred.shape[3]-0.5, -0.5, 3*pred.shape[2]-0.5])
        plt.colorbar()
        plt.xlabel(var_names[list(dg.var_dict.keys())[j]])
        plt.title(dg.data.time.isel(time=i).values)
    plt.show()


In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

dg = dg_test

# variable names for display in figure
var_names = {'geopotential' : 'geopotential at 500hPa', 
             'temperature' : 'temperature at 850hPa'}

# pick time stamps to visualize
idx = [2000] # index relative to start time of dataset !

for i in idx:
    pre = dg[[i]][0]
    post = dg[[i]][1]
    # predict for single time stamp
    pred = model_forward(torch.tensor(pre,requires_grad=False).to(device)).detach().numpy()
    
    plt.figure(figsize=(16,6))
    for j in range(2):
        plt.subplot(1,2,j+1)

        # top: current state, middle: model-predicted future state, bottom: future state
        j_ = dg._target_idx[j] # index for dg object in case first two dimensions not Z500, T850
        plt.imshow(np.vstack((post[0,j,:,:], pred[0,j,:,:], pre[0,j_,:,:])))

        plt.plot([0.5, pred.shape[3]+.5], (1*pred.shape[2]-0.5)*np.ones(2), 'k', linewidth=1.5)
        plt.plot([0.5, pred.shape[3]+.5], (2*pred.shape[2]-0.5)*np.ones(2), 'k', linewidth=1.5)
        plt.yticks([pred.shape[2]//2, 3*pred.shape[2]//2, 5*pred.shape[2]//2],
                   [f'+{lead_time}h true', f'+{lead_time}h est.', 'state'])
        plt.axis([-0.5, pred.shape[3]-0.5, -0.5, 3*pred.shape[2]-0.5])
        plt.colorbar()
        plt.xlabel(var_names[list(dg.var_dict.keys())[j]])
        plt.title(dg.data.time.isel(time=i).values)
    plt.show()


In [None]:
import cartopy.crs as ccrs

## model survery

In [None]:
import matplotlib.pyplot as plt
import os

exp_ids = ['resnet_baseline', 'resnet_baseline_no_L2',  'resnet_latmse', 'multi_delay_test']
def find_weights(fn):
    return fn[-4:] == 'h.pt'
dims = [23, 23, 23, 33]

fig = plt.figure(figsize=(12,5))
for exp_id, dim in zip(exp_ids, dims):
    save_dir = res_dir + 'models/' + exp_id + '/'

    model_fn = list(filter(find_weights, os.listdir(save_dir)))[0]
    lead_time = model_fn[-6:-4]
    training_outputs = np.load(save_dir + '_training_outputs' + '.npy', allow_pickle=True)[()]

    try:
        training_loss, validation_loss = training_outputs['training_loss'], training_outputs['validation_loss']
        RMSEs = np.load(save_dir + model_fn[:-3] + '_RMSE_zt.npy')

        plt.subplot(1,2,1)
        plt.semilogy(validation_loss, label=exp_id + f' ({dim}D)')
        plt.title('training')
        
        plt.subplot(1,4,3)
        plt.plot([0,1], RMSEs[0]*np.ones(2), label=exp_id)
        plt.title(f'RMSE {lead_time}h, z 500')
        plt.xticks([])
        plt.axis([-0.1, 1.1, 0, 600])
        
        plt.subplot(1,4,4)
        plt.plot([0,1], RMSEs[1]*np.ones(2), label=exp_id)
        plt.title(f'RMSE {lead_time}h, t 850')
        plt.xticks([])
        plt.axis([-0.1, 1.1, 0, 3.0])
    except:
        pass

plt.subplot(1,2,1)
plt.ylabel('validation error')
plt.legend()
#plt.subplot(1,4,3)
#plt.legend()
fig.patch.set_facecolor('xkcd:white')
plt.show()

# debug

### RMSE per pixel 

In [None]:
import matplotlib.pyplot as plt

RMSEs_z = np.sqrt(np.mean((preds[:,0,:,:] - z500_test.isel(time=slice(lead_time, None)))**2, axis=0))
RMSEs_t = np.sqrt(np.mean((preds[:,1,:,:] - t850_test.isel(time=slice(lead_time, None)))**2, axis=0))


weights_lat = np.cos(np.deg2rad(z500_test.lat))
weights_lat /= weights_lat.mean()

wRMSEs_z = np.sqrt(
    np.mean( ((preds[:,0,:,:] - z500_test.isel(time=slice(lead_time, None)))**2)*weights_lat, 
            axis=0))
wRMSEs_t = np.sqrt(
    np.mean( ((preds[:,1,:,:] - t850_test.isel(time=slice(lead_time, None)))**2)*weights_lat, 
            axis=0))

plt.figure(figsize=(16,8))

plt.subplot(2,2,1)
plt.imshow(RMSEs_z)
plt.title('RMSEs Z500')
plt.colorbar()
plt.subplot(2,2,2)
plt.imshow(RMSEs_t)
plt.title('RMSEs T850')
plt.colorbar()

plt.subplot(2,2,3)
plt.imshow(wRMSEs_z)
plt.title('weighted RMSEs Z500')
plt.colorbar()
plt.subplot(2,2,4)
plt.imshow(wRMSEs_t)
plt.title('weighted RMSEs T850')
plt.colorbar()

plt.show()


### RMSEs per time point

In [None]:
import matplotlib.pyplot as plt

RMSEs_z = np.sqrt(np.mean((preds[:,0,:,:] - z500_test.isel(time=slice(lead_time, None)))**2, axis=[1,2]))
RMSEs_t = np.sqrt(np.mean((preds[:,1,:,:] - t850_test.isel(time=slice(lead_time, None)))**2, axis=[1,2]))


weights_lat = np.cos(np.deg2rad(z500_test.lat))
weights_lat /= weights_lat.mean()

wRMSEs_z = np.sqrt(
    np.mean( ((preds[:,0,:,:] - z500_test.isel(time=slice(lead_time, None)))**2)*weights_lat, 
            axis=[1,2]))
wRMSEs_t = np.sqrt(
    np.mean( ((preds[:,1,:,:] - t850_test.isel(time=slice(lead_time, None)))**2)*weights_lat, 
            axis=[1,2]))

plt.figure(figsize=(16,8))

plt.subplot(2,2,1)
plt.plot(RMSEs_z)
plt.title('RMSEs Z500')
plt.subplot(2,2,2)
plt.plot(RMSEs_t)
plt.title('RMSEs T850')

plt.subplot(2,2,3)
plt.plot(wRMSEs_z)
plt.title('weighted RMSEs Z500')
plt.subplot(2,2,4)
plt.plot(wRMSEs_t)
plt.title('weighted RMSEs T850')

plt.show()


# quickplot of predictions

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

dg = dg_test

# variable names for display in figure
var_names = {'geopotential' : 'geopotential at 500hPa', 
             'temperature' : 'temperature at 850hPa'}

# pick time stamps to visualize
idx = [2000, 2024, 2048, 2072, 2096, 2120] # index relative to start time of dataset !
pre = dg[idx][0]
preds = model_forward(torch.tensor(pre,requires_grad=False).to(device)).detach().numpy() 
    
plt.figure(figsize=(16,3))
idx_plot = [0,1,2,3,4]
plt.imshow(np.hstack(preds[idx_plot,0,:,:]))
for i in range(1,len(idx_plot)):
    plt.plot(i*64*np.ones(2), [0, 31], 'k')
plt.xticks([])
plt.yticks([])
plt.show()

plt.figure(figsize=(16,3))
idx_plot = [0,1,2,3,4]
plt.imshow(np.hstack(preds[idx_plot,1,:,:]))
for i in range(1,len(idx_plot)):
    plt.plot(i*64*np.ones(2), [0, 31], 'k')
plt.xticks([])
plt.yticks([])
#plt.savefig('/gpfs/home/nonnenma/projects/seasonal_forecasting/results/weatherbench/figs/T850_example_dreds_N5_dt24h.pdf')
plt.show()
