In [1]:
import torch
import pickle
import numpy as np
from argparse import Namespace
from torch.utils.data import DataLoader
import torch.nn.functional as F
from heterogt.utils.tokenizer import EHRTokenizer
from heterogt.utils.dataset import FineTuneEHRDataset, batcher
from heterogt.utils.train import train_with_early_stopping
from heterogt.utils.seed import set_random_seed

In [2]:
set_random_seed(123)

[INFO] Random seed set to 123


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
config = Namespace(
    dataset = "MIMIC-III",
    tasks = ["death", "readmission", "stay", "next_diag_6m", "next_diag_12m"], 
    task_index = 2,
    token_type = ["diag", "med", "lab", "pro"],
    special_tokens = ["[PAD]"],
    batch_size = 32,
    lr = 1e-3,
    epochs = 500,
    early_stop_patience = 5,   
)

In [5]:
full_data_path = f"./data_process/{config.dataset}-processed/mimic.pkl"  # for tokenizer
curr_task = config.tasks[config.task_index]
print("Current task:", curr_task)
if curr_task == "next_diag_6m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_6m.pkl"
elif curr_task == "next_diag_12m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_downstream.pkl"

Current task: stay


In [6]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_gender_sentences = ["[PAD]"] + [str(c) + "_" + gender for c in set(ehr_full_data["AGE"].values.tolist()) for gender in ["M", "F"]]
token_type_sentences = ["[PAD]"] + config.token_type
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
config.max_num_adms = max_admissions
print(f"Max admissions per patient: {config.max_num_adms}")

Max admissions per patient: 8


In [7]:
task_sentences = config.tasks
tokenizer = EHRTokenizer(token_type_sentences, age_gender_sentences, task_sentences, diag_sentences, 
                         med_sentences, lab_sentences, pro_sentences, special_tokens=config.special_tokens)
config.label_vocab_size = len(tokenizer.diag_voc.id2word)  # only for diagnosis
config.global_vocab_size = len(tokenizer.vocab.id2word)
config.age_gender_vocab_size = tokenizer.token_number("age_gender")
print(f"Age and gender vocabulary size: {config.age_gender_vocab_size}")

Age and gender vocabulary size: 37


In [8]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))
# example label percentage
print("Percentage of DEATH in test dataset:",
      (test_data["DEATH"] == True).mean() * 100, "%")

print("Percentage of READMISSION in test dataset:",
      (test_data["READMISSION"] == 1).mean() * 100, "%")

print("Percentage of STAY>7 days in test dataset:",
      (test_data["STAY_DAYS"] > 7).mean() * 100, "%")

Percentage of DEATH in test dataset: 28.648477157360407 %
Percentage of READMISSION in test dataset: 40.1491116751269 %
Percentage of STAY>7 days in test dataset: 50.58692893401015 %


In [9]:
train_dataset = FineTuneEHRDataset(train_data, tokenizer, token_type=config.token_type, task=curr_task)
val_dataset = FineTuneEHRDataset(val_data, tokenizer, token_type=config.token_type, task=curr_task)
test_dataset = FineTuneEHRDataset(test_data, tokenizer, token_type=config.token_type, task=curr_task)

In [10]:
train_dataloader = DataLoader(
    train_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=True,
    batch_size=config.batch_size,
)

val_dataloader = DataLoader(
    val_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

test_dataloader = DataLoader(
    test_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

In [11]:
if curr_task in ["death", "stay", "readmission"]:
    eval_metric = "prauc"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "prauc"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [12]:
input_ids, token_types, adm_index, age_gender_ids, task_index, labels = next(iter(train_dataloader))
print("Input IDs shape:", input_ids.shape)
print("Token Types shape:", token_types.shape)
print("Admission Index shape:", adm_index.shape)
print("Age/Sex IDs shape:", age_gender_ids.shape)
print("Task Index:", task_index)
print("Labels shape:", labels.shape)

Input IDs shape: torch.Size([32, 256])
Token Types shape: torch.Size([32, 256])
Admission Index shape: torch.Size([32, 256])
Age/Sex IDs shape: torch.Size([32, 7])
Task Index: 2
Labels shape: torch.Size([32, 1])


# Model Walkthrough

In [13]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData, Batch as HeteroBatch
from torch_geometric.nn import HeteroConv, GATConv

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [14]:
class DiseaseOccHetGNN(nn.Module):
    def __init__(self, d_model: int, heads: int = 1, dropout: float = 0.0):
        super().__init__()
        self.act = nn.GELU()

        # 第1层
        self.conv1 = HeteroConv({
            ('visit','contains','occ'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('occ','contained_by','visit'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('visit','next','visit'):       GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
        }, aggr='mean')

        # 第2层
        self.conv2 = HeteroConv({
            ('visit','contains','occ'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('occ','contained_by','visit'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('visit','next','visit'):       GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
        }, aggr='mean')
        self.lin = nn.Linear(d_model, d_model)

    def forward(self, hg):
        # x_dict: {'visit': [N_visit, d], 'occ': [N_occ, d]}
        x_dict = {'visit': hg['visit'].x, 'occ': hg['occ'].x}

        # 第1层：HeteroConv → Linear → GELU → Dropout
        x_dict = self.conv1(x_dict, hg.edge_index_dict)

        # 第2层：HeteroConv → Linear（末层通常不再加激活/随你需要）
        x_dict = self.conv2(x_dict, hg.edge_index_dict)
        x_dict = {k: self.lin(v) for k, v in x_dict.items()}

        return x_dict  # {'visit': [N_visit, d], 'occ': [N_occ, d]}

In [15]:
# multi-class classification task
class MultiPredictionHead(nn.Module):
    def __init__(self, hidden_size, label_size):
        super(MultiPredictionHead, self).__init__()
        self.cls = nn.Sequential(
                nn.Linear(hidden_size, hidden_size), 
                nn.ReLU(), 
                nn.Linear(hidden_size, label_size)
            )

    def forward(self, input):
        return self.cls(input)
    
class BinaryPredictionHead(nn.Module):
    def __init__(self, hidden_size):
        super(BinaryPredictionHead, self).__init__()
        self.cls = nn.Sequential(
                nn.Linear(hidden_size, hidden_size), 
                nn.ReLU(), 
                nn.Linear(hidden_size, 1)
            )
    def forward(self, input):
        return self.cls(input)

In [16]:
for i in range(len(train_dataset)):
    age_gender_ids = train_dataset[i][3]
    if len(age_gender_ids[0]) > 3:
        print(age_gender_ids)
        break
exp_i = i
id_seq = torch.concat([train_dataset[exp_i][0][0], torch.zeros(5, dtype=train_dataset[exp_i][0][0].dtype)], dim=0)
type_seq = torch.concat([train_dataset[exp_i][1][0], torch.zeros(5, dtype=train_dataset[exp_i][1][0].dtype)], dim=0)
visit_seq = torch.concat([train_dataset[exp_i][2][0], torch.zeros(5, dtype=train_dataset[exp_i][2][0].dtype)], dim=0)
age_sex = torch.concat([train_dataset[exp_i][3][0], torch.zeros(3, dtype=train_dataset[exp_i][3][0].dtype)], dim=0)

tensor([[27, 27, 27,  1]])


In [17]:
class HeteroGT(nn.Module):
    def __init__(self, tokenizer, d_model, num_heads, num_layers, max_num_adms, device, task, use_hetero_graph):
        super(HeteroGT, self).__init__()
        self.device = device
        self.tokenizer = tokenizer
        self.max_num_adms = max_num_adms
        self.use_hetero_graph = use_hetero_graph
        self.global_vocab_size = len(self.tokenizer.vocab.word2id)
        self.age_sex_vocab_size = len(self.tokenizer.age_gender_voc.word2id)
        self.n_type = len(self.tokenizer.token_type_voc.word2id)
        self.d_model = d_model
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.seq_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="all")[0] #0
        self.type_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="type")[0] #0
        self.adm_pad_id = 0
        self.age_sex_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="age_gender")[0] #0
        self.diag_type_id = 1
        self.visit_type_id = 5
        
        # embedding layers
        self.token_emb = nn.Embedding(self.global_vocab_size, d_model, padding_idx=self.seq_pad_id)
        self.type_emb = nn.Embedding(self.n_type + 1, d_model, padding_idx=self.type_pad_id) # n_type already have PAD, + 1 for visit
        self.adm_index_emb = nn.Embedding(self.max_num_adms + 1, d_model, padding_idx=self.adm_pad_id) # +1 for pad
        self.age_sex_emb = nn.Embedding(self.age_sex_vocab_size, d_model, padding_idx=self.age_sex_pad_id)
        self.task_emb = nn.Embedding(5, d_model, padding_idx=None)  # 5 task in total, task embedding, not used in this model
        
        # GNN
        self.het_gnn = DiseaseOccHetGNN(d_model)    

        # encoder transformer
        enc_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, batch_first=True, norm_first = True)
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers, enable_nested_tensor=False)

        # prediction head
        if task in ["death", "stay", "readmission"]:
            self.cls_head = BinaryPredictionHead(self.d_model)
        else:
            self.cls_head = MultiPredictionHead(self.d_model, config.label_vocab_size)

    def visit_segment(self, B, input_ids, token_types, adm_index, age_gender_index):
        graphs = []
        for p in range(B):
            hg_p = self.build_patient_graph(input_ids[p], token_types[p], adm_index[p], age_gender_index[p])
            graphs.append(hg_p)

        batch_graph = HeteroBatch.from_data_list(graphs).to(self.device)
        out = self.het_gnn(batch_graph)
        h_visit_all = out['visit']  # extract virtual visit node representations

        # 取出每个样本的 visit 表示序列（按我们在 build_patient_graph 中的保序构造）
        visit_emb_seq = []
        offset = 0
        for p in range(B):
            n_v = graphs[p]['visit'].num_nodes
            visit_emb_p = h_visit_all[offset:offset + n_v]  # [N_visit_p, d]
            offset += n_v
            visit_emb_seq.append(visit_emb_p)
            
        visit_emb_pad = []
        visit_pad_mask = []
        visit_index_pad = []
        for p in range(B):
            v = visit_emb_seq[p]                          # [N_visit_p, d]
            Np = v.size(0)
            if Np < self.max_num_adms:
                pad_len = self.max_num_adms - Np
                v_pad = torch.cat([v, torch.zeros(pad_len, self.d_model, device=self.device, dtype=v.dtype)], dim=0)
                m_pad = torch.cat([torch.zeros(Np, dtype=torch.bool, device=self.device), torch.ones(pad_len, dtype=torch.bool, device=self.device)], dim=0)
                i_pad = torch.cat([torch.arange(1, Np + 1, device=self.device), torch.full((pad_len,), self.adm_pad_id, dtype=torch.long, device=self.device)], dim=0)
            else:
                v_pad = v[:self.max_num_adms]
                m_pad = torch.zeros(self.max_num_adms, dtype=torch.bool, device=self.device)
                i_pad = torch.arange(1, self.max_num_adms + 1, device=self.device)
            visit_emb_pad.append(v_pad)      # [V_max, d]
            visit_pad_mask.append(m_pad)     # [V_max]
            visit_index_pad.append(i_pad)    # [V_max]

        visit_emb_pad  = torch.stack(visit_emb_pad,  dim=0)  # [B, V_max, d]
        visit_pad_mask = torch.stack(visit_pad_mask, dim=0)  # [B, V_max]
        visit_index_pad = torch.stack(visit_index_pad, dim=0)  # [B, V_max]

        visit_type_ids = torch.full((B, self.max_num_adms), self.visit_type_id, dtype=torch.long, device=self.device) # [B, V_max]
        visit_type_ids_pad = visit_type_ids * (~visit_pad_mask)

        return visit_emb_pad, visit_pad_mask, visit_index_pad, visit_type_ids_pad

    def build_patient_graph(self, id_seq: torch.Tensor, type_seq: torch.Tensor, visit_seq: torch.Tensor, age_sex: torch.Tensor):
        # build a graph just for one patient
        hg = HeteroData()
        occ_mask = (type_seq == self.diag_type_id) & (id_seq != self.seq_pad_id) # 疾病token mask
        occ_pos = torch.nonzero(occ_mask, as_tuple=False).view(-1) # 疾病 token 的位置索引，形状 [N_occ]
        N_occ = occ_pos.numel() # 疾病 token 数量

        # build visit virtual nodes
        nonpad = id_seq != self.seq_pad_id
        visit_used = visit_seq[nonpad] # seq非pad部分
        visit_ids_unique, visit_lid_nonpad = torch.unique(visit_used, return_inverse=True)
        visit_lid_full = torch.full_like(id_seq, fill_value=-1)
        visit_lid_full[nonpad] = visit_lid_nonpad
        N_visit = visit_ids_unique.numel()
        age_sex_nonpad = age_sex[age_sex!=self.age_sex_pad_id]
        assert N_visit == len(visit_ids_unique) == len(age_sex_nonpad)
        visit_x = self.age_sex_emb(age_sex_nonpad.to(self.device))
        hg['visit'].x = visit_x
        hg['visit'].num_nodes = N_visit
        
        # build diag nodes
        gid_occ = id_seq[occ_pos]
        x_occ = self.token_emb(gid_occ) # [N_occ, d]
        hg['occ'].x = x_occ
        hg['occ'].num_nodes = N_occ

        # build edges between diag nodes and virtual visit nodes
        occ_visit_lid = visit_lid_full[occ_pos]
        e_v2o = torch.stack([occ_visit_lid, torch.arange(N_occ, device=self.device)], dim=0)
        e_o2v = torch.stack([torch.arange(N_occ, device=self.device), occ_visit_lid], dim=0)
        hg['visit','contains','occ'].edge_index = e_v2o
        hg['occ','contained_by','visit'].edge_index = e_o2v
        
        # build forward edges between virtual visit nodes
        if N_visit > 1:
            src = torch.arange(0, N_visit - 1, device=self.device)
            dst = torch.arange(1, N_visit, device=self.device)
            e_next = torch.stack([src, dst], dim=0) # [2, N_visit-1]
        else:
            e_next = torch.empty(2, 0, dtype=torch.long, device=self.device)
        hg['visit','next','visit'].edge_index = e_next
        return hg
    
    def forward(self, input_ids, token_types, adm_index, age_gender_index, task_id):
        B, L = input_ids.shape
        task_id = torch.full((B,), task_id, dtype=torch.long, device=self.device)
        # 基础表示
        token_embed = self.token_emb(input_ids)  # [B, L, d]
        seq_pad_mask = (input_ids == self.seq_pad_id) # [B, L]
        
        if self.use_hetero_graph:
            # get visit embed and mask
            visit_emb_pad, visit_pad_mask, visit_index_pad, visit_type_ids_pad = self.visit_segment(B, input_ids, token_types, adm_index, age_gender_index)
            token_embed = torch.concat([token_embed, visit_emb_pad], dim=1)  # [B, L+V, d]
            adm_index = torch.concat([adm_index, visit_index_pad], dim=1)  # [B, L+V]
            token_types = torch.concat([token_types, visit_type_ids_pad], dim=1) # [B, L+V]
            seq_pad_mask = torch.concat([seq_pad_mask, visit_pad_mask], dim=1) # [B, L+V]
            
        adm_emb = self.adm_index_emb(adm_index) # [B, L+V, d]
        token_type_emb = self.type_emb(token_types) # [B, L+V, d]
        x = token_embed + adm_emb + token_type_emb # [B, L+V, d]
        task_id_emb = self.task_emb(task_id).unsqueeze(1) # [B, 1, d]
        x = torch.concat([task_id_emb, x], dim=1) # [B, 1+L+V, d]

        # mask
        task_pad_mask = torch.zeros((B, 1), dtype=torch.bool, device=self.device)
        mask = torch.concat([task_pad_mask, seq_pad_mask], dim=1)  # [B, 1+L+V]

        # ===== Transformer 编码（batch_first=True） =====
        h = self.encoder(x, src_key_padding_mask=mask)   # [B, 1+L(+V), d]

        # ===== 分类：取 CLS（task 位） =====
        logits = self.cls_head(h[:, 0, :])  # [B, label_size]
        return logits

In [18]:
final_metrics = {"precision":[],"recall":[],"f1":[],"auc":[],"prauc":[]}
for i in range(5):
    model = HeteroGT(tokenizer, d_model=128, num_heads=4, num_layers=2, max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, use_hetero_graph=True).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 29.88it/s, loss=0.6914]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.95it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.32it/s]


