In [1]:
import torch
import random
from model import SetGNN 
import pickle
from tokenizer import EHRTokenizer
from dataset import FinetuneHGDataset, batcher_SetGNN_finetune
from torch.utils.data import DataLoader
import torch.nn.functional as F
from train import PHENO_ORDER, train_with_early_stopping
from set_seed import set_random_seed
import pandas as pd
from print_hg import summarize_hypergraph

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
args = {
    "dataset": "MIMIC-III", 
    "task": "next_diag_12m",  # options: death, stay, readmission, next_diag_12m
    "special_tokens":["[PAD]", "[CLS]"],
    "predicted_token_type": ["diag", "med"],
    "batch_size": 256,
    "lr": 1e-3,
    "epochs": 500,
    "model_name": "HG",
    "early_stop_patience": 20,
    # model hyperparameters
    "level": "visit",  # "visit" or "patient"
    "hg_all_num_layers": 3,
    "hg_use_type_embed": True,
    "MLP_num_layers": 2,
    "hg_aggregate": "mean",
    "hg_dropout": 0.0,
    "normtype": "all_one",
    "add_self_loop": True,
    "hg_normalization": "ln",
    "hg_hidden_size": 48,
    "PMA": True,
    "hg_num_heads": 4,
}

In [4]:
full_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic.pkl"

if args["task"] == "next_diag_6m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_6m.pkl"
elif args["task"] == "next_diag_12m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_downstream.pkl"

In [5]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = [[]]
pro_sentences = [[]]
age_gender_sentences = ["[PAD]"] + [str(c) + "_" + gender \
    for c in set(ehr_full_data["AGE"].values.tolist()) for gender in ["M", "F"]] # PAD token special for age_gender vocabulary
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
args["max_adm_len"] = max_admissions
print(f"Max admissions per patient: {max_admissions}")

Max admissions per patient: 8


In [6]:
tokenizer = EHRTokenizer(age_gender_sentences, diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, special_tokens=args["special_tokens"])
args["age_gender_vocab_size"] = tokenizer.token_number("age_gender")
args["global_vocab_size"] = len(tokenizer.vocab.id2word)
args["label_vocab_size"] = len(PHENO_ORDER)
print(f"Age and gender vocabulary size: {args['age_gender_vocab_size']}")
print(f"Global vocabulary size: {args['global_vocab_size']}")
print(f"Label vocabulary size: {args['label_vocab_size']}")

Age and gender vocabulary size: 37
Global vocabulary size: 2145
Label vocabulary size: 18


In [7]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))
train_data["EXP_FLAG"] = 1
val_data["EXP_FLAG"] = 1
test_data["EXP_FLAG"] = 1

# for transductiove learning, otherwise the graph is too sparse
all_exp_ids = pd.concat([
    train_data["SUBJECT_ID"],
    val_data["SUBJECT_ID"],
    test_data["SUBJECT_ID"]
]).unique()

non_exp_data = ehr_full_data[~ehr_full_data["SUBJECT_ID"].isin(all_exp_ids)].copy(deep = True)
non_exp_data["EXP_FLAG"] = 0

# concat
train_data_full = pd.concat([train_data, non_exp_data], axis = 0)
val_data_full = pd.concat([val_data, non_exp_data], axis = 0)
test_data_full = pd.concat([test_data, non_exp_data], axis = 0)

In [8]:
# output: input_ids (a patient has multiple visits), labels
train_dataset = FinetuneHGDataset(train_data_full, tokenizer, token_type=args["predicted_token_type"], task=args["task"], level=args["level"])
val_dataset = FinetuneHGDataset(val_data_full, tokenizer, token_type=args["predicted_token_type"], task=args["task"], level=args["level"])
test_dataset = FinetuneHGDataset(test_data_full, tokenizer, token_type=args["predicted_token_type"], task=args["task"], level=args["level"])
print(len(train_dataset), len(val_dataset), len(test_dataset))

32949 32476 32476


