In [1]:
import torch
from pyhealth.models import RETAINLayer, StageNetLayer
from set_seed_utils import set_random_seed
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve, precision_recall_fscore_support
from tqdm import tqdm
import pandas as pd
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import random
import pickle
from torch.utils.data import DataLoader
from token_utils_rep import EHRTokenizer
from dataset_utils_rep import HBERTFinetuneEHRDataset, batcher

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
args = {
    "dataset": "MIMIC-IV", 
    "task": "death",  # options: death, stay, readmission, next_diag_12m
    "batch_size": 64,
    "hidden_size": 256,
    "lr": 1e-3,
    "epochs": 500,
    "early_stop_patience": 5,
    "dropout": 0.0,
    "backbone": "stagenet",  # options: retain, stagenet
}

In [4]:
PHENO_ORDER = [
    "Acute and unspecified renal failure",
    "Acute cerebrovascular disease",
    "Acute myocardial infarction",
    "Cardiac dysrhythmias",
    "Chronic kidney disease",
    "Chronic obstructive pulmonary disease",
    "Conduction disorders",
    "Congestive heart failure; nonhypertensive",
    "Coronary atherosclerosis and related",
    "Disorders of lipid metabolism",
    "Essential hypertension",
    "Fluid and electrolyte disorders",
    "Gastrointestinal hemorrhage",
    "Hypertension with complications",
    "Other liver diseases",
    "Other lower respiratory disease",
    "Pneumonia",
    "Septicemia (except in labor)",
]

In [5]:
@torch.no_grad()
def evaluate(model, dataloader, device, long_seq_idx=None, task_type="binary"):
    model.eval()
    predicted_scores, gt_labels = [], []

    # 推理：收集 logits 与 labels
    for _, batch in enumerate(tqdm(dataloader, desc="Running inference")):
        batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]
        labels = batch[-1]
        output_logits = model(*batch[:-1])
        predicted_scores.append(output_logits)
        gt_labels.append(labels)

    if task_type == "binary":
        # —— 标准二分类评估 —— #
        logits_all = torch.cat(predicted_scores, dim=0).view(-1)          # logits [N]
        labels_all = torch.cat(gt_labels, dim=0).view(-1).cpu().numpy()    # y_true [N]
        scores_all = logits_all.cpu().numpy()                              # 连续分数（logits）
        ypred_all  = (logits_all > 0).float().cpu().numpy()                # logits > 0

        tp = (ypred_all * labels_all).sum()
        precision = tp / (ypred_all.sum() + 1e-8)
        recall    = tp / (labels_all.sum() + 1e-8)
        f1        = 2 * precision * recall / (precision + recall + 1e-8)
        roc_auc   = roc_auc_score(labels_all, scores_all)
        prec_curve, rec_curve, _ = precision_recall_curve(labels_all, scores_all)
        pr_auc    = auc(rec_curve, prec_curve)

        all_performance = {"precision": float(precision),
                           "recall": float(recall),
                           "f1": float(f1),
                           "auc": float(roc_auc),
                           "prauc": float(pr_auc)}

        subset_performance = None
        if long_seq_idx is not None:
            idx = torch.as_tensor(long_seq_idx, device=logits_all.device, dtype=torch.long)
            logits_sub = logits_all.index_select(0, idx).view(-1)
            labels_sub = torch.as_tensor(labels_all, device=logits_all.device)[idx].cpu().numpy()
            scores_sub = logits_sub.cpu().numpy()
            ypred_sub  = (logits_sub > 0).float().cpu().numpy()

            tp = (ypred_sub * labels_sub).sum()
            precision = tp / (ypred_sub.sum() + 1e-8)
            recall    = tp / (labels_sub.sum() + 1e-8)
            f1        = 2 * precision * recall / (precision + recall + 1e-8)
            roc_auc   = roc_auc_score(labels_sub, scores_sub)
            prec_curve, rec_curve, _ = precision_recall_curve(labels_sub, scores_sub)
            pr_auc    = auc(rec_curve, prec_curve)

            subset_performance = {"precision": float(precision),
                                  "recall": float(recall),
                                  "f1": float(f1),
                                  "auc": float(roc_auc),
                                  "prauc": float(pr_auc)}

        return all_performance, subset_performance

    else:
        # —— Multi-label evaluation（按类聚合） —— #
        logits_all = torch.cat(predicted_scores, dim=0)    # [B, C]
        labels_all_t = torch.cat(gt_labels, dim=0)         # [B, C]

        def _compute_metrics(logits_sub, labels_sub):
            # 连续分数（概率）：sigmoid(logits)，CPU + fp16 先升为 fp32
            if logits_sub.device.type == "cpu" and logits_sub.dtype == torch.float16:
                prob_t = torch.sigmoid(logits_sub.float())
            else:
                prob_t = torch.sigmoid(logits_sub)
            # 二值化：logits > 0
            ypred_t = (logits_sub > 0).to(torch.int32)

            y_true = labels_sub.cpu().numpy().astype(np.int32)       # [N, C]
            y_pred = ypred_t.cpu().numpy().astype(np.int32)          # [N, C]
            scores = prob_t.cpu().numpy()                             # [N, C]

            # per-class P/R/F1
            p_cls, r_cls, f1_cls, _ = precision_recall_fscore_support(
                y_true, y_pred, average=None, zero_division=0
            )

            # per-class AUC / PR-AUC
            C = y_true.shape[1]
            aucs, praucs = [], []
            for c in range(C):
                yt, ys = y_true[:, c], scores[:, c]
                if yt.max() == yt.min():
                    aucs.append(np.nan)
                    praucs.append(np.nan)
                else:
                    aucs.append(roc_auc_score(yt, ys))
                    prec_curve, rec_curve, _ = precision_recall_curve(yt, ys)
                    praucs.append(auc(rec_curve, prec_curve))

            # 宏平均（忽略 NaN）
            summary = {
                "precision": float(np.mean(p_cls)),
                "recall":    float(np.mean(r_cls)),
                "f1":        float(np.mean(f1_cls)),
                "auc":       float(np.nanmean(aucs)) if np.any(~np.isnan(aucs)) else float("nan"),
                "prauc":     float(np.nanmean(praucs)) if np.any(~np.isnan(praucs)) else float("nan"),
            }

            per_class_df = pd.DataFrame({
                "precision": p_cls,
                "recall":    r_cls,
                "f1":        f1_cls,
                "auc":       aucs,
                "prauc":     praucs,
            }, index=PHENO_ORDER)   # 确保 PHENO_ORDER 已定义且长度=C

            return {"global": summary, "per_class": per_class_df}

        # 全量
        all_performance = _compute_metrics(logits_all, labels_all_t)

        # 子集
        subset_performance = None
        if long_seq_idx is not None:
            idx = torch.as_tensor(long_seq_idx, device=logits_all.device, dtype=torch.long)
            subset_performance = _compute_metrics(
                logits_all.index_select(0, idx),
                labels_all_t.index_select(0, idx)
            )

        return all_performance, subset_performance

In [6]:
args["predicted_token_type"] = ["diag", "med", "pro", "lab"]
args["special_tokens"] = ("[PAD]", "[CLS]", "[SEP]", "[MASK0]", "[MASK1]", "[MASK2]", "[MASK3]")
args["max_visit_size"] = 15

