In [1]:
import torch
import pickle
import numpy as np
from argparse import Namespace
from torch.utils.data import DataLoader
import torch.nn.functional as F
from heterogt.utils.tokenizer import EHRTokenizer
from heterogt.utils.dataset import FineTuneEHRDataset, batcher, expand_level3
from heterogt.utils.train import train_with_early_stopping
from heterogt.utils.seed import set_random_seed
from heterogt.model.model import HeteroGT

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
set_random_seed(123)

[INFO] Random seed set to 123


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
config = Namespace(
    dataset = "MIMIC-III",
    tasks = ["death", "readmission", "stay", "next_diag_6m", "next_diag_12m"], 
    task_index = 0,  # index of the task to train
    token_type = ["diag", "med", "lab", "pro"],
    special_tokens = ["[PAD]"],
    batch_size = 32,
    lr = 1e-3,
    epochs = 500,
    early_stop_patience = 5,
    group_code_thre = 5,  # if there are group_code_thre diag codes belongs to the same group ICD code, then the group code is generated
)

In [5]:
full_data_path = f"./data_process/{config.dataset}-processed/mimic.pkl"  # for tokenizer
curr_task = config.tasks[config.task_index]
print("Current task:", curr_task)
if curr_task == "next_diag_6m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_6m.pkl"
elif curr_task == "next_diag_12m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_downstream.pkl"

Current task: death


In [6]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
group_code_sentences = [expand_level3()[1]]
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_sentences = [[str(c)] for c in set(ehr_full_data["AGE"].values.tolist())] # important of [[]]
token_type_sentences = ["[PAD]"] + config.token_type
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
config.max_num_adms = max_admissions
print(f"Max admissions per patient: {config.max_num_adms}")

Max admissions per patient: 8


In [7]:
task_sentences = config.tasks
tokenizer = EHRTokenizer(age_sentences, group_code_sentences, diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, special_tokens=config.special_tokens)
config.label_vocab_size = len(tokenizer.diag_voc.id2word)  # only for diagnosis
config.global_vocab_size = len(tokenizer.vocab.id2word)
config.age_vocab_size = tokenizer.token_number("age")
config.group_code_vocab_size = tokenizer.token_number("group")
print(f"Age vocabulary size: {config.age_vocab_size}")
print(f"Group code vocabulary size: {config.group_code_vocab_size}")


Age vocabulary size: 18
Group code vocabulary size: 19


In [8]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))
# example label percentage
print("Percentage of DEATH in test dataset:",
      (test_data["DEATH"] == True).mean() * 100, "%")

print("Percentage of READMISSION in test dataset:",
      (test_data["READMISSION"] == 1).mean() * 100, "%")

print("Percentage of STAY>7 days in test dataset:",
      (test_data["STAY_DAYS"] > 7).mean() * 100, "%")

Percentage of DEATH in test dataset: 28.648477157360407 %
Percentage of READMISSION in test dataset: 40.1491116751269 %
Percentage of STAY>7 days in test dataset: 50.58692893401015 %


In [9]:
train_dataset = FineTuneEHRDataset(train_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                   max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)
val_dataset = FineTuneEHRDataset(val_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                 max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)
test_dataset = FineTuneEHRDataset(test_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                   max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)

In [10]:
num_group_code = []
for i in range(len(train_dataset)):
    input_ids, token_types, adm_index, age_ids, diag_group_codes, labels = train_dataset[i]
    count = (token_types[0] == 5).sum().item()
    num_group_code.append(count)
print("Mean group token numer per patient", np.mean(num_group_code))

Mean group token numer per patient 0.7971893963589908


In [11]:
train_dataloader = DataLoader(
    train_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=True,
    batch_size=config.batch_size,
)

val_dataloader = DataLoader(
    val_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

test_dataloader = DataLoader(
    test_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

In [12]:
for batch in train_dataloader:
    pass  # just to check if the dataloader works
for batch in val_dataloader:
    pass  # just to check if the dataloader works
for batch in test_dataloader:
    pass  # just to check if the dataloader works

In [13]:
if curr_task in ["death", "stay", "readmission"]:
    eval_metric = "f1"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "f1"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [14]:
input_ids, token_types, adm_index, age_ids, diag_code_group_dicts, task_index, labels = next(iter(train_dataloader))
print("Input IDs shape:", input_ids.shape)
print("Token Types shape:", token_types.shape)
print("Admission Index shape:", adm_index.shape)
print("Age IDs shape:", age_ids.shape)
print("Diag Code Group Dict number:", len(diag_code_group_dicts))
print("Task Index:", task_index)
print("Labels shape:", labels.shape)

Input IDs shape: torch.Size([32, 292])
Token Types shape: torch.Size([32, 292])
Admission Index shape: torch.Size([32, 292])
Age IDs shape: torch.Size([32, 8])
Diag Code Group Dict number: 32
Task Index: 0
Labels shape: torch.Size([32, 1])


# Model Walkthrough

In [15]:
final_metrics = {"precision":[],"recall":[],"f1":[],"auc":[],"prauc":[]}
for i in range(5):
    model = HeteroGT(tokenizer, d_model=128, num_heads=4, layer_types=['gnn', 'tf'], max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, label_vocab_size=config.label_vocab_size).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

Epoch 001: 100%|██████████| 98/98 [00:04<00:00, 20.42it/s, loss=0.5437]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 39.64it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 37.64it/s]


Validation: {'precision': 0.7983425414254373, 'recall': 0.3406010606933377, 'f1': 0.47748863686108806, 'auc': 0.8402554539252076, 'prauc': 0.6859472704925622}
Test:      {'precision': 0.8275862068851367, 'recall': 0.3322259136194229, 'f1': 0.47412089673805385, 'auc': 0.841449589752453, 'prauc': 0.7037269256107657}


Epoch 002: 100%|██████████| 98/98 [00:04<00:00, 22.20it/s, loss=0.4093]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 36.36it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 39.01it/s]


