In [27]:
# %matplotlib inline
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import logging
import h5py
import gzip
import json
import os
import tqdm

import warnings
warnings.filterwarnings('ignore')

import ipywidgets as widgets
from ipywidgets import interact, interact_manual

logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)

## Model definition for loading

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class BaselineLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(BaselineLSTM, self).__init__()
        self.num_layers = num_layers
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            batch_first=True,
            num_layers=num_layers,
        )
        self.linear = nn.Linear(hidden_size * num_layers, input_size)

    def forward(self, lstm_input):
        """
        :lstm_input: (b, 232, 232, 10)
        """
        #         print(lstm_input.size())
        b, n_countries, n_countries, seq_len = lstm_input.size()
        lstm_input = lstm_input.permute(0, 3, 1, 2)
        _, (hn, __) = self.lstm(lstm_input.contiguous().view(b, seq_len, -1))
        hn = hn.permute(1, 0, 2)
        return self.linear(hn.reshape(b, -1)).view(b, n_countries, n_countries, 1)

In [3]:
model = BaselineLSTM(input_size=232*232, hidden_size=1024, num_layers=10)

model.load_state_dict(
    torch.load(
        "model_checkpoints/country_level-lr=0.001,input_size=53824,hidden_size=1024,num_layers=10_epi_64_mae=0.02303406.pth"
    )
)

<All keys matched successfully>

## Setup data and get predictions

In [4]:
import joblib
scaler = joblib.load('data/EpiGCN/standard_scaler.pkl')

In [5]:
with h5py.File("data/EpiGCN/train_test.hdf5", "r") as f:
    x_train = f["x_train"].value[..., 0]
    x_test = f["x_test"].value[..., 0]
    y_train = f["y_train"].value[..., 0]
    y_test = f["y_test"].value[..., 0]

In [6]:
print("X train shape: ", x_train.shape)
print("X test shape: ", x_test.shape)
print("Y train shape: ", y_train.shape)
print("Y test shape: ", y_test.shape)

X train shape:  (312, 232, 232, 10)
X test shape:  (34, 232, 232, 10)
Y train shape:  (312, 232, 232, 1)
Y test shape:  (34, 232, 232, 1)


In [7]:
import torch
use_cuda = torch.cuda.is_available()
# device = torch.device("cuda:0" if use_cuda else "cpu")
device = torch.device("cpu")
print(device)

cpu


In [8]:
def get_scaled_output(y):
    shape = y.shape
    y = y.reshape(-1)
    mean = scaler.mean_[0]
    scale = scaler.scale_[0]
    return (scale * y + mean).reshape(*shape)

# Define Dataset

In [34]:
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset

train_dataset = TensorDataset(torch.FloatTensor(x_train), torch.FloatTensor(y_train))
test_dataset = TensorDataset(torch.FloatTensor(x_test), torch.FloatTensor(y_test))

# Define Dataloader

In [35]:
# params = {"batch_size": 12, "shuffle": True, "num_workers": 4}

train_dataloader = DataLoader(train_dataset, batch_size=12, shuffle=True, num_workers=0)
test_dataloader = DataLoader(test_dataset, batch_size=12, shuffle=False, num_workers=0)

In [36]:
predictions, truth = [], []
with torch.no_grad():
    for batch_x, batch_y in test_dataloader:
        logits = model(batch_x.to(device))
        predictions.append(get_scaled_output(logits.cpu().numpy()))
        truth.append(get_scaled_output(batch_y.cpu().numpy()))
predictions = np.concatenate(predictions)
truth = np.concatenate(truth)

In [37]:
truth.shape

(34, 232, 232, 1)

In [38]:
import seaborn as sns


def animate(i):
#     fig, axes = plt.subplots(1, 2, figsize=(25, 8))
#     axes[0].set_title("Test: Truth")
#     axes[1].set_title("Test: Predictions")
    t = truth[i, :, :, 0]
    p = predictions[i, :, :, 0]
    vmin = min(t.min(), p.min())
    vmax = max(t.max(), p.max())
    sns.heatmap(t, vmin=vmin, vmax=vmax, cmap="Blues", ax=axes[0])
    sns.heatmap(p, vmin=vmin, vmax=vmax, cmap="Blues", ax=axes[1])
    

# Animator implementation

In [14]:
import seaborn as sns
import os
from tqdm import tnrange
import imageio


