In [None]:
import numpy as np
import pandas as pd
import math
import os
import sys
import yaml
# from geopy.distance import great_circle
# from geopy import Point
import torch
from models.GRU import GRU
from models.GC_GRU import GC_GRU
from models.Seq2Seq_GC_GRU import Seq2Seq_GC_GRU
from models.Seq2Seq_GNN_GRU import Seq2Seq_GNN_GRU
# from models.DGC_GRU import DGC_GRU
from dataset import Dataset
from graph import Graph
from utils import load_model
from datetime import datetime

import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = torch.nn.MSELoss()

In [None]:
proj_dir = os.path.abspath('')
sys.path.append(proj_dir)
config_fp = os.path.join(proj_dir, 'config.yaml')

with open(config_fp, 'r') as f:
    config = yaml.safe_load(f)

# ------------- Config parameters start ------------- #
data_dir = config['dirpath']['data_dir']
model_dir = config['dirpath']['model_dir']
plots_dir = config['dirpath']['plots_dir']

location = config['location']

npy_fp = data_dir + config[location]['filepath']['npy_fp']
locations_fp = data_dir + config[location]['filepath']['locations_fp']
altitude_fp = data_dir + config[location]['filepath']['altitude_fp'] if location == 'china' else None
# map_fp = data_dir + config[location]['filepath']['map_fp'] if location == 'bihar' else None

batch_size = int(config['train']['batch_size'])
num_epochs = int(config['train']['num_epochs'])
forecast_len = int(config['train']['forecast_len'])
hist_len = int(config['train']['hist_len'])
hidden_dim = int(config['train']['hidden_dim'])
lr = float(config['train']['lr'])
model_type = config['train']['model']

dataset_num = int(config[location]['dataset']['num'])
update = int(config[location]['dataset']['update'])
data_start = config[location]['dataset']['data_start']
data_end = config[location]['dataset']['data_end']

dist_thresh = float(config[location]['threshold']['distance'])
alt_thresh = float(config[location]['threshold']['altitude']) if location == 'china' else None
haze_thresh = float(config[location]['threshold']['haze'])

train_start = config[location]['split'][dataset_num]['train_start']
train_end = config[location]['split'][dataset_num]['train_end']
val_start = config[location]['split'][dataset_num]['val_start']
val_end = config[location]['split'][dataset_num]['val_end']
test_start = config[location]['split'][dataset_num]['test_start']
test_end = config[location]['split'][dataset_num]['test_end']

num_locs = 511
# ------------- Config parameters end   ------------- #

In [None]:
# china_pm25 = np.load(china_npy_fp)[:, :, -1].flatten()
bihar_pm25 = np.load(npy_fp)[:, :, -1].flatten()

bihar_avg = bihar_pm25.mean()
bihar_avg

## Train / Test Data Mean and Variance

