ConvLSTM trained on gridded forcings and geophysical data for all stations.
Test generalization by training on a subset of stations

In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../..')
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt 
from datetime import datetime, timedelta
from sklearn import preprocessing
import netCDF4 as nc
import torch
from torch import nn, utils
from torch.utils.tensorboard import SummaryWriter
from src import load_data, evaluate, conv_lstm, datasets
import torch.autograd as autograd
import pickle

time_stamp = datetime.now().strftime('%Y%m%d-%H%M%S')
time_stamp

In [None]:
import logging
logger = logging.getLogger()
fhandler = logging.FileHandler(filename='../../log.out', mode='a')
chandler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter('%(asctime)s - {} - %(message)s'.format(time_stamp))
fhandler.setFormatter(formatter)
chandler.setFormatter(formatter)
logger.addHandler(fhandler)
logger.addHandler(chandler)
logger.setLevel(logging.INFO)

In [None]:
USE_CUDA = False
if torch.cuda.is_available():
    print('CUDA Available')
    USE_CUDA = True
device = torch.device('cuda:0' if USE_CUDA else 'cpu')
num_devices = torch.cuda.device_count() if USE_CUDA else 0
logger.warning('cuda devices: {}'.format(list(torch.cuda.get_device_name(i) for i in range(num_devices))))
torch.manual_seed(0)
np.random.seed(0)

In [None]:
landcover_nc = nc.Dataset('../../data/NA_NALCMS_LC_30m_LAEA_mmu12_urb05_n40-45w75-90_erie.nc', 'r')
landcover_nc.set_auto_mask(False)
landcover_lats = landcover_nc['lat'][:][::-1]
landcover_lons = landcover_nc['lon'][:]
landcover_nc.close()

out_lats = landcover_lats[::len(landcover_lats)//34 + 1]
out_lons = landcover_lons[::len(landcover_lons)//39 + 1]

In [None]:
seq_len = 5*24
seq_steps = 2
stateful_lstm = False
validation_fraction, val_start, val_end = None, None, None

if stateful_lstm:
    val_start = datetime.strptime('2010-01-01', '%Y-%m-%d') + timedelta(hours=seq_len * seq_steps)  # first day for which to make a prediction in train set
    val_end = '2010-09-30'
    train_start = '2010-10-01'
    train_end = '2012-12-31'
else:
    validation_fraction = 0.1
    train_start = datetime.strptime('2010-01-01', '%Y-%m-%d') + timedelta(hours=seq_len * seq_steps)  # first day for which to make a prediction in train set
    train_end = '2012-12-31'
test_start = '2013-01-01'
test_end = '2014-12-31'

In [None]:
exclude_downstream_stations = ['02GB001', '02GB007', '02GC026', '02GG009', '02GG003', '04165500', '04164000', '04166500', '04198000', '04208504']

In [None]:
rdrs_vars = list(range(8))
train_dataset = datasets.RdrsGridDataset(rdrs_vars, seq_len, seq_steps, train_start, train_end, exclude_stations=exclude_downstream_stations, resample_rdrs=True)
if stateful_lstm:
    val_dataset = datasets.RdrsGridDataset(rdrs_vars, seq_len, seq_steps, val_start, val_end, conv_scalers=train_dataset.conv_scalers, exclude_stations=exclude_downstream_stations, resample_rdrs=True)
test_dataset = datasets.RdrsGridDataset(rdrs_vars, seq_len, seq_steps, test_start, test_end, conv_scalers=train_dataset.conv_scalers, resample_rdrs=True)

In [None]:
landcover_types = None
landcover, landcover_legend = load_data.load_landcover_reduced(values_to_use=landcover_types)
landcover = torch.from_numpy(landcover).float().to(device)

In [None]:
stations = train_dataset.data_runoff['station'].unique()
np.random.seed(2)
test_stations = np.random.choice(stations, size=int(0.2*(len(stations)-len(exclude_downstream_stations))), replace=False)
train_stations = list(s for s in stations if s not in test_stations)

train_station_indices = list(train_dataset.station_to_index[s] for s in train_stations)
test_station_indices = list(test_dataset.station_to_index[s] for s in test_stations)

train_mask = torch.zeros((train_dataset.out_lats.shape[0], train_dataset.out_lats.shape[1]), dtype=torch.bool)
for row in range(train_mask.shape[0]):
    for col in range(train_mask.shape[1]):
        train_mask[row, col] = True if (row, col) in train_station_indices else False
train_mask = train_mask.reshape(-1).to(device)

In [None]:
# Train model
num_epochs = 800
learning_rate = 2e-3
patience = 500
min_improvement = 0.01
best_loss_model = (-1, np.inf, None)

# Prepare model
batch_size = 32
num_convlstm_layers = 2
num_conv_layers = 2
convlstm_hidden_dims = [8] * num_convlstm_layers
conv_hidden_dims = [4] * (num_conv_layers - 1)
convlstm_kernel_size = [(5,5)] * num_convlstm_layers
conv_kernel_size = [(3,3)] * num_conv_layers
conv_activation = nn.Sigmoid
dropout = 0.2
weight_decay = 1e-5

model = conv_lstm.ConvLSTMGridWithGeophysicalInput((train_dataset.conv_height, train_dataset.conv_width), 
                                         train_dataset.n_conv_vars, landcover.shape[0], convlstm_hidden_dims, 
                                         conv_hidden_dims, convlstm_kernel_size, conv_kernel_size, 
                                         num_convlstm_layers, num_conv_layers, conv_activation, dropout=dropout).to(device)
if num_devices > 1:
    model = torch.nn.DataParallel(model, device_ids=list(range(num_devices)))
loss_fn = evaluate.NSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

writer = SummaryWriter(comment='ConvLSTM_withLandcover_generalizationTest')
param_description = {'time_stamp': time_stamp, 'H_convlstm': convlstm_hidden_dims, 'H_conv': conv_hidden_dims, 'batch_size': batch_size, 'num_convlstm_layers': num_convlstm_layers, 'num_conv_layers': num_conv_layers, 'convlstm_kernel_size': convlstm_kernel_size, 'conv_kernel_size': conv_kernel_size, 'loss': loss_fn, 
                     'optimizer': optimizer, 'lr': learning_rate, 'patience': patience, 'min_improvement': min_improvement, 'stateful_lstm': stateful_lstm, 'dropout': dropout, 'landcover_shape': landcover.shape, 'conv_activation': conv_activation,
                     'num_epochs': num_epochs, 'seq_len': seq_len, 'seq_steps': seq_steps, 'train_start': train_start, 'train_end': train_end, 'weight_decay': weight_decay, 'validation_fraction': validation_fraction, 'landcover_types': landcover_types,
                     'test_start': test_start, 'test_end': test_end, 'n_conv_vars': train_dataset.n_conv_vars, 'model': str(model).replace('\n','').replace(' ', ''), 'val_start': val_start, 'val_end': val_end, 'train_stations': train_stations, 'test_stations': test_stations,
                     'train len':len(train_dataset), 'conv_height': train_dataset.conv_height, 'conv_width': train_dataset.conv_width, 'test len': len(test_dataset)}
writer.add_text('Parameter Description', str(param_description))
str(param_description)

In [None]:
if stateful_lstm:
    train_sampler = datasets.StatefulBatchSampler(train_dataset, batch_size)
    val_sampler = datasets.StatefulBatchSampler(val_dataset, batch_size)
    test_sampler = datasets.StatefulBatchSampler(test_dataset, batch_size)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_sampler=train_sampler, pin_memory=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_sampler=val_sampler, pin_memory=True)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_sampler=test_sampler, pin_memory=True)
