In [20]:
import torch
import pickle
import numpy as np
import pandas as pd
from argparse import Namespace
from torch.utils.data import DataLoader
import torch.nn.functional as F
from heterogt.utils.tokenizer import EHRTokenizer
from heterogt.utils.dataset import FineTuneEHRDataset, batcher, expand_level3
from heterogt.utils.train import train_with_early_stopping, PHENO_ORDER
from heterogt.utils.seed import set_random_seed
from heterogt.model.model import HeteroGTFineTune

In [2]:
set_random_seed(123)

[INFO] Random seed set to 123


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
config = Namespace(
    dataset = "MIMIC-III",
    tasks = ["death", "readmission", "stay", "next_diag_6m", "next_diag_12m"],
    task_index = 4,  # index of the task to train
    token_type = ["diag", "med", "lab", "pro"],
    special_tokens = ["[PAD]", "[CLS]"],
    # only for tf layer
    attn_mask_dicts = [{1:[6,7], 2:[6,7], 3:[6,7], 4:[6,7], 5:[6,7], 6:[1,2,3,4,5,6,7], 7:[1,2,3,4,5,6,7]}, 
                       {1:[6,7], 2:[6,7], 3:[6,7], 4:[6,7], 5:[6,7], 6:[1,2,3,4,5,6,7], 7:[1,2,3,4,5,6,7]}], 
    d_model = 64,
    num_heads = 4,
    batch_size = 32,
    lr = 1e-3,
    epochs = 500,
    early_stop_patience = 5,
    group_code_thre = 5,  # if there are group_code_thre diag codes belongs to the same group ICD code, then the group code is generated
    use_pretrained_model = True,
    pretrain_mask_rate = 0.7,
    pretrain_cls_ontology_weight = 5e-2,
    pretrain_visit_ontology_weight = 5e-2,
    pretrain_adm_type_weight = 0,
    dec_loss_lambda = 1e-2,
)

In [5]:
full_data_path = f"./data_process/{config.dataset}-processed/mimic.pkl"  # for tokenizer
curr_task = config.tasks[config.task_index]
print("Current task:", curr_task)
if curr_task == "next_diag_6m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_6m.pkl"
elif curr_task == "next_diag_12m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_downstream.pkl"

Current task: next_diag_12m


In [6]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
group_code_sentences = [expand_level3()[1]]
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_sentences = [[str(c)] for c in set(ehr_full_data["AGE"].values.tolist())] # important of [[]]
adm_type_sentences = ehr_full_data["ADMISSION_TYPE"].values.tolist()
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
config.max_num_adms = max_admissions
print(f"Max admissions per patient: {config.max_num_adms}")

Max admissions per patient: 8


In [7]:
tokenizer = EHRTokenizer(age_sentences, group_code_sentences, diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, special_tokens=config.special_tokens, adm_types_sentences=adm_type_sentences)
config.label_vocab_size = len(PHENO_ORDER)  # a set of predifined diseases
config.global_vocab_size = len(tokenizer.vocab.id2word)
config.age_vocab_size = tokenizer.token_number("age")
config.group_code_vocab_size = tokenizer.token_number("group")
print(f"Age vocabulary size: {config.age_vocab_size}")
print(f"Group code vocabulary size: {config.group_code_vocab_size}")

Age vocabulary size: 18
Group code vocabulary size: 19


In [8]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))
# example label percentage
print("Percentage of DEATH in test dataset:",
      (test_data["DEATH"] == True).mean() * 100, "%")

print("Percentage of READMISSION in test dataset:",
      (test_data["READMISSION"] == 1).mean() * 100, "%")

print("Percentage of STAY>7 days in test dataset:",
      (test_data["STAY_DAYS"] > 7).mean() * 100, "%")

Percentage of DEATH in test dataset: 8.114199849737041 %
Percentage of READMISSION in test dataset: 64.08715251690458 %
Percentage of STAY>7 days in test dataset: 55.10894064613073 %


In [9]:
train_dataset = FineTuneEHRDataset(train_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                   max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)
val_dataset = FineTuneEHRDataset(val_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                 max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)
test_dataset = FineTuneEHRDataset(test_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                   max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)

