In [1]:
import torch
import pickle
import numpy as np
from argparse import Namespace
from torch.utils.data import DataLoader
import torch.nn.functional as F
from heterogt.utils.tokenizer import EHRTokenizer
from heterogt.utils.dataset import FineTuneEHRDataset, batcher, expand_level3
from heterogt.utils.train import train_with_early_stopping
from heterogt.utils.seed import set_random_seed
from heterogt.model.model import HeteroGT

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
set_random_seed(123)

[INFO] Random seed set to 123


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
config = Namespace(
    dataset = "MIMIC-III",
    tasks = ["death", "readmission", "stay", "next_diag_6m", "next_diag_12m"], 
    task_index = 0,  # index of the task to train
    token_type = ["diag", "med", "lab", "pro"],
    special_tokens = ["[PAD]"],
    batch_size = 32,
    lr = 1e-3,
    epochs = 500,
    early_stop_patience = 5,
    group_code_thre = 5,  # if there are group_code_thre diag codes belongs to the same group ICD code, then the group code is generated
)

In [5]:
full_data_path = f"./data_process/{config.dataset}-processed/mimic.pkl"  # for tokenizer
curr_task = config.tasks[config.task_index]
print("Current task:", curr_task)
if curr_task == "next_diag_6m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_6m.pkl"
elif curr_task == "next_diag_12m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_downstream.pkl"

Current task: death


In [6]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
group_code_sentences = [expand_level3()[1]]
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_sentences = [[str(c)] for c in set(ehr_full_data["AGE"].values.tolist())] # important of [[]]
token_type_sentences = ["[PAD]"] + config.token_type
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
config.max_num_adms = max_admissions
print(f"Max admissions per patient: {config.max_num_adms}")

Max admissions per patient: 8


In [7]:
task_sentences = config.tasks
tokenizer = EHRTokenizer(age_sentences, group_code_sentences, diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, special_tokens=config.special_tokens)
config.label_vocab_size = len(tokenizer.diag_voc.id2word)  # only for diagnosis
config.global_vocab_size = len(tokenizer.vocab.id2word)
config.age_vocab_size = tokenizer.token_number("age")
config.group_code_vocab_size = tokenizer.token_number("group")
print(f"Age vocabulary size: {config.age_vocab_size}")
print(f"Group code vocabulary size: {config.group_code_vocab_size}")


Age vocabulary size: 18
Group code vocabulary size: 19


In [8]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))
# example label percentage
print("Percentage of DEATH in test dataset:",
      (test_data["DEATH"] == True).mean() * 100, "%")

print("Percentage of READMISSION in test dataset:",
      (test_data["READMISSION"] == 1).mean() * 100, "%")

print("Percentage of STAY>7 days in test dataset:",
      (test_data["STAY_DAYS"] > 7).mean() * 100, "%")

Percentage of DEATH in test dataset: 28.648477157360407 %
Percentage of READMISSION in test dataset: 40.1491116751269 %
Percentage of STAY>7 days in test dataset: 50.58692893401015 %


In [9]:
train_dataset = FineTuneEHRDataset(train_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                   max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)
val_dataset = FineTuneEHRDataset(val_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                 max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)
test_dataset = FineTuneEHRDataset(test_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                   max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)

In [10]:
num_group_code = []
for i in range(len(train_dataset)):
    input_ids, token_types, adm_index, age_ids, diag_group_codes, labels = train_dataset[i]
    count = (token_types[0] == 5).sum().item()
    num_group_code.append(count)
print("Mean group token numer per patient", np.mean(num_group_code))

Mean group token numer per patient 0.7971893963589908


In [11]:
train_dataloader = DataLoader(
    train_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=True,
    batch_size=config.batch_size,
)

val_dataloader = DataLoader(
    val_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

test_dataloader = DataLoader(
    test_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

In [12]:
for batch in train_dataloader:
    pass  # just to check if the dataloader works
for batch in val_dataloader:
    pass  # just to check if the dataloader works
for batch in test_dataloader:
    pass  # just to check if the dataloader works

In [13]:
if curr_task in ["death", "stay", "readmission"]:
    eval_metric = "f1"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "f1"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [14]:
input_ids, token_types, adm_index, age_ids, diag_code_group_dicts, task_index, labels = next(iter(train_dataloader))
print("Input IDs shape:", input_ids.shape)
print("Token Types shape:", token_types.shape)
print("Admission Index shape:", adm_index.shape)
print("Age IDs shape:", age_ids.shape)
print("Diag Code Group Dict number:", len(diag_code_group_dicts))
print("Task Index:", task_index)
print("Labels shape:", labels.shape)

Input IDs shape: torch.Size([32, 292])
Token Types shape: torch.Size([32, 292])
Admission Index shape: torch.Size([32, 292])
Age IDs shape: torch.Size([32, 8])
Diag Code Group Dict number: 32
Task Index: 0
Labels shape: torch.Size([32, 1])


# Model Walkthrough

In [15]:
final_metrics = {"precision":[],"recall":[],"f1":[],"auc":[],"prauc":[]}
for i in range(5):
    model = HeteroGT(tokenizer, d_model=128, num_heads=4, layer_types=['gnn', 'tf', 'tf'], max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, label_vocab_size=config.label_vocab_size).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 24.83it/s, loss=0.5179]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 41.56it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.00it/s]


