# <b> Title: Explainability pipeline for a GRU-D model predicting neurobehavioral complications in the PICU </b>

# Library Imports

In [3]:
# Data wrangling libraries
import numpy as np
import pandas as pd
import math

# Machine Learning libraries
## PyTorch 
import torch 
import torch.nn as nn
from typing import Union, Optional 
from torch.nn import Parameter
from torch.autograd import Variable
import torch.nn.functional as F
from torch.utils.data import DataLoader
## PyPots 
from pypots.data.generating import gene_incomplete_random_walk_dataset
from pypots.utils.metrics import cal_binary_classification_metrics
from pypots.utils.logging import Logger
## Custom GRUD
from custom_grud import _GRUD

# Progress meter TQDM
import tqdm

# Explainer libraries
import timeshap
from timeshap.wrappers import TorchModelWrapper
from timeshap.explainer import local_report, global_report

# Data

## Imports

In [4]:
# Generate the unified data for testing and cache it first, DATA here is a singleton
# Otherwise, file lock will cause bug if running test parallely with pytest-xdist.
DATA = gene_incomplete_random_walk_dataset()

TRAIN_SET = {"X": DATA["train_X"], "y": DATA["train_y"]}
VAL_SET = {"X": DATA["val_X"], "y": DATA["val_y"]}
TEST = {"X": DATA["test_X"], "y": DATA["test_y"]}

In [14]:
# Shape should be (batch_size, seq_len, feature_dim)
print(f"Train_X {TRAIN_SET['X'].shape}")
#Shape of test
print(f"Test_X {TEST['X'].shape}")
#shape of val
print(f"Val_X {VAL_SET['X'].shape}")

Train_X (1280, 24, 10)
Test_X (400, 24, 10)
Val_X (320, 24, 10)


# Model Training

## Model Specification

In [15]:
# Set the logger to only print ERRORS otherwise it will print too much
logger = Logger(logging_level="error")
# We will use the modfiied GRUD model in custom_grud.py
# 1. The original GRUD model is defined in pypots/classification/grud/model.py
# 2. The custom GRUD model is called _GRUD in custom_grud.py
# Arguments of _GRUD:
#     n_steps: int,
#         The number of time steps in the input.
#     n_features: int,
#         The number of features in the input.
#     rnn_hidden_size: int,
#         The number of features in the hidden state h.
#     n_classes: int,
#         The number of classes in the output.
#     device: torch.device,
#         The device to use. 'cpu' or 'cuda'.

model = _GRUD(
    n_steps = DATA["n_steps"], # DATA["n_steps"]=24
    n_features = DATA["n_features"], # DATA["n_features"]=10
    rnn_hidden_size = 24,
    n_classes=1, # DATA["n_classes"]=2
    device = torch.device("mps")
    )

## Training Parameters

In [16]:
# Learning rate
lr = 1e-3
# Weight decay
weight_decay = 1e-5
# Number of epochs to train for
epochs = 20
# Patience for early stopping
original_patience = 30
# Set up Optimizer
optimizer = torch.optim.Adam(
            model.parameters(), lr=lr, weight_decay=weight_decay
        )
#Specify the device to use, should be the same as the model
device = torch.device("mps")

# Move the model to the device
model = model.to(device)

# each training starts from the very beginning, so reset the loss and model dict here
best_loss = float("inf")
best_model_dict = None
training_step = 0
epoch_train_loss = []
epoch_val_loss = []
loss_function = torch.nn.BCELoss()

# Convert the labels to tensors
train_label = torch.unsqueeze(torch.from_numpy(TRAIN_SET['y']), 1).to(torch.float32).to(device)
val_label = torch.unsqueeze(torch.from_numpy(VAL_SET['y']), 1).to(torch.float32).to(device)

## Training Loop

In [17]:
# Training loop
for epoch in tqdm.tqdm(range(epochs)):

    training_step += 1
    prediction = model.forward(TRAIN_SET['X'])
    # classification_loss = F.nll_loss(torch.log(prediction), torch.from_numpy(TRAIN_SET['y']).to(torch.long).to(device))
    train_loss = loss_function(prediction, train_label)
    optimizer.zero_grad()
    train_loss.backward()
    optimizer.step()
    epoch_train_loss.append(train_loss.item())


    
    with torch.no_grad():
        prediction = model.forward(VAL_SET['X'])
        # classification_loss = F.nll_loss(torch.log(prediction), torch.from_numpy(VAL_SET['y']).to(torch.long).to(device)) # this is for nll_loss (prediction is 2 dim)
        test_loss = loss_function(prediction, val_label)
        
        epoch_val_loss.append(test_loss.item())
        mean_loss = test_loss.item()
        print(f"Train loss: {train_loss.item()} --- Val loss {test_loss.item()} ")

    if mean_loss < best_loss:
        best_loss = mean_loss
        best_model_dict = model.state_dict()
        patience = original_patience
   
    else:
        patience -= 1
        if patience == 0:
            break


  5%|▌         | 1/20 [00:00<00:10,  1.83it/s]