Validation: {'precision': 0.7713815789410249, 'recall': 0.552740129637285, 'f1': 0.6440096072156671, 'auc': 0.8752920987177102, 'prauc': 0.7459394182362864}
Test:      {'precision': 0.8032520325137947, 'recall': 0.547065337759983, 'f1': 0.6508563851625119, 'auc': 0.8747025871471361, 'prauc': 0.7622956601572937}


Epoch 003: 100%|██████████| 98/98 [00:04<00:00, 22.30it/s, loss=0.3498]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 39.09it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 39.06it/s]


Validation: {'precision': 0.753643303256394, 'recall': 0.6399528579809078, 'f1': 0.6921606068835502, 'auc': 0.8877966639078589, 'prauc': 0.7769316638603886}
Test:      {'precision': 0.7654656696073049, 'recall': 0.6234772978924503, 'f1': 0.6872139102143691, 'auc': 0.884438241777938, 'prauc': 0.7897292450568626}


Epoch 004: 100%|██████████| 98/98 [00:04<00:00, 22.68it/s, loss=0.3165]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 39.03it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 39.59it/s]


Validation: {'precision': 0.6490099009868862, 'recall': 0.7725397760708749, 'f1': 0.7054075817974798, 'auc': 0.8961793047012613, 'prauc': 0.7838599598671249}
Test:      {'precision': 0.6708922476807856, 'recall': 0.7619047619005432, 'f1': 0.71350790271676, 'auc': 0.8960023822572554, 'prauc': 0.8039844753149333}


Epoch 005: 100%|██████████| 98/98 [00:04<00:00, 21.88it/s, loss=0.2794]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 39.30it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 39.38it/s]


Validation: {'precision': 0.6906392694024507, 'recall': 0.7130229817282674, 'f1': 0.7016526479400848, 'auc': 0.8899183995014984, 'prauc': 0.7799375316990136}
Test:      {'precision': 0.7168338907429386, 'recall': 0.7120708748576298, 'f1': 0.7144444394405309, 'auc': 0.8902209890750017, 'prauc': 0.7995273659504942}


Epoch 006: 100%|██████████| 98/98 [00:04<00:00, 23.35it/s, loss=0.2641]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 37.17it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 39.56it/s]


Validation: {'precision': 0.7020364415825187, 'recall': 0.7719505008793639, 'f1': 0.7353353865311181, 'auc': 0.906874975170092, 'prauc': 0.8065777559301819}
Test:      {'precision': 0.7173800928276567, 'recall': 0.7702104097410288, 'f1': 0.742857137859482, 'auc': 0.9057057843353044, 'prauc': 0.8236435968745119}


Epoch 007: 100%|██████████| 98/98 [00:04<00:00, 22.10it/s, loss=0.2327]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 39.19it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.83it/s]


Validation: {'precision': 0.7041522491308873, 'recall': 0.7195050088348881, 'f1': 0.7117458416883592, 'auc': 0.8930700573217986, 'prauc': 0.7921187272772959}
Test:      {'precision': 0.743367935405171, 'recall': 0.7137320044257269, 'f1': 0.7282485825685756, 'auc': 0.8953879218867793, 'prauc': 0.8093861733226053}


Epoch 008: 100%|██████████| 98/98 [00:04<00:00, 23.35it/s, loss=0.2344]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 39.38it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 39.20it/s]


Validation: {'precision': 0.7153891859731706, 'recall': 0.7094873305792017, 'f1': 0.7124260304988289, 'auc': 0.9008311679950374, 'prauc': 0.7943286421139275}
Test:      {'precision': 0.7382319173321572, 'recall': 0.7120708748576298, 'f1': 0.7249154403188486, 'auc': 0.9004150115690644, 'prauc': 0.8077780217090972}


Epoch 009: 100%|██████████| 98/98 [00:04<00:00, 22.76it/s, loss=0.1965]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.18it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.62it/s]


Validation: {'precision': 0.7542043984427284, 'recall': 0.687094873301785, 'f1': 0.7190872598846045, 'auc': 0.9025934367798927, 'prauc': 0.7967304179305411}
Test:      {'precision': 0.7792536369337176, 'recall': 0.6821705426318817, 'f1': 0.7274874470402082, 'auc': 0.9036580549888791, 'prauc': 0.8164846440355277}


Epoch 010: 100%|██████████| 98/98 [00:04<00:00, 21.84it/s, loss=0.1912]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 39.47it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 36.78it/s]


Validation: {'precision': 0.7157534246534489, 'recall': 0.7389510901547499, 'f1': 0.7271672898651327, 'auc': 0.9027120454977164, 'prauc': 0.7989247682677788}
Test:      {'precision': 0.7251236943335617, 'recall': 0.730343300106698, 'f1': 0.7277241329270838, 'auc': 0.9006345628203405, 'prauc': 0.8159741395709427}


Epoch 011: 100%|██████████| 98/98 [00:04<00:00, 22.74it/s, loss=0.1678]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.68it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 39.37it/s]


