In [1]:
import torch
from pyhealth.models import RETAINLayer, StageNetLayer
from set_seed_utils import set_random_seed
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve, precision_recall_fscore_support
from tqdm import tqdm
import pandas as pd
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import random
import pickle
from torch.utils.data import DataLoader
from token_utils_rep import EHRTokenizer
from dataset_utils_rep import HBERTFinetuneEHRDataset, batcher

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
args = {
    "dataset": "MIMIC-III", 
    "task": "stay",  # options: death, stay, readmission, next_diag_12m
    "batch_size": 64,
    "hidden_size": 512,
    "lr": 1e-3,
    "epochs": 500,
    "early_stop_patience": 5,
    "dropout": 0.0,
    "backbone": "stagenet",  # options: retain, stagenet
}

In [4]:
PHENO_ORDER = [
    "Acute and unspecified renal failure",
    "Acute cerebrovascular disease",
    "Acute myocardial infarction",
    "Cardiac dysrhythmias",
    "Chronic kidney disease",
    "Chronic obstructive pulmonary disease",
    "Conduction disorders",
    "Congestive heart failure; nonhypertensive",
    "Coronary atherosclerosis and related",
    "Disorders of lipid metabolism",
    "Essential hypertension",
    "Fluid and electrolyte disorders",
    "Gastrointestinal hemorrhage",
    "Hypertension with complications",
    "Other liver diseases",
    "Other lower respiratory disease",
    "Pneumonia",
    "Septicemia (except in labor)",
]

In [5]:
@torch.no_grad()
def evaluate(model, dataloader, device, long_seq_idx=None, task_type="binary"):
    model.eval()
    predicted_scores, gt_labels = [], []

    # 推理：收集 logits 与 labels
    for _, batch in enumerate(tqdm(dataloader, desc="Running inference")):
        batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]
        labels = batch[-1]
        output_logits = model(*batch[:-1])
        predicted_scores.append(output_logits)
        gt_labels.append(labels)

    if task_type == "binary":
        # —— 标准二分类评估 —— #
        logits_all = torch.cat(predicted_scores, dim=0).view(-1)          # logits [N]
        labels_all = torch.cat(gt_labels, dim=0).view(-1).cpu().numpy()    # y_true [N]
        scores_all = logits_all.cpu().numpy()                              # 连续分数（logits）
        ypred_all  = (logits_all > 0).float().cpu().numpy()                # logits > 0

        tp = (ypred_all * labels_all).sum()
        precision = tp / (ypred_all.sum() + 1e-8)
        recall    = tp / (labels_all.sum() + 1e-8)
        f1        = 2 * precision * recall / (precision + recall + 1e-8)
        roc_auc   = roc_auc_score(labels_all, scores_all)
        prec_curve, rec_curve, _ = precision_recall_curve(labels_all, scores_all)
        pr_auc    = auc(rec_curve, prec_curve)

        all_performance = {"precision": float(precision),
                           "recall": float(recall),
                           "f1": float(f1),
                           "auc": float(roc_auc),
                           "prauc": float(pr_auc)}

        subset_performance = None
        if long_seq_idx is not None:
            idx = torch.as_tensor(long_seq_idx, device=logits_all.device, dtype=torch.long)
            logits_sub = logits_all.index_select(0, idx).view(-1)
            labels_sub = torch.as_tensor(labels_all, device=logits_all.device)[idx].cpu().numpy()
            scores_sub = logits_sub.cpu().numpy()
            ypred_sub  = (logits_sub > 0).float().cpu().numpy()

            tp = (ypred_sub * labels_sub).sum()
            precision = tp / (ypred_sub.sum() + 1e-8)
            recall    = tp / (labels_sub.sum() + 1e-8)
            f1        = 2 * precision * recall / (precision + recall + 1e-8)
            roc_auc   = roc_auc_score(labels_sub, scores_sub)
            prec_curve, rec_curve, _ = precision_recall_curve(labels_sub, scores_sub)
            pr_auc    = auc(rec_curve, prec_curve)

            subset_performance = {"precision": float(precision),
                                  "recall": float(recall),
                                  "f1": float(f1),
                                  "auc": float(roc_auc),
                                  "prauc": float(pr_auc)}

        return all_performance, subset_performance

    else:
        # —— Multi-label evaluation（按类聚合） —— #
        logits_all = torch.cat(predicted_scores, dim=0)    # [B, C]
        labels_all_t = torch.cat(gt_labels, dim=0)         # [B, C]

        def _compute_metrics(logits_sub, labels_sub):
            # 连续分数（概率）：sigmoid(logits)，CPU + fp16 先升为 fp32
            if logits_sub.device.type == "cpu" and logits_sub.dtype == torch.float16:
                prob_t = torch.sigmoid(logits_sub.float())
            else:
                prob_t = torch.sigmoid(logits_sub)
            # 二值化：logits > 0
            ypred_t = (logits_sub > 0).to(torch.int32)

            y_true = labels_sub.cpu().numpy().astype(np.int32)       # [N, C]
            y_pred = ypred_t.cpu().numpy().astype(np.int32)          # [N, C]
            scores = prob_t.cpu().numpy()                             # [N, C]

            # per-class P/R/F1
            p_cls, r_cls, f1_cls, _ = precision_recall_fscore_support(
                y_true, y_pred, average=None, zero_division=0
            )

            # per-class AUC / PR-AUC
            C = y_true.shape[1]
            aucs, praucs = [], []
            for c in range(C):
                yt, ys = y_true[:, c], scores[:, c]
                if yt.max() == yt.min():
                    aucs.append(np.nan)
                    praucs.append(np.nan)
                else:
                    aucs.append(roc_auc_score(yt, ys))
                    prec_curve, rec_curve, _ = precision_recall_curve(yt, ys)
                    praucs.append(auc(rec_curve, prec_curve))

            # 宏平均（忽略 NaN）
            summary = {
                "precision": float(np.mean(p_cls)),
                "recall":    float(np.mean(r_cls)),
                "f1":        float(np.mean(f1_cls)),
                "auc":       float(np.nanmean(aucs)) if np.any(~np.isnan(aucs)) else float("nan"),
                "prauc":     float(np.nanmean(praucs)) if np.any(~np.isnan(praucs)) else float("nan"),
            }

            per_class_df = pd.DataFrame({
                "precision": p_cls,
                "recall":    r_cls,
                "f1":        f1_cls,
                "auc":       aucs,
                "prauc":     praucs,
            }, index=PHENO_ORDER)   # 确保 PHENO_ORDER 已定义且长度=C

            return {"global": summary, "per_class": per_class_df}

        # 全量
        all_performance = _compute_metrics(logits_all, labels_all_t)

        # 子集
        subset_performance = None
        if long_seq_idx is not None:
            idx = torch.as_tensor(long_seq_idx, device=logits_all.device, dtype=torch.long)
            subset_performance = _compute_metrics(
                logits_all.index_select(0, idx),
                labels_all_t.index_select(0, idx)
            )

        return all_performance, subset_performance