class Animator:
    def __init__(self, truth, predictions, dir_name="visualizations"):
        self.time_dim = truth.shape[0]
        self.truth = truth
        self.predictions = predictions
        self.dir_name = dir_name
        self.vmax = max(truth.max(), predictions.max())
        self.vmin = min(truth.min(), predictions.min())

        if not os.path.exists(dir_name):
            os.makedirs(dir_name)

    def list_pngs(self):
        return list(
            map(
                lambda file: os.path.join(self.dir_name, file),
                filter(lambda file: file.endswith(".png"), os.listdir(self.dir_name)),
            )
        )

    def remove_pngs(self):
        for file in self.list_pngs():
            os.remove(file)

    def animate_img(self, i, name):
        fig, axes = plt.subplots(1, 2, figsize=(25, 8))
        axes[0].set_title("{}: Truth".format(name))
        axes[1].set_title("{}: Predictions".format(name))
        t = self.truth[i, :, :, 0]
        p = self.predictions[i, :, :, 0]
        sns.heatmap(t, vmin=self.vmin, vmax=self.vmax, cmap="Blues", ax=axes[0])
        sns.heatmap(p, vmin=self.vmin, vmax=self.vmax, cmap="Blues", ax=axes[1])
        plt.savefig("{}/{}.png".format(self.dir_name, i))

    def create_gif(self, name):
        images = []
        for file in self.list_pngs():
            images.append(imageio.imread(file))
        imageio.mimsave(os.path.join(self.dir_name, "{}.gif".format(name)), images)

    def animate(self, name):
        self.remove_pngs()

        for i in tnrange(self.time_dim):
            self.animate_img(i, name)

        self.create_gif(name)
        self.remove_pngs()

In [None]:
animator = Animator(truth, predictions)
animator.animate("test")

In [30]:
train_dataloader = DataLoader(train_dataset, batch_size=12, shuffle=False, num_workers=12)

In [31]:
train_predictions, train_truth = [], []
with torch.no_grad():
    for batch_x, batch_y in train_dataloader:
        logits = model(batch_x.to(device))
        train_predictions.append(get_scaled_output(logits.cpu()))
        train_truth.append(get_scaled_output(batch_y.cpu()))
train_predictions = np.concatenate(train_predictions)
train_truth = np.concatenate(train_truth)

In [32]:
animator = Animator(train_truth, train_predictions)

In [None]:
animator.animate("train")

## Visualize after rearranging the grid based on graph clusters

In [43]:
import networkx as nx
clustered_graph = nx.read_gpickle('data/EpiGCN/hcs_components.gpkl')

We should now create a mapping from older index to newer index so that we can rearrange the grid - the older index is the natural `country_id` order and the newer index is an incrementing integer (starting from 0), this ensures that components in the same graph are clustered together

In [52]:
def get_reordered_mapping_of_nodes(clustered_graph):
    """
    takes in a networkx graph with each subgraph being it's own connected component
    """
    ix = 0
    mapping = {}
    for component in nx.connected_components(clustered_graph):
        for node in component:
            mapping[node] = ix
            ix += 1
    return mapping

In [53]:
clustered_mapping = get_reordered_mapping_of_nodes(clustered_graph)

## Reorder original grids based on the new mapping

In [57]:
def get_reordered_grid(grid, clustered_mapping):
    new_grid = np.zeros_like(grid)
    for old_ix, new_ix in clustered_mapping.items():
        new_grid[:, new_ix] = grid[:, old_ix] # first axis is time
    return new_grid

In [61]:
reordered_truth = get_reordered_grid(truth, clustered_mapping)
reordered_predictions = get_reordered_grid(predictions, clustered_mapping)
reordered_train_truth = get_reordered_grid(train_truth, clustered_mapping)
reordered_train_predictions = get_reordered_grid(train_predictions, clustered_mapping)

In [64]:
print("original shape: {}, new shape: {}".format(truth.shape, reordered_truth.shape))
print("original shape: {}, new shape: {}".format(predictions.shape, reordered_predictions.shape))
print("original shape: {}, new shape: {}".format(train_truth.shape, reordered_train_truth.shape))
print("original shape: {}, new shape: {}".format(train_predictions.shape, reordered_train_predictions.shape))

original shape: (34, 232, 232, 1), new shape: (34, 232, 232, 1)
original shape: (34, 232, 232, 1), new shape: (34, 232, 232, 1)
original shape: (312, 232, 232, 1), new shape: (312, 232, 232, 1)
original shape: (312, 232, 232, 1), new shape: (312, 232, 232, 1)


In [None]:
animator = Animator(reordered_truth, reordered_predictions)
animator.animate('reordered test')