full_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic.pkl"

if args["task"] == "next_diag_6m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_6m.pkl"
elif args["task"] == "next_diag_12m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_downstream.pkl"

In [7]:
ehr_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_data["NDC"].values.tolist()
lab_sentences = ehr_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_data["PRO_CODE"].values.tolist()
gender_set = [["M"], ["F"]]
age_gender_set = [[str(c) + "_" + gender] for c in set(ehr_data["AGE"].values.tolist()) for gender in ["M", "F"]]
age_set = [[c] for c in set(ehr_data["AGE"].values.tolist())]    

In [8]:
tokenizer = EHRTokenizer(diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, gender_set, age_set, age_gender_set, special_tokens=args["special_tokens"])
tokenizer.build_tree()

In [9]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))

train_dataset = HBERTFinetuneEHRDataset(
    train_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

val_dataset = HBERTFinetuneEHRDataset(
    val_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

test_dataset = HBERTFinetuneEHRDataset(
    test_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

print(len(train_dataset), len(val_dataset), len(test_dataset))

train_dataloader = DataLoader(
    train_dataset, 
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False), 
    batch_size =  args["batch_size"],
    shuffle = True
)

val_dataloader = DataLoader(
    val_dataset, 
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False), 
    batch_size =  args["batch_size"],
    shuffle = False
)

test_dataloader = DataLoader(
    test_dataset, 
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False),
    batch_size = args["batch_size"],
    shuffle = False
)

7621 15401 15621


In [10]:
long_adm_seq_crite = 3
val_long_seq_idx, test_long_seq_idx = [], []
for i in range(len(val_dataset)):
    hadm_id = list(val_dataset.records.keys())[i]
    num_adms = len(val_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite:
        val_long_seq_idx.append(i)
for i in range(len(test_dataset)):
    hadm_id = list(test_dataset.records.keys())[i]
    num_adms = len(test_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite:
        test_long_seq_idx.append(i)
print(len(val_long_seq_idx), len(test_long_seq_idx))

3379 3608


In [11]:
# examine a batch
batch = next(iter(train_dataloader))  # 取第一个 batch
input_ids, input_types, labeled_batch_idx, labels = batch

# 打印每个张量的形状
print("input_ids shape:", input_ids.shape)
print("input_types shape:", input_types.shape)
print("labeled_batch_idx shape:", len(labeled_batch_idx)) # it is a list
print("labels shape:", labels.shape)

input_ids shape: torch.Size([114, 129])
input_types shape: torch.Size([114, 129])
labeled_batch_idx shape: 64
labels shape: torch.Size([64, 1])


In [12]:
args["vocab_size"] = len(tokenizer.vocab.word2id)
args["label_vocab_size"] = 18  # only for diagnosis

In [13]:
if args["task"] in ["death", "stay", "readmission"]:
    eval_metric = "f1"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "prauc"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [14]:
def train_with_early_stopping(model, 
                              train_dataloader, 
                              val_dataloader, 
                              test_dataloader,
                              optimizer, 
                              loss_fn, 
                              device, 
                              args,
                              val_long_seq_idx = None,
                              test_long_seq_idx = None,
                              task_type="binary", 
                              eval_metric="f1"):
    best_score = 0.
    best_val_metric = None
    best_test_metric = None
    best_test_long_seq_metric = None
    epochs_no_improve = 0

    for epoch in range(1, 1 + args["epochs"]):
        model.train()
        ave_loss = 0.

        for step, batch in enumerate(tqdm(train_dataloader, desc="Training Batches")):
            batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]

            labels = batch[-1].float()
            output_logits = model(*batch[:-1])
            
            loss = loss_fn(output_logits.view(-1), labels.view(-1))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            ave_loss += loss.item()
            

        ave_loss /= (step + 1)

        # Evaluation
        val_metric, val_long_seq_metric = evaluate(model, val_dataloader, device, task_type=task_type, long_seq_idx=val_long_seq_idx)
        test_metric, test_long_seq_metric = evaluate(model, test_dataloader, device, task_type=task_type, long_seq_idx=test_long_seq_idx)

        if task_type != "binary":
            val_per_class_df = val_metric["per_class"]
            val_metric = val_metric["global"]
            test_per_class_df = test_metric["per_class"]
            test_metric = test_metric["global"]
            
            if val_long_seq_idx != None:
                val_long_seq_per_class_df = val_long_seq_metric["per_class"]
                val_long_seq_metric = val_long_seq_metric["global"]
            if test_long_seq_idx != None:
                test_long_seq_per_class_df = test_long_seq_metric["per_class"]
                test_long_seq_metric = test_long_seq_metric["global"]

        # Logging
        print(f"Epoch: {epoch:03d}, Average Loss: {ave_loss:.4f}")
        print(f"Validation: {val_metric}")
        print(f"Test:       {test_metric}")

        # Check for improvement
        current_score = val_metric[eval_metric]
        if current_score > best_score:
            best_score = current_score
            best_val_metric = val_metric if task_type == "binary" else {"global": val_metric, "per_class": val_per_class_df}
            best_test_metric = test_metric if task_type == "binary" else {"global": test_metric, "per_class": test_per_class_df}
            best_test_long_seq_metric = test_long_seq_metric if task_type == "binary" else {"global": test_long_seq_metric, "per_class": test_long_seq_per_class_df}
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        # Early stopping check
        if epochs_no_improve >= args["early_stop_patience"]:
            print(f"\nEarly stopping triggered after {epoch} epochs (no improvement for {args['early_stop_patience']} epochs).")
            break

    print("\nBest validation performance:")
    print(best_val_metric)
    print("Corresponding test performance:")
    print(best_test_metric)
    if best_test_long_seq_metric is not None:
        print("Corresponding test-long performance:")
        print(best_test_long_seq_metric)
    return best_test_metric, best_test_long_seq_metric

In [15]:
class BinaryPredictionHead(nn.Module):
    def __init__(self, in_dim: int):
        super().__init__()
        self.cls = nn.Sequential(
            nn.Linear(in_dim, in_dim),
            nn.ReLU(),
            nn.Linear(in_dim, 1),
        )
    def forward(self, x):
        return self.cls(x)  # [N,1]

class PyHealthModel(nn.Module):
    def __init__(self, args):
        """
        必要参数：
          args["vocab_size"]:  int
          args["hidden_size"]: int  (embedding 维度 & RETAIN feature_size)
          args["backbone"]:    str  ("retain" 或 "stagenet")
        可选参数：
          args["dropout"]:     float (默认 0.5)
        """
        super().__init__()
        self.vocab_size  = int(args["vocab_size"])
        self.hidden_size = int(args["hidden_size"])
        self.dropout     = float(args.get("dropout", 0.5))
        self.backbone    = str(args.get("backbone", "retain")).lower()

        # Embeddings（pad=0）
        self.input_emb = nn.Embedding(self.vocab_size, self.hidden_size, padding_idx=0)
        self.type_emb  = nn.Embedding(5,              self.hidden_size, padding_idx=0)

        # Backbone
        if self.backbone == "retain":
            self.backbone_mod = RETAINLayer(feature_size=self.hidden_size, dropout=self.dropout)
            backbone_out_dim = self.hidden_size
        elif self.backbone == "stagenet":
            # 你自己的 StageNetLayer，构造参数按你的实现来
            self.backbone_mod = StageNetLayer(self.hidden_size)  # 示例
            # StageNet 输出维可能与 hidden_size 不同；提供可配置项
            backbone_out_dim = 384
        else:
            raise ValueError(f"Unknown backbone: {self.backbone}")

        # 分类头：输入维度与 backbone 输出对齐
        self.pred_head = BinaryPredictionHead(backbone_out_dim)

    def forward(self, input_ids, input_types, labeled_batch_idx):
        """
        input_ids:        LongTensor [B, T], pad=0
        input_types:      LongTensor [B, T], pad=0
        labeled_batch_idx: list[int] 或 LongTensor [P]，表示要从 batch 里取出的行（病人）
        return:
          logits:         Tensor [P, 1]
        """
        # 1) embeddings 相加
        x = self.input_emb(input_ids) + self.type_emb(input_types)   # [B,T,H]

        # 2) mask（以 input_ids!=0 作为有效位）
        mask = (input_ids != 0)  # [B,T]，RETAIN 接受 bool/float mask

        # 3) backbone 前向：得到每个病人的表示 c: [B, H*]
        if self.backbone == "retain":
            c = self.backbone_mod(x, mask=mask)                      # [B,H]
        else:  # stagenet
            out = self.backbone_mod(x, mask=mask)                    # 取决于你的实现
            c = out[0] if isinstance(out, (tuple, list)) else out    # [B,H*]

        # 4) 用 labeled_batch_idx 选择需要的样本行：c_sel [P,H*]
        idx = torch.as_tensor(labeled_batch_idx, device=c.device, dtype=torch.long)
        # 健壮性检查（调试期可保留）
        assert idx.ndim == 1 and (idx.numel() == 0 or (0 <= idx.min() and idx.max() < c.size(0))), \
            f"Index out of range: max={int(idx.max()) if idx.numel()>0 else 'NA'}, B={c.size(0)}"

        c_sel = c.index_select(0, idx)                                # [P,H*]

        # 5) 分类
        logits = self.pred_head(c_sel)                                # [P,1]
        return logits

In [16]:
random.seed(42)
seeds = [random.randint(0, 2**32 - 1) for _ in range(5)]
print(seeds)

[2746317213, 1181241943, 958682846, 3163119785, 1812140441]


In [17]:
final_metrics, final_long_seq_metrics = [], []

for seed in seeds:
    args["seed"] = seed
    set_random_seed(args["seed"])
    print(f"Training with seed: {args['seed']}")
    
    # Initialize model, optimizer, and loss function
    model = PyHealthModel(args)
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args["lr"])
    
    best_test_metric, best_test_long_seq_metric = train_with_early_stopping(
        model, 
        train_dataloader, 
        val_dataloader, 
        test_dataloader,
        optimizer, 
        loss_fn, 
        device, 
        args,
        val_long_seq_idx,
        test_long_seq_idx,
        task_type=task_type)
    
    final_metrics.append(best_test_metric)
    final_long_seq_metrics.append(best_test_long_seq_metric)

