In [1]:
import torch
import random
from model import SetGNN 
import pickle
from tokenizer import EHRTokenizer
from dataset import FinetuneHGDataset, batcher_SetGNN_finetune
from torch.utils.data import DataLoader
import torch.nn.functional as F
from train import PHENO_ORDER, train_with_early_stopping
from set_seed import set_random_seed

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
args = {
    "dataset": "MIMIC-IV", 
    "task": "death",  # options: death, stay, readmission, next_diag_6m, next_diag_12m
    "special_tokens":["[PAD]", "[CLS]"],
    "predicted_token_type": ["diag", "med", "lab", "pro"],
    "batch_size": 256,
    "lr": 1e-3,
    "epochs": 500,
    "model_name": "HG",
    "early_stop_patience": 10,
    # model hyperparameters
    "level": "visit",  # "visit" or "patient"
    "hg_all_num_layers": 3,
    "hg_use_type_embed": True,
    "MLP_num_layers": 2,
    "hg_aggregate": "mean",
    "hg_dropout": 0.0,
    "normtype": "all_one",
    "add_self_loop": True,
    "hg_normalization": "ln",
    "hg_hidden_size": 128,
    "PMA": True,
    "hg_num_heads": 4,
}

In [4]:
full_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic.pkl"

if args["task"] == "next_diag_6m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_6m.pkl"
elif args["task"] == "next_diag_12m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_downstream.pkl"

In [5]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_gender_sentences = ["[PAD]"] + [str(c) + "_" + gender \
    for c in set(ehr_full_data["AGE"].values.tolist()) for gender in ["M", "F"]] # PAD token special for age_gender vocabulary
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
args["max_adm_len"] = max_admissions
print(f"Max admissions per patient: {max_admissions}")

Max admissions per patient: 8


In [6]:
tokenizer = EHRTokenizer(age_gender_sentences, diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, special_tokens=args["special_tokens"])
args["age_gender_vocab_size"] = tokenizer.token_number("age_gender")
args["global_vocab_size"] = len(tokenizer.vocab.id2word)
args["label_vocab_size"] = len(PHENO_ORDER)
print(f"Age and gender vocabulary size: {args['age_gender_vocab_size']}")
print(f"Global vocabulary size: {args['global_vocab_size']}")
print(f"Label vocabulary size: {args['label_vocab_size']}")

Age and gender vocabulary size: 41
Global vocabulary size: 4207
Label vocabulary size: 18


In [7]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))

# output: input_ids (a patient has multiple visits), labels
train_dataset = FinetuneHGDataset(train_data, tokenizer, token_type=args["predicted_token_type"], task=args["task"], level=args["level"])
val_dataset = FinetuneHGDataset(val_data, tokenizer, token_type=args["predicted_token_type"], task=args["task"], level=args["level"])
test_dataset = FinetuneHGDataset(test_data, tokenizer, token_type=args["predicted_token_type"], task=args["task"], level=args["level"])
print(len(train_dataset), len(val_dataset), len(test_dataset))

7621 15401 15621