Validation: {'precision': 0.7012770682914977, 'recall': 0.7442545668783486, 'f1': 0.7221269246743905, 'auc': 0.8974619650520084, 'prauc': 0.7905746032456024}
Test:      {'precision': 0.71935483870581, 'recall': 0.7408637873713131, 'f1': 0.7299508951607688, 'auc': 0.8988225725522405, 'prauc': 0.8093342742826807}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7020364415825187, 'recall': 0.7719505008793639, 'f1': 0.7353353865311181, 'auc': 0.906874975170092, 'prauc': 0.8065777559301819}
Corresponding test performance:
{'precision': 0.7173800928276567, 'recall': 0.7702104097410288, 'f1': 0.742857137859482, 'auc': 0.9057057843353044, 'prauc': 0.8236435968745119}


Epoch 001: 100%|██████████| 98/98 [00:04<00:00, 23.61it/s, loss=0.5187]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 39.70it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 39.68it/s]


Validation: {'precision': 0.7137978142027747, 'recall': 0.6157925751289582, 'f1': 0.6611831649059317, 'auc': 0.8492555370854395, 'prauc': 0.7183229457017054}
Test:      {'precision': 0.7213333333285245, 'recall': 0.599114064227026, 'f1': 0.6545674481544235, 'auc': 0.853505027705189, 'prauc': 0.7253950364775412}


Epoch 002: 100%|██████████| 98/98 [00:04<00:00, 22.66it/s, loss=0.4101]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.92it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 39.43it/s]


Validation: {'precision': 0.8156779660930542, 'recall': 0.4537418974634429, 'f1': 0.5831124528045483, 'auc': 0.8737122459253722, 'prauc': 0.7410491903917584}
Test:      {'precision': 0.823845327595877, 'recall': 0.42469545957682897, 'f1': 0.5604676608339234, 'auc': 0.8744382885564497, 'prauc': 0.7535721849432668}


Epoch 003: 100%|██████████| 98/98 [00:04<00:00, 21.84it/s, loss=0.3623]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.75it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 36.66it/s]


Validation: {'precision': 0.7959022852577156, 'recall': 0.5951679434260745, 'f1': 0.6810519168796985, 'auc': 0.8956723594167337, 'prauc': 0.7939073283822418}
Test:      {'precision': 0.8139892390406305, 'recall': 0.5863787375382814, 'f1': 0.6816865094501997, 'auc': 0.8959023624133182, 'prauc': 0.8044009912571983}


Epoch 004: 100%|██████████| 98/98 [00:04<00:00, 22.73it/s, loss=0.3247]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 39.49it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 39.37it/s]


Validation: {'precision': 0.8387681159344316, 'recall': 0.5456688273391535, 'f1': 0.6611924264939296, 'auc': 0.8986452418998293, 'prauc': 0.8024218195695776}
Test:      {'precision': 0.8431200701065459, 'recall': 0.5326688815031414, 'f1': 0.6528673179512181, 'auc': 0.8965439666306718, 'prauc': 0.8051209002562068}


Epoch 005: 100%|██████████| 98/98 [00:04<00:00, 22.66it/s, loss=0.2925]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.40it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.17it/s]


Validation: {'precision': 0.8111831442397798, 'recall': 0.5898644667024758, 'f1': 0.683043325041635, 'auc': 0.9000395490135038, 'prauc': 0.8012537413353666}
Test:      {'precision': 0.8213166144136261, 'recall': 0.5802879291219254, 'f1': 0.6800778666554541, 'auc': 0.901135092894738, 'prauc': 0.8126639139478262}


Epoch 006: 100%|██████████| 98/98 [00:04<00:00, 22.94it/s, loss=0.2484]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 37.66it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.24it/s]


Validation: {'precision': 0.7428023032582035, 'recall': 0.6841484973442301, 'f1': 0.712269933654385, 'auc': 0.9030568858141035, 'prauc': 0.8021577280092125}
Test:      {'precision': 0.7711442786021695, 'recall': 0.6866002214801407, 'f1': 0.7264206159850287, 'auc': 0.9024908080224654, 'prauc': 0.8138439238557807}


Epoch 007: 100%|██████████| 98/98 [00:04<00:00, 22.82it/s, loss=0.2372]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.33it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 36.94it/s]


Validation: {'precision': 0.8428571428491157, 'recall': 0.5215085444872039, 'f1': 0.6443392744863676, 'auc': 0.8993363276978119, 'prauc': 0.797716802023468}
Test:      {'precision': 0.8512173128868241, 'recall': 0.5227021040945588, 'f1': 0.6476843863620372, 'auc': 0.9024432293520881, 'prauc': 0.8112650524216589}


Epoch 008: 100%|██████████| 98/98 [00:04<00:00, 22.84it/s, loss=0.2093]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.58it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.79it/s]


Validation: {'precision': 0.78304048891686, 'recall': 0.604007071298739, 'f1': 0.6819693896230096, 'auc': 0.8939130925757329, 'prauc': 0.7928106952378919}
Test:      {'precision': 0.7979576951072359, 'recall': 0.6057585824994144, 'f1': 0.6887000265656377, 'auc': 0.8975908820309949, 'prauc': 0.8052882072843588}


Epoch 009: 100%|██████████| 98/98 [00:04<00:00, 22.00it/s, loss=0.1937]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.30it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.59it/s]


Validation: {'precision': 0.6970033296298724, 'recall': 0.7401296405377719, 'f1': 0.7179194005448404, 'auc': 0.8991111819087279, 'prauc': 0.8027343147673945}
Test:      {'precision': 0.7247854077214336, 'recall': 0.748062015499734, 'f1': 0.7362397770135855, 'auc': 0.901802425293486, 'prauc': 0.8129792512191218}


Epoch 010: 100%|██████████| 98/98 [00:04<00:00, 22.58it/s, loss=0.1719]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.62it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 37.49it/s]