Train loss: 33.418373107910156 --- Val loss 100.0 


 10%|█         | 2/20 [00:01<00:09,  1.84it/s]

Train loss: 100.0 --- Val loss 100.0 


 15%|█▌        | 3/20 [00:01<00:09,  1.81it/s]

Train loss: 100.0 --- Val loss 100.0 


 20%|██        | 4/20 [00:02<00:08,  1.80it/s]

Train loss: 100.0 --- Val loss 100.0 


 25%|██▌       | 5/20 [00:02<00:08,  1.80it/s]

Train loss: 100.0 --- Val loss 100.0 


 30%|███       | 6/20 [00:03<00:07,  1.80it/s]

Train loss: 100.0 --- Val loss 100.0 


 35%|███▌      | 7/20 [00:03<00:07,  1.79it/s]

Train loss: 100.0 --- Val loss 100.0 


 40%|████      | 8/20 [00:04<00:06,  1.81it/s]

Train loss: 100.0 --- Val loss 100.0 


 45%|████▌     | 9/20 [00:04<00:06,  1.82it/s]

Train loss: 100.0 --- Val loss 100.0 


 50%|█████     | 10/20 [00:05<00:05,  1.83it/s]

Train loss: 100.0 --- Val loss 100.0 


 55%|█████▌    | 11/20 [00:06<00:04,  1.83it/s]

Train loss: 100.0 --- Val loss 100.0 


 60%|██████    | 12/20 [00:06<00:04,  1.83it/s]

Train loss: 100.0 --- Val loss 100.0 


 65%|██████▌   | 13/20 [00:07<00:03,  1.83it/s]

Train loss: 100.0 --- Val loss 100.0 


 70%|███████   | 14/20 [00:07<00:03,  1.83it/s]

Train loss: 100.0 --- Val loss 100.0 


 75%|███████▌  | 15/20 [00:08<00:02,  1.84it/s]

Train loss: 100.0 --- Val loss 100.0 


 80%|████████  | 16/20 [00:08<00:02,  1.84it/s]

Train loss: 100.0 --- Val loss 100.0 


 85%|████████▌ | 17/20 [00:09<00:01,  1.84it/s]

Train loss: 100.0 --- Val loss 100.0 


 90%|█████████ | 18/20 [00:09<00:01,  1.84it/s]

Train loss: 100.0 --- Val loss 100.0 


 95%|█████████▌| 19/20 [00:10<00:00,  1.83it/s]

Train loss: 100.0 --- Val loss 100.0 


100%|██████████| 20/20 [00:10<00:00,  1.82it/s]

Train loss: 100.0 --- Val loss 100.0 





# Model Performance

In [19]:
import torch
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

tensor([1.], device='mps:0')


In [11]:
with torch.no_grad():
    model.eval()
    grud_prediction = model(TEST['X']).cpu().detach().numpy()
    metrics = cal_binary_classification_metrics(grud_prediction, TEST["y"])
    print("Testing classification metrics: \n"
        f'ROC_AUC: {metrics["roc_auc"]}, \n'
        f'PR_AUC: {metrics["pr_auc"]},\n'
        f'F1: {metrics["f1"]},\n'
        f'Precision: {metrics["precision"]},\n'
        f'Recall: {metrics["recall"]},\n'
    )

Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.


ValueError: Input contains NaN.

# Explainability with TimeSHAP

## Set up lambda function for predictions

In [9]:
model_wrapped = TorchModelWrapper(model)
f_hs = lambda x, y=None: model_wrapped.predict_last_hs(x, y)

## Calculate Baselines

In [22]:
#For best performance calculate baseline with the TRAIN_SET!
# Average Event when there are nan values
## Approach 1: Set nan to 0 and then take the median
average_event = np.nanmedian(np.nan_to_num(TRAIN_SET['X'].reshape(1280*24, 10)), axis = 0)
## Approach 2: Ignore nan when taking the median
average_event2 = np.nanmedian(TRAIN_SET['X'].reshape(1280*24, 10), axis = 0)

