In [1]:
import torch
import pickle
import numpy as np
from argparse import Namespace
from torch.utils.data import DataLoader
import torch.nn.functional as F
from heterogt.utils.tokenizer import EHRTokenizer
from heterogt.utils.dataset import FineTuneEHRDataset, batcher, expand_level3
from heterogt.utils.train import train_with_early_stopping
from heterogt.utils.seed import set_random_seed
from heterogt.model.model import HeteroGT

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
set_random_seed(123)

[INFO] Random seed set to 123


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
config = Namespace(
    dataset = "MIMIC-III",
    tasks = ["death", "readmission", "stay", "next_diag_6m", "next_diag_12m"], 
    task_index = 2,  # index of the task to train
    token_type = ["diag", "med", "lab", "pro"],
    special_tokens = ["[PAD]", "[CLS]"],
    batch_size = 32,
    lr = 1e-3,
    epochs = 500,
    early_stop_patience = 5,
    group_code_thre = 5,  # if there are group_code_thre diag codes belongs to the same group ICD code, then the group code is generated
)

In [5]:
full_data_path = f"./data_process/{config.dataset}-processed/mimic.pkl"  # for tokenizer
curr_task = config.tasks[config.task_index]
print("Current task:", curr_task)
if curr_task == "next_diag_6m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_6m.pkl"
elif curr_task == "next_diag_12m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_downstream.pkl"

Current task: stay


In [6]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
group_code_sentences = [expand_level3()[1]]
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_sentences = [[str(c)] for c in set(ehr_full_data["AGE"].values.tolist())] # important of [[]]
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
config.max_num_adms = max_admissions
print(f"Max admissions per patient: {config.max_num_adms}")

Max admissions per patient: 8


In [7]:
tokenizer = EHRTokenizer(age_sentences, group_code_sentences, diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, special_tokens=config.special_tokens)
config.label_vocab_size = len(tokenizer.diag_voc.id2word)  # only for diagnosis
config.global_vocab_size = len(tokenizer.vocab.id2word)
config.age_vocab_size = tokenizer.token_number("age")
config.group_code_vocab_size = tokenizer.token_number("group")
print(f"Age vocabulary size: {config.age_vocab_size}")
print(f"Group code vocabulary size: {config.group_code_vocab_size}")

Age vocabulary size: 18
Group code vocabulary size: 19


In [8]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))
# example label percentage
print("Percentage of DEATH in test dataset:",
      (test_data["DEATH"] == True).mean() * 100, "%")

print("Percentage of READMISSION in test dataset:",
      (test_data["READMISSION"] == 1).mean() * 100, "%")

print("Percentage of STAY>7 days in test dataset:",
      (test_data["STAY_DAYS"] > 7).mean() * 100, "%")

Percentage of DEATH in test dataset: 28.648477157360407 %
Percentage of READMISSION in test dataset: 40.1491116751269 %
Percentage of STAY>7 days in test dataset: 50.58692893401015 %


In [9]:
train_dataset = FineTuneEHRDataset(train_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                   max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)
val_dataset = FineTuneEHRDataset(val_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                 max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)
test_dataset = FineTuneEHRDataset(test_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                   max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)

In [10]:
num_group_code = []
for i in range(len(train_dataset)):
    input_ids, token_types, adm_index, age_ids, diag_group_codes, labels = train_dataset[i]
    count = (token_types[0] == 6).sum().item()
    num_group_code.append(count)
print("Mean group token numer per patient", np.mean(num_group_code))

Mean group token numer per patient 0.7971893963589908