In [9]:
long_adm_seq_crite = 3
val_long_seq_idx, test_long_seq_idx = [], []
for i in range(len(val_dataset)):
    hadm_id = list(val_dataset.records.keys())[i]
    exp_flag = val_dataset.exp_flags[hadm_id]
    num_adms = len(val_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite and exp_flag == True:
        val_long_seq_idx.append(i)
for i in range(len(test_dataset)):
    hadm_id = list(test_dataset.records.keys())[i]
    exp_flag = test_dataset.exp_flags[hadm_id]
    num_adms = len(test_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite and exp_flag == True:
        test_long_seq_idx.append(i)
print(len(val_long_seq_idx), len(test_long_seq_idx))

208 186


In [10]:
use_full_graph = True
train_batch_size = len(train_dataset) if use_full_graph else args["batch_size"]
val_batch_size = len(val_dataset) if use_full_graph else args["batch_size"]
test_batch_size = len(test_dataset) if use_full_graph else args["batch_size"]
train_dataloader = DataLoader(train_dataset, collate_fn=batcher_SetGNN_finetune(device = device), batch_size = train_batch_size, shuffle = True)
val_dataloader = DataLoader(val_dataset, collate_fn=batcher_SetGNN_finetune(device = device), batch_size = val_batch_size, shuffle = False)
test_dataloader = DataLoader(test_dataset, collate_fn=batcher_SetGNN_finetune(device = device), batch_size = test_batch_size, shuffle = False)

In [11]:
# exmain HG properties
batch = next(iter(train_dataloader))
HG_sample, global_node_ids, last_visit_indices, exp_flags, labels = batch
print(HG_sample)
print(last_visit_indices.shape)
print(exp_flags.shape)
print(labels.shape)

Data(edge_index=[2, 702150], n_x=[1], num_hyperedges=[1], totedges=37760, norm=[702150])
torch.Size([32949])
torch.Size([32949])
torch.Size([1883, 18])


In [12]:
if args["task"] in ["death", "stay", "readmission"]:
    eval_metric = "prauc"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "prauc"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [13]:
random.seed(42)
seeds = [random.randint(0, 2**32 - 1) for _ in range(10)]
print(seeds)

[2746317213, 1181241943, 958682846, 3163119785, 1812140441, 127978094, 939042955, 2340505846, 946785248, 2530876844]


In [14]:
final_metrics, final_long_seq_metrics = [], []

for seed in seeds:
    set_random_seed(seed)
    print(f"Training with seed: {seed}")
    
    # Initialize model, optimizer, and loss function
    model = SetGNN(args, tokenizer).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args["lr"])
    
    best_test_metric, best_test_long_seq_metric = train_with_early_stopping(
        model, 
        train_dataloader, 
        val_dataloader, 
        test_dataloader,
        optimizer, 
        loss_fn, 
        device, 
        args,
        val_long_seq_idx,
        test_long_seq_idx,
        task_type=task_type,
        eval_metric = "prauc")
    
    final_metrics.append(best_test_metric)
    final_long_seq_metrics.append(best_test_long_seq_metric)

[INFO] Random seed set to 2746317213
Training with seed: 2746317213


Training Batches: 100%|██████████| 1/1 [00:08<00:00,  8.21s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.43s/it]
Running inference: 100%|██████████| 1/1 [00:08<00:00,  8.38s/it]


Epoch: 001, Average Loss: 0.6991
Validation: {'precision': 0.16116466791526102, 'recall': 0.5838161933992092, 'f1': 0.1903875620316031, 'auc': 0.4945387778626855, 'prauc': 0.2149823233955287}
Test:       {'precision': 0.15723778571459243, 'recall': 0.5831737580406986, 'f1': 0.19209292123181013, 'auc': 0.4934768196309641, 'prauc': 0.213984985079847}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.73s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.50s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.35s/it]


Epoch: 002, Average Loss: 0.6954
Validation: {'precision': 0.136121596393143, 'recall': 0.5413912679934659, 'f1': 0.18374426208345945, 'auc': 0.5101065531717444, 'prauc': 0.2251152338551713}
Test:       {'precision': 0.1624012321098145, 'recall': 0.5420122794742329, 'f1': 0.18506165466392177, 'auc': 0.5119172574601856, 'prauc': 0.22340207556745387}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.33s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.73s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.41s/it]


Epoch: 003, Average Loss: 0.6915
Validation: {'precision': 0.1433551912732732, 'recall': 0.4872017634306455, 'f1': 0.1704201379805579, 'auc': 0.5232529668712698, 'prauc': 0.2347034314966334}
Test:       {'precision': 0.1640211627095286, 'recall': 0.4880086337788587, 'f1': 0.1711153821187326, 'auc': 0.5285655504710065, 'prauc': 0.23209013413787638}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.90s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.14s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.51s/it]


