LSTM trained on gridded forcings for each station

In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('..')
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from datetime import datetime, timedelta
from sklearn import preprocessing
import netCDF4 as nc
import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from src import load_data, evaluate
import torch.autograd as autograd
import pickle

time_stamp = datetime.now().strftime('%Y%m%d-%H%M%S')
time_stamp

In [None]:
USE_CUDA = False
if torch.cuda.is_available():
    print('CUDA Available')
    USE_CUDA = True
device = torch.device('cuda' if USE_CUDA else 'cpu')
torch.manual_seed(0)
np.random.seed(0)

writer = SummaryWriter()

In [None]:
station_data_dict = load_data.load_train_test_lstm()
data_runoff = load_data.load_discharge_gr4j_vic()

In [None]:
class LSTMRegression(nn.Module):
        def __init__(self, input_dim, hidden_dim, num_layers, batch_size, dropout):
            super(LSTMRegression, self).__init__()
            self.batch_size = batch_size
            self.hidden_dim = hidden_dim
            self.num_layers = num_layers
            self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, dropout=dropout)
            self.linear = nn.Linear(hidden_dim, 1)
            self.hidden = self.init_hidden()
        def init_hidden(self):
            return (torch.randn(self.num_layers, self.batch_size, self.hidden_dim, device=device),
                    torch.randn(self.num_layers, self.batch_size, self.hidden_dim, device=device))

        def forward(self, input):
            lstm_out, self.hidden = self.lstm(input, self.hidden)
            return self.linear(lstm_out[-1])

In [None]:
predictions = {}
actuals = {}
seq_len = 5 * 24
train_start = datetime.strptime('2010-01-01', '%Y-%m-%d') + timedelta(hours=seq_len)  # first day for which to make a prediction in train set
train_end = '2013-12-31' # last day for which to make a prediction in train set
test_start = '2014-01-01'
test_end = '2014-12-31'
validation_fraction = 0.1

