In [1]:
import torch
import random
from model import SetGNN 
import pickle
from tokenizer import EHRTokenizer
from dataset import FinetuneHGDataset, batcher_SetGNN_finetune
from torch.utils.data import DataLoader
import torch.nn.functional as F
from train import PHENO_ORDER, train_with_early_stopping
from set_seed import set_random_seed
import pandas as pd
from print_hg import summarize_hypergraph

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
args = {
    "dataset": "MIMIC-III", 
    "task": "death",  # options: death, stay, readmission, next_diag_12m
    "special_tokens":["[PAD]", "[CLS]"],
    "predicted_token_type": ["diag", "med"],
    "batch_size": 256,
    "lr": 1e-3,
    "epochs": 500,
    "model_name": "HG",
    "early_stop_patience": 10,
    # model hyperparameters
    "level": "visit",  # "visit" or "patient"
    "hg_all_num_layers": 3,
    "hg_use_type_embed": True,
    "MLP_num_layers": 2,
    "hg_aggregate": "mean",
    "hg_dropout": 0.0,
    "normtype": "all_one",
    "add_self_loop": True,
    "hg_normalization": "ln",
    "hg_hidden_size": 48,
    "PMA": True,
    "hg_num_heads": 4,
}

In [4]:
full_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic.pkl"

if args["task"] == "next_diag_6m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_6m.pkl"
elif args["task"] == "next_diag_12m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_downstream.pkl"

In [5]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = [[]]
pro_sentences = [[]]
age_gender_sentences = ["[PAD]"] + [str(c) + "_" + gender \
    for c in set(ehr_full_data["AGE"].values.tolist()) for gender in ["M", "F"]] # PAD token special for age_gender vocabulary
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
args["max_adm_len"] = max_admissions
print(f"Max admissions per patient: {max_admissions}")

Max admissions per patient: 8


In [6]:
tokenizer = EHRTokenizer(age_gender_sentences, diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, special_tokens=args["special_tokens"])
args["age_gender_vocab_size"] = tokenizer.token_number("age_gender")
args["global_vocab_size"] = len(tokenizer.vocab.id2word)
args["label_vocab_size"] = len(PHENO_ORDER)
print(f"Age and gender vocabulary size: {args['age_gender_vocab_size']}")
print(f"Global vocabulary size: {args['global_vocab_size']}")
print(f"Label vocabulary size: {args['label_vocab_size']}")

Age and gender vocabulary size: 37
Global vocabulary size: 2145
Label vocabulary size: 18


In [7]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))
train_data["EXP_FLAG"] = 1
val_data["EXP_FLAG"] = 1
test_data["EXP_FLAG"] = 1

# for transductiove learning, otherwise the graph is too sparse
all_exp_ids = pd.concat([
    train_data["SUBJECT_ID"],
    val_data["SUBJECT_ID"],
    test_data["SUBJECT_ID"]
]).unique()

non_exp_data = ehr_full_data[~ehr_full_data["SUBJECT_ID"].isin(all_exp_ids)].copy(deep = True)
non_exp_data["EXP_FLAG"] = 0

# concat
train_data_full = pd.concat([train_data, non_exp_data], axis = 0)
val_data_full = pd.concat([val_data, non_exp_data], axis = 0)
test_data_full = pd.concat([test_data, non_exp_data], axis = 0)

In [8]:
# output: input_ids (a patient has multiple visits), labels
train_dataset = FinetuneHGDataset(train_data_full, tokenizer, token_type=args["predicted_token_type"], task=args["task"], level=args["level"])
val_dataset = FinetuneHGDataset(val_data_full, tokenizer, token_type=args["predicted_token_type"], task=args["task"], level=args["level"])
test_dataset = FinetuneHGDataset(test_data_full, tokenizer, token_type=args["predicted_token_type"], task=args["task"], level=args["level"])
print(len(train_dataset), len(val_dataset), len(test_dataset))

27362 30541 30535