Validation: {'precision': 0.6929601357023498, 'recall': 0.4814378314644583, 'f1': 0.5681502037813371, 'auc': 0.8326625798501098, 'prauc': 0.665543677332173}
Test:      {'precision': 0.729480737012316, 'recall': 0.4822812846041956, 'f1': 0.5806666618708758, 'auc': 0.8345001494450346, 'prauc': 0.6893846986988573}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 27.56it/s, loss=0.4276]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 39.56it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.06it/s]


Validation: {'precision': 0.796316359688014, 'recall': 0.43311726576055914, 'f1': 0.5610686977221578, 'auc': 0.8633019517361518, 'prauc': 0.7307299406325504}
Test:      {'precision': 0.8328075709691608, 'recall': 0.43853820597763826, 'f1': 0.5745375362819265, 'auc': 0.8706479365506116, 'prauc': 0.7555027594371827}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 27.52it/s, loss=0.3856]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.17it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 41.84it/s]


Validation: {'precision': 0.7732893652038476, 'recall': 0.552740129637285, 'f1': 0.6446735346527864, 'auc': 0.8741861059563547, 'prauc': 0.7559748685632753}
Test:      {'precision': 0.7932421560676328, 'recall': 0.5459579180479183, 'f1': 0.6467694277670898, 'auc': 0.8769400772190126, 'prauc': 0.7746180070863891}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 27.43it/s, loss=0.3318]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 41.50it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 41.80it/s]


Validation: {'precision': 0.6511867905023159, 'recall': 0.7436652916868376, 'f1': 0.6943603801625872, 'auc': 0.8800098770339926, 'prauc': 0.7648860106299795}
Test:      {'precision': 0.6777392166550434, 'recall': 0.7569213731962519, 'f1': 0.7151451689586517, 'auc': 0.8852691143153572, 'prauc': 0.7920071803228398}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 27.52it/s, loss=0.3055]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 41.99it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.06it/s]


Validation: {'precision': 0.6949339207010191, 'recall': 0.7436652916868376, 'f1': 0.7184742335442034, 'auc': 0.8992614068437422, 'prauc': 0.7955599915847544}
Test:      {'precision': 0.711855396062138, 'recall': 0.7414174972273455, 'f1': 0.7263357693404129, 'auc': 0.8996661860790104, 'prauc': 0.8103837491039155}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 27.89it/s, loss=0.2690]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.25it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.48it/s]


Validation: {'precision': 0.7463672390967876, 'recall': 0.6658809664073903, 'f1': 0.7038305773849489, 'auc': 0.893149385284931, 'prauc': 0.7947239255203736}
Test:      {'precision': 0.7756410256360536, 'recall': 0.6699889257991696, 'f1': 0.7189542433884475, 'auc': 0.8971784309699351, 'prauc': 0.8107307293042927}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 27.78it/s, loss=0.2473]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.71it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.63it/s]


Validation: {'precision': 0.6992609437140463, 'recall': 0.7248084855584868, 'f1': 0.7118055505530456, 'auc': 0.8924201045417366, 'prauc': 0.7957071879599537}
Test:      {'precision': 0.7106126914621959, 'recall': 0.7192691029860506, 'f1': 0.7149146895477073, 'auc': 0.8950711821225331, 'prauc': 0.8094098059597077}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 28.01it/s, loss=0.2152]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.79it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.77it/s]


Validation: {'precision': 0.6326034063229558, 'recall': 0.7660577489642543, 'f1': 0.6929637477070724, 'auc': 0.8868152837520364, 'prauc': 0.7809707674756927}
Test:      {'precision': 0.6522144522114116, 'recall': 0.7746400885892878, 'f1': 0.708175140566001, 'auc': 0.88796257177424, 'prauc': 0.7943924330071277}


Epoch 009: 100%|██████████| 98/98 [00:03<00:00, 28.29it/s, loss=0.2201]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.76it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.36it/s]


Validation: {'precision': 0.8032786885183195, 'recall': 0.6063641720647828, 'f1': 0.6910678258518255, 'auc': 0.8982334646226026, 'prauc': 0.8000833833325466}
Test:      {'precision': 0.8085106382919406, 'recall': 0.6101882613476735, 'f1': 0.6954875305934772, 'auc': 0.8957535944362129, 'prauc': 0.8064217897404746}