In [11]:
train_dataloader = DataLoader(
    train_dataset,
    collate_fn=batcher(tokenizer, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=True,
    batch_size=config.batch_size,
)

val_dataloader = DataLoader(
    val_dataset,
    collate_fn=batcher(tokenizer, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

test_dataloader = DataLoader(
    test_dataset,
    collate_fn=batcher(tokenizer, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

In [12]:
for batch in train_dataloader:
    pass  # just to check if the dataloader works
for batch in val_dataloader:
    pass  # just to check if the dataloader works
for batch in test_dataloader:
    pass  # just to check if the dataloader works
print("All pass!")

All pass!


In [13]:
if curr_task in ["death", "stay", "readmission"]:
    eval_metric = "f1"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "f1"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [14]:
input_ids, token_types, adm_index, age_ids, diag_code_group_dicts, labels = next(iter(train_dataloader))
print("Input IDs shape:", input_ids.shape)
print("Token Types shape:", token_types.shape)
print("Admission Index shape:", adm_index.shape)
print("Age IDs shape:", age_ids.shape)
print("Diag Code Group Dict number:", len(diag_code_group_dicts))
print("Labels shape:", labels.shape)

Input IDs shape: torch.Size([32, 293])
Token Types shape: torch.Size([32, 293])
Admission Index shape: torch.Size([32, 293])
Age IDs shape: torch.Size([32, 8])
Diag Code Group Dict number: 32
Labels shape: torch.Size([32, 1])


# Model Walkthrough

In [15]:
attn_mask_dicts = [{1:[6,7], 2:[6,7], 3:[6,7], 4:[6,7], 6:[2,3,4,5,6,7], 7:[2,3,4,5,6,7]}, 
                   {1:[6,7], 2:[6,7], 3:[6,7], 4:[6,7], 6:[2,3,4,5,6,7], 7:[2,3,4,5,6,7]}]

In [None]:
final_metrics = []
for i in range(10):
    model = HeteroGT(tokenizer, d_model=64, num_heads=4, layer_types=['gnn', 'tf', 'gnn', 'tf'], max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, label_vocab_size=config.label_vocab_size, attn_mask_dicts=attn_mask_dicts,
                     use_cls_cat=True).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    final_metrics.append(best_test_metric)

Epoch 001: 100%|██████████| 98/98 [00:07<00:00, 13.55it/s, loss=0.7189]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.97it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.50it/s]


Validation: {'precision': 0.5480472297901035, 'recall': 0.9460645970494009, 'f1': 0.6940418633081348, 'auc': 0.6809048225190144, 'prauc': 0.6952668485660257}
Test:      {'precision': 0.5453720508157072, 'recall': 0.942301661960043, 'f1': 0.690884005010638, 'auc': 0.664839760674107, 'prauc': 0.6740540901916539}


Epoch 002: 100%|██████████| 98/98 [00:06<00:00, 16.31it/s, loss=0.6250]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.18it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 22.14it/s]


Validation: {'precision': 0.6847654190809574, 'recall': 0.8146754468459873, 'f1': 0.7440927918276102, 'auc': 0.7995612119480322, 'prauc': 0.8027942756501161}
Test:      {'precision': 0.6749934262406653, 'recall': 0.8049545311984793, 'f1': 0.7342677295902326, 'auc': 0.7890775725343993, 'prauc': 0.7900197909615001}


Epoch 003: 100%|██████████| 98/98 [00:06<00:00, 16.30it/s, loss=0.5580]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.20it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.64it/s]


Validation: {'precision': 0.8224508050052663, 'recall': 0.5766698024440995, 'f1': 0.677972345382212, 'auc': 0.8179570634356788, 'prauc': 0.8248015587838707}
Test:      {'precision': 0.8116453462690267, 'recall': 0.5769833803682127, 'f1': 0.6744867986593744, 'auc': 0.8095100181351728, 'prauc': 0.819089370166795}


Epoch 004: 100%|██████████| 98/98 [00:05<00:00, 16.48it/s, loss=0.5341]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.38it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 22.34it/s]


Validation: {'precision': 0.791131855306141, 'recall': 0.6378174976461656, 'f1': 0.7062499950551052, 'auc': 0.8104633447903313, 'prauc': 0.8180769272161461}
Test:      {'precision': 0.7621560497521216, 'recall': 0.6340545625568076, 'f1': 0.6922286838449236, 'auc': 0.8019092013225639, 'prauc': 0.813566320626925}


Epoch 005: 100%|██████████| 98/98 [00:06<00:00, 16.30it/s, loss=0.4788]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.44it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.66it/s]


