In [1]:
import torch
import random
from model import SetGNN 
import pickle
from tokenizer import EHRTokenizer
from dataset import FinetuneHGDataset, batcher_SetGNN_finetune
from torch.utils.data import DataLoader
import torch.nn.functional as F
from train import PHENO_ORDER, train_with_early_stopping
from set_seed import set_random_seed
import pandas as pd

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
args = {
    "dataset": "MIMIC-III", 
    "task": "next_diag_12m",  # options: death, stay, readmission, next_diag_6m, next_diag_12m
    "special_tokens":["[PAD]", "[CLS]"],
    "predicted_token_type": ["diag", "med"],
    "batch_size": 256,
    "lr": 1e-3,
    "epochs": 500,
    "model_name": "HG",
    "early_stop_patience": 10,
    # model hyperparameters
    "level": "patient",  # "visit" or "patient"
    "hg_all_num_layers": 3,
    "hg_use_type_embed": True,
    "MLP_num_layers": 2,
    "hg_aggregate": "mean",
    "hg_dropout": 0.0,
    "normtype": "all_one",
    "add_self_loop": True,
    "hg_normalization": "ln",
    "hg_hidden_size": 128,
    "PMA": True,
    "hg_num_heads": 4,
}

In [4]:
full_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic.pkl"

if args["task"] == "next_diag_6m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_6m.pkl"
elif args["task"] == "next_diag_12m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_downstream.pkl"

In [5]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
# lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
# pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
lab_sentences = [[]]
pro_sentences = [[]]
age_gender_sentences = ["[PAD]"] + [str(c) + "_" + gender \
    for c in set(ehr_full_data["AGE"].values.tolist()) for gender in ["M", "F"]] # PAD token special for age_gender vocabulary
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
args["max_adm_len"] = max_admissions
print(f"Max admissions per patient: {max_admissions}")

Max admissions per patient: 8


In [6]:
tokenizer = EHRTokenizer(age_gender_sentences, diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, special_tokens=args["special_tokens"])
args["age_gender_vocab_size"] = tokenizer.token_number("age_gender")
args["global_vocab_size"] = len(tokenizer.vocab.id2word)
args["label_vocab_size"] = len(PHENO_ORDER)
print(f"Age and gender vocabulary size: {args['age_gender_vocab_size']}")
print(f"Global vocabulary size: {args['global_vocab_size']}")
print(f"Label vocabulary size: {args['label_vocab_size']}")

Age and gender vocabulary size: 37
Global vocabulary size: 2145
Label vocabulary size: 18


In [7]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))

# output: input_ids (a patient has multiple visits), labels
train_dataset = FinetuneHGDataset(train_data, tokenizer, token_type=args["predicted_token_type"], task=args["task"], level=args["level"])
val_dataset = FinetuneHGDataset(val_data, tokenizer, token_type=args["predicted_token_type"], task=args["task"], level=args["level"])
test_dataset = FinetuneHGDataset(test_data, tokenizer, token_type=args["predicted_token_type"], task=args["task"], level=args["level"])
print(len(train_dataset), len(val_dataset), len(test_dataset))

1883 1410 1410


In [8]:
long_adm_seq_crite = 3
val_long_seq_idx, test_long_seq_idx = [], []
for i in range(len(val_dataset)):
    hadm_id = list(val_dataset.records.keys())[i]
    num_adms = len(val_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite:
        val_long_seq_idx.append(i)
for i in range(len(test_dataset)):
    hadm_id = list(test_dataset.records.keys())[i]
    num_adms = len(test_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite:
        test_long_seq_idx.append(i)
print(len(val_long_seq_idx), len(test_long_seq_idx))

208 186


In [9]:
use_full_graph = True
train_batch_size = len(train_dataset) if use_full_graph else args["batch_size"]
val_batch_size = len(val_dataset) if use_full_graph else args["batch_size"]
test_batch_size = len(test_dataset) if use_full_graph else args["batch_size"]
train_dataloader = DataLoader(train_dataset, collate_fn=batcher_SetGNN_finetune(device = device), batch_size = train_batch_size, shuffle = True)
val_dataloader = DataLoader(val_dataset, collate_fn=batcher_SetGNN_finetune(device = device), batch_size = val_batch_size, shuffle = False)
test_dataloader = DataLoader(test_dataset, collate_fn=batcher_SetGNN_finetune(device = device), batch_size = test_batch_size, shuffle = False)

In [10]:
batch = next(iter(train_dataloader))
graph, global_node_ids, last_visit_indices, labels = batch
print(labels.shape)

torch.Size([1883, 18])


In [11]:
if args["task"] in ["death", "stay", "readmission"]:
    eval_metric = "prauc"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "prauc"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [12]:
random.seed(42)
seeds = [random.randint(0, 2**32 - 1) for _ in range(15)]
print(seeds)

[2746317213, 1181241943, 958682846, 3163119785, 1812140441, 127978094, 939042955, 2340505846, 946785248, 2530876844, 3460967357, 2998485882, 1461364854, 667779376, 1445662585]


In [None]:
final_metrics, final_long_seq_metrics = [], []

for seed in seeds:
    set_random_seed(seed)
    print(f"Training with seed: {seed}")
    
    # Initialize model, optimizer, and loss function
    model = SetGNN(args, tokenizer).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args["lr"])
    
    best_test_metric, best_test_long_seq_metric = train_with_early_stopping(
        model, 
        train_dataloader, 
        val_dataloader, 
        test_dataloader,
        optimizer, 
        loss_fn, 
        device, 
        args,
        val_long_seq_idx,
        test_long_seq_idx,
        task_type=task_type,
        eval_metric = "prauc")
    
    final_metrics.append(best_test_metric)
    final_long_seq_metrics.append(best_test_long_seq_metric)

[INFO] Random seed set to 2746317213
Training with seed: 2746317213


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.15s/it]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  2.98it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.38it/s]


Epoch: 001, Average Loss: 0.6933
Validation: {'precision': 0.12555253325580906, 'recall': 0.37436539376380373, 'f1': 0.1689312515045155, 'auc': 0.53986987347967, 'prauc': 0.23358664913895102}
Test:       {'precision': 0.11672113616644293, 'recall': 0.37703396325910343, 'f1': 0.16264748726738584, 'auc': 0.5244515109196481, 'prauc': 0.22512015920636091}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  9.20it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.45it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.42it/s]


Epoch: 002, Average Loss: 0.6837
Validation: {'precision': 0.14693124377257838, 'recall': 0.2321835665509887, 'f1': 0.14767606061478822, 'auc': 0.5468580111809253, 'prauc': 0.2419990648587324}
Test:       {'precision': 0.11107427522633367, 'recall': 0.23097923977975116, 'f1': 0.1411878523761014, 'auc': 0.5343453608257313, 'prauc': 0.2326085997407975}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.97it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.48it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.21it/s]