Validation: {'precision': 0.8026556776520025, 'recall': 0.5497021009703679, 'f1': 0.6525218637747735, 'auc': 0.7841614312415848, 'prauc': 0.7958278993894328}
Test:      {'precision': 0.7758771929790532, 'recall': 0.5547193477561784, 'f1': 0.6469189931244258, 'auc': 0.7735731826951293, 'prauc': 0.7886587564292016}


Epoch 002: 100%|██████████| 98/98 [00:02<00:00, 32.83it/s, loss=0.5848]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.63it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.63it/s]


Validation: {'precision': 0.8104838709641108, 'recall': 0.5672624647207047, 'f1': 0.667404533061497, 'auc': 0.8004220692546038, 'prauc': 0.807690444464056}
Test:      {'precision': 0.7830551989696917, 'recall': 0.5738476011270811, 'f1': 0.6623235564628235, 'auc': 0.7877640686005817, 'prauc': 0.7980140655614889}


Epoch 003: 100%|██████████| 98/98 [00:02<00:00, 32.74it/s, loss=0.5453]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.70it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.51it/s]


Validation: {'precision': 0.8379888268113866, 'recall': 0.5174035747867125, 'f1': 0.6397828568514952, 'auc': 0.8068119855691862, 'prauc': 0.8208367693094212}
Test:      {'precision': 0.826108374380167, 'recall': 0.5258701787377678, 'f1': 0.6426518442573408, 'auc': 0.7953521510287923, 'prauc': 0.8123854677639399}


Epoch 004: 100%|██████████| 98/98 [00:02<00:00, 32.81it/s, loss=0.5222]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.20it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.36it/s]


Validation: {'precision': 0.7303012746212911, 'recall': 0.790529946689274, 'f1': 0.7592230035886001, 'auc': 0.8249674541079562, 'prauc': 0.8330842943320962}
Test:      {'precision': 0.7081911262778493, 'recall': 0.780809031041766, 'f1': 0.7427293014973726, 'auc': 0.8113158343764959, 'prauc': 0.8253301635479566}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 32.28it/s, loss=0.4921]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 57.82it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.61it/s]


Validation: {'precision': 0.78449905481745, 'recall': 0.6506741925348051, 'f1': 0.711347269638225, 'auc': 0.8208488426804372, 'prauc': 0.8326343782510636}
Test:      {'precision': 0.7764968389707085, 'recall': 0.6547507055482761, 'f1': 0.7104457248738429, 'auc': 0.8101215202539629, 'prauc': 0.8225504692445436}


Epoch 006: 100%|██████████| 98/98 [00:02<00:00, 32.79it/s, loss=0.4640]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.86it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.21it/s]


Validation: {'precision': 0.7548344804957233, 'recall': 0.7221699592326053, 'f1': 0.7381410206411053, 'auc': 0.8195571548264123, 'prauc': 0.8244615949317731}
Test:      {'precision': 0.7400768245814978, 'recall': 0.7249921605496238, 'f1': 0.7324568301003798, 'auc': 0.8097508641009651, 'prauc': 0.8159916936823564}


Epoch 007: 100%|██████████| 98/98 [00:02<00:00, 32.89it/s, loss=0.4397]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.38it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 59.16it/s]


Validation: {'precision': 0.6822666329352839, 'recall': 0.84571966133319, 'f1': 0.7552506251307101, 'auc': 0.8267749731258395, 'prauc': 0.8377227969404517}
Test:      {'precision': 0.677096370461384, 'recall': 0.8482282847260952, 'f1': 0.7530623558626225, 'auc': 0.8194925171649938, 'prauc': 0.8296172746588077}


Epoch 008: 100%|██████████| 98/98 [00:02<00:00, 32.80it/s, loss=0.4292]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.40it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 59.08it/s]


Validation: {'precision': 0.8243123336254468, 'recall': 0.5826277830022495, 'f1': 0.6827117349943815, 'auc': 0.821741700810088, 'prauc': 0.8320864071181875}
Test:      {'precision': 0.8046371833370346, 'recall': 0.5876450297880601, 'f1': 0.6792316007732124, 'auc': 0.8128179884001334, 'prauc': 0.8211899459253487}