Epoch: 004, Average Loss: 0.6878
Validation: {'precision': 0.1595125513854843, 'recall': 0.4095067868330127, 'f1': 0.16017243915519538, 'auc': 0.5351473317871308, 'prauc': 0.2448752200119866}
Test:       {'precision': 0.1622438010919371, 'recall': 0.4137657478892811, 'f1': 0.16138231164830696, 'auc': 0.5432461675297239, 'prauc': 0.24051170196408783}


Training Batches: 100%|██████████| 1/1 [00:08<00:00,  8.85s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.28s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.36s/it]


Epoch: 005, Average Loss: 0.6844
Validation: {'precision': 0.17378641425456762, 'recall': 0.3116310984277254, 'f1': 0.14484079874459185, 'auc': 0.5446907843652088, 'prauc': 0.2539943560780182}
Test:       {'precision': 0.15768650570439874, 'recall': 0.33173323480494354, 'f1': 0.15074525350469895, 'auc': 0.5539724809791109, 'prauc': 0.2481078768690598}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.56s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.64s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.81s/it]


Epoch: 006, Average Loss: 0.6795
Validation: {'precision': 0.19111435056594353, 'recall': 0.24417621852129834, 'f1': 0.1294461588043428, 'auc': 0.5515653597663721, 'prauc': 0.2622320188494004}
Test:       {'precision': 0.1788343399478003, 'recall': 0.27072858375064485, 'f1': 0.13379643093508098, 'auc': 0.5611211454214727, 'prauc': 0.25468737103215844}


Training Batches: 100%|██████████| 1/1 [00:08<00:00,  8.15s/it]
Running inference: 100%|██████████| 1/1 [00:08<00:00,  8.20s/it]
Running inference: 100%|██████████| 1/1 [00:08<00:00,  8.11s/it]


Epoch: 007, Average Loss: 0.6754
Validation: {'precision': 0.18950147013655988, 'recall': 0.17154958378475638, 'f1': 0.11020043777180383, 'auc': 0.5552198570287586, 'prauc': 0.26671430923529904}
Test:       {'precision': 0.20351460971481505, 'recall': 0.18586498005779725, 'f1': 0.11075354261707741, 'auc': 0.5653284963995018, 'prauc': 0.25950507050260363}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.73s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.54s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.56s/it]


Epoch: 008, Average Loss: 0.6696
Validation: {'precision': 0.13602237519873855, 'recall': 0.10148657150417105, 'f1': 0.08270193638589224, 'auc': 0.5558529921755422, 'prauc': 0.2675092813907908}
Test:       {'precision': 0.19939930302283923, 'recall': 0.1075724617619413, 'f1': 0.08235758114543419, 'auc': 0.5663525422024288, 'prauc': 0.2611995454535061}


Training Batches: 100%|██████████| 1/1 [00:08<00:00,  8.09s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.49s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.74s/it]


Epoch: 009, Average Loss: 0.6641
Validation: {'precision': 0.14193658245529306, 'recall': 0.045951203410460245, 'f1': 0.04926758121946259, 'auc': 0.5535820924374524, 'prauc': 0.26564058063754387}
Test:       {'precision': 0.12460065913917193, 'recall': 0.05501389816594776, 'f1': 0.05247913074001247, 'auc': 0.5653282054553164, 'prauc': 0.26159205509523603}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.64s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.52s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.49s/it]


Epoch: 010, Average Loss: 0.6585
Validation: {'precision': 0.08448004201680673, 'recall': 0.015809095640357627, 'f1': 0.0234251325365787, 'auc': 0.5494731865761785, 'prauc': 0.2633122348471138}
Test:       {'precision': 0.08677734997007198, 'recall': 0.017501980680216005, 'f1': 0.024786025315532447, 'auc': 0.5632806508204091, 'prauc': 0.26077713028750477}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.62s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.41s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.87s/it]


Epoch: 011, Average Loss: 0.6516
Validation: {'precision': 0.07032163742690059, 'recall': 0.003109983741777376, 'f1': 0.005800389118150082, 'auc': 0.5438338168487208, 'prauc': 0.2597209789277269}
Test:       {'precision': 0.09166666666666666, 'recall': 0.004352800936805736, 'f1': 0.008010788385044437, 'auc': 0.5587379075358672, 'prauc': 0.2582349882055919}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.86s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.29s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.40s/it]


Epoch: 012, Average Loss: 0.6430
Validation: {'precision': 0.09427609427609428, 'recall': 0.0006230297862908833, 'f1': 0.001229662149202379, 'auc': 0.5372724899348029, 'prauc': 0.2558986925199849}
Test:       {'precision': 0.05555555555555555, 'recall': 0.0007966290619148367, 'f1': 0.001563125865305484, 'auc': 0.5531740296896396, 'prauc': 0.2547493505086973}