Epoch 010: 100%|██████████| 98/98 [00:03<00:00, 27.53it/s, loss=0.1733]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.35it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.12it/s]


Validation: {'precision': 0.7044757817246814, 'recall': 0.6770771950460985, 'f1': 0.6905048026901246, 'auc': 0.8858480191194442, 'prauc': 0.7822188156694617}
Test:      {'precision': 0.7273788674797002, 'recall': 0.6899224806163349, 'f1': 0.7081557210580068, 'auc': 0.8890590970171558, 'prauc': 0.7975327488451686}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6949339207010191, 'recall': 0.7436652916868376, 'f1': 0.7184742335442034, 'auc': 0.8992614068437422, 'prauc': 0.7955599915847544}
Corresponding test performance:
{'precision': 0.711855396062138, 'recall': 0.7414174972273455, 'f1': 0.7263357693404129, 'auc': 0.8996661860790104, 'prauc': 0.8103837491039155}


Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 27.89it/s, loss=0.5281]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.75it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.56it/s]


Validation: {'precision': 0.5983333333300093, 'recall': 0.6346493812573091, 'f1': 0.6159565291729624, 'auc': 0.8095391811795749, 'prauc': 0.6526854606712222}
Test:      {'precision': 0.612974161623898, 'recall': 0.6173864894760942, 'f1': 0.6151724087897738, 'auc': 0.8098305780789985, 'prauc': 0.6590756109660049}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 27.95it/s, loss=0.4258]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.52it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.43it/s]


Validation: {'precision': 0.7154065620491056, 'recall': 0.5910430170854977, 'f1': 0.6473055774870805, 'auc': 0.8601130570378275, 'prauc': 0.7178953326398702}
Test:      {'precision': 0.7208504801047954, 'recall': 0.5819490586900224, 'f1': 0.6439950930921063, 'auc': 0.861127401522616, 'prauc': 0.7290246799377784}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 27.78it/s, loss=0.3797]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.03it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.58it/s]


Validation: {'precision': 0.7696867061746852, 'recall': 0.535651149083467, 'f1': 0.6316886678457045, 'auc': 0.8714287885904672, 'prauc': 0.7447373680911765}
Test:      {'precision': 0.7703703703640299, 'recall': 0.5182724252462998, 'f1': 0.6196623586430637, 'auc': 0.8703176556382632, 'prauc': 0.7479425065971447}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 28.18it/s, loss=0.3393]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.54it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 39.83it/s]


Validation: {'precision': 0.8785425101096013, 'recall': 0.3836181496736381, 'f1': 0.5340442943698422, 'auc': 0.8640256118185123, 'prauc': 0.7481527310196737}
Test:      {'precision': 0.8706666666550578, 'recall': 0.3615725359891386, 'f1': 0.5109546124378676, 'auc': 0.8585928679019148, 'prauc': 0.7442103140114997}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 27.99it/s, loss=0.3116]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.51it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.35it/s]


Validation: {'precision': 0.655987394954538, 'recall': 0.736004714197195, 'f1': 0.693696190513919, 'auc': 0.8828297881228027, 'prauc': 0.7696151335490495}
Test:      {'precision': 0.6648460774544943, 'recall': 0.7414174972273455, 'f1': 0.701047115430002, 'auc': 0.8812005532666912, 'prauc': 0.7674801361074888}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 27.77it/s, loss=0.2749]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 41.36it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.06it/s]


Validation: {'precision': 0.754797441359241, 'recall': 0.6258102533846447, 'f1': 0.6842783455546988, 'auc': 0.8867161557336936, 'prauc': 0.7809299206391479}
Test:      {'precision': 0.7556623198300916, 'recall': 0.609634551491641, 'f1': 0.6748391001710522, 'auc': 0.8825149063420336, 'prauc': 0.7772132011597019}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 27.73it/s, loss=0.2609]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.35it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.04it/s]


Validation: {'precision': 0.7615844544038746, 'recall': 0.6004714201496731, 'f1': 0.6714991713423047, 'auc': 0.8703688469252622, 'prauc': 0.7643837650705082}
Test:      {'precision': 0.7790096082721434, 'recall': 0.5836101882581196, 'f1': 0.6672997735094823, 'auc': 0.8706864057213568, 'prauc': 0.7683006479228597}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 27.62it/s, loss=0.2406]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.19it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 39.43it/s]


Validation: {'precision': 0.6979418886156299, 'recall': 0.6794342958121424, 'f1': 0.6885637453700365, 'auc': 0.8808742452506374, 'prauc': 0.7708544612191668}
Test:      {'precision': 0.709006928402373, 'recall': 0.6799557032077522, 'f1': 0.6941774964114911, 'auc': 0.8796130998543957, 'prauc': 0.7781069152289667}


Epoch 009: 100%|██████████| 98/98 [00:03<00:00, 27.24it/s, loss=0.2127]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 41.49it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.06it/s]


