In [1]:
import torch
import pickle
import numpy as np
from argparse import Namespace
from torch.utils.data import DataLoader
import torch.nn.functional as F
from heterogt.utils.tokenizer import EHRTokenizer
from heterogt.utils.dataset import FineTuneEHRDataset, batcher
from heterogt.utils.train import train_with_early_stopping
from heterogt.utils.seed import set_random_seed

In [2]:
set_random_seed(123)

[INFO] Random seed set to 123


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
config = Namespace(
    dataset = "MIMIC-III",
    tasks = ["death", "readmission", "stay", "next_diag_6m", "next_diag_12m"], 
    task_index = 2,
    token_type = ["diag", "med", "lab", "pro"],
    special_tokens = ["[PAD]"],
    batch_size = 32,
    lr = 1e-3,
    epochs = 500,
    early_stop_patience = 5,   
)

In [5]:
full_data_path = f"./data_process/{config.dataset}-processed/mimic.pkl"  # for tokenizer
curr_task = config.tasks[config.task_index]
print("Current task:", curr_task)
if curr_task == "next_diag_6m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_6m.pkl"
elif curr_task == "next_diag_12m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_downstream.pkl"

Current task: stay


In [6]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_gender_sentences = ["[PAD]"] + [str(c) + "_" + gender for c in set(ehr_full_data["AGE"].values.tolist()) for gender in ["M", "F"]]
token_type_sentences = ["[PAD]"] + config.token_type
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
config.max_num_adms = max_admissions
print(f"Max admissions per patient: {config.max_num_adms}")

Max admissions per patient: 8


In [7]:
task_sentences = config.tasks
tokenizer = EHRTokenizer(token_type_sentences, age_gender_sentences, task_sentences, diag_sentences, 
                         med_sentences, lab_sentences, pro_sentences, special_tokens=config.special_tokens)
config.label_vocab_size = len(tokenizer.diag_voc.id2word)  # only for diagnosis
config.global_vocab_size = len(tokenizer.vocab.id2word)
config.age_gender_vocab_size = tokenizer.token_number("age_gender")
print(f"Age and gender vocabulary size: {config.age_gender_vocab_size}")

Age and gender vocabulary size: 37


In [8]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))
# example label percentage
print("Percentage of DEATH in test dataset:",
      (test_data["DEATH"] == True).mean() * 100, "%")

print("Percentage of READMISSION in test dataset:",
      (test_data["READMISSION"] == 1).mean() * 100, "%")

print("Percentage of STAY>7 days in test dataset:",
      (test_data["STAY_DAYS"] > 7).mean() * 100, "%")

Percentage of DEATH in test dataset: 28.648477157360407 %
Percentage of READMISSION in test dataset: 40.1491116751269 %
Percentage of STAY>7 days in test dataset: 50.58692893401015 %


In [9]:
train_dataset = FineTuneEHRDataset(train_data, tokenizer, token_type=config.token_type, task=curr_task)
val_dataset = FineTuneEHRDataset(val_data, tokenizer, token_type=config.token_type, task=curr_task)
test_dataset = FineTuneEHRDataset(test_data, tokenizer, token_type=config.token_type, task=curr_task)

In [10]:
train_dataloader = DataLoader(
    train_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=True,
    batch_size=config.batch_size,
)