Training Batches: 100%|██████████| 1/1 [00:08<00:00,  8.17s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.26s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.69s/it]


Epoch: 013, Average Loss: 0.6366
Validation: {'precision': 0.018518518518518517, 'recall': 0.00010787486515641856, 'f1': 0.0002145002145002145, 'auc': 0.5305806061395776, 'prauc': 0.2520875877476212}
Test:       {'precision': 0.06018518518518518, 'recall': 0.0004984776009726781, 'f1': 0.000987885608802656, 'auc': 0.5475855820260632, 'prauc': 0.25019944192824334}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.94s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.75s/it]
Running inference: 100%|██████████| 1/1 [00:08<00:00,  8.01s/it]


Epoch: 014, Average Loss: 0.6297
Validation: {'precision': 0.018518518518518517, 'recall': 0.00010787486515641856, 'f1': 0.0002145002145002145, 'auc': 0.52475140500412, 'prauc': 0.24868096654374094}
Test:       {'precision': 0.041666666666666664, 'recall': 0.0002888645037626817, 'f1': 0.0005723208621759347, 'auc': 0.542530377076912, 'prauc': 0.24747144784298905}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.51s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.68s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.30s/it]


Epoch: 015, Average Loss: 0.6198
Validation: {'precision': 0.027777777777777776, 'recall': 0.00010787486515641856, 'f1': 0.0002149151085321298, 'auc': 0.5193427110787343, 'prauc': 0.24561448769742814}
Test:       {'precision': 0.07407407407407407, 'recall': 0.0002888645037626817, 'f1': 0.0005752794214332677, 'auc': 0.5377491384565826, 'prauc': 0.24501963951537353}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.90s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.69s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.82s/it]


Epoch: 016, Average Loss: 0.6105
Validation: {'precision': 0.027777777777777776, 'recall': 0.0002157497303128371, 'f1': 0.00042817383857846286, 'auc': 0.5148606783734428, 'prauc': 0.24328427939629987}
Test:       {'precision': 0.027777777777777776, 'recall': 0.00011022927689594356, 'f1': 0.00021958717610891522, 'auc': 0.5340702369203288, 'prauc': 0.2422836990074414}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.71s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.38s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.84s/it]


Epoch: 017, Average Loss: 0.6020
Validation: {'precision': 0.031746031746031744, 'recall': 0.0004314994606256742, 'f1': 0.0008514261387824605, 'auc': 0.5114196905335603, 'prauc': 0.24155344582276475}
Test:       {'precision': 0.08333333333333333, 'recall': 0.0003935288203200976, 'f1': 0.0007825108817919499, 'auc': 0.5311101726894026, 'prauc': 0.2407267676761901}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.71s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.49s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.82s/it]


Epoch: 018, Average Loss: 0.5946
Validation: {'precision': 0.034722222222222224, 'recall': 0.0005393743257820927, 'f1': 0.0010622477161674102, 'auc': 0.5083674463095623, 'prauc': 0.239724733476188}
Test:       {'precision': 0.08333333333333333, 'recall': 0.0003935288203200976, 'f1': 0.0007825108817919499, 'auc': 0.5285973362782451, 'prauc': 0.23917114636146256}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.64s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.72s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.37s/it]


Epoch: 019, Average Loss: 0.5857
Validation: {'precision': 0.04040404040404041, 'recall': 0.0008629989212513484, 'f1': 0.0016899028305872412, 'auc': 0.5058373034613663, 'prauc': 0.23817985507809183}
Test:       {'precision': 0.07777777777777778, 'recall': 0.0003935288203200976, 'f1': 0.0007816514607323524, 'auc': 0.5265747864656283, 'prauc': 0.23802432398473228}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.84s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.46s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.98s/it]


Epoch: 020, Average Loss: 0.5755
Validation: {'precision': 0.042735042735042736, 'recall': 0.0010787486515641855, 'f1': 0.0021043771043771043, 'auc': 0.5037216831180602, 'prauc': 0.23646643806530646}
Test:       {'precision': 0.018518518518518517, 'recall': 0.0002204585537918871, 'f1': 0.00043572984749455336, 'auc': 0.5250720457375565, 'prauc': 0.2368804683762208}


Training Batches: 100%|██████████| 1/1 [00:08<00:00,  8.10s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.62s/it]
Running inference: 100%|██████████| 1/1 [00:08<00:00,  8.21s/it]