Epoch: 003, Average Loss: 0.6737
Validation: {'precision': 0.12113658693635382, 'recall': 0.07942123701335077, 'f1': 0.07883712590937292, 'auc': 0.5435221234894875, 'prauc': 0.24076103069110794}
Test:       {'precision': 0.11797542575830218, 'recall': 0.08041751258305206, 'f1': 0.07756626541581496, 'auc': 0.5308104727675979, 'prauc': 0.23094702746094592}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.76it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.25it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.51it/s]


Epoch: 004, Average Loss: 0.6624
Validation: {'precision': 0.1313602091617452, 'recall': 0.023436145790402187, 'f1': 0.03340967243297452, 'auc': 0.5331006100497802, 'prauc': 0.2334590380665737}
Test:       {'precision': 0.11355934044915705, 'recall': 0.023321165785860756, 'f1': 0.03176988388012328, 'auc': 0.5216353429311871, 'prauc': 0.22460840644349822}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.88it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.22it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.57it/s]


Epoch: 005, Average Loss: 0.6486
Validation: {'precision': 0.14688628365098955, 'recall': 0.008279094249143804, 'f1': 0.014273890426049473, 'auc': 0.5259225419820585, 'prauc': 0.2283800038922369}
Test:       {'precision': 0.09673651904623032, 'recall': 0.007254730319194844, 'f1': 0.012306784125692133, 'auc': 0.5153716412341167, 'prauc': 0.2203787889634275}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.82it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.19it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.18it/s]


Epoch: 006, Average Loss: 0.6324
Validation: {'precision': 0.0954835713456403, 'recall': 0.004180855172612897, 'f1': 0.007727538539489601, 'auc': 0.5221283103788351, 'prauc': 0.22515468879213665}
Test:       {'precision': 0.08460356069051721, 'recall': 0.0037193373429314147, 'f1': 0.006892837154187477, 'auc': 0.511806137637908, 'prauc': 0.21795025853745198}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.80it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.20it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.26it/s]


Epoch: 007, Average Loss: 0.6140
Validation: {'precision': 0.04102693602693603, 'recall': 0.0014600514242345325, 'f1': 0.002797802219522175, 'auc': 0.5175002244438768, 'prauc': 0.22112701058004372}
Test:       {'precision': 0.03975795971410006, 'recall': 0.001343022800508578, 'f1': 0.002597476144315873, 'auc': 0.5073395450565905, 'prauc': 0.21517164522564056}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.83it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.15it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.20it/s]


Epoch: 008, Average Loss: 0.5934
Validation: {'precision': 0.02202581369248036, 'recall': 0.0005262330541943618, 'f1': 0.001026281273676019, 'auc': 0.5129154423573811, 'prauc': 0.21741511576251535}
Test:       {'precision': 0.047378547378547374, 'recall': 0.0007983224148490537, 'f1': 0.0015685683518075904, 'auc': 0.5032420731221534, 'prauc': 0.21264660140427333}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.80it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.16it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.28it/s]


Epoch: 009, Average Loss: 0.5708
Validation: {'precision': 0.03756613756613757, 'recall': 0.0003982248616700402, 'f1': 0.0007870935705328174, 'auc': 0.5114136428155319, 'prauc': 0.21569152178272047}
Test:       {'precision': 0.04351851851851852, 'recall': 0.00041239033889442175, 'f1': 0.0008168200047180939, 'auc': 0.5016076636886392, 'prauc': 0.2117690583811139}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.84it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.10it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.51it/s]


Epoch: 010, Average Loss: 0.5490
Validation: {'precision': 0.03888888888888889, 'recall': 0.0002702166691457187, 'f1': 0.0005365415957076673, 'auc': 0.5116363368196263, 'prauc': 0.21533320322558486}
Test:       {'precision': 0.046296296296296294, 'recall': 0.0002822836514106851, 'f1': 0.0005611306387207802, 'auc': 0.5017949313848304, 'prauc': 0.21167377230133724}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.87it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.21it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.19it/s]


Epoch: 011, Average Loss: 0.5273
Validation: {'precision': 0.041666666666666664, 'recall': 0.0002702166691457187, 'f1': 0.0005369163633517278, 'auc': 0.5139943664147091, 'prauc': 0.21684020249854694}
Test:       {'precision': 0.07407407407407407, 'recall': 0.0002822836514106851, 'f1': 0.0005622684194112765, 'auc': 0.5047299917570817, 'prauc': 0.21305567330203204}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.82it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.12it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.20it/s]


Epoch: 012, Average Loss: 0.5089
Validation: {'precision': 0.06944444444444445, 'recall': 0.0002702166691457187, 'f1': 0.0005379153663467389, 'auc': 0.516488748023844, 'prauc': 0.2182104253621801}
Test:       {'precision': 0.05555555555555555, 'recall': 0.0001786352268667381, 'f1': 0.0003561253561253561, 'auc': 0.5081102385254744, 'prauc': 0.2148253817316237}

Early stopping triggered after 12 epochs (no improvement for 10 epochs).

Best validation performance:
{'global': {'precision': 0.14693124377257838, 'recall': 0.2321835665509887, 'f1': 0.14767606061478822, 'auc': 0.5468580111809253, 'prauc': 0.2419990648587324}, 'per_class':                                            precision    recall        f1  \
Acute and unspecified renal failure         0.329499  0.923963  0.485766   
Acute cerebrovascular disease               0.000000  0.000000  0.000000   
Acute myocardial infarction                 0.000000  0.000000  0.000000   
Cardiac dysrhythmias                        0.411168  0.8

Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.80it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.17it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.26it/s]


Epoch: 001, Average Loss: 0.6925
Validation: {'precision': 0.1218112873800757, 'recall': 0.31829523311325325, 'f1': 0.1597691427687816, 'auc': 0.5393133742790032, 'prauc': 0.23617583469630604}
Test:       {'precision': 0.13201264574858343, 'recall': 0.3260267877951, 'f1': 0.15906635499744634, 'auc': 0.5307097621007452, 'prauc': 0.22677851557063214}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.84it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.11it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.89it/s]


Epoch: 002, Average Loss: 0.6832
Validation: {'precision': 0.10386111705300943, 'recall': 0.20782251862362178, 'f1': 0.13245067148472456, 'auc': 0.5521734384374158, 'prauc': 0.24656518283187084}
Test:       {'precision': 0.09348896548424587, 'recall': 0.2036359431251772, 'f1': 0.12628658951779295, 'auc': 0.5392807842064816, 'prauc': 0.2340867596589749}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.94it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.21it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.23it/s]


