In [1]:
import torch
import pickle
import numpy as np
from argparse import Namespace
from torch.utils.data import DataLoader
import torch.nn.functional as F
from heterogt.utils.tokenizer import EHRTokenizer
from heterogt.utils.dataset import FineTuneEHRDataset, batcher
from heterogt.utils.train import train_with_early_stopping
from heterogt.utils.seed import set_random_seed

In [2]:
set_random_seed(123)

[INFO] Random seed set to 123


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
config = Namespace(
    dataset = "MIMIC-III",
    tasks = ["death", "readmission", "stay", "next_diag_6m", "next_diag_12m"], 
    task_index = 2,
    token_type = ["diag", "med", "lab", "pro"],
    special_tokens = ["[PAD]"],
    batch_size = 32,
    lr = 1e-3,
    epochs = 500,
    early_stop_patience = 5,   
)

In [5]:
full_data_path = f"./data_process/{config.dataset}-processed/mimic.pkl"  # for tokenizer
curr_task = config.tasks[config.task_index]
print("Current task:", curr_task)
if curr_task == "next_diag_6m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_6m.pkl"
elif curr_task == "next_diag_12m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_downstream.pkl"

Current task: stay


In [6]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_gender_sentences = ["[PAD]"] + [str(c) + "_" + gender for c in set(ehr_full_data["AGE"].values.tolist()) for gender in ["M", "F"]]
token_type_sentences = ["[PAD]"] + config.token_type
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
config.max_num_adms = max_admissions
print(f"Max admissions per patient: {config.max_num_adms}")

Max admissions per patient: 8


In [7]:
task_sentences = config.tasks
tokenizer = EHRTokenizer(token_type_sentences, age_gender_sentences, task_sentences, diag_sentences, 
                         med_sentences, lab_sentences, pro_sentences, special_tokens=config.special_tokens)
config.label_vocab_size = len(tokenizer.diag_voc.id2word)  # only for diagnosis
config.global_vocab_size = len(tokenizer.vocab.id2word)
config.age_gender_vocab_size = tokenizer.token_number("age_gender")
print(f"Age and gender vocabulary size: {config.age_gender_vocab_size}")

Age and gender vocabulary size: 37


In [8]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))
# example label percentage
print("Percentage of DEATH in test dataset:",
      (test_data["DEATH"] == True).mean() * 100, "%")

print("Percentage of READMISSION in test dataset:",
      (test_data["READMISSION"] == 1).mean() * 100, "%")

print("Percentage of STAY>7 days in test dataset:",
      (test_data["STAY_DAYS"] > 7).mean() * 100, "%")

Percentage of DEATH in test dataset: 28.648477157360407 %
Percentage of READMISSION in test dataset: 40.1491116751269 %
Percentage of STAY>7 days in test dataset: 50.58692893401015 %


In [9]:
train_dataset = FineTuneEHRDataset(train_data, tokenizer, token_type=config.token_type, task=curr_task)
val_dataset = FineTuneEHRDataset(val_data, tokenizer, token_type=config.token_type, task=curr_task)
test_dataset = FineTuneEHRDataset(test_data, tokenizer, token_type=config.token_type, task=curr_task)

In [10]:
train_dataloader = DataLoader(
    train_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=True,
    batch_size=config.batch_size,
)