Epoch: 021, Average Loss: 0.5661
Validation: {'precision': 0.044444444444444446, 'recall': 0.0012944983818770227, 'f1': 0.0025157232704402514, 'auc': 0.5026897031621477, 'prauc': 0.23535792866861133}
Test:       {'precision': 0.018518518518518517, 'recall': 0.00033068783068783067, 'f1': 0.000649772579597141, 'auc': 0.5242876089589588, 'prauc': 0.23615158493828206}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.60s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.16s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.44s/it]


Epoch: 022, Average Loss: 0.5589
Validation: {'precision': 0.0457516339869281, 'recall': 0.0015102481121898597, 'f1': 0.0029239766081871343, 'auc': 0.5018965865558301, 'prauc': 0.23372528798834377}
Test:       {'precision': 0.01515151515151515, 'recall': 0.00033068783068783067, 'f1': 0.0006472491909385113, 'auc': 0.5236248329227009, 'prauc': 0.2349320809287804}


Training Batches: 100%|██████████| 1/1 [00:08<00:00,  8.30s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.16s/it]
Running inference: 100%|██████████| 1/1 [00:08<00:00,  8.06s/it]


Epoch: 023, Average Loss: 0.5506
Validation: {'precision': 0.0457516339869281, 'recall': 0.0015102481121898597, 'f1': 0.0029239766081871343, 'auc': 0.5023813414222015, 'prauc': 0.23328074493475537}
Test:       {'precision': 0.018518518518518517, 'recall': 0.0004409171075837742, 'f1': 0.0008613264427217916, 'auc': 0.5240767501790262, 'prauc': 0.23445513891848987}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.62s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.25s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.76s/it]


Epoch: 024, Average Loss: 0.5417
Validation: {'precision': 0.044444444444444446, 'recall': 0.001725997842502697, 'f1': 0.0033229491173416407, 'auc': 0.5039593292174671, 'prauc': 0.23404030116100882}
Test:       {'precision': 0.017361111111111112, 'recall': 0.0005511463844797178, 'f1': 0.0010683760683760685, 'auc': 0.5256764162565573, 'prauc': 0.2350509143438669}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.38s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.40s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.83s/it]


Epoch: 025, Average Loss: 0.5336
Validation: {'precision': 0.0966183574879227, 'recall': 0.0019367533660953297, 'f1': 0.0037163268718491387, 'auc': 0.5063618418887966, 'prauc': 0.23577882366451228}
Test:       {'precision': 0.0196078431372549, 'recall': 0.0006613756613756613, 'f1': 0.0012795905310300703, 'auc': 0.5279033013581867, 'prauc': 0.23651440416312797}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.62s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.19s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.18s/it]


Epoch: 026, Average Loss: 0.5266
Validation: {'precision': 0.1183127572016461, 'recall': 0.0020381321171091173, 'f1': 0.003892436034089798, 'auc': 0.5098197018390299, 'prauc': 0.23805437400622498}
Test:       {'precision': 0.08333333333333333, 'recall': 0.0016472464974337632, 'f1': 0.003131660928020987, 'auc': 0.5310452167133977, 'prauc': 0.23894221659899997}


Training Batches: 100%|██████████| 1/1 [00:08<00:00,  8.10s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.99s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.81s/it]


Epoch: 027, Average Loss: 0.5178
Validation: {'precision': 0.11035353535353536, 'recall': 0.0028781503966347943, 'f1': 0.005473425395289781, 'auc': 0.5136941495277565, 'prauc': 0.24062011318482476}
Test:       {'precision': 0.13804713804713806, 'recall': 0.002283075142094651, 'f1': 0.004344358064945384, 'auc': 0.5345624979720891, 'prauc': 0.2414433475394525}


Training Batches: 100%|██████████| 1/1 [00:08<00:00,  8.27s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.86s/it]
Running inference: 100%|██████████| 1/1 [00:08<00:00,  8.13s/it]


Epoch: 028, Average Loss: 0.5116
Validation: {'precision': 0.11983618233618233, 'recall': 0.003917433878890268, 'f1': 0.007429668342821983, 'auc': 0.5176846492216929, 'prauc': 0.24354008339737623}
Test:       {'precision': 0.1302748741773132, 'recall': 0.00334665918963532, 'f1': 0.006335934367910502, 'auc': 0.5382183389215156, 'prauc': 0.2440601111590325}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.54s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.37s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.41s/it]