Epoch 009: 100%|██████████| 98/98 [00:02<00:00, 32.80it/s, loss=0.3950]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.40it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 59.22it/s]


Validation: {'precision': 0.7256097560954541, 'recall': 0.7836312323587845, 'f1': 0.7535051962715131, 'auc': 0.8158097931360293, 'prauc': 0.8119922598403797}
Test:      {'precision': 0.7181266261904651, 'recall': 0.7789275634970871, 'f1': 0.7472924137785624, 'auc': 0.806161982376216, 'prauc': 0.80056905912264}


Epoch 010: 100%|██████████| 98/98 [00:02<00:00, 33.14it/s, loss=0.3707]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.31it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.91it/s]


Validation: {'precision': 0.7457789104786691, 'recall': 0.7340859203489054, 'f1': 0.7398862149726895, 'auc': 0.8196237185478881, 'prauc': 0.8304555752885543}
Test:      {'precision': 0.7402068317118765, 'recall': 0.7406710567552817, 'f1': 0.7404388664710337, 'auc': 0.8131950369120979, 'prauc': 0.8207104088747595}


Epoch 011: 100%|██████████| 98/98 [00:02<00:00, 32.76it/s, loss=0.3315]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.58it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.82it/s]


Validation: {'precision': 0.6984211934688295, 'recall': 0.8184383819353451, 'f1': 0.7536817738336299, 'auc': 0.8232994426029319, 'prauc': 0.8316999757492294}
Test:      {'precision': 0.6969218626658855, 'recall': 0.8306679209757584, 'f1': 0.7579399091992498, 'auc': 0.8209498743423295, 'prauc': 0.8275933689170734}


Epoch 012: 100%|██████████| 98/98 [00:02<00:00, 32.77it/s, loss=0.3076]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.32it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 59.12it/s]


Validation: {'precision': 0.6745679012329023, 'recall': 0.8566948886771506, 'f1': 0.7548003818624033, 'auc': 0.8271935961379577, 'prauc': 0.8345488142427887}
Test:      {'precision': 0.6745088527754681, 'recall': 0.8720602069586953, 'f1': 0.7606673911407701, 'auc': 0.8277011114147901, 'prauc': 0.8348712267594264}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6822666329352839, 'recall': 0.84571966133319, 'f1': 0.7552506251307101, 'auc': 0.8267749731258395, 'prauc': 0.8377227969404517}
Corresponding test performance:
{'precision': 0.677096370461384, 'recall': 0.8482282847260952, 'f1': 0.7530623558626225, 'auc': 0.8194925171649938, 'prauc': 0.8296172746588077}


Epoch 001: 100%|██████████| 98/98 [00:02<00:00, 32.94it/s, loss=0.6839]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.31it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 59.20it/s]


Validation: {'precision': 0.7097163548560359, 'recall': 0.7375352775141502, 'f1': 0.7233584449458015, 'auc': 0.7863900348733616, 'prauc': 0.7901652702212882}
Test:      {'precision': 0.7080072245613848, 'recall': 0.7375352775141502, 'f1': 0.7224696617177292, 'auc': 0.7787250213540022, 'prauc': 0.7821790406995777}


Epoch 002: 100%|██████████| 98/98 [00:02<00:00, 32.91it/s, loss=0.5846]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.35it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 59.22it/s]


Validation: {'precision': 0.7683892870585878, 'recall': 0.638758231418505, 'f1': 0.6976027347660719, 'auc': 0.7954847491713194, 'prauc': 0.8012590897882644}
Test:      {'precision': 0.7563209966993173, 'recall': 0.6472248353695603, 'f1': 0.6975329453489062, 'auc': 0.7847229667390966, 'prauc': 0.7918780196713886}


Epoch 003: 100%|██████████| 98/98 [00:02<00:00, 32.85it/s, loss=0.5513]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.55it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.98it/s]


Validation: {'precision': 0.6945292974154553, 'recall': 0.7842583882070108, 'f1': 0.7366715708630727, 'auc': 0.8017499778204658, 'prauc': 0.812323356548237}
Test:      {'precision': 0.6954732510268986, 'recall': 0.7949200376268583, 'f1': 0.7418788361087645, 'auc': 0.804449937510916, 'prauc': 0.8152016908523574}


Epoch 004: 100%|██████████| 98/98 [00:02<00:00, 32.80it/s, loss=0.5292]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.17it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.95it/s]


Validation: {'precision': 0.795206055505487, 'recall': 0.5929758544979837, 'f1': 0.6793605124379639, 'auc': 0.8044485966810172, 'prauc': 0.8155007673433736}
Test:      {'precision': 0.7841518778341554, 'recall': 0.5957980558150022, 'f1': 0.677120451256099, 'auc': 0.8033213589853161, 'prauc': 0.8145687856807181}


Epoch 005: 100%|██████████| 98/98 [00:02<00:00, 32.96it/s, loss=0.5188]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.18it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.79it/s]


Validation: {'precision': 0.8063520070542287, 'recall': 0.5732204452788547, 'f1': 0.6700879716799183, 'auc': 0.8081343680902462, 'prauc': 0.8198490559079963}
Test:      {'precision': 0.80255057167633, 'recall': 0.5722797115065152, 'f1': 0.6681310586560326, 'auc': 0.8083381024357909, 'prauc': 0.8206544691610089}


Epoch 006: 100%|██████████| 98/98 [00:02<00:00, 32.87it/s, loss=0.4971]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.23it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 59.17it/s]


Validation: {'precision': 0.7832740213495257, 'recall': 0.6901850109730631, 'f1': 0.7337889598449818, 'auc': 0.8213612075071017, 'prauc': 0.835366456370743}
Test:      {'precision': 0.7614143494652886, 'recall': 0.6955158356829868, 'f1': 0.7269747573808389, 'auc': 0.8165498676983027, 'prauc': 0.8305123168823996}


Epoch 007: 100%|██████████| 98/98 [00:02<00:00, 32.92it/s, loss=0.4726]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.12it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.41it/s]


Validation: {'precision': 0.8143107989428696, 'recall': 0.5816870492299101, 'f1': 0.6786171526267526, 'auc': 0.807420001207692, 'prauc': 0.8218983304189131}
Test:      {'precision': 0.7949367088574054, 'recall': 0.5907808090291917, 'f1': 0.6778197468600027, 'auc': 0.8053502031209813, 'prauc': 0.8179184454582032}


Epoch 008: 100%|██████████| 98/98 [00:02<00:00, 32.92it/s, loss=0.4535]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.98it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.44it/s]


Validation: {'precision': 0.618314077876298, 'recall': 0.9062402006870297, 'f1': 0.7350883838321987, 'auc': 0.8003942883202824, 'prauc': 0.8097228860649364}
Test:      {'precision': 0.6154335382645423, 'recall': 0.9103167137005007, 'f1': 0.7343789478790669, 'auc': 0.7996612049747652, 'prauc': 0.8067291671605789}


Epoch 009: 100%|██████████| 98/98 [00:02<00:00, 32.69it/s, loss=0.4303]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.74it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.31it/s]


Validation: {'precision': 0.7391443167281637, 'recall': 0.7259328943219633, 'f1': 0.7324790331249676, 'auc': 0.8167882547233365, 'prauc': 0.8277684618058168}
Test:      {'precision': 0.7365421152604923, 'recall': 0.729382251487208, 'f1': 0.732944693280463, 'auc': 0.812660494768584, 'prauc': 0.8221616840232198}


Epoch 010: 100%|██████████| 98/98 [00:03<00:00, 32.60it/s, loss=0.4246]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.44it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.61it/s]


Validation: {'precision': 0.7897310513415252, 'recall': 0.6077140169313022, 'f1': 0.6868686819510776, 'auc': 0.802566074164143, 'prauc': 0.8149664013142965}
Test:      {'precision': 0.781148867310756, 'recall': 0.6055189714625101, 'f1': 0.6822116184658926, 'auc': 0.8043651758376886, 'prauc': 0.8119468390951975}


Epoch 011: 100%|██████████| 98/98 [00:02<00:00, 33.08it/s, loss=0.3779]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.83it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.78it/s]


Validation: {'precision': 0.746335245376844, 'recall': 0.7343994982730185, 'f1': 0.740319261633041, 'auc': 0.8186014504963344, 'prauc': 0.8315049136507136}
Test:      {'precision': 0.7395538799851098, 'recall': 0.7381624333623764, 'f1': 0.7388574965670514, 'auc': 0.816440895594658, 'prauc': 0.8231922381644421}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7832740213495257, 'recall': 0.6901850109730631, 'f1': 0.7337889598449818, 'auc': 0.8213612075071017, 'prauc': 0.835366456370743}
Corresponding test performance:
{'precision': 0.7614143494652886, 'recall': 0.6955158356829868, 'f1': 0.7269747573808389, 'auc': 0.8165498676983027, 'prauc': 0.8305123168823996}


Epoch 001: 100%|██████████| 98/98 [00:02<00:00, 32.76it/s, loss=0.6852]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.48it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.86it/s]


Validation: {'precision': 0.6511919698854425, 'recall': 0.8137347130736478, 'f1': 0.723445771474371, 'auc': 0.7629505120583824, 'prauc': 0.7616230648002622}
Test:      {'precision': 0.6505216095363872, 'recall': 0.8212605832523636, 'f1': 0.7259875210528032, 'auc': 0.7685528152301224, 'prauc': 0.769007150615504}


Epoch 002: 100%|██████████| 98/98 [00:02<00:00, 32.88it/s, loss=0.6042]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.31it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 59.07it/s]


Validation: {'precision': 0.7589743589717642, 'recall': 0.6961429915312132, 'f1': 0.7262021539863386, 'auc': 0.8088039740099061, 'prauc': 0.8164348079471739}
Test:      {'precision': 0.7457171649286338, 'recall': 0.6961429915312132, 'f1': 0.7200778412572242, 'auc': 0.8083859192942031, 'prauc': 0.8183027270429202}


Epoch 003: 100%|██████████| 98/98 [00:02<00:00, 32.82it/s, loss=0.5688]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.27it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 59.04it/s]


