In [1]:
import torch
import pickle
import numpy as np
from argparse import Namespace
from torch.utils.data import DataLoader
import torch.nn.functional as F
from heterogt.utils.tokenizer import EHRTokenizer
from heterogt.utils.dataset import FineTuneEHRDataset, batcher, expand_level3
from heterogt.utils.train import train_with_early_stopping
from heterogt.utils.seed import set_random_seed
from heterogt.model.model import HeteroGT

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
set_random_seed(123)

[INFO] Random seed set to 123


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
config = Namespace(
    dataset = "MIMIC-III",
    tasks = ["death", "readmission", "stay", "next_diag_6m", "next_diag_12m"], 
    task_index = 0,  # index of the task to train
    token_type = ["diag", "med", "lab", "pro"],
    special_tokens = ["[PAD]"],
    batch_size = 32,
    lr = 1e-3,
    epochs = 500,
    early_stop_patience = 5,
    group_code_thre = 5,  # if there are group_code_thre diag codes belongs to the same group ICD code, then the group code is generated
)

In [5]:
full_data_path = f"./data_process/{config.dataset}-processed/mimic.pkl"  # for tokenizer
curr_task = config.tasks[config.task_index]
print("Current task:", curr_task)
if curr_task == "next_diag_6m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_6m.pkl"
elif curr_task == "next_diag_12m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_downstream.pkl"

Current task: death


In [6]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
group_code_sentences = [expand_level3()[1]]
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_sentences = [[str(c)] for c in set(ehr_full_data["AGE"].values.tolist())] # important of [[]]
token_type_sentences = ["[PAD]"] + config.token_type
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
config.max_num_adms = max_admissions
print(f"Max admissions per patient: {config.max_num_adms}")

Max admissions per patient: 8


In [7]:
task_sentences = config.tasks
tokenizer = EHRTokenizer(age_sentences, group_code_sentences, diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, special_tokens=config.special_tokens)
config.label_vocab_size = len(tokenizer.diag_voc.id2word)  # only for diagnosis
config.global_vocab_size = len(tokenizer.vocab.id2word)
config.age_vocab_size = tokenizer.token_number("age")
config.group_code_vocab_size = tokenizer.token_number("group")
print(f"Age vocabulary size: {config.age_vocab_size}")
print(f"Group code vocabulary size: {config.group_code_vocab_size}")


Age vocabulary size: 18
Group code vocabulary size: 19


In [8]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))
# example label percentage
print("Percentage of DEATH in test dataset:",
      (test_data["DEATH"] == True).mean() * 100, "%")

print("Percentage of READMISSION in test dataset:",
      (test_data["READMISSION"] == 1).mean() * 100, "%")

print("Percentage of STAY>7 days in test dataset:",
      (test_data["STAY_DAYS"] > 7).mean() * 100, "%")

Percentage of DEATH in test dataset: 28.648477157360407 %
Percentage of READMISSION in test dataset: 40.1491116751269 %
Percentage of STAY>7 days in test dataset: 50.58692893401015 %


In [9]:
train_dataset = FineTuneEHRDataset(train_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                   max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)
val_dataset = FineTuneEHRDataset(val_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                 max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)
test_dataset = FineTuneEHRDataset(test_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                   max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)

In [10]:
num_group_code = []
for i in range(len(train_dataset)):
    input_ids, token_types, adm_index, age_ids, diag_group_codes, labels = train_dataset[i]
    count = (token_types[0] == 5).sum().item()
    num_group_code.append(count)
print("Mean group token numer per patient", np.mean(num_group_code))

Mean group token numer per patient 0.7971893963589908


In [11]:
train_dataloader = DataLoader(
    train_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=True,
    batch_size=config.batch_size,
)