Epoch: 003, Average Loss: 0.6738
Validation: {'precision': 0.09948626628184333, 'recall': 0.04243969698015917, 'f1': 0.057287450014004405, 'auc': 0.5437447438439874, 'prauc': 0.24132646651256584}
Test:       {'precision': 0.08921798187298231, 'recall': 0.04355109408276639, 'f1': 0.05593778036104535, 'auc': 0.529617626052256, 'prauc': 0.229340018733236}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.84it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.54it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.22it/s]


Epoch: 004, Average Loss: 0.6627
Validation: {'precision': 0.09496060684722957, 'recall': 0.009547209487569843, 'f1': 0.01718479756519031, 'auc': 0.5307768452833526, 'prauc': 0.23206111907928031}
Test:       {'precision': 0.08163691512487488, 'recall': 0.006575782730344158, 'f1': 0.012145354335322463, 'auc': 0.5167956083524872, 'prauc': 0.22072126414779228}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.76it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.38it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.17it/s]


Epoch: 005, Average Loss: 0.6495
Validation: {'precision': 0.08391754850088182, 'recall': 0.0022337348554748926, 'f1': 0.004266563857064373, 'auc': 0.5218216696342213, 'prauc': 0.22505556884557443}
Test:       {'precision': 0.05567901234567902, 'recall': 0.0017251450444080792, 'f1': 0.0033186921390473303, 'auc': 0.5079705383841455, 'prauc': 0.21495134578813596}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.83it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.18it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.66it/s]


Epoch: 006, Average Loss: 0.6330
Validation: {'precision': 0.030086756924992224, 'recall': 0.0009808175586178197, 'f1': 0.0018806939661113865, 'auc': 0.5158215695874452, 'prauc': 0.22032557468754044}
Test:       {'precision': 0.058333333333333334, 'recall': 0.0014269935834659204, 'f1': 0.00275621259295786, 'auc': 0.5024834248552131, 'prauc': 0.21136759291499374}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.86it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.11it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.23it/s]


Epoch: 007, Average Loss: 0.6142
Validation: {'precision': 0.08395061728395062, 'recall': 0.0006608471056036522, 'f1': 0.0012960350174128108, 'auc': 0.5112995204675355, 'prauc': 0.21743943393810644}
Test:       {'precision': 0.05864197530864198, 'recall': 0.001175402761629798, 'f1': 0.0022775036297686736, 'auc': 0.49823355586708623, 'prauc': 0.2092405889278407}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  9.60it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.25it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.23it/s]


Epoch: 008, Average Loss: 0.5932
Validation: {'precision': 0.0617283950617284, 'recall': 0.0003125364124653729, 'f1': 0.0006101639611713843, 'auc': 0.5087896709775228, 'prauc': 0.21600535272448856}
Test:       {'precision': 0.0609567901234568, 'recall': 0.0010651734847338542, 'f1': 0.0020742873684050154, 'auc': 0.4963356896399049, 'prauc': 0.20833485116819891}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.81it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.58it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.12it/s]


Epoch: 009, Average Loss: 0.5692
Validation: {'precision': 0.06851851851851852, 'recall': 0.0008680919680209285, 'f1': 0.0015810943358175436, 'auc': 0.5082356737635103, 'prauc': 0.2157297305760648}
Test:       {'precision': 0.042328042328042326, 'recall': 0.0009657896644198013, 'f1': 0.00188598301805849, 'auc': 0.4961067558631966, 'prauc': 0.2081185886802143}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.84it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.37it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.21it/s]


Epoch: 010, Average Loss: 0.5459
Validation: {'precision': 0.010648148148148148, 'recall': 0.0006634304207119741, 'f1': 0.0011786331181127092, 'auc': 0.5084884447697349, 'prauc': 0.21623266257396279}
Test:       {'precision': 0.046296296296296294, 'recall': 0.0007033533860017877, 'f1': 0.0013827797514990884, 'auc': 0.4969038277315067, 'prauc': 0.20853819331722934}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.78it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.18it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.37it/s]


Epoch: 011, Average Loss: 0.5247
Validation: {'precision': 0.014814814814814815, 'recall': 0.0006634304207119741, 'f1': 0.0011798587885544407, 'auc': 0.5087122043413492, 'prauc': 0.21605632086223522}
Test:       {'precision': 0.039351851851851846, 'recall': 0.00048289483220990064, 'f1': 0.0009529740338164251, 'auc': 0.49747873335241544, 'prauc': 0.20886904383933835}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.79it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.20it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.63it/s]


Epoch: 012, Average Loss: 0.5069
Validation: {'precision': 0.018518518518518517, 'recall': 0.00010787486515641856, 'f1': 0.0002145002145002145, 'auc': 0.508524102049691, 'prauc': 0.21590061041200026}
Test:       {'precision': 0.018518518518518517, 'recall': 0.0002204585537918871, 'f1': 0.00043572984749455336, 'auc': 0.49843644820147237, 'prauc': 0.2091419402362239}

Early stopping triggered after 12 epochs (no improvement for 10 epochs).

Best validation performance:
{'global': {'precision': 0.10386111705300943, 'recall': 0.20782251862362178, 'f1': 0.13245067148472456, 'auc': 0.5521734384374158, 'prauc': 0.24656518283187084}, 'per_class':                                            precision    recall        f1  \
Acute and unspecified renal failure         0.000000  0.000000  0.000000   
Acute cerebrovascular disease               0.000000  0.000000  0.000000   
Acute myocardial infarction                 0.000000  0.000000  0.000000   
Cardiac dysrhythmias                        0.420

Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.86it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.13it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.19it/s]


Epoch: 001, Average Loss: 0.6926
Validation: {'precision': 0.10350587233544034, 'recall': 0.2570886517846184, 'f1': 0.10641333752674655, 'auc': 0.5245468593188924, 'prauc': 0.22790927312451462}
Test:       {'precision': 0.09437995791621055, 'recall': 0.25220368768098256, 'f1': 0.10333376999566773, 'auc': 0.5223947409746539, 'prauc': 0.2217316479604134}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.84it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 14.34it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.23it/s]


Epoch: 002, Average Loss: 0.6841
Validation: {'precision': 0.11269959630735624, 'recall': 0.07932749246993594, 'f1': 0.0564892578546584, 'auc': 0.5320367554566166, 'prauc': 0.23312111012105804}
Test:       {'precision': 0.09205490035531944, 'recall': 0.07365111988715525, 'f1': 0.0558351232819519, 'auc': 0.5205871864152828, 'prauc': 0.22346362744637988}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.84it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.72it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.22it/s]


Epoch: 003, Average Loss: 0.6743
Validation: {'precision': 0.06415996070875364, 'recall': 0.01997591375722003, 'f1': 0.02682962201948959, 'auc': 0.5268594811214552, 'prauc': 0.22932167594598438}
Test:       {'precision': 0.10324197308068277, 'recall': 0.016679730836665005, 'f1': 0.021891312098368552, 'auc': 0.5150796404245115, 'prauc': 0.22055668183652008}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.81it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.61it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.16it/s]