Epoch: 029, Average Loss: 0.5057
Validation: {'precision': 0.12535612535612536, 'recall': 0.005497593594350262, 'f1': 0.010313766594540927, 'auc': 0.5219089399443104, 'prauc': 0.24610619744946416}
Test:       {'precision': 0.13032012245848723, 'recall': 0.00526808322101519, 'f1': 0.00985231407114021, 'auc': 0.5420826617838385, 'prauc': 0.24643453488597217}


Training Batches: 100%|██████████| 1/1 [00:08<00:00,  8.21s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.14s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.96s/it]


Epoch: 030, Average Loss: 0.4984
Validation: {'precision': 0.12560906346217082, 'recall': 0.007151645696831093, 'f1': 0.013279012563211425, 'auc': 0.527080648913291, 'prauc': 0.24898830613891912}
Test:       {'precision': 0.12455851905979783, 'recall': 0.007409189413493734, 'f1': 0.013595895427134207, 'auc': 0.5468502501933115, 'prauc': 0.24907530110121767}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.75s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.14s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.71s/it]


Epoch: 031, Average Loss: 0.4938
Validation: {'precision': 0.12283903835627973, 'recall': 0.009135330024470884, 'f1': 0.016725193438011533, 'auc': 0.5330757951645908, 'prauc': 0.2523048639804315}
Test:       {'precision': 0.1165508468140047, 'recall': 0.008570229229572962, 'f1': 0.015589056940723548, 'auc': 0.5522634125912519, 'prauc': 0.25217082899711046}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.46s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.37s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.68s/it]


Epoch: 032, Average Loss: 0.4884
Validation: {'precision': 0.12230955600631058, 'recall': 0.011859155788044992, 'f1': 0.021253659740667737, 'auc': 0.5402014744125813, 'prauc': 0.2560809983739619}
Test:       {'precision': 0.11069928493863584, 'recall': 0.009744042554009574, 'f1': 0.01751051732304226, 'auc': 0.5581118045150403, 'prauc': 0.25559475653870606}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.85s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.28s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.37s/it]


Epoch: 033, Average Loss: 0.4835
Validation: {'precision': 0.11981532118310234, 'recall': 0.013002821836159107, 'f1': 0.023030715619551303, 'auc': 0.5478333265245338, 'prauc': 0.2604739844208043}
Test:       {'precision': 0.1123220181472609, 'recall': 0.011475971311624505, 'f1': 0.0201975455910942, 'auc': 0.5649070283988857, 'prauc': 0.2596819199190799}


Training Batches: 100%|██████████| 1/1 [00:08<00:00,  8.17s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.05s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.71s/it]


Epoch: 034, Average Loss: 0.4783
Validation: {'precision': 0.11951713910098244, 'recall': 0.015315124965988362, 'f1': 0.026600715797731433, 'auc': 0.5572821478121275, 'prauc': 0.2662295796711244}
Test:       {'precision': 0.11425106326422114, 'recall': 0.013278644232023403, 'f1': 0.023129839509149857, 'auc': 0.5736493921937367, 'prauc': 0.26502609784744363}


Training Batches: 100%|██████████| 1/1 [00:08<00:00,  8.15s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.45s/it]
Running inference: 100%|██████████| 1/1 [00:08<00:00,  8.25s/it]


Epoch: 035, Average Loss: 0.4736
Validation: {'precision': 0.11922692805935935, 'recall': 0.017212901647227902, 'f1': 0.029272919895454726, 'auc': 0.5655855142753912, 'prauc': 0.27160226547101907}
Test:       {'precision': 0.11429177268871925, 'recall': 0.015106475972790467, 'f1': 0.02577138902264492, 'auc': 0.5804706006006456, 'prauc': 0.2700921352141308}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.40s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.02s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.29s/it]


Epoch: 036, Average Loss: 0.4720
Validation: {'precision': 0.11938807402942425, 'recall': 0.020140986820252012, 'f1': 0.03347595182333073, 'auc': 0.5731263407423446, 'prauc': 0.2774764880261492}
Test:       {'precision': 0.11533119658119657, 'recall': 0.018193967887215133, 'f1': 0.03037891007121016, 'auc': 0.5869709792276034, 'prauc': 0.2755127589414721}


Training Batches: 100%|██████████| 1/1 [00:08<00:00,  8.18s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.38s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.81s/it]