Validation: {'precision': 0.745631067956339, 'recall': 0.6788450206206315, 'f1': 0.7106724194359715, 'auc': 0.9004269913841656, 'prauc': 0.8034064881653589}
Test:      {'precision': 0.7733835530397152, 'recall': 0.6821705426318817, 'f1': 0.7249190888665021, 'auc': 0.9034138834683263, 'prauc': 0.8157718365603273}


Epoch 011: 100%|██████████| 98/98 [00:05<00:00, 16.99it/s, loss=0.1576]
Running inference: 100%|██████████| 198/198 [00:07<00:00, 26.98it/s]
Running inference: 100%|██████████| 197/197 [00:06<00:00, 30.43it/s]


Validation: {'precision': 0.6832603938693477, 'recall': 0.736004714197195, 'f1': 0.7086524772723886, 'auc': 0.8950544699518832, 'prauc': 0.7919975899734921}
Test:      {'precision': 0.7391793142173178, 'recall': 0.7281284606825685, 'f1': 0.7336122683574183, 'auc': 0.8966730999430286, 'prauc': 0.8097350810928345}


Epoch 012: 100%|██████████| 98/98 [00:04<00:00, 20.19it/s, loss=0.1485]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.16it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.25it/s]


Validation: {'precision': 0.6842957366395883, 'recall': 0.7472009428359034, 'f1': 0.7143661921887293, 'auc': 0.8978683132818387, 'prauc': 0.7945282504381128}
Test:      {'precision': 0.7147453083071595, 'recall': 0.7380952380911512, 'f1': 0.7262326291569646, 'auc': 0.8968303619130343, 'prauc': 0.8082487469660042}


Epoch 013: 100%|██████████| 98/98 [00:04<00:00, 22.21it/s, loss=0.1291]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.03it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.41it/s]


Validation: {'precision': 0.6449704141980032, 'recall': 0.770771950496342, 'f1': 0.7022818742303399, 'auc': 0.8903532342623732, 'prauc': 0.7808794904496232}
Test:      {'precision': 0.6742574257392364, 'recall': 0.75415282391609, 'f1': 0.7119707216193438, 'auc': 0.8898756897983945, 'prauc': 0.7942088962126658}


Epoch 014: 100%|██████████| 98/98 [00:04<00:00, 20.08it/s, loss=0.1135]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.18it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 39.00it/s]


Validation: {'precision': 0.7448471926030928, 'recall': 0.6175604007034912, 'f1': 0.675257726998056, 'auc': 0.8864855170260675, 'prauc': 0.7755512745005032}
Test:      {'precision': 0.7837455830333305, 'recall': 0.6140642303399, 'f1': 0.6886060180436346, 'auc': 0.8881690742827993, 'prauc': 0.792369209558786}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6970033296298724, 'recall': 0.7401296405377719, 'f1': 0.7179194005448404, 'auc': 0.8991111819087279, 'prauc': 0.8027343147673945}
Corresponding test performance:
{'precision': 0.7247854077214336, 'recall': 0.748062015499734, 'f1': 0.7362397770135855, 'auc': 0.901802425293486, 'prauc': 0.8129792512191218}


Epoch 001: 100%|██████████| 98/98 [00:04<00:00, 22.58it/s, loss=0.5529]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.81it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 39.11it/s]


Validation: {'precision': 0.6905330151098605, 'recall': 0.5114908662315175, 'f1': 0.587677720225436, 'auc': 0.8055457016571113, 'prauc': 0.6510418311318309}
Test:      {'precision': 0.7065820777104951, 'recall': 0.493355481724843, 'f1': 0.5810237969147725, 'auc': 0.8100902603691957, 'prauc': 0.6677231085977664}


Epoch 002: 100%|██████████| 98/98 [00:04<00:00, 21.31it/s, loss=0.4277]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 39.18it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 39.18it/s]


Validation: {'precision': 0.7972972972890093, 'recall': 0.45197407188891003, 'f1': 0.5769086076379483, 'auc': 0.86660868869855, 'prauc': 0.7364006352224314}
Test:      {'precision': 0.8272921108653807, 'recall': 0.4296788482811203, 'f1': 0.565597663134674, 'auc': 0.8691746596370873, 'prauc': 0.7468218067938449}


Epoch 003: 100%|██████████| 98/98 [00:04<00:00, 22.69it/s, loss=0.3733]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.80it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.81it/s]


Validation: {'precision': 0.8166023165944344, 'recall': 0.49852681201827626, 'f1': 0.6190998855184637, 'auc': 0.8867022318239006, 'prauc': 0.7674221842109686}
Test:      {'precision': 0.8526315789392094, 'recall': 0.493355481724843, 'f1': 0.6250438396170273, 'auc': 0.8920716947165395, 'prauc': 0.7920540569371874}


Epoch 004: 100%|██████████| 98/98 [00:04<00:00, 22.76it/s, loss=0.3332]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 39.26it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.56it/s]


Validation: {'precision': 0.7436974789863887, 'recall': 0.6258102533846447, 'f1': 0.679679995032699, 'auc': 0.8844699991479589, 'prauc': 0.7683242331843698}
Test:      {'precision': 0.76907356947705, 'recall': 0.6251384274605475, 'f1': 0.6896762320680142, 'auc': 0.8921965194817728, 'prauc': 0.7932037158014157}


Epoch 005: 100%|██████████| 98/98 [00:04<00:00, 22.65it/s, loss=0.2891]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.53it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.83it/s]


Validation: {'precision': 0.7856012658165696, 'recall': 0.585150265170388, 'f1': 0.6707193466728075, 'auc': 0.8919675136023185, 'prauc': 0.7775707401892599}
Test:      {'precision': 0.8211764705817948, 'recall': 0.579734219265893, 'f1': 0.679649459603696, 'auc': 0.8996157145269928, 'prauc': 0.8068467481880993}


