In [1]:
from set_seed_utils import set_random_seed
import os
import random
import numpy as np
import pickle
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from token_utils_rep import EHRTokenizer
from dataset_utils_rep import HBERTFinetuneEHRDataset, batcher, UniqueIDSampler
from HEART_rep import HBERT_Finetune
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve, precision_recall_fscore_support
import pandas as pd

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
PHENO_ORDER = [
    "Acute and unspecified renal failure",
    "Acute cerebrovascular disease",
    "Acute myocardial infarction",
    "Cardiac dysrhythmias",
    "Chronic kidney disease",
    "Chronic obstructive pulmonary disease",
    "Conduction disorders",
    "Congestive heart failure; nonhypertensive",
    "Coronary atherosclerosis and related",
    "Disorders of lipid metabolism",
    "Essential hypertension",
    "Fluid and electrolyte disorders",
    "Gastrointestinal hemorrhage",
    "Hypertension with complications",
    "Other liver diseases",
    "Other lower respiratory disease",
    "Pneumonia",
    "Septicemia (except in labor)",
]

In [4]:
@torch.no_grad()
def evaluate(model, dataloader, device, long_seq_idx=None, task_type="binary"):
    model.eval()
    predicted_scores, gt_labels = [], []

    # 推理：收集 logits 与 labels
    for _, batch in enumerate(tqdm(dataloader, desc="Running inference")):
        batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]
        labels = batch[-1]
        output_logits = model(*batch[:-1])
        predicted_scores.append(output_logits)
        gt_labels.append(labels)

    if task_type == "binary":
        # —— 标准二分类评估 —— #
        logits_all = torch.cat(predicted_scores, dim=0).view(-1)          # logits [N]
        labels_all = torch.cat(gt_labels, dim=0).view(-1).cpu().numpy()    # y_true [N]
        scores_all = logits_all.cpu().numpy()                              # 连续分数（logits）
        ypred_all  = (logits_all > 0).float().cpu().numpy()                # logits > 0

        tp = (ypred_all * labels_all).sum()
        precision = tp / (ypred_all.sum() + 1e-8)
        recall    = tp / (labels_all.sum() + 1e-8)
        f1        = 2 * precision * recall / (precision + recall + 1e-8)
        roc_auc   = roc_auc_score(labels_all, scores_all)
        prec_curve, rec_curve, _ = precision_recall_curve(labels_all, scores_all)
        pr_auc    = auc(rec_curve, prec_curve)

        all_performance = {"precision": float(precision),
                           "recall": float(recall),
                           "f1": float(f1),
                           "auc": float(roc_auc),
                           "prauc": float(pr_auc)}

        subset_performance = None
        if long_seq_idx is not None:
            idx = torch.as_tensor(long_seq_idx, device=logits_all.device, dtype=torch.long)
            logits_sub = logits_all.index_select(0, idx).view(-1)
            labels_sub = torch.as_tensor(labels_all, device=logits_all.device)[idx].cpu().numpy()
            scores_sub = logits_sub.cpu().numpy()
            ypred_sub  = (logits_sub > 0).float().cpu().numpy()

            tp = (ypred_sub * labels_sub).sum()
            precision = tp / (ypred_sub.sum() + 1e-8)
            recall    = tp / (labels_sub.sum() + 1e-8)
            f1        = 2 * precision * recall / (precision + recall + 1e-8)
            roc_auc   = roc_auc_score(labels_sub, scores_sub)
            prec_curve, rec_curve, _ = precision_recall_curve(labels_sub, scores_sub)
            pr_auc    = auc(rec_curve, prec_curve)

            subset_performance = {"precision": float(precision),
                                  "recall": float(recall),
                                  "f1": float(f1),
                                  "auc": float(roc_auc),
                                  "prauc": float(pr_auc)}

        return all_performance, subset_performance

    else:
        # —— Multi-label evaluation（按类聚合） —— #
        logits_all = torch.cat(predicted_scores, dim=0)    # [B, C]
        labels_all_t = torch.cat(gt_labels, dim=0)         # [B, C]

        def _compute_metrics(logits_sub, labels_sub):
            # 连续分数（概率）：sigmoid(logits)，CPU + fp16 先升为 fp32
            if logits_sub.device.type == "cpu" and logits_sub.dtype == torch.float16:
                prob_t = torch.sigmoid(logits_sub.float())
            else:
                prob_t = torch.sigmoid(logits_sub)
            # 二值化：logits > 0
            ypred_t = (logits_sub > 0).to(torch.int32)

            y_true = labels_sub.cpu().numpy().astype(np.int32)       # [N, C]
            y_pred = ypred_t.cpu().numpy().astype(np.int32)          # [N, C]
            scores = prob_t.cpu().numpy()                             # [N, C]

            # per-class P/R/F1
            p_cls, r_cls, f1_cls, _ = precision_recall_fscore_support(
                y_true, y_pred, average=None, zero_division=0
            )

            # per-class AUC / PR-AUC
            C = y_true.shape[1]
            aucs, praucs = [], []
            for c in range(C):
                yt, ys = y_true[:, c], scores[:, c]
                if yt.max() == yt.min():
                    aucs.append(np.nan)
                    praucs.append(np.nan)
                else:
                    aucs.append(roc_auc_score(yt, ys))
                    prec_curve, rec_curve, _ = precision_recall_curve(yt, ys)
                    praucs.append(auc(rec_curve, prec_curve))

            # 宏平均（忽略 NaN）
            summary = {
                "precision": float(np.mean(p_cls)),
                "recall":    float(np.mean(r_cls)),
                "f1":        float(np.mean(f1_cls)),
                "auc":       float(np.nanmean(aucs)) if np.any(~np.isnan(aucs)) else float("nan"),
                "prauc":     float(np.nanmean(praucs)) if np.any(~np.isnan(praucs)) else float("nan"),
            }

            per_class_df = pd.DataFrame({
                "precision": p_cls,
                "recall":    r_cls,
                "f1":        f1_cls,
                "auc":       aucs,
                "prauc":     praucs,
            }, index=PHENO_ORDER)   # 确保 PHENO_ORDER 已定义且长度=C

            return {"global": summary, "per_class": per_class_df}

        # 全量
        all_performance = _compute_metrics(logits_all, labels_all_t)

        # 子集
        subset_performance = None
        if long_seq_idx is not None:
            idx = torch.as_tensor(long_seq_idx, device=logits_all.device, dtype=torch.long)
            subset_performance = _compute_metrics(
                logits_all.index_select(0, idx),
                labels_all_t.index_select(0, idx)
            )

        return all_performance, subset_performance

In [5]:
args = {
    "seed": 0,
    "dataset": "MIMIC-III", 
    "task": "next_diag_12m",  # options: death, stay, readmission, next_diag_6m, next_diag_12m
    "encoder": "hi_edge",  # options: hi_edge, hi_node, hi_edge_node
    "batch_size": 4,
    "eval_batch_size": 4,
    "pretrain_mask_rate": 0.7,
    "pretrain_anomaly_rate": 0.05,
    "pretrain_anomaly_loss_weight": 1,
    "pretrain_pos_weight": 1,
    "lr": 1e-4,
    "epochs": 500,
    "num_hidden_layers": 5,
    "num_attention_heads": 6,
    "attention_probs_dropout_prob": 0.2,
    "hidden_dropout_prob": 0.2,
    "edge_hidden_size": 32,
    "hidden_size": 288,  # must be divisible by num_attention_heads
    "intermediate_size": 288,
    "save_model": True,
    "gat": "dotattn",
    "gnn_n_heads": 1,
    "gnn_temp": 1,
    "diag_med_emb": "tree",  # simple, tree
    "early_stop_patience": 5,
}

