In [1]:
from set_seed_utils import set_random_seed
import os
import random
import numpy as np
import pickle
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from token_utils_rep import EHRTokenizer
from dataset_utils_rep import HBERTFinetuneEHRDataset, batcher, UniqueIDSampler
from HEART_rep import HBERT_Finetune
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve, precision_recall_fscore_support
import pandas as pd

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
PHENO_ORDER = [
    "Acute and unspecified renal failure",
    "Acute cerebrovascular disease",
    "Acute myocardial infarction",
    "Cardiac dysrhythmias",
    "Chronic kidney disease",
    "Chronic obstructive pulmonary disease",
    "Conduction disorders",
    "Congestive heart failure; nonhypertensive",
    "Coronary atherosclerosis and related",
    "Disorders of lipid metabolism",
    "Essential hypertension",
    "Fluid and electrolyte disorders",
    "Gastrointestinal hemorrhage",
    "Hypertension with complications",
    "Other liver diseases",
    "Other lower respiratory disease",
    "Pneumonia",
    "Septicemia (except in labor)",
]

In [4]:
@torch.no_grad()
def evaluate(model, dataloader, device, long_seq_idx=None, task_type="binary"):
    model.eval()
    predicted_scores, gt_labels = [], []

    # 推理：收集 logits 与 labels
    for _, batch in enumerate(tqdm(dataloader, desc="Running inference")):
        batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]
        labels = batch[-1]
        output_logits = model(*batch[:-1])
        predicted_scores.append(output_logits)
        gt_labels.append(labels)

    if task_type == "binary":
        # —— 标准二分类评估 —— #
        logits_all = torch.cat(predicted_scores, dim=0).view(-1)          # logits [N]
        labels_all = torch.cat(gt_labels, dim=0).view(-1).cpu().numpy()    # y_true [N]
        scores_all = logits_all.cpu().numpy()                              # 连续分数（logits）
        ypred_all  = (logits_all > 0).float().cpu().numpy()                # logits > 0

        tp = (ypred_all * labels_all).sum()
        precision = tp / (ypred_all.sum() + 1e-8)
        recall    = tp / (labels_all.sum() + 1e-8)
        f1        = 2 * precision * recall / (precision + recall + 1e-8)
        roc_auc   = roc_auc_score(labels_all, scores_all)
        prec_curve, rec_curve, _ = precision_recall_curve(labels_all, scores_all)
        pr_auc    = auc(rec_curve, prec_curve)

        all_performance = {"precision": float(precision),
                           "recall": float(recall),
                           "f1": float(f1),
                           "auc": float(roc_auc),
                           "prauc": float(pr_auc)}

        subset_performance = None
        if long_seq_idx is not None:
            idx = torch.as_tensor(long_seq_idx, device=logits_all.device, dtype=torch.long)
            logits_sub = logits_all.index_select(0, idx).view(-1)
            labels_sub = torch.as_tensor(labels_all, device=logits_all.device)[idx].cpu().numpy()
            scores_sub = logits_sub.cpu().numpy()
            ypred_sub  = (logits_sub > 0).float().cpu().numpy()

            tp = (ypred_sub * labels_sub).sum()
            precision = tp / (ypred_sub.sum() + 1e-8)
            recall    = tp / (labels_sub.sum() + 1e-8)
            f1        = 2 * precision * recall / (precision + recall + 1e-8)
            roc_auc   = roc_auc_score(labels_sub, scores_sub)
            prec_curve, rec_curve, _ = precision_recall_curve(labels_sub, scores_sub)
            pr_auc    = auc(rec_curve, prec_curve)

            subset_performance = {"precision": float(precision),
                                  "recall": float(recall),
                                  "f1": float(f1),
                                  "auc": float(roc_auc),
                                  "prauc": float(pr_auc)}

        return all_performance, subset_performance

    else:
        # —— Multi-label evaluation（按类聚合） —— #
        logits_all = torch.cat(predicted_scores, dim=0)    # [B, C]
        labels_all_t = torch.cat(gt_labels, dim=0)         # [B, C]

        def _compute_metrics(logits_sub, labels_sub):
            # 连续分数（概率）：sigmoid(logits)，CPU + fp16 先升为 fp32
            if logits_sub.device.type == "cpu" and logits_sub.dtype == torch.float16:
                prob_t = torch.sigmoid(logits_sub.float())
            else:
                prob_t = torch.sigmoid(logits_sub)
            # 二值化：logits > 0
            ypred_t = (logits_sub > 0).to(torch.int32)

            y_true = labels_sub.cpu().numpy().astype(np.int32)       # [N, C]
            y_pred = ypred_t.cpu().numpy().astype(np.int32)          # [N, C]
            scores = prob_t.cpu().numpy()                             # [N, C]

            # per-class P/R/F1
            p_cls, r_cls, f1_cls, _ = precision_recall_fscore_support(
                y_true, y_pred, average=None, zero_division=0
            )

            # per-class AUC / PR-AUC
            C = y_true.shape[1]
            aucs, praucs = [], []
            for c in range(C):
                yt, ys = y_true[:, c], scores[:, c]
                if yt.max() == yt.min():
                    aucs.append(np.nan)
                    praucs.append(np.nan)
                else:
                    aucs.append(roc_auc_score(yt, ys))
                    prec_curve, rec_curve, _ = precision_recall_curve(yt, ys)
                    praucs.append(auc(rec_curve, prec_curve))

            # 宏平均（忽略 NaN）
            summary = {
                "precision": float(np.mean(p_cls)),
                "recall":    float(np.mean(r_cls)),
                "f1":        float(np.mean(f1_cls)),
                "auc":       float(np.nanmean(aucs)) if np.any(~np.isnan(aucs)) else float("nan"),
                "prauc":     float(np.nanmean(praucs)) if np.any(~np.isnan(praucs)) else float("nan"),
            }

            per_class_df = pd.DataFrame({
                "precision": p_cls,
                "recall":    r_cls,
                "f1":        f1_cls,
                "auc":       aucs,
                "prauc":     praucs,
            }, index=PHENO_ORDER)   # 确保 PHENO_ORDER 已定义且长度=C

            return {"global": summary, "per_class": per_class_df}

        # 全量
        all_performance = _compute_metrics(logits_all, labels_all_t)

        # 子集
        subset_performance = None
        if long_seq_idx is not None:
            idx = torch.as_tensor(long_seq_idx, device=logits_all.device, dtype=torch.long)
            subset_performance = _compute_metrics(
                logits_all.index_select(0, idx),
                labels_all_t.index_select(0, idx)
            )

        return all_performance, subset_performance