Validation: {'precision': 0.6830537883129957, 'recall': 0.6959340011744495, 'f1': 0.6894337369695581, 'auc': 0.8831501019191874, 'prauc': 0.7674218360956243}
Test:      {'precision': 0.707234997191771, 'recall': 0.6982281284568205, 'f1': 0.7027026976989923, 'auc': 0.8820852210924802, 'prauc': 0.7749447847913659}


Epoch 010: 100%|██████████| 98/98 [00:03<00:00, 27.90it/s, loss=0.2015]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.56it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.41it/s]


Validation: {'precision': 0.7392197125206077, 'recall': 0.636417206831842, 'f1': 0.6839771957835665, 'auc': 0.8833880219374393, 'prauc': 0.775337237124335}
Test:      {'precision': 0.7572293207750018, 'recall': 0.6234772978924503, 'f1': 0.6838748811648449, 'auc': 0.8812718289462476, 'prauc': 0.7783965015922621}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.655987394954538, 'recall': 0.736004714197195, 'f1': 0.693696190513919, 'auc': 0.8828297881228027, 'prauc': 0.7696151335490495}
Corresponding test performance:
{'precision': 0.6648460774544943, 'recall': 0.7414174972273455, 'f1': 0.701047115430002, 'auc': 0.8812005532666912, 'prauc': 0.7674801361074888}


Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 28.00it/s, loss=0.5314]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.92it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.54it/s]


Validation: {'precision': 0.7319004524804084, 'recall': 0.38126104890759427, 'f1': 0.501356059033485, 'auc': 0.8371065425641786, 'prauc': 0.6667637388232126}
Test:      {'precision': 0.7758420441257161, 'recall': 0.3698781838296242, 'f1': 0.5009373784511414, 'auc': 0.8436729231694953, 'prauc': 0.6917131623481277}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 27.84it/s, loss=0.4213]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 41.76it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.08it/s]


Validation: {'precision': 0.6694711538421306, 'recall': 0.6564525633432148, 'f1': 0.6628979420361091, 'auc': 0.8698633706770891, 'prauc': 0.7391096829178843}
Test:      {'precision': 0.6894197952179214, 'recall': 0.6710963455112343, 'f1': 0.6801346751317705, 'auc': 0.8712512562492399, 'prauc': 0.7567372127691896}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 25.69it/s, loss=0.3631]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.55it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.30it/s]


Validation: {'precision': 0.8325074330938305, 'recall': 0.49499116086921047, 'f1': 0.6208425673807106, 'auc': 0.8875565722706485, 'prauc': 0.7791286550957287}
Test:      {'precision': 0.8287037036960306, 'recall': 0.4955703211489725, 'f1': 0.6202356155477323, 'auc': 0.8901918140559086, 'prauc': 0.7861586214977657}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 27.94it/s, loss=0.3258]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.55it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.34it/s]


Validation: {'precision': 0.6449589966201401, 'recall': 0.78786093105016, 'f1': 0.7092838146746197, 'auc': 0.8980250530737286, 'prauc': 0.7964530370088311}
Test:      {'precision': 0.6738418343440625, 'recall': 0.797342192686615, 'f1': 0.7304083135707135, 'auc': 0.9014445696795474, 'prauc': 0.8071335384274342}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 27.96it/s, loss=0.2854]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.19it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 41.61it/s]


Validation: {'precision': 0.8131672597792423, 'recall': 0.5385975250410219, 'f1': 0.64799715932789, 'auc': 0.892101770750873, 'prauc': 0.7879041037526906}
Test:      {'precision': 0.8213378492733164, 'recall': 0.5370985603514004, 'f1': 0.6494810799149267, 'auc': 0.8967367433391091, 'prauc': 0.8040404371448873}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 27.97it/s, loss=0.2548]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.45it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.33it/s]


Validation: {'precision': 0.6937745372927999, 'recall': 0.7289334118990635, 'f1': 0.710919535228853, 'auc': 0.8931065277460728, 'prauc': 0.7938691177179179}
Test:      {'precision': 0.7240085744870096, 'recall': 0.748062015499734, 'f1': 0.7358387749537544, 'auc': 0.8985750772953355, 'prauc': 0.8030812291725493}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 25.81it/s, loss=0.2512]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.81it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.53it/s]


Validation: {'precision': 0.7255871446185069, 'recall': 0.6918090748338727, 'f1': 0.7082956209412513, 'auc': 0.8908090187590832, 'prauc': 0.7949019258655265}
Test:      {'precision': 0.7531760435526124, 'recall': 0.6893687707603026, 'f1': 0.7198612265754383, 'auc': 0.8938838696366591, 'prauc': 0.7999599157097173}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 28.11it/s, loss=0.2301]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.52it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 41.60it/s]


