LSTM trained on gridded forcings for each station, one model for all stations

In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../..')
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from datetime import datetime, timedelta
from sklearn import preprocessing
import netCDF4 as nc
import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from src import load_data, evaluate, datasets, utils

time_stamp = datetime.now().strftime('%Y%m%d-%H%M%S')
time_stamp

'20190904-215419'

In [2]:
USE_CUDA = False
if torch.cuda.is_available():
    print('CUDA Available')
    USE_CUDA = True
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
device = torch.device('cuda:0' if USE_CUDA else 'cpu')
num_devices = torch.cuda.device_count() if USE_CUDA else 0
print('cuda devices: {}'.format(list(torch.cuda.get_device_name(i) for i in range(num_devices))))
torch.manual_seed(0)
np.random.seed(0)

CUDA Available
cuda devices: ['Tesla V100-SXM2-16GB']


In [3]:
rdrs_vars = [4,5]
agg = None #['sum', 'minmax']
seq_len = 5*24
seq_steps = 1
validation_fraction = 0.1
batch_size = 4

train_start = datetime.strptime('2010-01-01', '%Y-%m-%d') + timedelta(hours=seq_len * seq_steps)
train_end = '2012-12-31'
test_start = '2013-01-01'
test_end = '2014-12-31'

In [4]:
train_dataset = datasets.RdrsDataset(rdrs_vars, seq_len, seq_steps, train_start, train_end, station=True, aggregate_daily=agg)
test_dataset = datasets.RdrsDataset(rdrs_vars, seq_len, seq_steps, test_start, test_end, station=True, aggregate_daily=agg,
                                    conv_scalers=train_dataset.conv_scalers, fc_scalers=train_dataset.fc_scalers)

val_indices = np.random.choice(len(train_dataset), size=int(validation_fraction * len(train_dataset)), replace=False)
train_indices = list(i for i in range(len(train_dataset)) if i not in val_indices)
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
val_sampler = torch.utils.data.SubsetRandomSampler(val_indices)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size, sampler=train_sampler, pin_memory=True, drop_last=False)
val_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size, sampler=val_sampler, pin_memory=True, drop_last=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size, shuffle=False, pin_memory=True, drop_last=False)

In [5]:
class LSTMRegression(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, batch_size, dropout):
        super(LSTMRegression, self).__init__()
        self.batch_size = batch_size
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, dropout=dropout)
        self.linear = nn.Linear(hidden_dim, 1)
        self.init_hidden(batch_size)
    def init_hidden(self, batch_size):
        self.hidden = (torch.randn(self.num_layers, batch_size, self.hidden_dim, device=device, requires_grad=True),
                       torch.randn(self.num_layers, batch_size, self.hidden_dim, device=device, requires_grad=True))

    def forward(self, input):
        lstm_out, self.hidden = self.lstm(input.permute(1,0,2), self.hidden)
        return self.linear(lstm_out[-1])

In [6]:
# Create a mask of the grid that contains all grid cells that fall into a station's subwatershed
station_cell_mask = torch.zeros(train_dataset.x_conv.shape[-2:]).bool()
station_cell_mapping = load_data.get_station_cell_mapping()
for station in station_cell_mapping['station'].unique():    
    for _, row in station_cell_mapping[station_cell_mapping['station'] == station].iterrows():
        station_cell_mask[row['col'] - 1, row['row'] - 1] = True

onehot_vars = list(i for i,v in enumerate(train_dataset.fc_var_names) if v.startswith('station') or v.startswith('month'))

In [7]:
num_epochs = 300
learning_rate = 2e-3
patience = 100
min_improvement = 0.01
H = 20
lstm_layers = 1
dropout = 0.3
weight_decay = 1e-5
best_loss_model = (-1, np.inf, None)
input_dim = train_dataset.x_conv.shape[2] * int(station_cell_mask.sum()) + len(onehot_vars)
model = LSTMRegression(input_dim, H, lstm_layers, batch_size, dropout).to(device)
loss_fn = evaluate.NSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

param_description = {'time_stamp': time_stamp, 'H': H, 'batch_size': batch_size, 'lstm_layers': lstm_layers, 'loss': loss_fn, 'optimizer': optimizer, 'lr': learning_rate, 
                     'patience': patience, 'min_improvement': min_improvement, 'dropout': dropout, 'num_epochs': num_epochs, 'seq_len': seq_len, 'seq_steps': seq_steps, 
                     'train_start': train_start, 'train_end': train_end, 'weight_decay': weight_decay, 'validation_fraction': validation_fraction, 'test_start': test_start, 
                     'test_end': test_end, 'input_dim': input_dim, 'model': str(model).replace('\n','').replace(' ', ''), 'train len':len(train_dataset), 
                     'test len': len(test_dataset), 'rdrs_vars': rdrs_vars, 'aggregate_daily': agg}
