In [1]:
import torch
import random
from model import SetGNN 
import pickle
from tokenizer import EHRTokenizer
from dataset import FinetuneHGDataset, batcher_SetGNN_finetune
from torch.utils.data import DataLoader
import torch.nn.functional as F
from train import PHENO_ORDER, train_with_early_stopping
from set_seed import set_random_seed

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
args = {
    "dataset": "MIMIC-III", 
    "task": "readmission",  # options: death, stay, readmission, next_diag_6m, next_diag_12m
    "special_tokens":["[PAD]", "[CLS]"],
    "predicted_token_type": ["diag", "med", "lab", "pro"],
    "batch_size": 256,
    "lr": 1e-3,
    "epochs": 500,
    "model_name": "HG",
    "early_stop_patience": 10,
    # model hyperparameters
    "level": "visit",  # "visit" or "patient"
    "hg_all_num_layers": 3,
    "hg_use_type_embed": True,
    "MLP_num_layers": 2,
    "hg_aggregate": "mean",
    "hg_dropout": 0.0,
    "normtype": "all_one",
    "add_self_loop": True,
    "hg_normalization": "ln",
    "hg_hidden_size": 128,
    "PMA": True,
    "hg_num_heads": 4,
}

In [4]:
full_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic.pkl"

if args["task"] == "next_diag_6m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_6m.pkl"
elif args["task"] == "next_diag_12m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_downstream.pkl"

In [5]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_gender_sentences = ["[PAD]"] + [str(c) + "_" + gender \
    for c in set(ehr_full_data["AGE"].values.tolist()) for gender in ["M", "F"]] # PAD token special for age_gender vocabulary
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
args["max_adm_len"] = max_admissions
print(f"Max admissions per patient: {max_admissions}")

Max admissions per patient: 8


In [6]:
tokenizer = EHRTokenizer(age_gender_sentences, diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, special_tokens=args["special_tokens"])
args["age_gender_vocab_size"] = tokenizer.token_number("age_gender")
args["global_vocab_size"] = len(tokenizer.vocab.id2word)
args["label_vocab_size"] = len(PHENO_ORDER)
print(f"Age and gender vocabulary size: {args['age_gender_vocab_size']}")
print(f"Global vocabulary size: {args['global_vocab_size']}")
print(f"Label vocabulary size: {args['label_vocab_size']}")

Age and gender vocabulary size: 37
Global vocabulary size: 4446
Label vocabulary size: 18


In [7]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))

# output: input_ids (a patient has multiple visits), labels
train_dataset = FinetuneHGDataset(train_data, tokenizer, token_type=args["predicted_token_type"], task=args["task"], level=args["level"])
val_dataset = FinetuneHGDataset(val_data, tokenizer, token_type=args["predicted_token_type"], task=args["task"], level=args["level"])
test_dataset = FinetuneHGDataset(test_data, tokenizer, token_type=args["predicted_token_type"], task=args["task"], level=args["level"])
print(len(train_dataset), len(val_dataset), len(test_dataset))

3131 6310 6304


In [8]:
long_adm_seq_crite = 3
val_long_seq_idx, test_long_seq_idx = [], []
for i in range(len(val_dataset)):
    hadm_id = list(val_dataset.records.keys())[i]
    num_adms = len(val_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite:
        val_long_seq_idx.append(i)
for i in range(len(test_dataset)):
    hadm_id = list(test_dataset.records.keys())[i]
    num_adms = len(test_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite:
        test_long_seq_idx.append(i)
print(len(val_long_seq_idx), len(test_long_seq_idx))

835 854


In [9]:
use_full_graph = True
train_batch_size = len(train_dataset) if use_full_graph else args["batch_size"]
val_batch_size = len(val_dataset) if use_full_graph else args["batch_size"]
test_batch_size = len(test_dataset) if use_full_graph else args["batch_size"]
train_dataloader = DataLoader(train_dataset, collate_fn=batcher_SetGNN_finetune(device = device), batch_size = train_batch_size, shuffle = True)
val_dataloader = DataLoader(val_dataset, collate_fn=batcher_SetGNN_finetune(device = device), batch_size = val_batch_size, shuffle = False)
test_dataloader = DataLoader(test_dataset, collate_fn=batcher_SetGNN_finetune(device = device), batch_size = test_batch_size, shuffle = False)

In [10]:
if args["task"] in ["death", "stay", "readmission"]:
    eval_metric = "prauc"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "prauc"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [11]:
random.seed(42)
seeds = [random.randint(0, 2**32 - 1) for _ in range(15)]
print(seeds)

[2746317213, 1181241943, 958682846, 3163119785, 1812140441]


In [None]:
final_metrics, final_long_seq_metrics = [], []

for seed in seeds:
    set_random_seed(seed)
    print(f"Training with seed: {seed}")
    
    # Initialize model, optimizer, and loss function
    model = SetGNN(args, tokenizer).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args["lr"])
    
    best_test_metric, best_test_long_seq_metric = train_with_early_stopping(
        model, 
        train_dataloader, 
        val_dataloader, 
        test_dataloader,
        optimizer, 
        loss_fn, 
        device, 
        args,
        val_long_seq_idx,
        test_long_seq_idx,
        task_type=task_type,
        eval_metric = "f1")
    
    final_metrics.append(best_test_metric)
    final_long_seq_metrics.append(best_test_long_seq_metric)

[INFO] Random seed set to 2746317213
Training with seed: 2746317213


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.71s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.98s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  2.00s/it]