In [10]:
long_adm_seq_crite = 3
val_long_seq_idx, test_long_seq_idx = [], []
for i in range(len(val_dataset)):
    hadm_id = list(val_dataset.records.keys())[i]
    num_adms = len(val_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite:
        val_long_seq_idx.append(i)
for i in range(len(test_dataset)):
    hadm_id = list(test_dataset.records.keys())[i]
    num_adms = len(test_dataset.records[hadm_id])
    if num_adms >= long_adm_seq_crite:
        test_long_seq_idx.append(i)
print(len(val_long_seq_idx), len(test_long_seq_idx))

208 186


In [11]:
num_group_code = []
for i in range(len(train_dataset)):
    input_ids, token_types, adm_index, age_ids, diag_group_codes, labels = train_dataset[i]
    count = (token_types[0] == 6).sum().item()
    num_group_code.append(count)
print("Mean group token numer per patient", np.mean(num_group_code))

Mean group token numer per patient 0.8268720127456187


In [12]:
train_dataloader = DataLoader(
    train_dataset,
    collate_fn=batcher(tokenizer, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=True,
    batch_size=config.batch_size,
)

val_dataloader = DataLoader(
    val_dataset,
    collate_fn=batcher(tokenizer, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

test_dataloader = DataLoader(
    test_dataset,
    collate_fn=batcher(tokenizer, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

In [13]:
for batch in train_dataloader:
    pass  # just to check if the dataloader works
for batch in val_dataloader:
    pass  # just to check if the dataloader works
for batch in test_dataloader:
    pass  # just to check if the dataloader works
print("All pass!")

All pass!


In [14]:
if curr_task in ["death", "stay", "readmission"]:
    eval_metric = "f1"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "f1"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [15]:
input_ids, token_types, adm_index, age_ids, diag_code_group_dicts, labels = next(iter(train_dataloader))
print("Input IDs shape:", input_ids.shape)
print("Token Types shape:", token_types.shape)
print("Admission Index shape:", adm_index.shape)
print("Age IDs shape:", age_ids.shape)
print("Diag Code Group Dict number:", len(diag_code_group_dicts))
print("Labels shape:", labels.shape)

Input IDs shape: torch.Size([32, 193])
Token Types shape: torch.Size([32, 193])
Admission Index shape: torch.Size([32, 193])
Age IDs shape: torch.Size([32, 4])
Diag Code Group Dict number: 32
Labels shape: torch.Size([32, 18])


# Model Walkthrough

In [16]:
# load pretrained model
if config.use_pretrained_model:
    pretrain_exp_name = (
    f"{config.dataset}-{config.pretrain_mask_rate}-{config.d_model}-{config.pretrain_cls_ontology_weight}-{config.pretrain_visit_ontology_weight}-{config.pretrain_adm_type_weight}"
)
    print(pretrain_exp_name)
    save_path = "./pretrained_models/" + pretrain_exp_name
    state_dict = torch.load(f"{save_path}/pretrained_model.pt", map_location="cpu")

MIMIC-III-0.7-64-0.05-0.05-0


In [None]:
final_metrics, final_long_seq_metrics = [], []
for i in range(15):
    print(f"================================{i+1}==================================")
    model = HeteroGTFineTune(tokenizer=tokenizer, token_types=config.token_type, d_model=config.d_model, num_heads=config.num_heads, layer_types=['gnn', 'tf', 'gnn', 'tf'], max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, label_vocab_size=config.label_vocab_size, attn_mask_dicts=config.attn_mask_dicts,
                     use_cls_cat=True).to(device)
    if config.use_pretrained_model:
        model.load_weight(state_dict)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric, best_test_long_seq_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                                                            optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                                                            dec_loss_lambda=config.dec_loss_lambda, 
                                                                            val_long_seq_idx=val_long_seq_idx, test_long_seq_idx=test_long_seq_idx, 
                                                                            eval_metric=eval_metric, return_model=False)
    final_metrics.append(best_test_metric)
    final_long_seq_metrics.append(best_test_long_seq_metric)



Epoch 001: 100%|██████████| 59/59 [00:04<00:00, 13.90it/s, loss=0.4680]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.22it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.13it/s]


Validation: {'precision': 36.353, 'recall': 14.3667, 'f1': 19.0981, 'auc': 66.5156, 'prauc': 36.7519}
Test:      {'precision': 33.1821, 'recall': 13.7059, 'f1': 18.1904, 'auc': 67.2613, 'prauc': 35.7609}

Validation-long: {'precision': 36.6397, 'recall': 18.3494, 'f1': 22.0961, 'auc': 64.3429, 'prauc': 42.5343}
Test-long: {'precision': 31.8076, 'recall': 17.194, 'f1': 20.7733, 'auc': 65.1907, 'prauc': 40.4217}



Epoch 002: 100%|██████████| 59/59 [00:04<00:00, 13.95it/s, loss=0.4036]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.20it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.01it/s]


Validation: {'precision': 41.6864, 'recall': 20.9485, 'f1': 25.4404, 'auc': 70.7111, 'prauc': 39.6872}
Test:      {'precision': 38.4604, 'recall': 19.8374, 'f1': 24.2126, 'auc': 72.4446, 'prauc': 39.2965}

Validation-long: {'precision': 46.6562, 'recall': 25.4735, 'f1': 29.2943, 'auc': 68.7802, 'prauc': 46.5703}
Test-long: {'precision': 41.2657, 'recall': 22.6928, 'f1': 26.8704, 'auc': 71.8297, 'prauc': 43.6032}



Epoch 003: 100%|██████████| 59/59 [00:04<00:00, 14.68it/s, loss=0.3886]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.17it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.09it/s]


Validation: {'precision': 39.772, 'recall': 23.3231, 'f1': 27.0509, 'auc': 72.0743, 'prauc': 41.2629}
Test:      {'precision': 37.8287, 'recall': 22.2875, 'f1': 26.0094, 'auc': 73.5019, 'prauc': 40.3559}

Validation-long: {'precision': 40.1853, 'recall': 26.9921, 'f1': 29.8544, 'auc': 70.3106, 'prauc': 47.3538}
Test-long: {'precision': 44.5227, 'recall': 24.4104, 'f1': 28.5041, 'auc': 72.9209, 'prauc': 44.843}



Epoch 004: 100%|██████████| 59/59 [00:03<00:00, 15.07it/s, loss=0.3750]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.34it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.35it/s]


Validation: {'precision': 38.1766, 'recall': 29.3626, 'f1': 31.4561, 'auc': 73.5524, 'prauc': 42.3312}
Test:      {'precision': 42.5614, 'recall': 28.0491, 'f1': 30.4927, 'auc': 74.5317, 'prauc': 42.0541}

Validation-long: {'precision': 36.7382, 'recall': 30.7527, 'f1': 32.7358, 'auc': 72.027, 'prauc': 48.2985}
Test-long: {'precision': 43.0076, 'recall': 28.3069, 'f1': 31.7314, 'auc': 73.8523, 'prauc': 46.1041}



Epoch 005: 100%|██████████| 59/59 [00:03<00:00, 14.99it/s, loss=0.3639]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.43it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.40it/s]


Validation: {'precision': 42.9251, 'recall': 31.0239, 'f1': 33.9715, 'auc': 73.9891, 'prauc': 43.0932}
Test:      {'precision': 39.1788, 'recall': 30.1164, 'f1': 33.0235, 'auc': 74.5879, 'prauc': 42.5798}

Validation-long: {'precision': 42.7988, 'recall': 36.5903, 'f1': 38.6108, 'auc': 71.6012, 'prauc': 48.1071}
Test-long: {'precision': 41.1173, 'recall': 33.9844, 'f1': 36.749, 'auc': 73.9791, 'prauc': 45.4521}



Epoch 006: 100%|██████████| 59/59 [00:03<00:00, 14.76it/s, loss=0.3546]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.19it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.12it/s]