In [None]:
def get_indices(start_date, end_date, data_start):
    start_idx = (datetime(*start_date) - datetime(*data_start)).days * (24//update)
    end_idx = (datetime(*end_date) - datetime(*data_start)).days * (24//update)

    return start_idx, end_idx

def get_mean_std(arr, start_idx, end_idx):

    arr = arr[start_idx: end_idx+1]
    pm25 = arr.reshape(-1, arr.shape[-1])[:, -1]
    return pm25.mean().round(4), pm25.std().round(4)

In [None]:
npy_data = np.load(npy_fp)

train_start_idx, train_end_idx = get_indices(train_start, train_end, data_start)
val_start_idx, val_end_idx = get_indices(val_start, val_end, data_start)
test_start_idx, test_end_idx = get_indices(test_start, test_end, data_start)

train_pm25_mean, train_pm25_std = get_mean_std(npy_data, train_start_idx, train_end_idx)
val_pm25_mean, val_pm25_std = get_mean_std(npy_data, val_start_idx, val_end_idx)
test_pm25_mean, test_pm25_std = get_mean_std(npy_data, test_start_idx, test_end_idx)

In [None]:
print(train_pm25_mean, train_pm25_std)
print(val_pm25_mean, val_pm25_std)
print(test_pm25_mean, test_pm25_std)

## Evaluating Stats and getting the plots

In [None]:
def get_data_model_info(model_type, location):

    assert location in {'china', 'bihar'}, "Incorrect Location"
    assert model_type in {'GRU', 'GC_GRU', 'Seq2Seq_GC_GRU', 'Seq2Seq_Attn_GC_GRU', 'DGC_GRU', 'Seq2Seq_GNN_GRU',\
                          'Seq2Seq_GNN_Transformer'}, "Incorrect model type"

    train_data = Dataset(npy_fp, forecast_len, hist_len, train_start, train_end, data_start, update)
    val_data = Dataset(npy_fp, forecast_len, hist_len, val_start, val_end, data_start, update)
    test_data = Dataset(npy_fp, forecast_len, hist_len, test_start, test_end, data_start, update)

    graph = Graph(location, locations_fp, dist_thresh, altitude_fp, alt_thresh)

    in_dim, city_num = train_data.feature.shape[-1], train_data.feature.shape[-2]
    '''
        Decoder input dim: 3, since the last 3 elements are the only known features during forecasting (is_weekend, cyclic hour embedding)
    '''
    in_dim_dec = 3

    if model_type == 'GRU':
        model = GRU(in_dim, hidden_dim, city_num, hist_len, forecast_len, batch_size, device)
    elif model_type == 'GC_GRU':
        model = GC_GRU(in_dim, hidden_dim, city_num, hist_len, forecast_len, batch_size, device, graph.adj_mat)
    elif model_type == 'Seq2Seq_GC_GRU':
        model = Seq2Seq_GC_GRU(in_dim, hidden_dim, city_num, hist_len, forecast_len, batch_size, device, graph.adj_mat)
    # elif model_type == 'Seq2Seq_Attn_GC_GRU':
    #     model = Seq2Seq_Attn_GC_GRU(in_dim, hidden_dim, city_num, hist_len, forecast_len, batch_size, device, graph.adj_mat)
    elif model_type == 'Seq2Seq_GNN_GRU':
        model = Seq2Seq_GNN_GRU(in_dim, in_dim_dec, hidden_dim, city_num, hist_len, forecast_len, batch_size, device, graph.adj_mat)
    # elif model_type == 'Seq2Seq_GNN_Transformer':
        # model = Seq2Seq_GNN_Transformer(in_dim, in_dim_dec, hidden_dim, city_num, hist_len, forecast_len, batch_size, device, graph.adj_mat)
    # elif model_type == 'DGC_GRU':
    #     model = DGC_GRU(in_dim, hidden_dim, city_num, hist_len, forecast_len, batch_size, device, graph.adj_mat, graph.angles)
    else:
        raise Exception('Wrong model name!')

    return train_data, val_data, test_data, model

In [None]:
model_state_dict, _, train_losses, test_losses = load_model(f'{model_dir}/{model_type}_{hist_len}_{forecast_len}.pth.tar')

train_data, val_data, test_data, model = get_data_model_info(model_type, location)
model.load_state_dict(model_state_dict)
model.to(device)
pm25_mean, pm25_std = train_data.pm25_mean, train_data.pm25_std

train_loader = torch.utils.data.DataLoader(train_data, drop_last=True, batch_size=batch_size, shuffle=False)
val_loader = torch.utils.data.DataLoader(val_data, drop_last=True, batch_size=batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_data, drop_last=True, batch_size=batch_size, shuffle=False)

In [None]:
def test(model, loader, pm25_mean, pm25_std):
    model.eval()
    test_loss = 0

    y, y_hat = np.array([]), np.array([])

    for _, data in enumerate(loader):
        
        features, pm25 = data
        pm25 = pm25.to(device)

        pm25_label = pm25[:, hist_len:]
        pm25_preds = model(features, pm25)

        loss = criterion(pm25_label, pm25_preds)
        test_loss += loss.item()

        pm25_label = pm25_label * pm25_std + pm25_mean
        pm25_preds = pm25_preds * pm25_std + pm25_mean
        pm25_label, pm25_preds = pm25_label.detach().cpu().numpy(), pm25_preds.detach().cpu().numpy()

        labels, preds = pm25_label[:, :, 0, 0], pm25_preds[:, :, 0, 0]

        if y.shape[0] == 0:
            y, y_hat = labels, preds
        else:
            y, y_hat = np.concatenate((y, labels), axis=0), np.concatenate((y_hat, preds), axis=0)
    
    return y, y_hat

In [None]:
def get_values(model, loader, pm25_mean, pm25_std):
    labels, preds = test(model, loader, pm25_mean, pm25_std)
    print(labels.shape, preds.shape)

    y, y_hat = np.zeros((labels.shape[0] + forecast_len - 1)), np.zeros((preds.shape[0] + forecast_len - 1))
    freq = np.zeros((labels.shape[0] + forecast_len - 1))

    for i, (l, p) in enumerate(zip(labels, preds)):
        y[i:i+forecast_len] += l
        y_hat[i:i+forecast_len] += p
        freq[i:i+forecast_len] += 1

    y, y_hat = np.divide(y, freq), np.divide(y_hat, freq)
    return y, y_hat

In [None]:
HOUR_AVG = 3

train_labels, train_preds = get_values(model, train_loader, pm25_mean, pm25_std)
val_labels, val_preds = get_values(model, val_loader, pm25_mean, pm25_std)
test_labels, test_preds = get_values(model, test_loader, pm25_mean, pm25_std)

train_len, val_len, test_len = len(train_labels) - len(train_labels) % HOUR_AVG, len(val_labels) - len(val_labels) % HOUR_AVG,\
                        len(test_labels) - len(test_labels) % HOUR_AVG

train_labels, train_preds = train_labels[:train_len], train_preds[:train_len]
val_labels, val_preds = val_labels[:val_len], val_preds[:val_len]
test_labels, test_preds = test_labels[:test_len], test_preds[:test_len]

train_labels, train_preds = train_labels.reshape(-1, HOUR_AVG), train_preds.reshape(-1, HOUR_AVG)
train_labels, train_preds = np.mean(train_labels, axis=1), np.mean(train_preds, axis=1)

val_labels, val_preds = val_labels.reshape(-1, HOUR_AVG), val_preds.reshape(-1, HOUR_AVG)
val_labels, val_preds = np.mean(val_labels, axis=1), np.mean(val_preds, axis=1)

test_labels, test_preds = test_labels.reshape(-1, HOUR_AVG), test_preds.reshape(-1, HOUR_AVG)
test_labels, test_preds = np.mean(test_labels, axis=1), np.mean(test_preds, axis=1)

In [None]:
plt.plot(train_labels, label='True Label')
plt.plot(train_preds, label='Preds')
plt.legend(prop={'size': 15})
plt.title("Training Fit", fontsize=20)
plt.savefig(f'{plots_dir}/training_fit.jpg', dpi=400)

In [None]:
plt.plot(train_labels[-248:], label='True Label')
plt.plot(train_preds[-248:], label='Preds')
plt.legend(prop={'size': 15})
plt.title("Training Fit", fontsize=20)
plt.savefig(f'{plots_dir}/training_fit_dec.jpg', dpi=400)

In [None]:
plt.plot(val_labels, label='True Label')
plt.plot(val_preds, label='Preds')
plt.legend(prop={'size': 15})
plt.title("Validation Fit", fontsize=20)
plt.savefig(f'{plots_dir}/validation_fit.jpg', dpi=400)

In [None]:
plt.plot(test_labels, label='True Label')
plt.plot(test_preds, label='Preds')
plt.legend(prop={'size': 15})
plt.title("Test Fit", fontsize=20)
plt.savefig(f'{plots_dir}/test_fit.jpg', dpi=400)