Epoch: 001, Average Loss: 0.6949
Validation: {'precision': 0.6059539052477403, 'recall': 0.7482213438705604, 'f1': 0.6696144273131643, 'auc': 0.7607799527364746, 'prauc': 0.6554701514531603}
Test:       {'precision': 0.5966016362473361, 'recall': 0.7491110233079846, 'f1': 0.6642143933803359, 'auc': 0.754512688305091, 'prauc': 0.6336997801174096}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.54it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.20s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.28s/it]


Epoch: 002, Average Loss: 0.6811
Validation: {'precision': 0.6364710515926094, 'recall': 0.638735177863088, 'f1': 0.6376010997518895, 'auc': 0.7646097099357969, 'prauc': 0.6611181803409429}
Test:       {'precision': 0.6386554621823183, 'recall': 0.630580798101025, 'f1': 0.634592440325711, 'auc': 0.760234894883618, 'prauc': 0.6404378899702141}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.04it/s]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  2.00s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.23s/it]


Epoch: 003, Average Loss: 0.6709
Validation: {'precision': 0.6445378151233423, 'recall': 0.6063241106695403, 'f1': 0.6248472455112863, 'auc': 0.7648252190643495, 'prauc': 0.6626842004470748}
Test:       {'precision': 0.6440823327588161, 'recall': 0.5934413275361776, 'f1': 0.6177256787401527, 'auc': 0.7609542023462471, 'prauc': 0.6422377374122807}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.20s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.98s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.29s/it]


Epoch: 004, Average Loss: 0.6610
Validation: {'precision': 0.6470337174534911, 'recall': 0.5992094861636396, 'f1': 0.6222039761252692, 'auc': 0.7646612083568605, 'prauc': 0.6625033770972801}
Test:       {'precision': 0.6453309951747246, 'recall': 0.5815883050154816, 'f1': 0.6118038187848517, 'auc': 0.760962527421699, 'prauc': 0.6428918679221126}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.06it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.29s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.94s/it]


Epoch: 005, Average Loss: 0.6510
Validation: {'precision': 0.6457707170431698, 'recall': 0.5944664031597057, 'f1': 0.61905741422807, 'auc': 0.7645927180709791, 'prauc': 0.66213109410354}
Test:       {'precision': 0.6461742591744972, 'recall': 0.5772421967578932, 'f1': 0.609766272141876, 'auc': 0.7608350333416654, 'prauc': 0.6439579909000015}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.20s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.96s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.27s/it]


Epoch: 006, Average Loss: 0.6430
Validation: {'precision': 0.6462526766567612, 'recall': 0.5964426877446781, 'f1': 0.6203494297434067, 'auc': 0.7644622205491771, 'prauc': 0.662049377217506}
Test:       {'precision': 0.6459161147874353, 'recall': 0.5780323982592729, 'f1': 0.6100917381321026, 'auc': 0.7607766007366069, 'prauc': 0.6447740903006746}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.07it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.24s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.29s/it]


Epoch: 007, Average Loss: 0.6337
Validation: {'precision': 0.6465333900440386, 'recall': 0.6007905138316174, 'f1': 0.6228231869730314, 'auc': 0.7644426668339711, 'prauc': 0.6619418800180537}
Test:       {'precision': 0.6454305799620148, 'recall': 0.5804030027634121, 'f1': 0.611192006661205, 'auc': 0.7608224671900399, 'prauc': 0.645298651014421}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.03it/s]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.95s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.19s/it]


Epoch: 008, Average Loss: 0.6260
Validation: {'precision': 0.646186440675228, 'recall': 0.6027667984165899, 'f1': 0.6237218763940852, 'auc': 0.7645604073864944, 'prauc': 0.6620016087011016}
Test:       {'precision': 0.6446280991707498, 'recall': 0.5855393125223803, 'f1': 0.6136645912822869, 'auc': 0.7609947281852393, 'prauc': 0.6456850993866612}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.21s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.91s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.25s/it]