val_dataloader = DataLoader(
    val_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

test_dataloader = DataLoader(
    test_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

In [12]:
for batch in train_dataloader:
    pass  # just to check if the dataloader works
for batch in val_dataloader:
    pass  # just to check if the dataloader works
for batch in test_dataloader:
    pass  # just to check if the dataloader works

In [13]:
if curr_task in ["death", "stay", "readmission"]:
    eval_metric = "f1"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "f1"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [14]:
input_ids, token_types, adm_index, age_ids, diag_code_group_dicts, task_index, labels = next(iter(train_dataloader))
print("Input IDs shape:", input_ids.shape)
print("Token Types shape:", token_types.shape)
print("Admission Index shape:", adm_index.shape)
print("Age IDs shape:", age_ids.shape)
print("Diag Code Group Dict number:", len(diag_code_group_dicts))
print("Task Index:", task_index)
print("Labels shape:", labels.shape)

Input IDs shape: torch.Size([32, 292])
Token Types shape: torch.Size([32, 292])
Admission Index shape: torch.Size([32, 292])
Age IDs shape: torch.Size([32, 8])
Diag Code Group Dict number: 32
Task Index: 0
Labels shape: torch.Size([32, 1])


# Model Walkthrough

In [15]:
final_metrics = {"precision":[],"recall":[],"f1":[],"auc":[],"prauc":[]}
for i in range(5):
    model = HeteroGT(tokenizer, d_model=128, num_heads=4, layer_types=['gnn', 'tf'], max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, label_vocab_size=config.label_vocab_size).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

Epoch 001: 100%|██████████| 98/98 [00:04<00:00, 21.80it/s, loss=0.5206]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.43it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.00it/s]


Validation: {'precision': 0.7319587628803097, 'recall': 0.502062463167342, 'f1': 0.5955959406429828, 'auc': 0.85450312655646, 'prauc': 0.7037385391621366}
Test:      {'precision': 0.7547008546944043, 'recall': 0.48892580287658405, 'f1': 0.5934139737189948, 'auc': 0.8528156601654384, 'prauc': 0.7229726690227163}


Epoch 002: 100%|██████████| 98/98 [00:04<00:00, 22.68it/s, loss=0.3962]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 37.51it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.13it/s]


Validation: {'precision': 0.7577125658332754, 'recall': 0.5934001178515416, 'f1': 0.665565097515022, 'auc': 0.8814601991425682, 'prauc': 0.7579248951594755}
Test:      {'precision': 0.7841409691572384, 'recall': 0.5913621262425728, 'f1': 0.67424241933638, 'auc': 0.8805681816503164, 'prauc': 0.7738165989497076}


Epoch 003: 100%|██████████| 98/98 [00:04<00:00, 23.38it/s, loss=0.3480]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.87it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.21it/s]


Validation: {'precision': 0.7820224719042546, 'recall': 0.6152032999374473, 'f1': 0.6886543486287364, 'auc': 0.8947735646524815, 'prauc': 0.786268185614024}
Test:      {'precision': 0.7910983488816146, 'recall': 0.6101882613476735, 'f1': 0.6889652967357982, 'auc': 0.8955238873238605, 'prauc': 0.8030081980070783}


Epoch 004: 100%|██████████| 98/98 [00:04<00:00, 23.66it/s, loss=0.3030]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.21it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.06it/s]


Validation: {'precision': 0.7170818505295548, 'recall': 0.7124337065367565, 'f1': 0.7147502166925463, 'auc': 0.8980640783438365, 'prauc': 0.7901937262702552}
Test:      {'precision': 0.7429742388715282, 'recall': 0.7026578073050794, 'f1': 0.7222538367755323, 'auc': 0.8966181967425414, 'prauc': 0.7973859982436947}


Epoch 005: 100%|██████████| 98/98 [00:04<00:00, 23.45it/s, loss=0.2767]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.34it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.17it/s]


Validation: {'precision': 0.6912568305973156, 'recall': 0.7454331172613705, 'f1': 0.717323499397712, 'auc': 0.8952991603115941, 'prauc': 0.7899319368245424}
Test:      {'precision': 0.7090032154302841, 'recall': 0.7325581395308275, 'f1': 0.720588230291528, 'auc': 0.8964172337945695, 'prauc': 0.8030279570415557}


Epoch 006: 100%|██████████| 98/98 [00:04<00:00, 23.59it/s, loss=0.2498]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 37.48it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.05it/s]


Validation: {'precision': 0.662222222218952, 'recall': 0.7902180318162039, 'f1': 0.7205803281891761, 'auc': 0.9018720121876366, 'prauc': 0.8004580794691909}
Test:      {'precision': 0.6688836104481288, 'recall': 0.7796234772935791, 'f1': 0.7200204501521079, 'auc': 0.9012630567443043, 'prauc': 0.8139168772055492}


Epoch 007: 100%|██████████| 98/98 [00:04<00:00, 23.27it/s, loss=0.2110]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.36it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.23it/s]


