In [1]:
import torch
import pickle
import numpy as np
from argparse import Namespace
from torch.utils.data import DataLoader
import torch.nn.functional as F
from heterogt.utils.tokenizer import EHRTokenizer
from heterogt.utils.dataset import FineTuneEHRDataset, batcher
from heterogt.utils.train import train_with_early_stopping
from heterogt.utils.seed import set_random_seed

In [2]:
set_random_seed(123)

[INFO] Random seed set to 123


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
config = Namespace(
    dataset = "MIMIC-III",
    tasks = ["death", "readmission", "stay", "next_diag_6m", "next_diag_12m"], 
    task_index = 0,
    token_type = ["diag", "med", "lab", "pro"],
    special_tokens = ["[PAD]"],
    batch_size = 32,
    lr = 1e-3,
    epochs = 500,
    early_stop_patience = 5,   
)

In [5]:
full_data_path = f"./data_process/{config.dataset}-processed/mimic.pkl"  # for tokenizer
curr_task = config.tasks[config.task_index]
print("Current task:", curr_task)
if curr_task == "next_diag_6m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_6m.pkl"
elif curr_task == "next_diag_12m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_downstream.pkl"

Current task: death


In [6]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_gender_sentences = ["[PAD]"] + [str(c) + "_" + gender for c in set(ehr_full_data["AGE"].values.tolist()) for gender in ["M", "F"]]
token_type_sentences = ["[PAD]"] + config.token_type
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
config.max_num_adms = max_admissions
print(f"Max admissions per patient: {config.max_num_adms}")

Max admissions per patient: 8


In [7]:
task_sentences = config.tasks
tokenizer = EHRTokenizer(token_type_sentences, age_gender_sentences, task_sentences, diag_sentences, 
                         med_sentences, lab_sentences, pro_sentences, special_tokens=config.special_tokens)
config.label_vocab_size = len(tokenizer.diag_voc.id2word)  # only for diagnosis
config.global_vocab_size = len(tokenizer.vocab.id2word)
config.age_gender_vocab_size = tokenizer.token_number("age_gender")
print(f"Age and gender vocabulary size: {config.age_gender_vocab_size}")

Age and gender vocabulary size: 37


In [8]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))
# example label percentage
print("Percentage of DEATH in test dataset:",
      (test_data["DEATH"] == True).mean() * 100, "%")

print("Percentage of READMISSION in test dataset:",
      (test_data["READMISSION"] == 1).mean() * 100, "%")

print("Percentage of STAY>7 days in test dataset:",
      (test_data["STAY_DAYS"] > 7).mean() * 100, "%")

Percentage of DEATH in test dataset: 28.648477157360407 %
Percentage of READMISSION in test dataset: 40.1491116751269 %
Percentage of STAY>7 days in test dataset: 50.58692893401015 %


In [9]:
train_dataset = FineTuneEHRDataset(train_data, tokenizer, token_type=config.token_type, task=curr_task)
val_dataset = FineTuneEHRDataset(val_data, tokenizer, token_type=config.token_type, task=curr_task)
test_dataset = FineTuneEHRDataset(test_data, tokenizer, token_type=config.token_type, task=curr_task)

In [10]:
train_dataloader = DataLoader(
    train_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=True,
    batch_size=config.batch_size,
)

val_dataloader = DataLoader(
    val_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

test_dataloader = DataLoader(
    test_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

In [11]:
if curr_task in ["death", "stay", "readmission"]:
    eval_metric = "prauc"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "prauc"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [12]:
input_ids, token_types, adm_index, age_gender_ids, task_index, labels = next(iter(train_dataloader))
print("Input IDs shape:", input_ids.shape)
print("Token Types shape:", token_types.shape)
print("Admission Index shape:", adm_index.shape)
print("Age/Sex IDs shape:", age_gender_ids.shape)
print("Task Index:", task_index)
print("Labels shape:", labels.shape)

Input IDs shape: torch.Size([32, 256])
Token Types shape: torch.Size([32, 256])
Admission Index shape: torch.Size([32, 256])
Age/Sex IDs shape: torch.Size([32, 7])
Task Index: 0
Labels shape: torch.Size([32, 1])


# Model Walkthrough

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData, Batch as HeteroBatch
from torch_geometric.nn import HeteroConv, GATConv

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [14]:
class DiseaseOccHetGNN(nn.Module):
    def __init__(self, d_model: int):
        super().__init__()
        self.conv1 = HeteroConv({
            ('visit','contains','occ'): GATConv(d_model, d_model, add_self_loops=False),
            ('occ','contained_by','visit'): GATConv(d_model, d_model, add_self_loops=False),
            ('visit','next','visit'): GATConv(d_model, d_model, add_self_loops=False),
        }, aggr='mean')
        self.lin = nn.Linear(d_model, d_model)
    
    def forward(self, hg: HeteroData):
        # x_dict: {'visit': [N_visit, d], 'occ': [N_occ, d]}
        x_dict = {'visit': hg['visit'].x, 'occ': hg['occ'].x}
        x_dict = self.conv1(x_dict, hg.edge_index_dict)
        x_dict = {k: self.lin(v) for k, v in x_dict.items()}
        return x_dict # {'visit': [N_visit, d], 'occ': [N_occ, d]}

In [15]:
# multi-class classification task
class MultiPredictionHead(nn.Module):
    def __init__(self, hidden_size, label_size):
        super(MultiPredictionHead, self).__init__()
        self.cls = nn.Sequential(
                nn.Linear(hidden_size, hidden_size), 
                nn.ReLU(), 
                nn.Linear(hidden_size, label_size)
            )

    def forward(self, input):
        return self.cls(input)
    
class BinaryPredictionHead(nn.Module):
    def __init__(self, hidden_size):
        super(BinaryPredictionHead, self).__init__()
        self.cls = nn.Sequential(
                nn.Linear(hidden_size, hidden_size), 
                nn.ReLU(), 
                nn.Linear(hidden_size, 1)
            )

    def forward(self, input):
        return self.cls(input)

In [16]:
for i in range(len(train_dataset)):
    age_gender_ids = train_dataset[i][3]
    if len(age_gender_ids[0]) > 3:
        print(age_gender_ids)
        break
exp_i = i
id_seq = torch.concat([train_dataset[exp_i][0][0], torch.zeros(5, dtype=train_dataset[exp_i][0][0].dtype)], dim=0)
type_seq = torch.concat([train_dataset[exp_i][1][0], torch.zeros(5, dtype=train_dataset[exp_i][1][0].dtype)], dim=0)
visit_seq = torch.concat([train_dataset[exp_i][2][0], torch.zeros(5, dtype=train_dataset[exp_i][2][0].dtype)], dim=0)
age_sex = torch.concat([train_dataset[exp_i][3][0], torch.zeros(3, dtype=train_dataset[exp_i][3][0].dtype)], dim=0)

tensor([[25, 25, 25, 11]])


In [None]:
class HeteroGT(nn.Module):
    def __init__(self, tokenizer, d_model, num_heads, num_layers, max_num_adms, device, task, use_hetero_graph):
        super(HeteroGT, self).__init__()
        self.device = device
        self.tokenizer = tokenizer
        self.max_num_adms = max_num_adms
        self.use_hetero_graph = use_hetero_graph
        self.global_vocab_size = len(self.tokenizer.vocab.word2id)
        self.age_sex_vocab_size = len(self.tokenizer.age_gender_voc.word2id)
        self.n_type = len(self.tokenizer.token_type_voc.word2id)
        self.d_model = d_model
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.seq_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="all")[0] #0
        self.type_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="type")[0] #0
        self.adm_pad_id = 0
        self.age_sex_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="age_gender")[0] #0
        self.diag_type_id = 1
        self.visit_type_id = 5
        
        # embedding layers
        self.token_emb = nn.Embedding(self.global_vocab_size, d_model, padding_idx=self.seq_pad_id)
        self.type_emb = nn.Embedding(self.n_type + 1, d_model, padding_idx=self.type_pad_id)
        self.adm_index_emb = nn.Embedding(self.max_num_adms + 1, d_model, padding_idx=self.adm_pad_id) # +1 for pad
        self.age_sex_emb = nn.Embedding(self.age_sex_vocab_size, d_model, padding_idx=self.age_sex_pad_id)
        self.task_emb = nn.Embedding(5, d_model, padding_idx=None)  # task embedding, not used in this model
        
        # GNN
        self.het_gnn = DiseaseOccHetGNN(d_model)    

        # encoder transformer
        enc_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, batch_first=True, norm_first = True)
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers, enable_nested_tensor=False)

        # prediction head
        if task in ["death", "stay", "readmission"]:
            self.cls_head = BinaryPredictionHead(self.d_model)
        else:
            self.cls_head = MultiPredictionHead(self.d_model, config.label_vocab_size)


    def forward(self, input_ids, token_types, adm_index, age_gender_index, task_id):
        B, L = input_ids.shape
        task_id = torch.full((B,), task_id, dtype=torch.long, device=self.device)
        # 基础表示
        token_embed = self.token_emb(input_ids)  # [B, L, d]
        adm_emb  = self.adm_index_emb(adm_index)          # [B, L, d]
        type_emb = self.type_emb(token_types)       # [B, L, d]
        x_tokens = token_embed + adm_emb + type_emb  # [B, L, d]
        task_emb = self.task_emb(task_id).unsqueeze(1)           # [B, 1, d]
        x = torch.cat([task_emb, x_tokens], dim=1)  # [B, 1+L, d]
        
        # mask      
        seq_pad_mask = (input_ids == self.seq_pad_id)         # [B, L]
        task_pad_mask = torch.zeros((B, 1), dtype=torch.bool, device=self.device)
        mask = torch.concat([task_pad_mask, seq_pad_mask], dim=1)  # [B, 1+L]
        
        if self.use_hetero_graph:
            # get visit embed and mask
            visit_emb_pad, visit_pad_mask = self.visit_segment(B, input_ids, token_types, adm_index, age_gender_index)
            x = torch.cat([x, visit_emb_pad], dim=1)  # [B, 1+L(+V), d]
            mask = torch.concat([mask, visit_pad_mask], dim=1)

        
        # ===== Transformer 编码（batch_first=True） =====
        h = self.encoder(x, src_key_padding_mask=mask)   # [B, 1+L(+V), d]

        # ===== 分类：取 CLS（task 位） =====
        logits = self.cls_head(h[:, 0, :])  # [B, label_size]
        return logits

    def visit_segment(self, B, input_ids, token_types, adm_index, age_gender_index):
        graphs = []
        for p in range(B):
            hg_p = self.build_patient_graph(input_ids[p], token_types[p], adm_index[p], age_gender_index[p])
            graphs.append(hg_p)

        batch_graph = HeteroBatch.from_data_list(graphs).to(self.device)
        out = self.het_gnn(batch_graph)
        h_visit_all = out['visit']  # extract virtual visit node representations

        # 取出每个样本的 visit 表示序列（按我们在 build_patient_graph 中的保序构造）
        visit_emb_seq = []
        offset = 0
        for p in range(B):
            n_v = graphs[p]['visit'].num_nodes
            visit_emb_p = h_visit_all[offset:offset + n_v]  # [N_visit_p, d]
            offset += n_v
            visit_emb_seq.append(visit_emb_p)
            
        visit_emb_pad = []
        visit_pad_mask = []
        visit_index_pad = []
        for p in range(B):
            v = visit_emb_seq[p]                          # [N_visit_p, d]
            Np = v.size(0)
            if Np < self.max_num_adms:
                pad_len = self.max_num_adms - Np
                v_pad = torch.cat([v, torch.zeros(pad_len, self.d_model, device=self.device, dtype=v.dtype)], dim=0)
                m_pad = torch.cat([torch.zeros(Np, dtype=torch.bool, device=self.device), torch.ones(pad_len, dtype=torch.bool, device=self.device)], dim=0)
                i_pad = torch.cat([torch.arange(1, Np + 1, device=self.device), torch.full((pad_len,), self.adm_pad_id, dtype=torch.long, device=self.device)], dim=0)
            else:
                v_pad = v[:self.max_num_adms]
                m_pad = torch.zeros(self.max_num_adms, dtype=torch.bool, device=self.device)
                i_pad = torch.arange(1, self.max_num_adms + 1, device=self.device)
            visit_emb_pad.append(v_pad)      # [V_max, d]
            visit_pad_mask.append(m_pad)     # [V_max]
            visit_index_pad.append(i_pad)    # [V_max]

        visit_emb_pad  = torch.stack(visit_emb_pad,  dim=0)  # [B, V_max, d]
        visit_pad_mask = torch.stack(visit_pad_mask, dim=0)  # [B, V_max]
        visit_index_pad = torch.stack(visit_index_pad, dim=0)  # [B, V_max]

        # ====== 对齐与类型嵌入（关键部分） ======
        nonpad = (~visit_pad_mask).unsqueeze(-1)             # [B, V_max, 1], bool

        # 1. 加 type embedding（仅非 pad 位）
        visit_type_ids = torch.full((B, self.max_num_adms), self.visit_type_id, dtype=torch.long, device=self.device) # [B, V_max]                                                   # [B, V_max]
        visit_type_emb = self.type_emb(visit_type_ids) * nonpad      # [B, V_max, d]
        
        # 2. 加 visit index embedding（仅非 pad 位）
        visit_index_emb = self.adm_index_emb(visit_index_pad) * nonpad  # [B, V_max, d]
        
        # 3. 得到最终embedding
        visit_emb_pad  = visit_emb_pad + visit_type_emb + visit_index_emb

        return visit_emb_pad, visit_pad_mask
    
    def build_patient_graph(self, id_seq: torch.Tensor, type_seq: torch.Tensor, visit_seq: torch.Tensor, age_sex: torch.Tensor):
        # build a graph just for one patient
        hg = HeteroData()
        occ_mask = (type_seq == self.diag_type_id) & (id_seq != self.seq_pad_id) # 疾病token mask
        occ_pos = torch.nonzero(occ_mask, as_tuple=False).view(-1) # 疾病 token 的位置索引，形状 [N_occ]
        N_occ = occ_pos.numel() # 疾病 token 数量

        # build visit virtual nodes
        nonpad = id_seq != self.seq_pad_id
        visit_used = visit_seq[nonpad] # seq非pad部分
        visit_ids_unique, visit_lid_nonpad = torch.unique(visit_used, return_inverse=True)
        visit_lid_full = torch.full_like(id_seq, fill_value=-1)
        visit_lid_full[nonpad] = visit_lid_nonpad
        N_visit = visit_ids_unique.numel()
        age_sex_nonpad = age_sex[age_sex!=self.age_sex_pad_id]
        assert N_visit == len(visit_ids_unique) == len(age_sex_nonpad)
        visit_x = self.age_sex_emb(age_sex_nonpad.to(self.device))
        hg['visit'].x = visit_x
        hg['visit'].num_nodes = N_visit
        
        # build diag nodes
        gid_occ = id_seq[occ_pos]
        x_occ = self.token_emb(gid_occ) # [N_occ, d]
        hg['occ'].x = x_occ
        hg['occ'].num_nodes = N_occ

        # build edges between diag nodes and virtual visit nodes
        occ_visit_lid = visit_lid_full[occ_pos]
        e_v2o = torch.stack([occ_visit_lid, torch.arange(N_occ, device=self.device)], dim=0)
        e_o2v = torch.stack([torch.arange(N_occ, device=self.device), occ_visit_lid], dim=0)
        hg['visit','contains','occ'].edge_index = e_v2o
        hg['occ','contained_by','visit'].edge_index = e_o2v
        
        # build forward edges between virtual visit nodes
        if N_visit > 1:
            src = torch.arange(0, N_visit - 1, device=self.device)
            dst = torch.arange(1, N_visit, device=self.device)
            e_next = torch.stack([src, dst], dim=0) # [2, N_visit-1]
        else:
            e_next = torch.empty(2, 0, dtype=torch.long, device=self.device)
        hg['visit','next','visit'].edge_index = e_next
        return hg