Validation: {'precision': 0.6498266468516601, 'recall': 0.7731290512623858, 'f1': 0.7061356247431076, 'auc': 0.8893570998718618, 'prauc': 0.7897495496050713}
Test:      {'precision': 0.6666666666634601, 'recall': 0.7674418604608669, 'f1': 0.7135135085345301, 'auc': 0.8880059034481672, 'prauc': 0.790935656555287}


Epoch 009: 100%|██████████| 98/98 [00:03<00:00, 27.23it/s, loss=0.1988]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.64it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.43it/s]


Validation: {'precision': 0.784012298226103, 'recall': 0.6010606953411841, 'f1': 0.6804536308398684, 'auc': 0.8848589105549751, 'prauc': 0.7817946391416215}
Test:      {'precision': 0.80146520145933, 'recall': 0.6057585824994144, 'f1': 0.6900031486716669, 'auc': 0.8887063501090924, 'prauc': 0.7915073995195633}


Epoch 010: 100%|██████████| 98/98 [00:03<00:00, 27.90it/s, loss=0.1650]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.52it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.41it/s]


Validation: {'precision': 0.7775263951675904, 'recall': 0.6075427224478047, 'f1': 0.6821038653982847, 'auc': 0.8794347173657087, 'prauc': 0.7852198487075581}
Test:      {'precision': 0.7958452722006029, 'recall': 0.6151716500519647, 'f1': 0.6939412817734583, 'auc': 0.885732221580454, 'prauc': 0.7927522626985817}


Epoch 011: 100%|██████████| 98/98 [00:03<00:00, 27.74it/s, loss=0.1629]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 40.04it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 41.92it/s]


Validation: {'precision': 0.683827644092424, 'recall': 0.720094284026399, 'f1': 0.7014925323127424, 'auc': 0.8830642590991793, 'prauc': 0.7774363052729458}
Test:      {'precision': 0.7021960364183064, 'recall': 0.7259136212584391, 'f1': 0.7138578768379495, 'auc': 0.8875657545841711, 'prauc': 0.7901127769120115}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6937745372927999, 'recall': 0.7289334118990635, 'f1': 0.710919535228853, 'auc': 0.8931065277460728, 'prauc': 0.7938691177179179}
Corresponding test performance:
{'precision': 0.7240085744870096, 'recall': 0.748062015499734, 'f1': 0.7358387749537544, 'auc': 0.8985750772953355, 'prauc': 0.8030812291725493}


Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 27.45it/s, loss=0.5321]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 41.20it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 41.73it/s]


Validation: {'precision': 0.7551020408082159, 'recall': 0.41426045963220826, 'f1': 0.535007605770797, 'auc': 0.832380077772062, 'prauc': 0.6711370882281571}
Test:      {'precision': 0.761652542364813, 'recall': 0.3981173864872751, 'f1': 0.5229090863965568, 'auc': 0.835156772026647, 'prauc': 0.6856043044535105}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 26.99it/s, loss=0.4165]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.68it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.33it/s]


Validation: {'precision': 0.727339181281233, 'recall': 0.58632881555341, 'f1': 0.6492659004367344, 'auc': 0.864438283291781, 'prauc': 0.7295934647194035}
Test:      {'precision': 0.7403433476341893, 'recall': 0.5730897009935045, 'f1': 0.6460674108073828, 'auc': 0.868908822279571, 'prauc': 0.7433971159896074}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 27.80it/s, loss=0.3703]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.47it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.29it/s]


Validation: {'precision': 0.8069087688148193, 'recall': 0.536829699466489, 'f1': 0.6447275252752723, 'auc': 0.886006419561126, 'prauc': 0.7774434487334951}
Test:      {'precision': 0.828003457209784, 'recall': 0.5304540420790119, 'f1': 0.6466419122115543, 'auc': 0.8872238405945894, 'prauc': 0.7827500777182649}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 27.82it/s, loss=0.3260]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 39.56it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 41.52it/s]


Validation: {'precision': 0.8712753277607714, 'recall': 0.4307601649945153, 'f1': 0.5764984182807168, 'auc': 0.8757897188149449, 'prauc': 0.7652011631860911}
Test:      {'precision': 0.8605714285615934, 'recall': 0.41694352159237574, 'f1': 0.5617306930996844, 'auc': 0.877701828350437, 'prauc': 0.7736825421159137}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 27.91it/s, loss=0.2960]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.58it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.46it/s]


Validation: {'precision': 0.616067465598686, 'recall': 0.8179139658172192, 'f1': 0.7027848052220902, 'auc': 0.896484033938061, 'prauc': 0.7931659304890607}
Test:      {'precision': 0.6504897595696773, 'recall': 0.8089700996632948, 'f1': 0.7211253652429596, 'auc': 0.8947151114781173, 'prauc': 0.7935435493855152}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 28.04it/s, loss=0.2734]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.58it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.38it/s]