Epoch: 004, Average Loss: 0.6636
Validation: {'precision': 0.057279956559006605, 'recall': 0.0069873541691988054, 'f1': 0.011590486538293239, 'auc': 0.5229126101994327, 'prauc': 0.22642582618559912}
Test:       {'precision': 0.05140901771336554, 'recall': 0.006700603895311333, 'f1': 0.01118583587210468, 'auc': 0.5100320917559645, 'prauc': 0.21773291149131604}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.81it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.60it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.24it/s]


Epoch: 005, Average Loss: 0.6509
Validation: {'precision': 0.05685531788472965, 'recall': 0.0030197690226609307, 'f1': 0.00546006087806386, 'auc': 0.5193056399839227, 'prauc': 0.22358507929235827}
Test:       {'precision': 0.053765432098765435, 'recall': 0.0025069991452591686, 'f1': 0.004566609785095165, 'auc': 0.5061633009958405, 'prauc': 0.2152086871055479}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.83it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.19it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.17it/s]


Epoch: 006, Average Loss: 0.6354
Validation: {'precision': 0.031450665061346174, 'recall': 0.0015009350512524336, 'f1': 0.0026648019466371355, 'auc': 0.515533773983764, 'prauc': 0.22059178542120061}
Test:       {'precision': 0.0498936735778841, 'recall': 0.0014833087125093816, 'f1': 0.0026743654644173797, 'auc': 0.5026771777457026, 'prauc': 0.2124118586131615}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  9.53it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.21it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.19it/s]


Epoch: 007, Average Loss: 0.6174
Validation: {'precision': 0.009868421052631578, 'recall': 0.0013012677106636838, 'f1': 0.002280508162861104, 'auc': 0.5123995305810859, 'prauc': 0.21844496570090344}
Test:       {'precision': 0.013800705467372132, 'recall': 0.000977860402479329, 'f1': 0.0017018532459291402, 'auc': 0.4995844993716588, 'prauc': 0.21048713180225398}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.80it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.64it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.23it/s]


Epoch: 008, Average Loss: 0.5971
Validation: {'precision': 0.010912698412698412, 'recall': 0.0013012677106636838, 'f1': 0.0023214602161970585, 'auc': 0.5102300076010788, 'prauc': 0.21699150498363437}
Test:       {'precision': 0.006535947712418301, 'recall': 0.0008742119779353822, 'f1': 0.0015252895891790777, 'auc': 0.4975886538661226, 'prauc': 0.20970014326077038}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.82it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.15it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.82it/s]


Epoch: 009, Average Loss: 0.5757
Validation: {'precision': 0.01221001221001221, 'recall': 0.0009284116331096196, 'f1': 0.0017242926554106891, 'auc': 0.5101972867029574, 'prauc': 0.21697317210324124}
Test:       {'precision': 0.008547008547008548, 'recall': 0.0008742119779353822, 'f1': 0.001572159589412773, 'auc': 0.49772923532455216, 'prauc': 0.2097078604682877}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.83it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.19it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.85it/s]


Epoch: 010, Average Loss: 0.5531
Validation: {'precision': 0.037037037037037035, 'recall': 0.0009284116331096196, 'f1': 0.0017840532022713705, 'auc': 0.5123483322564942, 'prauc': 0.2183692963253453}
Test:       {'precision': 0.004273504273504274, 'recall': 0.0005787037037037037, 'f1': 0.0010193679918450561, 'auc': 0.4999737313419412, 'prauc': 0.21107981675688847}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.87it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.18it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.20it/s]


Epoch: 011, Average Loss: 0.5326
Validation: {'precision': 0.009259259259259259, 'recall': 0.0005555555555555556, 'f1': 0.0010482180293501049, 'auc': 0.5144238316372962, 'prauc': 0.22006076315161072}
Test:       {'precision': 0.004273504273504274, 'recall': 0.0005787037037037037, 'f1': 0.0010193679918450561, 'auc': 0.5023071273572745, 'prauc': 0.21237928161795536}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  9.65it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.21it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.20it/s]


Epoch: 012, Average Loss: 0.5149
Validation: {'precision': 0.009259259259259259, 'recall': 0.0005555555555555556, 'f1': 0.0010482180293501049, 'auc': 0.5150010856658691, 'prauc': 0.22065047453127257}
Test:       {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.5032134656333969, 'prauc': 0.2128160923076623}

Early stopping triggered after 12 epochs (no improvement for 10 epochs).

Best validation performance:
{'global': {'precision': 0.11269959630735624, 'recall': 0.07932749246993594, 'f1': 0.0564892578546584, 'auc': 0.5320367554566166, 'prauc': 0.23312111012105804}, 'per_class':                                            precision    recall        f1  \
Acute and unspecified renal failure         0.000000  0.000000  0.000000   
Acute cerebrovascular disease               0.000000  0.000000  0.000000   
Acute myocardial infarction                 0.000000  0.000000  0.000000   
Cardiac dysrhythmias                        0.000000  0.000000  0.000000   
Chronic kidney disease       

Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.81it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.65it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.18it/s]


Epoch: 001, Average Loss: 0.6908
Validation: {'precision': 0.09655835702541567, 'recall': 0.10089595747192452, 'f1': 0.06493924013547342, 'auc': 0.5196502500873489, 'prauc': 0.21956492059286703}
Test:       {'precision': 0.11778623842538069, 'recall': 0.1010438210171873, 'f1': 0.06079435505495217, 'auc': 0.5293805767107919, 'prauc': 0.2181292125351901}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.77it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.34it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.21it/s]


Epoch: 002, Average Loss: 0.6814
Validation: {'precision': 0.10845871442675954, 'recall': 0.016902629502781864, 'f1': 0.02317899420227547, 'auc': 0.5274308924491085, 'prauc': 0.22621802024892107}
Test:       {'precision': 0.08379741919119861, 'recall': 0.01120321571725541, 'f1': 0.0170404773272048, 'auc': 0.524866017204748, 'prauc': 0.22022153123711083}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.82it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.16it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.54it/s]


Epoch: 003, Average Loss: 0.6705
Validation: {'precision': 0.11270753512132822, 'recall': 0.0035417590632752993, 'f1': 0.00641986734024827, 'auc': 0.5255383614065011, 'prauc': 0.22610080915738262}
Test:       {'precision': 0.0558970317835566, 'recall': 0.00322248482455446, 'f1': 0.005896571280588104, 'auc': 0.517817361122221, 'prauc': 0.21886153504626862}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.81it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.20it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.77it/s]


