In [1]:
import torch
import random
from model import SetGNN 
import pickle
from tokenizer import EHRTokenizer
from dataset import FinetuneHGDataset, batcher_SetGNN_finetune
from torch.utils.data import DataLoader
import torch.nn.functional as F
from train import PHENO_ORDER, train_with_early_stopping
from set_seed import set_random_seed
import pandas as pd
from print_hg import summarize_hypergraph

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
args = {
    "dataset": "MIMIC-III", 
    "task": "readmission",  # options: death, stay, readmission, next_diag_12m
    "special_tokens":["[PAD]", "[CLS]"],
    "predicted_token_type": ["diag", "med"],
    "batch_size": 256,
    "lr": 1e-3,
    "epochs": 500,
    "model_name": "HG",
    "early_stop_patience": 10,
    # model hyperparameters
    "level": "visit",  # "visit" or "patient"
    "hg_all_num_layers": 3,
    "hg_use_type_embed": True,
    "MLP_num_layers": 2,
    "hg_aggregate": "mean",
    "hg_dropout": 0.0,
    "normtype": "all_one",
    "add_self_loop": True,
    "hg_normalization": "ln",
    "hg_hidden_size": 48,
    "PMA": True,
    "hg_num_heads": 4,
}

In [4]:
full_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic.pkl"

if args["task"] == "next_diag_6m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_6m.pkl"
elif args["task"] == "next_diag_12m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_downstream.pkl"

In [5]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = [[]]
pro_sentences = [[]]
age_gender_sentences = ["[PAD]"] + [str(c) + "_" + gender \
    for c in set(ehr_full_data["AGE"].values.tolist()) for gender in ["M", "F"]] # PAD token special for age_gender vocabulary
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
args["max_adm_len"] = max_admissions
print(f"Max admissions per patient: {max_admissions}")

Max admissions per patient: 8


In [6]:
tokenizer = EHRTokenizer(age_gender_sentences, diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, special_tokens=args["special_tokens"])
args["age_gender_vocab_size"] = tokenizer.token_number("age_gender")
args["global_vocab_size"] = len(tokenizer.vocab.id2word)
args["label_vocab_size"] = len(PHENO_ORDER)
print(f"Age and gender vocabulary size: {args['age_gender_vocab_size']}")
print(f"Global vocabulary size: {args['global_vocab_size']}")
print(f"Label vocabulary size: {args['label_vocab_size']}")

Age and gender vocabulary size: 37
Global vocabulary size: 2145
Label vocabulary size: 18


In [7]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))
train_data["EXP_FLAG"] = 1
val_data["EXP_FLAG"] = 1
test_data["EXP_FLAG"] = 1

# for transductiove learning, otherwise the graph is too sparse
all_exp_ids = pd.concat([
    train_data["SUBJECT_ID"],
    val_data["SUBJECT_ID"],
    test_data["SUBJECT_ID"]
]).unique()

non_exp_data = ehr_full_data[~ehr_full_data["SUBJECT_ID"].isin(all_exp_ids)].copy(deep = True)
non_exp_data["EXP_FLAG"] = 0

# concat
train_data_full = pd.concat([train_data, non_exp_data], axis = 0)
val_data_full = pd.concat([val_data, non_exp_data], axis = 0)
test_data_full = pd.concat([test_data, non_exp_data], axis = 0)

In [8]:
# output: input_ids (a patient has multiple visits), labels
train_dataset = FinetuneHGDataset(train_data_full, tokenizer, token_type=args["predicted_token_type"], task=args["task"], level=args["level"])
val_dataset = FinetuneHGDataset(val_data_full, tokenizer, token_type=args["predicted_token_type"], task=args["task"], level=args["level"])
test_dataset = FinetuneHGDataset(test_data_full, tokenizer, token_type=args["predicted_token_type"], task=args["task"], level=args["level"])
print(len(train_dataset), len(val_dataset), len(test_dataset))

27362 30541 30535


