In [7]:
import torch
import torch.nn.functional as F
import numpy as np
import os
import sys
import torch.nn as nn
from torch.utils.data import DataLoader
sys.path.append('/home/mei/nas/docker/thesis')
from model.lstm_gnn_embedding import PatientOutcomeModelEmbedding
from dataloader.ts_reader import MultiModalDataset, collate_fn
from dataloader.pyg_reader import GraphDataset
from captum.attr import IntegratedGradients
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
import gc
gc.collect()

torch.cuda.empty_cache()
torch.cuda.synchronize()

In [None]:
class SHAPWrapper(nn.Module):
    def __init__(self, model, graph_data, patient_ids_fixed, ts_data_fixed, lengths_fixed):
        super(SHAPWrapper, self).__init__()
        self.model = model
        self.graph_data = graph_data
        self.patient_ids_fixed = patient_ids_fixed 
        self.ts_data_fixed = ts_data_fixed          # tensor, shape (N, T, D_ts)
        self.lengths_fixed = lengths_fixed          # tensor, shape (N,)
    
    def forward(self, flat_input, ts_summary_input):
 
        ts_summary = ts_summary_input.unsqueeze(1)
        risk_scores, _ = self.model(flat_input, self.graph_data, self.patient_ids_fixed, ts_summary, self.lengths_fixed)
        risk_mean = risk_scores.mean(dim=1)  # shape: (N,)
        return risk_mean

In [5]:
train_data_dir = "/home/mei/nas/docker/thesis/data/hdf/train"
val_data_dir = "/home/mei/nas/docker/thesis/data/hdf/val"
test_data_dir = "/home/mei/nas/docker/thesis/data/hdf/test"

config = {  
    "data_dir": "/home/mei/nas/docker/thesis/data/hdf",
    "graph_dir": "/home/mei/nas/docker/thesis/data/graphs",
    "mode": "k_closest",
    "k": 3         
}

In [6]:
# === LSTM + Flat Dataset ===
lstm_dataset_train = MultiModalDataset(train_data_dir)
lstm_dataset_val = MultiModalDataset(val_data_dir)
lstm_dataset_test = MultiModalDataset(test_data_dir)

lstm_loader_train = DataLoader(lstm_dataset_train , batch_size=32, shuffle=True, collate_fn=collate_fn)
lstm_loader_val = DataLoader(lstm_dataset_val , batch_size=32, shuffle=False, collate_fn=collate_fn)
lstm_loader_test = DataLoader(lstm_dataset_test , batch_size=32, shuffle=False, collate_fn=collate_fn)

# === Graph Dataset ===

graph_dataset = GraphDataset(config)

==> Loading precomputed graph from /home/mei/nas/docker/thesis/data/graphs/diagnosis_graph_k_closest_k3.pt
==> Loading flat features from /home/mei/nas/docker/thesis/data/hdf/final_flat.h5


In [6]:
flat_input_dim = 104
graph_input_dim = 104
ts_input_dim = 162
hidden_dim = 128

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")   
model = PatientOutcomeModelEmbedding(flat_input_dim, graph_input_dim, ts_input_dim, hidden_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss() 

In [None]:
# best_model = '/home/mei/nas/docker/thesis/data/model_results/best_model.pth'
# model.load_state_dict(torch.load(best_model,weights_only=True))

for patient_ids, ts_data, flat_data, risk_data, lengths in lstm_loader_val:
    patient_ids_fixed = torch.tensor([int(pid) for pid in patient_ids], dtype=torch.long, device=device)
    ts_summary_fixed = ts_data.mean(dim=1)  # shape: (batch_size, D_ts) 
    lengths_fixed = lengths.to(device)
    break

wrapper = SHAPWrapper(model, graph_dataset.graph_data, patient_ids_fixed, ts_summary_fixed, lengths_fixed)
wrapper.eval()


In [24]:
def model_predict_multi(inputs):
    """
    inputs : [flat_input_np, ts_summary_np]
    flat_input_np: numpy array, shape (N, D_flat)
    ts_summary_np: numpy array, shape (N, D_ts_summary)
    
    return: risk score, shape (N,) it's the mean of risk scores over time
    """
    flat_input_np, ts_summary_np = inputs
    flat_tensor = torch.tensor(flat_input_np, dtype=torch.float32, device=device)
    ts_summary_tensor = torch.tensor(ts_summary_np, dtype=torch.float32, device=device)
    
    # fix the other inputs : fixed_patient_ids, fixed_ts, fixed_lengths
    with torch.no_grad():
        predictions = wrapper(flat_tensor, ts_summary_tensor)
    return predictions.detach().cpu().numpy()

In [25]:
class ModelPredictWrapper(nn.Module):
    def __init__(self, model_predict_func):
        super(ModelPredictWrapper, self).__init__()
        self.f = model_predict_func

    def forward(self,  *inputs):

        return self.f(list(inputs))


In [10]:
def summarize_ts(ts_data, lengths):
    """
    ts_data: tensor,shape (T, D_ts)
    lengths:valid ts data length
    """
    L = int(lengths.item())
    valid_ts = ts_data[:L]  # (L, D_ts)
    mean_val = valid_ts.mean(dim=0)  # (D_ts,)
    std_val = valid_ts.std(dim=0)    # (D_ts,)
    # concatenate mean and std
    summary = torch.cat((mean_val, std_val), dim=0)  # (2*D_ts,)
    return summary

In [11]:
def build_data_for_shap(num, dataset):
    num_background = num
    all_patient_ids = dataset.patient_ids
    background_ids = np.random.choice(all_patient_ids, size=num_background, replace=False)
    
    flat_list = []
    ts_summary_list = []
    for pid in background_ids:
        
        patient_id, ts_data, flat_data, risk_data = dataset[dataset.patient_ids.index(pid)]
        flat_list.append(flat_data)
        
        ts_tensor = ts_data.clone().detach()  # (T, D_ts)
        ts_summary = ts_tensor.mean(dim=0)  # (D_ts,)
        ts_summary_list.append(ts_summary.detach().cpu().numpy())
    
    flat_np = torch.stack([x.clone().detach() for x in flat_list]).detach().cpu().numpy()
    ts_summary_np = np.stack(ts_summary_list, axis=0)
    return [flat_np, ts_summary_np], background_ids


In [12]:
background_inputs, background_patient_ids = build_data_for_shap(100, lstm_dataset_val)
test_inputs, test_patient_ids = build_data_for_shap(100, lstm_dataset_test)

In [13]:
flat_data= pd.read_hdf(test_data_dir+'/flat.h5')
ts_data= pd.read_hdf(test_data_dir+'/timeseries.h5')
feature_flat = list(flat_data.columns)
feature_ts = list(ts_data.columns)
feature_ts = feature_ts[1:] # remove 'time' column

In [26]:
wrapped_model = ModelPredictWrapper(model_predict_multi)
explainer = shap.DeepExplainer(wrapped_model, background_inputs)
shap_values = explainer.shap_values(test_inputs)


shap.summary_plot(shap_values[0], test_inputs[0], feature_names=feature_flat)
shap.summary_plot(shap_values[1], test_inputs[1], feature_names=feature_ts )

RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED_CUBLAS