val_dataloader = DataLoader(
    val_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

test_dataloader = DataLoader(
    test_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

In [11]:
if curr_task in ["death", "stay", "readmission"]:
    eval_metric = "prauc"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "prauc"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [12]:
input_ids, token_types, adm_index, age_gender_ids, task_index, labels = next(iter(train_dataloader))
print("Input IDs shape:", input_ids.shape)
print("Token Types shape:", token_types.shape)
print("Admission Index shape:", adm_index.shape)
print("Age/Sex IDs shape:", age_gender_ids.shape)
print("Task Index:", task_index)
print("Labels shape:", labels.shape)

Input IDs shape: torch.Size([32, 256])
Token Types shape: torch.Size([32, 256])
Admission Index shape: torch.Size([32, 256])
Age/Sex IDs shape: torch.Size([32, 7])
Task Index: 2
Labels shape: torch.Size([32, 1])


# Model Walkthrough

In [13]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData, Batch as HeteroBatch
from torch_geometric.nn import HeteroConv, GATConv
from heterogt.model.layer import TransformerEncoder

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [14]:
class DiseaseOccHetGNN(nn.Module):
    def __init__(self, d_model: int, heads: int = 4, dropout: float = 0.0):
        super().__init__()
        self.act = nn.GELU()

        # 第1层
        self.conv1 = HeteroConv({
            ('visit','contains','occ'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('occ','contained_by','visit'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('visit','next','visit'):       GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
        }, aggr='mean')

        # 第2层
        self.conv2 = HeteroConv({
            ('visit','contains','occ'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('occ','contained_by','visit'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('visit','next','visit'):       GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
        }, aggr='mean')
        self.lin = nn.Linear(d_model, d_model)

    def forward(self, hg):
        # x_dict: {'visit': [N_visit, d], 'occ': [N_occ, d]}
        x_dict = {'visit': hg['visit'].x, 'occ': hg['occ'].x}

        # 第1层：HeteroConv → Linear → GELU → Dropout
        x_dict = self.conv1(x_dict, hg.edge_index_dict)

        # 第2层：HeteroConv → Linear（末层通常不再加激活/随你需要）
        x_dict = self.conv2(x_dict, hg.edge_index_dict)
        x_dict = {k: self.lin(v) for k, v in x_dict.items()}

        return x_dict  # {'visit': [N_visit, d], 'occ': [N_occ, d]}

In [15]:
# multi-class classification task
class MultiPredictionHead(nn.Module):
    def __init__(self, hidden_size, label_size):
        super(MultiPredictionHead, self).__init__()
        self.cls = nn.Sequential(
                nn.Linear(hidden_size, hidden_size), 
                nn.ReLU(), 
                nn.Linear(hidden_size, label_size)
            )

    def forward(self, input):
        return self.cls(input)
    
class BinaryPredictionHead(nn.Module):
    def __init__(self, hidden_size):
        super(BinaryPredictionHead, self).__init__()
        self.cls = nn.Sequential(
                nn.Linear(hidden_size, hidden_size), 
                nn.ReLU(), 
                nn.Linear(hidden_size, 1)
            )
    def forward(self, input):
        return self.cls(input)

In [16]:
for i in range(len(train_dataset)):
    age_gender_ids = train_dataset[i][3]
    if len(age_gender_ids[0]) > 3:
        print(age_gender_ids)
        break
exp_i = i
id_seq = torch.concat([train_dataset[exp_i][0][0], torch.zeros(5, dtype=train_dataset[exp_i][0][0].dtype)], dim=0)
type_seq = torch.concat([train_dataset[exp_i][1][0], torch.zeros(5, dtype=train_dataset[exp_i][1][0].dtype)], dim=0)
visit_seq = torch.concat([train_dataset[exp_i][2][0], torch.zeros(5, dtype=train_dataset[exp_i][2][0].dtype)], dim=0)
age_sex = torch.concat([train_dataset[exp_i][3][0], torch.zeros(3, dtype=train_dataset[exp_i][3][0].dtype)], dim=0)

tensor([[17, 17, 17, 23]])


In [17]:
class HeteroGT(nn.Module):
    def __init__(self, tokenizer, d_model, num_heads, layer_types, max_num_adms, device, task, label_vocab_size):
        super(HeteroGT, self).__init__()
        self.device = device
        self.tokenizer = tokenizer
        self.max_num_adms = max_num_adms
        self.global_vocab_size = len(self.tokenizer.vocab.word2id)
        self.age_gender_vocab_size = len(self.tokenizer.age_gender_voc.word2id)
        self.n_type = len(self.tokenizer.token_type_voc.word2id)
        self.d_model = d_model
        self.num_attn_heads = num_heads
        self.layer_types = layer_types
        self.seq_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="all")[0] #0
        self.type_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="type")[0] #0
        self.adm_pad_id = 0
        self.age_gender_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="age_gender")[0] #0
        self.node_type_id_dict = {'diag': 1, 'med': 2, 'lab': 3, 'pro': 4, 'visit': 5}
        self.graph_node_types = ['diag']
        
        # embedding layers
        self.token_emb = nn.Embedding(self.global_vocab_size, d_model, padding_idx=self.seq_pad_id) # already contains [PAD]
        self.type_emb = nn.Embedding(self.n_type + 1, d_model, padding_idx=self.type_pad_id) # n_type already have PAD, + 1 for visit
        self.adm_index_emb = nn.Embedding(self.max_num_adms + 1, d_model, padding_idx=self.adm_pad_id) # +1 for pad
        self.age_gender_emb = nn.Embedding(self.age_gender_vocab_size, d_model, padding_idx=self.age_gender_pad_id) # already contains [PAD]
        self.task_emb = nn.Embedding(5, d_model, padding_idx=None)  # 5 task in total, task embedding
        
        # stack together
        self.stack_layers = nn.ModuleList(self.make_gnn_layer() if layer_type == 'gnn' else self.make_tf_layer()
            for layer_type in self.layer_types
        )

        # prediction head
        if task in ["death", "stay", "readmission"]:
            self.cls_head = BinaryPredictionHead(self.d_model)
        else:
            self.cls_head = MultiPredictionHead(self.d_model, label_vocab_size)

    def make_tf_layer(self):
        assert self.d_model % self.num_attn_heads == 0, "Invalid model and attention head dimensions"
        layer_layer = nn.TransformerEncoderLayer(d_model=self.d_model, nhead=self.num_attn_heads, batch_first=True, norm_first=True)
        tf_wrapper = nn.TransformerEncoder(layer_layer, num_layers=1, enable_nested_tensor=False)
        return tf_wrapper

    def make_gnn_layer(self):
        return DiseaseOccHetGNN(d_model=self.d_model, heads=self.num_attn_heads)
    
    def forward(self, input_ids, token_types, adm_index, age_gender_ids, task_id):
        """Forward pass for the model.

        Args:
            input_ids (Tensor): Input token IDs. Shape of [B, L]
            token_types (Tensor): Token type IDs. Shape of [B, L]
            adm_index (Tensor): Admission index IDs. Shape of [B, L]
            age_gender_ids (Tensor): Age and gender IDs. Shape of [B, V]
            task_id (Tensor): Task ID. Shape of [1]

        Returns:
            Tensor: Output logits. Shape of [B, label_size]
        """
        B, L = input_ids.shape
        V = age_gender_ids.shape[1]
        num_visits = adm_index.max(dim=1).values
        
        task_id = torch.full((B,), task_id, dtype=torch.long, device=self.device) # [1] -> [B]
        # 基础表示
        task_id_embed = self.task_emb(task_id).unsqueeze(1) # [B, 1, d]
        seq_embed = self.token_emb(input_ids)  # [B, L, d]
        visit_embed = self.age_gender_emb(age_gender_ids) #[B, V, d], the virtual node representation in gnn

        # run through layers
        for i, layer_type in enumerate(self.layer_types):
            if layer_type == 'gnn':
                hg_batch = self.build_graph_batch(seq_embed, token_types, self.graph_node_types, visit_embed, adm_index) # num_visits is a 1d tensor of [B]
                gnn_out = self.stack_layers[i](hg_batch)['visit']  # extract virtual visit node representations
                visit_embed = self.process_gnn_out(gnn_out, num_visits, V) # [B, V, d]
            elif layer_type == 'tf':
                x, src_key_padding_mask, attn_mask = self.prepare_tf_input(task_id_embed, seq_embed, visit_embed, i, input_ids, 
                                                                            adm_index, token_types, num_visits)
                h = self.stack_layers[i](src=x, src_key_padding_mask=src_key_padding_mask, mask=attn_mask) # [B, 1+L+V, d]
                task_id_embed, seq_embed, visit_embed = self.process_tf_out(h, L, V) # # [B, 1, d], [B, L, d], [B, V, d]
            else:
                raise ValueError(f"Unknown layer type: {layer_type}")

        logits = self.cls_head(task_id_embed.squeeze())  # [B, label_size]
        return logits

    def build_graph_batch(self, seq_embed, token_types, graph_node_types, visit_embed, adm_index):
        """Build a batch of heterogeneous graphs from the input sequences.

        Args:
            seq_embed (Tensor): Sequence embeddings. Shape of [B, L, d]
            token_types (Tensor): Token type IDs. Shape of [B, L]
            graph_node_types: a list controls what types of tokens are connected to the virtual visit nodes. e.g. ['diag']
            visit_embed (Tensor): Visit embeddings. Shape of [B, V, d]
        Returns:
            A batch of heterogeneous graphs.
        """
        B, L = seq_embed.shape[0], seq_embed.shape[1]
        V = visit_embed.shape[1]
        graph_node_type_ids = [self.node_type_id_dict[t] for t in graph_node_types]
        graphs = [] # contains heterogeneous graphs for each patient
        for p in range(B):
            hg_p = self.build_patient_graph(seq_embed[p], token_types[p], visit_embed[p], adm_index[p], graph_node_type_ids)
            graphs.append(hg_p)
        hg_batch = HeteroBatch.from_data_list(graphs).to(self.device)
        return hg_batch

    def build_patient_graph(self, seq_embed_p, token_types_p, visit_embed_p, adm_index_p, graph_node_type_ids):
        """Build a heterogeneous graph for a single patient.

        Args:
            seq_embed_p (Tensor): Sequence embeddings for patient p. Shape [L, d]
            token_types_p (Tensor): Token type IDs for patient p. Shape [L]
            visit_embed_p (Tensor): Visit embeddings for patient p. Shape [V, d]
            graph_node_type_ids (list): List of graph node type IDs that the graph uses.
            adm_index_p (Tensor): Admission index for patient p. Shape [L]

        Returns:
            A heterogeneous graph for patient p.
        """
        hg = HeteroData()
        occ_mask = torch.isin(token_types_p, torch.tensor(graph_node_type_ids, device=token_types_p.device)) # [L], a mask for the token types needed in the graph
        occ_pos = torch.nonzero(occ_mask, as_tuple=False).view(-1) # [L], seq position index for the token types needed in the graph
        num_occ = occ_pos.numel() # int, number of occurrences of the token types needed in the graph
        
        # build visit virtual nodes
        nonpad = adm_index_p != self.adm_pad_id
        adm_index_used_p = adm_index_p[nonpad] # adm_index非pad部分
        adm_ids_unique, adm_lid_nonpad = torch.unique(adm_index_used_p, return_inverse=True)
        num_visit_p = adm_ids_unique.numel()  # int, number of visits for patient
        adm_lid_full = torch.full_like(token_types_p, fill_value=-1) # [L]
        adm_lid_full[nonpad] = adm_lid_nonpad
        hg['visit'].x = visit_embed_p[:num_visit_p, :]
        hg['visit'].num_nodes = num_visit_p
        
        # build medical code nodes
        gid_occ_embed = seq_embed_p[occ_pos, :]
        hg['occ'].x = gid_occ_embed
        hg['occ'].num_nodes = num_occ

        # build edges between occ nodes and virtual visit nodes
        occ_adm_lid = adm_lid_full[occ_pos]
        assert (occ_adm_lid != -1).all(), "occ_adm_lid contains -1"
        e_v2o = torch.stack([occ_adm_lid, torch.arange(num_occ, device=self.device)], dim=0)
        e_o2v = torch.stack([torch.arange(num_occ, device=self.device), occ_adm_lid], dim=0)
        hg['visit','contains','occ'].edge_index = e_v2o
        hg['occ','contained_by','visit'].edge_index = e_o2v
        
        # build forward edges between virtual visit nodes
        if num_visit_p > 1:
            src = torch.arange(0, num_visit_p - 1, device=self.device)
            dst = torch.arange(1, num_visit_p, device=self.device)
            e_next = torch.stack([src, dst], dim=0) # [2, num_visit_p-1]
        else:
            e_next = torch.empty(2, 0, dtype=torch.long, device=self.device)
        hg['visit','next','visit'].edge_index = e_next
        return hg

    def process_gnn_out(self, gnn_out, num_visits, V):
        """Process the output of the GNN layer.

        Args:
            gnn_out (Tensor): The output of the GNN layer. Shape [sum(num_visits), d]
            num_visits (Tensor): A tensor containing the number of visits for each patient.
            V (int): The maximum number of visits.

        Returns:
            Tensor: The processed visit embeddings. Shape [B, V, d]
        """
        B = len(num_visits)
        # 计算每个批次的累积偏移量
        cumsum = torch.cumsum(num_visits, dim=0)  # [B]
        offsets = torch.cat([torch.tensor([0], device=self.device), cumsum[:-1]])  # [B]

        # 创建索引以从 gnn_out 中提取所有批次的嵌入
        indices = torch.arange(sum(num_visits), device=self.device)  # [N]
        batch_indices = torch.repeat_interleave(torch.arange(B, device=self.device), num_visits)  # [N]
        visit_pos = indices - offsets[batch_indices]  # [N]，每个嵌入的相对位置

        # 创建目标张量 visit_emb_pad，初始化为零
        visit_emb_pad = torch.zeros(B, V, self.d_model, device=self.device, dtype=gnn_out.dtype)  # [B, V, d]

        # 创建掩码，选择有效位置 (visit_pos < V 且 visit_pos < num_visits)
        mask = (visit_pos < V) & (visit_pos < num_visits[batch_indices])  # [N]
        valid_indices = indices[mask]  # [N_valid]
        valid_batch_indices = batch_indices[mask]  # [N_valid]
        valid_visit_pos = visit_pos[mask]  # [N_valid]

        # 使用 scatter 将 gnn_out 的值分配到 visit_emb_pad
        visit_emb_pad[valid_batch_indices, valid_visit_pos] = gnn_out[valid_indices]
        return visit_emb_pad
    
    def prepare_tf_input(self, task_id_embed, seq_embed, visit_embed, layer_i, input_ids, adm_index, token_types, num_visits):
        """Prepare the input for the Transformer layer.
        Args:
            task_id_emb (Tensor): Task ID embeddings. Shape [B, 1, d]
            seq_embed (Tensor): Sequence embeddings. Shape [B, L, d]
            visit_embed (Tensor): Visit embeddings. Shape [B, V, d]
            layer_i (int): The current layer index.
            adm_index (tensor): The admission index. Shape [B, L]
            token_types (Tensor): Token types. Shape [B, L]

        Returns:
            Tuple[Tensor, Tensor, Tensor]: Processed inputs for the Transformer layer.
        """
        
        B, L, d = seq_embed.shape
        V = visit_embed.shape[1]

        # Part 1: prepare main seq embedding x
        # important: initiate new tensor to ensure safe autograd
        x = torch.empty(B, 1 + L + V, d, device=seq_embed.device, dtype=seq_embed.dtype)
        x[:, 0:1, :] = task_id_embed
        x[:, 1:1 + L, :] = seq_embed
        x[:, 1 + L:, :] = visit_embed
        
        # we already have token_types for main seq, just prepare token types for visit nodes
        # here it is out of the if branch because it is needed in the mask making
        arange_V = torch.arange(1, V + 1, device=self.device, dtype=torch.long)[None, :]  # [1, V]
        n_v = num_visits.view(B, 1)  # [B, 1]
        visit_index = torch.where(arange_V <= n_v, arange_V, torch.full((B, V), self.adm_pad_id, device=self.device, dtype=torch.long))  # [B, V]
        visit_type_id = torch.full((B, V), self.node_type_id_dict['visit'], dtype=torch.long, device=self.device)  # [B, V]
        visit_type_id_mask = (visit_index != self.adm_pad_id).long() # [B, V]
        visit_type_id = visit_type_id * visit_type_id_mask # [B, V]
        token_types = torch.cat([token_types, visit_type_id], dim=1)  # [B, L+V]

        # if it is the first time transformer going through, we need extra information of admission index and token types
        if (layer_i == 0) or (layer_i == 1 and self.layer_types[0] == 'gnn'):
            adm_index = torch.cat([adm_index, visit_index], dim=1)  # [B, L+V]
            # transform into embedding and add
            adm_index_embed = self.adm_index_emb(adm_index)
            token_type_embed = self.type_emb(token_types)
            x_non_task = x[:, 1:, :]
            x_non_task.add_(adm_index_embed).add_(token_type_embed)
            x[:, 1:, :] = x_non_task
        else:
            x = x
            
        # part 2: prepare mask (src_key_padding_mask and attn_mask)
        task_pad_mask = torch.zeros((B, 1), dtype=torch.bool, device=self.device) # [B, 1]
        seq_pad_mask = (input_ids == self.seq_pad_id) # [B, L], bool
        visit_pad_mask = (visit_index == self.adm_pad_id) # [B, V], bool
        src_key_padding_mask = torch.cat([task_pad_mask, seq_pad_mask, visit_pad_mask], dim=1)  # [B, 1+L+V]
        attn_mask = self.build_attn_mask(torch.cat([torch.full((B, 1), -1, device=self.device), token_types], dim=1), forbid_map=None, num_heads=self.num_attn_heads)
        assert attn_mask.dtype == src_key_padding_mask.dtype, f"attn_mask dtype ({attn_mask.dtype}) and src_key_padding_mask dtype ({src_key_padding_mask.dtype}) must match"
        return x, src_key_padding_mask, attn_mask
    
    def process_tf_out(self, h, L, V):
        return h[:, 0:1, :], h[:, 1:1 + L, :], h[:, 1 + L:, :]  # [B, 1, d], [B, L, d], [B, V, d]
        
    @staticmethod
    def build_attn_mask(token_types, forbid_map, num_heads):
        B, L = token_types.shape
        device = token_types.device
        
        if forbid_map == None:
            mask = torch.zeros((B, L, L), dtype=torch.bool, device=device)
        else:
            # 收集所有出现的 token 类型
            observed = torch.unique(token_types)
            for q_t, ks in forbid_map.items():
                observed = torch.unique(torch.cat([observed, torch.tensor([q_t] + list(ks), device=device)]))
            type_list = observed.sort().values
            t2i = {t.item(): i for i, t in enumerate(type_list)}  # Map token types to indices
            T = len(type_list)

            # 构造禁止矩阵 (T, T)，单向关系
            ban_table = torch.zeros((T, T), dtype=torch.bool, device=device)
            for q_t, ks in forbid_map.items():
                if q_t in t2i:
                    qi = t2i[q_t]
                    for k_t in ks:
                        if k_t in t2i:
                            ban_table[qi, t2i[k_t]] = True  # 只设置 q -> k 的禁止

            # 向量化映射 token_types 到类型索引
            mapping = torch.zeros_like(type_list, dtype=torch.long, device=device)
            for t, i in t2i.items():
                mapping[type_list == t] = i
            q_idx = mapping[torch.searchsorted(type_list, token_types.unsqueeze(-1))]
            k_idx = mapping[torch.searchsorted(type_list, token_types.unsqueeze(-2))]

            # 查询 ban_table 得到 (B, L, L)
            mask = ban_table[q_idx, k_idx].to(torch.bool)
        
        # 扩展到 num_heads
        mask = mask.unsqueeze(1).expand(B, num_heads, L, L)
        mask = mask.reshape(B * num_heads, L, L)
        return mask

In [18]:
final_metrics = {"precision":[],"recall":[],"f1":[],"auc":[],"prauc":[]}
for i in range(5):
    model = HeteroGT(tokenizer, d_model=128, num_heads=4, layer_types=['gnn', 'tf', 'tf'], max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, label_vocab_size=config.label_vocab_size).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 26.99it/s, loss=0.6870]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.10it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.93it/s]


Validation: {'precision': 0.6437924345280218, 'recall': 0.8325493885204374, 'f1': 0.7261041930833033, 'auc': 0.756861614475183, 'prauc': 0.7314555103458248}
Test:      {'precision': 0.6397500600801737, 'recall': 0.8347444339892294, 'f1': 0.7243537365820711, 'auc': 0.7450262665553289, 'prauc': 0.7220247910309557}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 29.24it/s, loss=0.5978]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.05it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.06it/s]