In [6]:
args["predicted_token_type"] = ["diag", "med", "pro", "lab"]
args["special_tokens"] = ("[PAD]", "[CLS]", "[SEP]", "[MASK0]", "[MASK1]", "[MASK2]", "[MASK3]")
args["max_visit_size"] = 15

full_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic.pkl"

if args["task"] == "next_diag_6m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_6m.pkl"
elif args["task"] == "next_diag_12m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_downstream.pkl"

In [7]:
ehr_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_data["NDC"].values.tolist()
lab_sentences = ehr_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_data["PRO_CODE"].values.tolist()
gender_set = [["M"], ["F"]]
age_gender_set = [[str(c) + "_" + gender] for c in set(ehr_data["AGE"].values.tolist()) for gender in ["M", "F"]]
age_set = [[c] for c in set(ehr_data["AGE"].values.tolist())]    

In [8]:
tokenizer = EHRTokenizer(diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, gender_set, age_set, age_gender_set, special_tokens=args["special_tokens"])
tokenizer.build_tree()

In [9]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))

train_dataset = HBERTFinetuneEHRDataset(
    train_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

val_dataset = HBERTFinetuneEHRDataset(
    val_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

test_dataset = HBERTFinetuneEHRDataset(
    test_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

print(len(train_dataset), len(val_dataset), len(test_dataset))

train_dataloader = DataLoader(
    train_dataset, 
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False), 
    batch_size =  args["batch_size"],
    shuffle = True
)

val_dataloader = DataLoader(
    val_dataset, 
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False), 
    batch_size =  args["batch_size"],
    shuffle = False
)

test_dataloader = DataLoader(
    test_dataset, 
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False),
    batch_size = args["batch_size"],
    shuffle = False
)

3131 6310 6304