In [6]:
exp_name = "Pretrain-HBERT" \
    + "-" + str(args["dataset"]) \
    + "-" + str(args["encoder"]) \
    + "-" + str(args["pretrain_mask_rate"]) \
    + "-" + str(args["pretrain_anomaly_rate"]) \
    + "-" + str(args["pretrain_anomaly_loss_weight"]) \
    + "-" + str(args["hidden_size"]) \
    + "-" + str(args["edge_hidden_size"]) \
    + "-" + str(args["num_hidden_layers"]) \
    + "-" + str(args["num_attention_heads"]) \
    + "-" + str(args["attention_probs_dropout_prob"]) \
    + "-" + str(args["hidden_dropout_prob"]) \
    + "-" + str(args["intermediate_size"]) \
    + "-" + str(args["gat"]) \
    + "-" + str(args["gnn_n_heads"]) \
    + "-" + str(args["gnn_temp"]) \
    + "-" + str(args["diag_med_emb"])
print(exp_name)

Pretrain-HBERT-MIMIC-III-hi_edge-0.7-0.05-1-288-32-5-6-0.2-0.2-288-dotattn-1-1-tree


In [7]:
pretrained_weight_path = "./pretrained_models/" + exp_name + f"/pretrained_model.pt"
finetune_exp_name = f"Finetune-{args['task']}-" + exp_name
save_path = "./saved_model/" + finetune_exp_name
if args["save_model"] and not os.path.exists(save_path):
    os.makedirs(save_path)

In [8]:
args["predicted_token_type"] = ["diag", "med", "pro", "lab"]
args["special_tokens"] = ("[PAD]", "[CLS]", "[SEP]", "[MASK0]", "[MASK1]", "[MASK2]", "[MASK3]")
args["max_visit_size"] = 15

full_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic.pkl"

if args["task"] == "next_diag_6m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_6m.pkl"
elif args["task"] == "next_diag_12m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_downstream.pkl"

In [9]:
ehr_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_data["NDC"].values.tolist()
lab_sentences = ehr_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_data["PRO_CODE"].values.tolist()
gender_set = [["M"], ["F"]]
age_gender_set = [[str(c) + "_" + gender] for c in set(ehr_data["AGE"].values.tolist()) for gender in ["M", "F"]]
age_set = [[c] for c in set(ehr_data["AGE"].values.tolist())]    

In [10]:
tokenizer = EHRTokenizer(diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, gender_set, age_set, age_gender_set, special_tokens=args["special_tokens"])
tokenizer.build_tree()

In [11]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))

train_dataset = HBERTFinetuneEHRDataset(
    train_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

val_dataset = HBERTFinetuneEHRDataset(
    val_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

test_dataset = HBERTFinetuneEHRDataset(
    test_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

train_dataloader = DataLoader(
    train_dataset, 
    batch_sampler=UniqueIDSampler(train_dataset.get_ids(), batch_size=args["batch_size"]),
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False), 
)

val_dataloader = DataLoader(
    val_dataset, 
    batch_sampler=UniqueIDSampler(val_dataset.get_ids(), batch_size=args["batch_size"]),
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False), 
)

test_dataloader = DataLoader(
    test_dataset, 
    batch_sampler=UniqueIDSampler(test_dataset.get_ids(), batch_size=args["eval_batch_size"]),
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False),
)

In [23]:
def print_pheno_percentages(df: pd.DataFrame):
    # 选出所有以 "PHENO_" 开头的列
    pheno_cols = [c for c in df.columns if c.startswith("PHENO_")]
    
    for col in pheno_cols:
        # 计算该列中 1 的百分比
        percentage = df[col].mean() * 100
        print(f"{col}: {percentage:.1f}%")

In [24]:
print_pheno_percentages(test_data)

PHENO_Acute and unspecified renal failure: 16.0%
PHENO_Acute cerebrovascular disease: 0.9%
PHENO_Acute myocardial infarction: 3.7%
PHENO_Cardiac dysrhythmias: 20.1%
PHENO_Chronic kidney disease: 12.4%
PHENO_Chronic obstructive pulmonary disease: 6.4%
PHENO_Conduction disorders: 1.4%
PHENO_Congestive heart failure; nonhypertensive: 20.1%
PHENO_Coronary atherosclerosis and related: 12.1%
PHENO_Disorders of lipid metabolism: 13.7%
PHENO_Essential hypertension: 18.9%
PHENO_Fluid and electrolyte disorders: 21.0%
PHENO_Gastrointestinal hemorrhage: 3.6%
PHENO_Hypertension with complications: 11.5%
PHENO_Other liver diseases: 0.9%
PHENO_Other lower respiratory disease: 21.4%
PHENO_Pneumonia: 7.1%
PHENO_Septicemia (except in labor): 11.7%


In [12]:
long_adm_seq_crite = 3
val_long_seq_idx, test_long_seq_idx = [], []
for i in range(len(val_dataset)):
    hadm_id = list(val_dataset.records.keys())[i]
    num_adms = len(val_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite:
        val_long_seq_idx.append(i)
for i in range(len(test_dataset)):
    hadm_id = list(test_dataset.records.keys())[i]
    num_adms = len(test_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite:
        test_long_seq_idx.append(i)
print(len(val_long_seq_idx), len(test_long_seq_idx))

208 186


In [13]:
# examine a batch
batch = next(iter(train_dataloader))  # 取第一个 batch
input_ids, input_types, edge_index, visit_positions, labeled_batch_idx, labels = batch

# 打印每个张量的形状
print("input_ids shape:", input_ids.shape)
print("input_types shape:", input_types.shape)
print("visit_positions shape:", visit_positions.shape)
print("labeled_batch_idx shape:", len(labeled_batch_idx)) # it is a list
print("labels shape:", labels.shape)

input_ids shape: torch.Size([12, 73])
input_types shape: torch.Size([12, 73])
visit_positions shape: torch.Size([12])
labeled_batch_idx shape: 4
labels shape: torch.Size([4, 18])


In [14]:
args["vocab_size"] = len(args["special_tokens"]) + \
                     len(tokenizer.diag_voc.id2word) + \
                     len(tokenizer.pro_voc.id2word) + \
                     len(tokenizer.med_voc.id2word) + \
                     len(tokenizer.lab_voc.id2word) + \
                     len(tokenizer.age_voc.id2word) + \
                     len(tokenizer.gender_voc.id2word) + \
                     len(tokenizer.age_gender_voc.id2word)
args["label_vocab_size"] = 18  # only for diagnosis

In [15]:
if args["task"] in ["death", "stay", "readmission"]:
    eval_metric = "f1"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "prauc"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [16]:
def train_with_early_stopping(model, 
                              train_dataloader, 
                              val_dataloader, 
                              test_dataloader,
                              optimizer, 
                              loss_fn, 
                              device, 
                              args,
                              val_long_seq_idx = None,
                              test_long_seq_idx = None,
                              task_type="binary", 
                              eval_metric="f1"):
    best_score = 0.
    best_val_metric = None
    best_test_metric = None
    best_test_long_seq_metric = None
    epochs_no_improve = 0

    for epoch in range(1, 1 + args["epochs"]):
        model.train()
        ave_loss = 0.

        for step, batch in enumerate(tqdm(train_dataloader, desc="Training Batches")):
            batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]

            labels = batch[-1].float()
            output_logits = model(*batch[:-1])
            
            loss = loss_fn(output_logits.view(-1), labels.view(-1))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            ave_loss += loss.item()
            

        ave_loss /= (step + 1)

        # Evaluation
        val_metric, val_long_seq_metric = evaluate(model, val_dataloader, device, task_type=task_type, long_seq_idx=val_long_seq_idx)
        test_metric, test_long_seq_metric = evaluate(model, test_dataloader, device, task_type=task_type, long_seq_idx=test_long_seq_idx)

        if task_type != "binary":
            val_per_class_df = val_metric["per_class"]
            val_metric = val_metric["global"]
            test_per_class_df = test_metric["per_class"]
            test_metric = test_metric["global"]
            
            if val_long_seq_idx != None:
                val_long_seq_per_class_df = val_long_seq_metric["per_class"]
                val_long_seq_metric = val_long_seq_metric["global"]
            if test_long_seq_idx != None:
                test_long_seq_per_class_df = test_long_seq_metric["per_class"]
                test_long_seq_metric = test_long_seq_metric["global"]

        # Logging
        print(f"Epoch: {epoch:03d}, Average Loss: {ave_loss:.4f}")
        print(f"Validation: {val_metric}")
        print(f"Test:       {test_metric}")

        # Check for improvement
        current_score = val_metric[eval_metric]
        if current_score > best_score:
            best_score = current_score
            best_val_metric = val_metric if task_type == "binary" else {"global": val_metric, "per_class": val_per_class_df}
            best_test_metric = test_metric if task_type == "binary" else {"global": test_metric, "per_class": test_per_class_df}
            best_test_long_seq_metric = test_long_seq_metric if task_type == "binary" else {"global": test_long_seq_metric, "per_class": test_long_seq_per_class_df}
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        # Early stopping check
        if epochs_no_improve >= args["early_stop_patience"]:
            print(f"\nEarly stopping triggered after {epoch} epochs (no improvement for {args['early_stop_patience']} epochs).")
            break

    print("\nBest validation performance:")
    print(best_val_metric)
    print("Corresponding test performance:")
    print(best_test_metric)
    if best_test_long_seq_metric is not None:
        print("Corresponding test-long performance:")
        print(best_test_long_seq_metric)
    return best_test_metric, best_test_long_seq_metric

In [17]:
random.seed(42)
seeds = [random.randint(0, 2**32 - 1) for _ in range(5)]
print(seeds)

[2746317213, 1181241943, 958682846, 3163119785, 1812140441]


In [18]:
final_metrics, final_long_seq_metrics = [], []

for seed in seeds:
    args["seed"] = seed
    set_random_seed(args["seed"])
    print(f"Training with seed: {args['seed']}")
    
    # Initialize model, optimizer, and loss function
    model = HBERT_Finetune(args, tokenizer)
    model.load_weight(torch.load(pretrained_weight_path, weights_only=True))
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args["lr"])
    
    best_test_metric, best_test_long_seq_metric = train_with_early_stopping(
        model, 
        train_dataloader, 
        val_dataloader, 
        test_dataloader,
        optimizer, 
        loss_fn, 
        device, 
        args,
        val_long_seq_idx,
        test_long_seq_idx,
        task_type=task_type)
    
    final_metrics.append(best_test_metric)
    final_long_seq_metrics.append(best_test_long_seq_metric)