# Average sequence when there are nan values
## Approach 1: Set nan to 0 and then take the median
average_sequence = np.median(np.nan_to_num(TRAIN_SET['X']), axis = 0)
## Approach 2: Ignore nan when taking the median
average_sequence2 = np.nanmedian(TRAIN_SET['X'], axis = 0)

## Data Processing

In [64]:
# TimeSHAP's Local and Global reporters require the data to be in different formats.
# For the local explainer, we can use the data as is. The only requirement is that the input is 3d (sample, sequence/event, features)
# For the global explainer, we need to add 2 new features to the data. One of these features is a sample_id, the other is an event_id.
# This is an utility function to do that
def add_sample_event_id(data):
    sample_id = np.repeat(np.arange(data.shape[0]), data.shape[1]).reshape(data.shape[0], data.shape[1], 1)
    event_id = np.tile(np.arange(data.shape[1]), data.shape[0]).reshape(data.shape[0], data.shape[1], 1)
    return np.concatenate((data, sample_id, event_id), axis = 2), 10, 11

# Add the sample_id and event_id to the data
test_data, entity_col, time_col = add_sample_event_id(TEST['X'])
print(test_data.shape)

(400, 24, 12)


## Global Report

In [106]:
# Set up for the global explainer
pruning_dict = {'tol': [0.0]} #For this example, we will not prune the time steps
event_dict = {'rs': 42, 'nsamples': 400}
feature_dict = {'rs': 42, 'nsamples': 400, 'feature_names': np.arange(10).astype(str).tolist()}
model_features = np.arange(10).tolist()

In [94]:
# Run th global explainer
prun_stats, global_plot = timeshap.explainer.global_report(
                                                        f_hs, 
                                                        test_data,  
                                                        None, # For this example we will not prune the time stemps
                                                        event_dict, 
                                                        feature_dict, 
                                                        average_sequence, 
                                                        entity_col = entity_col, 
                                                        time_col = time_col, 
                                                        model_features = model_features, 
                                                        verbose = False
                                                        )

No path to persist event explanations provided.
No path to persist feature explanations provided.
Calculating event data
No pruning data provided and no pruning tolerances provided. No pruning will take place
Calculating feat data
No pruning data provided and no pruning tolerances provided. No pruning will take place


In [89]:
import altair as alt
alt.data_transformers.disable_max_rows()
global_plot

## Local Report

In [218]:
# Set up for the local explainer
pruning_dict = {'tol': 0.025}
event_dict = {'rs': 42, 'nsamples': 400}
feature_dict = {'rs': 42, 
                'nsamples': 400, 
                'feature_names': [str(a) for a in list(range(10))],
                'plot_features': dict(zip([str(a) for a in list(range(10))], [str(a) for a in list(range(10))]))}
cell_dict = {'rs': 42, 'nsamples': 400, 'top_x_feats': 4, 'top_x_events': 5}


In [222]:
# Let's run the local explainer for the sample with the highest probability of being positive
index = 104
print(f"Sample {index}")
plot = local_report(
                f_hs, 
                np.expand_dims(TEST['X'][index, :, :], axis = 0), 
                pruning_dict, event_dict, feature_dict, cell_dict, average_sequence)
print(f"True Label = {TEST['y'][index]}, Model Probability = {grud_prediction[index][0]}, Average Sequence Probability = {f_hs(average_sequence.reshape(1, 24, 10))[0][0]}")
print(f"Difference between the model and the average sequence = {np.diff([f_hs(average_sequence.reshape(1, 24, 10))[0][0], grud_prediction[index][0]])[0]}")
plot

Sample 104
Assuming all features are model features
True Label = 1, Model Probability = 0.5895307660102844, Average Sequence Probability = 0.41309988498687744
Difference between the model and the average sequence = 0.17643088102340698


In [221]:
# Let's run the local explainer for sample with lowest probability of being positive
index = 391
print(f"Sample {index}")
plot = local_report(
                f_hs, 
                np.expand_dims(TEST['X'][index, :, :], axis = 0), 
                pruning_dict, event_dict, feature_dict, cell_dict, average_sequence)
print(f"True Label = {TEST['y'][index]}, Model Probability = {grud_prediction[index][0]}, Average Sequence Probability = {f_hs(average_sequence.reshape(1, 24, 10))[0][0]}")
print(f"Difference between the model and the average sequence = {np.diff([f_hs(average_sequence.reshape(1, 24, 10))[0][0], grud_prediction[index][0]])[0]}")
plot

Sample 391
Assuming all features are model features
True Label = 0, Model Probability = 0.30300477147102356, Average Sequence Probability = 0.41309988498687744
Difference between the model and the average sequence = [-0.11009511]