In [31]:
final_metrics = {"precision":[],"recall":[],"f1":[],"auc":[],"prauc":[]}
for i in range(5):
    model = HeteroGT(tokenizer, d_model=128, num_heads=4, num_layers=2, max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, use_hetero_graph=True).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

Epoch 001:   0%|          | 0/98 [00:00<?, ?it/s, loss=0.7180]

Epoch 001: 100%|██████████| 98/98 [00:02<00:00, 38.73it/s, loss=0.5260]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.52it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 68.89it/s]


Validation: {'precision': 0.7248868778214943, 'recall': 0.4720094284002828, 'f1': 0.5717344705938429, 'auc': 0.8336861788333322, 'prauc': 0.6783495571486897}
Test:      {'precision': 0.7383773928829683, 'recall': 0.4485049833862209, 'f1': 0.5580433986702152, 'auc': 0.8372240744871475, 'prauc': 0.696012729040115}


Epoch 002: 100%|██████████| 98/98 [00:02<00:00, 39.14it/s, loss=0.4290]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.52it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.10it/s]


Validation: {'precision': 0.9137055837099641, 'recall': 0.10606953447197368, 'f1': 0.19007391576874788, 'auc': 0.8554749260404066, 'prauc': 0.7136796609242874}
Test:      {'precision': 0.9317073170277216, 'recall': 0.10575858250218297, 'f1': 0.18995524431335267, 'auc': 0.8573256626422375, 'prauc': 0.7273704931597026}


Epoch 003: 100%|██████████| 98/98 [00:02<00:00, 38.77it/s, loss=0.4023]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.04it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 62.45it/s]


Validation: {'precision': 0.7478532396506804, 'recall': 0.5645256334675043, 'f1': 0.643384817121454, 'auc': 0.8643959367220894, 'prauc': 0.7324070934201978}
Test:      {'precision': 0.7629513343739173, 'recall': 0.5382059800634651, 'f1': 0.6311688263139065, 'auc': 0.866866016987001, 'prauc': 0.7515654082108046}


Epoch 004: 100%|██████████| 98/98 [00:02<00:00, 39.09it/s, loss=0.3672]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 68.74it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 68.54it/s]


Validation: {'precision': 0.7641815234946178, 'recall': 0.5556865055948398, 'f1': 0.6434663888426556, 'auc': 0.8777450062025269, 'prauc': 0.7471428060089337}
Test:      {'precision': 0.7722616233193675, 'recall': 0.542635658911724, 'f1': 0.6373983691320796, 'auc': 0.877365884776155, 'prauc': 0.7557563143820812}


Epoch 005: 100%|██████████| 98/98 [00:02<00:00, 38.49it/s, loss=0.3355]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 68.60it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 68.67it/s]


Validation: {'precision': 0.7067846607627919, 'recall': 0.7059516794301358, 'f1': 0.7063679195241388, 'auc': 0.8866593742850424, 'prauc': 0.773322185394425}
Test:      {'precision': 0.7134670487065131, 'recall': 0.6893687707603026, 'f1': 0.7012109214971037, 'auc': 0.8892852957411366, 'prauc': 0.7888877069910376}


Epoch 006: 100%|██████████| 98/98 [00:02<00:00, 38.65it/s, loss=0.3055]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 68.84it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 68.67it/s]


Validation: {'precision': 0.7738876732255735, 'recall': 0.6252209781931337, 'f1': 0.6916557968772387, 'auc': 0.8945385188357925, 'prauc': 0.7867118962632126}
Test:      {'precision': 0.7749658002682971, 'recall': 0.627353266884677, 'f1': 0.6933904479275351, 'auc': 0.8952074553129802, 'prauc': 0.7978664934394755}


Epoch 007: 100%|██████████| 98/98 [00:02<00:00, 38.44it/s, loss=0.2839]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.51it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 68.92it/s]


Validation: {'precision': 0.8165829145660254, 'recall': 0.5745433117231907, 'f1': 0.6745070861186754, 'auc': 0.8952622427892989, 'prauc': 0.7932734859462562}
Test:      {'precision': 0.8171114599621891, 'recall': 0.5764119601296988, 'f1': 0.6759740211188101, 'auc': 0.8925330170121137, 'prauc': 0.7974708889364917}


Epoch 008: 100%|██████████| 98/98 [00:02<00:00, 39.14it/s, loss=0.2529]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.28it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 62.40it/s]


Validation: {'precision': 0.6793650793614848, 'recall': 0.7566293459000788, 'f1': 0.7159185899366055, 'auc': 0.8882863639830099, 'prauc': 0.7931578965016712}
Test:      {'precision': 0.6805207811683499, 'recall': 0.7524916943479929, 'f1': 0.7146989169126137, 'auc': 0.8851415813205032, 'prauc': 0.7966185216060899}


Epoch 009: 100%|██████████| 98/98 [00:02<00:00, 39.03it/s, loss=0.2355]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.13it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 68.87it/s]


Validation: {'precision': 0.7517594369753566, 'recall': 0.6923983500253836, 'f1': 0.7208588907095469, 'auc': 0.8960305488026012, 'prauc': 0.7973641763347763}
Test:      {'precision': 0.7421083978514467, 'recall': 0.6899224806163349, 'f1': 0.7150645574128665, 'auc': 0.8966727921896627, 'prauc': 0.8058725491554759}


Epoch 010: 100%|██████████| 98/98 [00:02<00:00, 39.19it/s, loss=0.2127]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.58it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.41it/s]


Validation: {'precision': 0.6874663797703741, 'recall': 0.7530936947510131, 'f1': 0.7187851468623526, 'auc': 0.8919859084923203, 'prauc': 0.7948925434345352}
Test:      {'precision': 0.6944444444408723, 'recall': 0.7475083056437015, 'f1': 0.7199999950029313, 'auc': 0.8884627940952716, 'prauc': 0.7980145996142819}


Epoch 011: 100%|██████████| 98/98 [00:02<00:00, 39.06it/s, loss=0.1896]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.25it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.33it/s]


Validation: {'precision': 0.7113938692844917, 'recall': 0.7248084855584868, 'f1': 0.7180385238929172, 'auc': 0.89027575856247, 'prauc': 0.7919288138426661}
Test:      {'precision': 0.7086311159939053, 'recall': 0.7137320044257269, 'f1': 0.711172408789244, 'auc': 0.8840297299599624, 'prauc': 0.7908436061582603}


Epoch 012: 100%|██████████| 98/98 [00:02<00:00, 38.96it/s, loss=0.1664]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.41it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.17it/s]


Validation: {'precision': 0.7565406976689205, 'recall': 0.6134354743629143, 'f1': 0.6775138251835683, 'auc': 0.8823430261203606, 'prauc': 0.7759032755509105}
Test:      {'precision': 0.7599999999948475, 'recall': 0.6207087486122884, 'f1': 0.6833282486279477, 'auc': 0.8755162254960616, 'prauc': 0.7821931311523548}