Epoch: 009, Average Loss: 0.6188
Validation: {'precision': 0.646166807073942, 'recall': 0.6063241106695403, 'f1': 0.6256117405163745, 'auc': 0.7646902775163644, 'prauc': 0.6621164502919963}
Test:       {'precision': 0.6441850410694155, 'recall': 0.5887001185278993, 'f1': 0.6151940495079998, 'auc': 0.7611629051811605, 'prauc': 0.6461198554389364}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.05it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.25s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.95s/it]


Epoch: 010, Average Loss: 0.6110
Validation: {'precision': 0.6457023060769572, 'recall': 0.6086956521715071, 'f1': 0.6266530977484956, 'auc': 0.7648988853336679, 'prauc': 0.6623674716041775}
Test:       {'precision': 0.6446067898554164, 'recall': 0.592651126034798, 'f1': 0.6175380765213014, 'auc': 0.7613932846276277, 'prauc': 0.6467358804161426}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.24s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.97s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.23s/it]


Epoch: 011, Average Loss: 0.6051
Validation: {'precision': 0.6461538461511594, 'recall': 0.61422924900943, 'f1': 0.6297872290432088, 'auc': 0.7651560114603594, 'prauc': 0.6625739391485383}
Test:       {'precision': 0.6437713310552741, 'recall': 0.5962070327910067, 'f1': 0.6190769180817404, 'auc': 0.7616617814006923, 'prauc': 0.6474230899523232}

Early stopping triggered after 11 epochs (no improvement for 10 epochs).

Best validation performance:
{'precision': 0.6059539052477403, 'recall': 0.7482213438705604, 'f1': 0.6696144273131643, 'auc': 0.7607799527364746, 'prauc': 0.6554701514531603}
Corresponding test performance:
{'precision': 0.5966016362473361, 'recall': 0.7491110233079846, 'f1': 0.6642143933803359, 'auc': 0.754512688305091, 'prauc': 0.6336997801174096}
Corresponding test-long performance:
{'precision': 0.7716436637293395, 'recall': 0.9461538461392899, 'f1': 0.850034549290026, 'auc': 0.5609426847662141, 'prauc': 0.7855621052055546}
[INFO] Random seed set to 1181241943
Trainin

Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.06it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.20s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.28s/it]


Epoch: 001, Average Loss: 0.6971
Validation: {'precision': 0.5251030802800749, 'recall': 0.8557312252930603, 'f1': 0.6508342054154711, 'auc': 0.7546983813288162, 'prauc': 0.6490277878249386}
Test:       {'precision': 0.5191670610494105, 'recall': 0.8668510470135644, 'f1': 0.6494006168903321, 'auc': 0.7465282602801854, 'prauc': 0.6240625369929516}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.06it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.27s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.93s/it]


Epoch: 002, Average Loss: 0.6843
Validation: {'precision': 0.6434968017030128, 'recall': 0.5964426877446781, 'f1': 0.6190769180815839, 'auc': 0.7601032059727711, 'prauc': 0.6560659298312884}
Test:       {'precision': 0.6413556413528899, 'recall': 0.5906756222813486, 'f1': 0.6149732570380165, 'auc': 0.755174715059894, 'prauc': 0.6369001991169612}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.49s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.02s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.98s/it]


Epoch: 003, Average Loss: 0.6743
Validation: {'precision': 0.6631684157887898, 'recall': 0.5245059288516818, 'f1': 0.5857426567296607, 'auc': 0.7610208189556016, 'prauc': 0.6592176819135301}
Test:       {'precision': 0.6531234128966322, 'recall': 0.5080995653871667, 'f1': 0.5715555506310016, 'auc': 0.757992255690189, 'prauc': 0.6412941785467695}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.05it/s]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.97s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.25s/it]


Epoch: 004, Average Loss: 0.6654
Validation: {'precision': 0.6740780911026352, 'recall': 0.49130434782414506, 'f1': 0.5683584770591176, 'auc': 0.7601840349666437, 'prauc': 0.6600793825548477}
Test:       {'precision': 0.6601779755246932, 'recall': 0.4689845910688701, 'f1': 0.5483945435353665, 'auc': 0.7578912552464993, 'prauc': 0.6429434626787971}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.21s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.99s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  2.00s/it]


Epoch: 005, Average Loss: 0.6554
Validation: {'precision': 0.6807195053362524, 'recall': 0.4786561264803215, 'f1': 0.5620793639123239, 'auc': 0.7594169960474308, 'prauc': 0.6607998279445287}
Test:       {'precision': 0.6657077100076773, 'recall': 0.45713156854817416, 'f1': 0.5420473130430287, 'auc': 0.7572832105847208, 'prauc': 0.6438284476787166}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.21s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.99s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.00s/it]


