In [1]:
import torch
import pickle
import numpy as np
from argparse import Namespace
from torch.utils.data import DataLoader
import torch.nn.functional as F
from heterogt.utils.tokenizer import EHRTokenizer
from heterogt.utils.dataset import FineTuneEHRDataset, batcher
from heterogt.utils.train import train_with_early_stopping
from heterogt.utils.seed import set_random_seed

In [2]:
set_random_seed(123)

[INFO] Random seed set to 123


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
config = Namespace(
    dataset = "MIMIC-III",
    tasks = ["death", "readmission", "stay", "next_diag_6m", "next_diag_12m"], 
    task_index = 1,
    token_type = ["diag", "med", "lab", "pro"],
    special_tokens = ["[PAD]"],
    batch_size = 32,
    lr = 1e-3,
    epochs = 500,
    early_stop_patience = 5,   
)

In [5]:
full_data_path = f"./data_process/{config.dataset}-processed/mimic.pkl"  # for tokenizer
curr_task = config.tasks[config.task_index]
print("Current task:", curr_task)
if curr_task == "next_diag_6m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_6m.pkl"
elif curr_task == "next_diag_12m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_downstream.pkl"

Current task: readmission


In [6]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_gender_sentences = ["[PAD]"] + [str(c) + "_" + gender for c in set(ehr_full_data["AGE"].values.tolist()) for gender in ["M", "F"]]
token_type_sentences = ["[PAD]"] + config.token_type
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
config.max_num_adms = max_admissions
print(f"Max admissions per patient: {config.max_num_adms}")

Max admissions per patient: 8


In [7]:
task_sentences = config.tasks
tokenizer = EHRTokenizer(token_type_sentences, age_gender_sentences, task_sentences, diag_sentences, 
                         med_sentences, lab_sentences, pro_sentences, special_tokens=config.special_tokens)
config.label_vocab_size = len(tokenizer.diag_voc.id2word)  # only for diagnosis
config.global_vocab_size = len(tokenizer.vocab.id2word)
config.age_gender_vocab_size = tokenizer.token_number("age_gender")
print(f"Age and gender vocabulary size: {config.age_gender_vocab_size}")

Age and gender vocabulary size: 37


In [8]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))
# example label percentage
print("Percentage of DEATH in test dataset:",
      (test_data["DEATH"] == True).mean() * 100, "%")

print("Percentage of READMISSION in test dataset:",
      (test_data["READMISSION"] == 1).mean() * 100, "%")

print("Percentage of STAY>7 days in test dataset:",
      (test_data["STAY_DAYS"] > 7).mean() * 100, "%")

Percentage of DEATH in test dataset: 28.648477157360407 %
Percentage of READMISSION in test dataset: 40.1491116751269 %
Percentage of STAY>7 days in test dataset: 50.58692893401015 %


In [9]:
train_dataset = FineTuneEHRDataset(train_data, tokenizer, token_type=config.token_type, task=curr_task)
val_dataset = FineTuneEHRDataset(val_data, tokenizer, token_type=config.token_type, task=curr_task)
test_dataset = FineTuneEHRDataset(test_data, tokenizer, token_type=config.token_type, task=curr_task)

In [10]:
train_dataloader = DataLoader(
    train_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=True,
    batch_size=config.batch_size,
)