Epoch: 004, Average Loss: 0.6582
Validation: {'precision': 0.03761214630779849, 'recall': 0.0018977858878270223, 'f1': 0.00361268523194125, 'auc': 0.5193699475599558, 'prauc': 0.2217126122498252}
Test:       {'precision': 0.04135802469135803, 'recall': 0.0017053841041444091, 'f1': 0.0032198830697658383, 'auc': 0.5085274134857635, 'prauc': 0.21493867579849965}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.83it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.18it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.18it/s]


Epoch: 005, Average Loss: 0.6432
Validation: {'precision': 0.03367003367003366, 'recall': 0.001129736732681093, 'f1': 0.0021827554266843864, 'auc': 0.5139526013067434, 'prauc': 0.2183278664101263}
Test:       {'precision': 0.011728395061728396, 'recall': 0.0006970001115200179, 'f1': 0.001283648582323404, 'auc': 0.5017469994575724, 'prauc': 0.21169214407290235}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  9.39it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.20it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.23it/s]


Epoch: 006, Average Loss: 0.6255
Validation: {'precision': 0.009259259259259259, 'recall': 0.00012800819252432156, 'f1': 0.0002525252525252525, 'auc': 0.5099333047665827, 'prauc': 0.21544665909111216}
Test:       {'precision': 0.014109347442680775, 'recall': 0.0006970001115200179, 'f1': 0.0012944379842179058, 'auc': 0.4972045975897511, 'prauc': 0.20929406189582198}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.87it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.62it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.21it/s]


Epoch: 007, Average Loss: 0.6055
Validation: {'precision': 0.011111111111111112, 'recall': 0.00012800819252432156, 'f1': 0.0002531004808909137, 'auc': 0.5072665138325299, 'prauc': 0.21316260749978816}
Test:       {'precision': 0.016203703703703703, 'recall': 0.0006970001115200179, 'f1': 0.0013048256762579827, 'auc': 0.4939083683576329, 'prauc': 0.20771134552222184}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.85it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.18it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 16.10it/s]


Epoch: 008, Average Loss: 0.5835
Validation: {'precision': 0.018518518518518517, 'recall': 0.00012800819252432156, 'f1': 0.0002542588354945334, 'auc': 0.5039223633671863, 'prauc': 0.21105682702799472}
Test:       {'precision': 0.024691358024691357, 'recall': 0.0006970001115200179, 'f1': 0.0012968195319858, 'auc': 0.49028687751711747, 'prauc': 0.20595399120638114}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.83it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.17it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.67it/s]


Epoch: 009, Average Loss: 0.5606
Validation: {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.5002605990673096, 'prauc': 0.20885979118938947}
Test:       {'precision': 0.03395061728395062, 'recall': 0.0006970001115200179, 'f1': 0.0012974218581695215, 'auc': 0.48691709604484107, 'prauc': 0.20415594679969773}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.84it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.19it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.75it/s]


Epoch: 010, Average Loss: 0.5384
Validation: {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.49977258621458603, 'prauc': 0.20860314480037342}
Test:       {'precision': 0.006172839506172839, 'recall': 0.0005668934240362811, 'f1': 0.0010384215991692627, 'auc': 0.48719931113933534, 'prauc': 0.20421533592816005}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.86it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.17it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.19it/s]


Epoch: 011, Average Loss: 0.5188
Validation: {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.5013956130550127, 'prauc': 0.2090235111304818}
Test:       {'precision': 0.006172839506172839, 'recall': 0.0005668934240362811, 'f1': 0.0010384215991692627, 'auc': 0.4900320109748205, 'prauc': 0.20512038263462407}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.81it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.65it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.20it/s]


Epoch: 012, Average Loss: 0.5031
Validation: {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.5058678817308417, 'prauc': 0.21112116889881422}
Test:       {'precision': 0.006944444444444444, 'recall': 0.0005668934240362811, 'f1': 0.0010482180293501049, 'auc': 0.49655786090959864, 'prauc': 0.2076467136268297}

Early stopping triggered after 12 epochs (no improvement for 10 epochs).

Best validation performance:
{'global': {'precision': 0.10845871442675954, 'recall': 0.016902629502781864, 'f1': 0.02317899420227547, 'auc': 0.5274308924491085, 'prauc': 0.22621802024892107}, 'per_class':                                            precision    recall        f1  \
Acute and unspecified renal failure         0.362245  0.163594  0.225397   
Acute cerebrovascular disease               0.000000  0.000000  0.000000   
Acute myocardial infarction                 0.062500  0.010526  0.018018   
Cardiac dysrhythmias                        0.000000  0.000000  0.000000   
Chronic kidney disease    

Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.84it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.65it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.15it/s]


Epoch: 001, Average Loss: 0.6967
Validation: {'precision': 0.16397862543753686, 'recall': 0.42757823960698793, 'f1': 0.18345254573446335, 'auc': 0.5213285075873966, 'prauc': 0.22231434164120734}
Test:       {'precision': 0.1736211756331792, 'recall': 0.41769532389644215, 'f1': 0.1787846998808347, 'auc': 0.5206229541168903, 'prauc': 0.21503320842188012}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.80it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.82it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.23it/s]


Epoch: 002, Average Loss: 0.6864
Validation: {'precision': 0.13925564066501128, 'recall': 0.1398718354473482, 'f1': 0.10524826584872814, 'auc': 0.5299857805712415, 'prauc': 0.22764425041791275}
Test:       {'precision': 0.1215287743656836, 'recall': 0.14673072403084744, 'f1': 0.10077906462238541, 'auc': 0.5253042761265263, 'prauc': 0.22080812183002138}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.82it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.17it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.83it/s]


Epoch: 003, Average Loss: 0.6769
Validation: {'precision': 0.11373025020441882, 'recall': 0.05360097248516764, 'f1': 0.04629479986746394, 'auc': 0.5259583703676279, 'prauc': 0.22808954855817445}
Test:       {'precision': 0.0997764972019418, 'recall': 0.05642283141881407, 'f1': 0.04594213227791072, 'auc': 0.5206173846357943, 'prauc': 0.22170410429197504}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.83it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.24it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.79it/s]


Epoch: 004, Average Loss: 0.6656
Validation: {'precision': 0.08647829315863972, 'recall': 0.008164100919519947, 'f1': 0.01413929593968201, 'auc': 0.5210235268362652, 'prauc': 0.22465458036551444}
Test:       {'precision': 0.07105848773346447, 'recall': 0.008425835493760558, 'f1': 0.014439834198650537, 'auc': 0.5144401242521962, 'prauc': 0.21782824666664372}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.85it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.18it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.22it/s]


Epoch: 005, Average Loss: 0.6520
Validation: {'precision': 0.048059116809116814, 'recall': 0.0021461034304283986, 'f1': 0.0039970914188545365, 'auc': 0.5174121552875042, 'prauc': 0.2220868965707427}
Test:       {'precision': 0.044060705090116864, 'recall': 0.002372388657467125, 'f1': 0.0043612493149585855, 'auc': 0.5093483015982875, 'prauc': 0.21465043708900533}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  9.68it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.21it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.30it/s]