Epoch: 006, Average Loss: 0.6472
Validation: {'precision': 0.6802952867650183, 'recall': 0.47351778655939325, 'f1': 0.5583779956240748, 'auc': 0.7592429470690341, 'prauc': 0.6613146014613582}
Test:       {'precision': 0.6691816598916472, 'recall': 0.4555511655454147, 'f1': 0.5420780393715355, 'auc': 0.7569679048968513, 'prauc': 0.6449591605883467}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.07it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.32s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.70s/it]


Epoch: 007, Average Loss: 0.6380
Validation: {'precision': 0.6795952782423856, 'recall': 0.4778656126463326, 'f1': 0.5611510742859649, 'auc': 0.7591848610326871, 'prauc': 0.661644663215896}
Test:       {'precision': 0.6672423719017431, 'recall': 0.45792177004955387, 'f1': 0.5431115228181115, 'auc': 0.7567327084255941, 'prauc': 0.6453308211367683}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.24s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.30s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.66s/it]


Epoch: 008, Average Loss: 0.6275
Validation: {'precision': 0.6771349862221646, 'recall': 0.48577075098622224, 'f1': 0.5657077051442986, 'auc': 0.7592438881569317, 'prauc': 0.6619105626264001}
Test:       {'precision': 0.6647759500813116, 'recall': 0.4630580798085221, 'f1': 0.5458779644168336, 'auc': 0.7567510340633813, 'prauc': 0.6458928885589871}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.28s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.94s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.29s/it]


Epoch: 009, Average Loss: 0.6182
Validation: {'precision': 0.6788596019328733, 'recall': 0.4988142292470403, 'f1': 0.5750740438725037, 'auc': 0.7595426312817618, 'prauc': 0.6621137223324778}
Test:       {'precision': 0.6618222470617003, 'recall': 0.4677992888168005, 'f1': 0.5481481432931168, 'auc': 0.7571313695859128, 'prauc': 0.6463684634491328}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.57it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.24s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.28s/it]


Epoch: 010, Average Loss: 0.6116
Validation: {'precision': 0.6743697478956179, 'recall': 0.5075098814209189, 'f1': 0.5791610235138286, 'auc': 0.7601792249618338, 'prauc': 0.6621686606608501}
Test:       {'precision': 0.6566757493152225, 'recall': 0.47609640458128766, 'f1': 0.5519926657612735, 'auc': 0.7578096276198986, 'prauc': 0.646882946927554}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.06it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.25s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.33s/it]


Epoch: 011, Average Loss: 0.6029
Validation: {'precision': 0.6692386831241295, 'recall': 0.5142292490098253, 'f1': 0.5815824716142461, 'auc': 0.7609604324821716, 'prauc': 0.6625223447507407}
Test:       {'precision': 0.6557203389795778, 'recall': 0.4891347293540532, 'f1': 0.5603077570404171, 'auc': 0.7585861110724237, 'prauc': 0.6473650392304368}

Early stopping triggered after 11 epochs (no improvement for 10 epochs).

Best validation performance:
{'precision': 0.5251030802800749, 'recall': 0.8557312252930603, 'f1': 0.6508342054154711, 'auc': 0.7546983813288162, 'prauc': 0.6490277878249386}
Corresponding test performance:
{'precision': 0.5191670610494105, 'recall': 0.8668510470135644, 'f1': 0.6494006168903321, 'auc': 0.7465282602801854, 'prauc': 0.6240625369929516}
Corresponding test-long performance:
{'precision': 0.767386091117897, 'recall': 0.9846153846002367, 'f1': 0.8625336877876142, 'auc': 0.5580844645550529, 'prauc': 0.7885898735615631}
[INFO] Random seed set to 958682846
Train

Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.53it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.28s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.05s/it]


Epoch: 001, Average Loss: 0.6894
Validation: {'precision': 0.6997792494429744, 'recall': 0.3758893280617554, 'f1': 0.4890717362582136, 'auc': 0.7607169521299957, 'prauc': 0.6531940618122754}
Test:       {'precision': 0.6777531411627661, 'recall': 0.36230738838260645, 'f1': 0.4721936102875818, 'auc': 0.7579217805231562, 'prauc': 0.6375031150000732}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.08it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.24s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.40s/it]


Epoch: 002, Average Loss: 0.6764
Validation: {'precision': 0.6909090909044384, 'recall': 0.40553359683634177, 'f1': 0.5110834324470013, 'auc': 0.7625178283873936, 'prauc': 0.6579796584347243}
Test:       {'precision': 0.6749663526199532, 'recall': 0.39628605294193486, 'f1': 0.4993776403446023, 'auc': 0.759503230705224, 'prauc': 0.6430286626326729}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.55it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.26s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.30s/it]


