In [1]:
import torch
import pickle
import numpy as np
from argparse import Namespace
from torch.utils.data import DataLoader
import torch.nn.functional as F
from heterogt.utils.tokenizer import EHRTokenizer
from heterogt.utils.dataset import FineTuneEHRDataset, batcher
from heterogt.utils.train import train_with_early_stopping
from heterogt.utils.seed import set_random_seed

In [2]:
set_random_seed(123)

[INFO] Random seed set to 123


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
config = Namespace(
    dataset = "MIMIC-III",
    tasks = ["death", "readmission", "stay", "next_diag_6m", "next_diag_12m"], 
    task_index = 2,
    token_type = ["diag", "med", "lab", "pro"],
    special_tokens = ["[PAD]"],
    batch_size = 32,
    lr = 1e-3,
    epochs = 500,
    early_stop_patience = 5,   
)

In [5]:
full_data_path = f"./data_process/{config.dataset}-processed/mimic.pkl"  # for tokenizer
curr_task = config.tasks[config.task_index]
print("Current task:", curr_task)
if curr_task == "next_diag_6m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_6m.pkl"
elif curr_task == "next_diag_12m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_downstream.pkl"

Current task: stay


In [6]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_gender_sentences = [[str(c) + "_" + gender] for c in set(ehr_full_data["AGE"].values.tolist()) for gender in ["M", "F"]] # important of [[]]
token_type_sentences = ["[PAD]"] + config.token_type
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
config.max_num_adms = max_admissions
print(f"Max admissions per patient: {config.max_num_adms}")

Max admissions per patient: 8


In [7]:
task_sentences = config.tasks
tokenizer = EHRTokenizer(token_type_sentences, task_sentences, age_gender_sentences, diag_sentences, 
                         med_sentences, lab_sentences, pro_sentences, special_tokens=config.special_tokens)
config.label_vocab_size = len(tokenizer.diag_voc.id2word)  # only for diagnosis
config.global_vocab_size = len(tokenizer.vocab.id2word)
config.age_gender_vocab_size = tokenizer.token_number("age_gender")
print(f"Age and gender vocabulary size: {config.age_gender_vocab_size}")

Age and gender vocabulary size: 36


In [8]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))
# example label percentage
print("Percentage of DEATH in test dataset:",
      (test_data["DEATH"] == True).mean() * 100, "%")

print("Percentage of READMISSION in test dataset:",
      (test_data["READMISSION"] == 1).mean() * 100, "%")

print("Percentage of STAY>7 days in test dataset:",
      (test_data["STAY_DAYS"] > 7).mean() * 100, "%")

Percentage of DEATH in test dataset: 28.648477157360407 %
Percentage of READMISSION in test dataset: 40.1491116751269 %
Percentage of STAY>7 days in test dataset: 50.58692893401015 %


In [9]:
train_dataset = FineTuneEHRDataset(train_data, tokenizer, token_type=config.token_type, task=curr_task)
val_dataset = FineTuneEHRDataset(val_data, tokenizer, token_type=config.token_type, task=curr_task)
test_dataset = FineTuneEHRDataset(test_data, tokenizer, token_type=config.token_type, task=curr_task)

In [10]:
train_dataloader = DataLoader(
    train_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=True,
    batch_size=config.batch_size,
)

