In [1]:
import torch
import pickle
import numpy as np
from argparse import Namespace
from torch.utils.data import DataLoader
import torch.nn.functional as F
from heterogt.utils.tokenizer import EHRTokenizer
from heterogt.utils.dataset import FineTuneEHRDataset, batcher, expand_level3
from heterogt.utils.train import train_with_early_stopping
from heterogt.utils.seed import set_random_seed
from heterogt.model.model import HeteroGTFineTune

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
set_random_seed(123)

[INFO] Random seed set to 123


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
config = Namespace(
    dataset = "MIMIC-III",
    tasks = ["death", "readmission", "stay", "next_diag_6m", "next_diag_12m"], 
    task_index = 2,  # index of the task to train
    token_type = ["diag", "med", "lab", "pro"],
    special_tokens = ["[PAD]", "[CLS]"],
    attn_mask_dicts = [{1:[6,7], 2:[6,7], 3:[6,7], 4:[6,7], 6:[2,3,4,5,6,7], 7:[2,3,4,5,6,7]}, 
                       {1:[6,7], 2:[6,7], 3:[6,7], 4:[6,7], 6:[2,3,4,5,6,7], 7:[2,3,4,5,6,7]}],
    d_model = 64,
    num_heads = 4,
    batch_size = 32,
    lr = 1e-3,
    epochs = 500,
    early_stop_patience = 5,
    group_code_thre = 5,  # if there are group_code_thre diag codes belongs to the same group ICD code, then the group code is generated
    use_pretrained_model = True,
    pretrain_mask_rate = 0.7,
    pretrain_cls_ontology_weight = 5e-2,
    pretrain_visit_ontology_weight = 5e-2,
    pretrain_adm_type_weight = 5e-2,
    dec_loss_lambda = 1e-2,
)

In [5]:
full_data_path = f"./data_process/{config.dataset}-processed/mimic.pkl"  # for tokenizer
curr_task = config.tasks[config.task_index]
print("Current task:", curr_task)
if curr_task == "next_diag_6m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_6m.pkl"
elif curr_task == "next_diag_12m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_downstream.pkl"

Current task: stay


In [6]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
group_code_sentences = [expand_level3()[1]]
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_sentences = [[str(c)] for c in set(ehr_full_data["AGE"].values.tolist())] # important of [[]]
adm_type_sentences = ehr_full_data["ADMISSION_TYPE"].values.tolist()
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
config.max_num_adms = max_admissions
print(f"Max admissions per patient: {config.max_num_adms}")

Max admissions per patient: 8


In [7]:
tokenizer = EHRTokenizer(age_sentences, group_code_sentences, diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, special_tokens=config.special_tokens, adm_types_sentences=adm_type_sentences)
config.label_vocab_size = len(tokenizer.diag_voc.id2word)  # only for diagnosis
config.global_vocab_size = len(tokenizer.vocab.id2word)
config.age_vocab_size = tokenizer.token_number("age")
config.group_code_vocab_size = tokenizer.token_number("group")
print(f"Age vocabulary size: {config.age_vocab_size}")
print(f"Group code vocabulary size: {config.group_code_vocab_size}")

Age vocabulary size: 18
Group code vocabulary size: 19


In [8]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))
# example label percentage
print("Percentage of DEATH in test dataset:",
      (test_data["DEATH"] == True).mean() * 100, "%")

print("Percentage of READMISSION in test dataset:",
      (test_data["READMISSION"] == 1).mean() * 100, "%")

print("Percentage of STAY>7 days in test dataset:",
      (test_data["STAY_DAYS"] > 7).mean() * 100, "%")

Percentage of DEATH in test dataset: 28.648477157360407 %
Percentage of READMISSION in test dataset: 40.1491116751269 %
Percentage of STAY>7 days in test dataset: 50.58692893401015 %


