In [None]:
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.nn as nn
import sys
sys.path.append('/home/mei/nas/docker/thesis')
from dataloader.ts_reader import MultiModalDataset, collate_fn
from dataloader.pyg_reader import GraphDataset

from lstm_gnn_embedding import PatientOutcomeModelEmbedding
import pickle
from sklearn.manifold import TSNE
import numpy as np
from sklearn.cluster import KMeans

In [2]:
train_data_dir = "/home/mei/nas/docker/thesis/data/hdf/train"
val_data_dir = "/home/mei/nas/docker/thesis/data/hdf/val"
test_data_dir = "/home/mei/nas/docker/thesis/data/hdf/test"

config = {  
    "data_dir": "/home/mei/nas/docker/thesis/data/hdf",
    "graph_dir": "/home/mei/nas/docker/thesis/data/graphs",
    "mode": "k_closest",
    "k": 3         
}

In [3]:
# === LSTM + Flat Dataset ===
lstm_dataset_train = MultiModalDataset(train_data_dir)
lstm_dataset_val = MultiModalDataset(val_data_dir)
lstm_dataset_test = MultiModalDataset(test_data_dir)

lstm_loader_train = DataLoader(lstm_dataset_train , batch_size=32, shuffle=True, collate_fn=collate_fn)
lstm_loader_val = DataLoader(lstm_dataset_val , batch_size=32, shuffle=False, collate_fn=collate_fn)
lstm_loader_test = DataLoader(lstm_dataset_test , batch_size=32, shuffle=False, collate_fn=collate_fn)

# === Graph Dataset ===

graph_dataset = GraphDataset(config)


==> Loading precomputed graph from /home/mei/nas/docker/thesis/data/graphs/diagnosis_graph_k_closest_k3.pt
==> Loading flat features from /home/mei/nas/docker/thesis/data/hdf/final_flat.h5


In [4]:
def debug_patient_data(dataset, patient_id_debug):
    idx_debug = None
    for i in range(len(dataset)):
        if dataset.patient_ids[i] == patient_id_debug:
            idx_debug = i
            break
    if idx_debug is None:
        print(f"Patient {patient_id_debug} not found in dataset.")
        return

    pid, ts_data, flat_data, risk_data = dataset[idx_debug]
    print(f"Raw data for patient {patient_id_debug}:")
    print(f"  ts_data shape: {ts_data.shape}")
    print(f"  flat_data shape: {flat_data.shape}")
    print(f"  risk_data shape: {risk_data.shape}")
    
    return pid, ts_data, flat_data, risk_data


In [5]:
patient_id_debug = '1788546'
pid, ts_data, flat_data, risk_data = debug_patient_data(lstm_dataset_train, patient_id_debug)


Raw data for patient 1788546:
  ts_data shape: torch.Size([422, 162])
  flat_data shape: torch.Size([104])
  risk_data shape: torch.Size([422])


In [6]:
patient_id_debug = '3132351'
pid2, ts_data2, flat_data2, risk_data2 = debug_patient_data(lstm_dataset_train, patient_id_debug)

Raw data for patient 3132351:
  ts_data shape: torch.Size([1759, 162])
  flat_data shape: torch.Size([104])
  risk_data shape: torch.Size([1759])


In [7]:
from torch.utils.data import DataLoader

debug_batch = [(pid, ts_data, flat_data, risk_data),(pid2,ts_data2,flat_data2,risk_data2)]  # 只包含一个患者

debug_patient_ids, debug_padded_ts, debug_flat, debug_padded_risk, debug_lengths = collate_fn(debug_batch)

print("After collate_fn:")
print(f"  debug_patient_ids: {debug_patient_ids}")
print(f"  debug_padded_ts shape: {debug_padded_ts.shape}")
print(f"  debug_flat shape: {debug_flat.shape}")
print(f"  debug_padded_risk shape: {debug_padded_risk.shape}")
print(f"  debug_lengths: {debug_lengths}")


After collate_fn:
  debug_patient_ids: ['3132351', '1788546']
  debug_padded_ts shape: torch.Size([2, 1759, 162])
  debug_flat shape: torch.Size([2, 104])
  debug_padded_risk shape: torch.Size([2, 1759])
  debug_lengths: tensor([1759,  422])