val_dataloader = DataLoader(
    val_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

test_dataloader = DataLoader(
    test_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

In [11]:
if curr_task in ["death", "stay", "readmission"]:
    eval_metric = "f1"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "f1"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [12]:
input_ids, token_types, adm_index, age_gender_ids, task_index, labels = next(iter(train_dataloader))
print("Input IDs shape:", input_ids.shape)
print("Token Types shape:", token_types.shape)
print("Admission Index shape:", adm_index.shape)
print("Age/Sex IDs shape:", age_gender_ids.shape)
print("Task Index:", task_index)
print("Labels shape:", labels.shape)

Input IDs shape: torch.Size([32, 256])
Token Types shape: torch.Size([32, 256])
Admission Index shape: torch.Size([32, 256])
Age/Sex IDs shape: torch.Size([32, 7])
Task Index: 2
Labels shape: torch.Size([32, 1])


# Model Walkthrough

In [13]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData, Batch as HeteroBatch
from torch_geometric.nn import HeteroConv, GATConv
from heterogt.model.layer import TransformerEncoder

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [14]:
class DiseaseOccHetGNN(nn.Module):
    def __init__(self, d_model: int, heads: int = 4, dropout: float = 0.0):
        super().__init__()
        self.d = d_model
        self.act = nn.GELU()
        self.drop = nn.Dropout(dropout)

        # —— 规范化：按节点类型各自一套 LN —— #
        self.ln_v1 = nn.LayerNorm(d_model)
        self.ln_o1 = nn.LayerNorm(d_model)
        self.ln_v2 = nn.LayerNorm(d_model)
        self.ln_o2 = nn.LayerNorm(d_model)

        # —— 可学习缩放（残差权重），初始化小值避免早期干扰 —— #
        self.alpha_v1 = nn.Parameter(torch.tensor(0.1))
        self.alpha_o1 = nn.Parameter(torch.tensor(0.1))
        self.alpha_v2 = nn.Parameter(torch.tensor(0.1))
        self.alpha_o2 = nn.Parameter(torch.tensor(0.1))

        # 注意：这里用 aggr='sum'，让关系信号不被平均稀释
        self.conv1 = HeteroConv({
            ('visit','contains','occ'):   GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('occ','contained_by','visit'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('visit','next','visit'):     GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=True),
        }, aggr='sum')

        self.conv2 = HeteroConv({
            ('visit','contains','occ'):   GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('occ','contained_by','visit'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('visit','next','visit'):     GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=True),
        }, aggr='sum')

        # 末端线性 + 残差（用零初始化保持“近似恒等”）
        self.lin_v = nn.Linear(d_model, d_model)
        self.lin_o = nn.Linear(d_model, d_model)
        nn.init.zeros_(self.lin_v.weight); nn.init.zeros_(self.lin_v.bias)
        nn.init.zeros_(self.lin_o.weight); nn.init.zeros_(self.lin_o.bias)

    def forward(self, hg):
        # x_dict: {'visit': [N_visit, d], 'occ': [N_occ, d]}
        x_v = hg['visit'].x
        x_o = hg['occ'].x

        # ===== Layer 1: 图卷积（sum 聚合）→ 残差 + LN =====
        h1 = self.conv1({'visit': x_v, 'occ': x_o}, hg.edge_index_dict)
        # 残差注入前先丢弃避免过拟合
        dv = self.drop(h1['visit'])
        do = self.drop(h1['occ'])
        # y = LN(x + α * Δx)
        v1 = self.ln_v1(x_v + self.alpha_v1 * dv)
        o1 = self.ln_o1(x_o + self.alpha_o1 * do)

        # ===== Layer 2: 再一层图卷积 → 残差 + LN =====
        h2 = self.conv2({'visit': v1, 'occ': o1}, hg.edge_index_dict)
        dv2 = self.drop(h2['visit'])
        do2 = self.drop(h2['occ'])
        v2 = self.ln_v2(v1 + self.alpha_v2 * dv2)
        o2 = self.ln_o2(o1 + self.alpha_o2 * do2)

        # ===== 末端线性：零初始化，等价“细调残差”，不改变整体尺度期望 =====
        v_out = v2 + self.lin_v(v2)
        o_out = o2 + self.lin_o(o2)

        return {'visit': v_out, 'occ': o_out}

In [15]:
# multi-class classification task
class MultiPredictionHead(nn.Module):
    def __init__(self, hidden_size, label_size):
        super(MultiPredictionHead, self).__init__()
        self.cls = nn.Sequential(
                nn.Linear(hidden_size, hidden_size), 
                nn.ReLU(), 
                nn.Linear(hidden_size, label_size)
            )

    def forward(self, input):
        return self.cls(input)
    
class BinaryPredictionHead(nn.Module):
    def __init__(self, hidden_size):
        super(BinaryPredictionHead, self).__init__()
        self.cls = nn.Sequential(
                nn.Linear(hidden_size, hidden_size), 
                nn.ReLU(), 
                nn.Linear(hidden_size, 1)
            )
    def forward(self, input):
        return self.cls(input)

In [16]:
for i in range(len(train_dataset)):
    age_gender_ids = train_dataset[i][3]
    if len(age_gender_ids[0]) > 3:
        print(age_gender_ids)
        break
exp_i = i
id_seq = torch.concat([train_dataset[exp_i][0][0], torch.zeros(5, dtype=train_dataset[exp_i][0][0].dtype)], dim=0)
type_seq = torch.concat([train_dataset[exp_i][1][0], torch.zeros(5, dtype=train_dataset[exp_i][1][0].dtype)], dim=0)
visit_seq = torch.concat([train_dataset[exp_i][2][0], torch.zeros(5, dtype=train_dataset[exp_i][2][0].dtype)], dim=0)
age_sex = torch.concat([train_dataset[exp_i][3][0], torch.zeros(3, dtype=train_dataset[exp_i][3][0].dtype)], dim=0)

tensor([[11, 11, 11, 29]])


In [17]:
class HeteroGT(nn.Module):
    def __init__(self, tokenizer, d_model, num_heads, layer_types, max_num_adms, device, task, label_vocab_size):
        super(HeteroGT, self).__init__()
        self.device = device
        self.tokenizer = tokenizer
        self.max_num_adms = max_num_adms
        self.global_vocab_size = len(self.tokenizer.vocab.word2id)
        self.n_type = len(self.tokenizer.token_type_voc.word2id)
        self.d_model = d_model
        self.num_attn_heads = num_heads
        self.layer_types = layer_types
        self.seq_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="all")[0] #0
        self.type_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="type")[0] #0
        self.adm_pad_id = 0
        self.age_gender_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="all")[0] #0
        self.node_type_id_dict = {'diag': 1, 'med': 2, 'lab': 3, 'pro': 4, 'visit': 5}
        self.graph_node_types = ['diag']
        self.forbid_map = {-1: [], 1: [5], 2: [5], 3: [5], 4: [5], 5: [-1, 1, 5]}  # attention forbid mask

        # embedding layers
        self.token_emb = nn.Embedding(self.global_vocab_size, d_model, padding_idx=self.seq_pad_id) # already contains [PAD], will also be used for age_gender
        self.type_emb = nn.Embedding(self.n_type + 1, d_model, padding_idx=self.type_pad_id) # n_type already have PAD, + 1 for visit
        self.adm_index_emb = nn.Embedding(self.max_num_adms + 1, d_model, padding_idx=self.adm_pad_id) # +1 for pad
        self.task_emb = nn.Embedding(5, d_model, padding_idx=None)  # 5 task in total, task embedding
        
        # stack together
        self.stack_layers = nn.ModuleList(self.make_gnn_layer() if layer_type == 'gnn' else self.make_tf_layer()
            for layer_type in self.layer_types
        )

        # prediction head
        if task in ["death", "stay", "readmission"]:
            self.cls_head = BinaryPredictionHead(self.d_model)
        else:
            self.cls_head = MultiPredictionHead(self.d_model, label_vocab_size)

    def make_tf_layer(self):
        assert self.d_model % self.num_attn_heads == 0, "Invalid model and attention head dimensions"
        layer_layer = nn.TransformerEncoderLayer(d_model=self.d_model, nhead=self.num_attn_heads, batch_first=True, norm_first=True)
        tf_wrapper = nn.TransformerEncoder(layer_layer, num_layers=1, enable_nested_tensor=False)
        return tf_wrapper

    def make_gnn_layer(self):
        return DiseaseOccHetGNN(d_model=self.d_model, heads=self.num_attn_heads)
    
    def forward(self, input_ids, token_types, adm_index, age_gender_ids, task_id):
        """Forward pass for the model.

        Args:
            input_ids (Tensor): Input token IDs. Shape of [B, L]
            token_types (Tensor): Token type IDs. Shape of [B, L]
            adm_index (Tensor): Admission index IDs. Shape of [B, L]
            age_gender_ids (Tensor): Age and gender IDs. Shape of [B, V]
            task_id (Tensor): Task ID. Shape of [1]

        Returns:
            Tensor: Output logits. Shape of [B, label_size]
        """
        B, L = input_ids.shape
        V = age_gender_ids.shape[1]
        num_visits = adm_index.max(dim=1).values
        
        task_id = torch.full((B,), task_id, dtype=torch.long, device=self.device) # [1] -> [B]
        # 基础表示
        task_id_embed = self.task_emb(task_id).unsqueeze(1) # [B, 1, d]
        seq_embed = self.token_emb(input_ids)  # [B, L, d]
        visit_embed = self.token_emb(age_gender_ids)

        # run through layers
        for i, layer_type in enumerate(self.layer_types):
            if layer_type == 'gnn': # purpose is just to update visit_embed
                seq_embed_det = seq_embed.detach()
                visit_embed_det = visit_embed.detach()
                with torch.no_grad():
                    hg_batch = self.build_graph_batch(seq_embed_det, token_types, self.graph_node_types, visit_embed_det, adm_index) # num_visits is a 1d tensor of [B]
                    gnn_out = self.stack_layers[i](hg_batch)['visit']  # extract virtual visit node representations
                    visit_embed = self.process_gnn_out(gnn_out, num_visits, V).detach() # [B, V, d]
            elif layer_type == 'tf':
                x, src_key_padding_mask, attn_mask = self.prepare_tf_input(task_id_embed, seq_embed, visit_embed, i, input_ids, adm_index, token_types, num_visits)
                h = self.stack_layers[i](src=x, src_key_padding_mask=src_key_padding_mask, mask=attn_mask) # [B, 1+L+V, d]
                task_id_embed, seq_embed, visit_embed = self.process_tf_out(h, L, V) # # [B, 1, d], [B, L, d], [B, V, d]
            else:
                raise ValueError(f"Unknown layer type: {layer_type}")

        logits = self.cls_head(task_id_embed.squeeze())  # [B, label_size]
        return logits

    def build_graph_batch(self, seq_embed, token_types, graph_node_types, visit_embed, adm_index):
        """Build a batch of heterogeneous graphs from the input sequences.

        Args:
            seq_embed (Tensor): Sequence embeddings. Shape of [B, L, d]
            token_types (Tensor): Token type IDs. Shape of [B, L]
            graph_node_types: a list controls what types of tokens are connected to the virtual visit nodes. e.g. ['diag']
            visit_embed (Tensor): Visit embeddings. Shape of [B, V, d]
        Returns:
            A batch of heterogeneous graphs.
        """
        B, L = seq_embed.shape[0], seq_embed.shape[1]
        V = visit_embed.shape[1]
        graph_node_type_ids = [self.node_type_id_dict[t] for t in graph_node_types]
        graphs = [] # contains heterogeneous graphs for each patient
        for p in range(B):
            hg_p = self.build_patient_graph(seq_embed[p], token_types[p], visit_embed[p], adm_index[p], graph_node_type_ids)
            graphs.append(hg_p)
        hg_batch = HeteroBatch.from_data_list(graphs).to(self.device)
        return hg_batch

    def build_patient_graph(self, seq_embed_p, token_types_p, visit_embed_p, adm_index_p, graph_node_type_ids):
        """Build a heterogeneous graph for a single patient.

        Args:
            seq_embed_p (Tensor): Sequence embeddings for patient p. Shape [L, d]
            token_types_p (Tensor): Token type IDs for patient p. Shape [L]
            visit_embed_p (Tensor): Visit embeddings for patient p. Shape [V, d]
            graph_node_type_ids (list): List of graph node type IDs that the graph uses.
            adm_index_p (Tensor): Admission index for patient p. Shape [L]

        Returns:
            A heterogeneous graph for patient p.
        """
        hg = HeteroData()
        occ_mask = torch.isin(token_types_p, torch.tensor(graph_node_type_ids, device=token_types_p.device)) # [L], a mask for the token types needed in the graph
        occ_pos = torch.nonzero(occ_mask, as_tuple=False).view(-1) # [L], seq position index for the token types needed in the graph
        num_occ = occ_pos.numel() # int, number of occurrences of the token types needed in the graph
        
        # build visit virtual nodes
        nonpad = adm_index_p != self.adm_pad_id
        adm_index_used_p = adm_index_p[nonpad] # adm_index非pad部分
        adm_ids_unique, adm_lid_nonpad = torch.unique(adm_index_used_p, return_inverse=True)
        num_visit_p = adm_ids_unique.numel()  # int, number of visits for patient
        adm_lid_full = torch.full_like(token_types_p, fill_value=-1) # [L]
        adm_lid_full[nonpad] = adm_lid_nonpad
        hg['visit'].x = visit_embed_p[:num_visit_p, :]
        hg['visit'].num_nodes = num_visit_p
        
        # build medical code nodes
        gid_occ_embed = seq_embed_p[occ_pos, :]
        hg['occ'].x = gid_occ_embed
        hg['occ'].num_nodes = num_occ

        # build edges between occ nodes and virtual visit nodes
        occ_adm_lid = adm_lid_full[occ_pos]
        assert (occ_adm_lid != -1).all(), "occ_adm_lid contains -1"
        e_v2o = torch.stack([occ_adm_lid, torch.arange(num_occ, device=self.device)], dim=0)
        e_o2v = torch.stack([torch.arange(num_occ, device=self.device), occ_adm_lid], dim=0)
        hg['visit','contains','occ'].edge_index = e_v2o
        hg['occ','contained_by','visit'].edge_index = e_o2v
        
        # build forward edges between virtual visit nodes
        if num_visit_p > 1:
            src = torch.arange(0, num_visit_p - 1, device=self.device)
            dst = torch.arange(1, num_visit_p, device=self.device)
            e_next = torch.stack([src, dst], dim=0) # [2, num_visit_p-1]
        else:
            e_next = torch.empty(2, 0, dtype=torch.long, device=self.device)
        hg['visit','next','visit'].edge_index = e_next
        return hg

    def process_gnn_out(self, gnn_out, num_visits, V):
        """Process the output of the GNN layer.

        Args:
            gnn_out (Tensor): The output of the GNN layer. Shape [sum(num_visits), d]
            num_visits (Tensor): A tensor containing the number of visits for each patient.
            V (int): The maximum number of visits.

        Returns:
            Tensor: The processed visit embeddings. Shape [B, V, d]
        """
        B = len(num_visits)
        # 计算每个批次的累积偏移量
        cumsum = torch.cumsum(num_visits, dim=0)  # [B]
        offsets = torch.cat([torch.tensor([0], device=self.device), cumsum[:-1]])  # [B]

        # 创建索引以从 gnn_out 中提取所有批次的嵌入
        indices = torch.arange(sum(num_visits), device=self.device)  # [N]
        batch_indices = torch.repeat_interleave(torch.arange(B, device=self.device), num_visits)  # [N]
        visit_pos = indices - offsets[batch_indices]  # [N]，每个嵌入的相对位置

        # 创建目标张量 visit_emb_pad，初始化为零
        visit_emb_pad = torch.zeros(B, V, self.d_model, device=self.device, dtype=gnn_out.dtype)  # [B, V, d]

        # 创建掩码，选择有效位置 (visit_pos < V 且 visit_pos < num_visits)
        mask = (visit_pos < V) & (visit_pos < num_visits[batch_indices])  # [N]
        valid_indices = indices[mask]  # [N_valid]
        valid_batch_indices = batch_indices[mask]  # [N_valid]
        valid_visit_pos = visit_pos[mask]  # [N_valid]

        # 使用 scatter 将 gnn_out 的值分配到 visit_emb_pad
        visit_emb_pad[valid_batch_indices, valid_visit_pos] = gnn_out[valid_indices]
        return visit_emb_pad
    
    def prepare_tf_input(self, task_id_embed, seq_embed, visit_embed, layer_i, input_ids, adm_index, token_types, num_visits):
        """Prepare the input for the Transformer layer.
        Args:
            task_id_emb (Tensor): Task ID embeddings. Shape [B, 1, d]
            seq_embed (Tensor): Sequence embeddings. Shape [B, L, d]
            visit_embed (Tensor): Visit embeddings. Shape [B, V, d]
            layer_i (int): The current layer index.
            adm_index (tensor): The admission index. Shape [B, L]
            token_types (Tensor): Token types. Shape [B, L]

        Returns:
            Tuple[Tensor, Tensor, Tensor]: Processed inputs for the Transformer layer.
        """
        
        B, L, d = seq_embed.shape
        V = visit_embed.shape[1]

        # Part 1: prepare main seq embedding x
        # important: initiate new tensor to ensure safe autograd
        x = torch.empty(B, 1 + L + V, d, device=seq_embed.device, dtype=seq_embed.dtype)
        x[:, 0:1, :] = task_id_embed
        x[:, 1:1 + L, :] = seq_embed
        x[:, 1 + L:, :] = visit_embed
        
        # we already have token_types for main seq, just prepare token types for visit nodes
        # here it is out of the if branch because it is needed in the mask making
        arange_V = torch.arange(1, V + 1, device=self.device, dtype=torch.long)[None, :]  # [1, V]
        n_v = num_visits.view(B, 1)  # [B, 1]
        visit_index = torch.where(arange_V <= n_v, arange_V, torch.full((B, V), self.adm_pad_id, device=self.device, dtype=torch.long))  # [B, V]
        visit_type_id = torch.full((B, V), self.node_type_id_dict['visit'], dtype=torch.long, device=self.device)  # [B, V]
        visit_type_id_mask = (visit_index != self.adm_pad_id).long() # [B, V]
        visit_type_id = visit_type_id * visit_type_id_mask # [B, V]
        token_types = torch.cat([token_types, visit_type_id], dim=1)  # [B, L+V]

        # if it is the first time transformer going through, we need extra information of admission index and token types
        if (layer_i == 0) or (layer_i == 1 and self.layer_types[0] == 'gnn'):
            adm_index = torch.cat([adm_index, visit_index], dim=1)  # [B, L+V]
            # transform into embedding and add
            adm_index_embed = self.adm_index_emb(adm_index)
            token_type_embed = self.type_emb(token_types)
            x_non_task = x[:, 1:, :]
            x_non_task.add_(adm_index_embed).add_(token_type_embed)
            x[:, 1:, :] = x_non_task
        else:
            x = x
            
        # part 2: prepare mask (src_key_padding_mask and attn_mask)
        task_pad_mask = torch.zeros((B, 1), dtype=torch.bool, device=self.device) # [B, 1]
        seq_pad_mask = (input_ids == self.seq_pad_id) # [B, L], bool
        visit_pad_mask = (visit_index == self.adm_pad_id) # [B, V], bool
        src_key_padding_mask = torch.cat([task_pad_mask, seq_pad_mask, visit_pad_mask], dim=1)  # [B, 1+L+V]
        attn_mask = self.build_attn_mask(torch.cat([torch.full((B, 1), -1, device=self.device), token_types], dim=1), 
                                         forbid_map=self.forbid_map, 
                                         num_heads=self.num_attn_heads)
        assert attn_mask.dtype == src_key_padding_mask.dtype, f"attn_mask dtype ({attn_mask.dtype}) and src_key_padding_mask dtype ({src_key_padding_mask.dtype}) must match"
        return x, src_key_padding_mask, attn_mask
    
    def process_tf_out(self, h, L, V):
        return h[:, 0:1, :], h[:, 1:1 + L, :], h[:, 1 + L:, :]  # [B, 1, d], [B, L, d], [B, V, d]
        
    @staticmethod
    def build_attn_mask(token_types, forbid_map, num_heads):
        B, L = token_types.shape
        device = token_types.device
        
        if forbid_map == None:
            mask = torch.zeros((B, L, L), dtype=torch.bool, device=device)
        else:
            # 收集所有出现的 token 类型
            observed = torch.unique(token_types)
            for q_t, ks in forbid_map.items():
                observed = torch.unique(torch.cat([observed, torch.tensor([q_t] + list(ks), device=device)]))
            type_list = observed.sort().values
            t2i = {t.item(): i for i, t in enumerate(type_list)}  # Map token types to indices
            T = len(type_list)

            # 构造禁止矩阵 (T, T)，单向关系
            ban_table = torch.zeros((T, T), dtype=torch.bool, device=device)
            for q_t, ks in forbid_map.items():
                if q_t in t2i:
                    qi = t2i[q_t]
                    for k_t in ks:
                        if k_t in t2i:
                            ban_table[qi, t2i[k_t]] = True  # 只设置 q -> k 的禁止

            # 向量化映射 token_types 到类型索引
            mapping = torch.zeros_like(type_list, dtype=torch.long, device=device)
            for t, i in t2i.items():
                mapping[type_list == t] = i
            q_idx = mapping[torch.searchsorted(type_list, token_types.unsqueeze(-1))]
            k_idx = mapping[torch.searchsorted(type_list, token_types.unsqueeze(-2))]

            # 查询 ban_table 得到 (B, L, L)
            mask = ban_table[q_idx, k_idx].to(torch.bool)
        
        # 扩展到 num_heads
        mask = mask.unsqueeze(1).expand(B, num_heads, L, L)
        mask = mask.reshape(B * num_heads, L, L)
        return mask