Validation: {'precision': 42.0042, 'recall': 35.2419, 'f1': 36.38, 'auc': 73.958, 'prauc': 43.4859}
Test:      {'precision': 43.8041, 'recall': 33.9334, 'f1': 35.11, 'auc': 75.12, 'prauc': 43.2257}

Validation-long: {'precision': 41.1679, 'recall': 40.6116, 'f1': 40.1934, 'auc': 71.7995, 'prauc': 47.8623}
Test-long: {'precision': 39.8471, 'recall': 38.0409, 'f1': 38.2263, 'auc': 75.072, 'prauc': 46.5214}



Epoch 007: 100%|██████████| 59/59 [00:03<00:00, 15.03it/s, loss=0.3439]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.21it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 19.96it/s]


Validation: {'precision': 43.6497, 'recall': 33.6528, 'f1': 36.1832, 'auc': 74.7109, 'prauc': 44.4919}
Test:      {'precision': 40.611, 'recall': 31.8107, 'f1': 34.5555, 'auc': 75.1897, 'prauc': 43.6681}

Validation-long: {'precision': 43.5781, 'recall': 37.3066, 'f1': 39.0085, 'auc': 72.5988, 'prauc': 49.0179}
Test-long: {'precision': 42.5037, 'recall': 33.0338, 'f1': 36.0933, 'auc': 74.8821, 'prauc': 46.5278}



Epoch 008: 100%|██████████| 59/59 [00:03<00:00, 15.06it/s, loss=0.3339]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.33it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.24it/s]


Validation: {'precision': 41.1738, 'recall': 34.2901, 'f1': 36.2696, 'auc': 74.07, 'prauc': 44.1097}
Test:      {'precision': 39.7582, 'recall': 32.8992, 'f1': 35.0806, 'auc': 74.7253, 'prauc': 43.4937}

Validation-long: {'precision': 41.9574, 'recall': 38.7914, 'f1': 39.2535, 'auc': 71.8157, 'prauc': 49.3727}
Test-long: {'precision': 40.2678, 'recall': 34.5229, 'f1': 36.7547, 'auc': 74.3022, 'prauc': 46.4946}



Epoch 009: 100%|██████████| 59/59 [00:03<00:00, 15.09it/s, loss=0.3230]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.45it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.36it/s]


Validation: {'precision': 44.6061, 'recall': 33.7991, 'f1': 36.1577, 'auc': 74.276, 'prauc': 43.8628}
Test:      {'precision': 43.5192, 'recall': 32.4662, 'f1': 34.9665, 'auc': 74.652, 'prauc': 43.5068}

Validation-long: {'precision': 43.9903, 'recall': 38.5997, 'f1': 40.0063, 'auc': 71.9383, 'prauc': 48.7059}
Test-long: {'precision': 42.2677, 'recall': 35.422, 'f1': 37.25, 'auc': 73.2069, 'prauc': 46.0242}



Epoch 010: 100%|██████████| 59/59 [00:03<00:00, 15.01it/s, loss=0.3142]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.31it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.21it/s]


Validation: {'precision': 45.9698, 'recall': 36.087, 'f1': 37.1744, 'auc': 75.0015, 'prauc': 44.388}
Test:      {'precision': 44.5829, 'recall': 34.5084, 'f1': 36.0421, 'auc': 74.9538, 'prauc': 43.5775}

Validation-long: {'precision': 43.3859, 'recall': 39.3835, 'f1': 40.1839, 'auc': 72.3218, 'prauc': 50.1186}
Test-long: {'precision': 40.6807, 'recall': 35.6902, 'f1': 37.3101, 'auc': 73.5039, 'prauc': 46.0616}



Epoch 011: 100%|██████████| 59/59 [00:03<00:00, 15.24it/s, loss=0.3034]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.36it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.25it/s]


Validation: {'precision': 41.2956, 'recall': 36.7301, 'f1': 37.9058, 'auc': 74.2613, 'prauc': 43.6294}
Test:      {'precision': 44.0351, 'recall': 35.7154, 'f1': 37.264, 'auc': 74.4019, 'prauc': 43.4557}

Validation-long: {'precision': 41.5965, 'recall': 40.8647, 'f1': 40.5227, 'auc': 72.0398, 'prauc': 49.7409}
Test-long: {'precision': 40.0272, 'recall': 38.331, 'f1': 38.4838, 'auc': 71.9017, 'prauc': 45.6271}



Epoch 012: 100%|██████████| 59/59 [00:03<00:00, 15.03it/s, loss=0.2951]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.39it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.27it/s]


Validation: {'precision': 40.1442, 'recall': 38.1012, 'f1': 38.0196, 'auc': 74.5602, 'prauc': 43.2952}
Test:      {'precision': 44.3706, 'recall': 37.2643, 'f1': 37.2882, 'auc': 74.1609, 'prauc': 43.2734}

Validation-long: {'precision': 42.1886, 'recall': 41.9814, 'f1': 41.1663, 'auc': 72.5214, 'prauc': 49.3294}
Test-long: {'precision': 38.8649, 'recall': 38.1084, 'f1': 38.022, 'auc': 70.6971, 'prauc': 44.4641}



Epoch 013: 100%|██████████| 59/59 [00:04<00:00, 14.24it/s, loss=0.2878]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.35it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.28it/s]


Validation: {'precision': 42.636, 'recall': 36.6145, 'f1': 38.009, 'auc': 74.4111, 'prauc': 43.4659}
Test:      {'precision': 43.0481, 'recall': 35.0734, 'f1': 37.0773, 'auc': 73.6694, 'prauc': 43.0967}

Validation-long: {'precision': 41.9233, 'recall': 39.9678, 'f1': 40.0459, 'auc': 72.091, 'prauc': 48.8091}
Test-long: {'precision': 40.4054, 'recall': 36.6985, 'f1': 38.1304, 'auc': 71.7292, 'prauc': 45.5074}



Epoch 014: 100%|██████████| 59/59 [00:03<00:00, 14.81it/s, loss=0.2777]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.37it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.09it/s]