In [10]:
long_adm_seq_crite = 3
val_long_seq_idx, test_long_seq_idx = [], []
for i in range(len(val_dataset)):
    hadm_id = list(val_dataset.records.keys())[i]
    num_adms = len(val_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite:
        val_long_seq_idx.append(i)
for i in range(len(test_dataset)):
    hadm_id = list(test_dataset.records.keys())[i]
    num_adms = len(test_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite:
        test_long_seq_idx.append(i)
print(len(val_long_seq_idx), len(test_long_seq_idx))

835 854


In [11]:
# examine a batch
batch = next(iter(train_dataloader))  # 取第一个 batch
input_ids, input_types, labeled_batch_idx, labels = batch

# 打印每个张量的形状
print("input_ids shape:", input_ids.shape)
print("input_types shape:", input_types.shape)
print("labeled_batch_idx shape:", len(labeled_batch_idx)) # it is a list
print("labels shape:", labels.shape)

input_ids shape: torch.Size([108, 181])
input_types shape: torch.Size([108, 181])
labeled_batch_idx shape: 64
labels shape: torch.Size([64, 1])


In [12]:
args["vocab_size"] = len(tokenizer.vocab.word2id)
args["label_vocab_size"] = 18  # only for diagnosis

In [13]:
if args["task"] in ["death", "stay", "readmission"]:
    eval_metric = "f1"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "prauc"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [14]:
def train_with_early_stopping(model, 
                              train_dataloader, 
                              val_dataloader, 
                              test_dataloader,
                              optimizer, 
                              loss_fn, 
                              device, 
                              args,
                              val_long_seq_idx = None,
                              test_long_seq_idx = None,
                              task_type="binary", 
                              eval_metric="f1"):
    best_score = 0.
    best_val_metric = None
    best_test_metric = None
    best_test_long_seq_metric = None
    epochs_no_improve = 0

    for epoch in range(1, 1 + args["epochs"]):
        model.train()
        ave_loss = 0.

        for step, batch in enumerate(tqdm(train_dataloader, desc="Training Batches")):
            batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]

            labels = batch[-1].float()
            output_logits = model(*batch[:-1])
            
            loss = loss_fn(output_logits.view(-1), labels.view(-1))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            ave_loss += loss.item()
            

        ave_loss /= (step + 1)

        # Evaluation
        val_metric, val_long_seq_metric = evaluate(model, val_dataloader, device, task_type=task_type, long_seq_idx=val_long_seq_idx)
        test_metric, test_long_seq_metric = evaluate(model, test_dataloader, device, task_type=task_type, long_seq_idx=test_long_seq_idx)

        if task_type != "binary":
            val_per_class_df = val_metric["per_class"]
            val_metric = val_metric["global"]
            test_per_class_df = test_metric["per_class"]
            test_metric = test_metric["global"]
            
            if val_long_seq_idx != None:
                val_long_seq_per_class_df = val_long_seq_metric["per_class"]
                val_long_seq_metric = val_long_seq_metric["global"]
            if test_long_seq_idx != None:
                test_long_seq_per_class_df = test_long_seq_metric["per_class"]
                test_long_seq_metric = test_long_seq_metric["global"]

        # Logging
        print(f"Epoch: {epoch:03d}, Average Loss: {ave_loss:.4f}")
        print(f"Validation: {val_metric}")
        print(f"Test:       {test_metric}")

        # Check for improvement
        current_score = val_metric[eval_metric]
        if current_score > best_score:
            best_score = current_score
            best_val_metric = val_metric if task_type == "binary" else {"global": val_metric, "per_class": val_per_class_df}
            best_test_metric = test_metric if task_type == "binary" else {"global": test_metric, "per_class": test_per_class_df}
            best_test_long_seq_metric = test_long_seq_metric if task_type == "binary" else {"global": test_long_seq_metric, "per_class": test_long_seq_per_class_df}
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        # Early stopping check
        if epochs_no_improve >= args["early_stop_patience"]:
            print(f"\nEarly stopping triggered after {epoch} epochs (no improvement for {args['early_stop_patience']} epochs).")
            break

    print("\nBest validation performance:")
    print(best_val_metric)
    print("Corresponding test performance:")
    print(best_test_metric)
    if best_test_long_seq_metric is not None:
        print("Corresponding test-long performance:")
        print(best_test_long_seq_metric)
    return best_test_metric, best_test_long_seq_metric

In [15]:
class BinaryPredictionHead(nn.Module):
    def __init__(self, in_dim: int):
        super().__init__()
        self.cls = nn.Sequential(
            nn.Linear(in_dim, in_dim),
            nn.ReLU(),
            nn.Linear(in_dim, 1),
        )
    def forward(self, x):
        return self.cls(x)  # [N,1]

class PyHealthModel(nn.Module):
    def __init__(self, args):
        """
        必要参数：
          args["vocab_size"]:  int
          args["hidden_size"]: int  (embedding 维度 & RETAIN feature_size)
          args["backbone"]:    str  ("retain" 或 "stagenet")
        可选参数：
          args["dropout"]:     float (默认 0.5)
        """
        super().__init__()
        self.vocab_size  = int(args["vocab_size"])
        self.hidden_size = int(args["hidden_size"])
        self.dropout     = float(args.get("dropout", 0.5))
        self.backbone    = str(args.get("backbone", "retain")).lower()

        # Embeddings（pad=0）
        self.input_emb = nn.Embedding(self.vocab_size, self.hidden_size, padding_idx=0)
        self.type_emb  = nn.Embedding(5,              self.hidden_size, padding_idx=0)

        # Backbone
        if self.backbone == "retain":
            self.backbone_mod = RETAINLayer(feature_size=self.hidden_size, dropout=self.dropout)
            backbone_out_dim = self.hidden_size
        elif self.backbone == "stagenet":
            # 你自己的 StageNetLayer，构造参数按你的实现来
            self.backbone_mod = StageNetLayer(self.hidden_size)  # 示例
            # StageNet 输出维可能与 hidden_size 不同；提供可配置项
            backbone_out_dim = 384
        else:
            raise ValueError(f"Unknown backbone: {self.backbone}")

        # 分类头：输入维度与 backbone 输出对齐
        self.pred_head = BinaryPredictionHead(backbone_out_dim)

    def forward(self, input_ids, input_types, labeled_batch_idx):
        """
        input_ids:        LongTensor [B, T], pad=0
        input_types:      LongTensor [B, T], pad=0
        labeled_batch_idx: list[int] 或 LongTensor [P]，表示要从 batch 里取出的行（病人）
        return:
          logits:         Tensor [P, 1]
        """
        # 1) embeddings 相加
        x = self.input_emb(input_ids) + self.type_emb(input_types)   # [B,T,H]

        # 2) mask（以 input_ids!=0 作为有效位）
        mask = (input_ids != 0)  # [B,T]，RETAIN 接受 bool/float mask

        # 3) backbone 前向：得到每个病人的表示 c: [B, H*]
        if self.backbone == "retain":
            c = self.backbone_mod(x, mask=mask)                      # [B,H]
        else:  # stagenet
            out = self.backbone_mod(x, mask=mask)                    # 取决于你的实现
            c = out[0] if isinstance(out, (tuple, list)) else out    # [B,H*]

        # 4) 用 labeled_batch_idx 选择需要的样本行：c_sel [P,H*]
        idx = torch.as_tensor(labeled_batch_idx, device=c.device, dtype=torch.long)
        # 健壮性检查（调试期可保留）
        assert idx.ndim == 1 and (idx.numel() == 0 or (0 <= idx.min() and idx.max() < c.size(0))), \
            f"Index out of range: max={int(idx.max()) if idx.numel()>0 else 'NA'}, B={c.size(0)}"

        c_sel = c.index_select(0, idx)                                # [P,H*]

        # 5) 分类
        logits = self.pred_head(c_sel)                                # [P,1]
        return logits

In [16]:
random.seed(42)
seeds = [random.randint(0, 2**32 - 1) for _ in range(5)]
print(seeds)

[2746317213, 1181241943, 958682846, 3163119785, 1812140441]


In [17]:
final_metrics, final_long_seq_metrics = [], []

for seed in seeds:
    args["seed"] = seed
    set_random_seed(args["seed"])
    print(f"Training with seed: {args['seed']}")
    
    # Initialize model, optimizer, and loss function
    model = PyHealthModel(args)
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args["lr"])
    
    best_test_metric, best_test_long_seq_metric = train_with_early_stopping(
        model, 
        train_dataloader, 
        val_dataloader, 
        test_dataloader,
        optimizer, 
        loss_fn, 
        device, 
        args,
        val_long_seq_idx,
        test_long_seq_idx,
        task_type=task_type)
    
    final_metrics.append(best_test_metric)
    final_long_seq_metrics.append(best_test_long_seq_metric)

[INFO] Random seed set to 2746317213
Training with seed: 2746317213


Training Batches: 100%|██████████| 49/49 [00:12<00:00,  3.81it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.51it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.42it/s]


Epoch: 001, Average Loss: 0.6122
Validation: {'precision': 0.8114915387610725, 'recall': 0.646597679521334, 'f1': 0.7197207629497406, 'auc': 0.8195886030450115, 'prauc': 0.8400692342364835}
Test:       {'precision': 0.8049353701496086, 'recall': 0.6444026340525418, 'f1': 0.7157784694580129, 'auc': 0.8153058240430211, 'prauc': 0.8360143547568614}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.22it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 15.97it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 15.76it/s]


Epoch: 002, Average Loss: 0.5150
Validation: {'precision': 0.8638239339712801, 'recall': 0.5907808090291917, 'f1': 0.7016759728271923, 'auc': 0.833145899941012, 'prauc': 0.853253154796576}
Test:       {'precision': 0.8528079710106304, 'recall': 0.5904672311050785, 'f1': 0.697795066498539, 'auc': 0.8273051374936012, 'prauc': 0.8504730826944271}


Training Batches: 100%|██████████| 49/49 [00:12<00:00,  4.04it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 17.08it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.63it/s]


Epoch: 003, Average Loss: 0.4957
Validation: {'precision': 0.730414746541675, 'recall': 0.7952336155509714, 'f1': 0.7614472251523628, 'auc': 0.8359010854056252, 'prauc': 0.8534572676804513}
Test:       {'precision': 0.7142857142837096, 'recall': 0.7980558168679898, 'f1': 0.7538507059135819, 'auc': 0.830614869432293, 'prauc': 0.8526172133514625}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.25it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.42it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 15.91it/s]