In [5]:
args = {
    "seed": 0,
    "dataset": "MIMIC-IV", 
    "task": "next_diag_12m",  # options: death, stay, readmission, next_diag_6m, next_diag_12m
    "encoder": "hi_edge",  # options: hi_edge, hi_node, hi_edge_node
    "batch_size": 4,
    "eval_batch_size": 4,
    "pretrain_mask_rate": 0.7,
    "pretrain_anomaly_rate": 0.05,
    "pretrain_anomaly_loss_weight": 1,
    "pretrain_pos_weight": 1,
    "lr": 1e-4,
    "epochs": 500,
    "num_hidden_layers": 5,
    "num_attention_heads": 6,
    "attention_probs_dropout_prob": 0.2,
    "hidden_dropout_prob": 0.2,
    "edge_hidden_size": 32,
    "hidden_size": 288,  # must be divisible by num_attention_heads
    "intermediate_size": 288,
    "save_model": True,
    "gat": "dotattn",
    "gnn_n_heads": 1,
    "gnn_temp": 1,
    "diag_med_emb": "tree",  # simple, tree
    "early_stop_patience": 5,
}

In [6]:
exp_name = "Pretrain-HBERT" \
    + "-" + str(args["dataset"]) \
    + "-" + str(args["encoder"]) \
    + "-" + str(args["pretrain_mask_rate"]) \
    + "-" + str(args["pretrain_anomaly_rate"]) \
    + "-" + str(args["pretrain_anomaly_loss_weight"]) \
    + "-" + str(args["hidden_size"]) \
    + "-" + str(args["edge_hidden_size"]) \
    + "-" + str(args["num_hidden_layers"]) \
    + "-" + str(args["num_attention_heads"]) \
    + "-" + str(args["attention_probs_dropout_prob"]) \
    + "-" + str(args["hidden_dropout_prob"]) \
    + "-" + str(args["intermediate_size"]) \
    + "-" + str(args["gat"]) \
    + "-" + str(args["gnn_n_heads"]) \
    + "-" + str(args["gnn_temp"]) \
    + "-" + str(args["diag_med_emb"])
print(exp_name)

Pretrain-HBERT-MIMIC-IV-hi_edge-0.7-0.05-1-288-32-5-6-0.2-0.2-288-dotattn-1-1-tree


In [7]:
pretrained_weight_path = "./pretrained_models/" + exp_name + f"/pretrained_model.pt"
finetune_exp_name = f"Finetune-{args['task']}-" + exp_name
save_path = "./saved_model/" + finetune_exp_name
if args["save_model"] and not os.path.exists(save_path):
    os.makedirs(save_path)

In [8]:
args["predicted_token_type"] = ["diag", "med", "pro", "lab"]
args["special_tokens"] = ("[PAD]", "[CLS]", "[SEP]", "[MASK0]", "[MASK1]", "[MASK2]", "[MASK3]")
args["max_visit_size"] = 15

full_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic.pkl"

if args["task"] == "next_diag_6m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_6m.pkl"
elif args["task"] == "next_diag_12m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_downstream.pkl"

In [9]:
ehr_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_data["NDC"].values.tolist()
lab_sentences = ehr_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_data["PRO_CODE"].values.tolist()
gender_set = [["M"], ["F"]]
age_gender_set = [[str(c) + "_" + gender] for c in set(ehr_data["AGE"].values.tolist()) for gender in ["M", "F"]]
age_set = [[c] for c in set(ehr_data["AGE"].values.tolist())]    

In [10]:
tokenizer = EHRTokenizer(diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, gender_set, age_set, age_gender_set, special_tokens=args["special_tokens"])
tokenizer.build_tree()

In [11]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))

train_dataset = HBERTFinetuneEHRDataset(
    train_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

val_dataset = HBERTFinetuneEHRDataset(
    val_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

test_dataset = HBERTFinetuneEHRDataset(
    test_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

train_dataloader = DataLoader(
    train_dataset, 
    batch_sampler=UniqueIDSampler(train_dataset.get_ids(), batch_size=args["batch_size"]),
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False), 
)

val_dataloader = DataLoader(
    val_dataset, 
    batch_sampler=UniqueIDSampler(val_dataset.get_ids(), batch_size=args["batch_size"]),
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False), 
)

test_dataloader = DataLoader(
    test_dataset, 
    batch_sampler=UniqueIDSampler(test_dataset.get_ids(), batch_size=args["eval_batch_size"]),
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False),
)