In [18]:
final_metrics = {"precision":[],"recall":[],"f1":[],"auc":[],"prauc":[]}
for i in range(5):
    model = HeteroGT(tokenizer, d_model=128, num_heads=4, layer_types=['gnn', 'tf', 'tf'], max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, label_vocab_size=config.label_vocab_size).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 29.33it/s, loss=0.6497]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.58it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.31it/s]


Validation: {'precision': 0.8015952143529332, 'recall': 0.5042333019739598, 'f1': 0.619056780627445, 'auc': 0.7815021477726674, 'prauc': 0.7848365746312035}
Test:      {'precision': 0.806420699564895, 'recall': 0.5277516462824468, 'f1': 0.6379833159132139, 'auc': 0.7840194549180142, 'prauc': 0.7879161664018952}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 32.17it/s, loss=0.5945]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 45.45it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.27it/s]


Validation: {'precision': 0.5274120979827905, 'recall': 0.9924741298181484, 'f1': 0.6887921608638052, 'auc': 0.7846365203842228, 'prauc': 0.7984566136491054}
Test:      {'precision': 0.5294314942134338, 'recall': 0.9899655064252432, 'f1': 0.6899038416116962, 'auc': 0.7856935986313305, 'prauc': 0.7975601226493783}


Epoch 003: 100%|██████████| 98/98 [00:02<00:00, 33.67it/s, loss=0.5524]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.61it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 45.81it/s]


Validation: {'precision': 0.7901234567868719, 'recall': 0.6020696142972654, 'f1': 0.6833956170855937, 'auc': 0.8101620748750937, 'prauc': 0.8238749438971518}
Test:      {'precision': 0.7901874750666527, 'recall': 0.6211978676681681, 'f1': 0.6955758377658672, 'auc': 0.8081832764816054, 'prauc': 0.8254112423320001}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 31.06it/s, loss=0.5078]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.25it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.53it/s]


Validation: {'precision': 0.885365853653138, 'recall': 0.455315145812307, 'f1': 0.6013667381087885, 'auc': 0.8245745523225514, 'prauc': 0.8346506162832065}
Test:      {'precision': 0.8747795414410654, 'recall': 0.46660395108038066, 'f1': 0.6085889525157013, 'auc': 0.8215056572376855, 'prauc': 0.8302603332495617}


Epoch 005: 100%|██████████| 98/98 [00:02<00:00, 33.96it/s, loss=0.5097]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.56it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.12it/s]


Validation: {'precision': 0.6861499364658548, 'recall': 0.8466603951055295, 'f1': 0.7580011180173264, 'auc': 0.8248079523602692, 'prauc': 0.8320325184589825}
Test:      {'precision': 0.6809152627591629, 'recall': 0.8491690184984347, 'f1': 0.755791231452434, 'auc': 0.8154495766194689, 'prauc': 0.8190985327045398}


Epoch 006: 100%|██████████| 98/98 [00:02<00:00, 33.32it/s, loss=0.4444]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.35it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 45.07it/s]


Validation: {'precision': 0.8071046600427536, 'recall': 0.6625901536511051, 'f1': 0.7277423749321812, 'auc': 0.8277700630843228, 'prauc': 0.8364929116934356}
Test:      {'precision': 0.7939278937351274, 'recall': 0.6560050172447288, 'f1': 0.7184065884493689, 'auc': 0.8205475583957091, 'prauc': 0.8264118948410772}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 32.34it/s, loss=0.4303]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 47.94it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 47.86it/s]


Validation: {'precision': 0.7839178179214172, 'recall': 0.693947946062421, 'f1': 0.7361942731265275, 'auc': 0.8256152572690347, 'prauc': 0.8299818407195549}
Test:      {'precision': 0.7772851296017146, 'recall': 0.7146440890538895, 'f1': 0.7446495620705866, 'auc': 0.8245822442414661, 'prauc': 0.8266331399612751}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 32.02it/s, loss=0.4239]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 47.43it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 47.52it/s]


Validation: {'precision': 0.8373943311743541, 'recall': 0.5280652242065599, 'f1': 0.647692302946415, 'auc': 0.8211720660645688, 'prauc': 0.828160715425018}
Test:      {'precision': 0.8308505526149407, 'recall': 0.542176230791652, 'f1': 0.6561669781407293, 'auc': 0.8144092831145586, 'prauc': 0.8217955184164887}


Epoch 009: 100%|██████████| 98/98 [00:03<00:00, 32.46it/s, loss=0.3657]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.13it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 47.77it/s]


Validation: {'precision': 0.7924739195201473, 'recall': 0.6669802445886893, 'f1': 0.7243316824174408, 'auc': 0.8231121599209233, 'prauc': 0.8293543596445477}
Test:      {'precision': 0.7778176597244156, 'recall': 0.6795233615532157, 'f1': 0.7253556435558829, 'auc': 0.8191833685919747, 'prauc': 0.8252826000933664}


Epoch 010: 100%|██████████| 98/98 [00:03<00:00, 32.24it/s, loss=0.3324]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.44it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.65it/s]


Validation: {'precision': 0.7440220723459717, 'recall': 0.761053621822637, 'f1': 0.7524414769390919, 'auc': 0.8291257525845059, 'prauc': 0.8306847432775001}
Test:      {'precision': 0.7415048543666823, 'recall': 0.7663844465325608, 'f1': 0.7537393936112187, 'auc': 0.826874534100215, 'prauc': 0.8286007132530359}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6861499364658548, 'recall': 0.8466603951055295, 'f1': 0.7580011180173264, 'auc': 0.8248079523602692, 'prauc': 0.8320325184589825}
Corresponding test performance:
{'precision': 0.6809152627591629, 'recall': 0.8491690184984347, 'f1': 0.755791231452434, 'auc': 0.8154495766194689, 'prauc': 0.8190985327045398}


Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 32.36it/s, loss=0.6700]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 45.46it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.88it/s]


Validation: {'precision': 0.5579080709363773, 'recall': 0.9667607400408694, 'f1': 0.7075157727532163, 'auc': 0.7891112100440586, 'prauc': 0.7983994990800036}
Test:      {'precision': 0.5613843351538044, 'recall': 0.9664471621167562, 'f1': 0.7102200667866151, 'auc': 0.7835929285409767, 'prauc': 0.7949875282015515}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 32.64it/s, loss=0.5947]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.44it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 47.98it/s]


Validation: {'precision': 0.6722855726309562, 'recall': 0.8504233301948874, 'f1': 0.750934509810917, 'auc': 0.8167251573390547, 'prauc': 0.8250806958013426}
Test:      {'precision': 0.6622353520416981, 'recall': 0.8435246158643979, 'f1': 0.741966620363494, 'auc': 0.8087975469448299, 'prauc': 0.8215801786481814}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 31.11it/s, loss=0.5593]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 46.92it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.93it/s]


Validation: {'precision': 0.6931053811639768, 'recall': 0.7754782063318424, 'f1': 0.7319816436742121, 'auc': 0.8042156990110088, 'prauc': 0.8159137553325255}
Test:      {'precision': 0.6911314984690267, 'recall': 0.7795547193453134, 'f1': 0.7326849345974064, 'auc': 0.8012864244918956, 'prauc': 0.8141386203926089}


Epoch 004: 100%|██████████| 98/98 [00:02<00:00, 34.74it/s, loss=0.5275]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.39it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.71it/s]


Validation: {'precision': 0.7718984626357815, 'recall': 0.6770147381603104, 'f1': 0.7213498112569205, 'auc': 0.8201031782895967, 'prauc': 0.8269138253813524}
Test:      {'precision': 0.7720098211126902, 'recall': 0.6901850109730631, 'f1': 0.728807942033112, 'auc': 0.8189359289330751, 'prauc': 0.8283650274286332}


Epoch 005: 100%|██████████| 98/98 [00:02<00:00, 34.09it/s, loss=0.4961]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 46.19it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 49.13it/s]


Validation: {'precision': 0.7658989343390653, 'recall': 0.6986516149241184, 'f1': 0.7307313823482566, 'auc': 0.8201052882339757, 'prauc': 0.827185608326384}
Test:      {'precision': 0.7557944239141559, 'recall': 0.7055503292546079, 'f1': 0.7298086229633229, 'auc': 0.8199723467557771, 'prauc': 0.8286280429311708}


Epoch 006: 100%|██████████| 98/98 [00:02<00:00, 32.92it/s, loss=0.4856]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.19it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.61it/s]


Validation: {'precision': 0.7529089664587512, 'recall': 0.6898714330489499, 'f1': 0.7200130861543002, 'auc': 0.8010634923457749, 'prauc': 0.8074721976053916}
Test:      {'precision': 0.7480528276303825, 'recall': 0.6926936343659684, 'f1': 0.71930966612194, 'auc': 0.7986978714451312, 'prauc': 0.8083142228860276}


Epoch 007: 100%|██████████| 98/98 [00:02<00:00, 33.86it/s, loss=0.4490]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.50it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 49.07it/s]


Validation: {'precision': 0.7079772079751909, 'recall': 0.7792411414212003, 'f1': 0.7419017713938005, 'auc': 0.8099487193089752, 'prauc': 0.8125128347360443}
Test:      {'precision': 0.7067157835060733, 'recall': 0.7820633427382188, 'f1': 0.7424828768206592, 'auc': 0.8111206409271035, 'prauc': 0.8169468786085469}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6722855726309562, 'recall': 0.8504233301948874, 'f1': 0.750934509810917, 'auc': 0.8167251573390547, 'prauc': 0.8250806958013426}
Corresponding test performance:
{'precision': 0.6622353520416981, 'recall': 0.8435246158643979, 'f1': 0.741966620363494, 'auc': 0.8087975469448299, 'prauc': 0.8215801786481814}