Validation: {'precision': 0.7796914244679595, 'recall': 0.6814048290978947, 'f1': 0.7272422975636992, 'auc': 0.8253373976890482, 'prauc': 0.8325828139910577}
Test:      {'precision': 0.7708484408964074, 'recall': 0.6666666666645762, 'f1': 0.7149823390628689, 'auc': 0.8138634662591664, 'prauc': 0.8221454841980177}


Epoch 006: 100%|██████████| 98/98 [00:05<00:00, 16.38it/s, loss=0.4520]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.33it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 22.19it/s]


Validation: {'precision': 0.7018261106549419, 'recall': 0.8074631545913846, 'f1': 0.7509477932141968, 'auc': 0.8181661991130398, 'prauc': 0.8255635981487535}
Test:      {'precision': 0.6964732018864302, 'recall': 0.7864534336758029, 'f1': 0.7387334265331695, 'auc': 0.8054365754673343, 'prauc': 0.8160286823341312}


Epoch 007: 100%|██████████| 98/98 [00:05<00:00, 16.40it/s, loss=0.4172]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.36it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.63it/s]


Validation: {'precision': 0.8015444015413068, 'recall': 0.6509877704589182, 'f1': 0.7184633970238911, 'auc': 0.8303334445575442, 'prauc': 0.8352709271745429}
Test:      {'precision': 0.7943925233613925, 'recall': 0.6396989651908445, 'f1': 0.708702444248004, 'auc': 0.8184353115922661, 'prauc': 0.8272012544421933}


Epoch 008: 100%|██████████| 98/98 [00:05<00:00, 16.39it/s, loss=0.3959]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.35it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 22.32it/s]


Validation: {'precision': 0.6956521739112324, 'recall': 0.837880213230361, 'f1': 0.7601706920536483, 'auc': 0.8160451524078134, 'prauc': 0.8175277639763113}
Test:      {'precision': 0.6801432958017397, 'recall': 0.8334901222927767, 'f1': 0.7490488889480388, 'auc': 0.8088530145005881, 'prauc': 0.8165893868377709}


Epoch 009: 100%|██████████| 98/98 [00:05<00:00, 16.35it/s, loss=0.3784]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.31it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.41it/s]


Validation: {'precision': 0.7361801696380925, 'recall': 0.7892756349928214, 'f1': 0.7618038690957614, 'auc': 0.8265444365840643, 'prauc': 0.8311651226093715}
Test:      {'precision': 0.7159883720909419, 'recall': 0.7723424270907108, 'f1': 0.7430985015670022, 'auc': 0.8162221963843408, 'prauc': 0.8234686307517596}


Epoch 010: 100%|██████████| 98/98 [00:05<00:00, 16.53it/s, loss=0.3659]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.54it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 22.34it/s]


Validation: {'precision': 0.6840529074103219, 'recall': 0.859517089994169, 'f1': 0.7618121129057382, 'auc': 0.8222882266409817, 'prauc': 0.8266865106385122}
Test:      {'precision': 0.6757979391790003, 'recall': 0.8432110379402847, 'f1': 0.750279012915783, 'auc': 0.8117049126033662, 'prauc': 0.8172713765474119}


Epoch 011: 100%|██████████| 98/98 [00:05<00:00, 16.46it/s, loss=0.3324]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.47it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.69it/s]


Validation: {'precision': 0.7567304573442857, 'recall': 0.731577296956, 'f1': 0.7439413215296683, 'auc': 0.8221078263965897, 'prauc': 0.8250479286761008}
Test:      {'precision': 0.7378826530588716, 'recall': 0.72561931639785, 'f1': 0.7316995997411205, 'auc': 0.8128921800309753, 'prauc': 0.8164375288668286}


Epoch 012: 100%|██████████| 98/98 [00:05<00:00, 16.46it/s, loss=0.3203]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.43it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 22.41it/s]


Validation: {'precision': 0.7762212643650279, 'recall': 0.6776418940085367, 'f1': 0.7235894810409901, 'auc': 0.8173687406113754, 'prauc': 0.8197216231229482}
Test:      {'precision': 0.7616855524052348, 'recall': 0.6745061147674052, 'f1': 0.7154498536556582, 'auc': 0.8076749581099154, 'prauc': 0.8098434335690865}