Validation: {'precision': 0.6728849185478905, 'recall': 0.8030730636538004, 'f1': 0.7322373074027836, 'auc': 0.7925261047844596, 'prauc': 0.8077377834991725}
Test:      {'precision': 0.6614173228329097, 'recall': 0.7902163687651609, 'f1': 0.7201028668761388, 'auc': 0.7835963008878333, 'prauc': 0.8026325372736457}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 29.58it/s, loss=0.5438]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.33it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.27it/s]


Validation: {'precision': 0.7745706978415252, 'recall': 0.6647851991198972, 'f1': 0.7154910513884695, 'auc': 0.8160437960149983, 'prauc': 0.8204824995825031}
Test:      {'precision': 0.7545293072797352, 'recall': 0.6660395108163498, 'f1': 0.7075283094740732, 'auc': 0.8017717404380125, 'prauc': 0.8142003155344855}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 28.15it/s, loss=0.5230]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.17it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.77it/s]


Validation: {'precision': 0.750895473784595, 'recall': 0.7231106930049448, 'f1': 0.7367412090569309, 'auc': 0.8195792087688486, 'prauc': 0.8221592157973597}
Test:      {'precision': 0.7480694980670912, 'recall': 0.7290686735630948, 'f1': 0.7384468744648945, 'auc': 0.8143950387241052, 'prauc': 0.8263494259189552}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 29.46it/s, loss=0.4731]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.99it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.00it/s]


Validation: {'precision': 0.6884447860602397, 'recall': 0.8425838820920584, 'f1': 0.7577552121946224, 'auc': 0.8250878214110925, 'prauc': 0.8351096417501299}
Test:      {'precision': 0.6783431347464118, 'recall': 0.831922232672211, 'f1': 0.7473239387115712, 'auc': 0.8190295996420279, 'prauc': 0.8341726642396683}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 29.34it/s, loss=0.4360]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.79it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.53it/s]


Validation: {'precision': 0.7672872340400024, 'recall': 0.7237378488531712, 'f1': 0.7448765481727507, 'auc': 0.8321783397329956, 'prauc': 0.8420200981074151}
Test:      {'precision': 0.7667997338630512, 'recall': 0.7227971150808317, 'f1': 0.744148501862332, 'auc': 0.8301147050933007, 'prauc': 0.8427487468335182}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 28.92it/s, loss=0.4123]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 52.65it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 53.64it/s]


Validation: {'precision': 0.6957210776526791, 'recall': 0.825964252114061, 'f1': 0.7552688122387651, 'auc': 0.821758429654806, 'prauc': 0.8328849547522609}
Test:      {'precision': 0.6884249471440581, 'recall': 0.8168704923147794, 'f1': 0.7471676416712909, 'auc': 0.8181531921276337, 'prauc': 0.8345096316731729}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 29.09it/s, loss=0.3904]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.44it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.61it/s]


Validation: {'precision': 0.7331737959894311, 'recall': 0.7685794920013529, 'f1': 0.7504592724040332, 'auc': 0.8232297642016588, 'prauc': 0.8322767578748171}
Test:      {'precision': 0.7338563669259691, 'recall': 0.7626215114432029, 'f1': 0.7479624738554598, 'auc': 0.8203744613682568, 'prauc': 0.8334315820749832}


Epoch 009: 100%|██████████| 98/98 [00:03<00:00, 29.36it/s, loss=0.3637]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.64it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.15it/s]


Validation: {'precision': 0.7240981240960344, 'recall': 0.7867670115999161, 'f1': 0.7541328474259331, 'auc': 0.8217928418428897, 'prauc': 0.8280173468259712}
Test:      {'precision': 0.7240476882793718, 'recall': 0.780809031041766, 'f1': 0.7513578706837842, 'auc': 0.8234740507976104, 'prauc': 0.8359447322705349}


Epoch 010: 100%|██████████| 98/98 [00:03<00:00, 29.46it/s, loss=0.3448]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.24it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.13it/s]


Validation: {'precision': 0.7384568651253388, 'recall': 0.7623079335190897, 'f1': 0.7501928664693999, 'auc': 0.8221515323872945, 'prauc': 0.8340242752713076}
Test:      {'precision': 0.7311504956421414, 'recall': 0.7632486672914292, 'f1': 0.7468548584550627, 'auc': 0.8172349071119775, 'prauc': 0.8331543630271326}


Epoch 011: 100%|██████████| 98/98 [00:03<00:00, 28.61it/s, loss=0.3231]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.12it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.16it/s]


Validation: {'precision': 0.6929217668954419, 'recall': 0.8165569143906662, 'f1': 0.7496761142195235, 'auc': 0.8131560859486846, 'prauc': 0.8207518497081199}
Test:      {'precision': 0.693964371176033, 'recall': 0.8184383819353451, 'f1': 0.7510791317223544, 'auc': 0.8151422903872512, 'prauc': 0.8280447116813429}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7672872340400024, 'recall': 0.7237378488531712, 'f1': 0.7448765481727507, 'auc': 0.8321783397329956, 'prauc': 0.8420200981074151}
Corresponding test performance:
{'precision': 0.7667997338630512, 'recall': 0.7227971150808317, 'f1': 0.744148501862332, 'auc': 0.8301147050933007, 'prauc': 0.8427487468335182}


Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 29.57it/s, loss=0.6862]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.24it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.98it/s]


Validation: {'precision': 0.8233650083799701, 'recall': 0.4619002822186833, 'f1': 0.5918039327173215, 'auc': 0.7736148240271222, 'prauc': 0.7909020338803208}
Test:      {'precision': 0.824411134899227, 'recall': 0.48291000313426496, 'f1': 0.6090567483543982, 'auc': 0.7697587060657446, 'prauc': 0.7876225339272951}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 29.25it/s, loss=0.5777]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.39it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.10it/s]


Validation: {'precision': 0.8864902506902056, 'recall': 0.3991846973960515, 'f1': 0.5504864822024136, 'auc': 0.8060733543262751, 'prauc': 0.8217608776486537}
Test:      {'precision': 0.8851035404082492, 'recall': 0.41549074944993575, 'f1': 0.5655142935581282, 'auc': 0.8039277774170541, 'prauc': 0.8194772800791071}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 29.44it/s, loss=0.5645]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 53.48it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.01it/s]


