In [23]:
import pickle
import pandas as pd
from matplotlib import pyplot as plt
from seaborn import heatmap

import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader

from tools import validation
from models import Model1, Model2, Model3, Model4, Model5, Model6
from sklearn.metrics import r2_score, mean_absolute_error
device = torch.device('cuda:1')
criterion = nn.MSELoss()
eval_metrics = [r2_score, mean_absolute_error]

In [24]:
with open('data/test_dataset_masked.pickle', 'rb') as f:
    test_dataset = pickle.load(f)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# LSTNet

In [25]:
from LSTNet import LSTNet
lstnet_best = LSTNet(P=60, m=40, dropout=0.2, output_func='sigmoid', device=device, hidR=32, hidC=32, hidS=40, Ck=10, hw=1).to(device)
lstnet_state_dict = torch.load('saved_model_masked/LSTNet_best.pt', map_location='cuda:1')
lstnet_best.load_state_dict(lstnet_state_dict)

<All keys matched successfully>

# Simple LSTM

In [27]:
slstm_state_dict = torch.load('checkpoints/Simple_LSTM_best.pt')
for name, data in slstm_state_dict.items():
    print(name, '\t', *data.size())

lstm.weight_ih_l0 	 160 40
lstm.weight_hh_l0 	 160 40
lstm.bias_ih_l0 	 160
lstm.bias_hh_l0 	 160
lstm.weight_ih_l1 	 160 40
lstm.weight_hh_l1 	 160 40
lstm.bias_ih_l1 	 160
lstm.bias_hh_l1 	 160
lstm.weight_ih_l2 	 160 40
lstm.weight_hh_l2 	 160 40
lstm.bias_ih_l2 	 160
lstm.bias_hh_l2 	 160
linear_output.0.weight 	 80 40
linear_output.0.bias 	 80
linear_output.2.weight 	 1 80
linear_output.2.bias 	 1


# Input Attention + LSTM

In [None]:
ia_lstm_state_dict = torch.load('checkpoints/Input_Attention_LSTM_masked_best.pt')
for name, data in ia_lstm_state_dict.items():
    print(name, '\t', *data.size())

# Model 1

In [None]:
model1_best = Model1(T=60, n=40, m=40, cnn_kernel_height=10, cnn_hidden_size=32, skip_hidden_size=40, skip=8).to(device)
model1_state_dict = torch.load('saved_model_masked/model1_best.pt')
model1_best.load_state_dict(model1_state_dict)

# Model 2

In [None]:
model2_best = Model2(T=60, n=40, m=32, skip_hidden_size=40, T_modified=20, skip=10).to(device)
model2_state_dict = torch.load('saved_model_masked/model2_best.pt')
model2_best.load_state_dict(model2_state_dict)

# Model 3

In [None]:
model3_best = Model3(T=60, n=40, m=40, skip_hidden_size=40, T_modified=20, skip=8).to(device)
model3_state_dict = torch.load('saved_model_masked/model3_best.pt')
model3_best.load_state_dict(model3_state_dict)

# Model 4

In [None]:
model4_best = Model4(T=60, n=40, m=40, skip_hidden_size=40, skip=10).to(device)
model4_state_dict = torch.load('saved_model_masked/model4_best.pt', map_location='cuda:1')
model4_best.load_state_dict(model4_state_dict)

# Model 5

In [None]:
model5_best = Model5(T=60, n=40, p=32, cnn_kernel_height=30, cnn_hidden_size=32, skip_hidden_size=40, skip=8).to(device)
model5_state_dict = torch.load('saved_model_masked/model5_best.pt', map_location='cuda:1')
model5_best.load_state_dict(model5_state_dict)

# Model 6

In [None]:
model6_best = Model6(T=60, n=40, m=40, T_modified=20).to(device)
model6_state_dict = torch.load('saved_model_masked/model6_best.pt', map_location='cuda:1')
model6_best.load_state_dict(model6_state_dict)

In [None]:
unmasked_grid_search_result = []

lstnet_best.eval()
model6_best.eval()

with torch.no_grad():
    unmasked_grid_search_result.append(validation(lstnet_best, test_loader, criterion, eval_metrics, device))
    unmasked_grid_search_result.append(validation(model1_best, test_loader, criterion, eval_metrics, device))
    unmasked_grid_search_result.append(validation(model2_best, test_loader, criterion, eval_metrics, device))
    unmasked_grid_search_result.append(validation(model3_best, test_loader, criterion, eval_metrics, device))
    unmasked_grid_search_result.append(validation(model4_best, test_loader, criterion, eval_metrics, device))
    unmasked_grid_search_result.append(validation(model5_best, test_loader, criterion, eval_metrics, device))
    unmasked_grid_search_result.append(validation(model6_best, test_loader, criterion, eval_metrics, device))

unmasked_grid_search_result = pd.DataFrame(unmasked_grid_search_result, columns=['Loss', 'R2', 'MAE'])

In [None]:
unmasked_grid_search_result['Model'] = ['LSTNet'] + [f'Model {i}' for i in range(1, 7)]
unmasked_grid_search_result = unmasked_grid_search_result.iloc[:, [3, 0, 1, 2]]
unmasked_grid_search_result

# Attention Score Visualization

In [None]:
for X, y in test_loader:
    X = X.float().to(device)
    break

In [None]:
def plot_attention_scores(model, X, X_index, kind='Input', masked=False, save_filename=None):
        model.eval()
        with torch.no_grad():
            model(X)
            attention_scores_ = model.attention_scores_.cpu().detach() # (-1 x T x `dim`)
            # `dim` means the dimension of features of input to (Input or Temporal) Attention
        attention_scores = attention_scores_[X_index] # (T x `dim`)

        fig1, sub1 = plt.subplots(1, 1, dpi=100, figsize=(7, 5))
        heatmap(attention_scores, cmap='Reds', vmin=0, vmax=1, ax=sub1)
        sub1.set_ylabel('Time steps')
        sub1.set_xlabel(f'{kind} Features')

        fig2 = None
        if masked:
            fig2, sub = plt.subplots(1, 1, dpi=100, figsize=(7, 10))
            sub.hist(attention_scores[X[X_index] == 0], label='masked', bins=15, histtype='step')
            sub.hist(attention_scores[X[X_index] != 0], label='unmasked', bins=15, histtype='step')
        
        if save_filename:
            fig1.savefig(f'{save_filename}')
            print(f'Saving Process Complete. Directory: heatmap_{save_filename}')
            if fig2 is not None:
                fig2.savefig(f'{save_filename}')
                print(f'Saving Process Complete. Directory: hist_{save_filename}')
        
        return attention_scores_

In [None]:
_ = plot_attention_scores(model1_best, X, 0, kind='Input', masked=False, save_filename=None)

In [None]:
_ = plot_attention_scores(model2_best, X, 0, kind='Input', masked=False, save_filename=None)

In [None]:
_ = plot_attention_scores(model3_best, X, 0, kind='Input', masked=False, save_filename=None)

In [None]:
_ = plot_attention_scores(model4_best, X, 0, kind='Input', masked=False, save_filename=None)

In [None]:
_ = plot_attention_scores(model5_best, X, 0, kind='Input', masked=False, save_filename=None)

In [None]:
_ = plot_attention_scores(model6_best, X, 0, kind='Input', masked=False, save_filename=None)