In [1]:
import torch
import random
from model import SetGNN 
import pickle
from tokenizer import EHRTokenizer
from dataset import FinetuneHGDataset, batcher_SetGNN_finetune
from torch.utils.data import DataLoader
import torch.nn.functional as F
from train import PHENO_ORDER, train_with_early_stopping
from set_seed import set_random_seed
import pandas as pd

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
args = {
    "dataset": "MIMIC-IV", 
    "task": "next_diag_12m",  # options: death, stay, readmission, next_diag_6m, next_diag_12m
    "special_tokens":["[PAD]", "[CLS]"],
    "predicted_token_type": ["diag", "med", "lab", "pro"],
    "batch_size": 256,
    "lr": 1e-3,
    "epochs": 500,
    "model_name": "HG",
    "early_stop_patience": 10,
    # model hyperparameters
    "level": "patient",  # "visit" or "patient"
    "hg_all_num_layers": 3,
    "hg_use_type_embed": True,
    "MLP_num_layers": 2,
    "hg_aggregate": "mean",
    "hg_dropout": 0.0,
    "normtype": "all_one",
    "add_self_loop": True,
    "hg_normalization": "ln",
    "hg_hidden_size": 128,
    "PMA": True,
    "hg_num_heads": 4,
}

In [4]:
full_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic.pkl"

if args["task"] == "next_diag_6m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_6m.pkl"
elif args["task"] == "next_diag_12m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_downstream.pkl"

In [5]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_gender_sentences = ["[PAD]"] + [str(c) + "_" + gender \
    for c in set(ehr_full_data["AGE"].values.tolist()) for gender in ["M", "F"]] # PAD token special for age_gender vocabulary
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
args["max_adm_len"] = max_admissions
print(f"Max admissions per patient: {max_admissions}")

Max admissions per patient: 8


In [6]:
tokenizer = EHRTokenizer(age_gender_sentences, diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, special_tokens=args["special_tokens"])
args["age_gender_vocab_size"] = tokenizer.token_number("age_gender")
args["global_vocab_size"] = len(tokenizer.vocab.id2word)
args["label_vocab_size"] = len(PHENO_ORDER)
print(f"Age and gender vocabulary size: {args['age_gender_vocab_size']}")
print(f"Global vocabulary size: {args['global_vocab_size']}")
print(f"Label vocabulary size: {args['label_vocab_size']}")

Age and gender vocabulary size: 41
Global vocabulary size: 4207
Label vocabulary size: 18


In [7]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))

# output: input_ids (a patient has multiple visits), labels
train_dataset = FinetuneHGDataset(train_data, tokenizer, token_type=args["predicted_token_type"], task=args["task"], level=args["level"])
val_dataset = FinetuneHGDataset(val_data, tokenizer, token_type=args["predicted_token_type"], task=args["task"], level=args["level"])
test_dataset = FinetuneHGDataset(test_data, tokenizer, token_type=args["predicted_token_type"], task=args["task"], level=args["level"])
print(len(train_dataset), len(val_dataset), len(test_dataset))

6850 5046 5062


In [8]:
long_adm_seq_crite = 3
val_long_seq_idx, test_long_seq_idx = [], []
for i in range(len(val_dataset)):
    hadm_id = list(val_dataset.records.keys())[i]
    num_adms = len(val_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite:
        val_long_seq_idx.append(i)
for i in range(len(test_dataset)):
    hadm_id = list(test_dataset.records.keys())[i]
    num_adms = len(test_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite:
        test_long_seq_idx.append(i)
print(len(val_long_seq_idx), len(test_long_seq_idx))

857 876


In [9]:
use_full_graph = True
train_batch_size = len(train_dataset) if use_full_graph else args["batch_size"]
val_batch_size = len(val_dataset) if use_full_graph else args["batch_size"]
test_batch_size = len(test_dataset) if use_full_graph else args["batch_size"]
train_dataloader = DataLoader(train_dataset, collate_fn=batcher_SetGNN_finetune(device = device), batch_size = train_batch_size, shuffle = True)
val_dataloader = DataLoader(val_dataset, collate_fn=batcher_SetGNN_finetune(device = device), batch_size = val_batch_size, shuffle = False)
test_dataloader = DataLoader(test_dataset, collate_fn=batcher_SetGNN_finetune(device = device), batch_size = test_batch_size, shuffle = False)

In [10]:
if args["task"] in ["death", "stay", "readmission"]:
    eval_metric = "prauc"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "prauc"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [11]:
random.seed(42)
seeds = [random.randint(0, 2**32 - 1) for _ in range(15)]
print(seeds)

[2746317213, 1181241943, 958682846, 3163119785, 1812140441, 127978094, 939042955, 2340505846, 946785248, 2530876844, 3460967357, 2998485882, 1461364854, 667779376, 1445662585]


In [None]:
final_metrics, final_long_seq_metrics = [], []

for seed in seeds:
    set_random_seed(seed)
    print(f"Training with seed: {seed}")
    
    # Initialize model, optimizer, and loss function
    model = SetGNN(args, tokenizer).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args["lr"])
    
    best_test_metric, best_test_long_seq_metric = train_with_early_stopping(
        model, 
        train_dataloader, 
        val_dataloader, 
        test_dataloader,
        optimizer, 
        loss_fn, 
        device, 
        args,
        val_long_seq_idx,
        test_long_seq_idx,
        task_type=task_type,
        eval_metric = "prauc")
    
    final_metrics.append(best_test_metric)
    final_long_seq_metrics.append(best_test_long_seq_metric)

[INFO] Random seed set to 2746317213
Training with seed: 2746317213


Training Batches: 100%|██████████| 1/1 [00:02<00:00,  2.17s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.18s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.18s/it]


Epoch: 001, Average Loss: 0.6892
Validation: {'precision': 0.11760698198360703, 'recall': 0.18848865371940005, 'f1': 0.09564478904182769, 'auc': 0.5356592983214251, 'prauc': 0.19357521604115235}
Test:       {'precision': 0.10947074532390982, 'recall': 0.18784223596121932, 'f1': 0.09331952433306151, 'auc': 0.5369945186671163, 'prauc': 0.19617338036109344}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.89s/it]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  1.28it/s]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.18s/it]