val_dataloader = DataLoader(
    val_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

test_dataloader = DataLoader(
    test_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

In [11]:
if curr_task in ["death", "stay", "readmission"]:
    eval_metric = "prauc"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "prauc"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [12]:
input_ids, token_types, adm_index, age_gender_ids, task_index, labels = next(iter(train_dataloader))
print("Input IDs shape:", input_ids.shape)
print("Token Types shape:", token_types.shape)
print("Admission Index shape:", adm_index.shape)
print("Age/Sex IDs shape:", age_gender_ids.shape)
print("Task Index:", task_index)
print("Labels shape:", labels.shape)

Input IDs shape: torch.Size([32, 256])
Token Types shape: torch.Size([32, 256])
Admission Index shape: torch.Size([32, 256])
Age/Sex IDs shape: torch.Size([32, 7])
Task Index: 2
Labels shape: torch.Size([32, 1])


# Model Walkthrough

In [13]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData, Batch as HeteroBatch
from torch_geometric.nn import HeteroConv, GATConv

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [14]:
# class DiseaseOccHetGNN(nn.Module):
#     def __init__(self, d_model: int):
#         super().__init__()
#         self.conv1 = HeteroConv({
#             ('visit','contains','occ'): GATConv(d_model, d_model, add_self_loops=False),
#             ('occ','contained_by','visit'): GATConv(d_model, d_model, add_self_loops=False),
#             ('visit','next','visit'): GATConv(d_model, d_model, add_self_loops=False),
#         }, aggr='mean')
#         self.lin = nn.Linear(d_model, d_model)
    
#     def forward(self, hg: HeteroData):
#         # x_dict: {'visit': [N_visit, d], 'occ': [N_occ, d]}
#         x_dict = {'visit': hg['visit'].x, 'occ': hg['occ'].x}
#         x_dict = self.conv1(x_dict, hg.edge_index_dict)
#         x_dict = {k: self.lin(v) for k, v in x_dict.items()}
#         return x_dict # {'visit': [N_visit, d], 'occ': [N_occ, d]}


class DiseaseOccHetGNN(nn.Module):
    def __init__(self, d_model: int, heads: int = 1, dropout: float = 0.0):
        super().__init__()
        self.act = nn.GELU()

        # 第1层
        self.conv1 = HeteroConv({
            ('visit','contains','occ'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('occ','contained_by','visit'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('visit','next','visit'):       GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
        }, aggr='mean')

        # 第2层
        self.conv2 = HeteroConv({
            ('visit','contains','occ'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('occ','contained_by','visit'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('visit','next','visit'):       GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
        }, aggr='mean')
        self.lin = nn.Linear(d_model, d_model)

    def forward(self, hg):
        # x_dict: {'visit': [N_visit, d], 'occ': [N_occ, d]}
        x_dict = {'visit': hg['visit'].x, 'occ': hg['occ'].x}

        # 第1层：HeteroConv → Linear → GELU → Dropout
        x_dict = self.conv1(x_dict, hg.edge_index_dict)
        # x_dict = {k: self.act(v)   for k, v in x_dict.items()}

        # 第2层：HeteroConv → Linear（末层通常不再加激活/随你需要）
        x_dict = self.conv2(x_dict, hg.edge_index_dict)
        x_dict = {k: self.lin(v) for k, v in x_dict.items()}

        return x_dict  # {'visit': [N_visit, d], 'occ': [N_occ, d]}

In [15]:
# multi-class classification task
class MultiPredictionHead(nn.Module):
    def __init__(self, hidden_size, label_size):
        super(MultiPredictionHead, self).__init__()
        self.cls = nn.Sequential(
                nn.Linear(hidden_size, hidden_size), 
                nn.ReLU(), 
                nn.Linear(hidden_size, label_size)
            )

    def forward(self, input):
        return self.cls(input)
    
class BinaryPredictionHead(nn.Module):
    def __init__(self, hidden_size):
        super(BinaryPredictionHead, self).__init__()
        self.cls = nn.Sequential(
                nn.Linear(hidden_size, hidden_size), 
                nn.ReLU(), 
                nn.Linear(hidden_size, 1)
            )
    def forward(self, input):
        return self.cls(input)

In [16]:
for i in range(len(train_dataset)):
    age_gender_ids = train_dataset[i][3]
    if len(age_gender_ids[0]) > 3:
        print(age_gender_ids)
        break
exp_i = i
id_seq = torch.concat([train_dataset[exp_i][0][0], torch.zeros(5, dtype=train_dataset[exp_i][0][0].dtype)], dim=0)
type_seq = torch.concat([train_dataset[exp_i][1][0], torch.zeros(5, dtype=train_dataset[exp_i][1][0].dtype)], dim=0)
visit_seq = torch.concat([train_dataset[exp_i][2][0], torch.zeros(5, dtype=train_dataset[exp_i][2][0].dtype)], dim=0)
age_sex = torch.concat([train_dataset[exp_i][3][0], torch.zeros(3, dtype=train_dataset[exp_i][3][0].dtype)], dim=0)

tensor([[17, 17, 17,  7]])


In [None]:
class HeteroGT(nn.Module):
    def __init__(self, tokenizer, d_model, num_heads, num_layers, max_num_adms, device, task, use_hetero_graph):
        super(HeteroGT, self).__init__()
        self.device = device
        self.tokenizer = tokenizer
        self.max_num_adms = max_num_adms
        self.use_hetero_graph = use_hetero_graph
        self.global_vocab_size = len(self.tokenizer.vocab.word2id)
        self.age_sex_vocab_size = len(self.tokenizer.age_gender_voc.word2id)
        self.n_type = len(self.tokenizer.token_type_voc.word2id)
        self.d_model = d_model
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.seq_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="all")[0] #0
        self.type_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="type")[0] #0
        self.adm_pad_id = 0
        self.age_sex_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="age_gender")[0] #0
        self.diag_type_id = 1
        self.visit_type_id = 5
        
        # embedding layers
        self.token_emb = nn.Embedding(self.global_vocab_size, d_model, padding_idx=self.seq_pad_id)
        self.type_emb = nn.Embedding(self.n_type + 1, d_model, padding_idx=self.type_pad_id)
        self.adm_index_emb = nn.Embedding(self.max_num_adms + 1, d_model, padding_idx=self.adm_pad_id) # +1 for pad
        self.age_sex_emb = nn.Embedding(self.age_sex_vocab_size, d_model, padding_idx=self.age_sex_pad_id)
        self.task_emb = nn.Embedding(5, d_model, padding_idx=None)  # task embedding, not used in this model
        
        # GNN
        self.het_gnn = DiseaseOccHetGNN(d_model)    

        # encoder transformer
        enc_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, batch_first=True, norm_first = True)
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers, enable_nested_tensor=False)

        # prediction head
        if task in ["death", "stay", "readmission"]:
            self.cls_head = BinaryPredictionHead(self.d_model)
        else:
            self.cls_head = MultiPredictionHead(self.d_model, config.label_vocab_size)

    def forward(self, input_ids, token_types, adm_index, age_gender_index, task_id):
        B, L = input_ids.shape
        task_id = torch.full((B,), task_id, dtype=torch.long, device=self.device)
        # 基础表示
        token_embed = self.token_emb(input_ids)  # [B, L, d]
        adm_emb  = self.adm_index_emb(adm_index)          # [B, L, d]
        type_emb = self.type_emb(token_types)       # [B, L, d]
        x_tokens = token_embed + adm_emb + type_emb  # [B, L, d]
        task_emb = self.task_emb(task_id).unsqueeze(1)           # [B, 1, d]
        x = torch.cat([task_emb, x_tokens], dim=1)  # [B, 1+L, d]
        
        # mask      
        seq_pad_mask = (input_ids == self.seq_pad_id)         # [B, L]
        task_pad_mask = torch.zeros((B, 1), dtype=torch.bool, device=self.device)
        mask = torch.concat([task_pad_mask, seq_pad_mask], dim=1)  # [B, 1+L]
        
        if self.use_hetero_graph:
            # get visit embed and mask
            visit_emb_pad, visit_pad_mask = self.visit_segment(B, input_ids, token_types, adm_index, age_gender_index)
            x = torch.cat([x, visit_emb_pad], dim=1)  # [B, 1+L(+V), d]
            mask = torch.concat([mask, visit_pad_mask], dim=1)

        # ===== Transformer 编码（batch_first=True） =====
        h = self.encoder(x, src_key_padding_mask=mask)   # [B, 1+L(+V), d]

        # ===== 分类：取 CLS（task 位） =====
        logits = self.cls_head(h[:, 0, :])  # [B, label_size]
        return logits

    def visit_segment(self, B, input_ids, token_types, adm_index, age_gender_index):
        graphs = []
        for p in range(B):
            hg_p = self.build_patient_graph(input_ids[p], token_types[p], adm_index[p], age_gender_index[p])
            graphs.append(hg_p)

        batch_graph = HeteroBatch.from_data_list(graphs).to(self.device)
        out = self.het_gnn(batch_graph)
        h_visit_all = out['visit']  # extract virtual visit node representations

        # 取出每个样本的 visit 表示序列（按我们在 build_patient_graph 中的保序构造）
        visit_emb_seq = []
        offset = 0
        for p in range(B):
            n_v = graphs[p]['visit'].num_nodes
            visit_emb_p = h_visit_all[offset:offset + n_v]  # [N_visit_p, d]
            offset += n_v
            visit_emb_seq.append(visit_emb_p)
            
        visit_emb_pad = []
        visit_pad_mask = []
        visit_index_pad = []
        for p in range(B):
            v = visit_emb_seq[p]                          # [N_visit_p, d]
            Np = v.size(0)
            if Np < self.max_num_adms:
                pad_len = self.max_num_adms - Np
                v_pad = torch.cat([v, torch.zeros(pad_len, self.d_model, device=self.device, dtype=v.dtype)], dim=0)
                m_pad = torch.cat([torch.zeros(Np, dtype=torch.bool, device=self.device), torch.ones(pad_len, dtype=torch.bool, device=self.device)], dim=0)
                i_pad = torch.cat([torch.arange(1, Np + 1, device=self.device), torch.full((pad_len,), self.adm_pad_id, dtype=torch.long, device=self.device)], dim=0)
            else:
                v_pad = v[:self.max_num_adms]
                m_pad = torch.zeros(self.max_num_adms, dtype=torch.bool, device=self.device)
                i_pad = torch.arange(1, self.max_num_adms + 1, device=self.device)
            visit_emb_pad.append(v_pad)      # [V_max, d]
            visit_pad_mask.append(m_pad)     # [V_max]
            visit_index_pad.append(i_pad)    # [V_max]

        visit_emb_pad  = torch.stack(visit_emb_pad,  dim=0)  # [B, V_max, d]
        visit_pad_mask = torch.stack(visit_pad_mask, dim=0)  # [B, V_max]
        visit_index_pad = torch.stack(visit_index_pad, dim=0)  # [B, V_max]

        # ====== 对齐与类型嵌入（关键部分） ======
        nonpad = (~visit_pad_mask).unsqueeze(-1)             # [B, V_max, 1], bool

        # 1. 加 type embedding（仅非 pad 位）
        visit_type_ids = torch.full((B, self.max_num_adms), self.visit_type_id, dtype=torch.long, device=self.device) # [B, V_max]                                                   # [B, V_max]
        visit_type_emb = self.type_emb(visit_type_ids) * nonpad      # [B, V_max, d]
        
        # 2. 加 visit index embedding（仅非 pad 位）
        visit_index_emb = self.adm_index_emb(visit_index_pad) * nonpad  # [B, V_max, d]
        
        # 3. 得到最终embedding
        visit_emb_pad  = visit_emb_pad + visit_type_emb + visit_index_emb

        return visit_emb_pad, visit_pad_mask
    
    def build_patient_graph(self, id_seq: torch.Tensor, type_seq: torch.Tensor, visit_seq: torch.Tensor, age_sex: torch.Tensor):
        # build a graph just for one patient
        hg = HeteroData()
        occ_mask = (type_seq == self.diag_type_id) & (id_seq != self.seq_pad_id) # 疾病token mask
        occ_pos = torch.nonzero(occ_mask, as_tuple=False).view(-1) # 疾病 token 的位置索引，形状 [N_occ]
        N_occ = occ_pos.numel() # 疾病 token 数量

        # build visit virtual nodes
        nonpad = id_seq != self.seq_pad_id
        visit_used = visit_seq[nonpad] # seq非pad部分
        visit_ids_unique, visit_lid_nonpad = torch.unique(visit_used, return_inverse=True)
        visit_lid_full = torch.full_like(id_seq, fill_value=-1)
        visit_lid_full[nonpad] = visit_lid_nonpad
        N_visit = visit_ids_unique.numel()
        age_sex_nonpad = age_sex[age_sex!=self.age_sex_pad_id]
        assert N_visit == len(visit_ids_unique) == len(age_sex_nonpad)
        visit_x = self.age_sex_emb(age_sex_nonpad.to(self.device))
        hg['visit'].x = visit_x
        hg['visit'].num_nodes = N_visit
        
        # build diag nodes
        gid_occ = id_seq[occ_pos]
        x_occ = self.token_emb(gid_occ) # [N_occ, d]
        hg['occ'].x = x_occ
        hg['occ'].num_nodes = N_occ

        # build edges between diag nodes and virtual visit nodes
        occ_visit_lid = visit_lid_full[occ_pos]
        e_v2o = torch.stack([occ_visit_lid, torch.arange(N_occ, device=self.device)], dim=0)
        e_o2v = torch.stack([torch.arange(N_occ, device=self.device), occ_visit_lid], dim=0)
        hg['visit','contains','occ'].edge_index = e_v2o
        hg['occ','contained_by','visit'].edge_index = e_o2v
        
        # build forward edges between virtual visit nodes
        if N_visit > 1:
            src = torch.arange(0, N_visit - 1, device=self.device)
            dst = torch.arange(1, N_visit, device=self.device)
            e_next = torch.stack([src, dst], dim=0) # [2, N_visit-1]
        else:
            e_next = torch.empty(2, 0, dtype=torch.long, device=self.device)
        hg['visit','next','visit'].edge_index = e_next
        return hg

In [18]:
final_metrics = {"precision":[],"recall":[],"f1":[],"auc":[],"prauc":[]}
for i in range(5):
    model = HeteroGT(tokenizer, d_model=128, num_heads=4, num_layers=2, max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, use_hetero_graph=True).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

Epoch 001: 100%|██████████| 98/98 [00:02<00:00, 33.56it/s, loss=0.7016]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 66.53it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 66.49it/s]


Validation: {'precision': 0.7512324611272233, 'recall': 0.6211978676681681, 'f1': 0.6800549212354794, 'auc': 0.776896440614259, 'prauc': 0.7812321168399854}
Test:      {'precision': 0.7332116788294408, 'recall': 0.6299780495433366, 'f1': 0.6776859454396119, 'auc': 0.7655372324709688, 'prauc': 0.7670472141096356}


Epoch 002: 100%|██████████| 98/98 [00:02<00:00, 36.89it/s, loss=0.5851]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 61.02it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 66.44it/s]


Validation: {'precision': 0.8185966913815255, 'recall': 0.44998432110238323, 'f1': 0.5807365393291547, 'auc': 0.7889364363180105, 'prauc': 0.7975717189829002}
Test:      {'precision': 0.8121964382039277, 'recall': 0.4719347757903044, 'f1': 0.5969853186330997, 'auc': 0.7804364622168801, 'prauc': 0.7907392164579675}