Epoch: 004, Average Loss: 0.4465
Validation: {'precision': 0.7990737442080547, 'recall': 0.7033552837858158, 'f1': 0.7481654386468848, 'auc': 0.8372770705612622, 'prauc': 0.8550002342179386}
Test:       {'precision': 0.7834907310220934, 'recall': 0.7024145500134763, 'f1': 0.7407407357531772, 'auc': 0.8302177378397955, 'prauc': 0.8529567311666433}


Training Batches: 100%|██████████| 49/49 [00:12<00:00,  4.06it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.69it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.73it/s]


Epoch: 005, Average Loss: 0.4346
Validation: {'precision': 0.7801795809751242, 'recall': 0.7356538099694712, 'f1': 0.7572627451632643, 'auc': 0.8375042412393854, 'prauc': 0.8553929189131062}
Test:       {'precision': 0.7684210526290512, 'recall': 0.7325180307283395, 'f1': 0.7500401298535592, 'auc': 0.8339705055550606, 'prauc': 0.855794796720476}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.24it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.45it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.09it/s]


Epoch: 006, Average Loss: 0.3848
Validation: {'precision': 0.7082213863495211, 'recall': 0.8265914079622873, 'f1': 0.7628418413594671, 'auc': 0.835383043823846, 'prauc': 0.8533295022780003}
Test:       {'precision': 0.6967999999981419, 'recall': 0.8193791157076846, 'f1': 0.7531344524451239, 'auc': 0.8300419731349788, 'prauc': 0.8525743028213475}


Training Batches: 100%|██████████| 49/49 [00:12<00:00,  4.07it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.75it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.61it/s]


Epoch: 007, Average Loss: 0.3541
Validation: {'precision': 0.7370470200636806, 'recall': 0.7717152712424844, 'f1': 0.7539828381375849, 'auc': 0.824726920448767, 'prauc': 0.8435898434720888}
Test:       {'precision': 0.7237192774630627, 'recall': 0.7663844465325608, 'f1': 0.7444410550079237, 'auc': 0.8198985577932167, 'prauc': 0.843517342173146}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.20it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.76it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.13it/s]


Epoch: 008, Average Loss: 0.3093
Validation: {'precision': 0.6827983951838446, 'recall': 0.8538726873601321, 'f1': 0.7588128695199358, 'auc': 0.830097030313571, 'prauc': 0.8466994343735357}
Test:       {'precision': 0.6785445420309196, 'recall': 0.8479147068019821, 'f1': 0.753833281928705, 'auc': 0.8295982326889131, 'prauc': 0.8501467305529968}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.22it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.69it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.41it/s]


Epoch: 009, Average Loss: 0.2836
Validation: {'precision': 0.7291603053412851, 'recall': 0.7488240827822239, 'f1': 0.7388613811372129, 'auc': 0.8069131624258292, 'prauc': 0.8288904375252097}
Test:       {'precision': 0.7170495767813874, 'recall': 0.7438068359964133, 'f1': 0.7301831564585648, 'auc': 0.8064460145151848, 'prauc': 0.8332163372523307}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.10it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.60it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 15.86it/s]


Epoch: 010, Average Loss: 0.2488
Validation: {'precision': 0.8044088176320465, 'recall': 0.6293508936951102, 'f1': 0.7061928170284226, 'auc': 0.8096650322635613, 'prauc': 0.83197333819432}
Test:       {'precision': 0.7986550632879801, 'recall': 0.6331138287844682, 'f1': 0.7063144956765784, 'auc': 0.8093692352373, 'prauc': 0.8361922729552127}


Training Batches: 100%|██████████| 49/49 [00:12<00:00,  4.06it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.35it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.35it/s]


Epoch: 011, Average Loss: 0.2063
Validation: {'precision': 0.7173513182074882, 'recall': 0.7337723424247922, 'f1': 0.7254689145457481, 'auc': 0.8024365637686982, 'prauc': 0.8263905618726104}
Test:       {'precision': 0.7165742452226846, 'recall': 0.729382251487208, 'f1': 0.7229215179196686, 'auc': 0.8037265942769763, 'prauc': 0.8315318921774235}

Early stopping triggered after 11 epochs (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7082213863495211, 'recall': 0.8265914079622873, 'f1': 0.7628418413594671, 'auc': 0.835383043823846, 'prauc': 0.8533295022780003}
Corresponding test performance:
{'precision': 0.6967999999981419, 'recall': 0.8193791157076846, 'f1': 0.7531344524451239, 'auc': 0.8300419731349788, 'prauc': 0.8525743028213475}
Corresponding test-long performance:
{'precision': 0.7126213592094637, 'recall': 0.7995642701350857, 'f1': 0.7535934241591651, 'auc': 0.8083395383469844, 'prauc': 0.8463844411407574}
[INFO] Random seed set to 1181241943
Traini

Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.15it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.96it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.56it/s]


Epoch: 001, Average Loss: 0.6483
Validation: {'precision': 0.7052815873921644, 'recall': 0.7579178425815054, 'f1': 0.7306529575193766, 'auc': 0.7943389991368317, 'prauc': 0.8108598094430469}
Test:       {'precision': 0.7064487890262432, 'recall': 0.7591721542779581, 'f1': 0.7318621473621784, 'auc': 0.795080702273616, 'prauc': 0.811508644785323}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.18it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.35it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.82it/s]


Epoch: 002, Average Loss: 0.5311
Validation: {'precision': 0.8509288627057049, 'recall': 0.5888993414845127, 'f1': 0.6960711589880715, 'auc': 0.8254191831521143, 'prauc': 0.8456820046171886}
Test:       {'precision': 0.8490732568365001, 'recall': 0.6033239259937181, 'f1': 0.705407877817005, 'auc': 0.8258464716443512, 'prauc': 0.8466375603394327}


Training Batches: 100%|██████████| 49/49 [00:12<00:00,  4.03it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.65it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.45it/s]


Epoch: 003, Average Loss: 0.4911
Validation: {'precision': 0.7619955513162949, 'recall': 0.7519598620233554, 'f1': 0.756944439442275, 'auc': 0.8357222425011321, 'prauc': 0.8528753584641198}
Test:       {'precision': 0.7527216174160102, 'recall': 0.7588585763538449, 'f1': 0.7557776339733623, 'auc': 0.8353700798340202, 'prauc': 0.8539071902435165}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.22it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.38it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.68it/s]


Epoch: 004, Average Loss: 0.4562
Validation: {'precision': 0.7669559627104887, 'recall': 0.7481969269339975, 'f1': 0.7574603124586794, 'auc': 0.8365677273558006, 'prauc': 0.8551541123522091}
Test:       {'precision': 0.755898081155219, 'recall': 0.7535277516439213, 'f1': 0.754711050274024, 'auc': 0.8357264412630294, 'prauc': 0.8551584464315234}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.09it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.50it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.59it/s]


