In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import os
import sys
import torch.nn as nn
from torch.utils.data import DataLoader
sys.path.append('/home/mei/nas/docker/thesis')
from model.lstm_gnn_embedding import PatientOutcomeModelEmbedding
from dataloader.ts_reader import MultiModalDataset, collate_fn
from dataloader.pyg_reader import GraphDataset
from captum.attr import IntegratedGradients
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
import gc
gc.collect()

torch.cuda.empty_cache()
torch.cuda.synchronize()

In [4]:
train_data_dir = "/home/mei/nas/docker/thesis/data/hdf/train"
val_data_dir = "/home/mei/nas/docker/thesis/data/hdf/val"
test_data_dir = "/home/mei/nas/docker/thesis/data/hdf/test"

config = {  
    "data_dir": "/home/mei/nas/docker/thesis/data/hdf",
    "graph_dir": "/home/mei/nas/docker/thesis/data/graphs",
    "mode": "k_closest",
    "k": 3         
}

In [5]:
# === LSTM + Flat Dataset ===
lstm_dataset_train = MultiModalDataset(train_data_dir)
lstm_dataset_val = MultiModalDataset(val_data_dir)
lstm_dataset_test = MultiModalDataset(test_data_dir)

lstm_loader_train = DataLoader(lstm_dataset_train , batch_size=32, shuffle=True, collate_fn=collate_fn)
lstm_loader_val = DataLoader(lstm_dataset_val , batch_size=32, shuffle=False, collate_fn=collate_fn)
lstm_loader_test = DataLoader(lstm_dataset_test , batch_size=32, shuffle=False, collate_fn=collate_fn)

# === Graph Dataset ===

graph_dataset = GraphDataset(config)

==> Loading precomputed graph from /home/mei/nas/docker/thesis/data/graphs/diagnosis_graph_k_closest_k3.pt
==> Loading flat features from /home/mei/nas/docker/thesis/data/hdf/final_flat.h5


In [None]:
flat_data= pd.read_hdf(test_data_dir+'/flat.h5')
ts_data= pd.read_hdf(test_data_dir+'/timeseries.h5')
feature_flat = list(flat_data.columns)
feature_ts = list(ts_data.columns)
feature_ts = feature_ts[1:] # remove 'time' column