In [9]:
long_adm_seq_crite = 3
val_long_seq_idx, test_long_seq_idx = [], []
for i in range(len(val_dataset)):
    hadm_id = list(val_dataset.records.keys())[i]
    exp_flag = val_dataset.exp_flags[hadm_id]
    num_adms = len(val_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite and exp_flag == True:
        val_long_seq_idx.append(i)
for i in range(len(test_dataset)):
    hadm_id = list(test_dataset.records.keys())[i]
    exp_flag = test_dataset.exp_flags[hadm_id]
    num_adms = len(test_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite and exp_flag == True:
        test_long_seq_idx.append(i)
print(len(val_long_seq_idx), len(test_long_seq_idx))

835 854


In [10]:
use_full_graph = True
train_batch_size = len(train_dataset) if use_full_graph else args["batch_size"]
val_batch_size = len(val_dataset) if use_full_graph else args["batch_size"]
test_batch_size = len(test_dataset) if use_full_graph else args["batch_size"]
train_dataloader = DataLoader(train_dataset, collate_fn=batcher_SetGNN_finetune(device = device), batch_size = train_batch_size, shuffle = True)
val_dataloader = DataLoader(val_dataset, collate_fn=batcher_SetGNN_finetune(device = device), batch_size = val_batch_size, shuffle = False)
test_dataloader = DataLoader(test_dataset, collate_fn=batcher_SetGNN_finetune(device = device), batch_size = test_batch_size, shuffle = False)

In [11]:
# exmain HG properties
batch = next(iter(train_dataloader))
HG_sample, global_node_ids, last_visit_indices, exp_flags, labels = batch
print(HG_sample)
print(last_visit_indices.shape)
print(exp_flags.shape)
print(labels.shape)

Data(edge_index=[2, 588516], n_x=[1], num_hyperedges=[1], totedges=32434, norm=[588516])
torch.Size([27362])
torch.Size([27362])
torch.Size([3131])


In [12]:
if args["task"] in ["death", "stay", "readmission"]:
    eval_metric = "prauc"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "prauc"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [13]:
random.seed(42)
seeds = [random.randint(0, 2**32 - 1) for _ in range(8)]
print(seeds)

[2746317213, 1181241943, 958682846, 3163119785, 1812140441, 127978094, 939042955, 2340505846, 946785248, 2530876844]


In [14]:
final_metrics, final_long_seq_metrics = [], []

for seed in seeds:
    set_random_seed(seed)
    print(f"Training with seed: {seed}")
    
    # Initialize model, optimizer, and loss function
    model = SetGNN(args, tokenizer).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args["lr"])
    
    best_test_metric, best_test_long_seq_metric = train_with_early_stopping(
        model, 
        train_dataloader, 
        val_dataloader, 
        test_dataloader,
        optimizer, 
        loss_fn, 
        device, 
        args,
        val_long_seq_idx,
        test_long_seq_idx,
        task_type=task_type,
        eval_metric = "prauc")
    
    final_metrics.append(best_test_metric)
    final_long_seq_metrics.append(best_test_long_seq_metric)

[INFO] Random seed set to 2746317213
Training with seed: 2746317213


Training Batches: 100%|██████████| 1/1 [00:06<00:00,  6.32s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.56s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.22s/it]


Epoch: 001, Average Loss: 0.7007
Validation: {'precision': 0.41043307086546843, 'recall': 0.9889328063202019, 'f1': 0.580106650154109, 'auc': 0.5754595645899994, 'prauc': 0.4336762840808578}
Test:       {'precision': 0.4084069349028322, 'recall': 0.9865665744725937, 'f1': 0.577674952479871, 'auc': 0.5652940903587982, 'prauc': 0.4353364135708593}


Training Batches: 100%|██████████| 1/1 [00:05<00:00,  5.58s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.31s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.82s/it]


Epoch: 002, Average Loss: 0.6965
Validation: {'precision': 0.4244847192600844, 'recall': 0.9442687746998251, 'f1': 0.5856827611032867, 'auc': 0.590342451429408, 'prauc': 0.456216235867016}
Test:       {'precision': 0.418405564472234, 'recall': 0.926906361118424, 'f1': 0.576554431692283, 'auc': 0.5806313402125334, 'prauc': 0.460390725481577}


Training Batches: 100%|██████████| 1/1 [00:05<00:00,  5.57s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.02s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.50s/it]


Epoch: 003, Average Loss: 0.6926
Validation: {'precision': 0.4370642912462489, 'recall': 0.892094861656553, 'f1': 0.5866909235803154, 'auc': 0.6041935399544096, 'prauc': 0.4739707031920536}
Test:       {'precision': 0.43417203042630353, 'recall': 0.8794942710356402, 'f1': 0.581352829211085, 'auc': 0.5924335221781581, 'prauc': 0.4741408860884402}


Training Batches: 100%|██████████| 1/1 [00:05<00:00,  5.67s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.08s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.02s/it]


Epoch: 004, Average Loss: 0.6892
Validation: {'precision': 0.45055889939713983, 'recall': 0.8284584980204409, 'f1': 0.5836814212215493, 'auc': 0.611783988957902, 'prauc': 0.47917523209841845}
Test:       {'precision': 0.4493752692795145, 'recall': 0.824180165939059, 'f1': 0.5816255356517056, 'auc': 0.5986543955403566, 'prauc': 0.47865276333116963}


Training Batches: 100%|██████████| 1/1 [00:05<00:00,  5.46s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.40s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.25s/it]


Epoch: 005, Average Loss: 0.6864
Validation: {'precision': 0.46032493524732676, 'recall': 0.7727272727242185, 'f1': 0.5769514487647235, 'auc': 0.6165060543321412, 'prauc': 0.48346707444528286}
Test:       {'precision': 0.4525087514575204, 'recall': 0.7661003555876488, 'f1': 0.5689553943904717, 'auc': 0.60299935190073, 'prauc': 0.48171162311129545}


Training Batches: 100%|██████████| 1/1 [00:05<00:00,  5.00s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.89s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.60s/it]


Epoch: 006, Average Loss: 0.6815
Validation: {'precision': 0.4675057796032173, 'recall': 0.719367588929963, 'f1': 0.5667133690362527, 'auc': 0.6196689461906852, 'prauc': 0.4871507203930653}
Test:       {'precision': 0.4604722792595984, 'recall': 0.7088107467376183, 'f1': 0.5582697945080698, 'auc': 0.6058978394910792, 'prauc': 0.4852302613330282}


Training Batches: 100%|██████████| 1/1 [00:05<00:00,  5.16s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.07s/it]
Running inference:   0%|          | 0/1 [00:01<?, ?it/s]

KeyboardInterrupt



In [None]:
import numpy as np
def topk_avg_performance_formatted(performances, long_seq_performances, k=5):
    metrics = ["f1", "auc", "prauc"]
    scores = {m: np.array([p[m] for p in performances]) for m in metrics}

    # 计算排名（值越大排名越靠前）
    ranks = {m: (-scores[m]).argsort().argsort() + 1 for m in metrics}
    avg_ranks = np.mean(np.stack([ranks[m] for m in metrics], axis=1), axis=1)

    # 选 top-k
    topk_idx = np.argsort(avg_ranks)[:k]
    final_avg = {m: np.mean([performances[i][m] for i in topk_idx]) for m in performances[0].keys()}
    final_std = {m: np.std([performances[i][m] for i in topk_idx], ddof=0) for m in performances[0].keys()}
    final_long_seq_avg = {m: np.mean([long_seq_performances[i][m] for i in topk_idx]) for m in long_seq_performances[0].keys()}
    final_long_seq_std = {m: np.std([long_seq_performances[i][m] for i in topk_idx], ddof=0) for m in long_seq_performances[0].keys()}

    # 打印结果（转百分比，均保留两位小数）
    print("Final Metrics:")
    for m in performances[0].keys():
        mean_val = final_avg[m] * 100
        std_val = final_std[m] * 100
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")
    print("\nFinal Long Sequence Metrics:")
    for m in long_seq_performances[0].keys():
        mean_val = final_long_seq_avg[m] * 100
        std_val = final_long_seq_std[m] * 100
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")

In [None]:
def print_per_class_performance(dfs, col_name="prauc"):
    """
    输入一个 DataFrame 列表，对每个疾病在所有表格的指定列计算 mean ± std 并打印。

    参数:
        dfs (list[pd.DataFrame]): 多个表格组成的列表
        col_name (str): 要计算的指标列名 (默认: "prauc")
    """
    # 拼接所有表格
    all_values = pd.concat(dfs, axis=0)

    # 按疾病分组，计算 mean 和 std
    grouped = all_values.groupby(all_values.index)[col_name].agg(["mean", "std"])

    # 打印
    for disease, row in grouped.iterrows():
        mean_val = row["mean"] * 100
        std_val = row["std"] * 100
        print(f"{disease}: {mean_val:.2f} ± {std_val:.2f}")

In [None]:
if task_type == "binary":
    topk_avg_performance_formatted(final_metrics, final_long_seq_metrics)
else:
    final_metrics_global = [metrics["global"] for metrics in final_metrics]
    final_metrics_per_class = [metrics["per_class"] for metrics in final_metrics]
    final_long_seq_metrics_global = [metrics["global"] for metrics in final_long_seq_metrics]
    final_long_seq_metrics_per_class = [metrics["per_class"] for metrics in final_long_seq_metrics]
    topk_avg_performance_formatted(final_metrics_global, final_long_seq_metrics_global)
    print("\nPer-class performance, all patients:")
    print_per_class_performance(final_metrics_per_class, col_name="prauc")
    print("\nPer-class performance, long seq:")
    print_per_class_performance(final_long_seq_metrics_per_class, col_name="prauc")