Epoch: 005, Average Loss: 0.4193
Validation: {'precision': 0.791411042942088, 'recall': 0.7281279397907553, 'f1': 0.7584517343496537, 'auc': 0.8396434234189156, 'prauc': 0.8575815241904748}
Test:       {'precision': 0.7814153439127599, 'recall': 0.7409846346793949, 'f1': 0.7606631207052467, 'auc': 0.8385668633197885, 'prauc': 0.8581280314721227}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.20it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 15.99it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.18it/s]


Epoch: 006, Average Loss: 0.3869
Validation: {'precision': 0.7324522760624598, 'recall': 0.7820633427382188, 'f1': 0.7564452482636108, 'auc': 0.831577457715961, 'prauc': 0.8487345864125387}
Test:       {'precision': 0.7278536160304158, 'recall': 0.7858262778275766, 'f1': 0.7557297899387103, 'auc': 0.8320560192113037, 'prauc': 0.8519131072995371}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.23it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.82it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.53it/s]


Epoch: 007, Average Loss: 0.3673
Validation: {'precision': 0.7336989640440778, 'recall': 0.755095641264487, 'f1': 0.7442435431365784, 'auc': 0.818901163071673, 'prauc': 0.8412890660246887}
Test:       {'precision': 0.7208616110925912, 'recall': 0.7660708686084476, 'f1': 0.7427789551726284, 'auc': 0.8206224546960433, 'prauc': 0.8434598047405573}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.09it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.42it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.34it/s]


Epoch: 008, Average Loss: 0.2919
Validation: {'precision': 0.7681940700782743, 'recall': 0.7149576669780027, 'f1': 0.7406204270326215, 'auc': 0.8203724976185259, 'prauc': 0.8398066812736956}
Test:       {'precision': 0.7544483985740716, 'recall': 0.7312637190318869, 'f1': 0.7426751542345211, 'auc': 0.8244196165893292, 'prauc': 0.8456667727803853}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.16it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.74it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.66it/s]


Epoch: 009, Average Loss: 0.2656
Validation: {'precision': 0.7555116814716831, 'recall': 0.7199749137638132, 'f1': 0.7373153450326458, 'auc': 0.8161404515622581, 'prauc': 0.832700360017299}
Test:       {'precision': 0.7564683053015638, 'recall': 0.733458764500679, 'f1': 0.7447858571226866, 'auc': 0.8240408567371689, 'prauc': 0.842296476855461}


Training Batches: 100%|██████████| 49/49 [00:12<00:00,  4.04it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.23it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 15.97it/s]


Epoch: 010, Average Loss: 0.2437
Validation: {'precision': 0.6950146627545601, 'recall': 0.743179680148187, 'f1': 0.7182906451019307, 'auc': 0.7924802386126051, 'prauc': 0.8135271245065483}
Test:       {'precision': 0.6980369514991397, 'recall': 0.7582314205056186, 'f1': 0.7268901197621072, 'auc': 0.7984500291179502, 'prauc': 0.8219456230264953}

Early stopping triggered after 10 epochs (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.791411042942088, 'recall': 0.7281279397907553, 'f1': 0.7584517343496537, 'auc': 0.8396434234189156, 'prauc': 0.8575815241904748}
Corresponding test performance:
{'precision': 0.7814153439127599, 'recall': 0.7409846346793949, 'f1': 0.7606631207052467, 'auc': 0.8385668633197885, 'prauc': 0.8581280314721227}
Corresponding test-long performance:
{'precision': 0.8139534883510606, 'recall': 0.6862745097889701, 'f1': 0.7446808460824406, 'auc': 0.8152891536361381, 'prauc': 0.8559434371750028}
[INFO] Random seed set to 958682846
Trainin

Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.17it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.74it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.66it/s]


Epoch: 001, Average Loss: 0.6415
Validation: {'precision': 0.7602862253997843, 'recall': 0.6663530887404631, 'f1': 0.7102272677465751, 'auc': 0.8022104982995355, 'prauc': 0.8239210235075796}
Test:       {'precision': 0.7633093525152399, 'recall': 0.6654123549681236, 'f1': 0.7110068638433416, 'auc': 0.80207746633064, 'prauc': 0.822789813342325}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.12it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.36it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 15.92it/s]


Epoch: 002, Average Loss: 0.5378
Validation: {'precision': 0.8547535211229985, 'recall': 0.6089683286277549, 'f1': 0.711225045495455, 'auc': 0.8346340638061247, 'prauc': 0.8530244719040938}
Test:       {'precision': 0.8420138888852343, 'recall': 0.6083411727795286, 'f1': 0.7063535359974156, 'auc': 0.8329757135659447, 'prauc': 0.8527659455181192}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.19it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.45it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.12it/s]


Epoch: 003, Average Loss: 0.4891
Validation: {'precision': 0.8743914313491997, 'recall': 0.5631859517072336, 'f1': 0.6851039433530073, 'auc': 0.8368629186217562, 'prauc': 0.8555544071893643}
Test:       {'precision': 0.8658767772470812, 'recall': 0.5729068673547416, 'f1': 0.6895640638969158, 'auc': 0.8373897632662841, 'prauc': 0.8564377556826442}


Training Batches: 100%|██████████| 49/49 [00:12<00:00,  4.05it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 15.77it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 15.45it/s]


Epoch: 004, Average Loss: 0.4470
Validation: {'precision': 0.8035779481533276, 'recall': 0.6901850109730631, 'f1': 0.7425775928670628, 'auc': 0.8374006530177378, 'prauc': 0.8546398486490646}
Test:       {'precision': 0.7994987468643054, 'recall': 0.7002195045446842, 'f1': 0.746573047510221, 'auc': 0.8389665619225799, 'prauc': 0.8563028542342406}


Training Batches: 100%|██████████| 49/49 [00:12<00:00,  4.06it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.48it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.68it/s]


Epoch: 005, Average Loss: 0.4160
Validation: {'precision': 0.7814776274686733, 'recall': 0.7064910630269474, 'f1': 0.7420948566703334, 'auc': 0.8255462821825545, 'prauc': 0.8489312312951076}
Test:       {'precision': 0.7757041058677445, 'recall': 0.7168391345226817, 'f1': 0.7451108163873567, 'auc': 0.8287125638040476, 'prauc': 0.8516283450833592}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.31it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.08it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 15.85it/s]