[INFO] Random seed set to 2746317213
Training with seed: 2746317213


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.35it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.25it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.53it/s]


Epoch: 001, Average Loss: 0.2942
Validation: {'precision': 0.6376470588085259, 'recall': 0.19301994301856826, 'f1': 0.29633679249272626, 'auc': 0.8685879880242957, 'prauc': 0.48204918366761645}
Test:       {'precision': 0.6556122448812344, 'recall': 0.18124118476599974, 'f1': 0.2839778971559477, 'auc': 0.8787738977650981, 'prauc': 0.4975875083074963}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.35it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.16it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.72it/s]


Epoch: 002, Average Loss: 0.1918
Validation: {'precision': 0.6035976015949128, 'recall': 0.6452991452945492, 'f1': 0.6237521464642753, 'auc': 0.9370879128148544, 'prauc': 0.6930726865345461}
Test:       {'precision': 0.6050700466937621, 'recall': 0.6396332863142481, 'f1': 0.6218717810811825, 'auc': 0.9407306527644144, 'prauc': 0.6912896219544262}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.43it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.36it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.68it/s]


Epoch: 003, Average Loss: 0.1613
Validation: {'precision': 0.8144329896813925, 'recall': 0.5064102564066495, 'f1': 0.6245059241201846, 'auc': 0.9495703393502922, 'prauc': 0.7463139205419885}
Test:       {'precision': 0.7901098901012076, 'recall': 0.5070521861741393, 'f1': 0.6176975897344963, 'auc': 0.9492601584897289, 'prauc': 0.7362048591386692}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.43it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.63it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.50it/s]


Epoch: 004, Average Loss: 0.1273
Validation: {'precision': 0.7042471042416661, 'recall': 0.6495726495680231, 'f1': 0.6758058490231547, 'auc': 0.9503072697507219, 'prauc': 0.749491408702321}
Test:       {'precision': 0.6897081413157473, 'recall': 0.6332863187543493, 'f1': 0.6602941126512977, 'auc': 0.952354967419327, 'prauc': 0.7382822467061158}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.48it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 20.97it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.48it/s]


Epoch: 005, Average Loss: 0.1057
Validation: {'precision': 0.6291739894514987, 'recall': 0.7649572649518166, 'f1': 0.6904532255155084, 'auc': 0.9541288558577977, 'prauc': 0.7602635987578094}
Test:       {'precision': 0.6327985739712847, 'recall': 0.7510578279213607, 'f1': 0.6868751965799718, 'auc': 0.9541250398339532, 'prauc': 0.7416087260848939}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.41it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.42it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.73it/s]


Epoch: 006, Average Loss: 0.0869
Validation: {'precision': 0.7443548387036746, 'recall': 0.6574074074027251, 'f1': 0.698184563849054, 'auc': 0.9550973682394701, 'prauc': 0.776918591399544}
Test:       {'precision': 0.7202852614839914, 'recall': 0.6410437235497811, 'f1': 0.6783582039671029, 'auc': 0.9559289754533472, 'prauc': 0.7572537132978976}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.36it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.26it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.80it/s]


Epoch: 007, Average Loss: 0.0725
Validation: {'precision': 0.6032520325170556, 'recall': 0.7927350427293965, 'f1': 0.6851338824378552, 'auc': 0.9554504149953175, 'prauc': 0.7667137939553507}
Test:       {'precision': 0.6035145524403981, 'recall': 0.775035260925423, 'f1': 0.678604502637276, 'auc': 0.9551945609933418, 'prauc': 0.7543870770435812}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.48it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.60it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.51it/s]


Epoch: 008, Average Loss: 0.0599
Validation: {'precision': 0.8030634573216295, 'recall': 0.5227920227882993, 'f1': 0.6333045681256417, 'auc': 0.9456661144522829, 'prauc': 0.7427787134395568}
Test:       {'precision': 0.7910922587402853, 'recall': 0.5260930888538358, 'f1': 0.6319356156968141, 'auc': 0.9461202151713712, 'prauc': 0.7297763477315112}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.32it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.22it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.90it/s]