[INFO] Random seed set to 2746317213
Training with seed: 2746317213


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 24.98it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.10it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.24it/s]


Epoch: 001, Average Loss: 0.4276
Validation: {'precision': 0.35283331460943335, 'recall': 0.2584575353538129, 'f1': 0.27769093168901726, 'auc': 0.7016175242771628, 'prauc': 0.3963236935141467}
Test:       {'precision': 0.35117930825242893, 'recall': 0.2524380800991009, 'f1': 0.26997903204143586, 'auc': 0.7084981210996388, 'prauc': 0.39248710286939276}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.38it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.56it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.43it/s] 


Epoch: 002, Average Loss: 0.3877
Validation: {'precision': 0.3713067402634899, 'recall': 0.2820737753525924, 'f1': 0.3122772005509067, 'auc': 0.7178276959958872, 'prauc': 0.422563155070022}
Test:       {'precision': 0.3627016199612071, 'recall': 0.27231453316457344, 'f1': 0.3011248598949525, 'auc': 0.7255587898930295, 'prauc': 0.4095277861556732}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.42it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.84it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.18it/s] 


Epoch: 003, Average Loss: 0.3691
Validation: {'precision': 0.40613581561636225, 'recall': 0.3427372300552201, 'f1': 0.35480242171119225, 'auc': 0.7289993410554786, 'prauc': 0.44195496013435537}
Test:       {'precision': 0.4074749717972427, 'recall': 0.3368446250030969, 'f1': 0.3494997125316862, 'auc': 0.7409954732242249, 'prauc': 0.425853810196908}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.77it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.61it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.76it/s] 


Epoch: 004, Average Loss: 0.3556
Validation: {'precision': 0.43435847043703, 'recall': 0.33129854863297703, 'f1': 0.35631145643866646, 'auc': 0.7406178835039119, 'prauc': 0.4482600728054269}
Test:       {'precision': 0.4096783641181909, 'recall': 0.314653212442428, 'f1': 0.33843057526024906, 'auc': 0.7451690613130327, 'prauc': 0.4283417673771827}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 26.00it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.89it/s] 
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.02it/s] 


Epoch: 005, Average Loss: 0.3447
Validation: {'precision': 0.4291526679455893, 'recall': 0.35863369832612324, 'f1': 0.37679349890837477, 'auc': 0.7448661986686509, 'prauc': 0.45324593886509607}
Test:       {'precision': 0.4226632595890296, 'recall': 0.3400908106352614, 'f1': 0.36073902421177484, 'auc': 0.743918205661721, 'prauc': 0.4303253037649466}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.43it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.58it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.88it/s]


Epoch: 006, Average Loss: 0.3349
Validation: {'precision': 0.4463198649839074, 'recall': 0.33512185474783096, 'f1': 0.3667413201093372, 'auc': 0.7428290728277825, 'prauc': 0.45002160900315896}
Test:       {'precision': 0.4574578720092412, 'recall': 0.32382885830864516, 'f1': 0.3529637975272611, 'auc': 0.7442609313683857, 'prauc': 0.4327282401870567}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.33it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.94it/s] 
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.32it/s] 


Epoch: 007, Average Loss: 0.3261
Validation: {'precision': 0.4295589156107668, 'recall': 0.38502777553325296, 'f1': 0.3951907240571486, 'auc': 0.735417474149385, 'prauc': 0.44811642666806784}
Test:       {'precision': 0.42256583678813153, 'recall': 0.3806148979771963, 'f1': 0.3902797603921597, 'auc': 0.7404241284718522, 'prauc': 0.43510523656218414}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.38it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.68it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.09it/s]


Epoch: 008, Average Loss: 0.3161
Validation: {'precision': 0.4501533574565022, 'recall': 0.3729178686265433, 'f1': 0.39768141684956954, 'auc': 0.7335108182513613, 'prauc': 0.4440768516694446}
Test:       {'precision': 0.42580984886876144, 'recall': 0.35870740633409975, 'f1': 0.37951712429211093, 'auc': 0.7340517266813751, 'prauc': 0.42933652716632964}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.49it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.09it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.08it/s]


Epoch: 009, Average Loss: 0.3080
Validation: {'precision': 0.44500540976066183, 'recall': 0.35708581286445923, 'f1': 0.38694590189271805, 'auc': 0.737043832203674, 'prauc': 0.44553154449816673}
Test:       {'precision': 0.42875878315434435, 'recall': 0.3427918305165544, 'f1': 0.3710046290022618, 'auc': 0.740002702226083, 'prauc': 0.42711286505206586}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.25it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.75it/s] 
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.91it/s] 


Epoch: 010, Average Loss: 0.2999
Validation: {'precision': 0.43285358654689493, 'recall': 0.38622145105137484, 'f1': 0.3885949253335661, 'auc': 0.735021772643178, 'prauc': 0.44382143990240586}
Test:       {'precision': 0.40223502191268573, 'recall': 0.37125183772334225, 'f1': 0.3728708640155178, 'auc': 0.7373873690314774, 'prauc': 0.4235727390659972}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.26it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.28it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.81it/s]


Epoch: 011, Average Loss: 0.2946
Validation: {'precision': 0.4141882669218621, 'recall': 0.3951242236737322, 'f1': 0.3945253874777283, 'auc': 0.7293607389086356, 'prauc': 0.43940072783989387}
Test:       {'precision': 0.4055104912979682, 'recall': 0.3806497177305827, 'f1': 0.38186983745602615, 'auc': 0.73046574988029, 'prauc': 0.4204540324071912}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.20it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.21it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.09it/s]