In [9]:
train_dataset = FineTuneEHRDataset(train_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                   max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)
val_dataset = FineTuneEHRDataset(val_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                 max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)
test_dataset = FineTuneEHRDataset(test_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                   max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)

In [10]:
num_group_code = []
for i in range(len(train_dataset)):
    input_ids, token_types, adm_index, age_ids, diag_group_codes, labels = train_dataset[i]
    count = (token_types[0] == 6).sum().item()
    num_group_code.append(count)
print("Mean group token numer per patient", np.mean(num_group_code))

Mean group token numer per patient 0.7971893963589908


In [11]:
train_dataloader = DataLoader(
    train_dataset,
    collate_fn=batcher(tokenizer, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=True,
    batch_size=config.batch_size,
)

val_dataloader = DataLoader(
    val_dataset,
    collate_fn=batcher(tokenizer, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

test_dataloader = DataLoader(
    test_dataset,
    collate_fn=batcher(tokenizer, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

In [12]:
for batch in train_dataloader:
    pass  # just to check if the dataloader works
for batch in val_dataloader:
    pass  # just to check if the dataloader works
for batch in test_dataloader:
    pass  # just to check if the dataloader works
print("All pass!")

All pass!


In [13]:
if curr_task in ["death", "stay", "readmission"]:
    eval_metric = "f1"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "f1"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [14]:
input_ids, token_types, adm_index, age_ids, diag_code_group_dicts, labels = next(iter(train_dataloader))
print("Input IDs shape:", input_ids.shape)
print("Token Types shape:", token_types.shape)
print("Admission Index shape:", adm_index.shape)
print("Age IDs shape:", age_ids.shape)
print("Diag Code Group Dict number:", len(diag_code_group_dicts))
print("Labels shape:", labels.shape)

Input IDs shape: torch.Size([32, 293])
Token Types shape: torch.Size([32, 293])
Admission Index shape: torch.Size([32, 293])
Age IDs shape: torch.Size([32, 8])
Diag Code Group Dict number: 32
Labels shape: torch.Size([32, 1])


# Model Walkthrough

In [15]:
# load pretrained model
if config.use_pretrained_model:
    pretrain_exp_name = (
    f"{config.dataset}-{config.pretrain_mask_rate}-{config.d_model}-{config.pretrain_cls_ontology_weight}-{config.pretrain_visit_ontology_weight}-{config.pretrain_adm_type_weight}"
)
    print(pretrain_exp_name)
    save_path = "./pretrained_models/" + pretrain_exp_name
    state_dict = torch.load(f"{save_path}/pretrained_model.pt", map_location="cpu")

MIMIC-III-0.7-64-0.05-0.05-0.05


In [None]:
final_metrics = []
for i in range(15):
    model = HeteroGTFineTune(tokenizer=tokenizer, token_types=config.token_type, d_model=config.d_model, num_heads=config.num_heads, layer_types=['gnn', 'tf', 'gnn', 'tf'], max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, label_vocab_size=config.label_vocab_size, attn_mask_dicts=config.attn_mask_dicts,
                     use_cls_cat=True).to(device)
    if config.use_pretrained_model:
        model.load_weight(state_dict)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             dec_loss_lambda=config.dec_loss_lambda, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    final_metrics.append(best_test_metric)



Epoch 001: 100%|██████████| 98/98 [00:06<00:00, 14.51it/s, loss=0.5771]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.02it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.12it/s]


Validation: {'precision': 0.6748027613395592, 'recall': 0.8582627782977164, 'f1': 0.755555550625073, 'auc': 0.8197307731067294, 'prauc': 0.8268807315986392}
Test:      {'precision': 0.6728031418736554, 'recall': 0.859517089994169, 'f1': 0.7547845193734123, 'auc': 0.8155210502394115, 'prauc': 0.8211445440919768}


Epoch 002: 100%|██████████| 98/98 [00:06<00:00, 15.12it/s, loss=0.5034]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.13it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 21.98it/s]


Validation: {'precision': 0.7420430433442193, 'recall': 0.7676387582290134, 'f1': 0.7546239160841912, 'auc': 0.8307834655514907, 'prauc': 0.841366802801847}
Test:      {'precision': 0.7362275449079754, 'recall': 0.7710881153942581, 'f1': 0.7532547047568383, 'auc': 0.8312451459597019, 'prauc': 0.8407584518726899}


Epoch 003: 100%|██████████| 98/98 [00:06<00:00, 15.46it/s, loss=0.4497]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.94it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.07it/s]


Validation: {'precision': 0.8009618941886757, 'recall': 0.6788962057049893, 'f1': 0.7348947676045045, 'auc': 0.8393449667628499, 'prauc': 0.8455511790117729}
Test:      {'precision': 0.7926062846551105, 'recall': 0.6723110692986131, 'f1': 0.7275195063987402, 'auc': 0.834269990089327, 'prauc': 0.8442693842294524}


Epoch 004: 100%|██████████| 98/98 [00:06<00:00, 15.70it/s, loss=0.4234]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.15it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 22.00it/s]


Validation: {'precision': 0.7746338215686597, 'recall': 0.7296958294113212, 'f1': 0.7514936168331346, 'auc': 0.8416369189627634, 'prauc': 0.8471270728218124}
Test:      {'precision': 0.7651888341518385, 'recall': 0.7306365631836607, 'f1': 0.7475136299056275, 'auc': 0.8377582550772695, 'prauc': 0.8445053133101323}


Epoch 005: 100%|██████████| 98/98 [00:06<00:00, 15.46it/s, loss=0.3735]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.93it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.23it/s]