Epoch: 037, Average Loss: 0.4663
Validation: {'precision': 0.17438463890929415, 'recall': 0.02228814260503888, 'f1': 0.036543372722585205, 'auc': 0.5792662592912989, 'prauc': 0.28270787533600295}
Test:       {'precision': 0.11304713804713806, 'recall': 0.02052146378268576, 'f1': 0.033648110753031306, 'auc': 0.5918762133171677, 'prauc': 0.280080356928428}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.83s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.51s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.99s/it]


Epoch: 038, Average Loss: 0.4623
Validation: {'precision': 0.17129462205128793, 'recall': 0.023634566155180572, 'f1': 0.0383138515356002, 'auc': 0.5846903030766767, 'prauc': 0.28765240722206886}
Test:       {'precision': 0.11403186390290296, 'recall': 0.02293480652737738, 'f1': 0.0371288937514337, 'auc': 0.5964835935752141, 'prauc': 0.2844877365261858}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.49s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.18s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.94s/it]


Epoch: 039, Average Loss: 0.4602
Validation: {'precision': 0.1697419344478168, 'recall': 0.02509186838532353, 'f1': 0.040253566497521494, 'auc': 0.5882091261179203, 'prauc': 0.29110522251809345}
Test:       {'precision': 0.11264513658130679, 'recall': 0.024081908246059398, 'f1': 0.03853778634125455, 'auc': 0.5994014284742689, 'prauc': 0.28746033298654006}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.38s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.24s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.20s/it]


Epoch: 040, Average Loss: 0.4563
Validation: {'precision': 0.17062632243979395, 'recall': 0.026711140165191143, 'f1': 0.04256421509656587, 'auc': 0.5911725200231785, 'prauc': 0.29377953835864634}
Test:       {'precision': 0.11272380500962087, 'recall': 0.025650960907962414, 'f1': 0.040662592881776795, 'auc': 0.6019794487570036, 'prauc': 0.2901245854619169}


Training Batches: 100%|██████████| 1/1 [00:08<00:00,  8.00s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.62s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.01s/it]


Epoch: 041, Average Loss: 0.4544
Validation: {'precision': 0.17156442735953192, 'recall': 0.028986981940556537, 'f1': 0.04566469445394645, 'auc': 0.5940892665475148, 'prauc': 0.29624814902297314}
Test:       {'precision': 0.15619555920855555, 'recall': 0.028341214627897873, 'f1': 0.04462790109491103, 'auc': 0.6047388993621152, 'prauc': 0.2926140237101968}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.52s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.05s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.55s/it]


Epoch: 042, Average Loss: 0.4516
Validation: {'precision': 0.17127448070167997, 'recall': 0.03293450914850714, 'f1': 0.05094247282100398, 'auc': 0.5980757173363553, 'prauc': 0.2997784247875665}
Test:       {'precision': 0.15768954167012417, 'recall': 0.03310773832166006, 'f1': 0.05095535316995651, 'auc': 0.6086258830012226, 'prauc': 0.2967035870910452}


Training Batches: 100%|██████████| 1/1 [00:06<00:00,  6.81s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.53s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.71s/it]


Epoch: 043, Average Loss: 0.4483
Validation: {'precision': 0.15223369357790295, 'recall': 0.03647002523825996, 'f1': 0.05536352808636302, 'auc': 0.6018235511829046, 'prauc': 0.3032804227180894}
Test:       {'precision': 0.15103718186725654, 'recall': 0.0368561252734122, 'f1': 0.05540824418155839, 'auc': 0.6120776570279415, 'prauc': 0.3005615534230178}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.44s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.82s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.18s/it]


Epoch: 044, Average Loss: 0.4479
Validation: {'precision': 0.15702281166236312, 'recall': 0.04217382274220179, 'f1': 0.06197917134954188, 'auc': 0.6062818677609572, 'prauc': 0.30813629872584286}
Test:       {'precision': 0.14546295455386363, 'recall': 0.042002785618203746, 'f1': 0.06090352911477142, 'auc': 0.6158676009130148, 'prauc': 0.30521303295927976}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.23s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.84s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.05s/it]


Epoch: 045, Average Loss: 0.4432
Validation: {'precision': 0.16018899285482513, 'recall': 0.04835971694808443, 'f1': 0.06872161085477282, 'auc': 0.6109298415726107, 'prauc': 0.3133882079043977}
Test:       {'precision': 0.1410831348048353, 'recall': 0.04646836592497869, 'f1': 0.06497186549522997, 'auc': 0.6193199314730131, 'prauc': 0.30952234634023584}


Training Batches: 100%|██████████| 1/1 [00:06<00:00,  6.94s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.55s/it]
Running inference: 100%|██████████| 1/1 [00:07<00:00,  7.25s/it]


