In [None]:
import torch
import pickle
import numpy as np
from argparse import Namespace
from torch.utils.data import DataLoader
import torch.nn.functional as F
from heterogt.utils.tokenizer import EHRTokenizer
from heterogt.utils.dataset import FineTuneEHRDataset, batcher
from heterogt.utils.train import train_with_early_stopping
from heterogt.utils.seed import set_random_seed

In [None]:
set_random_seed(123)

[INFO] Random seed set to 123


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
config = Namespace(
    dataset = "MIMIC-III",
    tasks = ["death", "readmission", "stay", "next_diag_6m", "next_diag_12m"], 
    task_index = 0,  # index of the task to train
    token_type = ["diag", "med", "lab", "pro"],
    special_tokens = ["[PAD]"],
    batch_size = 32,
    lr = 1e-3,
    epochs = 500,
    early_stop_patience = 5,   
)

In [None]:
full_data_path = f"./data_process/{config.dataset}-processed/mimic.pkl"  # for tokenizer
curr_task = config.tasks[config.task_index]
print("Current task:", curr_task)
if curr_task == "next_diag_6m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_6m.pkl"
elif curr_task == "next_diag_12m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_downstream.pkl"

Current task: death


In [None]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_sentences = [[str(c)] for c in set(ehr_full_data["AGE"].values.tolist())] # important of [[]]
token_type_sentences = ["[PAD]"] + config.token_type
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
config.max_num_adms = max_admissions
print(f"Max admissions per patient: {config.max_num_adms}")

Max admissions per patient: 8


In [None]:
task_sentences = config.tasks
tokenizer = EHRTokenizer(token_type_sentences, task_sentences, age_sentences, diag_sentences, 
                         med_sentences, lab_sentences, pro_sentences, special_tokens=config.special_tokens)
config.label_vocab_size = len(tokenizer.diag_voc.id2word)  # only for diagnosis
config.global_vocab_size = len(tokenizer.vocab.id2word)
config.age_vocab_size = tokenizer.token_number("age")
print(f"Age vocabulary size: {config.age_vocab_size}")

ValueError: Unknown vocabulary type: age

In [None]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))
# example label percentage
print("Percentage of DEATH in test dataset:",
      (test_data["DEATH"] == True).mean() * 100, "%")

print("Percentage of READMISSION in test dataset:",
      (test_data["READMISSION"] == 1).mean() * 100, "%")

print("Percentage of STAY>7 days in test dataset:",
      (test_data["STAY_DAYS"] > 7).mean() * 100, "%")

Percentage of DEATH in test dataset: 28.648477157360407 %
Percentage of READMISSION in test dataset: 40.1491116751269 %
Percentage of STAY>7 days in test dataset: 50.58692893401015 %


In [None]:
train_dataset = FineTuneEHRDataset(train_data, tokenizer, token_type=config.token_type, task=curr_task)
val_dataset = FineTuneEHRDataset(val_data, tokenizer, token_type=config.token_type, task=curr_task)
test_dataset = FineTuneEHRDataset(test_data, tokenizer, token_type=config.token_type, task=curr_task)

In [None]:
train_dataloader = DataLoader(
    train_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=True,
    batch_size=config.batch_size,
)