val_dataloader = DataLoader(
    val_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

test_dataloader = DataLoader(
    test_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

In [11]:
if curr_task in ["death", "stay", "readmission"]:
    eval_metric = "prauc"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "prauc"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [12]:
input_ids, token_types, adm_index, age_gender_ids, task_index, labels = next(iter(train_dataloader))
print("Input IDs shape:", input_ids.shape)
print("Token Types shape:", token_types.shape)
print("Admission Index shape:", adm_index.shape)
print("Age/Sex IDs shape:", age_gender_ids.shape)
print("Task Index:", task_index)
print("Labels shape:", labels.shape)

Input IDs shape: torch.Size([32, 256])
Token Types shape: torch.Size([32, 256])
Admission Index shape: torch.Size([32, 256])
Age/Sex IDs shape: torch.Size([32, 7])
Task Index: 1
Labels shape: torch.Size([32, 1])


# Model Walkthrough

In [13]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData, Batch as HeteroBatch
from torch_geometric.nn import HeteroConv, GATConv

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [14]:
class DiseaseOccHetGNN(nn.Module):
    def __init__(self, d_model: int, heads: int = 1, dropout: float = 0.0):
        super().__init__()
        self.act = nn.GELU()

        # 第1层
        self.conv1 = HeteroConv({
            ('visit','contains','occ'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('occ','contained_by','visit'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('visit','next','visit'):       GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
        }, aggr='mean')

        # 第2层
        self.conv2 = HeteroConv({
            ('visit','contains','occ'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('occ','contained_by','visit'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('visit','next','visit'):       GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
        }, aggr='mean')
        self.lin = nn.Linear(d_model, d_model)

    def forward(self, hg):
        # x_dict: {'visit': [N_visit, d], 'occ': [N_occ, d]}
        x_dict = {'visit': hg['visit'].x, 'occ': hg['occ'].x}

        # 第1层：HeteroConv → Linear → GELU → Dropout
        x_dict = self.conv1(x_dict, hg.edge_index_dict)

        # 第2层：HeteroConv → Linear（末层通常不再加激活/随你需要）
        x_dict = self.conv2(x_dict, hg.edge_index_dict)
        x_dict = {k: self.lin(v) for k, v in x_dict.items()}

        return x_dict  # {'visit': [N_visit, d], 'occ': [N_occ, d]}

In [15]:
# multi-class classification task
class MultiPredictionHead(nn.Module):
    def __init__(self, hidden_size, label_size):
        super(MultiPredictionHead, self).__init__()
        self.cls = nn.Sequential(
                nn.Linear(hidden_size, hidden_size), 
                nn.ReLU(), 
                nn.Linear(hidden_size, label_size)
            )

    def forward(self, input):
        return self.cls(input)
    
class BinaryPredictionHead(nn.Module):
    def __init__(self, hidden_size):
        super(BinaryPredictionHead, self).__init__()
        self.cls = nn.Sequential(
                nn.Linear(hidden_size, hidden_size), 
                nn.ReLU(), 
                nn.Linear(hidden_size, 1)
            )
    def forward(self, input):
        return self.cls(input)

In [16]:
for i in range(len(train_dataset)):
    age_gender_ids = train_dataset[i][3]
    if len(age_gender_ids[0]) > 3:
        print(age_gender_ids)
        break
exp_i = i
id_seq = torch.concat([train_dataset[exp_i][0][0], torch.zeros(5, dtype=train_dataset[exp_i][0][0].dtype)], dim=0)
type_seq = torch.concat([train_dataset[exp_i][1][0], torch.zeros(5, dtype=train_dataset[exp_i][1][0].dtype)], dim=0)
visit_seq = torch.concat([train_dataset[exp_i][2][0], torch.zeros(5, dtype=train_dataset[exp_i][2][0].dtype)], dim=0)
age_sex = torch.concat([train_dataset[exp_i][3][0], torch.zeros(3, dtype=train_dataset[exp_i][3][0].dtype)], dim=0)

tensor([[19, 19, 19, 33]])


In [17]:
class HeteroGT(nn.Module):
    def __init__(self, tokenizer, d_model, num_heads, num_layers, max_num_adms, device, task, use_hetero_graph):
        super(HeteroGT, self).__init__()
        self.device = device
        self.tokenizer = tokenizer
        self.max_num_adms = max_num_adms
        self.use_hetero_graph = use_hetero_graph
        self.global_vocab_size = len(self.tokenizer.vocab.word2id)
        self.age_sex_vocab_size = len(self.tokenizer.age_gender_voc.word2id)
        self.n_type = len(self.tokenizer.token_type_voc.word2id)
        self.d_model = d_model
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.seq_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="all")[0] #0
        self.type_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="type")[0] #0
        self.adm_pad_id = 0
        self.age_sex_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="age_gender")[0] #0
        self.diag_type_id = 1
        self.visit_type_id = 5
        
        # embedding layers
        self.token_emb = nn.Embedding(self.global_vocab_size, d_model, padding_idx=self.seq_pad_id)
        self.type_emb = nn.Embedding(self.n_type + 1, d_model, padding_idx=self.type_pad_id) # n_type already have PAD, + 1 for visit
        self.adm_index_emb = nn.Embedding(self.max_num_adms + 1, d_model, padding_idx=self.adm_pad_id) # +1 for pad
        self.age_sex_emb = nn.Embedding(self.age_sex_vocab_size, d_model, padding_idx=self.age_sex_pad_id)
        self.task_emb = nn.Embedding(5, d_model, padding_idx=None)  # 5 task in total, task embedding, not used in this model
        
        # GNN
        self.het_gnn = DiseaseOccHetGNN(d_model)    

        # encoder transformer
        self.num_attn_heads = num_heads
        enc_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=self.num_attn_heads, batch_first=True, norm_first=True)
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers, enable_nested_tensor=False)

        # prediction head
        if task in ["death", "stay", "readmission"]:
            self.cls_head = BinaryPredictionHead(self.d_model)
        else:
            self.cls_head = MultiPredictionHead(self.d_model, config.label_vocab_size)

    def visit_segment(self, B, input_ids, token_types, adm_index, age_gender_index):
        graphs = []
        for p in range(B):
            hg_p = self.build_patient_graph(input_ids[p], token_types[p], adm_index[p], age_gender_index[p])
            graphs.append(hg_p)

        batch_graph = HeteroBatch.from_data_list(graphs).to(self.device)
        out = self.het_gnn(batch_graph)
        h_visit_all = out['visit']  # extract virtual visit node representations

        # 取出每个样本的 visit 表示序列（按我们在 build_patient_graph 中的保序构造）
        visit_emb_seq = []
        offset = 0
        for p in range(B):
            n_v = graphs[p]['visit'].num_nodes
            visit_emb_p = h_visit_all[offset:offset + n_v]  # [N_visit_p, d]
            offset += n_v
            visit_emb_seq.append(visit_emb_p)
            
        visit_emb_pad = []
        visit_pad_mask = []
        visit_index_pad = []
        for p in range(B):
            v = visit_emb_seq[p]                          # [N_visit_p, d]
            Np = v.size(0)
            if Np < self.max_num_adms:
                pad_len = self.max_num_adms - Np
                v_pad = torch.cat([v, torch.zeros(pad_len, self.d_model, device=self.device, dtype=v.dtype)], dim=0)
                m_pad = torch.cat([torch.zeros(Np, dtype=torch.bool, device=self.device), torch.ones(pad_len, dtype=torch.bool, device=self.device)], dim=0)
                i_pad = torch.cat([torch.arange(1, Np + 1, device=self.device), torch.full((pad_len,), self.adm_pad_id, dtype=torch.long, device=self.device)], dim=0)
            else:
                v_pad = v[:self.max_num_adms]
                m_pad = torch.zeros(self.max_num_adms, dtype=torch.bool, device=self.device)
                i_pad = torch.arange(1, self.max_num_adms + 1, device=self.device)
            visit_emb_pad.append(v_pad)      # [V_max, d]
            visit_pad_mask.append(m_pad)     # [V_max]
            visit_index_pad.append(i_pad)    # [V_max]

        visit_emb_pad  = torch.stack(visit_emb_pad,  dim=0)  # [B, V_max, d]
        visit_pad_mask = torch.stack(visit_pad_mask, dim=0)  # [B, V_max]
        visit_index_pad = torch.stack(visit_index_pad, dim=0)  # [B, V_max]

        visit_type_ids = torch.full((B, self.max_num_adms), self.visit_type_id, dtype=torch.long, device=self.device) # [B, V_max]
        visit_type_ids_pad = visit_type_ids * (~visit_pad_mask)

        return visit_emb_pad, visit_pad_mask, visit_index_pad, visit_type_ids_pad

    def build_patient_graph(self, id_seq: torch.Tensor, type_seq: torch.Tensor, visit_seq: torch.Tensor, age_sex: torch.Tensor):
        # build a graph just for one patient
        hg = HeteroData()
        occ_mask = (type_seq == self.diag_type_id) & (id_seq != self.seq_pad_id) # 疾病token mask
        occ_pos = torch.nonzero(occ_mask, as_tuple=False).view(-1) # 疾病 token 的位置索引，形状 [N_occ]
        N_occ = occ_pos.numel() # 疾病 token 数量

        # build visit virtual nodes
        nonpad = id_seq != self.seq_pad_id
        visit_used = visit_seq[nonpad] # seq非pad部分
        visit_ids_unique, visit_lid_nonpad = torch.unique(visit_used, return_inverse=True)
        visit_lid_full = torch.full_like(id_seq, fill_value=-1)
        visit_lid_full[nonpad] = visit_lid_nonpad
        N_visit = visit_ids_unique.numel()
        age_sex_nonpad = age_sex[age_sex!=self.age_sex_pad_id]
        assert N_visit == len(visit_ids_unique) == len(age_sex_nonpad)
        visit_x = self.age_sex_emb(age_sex_nonpad.to(self.device))
        hg['visit'].x = visit_x
        hg['visit'].num_nodes = N_visit
        
        # build diag nodes
        gid_occ = id_seq[occ_pos]
        x_occ = self.token_emb(gid_occ) # [N_occ, d]
        hg['occ'].x = x_occ
        hg['occ'].num_nodes = N_occ

        # build edges between diag nodes and virtual visit nodes
        occ_visit_lid = visit_lid_full[occ_pos]
        e_v2o = torch.stack([occ_visit_lid, torch.arange(N_occ, device=self.device)], dim=0)
        e_o2v = torch.stack([torch.arange(N_occ, device=self.device), occ_visit_lid], dim=0)
        hg['visit','contains','occ'].edge_index = e_v2o
        hg['occ','contained_by','visit'].edge_index = e_o2v
        
        # build forward edges between virtual visit nodes
        if N_visit > 1:
            src = torch.arange(0, N_visit - 1, device=self.device)
            dst = torch.arange(1, N_visit, device=self.device)
            e_next = torch.stack([src, dst], dim=0) # [2, N_visit-1]
        else:
            e_next = torch.empty(2, 0, dtype=torch.long, device=self.device)
        hg['visit','next','visit'].edge_index = e_next
        return hg

    @staticmethod
    def build_attn_mask(token_types, forbid_map, num_heads):
        B, L = token_types.shape
        device = token_types.device
        
        if forbid_map == None:
            mask = torch.zeros((B, L, L), dtype=torch.bool, device=device)
        else:
            # 收集所有出现的 token 类型
            observed = torch.unique(token_types)
            for q_t, ks in forbid_map.items():
                observed = torch.unique(torch.cat([observed, torch.tensor([q_t] + list(ks), device=device)]))
            type_list = observed.sort().values
            t2i = {t.item(): i for i, t in enumerate(type_list)}  # Map token types to indices
            T = len(type_list)

            # 构造禁止矩阵 (T, T)，单向关系
            ban_table = torch.zeros((T, T), dtype=torch.bool, device=device)
            for q_t, ks in forbid_map.items():
                if q_t in t2i:
                    qi = t2i[q_t]
                    for k_t in ks:
                        if k_t in t2i:
                            ban_table[qi, t2i[k_t]] = True  # 只设置 q -> k 的禁止

            # 向量化映射 token_types 到类型索引
            mapping = torch.zeros_like(type_list, dtype=torch.long, device=device)
            for t, i in t2i.items():
                mapping[type_list == t] = i
            q_idx = mapping[torch.searchsorted(type_list, token_types.unsqueeze(-1))]
            k_idx = mapping[torch.searchsorted(type_list, token_types.unsqueeze(-2))]

            # 查询 ban_table 得到 (B, L, L)
            mask = ban_table[q_idx, k_idx].to(torch.bool)
        
        # 扩展到 num_heads
        mask = mask.unsqueeze(1).expand(B, num_heads, L, L)
        mask = mask.reshape(B * num_heads, L, L)
        return mask
    
    def forward(self, input_ids, token_types, adm_index, age_gender_index, task_id):
        B, L = input_ids.shape
        task_id = torch.full((B,), task_id, dtype=torch.long, device=self.device)
        # 基础表示
        token_embed = self.token_emb(input_ids)  # [B, L, d]
        seq_pad_mask = (input_ids == self.seq_pad_id) # [B, L]
        
        forbid_map = None
        if self.use_hetero_graph:
            # get visit embed and mask
            visit_emb_pad, visit_pad_mask, visit_index_pad, visit_type_ids_pad = self.visit_segment(B, input_ids, token_types, adm_index, age_gender_index)
            token_embed = torch.concat([token_embed, visit_emb_pad], dim=1)  # [B, L+V, d]
            adm_index = torch.concat([adm_index, visit_index_pad], dim=1)  # [B, L+V]
            token_types = torch.concat([token_types, visit_type_ids_pad], dim=1) # [B, L+V]
            seq_pad_mask = torch.concat([seq_pad_mask, visit_pad_mask], dim=1) # [B, L+V]
            
        adm_emb = self.adm_index_emb(adm_index) # [B, L+V, d]
        token_type_emb = self.type_emb(token_types) # [B, L+V, d]
        x = token_embed + adm_emb + token_type_emb # [B, L+V, d]
        task_id_emb = self.task_emb(task_id).unsqueeze(1) # [B, 1, d]
        x = torch.concat([task_id_emb, x], dim=1) # [B, 1+L+V, d]

        # mask
        task_pad_mask = torch.zeros((B, 1), dtype=torch.bool, device=self.device)
        src_key_padding_mask = torch.concat([task_pad_mask, seq_pad_mask], dim=1)  # [B, 1+L+V]
        attn_mask = self.build_attn_mask(torch.concat([torch.full((B, 1), -1, device=self.device), token_types], dim=1), 
                                         forbid_map=forbid_map, 
                                         num_heads=self.num_attn_heads)
        assert attn_mask.dtype == src_key_padding_mask.dtype, f"attn_mask dtype ({attn_mask.dtype}) and src_key_padding_mask dtype ({src_key_padding_mask.dtype}) must match"
        
        # ===== Transformer 编码（batch_first=True） =====
        h = self.encoder(src=x, src_key_padding_mask=src_key_padding_mask, mask=attn_mask)   # [B, 1+L(+V), d]

        # ===== 分类：取 CLS（task 位） =====
        logits = self.cls_head(h[:, 0, :])  # [B, label_size]
        return logits

In [18]:
final_metrics = {"precision":[],"recall":[],"f1":[],"auc":[],"prauc":[]}
for i in range(5):
    model = HeteroGT(tokenizer, d_model=128, num_heads=4, num_layers=2, max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, use_hetero_graph=True).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 26.55it/s, loss=0.6061]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.06it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.78it/s]