Validation: {'precision': 0.73978241693402, 'recall': 0.7889620570687083, 'f1': 0.7635811786143899, 'auc': 0.8361162494954972, 'prauc': 0.8407854445834125}
Test:      {'precision': 0.7252018454419689, 'recall': 0.788648479144595, 'f1': 0.7555956086462902, 'auc': 0.8296478112210565, 'prauc': 0.8357787403367725}


Epoch 006: 100%|██████████| 98/98 [00:06<00:00, 15.45it/s, loss=0.3234]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.11it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 22.07it/s]


Validation: {'precision': 0.755810251509851, 'recall': 0.7444339918446397, 'f1': 0.7500789839394658, 'auc': 0.8317846341592561, 'prauc': 0.834855947777735}
Test:      {'precision': 0.7428393524260809, 'recall': 0.7481969269339975, 'f1': 0.7455085092923767, 'auc': 0.8283680810893385, 'prauc': 0.8345375165979549}


Epoch 007: 100%|██████████| 98/98 [00:06<00:00, 15.64it/s, loss=0.2774]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.34it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 22.00it/s]


Validation: {'precision': 0.7914831130661106, 'recall': 0.6760740043879709, 'f1': 0.7292406512097505, 'auc': 0.8323373893497443, 'prauc': 0.8341600185584677}
Test:      {'precision': 0.7888970051103401, 'recall': 0.6773283160844236, 'f1': 0.72886788772094, 'auc': 0.8302290125516736, 'prauc': 0.836496413604754}


Epoch 008: 100%|██████████| 98/98 [00:06<00:00, 15.66it/s, loss=0.2541]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.12it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 22.04it/s]


Validation: {'precision': 0.7156678312262412, 'recall': 0.8350580119133426, 'f1': 0.7707669993689404, 'auc': 0.834805371194979, 'prauc': 0.8353048768203521}
Test:      {'precision': 0.7032704068048304, 'recall': 0.8294136092793057, 'f1': 0.7611510741683687, 'auc': 0.8296143394201678, 'prauc': 0.8331096524012321}


Epoch 009: 100%|██████████| 98/98 [00:06<00:00, 15.42it/s, loss=0.2284]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.04it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.12it/s]