Epoch: 003, Average Loss: 0.6647
Validation: {'precision': 0.6937106918195365, 'recall': 0.4359683794449171, 'f1': 0.5354368884615587, 'auc': 0.763399104920844, 'prauc': 0.6596455999883414}
Test:       {'precision': 0.6747759282927351, 'recall': 0.416436191227118, 'f1': 0.5150256488325757, 'auc': 0.7601898661736268, 'prauc': 0.6451299036266667}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.54it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.24s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.99s/it]


Epoch: 004, Average Loss: 0.6539
Validation: {'precision': 0.686977299876422, 'recall': 0.45454545454365797, 'f1': 0.5470979971076435, 'auc': 0.7638530752661188, 'prauc': 0.6605886608622957}
Test:       {'precision': 0.6699147381201589, 'recall': 0.4346108257588518, 'f1': 0.5271986532641749, 'auc': 0.76056130067209, 'prauc': 0.6459306135183734}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.06it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.01s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.22s/it]


Epoch: 005, Average Loss: 0.6456
Validation: {'precision': 0.6779852857912961, 'recall': 0.47351778655939325, 'f1': 0.5575983195674336, 'auc': 0.7642525670786541, 'prauc': 0.6612485193713318}
Test:       {'precision': 0.6621936989459616, 'recall': 0.4484393520329971, 'f1': 0.5347467560778589, 'auc': 0.7609506419366199, 'prauc': 0.646722294318443}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.53s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.99s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.05s/it]


Epoch: 006, Average Loss: 0.6365
Validation: {'precision': 0.6756164383524624, 'recall': 0.4873517786542002, 'f1': 0.5662456897323336, 'auc': 0.7645337432293955, 'prauc': 0.6617277477303545}
Test:       {'precision': 0.6620267260542204, 'recall': 0.46977479257024984, 'f1': 0.5495724471870252, 'auc': 0.7612829119291838, 'prauc': 0.647516358096793}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.25s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.95s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.32s/it]


Epoch: 007, Average Loss: 0.6277
Validation: {'precision': 0.6696288551977542, 'recall': 0.5063241106699355, 'f1': 0.5766374022511551, 'auc': 0.7647964113181505, 'prauc': 0.6622904653575228}
Test:       {'precision': 0.6609349811893556, 'recall': 0.4859739233485343, 'f1': 0.5601092847312936, 'auc': 0.761554550240155, 'prauc': 0.6480629326278394}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.07it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.25s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.33s/it]


Epoch: 008, Average Loss: 0.6195
Validation: {'precision': 0.6664974619255508, 'recall': 0.518972332013759, 'f1': 0.5835555506303942, 'auc': 0.7649811259593868, 'prauc': 0.6626661550870936}
Test:       {'precision': 0.654564315349302, 'recall': 0.49861714737060997, 'f1': 0.5660461937881598, 'auc': 0.7617775994315074, 'prauc': 0.6484509756207005}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.52it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.31s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.30s/it]


Epoch: 009, Average Loss: 0.6113
Validation: {'precision': 0.6624203821623595, 'recall': 0.534387351776544, 'f1': 0.591555453378852, 'auc': 0.7651641152728109, 'prauc': 0.6628775518022783}
Test:       {'precision': 0.6530303030270049, 'recall': 0.5108652706419958, 'f1': 0.5732653464353905, 'auc': 0.7620152567741244, 'prauc': 0.648783008656223}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.06it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.26s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.02s/it]


Epoch: 010, Average Loss: 0.6041
Validation: {'precision': 0.6593824227997179, 'recall': 0.5486166007883454, 'f1': 0.5989212463878902, 'auc': 0.7653714160235899, 'prauc': 0.6630400852052301}
Test:       {'precision': 0.6500974658837715, 'recall': 0.5270644014202803, 'f1': 0.5821514242469289, 'auc': 0.7622631241149372, 'prauc': 0.6494627344312521}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.55s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.02s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.02s/it]


Epoch: 011, Average Loss: 0.5987
Validation: {'precision': 0.6546463245462106, 'recall': 0.559683794464191, 'f1': 0.6034519447403434, 'auc': 0.7656081519124998, 'prauc': 0.6631725950887173}
Test:       {'precision': 0.646117647055783, 'recall': 0.5424733306971851, 'f1': 0.5897766273378909, 'auc': 0.7625503130385447, 'prauc': 0.65004787643276}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.08it/s]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.94s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.26s/it]