Validation: {'precision': 49.9633, 'recall': 35.1631, 'f1': 37.7377, 'auc': 74.1439, 'prauc': 43.431}
Test:      {'precision': 43.2247, 'recall': 33.3892, 'f1': 36.2197, 'auc': 73.8245, 'prauc': 43.4873}

Validation-long: {'precision': 42.6116, 'recall': 40.6585, 'f1': 41.0687, 'auc': 71.689, 'prauc': 48.7862}
Test-long: {'precision': 41.5747, 'recall': 37.6227, 'f1': 38.6802, 'auc': 68.4859, 'prauc': 44.6783}



Epoch 015: 100%|██████████| 59/59 [00:03<00:00, 15.02it/s, loss=0.2652]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.44it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.37it/s]


Validation: {'precision': 44.3026, 'recall': 33.99, 'f1': 37.0634, 'auc': 73.9459, 'prauc': 43.42}
Test:      {'precision': 44.6139, 'recall': 32.9591, 'f1': 36.5605, 'auc': 73.2123, 'prauc': 42.7254}

Validation-long: {'precision': 42.214, 'recall': 36.6453, 'f1': 38.7399, 'auc': 72.0205, 'prauc': 48.8879}
Test-long: {'precision': 44.1617, 'recall': 34.9335, 'f1': 38.383, 'auc': 71.0501, 'prauc': 45.2277}



Epoch 016: 100%|██████████| 59/59 [00:03<00:00, 15.04it/s, loss=0.2579]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.26it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.24it/s]


Validation: {'precision': 47.9457, 'recall': 34.6076, 'f1': 37.469, 'auc': 73.7985, 'prauc': 43.2976}
Test:      {'precision': 45.3164, 'recall': 33.2031, 'f1': 36.6295, 'auc': 73.3117, 'prauc': 42.576}

Validation-long: {'precision': 47.8297, 'recall': 36.5577, 'f1': 38.668, 'auc': 72.02, 'prauc': 49.4483}
Test-long: {'precision': 41.8033, 'recall': 34.1061, 'f1': 36.9434, 'auc': 69.9047, 'prauc': 43.8715}



Epoch 017: 100%|██████████| 59/59 [00:03<00:00, 14.95it/s, loss=0.2505]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.34it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.35it/s]


Validation: {'precision': 43.6668, 'recall': 38.916, 'f1': 39.3334, 'auc': 73.2624, 'prauc': 42.2433}
Test:      {'precision': 41.7391, 'recall': 37.5475, 'f1': 38.0526, 'auc': 73.2057, 'prauc': 42.536}

Validation-long: {'precision': 42.3824, 'recall': 43.3106, 'f1': 42.063, 'auc': 70.5677, 'prauc': 48.4945}
Test-long: {'precision': 40.3062, 'recall': 40.621, 'f1': 39.6377, 'auc': 69.2221, 'prauc': 44.3629}



Epoch 018: 100%|██████████| 59/59 [00:03<00:00, 14.97it/s, loss=0.2410]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.37it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.29it/s]


Validation: {'precision': 44.5234, 'recall': 35.3417, 'f1': 37.7924, 'auc': 72.9135, 'prauc': 42.2448}
Test:      {'precision': 43.8009, 'recall': 34.6598, 'f1': 37.3105, 'auc': 72.3298, 'prauc': 41.4924}

Validation-long: {'precision': 45.6791, 'recall': 40.2071, 'f1': 41.7257, 'auc': 71.2287, 'prauc': 49.8367}
Test-long: {'precision': 41.5003, 'recall': 35.8417, 'f1': 38.2045, 'auc': 69.4194, 'prauc': 44.0443}



Epoch 019: 100%|██████████| 59/59 [00:04<00:00, 14.23it/s, loss=0.2316]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.28it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.12it/s]


Validation: {'precision': 43.8915, 'recall': 38.3172, 'f1': 39.5889, 'auc': 72.9493, 'prauc': 42.2979}
Test:      {'precision': 44.2489, 'recall': 36.351, 'f1': 37.8575, 'auc': 72.5153, 'prauc': 41.5347}

Validation-long: {'precision': 42.8558, 'recall': 40.6687, 'f1': 41.1896, 'auc': 72.0298, 'prauc': 49.994}
Test-long: {'precision': 41.7243, 'recall': 36.0073, 'f1': 37.8974, 'auc': 70.3523, 'prauc': 44.2475}



Epoch 020: 100%|██████████| 59/59 [00:03<00:00, 15.07it/s, loss=0.2251]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.17it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.45it/s]


Validation: {'precision': 41.9916, 'recall': 38.7594, 'f1': 38.7223, 'auc': 72.7224, 'prauc': 41.9549}
Test:      {'precision': 41.5051, 'recall': 38.409, 'f1': 38.6749, 'auc': 72.1284, 'prauc': 41.2931}

Validation-long: {'precision': 40.5523, 'recall': 41.0133, 'f1': 39.363, 'auc': 71.0607, 'prauc': 48.0386}
Test-long: {'precision': 39.9567, 'recall': 38.8961, 'f1': 38.7365, 'auc': 68.3093, 'prauc': 43.0158}



Epoch 021: 100%|██████████| 59/59 [00:03<00:00, 14.96it/s, loss=0.2124]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.23it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.02it/s]


Validation: {'precision': 43.1108, 'recall': 38.7757, 'f1': 39.6755, 'auc': 72.7951, 'prauc': 42.4222}
Test:      {'precision': 42.8236, 'recall': 37.7747, 'f1': 38.9767, 'auc': 72.2971, 'prauc': 41.617}

Validation-long: {'precision': 43.4629, 'recall': 43.02, 'f1': 42.2847, 'auc': 70.8918, 'prauc': 48.8378}
Test-long: {'precision': 40.5217, 'recall': 40.3233, 'f1': 39.9148, 'auc': 71.1175, 'prauc': 44.28}



Epoch 022: 100%|██████████| 59/59 [00:04<00:00, 14.70it/s, loss=0.2093]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.08it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.01it/s]


Validation: {'precision': 43.6151, 'recall': 36.8379, 'f1': 38.6473, 'auc': 72.31, 'prauc': 41.8945}
Test:      {'precision': 42.7101, 'recall': 36.533, 'f1': 38.4085, 'auc': 71.3784, 'prauc': 41.7399}