Epoch 001: 100%|██████████| 98/98 [00:02<00:00, 33.05it/s, loss=0.6590]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.40it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.25it/s]


Validation: {'precision': 0.6175094675860604, 'recall': 0.869238005641677, 'f1': 0.7220630323907408, 'auc': 0.7804262268497657, 'prauc': 0.7910914462163353}
Test:      {'precision': 0.6205197132602587, 'recall': 0.8686108497934506, 'f1': 0.7238991196632177, 'auc': 0.7811549230979082, 'prauc': 0.7940954549325}


Epoch 002: 100%|██████████| 98/98 [00:02<00:00, 33.21it/s, loss=0.5808]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.44it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 45.86it/s]


Validation: {'precision': 0.5624536693836871, 'recall': 0.9517089996834378, 'f1': 0.7070471706345599, 'auc': 0.7479032930102867, 'prauc': 0.7596651501226466}
Test:      {'precision': 0.5683345780422548, 'recall': 0.9545312010004562, 'f1': 0.7124634242256865, 'auc': 0.7454895867465762, 'prauc': 0.7614326833658639}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 32.01it/s, loss=0.5566]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.47it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 49.51it/s]


Validation: {'precision': 0.7189702727529299, 'recall': 0.7356538099694712, 'f1': 0.7272163620163824, 'auc': 0.7971796373487885, 'prauc': 0.8050369409040241}
Test:      {'precision': 0.7084959471608872, 'recall': 0.7400439009070554, 'f1': 0.7239263753682492, 'auc': 0.7929022668714235, 'prauc': 0.8067691880704354}


Epoch 004: 100%|██████████| 98/98 [00:02<00:00, 33.34it/s, loss=0.5286]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 52.27it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.27it/s]


Validation: {'precision': 0.7246029367673821, 'recall': 0.7582314205056186, 'f1': 0.7410358515740059, 'auc': 0.8079712995318235, 'prauc': 0.8186746090280342}
Test:      {'precision': 0.7133947133925946, 'recall': 0.753214173719808, 'f1': 0.7327638754163377, 'auc': 0.8041117968216387, 'prauc': 0.818796306719876}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 31.59it/s, loss=0.4913]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 49.73it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 53.93it/s]


Validation: {'precision': 0.8160705991601669, 'recall': 0.5509564126668205, 'f1': 0.6578060603297318, 'auc': 0.8135968131400102, 'prauc': 0.8220054275687896}
Test:      {'precision': 0.8112947658364955, 'recall': 0.5540921919079521, 'f1': 0.6584684132856454, 'auc': 0.8102449380822017, 'prauc': 0.8198669559475237}


Epoch 006: 100%|██████████| 98/98 [00:02<00:00, 34.94it/s, loss=0.4736]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 53.61it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 53.55it/s]


Validation: {'precision': 0.6788283658769815, 'recall': 0.8284728755069662, 'f1': 0.7462222800351708, 'auc': 0.8137890190255695, 'prauc': 0.8232385188910004}
Test:      {'precision': 0.6757862439256564, 'recall': 0.8287864534310794, 'f1': 0.7445070373031284, 'auc': 0.8105253965401735, 'prauc': 0.8213268120244583}


Epoch 007: 100%|██████████| 98/98 [00:02<00:00, 34.84it/s, loss=0.4542]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 46.73it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.87it/s]


Validation: {'precision': 0.7084571737349137, 'recall': 0.8222013170247031, 'f1': 0.7611030429209018, 'auc': 0.8276391460592921, 'prauc': 0.8392213709164975}
Test:      {'precision': 0.6994173728795037, 'recall': 0.8281592975828531, 'f1': 0.7583632398287424, 'auc': 0.8254903618830179, 'prauc': 0.8367705795635179}


Epoch 008: 100%|██████████| 98/98 [00:02<00:00, 33.05it/s, loss=0.4198]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.70it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.43it/s]


Validation: {'precision': 0.7912287060500499, 'recall': 0.6845406083390262, 'f1': 0.7340282398118275, 'auc': 0.8264412502565842, 'prauc': 0.8351696634311401}
Test:      {'precision': 0.7749648382532526, 'recall': 0.6911257447454026, 'f1': 0.7306480971190176, 'auc': 0.8249885365373648, 'prauc': 0.8348584141542339}


Epoch 009: 100%|██████████| 98/98 [00:02<00:00, 33.55it/s, loss=0.4038]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.78it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 49.93it/s]


Validation: {'precision': 0.7469325153351322, 'recall': 0.7635622452155423, 'f1': 0.7551558331127005, 'auc': 0.8273848977616404, 'prauc': 0.8408843900843888}
Test:      {'precision': 0.7429438543224797, 'recall': 0.7676387582290134, 'f1': 0.7550894459552072, 'auc': 0.8300416208002328, 'prauc': 0.8419449206383945}


Epoch 010: 100%|██████████| 98/98 [00:02<00:00, 32.82it/s, loss=0.3629]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 49.97it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.18it/s]


Validation: {'precision': 0.7780678851145914, 'recall': 0.6541235497000498, 'f1': 0.7107325333655199, 'auc': 0.8128454719940552, 'prauc': 0.8149910306938599}
Test:      {'precision': 0.7786400591249866, 'recall': 0.6607086861064262, 'f1': 0.714843082393312, 'auc': 0.8141091945778702, 'prauc': 0.817953609625486}


Epoch 011: 100%|██████████| 98/98 [00:02<00:00, 33.30it/s, loss=0.3449]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.12it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 47.10it/s]


Validation: {'precision': 0.7608401083985067, 'recall': 0.7042960175581553, 'f1': 0.7314769531552033, 'auc': 0.8222444201767349, 'prauc': 0.8340632379767887}
Test:      {'precision': 0.7518370073455183, 'recall': 0.705863907178721, 'f1': 0.7281255004206999, 'auc': 0.8193711126781618, 'prauc': 0.828021843735765}


Epoch 012: 100%|██████████| 98/98 [00:02<00:00, 32.76it/s, loss=0.2909]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.88it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.01it/s]


Validation: {'precision': 0.7766323024028294, 'recall': 0.7086861084957394, 'f1': 0.7411050942046226, 'auc': 0.8243229163369878, 'prauc': 0.8341892619726579}
Test:      {'precision': 0.7699864498618904, 'recall': 0.7127626215092105, 'f1': 0.740270309286098, 'auc': 0.824851830655841, 'prauc': 0.8343264969302487}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7084571737349137, 'recall': 0.8222013170247031, 'f1': 0.7611030429209018, 'auc': 0.8276391460592921, 'prauc': 0.8392213709164975}
Corresponding test performance:
{'precision': 0.6994173728795037, 'recall': 0.8281592975828531, 'f1': 0.7583632398287424, 'auc': 0.8254903618830179, 'prauc': 0.8367705795635179}


Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 31.97it/s, loss=0.6668]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.53it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 53.36it/s]


Validation: {'precision': 0.72948539638133, 'recall': 0.6578864847894077, 'f1': 0.6918384121586034, 'auc': 0.7780123500068171, 'prauc': 0.7903449974281109}
Test:      {'precision': 0.721065209966811, 'recall': 0.662276575726992, 'f1': 0.6904217014467864, 'auc': 0.7754536939026458, 'prauc': 0.7871244862388118}


Epoch 002: 100%|██████████| 98/98 [00:02<00:00, 36.00it/s, loss=0.6236]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 53.44it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 53.70it/s]


Validation: {'precision': 0.6971301198085897, 'recall': 0.784571966131124, 'f1': 0.7382708713799444, 'auc': 0.8005510772823394, 'prauc': 0.8044502564863738}
Test:      {'precision': 0.6857841804819631, 'recall': 0.7911571025375003, 'f1': 0.7347117014880754, 'auc': 0.7932206768149141, 'prauc': 0.8005789042831879}


Epoch 003: 100%|██████████| 98/98 [00:02<00:00, 34.97it/s, loss=0.5543]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.09it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 47.64it/s]


Validation: {'precision': 0.7649690295913112, 'recall': 0.6970837253035526, 'f1': 0.7294503641634309, 'auc': 0.8109545096996654, 'prauc': 0.8220096257017216}
Test:      {'precision': 0.7403241812744283, 'recall': 0.7017873941652499, 'f1': 0.7205408836039552, 'auc': 0.8009567901700618, 'prauc': 0.8141017995308915}


Epoch 004: 100%|██████████| 98/98 [00:02<00:00, 34.63it/s, loss=0.5147]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.17it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.88it/s]


Validation: {'precision': 0.7977786288747691, 'recall': 0.6531828159277103, 'f1': 0.7182758571161446, 'auc': 0.818348307407643, 'prauc': 0.8328577023612613}
Test:      {'precision': 0.7763602251377998, 'recall': 0.6487927249901261, 'f1': 0.706867094456847, 'auc': 0.8121584177552552, 'prauc': 0.8262327834610704}


Epoch 005: 100%|██████████| 98/98 [00:02<00:00, 33.12it/s, loss=0.4990]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 49.77it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.24it/s]


Validation: {'precision': 0.8307763830737761, 'recall': 0.5603637503902152, 'f1': 0.6692883846995238, 'auc': 0.8149006582925988, 'prauc': 0.8276540428643933}
Test:      {'precision': 0.831398049229766, 'recall': 0.5613044841625547, 'f1': 0.6701609835794046, 'auc': 0.8112349987190116, 'prauc': 0.8254718836756479}


Epoch 006: 100%|██████████| 98/98 [00:02<00:00, 33.33it/s, loss=0.4777]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.90it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 49.65it/s]


Validation: {'precision': 0.859132297799362, 'recall': 0.5029789902775071, 'f1': 0.6344936662254029, 'auc': 0.8171559376497369, 'prauc': 0.8279476443046747}
Test:      {'precision': 0.8435725348433476, 'recall': 0.5123863280009019, 'f1': 0.6375341349758531, 'auc': 0.8169850514433896, 'prauc': 0.827400010213271}


Epoch 007: 100%|██████████| 98/98 [00:02<00:00, 33.24it/s, loss=0.4569]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.45it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 49.88it/s]


Validation: {'precision': 0.7119722382859689, 'recall': 0.7720288491665976, 'f1': 0.7407853116901784, 'auc': 0.8143986924775158, 'prauc': 0.8258757299853514}
Test:      {'precision': 0.7123209169034032, 'recall': 0.7795547193453134, 'f1': 0.7444228127950198, 'auc': 0.8106965808932893, 'prauc': 0.8222473098502141}


Epoch 008: 100%|██████████| 98/98 [00:02<00:00, 33.30it/s, loss=0.4264]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 47.65it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 52.06it/s]


Validation: {'precision': 0.7625463118869568, 'recall': 0.7099404201921922, 'f1': 0.7353036650267283, 'auc': 0.8193162695098268, 'prauc': 0.8327064460794685}
Test:      {'precision': 0.7552840158495533, 'recall': 0.7171527124467948, 'f1': 0.7357246210264007, 'auc': 0.820981836137163, 'prauc': 0.8338611580290678}


Epoch 009: 100%|██████████| 98/98 [00:03<00:00, 31.84it/s, loss=0.3976]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 47.99it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.07it/s]


Validation: {'precision': 0.7278931750720241, 'recall': 0.7692066478495792, 'f1': 0.7479798699824692, 'auc': 0.8151774629003959, 'prauc': 0.8234360285211231}
Test:      {'precision': 0.7220732797118795, 'recall': 0.7601128880502975, 'f1': 0.7406049445885651, 'auc': 0.8127728895526203, 'prauc': 0.8187186735992835}