Epoch: 006, Average Loss: 0.6355
Validation: {'precision': 0.056782581453634094, 'recall': 0.0019299339535585325, 'f1': 0.003589088370592436, 'auc': 0.5142698352193942, 'prauc': 0.2194921193790458}
Test:       {'precision': 0.045465645465645466, 'recall': 0.0017078580531926096, 'f1': 0.003236012437991381, 'auc': 0.5052744734221173, 'prauc': 0.2120670291347955}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.81it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.55it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.21it/s]


Epoch: 007, Average Loss: 0.6169
Validation: {'precision': 0.03138616557734205, 'recall': 0.0018331472714059967, 'f1': 0.00340212911943842, 'auc': 0.5115887139870262, 'prauc': 0.2172629682236139}
Test:       {'precision': 0.03780864197530864, 'recall': 0.0010433274489180934, 'f1': 0.0020263024888670326, 'auc': 0.5017752959602869, 'prauc': 0.21013765016628444}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.90it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.93it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.26it/s]


Epoch: 008, Average Loss: 0.5961
Validation: {'precision': 0.03591269841269841, 'recall': 0.0008712656394280027, 'f1': 0.0016631683905579337, 'auc': 0.5084746319871218, 'prauc': 0.21478109413724442}
Test:       {'precision': 0.04034391534391534, 'recall': 0.0007670550418131208, 'f1': 0.0014992355246103801, 'auc': 0.4977584302823696, 'prauc': 0.20822301640938995}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.86it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.19it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 13.54it/s]


Epoch: 009, Average Loss: 0.5732
Validation: {'precision': 0.05555555555555555, 'recall': 0.00021616947686986597, 'f1': 0.0004306632213608958, 'auc': 0.5060703917763575, 'prauc': 0.21318618228106226}
Test:       {'precision': 0.027777777777777776, 'recall': 0.00019527436047646942, 'f1': 0.00038782237735117316, 'auc': 0.4947943544893911, 'prauc': 0.2066424551543517}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.86it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.18it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.17it/s]


Epoch: 010, Average Loss: 0.5509
Validation: {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.5057367831660934, 'prauc': 0.21277177892685528}
Test:       {'precision': 0.027777777777777776, 'recall': 9.763718023823471e-05, 'f1': 0.0001945903872348706, 'auc': 0.49425042473176234, 'prauc': 0.20652503460316646}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  9.58it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.24it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.22it/s]


Epoch: 011, Average Loss: 0.5300
Validation: {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.5076367057876963, 'prauc': 0.21347035042208842}
Test:       {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.4963731946135369, 'prauc': 0.20749071662185084}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  3.15it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 16.10it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.45it/s]


Epoch: 012, Average Loss: 0.5139
Validation: {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.5112220338794872, 'prauc': 0.21541907691292253}
Test:       {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.5007263440146825, 'prauc': 0.20958024571247252}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.95it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 16.12it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.15it/s]


Epoch: 013, Average Loss: 0.5013
Validation: {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.5146609648507776, 'prauc': 0.21717817848261306}
Test:       {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.5048535681246875, 'prauc': 0.21155769491154727}

Early stopping triggered after 13 epochs (no improvement for 10 epochs).

Best validation performance:
{'global': {'precision': 0.11373025020441882, 'recall': 0.05360097248516764, 'f1': 0.04629479986746394, 'auc': 0.5259583703676279, 'prauc': 0.22808954855817445}, 'per_class':                                            precision    recall        f1  \
Acute and unspecified renal failure         0.000000  0.000000  0.000000   
Acute cerebrovascular disease               0.000000  0.000000  0.000000   
Acute myocardial infarction                 0.014925  0.010526  0.012346   
Cardiac dysrhythmias                        0.000000  0.000000  0.000000   
Chronic kidney disease                      0.142857  0.002941  0.005764   
Chro

Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.74it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.05it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.72it/s]


Epoch: 001, Average Loss: 0.6913
Validation: {'precision': 0.1053932399164546, 'recall': 0.18792427687369453, 'f1': 0.11905423341769371, 'auc': 0.5227656576151882, 'prauc': 0.22734931904258365}
Test:       {'precision': 0.10393074313410108, 'recall': 0.186380271337652, 'f1': 0.11510084147687621, 'auc': 0.5168931842525263, 'prauc': 0.21995945351246973}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.84it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.20it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.90it/s]


Epoch: 002, Average Loss: 0.6826
Validation: {'precision': 0.10966203901876612, 'recall': 0.094302762226931, 'f1': 0.07794435577932646, 'auc': 0.5344022518632103, 'prauc': 0.23129147941348294}
Test:       {'precision': 0.08415442633416388, 'recall': 0.09079412498604311, 'f1': 0.07223968951619203, 'auc': 0.5219793748530263, 'prauc': 0.2233936945613833}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.88it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.20it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.87it/s]


Epoch: 003, Average Loss: 0.6735
Validation: {'precision': 0.06424098693170716, 'recall': 0.05530661504422067, 'f1': 0.04033482521478557, 'auc': 0.5315165773866752, 'prauc': 0.22881617641742238}
Test:       {'precision': 0.06934314528996238, 'recall': 0.05503248749829162, 'f1': 0.038500397075433374, 'auc': 0.5177037355577765, 'prauc': 0.2212203112279964}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.83it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.20it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.17it/s]


Epoch: 004, Average Loss: 0.6628
Validation: {'precision': 0.058730158730158716, 'recall': 0.002822757762214058, 'f1': 0.005099418551415203, 'auc': 0.5247121676593982, 'prauc': 0.2241816379185376}
Test:       {'precision': 0.03418803418803419, 'recall': 0.0025726056324374582, 'f1': 0.004658119658119658, 'auc': 0.5123312649893276, 'prauc': 0.21753904267211807}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.85it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.99it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.22it/s]


Epoch: 005, Average Loss: 0.6504
Validation: {'precision': 0.07642089093701997, 'recall': 0.0020520349683622514, 'f1': 0.0037686825659375917, 'auc': 0.5195293828640309, 'prauc': 0.2208864888963598}
Test:       {'precision': 0.019230769230769232, 'recall': 0.0015459548622988454, 'f1': 0.0028607375256916867, 'auc': 0.5057546174947144, 'prauc': 0.21385754070793989}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.82it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.69it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.21it/s]


Epoch: 006, Average Loss: 0.6349
Validation: {'precision': 0.012820512820512822, 'recall': 0.0010422001203920829, 'f1': 0.0019273359848571935, 'auc': 0.5142506921190357, 'prauc': 0.21764174671168998}
Test:       {'precision': 0.018461007591442376, 'recall': 0.0013673196354321074, 'f1': 0.0025460629895324863, 'auc': 0.5008632528741772, 'prauc': 0.21113607893080621}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.84it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 16.02it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.25it/s]