Epoch: 012, Average Loss: 0.5922
Validation: {'precision': 0.6532912533784794, 'recall': 0.572727272725009, 'f1': 0.6103622528117741, 'auc': 0.7658594223811616, 'prauc': 0.6631108598489659}
Test:       {'precision': 0.6460055096389072, 'recall': 0.5559067562206405, 'f1': 0.5975790988692627, 'auc': 0.7628724253918782, 'prauc': 0.6506447155857673}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.24s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.99s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.00s/it]


Epoch: 013, Average Loss: 0.5875
Validation: {'precision': 0.6524633821542279, 'recall': 0.5810276679818932, 'f1': 0.6146769759884836, 'auc': 0.7661473952778302, 'prauc': 0.6633698451416734}
Test:       {'precision': 0.6436626071238446, 'recall': 0.5638087712344377, 'f1': 0.601095192997432, 'auc': 0.7632295135339024, 'prauc': 0.6513277903341583}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.23s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.01s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.99s/it]


Epoch: 014, Average Loss: 0.5848
Validation: {'precision': 0.65004374452909, 'recall': 0.587351778653805, 'f1': 0.6171096295617667, 'auc': 0.7664766714766715, 'prauc': 0.6637012485389953}
Test:       {'precision': 0.6441964285685527, 'recall': 0.5701303832454756, 'f1': 0.604904627168654, 'auc': 0.7635951885462041, 'prauc': 0.6518449373180761}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.05it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.32s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.68s/it]


Epoch: 015, Average Loss: 0.5810
Validation: {'precision': 0.6498696785375766, 'recall': 0.5913043478237499, 'f1': 0.6192052930218146, 'auc': 0.7668108622456449, 'prauc': 0.6641052934335098}
Test:       {'precision': 0.6445623342146571, 'recall': 0.5760568945058235, 'f1': 0.6083872263923051, 'auc': 0.7639392916648822, 'prauc': 0.6523993457115722}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.26s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.34s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.69s/it]


Epoch: 016, Average Loss: 0.5794
Validation: {'precision': 0.6489041684544525, 'recall': 0.5968379446616726, 'f1': 0.6217829886236333, 'auc': 0.7671603195516239, 'prauc': 0.6645002626415091}
Test:       {'precision': 0.6453590192616228, 'recall': 0.5823785065168614, 'f1': 0.6122533698808117, 'auc': 0.7642742319646665, 'prauc': 0.6528086319605793}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.25s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.01s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.32s/it]


Epoch: 017, Average Loss: 0.5785
Validation: {'precision': 0.6468354430352454, 'recall': 0.6059288537525458, 'f1': 0.625714280717063, 'auc': 0.7675408850408851, 'prauc': 0.6649798561043279}
Test:       {'precision': 0.6462994836461002, 'recall': 0.5934413275361776, 'f1': 0.6187435583433067, 'auc': 0.7645860296018739, 'prauc': 0.6531529362738207}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.56it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.18s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.25s/it]


Epoch: 018, Average Loss: 0.5780
Validation: {'precision': 0.6465481603941854, 'recall': 0.6181818181793748, 'f1': 0.6320468731571646, 'auc': 0.7678960411569108, 'prauc': 0.6653579130064482}
Test:       {'precision': 0.6457627118616706, 'recall': 0.6021335440513547, 'f1': 0.6231854376533283, 'auc': 0.7648820671905844, 'prauc': 0.6536696339933254}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.08it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.30s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.34s/it]


Epoch: 019, Average Loss: 0.5767
Validation: {'precision': 0.6444534741948621, 'recall': 0.6268774703532535, 'f1': 0.6355439741609015, 'auc': 0.7681995942865507, 'prauc': 0.6657792797060949}
Test:       {'precision': 0.646936223423731, 'recall': 0.6131963650706709, 'f1': 0.629614599463505, 'auc': 0.7651330760693036, 'prauc': 0.6540062866428029}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.53it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.33s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.06s/it]


Epoch: 020, Average Loss: 0.5763
Validation: {'precision': 0.6420884814641606, 'recall': 0.6367588932781156, 'f1': 0.6394125768590296, 'auc': 0.7684038103603321, 'prauc': 0.6659698714748151}
Test:       {'precision': 0.6447908121384546, 'recall': 0.6210983800844683, 'f1': 0.6327228768667839, 'auc': 0.7652932945025286, 'prauc': 0.6541837617867672}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.04it/s]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.97s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.31s/it]


Epoch: 021, Average Loss: 0.5763
Validation: {'precision': 0.6406003159532362, 'recall': 0.6411067193650549, 'f1': 0.6408534126189623, 'auc': 0.7685265700483092, 'prauc': 0.6661028814479197}
Test:       {'precision': 0.6438133874213233, 'recall': 0.6270248913448162, 'f1': 0.6353082415956073, 'auc': 0.765415814480877, 'prauc': 0.6541286538041932}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.26s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.98s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.32s/it]