Validation: {'precision': 0.7674005681790931, 'recall': 0.6776418940085367, 'f1': 0.7197335503874188, 'auc': 0.8054911101512539, 'prauc': 0.8151032253770584}
Test:      {'precision': 0.7412250516147928, 'recall': 0.6754468485397446, 'f1': 0.706808854729543, 'auc': 0.7994987786567691, 'prauc': 0.8106698297327911}


Epoch 004: 100%|██████████| 98/98 [00:02<00:00, 32.78it/s, loss=0.5266]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.25it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.98it/s]


Validation: {'precision': 0.6981333333314717, 'recall': 0.8209470053282504, 'f1': 0.7545755822909188, 'auc': 0.8193393784244523, 'prauc': 0.830334404116373}
Test:      {'precision': 0.698953581967537, 'recall': 0.8168704923147794, 'f1': 0.7533256167747528, 'auc': 0.8147575911779407, 'prauc': 0.8265503581696183}


Epoch 005: 100%|██████████| 98/98 [00:02<00:00, 32.90it/s, loss=0.5034]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.75it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.75it/s]


Validation: {'precision': 0.7200347423256513, 'recall': 0.7798682972694266, 'f1': 0.7487580862295471, 'auc': 0.8183654381465284, 'prauc': 0.829178851493672}
Test:      {'precision': 0.7088571428551176, 'recall': 0.7779868297247476, 'f1': 0.7418149150265305, 'auc': 0.8130755450996024, 'prauc': 0.8266051105164459}


Epoch 006: 100%|██████████| 98/98 [00:02<00:00, 32.89it/s, loss=0.4650]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.10it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.69it/s]


Validation: {'precision': 0.6685176243344556, 'recall': 0.8682972718693375, 'f1': 0.7554221748669108, 'auc': 0.8202248015120062, 'prauc': 0.828115966329366}
Test:      {'precision': 0.6633924293227997, 'recall': 0.8682972718693375, 'f1': 0.7521390688345523, 'auc': 0.8138326621356419, 'prauc': 0.8226520014279564}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 32.62it/s, loss=0.4376]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.77it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.54it/s]


Validation: {'precision': 0.7668131125354282, 'recall': 0.7115083098127579, 'f1': 0.7381262149135103, 'auc': 0.8226447067674658, 'prauc': 0.8331916264173519}
Test:      {'precision': 0.7554794520522072, 'recall': 0.6917529005936289, 'f1': 0.7222131231788805, 'auc': 0.81411795261299, 'prauc': 0.8240068060273789}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 32.40it/s, loss=0.3998]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.97it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.79it/s]


Validation: {'precision': 0.6088103938274283, 'recall': 0.9404201944153641, 'f1': 0.7391250722451295, 'auc': 0.8177380813512165, 'prauc': 0.8243636833924721}
Test:      {'precision': 0.6065739060282183, 'recall': 0.9432423957323824, 'f1': 0.7383406923377185, 'auc': 0.8128584565624108, 'prauc': 0.8190732546544167}


Epoch 009: 100%|██████████| 98/98 [00:03<00:00, 32.59it/s, loss=0.3755]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.53it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.08it/s]


Validation: {'precision': 0.8200692041487022, 'recall': 0.5945437441185496, 'f1': 0.6893292079949637, 'auc': 0.828410581913617, 'prauc': 0.8355856805510037}
Test:      {'precision': 0.8146734520745773, 'recall': 0.6023831922213786, 'f1': 0.6926266401430704, 'auc': 0.8237202321181307, 'prauc': 0.83060998959667}


Epoch 010: 100%|██████████| 98/98 [00:03<00:00, 32.42it/s, loss=0.3529]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.67it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.66it/s]


Validation: {'precision': 0.7058981233225043, 'recall': 0.8256506741899479, 'f1': 0.7610926384739536, 'auc': 0.8260489010756596, 'prauc': 0.8338284886298845}
Test:      {'precision': 0.6919431279602635, 'recall': 0.824082784569382, 'f1': 0.7522541813819039, 'auc': 0.8145578677103829, 'prauc': 0.8219927005069806}


Epoch 011: 100%|██████████| 98/98 [00:02<00:00, 32.78it/s, loss=0.3164]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.87it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.92it/s]


Validation: {'precision': 0.7176470588214187, 'recall': 0.7651301348361081, 'f1': 0.7406283149300306, 'auc': 0.807918701632665, 'prauc': 0.815235492547344}
Test:      {'precision': 0.7153088630238278, 'recall': 0.7516462840992423, 'f1': 0.7330275179366061, 'auc': 0.8027056288495718, 'prauc': 0.8124522977254078}


Epoch 012: 100%|██████████| 98/98 [00:02<00:00, 32.90it/s, loss=0.3152]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.10it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.95it/s]


Validation: {'precision': 0.7092800437976752, 'recall': 0.8124804013771951, 'f1': 0.7573808778035935, 'auc': 0.8233638461432577, 'prauc': 0.8259870886990123}
Test:      {'precision': 0.7009194158985914, 'recall': 0.8127939793013083, 'f1': 0.752722515716283, 'auc': 0.8158752976599437, 'prauc': 0.819651715810248}


Epoch 013: 100%|██████████| 98/98 [00:02<00:00, 32.68it/s, loss=0.2614]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.94it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.71it/s]


Validation: {'precision': 0.6995381689739214, 'recall': 0.8074631545913846, 'f1': 0.7496360940045389, 'auc': 0.8115440382064709, 'prauc': 0.8165830042409429}
Test:      {'precision': 0.6921631776685664, 'recall': 0.8087174662878372, 'f1': 0.7459146732637153, 'auc': 0.8071623613877359, 'prauc': 0.8124873084757651}


Epoch 014: 100%|██████████| 98/98 [00:03<00:00, 32.66it/s, loss=0.2377]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.06it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.56it/s]


Validation: {'precision': 0.6929277171068827, 'recall': 0.8356851677615689, 'f1': 0.7576403646221228, 'auc': 0.8159735147724743, 'prauc': 0.8154894089411817}
Test:      {'precision': 0.6863905325426128, 'recall': 0.8366259015339084, 'f1': 0.7540983557022589, 'auc': 0.8107520987825827, 'prauc': 0.8125362893927844}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.8200692041487022, 'recall': 0.5945437441185496, 'f1': 0.6893292079949637, 'auc': 0.828410581913617, 'prauc': 0.8355856805510037}
Corresponding test performance:
{'precision': 0.8146734520745773, 'recall': 0.6023831922213786, 'f1': 0.6926266401430704, 'auc': 0.8237202321181307, 'prauc': 0.83060998959667}


Epoch 001: 100%|██████████| 98/98 [00:02<00:00, 32.96it/s, loss=0.6681]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.95it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.98it/s]


Validation: {'precision': 0.5299813147604383, 'recall': 0.9783631232330563, 'f1': 0.6875275406144105, 'auc': 0.7812846727913328, 'prauc': 0.7871168121467873}
Test:      {'precision': 0.5304140398696023, 'recall': 0.9761680777642642, 'f1': 0.6873481959219394, 'auc': 0.7731661353962029, 'prauc': 0.7842854425722557}


Epoch 002: 100%|██████████| 98/98 [00:02<00:00, 32.72it/s, loss=0.6030]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.81it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 59.01it/s]


Validation: {'precision': 0.6710594315228138, 'recall': 0.8143618689218741, 'f1': 0.735798266754324, 'auc': 0.7905282386415414, 'prauc': 0.798256928532831}
Test:      {'precision': 0.6770642201817115, 'recall': 0.8099717779842899, 'f1': 0.7375785215940892, 'auc': 0.7863887550855746, 'prauc': 0.7978456632883071}


Epoch 003: 100%|██████████| 98/98 [00:02<00:00, 32.75it/s, loss=0.5630]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.70it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.14it/s]


Validation: {'precision': 0.7457684495573942, 'recall': 0.6908121668212894, 'f1': 0.7172391289786106, 'auc': 0.7945632058454704, 'prauc': 0.7932357518696607}
Test:      {'precision': 0.7393190921203627, 'recall': 0.6945751019106473, 'f1': 0.7162489844932559, 'auc': 0.7914544227322352, 'prauc': 0.7989171371840761}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 32.63it/s, loss=0.5413]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.55it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.55it/s]


Validation: {'precision': 0.6744599745853254, 'recall': 0.8322358105963241, 'f1': 0.7450870248112984, 'auc': 0.8044929055129734, 'prauc': 0.8137866727629026}
Test:      {'precision': 0.6711494544514814, 'recall': 0.8294136092793057, 'f1': 0.7419354789245061, 'auc': 0.8000014596725199, 'prauc': 0.8130272195609071}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 32.56it/s, loss=0.4990]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.25it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 59.13it/s]


Validation: {'precision': 0.7630948305348747, 'recall': 0.6989651928482316, 'f1': 0.7296235629286717, 'auc': 0.8152710540046291, 'prauc': 0.8201040508194203}
Test:      {'precision': 0.7589708869304029, 'recall': 0.7030417058617027, 'f1': 0.7299365081092871, 'auc': 0.8104270448124498, 'prauc': 0.8184685894135868}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 32.64it/s, loss=0.4720]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.30it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 59.09it/s]


Validation: {'precision': 0.768305624334035, 'recall': 0.6810912511737814, 'f1': 0.7220744631008098, 'auc': 0.813343619814548, 'prauc': 0.811841163147998}
Test:      {'precision': 0.7575547064926796, 'recall': 0.6839134524907999, 'f1': 0.7188529943514848, 'auc': 0.8087069969150575, 'prauc': 0.8141765082408741}


Epoch 007: 100%|██████████| 98/98 [00:02<00:00, 32.81it/s, loss=0.4631]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.12it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.74it/s]


Validation: {'precision': 0.703468361131312, 'recall': 0.8331765443686636, 'f1': 0.7628481144707018, 'auc': 0.8190344412249373, 'prauc': 0.8228771031125202}
Test:      {'precision': 0.6953186987551037, 'recall': 0.8243963624934951, 'f1': 0.7543758917340491, 'auc': 0.8152094353231689, 'prauc': 0.8235838347701517}


Epoch 008: 100%|██████████| 98/98 [00:02<00:00, 32.84it/s, loss=0.4243]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.30it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.64it/s]