val_dataloader = DataLoader(
    val_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

test_dataloader = DataLoader(
    test_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

In [None]:
if curr_task in ["death", "stay", "readmission"]:
    eval_metric = "f1"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "f1"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [None]:
input_ids, token_types, adm_index, age_ids, task_index, labels = next(iter(train_dataloader))
print("Input IDs shape:", input_ids.shape)
print("Token Types shape:", token_types.shape)
print("Admission Index shape:", adm_index.shape)
print("Age IDs shape:", age_ids.shape)
print("Task Index:", task_index)
print("Labels shape:", labels.shape)

Input IDs shape: torch.Size([32, 256])
Token Types shape: torch.Size([32, 256])
Admission Index shape: torch.Size([32, 256])
Age/Sex IDs shape: torch.Size([32, 7])
Task Index: 0
Labels shape: torch.Size([32, 1])


# Model Walkthrough

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData, Batch as HeteroBatch
from torch_geometric.nn import HeteroConv, GATConv
from heterogt.model.layer import TransformerEncoder

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [None]:
class DiseaseOccHetGNN(nn.Module):
    def __init__(self, d_model: int, heads: int = 4, dropout: float = 0.0):
        super().__init__()
        self.d = d_model
        self.act = nn.GELU()
        self.drop = nn.Dropout(dropout)

        # —— 规范化：按节点类型各自一套 LN —— #
        self.ln_v1 = nn.LayerNorm(d_model)
        self.ln_o1 = nn.LayerNorm(d_model)
        self.ln_v2 = nn.LayerNorm(d_model)
        self.ln_o2 = nn.LayerNorm(d_model)

        # —— 可学习缩放（残差权重），初始化小值避免早期干扰 —— #
        self.alpha_v1 = nn.Parameter(torch.tensor(0.1))
        self.alpha_o1 = nn.Parameter(torch.tensor(0.1))
        self.alpha_v2 = nn.Parameter(torch.tensor(0.1))
        self.alpha_o2 = nn.Parameter(torch.tensor(0.1))

        # 注意：这里用 aggr='sum'，让关系信号不被平均稀释
        self.conv1 = HeteroConv({
            ('visit','contains','occ'):   GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('occ','contained_by','visit'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('visit','next','visit'):     GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=True),
        }, aggr='sum')

        self.conv2 = HeteroConv({
            ('visit','contains','occ'):   GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('occ','contained_by','visit'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('visit','next','visit'):     GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=True),
        }, aggr='sum')

        # 末端线性 + 残差（用零初始化保持“近似恒等”）
        self.lin_v = nn.Linear(d_model, d_model)
        self.lin_o = nn.Linear(d_model, d_model)
        nn.init.zeros_(self.lin_v.weight); nn.init.zeros_(self.lin_v.bias)
        nn.init.zeros_(self.lin_o.weight); nn.init.zeros_(self.lin_o.bias)

    def forward(self, hg):
        # x_dict: {'visit': [N_visit, d], 'occ': [N_occ, d]}
        x_v = hg['visit'].x
        x_o = hg['occ'].x

        # ===== Layer 1: 图卷积（sum 聚合）→ 残差 + LN =====
        h1 = self.conv1({'visit': x_v, 'occ': x_o}, hg.edge_index_dict)
        # 残差注入前先丢弃避免过拟合
        dv = self.drop(h1['visit'])
        do = self.drop(h1['occ'])
        # y = LN(x + α * Δx)
        v1 = self.ln_v1(x_v + self.alpha_v1 * dv)
        o1 = self.ln_o1(x_o + self.alpha_o1 * do)

        # ===== Layer 2: 再一层图卷积 → 残差 + LN =====
        h2 = self.conv2({'visit': v1, 'occ': o1}, hg.edge_index_dict)
        dv2 = self.drop(h2['visit'])
        do2 = self.drop(h2['occ'])
        v2 = self.ln_v2(v1 + self.alpha_v2 * dv2)
        o2 = self.ln_o2(o1 + self.alpha_o2 * do2)

        # ===== 末端线性：零初始化，等价“细调残差”，不改变整体尺度期望 =====
        v_out = v2 + self.lin_v(v2)
        o_out = o2 + self.lin_o(o2)

        return {'visit': v_out, 'occ': o_out}

In [None]:
# multi-class classification task
class MultiPredictionHead(nn.Module):
    def __init__(self, hidden_size, label_size):
        super(MultiPredictionHead, self).__init__()
        self.cls = nn.Sequential(
                nn.Linear(hidden_size, hidden_size), 
                nn.ReLU(), 
                nn.Linear(hidden_size, label_size)
            )

    def forward(self, input):
        return self.cls(input)
    
class BinaryPredictionHead(nn.Module):
    def __init__(self, hidden_size):
        super(BinaryPredictionHead, self).__init__()
        self.cls = nn.Sequential(
                nn.Linear(hidden_size, hidden_size), 
                nn.ReLU(), 
                nn.Linear(hidden_size, 1)
            )
    def forward(self, input):
        return self.cls(input)

In [None]:
for i in range(len(train_dataset)):
    age_ids = train_dataset[i][3]
    if len(age_ids[0]) > 3:
        print(age_ids)
        break
exp_i = i
id_seq = torch.concat([train_dataset[exp_i][0][0], torch.zeros(5, dtype=train_dataset[exp_i][0][0].dtype)], dim=0)
type_seq = torch.concat([train_dataset[exp_i][1][0], torch.zeros(5, dtype=train_dataset[exp_i][1][0].dtype)], dim=0)
visit_seq = torch.concat([train_dataset[exp_i][2][0], torch.zeros(5, dtype=train_dataset[exp_i][2][0].dtype)], dim=0)
age_seq = torch.concat([train_dataset[exp_i][3][0], torch.zeros(3, dtype=train_dataset[exp_i][3][0].dtype)], dim=0)

tensor([[15, 15, 15, 11]])


In [None]:
class HeteroGT(nn.Module):
    def __init__(self, tokenizer, d_model, num_heads, layer_types, max_num_adms, device, task, label_vocab_size):
        super(HeteroGT, self).__init__()
        self.device = device
        self.tokenizer = tokenizer
        self.max_num_adms = max_num_adms
        self.global_vocab_size = len(self.tokenizer.vocab.word2id)
        self.n_type = len(self.tokenizer.token_type_voc.word2id)
        self.d_model = d_model
        self.num_attn_heads = num_heads
        self.layer_types = layer_types
        self.seq_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="all")[0] #0
        self.type_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="type")[0] #0
        self.adm_pad_id = 0
        self.age_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="all")[0] #0
        self.node_type_id_dict = {'diag': 1, 'med': 2, 'lab': 3, 'pro': 4, 'visit': 5}
        self.graph_node_types = ['diag']
        self.forbid_map = {-1: [], 1: [5], 2: [5], 3: [5], 4: [5], 5: [-1, 1, 5]}  # attention forbid mask

        # embedding layers
        self.token_emb = nn.Embedding(self.global_vocab_size, d_model, padding_idx=self.seq_pad_id) # already contains [PAD], will also be used for age_gender
        self.type_emb = nn.Embedding(self.n_type + 1, d_model, padding_idx=self.type_pad_id) # n_type already have PAD, + 1 for visit
        self.adm_index_emb = nn.Embedding(self.max_num_adms + 1, d_model, padding_idx=self.adm_pad_id) # +1 for pad
        self.task_emb = nn.Embedding(5, d_model, padding_idx=None)  # 5 task in total, task embedding
        
        # stack together
        self.stack_layers = nn.ModuleList(self.make_gnn_layer() if layer_type == 'gnn' else self.make_tf_layer()
            for layer_type in self.layer_types
        )

        # prediction head
        if task in ["death", "stay", "readmission"]:
            self.cls_head = BinaryPredictionHead(self.d_model)
        else:
            self.cls_head = MultiPredictionHead(self.d_model, label_vocab_size)

    def make_tf_layer(self):
        assert self.d_model % self.num_attn_heads == 0, "Invalid model and attention head dimensions"
        layer_layer = nn.TransformerEncoderLayer(d_model=self.d_model, nhead=self.num_attn_heads, batch_first=True, norm_first=True)
        tf_wrapper = nn.TransformerEncoder(layer_layer, num_layers=1, enable_nested_tensor=False)
        return tf_wrapper

    def make_gnn_layer(self):
        return DiseaseOccHetGNN(d_model=self.d_model, heads=self.num_attn_heads)
    
    def forward(self, input_ids, token_types, adm_index, age_ids, task_id):
        """Forward pass for the model.

        Args:
            input_ids (Tensor): Input token IDs. Shape of [B, L]
            token_types (Tensor): Token type IDs. Shape of [B, L]
            adm_index (Tensor): Admission index IDs. Shape of [B, L]
            age_ids (Tensor): Age IDs. Shape of [B, V]
            task_id (Tensor): Task ID. Shape of [1]

        Returns:
            Tensor: Output logits. Shape of [B, label_size]
        """
        B, L = input_ids.shape
        V = age_ids.shape[1]
        num_visits = adm_index.max(dim=1).values
        
        task_id = torch.full((B,), task_id, dtype=torch.long, device=self.device) # [1] -> [B]
        # 基础表示
        task_id_embed = self.task_emb(task_id).unsqueeze(1) # [B, 1, d]
        seq_embed = self.token_emb(input_ids)  # [B, L, d]
        visit_embed = self.token_emb(age_ids)

        # run through layers
        for i, layer_type in enumerate(self.layer_types):
            if layer_type == 'gnn': # purpose is just to update visit_embed
                seq_embed_det = seq_embed.detach()
                visit_embed_det = visit_embed.detach()
                with torch.no_grad():
                    hg_batch = self.build_graph_batch(seq_embed_det, token_types, self.graph_node_types, visit_embed_det, adm_index) # num_visits is a 1d tensor of [B]
                    gnn_out = self.stack_layers[i](hg_batch)['visit']  # extract virtual visit node representations
                    visit_embed = self.process_gnn_out(gnn_out, num_visits, V).detach() # [B, V, d]
            elif layer_type == 'tf':
                x, src_key_padding_mask, attn_mask = self.prepare_tf_input(task_id_embed, seq_embed, visit_embed, i, input_ids, adm_index, token_types, num_visits)
                h = self.stack_layers[i](src=x, src_key_padding_mask=src_key_padding_mask, mask=attn_mask) # [B, 1+L+V, d]
                task_id_embed, seq_embed, visit_embed = self.process_tf_out(h, L, V) # # [B, 1, d], [B, L, d], [B, V, d]
            else:
                raise ValueError(f"Unknown layer type: {layer_type}")

        logits = self.cls_head(task_id_embed.squeeze())  # [B, label_size]
        return logits

    def build_graph_batch(self, seq_embed, token_types, graph_node_types, visit_embed, adm_index):
        """Build a batch of heterogeneous graphs from the input sequences.

        Args:
            seq_embed (Tensor): Sequence embeddings. Shape of [B, L, d]
            token_types (Tensor): Token type IDs. Shape of [B, L]
            graph_node_types: a list controls what types of tokens are connected to the virtual visit nodes. e.g. ['diag']
            visit_embed (Tensor): Visit embeddings. Shape of [B, V, d]
        Returns:
            A batch of heterogeneous graphs.
        """
        B, L = seq_embed.shape[0], seq_embed.shape[1]
        V = visit_embed.shape[1]
        graph_node_type_ids = [self.node_type_id_dict[t] for t in graph_node_types]
        graphs = [] # contains heterogeneous graphs for each patient
        for p in range(B):
            hg_p = self.build_patient_graph(seq_embed[p], token_types[p], visit_embed[p], adm_index[p], graph_node_type_ids)
            graphs.append(hg_p)
        hg_batch = HeteroBatch.from_data_list(graphs).to(self.device)
        return hg_batch

    def build_patient_graph(self, seq_embed_p, token_types_p, visit_embed_p, adm_index_p, graph_node_type_ids):
        """Build a heterogeneous graph for a single patient.

        Args:
            seq_embed_p (Tensor): Sequence embeddings for patient p. Shape [L, d]
            token_types_p (Tensor): Token type IDs for patient p. Shape [L]
            visit_embed_p (Tensor): Visit embeddings for patient p. Shape [V, d]
            graph_node_type_ids (list): List of graph node type IDs that the graph uses.
            adm_index_p (Tensor): Admission index for patient p. Shape [L]

        Returns:
            A heterogeneous graph for patient p.
        """
        hg = HeteroData()
        occ_mask = torch.isin(token_types_p, torch.tensor(graph_node_type_ids, device=token_types_p.device)) # [L], a mask for the token types needed in the graph
        occ_pos = torch.nonzero(occ_mask, as_tuple=False).view(-1) # [L], seq position index for the token types needed in the graph
        num_occ = occ_pos.numel() # int, number of occurrences of the token types needed in the graph
        
        # build visit virtual nodes
        nonpad = adm_index_p != self.adm_pad_id
        adm_index_used_p = adm_index_p[nonpad] # adm_index非pad部分
        adm_ids_unique, adm_lid_nonpad = torch.unique(adm_index_used_p, return_inverse=True)
        num_visit_p = adm_ids_unique.numel()  # int, number of visits for patient
        adm_lid_full = torch.full_like(token_types_p, fill_value=-1) # [L]
        adm_lid_full[nonpad] = adm_lid_nonpad
        hg['visit'].x = visit_embed_p[:num_visit_p, :]
        hg['visit'].num_nodes = num_visit_p
        
        # build medical code nodes
        gid_occ_embed = seq_embed_p[occ_pos, :]
        hg['occ'].x = gid_occ_embed
        hg['occ'].num_nodes = num_occ

        # build edges between occ nodes and virtual visit nodes
        occ_adm_lid = adm_lid_full[occ_pos]
        assert (occ_adm_lid != -1).all(), "occ_adm_lid contains -1"
        e_v2o = torch.stack([occ_adm_lid, torch.arange(num_occ, device=self.device)], dim=0)
        e_o2v = torch.stack([torch.arange(num_occ, device=self.device), occ_adm_lid], dim=0)
        hg['visit','contains','occ'].edge_index = e_v2o
        hg['occ','contained_by','visit'].edge_index = e_o2v
        
        # build forward edges between virtual visit nodes
        if num_visit_p > 1:
            src = torch.arange(0, num_visit_p - 1, device=self.device)
            dst = torch.arange(1, num_visit_p, device=self.device)
            e_next = torch.stack([src, dst], dim=0) # [2, num_visit_p-1]
        else:
            e_next = torch.empty(2, 0, dtype=torch.long, device=self.device)
        hg['visit','next','visit'].edge_index = e_next
        return hg

    def process_gnn_out(self, gnn_out, num_visits, V):
        """Process the output of the GNN layer.

        Args:
            gnn_out (Tensor): The output of the GNN layer. Shape [sum(num_visits), d]
            num_visits (Tensor): A tensor containing the number of visits for each patient.
            V (int): The maximum number of visits.

        Returns:
            Tensor: The processed visit embeddings. Shape [B, V, d]
        """
        B = len(num_visits)
        # 计算每个批次的累积偏移量
        cumsum = torch.cumsum(num_visits, dim=0)  # [B]
        offsets = torch.cat([torch.tensor([0], device=self.device), cumsum[:-1]])  # [B]

        # 创建索引以从 gnn_out 中提取所有批次的嵌入
        indices = torch.arange(sum(num_visits), device=self.device)  # [N]
        batch_indices = torch.repeat_interleave(torch.arange(B, device=self.device), num_visits)  # [N]
        visit_pos = indices - offsets[batch_indices]  # [N]，每个嵌入的相对位置

        # 创建目标张量 visit_emb_pad，初始化为零
        visit_emb_pad = torch.zeros(B, V, self.d_model, device=self.device, dtype=gnn_out.dtype)  # [B, V, d]

        # 创建掩码，选择有效位置 (visit_pos < V 且 visit_pos < num_visits)
        mask = (visit_pos < V) & (visit_pos < num_visits[batch_indices])  # [N]
        valid_indices = indices[mask]  # [N_valid]
        valid_batch_indices = batch_indices[mask]  # [N_valid]
        valid_visit_pos = visit_pos[mask]  # [N_valid]

        # 使用 scatter 将 gnn_out 的值分配到 visit_emb_pad
        visit_emb_pad[valid_batch_indices, valid_visit_pos] = gnn_out[valid_indices]
        return visit_emb_pad
    
    def prepare_tf_input(self, task_id_embed, seq_embed, visit_embed, layer_i, input_ids, adm_index, token_types, num_visits):
        """Prepare the input for the Transformer layer.
        Args:
            task_id_emb (Tensor): Task ID embeddings. Shape [B, 1, d]
            seq_embed (Tensor): Sequence embeddings. Shape [B, L, d]
            visit_embed (Tensor): Visit embeddings. Shape [B, V, d]
            layer_i (int): The current layer index.
            adm_index (tensor): The admission index. Shape [B, L]
            token_types (Tensor): Token types. Shape [B, L]

        Returns:
            Tuple[Tensor, Tensor, Tensor]: Processed inputs for the Transformer layer.
        """
        
        B, L, d = seq_embed.shape
        V = visit_embed.shape[1]

        # Part 1: prepare main seq embedding x
        # important: initiate new tensor to ensure safe autograd
        x = torch.empty(B, 1 + L + V, d, device=seq_embed.device, dtype=seq_embed.dtype)
        x[:, 0:1, :] = task_id_embed
        x[:, 1:1 + L, :] = seq_embed
        x[:, 1 + L:, :] = visit_embed
        
        # we already have token_types for main seq, just prepare token types for visit nodes
        # here it is out of the if branch because it is needed in the mask making
        arange_V = torch.arange(1, V + 1, device=self.device, dtype=torch.long)[None, :]  # [1, V]
        n_v = num_visits.view(B, 1)  # [B, 1]
        visit_index = torch.where(arange_V <= n_v, arange_V, torch.full((B, V), self.adm_pad_id, device=self.device, dtype=torch.long))  # [B, V]
        visit_type_id = torch.full((B, V), self.node_type_id_dict['visit'], dtype=torch.long, device=self.device)  # [B, V]
        visit_type_id_mask = (visit_index != self.adm_pad_id).long() # [B, V]
        visit_type_id = visit_type_id * visit_type_id_mask # [B, V]
        token_types = torch.cat([token_types, visit_type_id], dim=1)  # [B, L+V]

        # if it is the first time transformer going through, we need extra information of admission index and token types
        if (layer_i == 0) or (layer_i == 1 and self.layer_types[0] == 'gnn'):
            adm_index = torch.cat([adm_index, visit_index], dim=1)  # [B, L+V]
            # transform into embedding and add
            adm_index_embed = self.adm_index_emb(adm_index)
            token_type_embed = self.type_emb(token_types)
            x_non_task = x[:, 1:, :]
            x_non_task.add_(adm_index_embed).add_(token_type_embed)
            x[:, 1:, :] = x_non_task
        else:
            x = x
            
        # part 2: prepare mask (src_key_padding_mask and attn_mask)
        task_pad_mask = torch.zeros((B, 1), dtype=torch.bool, device=self.device) # [B, 1]
        seq_pad_mask = (input_ids == self.seq_pad_id) # [B, L], bool
        visit_pad_mask = (visit_index == self.adm_pad_id) # [B, V], bool
        src_key_padding_mask = torch.cat([task_pad_mask, seq_pad_mask, visit_pad_mask], dim=1)  # [B, 1+L+V]
        attn_mask = self.build_attn_mask(torch.cat([torch.full((B, 1), -1, device=self.device), token_types], dim=1), 
                                         forbid_map=self.forbid_map, 
                                         num_heads=self.num_attn_heads)
        assert attn_mask.dtype == src_key_padding_mask.dtype, f"attn_mask dtype ({attn_mask.dtype}) and src_key_padding_mask dtype ({src_key_padding_mask.dtype}) must match"
        return x, src_key_padding_mask, attn_mask
    
    def process_tf_out(self, h, L, V):
        return h[:, 0:1, :], h[:, 1:1 + L, :], h[:, 1 + L:, :]  # [B, 1, d], [B, L, d], [B, V, d]
        
    @staticmethod
    def build_attn_mask(token_types, forbid_map, num_heads):
        B, L = token_types.shape
        device = token_types.device
        
        if forbid_map == None:
            mask = torch.zeros((B, L, L), dtype=torch.bool, device=device)
        else:
            # 收集所有出现的 token 类型
            observed = torch.unique(token_types)
            for q_t, ks in forbid_map.items():
                observed = torch.unique(torch.cat([observed, torch.tensor([q_t] + list(ks), device=device)]))
            type_list = observed.sort().values
            t2i = {t.item(): i for i, t in enumerate(type_list)}  # Map token types to indices
            T = len(type_list)

            # 构造禁止矩阵 (T, T)，单向关系
            ban_table = torch.zeros((T, T), dtype=torch.bool, device=device)
            for q_t, ks in forbid_map.items():
                if q_t in t2i:
                    qi = t2i[q_t]
                    for k_t in ks:
                        if k_t in t2i:
                            ban_table[qi, t2i[k_t]] = True  # 只设置 q -> k 的禁止

            # 向量化映射 token_types 到类型索引
            mapping = torch.zeros_like(type_list, dtype=torch.long, device=device)
            for t, i in t2i.items():
                mapping[type_list == t] = i
            q_idx = mapping[torch.searchsorted(type_list, token_types.unsqueeze(-1))]
            k_idx = mapping[torch.searchsorted(type_list, token_types.unsqueeze(-2))]

            # 查询 ban_table 得到 (B, L, L)
            mask = ban_table[q_idx, k_idx].to(torch.bool)
        
        # 扩展到 num_heads
        mask = mask.unsqueeze(1).expand(B, num_heads, L, L)
        mask = mask.reshape(B * num_heads, L, L)
        return mask

In [None]:
final_metrics = {"precision":[],"recall":[],"f1":[],"auc":[],"prauc":[]}
for i in range(5):
    model = HeteroGT(tokenizer, d_model=128, num_heads=4, layer_types=['gnn', 'tf', 'tf'], max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, label_vocab_size=config.label_vocab_size).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 27.91it/s, loss=0.5407]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.20it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 49.99it/s]


Validation: {'precision': 0.941176470034602, 'recall': 0.009428403064175438, 'f1': 0.01866977809976595, 'auc': 0.8270847765550995, 'prauc': 0.6725166272479457}
Test:      {'precision': 0.999999999090909, 'recall': 0.006090808416356086, 'f1': 0.012107869995096155, 'auc': 0.8291711537107425, 'prauc': 0.6869925619316344}