Validation: {'precision': 0.7861024844689981, 'recall': 0.6349952963291471, 'f1': 0.7025151728511457, 'auc': 0.8118706274542546, 'prauc': 0.8194397443675241}
Test:      {'precision': 0.7710298000725349, 'recall': 0.6409532768872971, 'f1': 0.6999999950400363, 'auc': 0.8052551734065786, 'prauc': 0.817061147700396}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 29.58it/s, loss=0.5183]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.19it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.01it/s]


Validation: {'precision': 0.7313387730410054, 'recall': 0.7588585763538449, 'f1': 0.7448445625586618, 'auc': 0.8147214134939382, 'prauc': 0.8208960295413035}
Test:      {'precision': 0.7219298245592927, 'recall': 0.7742238946353898, 'f1': 0.747162954604392, 'auc': 0.8189720684113276, 'prauc': 0.8270063022622547}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 29.49it/s, loss=0.4849]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.77it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.10it/s]


Validation: {'precision': 0.7851420247604328, 'recall': 0.6760740043879709, 'f1': 0.726537484494659, 'auc': 0.8256758428147702, 'prauc': 0.8327173641328465}
Test:      {'precision': 0.7705210563855442, 'recall': 0.6770147381603104, 'f1': 0.7207477833676482, 'auc': 0.8224512733629394, 'prauc': 0.8329315077159989}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 29.49it/s, loss=0.4448]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.38it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.30it/s]


Validation: {'precision': 0.7510772290329762, 'recall': 0.7105675760404184, 'f1': 0.7302610327069335, 'auc': 0.8101638331620762, 'prauc': 0.8160163880678859}
Test:      {'precision': 0.7395561357678213, 'recall': 0.7105675760404184, 'f1': 0.7247721043871741, 'auc': 0.8087568271148767, 'prauc': 0.8169699100852921}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 28.68it/s, loss=0.4330]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.63it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.23it/s]


Validation: {'precision': 0.7172818791926251, 'recall': 0.804327375350253, 'f1': 0.7583148508899525, 'auc': 0.8293437299335498, 'prauc': 0.8372529437521848}
Test:      {'precision': 0.7223311852039161, 'recall': 0.8084038883637241, 'f1': 0.7629476126519056, 'auc': 0.8304370410525346, 'prauc': 0.8392810329892688}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 29.39it/s, loss=0.4254]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.27it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.96it/s]


Validation: {'precision': 0.8008333333299965, 'recall': 0.6026967701454917, 'f1': 0.6877795621038041, 'auc': 0.8069532513690274, 'prauc': 0.8184486208031045}
Test:      {'precision': 0.7917189460443675, 'recall': 0.5936030103462101, 'f1': 0.6784946187557426, 'auc': 0.804570133992904, 'prauc': 0.8178082388009109}


Epoch 009: 100%|██████████| 98/98 [00:03<00:00, 29.40it/s, loss=0.3736]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.28it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.17it/s]


Validation: {'precision': 0.811889596599525, 'recall': 0.5995609909043601, 'f1': 0.6897546848653519, 'auc': 0.8205274278200587, 'prauc': 0.8151028646479568}
Test:      {'precision': 0.8097381342029062, 'recall': 0.6205707118199418, 'f1': 0.7026451220155517, 'auc': 0.8237012563753714, 'prauc': 0.8232290490942189}


Epoch 010: 100%|██████████| 98/98 [00:03<00:00, 29.53it/s, loss=0.3592]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 53.62it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.67it/s]


Validation: {'precision': 0.7850098619298422, 'recall': 0.6240200689851866, 'f1': 0.6953179545317454, 'auc': 0.8016193119792896, 'prauc': 0.7937613599943104}
Test:      {'precision': 0.772585669778923, 'recall': 0.6221386014405076, 'f1': 0.6892478672114205, 'auc': 0.7982588623513713, 'prauc': 0.7939129466357512}


Epoch 011: 100%|██████████| 98/98 [00:03<00:00, 29.21it/s, loss=0.3098]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.27it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.35it/s]


Validation: {'precision': 0.7440381558004959, 'recall': 0.7337723424247922, 'f1': 0.7388695876723638, 'auc': 0.8122571491697519, 'prauc': 0.8097990898339629}
Test:      {'precision': 0.7342902711300342, 'recall': 0.7218563813084922, 'f1': 0.7280202353522685, 'auc': 0.8070417622374666, 'prauc': 0.8094891109641273}


Epoch 012: 100%|██████████| 98/98 [00:03<00:00, 29.61it/s, loss=0.2867]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.20it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.49it/s]


Validation: {'precision': 0.7805237919558078, 'recall': 0.6635308874234446, 'f1': 0.7172881306236076, 'auc': 0.8165544527914514, 'prauc': 0.8134237480665958}
Test:      {'precision': 0.7790697674390296, 'recall': 0.6723110692986131, 'f1': 0.7217640078170824, 'auc': 0.8152857912960232, 'prauc': 0.8146408267799495}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7172818791926251, 'recall': 0.804327375350253, 'f1': 0.7583148508899525, 'auc': 0.8293437299335498, 'prauc': 0.8372529437521848}
Corresponding test performance:
{'precision': 0.7223311852039161, 'recall': 0.8084038883637241, 'f1': 0.7629476126519056, 'auc': 0.8304370410525346, 'prauc': 0.8392810329892688}


Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 29.08it/s, loss=0.6701]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.04it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 53.37it/s]


Validation: {'precision': 0.7981308411177658, 'recall': 0.5355910943852756, 'f1': 0.6410208246152458, 'auc': 0.7907395345000522, 'prauc': 0.7973553807153831}
Test:      {'precision': 0.7877713779318221, 'recall': 0.5575415490731969, 'f1': 0.652956293344552, 'auc': 0.7893809830844088, 'prauc': 0.7974087057993914}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 28.92it/s, loss=0.5739]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.90it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.79it/s]


Validation: {'precision': 0.7235724743756264, 'recall': 0.774851050483616, 'f1': 0.748334337826121, 'auc': 0.8161243757955621, 'prauc': 0.8288430788405463}
Test:      {'precision': 0.7184073860337034, 'recall': 0.780809031041766, 'f1': 0.7483095367043849, 'auc': 0.8148323364776693, 'prauc': 0.8300032669021539}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 29.46it/s, loss=0.5328]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.86it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.66it/s]


Validation: {'precision': 0.7789272030621497, 'recall': 0.6375039197220523, 'f1': 0.7011553666632225, 'auc': 0.8079005159215901, 'prauc': 0.822820100285046}
Test:      {'precision': 0.7807909604490366, 'recall': 0.6500470366865787, 'f1': 0.7094455802549257, 'auc': 0.8128471315169974, 'prauc': 0.8293071948101607}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 29.23it/s, loss=0.5127]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.64it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.01it/s]


Validation: {'precision': 0.7678072625671516, 'recall': 0.6895578551248368, 'f1': 0.7265818552466081, 'auc': 0.808829544526307, 'prauc': 0.8193313950109768}
Test:      {'precision': 0.759892689467606, 'recall': 0.7105675760404184, 'f1': 0.7344028470531567, 'auc': 0.8139421375746383, 'prauc': 0.8287256319619588}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 28.84it/s, loss=0.4828]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 53.51it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.10it/s]


Validation: {'precision': 0.8122192273099182, 'recall': 0.5669488867965916, 'f1': 0.6677746950633318, 'auc': 0.8040759905510663, 'prauc': 0.8182500444627878}
Test:      {'precision': 0.8098681412130588, 'recall': 0.5970523675114549, 'f1': 0.6873646160505501, 'auc': 0.8083876809679341, 'prauc': 0.8236225106552842}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 29.16it/s, loss=0.4681]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.32it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.20it/s]


Validation: {'precision': 0.7303785780217594, 'recall': 0.7441204139205264, 'f1': 0.7371854563215368, 'auc': 0.8053327136125271, 'prauc': 0.8152410617471375}
Test:      {'precision': 0.7205266307578679, 'recall': 0.755095641264487, 'f1': 0.737406211506381, 'auc': 0.8081324396110828, 'prauc': 0.8208997784658141}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 29.57it/s, loss=0.4467]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.29it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.87it/s]


Validation: {'precision': 0.7549420586204671, 'recall': 0.6945751019106473, 'f1': 0.7235015465333382, 'auc': 0.8053069923858136, 'prauc': 0.8212585065884452}
Test:      {'precision': 0.7418289864617306, 'recall': 0.7046095954822684, 'f1': 0.7227404260077407, 'auc': 0.8063033189429756, 'prauc': 0.8253664093024937}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7235724743756264, 'recall': 0.774851050483616, 'f1': 0.748334337826121, 'auc': 0.8161243757955621, 'prauc': 0.8288430788405463}
Corresponding test performance:
{'precision': 0.7184073860337034, 'recall': 0.780809031041766, 'f1': 0.7483095367043849, 'auc': 0.8148323364776693, 'prauc': 0.8300032669021539}


Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 29.38it/s, loss=0.6677]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.92it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 53.07it/s]


Validation: {'precision': 0.766809172169581, 'recall': 0.6186892442752628, 'f1': 0.6848316507298817, 'auc': 0.7841575127734526, 'prauc': 0.7976663059641764}
Test:      {'precision': 0.7581448830940661, 'recall': 0.6202571338958286, 'f1': 0.6823042378900404, 'auc': 0.7836591674732616, 'prauc': 0.8015278716454809}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 29.10it/s, loss=0.5777]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.05it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.46it/s]


Validation: {'precision': 0.7450761798560198, 'recall': 0.628723737846884, 'f1': 0.6819727841491919, 'auc': 0.7833941650392464, 'prauc': 0.7992344613689741}
Test:      {'precision': 0.7311827956962137, 'recall': 0.6183756663511497, 'f1': 0.670064555005475, 'auc': 0.7788514088608162, 'prauc': 0.7989141013420193}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 27.01it/s, loss=0.5726]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.01it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.77it/s]