Validation: {'precision': 0.6820234868984552, 'recall': 0.5968379446616726, 'f1': 0.6365935869250661, 'auc': 0.7930991070121505, 'prauc': 0.6964381278551244}
Test:      {'precision': 0.6807174887861852, 'recall': 0.5997629395472155, 'f1': 0.6376811544375963, 'auc': 0.7949989962786389, 'prauc': 0.6890966176865467}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 29.19it/s, loss=0.5570]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 47.43it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.90it/s]


Validation: {'precision': 0.680783242255552, 'recall': 0.5909090909067554, 'f1': 0.6326703293430747, 'auc': 0.7995532969446013, 'prauc': 0.7068188859202365}
Test:      {'precision': 0.680449438199189, 'recall': 0.5981825365444561, 'f1': 0.6366694651609981, 'auc': 0.7959091521690801, 'prauc': 0.6913205843824715}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 29.17it/s, loss=0.5483]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.17it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.95it/s]


Validation: {'precision': 0.6450723638847516, 'recall': 0.7399209486136762, 'f1': 0.6892488904553744, 'auc': 0.8008069305895393, 'prauc': 0.7081839619407706}
Test:      {'precision': 0.6487330787898343, 'recall': 0.7384433030393582, 'f1': 0.6906873564374282, 'auc': 0.7940805153127458, 'prauc': 0.6874282907405301}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 29.40it/s, loss=0.5334]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.14it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.54it/s]


Validation: {'precision': 0.6083357966284116, 'recall': 0.8134387351746505, 'f1': 0.6960933487292983, 'auc': 0.7993533680490202, 'prauc': 0.7065449783800911}
Test:      {'precision': 0.6183228886880981, 'recall': 0.8186487554294009, 'f1': 0.7045222664310581, 'auc': 0.7977306682061599, 'prauc': 0.6939169092254329}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 29.44it/s, loss=0.5200]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 52.31it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.95it/s]


Validation: {'precision': 0.6410169491503694, 'recall': 0.7474308300365714, 'f1': 0.6901459804283113, 'auc': 0.7992494824016564, 'prauc': 0.7015221692274259}
Test:      {'precision': 0.6498792687111077, 'recall': 0.7443698142997062, 'f1': 0.6939226469541108, 'auc': 0.7949645440795989, 'prauc': 0.6893325712005286}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 29.49it/s, loss=0.5014]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 52.18it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.31it/s]


Validation: {'precision': 0.6464471403790418, 'recall': 0.7371541501947149, 'f1': 0.6888273265055569, 'auc': 0.799739266369701, 'prauc': 0.7031867110738906}
Test:      {'precision': 0.6475724764210703, 'recall': 0.7325167917790102, 'f1': 0.6874304733256256, 'auc': 0.7919641659431529, 'prauc': 0.685730942932568}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 29.47it/s, loss=0.5003]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 52.27it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 52.11it/s]


Validation: {'precision': 0.6098370549166873, 'recall': 0.7988142292458545, 'f1': 0.6916495501868676, 'auc': 0.7914757826714348, 'prauc': 0.6887283028577764}
Test:      {'precision': 0.610215871083581, 'recall': 0.7929672066345597, 'f1': 0.6896907167319276, 'auc': 0.7867355996876474, 'prauc': 0.6788814285383911}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 29.47it/s, loss=0.4803]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 52.14it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 52.03it/s]


Validation: {'precision': 0.6267310789029735, 'recall': 0.7691699604712682, 'f1': 0.6906832248632749, 'auc': 0.7960231193926847, 'prauc': 0.6958785082107074}
Test:      {'precision': 0.6258945998678404, 'recall': 0.7601738443273008, 'f1': 0.6865298790765911, 'auc': 0.7874666355584603, 'prauc': 0.6801840994716148}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6450723638847516, 'recall': 0.7399209486136762, 'f1': 0.6892488904553744, 'auc': 0.8008069305895393, 'prauc': 0.7081839619407706}
Corresponding test performance:
{'precision': 0.6487330787898343, 'recall': 0.7384433030393582, 'f1': 0.6906873564374282, 'auc': 0.7940805153127458, 'prauc': 0.6874282907405301}


Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 29.53it/s, loss=0.6109]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 52.19it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 52.06it/s]


Validation: {'precision': 0.7762008733539716, 'recall': 0.28102766798307893, 'f1': 0.4126523466458171, 'auc': 0.7910674550891943, 'prauc': 0.6924969729540251}
Test:      {'precision': 0.7673216132288798, 'recall': 0.29316475701188005, 'f1': 0.4242424202395473, 'auc': 0.7974549982548758, 'prauc': 0.6956754043382274}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 29.74it/s, loss=0.5531]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 52.12it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.92it/s]


Validation: {'precision': 0.6843601895702164, 'recall': 0.5707509881400366, 'f1': 0.6224137881417323, 'auc': 0.794139322834975, 'prauc': 0.6985476236240861}
Test:      {'precision': 0.6867924528269491, 'recall': 0.5752666930044439, 'f1': 0.6261019086033273, 'auc': 0.7976067868947185, 'prauc': 0.6940337268621632}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 26.87it/s, loss=0.5417]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.99it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.97it/s]