Epoch 013: 100%|██████████| 98/98 [00:02<00:00, 39.08it/s, loss=0.1764]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.49it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.53it/s]


Validation: {'precision': 0.8521276595654029, 'recall': 0.4720094284002828, 'f1': 0.6075085278306427, 'auc': 0.8792603491375671, 'prauc': 0.7794149802206571}
Test:      {'precision': 0.8506743737875658, 'recall': 0.48892580287658405, 'f1': 0.6209563947976594, 'auc': 0.8768138367882957, 'prauc': 0.7823373633616915}


Epoch 014: 100%|██████████| 98/98 [00:02<00:00, 39.09it/s, loss=0.1652]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 62.42it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.26it/s]


Validation: {'precision': 0.7465975372602295, 'recall': 0.6788450206206315, 'f1': 0.7111111061180175, 'auc': 0.8760022181171527, 'prauc': 0.7791349220452829}
Test:      {'precision': 0.7419158023139603, 'recall': 0.6733111849353638, 'f1': 0.7059506481281157, 'auc': 0.8733785706161024, 'prauc': 0.7808226190239043}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7517594369753566, 'recall': 0.6923983500253836, 'f1': 0.7208588907095469, 'auc': 0.8960305488026012, 'prauc': 0.7973641763347763}
Corresponding test performance:
{'precision': 0.7421083978514467, 'recall': 0.6899224806163349, 'f1': 0.7150645574128665, 'auc': 0.8966727921896627, 'prauc': 0.8058725491554759}


Epoch 001: 100%|██████████| 98/98 [00:02<00:00, 39.25it/s, loss=0.5533]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.88it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 68.84it/s]


Validation: {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.7953311086587429, 'prauc': 0.5948428723600928}
Test:      {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.7959131091608576, 'prauc': 0.6138694400998913}


Epoch 002: 100%|██████████| 98/98 [00:02<00:00, 39.10it/s, loss=0.4792]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.61it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.41it/s]


Validation: {'precision': 0.6235498839871024, 'recall': 0.6334708308742872, 'f1': 0.6284712022459796, 'auc': 0.8377278171997586, 'prauc': 0.6776266586552998}
Test:      {'precision': 0.6515759312283578, 'recall': 0.6295681063088064, 'f1': 0.640382985704712, 'auc': 0.8367978360752927, 'prauc': 0.6902619854067021}


Epoch 003: 100%|██████████| 98/98 [00:02<00:00, 38.68it/s, loss=0.4187]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.35it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.25it/s]


Validation: {'precision': 0.575463371577337, 'recall': 0.7684148497302982, 'f1': 0.6580873026950086, 'auc': 0.8591411936827349, 'prauc': 0.7216365343815501}
Test:      {'precision': 0.5946061643810163, 'recall': 0.769102990028964, 'f1': 0.6706904827657348, 'auc': 0.8600438019210703, 'prauc': 0.7333186099231481}


Epoch 004: 100%|██████████| 98/98 [00:02<00:00, 39.08it/s, loss=0.3804]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.20it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.07it/s]


Validation: {'precision': 0.698301113059764, 'recall': 0.7024160282810701, 'f1': 0.7003525214354114, 'auc': 0.8852724123531395, 'prauc': 0.7664365304258}
Test:      {'precision': 0.7175792507163252, 'recall': 0.6893687707603026, 'f1': 0.7031911839277194, 'auc': 0.886998072725321, 'prauc': 0.7805893177101151}


Epoch 005: 100%|██████████| 98/98 [00:02<00:00, 38.77it/s, loss=0.3243]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 62.33it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.24it/s]


Validation: {'precision': 0.7261064785072091, 'recall': 0.6670595167904122, 'f1': 0.695331690336406, 'auc': 0.8897262750948134, 'prauc': 0.7756750082972339}
Test:      {'precision': 0.7419962335169994, 'recall': 0.6544850498302631, 'f1': 0.6954986710967428, 'auc': 0.8909581814878226, 'prauc': 0.7818405341249494}


Epoch 006: 100%|██████████| 98/98 [00:02<00:00, 39.38it/s, loss=0.2940]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.62it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.09it/s]


Validation: {'precision': 0.831858407070444, 'recall': 0.44313494401624554, 'f1': 0.5782391342530933, 'auc': 0.8752040204075975, 'prauc': 0.7592247816177227}
Test:      {'precision': 0.8345177664889897, 'recall': 0.4551495016586094, 'f1': 0.5890361831747566, 'auc': 0.8772661111349107, 'prauc': 0.7648631379943085}


Epoch 007: 100%|██████████| 98/98 [00:02<00:00, 39.23it/s, loss=0.2825]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.37it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.12it/s]


Validation: {'precision': 0.833161688971852, 'recall': 0.47672362993237055, 'f1': 0.6064467719773782, 'auc': 0.8809551699924159, 'prauc': 0.7725082361216404}
Test:      {'precision': 0.8380487804796288, 'recall': 0.4756367663318071, 'f1': 0.606852697601608, 'auc': 0.8820430588813437, 'prauc': 0.7799869791762815}


Epoch 008: 100%|██████████| 98/98 [00:02<00:00, 39.11it/s, loss=0.2492]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.26it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.18it/s]


Validation: {'precision': 0.7710464201355544, 'recall': 0.5774896876807455, 'f1': 0.6603773535891219, 'auc': 0.8871354698061295, 'prauc': 0.777154001234529}
Test:      {'precision': 0.7793357933521821, 'recall': 0.5847176079701843, 'f1': 0.668142987821377, 'auc': 0.8849236919374035, 'prauc': 0.7794268341777462}


Epoch 009: 100%|██████████| 98/98 [00:02<00:00, 39.41it/s, loss=0.2410]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.26it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.38it/s]


Validation: {'precision': 0.6912362158984838, 'recall': 0.7018267530895591, 'f1': 0.6964912230663914, 'auc': 0.8873409432822947, 'prauc': 0.7779270351859466}
Test:      {'precision': 0.6957001102497481, 'recall': 0.6987818383128529, 'f1': 0.6972375640569459, 'auc': 0.8827888068377382, 'prauc': 0.7814958031378836}


Epoch 010: 100%|██████████| 98/98 [00:02<00:00, 39.16it/s, loss=0.2078]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.63it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 62.35it/s]


Validation: {'precision': 0.7709359605848035, 'recall': 0.553329404828796, 'f1': 0.6442538544787882, 'auc': 0.8755691078772156, 'prauc': 0.7599481041440443}
Test:      {'precision': 0.7823022709414073, 'recall': 0.5531561461763391, 'f1': 0.648070056771289, 'auc': 0.8739302492999226, 'prauc': 0.7722695774368167}


Epoch 011: 100%|██████████| 98/98 [00:02<00:00, 38.88it/s, loss=0.2097]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.16it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 68.95it/s]


Validation: {'precision': 0.6163284568163458, 'recall': 0.7695934001133201, 'f1': 0.6844863682231785, 'auc': 0.8852028566753203, 'prauc': 0.778880465086523}
Test:      {'precision': 0.6313901345263166, 'recall': 0.7796234772935791, 'f1': 0.6977205104134692, 'auc': 0.880890522525823, 'prauc': 0.7826627653748193}


Epoch 012: 100%|██████████| 98/98 [00:02<00:00, 38.98it/s, loss=0.1628]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.31it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 68.68it/s]


Validation: {'precision': 0.7795744680784717, 'recall': 0.5397760754240438, 'f1': 0.6378830035172782, 'auc': 0.8714232318007793, 'prauc': 0.7620236202227304}
Test:      {'precision': 0.7902155887164991, 'recall': 0.52768549279885, 'f1': 0.6328021200282051, 'auc': 0.868245613775927, 'prauc': 0.7634407378461205}


Epoch 013: 100%|██████████| 98/98 [00:02<00:00, 38.97it/s, loss=0.1552]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.01it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 68.93it/s]


Validation: {'precision': 0.6646136618103885, 'recall': 0.6994696523235152, 'f1': 0.6815963200065286, 'auc': 0.8759979387503815, 'prauc': 0.7659807226151197}
Test:      {'precision': 0.6868905742107969, 'recall': 0.7021040974490471, 'f1': 0.6944140147120216, 'auc': 0.8731822855192931, 'prauc': 0.7671382576360822}


Epoch 014: 100%|██████████| 98/98 [00:02<00:00, 38.83it/s, loss=0.1318]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 68.97it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.20it/s]


Validation: {'precision': 0.6561992420105242, 'recall': 0.7142015321112893, 'f1': 0.6839729069689799, 'auc': 0.8742898965683439, 'prauc': 0.7687140804828153}
Test:      {'precision': 0.6731657260098863, 'recall': 0.7264673311144715, 'f1': 0.6988015928730369, 'auc': 0.8742140594540111, 'prauc': 0.7746372016550302}


Epoch 015: 100%|██████████| 98/98 [00:02<00:00, 38.85it/s, loss=0.1281]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.09it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.07it/s]


Validation: {'precision': 0.6387061403473756, 'recall': 0.686505598110274, 'f1': 0.6617438177803663, 'auc': 0.8656739344791903, 'prauc': 0.746284277639963}
Test:      {'precision': 0.6561679789991802, 'recall': 0.6921373200404644, 'f1': 0.6736728594569475, 'auc': 0.8677990020912456, 'prauc': 0.7577830451675467}


Epoch 016: 100%|██████████| 98/98 [00:02<00:00, 34.44it/s, loss=0.1261]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.29it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.20it/s]


Validation: {'precision': 0.6461785141600953, 'recall': 0.7124337065367565, 'f1': 0.6776905779677337, 'auc': 0.8710779443863714, 'prauc': 0.7547166002115671}
Test:      {'precision': 0.6658084448987343, 'recall': 0.7159468438498564, 'f1': 0.6899679779271279, 'auc': 0.8725382808256851, 'prauc': 0.7661572022088735}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6163284568163458, 'recall': 0.7695934001133201, 'f1': 0.6844863682231785, 'auc': 0.8852028566753203, 'prauc': 0.778880465086523}
Corresponding test performance:
{'precision': 0.6313901345263166, 'recall': 0.7796234772935791, 'f1': 0.6977205104134692, 'auc': 0.880890522525823, 'prauc': 0.7826627653748193}


Epoch 001: 100%|██████████| 98/98 [00:02<00:00, 38.93it/s, loss=0.5269]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.13it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 68.75it/s]


Validation: {'precision': 0.6305137519427582, 'recall': 0.7159693576858223, 'f1': 0.6705297963409423, 'auc': 0.8645637262222096, 'prauc': 0.7376394827696036}
Test:      {'precision': 0.6601842374582386, 'recall': 0.7142857142817592, 'f1': 0.6861702077700544, 'auc': 0.864858172476804, 'prauc': 0.7559071564569965}


Epoch 002: 100%|██████████| 98/98 [00:02<00:00, 38.78it/s, loss=0.4307]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 68.94it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 68.53it/s]


Validation: {'precision': 0.7182432432383903, 'recall': 0.6263995285761557, 'f1': 0.6691847605211602, 'auc': 0.8714011323843188, 'prauc': 0.7518632079583415}
Test:      {'precision': 0.7303149606251292, 'recall': 0.6162790697640295, 'f1': 0.6684684635003112, 'auc': 0.8742932751704092, 'prauc': 0.7690097805705717}