Epoch: 009, Average Loss: 0.0488
Validation: {'precision': 0.6872005475654538, 'recall': 0.7150997150946219, 'f1': 0.7008725953461267, 'auc': 0.9498626791618149, 'prauc': 0.7573844611889274}
Test:       {'precision': 0.6601483479389066, 'recall': 0.6904090267934386, 'f1': 0.6749396709716592, 'auc': 0.9512815733420907, 'prauc': 0.7457643490350394}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.55it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.52it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.72it/s]


Epoch: 010, Average Loss: 0.0451
Validation: {'precision': 0.5754901960756104, 'recall': 0.8361823361763805, 'f1': 0.6817653842490164, 'auc': 0.9579606191558754, 'prauc': 0.755399189936282}
Test:       {'precision': 0.5793650793622056, 'recall': 0.823695345551314, 'f1': 0.6802562560678733, 'auc': 0.9605190285887872, 'prauc': 0.755950843304909}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.45it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.52it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.62it/s]


Epoch: 011, Average Loss: 0.0464
Validation: {'precision': 0.6800526662233045, 'recall': 0.7357549857497454, 'f1': 0.7068080688995848, 'auc': 0.9531070149952767, 'prauc': 0.7524716646263448}
Test:       {'precision': 0.669749009242604, 'recall': 0.7150916784152674, 'f1': 0.6916780304713107, 'auc': 0.9541506606750972, 'prauc': 0.7487760332668468}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.34it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.49it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.47it/s]


Epoch: 012, Average Loss: 0.0287
Validation: {'precision': 0.5944625407133851, 'recall': 0.779914529908975, 'f1': 0.6746765200406701, 'auc': 0.9502813687996228, 'prauc': 0.7209356181593204}
Test:       {'precision': 0.5773745997834718, 'recall': 0.7630465444233918, 'f1': 0.6573511494054294, 'auc': 0.950448548435356, 'prauc': 0.7080983738382329}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.46it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.26it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.73it/s]


Epoch: 013, Average Loss: 0.0240
Validation: {'precision': 0.6007905138305998, 'recall': 0.7578347578293602, 'f1': 0.670236215535025, 'auc': 0.9449378346642047, 'prauc': 0.6949340735508722}
Test:       {'precision': 0.5858641634423751, 'recall': 0.7482369534502946, 'f1': 0.657169397361727, 'auc': 0.9469775699466342, 'prauc': 0.685975131731535}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.44it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.51it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.44it/s]


Epoch: 014, Average Loss: 0.0299
Validation: {'precision': 0.6784682080875833, 'recall': 0.6688034187986553, 'f1': 0.673601142771609, 'auc': 0.9456028123242526, 'prauc': 0.7129622872160489}
Test:       {'precision': 0.6735566642860047, 'recall': 0.6664315937893764, 'f1': 0.66997518109961, 'auc': 0.9472818422616172, 'prauc': 0.7024117465920364}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.31it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.58it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.79it/s]


Epoch: 015, Average Loss: 0.0236
Validation: {'precision': 0.7350936967569414, 'recall': 0.6146723646679867, 'f1': 0.66951124406486, 'auc': 0.9472194591148654, 'prauc': 0.7255239597430732}
Test:       {'precision': 0.7326565143762043, 'recall': 0.6107193229858201, 'f1': 0.6661538411899173, 'auc': 0.9481940633730512, 'prauc': 0.7126499081634102}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.51it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.31it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.39it/s]


Epoch: 016, Average Loss: 0.0179
Validation: {'precision': 0.667534157445234, 'recall': 0.7307692307640259, 'f1': 0.6977218583172795, 'auc': 0.950196083939029, 'prauc': 0.7431874897566544}
Test:       {'precision': 0.6727509778313381, 'recall': 0.727785613535065, 'f1': 0.6991869868729024, 'auc': 0.9541435603257105, 'prauc': 0.7433038064869997}

Early stopping triggered after 16 epochs (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6800526662233045, 'recall': 0.7357549857497454, 'f1': 0.7068080688995848, 'auc': 0.9531070149952767, 'prauc': 0.7524716646263448}
Corresponding test performance:
{'precision': 0.669749009242604, 'recall': 0.7150916784152674, 'f1': 0.6916780304713107, 'auc': 0.9541506606750972, 'prauc': 0.7487760332668468}
Corresponding test-long performance:
{'precision': 0.5081967212948132, 'recall': 0.662393162364855, 'f1': 0.575139141633135, 'auc': 0.9242523267419532, 'prauc': 0.5602175222366894}
[INFO] Random seed set to 1181241943
Training w

Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.35it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.43it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.77it/s]


Epoch: 001, Average Loss: 0.2827
Validation: {'precision': 0.49999999999669315, 'recall': 0.5384615384577033, 'f1': 0.5185185135218209, 'auc': 0.8904740627163288, 'prauc': 0.5101113509742404}
Test:       {'precision': 0.5056179775247481, 'recall': 0.5394922425913999, 'f1': 0.5220061362504114, 'auc': 0.8963128779384398, 'prauc': 0.5076845318788831}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.42it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.12it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.48it/s]


Epoch: 002, Average Loss: 0.2064
Validation: {'precision': 0.838362068947449, 'recall': 0.2770655270635537, 'f1': 0.4164882189597252, 'auc': 0.9180249654637025, 'prauc': 0.650931125114053}
Test:       {'precision': 0.7643564356284286, 'recall': 0.2722143864578829, 'f1': 0.40145605436522996, 'auc': 0.9128590505174466, 'prauc': 0.6142108447184866}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.50it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.45it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.40it/s]


Epoch: 003, Average Loss: 0.1657
Validation: {'precision': 0.6927560366303684, 'recall': 0.5925925925883719, 'f1': 0.6387715881156701, 'auc': 0.938692754063905, 'prauc': 0.7034598091804927}
Test:       {'precision': 0.6674796747913213, 'recall': 0.578984485186326, 'f1': 0.6200906294616071, 'auc': 0.9371649367468106, 'prauc': 0.6783659343368342}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.36it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.00it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.61it/s]


Epoch: 004, Average Loss: 0.1321
Validation: {'precision': 0.771375464676846, 'recall': 0.5911680911638806, 'f1': 0.6693548337917404, 'auc': 0.9541310948398182, 'prauc': 0.7547840779067498}
Test:       {'precision': 0.7438914027082002, 'recall': 0.5796897038040925, 'f1': 0.6516052269386128, 'auc': 0.9518608724770299, 'prauc': 0.734163194628382}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.46it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.42it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.72it/s]


Epoch: 005, Average Loss: 0.1050
Validation: {'precision': 0.6832214765054818, 'recall': 0.7250712250660608, 'f1': 0.7035245285171763, 'auc': 0.9587977185587389, 'prauc': 0.7793816359719268}
Test:       {'precision': 0.6803713527806342, 'recall': 0.7235543018284658, 'f1': 0.7012986962986384, 'auc': 0.9568746128944132, 'prauc': 0.7702338522615517}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.36it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.45it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.81it/s]


Epoch: 006, Average Loss: 0.0919
Validation: {'precision': 0.5947611710293027, 'recall': 0.8247863247804503, 'f1': 0.6911369691647626, 'auc': 0.9614236119380077, 'prauc': 0.7754180618534957}
Test:       {'precision': 0.592252803258959, 'recall': 0.8194640338447147, 'f1': 0.6875739596224923, 'auc': 0.9607779182510459, 'prauc': 0.7755472913176731}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.54it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.57it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.59it/s]