Epoch 006: 100%|██████████| 98/98 [00:04<00:00, 22.14it/s, loss=0.2775]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 36.94it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.93it/s]


Validation: {'precision': 0.7481580709929796, 'recall': 0.6582203889177477, 'f1': 0.7003134746398817, 'auc': 0.9009170746861914, 'prauc': 0.7894816286887782}
Test:      {'precision': 0.772518080205309, 'recall': 0.6506090808380365, 'f1': 0.7063420449272447, 'auc': 0.9017775588215163, 'prauc': 0.8023776536144338}


Epoch 007: 100%|██████████| 98/98 [00:04<00:00, 22.16it/s, loss=0.2477]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.02it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.33it/s]


Validation: {'precision': 0.8021555042278511, 'recall': 0.6140247495544253, 'f1': 0.6955941205842615, 'auc': 0.9019948364010857, 'prauc': 0.8040512300684086}
Test:      {'precision': 0.8209064327425373, 'recall': 0.6218161683243533, 'f1': 0.7076244437359986, 'auc': 0.9034630009055336, 'prauc': 0.8163022389335797}


Epoch 008: 100%|██████████| 98/98 [00:04<00:00, 22.70it/s, loss=0.2337]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.22it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 39.05it/s]


Validation: {'precision': 0.6899841017451511, 'recall': 0.7672362993472762, 'f1': 0.7265624950099977, 'auc': 0.9035444142702957, 'prauc': 0.8020517001606757}
Test:      {'precision': 0.7082255561267035, 'recall': 0.7580287929083166, 'f1': 0.7322813536538976, 'auc': 0.9025072420522078, 'prauc': 0.811515287818804}


Epoch 009: 100%|██████████| 98/98 [00:04<00:00, 22.55it/s, loss=0.1944]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 37.56it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.08it/s]


Validation: {'precision': 0.7197867298535559, 'recall': 0.7159693576858223, 'f1': 0.7178729639765916, 'auc': 0.9007916317557629, 'prauc': 0.8006390614571495}
Test:      {'precision': 0.7328417470179647, 'recall': 0.715393133993824, 'f1': 0.7240123233799693, 'auc': 0.8987737628683992, 'prauc': 0.8043828841551048}


Epoch 010: 100%|██████████| 98/98 [00:04<00:00, 22.51it/s, loss=0.1870]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 36.50it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.77it/s]


Validation: {'precision': 0.684627575273721, 'recall': 0.7637006481982104, 'f1': 0.7220055660415221, 'auc': 0.9012115206685112, 'prauc': 0.7992389742908788}
Test:      {'precision': 0.7023255813917193, 'recall': 0.7524916943479929, 'f1': 0.7265436998938014, 'auc': 0.9010912688154253, 'prauc': 0.8114705321708989}


Epoch 011: 100%|██████████| 98/98 [00:04<00:00, 22.89it/s, loss=0.1662]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 39.44it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 39.27it/s]


Validation: {'precision': 0.5993179880622366, 'recall': 0.8285209192644165, 'f1': 0.6955231215166935, 'auc': 0.8917205238813576, 'prauc': 0.787712925865122}
Test:      {'precision': 0.6135508155557777, 'recall': 0.812292358799489, 'f1': 0.699070759925833, 'auc': 0.8900048846614245, 'prauc': 0.7909988050032114}


Epoch 012: 100%|██████████| 98/98 [00:04<00:00, 22.77it/s, loss=0.1579]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 37.69it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.75it/s]


Validation: {'precision': 0.77421652421101, 'recall': 0.6405421331724188, 'f1': 0.7010641678875849, 'auc': 0.8957184743840299, 'prauc': 0.7948832034125018}
Test:      {'precision': 0.7825174825120105, 'recall': 0.6196013289002237, 'f1': 0.6915945562498799, 'auc': 0.8939737951701926, 'prauc': 0.7984327854527056}


Epoch 013: 100%|██████████| 98/98 [00:04<00:00, 22.46it/s, loss=0.1408]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 39.07it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 39.21it/s]


Validation: {'precision': 0.7183264584518662, 'recall': 0.7183264584518662, 'f1': 0.7183264534518662, 'auc': 0.8967193352393336, 'prauc': 0.7889848374492153}
Test:      {'precision': 0.726700971979836, 'recall': 0.7037652270171442, 'f1': 0.7150492214388943, 'auc': 0.8944254540100756, 'prauc': 0.7957297553850641}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6899841017451511, 'recall': 0.7672362993472762, 'f1': 0.7265624950099977, 'auc': 0.9035444142702957, 'prauc': 0.8020517001606757}
Corresponding test performance:
{'precision': 0.7082255561267035, 'recall': 0.7580287929083166, 'f1': 0.7322813536538976, 'auc': 0.9025072420522078, 'prauc': 0.811515287818804}


Epoch 001: 100%|██████████| 98/98 [00:04<00:00, 21.99it/s, loss=0.5244]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.48it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 36.39it/s]


Validation: {'precision': 0.6520051746400258, 'recall': 0.5939893930430525, 'f1': 0.6216466185037686, 'auc': 0.8355610524483024, 'prauc': 0.6803111185884663}
Test:      {'precision': 0.6708860759451211, 'recall': 0.5869324473943138, 'f1': 0.6261074964952453, 'auc': 0.8389518634343207, 'prauc': 0.6915619272610999}


Epoch 002: 100%|██████████| 98/98 [00:04<00:00, 21.86it/s, loss=0.3951]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 37.62it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 37.31it/s]