Validation: {'precision': 0.7620643431584313, 'recall': 0.670005892747967, 'f1': 0.7130761944517495, 'auc': 0.9031215872848388, 'prauc': 0.8084194975010804}
Test:      {'precision': 0.7841207349029914, 'recall': 0.661683277958684, 'f1': 0.7177177127492647, 'auc': 0.9065591844191118, 'prauc': 0.8203935353242968}


Epoch 008: 100%|██████████| 98/98 [00:04<00:00, 23.69it/s, loss=0.2019]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.46it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.24it/s]


Validation: {'precision': 0.6422613531017505, 'recall': 0.8167354154341973, 'f1': 0.7190661429276946, 'auc': 0.9025319288664493, 'prauc': 0.8004737379364466}
Test:      {'precision': 0.649576460095187, 'recall': 0.8067552602391653, 'f1': 0.7196838676158136, 'auc': 0.9016055246899446, 'prauc': 0.810266386172872}


Epoch 009: 100%|██████████| 98/98 [00:04<00:00, 23.65it/s, loss=0.1822]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.36it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.21it/s]


Validation: {'precision': 0.7116170461137676, 'recall': 0.7183264584518662, 'f1': 0.7149560067261223, 'auc': 0.8994435034805304, 'prauc': 0.7977673500354447}
Test:      {'precision': 0.7220338983010055, 'recall': 0.7076411960093708, 'f1': 0.7147650956676502, 'auc': 0.89859194217979, 'prauc': 0.8078909039391132}


Epoch 010: 100%|██████████| 98/98 [00:04<00:00, 23.48it/s, loss=0.1489]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.35it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 37.61it/s]


Validation: {'precision': 0.7246200607858686, 'recall': 0.7024160282810701, 'f1': 0.7133452972111846, 'auc': 0.8929647338023095, 'prauc': 0.7941846406590323}
Test:      {'precision': 0.723076923072803, 'recall': 0.7026578073050794, 'f1': 0.7127211407425997, 'auc': 0.8936289882989708, 'prauc': 0.8065980149387668}


Epoch 011: 100%|██████████| 98/98 [00:04<00:00, 23.47it/s, loss=0.1299]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.20it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.13it/s]


Validation: {'precision': 0.7029876977111718, 'recall': 0.7071302298131578, 'f1': 0.7050528739618233, 'auc': 0.8954678440077559, 'prauc': 0.7874686210856655}
Test:      {'precision': 0.7274826789796335, 'recall': 0.6976744186007882, 'f1': 0.7122668124091278, 'auc': 0.8936825989353211, 'prauc': 0.8010876219774745}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.662222222218952, 'recall': 0.7902180318162039, 'f1': 0.7205803281891761, 'auc': 0.9018720121876366, 'prauc': 0.8004580794691909}
Corresponding test performance:
{'precision': 0.6688836104481288, 'recall': 0.7796234772935791, 'f1': 0.7200204501521079, 'auc': 0.9012630567443043, 'prauc': 0.8139168772055492}


Epoch 001: 100%|██████████| 98/98 [00:04<00:00, 23.61it/s, loss=0.5352]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.19it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.02it/s]


Validation: {'precision': 0.38615346163362624, 'recall': 0.9104301708844407, 'f1': 0.5422955381111433, 'auc': 0.8416761398221138, 'prauc': 0.6906384568136634}
Test:      {'precision': 0.41144548593879765, 'recall': 0.9235880398619957, 'f1': 0.569283272184382, 'auc': 0.8458366755348876, 'prauc': 0.7081651165984572}


Epoch 002: 100%|██████████| 98/98 [00:04<00:00, 23.38it/s, loss=0.4220]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.30it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.24it/s]


Validation: {'precision': 0.6451282051248968, 'recall': 0.7413081909207937, 'f1': 0.6898820898927772, 'auc': 0.8821189022696101, 'prauc': 0.7517186783017301}
Test:      {'precision': 0.6636178861754898, 'recall': 0.7231450719782772, 'f1': 0.692103863580002, 'auc': 0.8771217748062754, 'prauc': 0.762004104966062}


Epoch 003: 100%|██████████| 98/98 [00:04<00:00, 23.46it/s, loss=0.3621]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.38it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 37.75it/s]


Validation: {'precision': 0.7218899521487926, 'recall': 0.7112551561537346, 'f1': 0.7165330908701685, 'auc': 0.8982516678991669, 'prauc': 0.7951543177229051}
Test:      {'precision': 0.7382744643848392, 'recall': 0.7059800664412736, 'f1': 0.721766199357316, 'auc': 0.8958346566728069, 'prauc': 0.7994842147868365}