Validation: {'precision': 0.6857142857124547, 'recall': 0.8052681091225925, 'f1': 0.740698004836761, 'auc': 0.8027365777646627, 'prauc': 0.8124239430228842}
Test:      {'precision': 0.6838487972490515, 'recall': 0.8112260896807425, 'f1': 0.7421112973864309, 'auc': 0.8048231103406724, 'prauc': 0.8207017070610059}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 29.12it/s, loss=0.5261]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.48it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.78it/s]


Validation: {'precision': 0.768432355043778, 'recall': 0.6732518030709526, 'f1': 0.717700145445602, 'auc': 0.8118998150181621, 'prauc': 0.8236194613917156}
Test:      {'precision': 0.7713371265975415, 'recall': 0.6801505174014421, 'f1': 0.7228795150973112, 'auc': 0.8153434231937937, 'prauc': 0.8300476992008331}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 28.88it/s, loss=0.4930]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 53.64it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.80it/s]


Validation: {'precision': 0.7342589602817751, 'recall': 0.7130761994333237, 'f1': 0.7235125626093322, 'auc': 0.8002231316417406, 'prauc': 0.8038555837646456}
Test:      {'precision': 0.7306472081195093, 'recall': 0.7221699592326053, 'f1': 0.7263838461254616, 'auc': 0.8033905172626409, 'prauc': 0.813726242839385}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 29.32it/s, loss=0.4639]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.62it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.79it/s]


Validation: {'precision': 0.715382253605418, 'recall': 0.7306365631836607, 'f1': 0.7229289431832343, 'auc': 0.7973586309635945, 'prauc': 0.8090624763752505}
Test:      {'precision': 0.7158590308347519, 'recall': 0.7133897773574369, 'f1': 0.7146222660830538, 'auc': 0.788799932754397, 'prauc': 0.806624946189753}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 29.25it/s, loss=0.4335]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.66it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.58it/s]


Validation: {'precision': 0.7530446549365594, 'recall': 0.6980244590758921, 'f1': 0.7244914514735044, 'auc': 0.8101902577035828, 'prauc': 0.8177918596081881}
Test:      {'precision': 0.7455581629207323, 'recall': 0.6973973032276658, 'f1': 0.7206740066688213, 'auc': 0.8083316597432888, 'prauc': 0.8189561623507442}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 29.13it/s, loss=0.4090]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.37it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 52.80it/s]


Validation: {'precision': 0.6549988012451331, 'recall': 0.8566948886771506, 'f1': 0.7423912994348185, 'auc': 0.80840393860303, 'prauc': 0.820129507253571}
Test:      {'precision': 0.6594412331390669, 'recall': 0.8585763562218295, 'f1': 0.7459474136918293, 'auc': 0.8109254978112462, 'prauc': 0.8214286962675373}


Epoch 009: 100%|██████████| 98/98 [00:03<00:00, 29.14it/s, loss=0.3825]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.19it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.67it/s]


Validation: {'precision': 0.7025741466124719, 'recall': 0.7873941674481424, 'f1': 0.7425698604583373, 'auc': 0.8110702552198769, 'prauc': 0.8141630057252044}
Test:      {'precision': 0.7101901788228493, 'recall': 0.784571966131124, 'f1': 0.745530388335547, 'auc': 0.8112741078758392, 'prauc': 0.8183374981469732}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.768432355043778, 'recall': 0.6732518030709526, 'f1': 0.717700145445602, 'auc': 0.8118998150181621, 'prauc': 0.8236194613917156}
Corresponding test performance:
{'precision': 0.7713371265975415, 'recall': 0.6801505174014421, 'f1': 0.7228795150973112, 'auc': 0.8153434231937937, 'prauc': 0.8300476992008331}


Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 29.52it/s, loss=0.6734]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.29it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.11it/s]


Validation: {'precision': 0.820526893519097, 'recall': 0.4687989965491728, 'f1': 0.596687283336191, 'auc': 0.7762435635393172, 'prauc': 0.7807777693245315}
Test:      {'precision': 0.7942534633104451, 'recall': 0.4854186265271702, 'f1': 0.6025690883211855, 'auc': 0.7670400408305638, 'prauc': 0.7760180780521471}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 29.38it/s, loss=0.5985]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.20it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.31it/s]


Validation: {'precision': 0.7199131513625313, 'recall': 0.7278143618666422, 'f1': 0.7238421905382004, 'auc': 0.7925342933781204, 'prauc': 0.8009219052379647}
Test:      {'precision': 0.7072578196152225, 'recall': 0.7303229852595474, 'f1': 0.71860536371243, 'auc': 0.7883768793912864, 'prauc': 0.8014335812474234}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 28.84it/s, loss=0.5678]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 52.96it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.27it/s]


Validation: {'precision': 0.7346128391769206, 'recall': 0.6961429915312132, 'f1': 0.7148607259625107, 'auc': 0.7946769418948445, 'prauc': 0.8067465746629442}
Test:      {'precision': 0.7223642172500884, 'recall': 0.7089996864198527, 'f1': 0.7156195550551422, 'auc': 0.7879594130505797, 'prauc': 0.8031006864678859}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 29.42it/s, loss=0.5372]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.09it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.56it/s]


Validation: {'precision': 0.622670807452035, 'recall': 0.8802132329856375, 'f1': 0.7293750763454032, 'auc': 0.7871301732194003, 'prauc': 0.802245796289836}
Test:      {'precision': 0.6170921198654449, 'recall': 0.8717466290345822, 'f1': 0.7226410141202244, 'auc': 0.7802454967844421, 'prauc': 0.8005303325835422}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 29.42it/s, loss=0.5135]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.11it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.67it/s]


Validation: {'precision': 0.7969957081510859, 'recall': 0.5823142050781364, 'f1': 0.67294799294438, 'auc': 0.8031219440344287, 'prauc': 0.8117493164972686}
Test:      {'precision': 0.7876741240996722, 'recall': 0.5851364063951549, 'f1': 0.6714645507019549, 'auc': 0.7965672529013508, 'prauc': 0.8076525361063002}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 29.12it/s, loss=0.4801]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.15it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 52.51it/s]


Validation: {'precision': 0.7657819225223609, 'recall': 0.6694888679815946, 'f1': 0.7144052150301538, 'auc': 0.8067954576715517, 'prauc': 0.8129107310637826}
Test:      {'precision': 0.7546638507541195, 'recall': 0.6723110692986131, 'f1': 0.7111111061254056, 'auc': 0.8013258356499343, 'prauc': 0.8084475070496911}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 25.99it/s, loss=0.4546]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 53.65it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.08it/s]


Validation: {'precision': 0.6886543535601882, 'recall': 0.8184383819353451, 'f1': 0.7479581552298062, 'auc': 0.8147179973935155, 'prauc': 0.8274859412071645}
Test:      {'precision': 0.6858475894227946, 'recall': 0.8297271872034189, 'f1': 0.7509578494490613, 'auc': 0.8124652006521214, 'prauc': 0.8252185583546638}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 29.16it/s, loss=0.4145]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.87it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.13it/s]


Validation: {'precision': 0.6967418546346512, 'recall': 0.784571966131124, 'f1': 0.7380530923605333, 'auc': 0.8111487250560617, 'prauc': 0.8212226861935757}
Test:      {'precision': 0.6957823129232769, 'recall': 0.8018187519573478, 'f1': 0.7450466150695154, 'auc': 0.8114738816769322, 'prauc': 0.8191870629340979}


Epoch 009: 100%|██████████| 98/98 [00:03<00:00, 29.70it/s, loss=0.3985]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.38it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.10it/s]


Validation: {'precision': 0.7153511497802506, 'recall': 0.7218563813084922, 'f1': 0.7185890382318223, 'auc': 0.7881980562589541, 'prauc': 0.8045257496234766}
Test:      {'precision': 0.7049332919618215, 'recall': 0.7124490435850974, 'f1': 0.7086712364202633, 'auc': 0.7805904325009676, 'prauc': 0.7990685367729256}


Epoch 010: 100%|██████████| 98/98 [00:03<00:00, 29.07it/s, loss=0.3651]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.79it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 53.43it/s]


Validation: {'precision': 0.6964438122313046, 'recall': 0.7676387582290134, 'f1': 0.7303102575394775, 'auc': 0.7990723579301605, 'prauc': 0.8081743397908485}
Test:      {'precision': 0.6909191891122163, 'recall': 0.7801818751935398, 'f1': 0.7328424103328924, 'auc': 0.7923351085971189, 'prauc': 0.8007272773248102}


Epoch 011: 100%|██████████| 98/98 [00:03<00:00, 29.58it/s, loss=0.3252]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.22it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.93it/s]


Validation: {'precision': 0.6390909090894566, 'recall': 0.8817811226062033, 'f1': 0.741072600211674, 'auc': 0.8041555153594405, 'prauc': 0.8165573356650647}
Test:      {'precision': 0.634353361624672, 'recall': 0.8905613044813717, 'f1': 0.7409339893995821, 'auc': 0.7999896816252899, 'prauc': 0.8125217522362662}


Epoch 012: 100%|██████████| 98/98 [00:03<00:00, 29.17it/s, loss=0.3178]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.18it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.07it/s]

Validation: {'precision': 0.7015225509890792, 'recall': 0.7657572906843344, 'f1': 0.7322338780658579, 'auc': 0.7965539886036881, 'prauc': 0.8036118574028495}
Test:      {'precision': 0.6876560332851938, 'recall': 0.7773596738765213, 'f1': 0.7297615493292263, 'auc': 0.7954460734054211, 'prauc': 0.7999740964043472}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6886543535601882, 'recall': 0.8184383819353451, 'f1': 0.7479581552298062, 'auc': 0.8147179973935155, 'prauc': 0.8274859412071645}
Corresponding test performance:
{'precision': 0.6858475894227946, 'recall': 0.8297271872034189, 'f1': 0.7509578494490613, 'auc': 0.8124652006521214, 'prauc': 0.8252185583546638}





In [19]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")


Final Metrics:
precision: 0.7329 ± 0.0321
recall: 0.7644 ± 0.0553
f1: 0.7458 ± 0.0131
auc: 0.8206 ± 0.0079
prauc: 0.8335 ± 0.0065