Epoch 003: 100%|██████████| 98/98 [00:02<00:00, 36.84it/s, loss=0.5398]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 66.68it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 66.81it/s]


Validation: {'precision': 0.7873585641795265, 'recall': 0.632800250860355, 'f1': 0.7016689797577556, 'auc': 0.8112359863271585, 'prauc': 0.8215528250971054}
Test:      {'precision': 0.774653312785922, 'recall': 0.6306052053915628, 'f1': 0.6952463217571352, 'auc': 0.7954322316832491, 'prauc': 0.8101810248414388}


Epoch 004: 100%|██████████| 98/98 [00:02<00:00, 37.29it/s, loss=0.5209]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.06it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 66.94it/s]


Validation: {'precision': 0.7529752331915311, 'recall': 0.7340859203489054, 'f1': 0.7434106015402054, 'auc': 0.8202467047441295, 'prauc': 0.8292892166949365}
Test:      {'precision': 0.7352941176447824, 'recall': 0.7447475697687528, 'f1': 0.739990647747548, 'auc': 0.8095564256546002, 'prauc': 0.8206227996546465}


Epoch 005: 100%|██████████| 98/98 [00:02<00:00, 36.89it/s, loss=0.4977]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 66.84it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 66.81it/s]


Validation: {'precision': 0.7249014255361695, 'recall': 0.7494512386304502, 'f1': 0.7369719345612478, 'auc': 0.8088296952366196, 'prauc': 0.8181006170335507}
Test:      {'precision': 0.7031990521306185, 'recall': 0.7444339918446397, 'f1': 0.7232292410033768, 'auc': 0.7952249581854156, 'prauc': 0.8070056017937368}


Epoch 006: 100%|██████████| 98/98 [00:02<00:00, 36.96it/s, loss=0.4714]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.10it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 60.99it/s]


Validation: {'precision': 0.779138827020122, 'recall': 0.6582000627135208, 'f1': 0.7135815010673088, 'auc': 0.816482363025174, 'prauc': 0.8175778045217283}
Test:      {'precision': 0.7620751341654309, 'recall': 0.6679209783610288, 'f1': 0.711898390741222, 'auc': 0.8043603438183121, 'prauc': 0.8051074290062261}


Epoch 007: 100%|██████████| 98/98 [00:02<00:00, 37.06it/s, loss=0.4448]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 66.97it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.00it/s]


Validation: {'precision': 0.6774683544286647, 'recall': 0.8391345249268136, 'f1': 0.7496848248628114, 'auc': 0.8177241657656702, 'prauc': 0.8254618862602946}
Test:      {'precision': 0.6621050012424432, 'recall': 0.8344308560651162, 'f1': 0.7383462769732388, 'auc': 0.8084458162010562, 'prauc': 0.8180261568502593}


Epoch 008: 100%|██████████| 98/98 [00:02<00:00, 36.91it/s, loss=0.4186]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.09it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 66.86it/s]


Validation: {'precision': 0.7120699379562547, 'recall': 0.7917842583857266, 'f1': 0.7498143973874717, 'auc': 0.8048520984250873, 'prauc': 0.7946590721589851}
Test:      {'precision': 0.7005005561715781, 'recall': 0.7899027908410476, 'f1': 0.7425202603068858, 'auc': 0.7945353384200404, 'prauc': 0.7875646412029058}


Epoch 009: 100%|██████████| 98/98 [00:02<00:00, 37.04it/s, loss=0.3792]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.05it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 66.92it/s]


Validation: {'precision': 0.716975493123978, 'recall': 0.7522734399474686, 'f1': 0.7342004540672037, 'auc': 0.8060343203552665, 'prauc': 0.8055615796591242}
Test:      {'precision': 0.7201903062720185, 'recall': 0.7594857322020713, 'f1': 0.739316234317509, 'auc': 0.8027900381880533, 'prauc': 0.8001100830362586}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7529752331915311, 'recall': 0.7340859203489054, 'f1': 0.7434106015402054, 'auc': 0.8202467047441295, 'prauc': 0.8292892166949365}
Corresponding test performance:
{'precision': 0.7352941176447824, 'recall': 0.7447475697687528, 'f1': 0.739990647747548, 'auc': 0.8095564256546002, 'prauc': 0.8206227996546465}


Epoch 001: 100%|██████████| 98/98 [00:02<00:00, 37.36it/s, loss=0.6742]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.19it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.16it/s]


Validation: {'precision': 0.7012867647037339, 'recall': 0.717779868295021, 'f1': 0.7094374659422238, 'auc': 0.766816733948774, 'prauc': 0.7725094677713455}
Test:      {'precision': 0.6929745889366429, 'recall': 0.7268736280943027, 'f1': 0.7095194317928425, 'auc': 0.7607463355928058, 'prauc': 0.7672093856852096}


Epoch 002: 100%|██████████| 98/98 [00:02<00:00, 36.89it/s, loss=0.5906]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 61.29it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 66.87it/s]


Validation: {'precision': 0.746294381245273, 'recall': 0.6788962057049893, 'f1': 0.711001637044972, 'auc': 0.7976353350978496, 'prauc': 0.8046114811700842}
Test:      {'precision': 0.7396708095373206, 'recall': 0.6904985888971763, 'f1': 0.7142393722335649, 'auc': 0.7944637641330274, 'prauc': 0.8036450637123028}


Epoch 003: 100%|██████████| 98/98 [00:02<00:00, 37.49it/s, loss=0.5525]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.21it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.00it/s]


Validation: {'precision': 0.6530074287067985, 'recall': 0.8544998432083585, 'f1': 0.7402879603141536, 'auc': 0.8043152180542112, 'prauc': 0.8145168735803731}
Test:      {'precision': 0.6493381468095082, 'recall': 0.8460332392573031, 'f1': 0.7347494504221982, 'auc': 0.8031438326067686, 'prauc': 0.8177376102982592}


Epoch 004: 100%|██████████| 98/98 [00:02<00:00, 37.15it/s, loss=0.5146]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.12it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.05it/s]


Validation: {'precision': 0.7581163039573382, 'recall': 0.6663530887404631, 'f1': 0.709279033736673, 'auc': 0.8040603669153085, 'prauc': 0.8158073628034143}
Test:      {'precision': 0.753546099288108, 'recall': 0.6663530887404631, 'f1': 0.7072724197127896, 'auc': 0.8058767925659382, 'prauc': 0.8220751677946724}


Epoch 005: 100%|██████████| 98/98 [00:02<00:00, 37.18it/s, loss=0.4856]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.43it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.06it/s]


Validation: {'precision': 0.7534663510289027, 'recall': 0.6986516149241184, 'f1': 0.7250244011225657, 'auc': 0.8100744619466005, 'prauc': 0.8198097310686926}
Test:      {'precision': 0.7468227424724186, 'recall': 0.7002195045446842, 'f1': 0.7227706698693298, 'auc': 0.8110701563913271, 'prauc': 0.8219188668682799}


Epoch 006: 100%|██████████| 98/98 [00:02<00:00, 37.18it/s, loss=0.4496]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.28it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 61.23it/s]


Validation: {'precision': 0.7291421856617828, 'recall': 0.7783004076488608, 'f1': 0.752919758388442, 'auc': 0.8223709163659242, 'prauc': 0.8317124778682918}
Test:      {'precision': 0.7214677838754076, 'recall': 0.7830040765105581, 'f1': 0.7509774386151291, 'auc': 0.8229977445542891, 'prauc': 0.8359646576760589}


Epoch 007: 100%|██████████| 98/98 [00:02<00:00, 36.79it/s, loss=0.4281]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.03it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.23it/s]


Validation: {'precision': 0.8014642549491755, 'recall': 0.583568516774589, 'f1': 0.6753765148091966, 'auc': 0.8200770049319447, 'prauc': 0.824876686192703}
Test:      {'precision': 0.8117697966213251, 'recall': 0.5882721856362864, 'f1': 0.6821818133067563, 'auc': 0.8215908215791945, 'prauc': 0.8279041673820952}


Epoch 008: 100%|██████████| 98/98 [00:02<00:00, 37.31it/s, loss=0.3968]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.10it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.23it/s]


Validation: {'precision': 0.7785497177560636, 'recall': 0.5622452179348942, 'f1': 0.6529497402115102, 'auc': 0.794125040729462, 'prauc': 0.7942924985785802}
Test:      {'precision': 0.7785684386740119, 'recall': 0.5832549388504759, 'f1': 0.6669056960685451, 'auc': 0.7977490843071613, 'prauc': 0.8078616786196291}


Epoch 009: 100%|██████████| 98/98 [00:02<00:00, 37.30it/s, loss=0.3687]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.37it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.33it/s]


Validation: {'precision': 0.7538411245480424, 'recall': 0.7231106930049448, 'f1': 0.7381562049869979, 'auc': 0.8225016324438712, 'prauc': 0.8288714798940928}
Test:      {'precision': 0.7543173672181351, 'recall': 0.7259328943219633, 'f1': 0.7398529831746099, 'auc': 0.822797366750774, 'prauc': 0.8293322791197251}


Epoch 010: 100%|██████████| 98/98 [00:02<00:00, 37.20it/s, loss=0.3277]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.47it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 61.71it/s]