In [8]:
long_adm_seq_crite = 3
val_long_seq_idx, test_long_seq_idx = [], []
for i in range(len(val_dataset)):
    hadm_id = list(val_dataset.records.keys())[i]
    num_adms = len(val_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite:
        val_long_seq_idx.append(i)
for i in range(len(test_dataset)):
    hadm_id = list(test_dataset.records.keys())[i]
    num_adms = len(test_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite:
        test_long_seq_idx.append(i)
print(len(val_long_seq_idx), len(test_long_seq_idx))

3379 3608


In [9]:
use_full_graph = True
train_batch_size = len(train_dataset) if use_full_graph else args["batch_size"]
val_batch_size = len(val_dataset) if use_full_graph else args["batch_size"]
test_batch_size = len(test_dataset) if use_full_graph else args["batch_size"]
train_dataloader = DataLoader(train_dataset, collate_fn=batcher_SetGNN_finetune(device = device), batch_size = train_batch_size, shuffle = True)
val_dataloader = DataLoader(val_dataset, collate_fn=batcher_SetGNN_finetune(device = device), batch_size = val_batch_size, shuffle = False)
test_dataloader = DataLoader(test_dataset, collate_fn=batcher_SetGNN_finetune(device = device), batch_size = test_batch_size, shuffle = False)

In [10]:
if args["task"] in ["death", "stay", "readmission"]:
    eval_metric = "prauc"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "prauc"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [11]:
random.seed(42)
seeds = [random.randint(0, 2**32 - 1) for _ in range(15)]
print(seeds)

[2746317213, 1181241943, 958682846, 3163119785, 1812140441, 127978094, 939042955, 2340505846, 946785248, 2530876844]


In [None]:
final_metrics, final_long_seq_metrics = [], []

for seed in seeds:
    set_random_seed(seed)
    print(f"Training with seed: {seed}")
    
    # Initialize model, optimizer, and loss function
    model = SetGNN(args, tokenizer).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args["lr"])
    
    best_test_metric, best_test_long_seq_metric = train_with_early_stopping(
        model, 
        train_dataloader, 
        val_dataloader, 
        test_dataloader,
        optimizer, 
        loss_fn, 
        device, 
        args,
        val_long_seq_idx,
        test_long_seq_idx,
        task_type=task_type,
        eval_metric = "f1")
    
    final_metrics.append(best_test_metric)
    final_long_seq_metrics.append(best_test_long_seq_metric)

[INFO] Random seed set to 2746317213
Training with seed: 2746317213


Training Batches: 100%|██████████| 1/1 [00:02<00:00,  2.97s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.97s/it]
Running inference: 100%|██████████| 1/1 [00:05<00:00,  5.54s/it]


Epoch: 001, Average Loss: 0.6898
Validation: {'precision': 0.2703101920196409, 'recall': 0.13034188034095198, 'f1': 0.17587697782862788, 'auc': 0.6844667772723785, 'prauc': 0.17964804997537487}
Test:       {'precision': 0.2383638928034081, 'recall': 0.11918194640254456, 'f1': 0.15890925742524153, 'auc': 0.6701963182056832, 'prauc': 0.16740504708482043}


Training Batches: 100%|██████████| 1/1 [00:02<00:00,  2.13s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.85s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.50s/it]


Epoch: 002, Average Loss: 0.6443
Validation: {'precision': 0.23008849556504032, 'recall': 0.037037037036773246, 'f1': 0.06380367859228434, 'auc': 0.6348292837272619, 'prauc': 0.1555045944433323}
Test:       {'precision': 0.22310756971222678, 'recall': 0.03949224259492601, 'f1': 0.06710604897160762, 'auc': 0.6187983289253238, 'prauc': 0.14745341698788567}


Training Batches: 100%|██████████| 1/1 [00:03<00:00,  3.27s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.57s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.73s/it]


Epoch: 003, Average Loss: 0.5824
Validation: {'precision': 0.25252525249974495, 'recall': 0.01780626780614098, 'f1': 0.03326679850283027, 'auc': 0.636021210894398, 'prauc': 0.14588751390378552}
Test:       {'precision': 0.19827586205187278, 'recall': 0.016220028208630326, 'f1': 0.029986960791940763, 'auc': 0.6236105286562652, 'prauc': 0.1412498238302756}


Training Batches: 100%|██████████| 1/1 [00:02<00:00,  2.50s/it]
Running inference: 100%|██████████| 1/1 [00:05<00:00,  5.29s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.27s/it]


Epoch: 004, Average Loss: 0.5155
Validation: {'precision': 0.149999999925, 'recall': 0.002136752136736918, 'f1': 0.004213482869054589, 'auc': 0.6373595623970704, 'prauc': 0.1414823046344209}
Test:       {'precision': 0.14285714280612244, 'recall': 0.0028208744710661437, 'f1': 0.00553250307796203, 'auc': 0.6289407808020854, 'prauc': 0.13911899728716248}


Training Batches: 100%|██████████| 1/1 [00:02<00:00,  2.72s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.49s/it]
Running inference: 100%|██████████| 1/1 [00:05<00:00,  5.04s/it]


Epoch: 005, Average Loss: 0.4512
Validation: {'precision': 0.33333333277777777, 'recall': 0.0014245014244912786, 'f1': 0.002836879347839648, 'auc': 0.6347918825503308, 'prauc': 0.14154896760600896}
Test:       {'precision': 0.1999999996, 'recall': 0.0007052186177665359, 'f1': 0.0014054813073248527, 'auc': 0.6282973799114929, 'prauc': 0.14190885507313794}


Training Batches: 100%|██████████| 1/1 [00:03<00:00,  3.22s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.48s/it]
Running inference: 100%|██████████| 1/1 [00:05<00:00,  5.02s/it]


Epoch: 006, Average Loss: 0.3941
Validation: {'precision': 0.16666666638888888, 'recall': 0.0007122507122456393, 'f1': 0.001418439631547714, 'auc': 0.6326746197343468, 'prauc': 0.1408725998847107}
Test:       {'precision': 0.1999999996, 'recall': 0.0007052186177665359, 'f1': 0.0014054813073248527, 'auc': 0.6275844154580267, 'prauc': 0.14371590454321323}


Training Batches: 100%|██████████| 1/1 [00:02<00:00,  2.53s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.92s/it]
Running inference: 100%|██████████| 1/1 [00:05<00:00,  5.56s/it]


Epoch: 007, Average Loss: 0.3478
Validation: {'precision': 0.1999999996, 'recall': 0.0007122507122456393, 'f1': 0.0014194463451571878, 'auc': 0.6358545085057911, 'prauc': 0.1386527025060025}
Test:       {'precision': 0.3999999992, 'recall': 0.0014104372355330718, 'f1': 0.002810962684676847, 'auc': 0.6323276226332126, 'prauc': 0.14167435455861302}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.71s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.87s/it]
Running inference: 100%|██████████| 1/1 [00:05<00:00,  5.27s/it]