Epoch: 022, Average Loss: 0.5741
Validation: {'precision': 0.6398891966733681, 'recall': 0.6391304347800825, 'f1': 0.6395095856638756, 'auc': 0.7686076081728256, 'prauc': 0.666412438600003}
Test:       {'precision': 0.6444534741948621, 'recall': 0.6266297905941264, 'f1': 0.6354166616651042, 'auc': 0.7655234645131355, 'prauc': 0.6540627279474861}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.06it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.28s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.99s/it]


Epoch: 023, Average Loss: 0.5746
Validation: {'precision': 0.6392405063265854, 'recall': 0.638735177863088, 'f1': 0.6389877371880633, 'auc': 0.7686847773804295, 'prauc': 0.6666822859476975}
Test:       {'precision': 0.6445344129528562, 'recall': 0.6290003950982656, 'f1': 0.6366726604651044, 'auc': 0.7655674460438247, 'prauc': 0.6541159308578248}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.26s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.01s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.33s/it]


Epoch: 024, Average Loss: 0.5728
Validation: {'precision': 0.6392156862720031, 'recall': 0.6442687747010108, 'f1': 0.641732278462118, 'auc': 0.7687445887445887, 'prauc': 0.6671283484334655}
Test:       {'precision': 0.6436041834246033, 'recall': 0.6321612011037845, 'f1': 0.6378313683285354, 'auc': 0.7656110087027931, 'prauc': 0.6538028907312932}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.04it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.27s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.30s/it]


Epoch: 025, Average Loss: 0.5716
Validation: {'precision': 0.6376021798340303, 'recall': 0.6474308300369667, 'f1': 0.6424789124325636, 'auc': 0.768773030512161, 'prauc': 0.6674223128871957}
Test:       {'precision': 0.6425140112063951, 'recall': 0.6341367048572338, 'f1': 0.6382978673381025, 'auc': 0.765659807258272, 'prauc': 0.6535732818501605}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.02it/s]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.95s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.28s/it]


Epoch: 026, Average Loss: 0.5699
Validation: {'precision': 0.6366099071182794, 'recall': 0.6501976284559281, 'f1': 0.6433320247203725, 'auc': 0.768806752828492, 'prauc': 0.6677134176632261}
Test:       {'precision': 0.6413738019143715, 'recall': 0.6345318056079237, 'f1': 0.6379344537860905, 'auc': 0.7656981863796949, 'prauc': 0.6534197452771339}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.27s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.99s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.32s/it]


Epoch: 027, Average Loss: 0.5701
Validation: {'precision': 0.6370686312499533, 'recall': 0.6494071146219391, 'f1': 0.6431786992453488, 'auc': 0.7688289206767467, 'prauc': 0.6678423315494013}
Test:       {'precision': 0.6407185628716938, 'recall': 0.6341367048572338, 'f1': 0.6374106383653543, 'auc': 0.7657527967803006, 'prauc': 0.6533788166304233}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.04it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.31s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.94s/it]


Epoch: 028, Average Loss: 0.5676
Validation: {'precision': 0.636750483556531, 'recall': 0.6505928853729226, 'f1': 0.6435972579501633, 'auc': 0.7688689169123952, 'prauc': 0.6680469203927057}
Test:       {'precision': 0.6411999999974353, 'recall': 0.6333465033558541, 'f1': 0.6372490508513637, 'auc': 0.7658477759430034, 'prauc': 0.6533666414944685}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.26s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.02s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.31s/it]


Epoch: 029, Average Loss: 0.5673
Validation: {'precision': 0.6367851622850201, 'recall': 0.6513833992069116, 'f1': 0.6440015581087156, 'auc': 0.7690301043561913, 'prauc': 0.6684033946685186}
Test:       {'precision': 0.6415999999974337, 'recall': 0.6337416041065439, 'f1': 0.6376465861326184, 'auc': 0.766076375184657, 'prauc': 0.6534401858945215}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.08it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.27s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.30s/it]


Epoch: 030, Average Loss: 0.5683
Validation: {'precision': 0.6364683301319137, 'recall': 0.6553359683768564, 'f1': 0.6457643572186099, 'auc': 0.7691034569295438, 'prauc': 0.6684728796452601}
Test:       {'precision': 0.6411177644684986, 'recall': 0.6345318056079237, 'f1': 0.6378077789531207, 'auc': 0.7663275934992365, 'prauc': 0.6536901286606418}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.05it/s]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.97s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.22s/it]