Epoch: 006, Average Loss: 0.3582
Validation: {'precision': 0.7427662957051104, 'recall': 0.7325180307283395, 'f1': 0.7376065627276249, 'auc': 0.8164272030506983, 'prauc': 0.8402642640510483}
Test:       {'precision': 0.7391440174922239, 'recall': 0.7419253684517343, 'f1': 0.7405320763748516, 'auc': 0.8204841381413941, 'prauc': 0.8444470861293997}


Training Batches: 100%|██████████| 49/49 [00:12<00:00,  4.03it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.59it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.12it/s]


Epoch: 007, Average Loss: 0.3315
Validation: {'precision': 0.7640170333542796, 'recall': 0.6751332706156314, 'f1': 0.7168303595913484, 'auc': 0.8062824900036363, 'prauc': 0.8312267627995726}
Test:       {'precision': 0.75726795096057, 'recall': 0.6779554719326499, 'f1': 0.715420246501982, 'auc': 0.8102928052741492, 'prauc': 0.8365873141205288}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.24it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.11it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.02it/s]


Epoch: 008, Average Loss: 0.2870
Validation: {'precision': 0.7503209242594663, 'recall': 0.7331451865765659, 'f1': 0.7416336191061688, 'auc': 0.81948099588169, 'prauc': 0.8391391464653823}
Test:       {'precision': 0.7452531645546037, 'recall': 0.7384760112864895, 'f1': 0.7418491050938456, 'auc': 0.8219761247909271, 'prauc': 0.8432043019423969}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.15it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.53it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.36it/s]


Epoch: 009, Average Loss: 0.2673
Validation: {'precision': 0.71453337185578, 'recall': 0.7754782063318424, 'f1': 0.7437593935023687, 'auc': 0.8183581035779732, 'prauc': 0.83699123492475}
Test:       {'precision': 0.7153116920404616, 'recall': 0.780809031041766, 'f1': 0.7466266816640155, 'auc': 0.8220031035657784, 'prauc': 0.8412133777723483}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.22it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.20it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 15.81it/s]


Epoch: 010, Average Loss: 0.2191
Validation: {'precision': 0.7453947368396533, 'recall': 0.7105675760404184, 'f1': 0.7275646121140263, 'auc': 0.8042416714215769, 'prauc': 0.8256355905241263}
Test:       {'precision': 0.7366042608110503, 'recall': 0.7155848228262289, 'f1': 0.725942415867197, 'auc': 0.8103094153407555, 'prauc': 0.8319551444220065}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.09it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.39it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.14it/s]


Epoch: 011, Average Loss: 0.1724
Validation: {'precision': 0.721125074803348, 'recall': 0.7557227971127134, 'f1': 0.7380186751413508, 'auc': 0.8074380362084541, 'prauc': 0.8256885183786647}
Test:       {'precision': 0.7204525156275069, 'recall': 0.7588585763538449, 'f1': 0.7391569895032506, 'auc': 0.8148103407227997, 'prauc': 0.8342110134210435}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.20it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.04it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 15.53it/s]


Epoch: 012, Average Loss: 0.1460
Validation: {'precision': 0.7188968975220342, 'recall': 0.7193477579155869, 'f1': 0.7191222520510379, 'auc': 0.7906671935499202, 'prauc': 0.8094543694790972}
Test:       {'precision': 0.713488372090811, 'recall': 0.721542803384379, 'f1': 0.7174929790952075, 'auc': 0.7968181152406422, 'prauc': 0.8194592629378539}


Training Batches: 100%|██████████| 49/49 [00:12<00:00,  3.98it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.12it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.01it/s]


Epoch: 013, Average Loss: 0.1240
Validation: {'precision': 0.7075087310805952, 'recall': 0.7623079335190897, 'f1': 0.733886787457565, 'auc': 0.8024220955786718, 'prauc': 0.815589026039925}
Test:       {'precision': 0.7013769363146833, 'recall': 0.7666980244566739, 'f1': 0.7325842646706252, 'auc': 0.806884822274804, 'prauc': 0.8258902585300273}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.17it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.14it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 15.79it/s]


Epoch: 014, Average Loss: 0.1333
Validation: {'precision': 0.6975121359202139, 'recall': 0.7209156475361527, 'f1': 0.7090208122697992, 'auc': 0.7764803294406869, 'prauc': 0.791187562318998}
Test:       {'precision': 0.698600973233885, 'recall': 0.7202884916879264, 'f1': 0.7092789821844034, 'auc': 0.7818889873748394, 'prauc': 0.8023885459498972}

Early stopping triggered after 14 epochs (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.71453337185578, 'recall': 0.7754782063318424, 'f1': 0.7437593935023687, 'auc': 0.8183581035779732, 'prauc': 0.83699123492475}
Corresponding test performance:
{'precision': 0.7153116920404616, 'recall': 0.780809031041766, 'f1': 0.7466266816640155, 'auc': 0.8220031035657784, 'prauc': 0.8412133777723483}
Corresponding test-long performance:
{'precision': 0.7133891213239877, 'recall': 0.7429193899620279, 'f1': 0.727854850909679, 'auc': 0.7937012216982433, 'prauc': 0.8322044730453255}
[INFO] Random seed set to 3163119785
Training wit

Training Batches: 100%|██████████| 49/49 [00:12<00:00,  4.05it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.58it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.37it/s]


Epoch: 001, Average Loss: 0.6237
Validation: {'precision': 0.7786576168899749, 'recall': 0.6475384132936735, 'f1': 0.7070707021105475, 'auc': 0.8108377594440357, 'prauc': 0.8288626975452106}
Test:       {'precision': 0.7813780260678533, 'recall': 0.6578864847894077, 'f1': 0.7143343498181706, 'auc': 0.8078872146277307, 'prauc': 0.8239009981143588}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.23it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.42it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 15.86it/s]


Epoch: 002, Average Loss: 0.5248
Validation: {'precision': 0.7186544342487666, 'recall': 0.8105989338325161, 'f1': 0.7618626534302116, 'auc': 0.8334091406206592, 'prauc': 0.8504288587362744}
Test:       {'precision': 0.7076133909268153, 'recall': 0.8218877391005899, 'f1': 0.7604816430744492, 'auc': 0.8324465067771588, 'prauc': 0.849972441005923}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.15it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.49it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.46it/s]