Validation-long: {'precision': 43.4502, 'recall': 40.0199, 'f1': 40.8618, 'auc': 70.154, 'prauc': 49.4684}
Test-long: {'precision': 42.3846, 'recall': 38.1007, 'f1': 39.1195, 'auc': 69.0125, 'prauc': 44.5843}



Epoch 023: 100%|██████████| 59/59 [00:03<00:00, 14.87it/s, loss=0.1997]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.12it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.12it/s]


Validation: {'precision': 44.807, 'recall': 38.8229, 'f1': 39.9382, 'auc': 72.5009, 'prauc': 42.2453}
Test:      {'precision': 43.695, 'recall': 36.8971, 'f1': 38.5626, 'auc': 71.5927, 'prauc': 41.4727}

Validation-long: {'precision': 47.5931, 'recall': 43.5079, 'f1': 43.4123, 'auc': 70.7876, 'prauc': 50.0789}
Test-long: {'precision': 40.7909, 'recall': 38.8832, 'f1': 38.7796, 'auc': 69.7858, 'prauc': 44.2874}



Epoch 024: 100%|██████████| 59/59 [00:03<00:00, 15.02it/s, loss=0.1903]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 19.89it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.17it/s]


Validation: {'precision': 44.9062, 'recall': 37.3485, 'f1': 39.0666, 'auc': 72.1975, 'prauc': 42.0334}
Test:      {'precision': 42.0449, 'recall': 35.5921, 'f1': 37.5445, 'auc': 71.4617, 'prauc': 41.5372}

Validation-long: {'precision': 42.3038, 'recall': 41.6828, 'f1': 41.7863, 'auc': 70.4761, 'prauc': 48.6942}
Test-long: {'precision': 41.8776, 'recall': 37.4985, 'f1': 39.2376, 'auc': 69.7152, 'prauc': 44.4214}



Epoch 025: 100%|██████████| 59/59 [00:03<00:00, 15.09it/s, loss=0.1850]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.34it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.19it/s]


Validation: {'precision': 42.6601, 'recall': 37.3806, 'f1': 38.4611, 'auc': 71.9526, 'prauc': 41.5106}
Test:      {'precision': 41.299, 'recall': 36.6239, 'f1': 37.9482, 'auc': 71.5117, 'prauc': 41.2457}

Validation-long: {'precision': 40.5577, 'recall': 41.346, 'f1': 40.4871, 'auc': 70.0846, 'prauc': 49.37}
Test-long: {'precision': 39.7777, 'recall': 38.3302, 'f1': 38.6138, 'auc': 68.51, 'prauc': 44.041}



Epoch 026: 100%|██████████| 59/59 [00:03<00:00, 15.06it/s, loss=0.1792]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.36it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.26it/s]


Validation: {'precision': 42.5884, 'recall': 38.9795, 'f1': 40.2376, 'auc': 71.9977, 'prauc': 41.5118}
Test:      {'precision': 41.384, 'recall': 37.2724, 'f1': 38.8874, 'auc': 71.5392, 'prauc': 40.8687}

Validation-long: {'precision': 42.0654, 'recall': 42.5912, 'f1': 42.0343, 'auc': 69.9093, 'prauc': 49.6012}
Test-long: {'precision': 43.3806, 'recall': 39.4741, 'f1': 40.2626, 'auc': 69.5915, 'prauc': 44.3678}



Epoch 027: 100%|██████████| 59/59 [00:03<00:00, 14.96it/s, loss=0.1708]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 17.56it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 17.83it/s]


Validation: {'precision': 42.0925, 'recall': 37.1943, 'f1': 38.643, 'auc': 71.3251, 'prauc': 41.0162}
Test:      {'precision': 41.5439, 'recall': 35.7198, 'f1': 37.6343, 'auc': 70.8779, 'prauc': 41.0384}

Validation-long: {'precision': 42.6, 'recall': 39.8202, 'f1': 40.9881, 'auc': 70.0324, 'prauc': 48.4864}
Test-long: {'precision': 43.3326, 'recall': 36.1383, 'f1': 38.2597, 'auc': 67.1983, 'prauc': 44.1493}



Epoch 028: 100%|██████████| 59/59 [00:04<00:00, 14.49it/s, loss=0.1657]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.30it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.16it/s]


Validation: {'precision': 42.8096, 'recall': 36.6415, 'f1': 38.7612, 'auc': 71.7478, 'prauc': 41.2044}
Test:      {'precision': 42.9173, 'recall': 36.3283, 'f1': 38.772, 'auc': 71.6431, 'prauc': 41.1858}

Validation-long: {'precision': 44.2629, 'recall': 39.4811, 'f1': 41.2599, 'auc': 71.208, 'prauc': 48.2574}
Test-long: {'precision': 44.0986, 'recall': 36.5551, 'f1': 38.4487, 'auc': 69.7014, 'prauc': 44.5716}



Epoch 029: 100%|██████████| 59/59 [00:03<00:00, 15.04it/s, loss=0.1615]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.20it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.07it/s]


Validation: {'precision': 41.1373, 'recall': 40.0472, 'f1': 39.9415, 'auc': 71.3495, 'prauc': 40.9582}
Test:      {'precision': 40.1664, 'recall': 38.8735, 'f1': 39.023, 'auc': 71.3777, 'prauc': 40.7869}

Validation-long: {'precision': 43.5518, 'recall': 43.4837, 'f1': 42.4772, 'auc': 69.9432, 'prauc': 48.5759}
Test-long: {'precision': 40.7527, 'recall': 43.1248, 'f1': 41.0535, 'auc': 68.23, 'prauc': 43.4578}



Epoch 030: 100%|██████████| 59/59 [00:04<00:00, 14.07it/s, loss=0.1565]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 21.95it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 21.95it/s]


Validation: {'precision': 42.9993, 'recall': 38.1465, 'f1': 38.9522, 'auc': 71.6728, 'prauc': 41.3523}
Test:      {'precision': 40.1585, 'recall': 37.0124, 'f1': 37.8413, 'auc': 71.3992, 'prauc': 40.8537}