Validation: {'precision': 0.7452830188655809, 'recall': 0.743179680148187, 'f1': 0.744229858398521, 'auc': 0.8174304816028424, 'prauc': 0.8222064246826339}
Test:      {'precision': 0.7376449054278901, 'recall': 0.7582314205056186, 'f1': 0.7477965003334112, 'auc': 0.8217391041738077, 'prauc': 0.828506126061022}


Epoch 011: 100%|██████████| 98/98 [00:02<00:00, 37.38it/s, loss=0.3109]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.18it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.07it/s]


Validation: {'precision': 0.7504866969475974, 'recall': 0.7253057384737369, 'f1': 0.7376813855269296, 'auc': 0.818761806269127, 'prauc': 0.8220359048933393}
Test:      {'precision': 0.747294716738551, 'recall': 0.7362809658176975, 'f1': 0.7417469544040295, 'auc': 0.8195293109792038, 'prauc': 0.8274515105455527}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7291421856617828, 'recall': 0.7783004076488608, 'f1': 0.752919758388442, 'auc': 0.8223709163659242, 'prauc': 0.8317124778682918}
Corresponding test performance:
{'precision': 0.7214677838754076, 'recall': 0.7830040765105581, 'f1': 0.7509774386151291, 'auc': 0.8229977445542891, 'prauc': 0.8359646576760589}


Epoch 001: 100%|██████████| 98/98 [00:02<00:00, 37.17it/s, loss=0.6891]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.30it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.30it/s]


Validation: {'precision': 0.7341309405093582, 'recall': 0.6926936343659684, 'f1': 0.7128105790612911, 'auc': 0.7846636984772933, 'prauc': 0.7983661327785581}
Test:      {'precision': 0.7149774047749679, 'recall': 0.6945751019106473, 'f1': 0.7046285936945275, 'auc': 0.7769291208191078, 'prauc': 0.7969955576697249}


Epoch 002: 100%|██████████| 98/98 [00:02<00:00, 37.11it/s, loss=0.6176]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.10it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.10it/s]


Validation: {'precision': 0.8124392614149056, 'recall': 0.524302289117202, 'f1': 0.6373165570747474, 'auc': 0.7755344715177102, 'prauc': 0.7894767954806408}
Test:      {'precision': 0.7947882736119368, 'recall': 0.5355910943852756, 'f1': 0.6399400476414978, 'auc': 0.7684085593183229, 'prauc': 0.7823552510061793}


Epoch 003: 100%|██████████| 98/98 [00:02<00:00, 37.04it/s, loss=0.5682]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 66.99it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.34it/s]


Validation: {'precision': 0.7549530761183353, 'recall': 0.6810912511737814, 'f1': 0.7161226458516181, 'auc': 0.8043744472071319, 'prauc': 0.8125214857998292}
Test:      {'precision': 0.742629205685943, 'recall': 0.6713703355262736, 'f1': 0.705204211084157, 'auc': 0.7935232316948257, 'prauc': 0.8060523367253574}


Epoch 004: 100%|██████████| 98/98 [00:02<00:00, 37.19it/s, loss=0.5303]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 61.71it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.24it/s]


Validation: {'precision': 0.6937536148042401, 'recall': 0.7522734399474686, 'f1': 0.721829391726342, 'auc': 0.7854458850005963, 'prauc': 0.791642364815492}
Test:      {'precision': 0.6823430019693542, 'recall': 0.7597993101261844, 'f1': 0.7189910929351386, 'auc': 0.779986732080129, 'prauc': 0.7919550365216961}


Epoch 005: 100%|██████████| 98/98 [00:02<00:00, 37.34it/s, loss=0.5052]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.11it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 66.94it/s]


Validation: {'precision': 0.6888947230226072, 'recall': 0.8228284728729294, 'f1': 0.7499285460516792, 'auc': 0.8149620476266692, 'prauc': 0.8208839868919833}
Test:      {'precision': 0.6764553014535435, 'recall': 0.816243336466553, 'f1': 0.7398038887464506, 'auc': 0.8041398829342639, 'prauc': 0.815317345428293}


Epoch 006: 100%|██████████| 98/98 [00:02<00:00, 37.09it/s, loss=0.5042]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 66.82it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 66.85it/s]


Validation: {'precision': 0.7783964365227232, 'recall': 0.6575729068652946, 'f1': 0.712901575859084, 'auc': 0.8185729662472199, 'prauc': 0.8266745993204457}
Test:      {'precision': 0.7643884892058835, 'recall': 0.6663530887404631, 'f1': 0.7120120573430868, 'auc': 0.8093214183788877, 'prauc': 0.8203040283886383}


Epoch 007: 100%|██████████| 98/98 [00:02<00:00, 36.97it/s, loss=0.4531]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 66.98it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 66.82it/s]


Validation: {'precision': 0.7530695770779227, 'recall': 0.6923800564418552, 'f1': 0.721450738349048, 'auc': 0.8107975197905247, 'prauc': 0.8172516135531398}
Test:      {'precision': 0.7579130434756247, 'recall': 0.6832862966425736, 'f1': 0.7186675411851786, 'auc': 0.8064182807373057, 'prauc': 0.817551792511682}


Epoch 008: 100%|██████████| 98/98 [00:02<00:00, 37.20it/s, loss=0.4340]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 66.91it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 61.48it/s]


Validation: {'precision': 0.7032115171630586, 'recall': 0.796487927247424, 'f1': 0.7469489731086028, 'auc': 0.8112488469405155, 'prauc': 0.8167147045480188}
Test:      {'precision': 0.6958195819562821, 'recall': 0.7933521480062924, 'f1': 0.7413919364112165, 'auc': 0.807912381395316, 'prauc': 0.8190617327576497}


Epoch 009: 100%|██████████| 98/98 [00:02<00:00, 37.40it/s, loss=0.4061]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.34it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.19it/s]


Validation: {'precision': 0.6688311688295603, 'recall': 0.8720602069586953, 'f1': 0.7570436863874873, 'auc': 0.8128045290257513, 'prauc': 0.8148770798586222}
Test:      {'precision': 0.6620091544190267, 'recall': 0.8617121354629611, 'f1': 0.7487738370456999, 'auc': 0.809676320135377, 'prauc': 0.8177622028400311}


Epoch 010: 100%|██████████| 98/98 [00:02<00:00, 37.10it/s, loss=0.3767]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.12it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.13it/s]


Validation: {'precision': 0.6890343698835542, 'recall': 0.7920978363098398, 'f1': 0.7369803013677926, 'auc': 0.8030383500476094, 'prauc': 0.8088342577200301}
Test:      {'precision': 0.6950122649204279, 'recall': 0.7996237064885556, 'f1': 0.7436570378919665, 'auc': 0.8048948356282908, 'prauc': 0.8109058733632233}


Epoch 011: 100%|██████████| 98/98 [00:02<00:00, 37.31it/s, loss=0.3776]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.36it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.34it/s]


Validation: {'precision': 0.7424142480186596, 'recall': 0.705863907178721, 'f1': 0.7236778602958267, 'auc': 0.804777898714431, 'prauc': 0.8085676255947514}
Test:      {'precision': 0.7460474308275822, 'recall': 0.7102539981163053, 'f1': 0.7277108383741765, 'auc': 0.8087914062535391, 'prauc': 0.8144686274544731}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7783964365227232, 'recall': 0.6575729068652946, 'f1': 0.712901575859084, 'auc': 0.8185729662472199, 'prauc': 0.8266745993204457}
Corresponding test performance:
{'precision': 0.7643884892058835, 'recall': 0.6663530887404631, 'f1': 0.7120120573430868, 'auc': 0.8093214183788877, 'prauc': 0.8203040283886383}


Epoch 001: 100%|██████████| 98/98 [00:02<00:00, 37.59it/s, loss=0.6796]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.54it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.49it/s]


Validation: {'precision': 0.7262550881929232, 'recall': 0.6713703355262736, 'f1': 0.6977350447039866, 'auc': 0.773314006242823, 'prauc': 0.7891361538670264}
Test:      {'precision': 0.7173333333309423, 'recall': 0.6748196926915183, 'f1': 0.695427366144765, 'auc': 0.7645691675890287, 'prauc': 0.7804196761162413}


Epoch 002: 100%|██████████| 98/98 [00:02<00:00, 33.66it/s, loss=0.5892]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.40it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.35it/s]


Validation: {'precision': 0.8233805667974525, 'recall': 0.5101912825321098, 'f1': 0.6300096758154433, 'auc': 0.7795588387629738, 'prauc': 0.789986409332988}
Test:      {'precision': 0.8011667476869171, 'recall': 0.5167764189384861, 'f1': 0.6282882148262998, 'auc': 0.7732892008896957, 'prauc': 0.7869455382380754}


Epoch 003: 100%|██████████| 98/98 [00:02<00:00, 37.31it/s, loss=0.5482]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.33it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.36it/s]


Validation: {'precision': 0.758584442884518, 'recall': 0.6788962057049893, 'f1': 0.7165315190904394, 'auc': 0.8019704167712847, 'prauc': 0.8131091545268385}
Test:      {'precision': 0.743737305346162, 'recall': 0.6889306992766104, 'f1': 0.7152856860354296, 'auc': 0.7993972559163296, 'prauc': 0.8098158402942764}


Epoch 004: 100%|██████████| 98/98 [00:02<00:00, 37.15it/s, loss=0.5257]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.25it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.26it/s]