In [12]:
long_adm_seq_crite = 3
val_long_seq_idx, test_long_seq_idx = [], []
for i in range(len(val_dataset)):
    hadm_id = list(val_dataset.records.keys())[i]
    num_adms = len(val_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite:
        val_long_seq_idx.append(i)
for i in range(len(test_dataset)):
    hadm_id = list(test_dataset.records.keys())[i]
    num_adms = len(test_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite:
        test_long_seq_idx.append(i)
print(len(val_long_seq_idx), len(test_long_seq_idx))

857 876


In [13]:
# examine a batch
batch = next(iter(train_dataloader))  # 取第一个 batch
input_ids, input_types, edge_index, visit_positions, labeled_batch_idx, labels = batch

# 打印每个张量的形状
print("input_ids shape:", input_ids.shape)
print("input_types shape:", input_types.shape)
print("visit_positions shape:", visit_positions.shape)
print("labeled_batch_idx shape:", len(labeled_batch_idx)) # it is a list
print("labels shape:", labels.shape)

input_ids shape: torch.Size([8, 41])
input_types shape: torch.Size([8, 41])
visit_positions shape: torch.Size([8])
labeled_batch_idx shape: 4
labels shape: torch.Size([4, 18])


In [14]:
args["vocab_size"] = len(args["special_tokens"]) + \
                     len(tokenizer.diag_voc.id2word) + \
                     len(tokenizer.pro_voc.id2word) + \
                     len(tokenizer.med_voc.id2word) + \
                     len(tokenizer.lab_voc.id2word) + \
                     len(tokenizer.age_voc.id2word) + \
                     len(tokenizer.gender_voc.id2word) + \
                     len(tokenizer.age_gender_voc.id2word)
args["label_vocab_size"] = 18  # only for diagnosis

In [15]:
if args["task"] in ["death", "stay", "readmission"]:
    eval_metric = "f1"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "prauc"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [16]:
def train_with_early_stopping(model, 
                              train_dataloader, 
                              val_dataloader, 
                              test_dataloader,
                              optimizer, 
                              loss_fn, 
                              device, 
                              args,
                              val_long_seq_idx = None,
                              test_long_seq_idx = None,
                              task_type="binary", 
                              eval_metric="f1"):
    best_score = 0.
    best_val_metric = None
    best_test_metric = None
    best_test_long_seq_metric = None
    epochs_no_improve = 0

    for epoch in range(1, 1 + args["epochs"]):
        model.train()
        ave_loss = 0.

        for step, batch in enumerate(tqdm(train_dataloader, desc="Training Batches")):
            batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]

            labels = batch[-1].float()
            output_logits = model(*batch[:-1])
            
            loss = loss_fn(output_logits.view(-1), labels.view(-1))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            ave_loss += loss.item()
            

        ave_loss /= (step + 1)

        # Evaluation
        val_metric, val_long_seq_metric = evaluate(model, val_dataloader, device, task_type=task_type, long_seq_idx=val_long_seq_idx)
        test_metric, test_long_seq_metric = evaluate(model, test_dataloader, device, task_type=task_type, long_seq_idx=test_long_seq_idx)

        if task_type != "binary":
            val_per_class_df = val_metric["per_class"]
            val_metric = val_metric["global"]
            test_per_class_df = test_metric["per_class"]
            test_metric = test_metric["global"]
            
            if val_long_seq_idx != None:
                val_long_seq_per_class_df = val_long_seq_metric["per_class"]
                val_long_seq_metric = val_long_seq_metric["global"]
            if test_long_seq_idx != None:
                test_long_seq_per_class_df = test_long_seq_metric["per_class"]
                test_long_seq_metric = test_long_seq_metric["global"]

        # Logging
        print(f"Epoch: {epoch:03d}, Average Loss: {ave_loss:.4f}")
        print(f"Validation: {val_metric}")
        print(f"Test:       {test_metric}")

        # Check for improvement
        current_score = val_metric[eval_metric]
        if current_score > best_score:
            best_score = current_score
            best_val_metric = val_metric if task_type == "binary" else {"global": val_metric, "per_class": val_per_class_df}
            best_test_metric = test_metric if task_type == "binary" else {"global": test_metric, "per_class": test_per_class_df}
            best_test_long_seq_metric = test_long_seq_metric if task_type == "binary" else {"global": test_long_seq_metric, "per_class": test_long_seq_per_class_df}
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        # Early stopping check
        if epochs_no_improve >= args["early_stop_patience"]:
            print(f"\nEarly stopping triggered after {epoch} epochs (no improvement for {args['early_stop_patience']} epochs).")
            break

    print("\nBest validation performance:")
    print(best_val_metric)
    print("Corresponding test performance:")
    print(best_test_metric)
    if best_test_long_seq_metric is not None:
        print("Corresponding test-long performance:")
        print(best_test_long_seq_metric)
    return best_test_metric, best_test_long_seq_metric

In [17]:
random.seed(42)
seeds = [random.randint(0, 2**32 - 1) for _ in range(5)]
print(seeds)

[2746317213, 1181241943, 958682846, 3163119785, 1812140441]


In [None]:
final_metrics, final_long_seq_metrics = [], []

for seed in seeds:
    args["seed"] = seed
    set_random_seed(args["seed"])
    print(f"Training with seed: {args['seed']}")
    
    # Initialize model, optimizer, and loss function
    model = HBERT_Finetune(args, tokenizer)
    model.load_weight(torch.load(pretrained_weight_path, weights_only=True))
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args["lr"])
    
    best_test_metric, best_test_long_seq_metric = train_with_early_stopping(
        model, 
        train_dataloader, 
        val_dataloader, 
        test_dataloader,
        optimizer, 
        loss_fn, 
        device, 
        args,
        val_long_seq_idx,
        test_long_seq_idx,
        task_type=task_type)
    
    final_metrics.append(best_test_metric)
    final_long_seq_metrics.append(best_test_long_seq_metric)

[INFO] Random seed set to 2746317213
Training with seed: 2746317213


Training Batches: 100%|██████████| 1713/1713 [00:53<00:00, 32.21it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.19it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 116.72it/s]


Epoch: 001, Average Loss: 0.3193
Validation: {'precision': 0.4001425728603908, 'recall': 0.30998449938587375, 'f1': 0.32913068916921273, 'auc': 0.7479186222206698, 'prauc': 0.4098777838353179}
Test:       {'precision': 0.43328917244914433, 'recall': 0.31328084257819566, 'f1': 0.3305960263151405, 'auc': 0.7453023594041245, 'prauc': 0.41128964003607593}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.89it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 123.92it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.64it/s]


Epoch: 002, Average Loss: 0.2899
Validation: {'precision': 0.4312611191608904, 'recall': 0.33920526644069415, 'f1': 0.3613446600712906, 'auc': 0.7617896553360688, 'prauc': 0.42843195280385693}
Test:       {'precision': 0.42717113992508743, 'recall': 0.34071118158008573, 'f1': 0.35960140617175124, 'auc': 0.7599495079036331, 'prauc': 0.42594710288692184}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.48it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.10it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.17it/s]