Validation-long: {'precision': 46.7894, 'recall': 40.807, 'f1': 41.8368, 'auc': 69.6333, 'prauc': 49.3936}
Test-long: {'precision': 39.5785, 'recall': 38.0404, 'f1': 38.5593, 'auc': 68.6759, 'prauc': 44.2242}



Epoch 031: 100%|██████████| 59/59 [00:03<00:00, 14.84it/s, loss=0.1512]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.11it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.04it/s]


Validation: {'precision': 41.9189, 'recall': 37.3165, 'f1': 38.9654, 'auc': 71.2838, 'prauc': 40.9327}
Test:      {'precision': 40.8686, 'recall': 35.8625, 'f1': 37.8453, 'auc': 71.3907, 'prauc': 41.3148}

Validation-long: {'precision': 42.8342, 'recall': 38.873, 'f1': 40.4403, 'auc': 70.215, 'prauc': 48.68}
Test-long: {'precision': 43.2065, 'recall': 36.4787, 'f1': 38.5465, 'auc': 68.6109, 'prauc': 43.8657}


Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'global': {'precision': 42.5884, 'recall': 38.9795, 'f1': 40.2376, 'auc': 71.9977, 'prauc': 41.5118}, 'per_class':                                            precision   recall       f1  \
Acute and unspecified renal failure          45.5378  45.8525  45.6946   
Acute cerebrovascular disease                 0.0000   0.0000   0.0000   
Acute myocardial infarction                  17.3333  13.6842  15.2941   
Cardiac dysrhythmias                         70.4198  67.3358  68.8433   
Chronic kidney

Epoch 001: 100%|██████████| 59/59 [00:03<00:00, 14.85it/s, loss=0.4555]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.10it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.00it/s]


Validation: {'precision': 29.1467, 'recall': 21.5835, 'f1': 22.9602, 'auc': 67.4502, 'prauc': 37.8466}
Test:      {'precision': 36.5005, 'recall': 20.4271, 'f1': 22.058, 'auc': 67.0277, 'prauc': 36.2824}

Validation-long: {'precision': 33.3014, 'recall': 24.3009, 'f1': 24.6013, 'auc': 66.8273, 'prauc': 45.4011}
Test-long: {'precision': 25.9765, 'recall': 22.5019, 'f1': 23.1757, 'auc': 64.7486, 'prauc': 40.4448}



Epoch 002: 100%|██████████| 59/59 [00:03<00:00, 14.88it/s, loss=0.3990]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 21.98it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 21.68it/s]


Validation: {'precision': 42.7923, 'recall': 23.0194, 'f1': 25.7133, 'auc': 71.4532, 'prauc': 40.4823}
Test:      {'precision': 36.404, 'recall': 22.4167, 'f1': 25.1509, 'auc': 73.0017, 'prauc': 39.6831}

Validation-long: {'precision': 48.1512, 'recall': 26.582, 'f1': 28.8338, 'auc': 67.6524, 'prauc': 47.1035}
Test-long: {'precision': 37.4837, 'recall': 23.6488, 'f1': 26.2916, 'auc': 72.2771, 'prauc': 43.1295}



Epoch 003: 100%|██████████| 59/59 [00:04<00:00, 14.64it/s, loss=0.3821]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.00it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 21.95it/s]


Validation: {'precision': 39.9508, 'recall': 25.1711, 'f1': 29.9127, 'auc': 72.8151, 'prauc': 41.4806}
Test:      {'precision': 39.5631, 'recall': 24.8576, 'f1': 29.457, 'auc': 74.7243, 'prauc': 41.6904}

Validation-long: {'precision': 44.0171, 'recall': 26.6833, 'f1': 30.4713, 'auc': 69.1345, 'prauc': 47.5193}
Test-long: {'precision': 44.3914, 'recall': 26.2846, 'f1': 30.0785, 'auc': 72.9082, 'prauc': 45.0925}



Epoch 004: 100%|██████████| 59/59 [00:03<00:00, 14.83it/s, loss=0.3699]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 19.74it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.19it/s]


Validation: {'precision': 41.3207, 'recall': 28.0552, 'f1': 32.0714, 'auc': 73.5829, 'prauc': 42.9506}
Test:      {'precision': 40.5394, 'recall': 27.5692, 'f1': 31.5597, 'auc': 74.9581, 'prauc': 42.6932}

Validation-long: {'precision': 45.6564, 'recall': 29.8556, 'f1': 33.8061, 'auc': 69.5118, 'prauc': 47.6394}
Test-long: {'precision': 43.2091, 'recall': 27.9708, 'f1': 32.3746, 'auc': 72.9417, 'prauc': 44.7528}



Epoch 005: 100%|██████████| 59/59 [00:03<00:00, 14.75it/s, loss=0.3571]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 21.89it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 21.76it/s]


Validation: {'precision': 39.9631, 'recall': 35.2605, 'f1': 36.0394, 'auc': 73.8818, 'prauc': 43.6531}
Test:      {'precision': 42.474, 'recall': 34.7866, 'f1': 35.9536, 'auc': 75.2026, 'prauc': 43.3153}

Validation-long: {'precision': 40.7726, 'recall': 40.1439, 'f1': 39.1774, 'auc': 69.2298, 'prauc': 47.7368}
Test-long: {'precision': 41.7823, 'recall': 39.2812, 'f1': 39.6335, 'auc': 74.4733, 'prauc': 46.604}



Epoch 006: 100%|██████████| 59/59 [00:03<00:00, 14.96it/s, loss=0.3449]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.27it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.10it/s]


Validation: {'precision': 42.5218, 'recall': 34.685, 'f1': 36.3597, 'auc': 74.3656, 'prauc': 44.4599}
Test:      {'precision': 42.7713, 'recall': 34.3382, 'f1': 36.0414, 'auc': 75.5226, 'prauc': 44.3951}