Epoch 002: 100%|██████████| 98/98 [00:02<00:00, 33.34it/s, loss=0.4419]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 46.27it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.92it/s]


Validation: {'precision': 0.5060670949303853, 'recall': 0.8355922215625481, 'f1': 0.6303622980327604, 'auc': 0.8638226931881806, 'prauc': 0.7424091455144894}
Test:      {'precision': 0.5363540569001182, 'recall': 0.8455149501614313, 'f1': 0.6563507367045712, 'auc': 0.8681820319305197, 'prauc': 0.7574347222412666}


Epoch 003: 100%|██████████| 98/98 [00:02<00:00, 33.22it/s, loss=0.4129]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 49.50it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 52.96it/s]


Validation: {'precision': 0.7363221884442529, 'recall': 0.5710076605741249, 'f1': 0.6432127398483329, 'auc': 0.8559164034004487, 'prauc': 0.734058306841839}
Test:      {'precision': 0.7541229385250816, 'recall': 0.5570321151685658, 'f1': 0.6407643263190881, 'auc': 0.8632502842410088, 'prauc': 0.7515723080674327}


Epoch 004: 100%|██████████| 98/98 [00:02<00:00, 33.26it/s, loss=0.3928]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 47.49it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.49it/s]


Validation: {'precision': 0.6800242865775348, 'recall': 0.6599882144922806, 'f1': 0.6698564543272552, 'auc': 0.8685314656728997, 'prauc': 0.7433820656122344}
Test:      {'precision': 0.7046939988074588, 'recall': 0.6566998892543926, 'f1': 0.6798509551628216, 'auc': 0.8716581061990392, 'prauc': 0.7559181629103608}


Epoch 005: 100%|██████████| 98/98 [00:02<00:00, 35.31it/s, loss=0.3579]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.06it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.19it/s]


Validation: {'precision': 0.5354213273656397, 'recall': 0.8461991750097455, 'f1': 0.6558574969627081, 'auc': 0.8776108767962641, 'prauc': 0.7595318740076812}
Test:      {'precision': 0.5673889092891392, 'recall': 0.855481727570014, 'f1': 0.6822698119385502, 'auc': 0.8811724246090425, 'prauc': 0.7733133213754436}


Epoch 006: 100%|██████████| 98/98 [00:02<00:00, 35.78it/s, loss=0.3450]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 53.73it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 47.69it/s]


Validation: {'precision': 0.8948995363076523, 'recall': 0.34119033588484865, 'f1': 0.494027299753357, 'auc': 0.8756952533902485, 'prauc': 0.7597015502908834}
Test:      {'precision': 0.8849315068371927, 'recall': 0.357696566996912, 'f1': 0.5094637182935695, 'auc': 0.8813471669702346, 'prauc': 0.77830501251987}


Epoch 007: 100%|██████████| 98/98 [00:02<00:00, 34.11it/s, loss=0.2758]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 52.57it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 52.03it/s]


Validation: {'precision': 0.7009966777369824, 'recall': 0.7460223924528815, 'f1': 0.722809015839995, 'auc': 0.9043563442762065, 'prauc': 0.8107132465830378}
Test:      {'precision': 0.7280934678665529, 'recall': 0.7591362126203813, 'f1': 0.7432908597311387, 'auc': 0.9073148420338903, 'prauc': 0.825189605590882}


Epoch 008: 100%|██████████| 98/98 [00:02<00:00, 33.73it/s, loss=0.2463]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.42it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.12it/s]


Validation: {'precision': 0.7660550458665397, 'recall': 0.6888626988763179, 'f1': 0.7254111026732408, 'auc': 0.9093606485527246, 'prauc': 0.8187175782911926}
Test:      {'precision': 0.7850122850074631, 'recall': 0.7076411960093708, 'f1': 0.744321485981726, 'auc': 0.9137372854774387, 'prauc': 0.8364219792955678}


Epoch 009: 100%|██████████| 98/98 [00:02<00:00, 34.12it/s, loss=0.2338]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 53.85it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.51it/s]


Validation: {'precision': 0.7720694645385524, 'recall': 0.6287566293421994, 'f1': 0.6930821645833941, 'auc': 0.9027737450245975, 'prauc': 0.7999848144666126}
Test:      {'precision': 0.8034602076068965, 'recall': 0.6428571428535833, 'f1': 0.7142417668197931, 'auc': 0.9065591228684386, 'prauc': 0.8243903418336553}


Epoch 010: 100%|██████████| 98/98 [00:02<00:00, 33.18it/s, loss=0.1920]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.90it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.95it/s]


Validation: {'precision': 0.6687022900729328, 'recall': 0.7743076016454078, 'f1': 0.7176406285564485, 'auc': 0.9041018177600364, 'prauc': 0.8065329894085107}
Test:      {'precision': 0.6919339164203502, 'recall': 0.7884828349900971, 'f1': 0.7370600364253193, 'auc': 0.9044254072315641, 'prauc': 0.8209003602428949}


Epoch 011: 100%|██████████| 98/98 [00:02<00:00, 33.24it/s, loss=0.1851]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 47.28it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.40it/s]


Validation: {'precision': 0.6928327645011785, 'recall': 0.7177371832603552, 'f1': 0.7050651180076075, 'auc': 0.8981911180529112, 'prauc': 0.7984097647064279}
Test:      {'precision': 0.7228525121516864, 'recall': 0.7408637873713131, 'f1': 0.7317473288769849, 'auc': 0.9001044884228109, 'prauc': 0.8173712107700126}


Epoch 012: 100%|██████████| 98/98 [00:02<00:00, 33.48it/s, loss=0.1587]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.00it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.12it/s]


Validation: {'precision': 0.6523061327893953, 'recall': 0.7583971714746117, 'f1': 0.7013623928446199, 'auc': 0.8955639700822443, 'prauc': 0.7955193818506425}
Test:      {'precision': 0.6824324324291389, 'recall': 0.7829457364297733, 'f1': 0.7292418722760813, 'auc': 0.8992429020994688, 'prauc': 0.8145904934178873}


Epoch 013: 100%|██████████| 98/98 [00:02<00:00, 34.70it/s, loss=0.1325]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.00it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 47.69it/s]


Validation: {'precision': 0.7727930535399943, 'recall': 0.6293459045337104, 'f1': 0.6937317261293464, 'auc': 0.8992635784627007, 'prauc': 0.7958770042156564}
Test:      {'precision': 0.7993174061378887, 'recall': 0.648394241413907, 'f1': 0.7159889892413409, 'auc': 0.9037461340022168, 'prauc': 0.8171573221774496}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7660550458665397, 'recall': 0.6888626988763179, 'f1': 0.7254111026732408, 'auc': 0.9093606485527246, 'prauc': 0.8187175782911926}
Corresponding test performance:
{'precision': 0.7850122850074631, 'recall': 0.7076411960093708, 'f1': 0.744321485981726, 'auc': 0.9137372854774387, 'prauc': 0.8364219792955678}


Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 32.32it/s, loss=0.5431]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 52.99it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.63it/s]


Validation: {'precision': 0.5903682719513293, 'recall': 0.6140247495544253, 'f1': 0.6019641775518889, 'auc': 0.8032119777304307, 'prauc': 0.6203549427520596}
Test:      {'precision': 0.6147308781834859, 'recall': 0.6007751937951231, 'f1': 0.6076729157477458, 'auc': 0.805123613447985, 'prauc': 0.6388426087922203}


Epoch 002: 100%|██████████| 98/98 [00:02<00:00, 35.70it/s, loss=0.4629]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.17it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 53.19it/s]


Validation: {'precision': 0.7831578947285983, 'recall': 0.43842074248415785, 'f1': 0.5621458208567426, 'auc': 0.847950330220211, 'prauc': 0.7070830861559961}
Test:      {'precision': 0.7871093749923135, 'recall': 0.44629014396209143, 'f1': 0.569611302798247, 'auc': 0.8507940283044464, 'prauc': 0.7203806721012638}


Epoch 003: 100%|██████████| 98/98 [00:02<00:00, 33.47it/s, loss=0.4167]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.73it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.63it/s]


Validation: {'precision': 0.7896950578255553, 'recall': 0.44254566882473456, 'f1': 0.5672205391991991, 'auc': 0.860274714907947, 'prauc': 0.7236241325795691}
Test:      {'precision': 0.8036809815868744, 'recall': 0.435215946841444, 'f1': 0.5646551678520115, 'auc': 0.858603146864338, 'prauc': 0.7291998275909266}


Epoch 004: 100%|██████████| 98/98 [00:02<00:00, 33.82it/s, loss=0.3966]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.32it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.86it/s]


Validation: {'precision': 0.7679603633297444, 'recall': 0.5480259281051972, 'f1': 0.6396148507060943, 'auc': 0.8776711072867908, 'prauc': 0.7548854870531209}
Test:      {'precision': 0.7752895752835886, 'recall': 0.555924695456501, 'f1': 0.6475330489851908, 'auc': 0.8780978453817545, 'prauc': 0.7597448947518157}


Epoch 005: 100%|██████████| 98/98 [00:02<00:00, 33.59it/s, loss=0.3717]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.23it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.86it/s]


Validation: {'precision': 0.8461538461456072, 'recall': 0.5120801414230284, 'f1': 0.6380323007309882, 'auc': 0.8971033924392658, 'prauc': 0.7877199610326849}
Test:      {'precision': 0.8435251798485295, 'recall': 0.5193798449583645, 'f1': 0.6429060953469592, 'auc': 0.8956792412229971, 'prauc': 0.7931190662497049}


Epoch 006: 100%|██████████| 98/98 [00:02<00:00, 35.60it/s, loss=0.3176]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.74it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 47.04it/s]


Validation: {'precision': 0.6435316336135811, 'recall': 0.8031820860294451, 'f1': 0.7145478325407613, 'auc': 0.8978679300549637, 'prauc': 0.7934272043948499}
Test:      {'precision': 0.656292286871194, 'recall': 0.8056478405271006, 'f1': 0.7233407855034744, 'auc': 0.8975788180990493, 'prauc': 0.8034465552956903}


Epoch 007: 100%|██████████| 98/98 [00:02<00:00, 33.51it/s, loss=0.3013]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.76it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.01it/s]


Validation: {'precision': 0.7786589762020285, 'recall': 0.636417206831842, 'f1': 0.700389100104344, 'auc': 0.9024797461402985, 'prauc': 0.8004361601910164}
Test:      {'precision': 0.7768154563572498, 'recall': 0.6456256921337452, 'f1': 0.7051708447509966, 'auc': 0.8990212581253044, 'prauc': 0.804512791513452}


Epoch 008: 100%|██████████| 98/98 [00:02<00:00, 33.35it/s, loss=0.2828]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.00it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.85it/s]


Validation: {'precision': 0.7040586245732579, 'recall': 0.736004714197195, 'f1': 0.7196773214172133, 'auc': 0.9000643310180895, 'prauc': 0.7924424742533838}
Test:      {'precision': 0.7062863180100039, 'recall': 0.7403100775152808, 'f1': 0.7228980755611717, 'auc': 0.8974362667399366, 'prauc': 0.7997546450062547}


Epoch 009: 100%|██████████| 98/98 [00:02<00:00, 34.76it/s, loss=0.2601]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.82it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 47.94it/s]


Validation: {'precision': 0.7499999999951924, 'recall': 0.6894519740678289, 'f1': 0.7184525587133695, 'auc': 0.8978336312496479, 'prauc': 0.7952172752189735}
Test:      {'precision': 0.7527472527426573, 'recall': 0.6827242524879141, 'f1': 0.7160278695721998, 'auc': 0.8985145729835877, 'prauc': 0.8015081918646673}


Epoch 010: 100%|██████████| 98/98 [00:03<00:00, 32.19it/s, loss=0.2380]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 52.40it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.29it/s]


Validation: {'precision': 0.7325728770549267, 'recall': 0.6812021213866754, 'f1': 0.705954193475573, 'auc': 0.8890669332563133, 'prauc': 0.7819407161665561}
Test:      {'precision': 0.7403076923031366, 'recall': 0.6661129568069429, 'f1': 0.7012532739372538, 'auc': 0.8903705572108582, 'prauc': 0.7843920783878702}


Epoch 011: 100%|██████████| 98/98 [00:02<00:00, 35.32it/s, loss=0.2363]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.69it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 53.84it/s]


Validation: {'precision': 0.705464480870462, 'recall': 0.7607542722406556, 'f1': 0.732066907393092, 'auc': 0.902114850284118, 'prauc': 0.798545779672157}
Test:      {'precision': 0.7037422037385461, 'recall': 0.749723145067831, 'f1': 0.7260053569314062, 'auc': 0.9004616669793442, 'prauc': 0.800177122116429}