Epoch 010: 100%|██████████| 98/98 [00:02<00:00, 35.50it/s, loss=0.3863]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.39it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 52.48it/s]


Validation: {'precision': 0.7029163259724641, 'recall': 0.8087174662878372, 'f1': 0.7521143140657535, 'auc': 0.8192631190061881, 'prauc': 0.8291924338012026}
Test:      {'precision': 0.700513097486631, 'recall': 0.8134211351495346, 'f1': 0.7527568145264965, 'auc': 0.8185449380318681, 'prauc': 0.8289778851389309}


Epoch 011: 100%|██████████| 98/98 [00:02<00:00, 35.09it/s, loss=0.3613]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.80it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 49.59it/s]


Validation: {'precision': 0.7560113154145828, 'recall': 0.6704296017539341, 'f1': 0.7106531444256409, 'auc': 0.8059174696260948, 'prauc': 0.8138183772663146}
Test:      {'precision': 0.7566817077377415, 'recall': 0.6835998745666868, 'f1': 0.7182866506941971, 'auc': 0.811373315273661, 'prauc': 0.8189017501308403}


Epoch 012: 100%|██████████| 98/98 [00:02<00:00, 32.75it/s, loss=0.3411]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 49.77it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 45.99it/s]


Validation: {'precision': 0.7285669092036426, 'recall': 0.7221699592326053, 'f1': 0.7253543257064741, 'auc': 0.7994041215653496, 'prauc': 0.8090125404487303}
Test:      {'precision': 0.7289631263765239, 'recall': 0.7253057384737369, 'f1': 0.7271298283834736, 'auc': 0.8015792649995193, 'prauc': 0.8137859667137267}


Epoch 013: 100%|██████████| 98/98 [00:02<00:00, 32.90it/s, loss=0.3166]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 49.64it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.80it/s]


Validation: {'precision': 0.744701986752501, 'recall': 0.7052367513304947, 'f1': 0.7244322707301516, 'auc': 0.8050463137814836, 'prauc': 0.814186648882547}
Test:      {'precision': 0.7427536231859594, 'recall': 0.7071182188751737, 'f1': 0.7244979869685643, 'auc': 0.8092816548861026, 'prauc': 0.8216587241360547}


Epoch 014: 100%|██████████| 98/98 [00:02<00:00, 33.20it/s, loss=0.2915]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.66it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 49.55it/s]


Validation: {'precision': 0.7545968882575863, 'recall': 0.6691752900574814, 'f1': 0.7093235781966276, 'auc': 0.8024646963604163, 'prauc': 0.8115813433503707}
Test:      {'precision': 0.7676450034913779, 'recall': 0.6889306992766104, 'f1': 0.7261609601419325, 'auc': 0.8098456421476916, 'prauc': 0.8206502424553752}


Epoch 015: 100%|██████████| 98/98 [00:02<00:00, 33.13it/s, loss=0.2719]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.04it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 49.59it/s]


Validation: {'precision': 0.6784406070102605, 'recall': 0.8131075572254215, 'f1': 0.7396947604073547, 'auc': 0.798875781445531, 'prauc': 0.8053073957145753}
Test:      {'precision': 0.6811669705634771, 'recall': 0.820006271555911, 'f1': 0.7441661874140154, 'auc': 0.8076038871582542, 'prauc': 0.8205597684399248}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7029163259724641, 'recall': 0.8087174662878372, 'f1': 0.7521143140657535, 'auc': 0.8192631190061881, 'prauc': 0.8291924338012026}
Corresponding test performance:
{'precision': 0.700513097486631, 'recall': 0.8134211351495346, 'f1': 0.7527568145264965, 'auc': 0.8185449380318681, 'prauc': 0.8289778851389309}


Epoch 001: 100%|██████████| 98/98 [00:02<00:00, 33.34it/s, loss=0.6778]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.15it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 49.38it/s]


Validation: {'precision': 0.7403273809496268, 'recall': 0.6240200689851866, 'f1': 0.6772162618368218, 'auc': 0.7760193065939077, 'prauc': 0.7746910306189052}
Test:      {'precision': 0.7428888067944481, 'recall': 0.6306052053915628, 'f1': 0.6821573898750638, 'auc': 0.7695707103118816, 'prauc': 0.7641799513609133}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 29.90it/s, loss=0.6016]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.51it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 46.49it/s]


Validation: {'precision': 0.6248906386687995, 'recall': 0.8958921291912955, 'f1': 0.7362453243670299, 'auc': 0.7968796233528241, 'prauc': 0.8030417733758017}
Test:      {'precision': 0.6217639315475608, 'recall': 0.8886798369366928, 'f1': 0.731638048594291, 'auc': 0.7938447119839618, 'prauc': 0.8054745715954369}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 31.47it/s, loss=0.5544]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.53it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 52.75it/s]


Validation: {'precision': 0.8696581196519256, 'recall': 0.3828786453421672, 'f1': 0.5316786371637128, 'auc': 0.8011503014859334, 'prauc': 0.8077829798768841}
Test:      {'precision': 0.8588235294058214, 'recall': 0.3891502038244304, 'f1': 0.5356063832760141, 'auc': 0.7938817071323123, 'prauc': 0.8028438231284258}


Epoch 004: 100%|██████████| 98/98 [00:02<00:00, 34.18it/s, loss=0.5234]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 52.24it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.64it/s]


Validation: {'precision': 0.75704567541057, 'recall': 0.7328316086524528, 'f1': 0.7447418688039189, 'auc': 0.8218007792526959, 'prauc': 0.8272013724079741}
Test:      {'precision': 0.7416693113273829, 'recall': 0.7328316086524528, 'f1': 0.737223969761261, 'auc': 0.8131807925216447, 'prauc': 0.8223002040102056}


Epoch 005: 100%|██████████| 98/98 [00:02<00:00, 34.24it/s, loss=0.4947]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.44it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 49.41it/s]


Validation: {'precision': 0.7567208762006084, 'recall': 0.7149576669780027, 'f1': 0.7352466896162961, 'auc': 0.8186412882556778, 'prauc': 0.8293581946980587}
Test:      {'precision': 0.7543391188225823, 'recall': 0.7086861084957394, 'f1': 0.7308003183654804, 'auc': 0.8141706518243138, 'prauc': 0.8233305565727149}


Epoch 006: 100%|██████████| 98/98 [00:02<00:00, 32.71it/s, loss=0.4762]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 45.55it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.08it/s]


Validation: {'precision': 0.7028123309878237, 'recall': 0.8149890247701004, 'f1': 0.7547553311656744, 'auc': 0.8267673371366587, 'prauc': 0.8360914149461838}
Test:      {'precision': 0.6956406166913992, 'recall': 0.8206334274041373, 'f1': 0.7529851770200137, 'auc': 0.818948864651614, 'prauc': 0.8279817912737398}


Epoch 007: 100%|██████████| 98/98 [00:02<00:00, 35.35it/s, loss=0.4470]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.18it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.91it/s]


Validation: {'precision': 0.7348203221786406, 'recall': 0.7438068359964133, 'f1': 0.7392862658409536, 'auc': 0.8187063448740257, 'prauc': 0.8277444301455945}
Test:      {'precision': 0.7283428916367994, 'recall': 0.7566635308850528, 'f1': 0.742233154027527, 'auc': 0.8138861666835284, 'prauc': 0.8253020049555493}


Epoch 008: 100%|██████████| 98/98 [00:02<00:00, 33.61it/s, loss=0.4287]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.52it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 49.80it/s]


Validation: {'precision': 0.780478955004425, 'recall': 0.6745061147674052, 'f1': 0.72363330032266, 'auc': 0.8157826652797299, 'prauc': 0.8264812299070784}
Test:      {'precision': 0.7723087818669536, 'recall': 0.6839134524907999, 'f1': 0.7254282338319097, 'auc': 0.8135834104694759, 'prauc': 0.8234758262105519}


Epoch 009: 100%|██████████| 98/98 [00:02<00:00, 32.71it/s, loss=0.4015]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 49.62it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.16it/s]


Validation: {'precision': 0.721145501878157, 'recall': 0.7817497648141055, 'f1': 0.7502256946748467, 'auc': 0.8188164136391225, 'prauc': 0.8301526729893548}
Test:      {'precision': 0.7118304210807452, 'recall': 0.7792411414212003, 'f1': 0.7440119710558962, 'auc': 0.816249628161009, 'prauc': 0.8286529609977614}


Epoch 010: 100%|██████████| 98/98 [00:03<00:00, 32.59it/s, loss=0.3686]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.33it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 47.40it/s]


Validation: {'precision': 0.7615826851512967, 'recall': 0.7061774851028342, 'f1': 0.7328343588186026, 'auc': 0.8167541439558783, 'prauc': 0.8248911984507582}
Test:      {'precision': 0.7506729475075684, 'recall': 0.6995923486964579, 'f1': 0.7242330740494614, 'auc': 0.8099171157676341, 'prauc': 0.817995972157124}


Epoch 011: 100%|██████████| 98/98 [00:02<00:00, 33.20it/s, loss=0.3571]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 46.63it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 46.91it/s]

Validation: {'precision': 0.7941878980860104, 'recall': 0.6255879586057523, 'f1': 0.6998772095918216, 'auc': 0.815371075415541, 'prauc': 0.8265262526417874}
Test:      {'precision': 0.788235294114556, 'recall': 0.6302916274674497, 'f1': 0.7004704602973921, 'auc': 0.8139748543724994, 'prauc': 0.823787837577763}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7028123309878237, 'recall': 0.8149890247701004, 'f1': 0.7547553311656744, 'auc': 0.8267673371366587, 'prauc': 0.8360914149461838}
Corresponding test performance:
{'precision': 0.6956406166913992, 'recall': 0.8206334274041373, 'f1': 0.7529851770200137, 'auc': 0.818948864651614, 'prauc': 0.8279817912737398}





In [19]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")


Final Metrics:
precision: 0.6877 ± 0.0146
recall: 0.8310 ± 0.0135
f1: 0.7524 ± 0.0056
auc: 0.8174 ± 0.0054
prauc: 0.8269 ± 0.0062


In [20]:
final_metrics = {"precision":[],"recall":[],"f1":[],"auc":[],"prauc":[]}
for i in range(5):
    model = HeteroGT(tokenizer, d_model=128, num_heads=4, layer_types=['tf', 'tf'], max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, label_vocab_size=config.label_vocab_size).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

Epoch 001:   0%|          | 0/98 [00:00<?, ?it/s, loss=0.7715]

Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 60.79it/s, loss=0.6772]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 160.76it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 163.24it/s]


Validation: {'precision': 0.7550598926032422, 'recall': 0.5732204452788547, 'f1': 0.6516933997259637, 'auc': 0.7831272068385509, 'prauc': 0.7903610627256853}
Test:      {'precision': 0.7498999599809928, 'recall': 0.5876450297880601, 'f1': 0.6589310780529772, 'auc': 0.7783084106833935, 'prauc': 0.7875573141768277}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 64.69it/s, loss=0.5955]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 160.18it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 157.82it/s]


Validation: {'precision': 0.7222029788682987, 'recall': 0.6538099717759367, 'f1': 0.6863067757869202, 'auc': 0.7685157415414591, 'prauc': 0.7562942646500237}
Test:      {'precision': 0.7232295586701224, 'recall': 0.6629037315752183, 'f1': 0.6917539217087776, 'auc': 0.7700719819886477, 'prauc': 0.7613059275929781}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 62.41it/s, loss=0.5612]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 160.91it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 160.07it/s]