Epoch: 002, Average Loss: 0.6796
Validation: {'precision': 0.14072426929849532, 'recall': 0.10796984170098531, 'f1': 0.07465882485041112, 'auc': 0.5505306341708036, 'prauc': 0.20441414682307515}
Test:       {'precision': 0.1192145411472168, 'recall': 0.10519121431265692, 'f1': 0.07385940242854719, 'auc': 0.5540175667618117, 'prauc': 0.21105374312857006}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.87s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.17s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.18s/it]


Epoch: 003, Average Loss: 0.6705
Validation: {'precision': 0.10685244626077331, 'recall': 0.0656404309739096, 'f1': 0.06287167201861112, 'auc': 0.5547758289838539, 'prauc': 0.21095136466634362}
Test:       {'precision': 0.1160454594130049, 'recall': 0.06734835627707456, 'f1': 0.0648236090598583, 'auc': 0.556981580526336, 'prauc': 0.21864330764332526}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.85s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.16s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.16s/it]


Epoch: 004, Average Loss: 0.6602
Validation: {'precision': 0.08819728465109786, 'recall': 0.04469012085583524, 'f1': 0.05256911058362132, 'auc': 0.5548787048136268, 'prauc': 0.21515641033044627}
Test:       {'precision': 0.09656243689664008, 'recall': 0.048210928080093945, 'f1': 0.05627590577863902, 'auc': 0.5564413564442686, 'prauc': 0.22246346971533132}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.85s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.16s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.19s/it]


Epoch: 005, Average Loss: 0.6483
Validation: {'precision': 0.08850183751661583, 'recall': 0.022764975210427158, 'f1': 0.03326311769935336, 'auc': 0.5534051705869376, 'prauc': 0.21622898492682746}
Test:       {'precision': 0.1607682410385047, 'recall': 0.027913289325287416, 'f1': 0.04086816387017185, 'auc': 0.5543100795883658, 'prauc': 0.22389201827862357}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.88s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.17s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.17s/it]


Epoch: 006, Average Loss: 0.6342
Validation: {'precision': 0.08258198280453473, 'recall': 0.01423570152569198, 'f1': 0.01833420150354513, 'auc': 0.5492809576688805, 'prauc': 0.21329049433165156}
Test:       {'precision': 0.12259300229432973, 'recall': 0.016037447467044965, 'f1': 0.021305108844034314, 'auc': 0.5508362220297583, 'prauc': 0.2206106603791538}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.87s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.17s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.18s/it]


Epoch: 007, Average Loss: 0.6184
Validation: {'precision': 0.02374301675977654, 'recall': 0.014617368873602753, 'f1': 0.018094731240021287, 'auc': 0.545269350392295, 'prauc': 0.20889592838993687}
Test:       {'precision': 0.047214569941842664, 'recall': 0.0164780479561835, 'f1': 0.020976387306656456, 'auc': 0.5467526377806794, 'prauc': 0.21611759559718563}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.86s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.16s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.17s/it]


Epoch: 008, Average Loss: 0.6002
Validation: {'precision': 0.024205748865355523, 'recall': 0.013757523645743765, 'f1': 0.017543859649122806, 'auc': 0.5416244784119602, 'prauc': 0.20529171474114213}
Test:       {'precision': 0.02992075044476791, 'recall': 0.016223800754187494, 'f1': 0.02103946320937109, 'auc': 0.5430444522176298, 'prauc': 0.2116226431223013}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.87s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.14s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.11s/it]