Epoch: 008, Average Loss: 0.3123
Validation: {'precision': 0.1999999996, 'recall': 0.0007122507122456393, 'f1': 0.0014194463451571878, 'auc': 0.6443709854797945, 'prauc': 0.139405297894629}
Test:       {'precision': 0.3999999992, 'recall': 0.0014104372355330718, 'f1': 0.002810962684676847, 'auc': 0.6432144691813555, 'prauc': 0.14261276100169565}


Training Batches: 100%|██████████| 1/1 [00:02<00:00,  2.54s/it]
Running inference: 100%|██████████| 1/1 [00:05<00:00,  5.27s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.32s/it]


Epoch: 009, Average Loss: 0.2875
Validation: {'precision': 0.3333333329629629, 'recall': 0.002136752136736918, 'f1': 0.004246284374424329, 'auc': 0.6585668184492932, 'prauc': 0.14469886449292374}
Test:       {'precision': 0.33333333277777777, 'recall': 0.0014104372355330718, 'f1': 0.0028089886800909004, 'auc': 0.6596668228081495, 'prauc': 0.14837951778543618}


Training Batches: 100%|██████████| 1/1 [00:02<00:00,  2.74s/it]
Running inference: 100%|██████████| 1/1 [00:05<00:00,  5.30s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.04s/it]


Epoch: 010, Average Loss: 0.2698
Validation: {'precision': 0.3333333330555555, 'recall': 0.0028490028489825572, 'f1': 0.005649717345989345, 'auc': 0.6793032776457796, 'prauc': 0.15706056224677953}
Test:       {'precision': 0.2222222219753086, 'recall': 0.0014104372355330718, 'f1': 0.00280308326634842, 'auc': 0.6821044482248978, 'prauc': 0.16202101385568718}


Training Batches: 100%|██████████| 1/1 [00:02<00:00,  2.73s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.60s/it]
Running inference: 100%|██████████| 1/1 [00:05<00:00,  5.07s/it]


Epoch: 011, Average Loss: 0.2537
Validation: {'precision': 0.45454545413223135, 'recall': 0.0035612535612281966, 'f1': 0.007067137654819018, 'auc': 0.7048835454565254, 'prauc': 0.17148943417691087}
Test:       {'precision': 0.3333333329629629, 'recall': 0.0021156558532996078, 'f1': 0.004204624962194228, 'auc': 0.7101204407936622, 'prauc': 0.17564505952134077}

Early stopping triggered after 11 epochs (no improvement for 10 epochs).