Validation-long: {'precision': 42.1388, 'recall': 36.1489, 'f1': 37.5003, 'auc': 69.7696, 'prauc': 49.233}
Test-long: {'precision': 43.156, 'recall': 36.2482, 'f1': 38.3323, 'auc': 74.125, 'prauc': 45.8406}



Epoch 007: 100%|██████████| 59/59 [00:03<00:00, 14.85it/s, loss=0.3353]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.24it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.19it/s]


Validation: {'precision': 42.917, 'recall': 32.3592, 'f1': 34.5024, 'auc': 74.1988, 'prauc': 44.2709}
Test:      {'precision': 43.9621, 'recall': 31.4796, 'f1': 33.8826, 'auc': 75.3159, 'prauc': 44.348}

Validation-long: {'precision': 40.8809, 'recall': 33.0553, 'f1': 34.1404, 'auc': 69.6611, 'prauc': 48.498}
Test-long: {'precision': 43.9962, 'recall': 31.5776, 'f1': 34.3186, 'auc': 73.0024, 'prauc': 44.9242}



Epoch 008: 100%|██████████| 59/59 [00:03<00:00, 15.04it/s, loss=0.3266]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.32it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.20it/s]


Validation: {'precision': 41.4851, 'recall': 35.2014, 'f1': 37.3655, 'auc': 74.5775, 'prauc': 44.4236}
Test:      {'precision': 42.4152, 'recall': 34.465, 'f1': 36.5126, 'auc': 75.2349, 'prauc': 43.8623}

Validation-long: {'precision': 40.1167, 'recall': 36.1616, 'f1': 36.9399, 'auc': 70.8099, 'prauc': 50.1317}
Test-long: {'precision': 40.8407, 'recall': 37.2023, 'f1': 38.1086, 'auc': 72.752, 'prauc': 45.6561}



Epoch 009: 100%|██████████| 59/59 [00:03<00:00, 15.19it/s, loss=0.3167]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.32it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 20.03it/s]


Validation: {'precision': 41.6981, 'recall': 36.6968, 'f1': 38.1571, 'auc': 74.5155, 'prauc': 44.3199}
Test:      {'precision': 41.7089, 'recall': 35.2573, 'f1': 36.7116, 'auc': 75.2509, 'prauc': 44.1252}

Validation-long: {'precision': 42.7903, 'recall': 38.2035, 'f1': 39.7164, 'auc': 71.608, 'prauc': 50.3485}
Test-long: {'precision': 41.0807, 'recall': 36.3186, 'f1': 38.2555, 'auc': 74.3077, 'prauc': 46.4982}



Epoch 010: 100%|██████████| 59/59 [00:03<00:00, 15.03it/s, loss=0.3052]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.29it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.11it/s]


Validation: {'precision': 47.3439, 'recall': 36.347, 'f1': 38.5767, 'auc': 74.4262, 'prauc': 44.7539}
Test:      {'precision': 45.6813, 'recall': 35.3457, 'f1': 37.3415, 'auc': 75.4073, 'prauc': 44.466}

Validation-long: {'precision': 43.8055, 'recall': 40.1919, 'f1': 41.3957, 'auc': 71.3699, 'prauc': 50.4531}
Test-long: {'precision': 42.1297, 'recall': 38.9016, 'f1': 39.8262, 'auc': 74.5322, 'prauc': 47.5469}



Epoch 011: 100%|██████████| 59/59 [00:04<00:00, 14.13it/s, loss=0.2946]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.09it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.16it/s]


Validation: {'precision': 47.0512, 'recall': 37.183, 'f1': 38.1722, 'auc': 73.7808, 'prauc': 43.9281}
Test:      {'precision': 41.3432, 'recall': 36.5859, 'f1': 37.1985, 'auc': 74.6998, 'prauc': 43.4862}

Validation-long: {'precision': 40.8403, 'recall': 37.5591, 'f1': 38.5587, 'auc': 68.8758, 'prauc': 49.1486}
Test-long: {'precision': 41.118, 'recall': 38.3824, 'f1': 39.2895, 'auc': 73.1444, 'prauc': 45.1469}



Epoch 012: 100%|██████████| 59/59 [00:04<00:00, 12.73it/s, loss=0.2840]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 21.01it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 21.92it/s]


Validation: {'precision': 43.5202, 'recall': 36.5766, 'f1': 37.669, 'auc': 73.9224, 'prauc': 44.4276}
Test:      {'precision': 44.2751, 'recall': 35.4564, 'f1': 36.8829, 'auc': 73.8732, 'prauc': 43.4906}

Validation-long: {'precision': 41.7175, 'recall': 39.4831, 'f1': 39.1757, 'auc': 71.3302, 'prauc': 51.3529}
Test-long: {'precision': 42.7311, 'recall': 39.084, 'f1': 39.981, 'auc': 73.4053, 'prauc': 46.29}



Epoch 013: 100%|██████████| 59/59 [00:04<00:00, 14.74it/s, loss=0.2725]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.02it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 21.98it/s]


Validation: {'precision': 43.086, 'recall': 37.192, 'f1': 38.568, 'auc': 73.6229, 'prauc': 43.5202}
Test:      {'precision': 48.5667, 'recall': 36.2116, 'f1': 37.9616, 'auc': 73.7783, 'prauc': 43.0383}

Validation-long: {'precision': 42.0405, 'recall': 40.3862, 'f1': 40.6932, 'auc': 70.256, 'prauc': 49.7891}
Test-long: {'precision': 40.4433, 'recall': 40.8527, 'f1': 40.3401, 'auc': 73.4243, 'prauc': 45.4957}



Epoch 014: 100%|██████████| 59/59 [00:04<00:00, 12.63it/s, loss=0.2643]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 21.69it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 21.61it/s]


Validation: {'precision': 43.6746, 'recall': 36.69, 'f1': 37.8722, 'auc': 73.737, 'prauc': 43.5932}
Test:      {'precision': 39.8397, 'recall': 35.3115, 'f1': 36.6628, 'auc': 74.0654, 'prauc': 42.8674}

Validation-long: {'precision': 43.9152, 'recall': 39.4518, 'f1': 40.8304, 'auc': 70.1474, 'prauc': 50.2605}
Test-long: {'precision': 41.0402, 'recall': 36.8123, 'f1': 38.273, 'auc': 73.5808, 'prauc': 45.4947}