writer = SummaryWriter()
writer.add_text('Parameter Description', str(param_description))
str(param_description)

  "num_layers={}".format(dropout, num_layers))


"{'time_stamp': '20190904-215419', 'H': 20, 'batch_size': 4, 'lstm_layers': 1, 'loss': NSELoss(), 'optimizer': Adam (\nParameter Group 0\n    amsgrad: False\n    betas: (0.9, 0.999)\n    eps: 1e-08\n    lr: 0.002\n    weight_decay: 1e-05\n), 'lr': 0.002, 'patience': 100, 'min_improvement': 0.01, 'dropout': 0.3, 'num_epochs': 300, 'seq_len': 120, 'seq_steps': 1, 'train_start': datetime.datetime(2010, 1, 6, 0, 0), 'train_end': '2012-12-31', 'weight_decay': 1e-05, 'validation_fraction': 0.1, 'test_start': '2013-01-01', 'test_end': '2014-12-31', 'input_dim': 784, 'model': 'LSTMRegression((lstm):LSTM(784,20,dropout=0.3)(linear):Linear(in_features=20,out_features=1,bias=True))', 'train len': 49113, 'test len': 33580, 'rdrs_vars': [4, 5], 'aggregate_daily': None}"

In [8]:
for epoch in range(num_epochs):
    model.train()
    train_losses = torch.tensor(0.0)
    for i, train_batch in enumerate(train_dataloader):
        x_train = train_batch['x_conv'][...,station_cell_mask].reshape(*train_batch['x_conv'].shape[:2], -1)
        x_train = torch.cat([x_train, train_batch['x_fc'][:,onehot_vars].unsqueeze(dim=1).repeat(1,x_train.shape[1],1)], dim=2).to(device)
        model.init_hidden(x_train.shape[0])
        y_pred = model(x_train)

        loss = loss_fn(y_pred.reshape(-1), train_batch['y'].to(device), means=train_batch['y_mean'].to(device))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_losses += loss.detach()
        
    train_loss = (train_losses / len(train_dataloader)).item()
    print('Epoch', epoch, 'mean loss:', train_loss)
    writer.add_scalar('loss_nse', train_loss, epoch)

    model.eval()
    val_losses = torch.tensor(0.0)
    for i, val_batch in enumerate(val_dataloader):
        x_val = val_batch['x_conv'][...,station_cell_mask].reshape(*val_batch['x_conv'].shape[:2], -1)
        x_val = torch.cat([x_val, val_batch['x_fc'][:,onehot_vars].unsqueeze(dim=1).repeat(1,x_val.shape[1],1)], dim=2).to(device)
        
        model.init_hidden(x_val.shape[0])
        y_pred = model(x_val)

        loss = loss_fn(y_pred.reshape(-1), val_batch['y'].to(device), means=val_batch['y_mean'].to(device))
        val_losses += loss.detach()
        
    val_loss = (val_losses / len(val_dataloader)).item()
    print('Epoch', epoch, 'mean val loss:', val_loss)
    writer.add_scalar('loss_nse_val', val_loss, epoch)
    if val_loss < best_loss_model[1] - min_improvement:
        best_loss_model = (epoch, val_loss, model.state_dict())  # new best model
        load_data.pickle_model('LSTM_VIC-oneModel', model, 'allStations', time_stamp)
    elif epoch > best_loss_model[0] + patience:
        print('Patience exhausted in epoch {}. Best loss was {}'.format(epoch, best_loss_model[1]))
        break

Epoch 0 mean loss: 2.2175967693328857
Epoch 0 mean val loss: 2.160881996154785
Saved model as /home/mgauch/runoff-nn/src/../pickle/models/LSTM_VIC-oneModel_allStations_20190904-215419.pkl


  "type " + obj.__name__ + ". It won't be checked "


