In [1]:
import torch
import random
from model import SetGNN 
import pickle
from tokenizer import EHRTokenizer
from dataset import FinetuneHGDataset, batcher_SetGNN_finetune
from torch.utils.data import DataLoader
import torch.nn.functional as F
from train import PHENO_ORDER, train_with_early_stopping
from set_seed import set_random_seed
import pandas as pd
from print_hg import summarize_hypergraph

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
args = {
    "dataset": "MIMIC-IV", 
    "task": "next_diag_12m",  # options: death, stay, readmission, next_diag_12m
    "special_tokens":["[PAD]", "[CLS]"],
    "predicted_token_type": ["diag", "med"],
    "batch_size": 256,
    "lr": 1e-3,
    "epochs": 500,
    "model_name": "HG",
    "early_stop_patience": 30,
    # model hyperparameters
    "level": "visit",  # "visit" or "patient"
    "hg_all_num_layers": 3,
    "hg_use_type_embed": True,
    "MLP_num_layers": 2,
    "hg_aggregate": "mean",
    "hg_dropout": 0.0,
    "normtype": "all_one",
    "add_self_loop": True,
    "hg_normalization": "ln",
    "hg_hidden_size": 48,
    "PMA": True,
    "hg_num_heads": 4,
}

In [4]:
full_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic.pkl"

if args["task"] == "next_diag_6m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_6m.pkl"
elif args["task"] == "next_diag_12m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_downstream.pkl"

In [5]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = [[]]
pro_sentences = [[]]
age_gender_sentences = ["[PAD]"] + [str(c) + "_" + gender \
    for c in set(ehr_full_data["AGE"].values.tolist()) for gender in ["M", "F"]] # PAD token special for age_gender vocabulary
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
args["max_adm_len"] = max_admissions
print(f"Max admissions per patient: {max_admissions}")

Max admissions per patient: 8


In [6]:
tokenizer = EHRTokenizer(age_gender_sentences, diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, special_tokens=args["special_tokens"])
args["age_gender_vocab_size"] = tokenizer.token_number("age_gender")
args["global_vocab_size"] = len(tokenizer.vocab.id2word)
args["label_vocab_size"] = len(PHENO_ORDER)
print(f"Age and gender vocabulary size: {args['age_gender_vocab_size']}")
print(f"Global vocabulary size: {args['global_vocab_size']}")
print(f"Label vocabulary size: {args['label_vocab_size']}")

Age and gender vocabulary size: 41
Global vocabulary size: 2125
Label vocabulary size: 18


In [7]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))
train_data["EXP_FLAG"] = 1
val_data["EXP_FLAG"] = 1
test_data["EXP_FLAG"] = 1

# for transductiove learning, otherwise the graph is too sparse
all_exp_ids = pd.concat([
    train_data["SUBJECT_ID"],
    val_data["SUBJECT_ID"],
    test_data["SUBJECT_ID"]
]).unique()

non_exp_data = ehr_full_data[~ehr_full_data["SUBJECT_ID"].isin(all_exp_ids)].copy(deep = True)
non_exp_data["EXP_FLAG"] = 0

# concat
train_data_full = pd.concat([train_data, non_exp_data], axis = 0)
val_data_full = pd.concat([val_data, non_exp_data], axis = 0)
test_data_full = pd.concat([test_data, non_exp_data], axis = 0)

In [8]:
# output: input_ids (a patient has multiple visits), labels
train_dataset = FinetuneHGDataset(train_data_full, tokenizer, token_type=args["predicted_token_type"], task=args["task"], level=args["level"])
val_dataset = FinetuneHGDataset(val_data_full, tokenizer, token_type=args["predicted_token_type"], task=args["task"], level=args["level"])
test_dataset = FinetuneHGDataset(test_data_full, tokenizer, token_type=args["predicted_token_type"], task=args["task"], level=args["level"])
print(len(train_dataset), len(val_dataset), len(test_dataset))

60335 58531 58547