Epoch 003: 100%|██████████| 98/98 [00:02<00:00, 39.15it/s, loss=0.3812]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.51it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.27it/s]


Validation: {'precision': 0.8191094619590065, 'recall': 0.5203299941041819, 'f1': 0.6363963916405955, 'auc': 0.8904681384537383, 'prauc': 0.7809035548040626}
Test:      {'precision': 0.8266094420529906, 'recall': 0.5332225913591738, 'f1': 0.6482665721385126, 'auc': 0.8875006339719339, 'prauc': 0.7944148100149316}


Epoch 004: 100%|██████████| 98/98 [00:02<00:00, 38.98it/s, loss=0.3362]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.60it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.41it/s]


Validation: {'precision': 0.7207833228002478, 'recall': 0.6723629935140109, 'f1': 0.695731702318871, 'auc': 0.8875476941813768, 'prauc': 0.7805002884104382}
Test:      {'precision': 0.7430806257476349, 'recall': 0.6838316721999789, 'f1': 0.7122260619018604, 'auc': 0.8889373497855821, 'prauc': 0.7927455548319642}


Epoch 005: 100%|██████████| 98/98 [00:02<00:00, 39.57it/s, loss=0.3296]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 62.53it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.29it/s]


Validation: {'precision': 0.6900921658946424, 'recall': 0.7059516794301358, 'f1': 0.6979318330391077, 'auc': 0.8892895880707095, 'prauc': 0.7835490388879757}
Test:      {'precision': 0.7135593220298669, 'recall': 0.6993355481688852, 'f1': 0.7063758339227306, 'auc': 0.886960219061308, 'prauc': 0.7928652028526557}


Epoch 006: 100%|██████████| 98/98 [00:02<00:00, 39.09it/s, loss=0.2878]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.60it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.32it/s]


Validation: {'precision': 0.7193090684718119, 'recall': 0.687094873301785, 'f1': 0.7028330269453429, 'auc': 0.8818050394589553, 'prauc': 0.7718644371538499}
Test:      {'precision': 0.7226493199247627, 'recall': 0.676633444071558, 'f1': 0.6988847533657225, 'auc': 0.886203761287778, 'prauc': 0.7922519781609836}


Epoch 007: 100%|██████████| 98/98 [00:02<00:00, 39.04it/s, loss=0.2835]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.28it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 68.98it/s]


Validation: {'precision': 0.6879765395854078, 'recall': 0.6912197996423617, 'f1': 0.6895943512569966, 'auc': 0.8769249006899489, 'prauc': 0.7710317293358729}
Test:      {'precision': 0.7140439932277832, 'recall': 0.7009966777369824, 'f1': 0.7074601794055246, 'auc': 0.8822896308781508, 'prauc': 0.7897995451090303}


Epoch 008: 100%|██████████| 98/98 [00:02<00:00, 38.85it/s, loss=0.2551]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.26it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 68.95it/s]


Validation: {'precision': 0.7694779116404059, 'recall': 0.5645256334675043, 'f1': 0.651257642972194, 'auc': 0.867928394313884, 'prauc': 0.7648183405695933}
Test:      {'precision': 0.794342507639187, 'recall': 0.575304540417634, 'f1': 0.6673089225481253, 'auc': 0.8769972577944081, 'prauc': 0.7857645800722046}


Epoch 009: 100%|██████████| 98/98 [00:02<00:00, 38.76it/s, loss=0.2352]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.20it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.35it/s]


Validation: {'precision': 0.7450532724448626, 'recall': 0.5769004124892345, 'f1': 0.6502822933163677, 'auc': 0.8649237040001604, 'prauc': 0.7545259996662806}
Test:      {'precision': 0.7638190954719037, 'recall': 0.5891472868184433, 'f1': 0.6652078725408858, 'auc': 0.8696076686229932, 'prauc': 0.7669336347380027}


Epoch 010: 100%|██████████| 98/98 [00:02<00:00, 39.11it/s, loss=0.2079]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 62.92it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.40it/s]


Validation: {'precision': 0.6617733411587683, 'recall': 0.6641131408328573, 'f1': 0.6629411714667043, 'auc': 0.8550652565109926, 'prauc': 0.738349719855061}
Test:      {'precision': 0.7011952191195151, 'recall': 0.6821705426318817, 'f1': 0.6915520578654332, 'auc': 0.8682517688432463, 'prauc': 0.7591824738758762}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6900921658946424, 'recall': 0.7059516794301358, 'f1': 0.6979318330391077, 'auc': 0.8892895880707095, 'prauc': 0.7835490388879757}
Corresponding test performance:
{'precision': 0.7135593220298669, 'recall': 0.6993355481688852, 'f1': 0.7063758339227306, 'auc': 0.886960219061308, 'prauc': 0.7928652028526557}


Epoch 001: 100%|██████████| 98/98 [00:02<00:00, 39.02it/s, loss=0.5430]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.48it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.37it/s]


Validation: {'precision': 0.8196286471931132, 'recall': 0.18208603417688812, 'f1': 0.29797492469846637, 'auc': 0.818438539542818, 'prauc': 0.6433551773868899}
Test:      {'precision': 0.8263027295080322, 'recall': 0.18438538205877972, 'f1': 0.30149388565160723, 'auc': 0.817313724273665, 'prauc': 0.6636117666558194}


Epoch 002: 100%|██████████| 98/98 [00:02<00:00, 39.24it/s, loss=0.4512]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.75it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.52it/s]


Validation: {'precision': 0.6376419686282443, 'recall': 0.6947554507914275, 'f1': 0.6649746142947768, 'auc': 0.8525063867952282, 'prauc': 0.7250357782359015}
Test:      {'precision': 0.6542497376670816, 'recall': 0.6904761904723673, 'f1': 0.6718749950000089, 'auc': 0.8523999469186995, 'prauc': 0.7386378758501939}


Epoch 003: 100%|██████████| 98/98 [00:02<00:00, 39.13it/s, loss=0.3830]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.90it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.78it/s]


Validation: {'precision': 0.7627795527095625, 'recall': 0.5627578078929714, 'f1': 0.6476771738141053, 'auc': 0.8804047923287176, 'prauc': 0.7602331645895745}
Test:      {'precision': 0.7734976887459669, 'recall': 0.555924695456501, 'f1': 0.6469072116245999, 'auc': 0.8827591394132596, 'prauc': 0.7732203626941281}


Epoch 004: 100%|██████████| 98/98 [00:02<00:00, 39.43it/s, loss=0.3349]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.59it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.21it/s]


Validation: {'precision': 0.7269076305172228, 'recall': 0.6399528579809078, 'f1': 0.6806643635524781, 'auc': 0.8806132038775918, 'prauc': 0.7717957070749752}
Test:      {'precision': 0.760797342187636, 'recall': 0.6339977851570654, 'f1': 0.6916339424850456, 'auc': 0.8838463705045235, 'prauc': 0.7858489500223945}


Epoch 005: 100%|██████████| 98/98 [00:02<00:00, 39.17it/s, loss=0.3025]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.25it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 62.97it/s]


Validation: {'precision': 0.7345879299108723, 'recall': 0.6670595167904122, 'f1': 0.6991970302142049, 'auc': 0.8920668971052448, 'prauc': 0.7917359297375273}
Test:      {'precision': 0.7656348704942159, 'recall': 0.6710963455112343, 'f1': 0.7152552325506236, 'auc': 0.8921767617156782, 'prauc': 0.7942213714267401}


Epoch 006: 100%|██████████| 98/98 [00:02<00:00, 39.40it/s, loss=0.3080]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.75it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.63it/s]


Validation: {'precision': 0.7116221255393933, 'recall': 0.6747200942800547, 'f1': 0.6926799708009252, 'auc': 0.88065606141645, 'prauc': 0.7797633700286982}
Test:      {'precision': 0.7336561743296994, 'recall': 0.6710963455112343, 'f1': 0.7009832223048793, 'auc': 0.8800198267028487, 'prauc': 0.7844839781998068}


Epoch 007: 100%|██████████| 98/98 [00:02<00:00, 39.05it/s, loss=0.2780]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.35it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.45it/s]


Validation: {'precision': 0.7598072952459752, 'recall': 0.6505598114281052, 'f1': 0.700952375977931, 'auc': 0.8946113319420494, 'prauc': 0.7998544775973249}
Test:      {'precision': 0.7894021739076807, 'recall': 0.6434108527096157, 'f1': 0.7089688785131115, 'auc': 0.8962631109088965, 'prauc': 0.8118819544265783}


Epoch 008: 100%|██████████| 98/98 [00:02<00:00, 39.14it/s, loss=0.2561]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.44it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.37it/s]


Validation: {'precision': 0.8165217391233347, 'recall': 0.553329404828796, 'f1': 0.6596417233148186, 'auc': 0.8905548754697883, 'prauc': 0.7918871240497816}
Test:      {'precision': 0.8294970161907119, 'recall': 0.5387596899194975, 'f1': 0.6532393372824633, 'auc': 0.8908431432796268, 'prauc': 0.799742132055086}


Epoch 009: 100%|██████████| 98/98 [00:02<00:00, 39.51it/s, loss=0.2243]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.57it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.66it/s]


Validation: {'precision': 0.6342504743802929, 'recall': 0.78786093105016, 'f1': 0.7027595219928824, 'auc': 0.8923228287866232, 'prauc': 0.7932433576449608}
Test:      {'precision': 0.6791921089681578, 'recall': 0.8006644518228092, 'f1': 0.7349428158685812, 'auc': 0.8897109801969326, 'prauc': 0.7998907388297861}


Epoch 010: 100%|██████████| 98/98 [00:02<00:00, 39.24it/s, loss=0.2064]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.62it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.56it/s]


Validation: {'precision': 0.6312849161981784, 'recall': 0.7990571596888683, 'f1': 0.7053315945449661, 'auc': 0.8929278801511601, 'prauc': 0.7881199917672609}
Test:      {'precision': 0.6631090487208209, 'recall': 0.7912513842702589, 'f1': 0.7215349609528705, 'auc': 0.8944860814231698, 'prauc': 0.7978509554892705}


Epoch 011: 100%|██████████| 98/98 [00:02<00:00, 34.70it/s, loss=0.1830]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.67it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.15it/s]


Validation: {'precision': 0.6512207274506666, 'recall': 0.770182675304831, 'f1': 0.705723537147843, 'auc': 0.8888915430898382, 'prauc': 0.7831310178606523}
Test:      {'precision': 0.6812865497042823, 'recall': 0.7740863787332554, 'f1': 0.724727833274737, 'auc': 0.8868523207312022, 'prauc': 0.7915835472839658}


Epoch 012: 100%|██████████| 98/98 [00:02<00:00, 39.13it/s, loss=0.1746]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.53it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.19it/s]


Validation: {'precision': 0.6852260198418676, 'recall': 0.7324690630481293, 'f1': 0.7080603766591667, 'auc': 0.8905813181241659, 'prauc': 0.7828294455427458}
Test:      {'precision': 0.7000532765013316, 'recall': 0.7275747508265362, 'f1': 0.7135487324402859, 'auc': 0.8882031733557476, 'prauc': 0.786889222897619}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7598072952459752, 'recall': 0.6505598114281052, 'f1': 0.700952375977931, 'auc': 0.8946113319420494, 'prauc': 0.7998544775973249}
Corresponding test performance:
{'precision': 0.7894021739076807, 'recall': 0.6434108527096157, 'f1': 0.7089688785131115, 'auc': 0.8962631109088965, 'prauc': 0.8118819544265783}