Epoch: 003, Average Loss: 0.2817
Validation: {'precision': 0.4405625393808415, 'recall': 0.35066659723542976, 'f1': 0.372526010212822, 'auc': 0.7669076075231113, 'prauc': 0.43099439065820333}
Test:       {'precision': 0.43570315934934495, 'recall': 0.3524880816617595, 'f1': 0.3716779977294613, 'auc': 0.7656303927267772, 'prauc': 0.4293673980156111}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.56it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.60it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.74it/s]


Epoch: 004, Average Loss: 0.2757
Validation: {'precision': 0.42527471277263507, 'recall': 0.3543751874790645, 'f1': 0.3775427779323395, 'auc': 0.7720453677512876, 'prauc': 0.4359785330322727}
Test:       {'precision': 0.42555309826786836, 'recall': 0.357492519989691, 'f1': 0.3779762594229721, 'auc': 0.7724589940337656, 'prauc': 0.43602964152259605}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.79it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.37it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.46it/s]


Epoch: 005, Average Loss: 0.2705
Validation: {'precision': 0.43243626940217783, 'recall': 0.37067165401875646, 'f1': 0.3864119601866048, 'auc': 0.7726984407181627, 'prauc': 0.44082718300606905}
Test:       {'precision': 0.42878863507132015, 'recall': 0.37048529086379645, 'f1': 0.3842550436954349, 'auc': 0.7721920261850893, 'prauc': 0.43577211226594176}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.63it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 123.33it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.67it/s]


Epoch: 006, Average Loss: 0.2651
Validation: {'precision': 0.4401977034610917, 'recall': 0.35987979789031455, 'f1': 0.3848007773050865, 'auc': 0.7708422225932218, 'prauc': 0.43856468173386914}
Test:       {'precision': 0.4356642271207302, 'recall': 0.36237898323785445, 'f1': 0.383746048918964, 'auc': 0.7725053354861671, 'prauc': 0.4366918798648893}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.57it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 123.23it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.33it/s]


Epoch: 007, Average Loss: 0.2601
Validation: {'precision': 0.42872402140006227, 'recall': 0.3455141928284955, 'f1': 0.37455745508982097, 'auc': 0.7712686712752382, 'prauc': 0.43616578413114904}
Test:       {'precision': 0.45289288636741, 'recall': 0.3463417518180554, 'f1': 0.37239907770828284, 'auc': 0.7665680118858234, 'prauc': 0.43360348151276107}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.86it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 123.21it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.17it/s]


Epoch: 008, Average Loss: 0.2544
Validation: {'precision': 0.4296333569816507, 'recall': 0.36562559157872854, 'f1': 0.3899407980028778, 'auc': 0.7729104059724144, 'prauc': 0.43986462656025954}
Test:       {'precision': 0.4448436393699662, 'recall': 0.3643740969090587, 'f1': 0.386269359579445, 'auc': 0.7705478626693273, 'prauc': 0.4377991329822877}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.61it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.17it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 119.33it/s]


Epoch: 009, Average Loss: 0.2493
Validation: {'precision': 0.4790596921854736, 'recall': 0.38297644432814715, 'f1': 0.39417488097176945, 'auc': 0.7714955951512047, 'prauc': 0.4408572626369089}
Test:       {'precision': 0.43841417873808947, 'recall': 0.38160894201828877, 'f1': 0.39184981333194846, 'auc': 0.7689667917535467, 'prauc': 0.434756688013496}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.59it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 123.26it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.16it/s]


Epoch: 010, Average Loss: 0.2439
Validation: {'precision': 0.4480386200998791, 'recall': 0.37910557472796974, 'f1': 0.39241946529991956, 'auc': 0.7695823155298745, 'prauc': 0.43572973194022446}
Test:       {'precision': 0.4344519704197516, 'recall': 0.37930940851803846, 'f1': 0.3904826447059465, 'auc': 0.7647222178387318, 'prauc': 0.4321610702888362}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.75it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.63it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.04it/s]


Epoch: 011, Average Loss: 0.2390
Validation: {'precision': 0.47500530849543043, 'recall': 0.3548091391189512, 'f1': 0.37878730236351704, 'auc': 0.769216072412734, 'prauc': 0.4359348372379398}
Test:       {'precision': 0.4946185746992019, 'recall': 0.3573707830232577, 'f1': 0.3791350854878768, 'auc': 0.7637628673956065, 'prauc': 0.4320815645731335}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.58it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.27it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.63it/s]


Epoch: 012, Average Loss: 0.2334
Validation: {'precision': 0.48266020556337186, 'recall': 0.3774368845381344, 'f1': 0.39384932524894717, 'auc': 0.7676909188016574, 'prauc': 0.4318705957298667}
Test:       {'precision': 0.4357569853800969, 'recall': 0.3787957527588152, 'f1': 0.39322024801274275, 'auc': 0.7597231402713287, 'prauc': 0.4304802527986336}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.69it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 123.28it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.49it/s]


Epoch: 013, Average Loss: 0.2286
Validation: {'precision': 0.5041683627949929, 'recall': 0.3695616868322826, 'f1': 0.3884138538287703, 'auc': 0.7654679135323829, 'prauc': 0.43209097632402327}
Test:       {'precision': 0.45964745665776924, 'recall': 0.3707600522894431, 'f1': 0.387784911826359, 'auc': 0.764251070432521, 'prauc': 0.42993339352725035}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.77it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.31it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 119.51it/s]


Epoch: 014, Average Loss: 0.2228
Validation: {'precision': 0.4838502327007339, 'recall': 0.3655724009462628, 'f1': 0.3892057340311156, 'auc': 0.7666702774103381, 'prauc': 0.43164456822701}
Test:       {'precision': 0.45717927608283265, 'recall': 0.36443029911098307, 'f1': 0.38475126522897285, 'auc': 0.7657634276547199, 'prauc': 0.42933128863429015}

Early stopping triggered after 14 epochs (no improvement for 5 epochs).