Validation: {'precision': 0.7570242656425382, 'recall': 0.7434932580723002, 'f1': 0.7501977485180457, 'auc': 0.828766559672392, 'prauc': 0.8314415798035637}
Test:      {'precision': 0.7474747474723881, 'recall': 0.7425525242999607, 'f1': 0.745005500739413, 'auc': 0.8269205892848964, 'prauc': 0.833753598287468}


Epoch 010: 100%|██████████| 98/98 [00:06<00:00, 15.35it/s, loss=0.1879]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.89it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 22.04it/s]


Validation: {'precision': 0.7322074788880211, 'recall': 0.7613671997467502, 'f1': 0.746502685237889, 'auc': 0.8184408435396868, 'prauc': 0.8144429263408942}
Test:      {'precision': 0.7263473053870468, 'recall': 0.7607400438985239, 'f1': 0.7431459591603, 'auc': 0.8150713704361954, 'prauc': 0.8140408994813709}


Epoch 011: 100%|██████████| 98/98 [00:06<00:00, 15.37it/s, loss=0.1697]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.11it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.32it/s]


Validation: {'precision': 0.7685694635461879, 'recall': 0.7008466603929105, 'f1': 0.733147444573576, 'auc': 0.8271223101600151, 'prauc': 0.8291752632940486}
Test:      {'precision': 0.7669198895001142, 'recall': 0.6964565694553263, 'f1': 0.7299917780823242, 'auc': 0.8292928087974966, 'prauc': 0.8348861507524836}


Epoch 012: 100%|██████████| 98/98 [00:06<00:00, 15.14it/s, loss=0.1471]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.77it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 22.01it/s]


Validation: {'precision': 0.7771739130406624, 'recall': 0.6726246472227262, 'f1': 0.7211295966372937, 'auc': 0.8227990843645184, 'prauc': 0.8282632114164241}
Test:      {'precision': 0.771681415926472, 'recall': 0.6835998745666868, 'f1': 0.724975053213445, 'auc': 0.8223995808223191, 'prauc': 0.827271242740526}


Epoch 013: 100%|██████████| 98/98 [00:06<00:00, 15.29it/s, loss=0.1420]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.85it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.11it/s]


Validation: {'precision': 0.7297859690822539, 'recall': 0.7698338036978055, 'f1': 0.7492751361580011, 'auc': 0.8214782591833572, 'prauc': 0.8207954513157901}
Test:      {'precision': 0.7248459958910975, 'recall': 0.774851050483616, 'f1': 0.7490148479890417, 'auc': 0.8215522157577186, 'prauc': 0.8234318245176353}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7156678312262412, 'recall': 0.8350580119133426, 'f1': 0.7707669993689404, 'auc': 0.834805371194979, 'prauc': 0.8353048768203521}
Corresponding test performance:
{'precision': 0.7032704068048304, 'recall': 0.8294136092793057, 'f1': 0.7611510741683687, 'auc': 0.8296143394201678, 'prauc': 0.8331096524012321}


Epoch 001: 100%|██████████| 98/98 [00:06<00:00, 15.67it/s, loss=0.5840]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.98it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.77it/s]


Validation: {'precision': 0.6634684576621169, 'recall': 0.867356538096998, 'f1': 0.7518347327871143, 'auc': 0.8188545433482547, 'prauc': 0.823861663335688}
Test:      {'precision': 0.6689221846286396, 'recall': 0.8679836939452242, 'f1': 0.7555616164821163, 'auc': 0.8177025056537143, 'prauc': 0.8210573354679316}


Epoch 002: 100%|██████████| 98/98 [00:06<00:00, 15.60it/s, loss=0.5005]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.97it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.34it/s]


Validation: {'precision': 0.7526847757398841, 'recall': 0.747256193161658, 'f1': 0.7499606558946369, 'auc': 0.8306122586361782, 'prauc': 0.8414151033300019}
Test:      {'precision': 0.7463099630973361, 'recall': 0.761053621822637, 'f1': 0.7536096829347942, 'auc': 0.8316778130280302, 'prauc': 0.8398901846905868}