else:
    val_indices = np.random.choice(len(train_dataset), size=int(validation_fraction * len(train_dataset)), replace=False)
    train_indices = list(i for i in range(len(train_dataset)) if i not in val_indices)
    train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
    val_sampler = torch.utils.data.SubsetRandomSampler(val_indices)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size, sampler=train_sampler, pin_memory=True, drop_last=False)
    val_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size, sampler=val_sampler, pin_memory=True, drop_last=False)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size, shuffle=False, pin_memory=True, drop_last=False)

In [None]:
torch.manual_seed(0)
np.random.seed(0)
for epoch in range(num_epochs):
    model.train()

    epoch_losses = torch.tensor(0.0)
    conv_hidden_states = None
    for i, train_batch in enumerate(train_dataloader):
        y_train = train_batch['y'].reshape((train_batch['y'].shape[0],-1)).to(device, non_blocking=True)
        mask = train_batch['mask'].any(dim=0).reshape(-1).to(device, non_blocking=True)
        mask = mask & train_mask
        landcover_batch = landcover.repeat(y_train.shape[0],1,1,1).to(device, non_blocking=True)
        
        if not mask.any():
            print('Batch {} has no target values. skipping.'.format(i))
            continue
        if not stateful_lstm:
            conv_hidden_states = None
        
        y_pred, conv_hidden_states = model(train_batch['x_conv'].to(device), landcover_batch, hidden_state=conv_hidden_states)
        y_pred = y_pred.reshape((train_batch['y'].shape[0], -1))
        loss = loss_fn(y_pred[:,mask], y_train[:,mask])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_losses += loss.detach()
        
    epoch_loss = (epoch_losses / len(train_dataloader)).item()
    print('Epoch', epoch, 'mean train loss:\t{}'.format(epoch_loss))
    writer.add_scalar('loss_nse', epoch_loss, epoch)
    
    # eval on validation split
    model.eval()
    val_losses = torch.tensor(0.0)
    conv_hidden_states = None
    for i, val_batch in enumerate(val_dataloader):
        y_val = val_batch['y'].reshape((val_batch['y'].shape[0],-1)).to(device, non_blocking=True)
        mask = val_batch['mask'].any(dim=0).reshape(-1).to(device, non_blocking=True)
        mask = mask & train_mask
        landcover_batch = landcover.repeat(y_val.shape[0],1,1,1).to(device, non_blocking=True)
        
        if not stateful_lstm:
            conv_hidden_states = None
        
        batch_pred, conv_hidden_states = model(val_batch['x_conv'].to(device), landcover_batch, hidden_state=conv_hidden_states)
        batch_pred = batch_pred.detach().reshape((val_batch['y'].shape[0], -1))
        val_losses += loss_fn(batch_pred[:,mask], y_val[:,mask]).detach()
        
    val_nse = (val_losses / len(val_dataloader)).item()
    print('Epoch {} mean val loss:  \t{}'.format(epoch, val_nse))
    writer.add_scalar('loss_nse_val', val_nse, epoch)
    if val_nse < best_loss_model[1] - min_improvement:
        best_loss_model = (epoch, val_nse, model.state_dict())  # new best model
        load_data.pickle_model('ConvLSTM_withLandcover_generalizationTest', model, 'allStations', time_stamp)
    elif epoch > best_loss_model[0] + patience:
        print('Patience exhausted in epoch {}. Best val-loss was {}'.format(epoch, best_loss_model[1]))
        break