Epoch 013: 100%|██████████| 98/98 [00:05<00:00, 16.43it/s, loss=0.3023]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.47it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.74it/s]


Validation: {'precision': 0.7390357698266282, 'recall': 0.7450611476928659, 'f1': 0.7420362223556662, 'auc': 0.8064740930479443, 'prauc': 0.8006381269142404}
Test:      {'precision': 0.7260104905870842, 'recall': 0.7378488554382633, 'f1': 0.7318817990415964, 'auc': 0.8020405215158246, 'prauc': 0.801447729511845}


Epoch 014: 100%|██████████| 98/98 [00:05<00:00, 16.46it/s, loss=0.2838]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.46it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 22.41it/s]


Validation: {'precision': 0.686675291071962, 'recall': 0.8322358105963241, 'f1': 0.7524808569660991, 'auc': 0.8081356742462902, 'prauc': 0.80703658205213}
Test:      {'precision': 0.6778697001016601, 'recall': 0.8222013170247031, 'f1': 0.7430919604685836, 'auc': 0.8015190660914551, 'prauc': 0.8019501091483778}


Epoch 015: 100%|██████████| 98/98 [00:05<00:00, 16.53it/s, loss=0.2622]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.47it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.70it/s]


Validation: {'precision': 0.7028068244339494, 'recall': 0.8008780181850083, 'f1': 0.748644286386508, 'auc': 0.8022990154899055, 'prauc': 0.7933333319719723}
Test:      {'precision': 0.6922653454426307, 'recall': 0.788648479144595, 'f1': 0.7373204289093725, 'auc': 0.7988607507649439, 'prauc': 0.7972383259985933}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6840529074103219, 'recall': 0.859517089994169, 'f1': 0.7618121129057382, 'auc': 0.8222882266409817, 'prauc': 0.8266865106385122}
Corresponding test performance:
{'precision': 0.6757979391790003, 'recall': 0.8432110379402847, 'f1': 0.750279012915783, 'auc': 0.8117049126033662, 'prauc': 0.8172713765474119}


Epoch 001: 100%|██████████| 98/98 [00:05<00:00, 16.53it/s, loss=0.6869]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.44it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 22.38it/s]


Validation: {'precision': 0.6980332829025778, 'recall': 0.7234242709290579, 'f1': 0.7105019968472668, 'auc': 0.7709040981047777, 'prauc': 0.7726237842208918}
Test:      {'precision': 0.6742490521998418, 'recall': 0.7249921605496238, 'f1': 0.698700508754842, 'auc': 0.7582330311811216, 'prauc': 0.7567238090204421}


Epoch 002: 100%|██████████| 98/98 [00:05<00:00, 16.41it/s, loss=0.6159]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.31it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.53it/s]


Validation: {'precision': 0.7982967279211193, 'recall': 0.5584822828455362, 'f1': 0.6571955671095022, 'auc': 0.7970321421893526, 'prauc': 0.7990587085606558}
Test:      {'precision': 0.7708245243096371, 'recall': 0.571652555658289, 'f1': 0.6564638049744546, 'auc': 0.7903113984820412, 'prauc': 0.7964382877151437}


Epoch 003: 100%|██████████| 98/98 [00:05<00:00, 16.36it/s, loss=0.5531]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.21it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 22.14it/s]


Validation: {'precision': 0.6285714285700681, 'recall': 0.910630291624614, 'f1': 0.743757198393044, 'auc': 0.8105561823430008, 'prauc': 0.8160659398929025}
Test:      {'precision': 0.6237942122173123, 'recall': 0.9125117591692928, 'f1': 0.7410236773748538, 'auc': 0.8048526561258178, 'prauc': 0.8152749485076776}


Epoch 004: 100%|██████████| 98/98 [00:05<00:00, 16.36it/s, loss=0.5303]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.18it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.41it/s]


Validation: {'precision': 0.7269575267242793, 'recall': 0.7889620570687083, 'f1': 0.7566917243293977, 'auc': 0.8246565387327012, 'prauc': 0.8342729117148147}
Test:      {'precision': 0.7129629629609625, 'recall': 0.7968015051715371, 'f1': 0.7525544152708529, 'auc': 0.8214961441995383, 'prauc': 0.8344356154795632}