Epoch 012: 100%|██████████| 98/98 [00:02<00:00, 34.66it/s, loss=0.1989]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.96it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 52.07it/s]


Validation: {'precision': 0.7230948225670559, 'recall': 0.7324690630481293, 'f1': 0.7277517514362276, 'auc': 0.9022800849384046, 'prauc': 0.7949692655697613}
Test:      {'precision': 0.7337845459631485, 'recall': 0.7203765226981154, 'f1': 0.7270187153092991, 'auc': 0.8973834562623378, 'prauc': 0.7963787374401666}


Epoch 013: 100%|██████████| 98/98 [00:02<00:00, 33.50it/s, loss=0.1907]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.43it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.31it/s]


Validation: {'precision': 0.7709190672100761, 'recall': 0.6623453152583245, 'f1': 0.7125198048498494, 'auc': 0.9033314678700672, 'prauc': 0.808553602434262}
Test:      {'precision': 0.7856668878514554, 'recall': 0.6555924695423279, 'f1': 0.7147600312573588, 'auc': 0.8985881875887254, 'prauc': 0.8040127508654858}


Epoch 014: 100%|██████████| 98/98 [00:02<00:00, 33.91it/s, loss=0.1770]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.13it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.91it/s]


Validation: {'precision': 0.7207908927458312, 'recall': 0.7088980553876907, 'f1': 0.714795003908755, 'auc': 0.8935604599795536, 'prauc': 0.7905234376219142}
Test:      {'precision': 0.7289334118990635, 'recall': 0.6849390919120436, 'f1': 0.7062517791857933, 'auc': 0.8884482681363982, 'prauc': 0.7844341094976768}


Epoch 015: 100%|██████████| 98/98 [00:02<00:00, 33.67it/s, loss=0.1632]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.13it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 46.72it/s]


Validation: {'precision': 0.8470464134931746, 'recall': 0.47318797878330476, 'f1': 0.607183360235671, 'auc': 0.8859455503591411, 'prauc': 0.7796107165901567}
Test:      {'precision': 0.8604887983619095, 'recall': 0.46788482834735395, 'f1': 0.6061692924194952, 'auc': 0.8853523308255127, 'prauc': 0.7841210969594055}


Epoch 016: 100%|██████████| 98/98 [00:02<00:00, 33.65it/s, loss=0.1491]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.88it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.78it/s]


Validation: {'precision': 0.7691751085327846, 'recall': 0.6263995285761557, 'f1': 0.6904839183995852, 'auc': 0.8865701462943048, 'prauc': 0.7817705695260377}
Test:      {'precision': 0.7632120796104104, 'recall': 0.6157253599079971, 'f1': 0.6815813618933522, 'auc': 0.8796238096715312, 'prauc': 0.7769327263090837}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.705464480870462, 'recall': 0.7607542722406556, 'f1': 0.732066907393092, 'auc': 0.902114850284118, 'prauc': 0.798545779672157}
Corresponding test performance:
{'precision': 0.7037422037385461, 'recall': 0.749723145067831, 'f1': 0.7260053569314062, 'auc': 0.9004616669793442, 'prauc': 0.800177122116429}


Epoch 001: 100%|██████████| 98/98 [00:02<00:00, 33.27it/s, loss=0.5381]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.66it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.99it/s]


Validation: {'precision': 0.5946808510606666, 'recall': 0.6588096641092587, 'f1': 0.6251048314647218, 'auc': 0.8220735103236849, 'prauc': 0.680911998129169}
Test:      {'precision': 0.625592417058317, 'recall': 0.6578073089664573, 'f1': 0.6412955415583931, 'auc': 0.8241091032460841, 'prauc': 0.6812715238835922}


Epoch 002: 100%|██████████| 98/98 [00:02<00:00, 33.55it/s, loss=0.4426]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.99it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.85it/s]


Validation: {'precision': 0.7228127555133049, 'recall': 0.5209192692956929, 'f1': 0.6054794471824007, 'auc': 0.8449829406556578, 'prauc': 0.7034504922477456}
Test:      {'precision': 0.7358339983979583, 'recall': 0.5105204872618465, 'f1': 0.6028113714262162, 'auc': 0.8470486698407118, 'prauc': 0.7047621906904948}


Epoch 003: 100%|██████████| 98/98 [00:02<00:00, 35.11it/s, loss=0.4065]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 49.89it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 44.39it/s]


Validation: {'precision': 0.6526027397224515, 'recall': 0.7018267530895591, 'f1': 0.6763202675751657, 'auc': 0.8649288775629734, 'prauc': 0.7373068146407635}
Test:      {'precision': 0.6692913385791639, 'recall': 0.7059800664412736, 'f1': 0.6871463167460153, 'auc': 0.8702910041967711, 'prauc': 0.7436024620405515}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 32.66it/s, loss=0.3726]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.89it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.86it/s]


Validation: {'precision': 0.7611111111050706, 'recall': 0.5651149086590153, 'f1': 0.6486303637216564, 'auc': 0.8696287719584208, 'prauc': 0.7535856343648885}
Test:      {'precision': 0.7698113207489071, 'recall': 0.5647840531530189, 'f1': 0.6515490209841704, 'auc': 0.8740402403529167, 'prauc': 0.7586272078464743}


Epoch 005: 100%|██████████| 98/98 [00:02<00:00, 35.60it/s, loss=0.3394]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.36it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 53.95it/s]


Validation: {'precision': 0.6801169590603502, 'recall': 0.6853270477272521, 'f1': 0.6827120583949499, 'auc': 0.8749848785062225, 'prauc': 0.7582076195393439}
Test:      {'precision': 0.7048710601678805, 'recall': 0.681063122919817, 'f1': 0.6927625970814938, 'auc': 0.8767674891313821, 'prauc': 0.7635081676173703}


Epoch 006: 100%|██████████| 98/98 [00:02<00:00, 34.19it/s, loss=0.3155]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.36it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.91it/s]


Validation: {'precision': 0.7086871325884475, 'recall': 0.6393635827893969, 'f1': 0.672242869854163, 'auc': 0.8703636094913033, 'prauc': 0.7479971910905587}
Test:      {'precision': 0.7383237363996269, 'recall': 0.6389811738613568, 'f1': 0.6850697486580395, 'auc': 0.8731595733208852, 'prauc': 0.7535651817239255}


Epoch 007: 100%|██████████| 98/98 [00:02<00:00, 33.48it/s, loss=0.3086]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.86it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.00it/s]


Validation: {'precision': 0.7928388746735479, 'recall': 0.5480259281051972, 'f1': 0.648083618855538, 'auc': 0.873641157340053, 'prauc': 0.7545714850686935}
Test:      {'precision': 0.7801302931532563, 'recall': 0.5304540420790119, 'f1': 0.6315095535161296, 'auc': 0.8731195653833106, 'prauc': 0.7571123721977252}


Epoch 008: 100%|██████████| 98/98 [00:02<00:00, 33.68it/s, loss=0.2988]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 47.57it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.93it/s]


Validation: {'precision': 0.5925760286198901, 'recall': 0.7807896287520284, 'f1': 0.6737859091509948, 'auc': 0.876824239764106, 'prauc': 0.7664093304485518}
Test:      {'precision': 0.6219035202059022, 'recall': 0.7923588039823236, 'f1': 0.6968590162525845, 'auc': 0.8756832124724314, 'prauc': 0.7646085394613574}


Epoch 009: 100%|██████████| 98/98 [00:02<00:00, 33.39it/s, loss=0.2662]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.69it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.01it/s]


Validation: {'precision': 0.7350746268601861, 'recall': 0.5804360636383004, 'f1': 0.6486664422166132, 'auc': 0.8686933151564568, 'prauc': 0.7508515338071821}
Test:      {'precision': 0.7556338028115801, 'recall': 0.5941306755227347, 'f1': 0.6652200818622523, 'auc': 0.8702271145979978, 'prauc': 0.7594509273592058}


Epoch 010: 100%|██████████| 98/98 [00:02<00:00, 33.37it/s, loss=0.2438]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.16it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.81it/s]


Validation: {'precision': 0.643021914644504, 'recall': 0.6570418385347258, 'f1': 0.6499562759644407, 'auc': 0.8559761229218086, 'prauc': 0.733698906742223}
Test:      {'precision': 0.673660960791421, 'recall': 0.6755260243594933, 'f1': 0.6745921984798294, 'auc': 0.855983734865305, 'prauc': 0.7435088480529006}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6801169590603502, 'recall': 0.6853270477272521, 'f1': 0.6827120583949499, 'auc': 0.8749848785062225, 'prauc': 0.7582076195393439}
Corresponding test performance:
{'precision': 0.7048710601678805, 'recall': 0.681063122919817, 'f1': 0.6927625970814938, 'auc': 0.8767674891313821, 'prauc': 0.7635081676173703}


Epoch 001: 100%|██████████| 98/98 [00:02<00:00, 34.18it/s, loss=0.5707]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.34it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.48it/s]


Validation: {'precision': 0.999999995, 'recall': 0.0011785503830219297, 'f1': 0.0023543260506180114, 'auc': 0.7731654578200701, 'prauc': 0.5346815875728816}
Test:      {'precision': 0.999999995, 'recall': 0.001107419712064743, 'f1': 0.0022123893584070796, 'auc': 0.7683907256430444, 'prauc': 0.537357991828385}


Epoch 002: 100%|██████████| 98/98 [00:02<00:00, 35.06it/s, loss=0.4701]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 52.64it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 44.53it/s]


Validation: {'precision': 0.8207547169700706, 'recall': 0.3588685916301776, 'f1': 0.4993849896124159, 'auc': 0.8473258620273392, 'prauc': 0.7173317912140598}
Test:      {'precision': 0.8145161290213102, 'recall': 0.33554817275561716, 'f1': 0.4752941135105698, 'auc': 0.8427756990063753, 'prauc': 0.7226832939925175}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 32.01it/s, loss=0.4336]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.80it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.79it/s]


Validation: {'precision': 0.691176470583816, 'recall': 0.637006482023353, 'f1': 0.6629868088650299, 'auc': 0.8609288193124884, 'prauc': 0.7316795675815823}
Test:      {'precision': 0.7175718849794405, 'recall': 0.6218161683243533, 'f1': 0.6662711311829792, 'auc': 0.861449249992737, 'prauc': 0.7367312853712331}


Epoch 004: 100%|██████████| 98/98 [00:02<00:00, 35.21it/s, loss=0.3992]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.00it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 53.51it/s]


Validation: {'precision': 0.9114219114006661, 'recall': 0.23040659988078724, 'f1': 0.36782690176104393, 'auc': 0.871968627515102, 'prauc': 0.7482940727598212}
Test:      {'precision': 0.9160671462610056, 'recall': 0.21151716500436593, 'f1': 0.3436797090497436, 'auc': 0.8740305153465525, 'prauc': 0.7536902603076712}


Epoch 005: 100%|██████████| 98/98 [00:02<00:00, 34.26it/s, loss=0.3811]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.18it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.76it/s]


Validation: {'precision': 0.845671267241585, 'recall': 0.3971714790783903, 'f1': 0.5404971889106205, 'auc': 0.8721704603359545, 'prauc': 0.7521565550850196}
Test:      {'precision': 0.8458646616435356, 'recall': 0.37375415282185076, 'f1': 0.5184331754687401, 'auc': 0.8699902676075548, 'prauc': 0.751594396438346}


Epoch 006: 100%|██████████| 98/98 [00:02<00:00, 33.31it/s, loss=0.3504]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.58it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.39it/s]


Validation: {'precision': 0.7612380250496593, 'recall': 0.6087212728308267, 'f1': 0.6764898444354062, 'auc': 0.8843349755456544, 'prauc': 0.7725509754756503}
Test:      {'precision': 0.7607621736008415, 'recall': 0.5968992248028965, 'f1': 0.6689419745908698, 'auc': 0.8802798167464118, 'prauc': 0.7710696490745493}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 30.19it/s, loss=0.3214]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.61it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 52.09it/s]


Validation: {'precision': 0.7418712674137898, 'recall': 0.6588096641092587, 'f1': 0.6978776479470595, 'auc': 0.8872559946583283, 'prauc': 0.7752347681446365}
Test:      {'precision': 0.733757961778766, 'recall': 0.637873754149292, 'f1': 0.6824644499966942, 'auc': 0.88225953259896, 'prauc': 0.770026932001871}


Epoch 008: 100%|██████████| 98/98 [00:02<00:00, 33.58it/s, loss=0.2870]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 52.20it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.72it/s]


Validation: {'precision': 0.7600574712589077, 'recall': 0.6234531526186008, 'f1': 0.6850113255719892, 'auc': 0.8850135425990524, 'prauc': 0.7743224390447352}
Test:      {'precision': 0.750865051897918, 'recall': 0.6007751937951231, 'f1': 0.6674869221569003, 'auc': 0.8770670562578077, 'prauc': 0.7629590840907517}