Validation: {'precision': 0.8138248847851262, 'recall': 0.5203299941041819, 'f1': 0.6347951066680303, 'auc': 0.8844972721272323, 'prauc': 0.770134216754648}
Test:      {'precision': 0.8109058926929561, 'recall': 0.5105204872618465, 'f1': 0.6265715209082063, 'auc': 0.8854945744312595, 'prauc': 0.774941655608389}


Epoch 003: 100%|██████████| 98/98 [00:04<00:00, 22.20it/s, loss=0.3457]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 39.41it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 39.25it/s]


Validation: {'precision': 0.8399280575464037, 'recall': 0.5503830288712411, 'f1': 0.6650053351907648, 'auc': 0.8909461501092004, 'prauc': 0.7886492745401759}
Test:      {'precision': 0.8344887348281241, 'recall': 0.5332225913591738, 'f1': 0.6506756709138743, 'auc': 0.8928789933461261, 'prauc': 0.7964259501095542}


Epoch 004: 100%|██████████| 98/98 [00:04<00:00, 22.92it/s, loss=0.3182]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 39.21it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.42it/s]


Validation: {'precision': 0.7454545454495256, 'recall': 0.6523276370026381, 'f1': 0.6957888070857029, 'auc': 0.8950580467360504, 'prauc': 0.7924789110146645}
Test:      {'precision': 0.7686170212714853, 'recall': 0.6400885935734215, 'f1': 0.6984894210192752, 'auc': 0.8971954805064093, 'prauc': 0.8025656317984975}


Epoch 005: 100%|██████████| 98/98 [00:04<00:00, 22.63it/s, loss=0.2963]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.51it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 36.97it/s]


Validation: {'precision': 0.7206502107120972, 'recall': 0.705362404238625, 'f1': 0.712924354734268, 'auc': 0.8929739951184561, 'prauc': 0.7935262750055629}
Test:      {'precision': 0.7369994022669636, 'recall': 0.6827242524879141, 'f1': 0.7088243698235834, 'auc': 0.8915887681346748, 'prauc': 0.793817259696232}


Epoch 006: 100%|██████████| 98/98 [00:04<00:00, 22.33it/s, loss=0.2703]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.74it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.83it/s]


Validation: {'precision': 0.7638225255920559, 'recall': 0.6593989393007696, 'f1': 0.7077798811704477, 'auc': 0.8975692685770187, 'prauc': 0.7988426027378048}
Test:      {'precision': 0.7836021505323684, 'recall': 0.6456256921337452, 'f1': 0.7079538505371399, 'auc': 0.8995454236582077, 'prauc': 0.8063966119019965}


Epoch 007: 100%|██████████| 98/98 [00:04<00:00, 19.75it/s, loss=0.2249]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 36.95it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 37.39it/s]


Validation: {'precision': 0.7176749703397529, 'recall': 0.7130229817282674, 'f1': 0.7153414079429123, 'auc': 0.897584980878895, 'prauc': 0.794774754272545}
Test:      {'precision': 0.7374269005804829, 'recall': 0.6982281284568205, 'f1': 0.7172923727015814, 'auc': 0.8991427591541854, 'prauc': 0.8051325850877641}


Epoch 008: 100%|██████████| 98/98 [00:04<00:00, 21.98it/s, loss=0.2201]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.45it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 39.09it/s]


Validation: {'precision': 0.724803862397557, 'recall': 0.7077195050046687, 'f1': 0.7161598041795058, 'auc': 0.8982231813681224, 'prauc': 0.7988352778985586}
Test:      {'precision': 0.75554187191653, 'recall': 0.6794019933517198, 'f1': 0.7154518900536376, 'auc': 0.8989358257909138, 'prauc': 0.8060771836002067}


Epoch 009: 100%|██████████| 98/98 [00:04<00:00, 22.44it/s, loss=0.1851]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.18it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 36.65it/s]


Validation: {'precision': 0.845261121848692, 'recall': 0.5150265173805832, 'f1': 0.6400585818883103, 'auc': 0.8922018568364034, 'prauc': 0.7850704736784782}
Test:      {'precision': 0.8494726749678863, 'recall': 0.49058693244468116, 'f1': 0.6219726173268756, 'auc': 0.8877655480693523, 'prauc': 0.7855122508979234}


Epoch 010: 100%|██████████| 98/98 [00:04<00:00, 22.37it/s, loss=0.1823]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.45it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.31it/s]


Validation: {'precision': 0.7077570655399413, 'recall': 0.6935769004084056, 'f1': 0.70059523309158, 'auc': 0.8896995130847069, 'prauc': 0.7822258112800374}
Test:      {'precision': 0.7325029655947064, 'recall': 0.6838316721999789, 'f1': 0.7073310373844423, 'auc': 0.8880256612142619, 'prauc': 0.7901331331876228}


Epoch 011: 100%|██████████| 98/98 [00:04<00:00, 22.74it/s, loss=0.1528]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.87it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 37.98it/s]


Validation: {'precision': 0.6607422895940369, 'recall': 0.7448438420698595, 'f1': 0.7002770033242702, 'auc': 0.8946874024767443, 'prauc': 0.7889219878135109}
Test:      {'precision': 0.6955613576987177, 'recall': 0.7375415282351189, 'f1': 0.7159365711896389, 'auc': 0.8925656388689054, 'prauc': 0.7955708535104478}


Epoch 012: 100%|██████████| 98/98 [00:04<00:00, 22.57it/s, loss=0.1343]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.88it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 39.08it/s]