Epoch: 007, Average Loss: 0.0887
Validation: {'precision': 0.7658662092558675, 'recall': 0.6360398860353559, 'f1': 0.6949416292787175, 'auc': 0.9530892557969789, 'prauc': 0.773994206373962}
Test:       {'precision': 0.7621483375894106, 'recall': 0.6304654442832831, 'f1': 0.6900810448271064, 'auc': 0.951977109665244, 'prauc': 0.7649517246176433}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.50it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.43it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.69it/s]


Epoch: 008, Average Loss: 0.0753
Validation: {'precision': 0.7589359933436498, 'recall': 0.6502849002802686, 'f1': 0.7004219359526189, 'auc': 0.9538465405794119, 'prauc': 0.7628392520915883}
Test:       {'precision': 0.7543281121124953, 'recall': 0.6452750352563803, 'f1': 0.6955530166898344, 'auc': 0.9534579049083474, 'prauc': 0.7665057103689066}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.36it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.52it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.62it/s]


Epoch: 009, Average Loss: 0.0557
Validation: {'precision': 0.7141828653656287, 'recall': 0.7065527065476742, 'f1': 0.7103472918085199, 'auc': 0.9562273417563837, 'prauc': 0.7730822994302458}
Test:       {'precision': 0.7246790299520351, 'recall': 0.7165021156508005, 'f1': 0.7205673708815754, 'auc': 0.9583184168067952, 'prauc': 0.7803463189769103}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.45it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.61it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.61it/s]


Epoch: 010, Average Loss: 0.0489
Validation: {'precision': 0.7405063291080657, 'recall': 0.6666666666619184, 'f1': 0.7016491704208018, 'auc': 0.9502817758872626, 'prauc': 0.7526122385128154}
Test:       {'precision': 0.7315175097219337, 'recall': 0.6629055007005438, 'f1': 0.6955234874227936, 'auc': 0.952509089688535, 'prauc': 0.7622479213086746}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.58it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.42it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.74it/s]


Epoch: 011, Average Loss: 0.0479
Validation: {'precision': 0.6560587515259727, 'recall': 0.7635327635273254, 'f1': 0.7057274472952435, 'auc': 0.9509921438191782, 'prauc': 0.7397744641316539}
Test:       {'precision': 0.6497244335538903, 'recall': 0.7482369534502946, 'f1': 0.6955096639812666, 'auc': 0.9508554530732943, 'prauc': 0.7560726653507295}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.40it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.46it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.72it/s]


Epoch: 012, Average Loss: 0.0366
Validation: {'precision': 0.5555016965557175, 'recall': 0.8162393162335027, 'f1': 0.661090274957625, 'auc': 0.9528352839955327, 'prauc': 0.7174006476793867}
Test:       {'precision': 0.5653031049750356, 'recall': 0.8088857545782167, 'f1': 0.6655062324610403, 'auc': 0.9518100776698778, 'prauc': 0.7348824136788678}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.48it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.50it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.47it/s]


Epoch: 013, Average Loss: 0.0317
Validation: {'precision': 0.6743442098485327, 'recall': 0.7507122507069038, 'f1': 0.7104819633277365, 'auc': 0.9505249853092248, 'prauc': 0.7364499472859303}
Test:       {'precision': 0.6747382198908721, 'recall': 0.7270803949172985, 'f1': 0.6999321063396259, 'auc': 0.9511725358088494, 'prauc': 0.7518240018942477}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.46it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 20.46it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.34it/s]


Epoch: 014, Average Loss: 0.0254
Validation: {'precision': 0.7601410934677237, 'recall': 0.613960113955741, 'f1': 0.6792750147517856, 'auc': 0.9463638117814014, 'prauc': 0.7413068012935466}
Test:       {'precision': 0.7724014336848352, 'recall': 0.6078984485147539, 'f1': 0.6803472720980086, 'auc': 0.9440551058612442, 'prauc': 0.7505594808036375}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.45it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.32it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.38it/s]


Epoch: 015, Average Loss: 0.0212
Validation: {'precision': 0.7079388200968104, 'recall': 0.6923076923027613, 'f1': 0.7000360050784047, 'auc': 0.9498840258199407, 'prauc': 0.7366838406058598}
Test:       {'precision': 0.7017673048549208, 'recall': 0.6720733427315086, 'f1': 0.6865994186285131, 'auc': 0.9505855901438015, 'prauc': 0.7549915928267134}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.48it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.56it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.75it/s]


Epoch: 016, Average Loss: 0.0322
Validation: {'precision': 0.6296948356770559, 'recall': 0.7642450142395709, 'f1': 0.6904761855183328, 'auc': 0.9498547918387884, 'prauc': 0.7312845223302543}
Test:       {'precision': 0.6334922526779888, 'recall': 0.7496473906858276, 'f1': 0.6866925014907751, 'auc': 0.9482262383828604, 'prauc': 0.7415825524800943}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.42it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.36it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.24it/s]


Epoch: 017, Average Loss: 0.0171
Validation: {'precision': 0.7405172413729266, 'recall': 0.6118233618190042, 'f1': 0.6700467969121291, 'auc': 0.939929842516111, 'prauc': 0.7177610475755639}
Test:       {'precision': 0.7321428571366315, 'recall': 0.6071932298969874, 'f1': 0.6638396249535881, 'auc': 0.939782929906046, 'prauc': 0.7209145103734785}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.44it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.43it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.59it/s]


Epoch: 018, Average Loss: 0.0159
Validation: {'precision': 0.7155621742314553, 'recall': 0.6844729344680593, 'f1': 0.6996723648553985, 'auc': 0.9437288861451182, 'prauc': 0.7275078515483423}
Test:       {'precision': 0.713963963958604, 'recall': 0.6706629054959756, 'f1': 0.6916363586362236, 'auc': 0.942772623873043, 'prauc': 0.7399654882738738}

Early stopping triggered after 18 epochs (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6743442098485327, 'recall': 0.7507122507069038, 'f1': 0.7104819633277365, 'auc': 0.9505249853092248, 'prauc': 0.7364499472859303}
Corresponding test performance:
{'precision': 0.6747382198908721, 'recall': 0.7270803949172985, 'f1': 0.6999321063396259, 'auc': 0.9511725358088494, 'prauc': 0.7518240018942477}
Corresponding test-long performance:
{'precision': 0.5315614617763601, 'recall': 0.6837606837314631, 'f1': 0.5981308361775528, 'auc': 0.9238052173736822, 'prauc': 0.5661020104100785}
[INFO] Random seed set to 958682846
Trainin

Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.57it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.34it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.44it/s]


Epoch: 001, Average Loss: 0.2945
Validation: {'precision': 0.6094839609398957, 'recall': 0.31125356125134435, 'f1': 0.41206977392709493, 'auc': 0.8929336353516536, 'prauc': 0.5228877758367302}
Test:       {'precision': 0.6102841677860584, 'recall': 0.3180535966127077, 'f1': 0.41817338445774016, 'auc': 0.8942068299005544, 'prauc': 0.5235960792913007}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.36it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.52it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.82it/s]