Validation: {'precision': 0.7423750811113409, 'recall': 0.6741308190885438, 'f1': 0.7066090129195327, 'auc': 0.8976187687150441, 'prauc': 0.7960928934652882}
Test:      {'precision': 0.7664326738942794, 'recall': 0.6650055370948782, 'f1': 0.7121256991418415, 'auc': 0.895787201103776, 'prauc': 0.7987790561990205}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 27.81it/s, loss=0.2443]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.53it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 41.72it/s]


Validation: {'precision': 0.6894618834042071, 'recall': 0.7248084855584868, 'f1': 0.7066934738844408, 'auc': 0.8966372608169297, 'prauc': 0.7910262961091186}
Test:      {'precision': 0.7251782775604763, 'recall': 0.7320044296747952, 'f1': 0.7285753601104512, 'auc': 0.8965711104775494, 'prauc': 0.7973398684431695}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 27.36it/s, loss=0.2142]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.85it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.08it/s]


Validation: {'precision': 0.6858552631541346, 'recall': 0.737183264580217, 'f1': 0.710593576371398, 'auc': 0.8967957890009033, 'prauc': 0.790212264707172}
Test:      {'precision': 0.7131952017409313, 'recall': 0.7242524916903419, 'f1': 0.7186813136776657, 'auc': 0.8933294827232183, 'prauc': 0.7878719087076996}


Epoch 009: 100%|██████████| 98/98 [00:03<00:00, 27.89it/s, loss=0.2015]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.66it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.49it/s]


Validation: {'precision': 0.7148711943752057, 'recall': 0.7195050088348881, 'f1': 0.717180611735928, 'auc': 0.8985242060784636, 'prauc': 0.8038086640617153}
Test:      {'precision': 0.7387802071303868, 'recall': 0.710963455145565, 'f1': 0.7246049611377064, 'auc': 0.8958559532057314, 'prauc': 0.801151087323765}


Epoch 010: 100%|██████████| 98/98 [00:03<00:00, 27.87it/s, loss=0.1751]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.55it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.14it/s]


Validation: {'precision': 0.8417602996175866, 'recall': 0.5297583971683574, 'f1': 0.650271242993649, 'auc': 0.8886728482864841, 'prauc': 0.7853580094149122}
Test:      {'precision': 0.8255093002584101, 'recall': 0.5160575858221702, 'f1': 0.6350936920249054, 'auc': 0.8841304268613046, 'prauc': 0.7805663935803782}


Epoch 011: 100%|██████████| 98/98 [00:03<00:00, 27.58it/s, loss=0.1707]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.09it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.40it/s]


Validation: {'precision': 0.7039039038996763, 'recall': 0.6906305244508508, 'f1': 0.6972040402074893, 'auc': 0.88716861893082, 'prauc': 0.781387391329799}
Test:      {'precision': 0.7212049616023556, 'recall': 0.6760797342155256, 'f1': 0.6979136846268331, 'auc': 0.8789759888361852, 'prauc': 0.7805450375804965}


Epoch 012: 100%|██████████| 98/98 [00:03<00:00, 27.91it/s, loss=0.1373]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.63it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 40.01it/s]


Validation: {'precision': 0.7955482275284786, 'recall': 0.5686505598080811, 'f1': 0.6632302356835867, 'auc': 0.8871701518383202, 'prauc': 0.7767023602688307}
Test:      {'precision': 0.8012718600890202, 'recall': 0.5581395348806305, 'f1': 0.6579634416308399, 'auc': 0.8805171561422402, 'prauc': 0.7742768036659889}


Epoch 013: 100%|██████████| 98/98 [00:03<00:00, 27.78it/s, loss=0.1384]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.49it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.20it/s]


Validation: {'precision': 0.7648745519658432, 'recall': 0.6287566293421994, 'f1': 0.6901681709811387, 'auc': 0.8938097490617649, 'prauc': 0.7897731518175699}
Test:      {'precision': 0.7745987438885235, 'recall': 0.6146179401959324, 'f1': 0.6853967224470713, 'auc': 0.8853568240246557, 'prauc': 0.7876099396093054}


Epoch 014: 100%|██████████| 98/98 [00:03<00:00, 27.55it/s, loss=0.1228]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 41.65it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 41.94it/s]


Validation: {'precision': 0.6995277449781611, 'recall': 0.6982911019404933, 'f1': 0.6989088714335109, 'auc': 0.8943654280305677, 'prauc': 0.7821291291313173}
Test:      {'precision': 0.7218695903016626, 'recall': 0.6926910298964968, 'f1': 0.706979367702286, 'auc': 0.8879935317628556, 'prauc': 0.7879984588927723}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7148711943752057, 'recall': 0.7195050088348881, 'f1': 0.717180611735928, 'auc': 0.8985242060784636, 'prauc': 0.8038086640617153}
Corresponding test performance:
{'precision': 0.7387802071303868, 'recall': 0.710963455145565, 'f1': 0.7246049611377064, 'auc': 0.8958559532057314, 'prauc': 0.801151087323765}


Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 27.66it/s, loss=0.5342]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.00it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 41.89it/s]