Epoch: 012, Average Loss: 0.2850
Validation: {'precision': 0.5347113029722004, 'recall': 0.3831398906540611, 'f1': 0.3921309916151927, 'auc': 0.7278228572767431, 'prauc': 0.43846744929011544}
Test:       {'precision': 0.4019331012922792, 'recall': 0.3640476126079648, 'f1': 0.3729076372351234, 'auc': 0.7299949502017381, 'prauc': 0.41620238782388924}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.31it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.85it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.15it/s]


Epoch: 013, Average Loss: 0.2752
Validation: {'precision': 0.5168367974001359, 'recall': 0.39722857675433243, 'f1': 0.39626482250772854, 'auc': 0.721910086662563, 'prauc': 0.4295384692121841}
Test:       {'precision': 0.3923388783384754, 'recall': 0.3806079373359318, 'f1': 0.3774439727867864, 'auc': 0.7232828196299944, 'prauc': 0.4135228479960465}

Early stopping triggered after 13 epochs (no improvement for 5 epochs).

Best validation performance:
{'global': {'precision': 0.4501533574565022, 'recall': 0.3729178686265433, 'f1': 0.39768141684956954, 'auc': 0.7335108182513613, 'prauc': 0.4440768516694446}, 'per_class':                                            precision    recall        f1  \
Acute and unspecified renal failure         0.560117  0.440092  0.492903   
Acute cerebrovascular disease               0.000000  0.000000  0.000000   
Acute myocardial infarction                 0.266667  0.084211  0.128000   
Cardiac dysrhythmias                        0.862745  0.642336  0.73640

Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.24it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.99it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.92it/s]


Epoch: 001, Average Loss: 0.4313
Validation: {'precision': 0.3527923758456731, 'recall': 0.2906151198386232, 'f1': 0.3075321380533551, 'auc': 0.6959050887241458, 'prauc': 0.4075612197136186}
Test:       {'precision': 0.34958033243190967, 'recall': 0.28125263967848974, 'f1': 0.30057796532469433, 'auc': 0.7136820295884212, 'prauc': 0.40486797797439933}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.42it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.09it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.37it/s]


Epoch: 002, Average Loss: 0.3847
Validation: {'precision': 0.4368228705146864, 'recall': 0.29198684900389915, 'f1': 0.31453986678612744, 'auc': 0.7160751258690667, 'prauc': 0.4204411882390817}
Test:       {'precision': 0.37656149666757244, 'recall': 0.2833043041539991, 'f1': 0.3048505980974714, 'auc': 0.7254948482001491, 'prauc': 0.41740899773763496}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.56it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.76it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.35it/s]


Epoch: 003, Average Loss: 0.3683
Validation: {'precision': 0.46213068131812274, 'recall': 0.3014795509899517, 'f1': 0.32660463505394005, 'auc': 0.7308118991421461, 'prauc': 0.4327653669494707}
Test:       {'precision': 0.388098905872581, 'recall': 0.29153073991072714, 'f1': 0.3149484005769832, 'auc': 0.7386383834780192, 'prauc': 0.42335767292320825}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.35it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.65it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.87it/s] 


Epoch: 004, Average Loss: 0.3527
Validation: {'precision': 0.43917853922796407, 'recall': 0.3763373680005294, 'f1': 0.38381055390763547, 'auc': 0.7412350522982578, 'prauc': 0.4510014244163691}
Test:       {'precision': 0.4063141187260534, 'recall': 0.3681185139743755, 'f1': 0.3709427109146162, 'auc': 0.7448256922962228, 'prauc': 0.4393636891467099}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.39it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.47it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.07it/s]


Epoch: 005, Average Loss: 0.3417
Validation: {'precision': 0.4316483546886791, 'recall': 0.3833156232227838, 'f1': 0.3904652029838221, 'auc': 0.7404784799670785, 'prauc': 0.4465953439839365}
Test:       {'precision': 0.41001775830935155, 'recall': 0.3742558393797996, 'f1': 0.3797825252937336, 'auc': 0.7458800766634703, 'prauc': 0.43640721329948523}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.39it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.38it/s] 
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.24it/s] 


Epoch: 006, Average Loss: 0.3327
Validation: {'precision': 0.4990362242028521, 'recall': 0.37515350090490274, 'f1': 0.3924507266856811, 'auc': 0.7401250657457036, 'prauc': 0.4503079758718795}
Test:       {'precision': 0.41709491150377154, 'recall': 0.35552504193160206, 'f1': 0.371285864944815, 'auc': 0.7402049956694539, 'prauc': 0.4311242105600782}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.51it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.65it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.31it/s] 


Epoch: 007, Average Loss: 0.3224
Validation: {'precision': 0.4111451522438766, 'recall': 0.36485413654483045, 'f1': 0.3745988335263572, 'auc': 0.7348138171296374, 'prauc': 0.44665065020869693}
Test:       {'precision': 0.4442601027949291, 'recall': 0.35161809912004743, 'f1': 0.36000665098585705, 'auc': 0.7358868340131326, 'prauc': 0.42758407265948695}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.77it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.32it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.81it/s]


Epoch: 008, Average Loss: 0.3159
Validation: {'precision': 0.4491332574542038, 'recall': 0.374560140953509, 'f1': 0.39694120296197033, 'auc': 0.7364762745414934, 'prauc': 0.4459942948397652}
Test:       {'precision': 0.4074865261496664, 'recall': 0.35674878281418126, 'f1': 0.37496995579518516, 'auc': 0.732071324123794, 'prauc': 0.42648276751316067}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.28it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.50it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.33it/s] 


Epoch: 009, Average Loss: 0.3043
Validation: {'precision': 0.48208312323895597, 'recall': 0.3932115061387446, 'f1': 0.4003905251025906, 'auc': 0.7310301401702947, 'prauc': 0.44759501211713165}
Test:       {'precision': 0.41668218697853454, 'recall': 0.3722617979333396, 'f1': 0.3815294201146276, 'auc': 0.7278274884338762, 'prauc': 0.42963552792109494}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 26.15it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.73it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.46it/s] 


Epoch: 010, Average Loss: 0.2981
Validation: {'precision': 0.42849427104011095, 'recall': 0.35521368468257164, 'f1': 0.3771958754907863, 'auc': 0.7327182553130225, 'prauc': 0.44498976166403953}
Test:       {'precision': 0.41300967712133496, 'recall': 0.33712887761174254, 'f1': 0.3572027568880968, 'auc': 0.7285917670686826, 'prauc': 0.42680639861938185}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.33it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.98it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.98it/s] 


Epoch: 011, Average Loss: 0.2880
Validation: {'precision': 0.44579295131708185, 'recall': 0.3798834449607793, 'f1': 0.397496181749444, 'auc': 0.7302739204387911, 'prauc': 0.44171176224991826}
Test:       {'precision': 0.40575394089213634, 'recall': 0.36009545436617213, 'f1': 0.3739237078525887, 'auc': 0.725480555105881, 'prauc': 0.42338175065964356}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.42it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.91it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.68it/s] 


Epoch: 012, Average Loss: 0.2784
Validation: {'precision': 0.4652942727388163, 'recall': 0.3831323686253028, 'f1': 0.3968162711750254, 'auc': 0.7325570592382431, 'prauc': 0.44509534700228204}
Test:       {'precision': 0.38886008302127817, 'recall': 0.3529122475326481, 'f1': 0.3654125392772649, 'auc': 0.7268563876736048, 'prauc': 0.42170449503004115}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 26.12it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.55it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.05it/s] 