Epoch 009: 100%|██████████| 98/98 [00:02<00:00, 33.75it/s, loss=0.2687]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.92it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.74it/s]


Validation: {'precision': 0.8387715930821615, 'recall': 0.5150265173805832, 'f1': 0.6381891153981069, 'auc': 0.8913787493799709, 'prauc': 0.7832011752870942}
Test:      {'precision': 0.8239700374454685, 'recall': 0.48726467330848694, 'f1': 0.6123869125140184, 'auc': 0.8827547077647898, 'prauc': 0.7758852107816818}


Epoch 010: 100%|██████████| 98/98 [00:02<00:00, 33.67it/s, loss=0.2357]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.13it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.58it/s]


Validation: {'precision': 0.7682926829213179, 'recall': 0.6311137301082433, 'f1': 0.6929796132900845, 'auc': 0.8902698824170527, 'prauc': 0.7819283359864371}
Test:      {'precision': 0.7600281491853623, 'recall': 0.5980066445149612, 'f1': 0.6693523347013564, 'auc': 0.8802198048400496, 'prauc': 0.7748869906894172}


Epoch 011: 100%|██████████| 98/98 [00:02<00:00, 33.28it/s, loss=0.2377]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 47.05it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.33it/s]


Validation: {'precision': 0.7906018136785606, 'recall': 0.5651149086590153, 'f1': 0.6591065243434089, 'auc': 0.8793924346671629, 'prauc': 0.7630732716825817}
Test:      {'precision': 0.7801587301525384, 'recall': 0.5442967884798212, 'f1': 0.6412263487095043, 'auc': 0.8653582101458159, 'prauc': 0.758063707247648}


Epoch 012: 100%|██████████| 98/98 [00:02<00:00, 34.21it/s, loss=0.1963]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 52.75it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.00it/s]


Validation: {'precision': 0.5995698924705395, 'recall': 0.8214496169662849, 'f1': 0.6931874640393884, 'auc': 0.8942629148414954, 'prauc': 0.7927243488981313}
Test:      {'precision': 0.6203783318975908, 'recall': 0.799003322254712, 'f1': 0.6984511083381494, 'auc': 0.8805156173754104, 'prauc': 0.7809633092724552}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7418712674137898, 'recall': 0.6588096641092587, 'f1': 0.6978776479470595, 'auc': 0.8872559946583283, 'prauc': 0.7752347681446365}
Corresponding test performance:
{'precision': 0.733757961778766, 'recall': 0.637873754149292, 'f1': 0.6824644499966942, 'auc': 0.88225953259896, 'prauc': 0.770026932001871}


Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 32.09it/s, loss=0.5496]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.72it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.24it/s]


Validation: {'precision': 0.7259786476803739, 'recall': 0.4808485562729473, 'f1': 0.5785182511397969, 'auc': 0.8199886922523407, 'prauc': 0.6731242263151953}
Test:      {'precision': 0.7301451750578126, 'recall': 0.47342192690767765, 'f1': 0.5744037574003178, 'auc': 0.8228333424428331, 'prauc': 0.6875632050644037}


Epoch 002: 100%|██████████| 98/98 [00:02<00:00, 35.38it/s, loss=0.4602]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.46it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.35it/s]


Validation: {'precision': 0.7022742935857873, 'recall': 0.6004714201496731, 'f1': 0.6473951665639044, 'auc': 0.8469802552572021, 'prauc': 0.7153618236524658}
Test:      {'precision': 0.714380825561156, 'recall': 0.5941306755227347, 'f1': 0.6487303457033591, 'auc': 0.8510093941099451, 'prauc': 0.7291885431404185}


Epoch 003: 100%|██████████| 98/98 [00:02<00:00, 34.74it/s, loss=0.4160]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.06it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.16it/s]


Validation: {'precision': 0.8705882352813149, 'recall': 0.3488509133744912, 'f1': 0.4981068532943444, 'auc': 0.8568031265181373, 'prauc': 0.728840300466693}
Test:      {'precision': 0.8628571428448163, 'recall': 0.3344407530435524, 'f1': 0.4820430925382977, 'auc': 0.8611438355523581, 'prauc': 0.7432614084682263}


Epoch 004: 100%|██████████| 98/98 [00:02<00:00, 33.59it/s, loss=0.3712]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.05it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.14it/s]


Validation: {'precision': 0.7679521276544684, 'recall': 0.6806128461951644, 'f1': 0.7216494795497503, 'auc': 0.9017211485411638, 'prauc': 0.802643875895946}
Test:      {'precision': 0.7721196689957217, 'recall': 0.6716500553672666, 'f1': 0.7183890977738818, 'auc': 0.9024374435888081, 'prauc': 0.8197235328337437}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 32.01it/s, loss=0.3113]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.85it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.26it/s]


Validation: {'precision': 0.7675324675274836, 'recall': 0.6965232763659605, 'f1': 0.7303058337468236, 'auc': 0.9105146724157511, 'prauc': 0.8092192903203199}
Test:      {'precision': 0.7715355805195285, 'recall': 0.6843853820560112, 'f1': 0.7253521076897154, 'auc': 0.9089408261676041, 'prauc': 0.8245798835498346}


Epoch 006: 100%|██████████| 98/98 [00:02<00:00, 33.73it/s, loss=0.2851]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.81it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.21it/s]


Validation: {'precision': 0.8075734157588287, 'recall': 0.6157925751289582, 'f1': 0.6987629506193652, 'auc': 0.9084409679237828, 'prauc': 0.8076705240101167}
Test:      {'precision': 0.8167272727213329, 'recall': 0.6218161683243533, 'f1': 0.7060672695293505, 'auc': 0.911370538991859, 'prauc': 0.8303168774849969}


Epoch 007: 100%|██████████| 98/98 [00:02<00:00, 33.53it/s, loss=0.2512]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.38it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.14it/s]


Validation: {'precision': 0.8439781021820806, 'recall': 0.5450795521476425, 'f1': 0.6623702064691627, 'auc': 0.9071358888008461, 'prauc': 0.8094500017551822}
Test:      {'precision': 0.8637152777702803, 'recall': 0.5509413067522096, 'f1': 0.6727518546043026, 'auc': 0.9072376574897074, 'prauc': 0.8258171493921703}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 30.27it/s, loss=0.2408]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.38it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.08it/s]


Validation: {'precision': 0.7346059113255259, 'recall': 0.7030053034725811, 'f1': 0.7184582906921574, 'auc': 0.9067220037758067, 'prauc': 0.8017244314920786}
Test:      {'precision': 0.7516034985378915, 'recall': 0.7137320044257269, 'f1': 0.7321783534200843, 'auc': 0.905454042081949, 'prauc': 0.8170594704372459}


Epoch 009: 100%|██████████| 98/98 [00:02<00:00, 34.55it/s, loss=0.2237]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 52.75it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.28it/s]


Validation: {'precision': 0.6484988452625936, 'recall': 0.8273423688813946, 'f1': 0.7270844072913053, 'auc': 0.9076473050655822, 'prauc': 0.8062419605578973}
Test:      {'precision': 0.6623376623346963, 'recall': 0.8189368770718775, 'f1': 0.7323594899767428, 'auc': 0.9047290367024202, 'prauc': 0.8169699955973226}


Epoch 010: 100%|██████████| 98/98 [00:03<00:00, 32.03it/s, loss=0.1996]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.60it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.70it/s]

Validation: {'precision': 0.7403486923988771, 'recall': 0.7006482027065372, 'f1': 0.7199515541880522, 'auc': 0.9050850501790882, 'prauc': 0.8008533250937044}
Test:      {'precision': 0.7520515826450642, 'recall': 0.7104097452895326, 'f1': 0.7306378082117382, 'auc': 0.9036270950002634, 'prauc': 0.8153206528891458}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7675324675274836, 'recall': 0.6965232763659605, 'f1': 0.7303058337468236, 'auc': 0.9105146724157511, 'prauc': 0.8092192903203199}
Corresponding test performance:
{'precision': 0.7715355805195285, 'recall': 0.6843853820560112, 'f1': 0.7253521076897154, 'auc': 0.9089408261676041, 'prauc': 0.8245798835498346}





In [None]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")


Final Metrics:
precision: 0.7398 ± 0.0335
recall: 0.6921 ± 0.0366
f1: 0.7142 ± 0.0230
auc: 0.8964 ± 0.0146
prauc: 0.7989 ± 0.0288


In [None]:
final_metrics = {"precision":[],"recall":[],"f1":[],"auc":[],"prauc":[]}
for i in range(5):
    model = HeteroGT(tokenizer, d_model=128, num_heads=4, layer_types=['tf', 'tf'], max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, label_vocab_size=config.label_vocab_size).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

Epoch 001:   0%|          | 0/98 [00:00<?, ?it/s, loss=0.6213]

Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 63.64it/s, loss=0.5596]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 167.12it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 167.11it/s]


Validation: {'precision': 0.6233354470474107, 'recall': 0.5792575132552784, 'f1': 0.6004886938423891, 'auc': 0.809171602735269, 'prauc': 0.6408833845365225}
Test:      {'precision': 0.6532823454387937, 'recall': 0.5675526024331808, 'f1': 0.6074074024284638, 'auc': 0.8128257569378687, 'prauc': 0.6525078605875494}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 64.52it/s, loss=0.4581]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 166.71it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 169.27it/s]


Validation: {'precision': 0.760489510482863, 'recall': 0.5126694166145394, 'f1': 0.6124603964522906, 'auc': 0.8569168171577315, 'prauc': 0.7253112747735293}
Test:      {'precision': 0.7721297107732504, 'recall': 0.4878183831645193, 'f1': 0.597896160842667, 'auc': 0.8581333921265364, 'prauc': 0.7342194003978738}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 61.65it/s, loss=0.4390]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 167.88it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 167.09it/s]


Validation: {'precision': 0.7866184448391808, 'recall': 0.5126694166145394, 'f1': 0.6207634629310145, 'auc': 0.8687799883013609, 'prauc': 0.7356737337761556}
Test:      {'precision': 0.7798245613966682, 'recall': 0.49224806201277826, 'f1': 0.6035302057062935, 'auc': 0.8677942011387367, 'prauc': 0.7438042802047115}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 60.47it/s, loss=0.4049]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 160.01it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 165.08it/s]


Validation: {'precision': 0.6891385767747246, 'recall': 0.6505598114281052, 'f1': 0.6692937203714134, 'auc': 0.8673941760500832, 'prauc': 0.7441746593272818}
Test:      {'precision': 0.7176100628885685, 'recall': 0.6317829457329359, 'f1': 0.6719670150398274, 'auc': 0.8680776419887859, 'prauc': 0.750978169198757}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 59.07it/s, loss=0.3734]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 164.66it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 163.63it/s]


Validation: {'precision': 0.8735019973252529, 'recall': 0.3865645256311929, 'f1': 0.5359477081605933, 'auc': 0.890156000164021, 'prauc': 0.7761607044866097}
Test:      {'precision': 0.8782722512974048, 'recall': 0.3715393133977213, 'f1': 0.522178984144723, 'auc': 0.8869051312088011, 'prauc': 0.7788561616103306}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 59.65it/s, loss=0.3457]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 160.81it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 162.68it/s]


Validation: {'precision': 0.6386186770396954, 'recall': 0.7737183264538968, 'f1': 0.6997068961877725, 'auc': 0.8891499018747586, 'prauc': 0.7812795380553967}
Test:      {'precision': 0.6524189760420271, 'recall': 0.769102990028964, 'f1': 0.7059720407734299, 'auc': 0.8875899439987356, 'prauc': 0.7871070687684083}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 59.07it/s, loss=0.3148]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 162.98it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 155.58it/s]


Validation: {'precision': 0.7596355991537518, 'recall': 0.6387743075978859, 'f1': 0.6939820692966703, 'auc': 0.895640104488085, 'prauc': 0.7881728028516937}
Test:      {'precision': 0.7809973045769475, 'recall': 0.6417497231415186, 'f1': 0.7045592655603294, 'auc': 0.8927039432315679, 'prauc': 0.7932997157000569}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 61.57it/s, loss=0.2957]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 156.29it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 159.68it/s]


Validation: {'precision': 0.7893569844731017, 'recall': 0.6293459045337104, 'f1': 0.7003278639114712, 'auc': 0.896274089481687, 'prauc': 0.7969338287764691}
Test:      {'precision': 0.7947443181761737, 'recall': 0.6196013289002237, 'f1': 0.6963285576112327, 'auc': 0.8932877513667943, 'prauc': 0.8003237852626941}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 61.70it/s, loss=0.2635]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 156.68it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 162.17it/s]