Validation: {'precision': 0.6605744125283253, 'recall': 0.5963464938090964, 'f1': 0.6268194437549153, 'auc': 0.8324511663573814, 'prauc': 0.687452449809661}
Test:      {'precision': 0.6784586228636863, 'recall': 0.594684385378767, 'f1': 0.6338152797626712, 'auc': 0.8381446879060805, 'prauc': 0.6998002496189676}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 27.75it/s, loss=0.4164]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.60it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 39.94it/s]


Validation: {'precision': 0.8945518453269851, 'recall': 0.2999410724790811, 'f1': 0.44924977558189333, 'auc': 0.8735152673116033, 'prauc': 0.7496201183718907}
Test:      {'precision': 0.8977072310247318, 'recall': 0.2818383167204771, 'f1': 0.42899283243193337, 'auc': 0.8720397203728296, 'prauc': 0.7587833117690215}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 27.66it/s, loss=0.3737]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 41.73it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.38it/s]


Validation: {'precision': 0.6243996157510836, 'recall': 0.7660577489642543, 'f1': 0.688012696821211, 'auc': 0.8826577192559114, 'prauc': 0.76951503482804}
Test:      {'precision': 0.6462264150912914, 'recall': 0.758582502764349, 'f1': 0.6979113551914443, 'auc': 0.8833886181479945, 'prauc': 0.7861577722146899}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 26.97it/s, loss=0.3264]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.71it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.52it/s]


Validation: {'precision': 0.8252595155637954, 'recall': 0.5621685327014605, 'f1': 0.6687697112634281, 'auc': 0.8899579996119188, 'prauc': 0.7851045203373097}
Test:      {'precision': 0.8288814691082732, 'recall': 0.5498338870401449, 'f1': 0.6611185038555475, 'auc': 0.894728837278239, 'prauc': 0.8018005394695285}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 28.16it/s, loss=0.3017]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.68it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.48it/s]


Validation: {'precision': 0.8286713286640851, 'recall': 0.5586328815523947, 'f1': 0.6673706393241322, 'auc': 0.9022734423392373, 'prauc': 0.7999380018449589}
Test:      {'precision': 0.8495188101412992, 'recall': 0.5376522702074328, 'f1': 0.6585283099312019, 'auc': 0.9078548876404771, 'prauc': 0.8173654454672116}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 27.86it/s, loss=0.2642]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.02it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.64it/s]


Validation: {'precision': 0.7879699248061055, 'recall': 0.6175604007034912, 'f1': 0.6924347489506544, 'auc': 0.9035861221285288, 'prauc': 0.7995377679582873}
Test:      {'precision': 0.8156547183554086, 'recall': 0.6173864894760942, 'f1': 0.7028049115741043, 'auc': 0.908250412266409, 'prauc': 0.8151789138483638}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 25.95it/s, loss=0.2428]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.63it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.46it/s]


Validation: {'precision': 0.753755715213888, 'recall': 0.6800235710036534, 'f1': 0.7149937992219279, 'auc': 0.8997610069464981, 'prauc': 0.7978868381700873}
Test:      {'precision': 0.7761479591787236, 'recall': 0.6738648947913961, 'f1': 0.7213989280377933, 'auc': 0.9047974810510098, 'prauc': 0.8130904490186842}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 27.85it/s, loss=0.2161]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.62it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.35it/s]


Validation: {'precision': 0.7631403858897995, 'recall': 0.6758986446630767, 'f1': 0.7168749950138966, 'auc': 0.9018565553703435, 'prauc': 0.8107760932998641}
Test:      {'precision': 0.790426908144952, 'recall': 0.676633444071558, 'f1': 0.7291169401331304, 'auc': 0.9073736844774618, 'prauc': 0.8230447731478187}


Epoch 009: 100%|██████████| 98/98 [00:03<00:00, 27.95it/s, loss=0.2191]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 41.82it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 41.51it/s]


Validation: {'precision': 0.7687411598248322, 'recall': 0.6405421331724188, 'f1': 0.6988106668465905, 'auc': 0.8885851532032465, 'prauc': 0.7904807484723617}
Test:      {'precision': 0.7995795374856371, 'recall': 0.6317829457329359, 'f1': 0.7058459585657381, 'auc': 0.8895513177506724, 'prauc': 0.8034025852161994}


Epoch 010: 100%|██████████| 98/98 [00:03<00:00, 27.95it/s, loss=0.1958]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.64it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.45it/s]