print('Using best model from epoch', str(best_loss_model[0]), 'which had loss', str(best_loss_model[1]))
model.load_state_dict(best_loss_model[2])
load_data.pickle_model('ConvLSTM_withLandcover_generalizationTest', model, 'allStations', time_stamp)

In [None]:
logger.warning('predicting')
model.eval()

predictions = []
conv_hidden_states = None
for i, test_batch in enumerate(test_dataloader):
    if not stateful_lstm:
        conv_hidden_states = None
        
    landcover_batch = landcover.repeat(test_batch['y'].shape[0],1,1,1).to(device)
    pred, conv_hidden_states = model(test_batch['x_conv'].to(device), landcover_batch, hidden_state=conv_hidden_states)
    predictions.append(pred.detach())

predictions = torch.cat(predictions).cpu()

if stateful_lstm:
    # reorder time series
    pred_indices = np.array(list(test_sampler.__iter__())).reshape(-1)
    predictions = predictions[pred_indices.argsort()]

In [None]:
actuals = test_dataset.data_runoff.copy()
if len(actuals['date'].unique()) != len(predictions):
    print('Warning: length of prediction {} and actuals {} does not match.'.format(len(predictions), len(actuals['date'].unique())))

nse_dict = {}
mse_dict = {}
predictions_df = pd.DataFrame(columns=actuals.columns)
predictions_df['is_test_station'] = False
for station in actuals['station'].unique():
    row, col = test_dataset.station_to_index[station]
    
    act = actuals[actuals['station'] == station].set_index('date')['runoff']
    if predictions.shape[0] != act.shape[0]:
        print('Warning: length of prediction {} and actuals {} does not match for station {}. Ignoring excess actuals.'.format(len(predictions), len(act), station))
        act = act.iloc[:predictions.shape[0]]
    pred = pd.DataFrame({'runoff': predictions[:,row,col]}, index=act.index)
    pred['station'] = station
    pred['is_test_station'] = station in test_stations
    predictions_df = predictions_df.append(pred.reset_index(), sort=True)
    
    nse, mse = evaluate.evaluate_daily(station, pred['runoff'], act, writer=writer)
    nse_dict[station] = nse
    mse_dict[station] = mse
    
    print(station, '\tNSE:', nse, '\tMSE:', mse, '(clipped to 0)')

In [None]:
nse_train = list(nse_dict[s] for s in train_stations)
mse_train = list(mse_dict[s] for s in train_stations)
print('Train Median NSE (clipped to 0)', np.median(nse_train), '/ Min', np.min(nse_train), '/ Max', np.max(nse_train))
print('Train Median MSE (clipped to 0)', np.median(mse_train), '/ Min', np.min(mse_train), '/ Max', np.max(mse_train))

nse_test = list(nse_dict[s] for s in test_stations)
mse_test = list(mse_dict[s] for s in test_stations)
print('Test Median NSE (clipped to 0)', np.median(nse_test), '/ Min', np.min(nse_test), '/ Max', np.max(nse_test))
print('Test Median MSE (clipped to 0)', np.median(mse_test), '/ Min', np.min(mse_test), '/ Max', np.max(mse_test))

nse_test = list(nse_dict[s] for s in exclude_downstream_stations)
mse_test = list(mse_dict[s] for s in exclude_downstream_stations)
print('Exclude Median NSE (clipped to 0)', np.median(nse_test), '/ Min', np.min(nse_test), '/ Max', np.max(nse_test))
print('Exclude Median MSE (clipped to 0)', np.median(mse_test), '/ Min', np.min(mse_test), '/ Max', np.max(mse_test))

writer.add_scalar('nse_median', np.median(nse_test))

In [None]:
print(list((s, nse_dict[s]) for s in nse_dict.keys()))

In [None]:
writer.close()

In [None]:
save_df = pd.merge(predictions_df.rename({'runoff': 'prediction'}, axis=1), actuals.rename({'runoff': 'actual'}, axis=1), 
                   on=['date', 'station'])[['date', 'station', 'prediction', 'actual', 'is_test_station']]
load_data.pickle_results('ConvLSTM_withLandcover_generalizationTest', save_df, time_stamp)

In [None]:
datetime.now().strftime('%Y%m%d-%H%M%S')