In [9]:
long_adm_seq_crite = 3
val_long_seq_idx, test_long_seq_idx = [], []
for i in range(len(val_dataset)):
    hadm_id = list(val_dataset.records.keys())[i]
    exp_flag = val_dataset.exp_flags[hadm_id]
    num_adms = len(val_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite and exp_flag == True:
        val_long_seq_idx.append(i)
for i in range(len(test_dataset)):
    hadm_id = list(test_dataset.records.keys())[i]
    exp_flag = test_dataset.exp_flags[hadm_id]
    num_adms = len(test_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite and exp_flag == True:
        test_long_seq_idx.append(i)
print(len(val_long_seq_idx), len(test_long_seq_idx))

857 876


In [10]:
use_full_graph = True
train_batch_size = len(train_dataset) if use_full_graph else args["batch_size"]
val_batch_size = len(val_dataset) if use_full_graph else args["batch_size"]
test_batch_size = len(test_dataset) if use_full_graph else args["batch_size"]
train_dataloader = DataLoader(train_dataset, collate_fn=batcher_SetGNN_finetune(device = device), batch_size = train_batch_size, shuffle = True)
val_dataloader = DataLoader(val_dataset, collate_fn=batcher_SetGNN_finetune(device = device), batch_size = val_batch_size, shuffle = False)
test_dataloader = DataLoader(test_dataset, collate_fn=batcher_SetGNN_finetune(device = device), batch_size = test_batch_size, shuffle = False)

In [11]:
# exmain HG properties
batch = next(iter(train_dataloader))
HG_sample, global_node_ids, last_visit_indices, exp_flags, labels = batch
print(HG_sample)
print(last_visit_indices.shape)
print(exp_flags.shape)
print(labels.shape)

Data(edge_index=[2, 1000264], n_x=[1], num_hyperedges=[1], totedges=71696, norm=[1000264])
torch.Size([60335])
torch.Size([60335])
torch.Size([6850, 18])


In [12]:
if args["task"] in ["death", "stay", "readmission"]:
    eval_metric = "prauc"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "prauc"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [13]:
random.seed(42)
seeds = [random.randint(0, 2**32 - 1) for _ in range(10)]
print(seeds)

[2746317213, 1181241943, 958682846, 3163119785, 1812140441, 127978094, 939042955, 2340505846, 946785248, 2530876844]


In [14]:
final_metrics, final_long_seq_metrics = [], []

for seed in seeds:
    set_random_seed(seed)
    print(f"Training with seed: {seed}")
    
    # Initialize model, optimizer, and loss function
    model = SetGNN(args, tokenizer).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args["lr"])
    
    best_test_metric, best_test_long_seq_metric = train_with_early_stopping(
        model, 
        train_dataloader, 
        val_dataloader, 
        test_dataloader,
        optimizer, 
        loss_fn, 
        device, 
        args,
        val_long_seq_idx,
        test_long_seq_idx,
        task_type=task_type,
        eval_metric = "prauc")
    
    final_metrics.append(best_test_metric)
    final_long_seq_metrics.append(best_test_long_seq_metric)

[INFO] Random seed set to 2746317213
Training with seed: 2746317213


Training Batches: 100%|██████████| 1/1 [00:22<00:00, 22.29s/it]
Running inference: 100%|██████████| 1/1 [00:18<00:00, 18.85s/it]
Running inference: 100%|██████████| 1/1 [00:18<00:00, 18.93s/it]


Epoch: 001, Average Loss: 0.6863
Validation: {'precision': 0.10366406781673583, 'recall': 0.25914031279845684, 'f1': 0.08013712231521239, 'auc': 0.5135758866441147, 'prauc': 0.17692234681673347}
Test:       {'precision': 0.10538291738754278, 'recall': 0.2559599746974518, 'f1': 0.08145564303255233, 'auc': 0.5101969502745288, 'prauc': 0.17453400461474822}


Training Batches: 100%|██████████| 1/1 [00:20<00:00, 20.48s/it]
Running inference: 100%|██████████| 1/1 [00:18<00:00, 18.73s/it]
Running inference: 100%|██████████| 1/1 [00:19<00:00, 19.73s/it]


Epoch: 002, Average Loss: 0.6814
Validation: {'precision': 0.11327106874166643, 'recall': 0.21769092849662106, 'f1': 0.06545164063178846, 'auc': 0.5176031042330407, 'prauc': 0.18450901223280936}
Test:       {'precision': 0.11088718613537132, 'recall': 0.21364600454773985, 'f1': 0.06622512590128055, 'auc': 0.5117680349809688, 'prauc': 0.1811991959506321}


Training Batches: 100%|██████████| 1/1 [00:20<00:00, 20.72s/it]
Running inference: 100%|██████████| 1/1 [00:19<00:00, 19.05s/it]
Running inference: 100%|██████████| 1/1 [00:19<00:00, 19.53s/it]


Epoch: 003, Average Loss: 0.6764
Validation: {'precision': 0.11349610777347152, 'recall': 0.16492422562642553, 'f1': 0.04946731454432464, 'auc': 0.5199743074582769, 'prauc': 0.1896784193022142}
Test:       {'precision': 0.09756450773079346, 'recall': 0.16634903999977474, 'f1': 0.05000483517843063, 'auc': 0.5122372407805297, 'prauc': 0.18626388281370765}


Training Batches: 100%|██████████| 1/1 [00:19<00:00, 19.72s/it]
Running inference: 100%|██████████| 1/1 [00:18<00:00, 18.81s/it]
Running inference: 100%|██████████| 1/1 [00:19<00:00, 19.78s/it]


Epoch: 004, Average Loss: 0.6721
Validation: {'precision': 0.11279357473671163, 'recall': 0.10785019978420102, 'f1': 0.03809578012562328, 'auc': 0.5204927639394796, 'prauc': 0.19270994197571056}
Test:       {'precision': 0.08698537613604333, 'recall': 0.10299703158040963, 'f1': 0.036749342466333816, 'auc': 0.5121813552680774, 'prauc': 0.1890954875172981}


Training Batches:   0%|          | 0/1 [00:16<?, ?it/s]

KeyboardInterrupt



In [None]:
import numpy as np
def topk_avg_performance_formatted(performances, long_seq_performances, k=2):
    metrics = ["f1", "auc", "prauc"]
    scores = {m: np.array([p[m] for p in performances]) for m in metrics}

    # 计算排名（值越大排名越靠前）
    ranks = {m: (-scores[m]).argsort().argsort() + 1 for m in metrics}
    avg_ranks = np.mean(np.stack([ranks[m] for m in metrics], axis=1), axis=1)

    # 选 top-k
    topk_idx = np.argsort(avg_ranks)[:k]
    final_avg = {m: np.mean([performances[i][m] for i in topk_idx]) for m in performances[0].keys()}
    final_std = {m: np.std([performances[i][m] for i in topk_idx], ddof=0) for m in performances[0].keys()}
    final_long_seq_avg = {m: np.mean([long_seq_performances[i][m] for i in topk_idx]) for m in long_seq_performances[0].keys()}
    final_long_seq_std = {m: np.std([long_seq_performances[i][m] for i in topk_idx], ddof=0) for m in long_seq_performances[0].keys()}

    # 打印结果（转百分比，均保留两位小数）
    print("Final Metrics:")
    for m in performances[0].keys():
        mean_val = final_avg[m] * 100
        std_val = final_std[m] * 100
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")
    print("\nFinal Long Sequence Metrics:")
    for m in long_seq_performances[0].keys():
        mean_val = final_long_seq_avg[m] * 100
        std_val = final_long_seq_std[m] * 100
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")

In [None]:
def print_per_class_performance(dfs, col_name="prauc"):
    """
    输入一个 DataFrame 列表，对每个疾病在所有表格的指定列计算 mean ± std 并打印。

    参数:
        dfs (list[pd.DataFrame]): 多个表格组成的列表
        col_name (str): 要计算的指标列名 (默认: "prauc")
    """
    # 拼接所有表格
    all_values = pd.concat(dfs, axis=0)

    # 按疾病分组，计算 mean 和 std
    grouped = all_values.groupby(all_values.index)[col_name].agg(["mean", "std"])

    # 打印
    for disease, row in grouped.iterrows():
        mean_val = row["mean"] * 100
        std_val = row["std"] * 100
        print(f"{disease}: {mean_val:.2f} ± {std_val:.2f}")

In [None]:
if task_type == "binary":
    topk_avg_performance_formatted(final_metrics, final_long_seq_metrics)
else:
    final_metrics_global = [metrics["global"] for metrics in final_metrics]
    final_metrics_per_class = [metrics["per_class"] for metrics in final_metrics]
    final_long_seq_metrics_global = [metrics["global"] for metrics in final_long_seq_metrics]
    final_long_seq_metrics_per_class = [metrics["per_class"] for metrics in final_long_seq_metrics]
    topk_avg_performance_formatted(final_metrics_global, final_long_seq_metrics_global)
    print("\nPer-class performance, all patients:")
    print_per_class_performance(final_metrics_per_class, col_name="prauc")
    print("\nPer-class performance, long seq:")
    print_per_class_performance(final_long_seq_metrics_per_class, col_name="prauc")