Validation: {'precision': 0.8208232445454333, 'recall': 0.5992928697666512, 'f1': 0.6927792866700859, 'auc': 0.8980629925343572, 'prauc': 0.8037397581264404}
Test:      {'precision': 0.8392434988113535, 'recall': 0.5897009966744756, 'f1': 0.6926829219772485, 'auc': 0.8988481160816152, 'prauc': 0.8124815389697394}


Epoch 011: 100%|██████████| 98/98 [00:03<00:00, 25.92it/s, loss=0.1730]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.52it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.18it/s]


Validation: {'precision': 0.7544320420173708, 'recall': 0.6770771950460985, 'f1': 0.7136645912834593, 'auc': 0.8958202211193521, 'prauc': 0.8027934180297791}
Test:      {'precision': 0.7716635041064411, 'recall': 0.6755260243594933, 'f1': 0.7204015302997716, 'auc': 0.8974360205372438, 'prauc': 0.8126758324251184}


Epoch 012: 100%|██████████| 98/98 [00:03<00:00, 27.91it/s, loss=0.1552]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.24it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 41.54it/s]


Validation: {'precision': 0.7409261576924848, 'recall': 0.6977018267489824, 'f1': 0.718664638399241, 'auc': 0.9017475911955414, 'prauc': 0.803423423603917}
Test:      {'precision': 0.7658150851534926, 'recall': 0.6971207087447557, 'f1': 0.7298550674705618, 'auc': 0.9033789842366264, 'prauc': 0.8167096637889677}


Epoch 013: 100%|██████████| 98/98 [00:03<00:00, 27.19it/s, loss=0.1406]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.30it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.00it/s]


Validation: {'precision': 0.7812499999940454, 'recall': 0.604007071298739, 'f1': 0.681289460015845, 'auc': 0.8874193131782397, 'prauc': 0.7833804078133041}
Test:      {'precision': 0.7969653179133168, 'recall': 0.6107419712037058, 'f1': 0.6915360452399054, 'auc': 0.8864020775568027, 'prauc': 0.7963817724002626}


Epoch 014: 100%|██████████| 98/98 [00:03<00:00, 27.65it/s, loss=0.1349]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.00it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 41.80it/s]


Validation: {'precision': 0.6493638676811737, 'recall': 0.7519151443679911, 'f1': 0.696886942046458, 'auc': 0.8822587162078526, 'prauc': 0.7735112672750998}
Test:      {'precision': 0.6648648648615978, 'recall': 0.7491694352117987, 'f1': 0.7045040304215504, 'auc': 0.8856962144366365, 'prauc': 0.7922744530799964}


Epoch 015: 100%|██████████| 98/98 [00:03<00:00, 27.63it/s, loss=0.1326]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 39.48it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 41.45it/s]


Validation: {'precision': 0.703379224025636, 'recall': 0.6623453152583245, 'f1': 0.6822458220109948, 'auc': 0.8717496772271645, 'prauc': 0.7641850019981075}
Test:      {'precision': 0.7316926770664365, 'recall': 0.6749723145034608, 'f1': 0.7021889350962507, 'auc': 0.8712571035631931, 'prauc': 0.7802352988880039}


Epoch 016: 100%|██████████| 98/98 [00:03<00:00, 27.07it/s, loss=0.1267]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 41.70it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 41.84it/s]


Validation: {'precision': 0.6811926605465528, 'recall': 0.7000589275150262, 'f1': 0.6904969435583842, 'auc': 0.8858777192022596, 'prauc': 0.7781808076301601}
Test:      {'precision': 0.6988082340157161, 'recall': 0.7142857142817592, 'f1': 0.7064622074830399, 'auc': 0.887971558172526, 'prauc': 0.7934880910039235}


Epoch 017: 100%|██████████| 98/98 [00:03<00:00, 28.06it/s, loss=0.0982]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 42.61it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 42.42it/s]

Validation: {'precision': 0.6701256144146908, 'recall': 0.7230406599839538, 'f1': 0.695578226295787, 'auc': 0.8852626400678261, 'prauc': 0.7699822014049037}
Test:      {'precision': 0.6796671866839331, 'recall': 0.7236987818343096, 'f1': 0.7009922181172795, 'auc': 0.8850603959825629, 'prauc': 0.7819207809320521}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7409261576924848, 'recall': 0.6977018267489824, 'f1': 0.718664638399241, 'auc': 0.9017475911955414, 'prauc': 0.803423423603917}
Corresponding test performance:
{'precision': 0.7658150851534926, 'recall': 0.6971207087447557, 'f1': 0.7298550674705618, 'auc': 0.9033789842366264, 'prauc': 0.8167096637889677}





In [16]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")


Final Metrics:
precision: 0.7211 ± 0.0334
recall: 0.7278 ± 0.0200
f1: 0.7235 ± 0.0119
auc: 0.8957 ± 0.0077
prauc: 0.7998 ± 0.0171