Epoch: 013, Average Loss: 0.2711
Validation: {'precision': 0.439082409136874, 'recall': 0.38484405427386736, 'f1': 0.3954375033856879, 'auc': 0.7260887819248558, 'prauc': 0.4315377735283706}
Test:       {'precision': 0.4182204014410459, 'recall': 0.370689114282654, 'f1': 0.37850705692004, 'auc': 0.7231582324634611, 'prauc': 0.41391266114275305}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.31it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.67it/s] 
Running inference: 100%|██████████| 353/353 [00:03<00:00, 93.06it/s] 


Epoch: 014, Average Loss: 0.2617
Validation: {'precision': 0.4391125942337515, 'recall': 0.3845362172738627, 'f1': 0.39725518138162585, 'auc': 0.7272624654760629, 'prauc': 0.43899482386198563}
Test:       {'precision': 0.41475583690361983, 'recall': 0.3659761829802653, 'f1': 0.37741856915646604, 'auc': 0.7210751310787922, 'prauc': 0.41789569903730395}

Early stopping triggered after 14 epochs (no improvement for 5 epochs).

Best validation performance:
{'global': {'precision': 0.48208312323895597, 'recall': 0.3932115061387446, 'f1': 0.4003905251025906, 'auc': 0.7310301401702947, 'prauc': 0.44759501211713165}, 'per_class':                                            precision    recall        f1  \
Acute and unspecified renal failure         0.471366  0.493088  0.481982   
Acute cerebrovascular disease               0.000000  0.000000  0.000000   
Acute myocardial infarction                 0.166667  0.010526  0.019802   
Cardiac dysrhythmias                        0.751938  0.708029  0.

Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.70it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.76it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.71it/s] 


Epoch: 001, Average Loss: 0.4271
Validation: {'precision': 0.3625327573982084, 'recall': 0.284505238344102, 'f1': 0.3021751107622203, 'auc': 0.6998851994171678, 'prauc': 0.3983104922261744}
Test:       {'precision': 0.36616973050561, 'recall': 0.28366178273649983, 'f1': 0.3005113558967635, 'auc': 0.7115722518988791, 'prauc': 0.3954610674507564}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.55it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.28it/s] 
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.39it/s]


Epoch: 002, Average Loss: 0.3843
Validation: {'precision': 0.41877379632097944, 'recall': 0.3187122766259973, 'f1': 0.3468533324401018, 'auc': 0.7297553782748654, 'prauc': 0.43451849875983184}
Test:       {'precision': 0.41486913636489664, 'recall': 0.31427783327383274, 'f1': 0.34332468423731594, 'auc': 0.7393337121025699, 'prauc': 0.43387689019822745}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.41it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.77it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.81it/s]


Epoch: 003, Average Loss: 0.3640
Validation: {'precision': 0.47761235277587766, 'recall': 0.3207791267682186, 'f1': 0.35363727231868697, 'auc': 0.7436979539980871, 'prauc': 0.44961881855372204}
Test:       {'precision': 0.4038083628538122, 'recall': 0.3098232188929653, 'f1': 0.34050978345881483, 'auc': 0.7488417450002632, 'prauc': 0.4400521469300923}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.57it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.28it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.34it/s] 


Epoch: 004, Average Loss: 0.3505
Validation: {'precision': 0.4237231579568129, 'recall': 0.4019480743372299, 'f1': 0.39232690914749085, 'auc': 0.7345427553921859, 'prauc': 0.44327066620228994}
Test:       {'precision': 0.40968298910202855, 'recall': 0.3904579119670958, 'f1': 0.38261617559980043, 'auc': 0.734955407686861, 'prauc': 0.43584652586460804}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 26.14it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.59it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.76it/s]


Epoch: 005, Average Loss: 0.3423
Validation: {'precision': 0.4434621322700853, 'recall': 0.3514821477091499, 'f1': 0.37354183360133364, 'auc': 0.734667278135289, 'prauc': 0.44619915969073154}
Test:       {'precision': 0.42042730419140606, 'recall': 0.3389373473509545, 'f1': 0.35879210636687286, 'auc': 0.7403057563070494, 'prauc': 0.43780343470146876}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.37it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.56it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.31it/s] 


Epoch: 006, Average Loss: 0.3330
Validation: {'precision': 0.4100343842578722, 'recall': 0.3731236909708342, 'f1': 0.3832469876014285, 'auc': 0.7406160666982861, 'prauc': 0.44949454790278043}
Test:       {'precision': 0.398805521164234, 'recall': 0.3626059691957594, 'f1': 0.3738286627657168, 'auc': 0.7426531804155295, 'prauc': 0.43853396441981096}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.49it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.44it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.80it/s]


Epoch: 007, Average Loss: 0.3239
Validation: {'precision': 0.44004339068522946, 'recall': 0.38474277183190353, 'f1': 0.3917064355008425, 'auc': 0.7415340988555795, 'prauc': 0.45110466142791616}
Test:       {'precision': 0.46291059508518273, 'recall': 0.3719648913350455, 'f1': 0.3800582159091297, 'auc': 0.737929374751623, 'prauc': 0.43657267668299005}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.55it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.67it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.01it/s]


Epoch: 008, Average Loss: 0.3133
Validation: {'precision': 0.438129121534078, 'recall': 0.34949259017991713, 'f1': 0.3668822485015434, 'auc': 0.7344028791229864, 'prauc': 0.44304283808705747}
Test:       {'precision': 0.4246632632169436, 'recall': 0.33897963068654174, 'f1': 0.3590060786367155, 'auc': 0.7329502462795592, 'prauc': 0.4319928196037831}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.38it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.63it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.66it/s]


Epoch: 009, Average Loss: 0.3069
Validation: {'precision': 0.4455699407929397, 'recall': 0.35078987475830314, 'f1': 0.37643661360937103, 'auc': 0.728715702058532, 'prauc': 0.44110004938380987}
Test:       {'precision': 0.43102256980638576, 'recall': 0.3312341320052432, 'f1': 0.35993453487121413, 'auc': 0.7250317998569055, 'prauc': 0.4339783935712716}

Early stopping triggered after 9 epochs (no improvement for 5 epochs).

Best validation performance:
{'global': {'precision': 0.4237231579568129, 'recall': 0.4019480743372299, 'f1': 0.39232690914749085, 'auc': 0.7345427553921859, 'prauc': 0.44327066620228994}, 'per_class':                                            precision    recall        f1  \
Acute and unspecified renal failure         0.479853  0.603687  0.534694   
Acute cerebrovascular disease               0.000000  0.000000  0.000000   
Acute myocardial infarction                 0.000000  0.000000  0.000000   
Cardiac dysrhythmias                        0.850467  0.664234  0.74

Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.56it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.57it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.42it/s] 


Epoch: 001, Average Loss: 0.4288
Validation: {'precision': 0.3552461591265729, 'recall': 0.2538480023010162, 'f1': 0.27889116888739723, 'auc': 0.7173387082915408, 'prauc': 0.4114367816019394}
Test:       {'precision': 0.3516647873178465, 'recall': 0.24531353875895986, 'f1': 0.27052310548353314, 'auc': 0.7190754645518006, 'prauc': 0.40734114532051924}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.53it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.73it/s] 
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.83it/s] 


Epoch: 002, Average Loss: 0.3851
Validation: {'precision': 0.42101465017867923, 'recall': 0.34763273378481946, 'f1': 0.3467305206666961, 'auc': 0.7256066722296864, 'prauc': 0.4294530814781678}
Test:       {'precision': 0.3681724453502756, 'recall': 0.3442708656449879, 'f1': 0.3452232033586648, 'auc': 0.7344996476445187, 'prauc': 0.4170395840166014}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 26.13it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.05it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.53it/s]


Epoch: 003, Average Loss: 0.3678
Validation: {'precision': 0.4142597520880693, 'recall': 0.34096209300424124, 'f1': 0.3534918849420689, 'auc': 0.7338714615698709, 'prauc': 0.4362055743752503}
Test:       {'precision': 0.40411496770090743, 'recall': 0.3339236318010071, 'f1': 0.34600124929967835, 'auc': 0.7361464350115203, 'prauc': 0.4298767301191917}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.24it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.12it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.15it/s]