Epoch: 007, Average Loss: 0.6167
Validation: {'precision': 0.013935340022296544, 'recall': 0.0010422001203920829, 'f1': 0.0019378272377750727, 'auc': 0.5111445426305168, 'prauc': 0.2158950446864842}
Test:       {'precision': 0.01893939393939394, 'recall': 0.001188684408565369, 'f1': 0.002236762867712176, 'auc': 0.49844350444635105, 'prauc': 0.20948411625784982}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.85it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.22it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.48it/s]


Epoch: 008, Average Loss: 0.5963
Validation: {'precision': 0.009564509564509565, 'recall': 0.0007075280989730735, 'f1': 0.0013153957650862487, 'auc': 0.5099140607485799, 'prauc': 0.21546862737433495}
Test:       {'precision': 0.01770152505446623, 'recall': 0.001010049181698631, 'f1': 0.001910569105691057, 'auc': 0.4977357589996989, 'prauc': 0.20892043540661007}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.80it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.09it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.16it/s]


Epoch: 009, Average Loss: 0.5745
Validation: {'precision': 0.010898458266879319, 'recall': 0.0007075280989730735, 'f1': 0.0013275561886673, 'auc': 0.5085509110741148, 'prauc': 0.2146109356059599}
Test:       {'precision': 0.020833333333333332, 'recall': 0.001010049181698631, 'f1': 0.001926050746496356, 'auc': 0.4967548969060698, 'prauc': 0.208476819064727}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  9.68it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.21it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.17it/s]


Epoch: 010, Average Loss: 0.5518
Validation: {'precision': 0.01122334455667789, 'recall': 0.0007075280989730735, 'f1': 0.0013293650793650795, 'auc': 0.5072241500766729, 'prauc': 0.21417759787655008}
Test:       {'precision': 0.024074074074074074, 'recall': 0.001010049181698631, 'f1': 0.001936064202840498, 'auc': 0.49561915857125793, 'prauc': 0.20794249371409512}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.83it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.62it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.16it/s]


Epoch: 011, Average Loss: 0.5297
Validation: {'precision': 0.006944444444444444, 'recall': 0.0003346720214190094, 'f1': 0.0006385696040868455, 'auc': 0.5066734125647193, 'prauc': 0.21437636898682358}
Test:       {'precision': 0.015873015873015872, 'recall': 0.0007145409074669524, 'f1': 0.0013675213675213675, 'auc': 0.49548210752813737, 'prauc': 0.2080967237897713}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.82it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.18it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.82it/s]


Epoch: 012, Average Loss: 0.5124
Validation: {'precision': 0.003968253968253968, 'recall': 0.0001673360107095047, 'f1': 0.0003211303789338471, 'auc': 0.5076297614625466, 'prauc': 0.21601134283047174}
Test:       {'precision': 0.018518518518518517, 'recall': 0.0007145409074669524, 'f1': 0.0013759889920880635, 'auc': 0.49716222285597284, 'prauc': 0.208909097670592}

Early stopping triggered after 12 epochs (no improvement for 10 epochs).

Best validation performance:
{'global': {'precision': 0.10966203901876612, 'recall': 0.094302762226931, 'f1': 0.07794435577932646, 'auc': 0.5344022518632103, 'prauc': 0.23129147941348294}, 'per_class':                                            precision    recall        f1  \
Acute and unspecified renal failure         0.000000  0.000000  0.000000   
Acute cerebrovascular disease               0.000000  0.000000  0.000000   
Acute myocardial infarction                 0.000000  0.000000  0.000000   
Cardiac dysrhythmias                        0.396310 

Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.84it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.18it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 16.15it/s]


Epoch: 001, Average Loss: 0.6910
Validation: {'precision': 0.08353952086009259, 'recall': 0.1917723948757154, 'f1': 0.09024492021749901, 'auc': 0.5110466092676068, 'prauc': 0.21940639638714773}
Test:       {'precision': 0.09123146609840464, 'recall': 0.19318505281153145, 'f1': 0.0886401222072596, 'auc': 0.5179326414218192, 'prauc': 0.21710521765101098}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.81it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.23it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.16it/s]


Epoch: 002, Average Loss: 0.6798
Validation: {'precision': 0.10030309496473266, 'recall': 0.03864757656854064, 'f1': 0.04048250870993838, 'auc': 0.5145016563609052, 'prauc': 0.2217810249817128}
Test:       {'precision': 0.062432844194738796, 'recall': 0.03501656142120853, 'f1': 0.03774852880881494, 'auc': 0.5130515231651375, 'prauc': 0.21777664255641307}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  9.67it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.22it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.22it/s]


Epoch: 003, Average Loss: 0.6688
Validation: {'precision': 0.0731927630124873, 'recall': 0.01007784467829125, 'f1': 0.016626763918738494, 'auc': 0.5126147434391841, 'prauc': 0.22171092362436326}
Test:       {'precision': 0.06075122239505801, 'recall': 0.006183030545621579, 'f1': 0.010439477865330515, 'auc': 0.507307872117437, 'prauc': 0.2159461505484565}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.80it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 13.89it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.21it/s]


Epoch: 004, Average Loss: 0.6565
Validation: {'precision': 0.08348030570252792, 'recall': 0.003982637948440752, 'f1': 0.007215280858206314, 'auc': 0.5067823142380415, 'prauc': 0.217550999427999}
Test:       {'precision': 0.055708180708180705, 'recall': 0.002113135930068963, 'f1': 0.003883183897925894, 'auc': 0.4997602378784506, 'prauc': 0.21166100239004199}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.85it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.45it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.20it/s]


Epoch: 005, Average Loss: 0.6411
Validation: {'precision': 0.06759259259259259, 'recall': 0.0014577322271340182, 'f1': 0.002665097162379771, 'auc': 0.4988495945585668, 'prauc': 0.21212979304764598}
Test:       {'precision': 0.04431216931216931, 'recall': 0.0011504843850403549, 'f1': 0.0020693891827912444, 'auc': 0.4911896636905257, 'prauc': 0.20673081287283468}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.82it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.14it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 14.00it/s]


Epoch: 006, Average Loss: 0.6231
Validation: {'precision': 0.02607212475633528, 'recall': 0.0010957476438191243, 'f1': 0.0019555128995141207, 'auc': 0.4923025606030353, 'prauc': 0.20776632479585588}
Test:       {'precision': 0.031746031746031744, 'recall': 0.0010528472048021201, 'f1': 0.0019366957379958888, 'auc': 0.4848008513247122, 'prauc': 0.20297992649748856}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.86it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.16it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 16.13it/s]