Epoch 004: 100%|██████████| 98/98 [00:04<00:00, 23.55it/s, loss=0.3196]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.38it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.12it/s]


Validation: {'precision': 0.8404740200470331, 'recall': 0.5433117265731096, 'f1': 0.6599856788335856, 'auc': 0.9071298210419914, 'prauc': 0.8073367595431001}
Test:      {'precision': 0.8528634361158339, 'recall': 0.5359911406393356, 'f1': 0.6582794920256002, 'auc': 0.9032880123416485, 'prauc': 0.8116292923047754}


Epoch 005: 100%|██████████| 98/98 [00:04<00:00, 23.45it/s, loss=0.2687]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.47it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.38it/s]


Validation: {'precision': 0.6946658491680279, 'recall': 0.6676487919819232, 'f1': 0.6808894180747979, 'auc': 0.8852418180742825, 'prauc': 0.7763218327075245}
Test:      {'precision': 0.7370689655127028, 'recall': 0.6627906976707487, 'f1': 0.6979591786834773, 'auc': 0.884676442883191, 'prauc': 0.7829639458686638}


Epoch 006: 100%|██████████| 98/98 [00:04<00:00, 23.70it/s, loss=0.2436]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.38it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.20it/s]


Validation: {'precision': 0.7843551796984894, 'recall': 0.6558632881517038, 'f1': 0.7143774019671771, 'auc': 0.9087198293465176, 'prauc': 0.8118720141246567}
Test:      {'precision': 0.7865319865266901, 'recall': 0.64673311184581, 'f1': 0.7098146410475097, 'auc': 0.9042851947980326, 'prauc': 0.815570682190862}


Epoch 007: 100%|██████████| 98/98 [00:04<00:00, 23.47it/s, loss=0.2269]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.39it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 37.62it/s]


Validation: {'precision': 0.892768079789367, 'recall': 0.4219210371218508, 'f1': 0.5730292073214209, 'auc': 0.9010447531067245, 'prauc': 0.799569678275225}
Test:      {'precision': 0.9099639855833137, 'recall': 0.4197120708725376, 'f1': 0.5744600184112307, 'auc': 0.899365387939121, 'prauc': 0.8118034812891033}


Epoch 008: 100%|██████████| 98/98 [00:04<00:00, 23.56it/s, loss=0.2125]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.38it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.15it/s]


Validation: {'precision': 0.8365051903041826, 'recall': 0.5698291101911029, 'f1': 0.6778829254238969, 'auc': 0.9064964747598476, 'prauc': 0.8126691245906593}
Test:      {'precision': 0.8406374501925049, 'recall': 0.5841638981141519, 'f1': 0.6893172117533917, 'auc': 0.9050745821817202, 'prauc': 0.8197546079593385}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7218899521487926, 'recall': 0.7112551561537346, 'f1': 0.7165330908701685, 'auc': 0.8982516678991669, 'prauc': 0.7951543177229051}
Corresponding test performance:
{'precision': 0.7382744643848392, 'recall': 0.7059800664412736, 'f1': 0.721766199357316, 'auc': 0.8958346566728069, 'prauc': 0.7994842147868365}


Epoch 001: 100%|██████████| 98/98 [00:04<00:00, 23.43it/s, loss=0.5469]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.23it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.20it/s]


Validation: {'precision': 0.5295068714610772, 'recall': 0.7719505008793639, 'f1': 0.6281467225740013, 'auc': 0.8314506248578069, 'prauc': 0.6733554521073081}
Test:      {'precision': 0.545052292837711, 'recall': 0.7502768549238634, 'f1': 0.6314072644608688, 'auc': 0.8299225643290706, 'prauc': 0.6805412439789148}


Epoch 002: 100%|██████████| 98/98 [00:04<00:00, 23.62it/s, loss=0.4142]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.22it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.01it/s]


Validation: {'precision': 0.7803163444571151, 'recall': 0.5232763700617368, 'f1': 0.6264550216450033, 'auc': 0.874649810475149, 'prauc': 0.7414937972218506}
Test:      {'precision': 0.8046471600619222, 'recall': 0.5177187153902674, 'f1': 0.6300539035869537, 'auc': 0.8823420720517103, 'prauc': 0.7679706593262169}


Epoch 003: 100%|██████████| 98/98 [00:04<00:00, 23.42it/s, loss=0.3540]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.08it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.02it/s]