Epoch: 004, Average Loss: 0.3553
Validation: {'precision': 0.4026718669756295, 'recall': 0.3940745075710366, 'f1': 0.37962210517633727, 'auc': 0.7376253008465286, 'prauc': 0.44454075643296487}
Test:       {'precision': 0.40248190131530187, 'recall': 0.39217354836054574, 'f1': 0.37925005872220885, 'auc': 0.7433556446463152, 'prauc': 0.4353472950344842}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.47it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.28it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.18it/s]


Epoch: 005, Average Loss: 0.3439
Validation: {'precision': 0.4239879166682379, 'recall': 0.35033755247597026, 'f1': 0.37366306846086145, 'auc': 0.7337689435998224, 'prauc': 0.445135145178166}
Test:       {'precision': 0.42149573466038626, 'recall': 0.3372921607392533, 'f1': 0.3635948636170651, 'auc': 0.7412691911941187, 'prauc': 0.4342087937121935}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.29it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.31it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 93.20it/s] 


Epoch: 006, Average Loss: 0.3345
Validation: {'precision': 0.47182529906432946, 'recall': 0.3608431952675124, 'f1': 0.38128998795690827, 'auc': 0.7334038280785348, 'prauc': 0.44642830587689303}
Test:       {'precision': 0.42329149139313854, 'recall': 0.3442495569847429, 'f1': 0.36612110188933783, 'auc': 0.7371974973452301, 'prauc': 0.43131891680808704}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.26it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.05it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.50it/s] 


Epoch: 007, Average Loss: 0.3259
Validation: {'precision': 0.4093361513124629, 'recall': 0.3811392184773965, 'f1': 0.3907547313807319, 'auc': 0.7368914063843912, 'prauc': 0.4406972240168548}
Test:       {'precision': 0.40261033852339173, 'recall': 0.3754494010594669, 'f1': 0.38428734174442086, 'auc': 0.736395028040294, 'prauc': 0.426589490469699}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 26.15it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.80it/s] 
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.66it/s]


Epoch: 008, Average Loss: 0.3142
Validation: {'precision': 0.43444150265295906, 'recall': 0.34766332521364357, 'f1': 0.37899263184595666, 'auc': 0.7343840598393974, 'prauc': 0.44551204190369514}
Test:       {'precision': 0.4764924139205022, 'recall': 0.3315894483576046, 'f1': 0.36516875960374084, 'auc': 0.7356802731152647, 'prauc': 0.42598097663149886}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.58it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.36it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.95it/s] 


Epoch: 009, Average Loss: 0.3087
Validation: {'precision': 0.4242346700312474, 'recall': 0.39261857393655136, 'f1': 0.3976261559981861, 'auc': 0.731969107459802, 'prauc': 0.44151509039881176}
Test:       {'precision': 0.4428560668397085, 'recall': 0.3819137910257124, 'f1': 0.3910004582560428, 'auc': 0.7272602065785423, 'prauc': 0.4213446013519129}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.62it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.21it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.18it/s] 


Epoch: 010, Average Loss: 0.2985
Validation: {'precision': 0.46750032305320377, 'recall': 0.4019033381179105, 'f1': 0.39874609125129107, 'auc': 0.734718297091416, 'prauc': 0.4389975487904606}
Test:       {'precision': 0.40620801485797053, 'recall': 0.39535968394002324, 'f1': 0.3927001040525497, 'auc': 0.7316804960588167, 'prauc': 0.4216254212491996}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.23it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.95it/s] 
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.61it/s] 


Epoch: 011, Average Loss: 0.2935
Validation: {'precision': 0.4364443567857548, 'recall': 0.3645223064611155, 'f1': 0.3753891376047801, 'auc': 0.7301855559619891, 'prauc': 0.4401951129238304}
Test:       {'precision': 0.45914606256737245, 'recall': 0.3623769353308324, 'f1': 0.3728548795460489, 'auc': 0.7301450504538529, 'prauc': 0.4229446790872652}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.67it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.48it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.39it/s] 


Epoch: 012, Average Loss: 0.2832
Validation: {'precision': 0.44769437628341724, 'recall': 0.3531492782836296, 'f1': 0.37542685195462255, 'auc': 0.7317526329344531, 'prauc': 0.4382675407893064}
Test:       {'precision': 0.42880165120344604, 'recall': 0.3427008429682432, 'f1': 0.3677148450817993, 'auc': 0.7215022171199412, 'prauc': 0.41604842623673594}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.29it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.09it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.32it/s]


Epoch: 013, Average Loss: 0.2771
Validation: {'precision': 0.5250826672873621, 'recall': 0.3905782491072007, 'f1': 0.403352531731772, 'auc': 0.7292612994300817, 'prauc': 0.44545639987209573}
Test:       {'precision': 0.41319372656879066, 'recall': 0.37230341061382105, 'f1': 0.381235191216211, 'auc': 0.7258164454666293, 'prauc': 0.4168385443784904}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.20it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.05it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.77it/s]


Epoch: 014, Average Loss: 0.2672
Validation: {'precision': 0.4653899438701643, 'recall': 0.38866726281075714, 'f1': 0.3924513941092344, 'auc': 0.7280808595705991, 'prauc': 0.4354678788478866}
Test:       {'precision': 0.40453897108315645, 'recall': 0.37283819006271607, 'f1': 0.378425892434907, 'auc': 0.7209001426961089, 'prauc': 0.4133100703076949}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.49it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.11it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.72it/s]


Epoch: 015, Average Loss: 0.2641
Validation: {'precision': 0.47568788869937084, 'recall': 0.3903329947937466, 'f1': 0.39501575040297476, 'auc': 0.7255367810817313, 'prauc': 0.4323679407266525}
Test:       {'precision': 0.4006836176129965, 'recall': 0.37495593196164106, 'f1': 0.37893160329269804, 'auc': 0.7195818160729125, 'prauc': 0.40790808962111497}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.33it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.42it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.76it/s] 


Epoch: 016, Average Loss: 0.2549
Validation: {'precision': 0.47521466529593387, 'recall': 0.3728629935160509, 'f1': 0.3930815348861917, 'auc': 0.7249549820112576, 'prauc': 0.4349265790471077}
Test:       {'precision': 0.4199766270614276, 'recall': 0.3577490728866885, 'f1': 0.37801467465274463, 'auc': 0.7222905427734669, 'prauc': 0.41286399109543576}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.47it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.34it/s] 
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.10it/s] 


Epoch: 017, Average Loss: 0.2452
Validation: {'precision': 0.4313090824202147, 'recall': 0.3831313931988141, 'f1': 0.394778030977019, 'auc': 0.7261451904231312, 'prauc': 0.4277964912743972}
Test:       {'precision': 0.4172769362408584, 'recall': 0.36430381911836807, 'f1': 0.37802561929470285, 'auc': 0.7181800043179608, 'prauc': 0.407460125094813}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.50it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.92it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.59it/s]


Epoch: 018, Average Loss: 0.2379
Validation: {'precision': 0.4268736157007136, 'recall': 0.39679200747325116, 'f1': 0.3967107849532998, 'auc': 0.7205870363562341, 'prauc': 0.42730461499243855}
Test:       {'precision': 0.41092331356191836, 'recall': 0.3871520673460712, 'f1': 0.3858633035360851, 'auc': 0.7153190293606669, 'prauc': 0.41131733203868714}

Early stopping triggered after 18 epochs (no improvement for 5 epochs).

Best validation performance:
{'global': {'precision': 0.5250826672873621, 'recall': 0.3905782491072007, 'f1': 0.403352531731772, 'auc': 0.7292612994300817, 'prauc': 0.44545639987209573}, 'per_class':                                            precision    recall        f1  \
Acute and unspecified renal failure         0.500000  0.403226  0.446429   
Acute cerebrovascular disease               1.000000  0.055556  0.105263   
Acute myocardial infarction                 0.444444  0.084211  0.141593   
Cardiac dysrhythmias                        0.680395  0.753650  0.715

Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.47it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.19it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.77it/s]