Validation: {'precision': 0.7368257904502888, 'recall': 0.745374725616979, 'f1': 0.7410755990508568, 'auc': 0.817901802987661, 'prauc': 0.8252391650874983}
Test:      {'precision': 0.7331701346366795, 'recall': 0.7513327061751292, 'f1': 0.7421403078372294, 'auc': 0.8150777627951621, 'prauc': 0.8245035329436828}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 62.87it/s, loss=0.5056]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 164.15it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 164.00it/s]


Validation: {'precision': 0.7443438914003092, 'recall': 0.7221699592326053, 'f1': 0.7330892835552316, 'auc': 0.8209590119190758, 'prauc': 0.8274861892278798}
Test:      {'precision': 0.7433712121188657, 'recall': 0.7384760112864895, 'f1': 0.7409155211893235, 'auc': 0.8166891909236557, 'prauc': 0.8251969452650496}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 64.10it/s, loss=0.4825]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 156.86it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 161.12it/s]


Validation: {'precision': 0.7508250825057728, 'recall': 0.7133897773574369, 'f1': 0.731628874241951, 'auc': 0.8212498828227318, 'prauc': 0.824592851666328}
Test:      {'precision': 0.7460937499975714, 'recall': 0.7187206020673605, 'f1': 0.7321514085116259, 'auc': 0.8175604644174622, 'prauc': 0.8222064540342141}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 61.45it/s, loss=0.4485]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 154.69it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 159.51it/s]


Validation: {'precision': 0.7455061494773084, 'recall': 0.741298212603508, 'f1': 0.7433962214127967, 'auc': 0.822539310022065, 'prauc': 0.8309428183515131}
Test:      {'precision': 0.7541139240482465, 'recall': 0.747256193161658, 'f1': 0.7506693917531342, 'auc': 0.8280878742990426, 'prauc': 0.836745392479688}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 60.83it/s, loss=0.4245]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 153.80it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 160.96it/s]


Validation: {'precision': 0.7522328812413092, 'recall': 0.7130761994333237, 'f1': 0.7321313536618701, 'auc': 0.8186071774882198, 'prauc': 0.8267151025277144}
Test:      {'precision': 0.7611592271793433, 'recall': 0.7165255565985684, 'f1': 0.7381683038375834, 'auc': 0.8225685504998875, 'prauc': 0.8310854763882314}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 53.32it/s, loss=0.3858]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 158.30it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 161.90it/s]


Validation: {'precision': 0.7398605830141323, 'recall': 0.7322044528042264, 'f1': 0.7360126033508492, 'auc': 0.8217591832063699, 'prauc': 0.8302404053117654}
Test:      {'precision': 0.7483485372735189, 'recall': 0.7460018814652054, 'f1': 0.7471733618318366, 'auc': 0.8235714461881659, 'prauc': 0.829959374973449}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 59.80it/s, loss=0.3509]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 155.64it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 158.20it/s]


Validation: {'precision': 0.6633569739937036, 'recall': 0.8798996550615243, 'f1': 0.7564361724787995, 'auc': 0.8209636839387718, 'prauc': 0.8222085127529689}
Test:      {'precision': 0.6676272814585311, 'recall': 0.8717466290345822, 'f1': 0.7561539458542496, 'auc': 0.8238507973083637, 'prauc': 0.8288679283649225}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 57.85it/s, loss=0.3338]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 159.05it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 154.91it/s]


Validation: {'precision': 0.8004807692275622, 'recall': 0.6265286923780918, 'f1': 0.7029023697420097, 'auc': 0.8231905795203374, 'prauc': 0.8220682839192873}
Test:      {'precision': 0.808112324489828, 'recall': 0.6497334587624656, 'f1': 0.7203198281870486, 'auc': 0.8270665565368917, 'prauc': 0.8292246418974727}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 60.34it/s, loss=0.2988]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 155.33it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 157.09it/s]


Validation: {'precision': 0.6836400302478356, 'recall': 0.8504233301948874, 'f1': 0.7579653388244497, 'auc': 0.816095640362593, 'prauc': 0.8095986819246023}
Test:      {'precision': 0.6891376240124927, 'recall': 0.8494825964225479, 'f1': 0.7609550512319401, 'auc': 0.8210123885930116, 'prauc': 0.81697332884635}


Epoch 012: 100%|██████████| 98/98 [00:01<00:00, 61.06it/s, loss=0.2906]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 157.77it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 168.77it/s]


Validation: {'precision': 0.7296082209353577, 'recall': 0.7124490435850974, 'f1': 0.7209265379144922, 'auc': 0.803694844170058, 'prauc': 0.7961200686680223}
Test:      {'precision': 0.7469055374568505, 'recall': 0.7190341799914738, 'f1': 0.732704899936357, 'auc': 0.8107405220694935, 'prauc': 0.8083004529860323}


Epoch 013: 100%|██████████| 98/98 [00:01<00:00, 62.56it/s, loss=0.2600]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 153.91it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 160.99it/s]


Validation: {'precision': 0.7313386798250113, 'recall': 0.7434932580723002, 'f1': 0.7373658789974246, 'auc': 0.8110545311105772, 'prauc': 0.7990272070170091}
Test:      {'precision': 0.7420253948567915, 'recall': 0.7513327061751292, 'f1': 0.7466500417414014, 'auc': 0.8161421660634192, 'prauc': 0.8100640263401492}


Epoch 014: 100%|██████████| 98/98 [00:01<00:00, 59.43it/s, loss=0.2450]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 161.01it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 153.53it/s]


Validation: {'precision': 0.7255827677759646, 'recall': 0.7710881153942581, 'f1': 0.7476436556894881, 'auc': 0.8134726780790544, 'prauc': 0.8080298458453824}
Test:      {'precision': 0.7301119622842366, 'recall': 0.7770460959524081, 'f1': 0.7528482404833454, 'auc': 0.8177232430702046, 'prauc': 0.8169315779319822}


Epoch 015: 100%|██████████| 98/98 [00:01<00:00, 62.03it/s, loss=0.2036]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 153.80it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 160.26it/s]


Validation: {'precision': 0.7175707547148656, 'recall': 0.7632486672914292, 'f1': 0.7397052069763738, 'auc': 0.8065317648609662, 'prauc': 0.8011941776335139}
Test:      {'precision': 0.7322551662152373, 'recall': 0.7666980244566739, 'f1': 0.7490808773532862, 'auc': 0.8151580951172948, 'prauc': 0.8125258567404787}


Epoch 016: 100%|██████████| 98/98 [00:01<00:00, 60.21it/s, loss=0.1936]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 158.83it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 154.94it/s]


Validation: {'precision': 0.7274138466292226, 'recall': 0.7347130761971317, 'f1': 0.7310452368075162, 'auc': 0.7990272453098699, 'prauc': 0.7925443074613794}
Test:      {'precision': 0.7329000309479019, 'recall': 0.7425525242999607, 'f1': 0.7376946990477602, 'auc': 0.8077259459810434, 'prauc': 0.8070018821533358}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6836400302478356, 'recall': 0.8504233301948874, 'f1': 0.7579653388244497, 'auc': 0.816095640362593, 'prauc': 0.8095986819246023}
Corresponding test performance:
{'precision': 0.6891376240124927, 'recall': 0.8494825964225479, 'f1': 0.7609550512319401, 'auc': 0.8210123885930116, 'prauc': 0.81697332884635}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 60.63it/s, loss=0.6694]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 154.09it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 161.37it/s]


Validation: {'precision': 0.6354312354297542, 'recall': 0.8548134211324716, 'f1': 0.7289744569328552, 'auc': 0.7811808836226016, 'prauc': 0.787947872055196}
Test:      {'precision': 0.6241085806288841, 'recall': 0.8507369081190005, 'f1': 0.7200106108274026, 'auc': 0.7757127102746348, 'prauc': 0.7843456993531529}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 60.94it/s, loss=0.6042]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 158.16it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 155.63it/s]


Validation: {'precision': 0.8196962273355234, 'recall': 0.5246158670413151, 'f1': 0.639770549731769, 'auc': 0.7917480879131434, 'prauc': 0.805511928460425}
Test:      {'precision': 0.8057692307653569, 'recall': 0.5255566008136546, 'f1': 0.6361738422488837, 'auc': 0.7839021274475311, 'prauc': 0.7995009299742211}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 60.36it/s, loss=0.5504]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 152.77it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 160.10it/s]


Validation: {'precision': 0.6994187655668436, 'recall': 0.792411414233953, 'f1': 0.7430167547937796, 'auc': 0.8077030351750837, 'prauc': 0.8158431523221861}
Test:      {'precision': 0.6885832187051208, 'recall': 0.7848855440552371, 'f1': 0.7335873338234286, 'auc': 0.7982736604107117, 'prauc': 0.810250777952266}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 59.68it/s, loss=0.5222]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 159.04it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 154.18it/s]


Validation: {'precision': 0.7685147713445224, 'recall': 0.5954844778908891, 'f1': 0.6710247300604222, 'auc': 0.7943078021020873, 'prauc': 0.8055617708108579}
Test:      {'precision': 0.7734562951051586, 'recall': 0.6048918156142838, 'f1': 0.6788667907788838, 'auc': 0.7948837974840279, 'prauc': 0.8108612568026653}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 60.84it/s, loss=0.4996]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 154.35it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 156.47it/s]


Validation: {'precision': 0.7255882352919836, 'recall': 0.7735967387871634, 'f1': 0.7488237922406751, 'auc': 0.8218144436543875, 'prauc': 0.8311570188907345}
Test:      {'precision': 0.7202755315941292, 'recall': 0.7541549074921475, 'f1': 0.7368259753925395, 'auc': 0.8111510927158818, 'prauc': 0.825146361957962}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 59.72it/s, loss=0.4677]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 152.69it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 159.59it/s]


Validation: {'precision': 0.7162709901542668, 'recall': 0.7757917842559555, 'f1': 0.7448441919047056, 'auc': 0.817464944027697, 'prauc': 0.8324210236159716}
Test:      {'precision': 0.7133683596009535, 'recall': 0.7663844465325608, 'f1': 0.7389266767880088, 'auc': 0.808871587575066, 'prauc': 0.8229772350879065}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 62.11it/s, loss=0.4416]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 163.95it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 164.57it/s]


Validation: {'precision': 0.6980371067472492, 'recall': 0.8140482909977609, 'f1': 0.7515923517151539, 'auc': 0.8156462222098975, 'prauc': 0.8257140719541786}
Test:      {'precision': 0.6931260229113663, 'recall': 0.7968015051715371, 'f1': 0.7413566689826596, 'auc': 0.807079311054704, 'prauc': 0.8201249124144472}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 57.57it/s, loss=0.4290]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 145.26it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 144.70it/s]


Validation: {'precision': 0.6780331109447046, 'recall': 0.8604578237665085, 'f1': 0.7584300669311101, 'auc': 0.8257617476930521, 'prauc': 0.8369633653904051}
Test:      {'precision': 0.670465807728765, 'recall': 0.8485418626502085, 'f1': 0.7490657390112798, 'auc': 0.8173821327023523, 'prauc': 0.8304804204240928}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 57.13it/s, loss=0.3897]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 145.96it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 146.99it/s]


Validation: {'precision': 0.7303743961330604, 'recall': 0.7585449984297318, 'f1': 0.7441931960454933, 'auc': 0.8165583712595835, 'prauc': 0.8226964632359126}
Test:      {'precision': 0.7277676950976172, 'recall': 0.7544684854162607, 'f1': 0.7408775931517662, 'auc': 0.8093513164987792, 'prauc': 0.8190900815907091}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 60.06it/s, loss=0.3697]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 164.50it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 165.07it/s]