Validation: {'precision': 0.8230008984652022, 'recall': 0.5397760754240438, 'f1': 0.6519572905849902, 'auc': 0.8948735229957203, 'prauc': 0.7877515211605312}
Test:      {'precision': 0.8501314636209455, 'recall': 0.5370985603514004, 'f1': 0.6582965680360133, 'auc': 0.8961001862769573, 'prauc': 0.7958182959041842}


Epoch 004: 100%|██████████| 98/98 [00:04<00:00, 21.41it/s, loss=0.2986]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.04it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 39.97it/s]


Validation: {'precision': 0.7241165530023304, 'recall': 0.6882734236848069, 'f1': 0.7057401762678381, 'auc': 0.8898224011693019, 'prauc': 0.7827386164805786}
Test:      {'precision': 0.7451219512149688, 'recall': 0.676633444071558, 'f1': 0.7092280855472428, 'auc': 0.891857744576524, 'prauc': 0.7977257606457258}


Epoch 005: 100%|██████████| 98/98 [00:04<00:00, 23.23it/s, loss=0.2858]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 39.93it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 39.57it/s]


Validation: {'precision': 0.7647887323889804, 'recall': 0.6399528579809078, 'f1': 0.6968238641399248, 'auc': 0.8996801460758653, 'prauc': 0.7946714392199936}
Test:      {'precision': 0.7838205302462011, 'recall': 0.6384274640053244, 'f1': 0.703692396634775, 'auc': 0.9016887412001002, 'prauc': 0.8064255613797453}


Epoch 006: 100%|██████████| 98/98 [00:04<00:00, 23.52it/s, loss=0.2605]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 39.82it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 39.26it/s]


Validation: {'precision': 0.6716188524555757, 'recall': 0.7725397760708749, 'f1': 0.7185530232473907, 'auc': 0.9000325870586072, 'prauc': 0.812484283760276}
Test:      {'precision': 0.6904761904728355, 'recall': 0.7868217054219999, 'f1': 0.7355072413942713, 'auc': 0.9021988731795157, 'prauc': 0.81897862808027}


Epoch 007: 100%|██████████| 98/98 [00:04<00:00, 23.44it/s, loss=0.2384]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.28it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.11it/s]


Validation: {'precision': 0.8507853403067123, 'recall': 0.5745433117231907, 'f1': 0.6858951763296613, 'auc': 0.9074286741333739, 'prauc': 0.8167149281187841}
Test:      {'precision': 0.8478081058656096, 'recall': 0.5675526024331808, 'f1': 0.6799336601998214, 'auc': 0.9073690681769724, 'prauc': 0.8242518576140675}


Epoch 008: 100%|██████████| 98/98 [00:04<00:00, 23.50it/s, loss=0.2195]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 37.71it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.26it/s]


Validation: {'precision': 0.6674828599379653, 'recall': 0.8031820860294451, 'f1': 0.7290719394088222, 'auc': 0.908643311713802, 'prauc': 0.8156776191166395}
Test:      {'precision': 0.6713352007438314, 'recall': 0.7962347729745503, 'f1': 0.7284701064813601, 'auc': 0.905260095910721, 'prauc': 0.8118244650492015}


Epoch 009: 100%|██████████| 98/98 [00:04<00:00, 23.64it/s, loss=0.1980]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.49it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.19it/s]


Validation: {'precision': 0.810730253347163, 'recall': 0.6411314083639297, 'f1': 0.7160250032899065, 'auc': 0.910335705465109, 'prauc': 0.8201993152929735}
Test:      {'precision': 0.8206442166850613, 'recall': 0.6207087486122884, 'f1': 0.7068095789505152, 'auc': 0.9090651585274518, 'prauc': 0.8178575777137821}


Epoch 010: 100%|██████████| 98/98 [00:04<00:00, 23.45it/s, loss=0.1757]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.40it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.33it/s]


Validation: {'precision': 0.7595269382341687, 'recall': 0.6812021213866754, 'f1': 0.7182354718664818, 'auc': 0.9023085075983031, 'prauc': 0.8048477745393762}
Test:      {'precision': 0.7674870466271536, 'recall': 0.6561461793983603, 'f1': 0.7074626815935238, 'auc': 0.8972612781760516, 'prauc': 0.7994014302326475}


Epoch 011: 100%|██████████| 98/98 [00:04<00:00, 23.50it/s, loss=0.1876]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.40it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.84it/s]