Epoch: 003, Average Loss: 0.4975
Validation: {'precision': 0.7930298719744203, 'recall': 0.6992787707723447, 'f1': 0.7432094601064084, 'auc': 0.8373529280853591, 'prauc': 0.8552748076205259}
Test:       {'precision': 0.7906732117784339, 'recall': 0.7071182188751737, 'f1': 0.7465651332352372, 'auc': 0.8363865152432595, 'prauc': 0.8542747246653524}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.17it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.42it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 15.86it/s]


Epoch: 004, Average Loss: 0.4523
Validation: {'precision': 0.7210789766387067, 'recall': 0.8131075572254215, 'f1': 0.7643330827091797, 'auc': 0.8376964471249445, 'prauc': 0.854189890439153}
Test:       {'precision': 0.7176861334051545, 'recall': 0.8131075572254215, 'f1': 0.7624228118357692, 'auc': 0.8392763648315562, 'prauc': 0.8551726162304064}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.16it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.38it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.37it/s]


Epoch: 005, Average Loss: 0.4220
Validation: {'precision': 0.7559934318530181, 'recall': 0.7218563813084922, 'f1': 0.7385306334346906, 'auc': 0.8192970790633334, 'prauc': 0.8391210353408552}
Test:       {'precision': 0.7508738481069245, 'recall': 0.7409846346793949, 'f1': 0.7458964596443299, 'auc': 0.8227078737252402, 'prauc': 0.8439831906068714}


Training Batches: 100%|██████████| 49/49 [00:12<00:00,  4.03it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.47it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 15.81it/s]


Epoch: 006, Average Loss: 0.3795
Validation: {'precision': 0.782686781606384, 'recall': 0.6832862966425736, 'f1': 0.7296166030901915, 'auc': 0.8233894166596587, 'prauc': 0.8433800628011021}
Test:       {'precision': 0.7792571828984609, 'recall': 0.6973973032276658, 'f1': 0.7360582442268964, 'auc': 0.8250880459363976, 'prauc': 0.8475677474601294}


Training Batches: 100%|██████████| 49/49 [00:12<00:00,  4.08it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.06it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.23it/s]


Epoch: 007, Average Loss: 0.3406
Validation: {'precision': 0.7977713578175908, 'recall': 0.6061461273107365, 'f1': 0.688880964442087, 'auc': 0.8103496589777278, 'prauc': 0.8298789617213389}
Test:       {'precision': 0.7948103792383441, 'recall': 0.6243336469092997, 'f1': 0.6993326259091757, 'auc': 0.8132938919751735, 'prauc': 0.8323879658624544}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.09it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.30it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.11it/s]


Epoch: 008, Average Loss: 0.3110
Validation: {'precision': 0.838090551176978, 'recall': 0.5340232047647099, 'f1': 0.6523654424753761, 'auc': 0.8113067699373919, 'prauc': 0.8322474916291565}
Test:       {'precision': 0.8426013195059255, 'recall': 0.5606773283143284, 'f1': 0.6733195207123621, 'auc': 0.8153540435697149, 'prauc': 0.8370152570582107}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.09it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.12it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.69it/s]


Epoch: 009, Average Loss: 0.2569
Validation: {'precision': 0.7829457364313498, 'recall': 0.6967701473794394, 'f1': 0.7373485929902873, 'auc': 0.827172747878024, 'prauc': 0.8436657284081295}
Test:       {'precision': 0.7777777777750772, 'recall': 0.7024145500134763, 'f1': 0.7381776190013017, 'auc': 0.8290526171676615, 'prauc': 0.8473212514739774}

Early stopping triggered after 9 epochs (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7210789766387067, 'recall': 0.8131075572254215, 'f1': 0.7643330827091797, 'auc': 0.8376964471249445, 'prauc': 0.854189890439153}
Corresponding test performance:
{'precision': 0.7176861334051545, 'recall': 0.8131075572254215, 'f1': 0.7624228118357692, 'auc': 0.8392763648315562, 'prauc': 0.8551726162304064}
Corresponding test-long performance:
{'precision': 0.7422680412218089, 'recall': 0.7843137254731086, 'f1': 0.7627118593944135, 'auc': 0.8203193513692396, 'prauc': 0.8579087051536154}
[INFO] Random seed set to 1812140441
Trainin

Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.25it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.51it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 15.37it/s]


Epoch: 001, Average Loss: 0.6296
Validation: {'precision': 0.7242339832846666, 'recall': 0.7337723424247922, 'f1': 0.7289719576147656, 'auc': 0.8052960407697518, 'prauc': 0.8263490266404423}
Test:       {'precision': 0.7252512945454608, 'recall': 0.7466290373134318, 'f1': 0.7357849146526749, 'auc': 0.8062122655778516, 'prauc': 0.8255421427832695}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.15it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 15.96it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.37it/s]


Epoch: 002, Average Loss: 0.5309
Validation: {'precision': 0.7600127145557534, 'recall': 0.7497648165545633, 'f1': 0.7548539807910596, 'auc': 0.8329349055031268, 'prauc': 0.8505411589288507}
Test:       {'precision': 0.754303599371661, 'recall': 0.7557227971127134, 'f1': 0.7550125263259599, 'auc': 0.8319162933176696, 'prauc': 0.8513967697965082}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.22it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.65it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.53it/s]


Epoch: 003, Average Loss: 0.4989
Validation: {'precision': 0.7286330935230831, 'recall': 0.7939793038545188, 'f1': 0.7599039565915627, 'auc': 0.8373018875260992, 'prauc': 0.8537414328162272}
Test:       {'precision': 0.7238933030626451, 'recall': 0.7999372844126688, 'f1': 0.7600178707736315, 'auc': 0.8366098954723475, 'prauc': 0.8551325987058134}


Training Batches: 100%|██████████| 49/49 [00:12<00:00,  4.04it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 15.73it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 15.82it/s]


Epoch: 004, Average Loss: 0.4582
Validation: {'precision': 0.728155339803746, 'recall': 0.7996237064885556, 'f1': 0.762217899656698, 'auc': 0.8371058636459496, 'prauc': 0.8534782364373976}
Test:       {'precision': 0.7235009945986829, 'recall': 0.798369394792103, 'f1': 0.7590936145685734, 'auc': 0.8350957620673392, 'prauc': 0.854342623471933}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.23it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.13it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 15.99it/s]


Epoch: 005, Average Loss: 0.4365
Validation: {'precision': 0.7835541963955707, 'recall': 0.7231106930049448, 'f1': 0.7521200210982228, 'auc': 0.8377867728390678, 'prauc': 0.8542048197574895}
Test:       {'precision': 0.7754421087728547, 'recall': 0.7287550956389817, 'f1': 0.7513740654841206, 'auc': 0.8354175946912215, 'prauc': 0.8546519437006411}