Validation: {'precision': 0.7027855153183766, 'recall': 0.7911571025375003, 'f1': 0.7443575698788488, 'auc': 0.8037852201209521, 'prauc': 0.7978765147853619}
Test:      {'precision': 0.6926240745799489, 'recall': 0.7920978363098398, 'f1': 0.73902866675814, 'auc': 0.7976403638711924, 'prauc': 0.7985950824005522}


Epoch 009: 100%|██████████| 98/98 [00:02<00:00, 32.75it/s, loss=0.4010]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.05it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.88it/s]


Validation: {'precision': 0.7584026622270935, 'recall': 0.7146440890538895, 'f1': 0.7358734208980647, 'auc': 0.8196176901353771, 'prauc': 0.8256600663120053}
Test:      {'precision': 0.7548029957643934, 'recall': 0.7268736280943027, 'f1': 0.740575074871615, 'auc': 0.8200685341414886, 'prauc': 0.8279044236788468}


Epoch 010: 100%|██████████| 98/98 [00:02<00:00, 32.74it/s, loss=0.3946]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.78it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.85it/s]


Validation: {'precision': 0.6916578669464317, 'recall': 0.8215741611764767, 'f1': 0.7510391235999874, 'auc': 0.8039315598346568, 'prauc': 0.7980608650100385}
Test:      {'precision': 0.688023012550502, 'recall': 0.8250235183417215, 'f1': 0.7503208277780518, 'auc': 0.8001387192229308, 'prauc': 0.798328616559079}


Epoch 011: 100%|██████████| 98/98 [00:02<00:00, 33.10it/s, loss=0.3814]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.09it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.99it/s]


Validation: {'precision': 0.7480089200358458, 'recall': 0.7362809658176975, 'f1': 0.7420986043532134, 'auc': 0.8183177634509206, 'prauc': 0.8179229913475433}
Test:      {'precision': 0.7446202531622006, 'recall': 0.7378488554382633, 'f1': 0.7412190846181822, 'auc': 0.819827033839739, 'prauc': 0.8233018510288418}


Epoch 012: 100%|██████████| 98/98 [00:02<00:00, 33.18it/s, loss=0.3527]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.90it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.58it/s]


Validation: {'precision': 0.7589285714260617, 'recall': 0.7196613358397, 'f1': 0.738773534354118, 'auc': 0.8200122497342224, 'prauc': 0.821062960541402}
Test:      {'precision': 0.7614080834395, 'recall': 0.7325180307283395, 'f1': 0.7466837092395343, 'auc': 0.8238718870596005, 'prauc': 0.8233179149059902}


Epoch 013: 100%|██████████| 98/98 [00:02<00:00, 32.78it/s, loss=0.3128]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.55it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.81it/s]


Validation: {'precision': 0.7135963667308726, 'recall': 0.7883349012204819, 'f1': 0.7491060736752266, 'auc': 0.8103680958726575, 'prauc': 0.8087022053016092}
Test:      {'precision': 0.7108639863110612, 'recall': 0.7817497648141055, 'f1': 0.7446236509230314, 'auc': 0.8136324856662674, 'prauc': 0.8165435055412598}


Epoch 014: 100%|██████████| 98/98 [00:02<00:00, 32.78it/s, loss=0.2977]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.16it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 59.08it/s]


Validation: {'precision': 0.7335252719106413, 'recall': 0.7190341799914738, 'f1': 0.726207437595189, 'auc': 0.8019099316990911, 'prauc': 0.8069457117252106}
Test:      {'precision': 0.7334611697004364, 'recall': 0.7196613358397, 'f1': 0.7264957214938778, 'auc': 0.8043013025815566, 'prauc': 0.8083029977993124}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7584026622270935, 'recall': 0.7146440890538895, 'f1': 0.7358734208980647, 'auc': 0.8196176901353771, 'prauc': 0.8256600663120053}
Corresponding test performance:
{'precision': 0.7548029957643934, 'recall': 0.7268736280943027, 'f1': 0.740575074871615, 'auc': 0.8200685341414886, 'prauc': 0.8279044236788468}


Epoch 001: 100%|██████████| 98/98 [00:02<00:00, 32.72it/s, loss=0.6818]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.17it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.94it/s]


Validation: {'precision': 0.7510696227119756, 'recall': 0.6055189714625101, 'f1': 0.6704861061663406, 'auc': 0.7778892699180506, 'prauc': 0.7765625391883813}
Test:      {'precision': 0.7392273402647132, 'recall': 0.6240200689851866, 'f1': 0.676755648833782, 'auc': 0.7725083264250556, 'prauc': 0.7679425864573239}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 32.63it/s, loss=0.5921]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.06it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.79it/s]


Validation: {'precision': 0.6558441558425179, 'recall': 0.8234556287211557, 'f1': 0.7301543117589481, 'auc': 0.7903416090375549, 'prauc': 0.8024169768366712}
Test:      {'precision': 0.6505724240899687, 'recall': 0.8196926936317978, 'f1': 0.7254058504824864, 'auc': 0.7825611917370456, 'prauc': 0.7966473176998566}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 32.65it/s, loss=0.5599]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.16it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.85it/s]


Validation: {'precision': 0.7803812549611582, 'recall': 0.6161806208823576, 'f1': 0.6886279957675998, 'auc': 0.7993382109218959, 'prauc': 0.8059153682299334}
Test:      {'precision': 0.7661444401957732, 'recall': 0.628723737846884, 'f1': 0.6906648245328885, 'auc': 0.7966202037803505, 'prauc': 0.808364911526467}


Epoch 004: 100%|██████████| 98/98 [00:02<00:00, 33.06it/s, loss=0.5266]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.37it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.10it/s]


Validation: {'precision': 0.702133632788596, 'recall': 0.7842583882070108, 'f1': 0.7409272650471375, 'auc': 0.805497741405016, 'prauc': 0.8157716648227409}
Test:      {'precision': 0.687757909214064, 'recall': 0.7839448102828976, 'f1': 0.7327080841165143, 'auc': 0.8013377143642345, 'prauc': 0.816432580695795}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 32.37it/s, loss=0.4916]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.45it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.39it/s]


Validation: {'precision': 0.7829549980056434, 'recall': 0.6164941988064707, 'f1': 0.6898245564718308, 'auc': 0.8093215132239759, 'prauc': 0.8003942823441483}
Test:      {'precision': 0.7853525516136137, 'recall': 0.6321730950121287, 'f1': 0.7004864439788153, 'auc': 0.8076730454355788, 'prauc': 0.8032657142731854}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 32.53it/s, loss=0.4898]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.26it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.04it/s]


Validation: {'precision': 0.7629370629343953, 'recall': 0.6842270304149131, 'f1': 0.7214415556009326, 'auc': 0.8115731755336074, 'prauc': 0.8138799721967678}
Test:      {'precision': 0.7645027624282994, 'recall': 0.6942615239865342, 'f1': 0.7276910385589135, 'auc': 0.8146190229556153, 'prauc': 0.8216906420410597}


Epoch 007: 100%|██████████| 98/98 [00:02<00:00, 32.71it/s, loss=0.4294]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.45it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 53.71it/s]


Validation: {'precision': 0.6799574694293462, 'recall': 0.8021323298814609, 'f1': 0.7360092023401603, 'auc': 0.7992098057354116, 'prauc': 0.8151735005590621}
Test:      {'precision': 0.6846339501188399, 'recall': 0.8005644402608951, 'f1': 0.7380745830595539, 'auc': 0.7984849102578235, 'prauc': 0.814844958851816}


Epoch 008: 100%|██████████| 98/98 [00:02<00:00, 32.95it/s, loss=0.4186]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.58it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.51it/s]


Validation: {'precision': 0.7408117609436216, 'recall': 0.7268736280943027, 'f1': 0.7337765065524176, 'auc': 0.8173563823657279, 'prauc': 0.8249955754003117}
Test:      {'precision': 0.7467178994532606, 'recall': 0.7312637190318869, 'f1': 0.7389100076724767, 'auc': 0.8165365293114826, 'prauc': 0.8265110986743529}


Epoch 009: 100%|██████████| 98/98 [00:03<00:00, 32.60it/s, loss=0.3890]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.54it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.37it/s]


Validation: {'precision': 0.7520233085116478, 'recall': 0.7284415177148685, 'f1': 0.7400445951900548, 'auc': 0.8211245420792739, 'prauc': 0.8282378489759883}
Test:      {'precision': 0.7450542437755423, 'recall': 0.7322044528042264, 'f1': 0.7385734569622997, 'auc': 0.8200777451784249, 'prauc': 0.8281903666072032}


Epoch 010: 100%|██████████| 98/98 [00:03<00:00, 32.40it/s, loss=0.3579]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.85it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.76it/s]


Validation: {'precision': 0.7559681697587667, 'recall': 0.7149576669780027, 'f1': 0.7348912117621949, 'auc': 0.8139345047141684, 'prauc': 0.8213177017251998}
Test:      {'precision': 0.7556001337320107, 'recall': 0.7086861084957394, 'f1': 0.7313915807632834, 'auc': 0.8142153480035454, 'prauc': 0.823683497148304}


Epoch 011: 100%|██████████| 98/98 [00:02<00:00, 32.84it/s, loss=0.3508]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.62it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.82it/s]


Validation: {'precision': 0.780036297637822, 'recall': 0.6738789589191788, 'f1': 0.7230820946204544, 'auc': 0.8201991807588345, 'prauc': 0.8279288386358876}
Test:      {'precision': 0.7849927849899532, 'recall': 0.6823455628702341, 'f1': 0.7300788408532554, 'auc': 0.823905308526954, 'prauc': 0.8330945971439123}


Epoch 012: 100%|██████████| 98/98 [00:02<00:00, 32.88it/s, loss=0.3105]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.25it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.63it/s]


Validation: {'precision': 0.7511520737302464, 'recall': 0.7155848228262289, 'f1': 0.7329372039294443, 'auc': 0.8109205998792911, 'prauc': 0.8172979582676498}
Test:      {'precision': 0.749509483319982, 'recall': 0.7187206020673605, 'f1': 0.733792215265577, 'auc': 0.8190805875131559, 'prauc': 0.8299854173153973}