Validation: {'precision': 0.755496335771116, 'recall': 0.6682380671734341, 'f1': 0.7091932407929578, 'auc': 0.8965670025565065, 'prauc': 0.7917921784848856}
Test:      {'precision': 0.7686762778455526, 'recall': 0.6495016611259717, 'f1': 0.7040816276841433, 'auc': 0.894081324196259, 'prauc': 0.7990464566250519}


Epoch 013: 100%|██████████| 98/98 [00:04<00:00, 22.78it/s, loss=0.1123]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 39.27it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 39.17it/s]


Validation: {'precision': 0.7654600301601399, 'recall': 0.5981143193836292, 'f1': 0.6715183543166477, 'auc': 0.8830578719845954, 'prauc': 0.7694621320327317}
Test:      {'precision': 0.7853915662591462, 'recall': 0.5775193798417635, 'f1': 0.6656030582901126, 'auc': 0.8778983596499391, 'prauc': 0.7741389215835576}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.724803862397557, 'recall': 0.7077195050046687, 'f1': 0.7161598041795058, 'auc': 0.8982231813681224, 'prauc': 0.7988352778985586}
Corresponding test performance:
{'precision': 0.75554187191653, 'recall': 0.6794019933517198, 'f1': 0.7154518900536376, 'auc': 0.8989358257909138, 'prauc': 0.8060771836002067}


Epoch 001: 100%|██████████| 98/98 [00:04<00:00, 21.14it/s, loss=0.5528]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 37.89it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.49it/s]


Validation: {'precision': 0.6778464254132582, 'recall': 0.452563347080421, 'f1': 0.5427561789403366, 'auc': 0.8105331439511279, 'prauc': 0.6472991207557782}
Test:      {'precision': 0.7082961641328431, 'recall': 0.439645625689703, 'f1': 0.5425350140607091, 'auc': 0.8111393915937537, 'prauc': 0.6527144574106867}


Epoch 002: 100%|██████████| 98/98 [00:04<00:00, 21.88it/s, loss=0.4389]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 37.74it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.76it/s]


Validation: {'precision': 0.7838297872273717, 'recall': 0.5427224513815986, 'f1': 0.6413648976676717, 'auc': 0.8770374416489178, 'prauc': 0.7503599270380693}
Test:      {'precision': 0.7713147410297108, 'recall': 0.5359911406393356, 'f1': 0.6324730431814011, 'auc': 0.877217178349723, 'prauc': 0.7553922055194452}


Epoch 003: 100%|██████████| 98/98 [00:04<00:00, 22.20it/s, loss=0.3739]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.76it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.53it/s]


Validation: {'precision': 0.888268156412175, 'recall': 0.3747790218009736, 'f1': 0.5271446290586717, 'auc': 0.8892979551908144, 'prauc': 0.7819257168390379}
Test:      {'precision': 0.9006622516436998, 'recall': 0.3765227021020126, 'f1': 0.5310425573373528, 'auc': 0.8871502875401249, 'prauc': 0.787181803912109}


Epoch 004: 100%|██████████| 98/98 [00:04<00:00, 22.35it/s, loss=0.3220]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.00it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.03it/s]


Validation: {'precision': 0.7499999999951675, 'recall': 0.685916322918763, 'f1': 0.7165281575170902, 'auc': 0.9011381327219417, 'prauc': 0.7967674111402161}
Test:      {'precision': 0.7608825283198517, 'recall': 0.7065337762973061, 'f1': 0.7327016889446556, 'auc': 0.9040851551101585, 'prauc': 0.8082062644764771}


Epoch 005: 100%|██████████| 98/98 [00:04<00:00, 22.39it/s, loss=0.2862]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 35.87it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 37.88it/s]


Validation: {'precision': 0.714784633290368, 'recall': 0.7236299351754648, 'f1': 0.7191800828437077, 'auc': 0.9063103542408716, 'prauc': 0.8068143596540762}
Test:      {'precision': 0.7471719456971314, 'recall': 0.7314507198187628, 'f1': 0.7392277510120974, 'auc': 0.909052355987428, 'prauc': 0.8240566046834747}


Epoch 006: 100%|██████████| 98/98 [00:04<00:00, 22.70it/s, loss=0.2574]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 39.28it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.27it/s]


Validation: {'precision': 0.7314814814769662, 'recall': 0.6982911019404933, 'f1': 0.7145010501687209, 'auc': 0.9019071413178483, 'prauc': 0.797665026779781}
Test:      {'precision': 0.7538644470823195, 'recall': 0.7021040974490471, 'f1': 0.7270642151856366, 'auc': 0.9033611345414007, 'prauc': 0.8129950884768012}


Epoch 007: 100%|██████████| 98/98 [00:04<00:00, 23.37it/s, loss=0.2523]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.46it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.67it/s]


Validation: {'precision': 0.7669322709112422, 'recall': 0.6806128461951644, 'f1': 0.721198871066976, 'auc': 0.9092851528583423, 'prauc': 0.8101520550178256}
Test:      {'precision': 0.7792288557165471, 'recall': 0.6937984496085615, 'f1': 0.7340363160435666, 'auc': 0.9102961719912923, 'prauc': 0.8245040933324809}


Epoch 008: 100%|██████████| 98/98 [00:04<00:00, 22.35it/s, loss=0.2082]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.07it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.76it/s]


Validation: {'precision': 0.8811777076668645, 'recall': 0.4938126104861885, 'f1': 0.6329305089872223, 'auc': 0.9087245558113098, 'prauc': 0.8135021726225345}
Test:      {'precision': 0.8764367816008005, 'recall': 0.50664451826962, 'f1': 0.6421052585108182, 'auc': 0.9111496336257728, 'prauc': 0.8265113867743048}


Epoch 009: 100%|██████████| 98/98 [00:04<00:00, 22.22it/s, loss=0.2060]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 36.31it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 37.49it/s]