Validation: {'precision': 0.6892685531196515, 'recall': 0.7607542722406556, 'f1': 0.7232492947279885, 'auc': 0.9005912679712647, 'prauc': 0.8064251585951279}
Test:      {'precision': 0.7030174695568925, 'recall': 0.7353266888109894, 'f1': 0.7188091966224484, 'auc': 0.8956682236524957, 'prauc': 0.7992823594170982}


Epoch 012: 100%|██████████| 98/98 [00:04<00:00, 23.30it/s, loss=0.1455]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 37.55it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.29it/s]


Validation: {'precision': 0.7812499999945747, 'recall': 0.6629345904498354, 'f1': 0.7172457712483039, 'auc': 0.9047174717347826, 'prauc': 0.8089277534125626}
Test:      {'precision': 0.7888662593293356, 'recall': 0.6434108527096157, 'f1': 0.7087526635437373, 'auc': 0.900440124243727, 'prauc': 0.8113343501248347}


Epoch 013: 100%|██████████| 98/98 [00:04<00:00, 23.66it/s, loss=0.1344]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.51it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.22it/s]


Validation: {'precision': 0.7830802602980255, 'recall': 0.6381850324063749, 'f1': 0.7032467482941539, 'auc': 0.9017198072471012, 'prauc': 0.8044038597549112}
Test:      {'precision': 0.7948164146811029, 'recall': 0.6112956810597382, 'f1': 0.6910798072874196, 'auc': 0.8925384334713545, 'prauc': 0.7937502305119323}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6674828599379653, 'recall': 0.8031820860294451, 'f1': 0.7290719394088222, 'auc': 0.908643311713802, 'prauc': 0.8156776191166395}
Corresponding test performance:
{'precision': 0.6713352007438314, 'recall': 0.7962347729745503, 'f1': 0.7284701064813601, 'auc': 0.905260095910721, 'prauc': 0.8118244650492015}


Epoch 001: 100%|██████████| 98/98 [00:04<00:00, 23.66it/s, loss=0.5272]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.50it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.57it/s]


Validation: {'precision': 0.6839714471078263, 'recall': 0.6210960518525569, 'f1': 0.6510191426295734, 'auc': 0.8564075725119538, 'prauc': 0.7121999271124206}
Test:      {'precision': 0.7056222362558079, 'recall': 0.6184939091881589, 'f1': 0.6591914969357288, 'auc': 0.8587261866600487, 'prauc': 0.7233953163431541}


Epoch 002: 100%|██████████| 98/98 [00:04<00:00, 23.49it/s, loss=0.4030]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.52it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.34it/s]


Validation: {'precision': 0.6770428015526568, 'recall': 0.7177371832603552, 'f1': 0.696796333673039, 'auc': 0.8846152421335978, 'prauc': 0.7659363397925468}
Test:      {'precision': 0.6933911159225711, 'recall': 0.7087486157214355, 'f1': 0.700985756223486, 'auc': 0.8872438445633768, 'prauc': 0.7800784284404747}


Epoch 003: 100%|██████████| 98/98 [00:04<00:00, 23.87it/s, loss=0.3507]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 37.84it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.32it/s]


Validation: {'precision': 0.7209976798102031, 'recall': 0.7324690630481293, 'f1': 0.7266880978899539, 'auc': 0.9012805015060178, 'prauc': 0.8010936114999369}
Test:      {'precision': 0.7344632768320087, 'recall': 0.7198228128420829, 'f1': 0.7270693462268655, 'auc': 0.8986213634015758, 'prauc': 0.805612638362415}


Epoch 004: 100%|██████████| 98/98 [00:04<00:00, 23.48it/s, loss=0.3161]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.47it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.39it/s]


Validation: {'precision': 0.7945425361091931, 'recall': 0.5833824395958552, 'f1': 0.6727828697305851, 'auc': 0.8970510180996776, 'prauc': 0.7947204903933824}
Test:      {'precision': 0.8108527131720089, 'recall': 0.5791805094098607, 'f1': 0.6757105894497694, 'auc': 0.8952927029953511, 'prauc': 0.800583173746888}


Epoch 005: 100%|██████████| 98/98 [00:04<00:00, 23.57it/s, loss=0.2927]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.49it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.50it/s]


Validation: {'precision': 0.7456423498983497, 'recall': 0.6806128461951644, 'f1': 0.7116450966695955, 'auc': 0.8970238728626958, 'prauc': 0.7993101267473757}
Test:      {'precision': 0.760197775026204, 'recall': 0.681063122919817, 'f1': 0.7184579389361109, 'auc': 0.8969833153359165, 'prauc': 0.8049039972013722}