Epoch: 007, Average Loss: 0.6030
Validation: {'precision': 0.008101851851851851, 'recall': 0.0009284116331096196, 'f1': 0.0016479855310440018, 'auc': 0.487085075354889, 'prauc': 0.20450956809672405}
Test:       {'precision': 0.013227513227513227, 'recall': 0.0008742119779353822, 'f1': 0.001582838696240758, 'auc': 0.47946426475606363, 'prauc': 0.20040033198239202}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.78it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.17it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.22it/s]


Epoch: 008, Average Loss: 0.5790
Validation: {'precision': 0.008522727272727272, 'recall': 0.0009284116331096196, 'f1': 0.0016522988505747126, 'auc': 0.48303954408393135, 'prauc': 0.20216897333311318}
Test:       {'precision': 0.014309764309764309, 'recall': 0.0008742119779353822, 'f1': 0.0016111592853090108, 'auc': 0.47543804362868297, 'prauc': 0.19851064023230774}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.85it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.70it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.22it/s]


Epoch: 009, Average Loss: 0.5564
Validation: {'precision': 0.008754208754208756, 'recall': 0.0009284116331096196, 'f1': 0.0016606280193236715, 'auc': 0.4802901514972896, 'prauc': 0.2005673685387708}
Test:       {'precision': 0.0154320987654321, 'recall': 0.0008742119779353822, 'f1': 0.0016309387443408062, 'auc': 0.4729442409957708, 'prauc': 0.19715245073019966}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.80it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.67it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.27it/s]


Epoch: 010, Average Loss: 0.5339
Validation: {'precision': 0.005555555555555556, 'recall': 0.0003728560775540641, 'f1': 0.0006988120195667366, 'auc': 0.47764020166642734, 'prauc': 0.19943947085461614}
Test:       {'precision': 0.009259259259259259, 'recall': 0.0002955082742316785, 'f1': 0.000572737686139748, 'auc': 0.47097156260924916, 'prauc': 0.19611566336367986}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.86it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.12it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00, 15.73it/s]


Epoch: 011, Average Loss: 0.5144
Validation: {'precision': 0.005555555555555556, 'recall': 0.0003728560775540641, 'f1': 0.0006988120195667366, 'auc': 0.47747553459090825, 'prauc': 0.1989886255023171}
Test:       {'precision': 0.009259259259259259, 'recall': 0.0002955082742316785, 'f1': 0.000572737686139748, 'auc': 0.47221812057559337, 'prauc': 0.19629589795225744}


Training Batches: 100%|██████████| 1/1 [00:00<00:00,  2.79it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.17it/s]
Running inference: 100%|██████████| 1/1 [00:00<00:00,  3.24it/s]


Epoch: 012, Average Loss: 0.4995
Validation: {'precision': 0.005555555555555556, 'recall': 0.0003728560775540641, 'f1': 0.0006988120195667366, 'auc': 0.4792316623288554, 'prauc': 0.199249017804229}
Test:       {'precision': 0.011111111111111112, 'recall': 0.0002955082742316785, 'f1': 0.0005757052389176742, 'auc': 0.47626514795419017, 'prauc': 0.19769146419121483}

Early stopping triggered after 12 epochs (no improvement for 10 epochs).

Best validation performance:
{'global': {'precision': 0.10030309496473266, 'recall': 0.03864757656854064, 'f1': 0.04048250870993838, 'auc': 0.5145016563609052, 'prauc': 0.2217810249817128}, 'per_class':                                            precision    recall        f1  \
Acute and unspecified renal failure         0.000000  0.000000  0.000000   
Acute cerebrovascular disease               0.000000  0.000000  0.000000   
Acute myocardial infarction                 0.000000  0.000000  0.000000   
Cardiac dysrhythmias                        0.000000

Training Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
import numpy as np
def topk_avg_performance_formatted(performances, long_seq_performances, k=5):
    metrics = ["f1", "auc", "prauc"]
    scores = {m: np.array([p[m] for p in performances]) for m in metrics}

    # 计算排名（值越大排名越靠前）
    ranks = {m: (-scores[m]).argsort().argsort() + 1 for m in metrics}
    avg_ranks = np.mean(np.stack([ranks[m] for m in metrics], axis=1), axis=1)

    # 选 top-k
    topk_idx = np.argsort(avg_ranks)[:k]
    final_avg = {m: np.mean([performances[i][m] for i in topk_idx]) for m in performances[0].keys()}
    final_std = {m: np.std([performances[i][m] for i in topk_idx], ddof=0) for m in performances[0].keys()}
    final_long_seq_avg = {m: np.mean([long_seq_performances[i][m] for i in topk_idx]) for m in long_seq_performances[0].keys()}
    final_long_seq_std = {m: np.std([long_seq_performances[i][m] for i in topk_idx], ddof=0) for m in long_seq_performances[0].keys()}

    # 打印结果（转百分比，均保留两位小数）
    print("Final Metrics:")
    for m in performances[0].keys():
        mean_val = final_avg[m] * 100
        std_val = final_std[m] * 100
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")
    print("\nFinal Long Sequence Metrics:")
    for m in long_seq_performances[0].keys():
        mean_val = final_long_seq_avg[m] * 100
        std_val = final_long_seq_std[m] * 100
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")

In [None]:
def print_per_class_performance(dfs, col_name="prauc"):
    """
    输入一个 DataFrame 列表，对每个疾病在所有表格的指定列计算 mean ± std 并打印。

    参数:
        dfs (list[pd.DataFrame]): 多个表格组成的列表
        col_name (str): 要计算的指标列名 (默认: "prauc")
    """
    # 拼接所有表格
    all_values = pd.concat(dfs, axis=0)

    # 按疾病分组，计算 mean 和 std
    grouped = all_values.groupby(all_values.index)[col_name].agg(["mean", "std"])

    # 打印
    for disease, row in grouped.iterrows():
        mean_val = row["mean"] * 100
        std_val = row["std"] * 100
        print(f"{disease}: {mean_val:.2f} ± {std_val:.2f}")

In [None]:
if task_type == "binary":
    topk_avg_performance_formatted(final_metrics, final_long_seq_metrics)
else:
    final_metrics_global = [metrics["global"] for metrics in final_metrics]
    final_metrics_per_class = [metrics["per_class"] for metrics in final_metrics]
    final_long_seq_metrics_global = [metrics["global"] for metrics in final_long_seq_metrics]
    final_long_seq_metrics_per_class = [metrics["per_class"] for metrics in final_long_seq_metrics]
    topk_avg_performance_formatted(final_metrics_global, final_long_seq_metrics_global)
    print("\nPer-class performance, all patients:")
    print_per_class_performance(final_metrics_per_class, col_name="prauc")
    print("\nPer-class performance, long seq:")
    print_per_class_performance(final_long_seq_metrics_per_class, col_name="prauc")