Validation: {'precision': 0.6859174964418608, 'recall': 0.7560363750368265, 'f1': 0.7192720713819927, 'auc': 0.7878603646847959, 'prauc': 0.796914927529499}
Test:      {'precision': 0.6854769407027056, 'recall': 0.7503919724027897, 'f1': 0.7164670608763379, 'auc': 0.7854371492696353, 'prauc': 0.7958445153739651}


Epoch 005: 100%|██████████| 98/98 [00:02<00:00, 37.36it/s, loss=0.5117]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.32it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.03it/s]


Validation: {'precision': 0.6528424976685628, 'recall': 0.8786453433650717, 'f1': 0.7490977093160208, 'auc': 0.8073649417067581, 'prauc': 0.8169135121340757}
Test:      {'precision': 0.6440599769304636, 'recall': 0.8755095641239401, 'f1': 0.7421584214829756, 'auc': 0.7997143571879056, 'prauc': 0.8083305378608219}


Epoch 006: 100%|██████████| 98/98 [00:02<00:00, 37.31it/s, loss=0.4762]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 61.48it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.11it/s]


Validation: {'precision': 0.7692307692278106, 'recall': 0.6271558482263181, 'f1': 0.6909656195095558, 'auc': 0.8033250010625077, 'prauc': 0.8138109138840369}
Test:      {'precision': 0.7664206642038139, 'recall': 0.6513013483830313, 'f1': 0.7041871453950485, 'auc': 0.8030158847603646, 'prauc': 0.8129220799286946}


Epoch 007: 100%|██████████| 98/98 [00:02<00:00, 37.14it/s, loss=0.4544]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 66.88it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.16it/s]


Validation: {'precision': 0.7653239929920654, 'recall': 0.6851677641872526, 'f1': 0.7230311002412023, 'auc': 0.8179999154012778, 'prauc': 0.825105053629539}
Test:      {'precision': 0.7590361445757005, 'recall': 0.6914393226695157, 'f1': 0.7236626139779625, 'auc': 0.8147114856597242, 'prauc': 0.8207061568483447}


Epoch 008: 100%|██████████| 98/98 [00:02<00:00, 37.42it/s, loss=0.4308]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.23it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.25it/s]


Validation: {'precision': 0.7260630604549434, 'recall': 0.7870805895240293, 'f1': 0.7553415538382379, 'auc': 0.82271227522436, 'prauc': 0.8294309778408451}
Test:      {'precision': 0.714082098059538, 'recall': 0.7855126999034635, 'f1': 0.748096157469914, 'auc': 0.8152676208898265, 'prauc': 0.8234894008901932}


Epoch 009: 100%|██████████| 98/98 [00:02<00:00, 37.09it/s, loss=0.4214]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.28it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.21it/s]


Validation: {'precision': 0.6914478814642802, 'recall': 0.8341172781410031, 'f1': 0.756111421988709, 'auc': 0.819550674282963, 'prauc': 0.8270716798794073}
Test:      {'precision': 0.6825275006378038, 'recall': 0.8366259015339084, 'f1': 0.7517610545026964, 'auc': 0.8156890635798115, 'prauc': 0.823914148583915}


Epoch 010: 100%|██████████| 98/98 [00:02<00:00, 37.07it/s, loss=0.3810]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.40it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 61.56it/s]


Validation: {'precision': 0.6368980330078298, 'recall': 0.8833490122267691, 'f1': 0.7401471312303509, 'auc': 0.8073685587542646, 'prauc': 0.8143760590177036}
Test:      {'precision': 0.6344206974113962, 'recall': 0.8842897459991086, 'f1': 0.7388000999277515, 'auc': 0.8003025548799116, 'prauc': 0.8080684894298632}


Epoch 011: 100%|██████████| 98/98 [00:02<00:00, 37.12it/s, loss=0.3755]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.15it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.14it/s]


Validation: {'precision': 0.6862086776841781, 'recall': 0.8331765443686636, 'f1': 0.7525846147868966, 'auc': 0.8152291565376777, 'prauc': 0.8206212352066721}
Test:      {'precision': 0.6728016359901002, 'recall': 0.8253370962658346, 'f1': 0.7413040367340151, 'auc': 0.8074942607186522, 'prauc': 0.8148742585683033}


Epoch 012: 100%|██████████| 98/98 [00:02<00:00, 37.23it/s, loss=0.3664]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.11it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.09it/s]


Validation: {'precision': 0.7691737680593954, 'recall': 0.6509877704589182, 'f1': 0.7051630385104936, 'auc': 0.809155229512214, 'prauc': 0.8149623782995125}
Test:      {'precision': 0.7602189780994153, 'recall': 0.6531828159277103, 'f1': 0.7026479963756046, 'auc': 0.8040464135594517, 'prauc': 0.8091535545430885}


Epoch 013: 100%|██████████| 98/98 [00:02<00:00, 37.27it/s, loss=0.3474]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.17it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.12it/s]


Validation: {'precision': 0.6636825550431559, 'recall': 0.8601442458423953, 'f1': 0.7492488341960596, 'auc': 0.8079223689169425, 'prauc': 0.8092337164762614}
Test:      {'precision': 0.6601637764916662, 'recall': 0.859517089994169, 'f1': 0.7467647410314265, 'auc': 0.8019368847669079, 'prauc': 0.8062043171819306}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7260630604549434, 'recall': 0.7870805895240293, 'f1': 0.7553415538382379, 'auc': 0.82271227522436, 'prauc': 0.8294309778408451}
Corresponding test performance:
{'precision': 0.714082098059538, 'recall': 0.7855126999034635, 'f1': 0.748096157469914, 'auc': 0.8152676208898265, 'prauc': 0.8234894008901932}


Epoch 001: 100%|██████████| 98/98 [00:02<00:00, 37.34it/s, loss=0.6958]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.38it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.25it/s]


Validation: {'precision': 0.715524977562921, 'recall': 0.7500783944786765, 'f1': 0.73239436119772, 'auc': 0.7981347388376155, 'prauc': 0.8083094443241943}
Test:      {'precision': 0.7030606238943407, 'recall': 0.749137660706337, 'f1': 0.7253681443879839, 'auc': 0.7930821085925888, 'prauc': 0.8069151494400856}


Epoch 002: 100%|██████████| 98/98 [00:02<00:00, 33.76it/s, loss=0.5833]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.34it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.13it/s]


Validation: {'precision': 0.60235732009801, 'recall': 0.9134524929416323, 'f1': 0.72598130362001, 'auc': 0.8055421507105137, 'prauc': 0.8168920583625514}
Test:      {'precision': 0.601568951278692, 'recall': 0.9137660708657455, 'f1': 0.7255072776702786, 'auc': 0.8001220588227891, 'prauc': 0.8154582540761307}


Epoch 003: 100%|██████████| 98/98 [00:02<00:00, 37.47it/s, loss=0.5529]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.52it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 66.97it/s]


Validation: {'precision': 0.7678509830949712, 'recall': 0.6980244590758921, 'f1': 0.7312746336423201, 'auc': 0.8160773541779763, 'prauc': 0.8282435119181344}
Test:      {'precision': 0.7563797209909616, 'recall': 0.6970837253035526, 'f1': 0.7255221882174422, 'auc': 0.8080380642326376, 'prauc': 0.8230902727585464}


Epoch 004: 100%|██████████| 98/98 [00:02<00:00, 37.25it/s, loss=0.5278]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.26it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.32it/s]


Validation: {'precision': 0.7605734766997829, 'recall': 0.6654123549681236, 'f1': 0.7098176902866598, 'auc': 0.8092154131637823, 'prauc': 0.8210917948601998}
Test:      {'precision': 0.7613595706591723, 'recall': 0.6672938225128024, 'f1': 0.7112299415433632, 'auc': 0.8078940599885138, 'prauc': 0.8201508681118791}


Epoch 005: 100%|██████████| 98/98 [00:02<00:00, 37.29it/s, loss=0.5124]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.51it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 66.94it/s]


Validation: {'precision': 0.7796610169461493, 'recall': 0.634681718405034, 'f1': 0.6997407037795916, 'auc': 0.8072056409061549, 'prauc': 0.8197521457959733}
Test:      {'precision': 0.7707641195984838, 'recall': 0.6547507055482761, 'f1': 0.7080366176146422, 'auc': 0.8089806603457813, 'prauc': 0.8232374376199614}


Epoch 006: 100%|██████████| 98/98 [00:02<00:00, 37.08it/s, loss=0.4862]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 60.87it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 66.81it/s]


Validation: {'precision': 0.7429837518441862, 'recall': 0.788648479144595, 'f1': 0.7651353768092357, 'auc': 0.834663904448054, 'prauc': 0.842451013613768}
Test:      {'precision': 0.7314652656137436, 'recall': 0.7858262778275766, 'f1': 0.7576719526760851, 'auc': 0.8298115965445022, 'prauc': 0.8384766131054688}


Epoch 007: 100%|██████████| 98/98 [00:02<00:00, 36.79it/s, loss=0.4516]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 66.66it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 66.58it/s]


Validation: {'precision': 0.7704688593395015, 'recall': 0.6904985888971763, 'f1': 0.7282950173376925, 'auc': 0.8159646731007913, 'prauc': 0.823505248564107}
Test:      {'precision': 0.7683264177013543, 'recall': 0.6967701473794394, 'f1': 0.7308008501320363, 'auc': 0.8160339489628021, 'prauc': 0.8219656293536692}