Epoch 003: 100%|██████████| 98/98 [00:06<00:00, 15.48it/s, loss=0.4514]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.88it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 21.99it/s]


Validation: {'precision': 0.8074824629742499, 'recall': 0.6497334587624656, 'f1': 0.7200694998345454, 'auc': 0.8371970433851788, 'prauc': 0.8475333344821528}
Test:      {'precision': 0.8016910069146745, 'recall': 0.6541235497000498, 'f1': 0.720428245782783, 'auc': 0.8365127014159327, 'prauc': 0.8475629779656834}


Epoch 004: 100%|██████████| 98/98 [00:06<00:00, 15.60it/s, loss=0.4219]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.90it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.13it/s]


Validation: {'precision': 0.6890885750945082, 'recall': 0.841643148319719, 'f1': 0.7577639702028021, 'auc': 0.8200837366592486, 'prauc': 0.8285085769365297}
Test:      {'precision': 0.6868354430362359, 'recall': 0.8507369081190005, 'f1': 0.7600504222853905, 'auc': 0.8217120247318859, 'prauc': 0.8302057024536351}


Epoch 005: 100%|██████████| 98/98 [00:06<00:00, 15.59it/s, loss=0.3735]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.91it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.85it/s]


Validation: {'precision': 0.7231197771567601, 'recall': 0.8140482909977609, 'f1': 0.7658946697460222, 'auc': 0.8397886076868891, 'prauc': 0.8487314526565466}
Test:      {'precision': 0.7151432469284716, 'recall': 0.8218877391005899, 'f1': 0.7648088657543028, 'auc': 0.8360410258578469, 'prauc': 0.8430934667489693}


Epoch 006: 100%|██████████| 98/98 [00:06<00:00, 15.44it/s, loss=0.3285]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.90it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 20.87it/s]


Validation: {'precision': 0.712239935150737, 'recall': 0.8265914079622873, 'f1': 0.7651669035885243, 'auc': 0.8373025406041212, 'prauc': 0.8470954819109424}
Test:      {'precision': 0.715085287844576, 'recall': 0.8413295703956057, 'f1': 0.7730874464065509, 'auc': 0.8401359106116681, 'prauc': 0.8466006909701613}


Epoch 007: 100%|██████████| 98/98 [00:06<00:00, 15.63it/s, loss=0.3053]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.02it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.77it/s]


Validation: {'precision': 0.785413744737779, 'recall': 0.7024145500134763, 'f1': 0.7415990680142637, 'auc': 0.8301760025174652, 'prauc': 0.8400482952513257}
Test:      {'precision': 0.7754542337991929, 'recall': 0.7093132643439658, 'f1': 0.7409105747651106, 'auc': 0.8339524358159344, 'prauc': 0.8381713191256286}


Epoch 008: 100%|██████████| 98/98 [00:06<00:00, 15.66it/s, loss=0.2612]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.97it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.14it/s]


Validation: {'precision': 0.7164220437532085, 'recall': 0.8112260896807425, 'f1': 0.7608823479581951, 'auc': 0.8255613029770611, 'prauc': 0.8285882630253218}
Test:      {'precision': 0.7183639857240781, 'recall': 0.8206334274041373, 'f1': 0.7661006975959491, 'auc': 0.8297564813234901, 'prauc': 0.8312874106796706}


Epoch 009: 100%|██████████| 98/98 [00:06<00:00, 15.41it/s, loss=0.2217]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.80it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 22.08it/s]


Validation: {'precision': 0.704913606909544, 'recall': 0.8187519598594584, 'f1': 0.7575801488049089, 'auc': 0.8266765592915972, 'prauc': 0.8332102992512176}
Test:      {'precision': 0.7016574585616899, 'recall': 0.8363123236097952, 'f1': 0.7630901237915096, 'auc': 0.8311153357724965, 'prauc': 0.8384506516651578}