Validation: {'precision': 0.7719869706790099, 'recall': 0.6982911019404933, 'f1': 0.7332920742159451, 'auc': 0.9115511733704331, 'prauc': 0.8179693643529615}
Test:      {'precision': 0.7868956742952488, 'recall': 0.6849390919120436, 'f1': 0.7323860222547071, 'auc': 0.9073034551593496, 'prauc': 0.8216330507502276}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 59.01it/s, loss=0.2268]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 162.36it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 155.81it/s]


Validation: {'precision': 0.8029878618038937, 'recall': 0.5067766646994297, 'f1': 0.6213872784882373, 'auc': 0.8831871471837742, 'prauc': 0.76927831438089}
Test:      {'precision': 0.8242990654128571, 'recall': 0.4883720930205517, 'f1': 0.6133518729309756, 'auc': 0.8759267069355792, 'prauc': 0.7737627070666879}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 61.01it/s, loss=0.2146]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 153.70it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 159.68it/s]


Validation: {'precision': 0.7904903417474705, 'recall': 0.6269888037676665, 'f1': 0.6993098866163153, 'auc': 0.9043665636595407, 'prauc': 0.8030599373546229}
Test:      {'precision': 0.792932862185209, 'recall': 0.6212624584683208, 'f1': 0.6966780453642925, 'auc': 0.8953941385047717, 'prauc': 0.8060238961846886}


Epoch 012: 100%|██████████| 98/98 [00:01<00:00, 58.42it/s, loss=0.1999]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 162.39it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 153.06it/s]


Validation: {'precision': 0.7457180500609637, 'recall': 0.6670595167904122, 'f1': 0.7041990618851468, 'auc': 0.9001619261289321, 'prauc': 0.7963436889551831}
Test:      {'precision': 0.7571243523267026, 'recall': 0.6472868217018423, 'f1': 0.6979104427876106, 'auc': 0.8879012673037407, 'prauc': 0.7915668842847724}


Epoch 013: 100%|██████████| 98/98 [00:01<00:00, 63.62it/s, loss=0.1576]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 169.61it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 162.68it/s]


Validation: {'precision': 0.7766917293174685, 'recall': 0.6087212728308267, 'f1': 0.6825239461756952, 'auc': 0.8936399795561236, 'prauc': 0.791397983100384}
Test:      {'precision': 0.7715091678365903, 'recall': 0.6057585824994144, 'f1': 0.678660044695999, 'auc': 0.8854471804129015, 'prauc': 0.7907501461514056}


Epoch 014: 100%|██████████| 98/98 [00:01<00:00, 58.46it/s, loss=0.1400]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 159.77it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 156.70it/s]


Validation: {'precision': 0.7275661717189957, 'recall': 0.6641131408328573, 'f1': 0.6943930942051301, 'auc': 0.8997003293579506, 'prauc': 0.7864934430352951}
Test:      {'precision': 0.7389975550077079, 'recall': 0.6694352159431372, 'f1': 0.7024985423643033, 'auc': 0.892325406591437, 'prauc': 0.7852218473304479}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7719869706790099, 'recall': 0.6982911019404933, 'f1': 0.7332920742159451, 'auc': 0.9115511733704331, 'prauc': 0.8179693643529615}
Corresponding test performance:
{'precision': 0.7868956742952488, 'recall': 0.6849390919120436, 'f1': 0.7323860222547071, 'auc': 0.9073034551593496, 'prauc': 0.8216330507502276}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 61.06it/s, loss=0.5290]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 156.56it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 158.20it/s]


Validation: {'precision': 0.6630747126389147, 'recall': 0.5439010017646205, 'f1': 0.5976043977642028, 'auc': 0.8188234909387923, 'prauc': 0.667454329403183}
Test:      {'precision': 0.6910511363587284, 'recall': 0.5387596899194975, 'f1': 0.6054760373877782, 'auc': 0.819458395930368, 'prauc': 0.6720768318323151}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 60.80it/s, loss=0.5056]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 155.31it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 159.27it/s]


Validation: {'precision': 0.8156312625087048, 'recall': 0.23983500294496268, 'f1': 0.37067394912584484, 'auc': 0.8238790837454193, 'prauc': 0.6636670161845681}
Test:      {'precision': 0.8247011952026951, 'recall': 0.2292358803974018, 'f1': 0.3587521629707836, 'auc': 0.8283328950925403, 'prauc': 0.6851066139897712}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 58.33it/s, loss=0.4307]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 162.03it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 153.48it/s]


Validation: {'precision': 0.6734693877511752, 'recall': 0.6806128461951644, 'f1': 0.6770222693220793, 'auc': 0.866957489025979, 'prauc': 0.7353845909921312}
Test:      {'precision': 0.6990179087192432, 'recall': 0.6699889257991696, 'f1': 0.6841956410260867, 'auc': 0.8701965854640946, 'prauc': 0.7495090159744304}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 62.00it/s, loss=0.3780]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 153.38it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 160.19it/s]


Validation: {'precision': 0.6139122315564244, 'recall': 0.7748968768369188, 'f1': 0.6850742331464473, 'auc': 0.8756983830763946, 'prauc': 0.753044288872592}
Test:      {'precision': 0.6355738454474825, 'recall': 0.7696566998849964, 'f1': 0.6962183772108304, 'auc': 0.8804460035640302, 'prauc': 0.7755358257863573}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 59.72it/s, loss=0.3539]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 161.58it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 166.61it/s]


Validation: {'precision': 0.8180933852060498, 'recall': 0.49558043606072144, 'f1': 0.6172477017188508, 'auc': 0.8825290188970448, 'prauc': 0.7601453171904047}
Test:      {'precision': 0.8445945945864423, 'recall': 0.4844961240283251, 'f1': 0.6157635421607278, 'auc': 0.8858243629382223, 'prauc': 0.779906603198566}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 63.27it/s, loss=0.3568]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 165.36it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 151.53it/s]


Validation: {'precision': 0.5482853764266918, 'recall': 0.8196817913917521, 'f1': 0.6570618752125919, 'auc': 0.8730387885636415, 'prauc': 0.7582714586849323}
Test:      {'precision': 0.5768330733206832, 'recall': 0.8189368770718775, 'f1': 0.676887867000883, 'auc': 0.8746513769870403, 'prauc': 0.7772375957324399}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 56.86it/s, loss=0.3339]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 144.91it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 147.25it/s]


Validation: {'precision': 0.6657810839497037, 'recall': 0.7383618149632389, 'f1': 0.7001955803684857, 'auc': 0.8773805574443673, 'prauc': 0.7603766484249079}
Test:      {'precision': 0.6837965390630111, 'recall': 0.7220376522662124, 'f1': 0.7023969785711525, 'auc': 0.8803349661495918, 'prauc': 0.7692468118560818}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 57.06it/s, loss=0.2963]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 148.64it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 154.32it/s]


Validation: {'precision': 0.7653203342565089, 'recall': 0.6476134354705504, 'f1': 0.701563991200027, 'auc': 0.8927779107007291, 'prauc': 0.7819606349300485}
Test:      {'precision': 0.781144781139521, 'recall': 0.642303432997551, 'f1': 0.7049528968968249, 'auc': 0.8942862263873153, 'prauc': 0.7960472652162922}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 63.74it/s, loss=0.2765]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 168.57it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 168.49it/s]


Validation: {'precision': 0.7528409090855622, 'recall': 0.6246317030016227, 'f1': 0.6827697212869048, 'auc': 0.8848690660671636, 'prauc': 0.7716935823021781}
Test:      {'precision': 0.7708039492188237, 'recall': 0.6052048726433821, 'f1': 0.678039697301462, 'auc': 0.8823357323323717, 'prauc': 0.7825723921763883}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 65.84it/s, loss=0.2525]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 167.52it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 164.73it/s]


Validation: {'precision': 0.7314285714239275, 'recall': 0.6788450206206315, 'f1': 0.7041564742202512, 'auc': 0.8919258057440855, 'prauc': 0.7802903744802062}
Test:      {'precision': 0.7471054235176898, 'recall': 0.6788482834956875, 'f1': 0.7113431919902131, 'auc': 0.8888468087453165, 'prauc': 0.786109772788509}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 65.35it/s, loss=0.2512]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 168.12it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 167.94it/s]


Validation: {'precision': 0.7081850533765826, 'recall': 0.703594578664092, 'f1': 0.7058823479370564, 'auc': 0.8899169304651442, 'prauc': 0.7749856990524127}
Test:      {'precision': 0.7229536347983804, 'recall': 0.6993355481688852, 'f1': 0.7109484892276049, 'auc': 0.8873500194746331, 'prauc': 0.7902911357406093}


Epoch 012: 100%|██████████| 98/98 [00:01<00:00, 64.46it/s, loss=0.2211]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 168.48it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 170.18it/s]


Validation: {'precision': 0.6490335706986317, 'recall': 0.7519151443679911, 'f1': 0.6966966917198578, 'auc': 0.8804256781934071, 'prauc': 0.76541403305483}
Test:      {'precision': 0.6587887739997106, 'recall': 0.7408637873713131, 'f1': 0.6974198542786088, 'auc': 0.8761056962932214, 'prauc': 0.7741851972290366}


Epoch 013: 100%|██████████| 98/98 [00:01<00:00, 62.03it/s, loss=0.2031]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 167.10it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 161.74it/s]


Validation: {'precision': 0.5932787591529803, 'recall': 0.8114319387105986, 'f1': 0.6854156247836813, 'auc': 0.885978635612686, 'prauc': 0.7721806285385271}
Test:      {'precision': 0.6112510495356371, 'recall': 0.8062015503831329, 'f1': 0.6953199568868665, 'auc': 0.8829245260721265, 'prauc': 0.7760925738106467}


Epoch 014: 100%|██████████| 98/98 [00:01<00:00, 59.80it/s, loss=0.1881]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 161.36it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 161.57it/s]


Validation: {'precision': 0.6648907831610821, 'recall': 0.7354154390056841, 'f1': 0.6983771634474987, 'auc': 0.8826025345859061, 'prauc': 0.7664354761824448}
Test:      {'precision': 0.677336086728563, 'recall': 0.7264673311144715, 'f1': 0.7010419399663115, 'auc': 0.8791339278635959, 'prauc': 0.7760579451178045}


Epoch 015: 100%|██████████| 98/98 [00:01<00:00, 59.87it/s, loss=0.1932]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 164.98it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 166.92it/s]


Validation: {'precision': 0.6892877173264763, 'recall': 0.7242192103669758, 'f1': 0.706321834079454, 'auc': 0.888999421455161, 'prauc': 0.7756732865084672}
Test:      {'precision': 0.7047200878117195, 'recall': 0.710963455145565, 'f1': 0.7078279994063386, 'auc': 0.8845428163716912, 'prauc': 0.7821815743972982}


Epoch 016: 100%|██████████| 98/98 [00:01<00:00, 61.46it/s, loss=0.1659]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 158.48it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 163.09it/s]


Validation: {'precision': 0.6773028296166328, 'recall': 0.6629345904498354, 'f1': 0.6700416864796096, 'auc': 0.8714131401597367, 'prauc': 0.7376299740272487}
Test:      {'precision': 0.6871060171880395, 'recall': 0.6638981173828135, 'f1': 0.6753027266225665, 'auc': 0.8689058062965846, 'prauc': 0.7544670383584171}


Epoch 017: 100%|██████████| 98/98 [00:01<00:00, 62.34it/s, loss=0.1449]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 156.37it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 162.83it/s]


Validation: {'precision': 0.7255291005243021, 'recall': 0.6464348850875284, 'f1': 0.6837020828902004, 'auc': 0.870198694192746, 'prauc': 0.7550791941153412}
Test:      {'precision': 0.7412451361819634, 'recall': 0.6328903654450007, 'f1': 0.6827956939517414, 'auc': 0.8683514809338172, 'prauc': 0.7653839065394331}


Epoch 018: 100%|██████████| 98/98 [00:01<00:00, 60.05it/s, loss=0.1419]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 163.04it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 162.75it/s]


Validation: {'precision': 0.6498630136950693, 'recall': 0.6988803771320042, 'f1': 0.6734809717205537, 'auc': 0.8705720210401775, 'prauc': 0.7398247800835793}
Test:      {'precision': 0.665765278525334, 'recall': 0.6816168327758494, 'f1': 0.6735978062145165, 'auc': 0.8635195684362238, 'prauc': 0.7463326258224514}


Epoch 019: 100%|██████████| 98/98 [00:01<00:00, 58.01it/s, loss=0.1143]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 163.01it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 156.12it/s]


Validation: {'precision': 0.7480974124752809, 'recall': 0.5792575132552784, 'f1': 0.6529392179261142, 'auc': 0.8687039177666661, 'prauc': 0.7415451107243418}
Test:      {'precision': 0.7476294675364871, 'recall': 0.5675526024331808, 'f1': 0.6452628216556183, 'auc': 0.8642777496286032, 'prauc': 0.7535862261923091}