Epoch 008: 100%|██████████| 98/98 [00:02<00:00, 36.96it/s, loss=0.4263]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 66.42it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 66.65it/s]


Validation: {'precision': 0.7363717604982533, 'recall': 0.7751646284077292, 'f1': 0.7552703891348079, 'auc': 0.8234483443919538, 'prauc': 0.8281470506774974}
Test:      {'precision': 0.7306784660745408, 'recall': 0.776732518028295, 'f1': 0.7530019709865702, 'auc': 0.8211880022972226, 'prauc': 0.8283601519746688}


Epoch 009: 100%|██████████| 98/98 [00:02<00:00, 36.82it/s, loss=0.4124]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 66.95it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.46it/s]


Validation: {'precision': 0.7264739229004352, 'recall': 0.8037002195020266, 'f1': 0.7631383008017393, 'auc': 0.8313227070506002, 'prauc': 0.8348720418947138}
Test:      {'precision': 0.7176536626418253, 'recall': 0.8018187519573478, 'f1': 0.7574052082832395, 'auc': 0.8233132351527397, 'prauc': 0.8281372644814688}


Epoch 010: 100%|██████████| 98/98 [00:02<00:00, 37.27it/s, loss=0.3866]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.25it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 61.35it/s]


Validation: {'precision': 0.708913649023095, 'recall': 0.7980558168679898, 'f1': 0.7508482027155312, 'auc': 0.8158642497957121, 'prauc': 0.8148847129275961}
Test:      {'precision': 0.7035287579863754, 'recall': 0.7939793038545188, 'f1': 0.746022387473321, 'auc': 0.8118871703342196, 'prauc': 0.8130526419030113}


Epoch 011: 100%|██████████| 98/98 [00:02<00:00, 37.48it/s, loss=0.3488]
Running inference: 100%|██████████| 198/198 [00:02<00:00, 67.41it/s]
Running inference: 100%|██████████| 197/197 [00:02<00:00, 67.19it/s]

Validation: {'precision': 0.689646335210944, 'recall': 0.843838193788511, 'f1': 0.758990264403932, 'auc': 0.8244369035702168, 'prauc': 0.827649992290554}
Test:      {'precision': 0.6825599184072347, 'recall': 0.8394481028509269, 'f1': 0.7529180093949839, 'auc': 0.8192846396647384, 'prauc': 0.8247632884110783}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7429837518441862, 'recall': 0.788648479144595, 'f1': 0.7651353768092357, 'auc': 0.834663904448054, 'prauc': 0.842451013613768}
Corresponding test performance:
{'precision': 0.7314652656137436, 'recall': 0.7858262778275766, 'f1': 0.7576719526760851, 'auc': 0.8298115965445022, 'prauc': 0.8384766131054688}





In [19]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")


Final Metrics:
precision: 0.7333 ± 0.0172
recall: 0.7531 ± 0.0461
f1: 0.7417 ± 0.0159
auc: 0.8174 ± 0.0080
prauc: 0.8278 ± 0.0078


In [None]:
final_metrics = {"precision":[],"recall":[],"f1":[],"auc":[],"prauc":[]}
for i in range(5):
    model = HeteroGT(tokenizer, d_model=128, num_heads=4, num_layers=2, max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, use_hetero_graph=False).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

Epoch 001:   0%|          | 0/98 [00:00<?, ?it/s, loss=0.7284]

Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 82.23it/s, loss=0.6664]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 317.35it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 320.37it/s]


Validation: {'precision': 0.649757281551821, 'recall': 0.8394481028509269, 'f1': 0.7325215438546027, 'auc': 0.7838736247809551, 'prauc': 0.7941442224893596}
Test:      {'precision': 0.6478017974237362, 'recall': 0.8363123236097952, 'f1': 0.73008485683613, 'auc': 0.7779395162041266, 'prauc': 0.7890250631825644}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 82.75it/s, loss=0.5958]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 319.38it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 318.65it/s]


Validation: {'precision': 0.8079019073532793, 'recall': 0.55785512699731, 'f1': 0.6599888655046029, 'auc': 0.7987582274015663, 'prauc': 0.8042098917250383}
Test:      {'precision': 0.7897321428536173, 'recall': 0.5547193477561784, 'f1': 0.651685388408805, 'auc': 0.7899071195275493, 'prauc': 0.79772077041645}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 82.40it/s, loss=0.5647]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 319.10it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 321.77it/s]


Validation: {'precision': 0.7003444316857051, 'recall': 0.7651301348361081, 'f1': 0.7313052550105772, 'auc': 0.7931829003275337, 'prauc': 0.8107699362350751}
Test:      {'precision': 0.6927262313840414, 'recall': 0.7585449984297318, 'f1': 0.724143087359561, 'auc': 0.7874183275474935, 'prauc': 0.808460688577173}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 82.16it/s, loss=0.5400]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 319.86it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 319.08it/s]


Validation: {'precision': 0.8113848768019908, 'recall': 0.5989338350561338, 'f1': 0.6891574910518026, 'auc': 0.8168227673849621, 'prauc': 0.8286893813105896}
Test:      {'precision': 0.7991666666633368, 'recall': 0.6014424584490391, 'f1': 0.6863481790299142, 'auc': 0.8124596136297173, 'prauc': 0.8263396450941074}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 80.76it/s, loss=0.5010]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 323.26it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 321.63it/s]


Validation: {'precision': 0.726781170481149, 'recall': 0.7165255565985684, 'f1': 0.7216169222046686, 'auc': 0.8040897051895288, 'prauc': 0.8091870766977081}
Test:      {'precision': 0.7354838709653694, 'recall': 0.7149576669780027, 'f1': 0.7250755236996018, 'auc': 0.7991945627701967, 'prauc': 0.8042491954003581}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 81.58it/s, loss=0.4706]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 320.31it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 320.10it/s]


Validation: {'precision': 0.7930029154485922, 'recall': 0.5970523675114549, 'f1': 0.6812164530575638, 'auc': 0.8120437936036333, 'prauc': 0.8196860489645771}
Test:      {'precision': 0.7962510187416616, 'recall': 0.6127312637171128, 'f1': 0.6925394244639054, 'auc': 0.8066178028707229, 'prauc': 0.8140835727410924}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 81.06it/s, loss=0.4327]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 318.21it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 322.10it/s]


Validation: {'precision': 0.8495120698466898, 'recall': 0.5186578864831651, 'f1': 0.6440809921746177, 'auc': 0.824347632828283, 'prauc': 0.8369469438696044}
Test:      {'precision': 0.8383084577072721, 'recall': 0.528378802130673, 'f1': 0.6482015724810293, 'auc': 0.8186464607723077, 'prauc': 0.8316005309696537}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 81.86it/s, loss=0.4097]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 319.87it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 321.97it/s]


Validation: {'precision': 0.6856486210401286, 'recall': 0.8419567262438321, 'f1': 0.7558057656343177, 'auc': 0.8186766047056382, 'prauc': 0.8271512608922567}
Test:      {'precision': 0.6819105691039585, 'recall': 0.841643148319719, 'f1': 0.7534035038247744, 'auc': 0.8135300065886598, 'prauc': 0.823302864428956}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 82.74it/s, loss=0.3772]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 317.48it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 320.96it/s]


Validation: {'precision': 0.7219892150966987, 'recall': 0.7557227971127134, 'f1': 0.7384709617538284, 'auc': 0.810760445053582, 'prauc': 0.8179302301749222}
Test:      {'precision': 0.725508037607748, 'recall': 0.7500783944786765, 'f1': 0.7375886474813816, 'auc': 0.8074611915860449, 'prauc': 0.8182053166173573}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 83.08it/s, loss=0.3487]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 318.37it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 319.55it/s]


Validation: {'precision': 0.7350349437838498, 'recall': 0.7585449984297318, 'f1': 0.7466049332705395, 'auc': 0.8204574982349311, 'prauc': 0.8308807165450851}
Test:      {'precision': 0.7301539390258673, 'recall': 0.7585449984297318, 'f1': 0.7440787400010678, 'auc': 0.8183598112895099, 'prauc': 0.8297320742376826}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 82.48it/s, loss=0.3179]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 320.66it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 314.26it/s]


Validation: {'precision': 0.7398560209399874, 'recall': 0.7089996864198527, 'f1': 0.7240992744234878, 'auc': 0.814073459622547, 'prauc': 0.8219116558312942}
Test:      {'precision': 0.7430779137129973, 'recall': 0.7237378488531712, 'f1': 0.733280376253504, 'auc': 0.814473659706042, 'prauc': 0.8231374853157393}


Epoch 012: 100%|██████████| 98/98 [00:01<00:00, 82.55it/s, loss=0.2991]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 317.11it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 318.00it/s]


Validation: {'precision': 0.7587082728564634, 'recall': 0.6556914393206156, 'f1': 0.7034482708862266, 'auc': 0.8073019950327891, 'prauc': 0.812376303563308}
Test:      {'precision': 0.769424460428887, 'recall': 0.6707431796780472, 'f1': 0.7167029603418983, 'auc': 0.8123933243638973, 'prauc': 0.8224364146249481}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.8495120698466898, 'recall': 0.5186578864831651, 'f1': 0.6440809921746177, 'auc': 0.824347632828283, 'prauc': 0.8369469438696044}
Corresponding test performance:
{'precision': 0.8383084577072721, 'recall': 0.528378802130673, 'f1': 0.6482015724810293, 'auc': 0.8186464607723077, 'prauc': 0.8316005309696537}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 81.11it/s, loss=0.7044]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 319.61it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 320.85it/s]