Epoch 010: 100%|██████████| 98/98 [00:06<00:00, 15.65it/s, loss=0.1944]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.92it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 20.98it/s]


Validation: {'precision': 0.7365960099727662, 'recall': 0.7409846346793949, 'f1': 0.7387837999062853, 'auc': 0.8175110613834061, 'prauc': 0.8255477858869429}
Test:      {'precision': 0.7460757156025052, 'recall': 0.7601128880502975, 'f1': 0.7530288859580204, 'auc': 0.8264841975349655, 'prauc': 0.8326506262981991}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7231197771567601, 'recall': 0.8140482909977609, 'f1': 0.7658946697460222, 'auc': 0.8397886076868891, 'prauc': 0.8487314526565466}
Corresponding test performance:
{'precision': 0.7151432469284716, 'recall': 0.8218877391005899, 'f1': 0.7648088657543028, 'auc': 0.8360410258578469, 'prauc': 0.8430934667489693}


Epoch 001: 100%|██████████| 98/98 [00:06<00:00, 15.66it/s, loss=0.5814]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.87it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.74it/s]


Validation: {'precision': 0.6671480144388275, 'recall': 0.869238005641677, 'f1': 0.7549019558687664, 'auc': 0.8082656870094442, 'prauc': 0.8081504276269926}
Test:      {'precision': 0.6665048543673143, 'recall': 0.8610849796147347, 'f1': 0.7514023757056935, 'auc': 0.8103684565775109, 'prauc': 0.8126415545532387}


Epoch 002: 100%|██████████| 98/98 [00:06<00:00, 15.38it/s, loss=0.5061]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.04it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.23it/s]


Validation: {'precision': 0.7353374594732233, 'recall': 0.7823769206623319, 'f1': 0.7581282235044745, 'auc': 0.8317841820283177, 'prauc': 0.8348473608723351}
Test:      {'precision': 0.7320588235272587, 'recall': 0.7804954531176529, 'f1': 0.7555015885678669, 'auc': 0.8272268185128756, 'prauc': 0.8348934017854026}


Epoch 003: 100%|██████████| 98/98 [00:06<00:00, 15.65it/s, loss=0.4513]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.95it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 21.90it/s]


Validation: {'precision': 0.7153324287633234, 'recall': 0.8265914079622873, 'f1': 0.7669479147212115, 'auc': 0.8376863495339887, 'prauc': 0.845320500484377}
Test:      {'precision': 0.7086110370548957, 'recall': 0.8334901222927767, 'f1': 0.7659942313418203, 'auc': 0.8344826996089588, 'prauc': 0.8431892151447173}


Epoch 004: 100%|██████████| 98/98 [00:06<00:00, 15.45it/s, loss=0.4063]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.94it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 20.99it/s]


Validation: {'precision': 0.7293211637172581, 'recall': 0.8018187519573478, 'f1': 0.763853617114977, 'auc': 0.835449456835009, 'prauc': 0.8445395964697192}
Test:      {'precision': 0.7233146067395413, 'recall': 0.8074631545913846, 'f1': 0.7630760062737755, 'auc': 0.8322940464991265, 'prauc': 0.8397296368946243}


Epoch 005: 100%|██████████| 98/98 [00:06<00:00, 15.42it/s, loss=0.3711]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.92it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 21.92it/s]


Validation: {'precision': 0.7610953729909944, 'recall': 0.7582314205056186, 'f1': 0.7596606924528622, 'auc': 0.8371375630484034, 'prauc': 0.8430402211395149}
Test:      {'precision': 0.7518450184478725, 'recall': 0.7666980244566739, 'f1': 0.7591988771592763, 'auc': 0.838637984604985, 'prauc': 0.8449162408632573}