Epoch 013: 100%|██████████| 98/98 [00:03<00:00, 32.64it/s, loss=0.2791]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 59.05it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.88it/s]


Validation: {'precision': 0.6877022653703136, 'recall': 0.7996237064885556, 'f1': 0.7394519306503531, 'auc': 0.7943976254485013, 'prauc': 0.806106544577692}
Test:      {'precision': 0.6863429031374125, 'recall': 0.8021323298814609, 'f1': 0.7397339452883839, 'auc': 0.8024970970133591, 'prauc': 0.8163918450112057}


Epoch 014: 100%|██████████| 98/98 [00:02<00:00, 32.89it/s, loss=0.2576]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 58.95it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 58.60it/s]

Validation: {'precision': 0.7540106951844777, 'recall': 0.6632173094993314, 'f1': 0.7057057007238722, 'auc': 0.8005585123244363, 'prauc': 0.806152124262013}
Test:      {'precision': 0.7623512441371715, 'recall': 0.6629037315752183, 'f1': 0.7091579956928799, 'auc': 0.8090155918191898, 'prauc': 0.8181824349107634}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7520233085116478, 'recall': 0.7284415177148685, 'f1': 0.7400445951900548, 'auc': 0.8211245420792739, 'prauc': 0.8282378489759883}
Corresponding test performance:
{'precision': 0.7450542437755423, 'recall': 0.7322044528042264, 'f1': 0.7385734569622997, 'auc': 0.8200777451784249, 'prauc': 0.8281903666072032}





In [19]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")


Final Metrics:
precision: 0.7506 ± 0.0440
recall: 0.7210 ± 0.0788
f1: 0.7304 ± 0.0206
auc: 0.8200 ± 0.0023
prauc: 0.8294 ± 0.0011


In [20]:
final_metrics = {"precision":[],"recall":[],"f1":[],"auc":[],"prauc":[]}
for i in range(5):
    model = HeteroGT(tokenizer, d_model=128, num_heads=4, num_layers=2, max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, use_hetero_graph=False).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

Epoch 001:   0%|          | 0/98 [00:00<?, ?it/s, loss=0.9875]

Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 77.70it/s, loss=0.6808]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 270.50it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 270.79it/s]


Validation: {'precision': 0.7596806387195223, 'recall': 0.5967387895873417, 'f1': 0.6684228963694178, 'auc': 0.7741425110689188, 'prauc': 0.7736705791636499}
Test:      {'precision': 0.7453011123868611, 'recall': 0.6092819065518681, 'f1': 0.6704623829017937, 'auc': 0.773390975297811, 'prauc': 0.7793031485869256}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 77.31it/s, loss=0.5913]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 271.66it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 269.88it/s]


Validation: {'precision': 0.7406773862444725, 'recall': 0.6788962057049893, 'f1': 0.7084424033841157, 'auc': 0.79324569629119, 'prauc': 0.8035906459345656}
Test:      {'precision': 0.7338250083783647, 'recall': 0.6864220758837052, 'f1': 0.7093324642190848, 'auc': 0.7930329327287269, 'prauc': 0.804023772685664}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 76.05it/s, loss=0.5405]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 271.65it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 270.31it/s]


Validation: {'precision': 0.6375198728125425, 'recall': 0.8802132329856375, 'f1': 0.7394625873282186, 'auc': 0.7995902990383978, 'prauc': 0.8084762572832853}
Test:      {'precision': 0.6350482315097955, 'recall': 0.8670429601728848, 'f1': 0.7331300494723579, 'auc': 0.7937257738403531, 'prauc': 0.8105184032038965}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 76.31it/s, loss=0.5145]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 268.65it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 270.86it/s]


Validation: {'precision': 0.7458745874562843, 'recall': 0.7086861084957394, 'f1': 0.726804947565652, 'auc': 0.8020429586685006, 'prauc': 0.8118951009382478}
Test:      {'precision': 0.7362780123392781, 'recall': 0.7108811539645316, 'f1': 0.7233567276093148, 'auc': 0.8011351218851721, 'prauc': 0.8171981268484294}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 75.40it/s, loss=0.4924]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 270.64it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 270.75it/s]


Validation: {'precision': 0.7566475934003479, 'recall': 0.7049231734063816, 'f1': 0.7298701248740224, 'auc': 0.8131958734712573, 'prauc': 0.8249760708295899}
Test:      {'precision': 0.7427073090765562, 'recall': 0.7105675760404184, 'f1': 0.7262820462821689, 'auc': 0.8060778750389456, 'prauc': 0.8202185682468744}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 76.40it/s, loss=0.4818]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 272.23it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 271.98it/s]


Validation: {'precision': 0.7594086021479859, 'recall': 0.7086861084957394, 'f1': 0.7331711223353012, 'auc': 0.8156032697707566, 'prauc': 0.8234205661296161}
Test:      {'precision': 0.7511490479292477, 'recall': 0.7174662903709079, 'f1': 0.7339214063876055, 'auc': 0.8118749392851732, 'prauc': 0.8246863151127584}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 75.73it/s, loss=0.4569]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 268.30it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 270.92it/s]


Validation: {'precision': 0.7772277227693175, 'recall': 0.6400125431149577, 'f1': 0.7019776390685306, 'auc': 0.8074511480056655, 'prauc': 0.8150303510408229}
Test:      {'precision': 0.7695457929978796, 'recall': 0.6481655691418998, 'f1': 0.7036595695023411, 'auc': 0.8068491357983679, 'prauc': 0.8170335810146248}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 76.76it/s, loss=0.4251]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 272.92it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 272.08it/s]


Validation: {'precision': 0.723103547733709, 'recall': 0.7861398557516898, 'f1': 0.7533052834679974, 'auc': 0.8234725082787687, 'prauc': 0.827604467163499}
Test:      {'precision': 0.7138009049753569, 'recall': 0.7914706804616135, 'f1': 0.7506319652713027, 'auc': 0.8190424346934965, 'prauc': 0.827743723541574}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 77.34it/s, loss=0.3907]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 273.18it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 270.67it/s]


Validation: {'precision': 0.7041436464068946, 'recall': 0.7993101285644425, 'f1': 0.7487149311318013, 'auc': 0.817015123980834, 'prauc': 0.8222931115615574}
Test:      {'precision': 0.7062860279969633, 'recall': 0.8068359987431583, 'f1': 0.7532201355350969, 'auc': 0.8178121824268515, 'prauc': 0.8270057406167948}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 76.93it/s, loss=0.3678]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 272.62it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 272.36it/s]


Validation: {'precision': 0.7600401606400267, 'recall': 0.7121354656609842, 'f1': 0.7353083971398732, 'auc': 0.8179682662355949, 'prauc': 0.8209088948245178}
Test:      {'precision': 0.761174116074846, 'recall': 0.7155848228262289, 'f1': 0.7376757667819184, 'auc': 0.8191929822971924, 'prauc': 0.8257576841553562}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 76.82it/s, loss=0.3519]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 273.06it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 273.01it/s]


Validation: {'precision': 0.6469248291557018, 'recall': 0.8905613044813717, 'f1': 0.7494392351288547, 'auc': 0.8112326204635066, 'prauc': 0.8105900018423349}
Test:      {'precision': 0.648089171973048, 'recall': 0.8933835057983902, 'f1': 0.7512195073197527, 'auc': 0.814205784631863, 'prauc': 0.8186206411355479}


Epoch 012: 100%|██████████| 98/98 [00:01<00:00, 75.37it/s, loss=0.3277]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 271.66it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 270.37it/s]


Validation: {'precision': 0.7507609063214381, 'recall': 0.6961429915312132, 'f1': 0.7224210818905533, 'auc': 0.8083076347131666, 'prauc': 0.8136984470655553}
Test:      {'precision': 0.7460784313701109, 'recall': 0.7158984007503422, 'f1': 0.7306769033051213, 'auc': 0.8126675414635078, 'prauc': 0.8200435577556371}


Epoch 013: 100%|██████████| 98/98 [00:01<00:00, 78.35it/s, loss=0.3089]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 270.93it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 270.65it/s]


Validation: {'precision': 0.7327345927509051, 'recall': 0.7419253684517343, 'f1': 0.7373013349791993, 'auc': 0.812017168115043, 'prauc': 0.8178229870401195}
Test:      {'precision': 0.7354740061139589, 'recall': 0.7541549074921475, 'f1': 0.7446973165652865, 'auc': 0.8138855123475712, 'prauc': 0.8226934140880926}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.723103547733709, 'recall': 0.7861398557516898, 'f1': 0.7533052834679974, 'auc': 0.8234725082787687, 'prauc': 0.827604467163499}
Corresponding test performance:
{'precision': 0.7138009049753569, 'recall': 0.7914706804616135, 'f1': 0.7506319652713027, 'auc': 0.8190424346934965, 'prauc': 0.827743723541574}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 76.88it/s, loss=0.6675]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 273.21it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 271.83it/s]


Validation: {'precision': 0.653055983562781, 'recall': 0.7974286610197635, 'f1': 0.7180573153919946, 'auc': 0.7697164003665676, 'prauc': 0.7781810873333918}
Test:      {'precision': 0.6523649521823408, 'recall': 0.7914706804616135, 'f1': 0.7152167703348353, 'auc': 0.7619645581445448, 'prauc': 0.7721218331369964}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 76.60it/s, loss=0.5922]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 270.27it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 270.49it/s]


Validation: {'precision': 0.7645492582701996, 'recall': 0.6302916274674497, 'f1': 0.6909590875155879, 'auc': 0.7940072355016428, 'prauc': 0.8007158723119753}
Test:      {'precision': 0.751215862324163, 'recall': 0.6296644716192235, 'f1': 0.6850904078647905, 'auc': 0.7898473736213015, 'prauc': 0.8012395962527846}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 78.28it/s, loss=0.5625]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 271.07it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 270.37it/s]


Validation: {'precision': 0.7485855728403515, 'recall': 0.6638444653475577, 'f1': 0.7036729217233207, 'auc': 0.8005267124484408, 'prauc': 0.8106282316918073}
Test:      {'precision': 0.7451737451711297, 'recall': 0.6657259328922367, 'f1': 0.703212979445456, 'auc': 0.7944257119804383, 'prauc': 0.8083932374845886}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 77.58it/s, loss=0.5307]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 269.38it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 270.60it/s]