Validation: {'precision': 0.6041711002328066, 'recall': 0.8902477265572586, 'f1': 0.7198275813882987, 'auc': 0.7512075663811109, 'prauc': 0.7602884881890424}
Test:      {'precision': 0.6068848278780135, 'recall': 0.8955785512671823, 'f1': 0.7234958786533205, 'auc': 0.7499887001213543, 'prauc': 0.7553733071524591}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 82.36it/s, loss=0.6095]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 318.03it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 320.89it/s]


Validation: {'precision': 0.6912028725293501, 'recall': 0.7243650047013974, 'f1': 0.7073954933928609, 'auc': 0.7763731744083037, 'prauc': 0.7809376081665497}
Test:      {'precision': 0.6849752978765331, 'recall': 0.7391031671347159, 'f1': 0.7110105530744603, 'auc': 0.7762904385913254, 'prauc': 0.7854079259905636}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 82.56it/s, loss=0.5776]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 321.92it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 321.78it/s]


Validation: {'precision': 0.8052872394468465, 'recall': 0.49670743179524396, 'f1': 0.6144297858137733, 'auc': 0.7831656379683084, 'prauc': 0.7944085569484074}
Test:      {'precision': 0.7874810701626074, 'recall': 0.48918156161652815, 'f1': 0.6034816200288609, 'auc': 0.7760128491448584, 'prauc': 0.7872392592048586}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 80.93it/s, loss=0.5440]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 318.25it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 320.71it/s]


Validation: {'precision': 0.7245247740706622, 'recall': 0.7290686735630948, 'f1': 0.7267896167545761, 'auc': 0.8054915120454214, 'prauc': 0.8164791115957505}
Test:      {'precision': 0.7198156682005535, 'recall': 0.7347130761971317, 'f1': 0.727188076934953, 'auc': 0.8008041285578889, 'prauc': 0.8113270963808559}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 81.24it/s, loss=0.5099]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 318.49it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 318.57it/s]


Validation: {'precision': 0.740247383441991, 'recall': 0.7318908748801133, 'f1': 0.7360454065399403, 'auc': 0.81272626013665, 'prauc': 0.8238574176161562}
Test:      {'precision': 0.733021806851299, 'recall': 0.7378488554382633, 'f1': 0.7354274055306511, 'auc': 0.8099955857489655, 'prauc': 0.82124217994547}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 82.14it/s, loss=0.4715]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 318.56it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 320.10it/s]


Validation: {'precision': 0.7584308327571975, 'recall': 0.6911257447454026, 'f1': 0.7232157456236649, 'auc': 0.8060129194908524, 'prauc': 0.8132956296025577}
Test:      {'precision': 0.7498309668669715, 'recall': 0.6955158356829868, 'f1': 0.7216528337878594, 'auc': 0.8046907834767083, 'prauc': 0.8104787002834565}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 82.82it/s, loss=0.4549]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 318.54it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 319.79it/s]


Validation: {'precision': 0.7800074321784467, 'recall': 0.6582000627135208, 'f1': 0.7139455732647294, 'auc': 0.8174104371312433, 'prauc': 0.8288056903229332}
Test:      {'precision': 0.7839668799368311, 'recall': 0.6531828159277103, 'f1': 0.7126240114604541, 'auc': 0.8161396997201958, 'prauc': 0.8264232531476481}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 81.79it/s, loss=0.4261]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 318.78it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 319.83it/s]


Validation: {'precision': 0.7391716724605149, 'recall': 0.7331451865765659, 'f1': 0.7361460907156503, 'auc': 0.8176788019615248, 'prauc': 0.8240145545985252}
Test:      {'precision': 0.738744451487829, 'recall': 0.7306365631836607, 'f1': 0.7346681331028334, 'auc': 0.8153012940248557, 'prauc': 0.8232476043516996}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 82.72it/s, loss=0.3916]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 322.78it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 321.09it/s]


Validation: {'precision': 0.712750918855774, 'recall': 0.790529946689274, 'f1': 0.7496283030694373, 'auc': 0.8147319632158326, 'prauc': 0.8183328212554174}
Test:      {'precision': 0.7076006806559626, 'recall': 0.7823769206623319, 'f1': 0.7431124298577368, 'auc': 0.8076451103235591, 'prauc': 0.8149529719241984}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 81.64it/s, loss=0.3570]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 316.18it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 321.07it/s]


Validation: {'precision': 0.6935270805794095, 'recall': 0.8231420507970425, 'f1': 0.7527960948336129, 'auc': 0.8039355787763307, 'prauc': 0.798663008535198}
Test:      {'precision': 0.6910112359532076, 'recall': 0.8099717779842899, 'f1': 0.7457773878399943, 'auc': 0.8001904620970863, 'prauc': 0.7980951866390886}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 80.95it/s, loss=0.3498]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 319.45it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 318.90it/s]


Validation: {'precision': 0.7477995937686263, 'recall': 0.6926936343659684, 'f1': 0.7191925719217916, 'auc': 0.8040386646302689, 'prauc': 0.8081153276194301}
Test:      {'precision': 0.7323383084552825, 'recall': 0.6923800564418552, 'f1': 0.7117988344600522, 'auc': 0.802796178879344, 'prauc': 0.8097582154172291}


Epoch 012: 100%|██████████| 98/98 [00:01<00:00, 81.64it/s, loss=0.3080]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 316.37it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 321.06it/s]


Validation: {'precision': 0.7259370314820812, 'recall': 0.7591721542779581, 'f1': 0.7421827049940978, 'auc': 0.8085893625245144, 'prauc': 0.8046123375889568}
Test:      {'precision': 0.7228699551547897, 'recall': 0.7582314205056186, 'f1': 0.7401285533109613, 'auc': 0.8079729326381265, 'prauc': 0.805884784782996}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7800074321784467, 'recall': 0.6582000627135208, 'f1': 0.7139455732647294, 'auc': 0.8174104371312433, 'prauc': 0.8288056903229332}
Corresponding test performance:
{'precision': 0.7839668799368311, 'recall': 0.6531828159277103, 'f1': 0.7126240114604541, 'auc': 0.8161396997201958, 'prauc': 0.8264232531476481}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 80.75it/s, loss=0.6782]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 321.43it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 321.36it/s]


Validation: {'precision': 0.7257810034688363, 'recall': 0.7212292254602658, 'f1': 0.723497950329639, 'auc': 0.7883254567100201, 'prauc': 0.7929704735365867}
Test:      {'precision': 0.7131948686600086, 'recall': 0.7322044528042264, 'f1': 0.7225746507312609, 'auc': 0.7858507902616689, 'prauc': 0.7946047370677148}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 80.50it/s, loss=0.5762]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 319.63it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 320.50it/s]


Validation: {'precision': 0.7180576631237692, 'recall': 0.7419253684517343, 'f1': 0.7297964169608374, 'auc': 0.7956346556957596, 'prauc': 0.7984137925109932}
Test:      {'precision': 0.7143281807350942, 'recall': 0.7535277516439213, 'f1': 0.7334045425368075, 'auc': 0.7908049691279262, 'prauc': 0.7938095192714885}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 82.02it/s, loss=0.5555]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 318.08it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 320.29it/s]


Validation: {'precision': 0.6519561068686738, 'recall': 0.8570084666012637, 'f1': 0.7405500560576717, 'auc': 0.8010693198112022, 'prauc': 0.8089526801375302}
Test:      {'precision': 0.6573426573410722, 'recall': 0.8548134211324716, 'f1': 0.7431842917026524, 'auc': 0.8009157180053625, 'prauc': 0.8113682035145968}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 82.19it/s, loss=0.5335]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 316.95it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 320.32it/s]


Validation: {'precision': 0.7322560202765138, 'recall': 0.7246785826255105, 'f1': 0.7284475915305422, 'auc': 0.80508685485562, 'prauc': 0.8179390332792336}
Test:      {'precision': 0.7347513293689873, 'recall': 0.7365945437418107, 'f1': 0.7356717770208797, 'auc': 0.8058853492669171, 'prauc': 0.82017379537891}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 81.54it/s, loss=0.5130]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 321.85it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 322.31it/s]


Validation: {'precision': 0.7754805948466613, 'recall': 0.6704296017539341, 'f1': 0.7191389119429112, 'auc': 0.8118230029954177, 'prauc': 0.8203062534510964}
Test:      {'precision': 0.7650409398335172, 'recall': 0.6738789589191788, 'f1': 0.716572185747923, 'auc': 0.8101450763484228, 'prauc': 0.8183470950317174}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 82.10it/s, loss=0.4801]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 320.11it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 321.26it/s]


Validation: {'precision': 0.7503351206409171, 'recall': 0.7021009720893632, 'f1': 0.725417134157546, 'auc': 0.8010261664249776, 'prauc': 0.8125807490113361}
Test:      {'precision': 0.7329564349825974, 'recall': 0.6911257447454026, 'f1': 0.7114267219226118, 'auc': 0.7987862067993559, 'prauc': 0.8137054783678046}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 81.02it/s, loss=0.4631]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 323.42it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 320.98it/s]