Validation: {'precision': 0.6669371196727508, 'recall': 0.6498023715389336, 'f1': 0.6582582532564695, 'auc': 0.7996723445636489, 'prauc': 0.7050969682053583}
Test:      {'precision': 0.6671993607644139, 'recall': 0.659818253652075, 'f1': 0.663488274695572, 'auc': 0.799905293103916, 'prauc': 0.6969536243993214}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 29.58it/s, loss=0.5400]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 52.14it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 52.19it/s]


Validation: {'precision': 0.6355599214124572, 'recall': 0.7671936758862957, 'f1': 0.695200568107442, 'auc': 0.7969038208168643, 'prauc': 0.7007164031085862}
Test:      {'precision': 0.6442908851574719, 'recall': 0.7736072698507562, 'f1': 0.7030520596710224, 'auc': 0.79905445992094, 'prauc': 0.6956286106612977}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 29.51it/s, loss=0.5194]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 52.12it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 52.00it/s]


Validation: {'precision': 0.6020172910645475, 'recall': 0.8256916996014795, 'f1': 0.6963333284537345, 'auc': 0.797150385846038, 'prauc': 0.6975176221514661}
Test:      {'precision': 0.6059907834083929, 'recall': 0.8312919794514766, 'f1': 0.7009828370329024, 'auc': 0.7973448349922923, 'prauc': 0.6926662776765421}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 29.78it/s, loss=0.5088]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 52.11it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.98it/s]


Validation: {'precision': 0.6354860186396954, 'recall': 0.7545454545424722, 'f1': 0.6899168725188287, 'auc': 0.8016064370412196, 'prauc': 0.703846544942007}
Test:      {'precision': 0.6391580354138351, 'recall': 0.7558277360697123, 'f1': 0.692614042823922, 'auc': 0.794843228357448, 'prauc': 0.6888301896020392}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 29.46it/s, loss=0.4949]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.42it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 52.04it/s]


Validation: {'precision': 0.614791987671449, 'recall': 0.7885375494039979, 'f1': 0.6909090859833422, 'auc': 0.7947504025764895, 'prauc': 0.6981956253283012}
Test:      {'precision': 0.6164341085252204, 'recall': 0.7854602923714522, 'f1': 0.6907574655358865, 'auc': 0.7880190749992957, 'prauc': 0.6804483150399178}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 29.45it/s, loss=0.4734]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.81it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.91it/s]


Validation: {'precision': 0.6236807387842227, 'recall': 0.7474308300365714, 'f1': 0.6799712284075767, 'auc': 0.7880515820733212, 'prauc': 0.6791358395765101}
Test:      {'precision': 0.6271074380144559, 'recall': 0.7495061240586744, 'f1': 0.6828653658074079, 'auc': 0.7833693894620044, 'prauc': 0.6692574652234875}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6669371196727508, 'recall': 0.6498023715389336, 'f1': 0.6582582532564695, 'auc': 0.7996723445636489, 'prauc': 0.7050969682053583}
Corresponding test performance:
{'precision': 0.6671993607644139, 'recall': 0.659818253652075, 'f1': 0.663488274695572, 'auc': 0.799905293103916, 'prauc': 0.6969536243993214}


Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 29.50it/s, loss=0.5851]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 52.09it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.95it/s]


Validation: {'precision': 0.8294701986617637, 'recall': 0.19802371541423708, 'f1': 0.3197192055653233, 'auc': 0.7964881213794258, 'prauc': 0.7017704359252477}
Test:      {'precision': 0.785714285701814, 'recall': 0.19557487159148332, 'f1': 0.31319202464567125, 'auc': 0.799411390986069, 'prauc': 0.6937607142897819}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 29.51it/s, loss=0.5655]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 52.00it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.83it/s]


Validation: {'precision': 0.6527031865354361, 'recall': 0.7205533596809465, 'f1': 0.6849520896930816, 'auc': 0.7974311960181526, 'prauc': 0.7011449218843476}
Test:      {'precision': 0.6526277897744686, 'recall': 0.7163176610007258, 'f1': 0.6829911421169332, 'auc': 0.7981358218781517, 'prauc': 0.6928746529564623}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 29.56it/s, loss=0.5427]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 52.15it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.18it/s]


Validation: {'precision': 0.6263322883992906, 'recall': 0.7897233201549814, 'f1': 0.6986013936655241, 'auc': 0.7998910952171822, 'prauc': 0.7027775941139289}
Test:      {'precision': 0.6357868020284397, 'recall': 0.7917819043824901, 'f1': 0.7052613007056465, 'auc': 0.8000948849165654, 'prauc': 0.6953483001626659}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 29.42it/s, loss=0.5362]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.72it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.62it/s]


Validation: {'precision': 0.6414141414119818, 'recall': 0.7529644268744943, 'f1': 0.6927272677567539, 'auc': 0.7971836376184203, 'prauc': 0.7024512030839141}
Test:      {'precision': 0.6481167288746247, 'recall': 0.7546424338176427, 'f1': 0.6973347887466246, 'auc': 0.7983801288093372, 'prauc': 0.6959793696874785}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 29.48it/s, loss=0.5187]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.72it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.52it/s]


Validation: {'precision': 0.6415352981472189, 'recall': 0.7399209486136762, 'f1': 0.687224664626362, 'auc': 0.794684212727691, 'prauc': 0.69336399121253}
Test:      {'precision': 0.648052395723378, 'recall': 0.7427894112969468, 'f1': 0.6921943985552594, 'auc': 0.7925684407594438, 'prauc': 0.6808270299642386}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 29.51it/s, loss=0.5061]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.65it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.48it/s]


Validation: {'precision': 0.6236080178153879, 'recall': 0.7747035573091909, 'f1': 0.6909924152921508, 'auc': 0.7957531840140536, 'prauc': 0.7004598224874895}
Test:      {'precision': 0.6288025889947289, 'recall': 0.7676807585904083, 'f1': 0.6913360562460653, 'auc': 0.7918723806773218, 'prauc': 0.6875786007096216}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 29.34it/s, loss=0.4871]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.39it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.36it/s]


Validation: {'precision': 0.5813953488356186, 'recall': 0.8399209486132809, 'f1': 0.687146316909361, 'auc': 0.781050881485664, 'prauc': 0.6670907048016628}
Test:      {'precision': 0.5834490419311762, 'recall': 0.830106677199407, 'f1': 0.6852576598597261, 'auc': 0.7773388409379669, 'prauc': 0.6601293179216434}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 29.27it/s, loss=0.4626]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 47.90it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.32it/s]


Validation: {'precision': 0.6387240356059395, 'recall': 0.6806324110645035, 'f1': 0.6590126241644058, 'auc': 0.7873174289478637, 'prauc': 0.6825838743524302}
Test:      {'precision': 0.6439135381090485, 'recall': 0.6708810746713912, 'f1': 0.6571207380336163, 'auc': 0.7785956131774112, 'prauc': 0.6653020089088866}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6263322883992906, 'recall': 0.7897233201549814, 'f1': 0.6986013936655241, 'auc': 0.7998910952171822, 'prauc': 0.7027775941139289}
Corresponding test performance:
{'precision': 0.6357868020284397, 'recall': 0.7917819043824901, 'f1': 0.7052613007056465, 'auc': 0.8000948849165654, 'prauc': 0.6953483001626659}


Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 29.44it/s, loss=0.6030]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.25it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.17it/s]


Validation: {'precision': 0.6473087818673963, 'recall': 0.7225296442659189, 'f1': 0.6828539359912336, 'auc': 0.7930917874396134, 'prauc': 0.6929880222223104}
Test:      {'precision': 0.6494881750771285, 'recall': 0.7269853812693521, 'f1': 0.6860551777127692, 'auc': 0.7978002009118209, 'prauc': 0.6884988376987397}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 29.14it/s, loss=0.5497]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.30it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.13it/s]