Epoch: 002, Average Loss: 0.1982
Validation: {'precision': 0.7951807228778799, 'recall': 0.3290598290574854, 'f1': 0.46549117973390874, 'auc': 0.9125869869957889, 'prauc': 0.6397673418493149}
Test:       {'precision': 0.7625418060073154, 'recall': 0.3215796897015404, 'f1': 0.4523809482036762, 'auc': 0.9144022593212443, 'prauc': 0.621526449890682}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.49it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.65it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.29it/s]


Epoch: 003, Average Loss: 0.1609
Validation: {'precision': 0.6110363391614332, 'recall': 0.6467236467190405, 'f1': 0.6283736974218221, 'auc': 0.9423877868008753, 'prauc': 0.6888616665436204}
Test:       {'precision': 0.5940530058138717, 'recall': 0.6480959097274465, 'f1': 0.6198988145668346, 'auc': 0.9399517494019569, 'prauc': 0.6658502770143514}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.49it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.46it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.77it/s]


Epoch: 004, Average Loss: 0.1247
Validation: {'precision': 0.749268292675617, 'recall': 0.547008547004651, 'f1': 0.6323589905879096, 'auc': 0.9484595243954392, 'prauc': 0.7253995615445444}
Test:       {'precision': 0.7377049180256732, 'recall': 0.5394922425913999, 'f1': 0.6232179177222728, 'auc': 0.9463446954481397, 'prauc': 0.7076373074267718}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.48it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.50it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.40it/s]


Epoch: 005, Average Loss: 0.1061
Validation: {'precision': 0.8060796645617812, 'recall': 0.5477207977168966, 'f1': 0.65224766269141, 'auc': 0.9507195732011764, 'prauc': 0.7584068156141888}
Test:       {'precision': 0.7926701570597626, 'recall': 0.5338504936492676, 'f1': 0.6380109517799931, 'auc': 0.9496246099897248, 'prauc': 0.7440083042987979}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.51it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.45it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.57it/s]


Epoch: 006, Average Loss: 0.0951
Validation: {'precision': 0.9046153846014674, 'recall': 0.4188034188004359, 'f1': 0.5725413783361604, 'auc': 0.951304888898659, 'prauc': 0.772187659409146}
Test:       {'precision': 0.9070866141589435, 'recall': 0.40620592383352466, 'f1': 0.5611300493019638, 'auc': 0.9510875550537755, 'prauc': 0.7633437318310344}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.35it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.57it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.52it/s]


Epoch: 007, Average Loss: 0.0735
Validation: {'precision': 0.6835269993118009, 'recall': 0.7122507122456393, 'f1': 0.6975932981015414, 'auc': 0.956955061798957, 'prauc': 0.7678518534136962}
Test:       {'precision': 0.6883206634368465, 'recall': 0.7023977432954698, 'f1': 0.6952879531108419, 'auc': 0.9539574120050721, 'prauc': 0.74906206711279}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.55it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.67it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.80it/s]


Epoch: 008, Average Loss: 0.0608
Validation: {'precision': 0.8100436681134274, 'recall': 0.5284900284862644, 'f1': 0.6396551676295037, 'auc': 0.9553552837024294, 'prauc': 0.7586156768280639}
Test:       {'precision': 0.8025613660533345, 'recall': 0.530324400560435, 'f1': 0.6386411841628193, 'auc': 0.9515039682015571, 'prauc': 0.7442177106681086}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.53it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.19it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.51it/s]


Epoch: 009, Average Loss: 0.0549
Validation: {'precision': 0.654116145496783, 'recall': 0.7300569800517803, 'f1': 0.6900033608804828, 'auc': 0.9531239600182946, 'prauc': 0.7493602729765777}
Test:       {'precision': 0.6469088591418298, 'recall': 0.715796897033034, 'f1': 0.6796116454936644, 'auc': 0.952044190588472, 'prauc': 0.7382799376919805}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.40it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.53it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.60it/s]


Epoch: 010, Average Loss: 0.0419
Validation: {'precision': 0.6550652579201053, 'recall': 0.7507122507069038, 'f1': 0.6996349103852464, 'auc': 0.953212806895739, 'prauc': 0.7506795321346245}
Test:       {'precision': 0.6374695863708183, 'recall': 0.7390691114193296, 'f1': 0.6845199166426233, 'auc': 0.9525689213039976, 'prauc': 0.7459201507588379}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.50it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.14it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.83it/s]


Epoch: 011, Average Loss: 0.0543
Validation: {'precision': 0.7763033175281868, 'recall': 0.5833333333291786, 'f1': 0.6661244359249051, 'auc': 0.9519864044940847, 'prauc': 0.7549894652027485}
Test:       {'precision': 0.7730627306201747, 'recall': 0.5909732016883571, 'f1': 0.6698641037967772, 'auc': 0.950925115941754, 'prauc': 0.7542495609476618}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.42it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.61it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.76it/s]


Epoch: 012, Average Loss: 0.0283
Validation: {'precision': 0.7722222222150721, 'recall': 0.5940170940128632, 'f1': 0.6714975796207225, 'auc': 0.9515624736029109, 'prauc': 0.7547814511559133}
Test:       {'precision': 0.7778776978347314, 'recall': 0.6100141043680536, 'f1': 0.6837944614708996, 'auc': 0.9506717873923018, 'prauc': 0.7544498287046403}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.52it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.42it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.59it/s]


Epoch: 013, Average Loss: 0.0305
Validation: {'precision': 0.7627416520143873, 'recall': 0.6182336182292149, 'f1': 0.6829268243176695, 'auc': 0.9492014670624375, 'prauc': 0.7525198289563403}
Test:       {'precision': 0.7608695652107751, 'recall': 0.6170662905457189, 'f1': 0.6814641695039779, 'auc': 0.9501366295902642, 'prauc': 0.751294834799041}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.48it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.38it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.49it/s]


Epoch: 014, Average Loss: 0.0314
Validation: {'precision': 0.8497913769005592, 'recall': 0.4351851851820856, 'f1': 0.5756005607529849, 'auc': 0.9443467434108286, 'prauc': 0.7469035367038623}
Test:       {'precision': 0.8480845442424295, 'recall': 0.452750352606116, 'f1': 0.5903448230425791, 'auc': 0.9441711940910792, 'prauc': 0.7366423046112662}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.33it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.44it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.75it/s]


Epoch: 015, Average Loss: 0.0258
Validation: {'precision': 0.7161949685478287, 'recall': 0.6488603988557774, 'f1': 0.6808669606274063, 'auc': 0.9423950126064865, 'prauc': 0.7357125252828347}
Test:       {'precision': 0.7104247104192246, 'recall': 0.648801128345213, 'f1': 0.6782159920565125, 'auc': 0.9419045440945102, 'prauc': 0.7263868799424791}

Early stopping triggered after 15 epochs (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6550652579201053, 'recall': 0.7507122507069038, 'f1': 0.6996349103852464, 'auc': 0.953212806895739, 'prauc': 0.7506795321346245}
Corresponding test performance:
{'precision': 0.6374695863708183, 'recall': 0.7390691114193296, 'f1': 0.6845199166426233, 'auc': 0.9525689213039976, 'prauc': 0.7459201507588379}
Corresponding test-long performance:
{'precision': 0.49371069180837385, 'recall': 0.6709401709114983, 'f1': 0.5688405748053192, 'auc': 0.9184817027140679, 'prauc': 0.543417159944124}
[INFO] Random seed set to 3163119785
Traini

Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.53it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.56it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.65it/s]