Validation: {'precision': 0.7516960651263511, 'recall': 0.6948886798347604, 'f1': 0.722176954431787, 'auc': 0.8084616104160518, 'prauc': 0.8193932427482901}
Test:      {'precision': 0.7340565417464364, 'recall': 0.7002195045446842, 'f1': 0.7167388812145732, 'auc': 0.7977095221485172, 'prauc': 0.8103922019855921}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 77.58it/s, loss=0.5004]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 271.39it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 271.74it/s]


Validation: {'precision': 0.7030567685570331, 'recall': 0.8077767325154978, 'f1': 0.7517875333262195, 'auc': 0.8128637079419009, 'prauc': 0.820932364727543}
Test:      {'precision': 0.6870821075135408, 'recall': 0.8055816870467056, 'f1': 0.741628170549002, 'auc': 0.8017449126637666, 'prauc': 0.8117955363066887}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 78.41it/s, loss=0.4781]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 272.16it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 270.02it/s]


Validation: {'precision': 0.7391304347802282, 'recall': 0.7196613358397, 'f1': 0.7292659625867525, 'auc': 0.8065824537628297, 'prauc': 0.8163821276216885}
Test:      {'precision': 0.7296689167446813, 'recall': 0.711821887736871, 'f1': 0.7206349156333993, 'auc': 0.7985444044963953, 'prauc': 0.8086593489164879}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 77.43it/s, loss=0.4398]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 272.30it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 272.67it/s]


Validation: {'precision': 0.8496982995016473, 'recall': 0.4857322044512834, 'f1': 0.6181165157200972, 'auc': 0.8104524434110405, 'prauc': 0.823709619196587}
Test:      {'precision': 0.8273495248108377, 'recall': 0.49137660708532027, 'f1': 0.6165650159792073, 'auc': 0.7996632686497074, 'prauc': 0.8128254839487375}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 76.08it/s, loss=0.4268]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 272.16it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 270.38it/s]


Validation: {'precision': 0.8041894353333143, 'recall': 0.5537786139838389, 'f1': 0.6558960025956238, 'auc': 0.8064074790896977, 'prauc': 0.8125005976833715}
Test:      {'precision': 0.8019625334486978, 'recall': 0.56381310755546, 'f1': 0.662124834037451, 'auc': 0.8008102692491798, 'prauc': 0.811277286896215}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 77.86it/s, loss=0.3953]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 271.32it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 272.11it/s]


Validation: {'precision': 0.7325759901727584, 'recall': 0.7481969269339975, 'f1': 0.7403040595344061, 'auc': 0.8140428151922828, 'prauc': 0.8218101873319212}
Test:      {'precision': 0.7182846932676644, 'recall': 0.7563499529609397, 'f1': 0.736826022189106, 'auc': 0.8059017076658478, 'prauc': 0.8156704736004937}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 76.77it/s, loss=0.3746]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 272.23it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 272.50it/s]


Validation: {'precision': 0.6909329829153984, 'recall': 0.8243963624934951, 'f1': 0.7517872412476748, 'auc': 0.8121396453625582, 'prauc': 0.8166653600452305}
Test:      {'precision': 0.6722451081342116, 'recall': 0.8187519598594584, 'f1': 0.738300574715353, 'auc': 0.7969956919527247, 'prauc': 0.8079729650408417}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 76.87it/s, loss=0.3498]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 271.06it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 269.87it/s]


Validation: {'precision': 0.704743193935715, 'recall': 0.7873941674481424, 'f1': 0.7437796158662183, 'auc': 0.8055325052504961, 'prauc': 0.8129459340302254}
Test:      {'precision': 0.6874487284640759, 'recall': 0.7883349012204819, 'f1': 0.7344434656610104, 'auc': 0.7964765518709731, 'prauc': 0.8062725675567427}


Epoch 012: 100%|██████████| 98/98 [00:01<00:00, 76.14it/s, loss=0.3353]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 269.37it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 269.04it/s]


Validation: {'precision': 0.732598425194543, 'recall': 0.729382251487208, 'f1': 0.7309867957519696, 'auc': 0.804575444527603, 'prauc': 0.8086647572609807}
Test:      {'precision': 0.7281191806308873, 'recall': 0.7356538099694712, 'f1': 0.7318670984138531, 'auc': 0.7973731934665058, 'prauc': 0.8070798148489233}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.8496982995016473, 'recall': 0.4857322044512834, 'f1': 0.6181165157200972, 'auc': 0.8104524434110405, 'prauc': 0.823709619196587}
Corresponding test performance:
{'precision': 0.8273495248108377, 'recall': 0.49137660708532027, 'f1': 0.6165650159792073, 'auc': 0.7996632686497074, 'prauc': 0.8128254839487375}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 76.79it/s, loss=0.6849]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 268.90it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 269.12it/s]


Validation: {'precision': 0.5324453551903476, 'recall': 0.9777359673848299, 'f1': 0.6894416759196416, 'auc': 0.7720011687082389, 'prauc': 0.7770787224300174}
Test:      {'precision': 0.5322718712540109, 'recall': 0.9749137660678115, 'f1': 0.6885935723954073, 'auc': 0.7570513507759167, 'prauc': 0.7637083841347175}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 76.29it/s, loss=0.6048]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 270.25it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 268.86it/s]


Validation: {'precision': 0.6505938242264832, 'recall': 0.8588899341459427, 'f1': 0.7403703154067632, 'auc': 0.798609526559628, 'prauc': 0.8084659640757648}
Test:      {'precision': 0.6382775119601956, 'recall': 0.8366259015339084, 'f1': 0.7241145289465163, 'auc': 0.7846902499412356, 'prauc': 0.8025325066396418}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 77.05it/s, loss=0.5631]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 268.48it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 270.44it/s]


Validation: {'precision': 0.6765903307870825, 'recall': 0.83380370021689, 'f1': 0.7470150252529436, 'auc': 0.8114734555433212, 'prauc': 0.8189137869321814}
Test:      {'precision': 0.66896199948822, 'recall': 0.8225148949488162, 'f1': 0.7378340316191356, 'auc': 0.8004622128534735, 'prauc': 0.8176036633103503}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 75.77it/s, loss=0.5224]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 269.54it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 267.67it/s]


Validation: {'precision': 0.7588057698733351, 'recall': 0.7093132643439658, 'f1': 0.7332252786337756, 'auc': 0.82076047620038, 'prauc': 0.8321130965614592}
Test:      {'precision': 0.7475149105343025, 'recall': 0.7074317967992868, 'f1': 0.7269212129811531, 'auc': 0.8074899823681627, 'prauc': 0.8235474790963113}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 76.72it/s, loss=0.4906]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 269.53it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 269.29it/s]


Validation: {'precision': 0.7918864097330959, 'recall': 0.6121041078688865, 'f1': 0.6904846077431438, 'auc': 0.8132666570814908, 'prauc': 0.826185798180048}
Test:      {'precision': 0.7836422240097356, 'recall': 0.6099090624000943, 'f1': 0.685946036337857, 'auc': 0.8011738283737184, 'prauc': 0.8181790049370115}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 76.03it/s, loss=0.4715]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 269.70it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 267.48it/s]


Validation: {'precision': 0.7831513260499877, 'recall': 0.6296644716192235, 'f1': 0.6980705669321285, 'auc': 0.8115785006313255, 'prauc': 0.818911822629781}
Test:      {'precision': 0.7734374999969788, 'recall': 0.620884289744055, 'f1': 0.6888154412220073, 'auc': 0.7982659593798305, 'prauc': 0.8137203364631582}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 76.90it/s, loss=0.4402]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 269.81it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 268.52it/s]


Validation: {'precision': 0.7911921032619925, 'recall': 0.6534963938518235, 'f1': 0.7157822378731196, 'auc': 0.8190879936227433, 'prauc': 0.8288479953821108}
Test:      {'precision': 0.7765321375157828, 'recall': 0.6516149263071446, 'f1': 0.7086103957178489, 'auc': 0.8084670569528984, 'prauc': 0.8222128951524295}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 76.98it/s, loss=0.4144]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 269.43it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 266.17it/s]


Validation: {'precision': 0.7106757524113837, 'recall': 0.7848855440552371, 'f1': 0.7459394973197285, 'auc': 0.8130515432283898, 'prauc': 0.8189713092499962}
Test:      {'precision': 0.7010163749274393, 'recall': 0.778613985572974, 'f1': 0.737780413968624, 'auc': 0.8024157076869878, 'prauc': 0.8106723475369689}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 77.24it/s, loss=0.3770]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 266.72it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 265.56it/s]


Validation: {'precision': 0.7716942148733758, 'recall': 0.7027281279375894, 'f1': 0.7355982224826757, 'auc': 0.8166883337859667, 'prauc': 0.8193410153011862}
Test:      {'precision': 0.7532334921689815, 'recall': 0.693947946062421, 'f1': 0.7223763619066371, 'auc': 0.8067564717601184, 'prauc': 0.8150928403267784}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7588057698733351, 'recall': 0.7093132643439658, 'f1': 0.7332252786337756, 'auc': 0.82076047620038, 'prauc': 0.8321130965614592}
Corresponding test performance:
{'precision': 0.7475149105343025, 'recall': 0.7074317967992868, 'f1': 0.7269212129811531, 'auc': 0.8074899823681627, 'prauc': 0.8235474790963113}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 76.35it/s, loss=0.6726]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 266.45it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 265.41it/s]


Validation: {'precision': 0.826494724496914, 'recall': 0.4421448729995543, 'f1': 0.5760980547007035, 'auc': 0.7824252484384151, 'prauc': 0.7893334911459093}
Test:      {'precision': 0.8215313759307915, 'recall': 0.447475697709478, 'f1': 0.5793747416764898, 'auc': 0.7726558036831062, 'prauc': 0.7822448904155225}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 76.51it/s, loss=0.5895]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 266.10it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 265.08it/s]


Validation: {'precision': 0.8033635187546212, 'recall': 0.5841956726228154, 'f1': 0.6764705833575823, 'auc': 0.8010132053380788, 'prauc': 0.8057981090017364}
Test:      {'precision': 0.7820565342040883, 'recall': 0.5986202571320206, 'f1': 0.6781527481941976, 'auc': 0.7959069775869801, 'prauc': 0.8039403852269505}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 77.48it/s, loss=0.5435]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 266.40it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 267.97it/s]