Validation: {'precision': 0.6572504708073174, 'recall': 0.6897233201553766, 'f1': 0.6730954626955846, 'auc': 0.7951988309597005, 'prauc': 0.6917262986232744}
Test:      {'precision': 0.6597588545566702, 'recall': 0.6918214144579541, 'f1': 0.6754098310657825, 'auc': 0.7983861500903244, 'prauc': 0.6878645085801778}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 29.08it/s, loss=0.5406]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.11it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.02it/s]


Validation: {'precision': 0.6364532019683532, 'recall': 0.7660079051353123, 'f1': 0.6952466318114736, 'auc': 0.7980665871970218, 'prauc': 0.6957286412571014}
Test:      {'precision': 0.6442595673855432, 'recall': 0.7649150533355792, 'f1': 0.6994219603520475, 'auc': 0.7963682355751313, 'prauc': 0.685762896162096}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 29.06it/s, loss=0.5302]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 47.55it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.69it/s]


Validation: {'precision': 0.6365996106403744, 'recall': 0.7754940711431799, 'f1': 0.6992159608334803, 'auc': 0.8019355041094172, 'prauc': 0.6990725647589323}
Test:      {'precision': 0.6424562561880408, 'recall': 0.7688660608424779, 'f1': 0.6999999950375946, 'auc': 0.8012981462936712, 'prauc': 0.6885532804462243}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 28.89it/s, loss=0.5241]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.88it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.78it/s]


Validation: {'precision': 0.5207435955553826, 'recall': 0.9079051383363325, 'f1': 0.6618642800504976, 'auc': 0.7854273061881757, 'prauc': 0.6760374734976494}
Test:      {'precision': 0.526435733818308, 'recall': 0.9126827340935888, 'f1': 0.6677265454377308, 'auc': 0.7855689372271508, 'prauc': 0.6700601195338458}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 29.11it/s, loss=0.5155]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.72it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.56it/s]


Validation: {'precision': 0.6320206584873079, 'recall': 0.773913043475202, 'f1': 0.6958066759297634, 'auc': 0.7961624526841918, 'prauc': 0.6974143935357022}
Test:      {'precision': 0.6351706036724568, 'recall': 0.7649150533355792, 'f1': 0.694031183425465, 'auc': 0.7927981395393646, 'prauc': 0.686748715750095}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 29.03it/s, loss=0.4894]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.76it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.46it/s]


Validation: {'precision': 0.6206255806732096, 'recall': 0.7920948616569483, 'f1': 0.6959541537793519, 'auc': 0.7969112972373842, 'prauc': 0.6961289573572083}
Test:      {'precision': 0.6248042593152997, 'recall': 0.7882259976262812, 'f1': 0.6970649845822626, 'auc': 0.7894755966906203, 'prauc': 0.6791740910205095}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 29.05it/s, loss=0.4836]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.71it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 47.17it/s]


Validation: {'precision': 0.614981504313764, 'recall': 0.7885375494039979, 'f1': 0.6910287446410873, 'auc': 0.7904067068197502, 'prauc': 0.6903087240075112}
Test:      {'precision': 0.6171213546546971, 'recall': 0.7775582773576549, 'f1': 0.6881118831756475, 'auc': 0.784342899700224, 'prauc': 0.6772144362839646}


Epoch 009: 100%|██████████| 98/98 [00:03<00:00, 29.04it/s, loss=0.4620]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.75it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.04it/s]


Validation: {'precision': 0.6503391107737365, 'recall': 0.6822134387324814, 'f1': 0.6658950567286869, 'auc': 0.784706223727963, 'prauc': 0.6746423767907836}
Test:      {'precision': 0.6513409961660869, 'recall': 0.671671276172771, 'f1': 0.6613499269184678, 'auc': 0.7820053860620225, 'prauc': 0.6650109510146862}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6365996106403744, 'recall': 0.7754940711431799, 'f1': 0.6992159608334803, 'auc': 0.8019355041094172, 'prauc': 0.6990725647589323}
Corresponding test performance:
{'precision': 0.6424562561880408, 'recall': 0.7688660608424779, 'f1': 0.6999999950375946, 'auc': 0.8012981462936712, 'prauc': 0.6885532804462243}


Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 29.21it/s, loss=0.5960]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.09it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.89it/s]


Validation: {'precision': 0.7062857142816784, 'recall': 0.4885375494051837, 'f1': 0.5775700886213077, 'auc': 0.793710291318987, 'prauc': 0.6983218991157464}
Test:      {'precision': 0.7100958826807101, 'recall': 0.4974318451185404, 'f1': 0.585037169873554, 'auc': 0.7967616608389394, 'prauc': 0.6929653276524188}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 29.09it/s, loss=0.5601]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.08it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.78it/s]


Validation: {'precision': 0.6611893583698701, 'recall': 0.66798418972068, 'f1': 0.6645694012106514, 'auc': 0.7919727816466948, 'prauc': 0.6910651058282086}
Test:      {'precision': 0.6643137254875909, 'recall': 0.6693006716686318, 'f1': 0.6667978694316118, 'auc': 0.7954458800458204, 'prauc': 0.689710652325761}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 28.77it/s, loss=0.5422]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.66it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.32it/s]


Validation: {'precision': 0.6539717083763005, 'recall': 0.7126482213410568, 'f1': 0.6820503070928865, 'auc': 0.8009911224041658, 'prauc': 0.7055343964658254}
Test:      {'precision': 0.6557199566919678, 'recall': 0.7178980640034852, 'f1': 0.6854017302019261, 'auc': 0.8017532504183744, 'prauc': 0.6971496399406456}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 28.67it/s, loss=0.5322]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 46.84it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 49.94it/s]


Validation: {'precision': 0.6494736842082475, 'recall': 0.731620553356792, 'f1': 0.688104084234462, 'auc': 0.8026779701779702, 'prauc': 0.7037740011063828}
Test:      {'precision': 0.6530255334010038, 'recall': 0.7376531015379785, 'f1': 0.6927643734946095, 'auc': 0.7990639892525894, 'prauc': 0.6939290468458132}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 28.90it/s, loss=0.5215]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.04it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 49.88it/s]


Validation: {'precision': 0.6740934906043072, 'recall': 0.6098814229224906, 'f1': 0.6403818169646098, 'auc': 0.8008477633477633, 'prauc': 0.7024967220398932}
Test:      {'precision': 0.6768898488091711, 'recall': 0.6191228763310189, 'f1': 0.6467189384657872, 'auc': 0.798584642927042, 'prauc': 0.6903453895596958}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 28.84it/s, loss=0.5000]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 49.84it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 49.63it/s]


Validation: {'precision': 0.6906187624716037, 'recall': 0.5470355731203674, 'f1': 0.610498451173997, 'auc': 0.7983030093899659, 'prauc': 0.6990464350759154}
Test:      {'precision': 0.6858877086459996, 'recall': 0.5357566179354574, 'f1': 0.6015971556762159, 'auc': 0.7962679157979877, 'prauc': 0.6852508545972036}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 28.63it/s, loss=0.4903]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.25it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.35it/s]


Validation: {'precision': 0.664315352694339, 'recall': 0.6328063241081707, 'f1': 0.6481781326521481, 'auc': 0.7848432565823871, 'prauc': 0.6743959009736633}
Test:      {'precision': 0.6701417848178893, 'recall': 0.6349269063586135, 'f1': 0.6520592362263955, 'auc': 0.7877038216703913, 'prauc': 0.6713459222279877}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 29.23it/s, loss=0.4770]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 47.74it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.28it/s]

Validation: {'precision': 0.6260932944583597, 'recall': 0.6790513833965255, 'f1': 0.6514979093023108, 'auc': 0.7748306564610912, 'prauc': 0.6526970898378348}
Test:      {'precision': 0.6331861662964196, 'recall': 0.6799683919372581, 'f1': 0.6557439462326531, 'auc': 0.7744511916533945, 'prauc': 0.6428359776331427}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6539717083763005, 'recall': 0.7126482213410568, 'f1': 0.6820503070928865, 'auc': 0.8009911224041658, 'prauc': 0.7055343964658254}
Corresponding test performance:
{'precision': 0.6557199566919678, 'recall': 0.7178980640034852, 'f1': 0.6854017302019261, 'auc': 0.8017532504183744, 'prauc': 0.6971496399406456}