Epoch 006: 100%|██████████| 98/98 [00:06<00:00, 15.59it/s, loss=0.3137]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.77it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.10it/s]


Validation: {'precision': 0.6996838777642264, 'recall': 0.8328629664445505, 'f1': 0.7604867523727321, 'auc': 0.8208336711756178, 'prauc': 0.8247186803459964}
Test:      {'precision': 0.7024967148470368, 'recall': 0.8381937911544742, 'f1': 0.7643694545733465, 'auc': 0.8273852684815933, 'prauc': 0.8305979581664524}


Epoch 007: 100%|██████████| 98/98 [00:06<00:00, 15.56it/s, loss=0.2743]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.94it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.87it/s]


Validation: {'precision': 0.7607223476273437, 'recall': 0.7397303229829423, 'f1': 0.7500794862545556, 'auc': 0.8313640519130715, 'prauc': 0.8398612350296016}
Test:      {'precision': 0.7593840230966975, 'recall': 0.7422389463758475, 'f1': 0.7507136010877096, 'auc': 0.8345492405424546, 'prauc': 0.8414620667795152}


Epoch 008: 100%|██████████| 98/98 [00:06<00:00, 15.38it/s, loss=0.2397]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.01it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.13it/s]


Validation: {'precision': 0.697963501717276, 'recall': 0.8275321417346267, 'f1': 0.7572453321931513, 'auc': 0.8248372906344894, 'prauc': 0.8308872792261702}
Test:      {'precision': 0.6932023778736284, 'recall': 0.8410159924714926, 'f1': 0.7599886603885482, 'auc': 0.8256196687348716, 'prauc': 0.8313115268862741}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7153324287633234, 'recall': 0.8265914079622873, 'f1': 0.7669479147212115, 'auc': 0.8376863495339887, 'prauc': 0.845320500484377}
Corresponding test performance:
{'precision': 0.7086110370548957, 'recall': 0.8334901222927767, 'f1': 0.7659942313418203, 'auc': 0.8344826996089588, 'prauc': 0.8431892151447173}


Epoch 001: 100%|██████████| 98/98 [00:06<00:00, 15.22it/s, loss=0.5743]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.26it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.89it/s]


Validation: {'precision': 0.7410296411833354, 'recall': 0.7447475697687528, 'f1': 0.7428839487043076, 'auc': 0.823070865295223, 'prauc': 0.8269596444672752}
Test:      {'precision': 0.739008419080951, 'recall': 0.743179680148187, 'f1': 0.7410881751102927, 'auc': 0.8165844971705003, 'prauc': 0.822400875854453}


Epoch 002: 100%|██████████| 98/98 [00:06<00:00, 15.62it/s, loss=0.4955]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.01it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.16it/s]


Validation: {'precision': 0.7648514851458104, 'recall': 0.6782690498567631, 'f1': 0.7189629333569745, 'auc': 0.8208248295039351, 'prauc': 0.8239824906084964}
Test:      {'precision': 0.7579281183905641, 'recall': 0.6745061147674052, 'f1': 0.7137879492206625, 'auc': 0.8142736342372734, 'prauc': 0.8197385962547254}


Epoch 003: 100%|██████████| 98/98 [00:06<00:00, 15.45it/s, loss=0.4604]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.76it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 21.92it/s]


Validation: {'precision': 0.7686809616609205, 'recall': 0.7419253684517343, 'f1': 0.7550662148810802, 'auc': 0.8380871384924288, 'prauc': 0.8482714402004221}
Test:      {'precision': 0.7620115310673863, 'recall': 0.7460018814652054, 'f1': 0.7539217189721881, 'auc': 0.8353581507861847, 'prauc': 0.8448788083983879}


Epoch 004: 100%|██████████| 98/98 [00:06<00:00, 15.56it/s, loss=0.4109]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.92it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 20.95it/s]