In [20]:
final_metrics = {"precision":[],"recall":[],"f1":[],"auc":[],"prauc":[]}
for i in range(5):
    model = HeteroGT(tokenizer, d_model=128, num_heads=4, layer_types=['tf', 'tf'], max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, label_vocab_size=config.label_vocab_size).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 66.36it/s, loss=0.6820]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 219.51it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 219.41it/s]


Validation: {'precision': 0.7802152317848501, 'recall': 0.5910943869533049, 'f1': 0.6726137328268653, 'auc': 0.7779310669114603, 'prauc': 0.7777426134816886}
Test:      {'precision': 0.772115776596934, 'recall': 0.5939165882703232, 'f1': 0.6713931180887326, 'auc': 0.773911977720364, 'prauc': 0.7720519788630555}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 67.90it/s, loss=0.5837]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 222.25it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 219.23it/s]


Validation: {'precision': 0.8650234741733274, 'recall': 0.4622138601427965, 'f1': 0.6024933533162475, 'auc': 0.7981641775853777, 'prauc': 0.8051068662292509}
Test:      {'precision': 0.8506711409348396, 'recall': 0.4769520225761149, 'f1': 0.6112115686306304, 'auc': 0.7929814918557825, 'prauc': 0.7995481641001062}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 66.75it/s, loss=0.5502]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 216.96it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 217.80it/s]


Validation: {'precision': 0.840271055175023, 'recall': 0.5443712762604441, 'f1': 0.6607040865674061, 'auc': 0.8133547221409223, 'prauc': 0.8272673851986445}
Test:      {'precision': 0.829166666662828, 'recall': 0.5616180620866679, 'f1': 0.6696578751600978, 'auc': 0.8107038289223539, 'prauc': 0.8240000547074331}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 65.78it/s, loss=0.5214]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 212.63it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 208.89it/s]


Validation: {'precision': 0.7975308641942489, 'recall': 0.6077140169313022, 'f1': 0.6898024510417909, 'auc': 0.8099723808280809, 'prauc': 0.8213436696103045}
Test:      {'precision': 0.7895990472378341, 'recall': 0.6237064910610733, 'f1': 0.6969166033355435, 'auc': 0.8103293977542183, 'prauc': 0.8196388109068522}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 66.75it/s, loss=0.4868]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 209.64it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 208.02it/s]


Validation: {'precision': 0.8006535947679713, 'recall': 0.6146127312617917, 'f1': 0.6954053525435737, 'auc': 0.8075939209086345, 'prauc': 0.817452171528826}
Test:      {'precision': 0.78876582278169, 'recall': 0.6252743806816392, 'f1': 0.6975686499533276, 'auc': 0.8032812431577849, 'prauc': 0.8148629490349683}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 64.94it/s, loss=0.4714]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 210.89it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 216.81it/s]


Validation: {'precision': 0.7460875119746213, 'recall': 0.7325180307283395, 'f1': 0.7392405013271958, 'auc': 0.8184862073438322, 'prauc': 0.8233309445337794}
Test:      {'precision': 0.7379695746639616, 'recall': 0.745374725616979, 'f1': 0.7416536611444564, 'auc': 0.8158878810437362, 'prauc': 0.8214888690273356}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 66.49it/s, loss=0.4525]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 220.96it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 220.94it/s]


Validation: {'precision': 0.7644429160909063, 'recall': 0.6970837253035526, 'f1': 0.7292110824282713, 'auc': 0.8185115266763784, 'prauc': 0.8285470141333441}
Test:      {'precision': 0.754742096503312, 'recall': 0.7111947318886448, 'f1': 0.7323215965519346, 'auc': 0.8153599325933297, 'prauc': 0.8246691229212939}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 67.06it/s, loss=0.4189]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 220.63it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 218.38it/s]


Validation: {'precision': 0.7272182254174844, 'recall': 0.7607400438985239, 'f1': 0.7436015275673084, 'auc': 0.8162918149530551, 'prauc': 0.8214974615291964}
Test:      {'precision': 0.7241379310323486, 'recall': 0.7704609595460318, 'f1': 0.7465815811465638, 'auc': 0.8173869143881933, 'prauc': 0.8201820226348031}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 66.00it/s, loss=0.4124]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 217.94it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 218.94it/s]


Validation: {'precision': 0.7544889502736378, 'recall': 0.6851677641872526, 'f1': 0.7181594033904978, 'auc': 0.8069147197657279, 'prauc': 0.8092984934584941}
Test:      {'precision': 0.7524752475221834, 'recall': 0.6911257447454026, 'f1': 0.7204968894166128, 'auc': 0.8046008374493582, 'prauc': 0.8091342648953093}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 66.47it/s, loss=0.3798]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 221.14it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 218.71it/s]


Validation: {'precision': 0.741121495324794, 'recall': 0.7460018814652054, 'f1': 0.7435536752602709, 'auc': 0.8177198454033706, 'prauc': 0.8216753949267185}
Test:      {'precision': 0.7309815950897823, 'recall': 0.747256193161658, 'f1': 0.7390293018675963, 'auc': 0.8128063110199737, 'prauc': 0.816905717321505}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 66.31it/s, loss=0.3739]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 221.23it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 220.20it/s]


Validation: {'precision': 0.7452768729617418, 'recall': 0.7174662903709079, 'f1': 0.7311072006233729, 'auc': 0.8119561304383691, 'prauc': 0.8134965575598239}
Test:      {'precision': 0.7346809854683048, 'recall': 0.729382251487208, 'f1': 0.7320220248954802, 'auc': 0.8115890448054031, 'prauc': 0.8151821380009914}


Epoch 012: 100%|██████████| 98/98 [00:01<00:00, 66.76it/s, loss=0.3422]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 222.20it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 219.52it/s]


Validation: {'precision': 0.7174366616968184, 'recall': 0.7547820633403739, 'f1': 0.7356356918224851, 'auc': 0.8050714824037171, 'prauc': 0.8027855614720343}
Test:      {'precision': 0.7159024956450496, 'recall': 0.7735967387871634, 'f1': 0.7436322482079729, 'auc': 0.8084618222652405, 'prauc': 0.8084965501652762}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7644429160909063, 'recall': 0.6970837253035526, 'f1': 0.7292110824282713, 'auc': 0.8185115266763784, 'prauc': 0.8285470141333441}
Corresponding test performance:
{'precision': 0.754742096503312, 'recall': 0.7111947318886448, 'f1': 0.7323215965519346, 'auc': 0.8153599325933297, 'prauc': 0.8246691229212939}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 67.62it/s, loss=0.6741]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 220.55it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 219.96it/s]


Validation: {'precision': 0.7259598450132936, 'recall': 0.6462841015972208, 'f1': 0.6838088868526764, 'auc': 0.7751152456643406, 'prauc': 0.7838679989813095}
Test:      {'precision': 0.7199585635334256, 'recall': 0.6538099717759367, 'f1': 0.6852916959132024, 'auc': 0.776751795774701, 'prauc': 0.7861070587887498}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 65.80it/s, loss=0.5865]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 220.47it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 217.87it/s]


Validation: {'precision': 0.6677792041061161, 'recall': 0.8156161806183267, 'f1': 0.7343308815523883, 'auc': 0.796995117689181, 'prauc': 0.8081394804005039}
Test:      {'precision': 0.6679457661789001, 'recall': 0.8187519598594584, 'f1': 0.7357001922880332, 'auc': 0.7971154354328962, 'prauc': 0.808823175244052}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 66.46it/s, loss=0.5492]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 216.04it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 217.73it/s]


Validation: {'precision': 0.641978175062359, 'recall': 0.8670429601728848, 'f1': 0.7377267827293187, 'auc': 0.7991972967794512, 'prauc': 0.810105479426211}
Test:      {'precision': 0.6374885426199874, 'recall': 0.8723737848828085, 'f1': 0.7366609245510691, 'auc': 0.7919851898606113, 'prauc': 0.8049538638959202}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 67.74it/s, loss=0.5147]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 216.25it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 213.58it/s]


Validation: {'precision': 0.7775297619018693, 'recall': 0.6553778613965024, 'f1': 0.7112472300177507, 'auc': 0.8132043634855437, 'prauc': 0.8232223119410809}
Test:      {'precision': 0.7679355783280822, 'recall': 0.6578864847894077, 'f1': 0.7086640720414102, 'auc': 0.8093351091004541, 'prauc': 0.8195537149559171}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 65.63it/s, loss=0.4919]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 211.01it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 208.07it/s]


Validation: {'precision': 0.721025043678739, 'recall': 0.7764189401041819, 'f1': 0.7476974130930639, 'auc': 0.8167343506681339, 'prauc': 0.824928908386181}
Test:      {'precision': 0.7182608695631355, 'recall': 0.7770460959524081, 'f1': 0.7464979615667079, 'auc': 0.8116592600869662, 'prauc': 0.8166648163365459}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 68.02it/s, loss=0.4568]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 220.78it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 219.72it/s]


Validation: {'precision': 0.752027448531653, 'recall': 0.7560363750368265, 'f1': 0.7540265782658554, 'auc': 0.8267224757002227, 'prauc': 0.8342930295772399}
Test:      {'precision': 0.7443445924984681, 'recall': 0.753214173719808, 'f1': 0.7487531122048239, 'auc': 0.8255206626711906, 'prauc': 0.8352043650364971}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 67.11it/s, loss=0.4160]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 221.52it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 218.89it/s]


Validation: {'precision': 0.7532552083308813, 'recall': 0.72561931639785, 'f1': 0.7391790398803944, 'auc': 0.8200268686345615, 'prauc': 0.8269643867687336}
Test:      {'precision': 0.7467220978549833, 'recall': 0.7322044528042264, 'f1': 0.7393920152641317, 'auc': 0.8173650696339292, 'prauc': 0.8240164572747245}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 66.76it/s, loss=0.3979]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 220.50it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 216.68it/s]


Validation: {'precision': 0.7954722872724611, 'recall': 0.6390718093426182, 'f1': 0.7087463000474036, 'auc': 0.8243731531079128, 'prauc': 0.8265064322153706}
Test:      {'precision': 0.7856609409978131, 'recall': 0.6597679523340867, 'f1': 0.717231970491425, 'auc': 0.8190249186232569, 'prauc': 0.8226607666809336}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 67.32it/s, loss=0.3694]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 221.98it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 220.47it/s]