In [19]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")


Final Metrics:
precision: 0.6500 ± 0.0109
recall: 0.7354 ± 0.0454
f1: 0.6890 ± 0.0145
auc: 0.7994 ± 0.0028
prauc: 0.6931 ± 0.0042


In [20]:
final_metrics = {"precision":[],"recall":[],"f1":[],"auc":[],"prauc":[]}
for i in range(5):
    model = HeteroGT(tokenizer, d_model=128, num_heads=4, num_layers=2, max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, use_hetero_graph=False).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 69.19it/s, loss=0.5950]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 232.18it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 232.34it/s]


Validation: {'precision': 0.6764972776738816, 'recall': 0.5893280632387774, 'f1': 0.6299112751224439, 'auc': 0.7909441725746074, 'prauc': 0.6890348526448014}
Test:      {'precision': 0.6810540663303906, 'recall': 0.592256025284108, 'f1': 0.6335587439650036, 'auc': 0.7977340191799265, 'prauc': 0.6911663217106897}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 68.79it/s, loss=0.5559]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 228.62it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 231.36it/s]


Validation: {'precision': 0.6255850233989841, 'recall': 0.7924901185739428, 'f1': 0.6992153394434609, 'auc': 0.8003476796955058, 'prauc': 0.7058448309515591}
Test:      {'precision': 0.6348350253786966, 'recall': 0.7905966021304204, 'f1': 0.7042055203079729, 'auc': 0.7989776493191294, 'prauc': 0.6942260219012742}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 68.41it/s, loss=0.5318]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 229.27it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 231.04it/s]


Validation: {'precision': 0.6475409836043459, 'recall': 0.7494071146215439, 'f1': 0.6947599803666575, 'auc': 0.8012745989919903, 'prauc': 0.7076905646272043}
Test:      {'precision': 0.6508264462787506, 'recall': 0.7467404188038453, 'f1': 0.695492175333778, 'auc': 0.7978528740307178, 'prauc': 0.6917968583225995}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 69.52it/s, loss=0.5218]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 229.44it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 231.41it/s]


Validation: {'precision': 0.6595256312140033, 'recall': 0.6814229248984924, 'f1': 0.6702954848898626, 'auc': 0.8013506702637138, 'prauc': 0.7046224160025857}
Test:      {'precision': 0.6671732522771004, 'recall': 0.6937969182114034, 'f1': 0.680224670575494, 'auc': 0.7992374545039862, 'prauc': 0.692663678395727}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 68.75it/s, loss=0.5165]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 230.04it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 231.48it/s]


Validation: {'precision': 0.6498936924144227, 'recall': 0.7249011857678858, 'f1': 0.6853512655653867, 'auc': 0.7959961938222809, 'prauc': 0.7046224513214445}
Test:      {'precision': 0.6520664076275096, 'recall': 0.7293559857734914, 'f1': 0.6885490438754484, 'auc': 0.7938006566442531, 'prauc': 0.690031341107665}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 68.87it/s, loss=0.4982]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 226.80it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 231.51it/s]


Validation: {'precision': 0.6663957740728712, 'recall': 0.6482213438709556, 'f1': 0.6571829242710131, 'auc': 0.7976054018445322, 'prauc': 0.7041376595226213}
Test:      {'precision': 0.6753036437219624, 'recall': 0.6590280521506953, 'f1': 0.6670665816807397, 'auc': 0.7955416969519646, 'prauc': 0.6914227536801723}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 68.97it/s, loss=0.4906]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 229.95it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 230.65it/s]


Validation: {'precision': 0.5792158968835145, 'recall': 0.8525691699571045, 'f1': 0.6897985241215198, 'auc': 0.790356881443838, 'prauc': 0.6950081041684144}
Test:      {'precision': 0.5752906190846843, 'recall': 0.8407743974680333, 'f1': 0.683146062589281, 'auc': 0.7863610236512777, 'prauc': 0.6779212538510551}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 69.52it/s, loss=0.4794]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 229.26it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 232.35it/s]


Validation: {'precision': 0.6276525198918181, 'recall': 0.7482213438705604, 'f1': 0.6826541602000162, 'auc': 0.7847266662484054, 'prauc': 0.6773685537904122}
Test:      {'precision': 0.628684995029385, 'recall': 0.7499012248093643, 'f1': 0.6839639590001559, 'auc': 0.7816533243806485, 'prauc': 0.6699662205466703}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6475409836043459, 'recall': 0.7494071146215439, 'f1': 0.6947599803666575, 'auc': 0.8012745989919903, 'prauc': 0.7076905646272043}
Corresponding test performance:
{'precision': 0.6508264462787506, 'recall': 0.7467404188038453, 'f1': 0.695492175333778, 'auc': 0.7978528740307178, 'prauc': 0.6917968583225995}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 70.52it/s, loss=0.5994]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 229.74it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 232.14it/s]


Validation: {'precision': 0.6526315789449996, 'recall': 0.7106719367560843, 'f1': 0.6804162674757402, 'auc': 0.7915755902712424, 'prauc': 0.6978456477475646}
Test:      {'precision': 0.6541681834704981, 'recall': 0.7099960489896879, 'f1': 0.680939744911061, 'auc': 0.7948946448611823, 'prauc': 0.6927465319009085}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 70.52it/s, loss=0.5624]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 231.05it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 232.26it/s]


Validation: {'precision': 0.6649724692898561, 'recall': 0.6205533596813417, 'f1': 0.6419954969456876, 'auc': 0.7937123303427651, 'prauc': 0.691874820528759}
Test:      {'precision': 0.6724797958287007, 'recall': 0.624654286840677, 'f1': 0.6476853698505182, 'auc': 0.7987681087407742, 'prauc': 0.6912440205224875}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 69.11it/s, loss=0.5341]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 228.43it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 231.20it/s]


Validation: {'precision': 0.624194043596487, 'recall': 0.8035573122497883, 'f1': 0.7026092917759807, 'auc': 0.8021110692849823, 'prauc': 0.7039218811844061}
Test:      {'precision': 0.6287831995039259, 'recall': 0.8044251284045657, 'f1': 0.7058415620686607, 'auc': 0.8027259752721174, 'prauc': 0.6977381811433201}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 67.85it/s, loss=0.5333]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 220.37it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 225.79it/s]


Validation: {'precision': 0.6748850814848522, 'recall': 0.6383399209460936, 'f1': 0.6561039966262322, 'auc': 0.8008005521049, 'prauc': 0.704338777345127}
Test:      {'precision': 0.6764829617136033, 'recall': 0.6353220071093034, 'f1': 0.655256718718634, 'auc': 0.797880519564294, 'prauc': 0.6925758833717739}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 70.58it/s, loss=0.5154]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 228.84it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 230.84it/s]


Validation: {'precision': 0.6453752181478347, 'recall': 0.7308300395228031, 'f1': 0.6854494852855051, 'auc': 0.7960167409080453, 'prauc': 0.6980689874680645}
Test:      {'precision': 0.6521126760540419, 'recall': 0.7317265902776305, 'f1': 0.6896294867287458, 'auc': 0.7980953483981246, 'prauc': 0.6933028689582458}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 69.43it/s, loss=0.5050]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 227.30it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 228.66it/s]


Validation: {'precision': 0.6038011695888778, 'recall': 0.8162055335936119, 'f1': 0.6941176421683611, 'auc': 0.795817230273752, 'prauc': 0.7000249164132761}
Test:      {'precision': 0.605820105818325, 'recall': 0.8143026471718123, 'f1': 0.6947581275847708, 'auc': 0.79397166102429, 'prauc': 0.6896080512346959}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 68.08it/s, loss=0.4843]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 228.44it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 230.57it/s]