Validation: {'precision': 0.6033163265280471, 'recall': 0.8361814967540591, 'f1': 0.7009138010053774, 'auc': 0.9004680605309403, 'prauc': 0.7974451963475009}
Test:      {'precision': 0.6278210915033737, 'recall': 0.8471760797295285, 'f1': 0.7211878339004888, 'auc': 0.9040919256842095, 'prauc': 0.8146522568839175}


Epoch 010: 100%|██████████| 98/98 [00:04<00:00, 22.28it/s, loss=0.1913]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.14it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.04it/s]


Validation: {'precision': 0.726999398672718, 'recall': 0.7124337065367565, 'f1': 0.7196428521390856, 'auc': 0.89886936575058, 'prauc': 0.7981604822631924}
Test:      {'precision': 0.7498552402967583, 'recall': 0.7170542635619211, 'f1': 0.733088022170725, 'auc': 0.8994257075988492, 'prauc': 0.8089307807761034}


Epoch 011: 100%|██████████| 98/98 [00:04<00:00, 22.45it/s, loss=0.1578]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.00it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.62it/s]


Validation: {'precision': 0.8221024258686245, 'recall': 0.5391868002325328, 'f1': 0.6512455468127533, 'auc': 0.8854161863024239, 'prauc': 0.7784030971836496}
Test:      {'precision': 0.8303647158538562, 'recall': 0.5420819490556917, 'f1': 0.6559463938761771, 'auc': 0.8895127870292543, 'prauc': 0.7957264905839213}


Epoch 012: 100%|██████████| 98/98 [00:04<00:00, 23.23it/s, loss=0.1455]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.51it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 39.12it/s]


Validation: {'precision': 0.7132987910148918, 'recall': 0.7301119622820854, 'f1': 0.721607449859609, 'auc': 0.9002404237671687, 'prauc': 0.8011546511088312}
Test:      {'precision': 0.7213203463164432, 'recall': 0.7380952380911512, 'f1': 0.7296113797804659, 'auc': 0.8986857454057346, 'prauc': 0.8064093841634901}


Epoch 013: 100%|██████████| 98/98 [00:04<00:00, 22.85it/s, loss=0.1607]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 36.02it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 39.81it/s]


Validation: {'precision': 0.731245923021975, 'recall': 0.6605774896837916, 'f1': 0.6941176420674158, 'auc': 0.8913253531020491, 'prauc': 0.7869824034560186}
Test:      {'precision': 0.7443973349440073, 'recall': 0.6805094130637845, 'f1': 0.7110211115810033, 'auc': 0.8946753497432351, 'prauc': 0.7995803981826874}


Epoch 014: 100%|██████████| 98/98 [00:04<00:00, 23.79it/s, loss=0.1262]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 39.22it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.60it/s]


Validation: {'precision': 0.7634795111375738, 'recall': 0.6258102533846447, 'f1': 0.6878238292415338, 'auc': 0.8943182272537924, 'prauc': 0.7937207376949704}
Test:      {'precision': 0.7818545697079263, 'recall': 0.6489479512699394, 'f1': 0.7092284367937676, 'auc': 0.8968216217174411, 'prauc': 0.804834020383906}


Epoch 015: 100%|██████████| 98/98 [00:04<00:00, 22.52it/s, loss=0.1179]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.63it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.45it/s]


Validation: {'precision': 0.7942857142792303, 'recall': 0.5733647613401688, 'f1': 0.6659821990957902, 'auc': 0.8870717902737275, 'prauc': 0.7838725386215321}
Test:      {'precision': 0.8057054741649522, 'recall': 0.5786267995538282, 'f1': 0.6735417289361901, 'auc': 0.8889429508968425, 'prauc': 0.7936653373317695}


Epoch 016: 100%|██████████| 98/98 [00:04<00:00, 22.95it/s, loss=0.0822]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.03it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 38.03it/s]


Validation: {'precision': 0.7751999999937985, 'recall': 0.5710076605741249, 'f1': 0.6576179116358507, 'auc': 0.8852301935257396, 'prauc': 0.7735804824060666}
Test:      {'precision': 0.7939622641449513, 'recall': 0.5825027685460549, 'f1': 0.6719897747368345, 'auc': 0.8795777697679834, 'prauc': 0.7796357026714728}


Epoch 017: 100%|██████████| 98/98 [00:04<00:00, 22.43it/s, loss=0.0960]
Running inference: 100%|██████████| 198/198 [00:05<00:00, 38.29it/s]
Running inference: 100%|██████████| 197/197 [00:05<00:00, 36.42it/s]

Validation: {'precision': 0.8300469483490137, 'recall': 0.5209192692956929, 'f1': 0.6401158533310167, 'auc': 0.8942175024568038, 'prauc': 0.7929031655763956}
Test:      {'precision': 0.8427672955899124, 'recall': 0.5193798449583645, 'f1': 0.6426858465963593, 'auc': 0.8926233734003596, 'prauc': 0.7989546073533438}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7132987910148918, 'recall': 0.7301119622820854, 'f1': 0.721607449859609, 'auc': 0.9002404237671687, 'prauc': 0.8011546511088312}
Corresponding test performance:
{'precision': 0.7213203463164432, 'recall': 0.7380952380911512, 'f1': 0.7296113797804659, 'auc': 0.8986857454057346, 'prauc': 0.8064093841634901}





In [16]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")


Final Metrics:
precision: 0.7255 ± 0.0160
recall: 0.7388 ± 0.0315
f1: 0.7313 ± 0.0091
auc: 0.9015 ± 0.0026
prauc: 0.8121 ± 0.0064