Validation: {'precision': 0.6979973297711669, 'recall': 0.8196926936317978, 'f1': 0.7539659598410491, 'auc': 0.8165456111197686, 'prauc': 0.8200748020546728}
Test:      {'precision': 0.690571049134952, 'recall': 0.8153026026942136, 'f1': 0.7477710620439453, 'auc': 0.8115978028405226, 'prauc': 0.8163647939813318}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 64.99it/s, loss=0.3366]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 165.19it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 162.04it/s]


Validation: {'precision': 0.7666314677903322, 'recall': 0.6829727187184604, 'f1': 0.7223880547157497, 'auc': 0.8175643123605867, 'prauc': 0.8248838897716213}
Test:      {'precision': 0.7578427916786823, 'recall': 0.674192536843292, 'f1': 0.7135745054693888, 'auc': 0.8108630338940993, 'prauc': 0.8209996886556893}


Epoch 012: 100%|██████████| 98/98 [00:01<00:00, 65.03it/s, loss=0.3125]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 166.35it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 165.22it/s]


Validation: {'precision': 0.6818068380317399, 'recall': 0.8566948886771506, 'f1': 0.7593107232448224, 'auc': 0.8144405397076964, 'prauc': 0.8191978444348448}
Test:      {'precision': 0.671952428145015, 'recall': 0.8504233301948874, 'f1': 0.7507266386652544, 'auc': 0.8092180332976469, 'prauc': 0.8167446131050006}


Epoch 013: 100%|██████████| 98/98 [00:01<00:00, 64.04it/s, loss=0.3108]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 166.48it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 165.19it/s]


Validation: {'precision': 0.7417175940793126, 'recall': 0.7231106930049448, 'f1': 0.7322959619721238, 'auc': 0.8145433743777799, 'prauc': 0.8219318679752932}
Test:      {'precision': 0.7361594432118439, 'recall': 0.7296958294113212, 'f1': 0.7329133808245606, 'auc': 0.8088362030998411, 'prauc': 0.8195309844019592}


Epoch 014: 100%|██████████| 98/98 [00:01<00:00, 61.57it/s, loss=0.2929]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 164.49it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 158.46it/s]


Validation: {'precision': 0.7516475893140768, 'recall': 0.6795233615532157, 'f1': 0.7137681109523765, 'auc': 0.8048383837866248, 'prauc': 0.812109846617036}
Test:      {'precision': 0.7430531732393034, 'recall': 0.6792097836291026, 'f1': 0.7096985533301612, 'auc': 0.798559856891693, 'prauc': 0.8057867745242302}


Epoch 015: 100%|██████████| 98/98 [00:01<00:00, 60.47it/s, loss=0.2730]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 163.36it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 157.51it/s]


Validation: {'precision': 0.7252982563452882, 'recall': 0.7434932580723002, 'f1': 0.7342830547693202, 'auc': 0.8060772225576365, 'prauc': 0.8141422651106813}
Test:      {'precision': 0.7184759600824964, 'recall': 0.7450611476928659, 'f1': 0.7315270885954568, 'auc': 0.8052002091861722, 'prauc': 0.8145776044969963}


Epoch 016: 100%|██████████| 98/98 [00:01<00:00, 60.73it/s, loss=0.2369]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 163.69it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 154.66it/s]


Validation: {'precision': 0.7155096011795702, 'recall': 0.7594857322020713, 'f1': 0.7368421002653609, 'auc': 0.8005727795673789, 'prauc': 0.801027628358725}
Test:      {'precision': 0.7102390085550008, 'recall': 0.7547820633403739, 'f1': 0.7318333790096943, 'auc': 0.7985171740538681, 'prauc': 0.8004769045720661}


Epoch 017: 100%|██████████| 98/98 [00:01<00:00, 61.03it/s, loss=0.2363]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 164.90it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 162.12it/s]


Validation: {'precision': 0.6814020402807183, 'recall': 0.8168704923147794, 'f1': 0.7430119745025329, 'auc': 0.7983531683175976, 'prauc': 0.8025802308356826}
Test:      {'precision': 0.6742305685950073, 'recall': 0.8105989338325161, 'f1': 0.7361526363614501, 'auc': 0.7966577022640527, 'prauc': 0.8036038929267946}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6818068380317399, 'recall': 0.8566948886771506, 'f1': 0.7593107232448224, 'auc': 0.8144405397076964, 'prauc': 0.8191978444348448}
Corresponding test performance:
{'precision': 0.671952428145015, 'recall': 0.8504233301948874, 'f1': 0.7507266386652544, 'auc': 0.8092180332976469, 'prauc': 0.8167446131050006}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 59.76it/s, loss=0.6771]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 159.47it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 157.29it/s]


Validation: {'precision': 0.6757062146873568, 'recall': 0.7500783944786765, 'f1': 0.7109525882645753, 'auc': 0.7720675314826309, 'prauc': 0.7836292638940019}
Test:      {'precision': 0.6779079737048359, 'recall': 0.7438068359964133, 'f1': 0.7093301385492911, 'auc': 0.7663187109380308, 'prauc': 0.7769235763169671}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 58.94it/s, loss=0.5953]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 157.81it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 153.72it/s]


Validation: {'precision': 0.7927411652302162, 'recall': 0.520539354027844, 'f1': 0.6284308110367646, 'auc': 0.7811125113773727, 'prauc': 0.7785158651889198}
Test:      {'precision': 0.7847222222185893, 'recall': 0.5315145813718046, 'f1': 0.6337633154294413, 'auc': 0.7701722463907081, 'prauc': 0.7707967017650452}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 60.94it/s, loss=0.5698]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 154.62it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 160.13it/s]


Validation: {'precision': 0.625332741790982, 'recall': 0.8839761680749953, 'f1': 0.7324931743056394, 'auc': 0.7970892111611234, 'prauc': 0.8064483887403411}
Test:      {'precision': 0.6234374999986084, 'recall': 0.8758231420480532, 'f1': 0.7283870077881169, 'auc': 0.7956030637016188, 'prauc': 0.8076650653876327}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 58.87it/s, loss=0.5296]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 163.32it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 159.86it/s]


Validation: {'precision': 0.7103888566433245, 'recall': 0.7676387582290134, 'f1': 0.7379050439879451, 'auc': 0.7897957362846834, 'prauc': 0.7914641209755038}
Test:      {'precision': 0.6999419616926874, 'recall': 0.7563499529609397, 'f1': 0.7270534991499974, 'auc': 0.7796501013969066, 'prauc': 0.7822532286647524}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 59.02it/s, loss=0.5196]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 157.73it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 153.91it/s]


Validation: {'precision': 0.7486214725892033, 'recall': 0.7237378488531712, 'f1': 0.7359693827541833, 'auc': 0.8097715342179225, 'prauc': 0.8182082266189142}
Test:      {'precision': 0.7341206511307561, 'recall': 0.7212292254602658, 'f1': 0.7276178374530098, 'auc': 0.8047710654653059, 'prauc': 0.8124745622142506}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 60.34it/s, loss=0.4718]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 154.80it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 161.16it/s]


Validation: {'precision': 0.7594192879337732, 'recall': 0.6889306992766104, 'f1': 0.7224597122077574, 'auc': 0.8014399667070873, 'prauc': 0.7891437801832769}
Test:      {'precision': 0.7365935919030806, 'recall': 0.6848541862631394, 'f1': 0.7097822504479318, 'auc': 0.7923118041703348, 'prauc': 0.7835347664578334}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 63.24it/s, loss=0.4453]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 157.42it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 154.18it/s]


Validation: {'precision': 0.7919876733405548, 'recall': 0.6447162119766551, 'f1': 0.710803797988715, 'auc': 0.8236742089140328, 'prauc': 0.8310782912527035}
Test:      {'precision': 0.7843432442574921, 'recall': 0.6534963938518235, 'f1': 0.7129661257266178, 'auc': 0.8177071866724852, 'prauc': 0.8265337608631464}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 59.98it/s, loss=0.4165]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 158.47it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 154.99it/s]


Validation: {'precision': 0.737646001794736, 'recall': 0.7723424270907108, 'f1': 0.7545955832356224, 'auc': 0.8151142650425722, 'prauc': 0.810265223807879}
Test:      {'precision': 0.7325200833063596, 'recall': 0.7720288491665976, 'f1': 0.7517557201919922, 'auc': 0.8099996627653143, 'prauc': 0.8095337771573579}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 61.06it/s, loss=0.3714]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 150.79it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 157.94it/s]


Validation: {'precision': 0.7742050732376771, 'recall': 0.6795233615532157, 'f1': 0.723780890142373, 'auc': 0.8191136646126861, 'prauc': 0.8168092942722711}
Test:      {'precision': 0.7747716092734548, 'recall': 0.6914393226695157, 'f1': 0.7307373603824122, 'auc': 0.8231263970701855, 'prauc': 0.8259895229991983}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 59.30it/s, loss=0.3493]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 156.98it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 151.00it/s]


Validation: {'precision': 0.7675070027984331, 'recall': 0.6873628096560447, 'f1': 0.7252274557241051, 'auc': 0.8160011952332539, 'prauc': 0.8166810880282511}
Test:      {'precision': 0.7687963289771663, 'recall': 0.6829727187184604, 'f1': 0.7233477200233746, 'auc': 0.812735038734172, 'prauc': 0.8164226636368725}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 60.93it/s, loss=0.3140]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 161.19it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 161.88it/s]


Validation: {'precision': 0.7304009575082873, 'recall': 0.7654437127602213, 'f1': 0.7475118614833824, 'auc': 0.8176010354401329, 'prauc': 0.8173598116525277}
Test:      {'precision': 0.7286775631479089, 'recall': 0.768893069925466, 'f1': 0.7482453413546943, 'auc': 0.817410671816794, 'prauc': 0.8210334346312166}


Epoch 012: 100%|██████████| 98/98 [00:01<00:00, 61.54it/s, loss=0.2625]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 143.18it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 145.72it/s]


Validation: {'precision': 0.7509829619896757, 'recall': 0.7187206020673605, 'f1': 0.734497671654438, 'auc': 0.814629781623771, 'prauc': 0.8131826651823014}
Test:      {'precision': 0.7486285898652836, 'recall': 0.727500783942529, 'f1': 0.7379134810037665, 'auc': 0.8136662594683671, 'prauc': 0.8163463242158825}


Epoch 013: 100%|██████████| 98/98 [00:01<00:00, 56.44it/s, loss=0.2441]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 147.36it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 148.05it/s]


Validation: {'precision': 0.7639645776540737, 'recall': 0.7033552837858158, 'f1': 0.7324081582714457, 'auc': 0.8140377412784192, 'prauc': 0.8148023262785897}
Test:      {'precision': 0.7566835871378793, 'recall': 0.7011602383170237, 'f1': 0.7278645783382167, 'auc': 0.8127639805168951, 'prauc': 0.8154472601538022}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.737646001794736, 'recall': 0.7723424270907108, 'f1': 0.7545955832356224, 'auc': 0.8151142650425722, 'prauc': 0.810265223807879}
Corresponding test performance:
{'precision': 0.7325200833063596, 'recall': 0.7720288491665976, 'f1': 0.7517557201919922, 'auc': 0.8099996627653143, 'prauc': 0.8095337771573579}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 57.79it/s, loss=0.6846]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 162.83it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 166.60it/s]


Validation: {'precision': 0.7273676880197515, 'recall': 0.6550642834723893, 'f1': 0.6893251888738016, 'auc': 0.7709497130927776, 'prauc': 0.775855612855261}
Test:      {'precision': 0.7270261105434824, 'recall': 0.6723110692986131, 'f1': 0.6985988871526474, 'auc': 0.7714684859219618, 'prauc': 0.7751700780922268}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 63.83it/s, loss=0.5908]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 153.81it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 165.94it/s]