Validation: {'precision': 0.5955797933392779, 'recall': 0.8201581027635567, 'f1': 0.6900565298757673, 'auc': 0.7908471359558314, 'prauc': 0.6905405761112089}
Test:      {'precision': 0.5954350927229802, 'recall': 0.8245752666897489, 'f1': 0.6915175564267757, 'auc': 0.7885354914721383, 'prauc': 0.6774986518967324}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 69.94it/s, loss=0.4618]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 228.88it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 232.02it/s]


Validation: {'precision': 0.5692475463452311, 'recall': 0.825296442684485, 'f1': 0.6737657260473141, 'auc': 0.778334117573248, 'prauc': 0.6767187789398734}
Test:      {'precision': 0.5757162346505438, 'recall': 0.8336625839556158, 'f1': 0.6810845658560539, 'auc': 0.7822784904240165, 'prauc': 0.6766543238689386}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 69.30it/s, loss=0.4493]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 230.08it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 230.63it/s]


Validation: {'precision': 0.57608695651995, 'recall': 0.73320158102477, 'f1': 0.6452173863741036, 'auc': 0.7415228370663153, 'prauc': 0.6140910613595193}
Test:      {'precision': 0.5796296296278407, 'recall': 0.741999209795567, 'f1': 0.6508404040144702, 'auc': 0.7379778318424817, 'prauc': 0.5964430359154682}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6748850814848522, 'recall': 0.6383399209460936, 'f1': 0.6561039966262322, 'auc': 0.8008005521049, 'prauc': 0.704338777345127}
Corresponding test performance:
{'precision': 0.6764829617136033, 'recall': 0.6353220071093034, 'f1': 0.655256718718634, 'auc': 0.797880519564294, 'prauc': 0.6925758833717739}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 69.04it/s, loss=0.6164]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 229.44it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 231.45it/s]


Validation: {'precision': 0.685175484543294, 'recall': 0.5169960474287866, 'f1': 0.5893219144464051, 'auc': 0.788809210113558, 'prauc': 0.6911992049336193}
Test:      {'precision': 0.6970010341225595, 'recall': 0.5325958119299384, 'f1': 0.6038073859041521, 'auc': 0.7951646600442348, 'prauc': 0.6916530741008602}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 70.50it/s, loss=0.5627]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 227.39it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 230.91it/s]


Validation: {'precision': 0.6541850220240302, 'recall': 0.7043478260841726, 'f1': 0.6783403071473639, 'auc': 0.7931745508919421, 'prauc': 0.6953966910327346}
Test:      {'precision': 0.6574585635334901, 'recall': 0.7052548399814096, 'f1': 0.6805184852818641, 'auc': 0.7958535993071023, 'prauc': 0.6888291151878287}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 69.16it/s, loss=0.5486]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 230.29it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 232.18it/s]


Validation: {'precision': 0.639919759275694, 'recall': 0.7565217391274446, 'f1': 0.693352648537149, 'auc': 0.7947604408473975, 'prauc': 0.6963440406307766}
Test:      {'precision': 0.6469011725271461, 'recall': 0.7629395495821298, 'f1': 0.7001450276636751, 'auc': 0.7983708089135484, 'prauc': 0.6928107683374923}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 68.11it/s, loss=0.5297]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 229.85it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 231.65it/s]


Validation: {'precision': 0.6450051849269098, 'recall': 0.7375494071117094, 'f1': 0.688179969203896, 'auc': 0.7977296777296776, 'prauc': 0.7028281311623512}
Test:      {'precision': 0.6530187369859368, 'recall': 0.7435796127983265, 'f1': 0.6953630099824302, 'auc': 0.8011007529952208, 'prauc': 0.6970254993822014}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 68.91it/s, loss=0.5163]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 230.81it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 231.46it/s]


Validation: {'precision': 0.6193428394277144, 'recall': 0.7897233201549814, 'f1': 0.6942321006996018, 'auc': 0.7948980488110922, 'prauc': 0.7009802240986636}
Test:      {'precision': 0.6237562189035332, 'recall': 0.7925721058838697, 'f1': 0.6981033533424871, 'auc': 0.7970836684743425, 'prauc': 0.6948791316628661}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 68.86it/s, loss=0.5268]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 230.17it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 231.29it/s]


Validation: {'precision': 0.635482793181973, 'recall': 0.7517786561235108, 'f1': 0.6887561058419871, 'auc': 0.7814718091892006, 'prauc': 0.6625730107145472}
Test:      {'precision': 0.6453135536053766, 'recall': 0.7562228368204021, 'f1': 0.6963798385798052, 'auc': 0.7878522069775024, 'prauc': 0.6723470410376834}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 70.43it/s, loss=0.4883]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 231.39it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 230.80it/s]


Validation: {'precision': 0.6326188881425565, 'recall': 0.7466403162025825, 'f1': 0.6849166012680904, 'auc': 0.7955903235251061, 'prauc': 0.6992356843236152}
Test:      {'precision': 0.6389076196876639, 'recall': 0.7487159225572947, 'f1': 0.689466977018979, 'auc': 0.794914750703783, 'prauc': 0.6890286213328429}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 69.38it/s, loss=0.4677]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 230.28it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 230.45it/s]


Validation: {'precision': 0.6037974683525197, 'recall': 0.7541501976254777, 'f1': 0.6706502586793247, 'auc': 0.7748455570194701, 'prauc': 0.678139353590149}
Test:      {'precision': 0.6061933062226894, 'recall': 0.765705254836959, 'f1': 0.6766759727188634, 'auc': 0.7687464206102479, 'prauc': 0.6637875012333353}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 68.58it/s, loss=0.4463]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 229.50it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 231.48it/s]


Validation: {'precision': 0.6391194514592597, 'recall': 0.6999999999972333, 'f1': 0.6681758108916092, 'auc': 0.7790276993537864, 'prauc': 0.6731643793486382}
Test:      {'precision': 0.6350338922560291, 'recall': 0.7032793362279602, 'f1': 0.6674165679388833, 'auc': 0.7775345063905688, 'prauc': 0.6701052350928162}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6450051849269098, 'recall': 0.7375494071117094, 'f1': 0.688179969203896, 'auc': 0.7977296777296776, 'prauc': 0.7028281311623512}
Corresponding test performance:
{'precision': 0.6530187369859368, 'recall': 0.7435796127983265, 'f1': 0.6953630099824302, 'auc': 0.8011007529952208, 'prauc': 0.6970254993822014}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 69.24it/s, loss=0.6132]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 229.89it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 232.03it/s]


Validation: {'precision': 0.6399590862576203, 'recall': 0.7418972331986486, 'f1': 0.687168217613015, 'auc': 0.7866751364577451, 'prauc': 0.6857634410487328}
Test:      {'precision': 0.6469982847319142, 'recall': 0.7451600158010859, 'f1': 0.692618430571342, 'auc': 0.7950868022631221, 'prauc': 0.6918984503562713}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 68.59it/s, loss=0.5556]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 230.02it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 232.40it/s]


Validation: {'precision': 0.6841607564979473, 'recall': 0.57193675889102, 'f1': 0.6230355171039671, 'auc': 0.7943520086998349, 'prauc': 0.6982877074914825}
Test:      {'precision': 0.6857808857776887, 'recall': 0.5811932042647918, 'f1': 0.6291702259980191, 'auc': 0.7964171388485404, 'prauc': 0.6930203743894635}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 69.68it/s, loss=0.5568]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 230.91it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 231.63it/s]


Validation: {'precision': 0.644778204678153, 'recall': 0.7296442687718196, 'f1': 0.6845911316751562, 'auc': 0.7963898299767865, 'prauc': 0.6994465034041167}
Test:      {'precision': 0.6514566514543648, 'recall': 0.73330699328039, 'f1': 0.6899628202937141, 'auc': 0.7988883772836233, 'prauc': 0.6946034480515831}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 70.17it/s, loss=0.5312]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 229.38it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 231.56it/s]