Epoch: 046, Average Loss: 0.4411
Validation: {'precision': 0.1599603374113178, 'recall': 0.054961429075442665, 'f1': 0.07468864009555638, 'auc': 0.6158290597662339, 'prauc': 0.3185965426256104}
Test:       {'precision': 0.173085140527001, 'recall': 0.052487327139159705, 'f1': 0.07135365720707362, 'auc': 0.6227022383842002, 'prauc': 0.3138507031126527}


Training Batches: 100%|██████████| 1/1 [00:06<00:00,  6.92s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.72s/it]
Running inference: 100%|██████████| 1/1 [00:06<00:00,  6.53s/it]


Epoch: 047, Average Loss: 0.4394
Validation: {'precision': 0.23693387305837105, 'recall': 0.05827775002101513, 'f1': 0.07777811609541951, 'auc': 0.6192842298424219, 'prauc': 0.32181613940645976}
Test:       {'precision': 0.22847370432438455, 'recall': 0.05664553947124237, 'f1': 0.0757076130689303, 'auc': 0.6255452116421423, 'prauc': 0.3167100234699853}


Training Batches: 100%|██████████| 1/1 [00:07<00:00,  7.34s/it]
Running inference:   0%|          | 0/1 [00:00<?, ?it/s]

KeyboardInterrupt



In [None]:
import numpy as np
def topk_avg_performance_formatted(performances, long_seq_performances, k=5):
    metrics = ["f1", "auc", "prauc"]
    scores = {m: np.array([p[m] for p in performances]) for m in metrics}

    # 计算排名（值越大排名越靠前）
    ranks = {m: (-scores[m]).argsort().argsort() + 1 for m in metrics}
    avg_ranks = np.mean(np.stack([ranks[m] for m in metrics], axis=1), axis=1)

    # 选 top-k
    topk_idx = np.argsort(avg_ranks)[:k]
    final_avg = {m: np.mean([performances[i][m] for i in topk_idx]) for m in performances[0].keys()}
    final_std = {m: np.std([performances[i][m] for i in topk_idx], ddof=0) for m in performances[0].keys()}
    final_long_seq_avg = {m: np.mean([long_seq_performances[i][m] for i in topk_idx]) for m in long_seq_performances[0].keys()}
    final_long_seq_std = {m: np.std([long_seq_performances[i][m] for i in topk_idx], ddof=0) for m in long_seq_performances[0].keys()}

    # 打印结果（转百分比，均保留两位小数）
    print("Final Metrics:")
    for m in performances[0].keys():
        mean_val = final_avg[m] * 100
        std_val = final_std[m] * 100
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")
    print("\nFinal Long Sequence Metrics:")
    for m in long_seq_performances[0].keys():
        mean_val = final_long_seq_avg[m] * 100
        std_val = final_long_seq_std[m] * 100
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")

In [None]:
def print_per_class_performance(dfs, col_name="prauc"):
    """
    输入一个 DataFrame 列表，对每个疾病在所有表格的指定列计算 mean ± std 并打印。

    参数:
        dfs (list[pd.DataFrame]): 多个表格组成的列表
        col_name (str): 要计算的指标列名 (默认: "prauc")
    """
    # 拼接所有表格
    all_values = pd.concat(dfs, axis=0)

    # 按疾病分组，计算 mean 和 std
    grouped = all_values.groupby(all_values.index)[col_name].agg(["mean", "std"])

    # 打印
    for disease, row in grouped.iterrows():
        mean_val = row["mean"] * 100
        std_val = row["std"] * 100
        print(f"{disease}: {mean_val:.2f} ± {std_val:.2f}")

In [None]:
if task_type == "binary":
    topk_avg_performance_formatted(final_metrics, final_long_seq_metrics)
else:
    final_metrics_global = [metrics["global"] for metrics in final_metrics]
    final_metrics_per_class = [metrics["per_class"] for metrics in final_metrics]
    final_long_seq_metrics_global = [metrics["global"] for metrics in final_long_seq_metrics]
    final_long_seq_metrics_per_class = [metrics["per_class"] for metrics in final_long_seq_metrics]
    topk_avg_performance_formatted(final_metrics_global, final_long_seq_metrics_global)
    print("\nPer-class performance, all patients:")
    print_per_class_performance(final_metrics_per_class, col_name="prauc")
    print("\nPer-class performance, long seq:")
    print_per_class_performance(final_long_seq_metrics_per_class, col_name="prauc")