Epoch 1 mean loss: 2.524139881134033
Epoch 1 mean val loss: 2.033656120300293
Saved model as /home/mgauch/runoff-nn/src/../pickle/models/LSTM_VIC-oneModel_allStations_20190904-215419.pkl
Epoch 2 mean loss: 2.2277350425720215
Epoch 2 mean val loss: 2.5408828258514404
Epoch 3 mean loss: 2.3025639057159424
Epoch 3 mean val loss: 1.8486477136611938
Saved model as /home/mgauch/runoff-nn/src/../pickle/models/LSTM_VIC-oneModel_allStations_20190904-215419.pkl
Epoch 4 mean loss: 1.9110186100006104
Epoch 4 mean val loss: 1.7836331129074097
Saved model as /home/mgauch/runoff-nn/src/../pickle/models/LSTM_VIC-oneModel_allStations_20190904-215419.pkl
Epoch 5 mean loss: 1.8606027364730835
Epoch 5 mean val loss: 1.4809465408325195
Saved model as /home/mgauch/runoff-nn/src/../pickle/models/LSTM_VIC-oneModel_allStations_20190904-215419.pkl
Epoch 6 mean loss: 2.106511116027832
Epoch 6 mean val loss: 1.8561232089996338
Epoch 7 mean loss: 1.502906322479248
Epoch 7 mean val loss: 1.712312936782837
Epoch 8 m

In [9]:
print('Using best model from epoch', str(best_loss_model[0]), 'which had loss', str(best_loss_model[1]))
model.load_state_dict(best_loss_model[2])
model.eval()
predict = test_dataset.data_runoff.copy()
predict['actual'] = predict['runoff']
predict['runoff'] = np.nan
pred_array = np.array([])
for i, test_batch in enumerate(test_dataloader):
    x_test = test_batch['x_conv'][...,station_cell_mask].reshape(*test_batch['x_conv'].shape[:2], -1)
    x_test = torch.cat([x_test, test_batch['x_fc'][:,onehot_vars].unsqueeze(dim=1).repeat(1,x_test.shape[1],1)], dim=2).to(device)
    model.init_hidden(x_test.shape[0])
    pred_array = np.concatenate([pred_array, model(x_test).detach().cpu().numpy().reshape(-1)])
    
predict['runoff'] = pred_array

Using best model from epoch 116 which had loss 0.7315791249275208


In [10]:
nse_list = []
mse_list = []
grouped_predict = predict.groupby('station')
for station in grouped_predict.groups.keys():
    station_predict = grouped_predict.get_group(station).set_index('date')
    nse, mse = evaluate.evaluate_daily(station, station_predict[['runoff']], station_predict['actual'], writer=writer)
    nse_list.append(nse)
    mse_list.append(mse)
    
    print(station, '\tNSE:', nse, '\tMSE:', mse, '(clipped to 0)')

print('Median NSE (clipped to 0)', np.median(nse_list), '/ Min', np.min(nse_list), '/ Max', np.max(nse_list))
print('Median MSE (clipped to 0)', np.median(mse_list), '/ Min', np.min(mse_list), '/ Max', np.max(mse_list))


To register the converters:
	>>> from pandas.plotting import register_matplotlib_converters
	>>> register_matplotlib_converters()


02GA010 	NSE: -0.14521424424021445 	MSE: 546.0078587545494 (clipped to 0)
02GA018 	NSE: -0.11914420261325254 	MSE: 281.07662592673546 (clipped to 0)
02GA038 	NSE: -0.0518527754170941 	MSE: 173.07023258356543 (clipped to 0)
02GA047 	NSE: -0.4593485833655171 	MSE: 114.3983714443547 (clipped to 0)
02GB001 	NSE: -0.21653837312290758 	MSE: 9136.051356463951 (clipped to 0)
02GB007 	NSE: 0.02276503812807873 	MSE: 30.560105925634023 (clipped to 0)
02GC002 	NSE: -0.003552777015745079 	MSE: 128.95957724834642 (clipped to 0)
02GC007 	NSE: -0.5340803971868235 	MSE: 46.2085640941852 (clipped to 0)
02GC010 	NSE: -0.03342948046527394 	MSE: 61.43729973144633 (clipped to 0)
02GC018 	NSE: 0.0204710935297715 	MSE: 66.34036003153619 (clipped to 0)
02GC026 	NSE: -0.16336877486771173 	MSE: 187.10780321361818 (clipped to 0)
02GD004 	NSE: -0.01572587575239215 	MSE: 55.928988464007354 (clipped to 0)
02GE007 	NSE: 0.03350107254075507 	MSE: 34.96315522133642 (clipped to 0)
02GG002 	NSE: -0.08161154881481081 	MSE

In [11]:
writer.close()

In [12]:
load_data.pickle_results('LSTM_VIC-oneModel', predict[['date', 'station', 'runoff', 'actual']].rename({'runoff': 'prediction'}, axis=1).reset_index(drop=True), time_stamp)

'LSTM_VIC-oneModel_20190904-215419.pkl'

In [13]:
datetime.now().strftime('%Y%m%d-%H%M%S')

'20190905-080828'