Epoch: 031, Average Loss: 0.5661
Validation: {'precision': 0.6343881052206085, 'recall': 0.6577075098788233, 'f1': 0.6458373712847809, 'auc': 0.7691050254093732, 'prauc': 0.6680136305494979}
Test:       {'precision': 0.6403960396014242, 'recall': 0.6388779138655122, 'f1': 0.639636070946844, 'auc': 0.7665480247423336, 'prauc': 0.6537431814304189}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.19s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.95s/it]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.27s/it]


Epoch: 032, Average Loss: 0.5661
Validation: {'precision': 0.633245382583365, 'recall': 0.6640316205507351, 'f1': 0.6482731958492436, 'auc': 0.7690710416797373, 'prauc': 0.6680235236195193}
Test:       {'precision': 0.6374705420242047, 'recall': 0.6412485183696514, 'f1': 0.6393539441801133, 'auc': 0.7667053634324779, 'prauc': 0.6537317794389288}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  1.09it/s]
Running inference: 100%|██████████| 1/1 [00:02<00:00,  2.26s/it]
Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.98s/it]


Epoch: 033, Average Loss: 0.5657
Validation: {'precision': 0.6313036981672346, 'recall': 0.66798418972068, 'f1': 0.649126171302625, 'auc': 0.7688698580002927, 'prauc': 0.6678180527380047}
Test:       {'precision': 0.6379914363540756, 'recall': 0.6475701303806892, 'f1': 0.6427450930369728, 'auc': 0.7666808070778429, 'prauc': 0.6535383100237918}


Training Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
import numpy as np
def topk_avg_performance_formatted(performances, long_seq_performances, k=5):
    metrics = ["f1", "auc", "prauc"]
    scores = {m: np.array([p[m] for p in performances]) for m in metrics}

    # 计算排名（值越大排名越靠前）
    ranks = {m: (-scores[m]).argsort().argsort() + 1 for m in metrics}
    avg_ranks = np.mean(np.stack([ranks[m] for m in metrics], axis=1), axis=1)

    # 选 top-k
    topk_idx = np.argsort(avg_ranks)[:k]
    final_avg = {m: np.mean([performances[i][m] for i in topk_idx]) for m in performances[0].keys()}
    final_std = {m: np.std([performances[i][m] for i in topk_idx], ddof=0) for m in performances[0].keys()}
    final_long_seq_avg = {m: np.mean([long_seq_performances[i][m] for i in topk_idx]) for m in long_seq_performances[0].keys()}
    final_long_seq_std = {m: np.std([long_seq_performances[i][m] for i in topk_idx], ddof=0) for m in long_seq_performances[0].keys()}

    # 打印结果（转百分比，均保留两位小数）
    print("Final Metrics:")
    for m in performances[0].keys():
        mean_val = final_avg[m] * 100
        std_val = final_std[m] * 100
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")
    print("\nFinal Long Sequence Metrics:")
    for m in long_seq_performances[0].keys():
        mean_val = final_long_seq_avg[m] * 100
        std_val = final_long_seq_std[m] * 100
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")

In [None]:
def print_per_class_performance(dfs, col_name="prauc"):
    """
    输入一个 DataFrame 列表，对每个疾病在所有表格的指定列计算 mean ± std 并打印。

    参数:
        dfs (list[pd.DataFrame]): 多个表格组成的列表
        col_name (str): 要计算的指标列名 (默认: "prauc")
    """
    # 拼接所有表格
    all_values = pd.concat(dfs, axis=0)

    # 按疾病分组，计算 mean 和 std
    grouped = all_values.groupby(all_values.index)[col_name].agg(["mean", "std"])

    # 打印
    for disease, row in grouped.iterrows():
        mean_val = row["mean"] * 100
        std_val = row["std"] * 100
        print(f"{disease}: {mean_val:.2f} ± {std_val:.2f}")

In [None]:
if task_type == "binary":
    topk_avg_performance_formatted(final_metrics, final_long_seq_metrics)
else:
    final_metrics_global = [metrics["global"] for metrics in final_metrics]
    final_metrics_per_class = [metrics["per_class"] for metrics in final_metrics]
    final_long_seq_metrics_global = [metrics["global"] for metrics in final_long_seq_metrics]
    final_long_seq_metrics_per_class = [metrics["per_class"] for metrics in final_long_seq_metrics]
    topk_avg_performance_formatted(final_metrics_global, final_long_seq_metrics_global)
    print("\nPer-class performance, all patients:")
    print_per_class_performance(final_metrics_per_class, col_name="prauc")
    print("\nPer-class performance, long seq:")
    print_per_class_performance(final_long_seq_metrics_per_class, col_name="prauc")