Best validation performance:
{'global': {'precision': 0.4790596921854736, 'recall': 0.38297644432814715, 'f1': 0.39417488097176945, 'auc': 0.7714955951512047, 'prauc': 0.4408572626369089}, 'per_class':                                            precision    recall        f1  \
Acute and unspecified renal failure         0.445680  0.450303  0.447979   
Acute cerebrovascular disease               0.000000  0.000000  0.000000   
Acute myocardial infarction                 0.000000  0.000000  0.000000   
Cardiac dysrhythmias                        0.781463  0.671293  0.722

Training Batches: 100%|██████████| 1713/1713 [00:51<00:00, 32.99it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 123.95it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.73it/s]


Epoch: 001, Average Loss: 0.3166
Validation: {'precision': 0.44012884961226795, 'recall': 0.3240743898384623, 'f1': 0.3434866546379857, 'auc': 0.7580570901830508, 'prauc': 0.42356559509248265}
Test:       {'precision': 0.4594012136463209, 'recall': 0.3281289800888677, 'f1': 0.3459872811903517, 'auc': 0.7551027556144446, 'prauc': 0.4234990160707823}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.44it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.58it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.44it/s]


Epoch: 002, Average Loss: 0.2888
Validation: {'precision': 0.42819986661639253, 'recall': 0.36624136098564963, 'f1': 0.38783421971756954, 'auc': 0.7666437612659364, 'prauc': 0.4313735226178135}
Test:       {'precision': 0.42523636002643683, 'recall': 0.37048058038270604, 'f1': 0.38851825487838415, 'auc': 0.7742715997407061, 'prauc': 0.43221337289213685}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.82it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 119.42it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 119.05it/s]


Epoch: 003, Average Loss: 0.2806
Validation: {'precision': 0.445410313200182, 'recall': 0.347977133402003, 'f1': 0.3752840203491949, 'auc': 0.773816187935981, 'prauc': 0.4369767058588623}
Test:       {'precision': 0.44286116490433297, 'recall': 0.3523173716493642, 'f1': 0.3766988085452089, 'auc': 0.773559627797018, 'prauc': 0.43576449304417614}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.76it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 123.35it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.24it/s]


Epoch: 004, Average Loss: 0.2749
Validation: {'precision': 0.44841268788966615, 'recall': 0.3664733914035723, 'f1': 0.3868295912197428, 'auc': 0.7762875858822675, 'prauc': 0.44026050718794657}
Test:       {'precision': 0.4405758214752914, 'recall': 0.36835891380984137, 'f1': 0.3852948127640079, 'auc': 0.778315444707041, 'prauc': 0.43830603900414117}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.71it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.49it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.57it/s]


Epoch: 005, Average Loss: 0.2696
Validation: {'precision': 0.44922082189767387, 'recall': 0.3473628515532392, 'f1': 0.36739102928820466, 'auc': 0.7748807988627616, 'prauc': 0.43875812695936794}
Test:       {'precision': 0.45031370042518165, 'recall': 0.3493512826922662, 'f1': 0.36764928383364354, 'auc': 0.7743560487776464, 'prauc': 0.441117882552632}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.79it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 123.05it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.28it/s]


Epoch: 006, Average Loss: 0.2639
Validation: {'precision': 0.4388911814880074, 'recall': 0.358670923975567, 'f1': 0.38595328241556714, 'auc': 0.7769591395381847, 'prauc': 0.4423220123728926}
Test:       {'precision': 0.43594794455586755, 'recall': 0.36044822508993096, 'f1': 0.3855879201037738, 'auc': 0.7763862576834772, 'prauc': 0.440979204242157}


Training Batches: 100%|██████████| 1713/1713 [00:51<00:00, 33.01it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 123.94it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 116.32it/s]


Epoch: 007, Average Loss: 0.2594
Validation: {'precision': 0.4696234665112612, 'recall': 0.3623720405931986, 'f1': 0.38298907866786647, 'auc': 0.7728471159610397, 'prauc': 0.44122370573943964}
Test:       {'precision': 0.47019650650110534, 'recall': 0.36748284464876707, 'f1': 0.38516613856462134, 'auc': 0.7744248414460191, 'prauc': 0.43736283876113213}

Early stopping triggered after 7 epochs (no improvement for 5 epochs).

Best validation performance:
{'global': {'precision': 0.42819986661639253, 'recall': 0.36624136098564963, 'f1': 0.38783421971756954, 'auc': 0.7666437612659364, 'prauc': 0.4313735226178135}, 'per_class':                                            precision    recall        f1  \
Acute and unspecified renal failure         0.511574  0.382022  0.437407   
Acute cerebrovascular disease               0.000000  0.000000  0.000000   
Acute myocardial infarction                 0.000000  0.000000  0.000000   
Cardiac dysrhythmias                        0.849810  0.653031  0

Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.65it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.61it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.26it/s]


Epoch: 001, Average Loss: 0.3158
Validation: {'precision': 0.4112129497492685, 'recall': 0.3324053875175779, 'f1': 0.35053823902287246, 'auc': 0.7585971458186496, 'prauc': 0.4180176834605478}
Test:       {'precision': 0.4169671103559796, 'recall': 0.3371491553490232, 'f1': 0.35341875689898217, 'auc': 0.761313188650069, 'prauc': 0.4177442644741691}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.58it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 123.02it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.62it/s]


Epoch: 002, Average Loss: 0.2892
Validation: {'precision': 0.43714720597162204, 'recall': 0.32813738365809036, 'f1': 0.35193393617675733, 'auc': 0.7704175868687718, 'prauc': 0.4354088000938025}
Test:       {'precision': 0.44675214439098393, 'recall': 0.3318869477701953, 'f1': 0.35290904776936416, 'auc': 0.7696957257817099, 'prauc': 0.4332458944515671}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.66it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 120.92it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.34it/s]