Epoch 001: 100%|██████████| 98/98 [00:02<00:00, 39.30it/s, loss=0.5393]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.46it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.50it/s]


Validation: {'precision': 0.7580893682471789, 'recall': 0.2899233942233947, 'f1': 0.41943733614766104, 'auc': 0.8177768983430675, 'prauc': 0.6382502495863877}
Test:      {'precision': 0.7828655834448617, 'recall': 0.2934662236971569, 'f1': 0.4269029360222303, 'auc': 0.8304133693971036, 'prauc': 0.6780288530133204}


Epoch 002: 100%|██████████| 98/98 [00:02<00:00, 38.78it/s, loss=0.4420]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.34it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.27it/s]


Validation: {'precision': 0.7643724696294384, 'recall': 0.5562757807863508, 'f1': 0.6439290537827803, 'auc': 0.8522803468101026, 'prauc': 0.7334234894136957}
Test:      {'precision': 0.7860016090041352, 'recall': 0.540974529343627, 'f1': 0.6408658528245244, 'auc': 0.8548779770214102, 'prauc': 0.7482596111993287}


Epoch 003: 100%|██████████| 98/98 [00:02<00:00, 39.05it/s, loss=0.3750]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.22it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 68.96it/s]


Validation: {'precision': 0.6650110375239238, 'recall': 0.7100766057707126, 'f1': 0.6868053526532085, 'auc': 0.8813251755402637, 'prauc': 0.7616497558049372}
Test:      {'precision': 0.6857451403850662, 'recall': 0.7032115171611119, 'f1': 0.6943685023780769, 'auc': 0.8794231544769252, 'prauc': 0.7789539095690706}


Epoch 004: 100%|██████████| 98/98 [00:02<00:00, 38.74it/s, loss=0.3196]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 62.61it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.12it/s]


Validation: {'precision': 0.7223701730977207, 'recall': 0.6393635827893969, 'f1': 0.6783369753206835, 'auc': 0.8777976998978445, 'prauc': 0.7639579295266115}
Test:      {'precision': 0.7534766118789288, 'recall': 0.6600221483905868, 'f1': 0.7036599714049518, 'auc': 0.8831909173857015, 'prauc': 0.7850185451255319}


Epoch 005: 100%|██████████| 98/98 [00:02<00:00, 38.74it/s, loss=0.2759]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 68.73it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 68.50it/s]


Validation: {'precision': 0.7345132743308664, 'recall': 0.586918090744921, 'f1': 0.6524729724573823, 'auc': 0.8735434344869186, 'prauc': 0.7599330200267869}
Test:      {'precision': 0.770745428967857, 'recall': 0.6068660022114791, 'f1': 0.6790582354630795, 'auc': 0.8808112452587517, 'prauc': 0.7854248443091578}


Epoch 006: 100%|██████████| 98/98 [00:02<00:00, 38.96it/s, loss=0.2914]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 68.90it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 68.62it/s]


Validation: {'precision': 0.731292517001828, 'recall': 0.6334708308742872, 'f1': 0.6788759028205796, 'auc': 0.8869877997169486, 'prauc': 0.777581826077121}
Test:      {'precision': 0.7595099935476499, 'recall': 0.6522702104061336, 'f1': 0.7018170936246095, 'auc': 0.8848226872826952, 'prauc': 0.7850311985934505}


Epoch 007: 100%|██████████| 98/98 [00:02<00:00, 38.94it/s, loss=0.2412]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 68.93it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 68.85it/s]


Validation: {'precision': 0.7451253481842262, 'recall': 0.6305244549167324, 'f1': 0.6830513834759191, 'auc': 0.8829865917858386, 'prauc': 0.7755923426043634}
Test:      {'precision': 0.7721021610951402, 'recall': 0.6528239202621661, 'f1': 0.7074707421054978, 'auc': 0.8876079783459807, 'prauc': 0.7951323584223474}


Epoch 008: 100%|██████████| 98/98 [00:02<00:00, 38.69it/s, loss=0.2210]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.49it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.69it/s]


Validation: {'precision': 0.682359952320129, 'recall': 0.6747200942800547, 'f1': 0.6785185135146562, 'auc': 0.8741620904055191, 'prauc': 0.7650506247738511}
Test:      {'precision': 0.7134799774353442, 'recall': 0.70044296788095, 'f1': 0.7069013640939882, 'auc': 0.8776105487020933, 'prauc': 0.7849987386971597}


Epoch 009: 100%|██████████| 98/98 [00:02<00:00, 39.39it/s, loss=0.2133]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 62.83it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.71it/s]


Validation: {'precision': 0.7322936972012197, 'recall': 0.6641131408328573, 'f1': 0.6965389319668238, 'auc': 0.8854773109889922, 'prauc': 0.7829647952819372}
Test:      {'precision': 0.7612219451324114, 'recall': 0.6760797342155256, 'f1': 0.7161290272714099, 'auc': 0.8870165994779518, 'prauc': 0.7949989873825286}


Epoch 010: 100%|██████████| 98/98 [00:02<00:00, 39.07it/s, loss=0.1836]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.57it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.46it/s]


Validation: {'precision': 0.7538917716771395, 'recall': 0.5992928697666512, 'f1': 0.6677609930910824, 'auc': 0.8783686440705031, 'prauc': 0.768286202317281}
Test:      {'precision': 0.7956307258576771, 'recall': 0.6251384274605475, 'f1': 0.700155033827348, 'auc': 0.8809472106958328, 'prauc': 0.7820796458134545}


Epoch 011: 100%|██████████| 98/98 [00:02<00:00, 39.35it/s, loss=0.1636]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.64it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.61it/s]


Validation: {'precision': 0.7346657477551023, 'recall': 0.6281673541506885, 'f1': 0.67725539528036, 'auc': 0.8808545729377188, 'prauc': 0.7699237284167865}
Test:      {'precision': 0.7749343831970149, 'recall': 0.6539313399742308, 'f1': 0.7093093043409069, 'auc': 0.8834674645603534, 'prauc': 0.7909397086775963}


Epoch 012: 100%|██████████| 98/98 [00:02<00:00, 38.94it/s, loss=0.1519]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.70it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.26it/s]


Validation: {'precision': 0.6533110740086071, 'recall': 0.6918090748338727, 'f1': 0.6720091535577762, 'auc': 0.8686758783336426, 'prauc': 0.7534936181795037}
Test:      {'precision': 0.6856840993096373, 'recall': 0.7187153931300182, 'f1': 0.7018112953504178, 'auc': 0.8699162836983781, 'prauc': 0.7651372878554726}


Epoch 013: 100%|██████████| 98/98 [00:02<00:00, 39.37it/s, loss=0.1448]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.50it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 69.44it/s]


Validation: {'precision': 0.6695652173874228, 'recall': 0.6806128461951644, 'f1': 0.6750438290115853, 'auc': 0.8701919877224329, 'prauc': 0.7579580581757852}
Test:      {'precision': 0.6937984496085615, 'recall': 0.6937984496085615, 'f1': 0.6937984446085615, 'auc': 0.8762996424644497, 'prauc': 0.7797466763174398}


Epoch 014: 100%|██████████| 98/98 [00:02<00:00, 39.06it/s, loss=0.1375]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 69.70it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 62.42it/s]

Validation: {'precision': 0.5881548974916257, 'recall': 0.7607542722406556, 'f1': 0.6634121225193575, 'auc': 0.8622990470041814, 'prauc': 0.7505121227553713}
Test:      {'precision': 0.6036846615226921, 'recall': 0.7801771871496115, 'f1': 0.6806763235804547, 'auc': 0.8674740145367918, 'prauc': 0.7734520232562784}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7322936972012197, 'recall': 0.6641131408328573, 'f1': 0.6965389319668238, 'auc': 0.8854773109889922, 'prauc': 0.7829647952819372}
Corresponding test performance:
{'precision': 0.7612219451324114, 'recall': 0.6760797342155256, 'f1': 0.7161290272714099, 'auc': 0.8870165994779518, 'prauc': 0.7949989873825286}





In [32]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")


Final Metrics:
precision: 0.7275 ± 0.0541
recall: 0.6977 ± 0.0451
f1: 0.7089 ± 0.0067
auc: 0.8896 ± 0.0061
prauc: 0.7977 ± 0.0102


In [33]:
final_metrics = {"precision":[],"recall":[],"f1":[],"auc":[],"prauc":[]}
for i in range(5):
    model = HeteroGT(tokenizer, d_model=128, num_heads=4, num_layers=2, max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, use_hetero_graph=False).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 81.72it/s, loss=0.5307]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 327.10it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 328.78it/s]


Validation: {'precision': 0.825688073369245, 'recall': 0.1591043017079605, 'f1': 0.2667984162605112, 'auc': 0.8247873953104016, 'prauc': 0.6534444119908926}
Test:      {'precision': 0.815340909067746, 'recall': 0.1589147286812906, 'f1': 0.26598702229054805, 'auc': 0.830052867104218, 'prauc': 0.6700850619612017}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 82.29it/s, loss=0.4665]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 326.71it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 328.70it/s]


Validation: {'precision': 0.6424452133757644, 'recall': 0.6564525633432148, 'f1': 0.6493733555330831, 'auc': 0.849634420722559, 'prauc': 0.698726053965898}
Test:      {'precision': 0.6787983824339758, 'recall': 0.6506090808380365, 'f1': 0.6644048578766366, 'auc': 0.8538049025849805, 'prauc': 0.7137733147772144}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 81.55it/s, loss=0.4289]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 325.87it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 326.14it/s]


Validation: {'precision': 0.7787418655013152, 'recall': 0.42309958750487275, 'f1': 0.5483008736314252, 'auc': 0.8585832792238276, 'prauc': 0.7175151618887028}
Test:      {'precision': 0.8018769551532651, 'recall': 0.4258028792888937, 'f1': 0.5562386934760142, 'auc': 0.861994158102506, 'prauc': 0.7350741814909747}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 82.92it/s, loss=0.4008]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 323.41it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 324.43it/s]


Validation: {'precision': 0.8232118758323453, 'recall': 0.35945786682168857, 'f1': 0.5004101680370583, 'auc': 0.8519756175733026, 'prauc': 0.7122467755817484}
Test:      {'precision': 0.8375499334109514, 'recall': 0.3482834994443617, 'f1': 0.4919827881820838, 'auc': 0.8588250985918683, 'prauc': 0.7377442878362301}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 81.50it/s, loss=0.3696]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 323.20it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 326.05it/s]


Validation: {'precision': 0.7056921086629645, 'recall': 0.6428992339384626, 'f1': 0.6728337908747141, 'auc': 0.8723036955461756, 'prauc': 0.7497247598514636}
Test:      {'precision': 0.7201001878477139, 'recall': 0.6367663344372273, 'f1': 0.675874223636692, 'auc': 0.8731675749084002, 'prauc': 0.7582482365494772}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 82.42it/s, loss=0.3453]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 326.80it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 325.01it/s]


Validation: {'precision': 0.6582694413983666, 'recall': 0.7083087801961797, 'f1': 0.6823729725787598, 'auc': 0.8737866558102751, 'prauc': 0.7546315104610981}
Test:      {'precision': 0.6726504751812427, 'recall': 0.7054263565852413, 'f1': 0.6886486436477547, 'auc': 0.8705945721169542, 'prauc': 0.7577154761005642}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 84.00it/s, loss=0.3170]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 324.59it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 324.35it/s]