Validation: {'precision': 0.629359095191731, 'recall': 0.8375666353062479, 'f1': 0.7186869317327232, 'auc': 0.7827918261558552, 'prauc': 0.8006390269711605}
Test:      {'precision': 0.6306413301647729, 'recall': 0.8325493885204374, 'f1': 0.7176645443566826, 'auc': 0.7792754185610953, 'prauc': 0.7996120314079271}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 76.08it/s, loss=0.5231]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 267.82it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 266.18it/s]


Validation: {'precision': 0.8074107959706889, 'recall': 0.5534650360597257, 'f1': 0.6567441812181745, 'auc': 0.785400571433222, 'prauc': 0.806359330144618}
Test:      {'precision': 0.7959731543588547, 'recall': 0.55785512699731, 'f1': 0.6559734464796926, 'auc': 0.7802394567602217, 'prauc': 0.8028985903560122}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 76.95it/s, loss=0.4922]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 268.85it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 267.71it/s]


Validation: {'precision': 0.8004856333435837, 'recall': 0.6202571338958286, 'f1': 0.6989399244066135, 'auc': 0.8229305037572583, 'prauc': 0.8327218487115126}
Test:      {'precision': 0.7913043478229593, 'recall': 0.6277830040745445, 'f1': 0.7001223940847492, 'auc': 0.8178472145673303, 'prauc': 0.8303928911971568}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 76.99it/s, loss=0.4706]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 268.52it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 268.45it/s]


Validation: {'precision': 0.774285714282949, 'recall': 0.6798369394773288, 'f1': 0.7239939839984727, 'auc': 0.8209063135463754, 'prauc': 0.832091936664052}
Test:      {'precision': 0.7579705176525267, 'recall': 0.6933207902141947, 'f1': 0.7242056943197018, 'auc': 0.8158145450829923, 'prauc': 0.8300598192430708}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 76.39it/s, loss=0.4393]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 267.29it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 267.55it/s]


Validation: {'precision': 0.7975338106571299, 'recall': 0.628723737846884, 'f1': 0.7031386939979657, 'auc': 0.825816706720444, 'prauc': 0.8348135992659167}
Test:      {'precision': 0.793638479438349, 'recall': 0.6415804327355235, 'f1': 0.7095543560732582, 'auc': 0.824083338240853, 'prauc': 0.8373027448454122}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 77.50it/s, loss=0.4041]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 267.08it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 265.77it/s]


Validation: {'precision': 0.7151310228213155, 'recall': 0.7958607713991978, 'f1': 0.7533392648250379, 'auc': 0.818174689127326, 'prauc': 0.8231845006296772}
Test:      {'precision': 0.7178481367309671, 'recall': 0.8033866415779135, 'f1': 0.7582124839156071, 'auc': 0.8228095474662854, 'prauc': 0.8317623856636667}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 75.57it/s, loss=0.3859]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 265.09it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 263.52it/s]


Validation: {'precision': 0.7264906555896574, 'recall': 0.7679523361531265, 'f1': 0.746646336464987, 'auc': 0.8198077860765574, 'prauc': 0.8238577159566957}
Test:      {'precision': 0.7343108504377294, 'recall': 0.7851991219793503, 'f1': 0.7589028590736216, 'auc': 0.8261823976580813, 'prauc': 0.8307374175436999}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 76.86it/s, loss=0.3525]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 265.98it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 264.15it/s]


Validation: {'precision': 0.7489249090282868, 'recall': 0.7099404201921922, 'f1': 0.7289117786457826, 'auc': 0.8223267080075102, 'prauc': 0.8270166455847985}
Test:      {'precision': 0.7588424437274636, 'recall': 0.7400439009070554, 'f1': 0.7493252847269358, 'auc': 0.8268126238519551, 'prauc': 0.834419936231689}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 75.83it/s, loss=0.3067]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 263.91it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 268.42it/s]


Validation: {'precision': 0.7447563730211528, 'recall': 0.7237378488531712, 'f1': 0.7340966871106489, 'auc': 0.8154586381072633, 'prauc': 0.818280031888166}
Test:      {'precision': 0.7478455154779833, 'recall': 0.7347130761971317, 'f1': 0.7412211275510371, 'auc': 0.8217760489886231, 'prauc': 0.8271851084232431}


Epoch 012: 100%|██████████| 98/98 [00:01<00:00, 78.20it/s, loss=0.2906]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 270.20it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 268.85it/s]


Validation: {'precision': 0.7489999999975033, 'recall': 0.7046095954822684, 'f1': 0.7261269945175853, 'auc': 0.8128585335544957, 'prauc': 0.8131383812379978}
Test:      {'precision': 0.7534201954372853, 'recall': 0.7253057384737369, 'f1': 0.7390956971882939, 'auc': 0.8193321545219396, 'prauc': 0.8216441908342393}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7975338106571299, 'recall': 0.628723737846884, 'f1': 0.7031386939979657, 'auc': 0.825816706720444, 'prauc': 0.8348135992659167}
Corresponding test performance:
{'precision': 0.793638479438349, 'recall': 0.6415804327355235, 'f1': 0.7095543560732582, 'auc': 0.824083338240853, 'prauc': 0.8373027448454122}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 76.11it/s, loss=0.6953]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 271.40it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 271.65it/s]


Validation: {'precision': 0.7699009900959608, 'recall': 0.6095954844759812, 'f1': 0.6804340167662224, 'auc': 0.7879993698299455, 'prauc': 0.7952794293645455}
Test:      {'precision': 0.7551981169056289, 'recall': 0.6036375039178312, 'f1': 0.6709654882630707, 'auc': 0.7794883797484029, 'prauc': 0.7882891621456141}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 77.74it/s, loss=0.5812]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 271.80it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 269.87it/s]


Validation: {'precision': 0.6801373481228734, 'recall': 0.8074631545913846, 'f1': 0.7383512495147991, 'auc': 0.8021077641029939, 'prauc': 0.8086485612742184}
Test:      {'precision': 0.6731374606488113, 'recall': 0.8046409532743661, 'f1': 0.7330381324464412, 'auc': 0.7904914918708824, 'prauc': 0.7953141518065054}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 78.16it/s, loss=0.5453]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 272.10it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 271.04it/s]


Validation: {'precision': 0.8071287128680906, 'recall': 0.6390718093426182, 'f1': 0.7133356618483614, 'auc': 0.8201987286278961, 'prauc': 0.8284678812152024}
Test:      {'precision': 0.7880308880278455, 'recall': 0.6400125431149577, 'f1': 0.70635057473634, 'auc': 0.8108525141852486, 'prauc': 0.816701147987353}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 77.23it/s, loss=0.5253]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 271.67it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 270.71it/s]


Validation: {'precision': 0.7952522255163381, 'recall': 0.6723110692986131, 'f1': 0.7286321105806163, 'auc': 0.8283980729576568, 'prauc': 0.8413077788946206}
Test:      {'precision': 0.7792161093103948, 'recall': 0.6795233615532157, 'f1': 0.7259631440996481, 'auc': 0.8241270780829165, 'prauc': 0.8336880381691711}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 76.45it/s, loss=0.4735]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 272.57it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 270.05it/s]


Validation: {'precision': 0.689245087898942, 'recall': 0.8359987456856821, 'f1': 0.7555618485229622, 'auc': 0.823966938578213, 'prauc': 0.8334505772491907}
Test:      {'precision': 0.6878374903556497, 'recall': 0.8388209470027005, 'f1': 0.7558632332496503, 'auc': 0.8238884971262068, 'prauc': 0.8304018679387182}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 78.08it/s, loss=0.4532]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 270.41it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 271.08it/s]


Validation: {'precision': 0.7504072987919505, 'recall': 0.7221699592326053, 'f1': 0.7360178920912089, 'auc': 0.8173595975190672, 'prauc': 0.825555350578715}
Test:      {'precision': 0.7457789104786691, 'recall': 0.7340859203489054, 'f1': 0.7398862149726895, 'auc': 0.8193341678633465, 'prauc': 0.8262427003031784}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 77.14it/s, loss=0.4166]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 268.21it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 269.67it/s]


Validation: {'precision': 0.7362869198290046, 'recall': 0.7660708686084476, 'f1': 0.7508836587463917, 'auc': 0.8269634112535794, 'prauc': 0.8369090431413035}
Test:      {'precision': 0.7317647058802007, 'recall': 0.7801818751935398, 'f1': 0.7551980523711763, 'auc': 0.8296206311120641, 'prauc': 0.8373054489234291}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 76.45it/s, loss=0.4017]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 271.93it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 270.36it/s]


Validation: {'precision': 0.7513884351494564, 'recall': 0.7212292254602658, 'f1': 0.735999994999742, 'auc': 0.8184255715613257, 'prauc': 0.824237878865308}
Test:      {'precision': 0.7538314176221143, 'recall': 0.7403574788311685, 'f1': 0.7470336921978532, 'auc': 0.8246424431495303, 'prauc': 0.826336013520248}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 75.60it/s, loss=0.3663]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 270.02it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 269.35it/s]

Validation: {'precision': 0.6740778170776299, 'recall': 0.8366259015339084, 'f1': 0.746606963014381, 'auc': 0.8151495312557615, 'prauc': 0.8264073755421899}
Test:      {'precision': 0.6823113802657524, 'recall': 0.8479147068019821, 'f1': 0.756152120336087, 'auc': 0.8155256809246472, 'prauc': 0.8183844442324364}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7952522255163381, 'recall': 0.6723110692986131, 'f1': 0.7286321105806163, 'auc': 0.8283980729576568, 'prauc': 0.8413077788946206}
Corresponding test performance:
{'precision': 0.7792161093103948, 'recall': 0.6795233615532157, 'f1': 0.7259631440996481, 'auc': 0.8241270780829165, 'prauc': 0.8336880381691711}





In [21]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")


Final Metrics:
precision: 0.7723 ± 0.0389
recall: 0.6623 ± 0.0987
f1: 0.7059 ± 0.0466
auc: 0.8149 ± 0.0097
prauc: 0.8270 ± 0.0085