Validation: {'precision': 0.7176870748278977, 'recall': 0.7939793038545188, 'f1': 0.7539079896509554, 'auc': 0.818451493735123, 'prauc': 0.8250984976522793}
Test:      {'precision': 0.7093513058109819, 'recall': 0.7920978363098398, 'f1': 0.7484444394574131, 'auc': 0.8130015044693664, 'prauc': 0.81696409569918}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 67.40it/s, loss=0.3416]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 220.53it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 219.02it/s]


Validation: {'precision': 0.6481394253399714, 'recall': 0.8629664471594137, 'f1': 0.7402824429807047, 'auc': 0.8044736648297088, 'prauc': 0.8169446044716173}
Test:      {'precision': 0.6533553875220857, 'recall': 0.8670429601728848, 'f1': 0.7451825850442056, 'auc': 0.8097731115235105, 'prauc': 0.8169796908696796}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 66.98it/s, loss=0.3175]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 221.49it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 220.07it/s]


Validation: {'precision': 0.7222857142836506, 'recall': 0.7927249921580661, 'f1': 0.7558678377354137, 'auc': 0.8215036789894452, 'prauc': 0.8265638406542284}
Test:      {'precision': 0.7169065763457045, 'recall': 0.796487927247424, 'f1': 0.7546048672635151, 'auc': 0.8184183491909136, 'prauc': 0.8208460578544405}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.752027448531653, 'recall': 0.7560363750368265, 'f1': 0.7540265782658554, 'auc': 0.8267224757002227, 'prauc': 0.8342930295772399}
Corresponding test performance:
{'precision': 0.7443445924984681, 'recall': 0.753214173719808, 'f1': 0.7487531122048239, 'auc': 0.8255206626711906, 'prauc': 0.8352043650364971}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 67.17it/s, loss=0.6819]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 221.28it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 219.77it/s]


Validation: {'precision': 0.736938253559602, 'recall': 0.6324866729362418, 'f1': 0.6807289859144052, 'auc': 0.7801711245270082, 'prauc': 0.7839782629930537}
Test:      {'precision': 0.7262892174514018, 'recall': 0.6315459391639023, 'f1': 0.6756122056896379, 'auc': 0.7701826150989532, 'prauc': 0.7761312220643608}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 67.45it/s, loss=0.6033]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 220.42it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 219.50it/s]


Validation: {'precision': 0.797916666663342, 'recall': 0.6005017246766996, 'f1': 0.6852746417244966, 'auc': 0.797923945346814, 'prauc': 0.8089546798397671}
Test:      {'precision': 0.7788657690706697, 'recall': 0.5986202571320206, 'f1': 0.6769503496931386, 'auc': 0.7941590952446386, 'prauc': 0.8066727466665151}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 67.03it/s, loss=0.5473]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 220.57it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 213.98it/s]


Validation: {'precision': 0.8031959629907351, 'recall': 0.5989338350561338, 'f1': 0.6861864510044924, 'auc': 0.8111109470043261, 'prauc': 0.8196323299196819}
Test:      {'precision': 0.7980008329829321, 'recall': 0.6008153026008127, 'f1': 0.6855098340951157, 'auc': 0.8105015384445026, 'prauc': 0.8239306386412032}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 67.45it/s, loss=0.5220]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 217.88it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 214.35it/s]


Validation: {'precision': 0.7827701448169968, 'recall': 0.6610222640305393, 'f1': 0.7167630008134633, 'auc': 0.8142002572323619, 'prauc': 0.8236235468878259}
Test:      {'precision': 0.7718773373194769, 'recall': 0.6472248353695603, 'f1': 0.7040764064296613, 'auc': 0.807619742221833, 'prauc': 0.8173628501786974}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 65.39it/s, loss=0.5018]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 211.50it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 209.26it/s]


Validation: {'precision': 0.7703916636695279, 'recall': 0.6723110692986131, 'f1': 0.7180174096221782, 'auc': 0.8183205767100923, 'prauc': 0.8284495292909162}
Test:      {'precision': 0.7669595782046856, 'recall': 0.6842270304149131, 'f1': 0.7232349966711292, 'auc': 0.8168521205770034, 'prauc': 0.8293327442348487}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 66.39it/s, loss=0.4738]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 221.79it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 220.45it/s]


Validation: {'precision': 0.7778170257831939, 'recall': 0.6904985888971763, 'f1': 0.7315614568092721, 'auc': 0.8267926062324341, 'prauc': 0.8356179454093817}
Test:      {'precision': 0.7705657757696267, 'recall': 0.6961429915312132, 'f1': 0.7314662223580747, 'auc': 0.8231857906416871, 'prauc': 0.8347818236691937}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 67.62it/s, loss=0.4478]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 220.61it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 219.11it/s]


Validation: {'precision': 0.731125037300116, 'recall': 0.7682659140772397, 'f1': 0.749235469006893, 'auc': 0.81838181533385, 'prauc': 0.8259783761700576}
Test:      {'precision': 0.7197452229278525, 'recall': 0.7795547193453134, 'f1': 0.7484570174353286, 'auc': 0.8167531648468578, 'prauc': 0.8270986741468578}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 66.77it/s, loss=0.4167]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 219.56it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 218.58it/s]


Validation: {'precision': 0.5845918562405189, 'recall': 0.9589212919380404, 'f1': 0.7263657910168148, 'auc': 0.808848986156655, 'prauc': 0.8226739105276896}
Test:      {'precision': 0.5862466384929192, 'recall': 0.9570398243933614, 'f1': 0.727099459253544, 'auc': 0.8121730648139899, 'prauc': 0.8249596139840718}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 67.27it/s, loss=0.4134]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 218.03it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 213.19it/s]


Validation: {'precision': 0.7590361445757005, 'recall': 0.6914393226695157, 'f1': 0.7236626139779625, 'auc': 0.8218239886408633, 'prauc': 0.8307800743008277}
Test:      {'precision': 0.7663333333307789, 'recall': 0.7209156475361527, 'f1': 0.7429310016269188, 'auc': 0.8254637354429124, 'prauc': 0.8364308502357167}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 67.06it/s, loss=0.3709]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 219.86it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 217.05it/s]


Validation: {'precision': 0.7160529344053048, 'recall': 0.7804954531176529, 'f1': 0.7468867166874501, 'auc': 0.8157398133141309, 'prauc': 0.8179935066366191}
Test:      {'precision': 0.7196367763884233, 'recall': 0.7952336155509714, 'f1': 0.7555489299126289, 'auc': 0.8163835656980984, 'prauc': 0.8216616598552875}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 66.52it/s, loss=0.3426]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 220.15it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 218.52it/s]


Validation: {'precision': 0.7244509516816268, 'recall': 0.7757917842559555, 'f1': 0.7492428781047374, 'auc': 0.8227233273139634, 'prauc': 0.8310379930730293}
Test:      {'precision': 0.7194202898529872, 'recall': 0.7783004076488608, 'f1': 0.7477029623198298, 'auc': 0.8171544237892394, 'prauc': 0.829096636039102}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7778170257831939, 'recall': 0.6904985888971763, 'f1': 0.7315614568092721, 'auc': 0.8267926062324341, 'prauc': 0.8356179454093817}
Corresponding test performance:
{'precision': 0.7705657757696267, 'recall': 0.6961429915312132, 'f1': 0.7314662223580747, 'auc': 0.8231857906416871, 'prauc': 0.8347818236691937}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 67.24it/s, loss=0.6756]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 218.64it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 215.33it/s]


Validation: {'precision': 0.9069767441749452, 'recall': 0.23236124176785086, 'f1': 0.3699450791275706, 'auc': 0.7747322907595791, 'prauc': 0.7806328352890998}
Test:      {'precision': 0.9058679706490725, 'recall': 0.23236124176785086, 'f1': 0.36985275442285465, 'auc': 0.7736111341806481, 'prauc': 0.7829726608354701}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 66.95it/s, loss=0.5928]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 218.47it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 216.04it/s]


Validation: {'precision': 0.7113616137623514, 'recall': 0.7519598620233554, 'f1': 0.7310975559772296, 'auc': 0.8009696500576868, 'prauc': 0.8095670862379283}
Test:      {'precision': 0.7041055718454425, 'recall': 0.7529005957956949, 'f1': 0.7276860080356801, 'auc': 0.7966324348293969, 'prauc': 0.8089268850230624}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 66.71it/s, loss=0.5555]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 218.06it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 216.49it/s]


Validation: {'precision': 0.6790838852078392, 'recall': 0.7717152712424844, 'f1': 0.7224423845676533, 'auc': 0.7872150231254927, 'prauc': 0.7982742935251389}
Test:      {'precision': 0.6752827140531091, 'recall': 0.7864534336758029, 'f1': 0.7266405860741866, 'auc': 0.789112906676089, 'prauc': 0.800406500449212}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 66.32it/s, loss=0.5328]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 211.88it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 213.48it/s]


Validation: {'precision': 0.7863426895652018, 'recall': 0.5885857635603996, 'f1': 0.6732424628198166, 'auc': 0.7996881602681599, 'prauc': 0.8121500474168935}
Test:      {'precision': 0.7773262901228065, 'recall': 0.5998745688284733, 'f1': 0.6771681366735345, 'auc': 0.7984805312402636, 'prauc': 0.8138659106637867}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 65.98it/s, loss=0.4938]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 211.71it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 206.34it/s]


Validation: {'precision': 0.7862870890104801, 'recall': 0.614926309185905, 'f1': 0.6901284483540686, 'auc': 0.8059956380416542, 'prauc': 0.8153791071342451}
Test:      {'precision': 0.7815912636474978, 'recall': 0.6284101599227707, 'f1': 0.6966799881036961, 'auc': 0.8039752922742553, 'prauc': 0.8181761949453559}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 65.17it/s, loss=0.4711]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 206.89it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 205.66it/s]