In [9]:
long_adm_seq_crite = 3
val_long_seq_idx, test_long_seq_idx = [], []
for i in range(len(val_dataset)):
    hadm_id = list(val_dataset.records.keys())[i]
    exp_flag = val_dataset.exp_flags[hadm_id]
    num_adms = len(val_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite and exp_flag == True:
        val_long_seq_idx.append(i)
for i in range(len(test_dataset)):
    hadm_id = list(test_dataset.records.keys())[i]
    exp_flag = test_dataset.exp_flags[hadm_id]
    num_adms = len(test_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite and exp_flag == True:
        test_long_seq_idx.append(i)
print(len(val_long_seq_idx), len(test_long_seq_idx))

835 854


In [10]:
use_full_graph = True
train_batch_size = len(train_dataset) if use_full_graph else args["batch_size"]
val_batch_size = len(val_dataset) if use_full_graph else args["batch_size"]
test_batch_size = len(test_dataset) if use_full_graph else args["batch_size"]
train_dataloader = DataLoader(train_dataset, collate_fn=batcher_SetGNN_finetune(device = device), batch_size = train_batch_size, shuffle = True)
val_dataloader = DataLoader(val_dataset, collate_fn=batcher_SetGNN_finetune(device = device), batch_size = val_batch_size, shuffle = False)
test_dataloader = DataLoader(test_dataset, collate_fn=batcher_SetGNN_finetune(device = device), batch_size = test_batch_size, shuffle = False)

In [11]:
# exmain HG properties
batch = next(iter(train_dataloader))
HG_sample, global_node_ids, last_visit_indices, exp_flags, labels = batch
print(HG_sample)
print(last_visit_indices.shape)
print(exp_flags.shape)
print(labels.shape)

Data(edge_index=[2, 588516], n_x=[1], num_hyperedges=[1], totedges=32434, norm=[588516])
torch.Size([27362])
torch.Size([27362])
torch.Size([3131])


In [12]:
if args["task"] in ["death", "stay", "readmission"]:
    eval_metric = "prauc"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "prauc"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [13]:
random.seed(42)
seeds = [random.randint(0, 2**32 - 1) for _ in range(8)]
print(seeds)

[2746317213, 1181241943, 958682846, 3163119785, 1812140441, 127978094, 939042955, 2340505846, 946785248, 2530876844]


In [None]:
final_metrics, final_long_seq_metrics = [], []

for seed in seeds:
    set_random_seed(seed)
    print(f"Training with seed: {seed}")
    
    # Initialize model, optimizer, and loss function
    model = SetGNN(args, tokenizer).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args["lr"])
    
    best_test_metric, best_test_long_seq_metric = train_with_early_stopping(
        model, 
        train_dataloader, 
        val_dataloader, 
        test_dataloader,
        optimizer, 
        loss_fn, 
        device, 
        args,
        val_long_seq_idx,
        test_long_seq_idx,
        task_type=task_type,
        eval_metric = "prauc")
    
    final_metrics.append(best_test_metric)
    final_long_seq_metrics.append(best_test_long_seq_metric)

[INFO] Random seed set to 2746317213
Training with seed: 2746317213


Training Batches: 100%|██████████| 1/1 [00:06<00:00,  6.05s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.46s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.61s/it]


Epoch: 001, Average Loss: 0.7069
Validation: {'precision': 0.2683906199802949, 'recall': 0.9846788450148223, 'f1': 0.4218099171173868, 'auc': 0.5322707686930724, 'prauc': 0.2729564731502202}
Test:       {'precision': 0.2864015394479051, 'recall': 0.9889258028738155, 'f1': 0.44416811389985345, 'auc': 0.5283754758482544, 'prauc': 0.28580978378396515}


Training Batches: 100%|██████████| 1/1 [00:05<00:00,  5.40s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.07s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.42s/it]


Epoch: 002, Average Loss: 0.7008
Validation: {'precision': 0.28586723768685607, 'recall': 0.9440188568005656, 'f1': 0.43884399040405847, 'auc': 0.5540362029319155, 'prauc': 0.2848938562047667}
Test:       {'precision': 0.3046524160224454, 'recall': 0.9390919158309021, 'f1': 0.46005696090035453, 'auc': 0.5478981183713002, 'prauc': 0.2969665044792387}


Training Batches: 100%|██████████| 1/1 [00:05<00:00,  5.26s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.68s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.35s/it]


Epoch: 003, Average Loss: 0.6947
Validation: {'precision': 0.2971718024476632, 'recall': 0.8296994696474385, 'f1': 0.43760683372209846, 'auc': 0.5647498211927272, 'prauc': 0.2912676831829576}
Test:       {'precision': 0.3109421208273956, 'recall': 0.8150609080796509, 'f1': 0.45015290119960555, 'auc': 0.5559521470598228, 'prauc': 0.3025905442375967}


Training Batches: 100%|██████████| 1/1 [00:05<00:00,  5.44s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.03s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.82s/it]


Epoch: 004, Average Loss: 0.6871
Validation: {'precision': 0.3102417821235799, 'recall': 0.6729522687055218, 'f1': 0.4246931901749902, 'auc': 0.5753277771397759, 'prauc': 0.29942442579251816}
Test:       {'precision': 0.3190233431705956, 'recall': 0.6583610188224898, 'f1': 0.42978492240396904, 'auc': 0.565780558555125, 'prauc': 0.31019668957088736}


Training Batches: 100%|██████████| 1/1 [00:05<00:00,  5.48s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.23s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.39s/it]


Epoch: 005, Average Loss: 0.6811
Validation: {'precision': 0.32398373983608136, 'recall': 0.469652327634239, 'f1': 0.38344959824574415, 'auc': 0.5872601846054954, 'prauc': 0.3097389821069248}
Test:       {'precision': 0.33449747768593524, 'recall': 0.47729789589990423, 'f1': 0.3933378915709, 'auc': 0.5775542175259879, 'prauc': 0.32028077903320507}


Training Batches: 100%|██████████| 1/1 [00:05<00:00,  5.02s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.60s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.59s/it]


Epoch: 006, Average Loss: 0.6734
Validation: {'precision': 0.36216216215889946, 'recall': 0.23688862698740787, 'f1': 0.28642678538405386, 'auc': 0.6008117256182439, 'prauc': 0.3224175847405204}
Test:       {'precision': 0.38593622240076914, 'recall': 0.2613510520472794, 'f1': 0.3116540064079978, 'auc': 0.5902966840928933, 'prauc': 0.3325338306861316}


Training Batches: 100%|██████████| 1/1 [00:05<00:00,  5.84s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.89s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.90s/it]


Epoch: 007, Average Loss: 0.6668
Validation: {'precision': 0.2990353697653043, 'recall': 0.05480259281051973, 'f1': 0.09262947945294031, 'auc': 0.6146252660712257, 'prauc': 0.33543052777301724}
Test:       {'precision': 0.31044776118476275, 'recall': 0.05758582502736664, 'f1': 0.09715086144156891, 'auc': 0.6037671720223139, 'prauc': 0.34595038287962954}


Training Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
import numpy as np
def topk_avg_performance_formatted(performances, long_seq_performances, k=5):
    metrics = ["f1", "auc", "prauc"]
    scores = {m: np.array([p[m] for p in performances]) for m in metrics}

    # 计算排名（值越大排名越靠前）
    ranks = {m: (-scores[m]).argsort().argsort() + 1 for m in metrics}
    avg_ranks = np.mean(np.stack([ranks[m] for m in metrics], axis=1), axis=1)

    # 选 top-k
    topk_idx = np.argsort(avg_ranks)[:k]
    final_avg = {m: np.mean([performances[i][m] for i in topk_idx]) for m in performances[0].keys()}
    final_std = {m: np.std([performances[i][m] for i in topk_idx], ddof=0) for m in performances[0].keys()}
    final_long_seq_avg = {m: np.mean([long_seq_performances[i][m] for i in topk_idx]) for m in long_seq_performances[0].keys()}
    final_long_seq_std = {m: np.std([long_seq_performances[i][m] for i in topk_idx], ddof=0) for m in long_seq_performances[0].keys()}

    # 打印结果（转百分比，均保留两位小数）
    print("Final Metrics:")
    for m in performances[0].keys():
        mean_val = final_avg[m] * 100
        std_val = final_std[m] * 100
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")
    print("\nFinal Long Sequence Metrics:")
    for m in long_seq_performances[0].keys():
        mean_val = final_long_seq_avg[m] * 100
        std_val = final_long_seq_std[m] * 100
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")

In [None]:
def print_per_class_performance(dfs, col_name="prauc"):
    """
    输入一个 DataFrame 列表，对每个疾病在所有表格的指定列计算 mean ± std 并打印。

    参数:
        dfs (list[pd.DataFrame]): 多个表格组成的列表
        col_name (str): 要计算的指标列名 (默认: "prauc")
    """
    # 拼接所有表格
    all_values = pd.concat(dfs, axis=0)

    # 按疾病分组，计算 mean 和 std
    grouped = all_values.groupby(all_values.index)[col_name].agg(["mean", "std"])

    # 打印
    for disease, row in grouped.iterrows():
        mean_val = row["mean"] * 100
        std_val = row["std"] * 100
        print(f"{disease}: {mean_val:.2f} ± {std_val:.2f}")

In [None]:
if task_type == "binary":
    topk_avg_performance_formatted(final_metrics, final_long_seq_metrics)
else:
    final_metrics_global = [metrics["global"] for metrics in final_metrics]
    final_metrics_per_class = [metrics["per_class"] for metrics in final_metrics]
    final_long_seq_metrics_global = [metrics["global"] for metrics in final_long_seq_metrics]
    final_long_seq_metrics_per_class = [metrics["per_class"] for metrics in final_long_seq_metrics]
    topk_avg_performance_formatted(final_metrics_global, final_long_seq_metrics_global)
    print("\nPer-class performance, all patients:")
    print_per_class_performance(final_metrics_per_class, col_name="prauc")
    print("\nPer-class performance, long seq:")
    print_per_class_performance(final_long_seq_metrics_per_class, col_name="prauc")