Epoch 020: 100%|██████████| 98/98 [00:01<00:00, 61.85it/s, loss=0.1132]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 155.59it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 161.24it/s]


Validation: {'precision': 0.6542857142819756, 'recall': 0.6747200942800547, 'f1': 0.6643458029462688, 'auc': 0.8639922710803842, 'prauc': 0.7287514662201136}
Test:      {'precision': 0.6625067824163727, 'recall': 0.6760797342155256, 'f1': 0.6692244400502854, 'auc': 0.8578654004954582, 'prauc': 0.7344005993098404}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6892877173264763, 'recall': 0.7242192103669758, 'f1': 0.706321834079454, 'auc': 0.888999421455161, 'prauc': 0.7756732865084672}
Corresponding test performance:
{'precision': 0.7047200878117195, 'recall': 0.710963455145565, 'f1': 0.7078279994063386, 'auc': 0.8845428163716912, 'prauc': 0.7821815743972982}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 60.42it/s, loss=0.5377]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 161.42it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 159.33it/s]


Validation: {'precision': 0.6312419974351393, 'recall': 0.5810253388298113, 'f1': 0.6050935819947405, 'auc': 0.7882781373794258, 'prauc': 0.6292788445498949}
Test:      {'precision': 0.6482188295124159, 'recall': 0.5642303432969866, 'f1': 0.6033155663644114, 'auc': 0.7886525917511266, 'prauc': 0.6387808515481211}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 59.92it/s, loss=0.4496]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 159.31it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 158.09it/s]


Validation: {'precision': 0.5720108695626268, 'recall': 0.7442545668783486, 'f1': 0.6468629912410768, 'auc': 0.8513202996169902, 'prauc': 0.7245377284547245}
Test:      {'precision': 0.6035714285687341, 'recall': 0.7486157253557663, 'f1': 0.6683143796315874, 'auc': 0.8530410587306676, 'prauc': 0.7274735133960826}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 60.41it/s, loss=0.3916]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 157.38it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 160.59it/s]


Validation: {'precision': 0.8111353711701841, 'recall': 0.4378314672926469, 'f1': 0.5686949820477583, 'auc': 0.8742109518320863, 'prauc': 0.7434802178534355}
Test:      {'precision': 0.8071654372939182, 'recall': 0.4241417497207966, 'f1': 0.5560798502892258, 'auc': 0.8731937339445068, 'prauc': 0.750122085831075}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 60.23it/s, loss=0.3951]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 159.75it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 159.42it/s]


Validation: {'precision': 0.8621118012315265, 'recall': 0.4089569829086096, 'f1': 0.5547561906750444, 'auc': 0.8741550645794768, 'prauc': 0.7554084137684428}
Test:      {'precision': 0.8333333333231955, 'recall': 0.3792912513821745, 'f1': 0.5213089759101088, 'auc': 0.8701780587114637, 'prauc': 0.7545373650725277}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 59.94it/s, loss=0.3363]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 157.97it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 159.32it/s]


Validation: {'precision': 0.7705041384441648, 'recall': 0.603417796107228, 'f1': 0.6768010525711274, 'auc': 0.8882916652881144, 'prauc': 0.7762118618956382}
Test:      {'precision': 0.7942942942883312, 'recall': 0.5858250276822491, 'f1': 0.6743148453328573, 'auc': 0.88531718539112, 'prauc': 0.7824417733272966}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 59.53it/s, loss=0.3038]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 155.22it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 161.28it/s]


Validation: {'precision': 0.8110773899786717, 'recall': 0.6299351797252214, 'f1': 0.709121056434172, 'auc': 0.9033183104140242, 'prauc': 0.8078817265605415}
Test:      {'precision': 0.8015988372034768, 'recall': 0.6107419712037058, 'f1': 0.6932746651058063, 'auc': 0.9001432653469218, 'prauc': 0.8101998258239964}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 59.45it/s, loss=0.2865]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 163.55it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 154.84it/s]


Validation: {'precision': 0.7611940298458068, 'recall': 0.6912197996423617, 'f1': 0.7245213044574085, 'auc': 0.9023290102361176, 'prauc': 0.806831963039089}
Test:      {'precision': 0.772955974837906, 'recall': 0.6805094130637845, 'f1': 0.7237926923068956, 'auc': 0.9008750413005017, 'prauc': 0.8113464759236904}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 60.88it/s, loss=0.2345]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 157.63it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 159.93it/s]


Validation: {'precision': 0.7925813777381334, 'recall': 0.6169711255119802, 'f1': 0.6938369732042229, 'auc': 0.9026237755741666, 'prauc': 0.8015592017539219}
Test:      {'precision': 0.8049311094938004, 'recall': 0.6146179401959324, 'f1': 0.6970172635313315, 'auc': 0.9003341339844902, 'prauc': 0.8054575584939342}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 60.81it/s, loss=0.2153]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 159.20it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 161.43it/s]


Validation: {'precision': 0.8013100436622904, 'recall': 0.6487919858535722, 'f1': 0.7170302783459854, 'auc': 0.9049912234658504, 'prauc': 0.8183573298945772}
Test:      {'precision': 0.8088130774640454, 'recall': 0.6301218161648388, 'f1': 0.7083722328567003, 'auc': 0.9017283182829627, 'prauc': 0.8193440809568008}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 60.51it/s, loss=0.2049]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 158.16it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 161.86it/s]


Validation: {'precision': 0.8490374873267574, 'recall': 0.4938126104861885, 'f1': 0.624441127983083, 'auc': 0.8938324871896837, 'prauc': 0.7882298126637316}
Test:      {'precision': 0.8575727181458619, 'recall': 0.47342192690767765, 'f1': 0.6100606447164696, 'auc': 0.8900810843948362, 'prauc': 0.7906514435162495}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 62.98it/s, loss=0.1926]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 165.12it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 166.89it/s]


Validation: {'precision': 0.6998274870575053, 'recall': 0.7171479080688442, 'f1': 0.708381834344703, 'auc': 0.891318646631736, 'prauc': 0.7787605170818557}
Test:      {'precision': 0.7044956140312254, 'recall': 0.7115171650015975, 'f1': 0.7079889757124758, 'auc': 0.8877151996186813, 'prauc': 0.7804141520349198}


Epoch 012: 100%|██████████| 98/98 [00:01<00:00, 60.19it/s, loss=0.1768]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 146.42it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 149.74it/s]


Validation: {'precision': 0.6498316498285241, 'recall': 0.7961107837313135, 'f1': 0.715572028945687, 'auc': 0.898963128592672, 'prauc': 0.7916668771234153}
Test:      {'precision': 0.6677374942603592, 'recall': 0.8056478405271006, 'f1': 0.7302383890175562, 'auc': 0.8959279674933661, 'prauc': 0.7948053663254108}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7611940298458068, 'recall': 0.6912197996423617, 'f1': 0.7245213044574085, 'auc': 0.9023290102361176, 'prauc': 0.806831963039089}
Corresponding test performance:
{'precision': 0.772955974837906, 'recall': 0.6805094130637845, 'f1': 0.7237926923068956, 'auc': 0.9008750413005017, 'prauc': 0.8113464759236904}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 58.20it/s, loss=0.5262]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 147.67it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 151.96it/s]


Validation: {'precision': 0.8298429319154491, 'recall': 0.18680023570897586, 'f1': 0.3049543019517515, 'auc': 0.8230666427703419, 'prauc': 0.6538985223920684}
Test:      {'precision': 0.8328690807567446, 'recall': 0.16555924695367907, 'f1': 0.2762124683626154, 'auc': 0.8275371679895138, 'prauc': 0.6697032994734657}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 58.93it/s, loss=0.4608]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 166.70it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 169.80it/s]


Validation: {'precision': 0.7886075949267265, 'recall': 0.36711844431133106, 'f1': 0.5010052228423315, 'auc': 0.8445642014235346, 'prauc': 0.6941001414948355}
Test:      {'precision': 0.7992073976116353, 'recall': 0.3349944628995848, 'f1': 0.47210300012573675, 'auc': 0.8483054114859465, 'prauc': 0.7005934149108075}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 63.52it/s, loss=0.4260]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 172.42it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 168.05it/s]


Validation: {'precision': 0.5835866261372836, 'recall': 0.7919858573907367, 'f1': 0.6719999951114013, 'auc': 0.8708944425843748, 'prauc': 0.7476228699485307}
Test:      {'precision': 0.6143157894710977, 'recall': 0.8078626799512301, 'f1': 0.6979191531854158, 'auc': 0.8752006551946061, 'prauc': 0.74826039220033}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 63.73it/s, loss=0.3831]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 167.17it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 167.91it/s]


Validation: {'precision': 0.7910817506127905, 'recall': 0.5645256334675043, 'f1': 0.6588720721640087, 'auc': 0.8834954532047412, 'prauc': 0.7723985706400939}
Test:      {'precision': 0.809407948087515, 'recall': 0.5526024363203068, 'f1': 0.6567949935281531, 'auc': 0.8868791568247141, 'prauc': 0.7724992663737493}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 64.92it/s, loss=0.3606]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 166.84it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 170.15it/s]


Validation: {'precision': 0.5823408624206065, 'recall': 0.8355922215625481, 'f1': 0.6863504307805735, 'auc': 0.8892278885438285, 'prauc': 0.7822292808859006}
Test:      {'precision': 0.6137096774168802, 'recall': 0.8427464008812694, 'f1': 0.7102193138324194, 'auc': 0.8891349889972017, 'prauc': 0.7783000840301888}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 61.99it/s, loss=0.3286]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 168.22it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 163.79it/s]


Validation: {'precision': 0.7690615835720744, 'recall': 0.6181496758950021, 'f1': 0.6853969241628306, 'auc': 0.895322664893263, 'prauc': 0.792372464511717}
Test:      {'precision': 0.7776998597420919, 'recall': 0.6140642303399, 'f1': 0.6862623713024957, 'auc': 0.8945829006321009, 'prauc': 0.7915159169340362}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 60.19it/s, loss=0.3038]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 163.37it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 162.47it/s]


Validation: {'precision': 0.8579439252256268, 'recall': 0.5409546258070657, 'f1': 0.6635345091659264, 'auc': 0.9060860387766836, 'prauc': 0.808612935495065}
Test:      {'precision': 0.8328865058013147, 'recall': 0.5160575858221702, 'f1': 0.6372649525364238, 'auc': 0.9020179757510044, 'prauc': 0.804018920543841}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 60.43it/s, loss=0.2873]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 164.47it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 158.94it/s]


Validation: {'precision': 0.8177737881434671, 'recall': 0.536829699466489, 'f1': 0.6481679069856302, 'auc': 0.8964381105842025, 'prauc': 0.7916134478112338}
Test:      {'precision': 0.8200692041451552, 'recall': 0.5249169435186882, 'f1': 0.6401080303478727, 'auc': 0.8933666593298264, 'prauc': 0.7872720600267911}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 61.14it/s, loss=0.2576]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 163.56it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 157.05it/s]


Validation: {'precision': 0.8270676691660229, 'recall': 0.5833824395958552, 'f1': 0.6841741485653923, 'auc': 0.9004542004922933, 'prauc': 0.8081420081813848}
Test:      {'precision': 0.8306709265109372, 'recall': 0.5758582502736663, 'f1': 0.6801831213859455, 'auc': 0.8974050605486282, 'prauc': 0.8079875582438645}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 61.77it/s, loss=0.2394]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 166.45it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 158.60it/s]


Validation: {'precision': 0.6813186813151161, 'recall': 0.7672362993472762, 'f1': 0.7217294850357623, 'auc': 0.9102264219345778, 'prauc': 0.8105374378394559}
Test:      {'precision': 0.6966236345546345, 'recall': 0.7768549280134173, 'f1': 0.734554968832968, 'auc': 0.9033436541502142, 'prauc': 0.8096698798681518}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 61.42it/s, loss=0.2278]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 158.66it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 159.63it/s]


Validation: {'precision': 0.724629080114394, 'recall': 0.7195050088348881, 'f1': 0.7220579488692408, 'auc': 0.9080588268582257, 'prauc': 0.8129014322219554}
Test:      {'precision': 0.7480091012471672, 'recall': 0.7281284606825685, 'f1': 0.7379348995983372, 'auc': 0.9027939450879363, 'prauc': 0.8183341791167057}


Epoch 012: 100%|██████████| 98/98 [00:01<00:00, 60.65it/s, loss=0.1988]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 160.75it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 164.85it/s]


Validation: {'precision': 0.7228915662607055, 'recall': 0.7071302298131578, 'f1': 0.7149240343171703, 'auc': 0.8936168582013297, 'prauc': 0.7967828084604842}
Test:      {'precision': 0.7428236672481382, 'recall': 0.7021040974490471, 'f1': 0.721890117402366, 'auc': 0.8884538076969855, 'prauc': 0.7953833806615758}