Validation: {'precision': 0.7202363367777836, 'recall': 0.7645029789878818, 'f1': 0.7417097607460273, 'auc': 0.8127983499029275, 'prauc': 0.8190849968101703}
Test:      {'precision': 0.7160782025073941, 'recall': 0.7695202257736924, 'f1': 0.741837963565292, 'auc': 0.8100663043658805, 'prauc': 0.8201797459591431}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 81.12it/s, loss=0.4417]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 318.91it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 319.96it/s]


Validation: {'precision': 0.7441244620961797, 'recall': 0.7049231734063816, 'f1': 0.7239935537774952, 'auc': 0.8101995515062039, 'prauc': 0.8151617001572766}
Test:      {'precision': 0.7410141206651444, 'recall': 0.7240514267772843, 'f1': 0.7324345707318917, 'auc': 0.8098692485756868, 'prauc': 0.8188356976993639}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 81.54it/s, loss=0.4173]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 319.78it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 319.48it/s]


Validation: {'precision': 0.7305444478599491, 'recall': 0.7447475697687528, 'f1': 0.7375776347497257, 'auc': 0.8086949602169988, 'prauc': 0.8152402673956889}
Test:      {'precision': 0.7252279635236315, 'recall': 0.7481969269339975, 'f1': 0.7365334106495048, 'auc': 0.811273453539882, 'prauc': 0.8184362289232683}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 82.56it/s, loss=0.4027]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 318.02it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 318.58it/s]


Validation: {'precision': 0.8063801506388729, 'recall': 0.5707118218859495, 'f1': 0.6683804578689168, 'auc': 0.8087119402455714, 'prauc': 0.8163029859852434}
Test:      {'precision': 0.7980017376159948, 'recall': 0.5760426465958732, 'f1': 0.6690948776630915, 'auc': 0.8102583771360923, 'prauc': 0.8170003089952736}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7754805948466613, 'recall': 0.6704296017539341, 'f1': 0.7191389119429112, 'auc': 0.8118230029954177, 'prauc': 0.8203062534510964}
Corresponding test performance:
{'precision': 0.7650409398335172, 'recall': 0.6738789589191788, 'f1': 0.716572185747923, 'auc': 0.8101450763484228, 'prauc': 0.8183470950317174}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 81.58it/s, loss=0.6891]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 320.90it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 321.37it/s]


Validation: {'precision': 0.7980072463731974, 'recall': 0.5525243022873862, 'f1': 0.6529553407251272, 'auc': 0.7875873780715893, 'prauc': 0.7913531807991625}
Test:      {'precision': 0.7867583834875892, 'recall': 0.5738476011270811, 'f1': 0.6636446007410606, 'auc': 0.7832471874878886, 'prauc': 0.782370790665681}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 81.11it/s, loss=0.6146]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 317.96it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 321.55it/s]


Validation: {'precision': 0.7920792079173824, 'recall': 0.5769833803682127, 'f1': 0.6676342476609111, 'auc': 0.7908673870820565, 'prauc': 0.7985715291235543}
Test:      {'precision': 0.7785234899296203, 'recall': 0.5820006271540232, 'f1': 0.6660685398713576, 'auc': 0.7873405622356545, 'prauc': 0.7958949306000732}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 81.87it/s, loss=0.5621]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 318.60it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 318.52it/s]


Validation: {'precision': 0.7239185750613107, 'recall': 0.71370335528155, 'f1': 0.718774667349159, 'auc': 0.7967322286669302, 'prauc': 0.8130244227500474}
Test:      {'precision': 0.7256214149116456, 'recall': 0.7140169332056632, 'f1': 0.71977239898098, 'auc': 0.7952172571545346, 'prauc': 0.8101997956483213}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 82.34it/s, loss=0.5293]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 317.58it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 321.21it/s]


Validation: {'precision': 0.7082981715873184, 'recall': 0.7895892129169345, 'f1': 0.7467378360564029, 'auc': 0.8080782536171229, 'prauc': 0.8192195550415076}
Test:      {'precision': 0.7025298860141714, 'recall': 0.792411414233953, 'f1': 0.7447686363362449, 'auc': 0.807372503897074, 'prauc': 0.8219056269515133}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 81.85it/s, loss=0.5226]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 318.03it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 320.14it/s]


Validation: {'precision': 0.7310527867864021, 'recall': 0.7773596738765213, 'f1': 0.7534954357319052, 'auc': 0.825403258095731, 'prauc': 0.8356184522784769}
Test:      {'precision': 0.7200699096978734, 'recall': 0.7751646284077292, 'f1': 0.7466022299788615, 'auc': 0.8187734019480085, 'prauc': 0.8340712474415439}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 82.61it/s, loss=0.4599]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 320.87it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 319.64it/s]


Validation: {'precision': 0.7435324177555301, 'recall': 0.7300094073354343, 'f1': 0.7367088557575835, 'auc': 0.8152406105214487, 'prauc': 0.8263462367625716}
Test:      {'precision': 0.741782553727112, 'recall': 0.7359673878935843, 'f1': 0.7388635240391492, 'auc': 0.8139438992483693, 'prauc': 0.8280436480836628}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 82.20it/s, loss=0.4434]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 319.76it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 322.50it/s]


Validation: {'precision': 0.7408748114608119, 'recall': 0.7701473816219186, 'f1': 0.755227547275077, 'auc': 0.8288577896483919, 'prauc': 0.8382144020596137}
Test:      {'precision': 0.7326997326975566, 'recall': 0.7735967387871634, 'f1': 0.7525930395407432, 'auc': 0.8238216038579649, 'prauc': 0.8377983101253612}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 81.49it/s, loss=0.4226]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 322.02it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 321.53it/s]


Validation: {'precision': 0.8380864765370833, 'recall': 0.5713389777341759, 'f1': 0.6794704408226561, 'auc': 0.8296376652802323, 'prauc': 0.840215561367446}
Test:      {'precision': 0.8382616487417641, 'recall': 0.5867042960157206, 'f1': 0.690278541546932, 'auc': 0.8282272478579304, 'prauc': 0.8426821693031619}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 82.17it/s, loss=0.3872]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 318.97it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 320.90it/s]


Validation: {'precision': 0.8308128544384178, 'recall': 0.5512699905909336, 'f1': 0.6627709659843319, 'auc': 0.8217792276779691, 'prauc': 0.8302777539842163}
Test:      {'precision': 0.8410997204061459, 'recall': 0.5660081530242521, 'f1': 0.6766635378314915, 'auc': 0.8283253479179784, 'prauc': 0.8402224741388288}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 82.23it/s, loss=0.3748]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 320.10it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 319.36it/s]


Validation: {'precision': 0.6866716679152621, 'recall': 0.8610849796147347, 'f1': 0.7640511914998754, 'auc': 0.825644595543255, 'prauc': 0.8308749706730807}
Test:      {'precision': 0.6885204722414255, 'recall': 0.859517089994169, 'f1': 0.7645746115163364, 'auc': 0.8309527584538947, 'prauc': 0.8411784978162753}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 79.28it/s, loss=0.3258]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 317.78it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 320.35it/s]


Validation: {'precision': 0.8137168141556915, 'recall': 0.5766698024440995, 'f1': 0.6749862311494634, 'auc': 0.8116562671527175, 'prauc': 0.8194730555594634}
Test:      {'precision': 0.8089935760136661, 'recall': 0.5923486986497575, 'f1': 0.6839246873690181, 'auc': 0.8101347579737128, 'prauc': 0.8210057810222049}


Epoch 012: 100%|██████████| 98/98 [00:01<00:00, 81.72it/s, loss=0.3131]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 319.96it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 318.78it/s]


Validation: {'precision': 0.725076452597171, 'recall': 0.7434932580723002, 'f1': 0.7341693710629192, 'auc': 0.8111218986203877, 'prauc': 0.8132447702401917}
Test:      {'precision': 0.7299999999977879, 'recall': 0.7554092191886002, 'f1': 0.7424872811757812, 'auc': 0.8154954304700095, 'prauc': 0.821061593542323}


Epoch 013: 100%|██████████| 98/98 [00:01<00:00, 82.41it/s, loss=0.2828]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 318.79it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 317.45it/s]


Validation: {'precision': 0.6786079836216004, 'recall': 0.8316086547480979, 'f1': 0.7473580336570754, 'auc': 0.8004029290448815, 'prauc': 0.7977663463071984}
Test:      {'precision': 0.6891578676263478, 'recall': 0.8391345249268136, 'f1': 0.7567873253627527, 'auc': 0.8125188562006134, 'prauc': 0.8117823903357132}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.8380864765370833, 'recall': 0.5713389777341759, 'f1': 0.6794704408226561, 'auc': 0.8296376652802323, 'prauc': 0.840215561367446}
Corresponding test performance:
{'precision': 0.8382616487417641, 'recall': 0.5867042960157206, 'f1': 0.690278541546932, 'auc': 0.8282272478579304, 'prauc': 0.8426821693031619}


Epoch 001:   9%|▉         | 9/98 [00:00<00:01, 83.19it/s, loss=0.7705]

In [None]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")


Final Metrics:
precision: 0.7711 ± 0.0397
recall: 0.6721 ± 0.0880
f1: 0.7122 ± 0.0370
auc: 0.8141 ± 0.0069
prauc: 0.8249 ± 0.0061