In [8]:
flat_input_dim = 104
graph_input_dim = 104
ts_input_dim = 162
hidden_dim = 128

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")   
model = PatientOutcomeModelEmbedding(flat_input_dim, graph_input_dim, ts_input_dim, hidden_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss() 

In [10]:
train_loss = 0.0
debug_patient_ids_tensor = torch.tensor([int(pid) for pid in debug_patient_ids], dtype=torch.long)

debug_padded_ts = debug_padded_ts.to(device)
debug_flat = debug_flat.to(device)
debug_padded_risk = debug_padded_risk.to(device)
debug_lengths = debug_lengths.to(device)

with torch.no_grad():
    outputs, embedding = model(
        debug_flat,
        graph_dataset.graph_data,  # 包含 x, edge_index, patient_ids
        debug_patient_ids_tensor,
        debug_padded_ts,
        debug_lengths
        
    )

print("Model forward output:")
print("Model output shapes for debug patient:")
print("  risk_scores:", outputs.shape)
print("  embeddings:", embedding.shape)

graph encoder output: node_embeddings.shape = torch.Size([11698, 128])
batch graph embeddings shape = torch.Size([2, 128])
flat encoder output: flat_emb.shape = torch.Size([2, 128])
Time series Encoder input: x.shape = torch.Size([2, 1759, 162]), lengths.shape = torch.Size([2])
packed sequence batch_sizes = tensor([2, 2, 2,  ..., 1, 1, 1])
Time series Encoder output: out.shape = torch.Size([2, 1759, 256])
Time series encoder output: ts_emb.shape = torch.Size([2, 1759, 256])
risk predictor output: risk_scores.shape = torch.Size([2, 1759])
combimed_embeddings.shape = torch.Size([2, 1759, 128])
Model forward output:
Model output shapes for debug patient:
  risk_scores: torch.Size([2, 1759])
  embeddings: torch.Size([2, 1759, 128])


In [16]:
loss_list = []
batch_size = outputs.shape[0]
outputs_np = outputs.detach().cpu().numpy()
padded_risk_np = debug_padded_risk.detach().cpu().numpy()
lengths_np = debug_lengths.detach().cpu().numpy()

for i in range(batch_size):
    L = int(lengths_np[i])
    valid_output = outputs_np[i][:L]
    valid_target = padded_risk_np[i][:L]
    sample_loss = criterion(torch.tensor(valid_output), torch.tensor(valid_target))
    loss_list.append(sample_loss)
    print(f"Loss for patient {debug_patient_ids[i]}:", sample_loss.item())
    print(f"valid_output shape: {valid_output.shape}")
    print(f"valid_target shape: {valid_target.shape}")

loss = sum(loss_list) / len(loss_list)
print("Loss for debug patient:", loss.item())

Loss for patient 3132351: 0.04646523669362068
valid_output shape: (1759,)
valid_target shape: (1759,)
Loss for patient 1788546: 0.03306068480014801
valid_output shape: (422,)
valid_target shape: (422,)
Loss for debug patient: 0.0397629588842392


In [None]:
ts_mask = debug_padded_ts != -99
ts_mask = ts_mask.any(dim=2)
risk_mask =debug_padded_risk!= -99
combined_mask = ts_mask & risk_mask
masked_outputs = outputs[combined_mask]
masked_risk_data = debug_padded_risk[combined_mask]
loss_list = []
batch_size = outputs.size(0)

for i in range(batch_size):
    valid_time_steps = combined_mask[i]  # boolean mask, marks which time steps are valid
    valid_output = outputs[i][valid_time_steps]  # only take the valid time steps
    valid_target = debug_padded_risk[i][valid_time_steps]
    sample_loss = criterion(valid_output, valid_target)
    loss_list.append(sample_loss)
    print(f"Loss for patient {debug_patient_ids[i]}: {sample_loss.item()}")
    print(f"valid_output shape: {valid_output.shape}")
    print(f"valid_target shape: {valid_target.shape}")

loss = torch.mean(torch.stack(loss_list))
print("Loss for debug patient:", loss.item())

Loss for patient 3132351: 0.04646524414420128
valid_output shape: torch.Size([1759])
valid_target shape: torch.Size([1759])
Loss for patient 1788546: 0.03306068480014801
valid_output shape: torch.Size([422])
valid_target shape: torch.Size([422])
Loss for debug patient: 0.039762966334819794