Validation: {'precision': 0.7639569049926087, 'recall': 0.7337723424247922, 'f1': 0.7485604556522274, 'auc': 0.83589591101822, 'prauc': 0.8435209327038548}
Test:      {'precision': 0.7583870967717472, 'recall': 0.7372216995900369, 'f1': 0.7476546300757425, 'auc': 0.8343333096765717, 'prauc': 0.8423216767752015}


Epoch 005: 100%|██████████| 98/98 [00:06<00:00, 15.29it/s, loss=0.3676]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.97it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.69it/s]


Validation: {'precision': 0.6950729009535066, 'recall': 0.8670429601728848, 'f1': 0.7715920140343052, 'auc': 0.8426877717369736, 'prauc': 0.8494583708762548}
Test:      {'precision': 0.6904761904744781, 'recall': 0.8730009407310348, 'f1': 0.7710843324154065, 'auc': 0.8379324091089606, 'prauc': 0.8447891206910372}


Epoch 006: 100%|██████████| 98/98 [00:06<00:00, 15.48it/s, loss=0.3284]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.00it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.27it/s]


Validation: {'precision': 0.7435590173734465, 'recall': 0.7783004076488608, 'f1': 0.7605331649098815, 'auc': 0.8385364059348113, 'prauc': 0.8435352931719894}
Test:      {'precision': 0.7426796805656827, 'recall': 0.7873941674481424, 'f1': 0.764383556645779, 'auc': 0.8378219269992606, 'prauc': 0.8416335608529212}


Epoch 007: 100%|██████████| 98/98 [00:06<00:00, 15.31it/s, loss=0.2924]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.00it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 21.99it/s]


Validation: {'precision': 0.7538119440890921, 'recall': 0.7441204139205264, 'f1': 0.748934822203148, 'auc': 0.8315174247747057, 'prauc': 0.840970965666193}
Test:      {'precision': 0.7540006275470537, 'recall': 0.7535277516439213, 'f1': 0.7537641104305095, 'auc': 0.8315943096931819, 'prauc': 0.8386212532035349}


Epoch 008: 100%|██████████| 98/98 [00:06<00:00, 15.57it/s, loss=0.2534]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.84it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.11it/s]


Validation: {'precision': 0.7602452403976759, 'recall': 0.7387895892106028, 'f1': 0.7493638626831193, 'auc': 0.8326194688184885, 'prauc': 0.8413176151428419}
Test:      {'precision': 0.759922555660665, 'recall': 0.7384760112864895, 'f1': 0.7490457965253595, 'auc': 0.8324256183600629, 'prauc': 0.8387005579741498}


Epoch 009:  88%|████████▊ | 86/98 [00:05<00:00, 15.84it/s, loss=0.2189]

In [None]:
def topk_avg_performance_formatted(performances, k=5):
    metrics = ["f1", "auc", "prauc"]
    scores = {m: np.array([p[m] for p in performances]) for m in metrics}

    # 计算排名（值越大排名越靠前）
    ranks = {m: (-scores[m]).argsort().argsort() + 1 for m in metrics}
    avg_ranks = np.mean(np.stack([ranks[m] for m in metrics], axis=1), axis=1)

    # 选 top-k
    topk_idx = np.argsort(avg_ranks)[:k]
    final_avg = {m: np.mean([performances[i][m] for i in topk_idx]) for m in performances[0].keys()}
    final_std = {m: np.std([performances[i][m] for i in topk_idx], ddof=0) for m in performances[0].keys()}

    # 打印结果（转百分比，均保留两位小数）
    print("Final Metrics:")
    for m in performances[0].keys():
        mean_val = final_avg[m] * 100
        std_val = final_std[m] * 100
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")

In [None]:
topk_avg_performance_formatted(final_metrics, 5)

Final Metrics:
precision: 62.22 ± 1.07
recall: 76.43 ± 2.00
f1: 68.57 ± 0.37
auc: 85.29 ± 0.16
prauc: 75.12 ± 0.43