In [60]:
flat_input_dim = 104
graph_input_dim = 104
ts_input_dim = 162
hidden_dim = 128

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")   
model = PatientOutcomeModelEmbedding(flat_input_dim, graph_input_dim, ts_input_dim, hidden_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss() 

In [68]:
graph_dataset.graph_data.x = graph_dataset.graph_data.x.to(device)
graph_dataset.graph_data.edge_index = graph_dataset.graph_data.edge_index.to(device)
graph_dataset.graph_data.patient_ids =graph_dataset.graph_data.patient_ids.clone().detach().to(device)

for patient_ids, ts_data, flat_data, risk_data, lengths in lstm_loader_val:
    fixed_patient_ids = torch.tensor([int(pid) for pid in patient_ids], dtype=torch.long).to(device)  # shape: (N,)
    fixed_ts = ts_data.to(device)  # shape: (N, T, D_ts)
    fixed_lengths = lengths.to(device)  # shape: (N,)
    fixed_flat = flat_data.to(device)  # shape: (N, D_flat)
    break

In [74]:
class ModelWrapperForIG(nn.Module):
    def __init__(self, model, graph_data, patient_ids_fixed, ts_data_fixed, lengths_fixed):
        super(ModelWrapperForIG, self).__init__()
        self.model = model
        self.graph_data = graph_data
        self.patient_ids_fixed = patient_ids_fixed  
        self.ts_data_fixed = ts_data_fixed  
        self.lengths_fixed = lengths_fixed  
        
    def forward(self, flat_input, ts_input):
        """
        flat_input: tensor, shape (N, D_flat)
        ts_input: tensor, shape (N, T, D_ts)
        其他输入固定：graph_data, patient_ids_fixed, ts_data_fixed, lengths_fixed
        这里我们将调用模型并返回每个样本 risk score 取所有时间步均
        """
        risk_scores, _ = self.model(flat_input, self.graph_data, self.patient_ids_fixed, ts_input, self.lengths_fixed)
        # risk_scores shape: (N, T) -> 返回每个样本在时间维度上的均值
        return risk_scores

In [70]:
wrapper = ModelWrapperForIG(model, graph_dataset.graph_data, fixed_patient_ids, fixed_ts, fixed_lengths)
wrapper.eval()

ModelWrapperForIG(
  (model): PatientOutcomeModelEmbedding(
    (flat_encoder): Linear(in_features=104, out_features=128, bias=True)
    (graph_encoder): GraphEncoder(
      (gcn1): GCNConv(104, 128)
      (gcn2): GCNConv(128, 128)
      (bn1): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (ts_encoder): TimeSeriesEncoder(
      (lstm): LSTM(162, 128, batch_first=True, bidirectional=True)
    )
    (risk_predictor): RiskPredictor(
      (fc1): Linear(in_features=512, out_features=128, bias=True)
      (fc2): Linear(in_features=128, out_features=1, bias=True)
    )
  )
)

In [75]:
def forward_flat(flat_input):
    """
    解释平面特征：
    flat_input: tensor, shape (N, D_flat)
    固定 ts_input 使用 ts_sample（从测试集中选取的样本的时序部分）
    """
    # 在这里，我们固定 ts_input 为测试样本 ts_sample（稍后获取）
    risk_scores = wrapper(flat_input, ts_sample)
    return risk_scores

def forward_ts(ts_input):
    """
    解释时序特征：
    ts_input: tensor, shape (N, T, D_ts)
    固定平面输入为 flat_sample（从测试集中选取的样本的平面部分）
    """
    risk_scores = wrapper(flat_sample, ts_input)
    return risk_scores


In [76]:
print("ts_sample shape:", ts_sample.shape)
print("baseline_ts shape:", baseline_ts.shape)

ts_sample shape: torch.Size([32, 3995, 162])
baseline_ts shape: torch.Size([32, 3995, 162])


In [79]:
for patient_ids, ts_data, flat_data, risk_data, lengths in lstm_loader_test:
    flat_sample = flat_data.to(device)       # shape: (batch_size, D_flat)
    ts_sample = ts_data.to(device)           # shape: (batch_size, T, D_ts)
    # 为 Integrated Gradients 选择基线，通常使用全零
    baseline_flat = torch.zeros_like(flat_sample)
    baseline_ts = torch.zeros_like(ts_sample)
    break


ig_flat = IntegratedGradients(forward_flat)

# 对于平面特征 Integrated Gradients，不设置 target 参数
ig_flat = IntegratedGradients(forward_flat)
attr_flat, delta_flat = ig_flat.attribute(flat_sample, baseline_flat, return_convergence_delta=True)
print("Flat attribution:", attr_flat)
print("Flat convergence delta:", delta_flat)

# 对于时序特征 Integrated Gradients，也不设置 target 参数
ig_ts = IntegratedGradients(forward_ts)
attr_ts, delta_ts = ig_ts.attribute(ts_sample, baseline_ts, return_convergence_delta=True)
print("Time series attribution:", attr_ts)
print("Time series convergence delta:", delta_ts)

RuntimeError: shape '[36, 1, 256]' is invalid for input of size 2560

In [None]:
attr_flat_np = attr_flat.mean(dim=0).cpu().detach().numpy()  # (D_flat,)
plt.figure(figsize=(10, 5))
plt.bar(range(len(attr_flat_np)), attr_flat_np)
plt.xlabel("Flat Feature Index")
plt.ylabel("Attribution")
plt.title("Integrated Gradients for Flat Features")
plt.show()

In [None]:
attr_ts_np = attr_ts.mean(dim=0).cpu().detach().numpy()  # (T, D_ts)
# 例如，我们取每个时间步所有特征的均值
time_step_attr = attr_ts_np.mean(axis=1)
plt.figure(figsize=(10, 5))
plt.plot(time_step_attr, marker='o')
plt.xlabel("Time Step")
plt.ylabel("Mean Attribution")
plt.title("Integrated Gradients for Time Series Features (Averaged over Features)")
plt.show()

In [25]:
class ModelWrapperForIG(torch.nn.Module):
    def __init__(self, model, graph_data, patient_ids_fixed, ts_data_fixed, lengths_fixed):
        """
        model: OutcomeModelEmbedding model, set it to eval mode
        graph_data: fixed graph data
        patient_ids_fixed: shape (N,)
        ts_data_fixed: shape (N, T, D_ts)
        lengths_fixed: the valid lengths of each time series in ts_data_fixed
        """
        super(ModelWrapperForIG, self).__init__()
        self.model = model
        self.graph_data = graph_data
        self.patient_ids_fixed = torch.tensor(patient_ids_fixed, dtype=torch.long).to(model.device)  # Ensure it's a tensor and on the same device
        self.ts_data_fixed = ts_data_fixed.to(model.device)
        self.lengths_fixed = lengths_fixed.to(model.device)
        self.device = self.patient_ids_fixed.device  # Get the device

    def forward(self, flat_input, ts_input):
        """
        flat_input: tensor, shape (N, D_flat)
        ts_input: tensor, shape (N, T, D_ts)
        this function should return the risk scores for each sample in the batch
        """
        # get risk_scores and combined embeddings
        risk_scores, combined = self.model(flat_input, self.graph_data, self.patient_ids_fixed, ts_input, self.lengths_fixed)
        # risk_scores shape is (N, T)
        risk_mean = risk_scores.mean(dim=1)  # shape (N,)
        return risk_mean

In [21]:
for patient_ids, ts_data, flat_data, risk_data ,lengths in lstm_loader_val:
    patient_ids_fixed = patient_ids  # Ensure it's a list
    ts_data_fixed = ts_data
    lengths_fixed = lengths
    break

In [None]:


ts_data_fixed = ts_data_fixed.to(device)
lengths_fixed = lengths_fixed.to(device)

In [27]:
wrapper = ModelWrapperForIG(model, graph_dataset.graph_data, patient_ids_fixed, ts_data_fixed, lengths_fixed)
wrapper.eval()

ig = IntegratedGradients(wrapper)

# 选择一个输入样本和基线数据进行解释
for patient_ids, ts_data, flat_data, risk_data in lstm_loader_test:
    flat_input = flat_data.to(device)
    ts_sample = ts_data.to(device)
    baseline_ts = torch.zeros_like(ts_sample).to(device)  # 使用全零作为基线数据
    break


attr_ts, delta_ts = ig.attribute((flat_input, ts_sample), (flat_input, baseline_ts), target=0, return_convergence_delta=True)
print("Time series attribution:", attr_ts)
print("Time series convergence delta:", delta_ts)

ValueError: too many dimensions 'str'

In [None]:
attr_ts = attr_ts[1].cpu().detach().numpy()  # 获取时间序列部分的归因
ts_sample = ts_sample.cpu().detach().numpy()

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(np.mean(attr_ts, axis=0))
plt.xlabel('Time Step')
plt.ylabel('Attribution')
plt.title('Time Series Feature Importance using Integrated Gradients')
plt.show()