Epoch: 009, Average Loss: 0.5800
Validation: {'precision': 0.027583527583527584, 'recall': 0.010174835196331326, 'f1': 0.014865996649916249, 'auc': 0.5385403607333139, 'prauc': 0.20141779962237652}
Test:       {'precision': 0.03260357815442561, 'recall': 0.012145926510567395, 'f1': 0.01769854961344323, 'auc': 0.5398805592614001, 'prauc': 0.20677422488608288}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.81s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.17s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.17s/it]


Epoch: 010, Average Loss: 0.5589
Validation: {'precision': 0.03835978835978836, 'recall': 0.002770612400878953, 'f1': 0.00516795865633075, 'auc': 0.535692720074612, 'prauc': 0.19656356637300362}
Test:       {'precision': 0.042105263157894736, 'recall': 0.0031570639305445935, 'f1': 0.005873715124816447, 'auc': 0.5369100575561626, 'prauc': 0.2011282973130574}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.87s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.18s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.18s/it]


Epoch: 011, Average Loss: 0.5364
Validation: {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.5344937242186473, 'prauc': 0.19361483994001494}
Test:       {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.5354573199762948, 'prauc': 0.19767854108448202}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.88s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.18s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.18s/it]


Epoch: 012, Average Loss: 0.5142
Validation: {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.5329369977067552, 'prauc': 0.19070120582795508}
Test:       {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.5337099335664335, 'prauc': 0.19408466192418852}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.86s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.17s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.17s/it]


Epoch: 013, Average Loss: 0.4934
Validation: {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.5321989118003434, 'prauc': 0.187111762489665}
Test:       {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.5326847518505664, 'prauc': 0.1896349611198469}


Training Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
import numpy as np
def topk_avg_performance_formatted(performances, long_seq_performances, k=5):
    metrics = ["f1", "auc", "prauc"]
    scores = {m: np.array([p[m] for p in performances]) for m in metrics}

    # 计算排名（值越大排名越靠前）
    ranks = {m: (-scores[m]).argsort().argsort() + 1 for m in metrics}
    avg_ranks = np.mean(np.stack([ranks[m] for m in metrics], axis=1), axis=1)

    # 选 top-k
    topk_idx = np.argsort(avg_ranks)[:k]
    final_avg = {m: np.mean([performances[i][m] for i in topk_idx]) for m in performances[0].keys()}
    final_std = {m: np.std([performances[i][m] for i in topk_idx], ddof=0) for m in performances[0].keys()}
    final_long_seq_avg = {m: np.mean([long_seq_performances[i][m] for i in topk_idx]) for m in long_seq_performances[0].keys()}
    final_long_seq_std = {m: np.std([long_seq_performances[i][m] for i in topk_idx], ddof=0) for m in long_seq_performances[0].keys()}

    # 打印结果（转百分比，均保留两位小数）
    print("Final Metrics:")
    for m in performances[0].keys():
        mean_val = final_avg[m] * 100
        std_val = final_std[m] * 100
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")
    print("\nFinal Long Sequence Metrics:")
    for m in long_seq_performances[0].keys():
        mean_val = final_long_seq_avg[m] * 100
        std_val = final_long_seq_std[m] * 100
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")

In [None]:
def print_per_class_performance(dfs, col_name="prauc"):
    """
    输入一个 DataFrame 列表，对每个疾病在所有表格的指定列计算 mean ± std 并打印。

    参数:
        dfs (list[pd.DataFrame]): 多个表格组成的列表
        col_name (str): 要计算的指标列名 (默认: "prauc")
    """
    # 拼接所有表格
    all_values = pd.concat(dfs, axis=0)

    # 按疾病分组，计算 mean 和 std
    grouped = all_values.groupby(all_values.index)[col_name].agg(["mean", "std"])

    # 打印
    for disease, row in grouped.iterrows():
        mean_val = row["mean"] * 100
        std_val = row["std"] * 100
        print(f"{disease}: {mean_val:.2f} ± {std_val:.2f}")

In [None]:
if task_type == "binary":
    topk_avg_performance_formatted(final_metrics, final_long_seq_metrics)
else:
    final_metrics_global = [metrics["global"] for metrics in final_metrics]
    final_metrics_per_class = [metrics["per_class"] for metrics in final_metrics]
    final_long_seq_metrics_global = [metrics["global"] for metrics in final_long_seq_metrics]
    final_long_seq_metrics_per_class = [metrics["per_class"] for metrics in final_long_seq_metrics]
    topk_avg_performance_formatted(final_metrics_global, final_long_seq_metrics_global)
    print("\nPer-class performance, all patients:")
    print_per_class_performance(final_metrics_per_class, col_name="prauc")
    print("\nPer-class performance, long seq:")
    print_per_class_performance(final_long_seq_metrics_per_class, col_name="prauc")