Training Batches: 100%|██████████| 49/49 [00:12<00:00,  4.01it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 15.69it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 15.85it/s]


Epoch: 006, Average Loss: 0.3880
Validation: {'precision': 0.7846153846126412, 'recall': 0.7036688617099289, 'f1': 0.7419408116762494, 'auc': 0.8314447824039481, 'prauc': 0.8495653791338428}
Test:       {'precision': 0.7753229095826808, 'recall': 0.7152712449021158, 'f1': 0.7440874195693807, 'auc': 0.8298950998793504, 'prauc': 0.8513814612763869}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.24it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.55it/s]
Running inference: 100%|██████████| 99/99 [00:05<00:00, 16.55it/s]


Epoch: 007, Average Loss: 0.3353
Validation: {'precision': 0.7156862745077992, 'recall': 0.8011915961091214, 'f1': 0.7560289933861907, 'auc': 0.8300946189485665, 'prauc': 0.8473856712046484}
Test:       {'precision': 0.7144039735079625, 'recall': 0.8118532455289689, 'f1': 0.7600176084043258, 'auc': 0.8307260058779502, 'prauc': 0.8484293179034104}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.13it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 15.96it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 15.96it/s]


Epoch: 008, Average Loss: 0.2641
Validation: {'precision': 0.6623376623360694, 'recall': 0.8635936030076401, 'f1': 0.7496937476369969, 'auc': 0.8265677966825445, 'prauc': 0.8411043433685084}
Test:       {'precision': 0.66299401197446, 'recall': 0.8679836939452242, 'f1': 0.7517653400088358, 'auc': 0.8253295462381471, 'prauc': 0.8440033554119033}


Training Batches: 100%|██████████| 49/49 [00:11<00:00,  4.20it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.48it/s]
Running inference: 100%|██████████| 99/99 [00:06<00:00, 16.35it/s]

Epoch: 009, Average Loss: 0.2285
Validation: {'precision': 0.6885078064993896, 'recall': 0.8435246158643979, 'f1': 0.7581736139893019, 'auc': 0.8306115553213852, 'prauc': 0.8446953119145564}
Test:       {'precision': 0.6836115326234607, 'recall': 0.847601128877869, 'f1': 0.7568248585579607, 'auc': 0.8298720471202423, 'prauc': 0.8450639863261656}

Early stopping triggered after 9 epochs (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.728155339803746, 'recall': 0.7996237064885556, 'f1': 0.762217899656698, 'auc': 0.8371058636459496, 'prauc': 0.8534782364373976}
Corresponding test performance:
{'precision': 0.7235009945986829, 'recall': 0.798369394792103, 'f1': 0.7590936145685734, 'auc': 0.8350957620673392, 'prauc': 0.854342623471933}
Corresponding test-long performance:
{'precision': 0.7451403887528047, 'recall': 0.7516339869117291, 'f1': 0.7483730969361381, 'auc': 0.8150133752516479, 'prauc': 0.8529058583460447}





In [18]:
def topk_avg_performance_formatted(performances, long_seq_performances, k=5):
    metrics = ["f1", "auc", "prauc"]
    scores = {m: np.array([p[m] for p in performances]) for m in metrics}

    # 计算排名（值越大排名越靠前）
    ranks = {m: (-scores[m]).argsort().argsort() + 1 for m in metrics}
    avg_ranks = np.mean(np.stack([ranks[m] for m in metrics], axis=1), axis=1)

    # 选 top-k
    topk_idx = np.argsort(avg_ranks)[:k]
    final_avg = {m: np.mean([performances[i][m] for i in topk_idx]) for m in performances[0].keys()}
    final_std = {m: np.std([performances[i][m] for i in topk_idx], ddof=0) for m in performances[0].keys()}
    final_long_seq_avg = {m: np.mean([long_seq_performances[i][m] for i in topk_idx]) for m in long_seq_performances[0].keys()}
    final_long_seq_std = {m: np.std([long_seq_performances[i][m] for i in topk_idx], ddof=0) for m in long_seq_performances[0].keys()}

    # 打印结果（转百分比，均保留两位小数）
    print("Final Metrics:")
    for m in performances[0].keys():
        mean_val = final_avg[m] * 100
        std_val = final_std[m] * 100
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")
    print("\nFinal Long Sequence Metrics:")
    for m in long_seq_performances[0].keys():
        mean_val = final_long_seq_avg[m] * 100
        std_val = final_long_seq_std[m] * 100
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")

In [19]:
def print_per_class_performance(dfs, col_name="prauc"):
    """
    输入一个 DataFrame 列表，对每个疾病在所有表格的指定列计算 mean ± std 并打印。

    参数:
        dfs (list[pd.DataFrame]): 多个表格组成的列表
        col_name (str): 要计算的指标列名 (默认: "prauc")
    """
    # 拼接所有表格
    all_values = pd.concat(dfs, axis=0)

    # 按疾病分组，计算 mean 和 std
    grouped = all_values.groupby(all_values.index)[col_name].agg(["mean", "std"])

    # 打印
    for disease, row in grouped.iterrows():
        mean_val = row["mean"] * 100
        std_val = row["std"] * 100
        print(f"{disease}: {mean_val:.2f} ± {std_val:.2f}")

In [20]:
if task_type == "binary":
    topk_avg_performance_formatted(final_metrics, final_long_seq_metrics)
else:
    final_metrics_global = [metrics["global"] for metrics in final_metrics]
    final_metrics_per_class = [metrics["per_class"] for metrics in final_metrics]
    final_long_seq_metrics_global = [metrics["global"] for metrics in final_long_seq_metrics]
    final_long_seq_metrics_per_class = [metrics["per_class"] for metrics in final_long_seq_metrics]
    topk_avg_performance_formatted(final_metrics_global, final_long_seq_metrics_global)
    print("\nPer-class performance, all patients:")
    print_per_class_performance(final_metrics_per_class, col_name="prauc")
    print("\nPer-class performance, long seq:")
    print_per_class_performance(final_long_seq_metrics_per_class, col_name="prauc")

Final Metrics:
precision: 72.69 ± 2.87
recall: 79.05 ± 2.81
f1: 75.64 ± 0.58
auc: 83.30 ± 0.64
prauc: 85.23 ± 0.58

Final Long Sequence Metrics:
precision: 74.55 ± 3.69
recall: 75.29 ± 3.93
f1: 74.74 ± 1.15
auc: 81.05 ± 0.92
prauc: 84.91 ± 0.93