Validation: {'precision': 0.5701244813254351, 'recall': 0.8096641131360657, 'f1': 0.6691015291138357, 'auc': 0.872866336469875, 'prauc': 0.7516179498443241}
Test:      {'precision': 0.5907414993830777, 'recall': 0.7984496123986797, 'f1': 0.6790675722218358, 'auc': 0.8661263625472524, 'prauc': 0.7527447350296275}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 81.45it/s, loss=0.2871]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 326.11it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 324.77it/s]


Validation: {'precision': 0.5896988214727643, 'recall': 0.7961107837313135, 'f1': 0.6775325929009082, 'auc': 0.874061940448843, 'prauc': 0.7451399270672913}
Test:      {'precision': 0.6110386788326335, 'recall': 0.7785160575815143, 'f1': 0.6846846797539831, 'auc': 0.868764609052282, 'prauc': 0.7495963765844187}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 82.46it/s, loss=0.2791]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 326.69it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 325.78it/s]


Validation: {'precision': 0.7918552036127434, 'recall': 0.5156157925720942, 'f1': 0.6245538852972488, 'auc': 0.8598276168870711, 'prauc': 0.7419252838750059}
Test:      {'precision': 0.7862318840508494, 'recall': 0.48062015503609845, 'f1': 0.5965635691700382, 'auc': 0.8499017281951815, 'prauc': 0.7384679797298405}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 82.43it/s, loss=0.2349]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 324.83it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 327.15it/s]


Validation: {'precision': 0.6093224411311422, 'recall': 0.7472009428359034, 'f1': 0.6712546271285671, 'auc': 0.8693541260313115, 'prauc': 0.743849594424034}
Test:      {'precision': 0.637922820387623, 'recall': 0.7414174972273455, 'f1': 0.6857874470092719, 'auc': 0.8625547000832658, 'prauc': 0.7342952108011472}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 83.11it/s, loss=0.2270]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 326.69it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 326.67it/s]


Validation: {'precision': 0.658135283359792, 'recall': 0.636417206831842, 'f1': 0.6470940633019041, 'auc': 0.8567772587040723, 'prauc': 0.7349540618434868}
Test:      {'precision': 0.6951444376109702, 'recall': 0.6262458471726122, 'f1': 0.6588989172352137, 'auc': 0.8536956501400648, 'prauc': 0.7372942203562579}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6582694413983666, 'recall': 0.7083087801961797, 'f1': 0.6823729725787598, 'auc': 0.8737866558102751, 'prauc': 0.7546315104610981}
Corresponding test performance:
{'precision': 0.6726504751812427, 'recall': 0.7054263565852413, 'f1': 0.6886486436477547, 'auc': 0.8705945721169542, 'prauc': 0.7577154761005642}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 82.33it/s, loss=0.5622]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 327.94it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 326.14it/s]


Validation: {'precision': 0.6311030741353427, 'recall': 0.41131408367465344, 'f1': 0.4980378118437703, 'auc': 0.7764163714010046, 'prauc': 0.5815283559016233}
Test:      {'precision': 0.6714542190244933, 'recall': 0.4141749723122139, 'f1': 0.5123287624005911, 'auc': 0.7802299360808569, 'prauc': 0.6004967053431717}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 82.02it/s, loss=0.4709]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 327.50it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 326.12it/s]


Validation: {'precision': 0.6945510360652759, 'recall': 0.5332940483174232, 'f1': 0.6033333284155534, 'auc': 0.832741652328659, 'prauc': 0.684460711461778}
Test:      {'precision': 0.7172100075760637, 'recall': 0.5238095238066235, 'f1': 0.6054399951175559, 'auc': 0.8379278448844251, 'prauc': 0.6952856727160954}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 83.25it/s, loss=0.4220]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 328.26it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 328.63it/s]


Validation: {'precision': 0.8244746600639743, 'recall': 0.3930465527378135, 'f1': 0.5323224218007445, 'auc': 0.850330041371896, 'prauc': 0.7149708404597896}
Test:      {'precision': 0.8009478672890883, 'recall': 0.37430786267788313, 'f1': 0.510188674900346, 'auc': 0.8542019659777422, 'prauc': 0.723981222796491}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 81.94it/s, loss=0.3855]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 323.65it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 325.09it/s]


Validation: {'precision': 0.5761861313842328, 'recall': 0.7442545668783486, 'f1': 0.6495242943833975, 'auc': 0.8535847872215808, 'prauc': 0.7168413237646986}
Test:      {'precision': 0.6060473269035669, 'recall': 0.7657807308927698, 'f1': 0.6766144764734813, 'auc': 0.8617539258250375, 'prauc': 0.7289516258670293}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 81.53it/s, loss=0.3594]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 324.72it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 326.73it/s]


Validation: {'precision': 0.7212078651634747, 'recall': 0.6051856216817608, 'f1': 0.658122391701774, 'auc': 0.8679913712636815, 'prauc': 0.7512393971371341}
Test:      {'precision': 0.739541160588802, 'recall': 0.6068660022114791, 'f1': 0.6666666617111624, 'auc': 0.8728805641193058, 'prauc': 0.7554527613745436}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 80.83it/s, loss=0.3275]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 325.04it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 328.85it/s]


Validation: {'precision': 0.6204556471128432, 'recall': 0.754272245134035, 'f1': 0.6808510588735416, 'auc': 0.8766873638985722, 'prauc': 0.7612247066904945}
Test:      {'precision': 0.6416157820542903, 'recall': 0.7563676633402195, 'f1': 0.6942820788929303, 'auc': 0.8783852870255613, 'prauc': 0.7652687840579147}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 82.20it/s, loss=0.3101]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 327.73it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 327.31it/s]


Validation: {'precision': 0.8266666666587937, 'recall': 0.5114908662315175, 'f1': 0.6319621357896977, 'auc': 0.8771695271785138, 'prauc': 0.7645705854925606}
Test:      {'precision': 0.8250676284866992, 'recall': 0.50664451826962, 'f1': 0.6277873023141463, 'auc': 0.8840725692285041, 'prauc': 0.7708153821973146}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 81.93it/s, loss=0.2740]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 324.44it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 325.28it/s]


Validation: {'precision': 0.7352941176419098, 'recall': 0.6187389510865131, 'f1': 0.6719999950327482, 'auc': 0.8739437788290401, 'prauc': 0.7642108959120973}
Test:      {'precision': 0.7469879518022291, 'recall': 0.6179401993321266, 'f1': 0.6763636314042316, 'auc': 0.8770128916653988, 'prauc': 0.7675860976532232}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 83.40it/s, loss=0.2673]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 327.30it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 328.31it/s]


Validation: {'precision': 0.666850523989703, 'recall': 0.7124337065367565, 'f1': 0.6888888838904248, 'auc': 0.8802757087429762, 'prauc': 0.7729511317066206}
Test:      {'precision': 0.6837333333296868, 'recall': 0.7098560354335003, 'f1': 0.6965498455820528, 'auc': 0.8823103119043434, 'prauc': 0.7751030027699225}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 82.18it/s, loss=0.2170]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 325.21it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 323.49it/s]


Validation: {'precision': 0.6231751824789089, 'recall': 0.804949911603978, 'f1': 0.7024942095284068, 'auc': 0.892571479157376, 'prauc': 0.7859493867940827}
Test:      {'precision': 0.6333771353454517, 'recall': 0.8006644518228092, 'f1': 0.7072633846463869, 'auc': 0.8905350206096274, 'prauc': 0.7852011825320651}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 83.37it/s, loss=0.2011]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 324.66it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 324.84it/s]


Validation: {'precision': 0.7583454281512457, 'recall': 0.6157925751289582, 'f1': 0.6796747917973567, 'auc': 0.8875547200074193, 'prauc': 0.7815219692105135}
Test:      {'precision': 0.7684495599135515, 'recall': 0.6284606865967417, 'f1': 0.6914407504526415, 'auc': 0.8893149631656152, 'prauc': 0.7842706070759117}


Epoch 012: 100%|██████████| 98/98 [00:01<00:00, 82.27it/s, loss=0.1921]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 326.26it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 328.42it/s]


Validation: {'precision': 0.6260280599872955, 'recall': 0.7625220978151885, 'f1': 0.6875664137481678, 'auc': 0.882621695929658, 'prauc': 0.7674202562477099}
Test:      {'precision': 0.649436090222512, 'recall': 0.7652270210367375, 'f1': 0.7025927759145216, 'auc': 0.8833325454847165, 'prauc': 0.7674104559312254}


Epoch 013: 100%|██████████| 98/98 [00:01<00:00, 81.06it/s, loss=0.1773]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 326.13it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 328.38it/s]


Validation: {'precision': 0.7830731306427028, 'recall': 0.5615792575099495, 'f1': 0.6540837288305599, 'auc': 0.8738032623081934, 'prauc': 0.7552044283813798}
Test:      {'precision': 0.7753236862088704, 'recall': 0.5636766334409542, 'f1': 0.6527733199043191, 'auc': 0.8721287841969385, 'prauc': 0.7540941215942637}


Epoch 014: 100%|██████████| 98/98 [00:01<00:00, 81.89it/s, loss=0.1328]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 325.02it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 323.19it/s]


Validation: {'precision': 0.6503099173520129, 'recall': 0.7418974661123047, 'f1': 0.6930911042939039, 'auc': 0.881052573489821, 'prauc': 0.7745464154053813}
Test:      {'precision': 0.6676572560640532, 'recall': 0.7464008859316368, 'f1': 0.704836596319009, 'auc': 0.8815670259748765, 'prauc': 0.7798965115343904}


Epoch 015: 100%|██████████| 98/98 [00:01<00:00, 81.90it/s, loss=0.1457]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 325.14it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 326.64it/s]


Validation: {'precision': 0.694565896310465, 'recall': 0.6552740129601929, 'f1': 0.6743480847515116, 'auc': 0.870566591992781, 'prauc': 0.7568968529362148}
Test:      {'precision': 0.7028604786882495, 'recall': 0.6666666666629754, 'f1': 0.6842853033258323, 'auc': 0.8711826887993039, 'prauc': 0.7551902375785011}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6231751824789089, 'recall': 0.804949911603978, 'f1': 0.7024942095284068, 'auc': 0.892571479157376, 'prauc': 0.7859493867940827}
Corresponding test performance:
{'precision': 0.6333771353454517, 'recall': 0.8006644518228092, 'f1': 0.7072633846463869, 'auc': 0.8905350206096274, 'prauc': 0.7852011825320651}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 81.36it/s, loss=0.5356]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 325.31it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 324.89it/s]


Validation: {'precision': 0.6333333333298149, 'recall': 0.6717737183224999, 'f1': 0.6519874127872831, 'auc': 0.837242460362525, 'prauc': 0.684711554662137}
Test:      {'precision': 0.6516976998869021, 'recall': 0.6589147286785221, 'f1': 0.6552863386088781, 'auc': 0.8349351280524826, 'prauc': 0.6936082630946004}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 82.46it/s, loss=0.4489]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 323.60it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 324.08it/s]


Validation: {'precision': 0.7443729903477141, 'recall': 0.5456688273391535, 'f1': 0.6297177781813268, 'auc': 0.8644172696847997, 'prauc': 0.7250027225393844}
Test:      {'precision': 0.7591537835576961, 'recall': 0.5166112956782026, 'f1': 0.6148270132985787, 'auc': 0.8611484518528476, 'prauc': 0.7314503501258912}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 81.53it/s, loss=0.4096]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 325.61it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 325.38it/s]