Epoch: 001, Average Loss: 0.4244
Validation: {'precision': 0.36736892767025225, 'recall': 0.32284666013888413, 'f1': 0.33314905468417666, 'auc': 0.7060657864342816, 'prauc': 0.4082520549408274}
Test:       {'precision': 0.3715610157663922, 'recall': 0.3309896264158787, 'f1': 0.34080273287098134, 'auc': 0.7131096485647664, 'prauc': 0.40970444497330355}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.43it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.66it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 94.19it/s] 


Epoch: 002, Average Loss: 0.3817
Validation: {'precision': 0.42884085634444014, 'recall': 0.32386145638117697, 'f1': 0.33568829697715247, 'auc': 0.7296972822810702, 'prauc': 0.4423410163810193}
Test:       {'precision': 0.4192242910522088, 'recall': 0.3169723249644077, 'f1': 0.3324365500585169, 'auc': 0.7362459262605123, 'prauc': 0.4306151482651836}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.55it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.85it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.93it/s] 


Epoch: 003, Average Loss: 0.3654
Validation: {'precision': 0.42015570498666244, 'recall': 0.36656040448770155, 'f1': 0.3798545110457876, 'auc': 0.7351712846765817, 'prauc': 0.43919572138930746}
Test:       {'precision': 0.41314122028550027, 'recall': 0.3610652372793124, 'f1': 0.37530280748878453, 'auc': 0.7412766329038276, 'prauc': 0.43526388716063846}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.28it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.85it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.67it/s] 


Epoch: 004, Average Loss: 0.3529
Validation: {'precision': 0.42476557398611586, 'recall': 0.37835920018895214, 'f1': 0.3780575109624017, 'auc': 0.7403869271205685, 'prauc': 0.4501139316266492}
Test:       {'precision': 0.4149089288840735, 'recall': 0.3682881121788377, 'f1': 0.36858827892759327, 'auc': 0.7377122139055708, 'prauc': 0.43432085747611365}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.55it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.19it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.12it/s]


Epoch: 005, Average Loss: 0.3426
Validation: {'precision': 0.4369545585228339, 'recall': 0.3405568612568482, 'f1': 0.365405033281596, 'auc': 0.7419365102507984, 'prauc': 0.44572272823131126}
Test:       {'precision': 0.4405155486848656, 'recall': 0.32609276632241674, 'f1': 0.35220571547351265, 'auc': 0.7405297901864151, 'prauc': 0.43568504869982444}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.24it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.99it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.49it/s]


Epoch: 006, Average Loss: 0.3339
Validation: {'precision': 0.42050557258604937, 'recall': 0.3727537331191286, 'f1': 0.3870336994696238, 'auc': 0.7400860810389449, 'prauc': 0.4460467523988445}
Test:       {'precision': 0.40760375755275596, 'recall': 0.3672469960212752, 'f1': 0.3801464707601324, 'auc': 0.7382379375027792, 'prauc': 0.4349927736296474}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.52it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.59it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 93.07it/s]


Epoch: 007, Average Loss: 0.3246
Validation: {'precision': 0.4098163642746482, 'recall': 0.37663494723310514, 'f1': 0.38340010604398816, 'auc': 0.7360540525019945, 'prauc': 0.4431572156526404}
Test:       {'precision': 0.4627463794187241, 'recall': 0.3681413234671112, 'f1': 0.3795940198533787, 'auc': 0.7327350833832453, 'prauc': 0.42978556099296067}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.35it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.52it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.97it/s]


Epoch: 008, Average Loss: 0.3162
Validation: {'precision': 0.4303944379682163, 'recall': 0.3529193129594901, 'f1': 0.37855479385835133, 'auc': 0.7383015727518282, 'prauc': 0.4458433794216449}
Test:       {'precision': 0.4155284612972552, 'recall': 0.34059398411283187, 'f1': 0.36681704864903636, 'auc': 0.7316761011263578, 'prauc': 0.4330102931161648}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.49it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.55it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.83it/s]


Epoch: 009, Average Loss: 0.3066
Validation: {'precision': 0.4324806693679415, 'recall': 0.369416779765323, 'f1': 0.39262398866703885, 'auc': 0.7324417365031841, 'prauc': 0.440243729463845}
Test:       {'precision': 0.41501410499301666, 'recall': 0.3608356239836471, 'f1': 0.3806471903229994, 'auc': 0.7299000550485684, 'prauc': 0.4275070118299419}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.55it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.27it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.99it/s]


Epoch: 010, Average Loss: 0.2996
Validation: {'precision': 0.42454306104360495, 'recall': 0.3521608282782052, 'f1': 0.3768241854862086, 'auc': 0.7307590654154787, 'prauc': 0.43947511001898243}
Test:       {'precision': 0.4320158010447097, 'recall': 0.339068545137619, 'f1': 0.3634114388382955, 'auc': 0.7250033248110594, 'prauc': 0.4252628817529821}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.17it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.43it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.45it/s] 


Epoch: 011, Average Loss: 0.2908
Validation: {'precision': 0.42878111657376106, 'recall': 0.358886804905691, 'f1': 0.38193014285755394, 'auc': 0.7262500675195371, 'prauc': 0.4337356138772488}
Test:       {'precision': 0.40772015570699477, 'recall': 0.3402502154424847, 'f1': 0.36270622771673167, 'auc': 0.7170861574016646, 'prauc': 0.4208952357874015}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.44it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.80it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.64it/s]


Epoch: 012, Average Loss: 0.2825
Validation: {'precision': 0.44245797795738123, 'recall': 0.35378348991337677, 'f1': 0.3797911917207632, 'auc': 0.7275679761544949, 'prauc': 0.4350748704867973}
Test:       {'precision': 0.4215378060551108, 'recall': 0.34652739207098876, 'f1': 0.3723149222117079, 'auc': 0.7259309427053876, 'prauc': 0.42540388626597736}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.26it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.62it/s] 
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.52it/s] 


Epoch: 013, Average Loss: 0.2756
Validation: {'precision': 0.4235145052025693, 'recall': 0.3774239847151891, 'f1': 0.39286706037665703, 'auc': 0.725527121186674, 'prauc': 0.43310937227592805}
Test:       {'precision': 0.48139126938521337, 'recall': 0.3757327559023011, 'f1': 0.3900289123218956, 'auc': 0.7271199428242963, 'prauc': 0.4205833174896274}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.37it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.93it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.73it/s]


Epoch: 014, Average Loss: 0.2648
Validation: {'precision': 0.41690493341406176, 'recall': 0.3782963608061321, 'f1': 0.3900492336310607, 'auc': 0.7249814459647661, 'prauc': 0.43274812509047234}
Test:       {'precision': 0.4317453726374013, 'recall': 0.3734744998303395, 'f1': 0.3878274291883578, 'auc': 0.7219133032909006, 'prauc': 0.4204592510195203}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 26.11it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.92it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.20it/s] 


Epoch: 015, Average Loss: 0.2572
Validation: {'precision': 0.4276934137892865, 'recall': 0.37636175260176014, 'f1': 0.3910717697888032, 'auc': 0.7202692005641477, 'prauc': 0.42799056570716615}
Test:       {'precision': 0.47136872438264493, 'recall': 0.35829939833726554, 'f1': 0.37596191885883706, 'auc': 0.7161772601827742, 'prauc': 0.411775462081485}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.40it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.93it/s] 
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.69it/s]


Epoch: 016, Average Loss: 0.2511
Validation: {'precision': 0.43126922958141767, 'recall': 0.38905994196268917, 'f1': 0.400552138007492, 'auc': 0.7224055298892125, 'prauc': 0.42608736164093797}
Test:       {'precision': 0.4227625241415429, 'recall': 0.37655061403372725, 'f1': 0.3868473158971108, 'auc': 0.7179907088672358, 'prauc': 0.4138383070883196}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.35it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.49it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.11it/s] 