Validation: {'precision': 0.8063820612298129, 'recall': 0.5863907180916075, 'f1': 0.6790123408012912, 'auc': 0.8040223376797182, 'prauc': 0.8163705769709064}
Test:      {'precision': 0.7963743676189023, 'recall': 0.5923486986497575, 'f1': 0.6793742083764728, 'auc': 0.8030768890049915, 'prauc': 0.8187623218044859}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 62.74it/s, loss=0.5503]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 166.41it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 165.57it/s]


Validation: {'precision': 0.7405819295536277, 'recall': 0.7582314205056186, 'f1': 0.7493027529779192, 'auc': 0.8159731128783068, 'prauc': 0.828403001788468}
Test:      {'precision': 0.7225728884859556, 'recall': 0.7538413295680344, 'f1': 0.7378759925444837, 'auc': 0.8124367622047498, 'prauc': 0.8279090015892575}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 64.50it/s, loss=0.5220]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 166.10it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 166.71it/s]


Validation: {'precision': 0.8040141676473671, 'recall': 0.640639698963184, 'f1': 0.7130890002970588, 'auc': 0.8220510588454446, 'prauc': 0.8369158221807611}
Test:      {'precision': 0.7826086956492151, 'recall': 0.6491063029142392, 'f1': 0.7096331798199242, 'auc': 0.8134106154432346, 'prauc': 0.8292914633817089}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 62.98it/s, loss=0.5021]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 164.32it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 166.34it/s]


Validation: {'precision': 0.7280047718439367, 'recall': 0.7654437127602213, 'f1': 0.7462549629005857, 'auc': 0.8142447670114015, 'prauc': 0.8275144825555888}
Test:      {'precision': 0.7142440163435078, 'recall': 0.7673251803049003, 'f1': 0.739833706266464, 'auc': 0.8136769301778233, 'prauc': 0.8277940007175655}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 62.58it/s, loss=0.4829]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 156.13it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 156.95it/s]


Validation: {'precision': 0.6841835426796099, 'recall': 0.8369394794580215, 'f1': 0.7528913913811635, 'auc': 0.8182476329187092, 'prauc': 0.8283799057372544}
Test:      {'precision': 0.6833839918928992, 'recall': 0.8460332392573031, 'f1': 0.7560599642291533, 'auc': 0.8193493685909681, 'prauc': 0.8305607021080226}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 62.21it/s, loss=0.4697]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 154.29it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 161.11it/s]


Validation: {'precision': 0.6864012409496215, 'recall': 0.8325493885204374, 'f1': 0.7524443765093676, 'auc': 0.8162600150770597, 'prauc': 0.8239933560630007}
Test:      {'precision': 0.6797352342141555, 'recall': 0.8372530573821347, 'f1': 0.7503161394946842, 'auc': 0.8140875511577468, 'prauc': 0.8246936417176316}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 62.69it/s, loss=0.4552]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 154.14it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 160.92it/s]


Validation: {'precision': 0.7578668433230942, 'recall': 0.7174662903709079, 'f1': 0.7371133970632304, 'auc': 0.8195573055367251, 'prauc': 0.833409668396827}
Test:      {'precision': 0.7479884132579724, 'recall': 0.7287550956389817, 'f1': 0.7382465007164193, 'auc': 0.8173988434360289, 'prauc': 0.8314569612367593}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 63.01it/s, loss=0.4245]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 156.53it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 156.47it/s]


Validation: {'precision': 0.7516951888900495, 'recall': 0.7300094073354343, 'f1': 0.7406935998348584, 'auc': 0.8217515472171895, 'prauc': 0.8293385242957223}
Test:      {'precision': 0.7403138008301623, 'recall': 0.7249921605496238, 'f1': 0.7325728720577946, 'auc': 0.8165715614519614, 'prauc': 0.824770527450202}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 61.51it/s, loss=0.3983]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 157.42it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 160.90it/s]


Validation: {'precision': 0.7567475230585694, 'recall': 0.6945751019106473, 'f1': 0.7243296222141322, 'auc': 0.8108283651678728, 'prauc': 0.8176211355944156}
Test:      {'precision': 0.7590970350378737, 'recall': 0.7064910630269474, 'f1': 0.7318499219165222, 'auc': 0.8148901697095805, 'prauc': 0.818952428277414}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 59.19it/s, loss=0.3663]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 160.94it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 157.29it/s]


Validation: {'precision': 0.7418851435682213, 'recall': 0.745374725616979, 'f1': 0.743625835761036, 'auc': 0.819694904052289, 'prauc': 0.8266562720925392}
Test:      {'precision': 0.7431023911687826, 'recall': 0.7601128880502975, 'f1': 0.7515113885807007, 'auc': 0.8240495644387534, 'prauc': 0.8253117169384951}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6841835426796099, 'recall': 0.8369394794580215, 'f1': 0.7528913913811635, 'auc': 0.8182476329187092, 'prauc': 0.8283799057372544}
Corresponding test performance:
{'precision': 0.6833839918928992, 'recall': 0.8460332392573031, 'f1': 0.7560599642291533, 'auc': 0.8193493685909681, 'prauc': 0.8305607021080226}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 59.89it/s, loss=0.6721]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 158.83it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 153.88it/s]


Validation: {'precision': 0.727714748782082, 'recall': 0.7039824396340421, 'f1': 0.7156518917157008, 'auc': 0.7893688744421332, 'prauc': 0.7868081427777367}
Test:      {'precision': 0.7182284980721495, 'recall': 0.7017873941652499, 'f1': 0.7099127626431448, 'auc': 0.783397332423303, 'prauc': 0.7864249820614669}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 62.73it/s, loss=0.5791]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 153.61it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 149.88it/s]


Validation: {'precision': 0.7818659658310037, 'recall': 0.5597365945419889, 'f1': 0.6524122758364373, 'auc': 0.7886583757909403, 'prauc': 0.7911619079926978}
Test:      {'precision': 0.7723611699840087, 'recall': 0.5713389777341759, 'f1': 0.6568132611513553, 'auc': 0.7875477350664177, 'prauc': 0.7955695390725754}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 60.34it/s, loss=0.5494]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 148.00it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 157.67it/s]


Validation: {'precision': 0.7278849097008635, 'recall': 0.7456883035410923, 'f1': 0.7366790532388444, 'auc': 0.80644616140331, 'prauc': 0.8142552456934412}
Test:      {'precision': 0.7282608695629558, 'recall': 0.735340232045358, 'f1': 0.7317834245500257, 'auc': 0.8049975663735747, 'prauc': 0.8112030722659163}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 57.87it/s, loss=0.5313]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 154.49it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 149.17it/s]


Validation: {'precision': 0.8318666049058274, 'recall': 0.5631859517072336, 'f1': 0.6716529495584206, 'auc': 0.809306743613324, 'prauc': 0.8199334775011925}
Test:      {'precision': 0.811977102594399, 'recall': 0.5782376920646654, 'f1': 0.675457870596743, 'auc': 0.8042106015511788, 'prauc': 0.818654791859407}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 60.90it/s, loss=0.5195]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 148.88it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 151.03it/s]


Validation: {'precision': 0.7697205518189965, 'recall': 0.6823455628702341, 'f1': 0.723404250334848, 'auc': 0.8082970347545015, 'prauc': 0.8179885894525718}
Test:      {'precision': 0.7467532467506947, 'recall': 0.6851677641872526, 'f1': 0.7146361356446875, 'auc': 0.8031872704476212, 'prauc': 0.8192233721838432}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 59.08it/s, loss=0.4893]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 152.84it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 152.45it/s]


Validation: {'precision': 0.7692041522464733, 'recall': 0.6970837253035526, 'f1': 0.7313702861760002, 'auc': 0.8131421703631386, 'prauc': 0.8177208989330682}
Test:      {'precision': 0.7503382949906958, 'recall': 0.6955158356829868, 'f1': 0.7218877085931221, 'auc': 0.8016718787042336, 'prauc': 0.8092380576918117}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 59.05it/s, loss=0.4791]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 157.98it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 146.57it/s]


Validation: {'precision': 0.7205290396759042, 'recall': 0.7858262778275766, 'f1': 0.751762406886546, 'auc': 0.8165841929598391, 'prauc': 0.8216505027915164}
Test:      {'precision': 0.7136177673854074, 'recall': 0.7657572906843344, 'f1': 0.7387687138059725, 'auc': 0.8075803310637943, 'prauc': 0.8199174366701973}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 60.75it/s, loss=0.4399]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 149.15it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 156.45it/s]


Validation: {'precision': 0.72418096723643, 'recall': 0.7278143618666422, 'f1': 0.7259931135464, 'auc': 0.7956501286212045, 'prauc': 0.7968930030839658}
Test:      {'precision': 0.7228877679674562, 'recall': 0.7190341799914738, 'f1': 0.7209558195536564, 'auc': 0.7950191946936374, 'prauc': 0.8045031601588978}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 58.01it/s, loss=0.4307]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 156.93it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 150.79it/s]


Validation: {'precision': 0.717255717253587, 'recall': 0.7572906867332792, 'f1': 0.7367297082412186, 'auc': 0.8050203916076861, 'prauc': 0.8112927132202179}
Test:      {'precision': 0.7045387994122854, 'recall': 0.7544684854162607, 'f1': 0.728649298456102, 'auc': 0.8006912304385007, 'prauc': 0.8164654689863866}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 62.08it/s, loss=0.3915]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 158.78it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 148.45it/s]


Validation: {'precision': 0.780725712053138, 'recall': 0.6274694261504313, 'f1': 0.6957579922751614, 'auc': 0.8045359081888851, 'prauc': 0.8070570812330199}
Test:      {'precision': 0.7741194486953518, 'recall': 0.6340545625568076, 'f1': 0.697121181049477, 'auc': 0.8017204002321383, 'prauc': 0.812876971878644}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 60.26it/s, loss=0.3646]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 145.98it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 153.68it/s]


Validation: {'precision': 0.7127564674376665, 'recall': 0.7516462840992423, 'f1': 0.7316849766862746, 'auc': 0.79897349196498, 'prauc': 0.7999241146561251}
Test:      {'precision': 0.7040846312056889, 'recall': 0.7513327061751292, 'f1': 0.7269417425758795, 'auc': 0.7991836907265997, 'prauc': 0.8115236372062105}


Epoch 012: 100%|██████████| 98/98 [00:01<00:00, 57.98it/s, loss=0.3328]
Running inference: 100%|██████████| 198/198 [00:01<00:00, 150.92it/s]
Running inference: 100%|██████████| 197/197 [00:01<00:00, 149.16it/s]

Validation: {'precision': 0.780024262026769, 'recall': 0.6048918156142838, 'f1': 0.6813846648055616, 'auc': 0.8015953490395584, 'prauc': 0.8081888938887773}
Test:      {'precision': 0.7781262250067499, 'recall': 0.6224521793646207, 'f1': 0.6916376257213827, 'auc': 0.8046613886921687, 'prauc': 0.8160369299024629}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7205290396759042, 'recall': 0.7858262778275766, 'f1': 0.751762406886546, 'auc': 0.8165841929598391, 'prauc': 0.8216505027915164}
Corresponding test performance:
{'precision': 0.7136177673854074, 'recall': 0.7657572906843344, 'f1': 0.7387687138059725, 'auc': 0.8075803310637943, 'prauc': 0.8199174366701973}





In [21]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")


Final Metrics:
precision: 0.6981 ± 0.0219
recall: 0.8167 ± 0.0391
f1: 0.7517 ± 0.0074
auc: 0.8134 ± 0.0056
prauc: 0.8187 ± 0.0068