Validation: {'precision': 0.71270718231552, 'recall': 0.6081319976393157, 'f1': 0.6562798042481542, 'auc': 0.864363809535732, 'prauc': 0.7295254389063469}
Test:      {'precision': 0.7184931506800104, 'recall': 0.5808416389779577, 'f1': 0.6423759901532238, 'auc': 0.8619307609091182, 'prauc': 0.7330527003258654}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 81.63it/s, loss=0.3852]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 325.23it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 325.46it/s]


Validation: {'precision': 0.728412256262337, 'recall': 0.6163818503204692, 'f1': 0.6677306046697609, 'auc': 0.8744783803197159, 'prauc': 0.7472894924969021}
Test:      {'precision': 0.7533333333283112, 'recall': 0.6256921373165798, 'f1': 0.6836055606769339, 'auc': 0.8743753837684474, 'prauc': 0.7559776067261019}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 81.05it/s, loss=0.3465]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 324.72it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 326.09it/s]


Validation: {'precision': 0.7222222222172687, 'recall': 0.6205067766610459, 'f1': 0.6675118809198651, 'auc': 0.8767521931115991, 'prauc': 0.7459307848279431}
Test:      {'precision': 0.7493438320160805, 'recall': 0.6323366555889682, 'f1': 0.685885880917624, 'auc': 0.8746925543874058, 'prauc': 0.7505394684135436}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 81.85it/s, loss=0.3357]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 325.03it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 323.77it/s]


Validation: {'precision': 0.7609178386324137, 'recall': 0.6057748968732718, 'f1': 0.6745406774747027, 'auc': 0.8921234869404584, 'prauc': 0.7780205628506933}
Test:      {'precision': 0.771731448757797, 'recall': 0.6046511627873496, 'f1': 0.6780502900089285, 'auc': 0.8863915523916869, 'prauc': 0.7772977966867874}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 82.13it/s, loss=0.3026]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 325.74it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 323.26it/s]


Validation: {'precision': 0.7862481315337352, 'recall': 0.619917501469535, 'f1': 0.6932454645876309, 'auc': 0.8988035784703652, 'prauc': 0.7963564516842186}
Test:      {'precision': 0.8057347670193138, 'recall': 0.6223698781803856, 'f1': 0.7022805324101254, 'auc': 0.8968290693488973, 'prauc': 0.8078860073043493}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 82.47it/s, loss=0.2721]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 325.32it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 327.74it/s]


Validation: {'precision': 0.7639366827201381, 'recall': 0.654095462577171, 'f1': 0.7047618997874309, 'auc': 0.897161834537709, 'prauc': 0.7937010314827213}
Test:      {'precision': 0.7810218978050364, 'recall': 0.6517165005501013, 'f1': 0.7105342540162132, 'auc': 0.8946318949679617, 'prauc': 0.8047983635650586}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 82.03it/s, loss=0.2474]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 325.71it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 325.59it/s]


Validation: {'precision': 0.7559870550112882, 'recall': 0.6882734236848069, 'f1': 0.7205428697752072, 'auc': 0.9022073995744393, 'prauc': 0.803493149787453}
Test:      {'precision': 0.7644171779094208, 'recall': 0.6899224806163349, 'f1': 0.7252619274885246, 'auc': 0.9004333536696757, 'prauc': 0.815501304155296}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 82.03it/s, loss=0.2167]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 326.06it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 326.48it/s]


Validation: {'precision': 0.6219458018630745, 'recall': 0.8249852681153508, 'f1': 0.7092198532508902, 'auc': 0.9017798461241902, 'prauc': 0.7973427264043931}
Test:      {'precision': 0.6418485237456489, 'recall': 0.8305647840485573, 'f1': 0.7241129567006528, 'auc': 0.9039133056306062, 'prauc': 0.8183866314080296}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 80.55it/s, loss=0.2039]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 315.53it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 324.88it/s]


Validation: {'precision': 0.7578176979324165, 'recall': 0.6711844431309889, 'f1': 0.7118749950139278, 'auc': 0.9025026758816548, 'prauc': 0.80038300984218}
Test:      {'precision': 0.773248407638387, 'recall': 0.672203765223299, 'f1': 0.7191943078163818, 'auc': 0.898554765573182, 'prauc': 0.8093209899118187}


Epoch 012: 100%|██████████| 98/98 [00:01<00:00, 78.73it/s, loss=0.1678]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 326.59it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 324.71it/s]


Validation: {'precision': 0.7494766224651118, 'recall': 0.6328815556827763, 'f1': 0.6862619758618564, 'auc': 0.8906216846883362, 'prauc': 0.785789168389962}
Test:      {'precision': 0.749006622511596, 'recall': 0.6262458471726122, 'f1': 0.682147160295075, 'auc': 0.8830632612895014, 'prauc': 0.7907883749788389}


Epoch 013: 100%|██████████| 98/98 [00:01<00:00, 80.65it/s, loss=0.1594]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 326.18it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 324.98it/s]


Validation: {'precision': 0.6400759734062674, 'recall': 0.7943429581567806, 'f1': 0.7089140103052209, 'auc': 0.8982800905590653, 'prauc': 0.7944049792503285}
Test:      {'precision': 0.6496350364933867, 'recall': 0.7884828349900971, 'f1': 0.712356173132089, 'auc': 0.8930321314210278, 'prauc': 0.7986673108051304}


Epoch 014: 100%|██████████| 98/98 [00:01<00:00, 82.22it/s, loss=0.1444]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 326.06it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 326.82it/s]


Validation: {'precision': 0.7317073170685947, 'recall': 0.6894519740678289, 'f1': 0.709951451310792, 'auc': 0.8982468775632289, 'prauc': 0.7965917722814725}
Test:      {'precision': 0.7509157509111666, 'recall': 0.681063122919817, 'f1': 0.714285709293464, 'auc': 0.8954478722424682, 'prauc': 0.8045002864285016}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7559870550112882, 'recall': 0.6882734236848069, 'f1': 0.7205428697752072, 'auc': 0.9022073995744393, 'prauc': 0.803493149787453}
Corresponding test performance:
{'precision': 0.7644171779094208, 'recall': 0.6899224806163349, 'f1': 0.7252619274885246, 'auc': 0.9004333536696757, 'prauc': 0.815501304155296}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 80.01it/s, loss=0.5612]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 326.23it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 326.35it/s]


Validation: {'precision': 0.5813382443181011, 'recall': 0.5580436063608837, 'f1': 0.5694527911501992, 'auc': 0.7828086595477591, 'prauc': 0.5932131546633252}
Test:      {'precision': 0.6135656502762068, 'recall': 0.5459579180479183, 'f1': 0.5777907948964134, 'auc': 0.7851554671523754, 'prauc': 0.6071685766306697}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 82.65it/s, loss=0.5022]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 327.09it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 326.19it/s]


Validation: {'precision': 0.6822001527831765, 'recall': 0.5262227460192916, 'f1': 0.5941450383261887, 'auc': 0.821640336212602, 'prauc': 0.6683691903379304}
Test:      {'precision': 0.7150496562206643, 'recall': 0.5182724252462998, 'f1': 0.6009630769853819, 'auc': 0.8246580121496105, 'prauc': 0.6830999131704272}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 80.41it/s, loss=0.4397]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 328.13it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 328.02it/s]


Validation: {'precision': 0.8007928642140655, 'recall': 0.4761343547408596, 'f1': 0.5971914217785204, 'auc': 0.862744675988703, 'prauc': 0.7399906941826678}
Test:      {'precision': 0.8062088428898758, 'recall': 0.4745293466197424, 'f1': 0.5974206994092524, 'auc': 0.8645521425296933, 'prauc': 0.7450722699190726}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 82.30it/s, loss=0.3847]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 325.71it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 325.35it/s]


Validation: {'precision': 0.6818960593907374, 'recall': 0.703594578664092, 'f1': 0.692575401029692, 'auc': 0.8797610350498022, 'prauc': 0.7607515189021464}
Test:      {'precision': 0.6985619468987911, 'recall': 0.6993355481688852, 'f1': 0.6989485284770411, 'auc': 0.8810628644107608, 'prauc': 0.7718598106629864}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 83.84it/s, loss=0.3449]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 324.37it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 327.11it/s]


Validation: {'precision': 0.6984126984084347, 'recall': 0.6741308190885438, 'f1': 0.6860569665116936, 'auc': 0.8791103796871362, 'prauc': 0.7627592075582246}
Test:      {'precision': 0.714534883716776, 'recall': 0.6805094130637845, 'f1': 0.6971071986291962, 'auc': 0.8774098319568141, 'prauc': 0.7707665335968834}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 80.80it/s, loss=0.3323]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 325.26it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 326.56it/s]


Validation: {'precision': 0.6806771745436038, 'recall': 0.687094873301785, 'f1': 0.6838709627380346, 'auc': 0.8814513849244423, 'prauc': 0.7677305919105659}
Test:      {'precision': 0.6972425436089069, 'recall': 0.6860465116241083, 'f1': 0.6915992135284237, 'auc': 0.8782551073517602, 'prauc': 0.769350462367904}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 81.68it/s, loss=0.2906]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 324.75it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 325.42it/s]


Validation: {'precision': 0.7034005037739081, 'recall': 0.6582203889177477, 'f1': 0.6800608778019734, 'auc': 0.8787708406758539, 'prauc': 0.7666108818201387}
Test:      {'precision': 0.7231812577020766, 'recall': 0.6495016611259717, 'f1': 0.6843640556871922, 'auc': 0.8780353714484647, 'prauc': 0.7702156273967382}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 82.33it/s, loss=0.2739]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 322.57it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 323.80it/s]


Validation: {'precision': 0.7634660421486069, 'recall': 0.5763111372977237, 'f1': 0.6568166505666292, 'auc': 0.8782644063604932, 'prauc': 0.7672178153036927}
Test:      {'precision': 0.7715355805185653, 'recall': 0.5703211517133426, 'f1': 0.6558420836150973, 'auc': 0.8748197796288937, 'prauc': 0.7677763392279631}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 81.99it/s, loss=0.2405]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 324.43it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 323.09it/s]


Validation: {'precision': 0.7160256410210512, 'recall': 0.6582203889177477, 'f1': 0.6859072716395749, 'auc': 0.8841213904339673, 'prauc': 0.7800255480258407}
Test:      {'precision': 0.7310819261993053, 'recall': 0.6472868217018423, 'f1': 0.6866372931054884, 'auc': 0.8804389867872863, 'prauc': 0.7805275386200679}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 81.62it/s, loss=0.2153]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 321.25it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 321.85it/s]


Validation: {'precision': 0.6706014614914525, 'recall': 0.7030053034725811, 'f1': 0.686421168761779, 'auc': 0.8849499269377963, 'prauc': 0.7781594293883471}
Test:      {'precision': 0.6854111405799183, 'recall': 0.715393133993824, 'f1': 0.700081273784734, 'auc': 0.8830224531931752, 'prauc': 0.7883112458547382}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 80.83it/s, loss=0.2110]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 321.98it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 323.62it/s]