Best validation performance:
{'precision': 0.2703101920196409, 'recall': 0.13034188034095198, 'f1': 0.17587697782862788, 'auc': 0.6844667772723785, 'prauc': 0.17964804997537487}
Corresponding test performance:
{'precision': 0.2383638928034081, 'recall': 0.11918194640254456, 'f1': 0.15890925742524153, 'auc': 0.6701963182056832, 'prauc': 0.16740504708482043}
Corresponding test-long performance:
{'precision': 0.18181818180899909, 'recall': 0.15384615383957922, 'f1': 0.16666666169367297, 'auc': 0.7080426489140182, 'prauc': 0.139588845550085}
[INFO] Random seed 

Training Batches: 100%|██████████| 1/1 [00:03<00:00,  3.01s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.97s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.55s/it]


Epoch: 001, Average Loss: 0.6997
Validation: {'precision': 0.23806492679670235, 'recall': 0.2663817663798691, 'f1': 0.25142856644263667, 'auc': 0.6704875403703724, 'prauc': 0.17862193927186812}
Test:       {'precision': 0.22317596566386771, 'recall': 0.25669957686701905, 'f1': 0.23876680381260243, 'auc': 0.6540655160658067, 'prauc': 0.16926796450226383}


Training Batches: 100%|██████████| 1/1 [00:02<00:00,  2.25s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.98s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.86s/it]


Epoch: 002, Average Loss: 0.6555
Validation: {'precision': 0.2330917874367984, 'recall': 0.1374643874634084, 'f1': 0.1729390634317954, 'auc': 0.6676093289831948, 'prauc': 0.16249105202511688}
Test:       {'precision': 0.23302107728064378, 'recall': 0.14033850493554065, 'f1': 0.17517605164460034, 'auc': 0.6570531742682942, 'prauc': 0.1583368544642176}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.62s/it]
Running inference: 100%|██████████| 1/1 [00:05<00:00,  5.49s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.61s/it]


Epoch: 003, Average Loss: 0.5949
Validation: {'precision': 0.19138755979945513, 'recall': 0.028490028489825573, 'f1': 0.04959702192226915, 'auc': 0.6514690164579426, 'prauc': 0.15435431919897621}
Test:       {'precision': 0.21256038646316133, 'recall': 0.03102961918172758, 'f1': 0.05415384393002423, 'auc': 0.6404550648678983, 'prauc': 0.15051101399463465}


Training Batches: 100%|██████████| 1/1 [00:02<00:00,  2.27s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.97s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.70s/it]


Epoch: 004, Average Loss: 0.5319
Validation: {'precision': 0.24390243896490185, 'recall': 0.007122507122456393, 'f1': 0.013840829898262736, 'auc': 0.6503712791935268, 'prauc': 0.15188855737665002}
Test:       {'precision': 0.2222222221728395, 'recall': 0.007052186177665359, 'f1': 0.013670539389890275, 'auc': 0.644372645402494, 'prauc': 0.14934746907345187}


Training Batches: 100%|██████████| 1/1 [00:02<00:00,  2.13s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.71s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.81s/it]


Epoch: 005, Average Loss: 0.4704
Validation: {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.6441595797797126, 'prauc': 0.1469256396954959}
Test:       {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.6399646442322771, 'prauc': 0.14472833642356114}


Training Batches: 100%|██████████| 1/1 [00:02<00:00,  2.25s/it]
Running inference: 100%|██████████| 1/1 [00:05<00:00,  5.06s/it]
Running inference: 100%|██████████| 1/1 [00:05<00:00,  5.13s/it]


Epoch: 006, Average Loss: 0.4124
Validation: {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.6405274675261101, 'prauc': 0.14429818593197458}
Test:       {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.6367114180668836, 'prauc': 0.14184397584481118}


Training Batches: 100%|██████████| 1/1 [00:01<00:00,  1.65s/it]
Running inference: 100%|██████████| 1/1 [00:05<00:00,  5.13s/it]
Running inference: 100%|██████████| 1/1 [00:05<00:00,  5.55s/it]


Epoch: 007, Average Loss: 0.3629
Validation: {'precision': 0.4999999975, 'recall': 0.0007122507122456393, 'f1': 0.0014224750782563657, 'auc': 0.6437248356230996, 'prauc': 0.14476326024447658}
Test:       {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.6412498074712956, 'prauc': 0.14282122027790412}


Training Batches: 100%|██████████| 1/1 [00:02<00:00,  2.12s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.91s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.57s/it]