Validation: {'precision': 0.8319057815801291, 'recall': 0.4873000940718492, 'f1': 0.6145936279272455, 'auc': 0.7924038284840281, 'prauc': 0.8047773064240333}
Test:      {'precision': 0.8079096045156245, 'recall': 0.4932580746299992, 'f1': 0.6125389360999741, 'auc': 0.7850509903878049, 'prauc': 0.8001263619845511}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 66.85it/s, loss=0.4625]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 217.18it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 216.22it/s]


Validation: {'precision': 0.7196289646896481, 'recall': 0.7541549074921475, 'f1': 0.7364875160539263, 'auc': 0.8000386119821329, 'prauc': 0.8052400323823548}
Test:      {'precision': 0.7108963093125047, 'recall': 0.761053621822637, 'f1': 0.7351203948218435, 'auc': 0.8007841964779612, 'prauc': 0.8092375438968185}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 66.13it/s, loss=0.4357]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 217.98it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 215.19it/s]


Validation: {'precision': 0.7376596479090864, 'recall': 0.6701160238298209, 'f1': 0.7022674941876446, 'auc': 0.7887872833451339, 'prauc': 0.7975091786761723}
Test:      {'precision': 0.7315664288116495, 'recall': 0.6751332706156314, 'f1': 0.7022178684565069, 'auc': 0.7928688454040701, 'prauc': 0.8057648008554666}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 65.62it/s, loss=0.4141]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 216.91it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 215.85it/s]


Validation: {'precision': 0.7072463768095443, 'recall': 0.7651301348361081, 'f1': 0.7350504544120505, 'auc': 0.7952716950258263, 'prauc': 0.7984654640930176}
Test:      {'precision': 0.7017443522999665, 'recall': 0.7695202257736924, 'f1': 0.7340711885471525, 'auc': 0.7973616670869517, 'prauc': 0.8071895249296552}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 66.43it/s, loss=0.4040]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 212.87it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 215.28it/s]


Validation: {'precision': 0.724061810152242, 'recall': 0.7199749137638132, 'f1': 0.7220125736141219, 'auc': 0.798696787830725, 'prauc': 0.8109888228443818}
Test:      {'precision': 0.7198767334338371, 'recall': 0.7325180307283395, 'f1': 0.7261423636645811, 'auc': 0.7968873741850371, 'prauc': 0.8137124234145658}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7862870890104801, 'recall': 0.614926309185905, 'f1': 0.6901284483540686, 'auc': 0.8059956380416542, 'prauc': 0.8153791071342451}
Corresponding test performance:
{'precision': 0.7815912636474978, 'recall': 0.6284101599227707, 'f1': 0.6966799881036961, 'auc': 0.8039752922742553, 'prauc': 0.8181761949453559}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 65.89it/s, loss=0.7057]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 217.33it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 215.66it/s]


Validation: {'precision': 0.7254823151096242, 'recall': 0.5660081530242521, 'f1': 0.6358992376316861, 'auc': 0.7433590254227198, 'prauc': 0.7304281364573326}
Test:      {'precision': 0.7183544303769053, 'recall': 0.5694575101894969, 'f1': 0.6352982284037815, 'auc': 0.7403219936911947, 'prauc': 0.7225840725730017}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 66.07it/s, loss=0.6129]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 216.22it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 215.26it/s]


Validation: {'precision': 0.8154481132027391, 'recall': 0.433678269048499, 'f1': 0.5662231275015762, 'auc': 0.7801704714489863, 'prauc': 0.7806236040201948}
Test:      {'precision': 0.8166287015898825, 'recall': 0.4496707431782701, 'f1': 0.5799797729706228, 'auc': 0.7769742700001561, 'prauc': 0.7812943467244434}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 67.81it/s, loss=0.5614]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 218.71it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 216.12it/s]


Validation: {'precision': 0.7793352022395702, 'recall': 0.6102226403242076, 'f1': 0.684488211744179, 'auc': 0.7955148409971035, 'prauc': 0.8082897427740641}
Test:      {'precision': 0.7788688138225497, 'recall': 0.6218250235163945, 'f1': 0.6915431511197264, 'auc': 0.7928173541975904, 'prauc': 0.8076052825402584}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 66.22it/s, loss=0.5280]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 218.69it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 215.80it/s]


Validation: {'precision': 0.7039319248805636, 'recall': 0.7522734399474686, 'f1': 0.7273002830130068, 'auc': 0.7981007787804703, 'prauc': 0.8138832385253201}
Test:      {'precision': 0.7074106879223283, 'recall': 0.7513327061751292, 'f1': 0.7287104572894213, 'auc': 0.7972116228186075, 'prauc': 0.8161028182659663}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 65.51it/s, loss=0.5001]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 219.30it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 215.75it/s]


Validation: {'precision': 0.7183470105488912, 'recall': 0.7685794920013529, 'f1': 0.7426147503435565, 'auc': 0.8118326484554353, 'prauc': 0.8253702607365363}
Test:      {'precision': 0.7109127555407114, 'recall': 0.7742238946353898, 'f1': 0.7412188482040994, 'auc': 0.811163525099069, 'prauc': 0.8257552882853234}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 65.69it/s, loss=0.4855]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 218.48it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 215.46it/s]


Validation: {'precision': 0.7407407407383757, 'recall': 0.727500783942529, 'f1': 0.734061061285064, 'auc': 0.8099082787083806, 'prauc': 0.8183032585992817}
Test:      {'precision': 0.7316384180767996, 'recall': 0.7309501411077738, 'f1': 0.7312941126447656, 'auc': 0.8097836815658964, 'prauc': 0.825012579975156}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 66.16it/s, loss=0.4500]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 210.44it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 207.03it/s]


Validation: {'precision': 0.7484949832750887, 'recall': 0.7017873941652499, 'f1': 0.7243890547212426, 'auc': 0.8122835234744876, 'prauc': 0.820831826045771}
Test:      {'precision': 0.753706199458377, 'recall': 0.7014738162411368, 'f1': 0.7266525855514261, 'auc': 0.8124533722713563, 'prauc': 0.8238166250606056}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 66.43it/s, loss=0.4181]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 207.93it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 210.41it/s]


Validation: {'precision': 0.7730916695685845, 'recall': 0.6955158356829868, 'f1': 0.7322548646054593, 'auc': 0.8213361393584102, 'prauc': 0.8250849329991723}
Test:      {'precision': 0.7636300897144113, 'recall': 0.693947946062421, 'f1': 0.7271233726991977, 'auc': 0.8122458471058468, 'prauc': 0.8179685464692689}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 66.78it/s, loss=0.3994]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 218.97it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 215.89it/s]


Validation: {'precision': 0.7568792489454294, 'recall': 0.7331451865765659, 'f1': 0.7448231870982905, 'auc': 0.8260324736515672, 'prauc': 0.8298430235668481}
Test:      {'precision': 0.7448165869194743, 'recall': 0.7322044528042264, 'f1': 0.7384566679898066, 'auc': 0.8206998676731361, 'prauc': 0.8306641312535861}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 67.35it/s, loss=0.3597]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 219.19it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 217.76it/s]


Validation: {'precision': 0.6786259541967465, 'recall': 0.8363123236097952, 'f1': 0.7492625319252227, 'auc': 0.8138794954500054, 'prauc': 0.8166604411661491}
Test:      {'precision': 0.6788732394348813, 'recall': 0.8312950768239846, 'f1': 0.7473921574395804, 'auc': 0.8139978064645372, 'prauc': 0.8199006033636341}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 67.92it/s, loss=0.3278]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 218.60it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 213.02it/s]


Validation: {'precision': 0.732292917164669, 'recall': 0.7651301348361081, 'f1': 0.7483514748344906, 'auc': 0.8205702295488868, 'prauc': 0.8213479060445483}
Test:      {'precision': 0.7305371152663532, 'recall': 0.7591721542779581, 'f1': 0.7445794198803818, 'auc': 0.8169926518072005, 'prauc': 0.8201434342444918}


Epoch 012: 100%|██████████| 98/98 [00:01<00:00, 66.99it/s, loss=0.2904]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 220.13it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 218.05it/s]


Validation: {'precision': 0.7266918385108142, 'recall': 0.784571966131124, 'f1': 0.7545235173211023, 'auc': 0.8192962250382276, 'prauc': 0.8162981776159262}
Test:      {'precision': 0.7194873288647845, 'recall': 0.7745374725595029, 'f1': 0.7459981828631884, 'auc': 0.814956609976006, 'prauc': 0.8176047671968805}


Epoch 013: 100%|██████████| 98/98 [00:01<00:00, 66.48it/s, loss=0.2663]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 220.88it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 218.85it/s]


Validation: {'precision': 0.680656934304795, 'recall': 0.8187519598594584, 'f1': 0.743345190769833, 'auc': 0.8037202137393751, 'prauc': 0.804005926451631}
Test:      {'precision': 0.6760818864973411, 'recall': 0.818124804011232, 'f1': 0.7403518679148202, 'auc': 0.797684707715678, 'prauc': 0.8027133041881584}


Epoch 014: 100%|██████████| 98/98 [00:01<00:00, 66.27it/s, loss=0.2552]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 221.57it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 219.34it/s]

Validation: {'precision': 0.7202919708008166, 'recall': 0.7735967387871634, 'f1': 0.7459933424489243, 'auc': 0.809754503952579, 'prauc': 0.8091303778970079}
Test:      {'precision': 0.7050946142628672, 'recall': 0.7594857322020713, 'f1': 0.7312801882414032, 'auc': 0.7978885585331197, 'prauc': 0.8030001426998757}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7568792489454294, 'recall': 0.7331451865765659, 'f1': 0.7448231870982905, 'auc': 0.8260324736515672, 'prauc': 0.8298430235668481}
Corresponding test performance:
{'precision': 0.7448165869194743, 'recall': 0.7322044528042264, 'f1': 0.7384566679898066, 'auc': 0.8206998676731361, 'prauc': 0.8306641312535861}





In [21]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")


Final Metrics:
precision: 0.7592 ± 0.0147
recall: 0.7042 ± 0.0425
f1: 0.7295 ± 0.0175
auc: 0.8177 ± 0.0077
prauc: 0.8287 ± 0.0065