Epoch 005: 100%|██████████| 98/98 [00:06<00:00, 16.20it/s, loss=0.4988]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.33it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 22.39it/s]


Validation: {'precision': 0.7684210526290512, 'recall': 0.7325180307283395, 'f1': 0.7500401298535592, 'auc': 0.8304771719591606, 'prauc': 0.8393102556746204}
Test:      {'precision': 0.7532299741577738, 'recall': 0.7312637190318869, 'f1': 0.742084322763252, 'auc': 0.8275398427681029, 'prauc': 0.8406730174585169}


Epoch 006: 100%|██████████| 98/98 [00:05<00:00, 16.51it/s, loss=0.4592]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.46it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.38it/s]


Validation: {'precision': 0.6710747775891535, 'recall': 0.875195986199827, 'f1': 0.7596624882804912, 'auc': 0.8284442907869078, 'prauc': 0.8365107195804611}
Test:      {'precision': 0.6750911300105101, 'recall': 0.8711194731863559, 'f1': 0.7606790750344712, 'auc': 0.8279638021348465, 'prauc': 0.8370332184395309}


Epoch 007: 100%|██████████| 98/98 [00:05<00:00, 16.34it/s, loss=0.4473]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.44it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 22.38it/s]


Validation: {'precision': 0.6727359846653744, 'recall': 0.8805268109097506, 'f1': 0.762732576915469, 'auc': 0.8318543125605291, 'prauc': 0.8379244002233328}
Test:      {'precision': 0.6703587767862501, 'recall': 0.8730009407310348, 'f1': 0.7583764592628298, 'auc': 0.8254213546062987, 'prauc': 0.8350442311491446}


Epoch 008: 100%|██████████| 98/98 [00:05<00:00, 16.50it/s, loss=0.4178]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.43it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.68it/s]


Validation: {'precision': 0.697532434493265, 'recall': 0.8598306679182821, 'f1': 0.7702247141532622, 'auc': 0.8340511665530812, 'prauc': 0.8428543056071005}
Test:      {'precision': 0.6927056389014674, 'recall': 0.83976168077504, 'f1': 0.7591778830665686, 'auc': 0.8259588160948526, 'prauc': 0.8342990491455897}


Epoch 009: 100%|██████████| 98/98 [00:05<00:00, 16.50it/s, loss=0.3800]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.43it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 22.41it/s]


Validation: {'precision': 0.7915419161647024, 'recall': 0.6632173094993314, 'f1': 0.7217198380666422, 'auc': 0.8251140450055154, 'prauc': 0.8348302474457654}
Test:      {'precision': 0.7870894677206981, 'recall': 0.6538099717759367, 'f1': 0.7142857093260463, 'auc': 0.819066443789773, 'prauc': 0.8279939905402085}


Epoch 010:  76%|███████▌  | 74/98 [00:04<00:01, 16.35it/s, loss=0.3667]

In [None]:
def topk_avg_performance_formatted(performances, k=5):
    metrics = ["f1", "auc", "prauc"]
    scores = {m: np.array([p[m] for p in performances]) for m in metrics}

    # 计算排名（值越大排名越靠前）
    ranks = {m: (-scores[m]).argsort().argsort() + 1 for m in metrics}
    avg_ranks = np.mean(np.stack([ranks[m] for m in metrics], axis=1), axis=1)

    # 选 top-k
    topk_idx = np.argsort(avg_ranks)[:k]
    final_avg = {m: np.mean([performances[i][m] for i in topk_idx]) for m in performances[0].keys()}
    final_std = {m: np.std([performances[i][m] for i in topk_idx], ddof=0) for m in performances[0].keys()}

    # 打印结果
    print("Final Metrics:")
    for m in performances[0].keys():
        print(f"{m}: {final_avg[m]:.4f}±{final_std[m]:.4f}")

In [None]:
topk_avg_performance_formatted(final_metrics, 5)

Final Metrics:
precision: 0.6955±0.0287
recall: 0.8405±0.0428
f1: 0.7595±0.0028
auc: 0.8283±0.0041
prauc: 0.8352±0.0060