Epoch 006: 100%|██████████| 98/98 [00:04<00:00, 23.63it/s, loss=0.2737]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.51it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.33it/s]


Validation: {'precision': 0.7262334536659072, 'recall': 0.7112551561537346, 'f1': 0.7186662647193959, 'auc': 0.9011634256956941, 'prauc': 0.8050638031456672}
Test:      {'precision': 0.7647761193984193, 'recall': 0.7093023255774679, 'f1': 0.7359953986225021, 'auc': 0.9050035527048567, 'prauc': 0.815992023452804}


Epoch 007: 100%|██████████| 98/98 [00:04<00:00, 23.53it/s, loss=0.2411]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.59it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 37.81it/s]


Validation: {'precision': 0.666009852213468, 'recall': 0.7967000589228245, 'f1': 0.7255164962434276, 'auc': 0.9001642893613282, 'prauc': 0.7998692330065725}
Test:      {'precision': 0.69147894221013, 'recall': 0.7818383167177085, 'f1': 0.7338877289027268, 'auc': 0.901218740259606, 'prauc': 0.8124445414488507}


Epoch 008: 100%|██████████| 98/98 [00:04<00:00, 23.65it/s, loss=0.2210]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.51it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.49it/s]


Validation: {'precision': 0.7087155963262115, 'recall': 0.7283441367075525, 'f1': 0.7183958101667661, 'auc': 0.901583059123859, 'prauc': 0.8028334453389547}
Test:      {'precision': 0.7382966723026605, 'recall': 0.7248062015463743, 'f1': 0.7314892378015904, 'auc': 0.902078480062752, 'prauc': 0.8116271966467934}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7209976798102031, 'recall': 0.7324690630481293, 'f1': 0.7266880978899539, 'auc': 0.9012805015060178, 'prauc': 0.8010936114999369}
Corresponding test performance:
{'precision': 0.7344632768320087, 'recall': 0.7198228128420829, 'f1': 0.7270693462268655, 'auc': 0.8986213634015758, 'prauc': 0.805612638362415}


Epoch 001: 100%|██████████| 98/98 [00:04<00:00, 23.62it/s, loss=0.5357]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.19it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.22it/s]


Validation: {'precision': 0.9107142856600765, 'recall': 0.09015910430117761, 'f1': 0.16407506538305028, 'auc': 0.8321836739986057, 'prauc': 0.6784198957639137}
Test:      {'precision': 0.8780487804342653, 'recall': 0.0797342192686615, 'f1': 0.14619289187316345, 'auc': 0.8285401977598509, 'prauc': 0.687549661805976}


Epoch 002: 100%|██████████| 98/98 [00:04<00:00, 23.54it/s, loss=0.4225]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.62it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.45it/s]


Validation: {'precision': 0.5990220048870464, 'recall': 0.7218621096009319, 'f1': 0.6547300859002467, 'auc': 0.8656823015992952, 'prauc': 0.7345990091017726}
Test:      {'precision': 0.611243449258641, 'recall': 0.7104097452895326, 'f1': 0.657106269032466, 'auc': 0.8617484478151235, 'prauc': 0.7367782773307254}


Epoch 003: 100%|██████████| 98/98 [00:04<00:00, 23.65it/s, loss=0.3621]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.53it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 37.72it/s]


Validation: {'precision': 0.7095975232154205, 'recall': 0.6753093694715657, 'f1': 0.6920289805061325, 'auc': 0.8925864250055024, 'prauc': 0.7767299236793677}
Test:      {'precision': 0.7271084337305597, 'recall': 0.6683277962310724, 'f1': 0.6964800873303003, 'auc': 0.8893389063774868, 'prauc': 0.7810618265264988}


Epoch 004: 100%|██████████| 98/98 [00:04<00:00, 23.54it/s, loss=0.3287]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.50it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.39it/s]


Validation: {'precision': 0.7264573990984853, 'recall': 0.6682380671734341, 'f1': 0.6961325916895222, 'auc': 0.8939081744975033, 'prauc': 0.7866094927549006}
Test:      {'precision': 0.7422810333916681, 'recall': 0.6522702104061336, 'f1': 0.6943707583530178, 'auc': 0.892896289085293, 'prauc': 0.7926250043535713}