Validation: {'precision': 0.648266100493106, 'recall': 0.7241106719338968, 'f1': 0.6840926014354202, 'auc': 0.7987430202647594, 'prauc': 0.7012755188338557}
Test:      {'precision': 0.6509734513251293, 'recall': 0.7265902805186623, 'f1': 0.6867064923986103, 'auc': 0.7971009992917927, 'prauc': 0.6916955948632788}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 68.85it/s, loss=0.5181]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 229.62it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 230.63it/s]


Validation: {'precision': 0.6519316493289304, 'recall': 0.6936758893253214, 'f1': 0.6721562569708321, 'auc': 0.7896250810381245, 'prauc': 0.6888070776006746}
Test:      {'precision': 0.6590909090884535, 'recall': 0.6989332279703717, 'f1': 0.6784276076575027, 'auc': 0.7920729678726437, 'prauc': 0.690014543233221}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 69.28it/s, loss=0.5052]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 229.56it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 231.62it/s]


Validation: {'precision': 0.6372315035777797, 'recall': 0.7387351778626928, 'f1': 0.6842394239099324, 'auc': 0.7861289394985047, 'prauc': 0.6805303988824221}
Test:      {'precision': 0.6436464088375565, 'recall': 0.7364677992859089, 'f1': 0.6869356869309019, 'auc': 0.7842955148368028, 'prauc': 0.6691686015473024}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 69.58it/s, loss=0.4917]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 230.24it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 231.73it/s]


Validation: {'precision': 0.6319241982484259, 'recall': 0.6853754940684372, 'f1': 0.6575654102503349, 'auc': 0.7799565531087271, 'prauc': 0.6802636008148021}
Test:      {'precision': 0.6275887573941288, 'recall': 0.6704859739207014, 'f1': 0.6483285527871289, 'auc': 0.770323524998212, 'prauc': 0.662010541892839}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 68.69it/s, loss=0.4744]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 231.55it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 231.40it/s]


Validation: {'precision': 0.5968969555017656, 'recall': 0.8059288537517553, 'f1': 0.6858392147521676, 'auc': 0.7877568124307255, 'prauc': 0.6908962984504425}
Test:      {'precision': 0.6009445100336454, 'recall': 0.8044251284045657, 'f1': 0.6879540413940967, 'auc': 0.7833217428037575, 'prauc': 0.682890706806769}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 70.60it/s, loss=0.4552]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 229.17it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 230.67it/s]


Validation: {'precision': 0.6509992598051406, 'recall': 0.6952569169932994, 'f1': 0.6724006066236285, 'auc': 0.7876141853315766, 'prauc': 0.6895070472579061}
Test:      {'precision': 0.6489795918343266, 'recall': 0.6910312129565744, 'f1': 0.6693455747957036, 'auc': 0.7792239207586855, 'prauc': 0.6708661066439696}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.648266100493106, 'recall': 0.7241106719338968, 'f1': 0.6840926014354202, 'auc': 0.7987430202647594, 'prauc': 0.7012755188338557}
Corresponding test performance:
{'precision': 0.6509734513251293, 'recall': 0.7265902805186623, 'f1': 0.6867064923986103, 'auc': 0.7971009992917927, 'prauc': 0.6916955948632788}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 70.14it/s, loss=0.6067]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 229.54it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 229.98it/s]


Validation: {'precision': 0.6518895348813522, 'recall': 0.7090909090881063, 'f1': 0.6792881434348859, 'auc': 0.789279231235753, 'prauc': 0.6936382623315853}
Test:      {'precision': 0.6552727272703445, 'recall': 0.7119715527431373, 'f1': 0.6824465013495098, 'auc': 0.7954815888600228, 'prauc': 0.6916799085475168}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 68.96it/s, loss=0.5552]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 228.45it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 231.15it/s]


Validation: {'precision': 0.6508166969123382, 'recall': 0.7086956521711119, 'f1': 0.6785241198882355, 'auc': 0.7923320158102767, 'prauc': 0.6920015025643826}
Test:      {'precision': 0.6545586632740481, 'recall': 0.7119715527431373, 'f1': 0.6820590411833827, 'auc': 0.7994976785605641, 'prauc': 0.6947214323530759}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 71.12it/s, loss=0.5419]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 229.69it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 230.88it/s]


Validation: {'precision': 0.6508605549678579, 'recall': 0.732411067190781, 'f1': 0.6892319087213425, 'auc': 0.7979445594662986, 'prauc': 0.6960464920728742}
Test:      {'precision': 0.6552575864479349, 'recall': 0.7337020940310799, 'f1': 0.6922646734849428, 'auc': 0.8000838895338932, 'prauc': 0.6924532351712995}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 68.90it/s, loss=0.5319]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 229.03it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 231.30it/s]


Validation: {'precision': 0.6429058663006723, 'recall': 0.7450592885346046, 'f1': 0.6902233564306354, 'auc': 0.8034508124725515, 'prauc': 0.7027859256802842}
Test:      {'precision': 0.6526097476645261, 'recall': 0.7459502173024657, 'f1': 0.6961651867601175, 'auc': 0.7984425406957438, 'prauc': 0.6912904793271131}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 68.90it/s, loss=0.5191]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 228.69it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 230.95it/s]


Validation: {'precision': 0.6534582132541302, 'recall': 0.716996047427996, 'f1': 0.6837542354906431, 'auc': 0.8002398728485685, 'prauc': 0.7026738769434975}
Test:      {'precision': 0.6581227436799346, 'recall': 0.7202686685076244, 'f1': 0.6877947507140393, 'auc': 0.79569955923176, 'prauc': 0.6925534793788024}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 69.49it/s, loss=0.5103]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 229.98it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 230.77it/s]


Validation: {'precision': 0.6428082191758808, 'recall': 0.7418972331986486, 'f1': 0.6888073344726175, 'auc': 0.7929751448229709, 'prauc': 0.6814605074225121}
Test:      {'precision': 0.648154536044677, 'recall': 0.7423943105462568, 'f1': 0.6920810263279666, 'auc': 0.7942233505695556, 'prauc': 0.6852697422121805}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 69.98it/s, loss=0.4951]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 229.63it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 233.00it/s]


Validation: {'precision': 0.5736559139769526, 'recall': 0.8434782608662313, 'f1': 0.6828799951790757, 'auc': 0.7930898529811574, 'prauc': 0.6952946198410629}
Test:      {'precision': 0.5803717878606333, 'recall': 0.838798893714584, 'f1': 0.6860559008026171, 'auc': 0.7891433266980562, 'prauc': 0.6876457860313263}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 70.31it/s, loss=0.4769]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 230.54it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 230.73it/s]


Validation: {'precision': 0.615484288951683, 'recall': 0.7509881422895218, 'f1': 0.6765177091290082, 'auc': 0.7833529916138613, 'prauc': 0.6794487064930153}
Test:      {'precision': 0.6158357771240931, 'recall': 0.7467404188038453, 'f1': 0.6749999950437379, 'auc': 0.7756836693330295, 'prauc': 0.6646242292727216}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 70.33it/s, loss=0.4614]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 229.64it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 231.79it/s]

Validation: {'precision': 0.6160203432911124, 'recall': 0.7660079051353123, 'f1': 0.6828752593270976, 'auc': 0.7911798105276366, 'prauc': 0.6944167411458915}
Test:      {'precision': 0.61800643086618, 'recall': 0.7593836428259211, 'f1': 0.6814394561387199, 'auc': 0.7852518513344677, 'prauc': 0.6824556506066514}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6429058663006723, 'recall': 0.7450592885346046, 'f1': 0.6902233564306354, 'auc': 0.8034508124725515, 'prauc': 0.7027859256802842}
Corresponding test performance:
{'precision': 0.6526097476645261, 'recall': 0.7459502173024657, 'f1': 0.6961651867601175, 'auc': 0.7984425406957438, 'prauc': 0.6912904793271131}





In [21]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")


Final Metrics:
precision: 0.6568 ± 0.0099
recall: 0.7196 ± 0.0428
f1: 0.6858 ± 0.0157
auc: 0.7985 ± 0.0014
prauc: 0.6929 ± 0.0021