Epoch 013: 100%|██████████| 98/98 [00:01<00:00, 58.75it/s, loss=0.1842]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 162.79it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 155.63it/s]


Validation: {'precision': 0.7003869541144258, 'recall': 0.7466116676443925, 'f1': 0.722760976176108, 'auc': 0.8978486409689201, 'prauc': 0.8096292556033907}
Test:      {'precision': 0.7270777479853776, 'recall': 0.7508305647798957, 'f1': 0.73876327475755, 'auc': 0.8968776328300457, 'prauc': 0.8136322003254789}


Epoch 014: 100%|██████████| 98/98 [00:01<00:00, 60.89it/s, loss=0.1789]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 153.52it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 162.69it/s]


Validation: {'precision': 0.7553619302898434, 'recall': 0.6641131408328573, 'f1': 0.7068046359695059, 'auc': 0.9016199127750084, 'prauc': 0.8064989776168188}
Test:      {'precision': 0.7730179028083567, 'recall': 0.6694352159431372, 'f1': 0.7175074134191514, 'auc': 0.8957166024816247, 'prauc': 0.8063197598616011}


Epoch 015: 100%|██████████| 98/98 [00:01<00:00, 59.06it/s, loss=0.1444]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 164.72it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 161.35it/s]


Validation: {'precision': 0.709946396660453, 'recall': 0.7024160282810701, 'f1': 0.7061611324367171, 'auc': 0.8947616846193555, 'prauc': 0.8024114524223169}
Test:      {'precision': 0.7343749999957502, 'recall': 0.7026578073050794, 'f1': 0.7181663786995599, 'auc': 0.8885225597989411, 'prauc': 0.7987669616954995}


Epoch 016: 100%|██████████| 98/98 [00:01<00:00, 58.98it/s, loss=0.1526]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 158.45it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 157.36it/s]


Validation: {'precision': 0.8334866605259109, 'recall': 0.5338833235089341, 'f1': 0.6508620642008858, 'auc': 0.894506902618602, 'prauc': 0.8010958386361098}
Test:      {'precision': 0.8530183726959492, 'recall': 0.5398671096315623, 'f1': 0.661241093925758, 'auc': 0.8901873208567657, 'prauc': 0.8044200619147818}


Epoch 017: 100%|██████████| 98/98 [00:01<00:00, 61.02it/s, loss=0.1292]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 157.38it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 159.99it/s]


Validation: {'precision': 0.6874292185691538, 'recall': 0.7153800824943113, 'f1': 0.7011261861616667, 'auc': 0.8857818486123546, 'prauc': 0.7910054043276932}
Test:      {'precision': 0.712396694210951, 'recall': 0.7159468438498564, 'f1': 0.714167352079765, 'auc': 0.8827513224777643, 'prauc': 0.7915254390308375}


Epoch 018: 100%|██████████| 98/98 [00:01<00:00, 60.51it/s, loss=0.1137]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 163.40it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 153.90it/s]


Validation: {'precision': 0.7409551374765488, 'recall': 0.603417796107228, 'f1': 0.6651510181074469, 'auc': 0.8779905907582795, 'prauc': 0.7759608945121333}
Test:      {'precision': 0.7681564245756414, 'recall': 0.6090808416356087, 'f1': 0.6794317430550968, 'auc': 0.8725852439893306, 'prauc': 0.7752635740152416}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7003869541144258, 'recall': 0.7466116676443925, 'f1': 0.722760976176108, 'auc': 0.8978486409689201, 'prauc': 0.8096292556033907}
Corresponding test performance:
{'precision': 0.7270777479853776, 'recall': 0.7508305647798957, 'f1': 0.73876327475755, 'auc': 0.8968776328300457, 'prauc': 0.8136322003254789}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 60.93it/s, loss=0.5714]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 156.09it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 159.81it/s]


Validation: {'precision': 0.6684831970874798, 'recall': 0.4337065409520701, 'f1': 0.5260900595547698, 'auc': 0.8168854487605868, 'prauc': 0.6383575012750712}
Test:      {'precision': 0.6987724268111543, 'recall': 0.40974529346395494, 'f1': 0.516579401968065, 'auc': 0.8180503627304272, 'prauc': 0.6581130067562897}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 60.41it/s, loss=0.4762]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 154.64it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 164.61it/s]


Validation: {'precision': 0.7298311444584444, 'recall': 0.45845609899553064, 'f1': 0.5631559851227869, 'auc': 0.8334918189365429, 'prauc': 0.6582179334413093}
Test:      {'precision': 0.7373188405730317, 'recall': 0.4507198228103504, 'f1': 0.5594501671084375, 'auc': 0.8351333212201608, 'prauc': 0.6757605285336208}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 58.56it/s, loss=0.4374]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 160.17it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 157.90it/s]


Validation: {'precision': 0.6362637362602404, 'recall': 0.6823806717696973, 'f1': 0.6585157754971107, 'auc': 0.8551438180203751, 'prauc': 0.7246481777964013}
Test:      {'precision': 0.6446324695893991, 'recall': 0.6749723145034608, 'f1': 0.6594536060350508, 'auc': 0.8554983462565127, 'prauc': 0.7321173008974206}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 63.29it/s, loss=0.3873]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 165.69it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 166.93it/s]


Validation: {'precision': 0.6694966646411795, 'recall': 0.6505598114281052, 'f1': 0.6598924038434685, 'auc': 0.8606869392831946, 'prauc': 0.7271599044370003}
Test:      {'precision': 0.6933414783097536, 'recall': 0.6284606865967417, 'f1': 0.659308737384052, 'auc': 0.8607049177018259, 'prauc': 0.7308711768853022}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 61.46it/s, loss=0.3628]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 145.35it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 149.24it/s]


Validation: {'precision': 0.6659304251757817, 'recall': 0.7106658809622236, 'f1': 0.6875712606798097, 'auc': 0.8774698493062507, 'prauc': 0.7442747879298779}
Test:      {'precision': 0.6958650707252674, 'recall': 0.7081949058654031, 'f1': 0.7019758457100346, 'auc': 0.8791041373377709, 'prauc': 0.7599992303845307}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 57.34it/s, loss=0.3283]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 147.66it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 151.94it/s]


Validation: {'precision': 0.801792828677273, 'recall': 0.4743665291663267, 'f1': 0.5960755229071083, 'auc': 0.8685665309319656, 'prauc': 0.7395409278411923}
Test:      {'precision': 0.8239921337185449, 'recall': 0.4640088593551273, 'f1': 0.5936946464667778, 'auc': 0.8672854232741314, 'prauc': 0.7548863650938381}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 58.12it/s, loss=0.3217]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 163.86it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 169.26it/s]


Validation: {'precision': 0.728041639552843, 'recall': 0.6593989393007696, 'f1': 0.6920222584587938, 'auc': 0.8850511627039517, 'prauc': 0.766036528137414}
Test:      {'precision': 0.7469414037298974, 'recall': 0.642303432997551, 'f1': 0.6906817455452416, 'auc': 0.8845324143079217, 'prauc': 0.7743610368401971}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 63.63it/s, loss=0.3044]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 172.98it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 167.11it/s]


Validation: {'precision': 0.7923139820049688, 'recall': 0.5710076605741249, 'f1': 0.6636986252641937, 'auc': 0.8809998797945036, 'prauc': 0.7595191395416002}
Test:      {'precision': 0.7955246913518864, 'recall': 0.5708748615693751, 'f1': 0.6647324258207451, 'auc': 0.8792210220661625, 'prauc': 0.7744445873173618}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 63.87it/s, loss=0.2796]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 167.39it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 168.32it/s]


Validation: {'precision': 0.6783831282912808, 'recall': 0.6823806717696973, 'f1': 0.6803760231981609, 'auc': 0.8752857116031263, 'prauc': 0.7401090245842197}
Test:      {'precision': 0.6838032061874859, 'recall': 0.6849390919120436, 'f1': 0.6843706727278909, 'auc': 0.8711479126689504, 'prauc': 0.7534048473499136}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 58.96it/s, loss=0.2522]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 168.26it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 170.20it/s]


Validation: {'precision': 0.7118193891054992, 'recall': 0.6317030052997543, 'f1': 0.6693724583292417, 'auc': 0.8695232568254941, 'prauc': 0.747305992119948}
Test:      {'precision': 0.7319257837445174, 'recall': 0.633444075301033, 'f1': 0.6791332689905174, 'auc': 0.8728340318103727, 'prauc': 0.7675179363428892}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 61.44it/s, loss=0.2280]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 170.21it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 167.44it/s]


Validation: {'precision': 0.7773073666318603, 'recall': 0.5409546258070657, 'f1': 0.6379430111396153, 'auc': 0.8744219182267938, 'prauc': 0.754175236444835}
Test:      {'precision': 0.8060309698385817, 'recall': 0.5476190476160154, 'f1': 0.6521595731535156, 'auc': 0.8784459144386554, 'prauc': 0.7696732895811232}


Epoch 012: 100%|██████████| 98/98 [00:01<00:00, 59.86it/s, loss=0.2255]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 160.42it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 165.95it/s]


Validation: {'precision': 0.6485671191520938, 'recall': 0.7601649970491446, 'f1': 0.6999457356678407, 'auc': 0.8812988606281777, 'prauc': 0.7610816350527836}
Test:      {'precision': 0.6676557863468464, 'recall': 0.7475083056437015, 'f1': 0.7053291486172503, 'auc': 0.8818417266293325, 'prauc': 0.7781967368544378}


Epoch 013: 100%|██████████| 98/98 [00:01<00:00, 60.18it/s, loss=0.2010]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 166.43it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 165.93it/s]


Validation: {'precision': 0.6985645932972574, 'recall': 0.6882734236848069, 'f1': 0.6933808201668331, 'auc': 0.8756766030156634, 'prauc': 0.7510616375086445}
Test:      {'precision': 0.7085484796287519, 'recall': 0.6838316721999789, 'f1': 0.6959706909683496, 'auc': 0.8765163008340855, 'prauc': 0.7674487637585106}


Epoch 014: 100%|██████████| 98/98 [00:01<00:00, 59.86it/s, loss=0.1842]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 160.96it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 163.92it/s]


Validation: {'precision': 0.7531068765471989, 'recall': 0.535651149083467, 'f1': 0.626033052989282, 'auc': 0.8590269282028281, 'prauc': 0.7253699968582422}
Test:      {'precision': 0.7785179017420607, 'recall': 0.5177187153902674, 'f1': 0.6218822698906499, 'auc': 0.8586263514681313, 'prauc': 0.7396989965246883}


Epoch 015: 100%|██████████| 98/98 [00:01<00:00, 60.13it/s, loss=0.1733]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 161.72it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 160.97it/s]


Validation: {'precision': 0.6964064436140248, 'recall': 0.6623453152583245, 'f1': 0.6789489530177665, 'auc': 0.8778019153934699, 'prauc': 0.7521514092420383}
Test:      {'precision': 0.7111636148806193, 'recall': 0.6666666666629754, 'f1': 0.6881966226091693, 'auc': 0.8750083093408808, 'prauc': 0.767462029142404}


Epoch 016: 100%|██████████| 98/98 [00:01<00:00, 59.90it/s, loss=0.1521]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 160.65it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 157.54it/s]


Validation: {'precision': 0.7874779541376766, 'recall': 0.5262227460192916, 'f1': 0.6308724784147655, 'auc': 0.8700355034151263, 'prauc': 0.7424348620507327}
Test:      {'precision': 0.7924850555013452, 'recall': 0.5138427463980407, 'f1': 0.6234464177963603, 'auc': 0.8666018414976608, 'prauc': 0.7562075699880323}


Epoch 017: 100%|██████████| 98/98 [00:01<00:00, 62.01it/s, loss=0.1543]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 154.16it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 163.96it/s]

Validation: {'precision': 0.6814052089601369, 'recall': 0.6629345904498354, 'f1': 0.6720430057496176, 'auc': 0.8673144648600756, 'prauc': 0.7419531533846142}
Test:      {'precision': 0.7132743362789777, 'recall': 0.6694352159431372, 'f1': 0.6906598064835143, 'auc': 0.8658153469956131, 'prauc': 0.7534180322626077}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6485671191520938, 'recall': 0.7601649970491446, 'f1': 0.6999457356678407, 'auc': 0.8812988606281777, 'prauc': 0.7610816350527836}
Corresponding test performance:
{'precision': 0.6676557863468464, 'recall': 0.7475083056437015, 'f1': 0.7053291486172503, 'auc': 0.8818417266293325, 'prauc': 0.7781967368544378}





In [None]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")


Final Metrics:
precision: 0.7319 ± 0.0438
recall: 0.7150 ± 0.0298
f1: 0.7216 ± 0.0132
auc: 0.8943 ± 0.0097
prauc: 0.8014 ± 0.0177