Epoch: 008, Average Loss: 0.3240
Validation: {'precision': 0.4999999975, 'recall': 0.0007122507122456393, 'f1': 0.0014224750782563657, 'auc': 0.6539375959073037, 'prauc': 0.15028173374384812}
Test:       {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.6540780285696213, 'prauc': 0.14899347371028632}


Training Batches: 100%|██████████| 1/1 [00:03<00:00,  3.41s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.56s/it]
Running inference: 100%|██████████| 1/1 [00:04<00:00,  4.96s/it]


Epoch: 009, Average Loss: 0.2948
Validation: {'precision': 0.49999999875, 'recall': 0.0014245014244912786, 'f1': 0.0028409090342119718, 'auc': 0.6723840344705531, 'prauc': 0.16331746730310856}
Test:       {'precision': 0.9999999900000002, 'recall': 0.0007052186177665359, 'f1': 0.001409443255804021, 'auc': 0.6749723458769861, 'prauc': 0.16357399031330253}


Training Batches: 100%|██████████| 1/1 [00:02<00:00,  2.15s/it]
Running inference:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
import numpy as np
def topk_avg_performance_formatted(performances, long_seq_performances, k=5):
    metrics = ["f1", "auc", "prauc"]
    scores = {m: np.array([p[m] for p in performances]) for m in metrics}

    # 计算排名（值越大排名越靠前）
    ranks = {m: (-scores[m]).argsort().argsort() + 1 for m in metrics}
    avg_ranks = np.mean(np.stack([ranks[m] for m in metrics], axis=1), axis=1)

    # 选 top-k
    topk_idx = np.argsort(avg_ranks)[:k]
    final_avg = {m: np.mean([performances[i][m] for i in topk_idx]) for m in performances[0].keys()}
    final_std = {m: np.std([performances[i][m] for i in topk_idx], ddof=0) for m in performances[0].keys()}
    final_long_seq_avg = {m: np.mean([long_seq_performances[i][m] for i in topk_idx]) for m in long_seq_performances[0].keys()}
    final_long_seq_std = {m: np.std([long_seq_performances[i][m] for i in topk_idx], ddof=0) for m in long_seq_performances[0].keys()}

    # 打印结果（转百分比，均保留两位小数）
    print("Final Metrics:")
    for m in performances[0].keys():
        mean_val = final_avg[m] * 100
        std_val = final_std[m] * 100
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")
    print("\nFinal Long Sequence Metrics:")
    for m in long_seq_performances[0].keys():
        mean_val = final_long_seq_avg[m] * 100
        std_val = final_long_seq_std[m] * 100
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")

In [None]:
def print_per_class_performance(dfs, col_name="prauc"):
    """
    输入一个 DataFrame 列表，对每个疾病在所有表格的指定列计算 mean ± std 并打印。

    参数:
        dfs (list[pd.DataFrame]): 多个表格组成的列表
        col_name (str): 要计算的指标列名 (默认: "prauc")
    """
    # 拼接所有表格
    all_values = pd.concat(dfs, axis=0)

    # 按疾病分组，计算 mean 和 std
    grouped = all_values.groupby(all_values.index)[col_name].agg(["mean", "std"])

    # 打印
    for disease, row in grouped.iterrows():
        mean_val = row["mean"] * 100
        std_val = row["std"] * 100
        print(f"{disease}: {mean_val:.2f} ± {std_val:.2f}")

In [None]:
if task_type == "binary":
    topk_avg_performance_formatted(final_metrics, final_long_seq_metrics)
else:
    final_metrics_global = [metrics["global"] for metrics in final_metrics]
    final_metrics_per_class = [metrics["per_class"] for metrics in final_metrics]
    final_long_seq_metrics_global = [metrics["global"] for metrics in final_long_seq_metrics]
    final_long_seq_metrics_per_class = [metrics["per_class"] for metrics in final_long_seq_metrics]
    topk_avg_performance_formatted(final_metrics_global, final_long_seq_metrics_global)
    print("\nPer-class performance, all patients:")
    print_per_class_performance(final_metrics_per_class, col_name="prauc")
    print("\nPer-class performance, long seq:")
    print_per_class_performance(final_long_seq_metrics_per_class, col_name="prauc")