Epoch: 017, Average Loss: 0.2436
Validation: {'precision': 0.417531839171465, 'recall': 0.37707221806740177, 'f1': 0.38694937076814595, 'auc': 0.7242090478466899, 'prauc': 0.42764759235450545}
Test:       {'precision': 0.395601608373336, 'recall': 0.3655970113843599, 'f1': 0.3749534671790248, 'auc': 0.7153527784921914, 'prauc': 0.41247221983013294}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.52it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.96it/s] 
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.35it/s] 


Epoch: 018, Average Loss: 0.2369
Validation: {'precision': 0.4319207405423039, 'recall': 0.38371220705076603, 'f1': 0.39806938605227754, 'auc': 0.7234823290569592, 'prauc': 0.4272884998974946}
Test:       {'precision': 0.42637157364384215, 'recall': 0.36745271632798077, 'f1': 0.383793502251811, 'auc': 0.7124963142445432, 'prauc': 0.4126278926979422}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.23it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.36it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.26it/s] 


Epoch: 019, Average Loss: 0.2267
Validation: {'precision': 0.4172939138284532, 'recall': 0.3796295530518668, 'f1': 0.39200355853591207, 'auc': 0.7187991476399954, 'prauc': 0.41851921725634905}
Test:       {'precision': 0.4279558479132308, 'recall': 0.36451334010901504, 'f1': 0.3815200159287998, 'auc': 0.7113073093865019, 'prauc': 0.41237737011078185}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.94it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.98it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 91.74it/s]


Epoch: 020, Average Loss: 0.2190
Validation: {'precision': 0.4226624905430321, 'recall': 0.40087154933341573, 'f1': 0.3989751230670336, 'auc': 0.7205375438267285, 'prauc': 0.42011251007667727}
Test:       {'precision': 0.4012050159793869, 'recall': 0.38842161836955524, 'f1': 0.38832460778529104, 'auc': 0.7114098037791056, 'prauc': 0.41189691568898074}


Training Batches: 100%|██████████| 471/471 [00:17<00:00, 26.54it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.30it/s] 
Running inference: 100%|██████████| 353/353 [00:03<00:00, 92.65it/s]


Epoch: 021, Average Loss: 0.2130
Validation: {'precision': 0.43546889259056193, 'recall': 0.3958201387525165, 'f1': 0.3939920802819047, 'auc': 0.7157267571220759, 'prauc': 0.41830971937633454}
Test:       {'precision': 0.38726894999276307, 'recall': 0.3826729150039668, 'f1': 0.37772894713747823, 'auc': 0.7067864721223946, 'prauc': 0.4044497744803353}

Early stopping triggered after 21 epochs (no improvement for 5 epochs).

Best validation performance:
{'global': {'precision': 0.43126922958141767, 'recall': 0.38905994196268917, 'f1': 0.400552138007492, 'auc': 0.7224055298892125, 'prauc': 0.42608736164093797}, 'per_class':                                            precision    recall        f1  \
Acute and unspecified renal failure         0.438228  0.433180  0.435689   
Acute cerebrovascular disease               0.000000  0.000000  0.000000   
Acute myocardial infarction                 0.366667  0.115789  0.176000   
Cardiac dysrhythmias                        0.788337  0.666058  0.7

In [19]:
def topk_avg_performance_formatted(performances, long_seq_performances, k=5):
    metrics = ["f1", "auc", "prauc"]
    scores = {m: np.array([p[m] for p in performances]) for m in metrics}

    # 计算排名（值越大排名越靠前）
    ranks = {m: (-scores[m]).argsort().argsort() + 1 for m in metrics}
    avg_ranks = np.mean(np.stack([ranks[m] for m in metrics], axis=1), axis=1)

    # 选 top-k
    topk_idx = np.argsort(avg_ranks)[:k]
    final_avg = {m: np.mean([performances[i][m] for i in topk_idx]) for m in performances[0].keys()}
    final_std = {m: np.std([performances[i][m] for i in topk_idx], ddof=0) for m in performances[0].keys()}
    final_long_seq_avg = {m: np.mean([long_seq_performances[i][m] for i in topk_idx]) for m in long_seq_performances[0].keys()}
    final_long_seq_std = {m: np.std([long_seq_performances[i][m] for i in topk_idx], ddof=0) for m in long_seq_performances[0].keys()}

    # 打印结果（转百分比，均保留两位小数）
    print("Final Metrics:")
    for m in performances[0].keys():
        mean_val = final_avg[m] * 100
        std_val = final_std[m] * 100
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")
    print("\nFinal Long Sequence Metrics:")
    for m in long_seq_performances[0].keys():
        mean_val = final_long_seq_avg[m] * 100
        std_val = final_long_seq_std[m] * 100
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")

In [20]:
def print_per_class_performance(dfs, col_name="prauc"):
    """
    输入一个 DataFrame 列表，对每个疾病在所有表格的指定列计算 mean ± std 并打印。

    参数:
        dfs (list[pd.DataFrame]): 多个表格组成的列表
        col_name (str): 要计算的指标列名 (默认: "prauc")
    """
    # 拼接所有表格
    all_values = pd.concat(dfs, axis=0)

    # 按疾病分组，计算 mean 和 std
    grouped = all_values.groupby(all_values.index)[col_name].agg(["mean", "std"])

    # 打印
    for disease, row in grouped.iterrows():
        mean_val = row["mean"] * 100
        std_val = row["std"] * 100
        print(f"{disease}: {mean_val:.2f} ± {std_val:.2f}")

In [21]:
if task_type == "binary":
    topk_avg_performance_formatted(final_metrics, final_long_seq_metrics)
else:
    final_metrics_global = [metrics["global"] for metrics in final_metrics]
    final_metrics_per_class = [metrics["per_class"] for metrics in final_metrics]
    final_long_seq_metrics_global = [metrics["global"] for metrics in final_long_seq_metrics]
    final_long_seq_metrics_per_class = [metrics["per_class"] for metrics in final_long_seq_metrics]
    topk_avg_performance_formatted(final_metrics_global, final_long_seq_metrics_global)
    print("\nPer-class performance, all patients:")
    print_per_class_performance(final_metrics_per_class, col_name="prauc")
    print("\nPer-class performance, long seq:")
    print_per_class_performance(final_long_seq_metrics_per_class, col_name="prauc")

Final Metrics:
precision: 41.76 ± 0.59
recall: 37.41 ± 1.02
f1: 38.23 ± 0.25
auc: 72.81 ± 0.62
prauc: 42.51 ± 0.84

Final Long Sequence Metrics:
precision: 40.96 ± 0.86
recall: 38.17 ± 2.04
f1: 38.52 ± 0.96
auc: 74.08 ± 1.41
prauc: 43.61 ± 0.57

Per-class performance, all patients:
Acute and unspecified renal failure: 47.29 ± 1.81
Acute cerebrovascular disease: 3.21 ± 0.91
Acute myocardial infarction: 16.61 ± 1.41
Cardiac dysrhythmias: 74.98 ± 1.47
Chronic kidney disease: 76.96 ± 2.74
Chronic obstructive pulmonary disease: 43.31 ± 3.21
Conduction disorders: 4.24 ± 0.87
Congestive heart failure; nonhypertensive: 74.27 ± 1.04
Coronary atherosclerosis and related: 61.87 ± 1.33
Disorders of lipid metabolism: 57.82 ± 1.83
Essential hypertension: 64.75 ± 2.07
Fluid and electrolyte disorders: 46.37 ± 0.59
Gastrointestinal hemorrhage: 9.02 ± 0.59
Hypertension with complications: 72.26 ± 3.77
Other liver diseases: 5.71 ± 3.23
Other lower respiratory disease: 56.50 ± 1.10
Pneumonia: 16.51 ± 1.16