Epoch: 003, Average Loss: 0.2807
Validation: {'precision': 0.43539003546311444, 'recall': 0.34769947426748454, 'f1': 0.36911277215475047, 'auc': 0.7718403663832539, 'prauc': 0.43722367224305836}
Test:       {'precision': 0.43307246222119716, 'recall': 0.35125037851262664, 'f1': 0.370021817262413, 'auc': 0.7726266839376292, 'prauc': 0.43627523095907916}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.43it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.62it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.90it/s]


Epoch: 004, Average Loss: 0.2737
Validation: {'precision': 0.43455232953788225, 'recall': 0.35994430055229265, 'f1': 0.38244913567092337, 'auc': 0.7688925784776597, 'prauc': 0.44021361434607537}
Test:       {'precision': 0.43245606087785826, 'recall': 0.36455976872616386, 'f1': 0.38455946752514225, 'auc': 0.7713566550140548, 'prauc': 0.43847749963207544}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.76it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 119.56it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.24it/s]


Epoch: 005, Average Loss: 0.2681
Validation: {'precision': 0.41876044412331653, 'recall': 0.37679607675705756, 'f1': 0.38674527936175007, 'auc': 0.7738855076426294, 'prauc': 0.43911684778208243}
Test:       {'precision': 0.413427626875844, 'recall': 0.38158012307160155, 'f1': 0.38693023175594504, 'auc': 0.7746674556973069, 'prauc': 0.4405654274392072}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.85it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.69it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 119.29it/s]


Epoch: 006, Average Loss: 0.2633
Validation: {'precision': 0.46034416397830646, 'recall': 0.3646183735853319, 'f1': 0.3886028613939419, 'auc': 0.7674620994062461, 'prauc': 0.43893420156984214}
Test:       {'precision': 0.4510523161736212, 'recall': 0.36778277373808643, 'f1': 0.3889103161499331, 'auc': 0.7674917381750559, 'prauc': 0.4408812129584067}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.85it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 123.30it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 119.02it/s]


Epoch: 007, Average Loss: 0.2578
Validation: {'precision': 0.4428041557618878, 'recall': 0.3878965242825346, 'f1': 0.39768098276061103, 'auc': 0.7739290164985486, 'prauc': 0.4397932064387262}
Test:       {'precision': 0.4518202434596318, 'recall': 0.3891797203397819, 'f1': 0.3974736572567653, 'auc': 0.7714610718346516, 'prauc': 0.4435170513532689}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.77it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.29it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.61it/s]


Epoch: 008, Average Loss: 0.2528
Validation: {'precision': 0.4773714941056496, 'recall': 0.37133898583984887, 'f1': 0.3917353711384394, 'auc': 0.77586355527302, 'prauc': 0.4416335087989209}
Test:       {'precision': 0.48213653052233013, 'recall': 0.37306300019373245, 'f1': 0.39189339405430523, 'auc': 0.7728643384894577, 'prauc': 0.44364028950999207}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.50it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.75it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.40it/s]


Epoch: 009, Average Loss: 0.2479
Validation: {'precision': 0.47645713932564543, 'recall': 0.3615401586654328, 'f1': 0.3842600965858904, 'auc': 0.7713513261154333, 'prauc': 0.44104919604506976}
Test:       {'precision': 0.46630978289140645, 'recall': 0.3630397072989, 'f1': 0.3825612842577016, 'auc': 0.7693505259186028, 'prauc': 0.4415691540347545}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.64it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.98it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.72it/s]


Epoch: 010, Average Loss: 0.2410
Validation: {'precision': 0.45109467101900985, 'recall': 0.36863985732216153, 'f1': 0.39313590466486675, 'auc': 0.771443619460419, 'prauc': 0.4333958080890319}
Test:       {'precision': 0.5077691710047811, 'recall': 0.37059482346969175, 'f1': 0.39445642723822133, 'auc': 0.7665525921957539, 'prauc': 0.4348167007180134}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.60it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.90it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.63it/s]


Epoch: 011, Average Loss: 0.2369
Validation: {'precision': 0.5047871606158677, 'recall': 0.36891576452261926, 'f1': 0.38469128542341635, 'auc': 0.771726921082505, 'prauc': 0.4333649396290859}
Test:       {'precision': 0.4780024625443555, 'recall': 0.36827214568541433, 'f1': 0.38252346621696415, 'auc': 0.769531999827666, 'prauc': 0.43638925021596575}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.87it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 123.44it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.01it/s]


Epoch: 012, Average Loss: 0.2317
Validation: {'precision': 0.4657023933359056, 'recall': 0.3545436951647182, 'f1': 0.38286355395732163, 'auc': 0.7664124488082781, 'prauc': 0.4290132099126463}
Test:       {'precision': 0.5093236796023191, 'recall': 0.3584194489930836, 'f1': 0.3854919050047022, 'auc': 0.7646718853386564, 'prauc': 0.4320590112140231}

Early stopping triggered after 12 epochs (no improvement for 5 epochs).

Best validation performance:
{'global': {'precision': 0.4428041557618878, 'recall': 0.3878965242825346, 'f1': 0.39768098276061103, 'auc': 0.7739290164985486, 'prauc': 0.4397932064387262}, 'per_class':                                            precision    recall        f1  \
Acute and unspecified renal failure         0.492308  0.387208  0.433478   
Acute cerebrovascular disease               0.000000  0.000000  0.000000   
Acute myocardial infarction                 0.000000  0.000000  0.000000   
Cardiac dysrhythmias                        0.866469  0.639883  0.73613

Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.80it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 123.83it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.92it/s]


Epoch: 001, Average Loss: 0.3163
Validation: {'precision': 0.3622829113582403, 'recall': 0.3129755526308087, 'f1': 0.32388436884957666, 'auc': 0.7426339684358788, 'prauc': 0.40283532444244163}
Test:       {'precision': 0.3669178969925635, 'recall': 0.31840218450861385, 'f1': 0.32824673608307126, 'auc': 0.7423187837409636, 'prauc': 0.40422357677040716}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.56it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.92it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.02it/s]