Epoch 015: 100%|██████████| 59/59 [00:04<00:00, 13.90it/s, loss=0.2518]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 22.06it/s]
Running inference: 100%|██████████| 45/45 [00:02<00:00, 21.93it/s]

Validation: {'precision': 42.6193, 'recall': 37.1741, 'f1': 38.3178, 'auc': 73.2436, 'prauc': 43.0524}
Test:      {'precision': 45.0786, 'recall': 36.3465, 'f1': 37.7212, 'auc': 73.7083, 'prauc': 42.3841}

Validation-long: {'precision': 41.993, 'recall': 40.4108, 'f1': 40.5588, 'auc': 70.7476, 'prauc': 49.898}
Test-long: {'precision': 39.8843, 'recall': 39.346, 'f1': 39.1917, 'auc': 72.9955, 'prauc': 45.9408}


Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'global': {'precision': 47.3439, 'recall': 36.347, 'f1': 38.5767, 'auc': 74.4262, 'prauc': 44.7539}, 'per_class':                                            precision   recall       f1  \
Acute and unspecified renal failure          51.6854  52.9954  52.3322   
Acute cerebrovascular disease                 0.0000   0.0000   0.0000   
Acute myocardial infarction                   0.0000   0.0000   0.0000   
Cardiac dysrhythmias                         77.4127  68.7956  72.8502   
Chronic kidney




In [18]:
def topk_avg_performance_formatted(performances, long_seq_performances, k=5):
    metrics = ["f1", "auc", "prauc"]
    scores = {m: np.array([p[m] for p in performances]) for m in metrics}

    # 计算排名（值越大排名越靠前）
    ranks = {m: (-scores[m]).argsort().argsort() + 1 for m in metrics}
    avg_ranks = np.mean(np.stack([ranks[m] for m in metrics], axis=1), axis=1)

    # 选 top-k
    topk_idx = np.argsort(avg_ranks)[:k]
    final_avg = {m: np.mean([performances[i][m] for i in topk_idx]) for m in performances[0].keys()}
    final_std = {m: np.std([performances[i][m] for i in topk_idx], ddof=0) for m in performances[0].keys()}
    final_long_seq_avg = {m: np.mean([long_seq_performances[i][m] for i in topk_idx]) for m in long_seq_performances[0].keys()}
    final_long_seq_std = {m: np.std([long_seq_performances[i][m] for i in topk_idx], ddof=0) for m in long_seq_performances[0].keys()}

    # 打印结果（转百分比，均保留两位小数）
    print("Final Metrics:")
    for m in performances[0].keys():
        mean_val = final_avg[m]
        std_val = final_std[m]
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")
    print("\nFinal Long Sequence Metrics:")
    for m in long_seq_performances[0].keys():
        mean_val = final_long_seq_avg[m]
        std_val = final_long_seq_std[m]
        print(f"{m}: {mean_val:.2f} ± {std_val:.2f}")

In [30]:
def print_per_class_performance(dfs, col_name="prauc"):
    """
    输入一个 DataFrame 列表，对每个疾病在所有表格的指定列计算 mean ± std 并打印。

    参数:
        dfs (list[pd.DataFrame]): 多个表格组成的列表
        col_name (str): 要计算的指标列名 (默认: "prauc")
    """
    # 拼接所有表格
    all_values = pd.concat(dfs, axis=0)

    # 按疾病分组，计算 mean 和 std
    grouped = all_values.groupby(all_values.index)[col_name].agg(["mean", "std"])

    # 打印
    for disease, row in grouped.iterrows():
        mean_val = row["mean"]
        std_val = row["std"]
        print(f"{disease}: {mean_val:.2f} ± {std_val:.2f}")

In [31]:
if task_type == "binary":
    topk_avg_performance_formatted(final_metrics, final_long_seq_metrics)
else:
    final_metrics_global = [metrics["global"] for metrics in final_metrics]
    final_metrics_per_class = [metrics["per_class"] for metrics in final_metrics]
    final_long_seq_metrics_global = [metrics["global"] for metrics in final_long_seq_metrics]
    final_long_seq_metrics_per_class = [metrics["per_class"] for metrics in final_long_seq_metrics]
    topk_avg_performance_formatted(final_metrics_global, final_long_seq_metrics_global)
    print("\nPer-class performance, all patients:")
    print_per_class_performance(final_metrics_per_class, col_name="prauc")
    print("\nPer-class performance, long seq:")
    print_per_class_performance(final_long_seq_metrics_per_class, col_name="prauc")

Final Metrics:
precision: 43.53 ± 2.15
recall: 36.31 ± 0.96
f1: 38.11 ± 0.77
auc: 73.47 ± 1.93
prauc: 42.67 ± 1.80

Final Long Sequence Metrics:
precision: 42.76 ± 0.63
recall: 39.19 ± 0.29
f1: 40.04 ± 0.22
auc: 72.06 ± 2.47
prauc: 45.96 ± 1.59

Per-class performance, all patients:
Acute and unspecified renal failure: 45.65 ± 1.58
Acute cerebrovascular disease: 7.40 ± 3.99
Acute myocardial infarction: 15.76 ± 5.01
Cardiac dysrhythmias: 74.58 ± 3.39
Chronic kidney disease: 79.53 ± 0.77
Chronic obstructive pulmonary disease: 44.01 ± 8.65
Conduction disorders: 4.93 ± 0.86
Congestive heart failure; nonhypertensive: 72.27 ± 3.32
Coronary atherosclerosis and related: 58.88 ± 2.43
Disorders of lipid metabolism: 54.73 ± 4.33
Essential hypertension: 62.66 ± 0.17
Fluid and electrolyte disorders: 49.45 ± 0.18
Gastrointestinal hemorrhage: 11.33 ± 2.68
Hypertension with complications: 71.66 ± 2.02
Other liver diseases: 3.08 ± 1.60
Other lower respiratory disease: 57.31 ± 4.18
Pneumonia: 17.63 ± 0.1