Validation: {'precision': 0.6782460136635635, 'recall': 0.7018267530895591, 'f1': 0.6898349211486372, 'auc': 0.8856510405056756, 'prauc': 0.7757297258205751}
Test:      {'precision': 0.6857601713025387, 'recall': 0.7093023255774679, 'f1': 0.697332602509876, 'auc': 0.882432982396015, 'prauc': 0.7843510259430402}


Epoch 012: 100%|██████████| 98/98 [00:01<00:00, 80.55it/s, loss=0.1890]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 322.12it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 322.40it/s]


Validation: {'precision': 0.6938650306705898, 'recall': 0.6664702415989012, 'f1': 0.6798917894073185, 'auc': 0.8785144618964543, 'prauc': 0.7667456120835716}
Test:      {'precision': 0.7012020606714299, 'recall': 0.6782945736396551, 'f1': 0.6895581148961744, 'auc': 0.8757908645998442, 'prauc': 0.7730323023075619}


Epoch 013: 100%|██████████| 98/98 [00:01<00:00, 80.67it/s, loss=0.1722]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 321.03it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 321.62it/s]


Validation: {'precision': 0.6531137416332831, 'recall': 0.7477902180274144, 'f1': 0.6972527422717533, 'auc': 0.8852135231566756, 'prauc': 0.7727194656922268}
Test:      {'precision': 0.666009852213468, 'recall': 0.7486157253557663, 'f1': 0.7049009334909551, 'auc': 0.8789830056129289, 'prauc': 0.7765585999861522}


Epoch 014: 100%|██████████| 98/98 [00:01<00:00, 81.60it/s, loss=0.1808]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 321.52it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 320.68it/s]


Validation: {'precision': 0.736571008088767, 'recall': 0.5898644667024758, 'f1': 0.6551047070987618, 'auc': 0.8638931430620416, 'prauc': 0.7452053398631776}
Test:      {'precision': 0.7279829545402843, 'recall': 0.5675526024331808, 'f1': 0.6378344692481867, 'auc': 0.855124733670237, 'prauc': 0.7409137190073336}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7160256410210512, 'recall': 0.6582203889177477, 'f1': 0.6859072716395749, 'auc': 0.8841213904339673, 'prauc': 0.7800255480258407}
Corresponding test performance:
{'precision': 0.7310819261993053, 'recall': 0.6472868217018423, 'f1': 0.6866372931054884, 'auc': 0.8804389867872863, 'prauc': 0.7805275386200679}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 81.36it/s, loss=0.5249]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 323.04it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 318.57it/s]


Validation: {'precision': 0.75297225890683, 'recall': 0.33588685916124994, 'f1': 0.4645476729914569, 'auc': 0.7986023332640545, 'prauc': 0.6408423980997614}
Test:      {'precision': 0.7691275167681996, 'recall': 0.31727574750654886, 'f1': 0.44923558974615685, 'auc': 0.7981464137869568, 'prauc': 0.6431435799386234}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 80.51it/s, loss=0.4629]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 319.86it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 321.82it/s]


Validation: {'precision': 0.6449438202210959, 'recall': 0.6764879198545877, 'f1': 0.6603393680217718, 'auc': 0.8583479140514094, 'prauc': 0.7296918358145127}
Test:      {'precision': 0.6774734488503216, 'recall': 0.6710963455112343, 'f1': 0.6742698141896848, 'auc': 0.8611663630987464, 'prauc': 0.7414714084771248}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 81.62it/s, loss=0.4170]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 320.56it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 320.61it/s]


Validation: {'precision': 0.5510370069111385, 'recall': 0.7984678844973574, 'f1': 0.6520692925662945, 'auc': 0.861897105883414, 'prauc': 0.7456672068578514}
Test:      {'precision': 0.5765367617493912, 'recall': 0.7945736434064531, 'f1': 0.6682188542618618, 'auc': 0.8644657253845317, 'prauc': 0.7540594438377235}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 81.30it/s, loss=0.3970]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 326.95it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 325.67it/s]


Validation: {'precision': 0.8664383561544927, 'recall': 0.4472598703568223, 'f1': 0.5899727899079034, 'auc': 0.8858721624125716, 'prauc': 0.7758730393568843}
Test:      {'precision': 0.8510392609601497, 'recall': 0.40808416389585783, 'f1': 0.5516467022015009, 'auc': 0.8836267577025743, 'prauc': 0.7725891185117639}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 83.07it/s, loss=0.3708]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 326.10it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 327.57it/s]


Validation: {'precision': 0.6144796380062694, 'recall': 0.8002357100718902, 'f1': 0.6951625238771152, 'auc': 0.8866611626771259, 'prauc': 0.7736155196602483}
Test:      {'precision': 0.6412655971450925, 'recall': 0.7967884828305826, 'f1': 0.7106172790055882, 'auc': 0.8861489196379639, 'prauc': 0.7747224483880978}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 82.51it/s, loss=0.3264]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 327.93it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 326.91it/s]


Validation: {'precision': 0.7232742822191615, 'recall': 0.6977018267489824, 'f1': 0.7102579434076768, 'auc': 0.8976704404720282, 'prauc': 0.791078087961009}
Test:      {'precision': 0.7250287026364809, 'recall': 0.6993355481688852, 'f1': 0.7119503895861145, 'auc': 0.8916741389183922, 'prauc': 0.7873219141086721}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 81.87it/s, loss=0.2989]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 325.94it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 325.11it/s]


Validation: {'precision': 0.7684996605514699, 'recall': 0.6670595167904122, 'f1': 0.7141955786166746, 'auc': 0.9088407374255917, 'prauc': 0.8081765948617871}
Test:      {'precision': 0.7600510529626034, 'recall': 0.6594684385345545, 'f1': 0.7061962594739255, 'auc': 0.8996668015857423, 'prauc': 0.7971162991040833}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 81.56it/s, loss=0.2732]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 325.21it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 322.91it/s]


Validation: {'precision': 0.8622009569295485, 'recall': 0.5309369475513793, 'f1': 0.6571845321123371, 'auc': 0.910956596873814, 'prauc': 0.8176185970431763}
Test:      {'precision': 0.8555023923363111, 'recall': 0.4950166112929401, 'f1': 0.6271483643451781, 'auc': 0.9066511411248608, 'prauc': 0.8187423209599084}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 82.25it/s, loss=0.2640]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 326.87it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 327.07it/s]


Validation: {'precision': 0.7873399715448981, 'recall': 0.6523276370026381, 'f1': 0.7135030565927105, 'auc': 0.9106593405610773, 'prauc': 0.8160521293065738}
Test:      {'precision': 0.7822045152670505, 'recall': 0.6522702104061336, 'f1': 0.711352652041559, 'auc': 0.9060615472263542, 'prauc': 0.816782379803323}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 83.16it/s, loss=0.2330]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 326.22it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 325.89it/s]


Validation: {'precision': 0.8089461713357927, 'recall': 0.6287566293421994, 'f1': 0.7075596767714608, 'auc': 0.9092834922085505, 'prauc': 0.8167639309618396}
Test:      {'precision': 0.8042059463320943, 'recall': 0.6140642303399, 'f1': 0.6963893200462491, 'auc': 0.9034945148502078, 'prauc': 0.8166213386784587}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 81.19it/s, loss=0.2128]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 325.19it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 325.82it/s]


Validation: {'precision': 0.7056277056238873, 'recall': 0.7684148497302982, 'f1': 0.7356840570641597, 'auc': 0.9098955055279838, 'prauc': 0.810509668679462}
Test:      {'precision': 0.704591836731099, 'recall': 0.764673311180705, 'f1': 0.7334041373305413, 'auc': 0.9040200344979212, 'prauc': 0.814324726490316}


Epoch 012: 100%|██████████| 98/98 [00:01<00:00, 81.06it/s, loss=0.2001]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 327.61it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 325.27it/s]


Validation: {'precision': 0.8037313432775841, 'recall': 0.6346493812573091, 'f1': 0.7092525469248081, 'auc': 0.9125310206187555, 'prauc': 0.8206303116089178}
Test:      {'precision': 0.8016997167082033, 'recall': 0.6267995570286445, 'f1': 0.7035425680973054, 'auc': 0.9054422243526963, 'prauc': 0.8219136048179316}


Epoch 013: 100%|██████████| 98/98 [00:01<00:00, 82.26it/s, loss=0.1806]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 326.10it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 322.82it/s]


Validation: {'precision': 0.8165829145660254, 'recall': 0.5745433117231907, 'f1': 0.6745070861186754, 'auc': 0.8982737034444814, 'prauc': 0.7977837177717894}
Test:      {'precision': 0.8053745928273178, 'recall': 0.5476190476160154, 'f1': 0.651944622731552, 'auc': 0.8912861850252629, 'prauc': 0.799605043319187}


Epoch 014: 100%|██████████| 98/98 [00:01<00:00, 81.48it/s, loss=0.1753]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 325.01it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 325.21it/s]


Validation: {'precision': 0.6753112033159995, 'recall': 0.7672362993472762, 'f1': 0.7183448226025476, 'auc': 0.9030564387160827, 'prauc': 0.8009861152159711}
Test:      {'precision': 0.6833333333298822, 'recall': 0.7491694352117987, 'f1': 0.7147385053078948, 'auc': 0.8950961101451759, 'prauc': 0.7999603920217099}


Epoch 015: 100%|██████████| 98/98 [00:01<00:00, 80.99it/s, loss=0.1673]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 328.69it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 328.89it/s]


Validation: {'precision': 0.819023569016675, 'recall': 0.5733647613401688, 'f1': 0.6745233920313775, 'auc': 0.8976993102299475, 'prauc': 0.799144322498208}
Test:      {'precision': 0.8313253011981421, 'recall': 0.5730897009935045, 'f1': 0.6784660718607664, 'auc': 0.8929484840561599, 'prauc': 0.8056489701480083}


Epoch 016: 100%|██████████| 98/98 [00:01<00:00, 81.35it/s, loss=0.1435]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 326.28it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 323.73it/s]


Validation: {'precision': 0.6970198675458222, 'recall': 0.7442545668783486, 'f1': 0.7198632038926894, 'auc': 0.8952285188242958, 'prauc': 0.7889977236877954}
Test:      {'precision': 0.6935312831352412, 'recall': 0.7242524916903419, 'f1': 0.7085590415857249, 'auc': 0.8853061678206187, 'prauc': 0.7892736951342193}


Epoch 017: 100%|██████████| 98/98 [00:01<00:00, 81.24it/s, loss=0.1305]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 326.11it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 325.95it/s]

Validation: {'precision': 0.8038834951378264, 'recall': 0.4879198585710789, 'f1': 0.6072607213672778, 'auc': 0.8675190441401992, 'prauc': 0.7473270407589773}
Test:      {'precision': 0.8159111933319528, 'recall': 0.4883720930205517, 'f1': 0.6110148896650887, 'auc': 0.8655265512369962, 'prauc': 0.7617726055287253}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.8037313432775841, 'recall': 0.6346493812573091, 'f1': 0.7092525469248081, 'auc': 0.9125310206187555, 'prauc': 0.8206303116089178}
Corresponding test performance:
{'precision': 0.8016997167082033, 'recall': 0.6267995570286445, 'f1': 0.7035425680973054, 'auc': 0.9054422243526963, 'prauc': 0.8219136048179316}





In [35]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")


Final Metrics:
precision: 0.7206 ± 0.0609
recall: 0.6940 ± 0.0604
f1: 0.7023 ± 0.0140
auc: 0.8895 ± 0.0128
prauc: 0.7922 ± 0.0237