Epoch: 002, Average Loss: 0.2906
Validation: {'precision': 0.4000559932095158, 'recall': 0.3296372308240556, 'f1': 0.3501158833253393, 'auc': 0.7594194126631517, 'prauc': 0.4294725278692123}
Test:       {'precision': 0.4470138719157435, 'recall': 0.3346270357246143, 'f1': 0.3535393050612773, 'auc': 0.7631416261257009, 'prauc': 0.42928660842669947}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.74it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.38it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.43it/s]


Epoch: 003, Average Loss: 0.2811
Validation: {'precision': 0.44763765943490647, 'recall': 0.335630529172556, 'f1': 0.36987538003822945, 'auc': 0.7646552415332992, 'prauc': 0.44120875011394745}
Test:       {'precision': 0.4452772183341616, 'recall': 0.3377366454923841, 'f1': 0.36958830696227796, 'auc': 0.7712418904296005, 'prauc': 0.43969736175156376}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.52it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 120.74it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.25it/s]


Epoch: 004, Average Loss: 0.2744
Validation: {'precision': 0.4622246261403171, 'recall': 0.3628953095116507, 'f1': 0.3780629213056899, 'auc': 0.7735588319680555, 'prauc': 0.4396550736954501}
Test:       {'precision': 0.42735976455391383, 'recall': 0.3615008143301534, 'f1': 0.37416521669485964, 'auc': 0.7747449225912246, 'prauc': 0.4375017660294964}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.56it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 123.86it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 119.04it/s]


Epoch: 005, Average Loss: 0.2695
Validation: {'precision': 0.4658589110520508, 'recall': 0.3595136520013279, 'f1': 0.3789580238338753, 'auc': 0.7779183713120218, 'prauc': 0.4445703002942493}
Test:       {'precision': 0.43206732293015476, 'recall': 0.36036553865866505, 'f1': 0.37661632413687, 'auc': 0.7749410956942393, 'prauc': 0.44372925169490773}


Training Batches: 100%|██████████| 1713/1713 [00:51<00:00, 33.01it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.06it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.28it/s]


Epoch: 006, Average Loss: 0.2635
Validation: {'precision': 0.49337218368586705, 'recall': 0.361098593080772, 'f1': 0.38473264230898657, 'auc': 0.7765675158158427, 'prauc': 0.4444568425720732}
Test:       {'precision': 0.48792160527541206, 'recall': 0.36300685312760916, 'f1': 0.38333726767103254, 'auc': 0.7711151981444616, 'prauc': 0.4395774467271672}


Training Batches: 100%|██████████| 1713/1713 [00:51<00:00, 33.00it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 123.97it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.77it/s]


Epoch: 007, Average Loss: 0.2590
Validation: {'precision': 0.45515986749095305, 'recall': 0.3585398284880762, 'f1': 0.3788655350038396, 'auc': 0.7704376672776323, 'prauc': 0.44118907719528644}
Test:       {'precision': 0.4229451539100494, 'recall': 0.3611247698095854, 'f1': 0.3771458518155178, 'auc': 0.7730594514842472, 'prauc': 0.4389914458029416}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.78it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.12it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.33it/s]


Epoch: 008, Average Loss: 0.2529
Validation: {'precision': 0.4665947933825416, 'recall': 0.3425454565616777, 'f1': 0.37199297402532233, 'auc': 0.7715065332952585, 'prauc': 0.4402514432315996}
Test:       {'precision': 0.44544916154083875, 'recall': 0.3433882557548348, 'f1': 0.36786806861490706, 'auc': 0.770798325274499, 'prauc': 0.4383742957489899}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.58it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.42it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 119.05it/s]


Epoch: 009, Average Loss: 0.2480
Validation: {'precision': 0.46940400629752727, 'recall': 0.35897739544047175, 'f1': 0.38747768395834253, 'auc': 0.7692075370950557, 'prauc': 0.43959676368149175}
Test:       {'precision': 0.45487611882301615, 'recall': 0.3584753284527331, 'f1': 0.38361337537664764, 'auc': 0.7683939343190067, 'prauc': 0.43650614285415057}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.36it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 123.04it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.68it/s]


Epoch: 010, Average Loss: 0.2433
Validation: {'precision': 0.4592646560293272, 'recall': 0.35684472283164087, 'f1': 0.37954664057974924, 'auc': 0.771158860803776, 'prauc': 0.43714512169175584}
Test:       {'precision': 0.4515307939684469, 'recall': 0.3594900942688167, 'f1': 0.3797606539077049, 'auc': 0.7710523885073095, 'prauc': 0.4374560286963859}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.70it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 123.63it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 119.46it/s]


Epoch: 011, Average Loss: 0.2368
Validation: {'precision': 0.49506278115437297, 'recall': 0.36884324256179607, 'f1': 0.39064945269990453, 'auc': 0.7724499725910179, 'prauc': 0.43680541884414054}
Test:       {'precision': 0.47358622103568926, 'recall': 0.3735946176912273, 'f1': 0.3925895656061502, 'auc': 0.7688608012989481, 'prauc': 0.4345109254991718}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.91it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 124.20it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.90it/s]


Epoch: 012, Average Loss: 0.2313
Validation: {'precision': 0.4825427090438114, 'recall': 0.3715318267493917, 'f1': 0.39554646408644323, 'auc': 0.7702136404535205, 'prauc': 0.4345965012708435}
Test:       {'precision': 0.5108902315544129, 'recall': 0.3693705167318601, 'f1': 0.3902104846246799, 'auc': 0.7642816995391571, 'prauc': 0.433303154229437}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.86it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 123.70it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.55it/s]


Epoch: 013, Average Loss: 0.2263
Validation: {'precision': 0.4886488179612056, 'recall': 0.36807495968710885, 'f1': 0.3909092468912237, 'auc': 0.7692263690070008, 'prauc': 0.43515584283077174}
Test:       {'precision': 0.4750273524807457, 'recall': 0.37314652707747453, 'f1': 0.3937952462949941, 'auc': 0.7657450724099396, 'prauc': 0.43346654288410413}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.73it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.73it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.85it/s]


