# Train the LSTM-FC baseline with spatial correction

![picture](https://repository.ust.hk/ir/profileImages/xclu.jpg)![picture](https://envrpg.ust.hk/stuphotos/xluad.jpg)

This notebook is for building the LSTM-FC models with the spatial correction. For Google Colab only.

## Set up the environment

### Import the packages

In [1]:
! pip install geopy
! pip install pykrige
import torch
import torch.nn as nn
import torch.nn.functional as F

import os, sys
import numpy as np



### Training device

In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f'Using device {device}...')

Using device cuda...


### Saving and loading files to/from different places

- **Temporary files** will be saved in the COLAB workspace directory
- **Models** and other files necessary for long-term use will be saved in the drive directory 

In [3]:
workspace_dir = '/content'
drive_dir = '/content/drive/Othercomputers/DESKTOP-P14JC7J/2130/code'
sys.path.append(drive_dir)

from data_utils import *

### Copy the processed data here

The processed data should be saved on the Google drive.

In [4]:
! mkdir data
! mkdir baseline-results
! cp ./drive/MyDrive/urop-paper2-data/* ./data
! cp ./drive/MyDrive/urop-paper2-baseline-results/* ./baseline-results

mkdir: cannot create directory ‘data’: File exists
mkdir: cannot create directory ‘baseline-results’: File exists


## Dataset
The dataset used for training and testing.

In [5]:
from data_utils.obs import ObsReader

from torch.utils.data import Dataset

from datetime import datetime, date, time, timedelta

class BaselineSourceDataset(Dataset):
  def __init__(self, train, device):
    super().__init__()
    self.period = 'train' if train else 'test'
    self.device = device
    self.first_date = eval(f'{self.period}_first_dt').date()
    self.last_date = eval(f'{self.period}_last_dt').date()

    self.obs_normalizing = load_dict(f'{data_dir}/obs_normalizing.pkl')
    self.wrf_normalizing = np.load(f'{data_dir}/wrf_normalizing.npz')
    self.wrf_normalizing = (self.wrf_normalizing['mean'], self.wrf_normalizing['std'])
    self.cmaq_normalizing = np.load(f'{data_dir}/cmaq_normalizing.npz')
    self.cmaq_normalizing = (self.cmaq_normalizing['mean'], self.cmaq_normalizing['std'])
    
    source_obs_data = load_dict(f'{data_dir}/{self.period}_source_data.pkl')
    source_wrf_match = load_dict(f'{data_dir}/source_wrf_match.pkl')
    source_cmaq_match = load_dict(f'{data_dir}/source_cmaq_match.pkl')
    assert list(source_obs_data.keys()) == list(source_wrf_match.keys()) == list(source_cmaq_match.keys())
    self.source_stations = list(source_obs_data.keys())

    target_obs_data = load_dict(f'{data_dir}/{self.period}_target_data.pkl')

    self.source_obs_reader = ObsReader(source_obs_data)
    self.target_obs_reader = ObsReader(target_obs_data, self.source_stations)

    target_wrf_data = np.load(f'{data_dir}/{self.period}_source_wrf_data.npy')
    target_cmaq_data = np.load(f'{data_dir}/{self.period}_source_cmaq_data.npy')
        
    self.target_wrf_data, self.target_cmaq_data = target_wrf_data, target_cmaq_data

    assert len(self) == self.target_wrf_data.shape[0] == self.target_cmaq_data.shape[0]

  def get_source_obs(self, day0):
    first_dt = datetime.combine(day0-timedelta(days = history_days), time(0))
    last_dt = datetime.combine(day0-timedelta(days = 1), time(23))
    source_obs = self.source_obs_reader(first_dt, last_dt)
    out = {}
    for st, df in source_obs.items():
      means, stds = np.array([self.obs_normalizing[sp][0] for sp in df.columns]), np.array([self.obs_normalizing[sp][1] for sp in df.columns])
      out[st] = torch.nan_to_num(torch.tensor(((df.values - means) / stds).T, dtype = torch.float, device = self.device))
    return out

  def get_target_obs(self, day0):
    means, stds = np.array([self.obs_normalizing[sp][0] for sp in target_species]), np.array([self.obs_normalizing[sp][1] for sp in target_species])
    first_dt = datetime.combine(day0, time(0))
    last_dt = datetime.combine(day0 + timedelta(days = horizon_days - 1), time(23))
    target_obs = self.target_obs_reader(first_dt, last_dt)
    return torch.tensor(np.array([((df.values - means)/stds).T for df in target_obs.values()]), dtype = torch.float, device = self.device)

  def __getitem__(self, index):
    day0 = self.first_date + timedelta(days = history_days + index)

    source_obs = self.get_source_obs(day0)
    target_wrf = torch.tensor((self.target_wrf_data[index] - self.wrf_normalizing[0][:, None])/self.wrf_normalizing[1][:, None], dtype = torch.float, device = self.device)
    target_cmaq = torch.tensor((self.target_cmaq_data[index] - self.cmaq_normalizing[0][:, None])/self.cmaq_normalizing[1][:, None], dtype = torch.float, device = self.device)
    target_wrf_cmaq = torch.cat([target_wrf, target_cmaq], dim = -2)

    target_obs = self.get_target_obs(day0)

    return source_obs, target_wrf_cmaq, target_obs

  def __len__(self):
    return (self.last_date - self.first_date).days - (history_days + horizon_days) + 2

def regional_dataset_collate_fn(batch):
  source_obs_out, target_wrf_cmaq_out, target_obs_out = [], [], []
  for source_obs, target_wrf_cmaq, target_obs in batch:
    source_obs_out.append(source_obs)
    target_wrf_cmaq_out.append(target_wrf_cmaq)
    target_obs_out.append(target_obs)
  return torch.utils.data._utils.collate.default_collate(source_obs_out), torch.stack(target_wrf_cmaq_out, axis = 1), torch.stack(target_obs_out, axis = 1)

## Training
### Hyperparameters
Hyperparameters for training the model and model configurations.

In [None]:
num_epoch = 32
batch_size = 128
learning_rate, gamma, step_size = 1e-3, 0.9, 100

### Initialization
Build the models, optimizers and learning rate schedulers, and the dataset

In [None]:
# the model, the optimizer and the scheduler
from baselines.lstm_fc.model import LSTM_FC
models = nn.ModuleDict({st: LSTM_FC(len(df.columns)) for st, df in load_dict(f'{data_dir}/train_source_data.pkl').items()}).to(device)
optimizer = torch.optim.Adam(models.parameters(), lr = learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma = gamma, step_size = step_size)

In [None]:
from torch.utils.data import DataLoader

# the training set and the distance used
training_set = BaselineSourceDataset(train = True, device = device)
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True, collate_fn=regional_dataset_collate_fn)

# the validation set and the distance used
validation_set = BaselineSourceDataset(train = False, device = device)
validation_loader = DataLoader(validation_set, batch_size=batch_size, shuffle=True, collate_fn=regional_dataset_collate_fn)

### Training loop

In [None]:
for i in range(num_epoch):
  models.train()
  train_epoch_loss = 0.0
  for source_obs, target_wrf_cmaq, target_obs in training_loader:
    optimizer.zero_grad()
    out = [models[st](X0, X1) for (st, X0), X1 in zip(source_obs.items(), target_wrf_cmaq)]
    out = torch.stack(out)
    
    mask = ~torch.isnan(target_obs)
    loss = torch.abs(out[mask] - target_obs[mask]).mean()
    with torch.no_grad():
      train_epoch_loss += loss.item() * target_obs.shape[1] / len(training_set)
    loss.backward()
    optimizer.step()
    scheduler.step()

  valid_epoch_loss = 0.0
  models.eval()
  with torch.no_grad():
    for source_obs, target_wrf_cmaq, target_obs in validation_loader:
      out = [models[st](X0, X1) for (st, X0), X1 in zip(source_obs.items(), target_wrf_cmaq)]
      out = torch.stack(out)
      
      mask = ~torch.isnan(target_obs)
      loss = torch.abs(out[mask] - target_obs[mask]).mean()
      valid_epoch_loss += loss.item() * target_obs.shape[1] / len(validation_set)

  print(f'Epoch {i}, training loss {train_epoch_loss:.3g}, validation loss {valid_epoch_loss:.3g}...')
  if train_epoch_loss < valid_epoch_loss:
    break

Epoch 0, training loss 0.645, validation loss 0.537...
Epoch 1, training loss 0.578, validation loss 0.466...
Epoch 2, training loss 0.532, validation loss 0.427...
Epoch 3, training loss 0.5, validation loss 0.396...
Epoch 4, training loss 0.477, validation loss 0.38...
Epoch 5, training loss 0.459, validation loss 0.376...
Epoch 6, training loss 0.447, validation loss 0.367...
Epoch 7, training loss 0.438, validation loss 0.36...
Epoch 8, training loss 0.431, validation loss 0.347...
Epoch 9, training loss 0.426, validation loss 0.357...
Epoch 10, training loss 0.42, validation loss 0.344...
Epoch 11, training loss 0.417, validation loss 0.344...
Epoch 12, training loss 0.412, validation loss 0.345...
Epoch 13, training loss 0.408, validation loss 0.348...
Epoch 14, training loss 0.405, validation loss 0.332...
Epoch 15, training loss 0.403, validation loss 0.337...
Epoch 16, training loss 0.4, validation loss 0.341...
Epoch 17, training loss 0.398, validation loss 0.34...
Epoch 18, 

### Save the trained model

In [None]:
import json
model_dir = f'{drive_dir}/models/baseline_lstm_fc'
os.makedirs(model_dir, exist_ok = True)
torch.save(models.state_dict(), f'{model_dir}/state_dict')
torch.save(optimizer.state_dict(), f'{model_dir}/optimizer_state_dict')
torch.save(scheduler.state_dict(), f'{model_dir}/scheduler_state_dict')
with open (f'{model_dir}/training_config.json', 'w') as f:
  json.dump({
    'num_epoch': num_epoch, 
    'batch_size': batch_size,
    'learning_rate': learning_rate,
    'gamma': gamma,
    'step_size': step_size
  }, f)

## Prediction
Test on the testing target stations that are neither training target stations nor source stations, report the results on multiple metrics.


### Load the trained model


In [8]:
from model import Regional
model_dir = f'{drive_dir}/models/baseline_lstm_fc'
models = nn.ModuleDict({st: LSTM_FC(len(df.columns)) for st, df in load_dict(f'{data_dir}/train_source_data.pkl').items()}).to(device)
models.load_state_dict(torch.load(f'{model_dir}/state_dict'))

ModuleNotFoundError: ignored

### Load the testing data

In [None]:
# Only the testing target stations that are neither training target stations nor source stations
source_loc = load_dict(f'{data_dir}/source_loc.pkl')
source_stations = set(source_loc.keys())

train_target_stations = set(load_dict(f'{data_dir}/train_target_loc.pkl').keys())
test_target_loc = load_dict(f'{data_dir}/test_target_loc.pkl')
test_target_stations = set(test_target_loc.keys())

test_target_stations = list(test_target_stations - train_target_stations - source_stations)
test_target_loc = {st: test_target_loc[st] for st in test_target_stations}

In [None]:
# denormalize the output
obs_normalizing = load_dict(f'{data_dir}/obs_normalizing.pkl')
means, stds = np.array([obs_normalizing[sp][0] for sp in target_species]), np.array([obs_normalizing[sp][1] for sp in target_species])

In [None]:
# the testing set
testing_set = BaselineSourceDataset(train = False, device = device)
testing_loader = torch.utils.data.DataLoader(testing_set, batch_size=batch_size, collate_fn=regional_dataset_collate_fn)

In [None]:
models.eval()
assert not models.training

In [None]:
pred = {st: [] for st in testing_set.source_stations}
with torch.no_grad():
  testing_loss = 0.0
  for source_obs, target_wrf_cmaq, target_obs in testing_loader:
    assert list(source_obs.keys()) == list(models.keys())
    for (st, model), (st, X0), X1 in zip(models.items(), source_obs.items(), target_wrf_cmaq):
      pred[st].append(model.predict(X0, X1, (means, stds)))

for st in testing_set.source_stations:
  pred[st] = np.concatenate(pred[st])
  print(st, pred[st].shape)

CN_1379A (361, 2, 48)
SP_A (361, 2, 48)
XCNAQ437 (361, 2, 48)
CN_1394A (361, 2, 48)
CN_1370A (361, 2, 48)
TW_A (361, 2, 48)
CW_A (361, 2, 48)
CN_1365A (361, 2, 48)
TP_A (361, 2, 48)
CN_1369A (361, 2, 48)
XCNAQ1145 (361, 2, 48)
CL_R (361, 2, 48)
YL_A (361, 2, 48)
EN_A (361, 2, 48)
CN_1392A (361, 2, 48)
TC_A (361, 2, 48)
CN_1396A (361, 2, 48)
XCNAQ418 (361, 2, 48)
XCNAQ1144 (361, 2, 48)
CN_1391A (361, 2, 48)
XCNAQ445 (361, 2, 48)
CN_1364A (361, 2, 48)
MKaR (361, 2, 48)
XCNAQ439 (361, 2, 48)
CN_1380A (361, 2, 48)
CN_1381A (361, 2, 48)
CN_1357A (361, 2, 48)
KC_A (361, 2, 48)
TM_A (361, 2, 48)
KT_A (361, 2, 48)
CN_1358A (361, 2, 48)
CB_R (361, 2, 48)


### Interpolation

In [None]:
from data_utils.interpolation import *
source_cmaq_pred = load_dict('data/test_source_cmaq')

In [None]:
cmaq_pred = load_dict('baseline-results/cmaq_pred.pkl')
for st in test_target_stations:
  print(st, cmaq_pred[st].shape)

XCNAQ3249 (361, 2, 48)
XCNAQ3278 (361, 2, 48)
TK_A (361, 2, 48)
XCNAQ3238 (361, 2, 48)
XCNAQ3290 (361, 2, 48)
XCNAQ1295 (361, 2, 48)
XCNAQ3237 (361, 2, 48)
XCNAQ3282 (361, 2, 48)
NH_A (361, 2, 48)
XCNAQ1296 (361, 2, 48)
MACG (361, 2, 48)
XCNAQ2914 (361, 2, 48)
XCNAQ3260 (361, 2, 48)
SN_A (361, 2, 48)
XCNAQ3279 (361, 2, 48)
XCNAQ3950 (361, 2, 48)
MACT (361, 2, 48)
XCNAQ3277 (361, 2, 48)
MCKO (361, 2, 48)
MACH (361, 2, 48)
MACC (361, 2, 48)