Epoch 005: 100%|██████████| 98/98 [00:04<00:00, 23.65it/s, loss=0.3111]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 39.17it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.27it/s]


Validation: {'precision': 0.7683772538088185, 'recall': 0.652916912194149, 'f1': 0.7059573062741182, 'auc': 0.9033437311300684, 'prauc': 0.8014848987379105}
Test:      {'precision': 0.7868525896362095, 'recall': 0.6561461793983603, 'f1': 0.71557970518163, 'auc': 0.9034663861925591, 'prauc': 0.8107890916856701}


Epoch 006: 100%|██████████| 98/98 [00:04<00:00, 23.55it/s, loss=0.2568]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.45it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.34it/s]


Validation: {'precision': 0.7602564102515369, 'recall': 0.6988803771320042, 'f1': 0.7282775510375341, 'auc': 0.9064318371602583, 'prauc': 0.8181003354321332}
Test:      {'precision': 0.7710396039556248, 'recall': 0.6899224806163349, 'f1': 0.728229100797248, 'auc': 0.9060661019761707, 'prauc': 0.8235721689873131}


Epoch 007: 100%|██████████| 98/98 [00:04<00:00, 23.73it/s, loss=0.2515]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.37it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 37.48it/s]


Validation: {'precision': 0.8025241276852078, 'recall': 0.637006482023353, 'f1': 0.7102496665463243, 'auc': 0.9056766247318528, 'prauc': 0.8170578986589668}
Test:      {'precision': 0.8155061019324085, 'recall': 0.629014396452774, 'f1': 0.7102219394365092, 'auc': 0.9040088938260736, 'prauc': 0.8226498535033857}


Epoch 008: 100%|██████████| 98/98 [00:04<00:00, 23.67it/s, loss=0.2140]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.43it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.37it/s]


Validation: {'precision': 0.840197693568038, 'recall': 0.6010606953411841, 'f1': 0.7007901016254507, 'auc': 0.9132012205520486, 'prauc': 0.82503289835586}
Test:      {'precision': 0.8468185388778726, 'recall': 0.5968992248028965, 'f1': 0.7002273416863686, 'auc': 0.9096237308866695, 'prauc': 0.8301596325994949}


Epoch 009: 100%|██████████| 98/98 [00:04<00:00, 23.59it/s, loss=0.2049]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.49it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.40it/s]


Validation: {'precision': 0.795950155757041, 'recall': 0.6022392457242061, 'f1': 0.6856759427599398, 'auc': 0.9025074662175929, 'prauc': 0.8035881319091398}
Test:      {'precision': 0.813622754484928, 'recall': 0.6018826135071879, 'f1': 0.6919159721921355, 'auc': 0.9005392823782392, 'prauc': 0.8131600980550772}


Epoch 010: 100%|██████████| 98/98 [00:04<00:00, 23.72it/s, loss=0.1921]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.75it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.53it/s]


Validation: {'precision': 0.7380352644789797, 'recall': 0.6906305244508508, 'f1': 0.7135464181366251, 'auc': 0.9031528841463001, 'prauc': 0.81200070115611}
Test:      {'precision': 0.7437097717920205, 'recall': 0.7037652270171442, 'f1': 0.7231863392386688, 'auc': 0.8996094979090005, 'prauc': 0.8117225846421472}


Epoch 011: 100%|██████████| 98/98 [00:04<00:00, 23.61it/s, loss=0.1676]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.85it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.51it/s]

Validation: {'precision': 0.768384879719805, 'recall': 0.6588096641092587, 'f1': 0.7093908579691346, 'auc': 0.9014836117497871, 'prauc': 0.8055963070438528}
Test:      {'precision': 0.777044854876141, 'recall': 0.6522702104061336, 'f1': 0.7092113135166757, 'auc': 0.895518963270005, 'prauc': 0.8035898384438117}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7602564102515369, 'recall': 0.6988803771320042, 'f1': 0.7282775510375341, 'auc': 0.9064318371602583, 'prauc': 0.8181003354321332}
Corresponding test performance:
{'precision': 0.7710396039556248, 'recall': 0.6899224806163349, 'f1': 0.728229100797248, 'auc': 0.9060661019761707, 'prauc': 0.8235721689873131}





In [16]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")


Final Metrics:
precision: 0.7168 ± 0.0402
recall: 0.7383 ± 0.0419
f1: 0.7251 ± 0.0035
auc: 0.9014 ± 0.0039
prauc: 0.8109 ± 0.0081