for station in station_data_dict.keys():
    torch.manual_seed(0)
    np.random.seed(0)

    station_rdrs = station_data_dict[station]
    station_runoff = data_runoff[data_runoff['station'] == station].set_index('date')
    if any(station_runoff['runoff'].isna()):
        print('Station', station, 'had NA runoff values. Skipping.')
        continue
    
    num_train_days = len(pd.date_range(train_start, train_end, freq='D'))
    num_test_days = len(pd.date_range(test_start, test_end, freq='D'))
    num_total_days = len(pd.date_range(train_start, test_end, freq='D'))
    
    x = np.zeros((seq_len, num_total_days, station_rdrs.shape[1]))
    for day in range(x.shape[1]):
        # For each day that is to be predicted, cut out a sequence that ends with that day's 23:00:00 and is seq_len long
        x[:,day,:] = station_rdrs[train_start + timedelta(days=day, hours=-seq_len + 24) : train_start + timedelta(hours=23, days=day)]
    
    # Scale training data
    scalers = []  # save scalers to apply them to test data later
    x_train = x[:,:num_train_days,:].copy()
    for i in range(x.shape[2]):
        scalers.append(preprocessing.StandardScaler())
        x_train[:,:,i] = scalers[i].fit_transform(x_train[:,:,i].reshape((-1, 1))).reshape(x_train[:,:,i].shape)
    x_train = torch.from_numpy(x_train).float().to(device)
    y_train = torch.from_numpy(station_runoff.loc[train_start:train_end, 'runoff'].to_numpy()).float().to(device)
    
    # Get validation split
    num_validation_samples = int(x_train.shape[1] * validation_fraction)
    validation_indices = np.random.choice(range(x_train.shape[1]), size=num_validation_samples)
    shuffle_indices = np.arange(x_train.shape[1])
    np.random.shuffle(shuffle_indices)
    x_train = x_train[:,shuffle_indices,:]
    y_train = y_train[shuffle_indices]
    x_val, x_train = x_train[:,:num_validation_samples,:], x_train[:,num_validation_samples:,:]
    y_val, y_train = y_train[:num_validation_samples], y_train[num_validation_samples:]
    print('Shapes: x_train {}, y_train {}, x_val {}, y_val {}'.format(x_train.shape, y_train.shape, x_val.shape, y_val.shape))
    
    # Train model
    learning_rate = 2e-3
    patience = 50
    min_improvement = 0.05
    best_loss_model = (-1, np.inf, None)
    
    # Prepare model
    H = 20
    batch_size = 5
    lstm_layers = 1
    dropout = 0
    weight_decay = 2e-5
    model = LSTMRegression(station_rdrs.shape[1], H, lstm_layers, batch_size, dropout).to(device)
    loss_fn = torch.nn.MSELoss(reduction='mean')
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    
    for epoch in range(300):
        epoch_losses = []
        
        shuffle_indices = np.arange(x_train.shape[1])
        np.random.shuffle(shuffle_indices)
        x_train = x_train[:,shuffle_indices,:]
        y_train = y_train[shuffle_indices]

        model.train()
        for i in range(x_train.shape[1] // batch_size):
            model.hidden = model.init_hidden()
            y_pred = model(x_train[:,i*batch_size : (i+1)*batch_size,:])
    
            loss = loss_fn(y_pred, y_train[i*batch_size : (i+1)*batch_size].reshape((batch_size,1))).to(device)
            epoch_losses.append(loss.item())
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        epoch_loss = np.array(epoch_losses).mean()
        print('Epoch', epoch, 'mean train loss:\t{}'.format(epoch_loss))
        writer.add_scalar('loss_' + station, epoch_loss, epoch)
        
        # eval on validation split
        model.eval()
        val_pred = pd.Series()
        for i in range(x_val.shape[1] // batch_size):
            batch_pred = model(x_val[:,i*batch_size : (i+1)*batch_size,:]).detach().cpu().numpy().reshape(batch_size)
            val_pred = val_pred.append(pd.Series(batch_pred))
        val_nse, val_mse = evaluate.evaluate_daily(station, val_pred, pd.Series(y_val.cpu().numpy().flatten())[:val_pred.shape[0]])
        print('Epoch {} mean val mse:    \t{},\tnse: {}'.format(epoch, val_mse, val_nse))
        writer.add_scalar('loss_eval_' + station, val_mse, epoch)
        
        if val_mse < best_loss_model[1] - min_improvement:
            best_loss_model = (epoch, val_mse, model.state_dict())  # new best model
        elif epoch > best_loss_model[0] + patience:
            print('Patience exhausted in epoch {}. Best val-loss was {}'.format(epoch, best_loss_model[1]))
            break
    
    print('Using best model from epoch', str(best_loss_model[0]), 'which had loss', str(best_loss_model[1]))
    model.load_state_dict(best_loss_model[2])
    load_data.pickle_model('LSTM_VIC', model, station, time_stamp)
    model.eval()
    
    # scale test data
    x_test = x[:,num_train_days:num_train_days+num_test_days,:].copy()
    for i in range(x.shape[2]):
        x_test[:,:,i] = scalers[i].transform(x_test[:,:,i].reshape((-1, 1))).reshape(x_test[:,:,i].shape)
    print('x_test shape: {}'.format(x_test.shape))
    # if batch size doesn't align with number of samples, add dummies to the last batch
    if x_test.shape[1] % batch_size != 0:
        x_test = np.concatenate([x_test, np.zeros((x_test.shape[0], batch_size - (x_test.shape[1] % batch_size), x_test.shape[2]))], axis=1)
        print('Appended dummy entries to x_test. New shape: {}'.format(x_test.shape))
    
    # Predict
    x_test = torch.from_numpy(x_test).float().to(device)
    predict = station_runoff[test_start:test_end].copy()
    predict['runoff'] = np.nan
    pred_array = np.array([])
    print('Predicting')
    for i in range(x_test.shape[1] // batch_size):
        pred_array = np.concatenate([pred_array, model(x_test[:,i*batch_size : (i+1)*batch_size,:]).detach().cpu().numpy().reshape(batch_size)])
    predict['runoff'] = pred_array[:predict.shape[0]]  # ignore dummies
    
    predictions[station] = predict
    actuals[station] = station_runoff['runoff'].loc[test_start:test_end]    

In [None]:
nse_list = []
mse_list = []
for station, predict in predictions.items():
    nse, mse = evaluate.evaluate_daily(station, predict['runoff'], actuals[station], writer=writer)
    nse_list.append(nse)
    mse_list.append(mse)
    
    print(station, '\tNSE:', nse, '\tMSE:', mse, '(clipped to 0)')

print('Median NSE (clipped to 0)', np.median(nse_list), '/ Min', np.min(nse_list), '/ Max', np.max(nse_list))
print('Median MSE (clipped to 0)', np.median(mse_list), '/ Min', np.min(mse_list), '/ Max', np.max(mse_list))

In [None]:
writer.close()

In [None]:
load_data.pickle_results('LSTM_VIC', (predictions, actuals), time_stamp)

In [None]:
datetime.now().strftime('%Y%m%d-%H%M%S')