Epoch: 001, Average Loss: 0.2973
Validation: {'precision': 0.48702185792017066, 'recall': 0.5078347578311408, 'f1': 0.4972105947197811, 'auc': 0.8617121760116687, 'prauc': 0.4721998790521194}
Test:       {'precision': 0.4983119513808352, 'recall': 0.5204513399117034, 'f1': 0.5091410781309634, 'auc': 0.874487893507073, 'prauc': 0.4810030132160932}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.42it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.50it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.70it/s]


Epoch: 002, Average Loss: 0.2247
Validation: {'precision': 0.6366322008768592, 'recall': 0.3069800569778705, 'f1': 0.4142239264087494, 'auc': 0.9002911847003439, 'prauc': 0.5271713677338216}
Test:       {'precision': 0.6316546762499042, 'recall': 0.30959097319950923, 'f1': 0.41552294872864426, 'auc': 0.9075284259756797, 'prauc': 0.5306214038310677}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.47it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.66it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.96it/s]


Epoch: 003, Average Loss: 0.1801
Validation: {'precision': 0.6832844574713266, 'recall': 0.4978632478597019, 'f1': 0.5760197726215633, 'auc': 0.9323203568041748, 'prauc': 0.6607057176715103}
Test:       {'precision': 0.6708984374934484, 'recall': 0.48448519040561017, 'f1': 0.5626535577791126, 'auc': 0.9350784767357301, 'prauc': 0.6406035125977617}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.46it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.67it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.82it/s]


Epoch: 004, Average Loss: 0.1598
Validation: {'precision': 0.6216586703178777, 'recall': 0.6460113960067948, 'f1': 0.6336011127061164, 'auc': 0.9321831173835174, 'prauc': 0.6795597086386729}
Test:       {'precision': 0.6167341430457711, 'recall': 0.6445698166386138, 'f1': 0.630344822584295, 'auc': 0.9376100740352934, 'prauc': 0.6659551164153933}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.38it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.60it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.37it/s]


Epoch: 005, Average Loss: 0.1288
Validation: {'precision': 0.7843749999918295, 'recall': 0.5363247863209664, 'f1': 0.6370558327344388, 'auc': 0.9431319938928713, 'prauc': 0.7385922626697499}
Test:       {'precision': 0.787301587293256, 'recall': 0.5246826516183027, 'f1': 0.6297079935022457, 'auc': 0.9478445573637228, 'prauc': 0.7320083750862592}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.50it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.45it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.45it/s]


Epoch: 006, Average Loss: 0.1026
Validation: {'precision': 0.8142340168779948, 'recall': 0.4807692307658065, 'f1': 0.6045678412732757, 'auc': 0.9504587572387815, 'prauc': 0.7507307988922328}
Test:       {'precision': 0.8141176470492457, 'recall': 0.48801128349444284, 'f1': 0.6102292722041657, 'auc': 0.9529521415597156, 'prauc': 0.7386398680089421}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.54it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.25it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.28it/s]


Epoch: 007, Average Loss: 0.0879
Validation: {'precision': 0.7663551401797538, 'recall': 0.5840455840414243, 'f1': 0.662894093711479, 'auc': 0.950239184342921, 'prauc': 0.749614882811921}
Test:       {'precision': 0.7629151291442536, 'recall': 0.5832157968929251, 'f1': 0.6610711381693495, 'auc': 0.9540283161933548, 'prauc': 0.7411295707377784}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.38it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.57it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.82it/s]


Epoch: 008, Average Loss: 0.0683
Validation: {'precision': 0.7336206896488482, 'recall': 0.606125356121039, 'f1': 0.6638065473021934, 'auc': 0.9465092947267699, 'prauc': 0.7306706003387106}
Test:       {'precision': 0.7379367720404498, 'recall': 0.6255289139589173, 'f1': 0.6770992316700367, 'auc': 0.9509063968388252, 'prauc': 0.7316813478324349}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.46it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.45it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.29it/s]


Epoch: 009, Average Loss: 0.0607
Validation: {'precision': 0.8073286051914027, 'recall': 0.48646723646377166, 'f1': 0.6071111064132347, 'auc': 0.9460266160005391, 'prauc': 0.7380450214607066}
Test:       {'precision': 0.8178613395908594, 'recall': 0.490832157965509, 'f1': 0.613486112539078, 'auc': 0.9509687110939337, 'prauc': 0.7396953997229128}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.43it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.39it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.75it/s]


Epoch: 010, Average Loss: 0.0563
Validation: {'precision': 0.7615230460845539, 'recall': 0.5413105413066859, 'f1': 0.6328059901417429, 'auc': 0.9447294057924909, 'prauc': 0.719023045242473}
Test:       {'precision': 0.7717717717640463, 'recall': 0.5437235542979991, 'f1': 0.6379809632873066, 'auc': 0.9477321186141667, 'prauc': 0.717952979920025}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.26it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.20it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.47it/s]


Epoch: 011, Average Loss: 0.0393
Validation: {'precision': 0.7112561174493373, 'recall': 0.6210826210781974, 'f1': 0.6631178657402942, 'auc': 0.9459429849334829, 'prauc': 0.7175951906133418}
Test:       {'precision': 0.7327800829814708, 'recall': 0.6227080394878511, 'f1': 0.6732748711239107, 'auc': 0.9498851381941497, 'prauc': 0.7236285651330372}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.65it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.50it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.59it/s]


Epoch: 012, Average Loss: 0.0326
Validation: {'precision': 0.6621160409511119, 'recall': 0.6908831908782701, 'f1': 0.6761937907451939, 'auc': 0.9462660089758754, 'prauc': 0.7199971616236155}
Test:       {'precision': 0.6650519031095845, 'recall': 0.677715091673641, 'f1': 0.6713237812339669, 'auc': 0.9469996157866886, 'prauc': 0.7201845992089511}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.36it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.20it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.61it/s]


Epoch: 013, Average Loss: 0.0360
Validation: {'precision': 0.6633132126045386, 'recall': 0.7044159544109373, 'f1': 0.6832469725472912, 'auc': 0.9487286347685006, 'prauc': 0.7262612954593511}
Test:       {'precision': 0.675202156329682, 'recall': 0.706629055002069, 'f1': 0.6905582306973448, 'auc': 0.9521864458401734, 'prauc': 0.737531552999397}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.46it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.38it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.71it/s]


Epoch: 014, Average Loss: 0.0250
Validation: {'precision': 0.7497773820057901, 'recall': 0.5997150997108283, 'f1': 0.666402844284886, 'auc': 0.9423723683565077, 'prauc': 0.7259783682215755}
Test:       {'precision': 0.7499999999934441, 'recall': 0.6050775740436878, 'f1': 0.6697892222182369, 'auc': 0.9468631202589651, 'prauc': 0.7270598503050762}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.52it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 20.97it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.56it/s]


Epoch: 015, Average Loss: 0.0346
Validation: {'precision': 0.6984496123976865, 'recall': 0.641737891733321, 'f1': 0.6688938331628593, 'auc': 0.9493874043420374, 'prauc': 0.7120278075266223}
Test:       {'precision': 0.7169517884858714, 'recall': 0.6502115655807461, 'f1': 0.6819526577287649, 'auc': 0.9522712528104723, 'prauc': 0.7307695573389058}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.41it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.62it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.70it/s]