Epoch: 014, Average Loss: 0.2208
Validation: {'precision': 0.5074857287635477, 'recall': 0.3826972635567656, 'f1': 0.39772332276786504, 'auc': 0.7677922145504161, 'prauc': 0.4297934908531442}
Test:       {'precision': 0.5206063537551502, 'recall': 0.38605959295926856, 'f1': 0.4000128265434976, 'auc': 0.7633459793988345, 'prauc': 0.43218652611607405}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.71it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.02it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.75it/s]


Epoch: 015, Average Loss: 0.2163
Validation: {'precision': 0.48747731575082526, 'recall': 0.3775971865833518, 'f1': 0.392305989654367, 'auc': 0.7652556390691833, 'prauc': 0.4302660532612161}
Test:       {'precision': 0.48066893843851105, 'recall': 0.3824107684937445, 'f1': 0.39483656895274455, 'auc': 0.7598546577301541, 'prauc': 0.426575587781434}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.73it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.88it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 116.67it/s]


Epoch: 016, Average Loss: 0.2099
Validation: {'precision': 0.49023143767025534, 'recall': 0.3862232997327829, 'f1': 0.3984212135992766, 'auc': 0.7629316674464447, 'prauc': 0.42594608443068904}
Test:       {'precision': 0.49773984386661, 'recall': 0.39095485057358786, 'f1': 0.4012831171943006, 'auc': 0.7565799363805041, 'prauc': 0.42394270943227447}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.83it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.53it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 119.69it/s]


Epoch: 017, Average Loss: 0.2064
Validation: {'precision': 0.4687091543745383, 'recall': 0.3719271087221923, 'f1': 0.3942148590509889, 'auc': 0.7661881160776294, 'prauc': 0.4261809598398301}
Test:       {'precision': 0.49547428830811335, 'recall': 0.36979750908270304, 'f1': 0.3897483534896254, 'auc': 0.7628052735040618, 'prauc': 0.4256670365381677}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.72it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.88it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 116.83it/s]


Epoch: 018, Average Loss: 0.2018
Validation: {'precision': 0.5162877702677572, 'recall': 0.36346889673324534, 'f1': 0.38912709961727804, 'auc': 0.765111440149161, 'prauc': 0.4267733799233129}
Test:       {'precision': 0.5153824411160148, 'recall': 0.362397531946563, 'f1': 0.3842434321628646, 'auc': 0.7601922802187833, 'prauc': 0.4272038708670924}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.64it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.07it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 120.02it/s]


Epoch: 019, Average Loss: 0.1980
Validation: {'precision': 0.4918959658681314, 'recall': 0.3775991594139352, 'f1': 0.39503992502045226, 'auc': 0.7577736463618705, 'prauc': 0.42315465396473245}
Test:       {'precision': 0.52499260013022, 'recall': 0.3796287516457493, 'f1': 0.3987049855962195, 'auc': 0.7592465354158362, 'prauc': 0.424016338573005}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.54it/s]
Running inference:  87%|████████▋ | 1092/1262 [00:08<00:01, 110.31it/s]

In [None]:
def topk_avg_performance_formatted(performances, long_seq_performances, k=5):
    metrics = ["f1", "auc", "prauc"]
    scores = {m: np.array([p[m] for p in performances]) for m in metrics}

    # 计算排名（值越大排名越靠前）
    ranks = {m: (-scores[m]).argsort().argsort() + 1 for m in metrics}
    avg_ranks = np.mean(np.stack([ranks[m] for m in metrics], axis=1), axis=1)

    # 选 top-k
    topk_idx = np.argsort(avg_ranks)[:k]
    final_avg = {m: np.mean([performances[i][m] for i in topk_idx]) for m in performances[0].keys()}
    final_std = {m: np.std([performances[i][m] for i in topk_idx], ddof=0) for m in performances[0].keys()}
    final_long_seq_avg = {m: np.mean([long_seq_performances[i][m] for i in topk_idx]) for m in long_seq_performances[0].keys()}
    final_long_seq_std = {m: np.std([long_seq_performances[i][m] for i in topk_idx], ddof=0) for m in long_seq_performances[0].keys()}

    # 打印结果（转百分比，均保留两位小数）
    print("Final Metrics:")
    for m in performances[0].keys():
        mean_val = final_avg[m] * 100
        std_val = final_std[m] * 100
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")
    print("\nFinal Long Sequence Metrics:")
    for m in long_seq_performances[0].keys():
        mean_val = final_long_seq_avg[m] * 100
        std_val = final_long_seq_std[m] * 100
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")

In [None]:
def print_per_class_performance(dfs, col_name="prauc"):
    """
    输入一个 DataFrame 列表，对每个疾病在所有表格的指定列计算 mean ± std 并打印。

    参数:
        dfs (list[pd.DataFrame]): 多个表格组成的列表
        col_name (str): 要计算的指标列名 (默认: "prauc")
    """
    # 拼接所有表格
    all_values = pd.concat(dfs, axis=0)

    # 按疾病分组，计算 mean 和 std
    grouped = all_values.groupby(all_values.index)[col_name].agg(["mean", "std"])

    # 打印
    for disease, row in grouped.iterrows():
        mean_val = row["mean"] * 100
        std_val = row["std"] * 100
        print(f"{disease}: {mean_val:.2f} ± {std_val:.2f}")

In [None]:
if task_type == "binary":
    topk_avg_performance_formatted(final_metrics, final_long_seq_metrics)
else:
    final_metrics_global = [metrics["global"] for metrics in final_metrics]
    final_metrics_per_class = [metrics["per_class"] for metrics in final_metrics]
    final_long_seq_metrics_global = [metrics["global"] for metrics in final_long_seq_metrics]
    final_long_seq_metrics_per_class = [metrics["per_class"] for metrics in final_long_seq_metrics]
    topk_avg_performance_formatted(final_metrics_global, final_long_seq_metrics_global)
    print("\nPer-class performance, all patients:")
    print_per_class_performance(final_metrics_per_class, col_name="prauc")
    print("\nPer-class performance, long seq:")
    print_per_class_performance(final_long_seq_metrics_per_class, col_name="prauc")