Epoch: 016, Average Loss: 0.0239
Validation: {'precision': 0.6863988724405469, 'recall': 0.6937321937272527, 'f1': 0.6900460452963506, 'auc': 0.9451534384555745, 'prauc': 0.7214410149944999}
Test:       {'precision': 0.6951646811443928, 'recall': 0.6995768688244036, 'f1': 0.6973637911287154, 'auc': 0.9497356088082863, 'prauc': 0.735960697761532}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.47it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.28it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.74it/s]


Epoch: 017, Average Loss: 0.0261
Validation: {'precision': 0.6796759941039788, 'recall': 0.6574074074027251, 'f1': 0.6683562585736653, 'auc': 0.943272846216334, 'prauc': 0.6998437795814259}
Test:       {'precision': 0.7014042867649564, 'recall': 0.6692524682604426, 'f1': 0.684951276123755, 'auc': 0.9484877596431434, 'prauc': 0.7163038289745917}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.38it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.45it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.61it/s]


Epoch: 018, Average Loss: 0.0194
Validation: {'precision': 0.6520051746400258, 'recall': 0.7179487179436044, 'f1': 0.6833898255154266, 'auc': 0.9457093675140399, 'prauc': 0.69881444466406}
Test:       {'precision': 0.6636539702991372, 'recall': 0.7249647390639989, 'f1': 0.6929558426626428, 'auc': 0.9490818553103711, 'prauc': 0.7158358950240764}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.46it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.35it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.69it/s]


Epoch: 019, Average Loss: 0.0117
Validation: {'precision': 0.7471590909020156, 'recall': 0.5619658119618094, 'f1': 0.6414634097289907, 'auc': 0.9386276709274496, 'prauc': 0.7122344235532092}
Test:       {'precision': 0.7722488038203613, 'recall': 0.5691114245375944, 'f1': 0.6552984116745159, 'auc': 0.9424186987651448, 'prauc': 0.7218950928644962}


Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.48it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.37it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.54it/s]


Epoch: 020, Average Loss: 0.0132
Validation: {'precision': 0.65500685870607, 'recall': 0.6801994301945855, 'f1': 0.6673654736833499, 'auc': 0.9428210807077707, 'prauc': 0.6970164350753152}
Test:       {'precision': 0.664179104473106, 'recall': 0.6904090267934386, 'f1': 0.6770401056472619, 'auc': 0.9465030630311423, 'prauc': 0.707228398537542}


Training Batches: 100%|██████████| 120/120 [00:22<00:00,  5.30it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.23it/s]
Running inference: 100%|██████████| 245/245 [00:11<00:00, 21.67it/s]


Epoch: 021, Average Loss: 0.0120
Validation: {'precision': 0.7460611677410004, 'recall': 0.5733618233577397, 'f1': 0.6484091775210344, 'auc': 0.941794609223344, 'prauc': 0.7208170401713448}
Test:       {'precision': 0.7692307692236402, 'recall': 0.5853314527462248, 'f1': 0.6647977523956038, 'auc': 0.9435226293100237, 'prauc': 0.7263820452840777}

Early stopping triggered after 21 epochs (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6863988724405469, 'recall': 0.6937321937272527, 'f1': 0.6900460452963506, 'auc': 0.9451534384555745, 'prauc': 0.7214410149944999}
Corresponding test performance:
{'precision': 0.6951646811443928, 'recall': 0.6995768688244036, 'f1': 0.6973637911287154, 'auc': 0.9497356088082863, 'prauc': 0.735960697761532}
Corresponding test-long performance:
{'precision': 0.5384615384418145, 'recall': 0.6282051281782818, 'f1': 0.5798816518114446, 'auc': 0.93051185789775, 'prauc': 0.5786510293436371}
[INFO] Random seed set to 1812140441
Training

Training Batches: 100%|██████████| 120/120 [00:21<00:00,  5.54it/s]
Running inference: 100%|██████████| 241/241 [00:11<00:00, 21.39it/s]
Running inference:  73%|███████▎  | 180/245 [00:08<00:03, 21.46it/s]


KeyboardInterrupt: 

In [None]:
def topk_avg_performance_formatted(performances, long_seq_performances, k=5):
    metrics = ["f1", "auc", "prauc"]
    scores = {m: np.array([p[m] for p in performances]) for m in metrics}

    # 计算排名（值越大排名越靠前）
    ranks = {m: (-scores[m]).argsort().argsort() + 1 for m in metrics}
    avg_ranks = np.mean(np.stack([ranks[m] for m in metrics], axis=1), axis=1)

    # 选 top-k
    topk_idx = np.argsort(avg_ranks)[:k]
    final_avg = {m: np.mean([performances[i][m] for i in topk_idx]) for m in performances[0].keys()}
    final_std = {m: np.std([performances[i][m] for i in topk_idx], ddof=0) for m in performances[0].keys()}
    final_long_seq_avg = {m: np.mean([long_seq_performances[i][m] for i in topk_idx]) for m in long_seq_performances[0].keys()}
    final_long_seq_std = {m: np.std([long_seq_performances[i][m] for i in topk_idx], ddof=0) for m in long_seq_performances[0].keys()}

    # 打印结果（转百分比，均保留两位小数）
    print("Final Metrics:")
    for m in performances[0].keys():
        mean_val = final_avg[m] * 100
        std_val = final_std[m] * 100
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")
    print("\nFinal Long Sequence Metrics:")
    for m in long_seq_performances[0].keys():
        mean_val = final_long_seq_avg[m] * 100
        std_val = final_long_seq_std[m] * 100
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")

In [None]:
def print_per_class_performance(dfs, col_name="prauc"):
    """
    输入一个 DataFrame 列表，对每个疾病在所有表格的指定列计算 mean ± std 并打印。

    参数:
        dfs (list[pd.DataFrame]): 多个表格组成的列表
        col_name (str): 要计算的指标列名 (默认: "prauc")
    """
    # 拼接所有表格
    all_values = pd.concat(dfs, axis=0)

    # 按疾病分组，计算 mean 和 std
    grouped = all_values.groupby(all_values.index)[col_name].agg(["mean", "std"])

    # 打印
    for disease, row in grouped.iterrows():
        mean_val = row["mean"] * 100
        std_val = row["std"] * 100
        print(f"{disease}: {mean_val:.2f} ± {std_val:.2f}")

In [None]:
if task_type == "binary":
    topk_avg_performance_formatted(final_metrics, final_long_seq_metrics)
else:
    final_metrics_global = [metrics["global"] for metrics in final_metrics]
    final_metrics_per_class = [metrics["per_class"] for metrics in final_metrics]
    final_long_seq_metrics_global = [metrics["global"] for metrics in final_long_seq_metrics]
    final_long_seq_metrics_per_class = [metrics["per_class"] for metrics in final_long_seq_metrics]
    topk_avg_performance_formatted(final_metrics_global, final_long_seq_metrics_global)
    print("\nPer-class performance, all patients:")
    print_per_class_performance(final_metrics_per_class, col_name="prauc")
    print("\nPer-class performance, long seq:")
    print_per_class_performance(final_long_seq_metrics_per_class, col_name="prauc")