In [None]:
import torch
import pickle
import numpy as np
from argparse import Namespace
from torch.utils.data import DataLoader
import torch.nn.functional as F
from heterogt.utils.tokenizer import EHRTokenizer
from heterogt.utils.dataset import FineTuneEHRDataset, batcher
from heterogt.utils.train import train_with_early_stopping
from heterogt.utils.seed import set_random_seed

In [2]:
set_random_seed(123)

[INFO] Random seed set to 123


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
config = Namespace(
    dataset = "MIMIC-III",
    tasks = ["death", "readmission", "stay", "next_diag_6m", "next_diag_12m"], 
    task_index = 2,
    token_type = ["diag", "med", "lab", "pro"],
    special_tokens = ["[PAD]"],
    batch_size = 32,
    lr = 1e-3,
    epochs = 500,
    early_stop_patience = 5,   
)

In [5]:
full_data_path = f"./data_process/{config.dataset}-processed/mimic.pkl"  # for tokenizer
curr_task = config.tasks[config.task_index]
print("Current task:", curr_task)
if curr_task == "next_diag_6m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_6m.pkl"
elif curr_task == "next_diag_12m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_downstream.pkl"

Current task: stay


In [6]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_gender_sentences = ["[PAD]"] + [str(c) + "_" + gender for c in set(ehr_full_data["AGE"].values.tolist()) for gender in ["M", "F"]]
token_type_sentences = ["[PAD]"] + config.token_type
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
config.max_num_adms = max_admissions
print(f"Max admissions per patient: {config.max_num_adms}")

Max admissions per patient: 8


In [7]:
task_sentences = config.tasks
tokenizer = EHRTokenizer(token_type_sentences, age_gender_sentences, task_sentences, diag_sentences, 
                         med_sentences, lab_sentences, pro_sentences, special_tokens=config.special_tokens)
config.label_vocab_size = len(tokenizer.diag_voc.id2word)  # only for diagnosis
config.global_vocab_size = len(tokenizer.vocab.id2word)
config.age_gender_vocab_size = tokenizer.token_number("age_gender")
print(f"Age and gender vocabulary size: {config.age_gender_vocab_size}")

Age and gender vocabulary size: 37


In [8]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))
# example label percentage
print("Percentage of DEATH in test dataset:",
      (test_data["DEATH"] == True).mean() * 100, "%")

print("Percentage of READMISSION in test dataset:",
      (test_data["READMISSION"] == 1).mean() * 100, "%")

print("Percentage of STAY>7 days in test dataset:",
      (test_data["STAY_DAYS"] > 7).mean() * 100, "%")

Percentage of DEATH in test dataset: 28.648477157360407 %
Percentage of READMISSION in test dataset: 40.1491116751269 %
Percentage of STAY>7 days in test dataset: 50.58692893401015 %


In [9]:
train_dataset = FineTuneEHRDataset(train_data, tokenizer, token_type=config.token_type, task=curr_task)
val_dataset = FineTuneEHRDataset(val_data, tokenizer, token_type=config.token_type, task=curr_task)
test_dataset = FineTuneEHRDataset(test_data, tokenizer, token_type=config.token_type, task=curr_task)

In [10]:
train_dataloader = DataLoader(
    train_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=True,
    batch_size=config.batch_size,
)

val_dataloader = DataLoader(
    val_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

test_dataloader = DataLoader(
    test_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

In [11]:
if curr_task in ["death", "stay", "readmission"]:
    eval_metric = "prauc"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "prauc"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [12]:
input_ids, token_types, adm_index, age_gender_ids, task_index, labels = next(iter(train_dataloader))
print("Input IDs shape:", input_ids.shape)
print("Token Types shape:", token_types.shape)
print("Admission Index shape:", adm_index.shape)
print("Age/Sex IDs shape:", age_gender_ids.shape)
print("Task Index:", task_index)
print("Labels shape:", labels.shape)

Input IDs shape: torch.Size([32, 256])
Token Types shape: torch.Size([32, 256])
Admission Index shape: torch.Size([32, 256])
Age/Sex IDs shape: torch.Size([32, 7])
Task Index: 2
Labels shape: torch.Size([32, 1])


# Model Walkthrough

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData, Batch as HeteroBatch
from torch_geometric.nn import HeteroConv, GATConv
from heterogt.model.layer import TransformerEncoder

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [14]:
class DiseaseOccHetGNN(nn.Module):
    def __init__(self, d_model: int, heads: int = 1, dropout: float = 0.0):
        super().__init__()
        self.act = nn.GELU()

        # 第1层
        self.conv1 = HeteroConv({
            ('visit','contains','occ'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('occ','contained_by','visit'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('visit','next','visit'):       GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
        }, aggr='mean')

        # 第2层
        self.conv2 = HeteroConv({
            ('visit','contains','occ'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('occ','contained_by','visit'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('visit','next','visit'):       GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
        }, aggr='mean')
        self.lin = nn.Linear(d_model, d_model)

    def forward(self, hg):
        # x_dict: {'visit': [N_visit, d], 'occ': [N_occ, d]}
        x_dict = {'visit': hg['visit'].x, 'occ': hg['occ'].x}

        # 第1层：HeteroConv → Linear → GELU → Dropout
        x_dict = self.conv1(x_dict, hg.edge_index_dict)

        # 第2层：HeteroConv → Linear（末层通常不再加激活/随你需要）
        x_dict = self.conv2(x_dict, hg.edge_index_dict)
        x_dict = {k: self.lin(v) for k, v in x_dict.items()}

        return x_dict  # {'visit': [N_visit, d], 'occ': [N_occ, d]}

In [15]:
# multi-class classification task
class MultiPredictionHead(nn.Module):
    def __init__(self, hidden_size, label_size):
        super(MultiPredictionHead, self).__init__()
        self.cls = nn.Sequential(
                nn.Linear(hidden_size, hidden_size), 
                nn.ReLU(), 
                nn.Linear(hidden_size, label_size)
            )

    def forward(self, input):
        return self.cls(input)
    
class BinaryPredictionHead(nn.Module):
    def __init__(self, hidden_size):
        super(BinaryPredictionHead, self).__init__()
        self.cls = nn.Sequential(
                nn.Linear(hidden_size, hidden_size), 
                nn.ReLU(), 
                nn.Linear(hidden_size, 1)
            )
    def forward(self, input):
        return self.cls(input)

In [16]:
for i in range(len(train_dataset)):
    age_gender_ids = train_dataset[i][3]
    if len(age_gender_ids[0]) > 3:
        print(age_gender_ids)
        break
exp_i = i
id_seq = torch.concat([train_dataset[exp_i][0][0], torch.zeros(5, dtype=train_dataset[exp_i][0][0].dtype)], dim=0)
type_seq = torch.concat([train_dataset[exp_i][1][0], torch.zeros(5, dtype=train_dataset[exp_i][1][0].dtype)], dim=0)
visit_seq = torch.concat([train_dataset[exp_i][2][0], torch.zeros(5, dtype=train_dataset[exp_i][2][0].dtype)], dim=0)
age_sex = torch.concat([train_dataset[exp_i][3][0], torch.zeros(3, dtype=train_dataset[exp_i][3][0].dtype)], dim=0)

tensor([[11, 11, 11, 13]])


In [None]:
class HeteroGT(nn.Module):
    def __init__(self, tokenizer, d_model, num_heads, num_layers, max_num_adms, device, task, use_hetero_graph):
        super(HeteroGT, self).__init__()
        self.device = device
        self.tokenizer = tokenizer
        self.max_num_adms = max_num_adms
        self.use_hetero_graph = use_hetero_graph
        self.global_vocab_size = len(self.tokenizer.vocab.word2id)
        self.age_sex_vocab_size = len(self.tokenizer.age_gender_voc.word2id)
        self.n_type = len(self.tokenizer.token_type_voc.word2id)
        self.d_model = d_model
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.seq_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="all")[0] #0
        self.type_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="type")[0] #0
        self.adm_pad_id = 0
        self.age_sex_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="age_gender")[0] #0
        self.diag_type_id = 1
        self.visit_type_id = 5
        
        # embedding layers
        self.token_emb = nn.Embedding(self.global_vocab_size, d_model, padding_idx=self.seq_pad_id)
        self.type_emb = nn.Embedding(self.n_type + 1, d_model, padding_idx=self.type_pad_id) # n_type already have PAD, + 1 for visit
        self.adm_index_emb = nn.Embedding(self.max_num_adms + 1, d_model, padding_idx=self.adm_pad_id) # +1 for pad
        self.age_sex_emb = nn.Embedding(self.age_sex_vocab_size, d_model, padding_idx=self.age_sex_pad_id)
        self.task_emb = nn.Embedding(5, d_model, padding_idx=None)  # 5 task in total, task embedding, not used in this model
        
        # GNN
        self.het_gnn = DiseaseOccHetGNN(d_model)    

        # encoder transformer
        self.num_attn_heads = num_heads
        # enc_layer = nn.TransformerEncoderLayer(d_model=self.d_model, nhead=self.num_attn_heads, batch_first=True, norm_first=True)
        # self.encoder = nn.TransformerEncoder(enc_layer, num_layers=self.num_layers, enable_nested_tensor=False)
        self.encoder = TransformerEncoder(d_model = self.d_model, num_heads = self.num_attn_heads, num_layers = self.num_layers)

        # prediction head
        if task in ["death", "stay", "readmission"]:
            self.cls_head = BinaryPredictionHead(self.d_model)
        else:
            self.cls_head = MultiPredictionHead(self.d_model, config.label_vocab_size)

    def run_graph(self, B, input_ids, token_types, adm_index, age_gender_index):
        """Run the heterogeneous graph neural network.

        Args:
            B (int): Batch size.
            input_ids (Tensor): Input token IDs. Shape of [B, L]
            token_types (Tensor): Token type IDs. Shape of [B, L]
            adm_index (Tensor): Admission index IDs. Shape of [B, L]
            age_gender_index (Tensor): Age and gender index IDs. Shape of [B, V_max]

        Returns:
            visit_emb_pad: Padded visit embeddings. Shape of [B, V_max, d]
            visit_pad_mask: Visit padding mask. Shape of [B, V_max]
            visit_index_pad: Padded visit index. Shape of [B, V_max]
            visit_type_ids_pad: Padded visit type IDs. Shape of [B, V_max], using self.visit_type_id and 0.
        """
        graphs = [] # contains heterogeneous graphs for each patient
        for p in range(B):
            hg_p = self.build_patient_graph(input_ids[p], token_types[p], adm_index[p], age_gender_index[p])
            graphs.append(hg_p)

        batch_graph = HeteroBatch.from_data_list(graphs).to(self.device)
        out = self.het_gnn(batch_graph)
        h_visit_all = out['visit']  # extract virtual visit node representations

        # 取出每个样本的 visit 表示序列（按我们在 build_patient_graph 中的保序构造）
        visit_emb_seq = []
        offset = 0
        for p in range(B):
            n_v = graphs[p]['visit'].num_nodes
            visit_emb_p = h_visit_all[offset:offset + n_v]  # [N_visit_p, d]
            offset += n_v
            visit_emb_seq.append(visit_emb_p)
            
        visit_emb_pad = []
        visit_pad_mask = []
        visit_index_pad = []
        for p in range(B):
            v = visit_emb_seq[p]                          # [N_visit_p, d]
            Np = v.size(0)
            if Np < self.max_num_adms:
                pad_len = self.max_num_adms - Np
                v_pad = torch.cat([v, torch.zeros(pad_len, self.d_model, device=self.device, dtype=v.dtype)], dim=0)
                m_pad = torch.cat([torch.zeros(Np, dtype=torch.bool, device=self.device), torch.ones(pad_len, dtype=torch.bool, device=self.device)], dim=0)
                i_pad = torch.cat([torch.arange(1, Np + 1, device=self.device), torch.full((pad_len,), self.adm_pad_id, dtype=torch.long, device=self.device)], dim=0)
            else:
                v_pad = v[:self.max_num_adms]
                m_pad = torch.zeros(self.max_num_adms, dtype=torch.bool, device=self.device)
                i_pad = torch.arange(1, self.max_num_adms + 1, device=self.device)
            visit_emb_pad.append(v_pad)      # [V_max, d]
            visit_pad_mask.append(m_pad)     # [V_max]
            visit_index_pad.append(i_pad)    # [V_max]

        visit_emb_pad  = torch.stack(visit_emb_pad,  dim=0)  # [B, V_max, d]
        visit_pad_mask = torch.stack(visit_pad_mask, dim=0)  # [B, V_max]
        visit_index_pad = torch.stack(visit_index_pad, dim=0)  # [B, V_max]

        visit_type_ids = torch.full((B, self.max_num_adms), self.visit_type_id, dtype=torch.long, device=self.device) # [B, V_max]
        visit_type_ids_pad = visit_type_ids * (~visit_pad_mask)

        return visit_emb_pad, visit_pad_mask, visit_index_pad, visit_type_ids_pad

    def build_patient_graph(self, id_seq: torch.Tensor, type_seq: torch.Tensor, adm_seq: torch.Tensor, age_sex: torch.Tensor):
        """Build a heterogeneous graph for a single patient.

        Args:
            id_seq (torch.Tensor): Sequence of token IDs of the current patient. Shape of [L]
            type_seq (torch.Tensor): Sequence of token types. Shape of [L]
            adm_seq (torch.Tensor): Sequence of admission IDs. Shape of [L]
            age_sex (torch.Tensor): Sequence of age and sex IDs. Shape of [V_max]

        Returns:
            a heterogeneous graph for the patient.
        """
        # build a graph just for one patient
        hg = HeteroData()
        occ_mask = (type_seq == self.diag_type_id) & (id_seq != self.seq_pad_id) # [L], a mask for the token types needed in the graph
        occ_pos = torch.nonzero(occ_mask, as_tuple=False).view(-1) # [L], seq position index for the token types needed in the graph
        N_occ = occ_pos.numel() # int, number of occurrences of the token types needed in the graph

        # build visit virtual nodes
        nonpad = adm_seq != self.adm_pad_id # [L], mask for non-pad tokens in the main seq
        adm_used = adm_seq[nonpad] # seq非pad部分
        adm_ids_unique, adm_lid_nonpad = torch.unique(adm_used, return_inverse=True) # adm_ids_unique: [V_max]
        adm_lid_full = torch.full_like(id_seq, fill_value=-1) # [L]
        adm_lid_full[nonpad] = adm_lid_nonpad
        N_visit = adm_ids_unique.numel() # should be V_max
        age_sex_nonpad = age_sex[age_sex!=self.age_sex_pad_id]
        assert N_visit == len(adm_ids_unique) == len(age_sex_nonpad)
        adm_x = self.age_sex_emb(age_sex_nonpad.to(self.device)) # [V_max, d]
        hg['visit'].x = adm_x
        hg['visit'].num_nodes = N_visit
        
        # build diag nodes
        gid_occ = id_seq[occ_pos]
        x_occ = self.token_emb(gid_occ) # [N_occ, d]
        hg['occ'].x = x_occ
        hg['occ'].num_nodes = N_occ

        # build edges between diag nodes and virtual visit nodes
        occ_adm_lid = adm_lid_full[occ_pos]
        e_v2o = torch.stack([occ_adm_lid, torch.arange(N_occ, device=self.device)], dim=0)
        e_o2v = torch.stack([torch.arange(N_occ, device=self.device), occ_adm_lid], dim=0)
        hg['visit','contains','occ'].edge_index = e_v2o
        hg['occ','contained_by','visit'].edge_index = e_o2v
        
        # build forward edges between virtual visit nodes
        if N_visit > 1:
            src = torch.arange(0, N_visit - 1, device=self.device)
            dst = torch.arange(1, N_visit, device=self.device)
            e_next = torch.stack([src, dst], dim=0) # [2, N_visit-1]
        else:
            e_next = torch.empty(2, 0, dtype=torch.long, device=self.device)
        hg['visit','next','visit'].edge_index = e_next
        return hg

    @staticmethod
    def build_attn_mask(token_types, forbid_map, num_heads):
        B, L = token_types.shape
        device = token_types.device
        
        if forbid_map == None:
            mask = torch.zeros((B, L, L), dtype=torch.bool, device=device)
        else:
            # 收集所有出现的 token 类型
            observed = torch.unique(token_types)
            for q_t, ks in forbid_map.items():
                observed = torch.unique(torch.cat([observed, torch.tensor([q_t] + list(ks), device=device)]))
            type_list = observed.sort().values
            t2i = {t.item(): i for i, t in enumerate(type_list)}  # Map token types to indices
            T = len(type_list)

            # 构造禁止矩阵 (T, T)，单向关系
            ban_table = torch.zeros((T, T), dtype=torch.bool, device=device)
            for q_t, ks in forbid_map.items():
                if q_t in t2i:
                    qi = t2i[q_t]
                    for k_t in ks:
                        if k_t in t2i:
                            ban_table[qi, t2i[k_t]] = True  # 只设置 q -> k 的禁止

            # 向量化映射 token_types 到类型索引
            mapping = torch.zeros_like(type_list, dtype=torch.long, device=device)
            for t, i in t2i.items():
                mapping[type_list == t] = i
            q_idx = mapping[torch.searchsorted(type_list, token_types.unsqueeze(-1))]
            k_idx = mapping[torch.searchsorted(type_list, token_types.unsqueeze(-2))]

            # 查询 ban_table 得到 (B, L, L)
            mask = ban_table[q_idx, k_idx].to(torch.bool)
        
        # 扩展到 num_heads
        mask = mask.unsqueeze(1).expand(B, num_heads, L, L)
        mask = mask.reshape(B * num_heads, L, L)
        return mask
    
    def forward(self, input_ids, token_types, adm_index, age_gender_index, task_id):
        # input_ids, [B, L], pad_id = 0, token_types, [B, L], pad_id = 0, adm_index, [B, L], pad_id = 0, 
        # age_gender_index, [B, V_max], task_id, [1]
        B, L = input_ids.shape
        task_id = torch.full((B,), task_id, dtype=torch.long, device=self.device) # [1] -> [B]
        # 基础表示
        token_embed = self.token_emb(input_ids)  # [B, L, d]
        seq_pad_mask = (input_ids == self.seq_pad_id) # [B, L], bool

        forbid_map = None # a dict, encode the forbidden attentions,
        # e.g., {3: [1, 2, 4]} means that token type 3 cannot attend to token types 1, 2, and 4, it is uni-directional
        if self.use_hetero_graph: # use graph as middle layer
            # get visit embed and mask
            visit_emb_pad, visit_pad_mask, visit_index_pad, visit_type_ids_pad = self.run_graph(B, input_ids, token_types, adm_index, age_gender_index)
            # elongate the main sequence with visit information
            token_embed = torch.concat([token_embed, visit_emb_pad], dim=1)  # [B, L+V, d]
            adm_index = torch.concat([adm_index, visit_index_pad], dim=1)  # [B, L+V]
            token_types = torch.concat([token_types, visit_type_ids_pad], dim=1) # [B, L+V]
            seq_pad_mask = torch.concat([seq_pad_mask, visit_pad_mask], dim=1) # [B, L+V]
            
        adm_emb = self.adm_index_emb(adm_index) # [B, L+V, d]
        token_type_emb = self.type_emb(token_types) # [B, L+V, d]
        x = token_embed + adm_emb + token_type_emb # [B, L+V, d]
        task_id_emb = self.task_emb(task_id).unsqueeze(1) # [B, 1, d]
        x = torch.concat([task_id_emb, x], dim=1) # [B, 1+L+V, d]

        # mask
        task_pad_mask = torch.zeros((B, 1), dtype=torch.bool, device=self.device)
        src_key_padding_mask = torch.concat([task_pad_mask, seq_pad_mask], dim=1)  # [B, 1+L+V]
        attn_mask = self.build_attn_mask(torch.concat([torch.full((B, 1), -1, device=self.device), token_types], dim=1), 
                                         forbid_map=forbid_map, 
                                         num_heads=self.num_attn_heads)
        assert attn_mask.dtype == src_key_padding_mask.dtype, f"attn_mask dtype ({attn_mask.dtype}) and src_key_padding_mask dtype ({src_key_padding_mask.dtype}) must match"
        
        # ===== Transformer 编码（batch_first=True） =====
        h = self.encoder(src=x, src_key_padding_mask=src_key_padding_mask, mask=attn_mask)   # [B, 1+L(+V), d]

        # ===== 分类：取 CLS（task 位） =====
        logits = self.cls_head(h[:, 0, :])  # [B, label_size]
        return logits

In [18]:
final_metrics = {"precision":[],"recall":[],"f1":[],"auc":[],"prauc":[]}
for i in range(5):
    model = HeteroGT(tokenizer, d_model=128, num_heads=4, num_layers=2, max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, use_hetero_graph=True).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

Epoch 001: 100%|██████████| 98/98 [00:04<00:00, 21.26it/s, loss=0.6710]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 47.95it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 44.77it/s]


Validation: {'precision': 0.7570834907413786, 'recall': 0.6284101599227707, 'f1': 0.6867717565212385, 'auc': 0.7894704531929437, 'prauc': 0.7930792404377097}
Test:      {'precision': 0.7393655371278322, 'recall': 0.6431483223560892, 'f1': 0.6879087657748875, 'auc': 0.7790617023707599, 'prauc': 0.7860167804328437}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 28.82it/s, loss=0.5732]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.75it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 44.62it/s]


Validation: {'precision': 0.6873655913960017, 'recall': 0.8018187519573478, 'f1': 0.7401939449477855, 'auc': 0.8036371723570359, 'prauc': 0.8156585775057466}
Test:      {'precision': 0.6744680851045892, 'recall': 0.7952336155509714, 'f1': 0.7298891877212547, 'auc': 0.7891822159540194, 'prauc': 0.8022493903051733}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 28.91it/s, loss=0.5295]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.10it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 47.47it/s]


Validation: {'precision': 0.7479958173553574, 'recall': 0.6729382251468393, 'f1': 0.7084846434104237, 'auc': 0.7998909158756133, 'prauc': 0.8014714628783864}
Test:      {'precision': 0.7476796149853981, 'recall': 0.6820319849461209, 'f1': 0.7133486339062014, 'auc': 0.7985446561640712, 'prauc': 0.7961504596538833}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 29.17it/s, loss=0.5121]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.97it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.80it/s]


Validation: {'precision': 0.7961165048510084, 'recall': 0.591407964877418, 'f1': 0.6786613840952359, 'auc': 0.806327502150385, 'prauc': 0.8219869828375939}
Test:      {'precision': 0.791236047950429, 'recall': 0.6001881467525865, 'f1': 0.6825962861046663, 'auc': 0.8031152431587918, 'prauc': 0.8198482713508173}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 28.85it/s, loss=0.4812]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.41it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 46.09it/s]


Validation: {'precision': 0.7132205334077625, 'recall': 0.7798682972694266, 'f1': 0.7450569153192944, 'auc': 0.8077675391889514, 'prauc': 0.8128063325315595}
Test:      {'precision': 0.7065615319608376, 'recall': 0.7867670115999161, 'f1': 0.7445103807688908, 'auc': 0.8066177525371877, 'prauc': 0.8111498578508225}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 28.68it/s, loss=0.4465]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.15it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 47.61it/s]


Validation: {'precision': 0.7552631578922524, 'recall': 0.7199749137638132, 'f1': 0.7371969768595404, 'auc': 0.8157608625211483, 'prauc': 0.8267580635971463}
Test:      {'precision': 0.7409105797551887, 'recall': 0.7093132643439658, 'f1': 0.7247676976594555, 'auc': 0.807432602138068, 'prauc': 0.8218491147745572}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 29.16it/s, loss=0.4351]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 45.05it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.69it/s]


Validation: {'precision': 0.7694950188912075, 'recall': 0.7024145500134763, 'f1': 0.7344262245161736, 'auc': 0.8224284877054044, 'prauc': 0.8214763248340963}
Test:      {'precision': 0.7561634582885641, 'recall': 0.7021009720893632, 'f1': 0.7281300763053173, 'auc': 0.8175399283351126, 'prauc': 0.8190578595199994}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 29.01it/s, loss=0.4085]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.65it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.31it/s]


Validation: {'precision': 0.7957254418380776, 'recall': 0.607086861083076, 'f1': 0.6887228695098772, 'auc': 0.8127686599713109, 'prauc': 0.8155875923437091}
Test:      {'precision': 0.7935272429299733, 'recall': 0.6074004390071891, 'f1': 0.6880994622261338, 'auc': 0.8134131321199931, 'prauc': 0.8169380569939692}


Epoch 009: 100%|██████████| 98/98 [00:03<00:00, 29.01it/s, loss=0.3668]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.10it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 49.03it/s]


Validation: {'precision': 0.7566154851331048, 'recall': 0.7262464722460764, 'f1': 0.7411199949997257, 'auc': 0.8296220416444746, 'prauc': 0.8373906151955047}
Test:      {'precision': 0.7532133676068342, 'recall': 0.7350266541212448, 'f1': 0.7440088824765633, 'auc': 0.8248970301704244, 'prauc': 0.834327804297587}


Epoch 010: 100%|██████████| 98/98 [00:03<00:00, 29.17it/s, loss=0.3454]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.19it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.93it/s]


Validation: {'precision': 0.7592775041025969, 'recall': 0.7249921605496238, 'f1': 0.7417388464600253, 'auc': 0.8249124448437933, 'prauc': 0.8287111509071432}
Test:      {'precision': 0.756152849738484, 'recall': 0.7322044528042264, 'f1': 0.7439859755628877, 'auc': 0.82402500167359, 'prauc': 0.8295150063174397}


Epoch 011: 100%|██████████| 98/98 [00:03<00:00, 29.31it/s, loss=0.3025]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.28it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.82it/s]


Validation: {'precision': 0.7329615861191668, 'recall': 0.7419253684517343, 'f1': 0.7374162331153868, 'auc': 0.812151300293413, 'prauc': 0.812330286083273}
Test:      {'precision': 0.7299134734217246, 'recall': 0.7406710567552817, 'f1': 0.7352529132859166, 'auc': 0.814197831933306, 'prauc': 0.8183452556416169}


Epoch 012: 100%|██████████| 98/98 [00:03<00:00, 26.33it/s, loss=0.2915]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.63it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 49.06it/s]


Validation: {'precision': 0.6883793642504211, 'recall': 0.8284728755069662, 'f1': 0.7519567333356572, 'auc': 0.8147360826310484, 'prauc': 0.8193375246046865}
Test:      {'precision': 0.6929665195933222, 'recall': 0.8372530573821347, 'f1': 0.758307294105067, 'auc': 0.8168827233663876, 'prauc': 0.8249521083687893}


Epoch 013: 100%|██████████| 98/98 [00:03<00:00, 29.35it/s, loss=0.2527]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.12it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 49.02it/s]


Validation: {'precision': 0.65473330129588, 'recall': 0.8544998432083585, 'f1': 0.7413957235578994, 'auc': 0.7977914709818847, 'prauc': 0.8021744298870811}
Test:      {'precision': 0.6583715873383466, 'recall': 0.8544998432083585, 'f1': 0.7437227025055834, 'auc': 0.8029829162948277, 'prauc': 0.8125763843877304}


Epoch 014: 100%|██████████| 98/98 [00:03<00:00, 29.20it/s, loss=0.2525]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.31it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.91it/s]


Validation: {'precision': 0.7507072135758461, 'recall': 0.6657259328922367, 'f1': 0.7056672710668408, 'auc': 0.8000794042401242, 'prauc': 0.8000033383524767}
Test:      {'precision': 0.7445558244011595, 'recall': 0.6754468485397446, 'f1': 0.7083196267096125, 'auc': 0.8027279769391875, 'prauc': 0.8100030817850019}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7566154851331048, 'recall': 0.7262464722460764, 'f1': 0.7411199949997257, 'auc': 0.8296220416444746, 'prauc': 0.8373906151955047}
Corresponding test performance:
{'precision': 0.7532133676068342, 'recall': 0.7350266541212448, 'f1': 0.7440088824765633, 'auc': 0.8248970301704244, 'prauc': 0.834327804297587}


Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 29.01it/s, loss=0.6592]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.12it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 49.08it/s]


Validation: {'precision': 0.7109869646156015, 'recall': 0.5986202571320206, 'f1': 0.6499829708602691, 'auc': 0.7437492646592656, 'prauc': 0.7402382191487381}
Test:      {'precision': 0.7029992684685333, 'recall': 0.6026967701454917, 'f1': 0.6489954365265548, 'auc': 0.7392615164386809, 'prauc': 0.7345797724386656}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 29.23it/s, loss=0.5767]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 49.53it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 45.24it/s]


Validation: {'precision': 0.7406981204421147, 'recall': 0.6055189714625101, 'f1': 0.6663215961523256, 'auc': 0.7709705111159405, 'prauc': 0.7840358300498396}
Test:      {'precision': 0.7341255105802669, 'recall': 0.6199435559717155, 'f1': 0.6722203282532612, 'auc': 0.7709288097578606, 'prauc': 0.7868821338647192}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 29.33it/s, loss=0.5505]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.43it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 49.32it/s]


Validation: {'precision': 0.636425648020381, 'recall': 0.8777046095927322, 'f1': 0.7378410390158856, 'auc': 0.79802256012814, 'prauc': 0.8070867336869428}
Test:      {'precision': 0.6368044331548446, 'recall': 0.8648479147040927, 'f1': 0.7335106334112316, 'auc': 0.7900734215277536, 'prauc': 0.8047923913386558}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 29.24it/s, loss=0.5041]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.92it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.26it/s]


Validation: {'precision': 0.7584041374186982, 'recall': 0.6437754782043156, 'f1': 0.6964043369577834, 'auc': 0.7970523876080355, 'prauc': 0.8066171762737393}
Test:      {'precision': 0.7452457839944555, 'recall': 0.6513013483830313, 'f1': 0.6951137835075819, 'auc': 0.7888292268718664, 'prauc': 0.8018815495676275}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 29.33it/s, loss=0.4827]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.33it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 49.19it/s]


Validation: {'precision': 0.7326607818388, 'recall': 0.7287550956389817, 'f1': 0.7307027146958991, 'auc': 0.7978325144237305, 'prauc': 0.807458308219293}
Test:      {'precision': 0.7239321608017465, 'recall': 0.7227971150808317, 'f1': 0.7233641876856353, 'auc': 0.7952054287737693, 'prauc': 0.8061569700779221}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 28.97it/s, loss=0.4532]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.43it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 49.19it/s]


Validation: {'precision': 0.7257814917959586, 'recall': 0.735340232045358, 'f1': 0.7305295900135147, 'auc': 0.8028643298731251, 'prauc': 0.8088340706417713}
Test:      {'precision': 0.7159193499827989, 'recall': 0.7460018814652054, 'f1': 0.7306511006509788, 'auc': 0.8005093250423935, 'prauc': 0.8081789755854462}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 29.29it/s, loss=0.4290]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 45.14it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.20it/s]


Validation: {'precision': 0.6640936254963543, 'recall': 0.8363123236097952, 'f1': 0.7403192178257898, 'auc': 0.7915009732369631, 'prauc': 0.7953513472411644}
Test:      {'precision': 0.6604814443313428, 'recall': 0.825964252114061, 'f1': 0.7340114204465759, 'auc': 0.790547966097344, 'prauc': 0.796463602168954}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 28.98it/s, loss=0.4108]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.96it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.72it/s]


Validation: {'precision': 0.7121125143492764, 'recall': 0.7779868297247476, 'f1': 0.7435935811007547, 'auc': 0.8042621680241144, 'prauc': 0.8078166377874247}
Test:      {'precision': 0.6966891133538252, 'recall': 0.778613985572974, 'f1': 0.7353768645527045, 'auc': 0.797855036398696, 'prauc': 0.8048042055149245}


Epoch 009: 100%|██████████| 98/98 [00:03<00:00, 29.15it/s, loss=0.3833]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.68it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.75it/s]


Validation: {'precision': 0.7182841496783746, 'recall': 0.7403574788311685, 'f1': 0.7291537936400299, 'auc': 0.7975994660434091, 'prauc': 0.7985622861407577}
Test:      {'precision': 0.7013682331924409, 'recall': 0.739416745058829, 'f1': 0.7198900881168441, 'auc': 0.7878453069263474, 'prauc': 0.7930459078595182}


Epoch 010: 100%|██████████| 98/98 [00:03<00:00, 29.04it/s, loss=0.3403]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.81it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.57it/s]


Validation: {'precision': 0.70011668611231, 'recall': 0.7525870178715818, 'f1': 0.7254042567543684, 'auc': 0.7900068311961104, 'prauc': 0.7943873536990098}
Test:      {'precision': 0.6923520923500942, 'recall': 0.7522734399474686, 'f1': 0.7210700280692547, 'auc': 0.78341706316909, 'prauc': 0.7892574102921825}


Epoch 011: 100%|██████████| 98/98 [00:03<00:00, 28.48it/s, loss=0.3237]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.43it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 45.15it/s]


Validation: {'precision': 0.7361563517891331, 'recall': 0.7086861084957394, 'f1': 0.7221600844706615, 'auc': 0.7953892993065617, 'prauc': 0.7924783038574938}
Test:      {'precision': 0.7249683143196294, 'recall': 0.7174662903709079, 'f1': 0.7211977885360811, 'auc': 0.7875899145688907, 'prauc': 0.7912287966286912}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7257814917959586, 'recall': 0.735340232045358, 'f1': 0.7305295900135147, 'auc': 0.8028643298731251, 'prauc': 0.8088340706417713}
Corresponding test performance:
{'precision': 0.7159193499827989, 'recall': 0.7460018814652054, 'f1': 0.7306511006509788, 'auc': 0.8005093250423935, 'prauc': 0.8081789755854462}


Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 29.36it/s, loss=0.6466]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.22it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.97it/s]


Validation: {'precision': 0.8742603550231195, 'recall': 0.370649106301754, 'f1': 0.5205901741907633, 'auc': 0.8105676363267716, 'prauc': 0.8145171758645502}
Test:      {'precision': 0.8938307030065006, 'recall': 0.3907180934449962, 'f1': 0.5437486320290896, 'auc': 0.8081664147473231, 'prauc': 0.8183927563432456}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 29.15it/s, loss=0.5900]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.14it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 49.20it/s]


Validation: {'precision': 0.7602201257831752, 'recall': 0.6064597052348496, 'f1': 0.6746903840370382, 'auc': 0.7912819409157299, 'prauc': 0.7984879935139151}
Test:      {'precision': 0.7554812337392959, 'recall': 0.6375039197220523, 'f1': 0.6914965936729691, 'auc': 0.7902037350503108, 'prauc': 0.7933248082010108}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 29.13it/s, loss=0.5545]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.21it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.30it/s]


Validation: {'precision': 0.5857772666272076, 'recall': 0.9582941360898141, 'f1': 0.7270996859864142, 'auc': 0.8068063088140716, 'prauc': 0.8148864773207966}
Test:      {'precision': 0.5826923076911872, 'recall': 0.9501411100628719, 'f1': 0.7223745333711914, 'auc': 0.8055892874130425, 'prauc': 0.8194315146751447}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 29.46it/s, loss=0.5113]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.42it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 49.03it/s]


Validation: {'precision': 0.679947916664896, 'recall': 0.8187519598594584, 'f1': 0.742922174582673, 'auc': 0.8023184068834825, 'prauc': 0.8110903966229119}
Test:      {'precision': 0.6711065573753302, 'recall': 0.8215741611764767, 'f1': 0.7387565155619059, 'auc': 0.8000660375981441, 'prauc': 0.8109406674131272}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 29.38it/s, loss=0.4909]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 45.40it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 49.15it/s]


Validation: {'precision': 0.7024121657034547, 'recall': 0.8400752586991531, 'f1': 0.7651006661785802, 'auc': 0.8295964711280738, 'prauc': 0.8378756918199562}
Test:      {'precision': 0.6892655367213938, 'recall': 0.841643148319719, 'f1': 0.7578709536807426, 'auc': 0.8269267299761872, 'prauc': 0.8367454739544032}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 29.31it/s, loss=0.4368]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.94it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 47.75it/s]


Validation: {'precision': 0.8334060183129812, 'recall': 0.599247412980247, 'f1': 0.6971908014061083, 'auc': 0.8331551435068622, 'prauc': 0.8448416776114743}
Test:      {'precision': 0.8239258635180121, 'recall': 0.6133584195653391, 'f1': 0.7032176834024696, 'auc': 0.8319729688782719, 'prauc': 0.8466075035069052}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 28.96it/s, loss=0.4265]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.04it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.92it/s]


Validation: {'precision': 0.7374387254879368, 'recall': 0.7547820633403739, 'f1': 0.7460096029326575, 'auc': 0.8207677605321642, 'prauc': 0.8253342381214253}
Test:      {'precision': 0.7267303102603618, 'recall': 0.7638758231396555, 'f1': 0.7448402334964705, 'auc': 0.8212824783427382, 'prauc': 0.827950856127843}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 27.63it/s, loss=0.4149]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.76it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.62it/s]


Validation: {'precision': 0.8529111338056542, 'recall': 0.5236751332689756, 'f1': 0.6489216972457944, 'auc': 0.8245798271834985, 'prauc': 0.8363917114814909}
Test:      {'precision': 0.8493493493450984, 'recall': 0.5321417372200309, 'f1': 0.6543281232734252, 'auc': 0.8188920884239412, 'prauc': 0.8335599651168355}


Epoch 009: 100%|██████████| 98/98 [00:03<00:00, 29.20it/s, loss=0.3819]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.01it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.53it/s]


Validation: {'precision': 0.8038904327082022, 'recall': 0.6349952963291471, 'f1': 0.7095304785982883, 'auc': 0.8201009176349051, 'prauc': 0.8371341240279162}
Test:      {'precision': 0.7940729483252505, 'recall': 0.6553778613965024, 'f1': 0.7180896703568339, 'auc': 0.8209201272230435, 'prauc': 0.8348170592716337}


Epoch 010: 100%|██████████| 98/98 [00:03<00:00, 25.94it/s, loss=0.3415]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.40it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.61it/s]


Validation: {'precision': 0.8063205417570821, 'recall': 0.5600501724661021, 'f1': 0.66099185304303, 'auc': 0.8116441098541536, 'prauc': 0.8189813290388949}
Test:      {'precision': 0.8103448275826245, 'recall': 0.5747883348994206, 'f1': 0.672537144289081, 'auc': 0.820184754274198, 'prauc': 0.8282868904859615}


Epoch 011: 100%|██████████| 98/98 [00:03<00:00, 28.98it/s, loss=0.2978]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.85it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.41it/s]


Validation: {'precision': 0.7338961851134025, 'recall': 0.7359673878935843, 'f1': 0.7349303222248889, 'auc': 0.814127062257124, 'prauc': 0.8184955960942248}
Test:      {'precision': 0.7240319606615734, 'recall': 0.7387895892106028, 'f1': 0.7313363290041367, 'auc': 0.8128592618989736, 'prauc': 0.8224725381605369}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.8334060183129812, 'recall': 0.599247412980247, 'f1': 0.6971908014061083, 'auc': 0.8331551435068622, 'prauc': 0.8448416776114743}
Corresponding test performance:
{'precision': 0.8239258635180121, 'recall': 0.6133584195653391, 'f1': 0.7032176834024696, 'auc': 0.8319729688782719, 'prauc': 0.8466075035069052}


Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 29.21it/s, loss=0.6970]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.29it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 49.15it/s]


Validation: {'precision': 0.652702353243722, 'recall': 0.7914706804616135, 'f1': 0.715419496177924, 'auc': 0.7668195472079458, 'prauc': 0.7774008917648437}
Test:      {'precision': 0.6638787245147312, 'recall': 0.796487927247424, 'f1': 0.7241625039486433, 'auc': 0.7773907296701594, 'prauc': 0.789442727205038}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 29.09it/s, loss=0.5725]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.06it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 49.13it/s]


Validation: {'precision': 0.9175879396892704, 'recall': 0.28629664471531424, 'f1': 0.43642447056015893, 'auc': 0.8088089474502277, 'prauc': 0.8212923842560595}
Test:      {'precision': 0.9295918367252083, 'recall': 0.28566948886708793, 'f1': 0.4370352566559366, 'auc': 0.8102563134611502, 'prauc': 0.8255603138737563}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 29.19it/s, loss=0.5485]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 45.01it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.61it/s]


Validation: {'precision': 0.755023183922894, 'recall': 0.6127312637171128, 'f1': 0.6764756744701559, 'auc': 0.8005415322958638, 'prauc': 0.8085814212720834}
Test:      {'precision': 0.7574838954120596, 'recall': 0.626842270302205, 'f1': 0.6859986223585798, 'auc': 0.8011636609996139, 'prauc': 0.8112279769337214}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 28.89it/s, loss=0.5173]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.06it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.80it/s]


Validation: {'precision': 0.6759236300503945, 'recall': 0.8548134211324716, 'f1': 0.7549155309288386, 'auc': 0.8192474953704304, 'prauc': 0.8247749997897528}
Test:      {'precision': 0.6733990147766665, 'recall': 0.8573220445253769, 'f1': 0.7543109345100271, 'auc': 0.8201384474218408, 'prauc': 0.8309264761662789}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 29.03it/s, loss=0.4913]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 47.92it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.21it/s]


Validation: {'precision': 0.717587034811772, 'recall': 0.7497648165545633, 'f1': 0.7333231049526167, 'auc': 0.8118808757555234, 'prauc': 0.8222185795076155}
Test:      {'precision': 0.7195484254286407, 'recall': 0.7594857322020713, 'f1': 0.7389778744827029, 'auc': 0.8140677700784247, 'prauc': 0.8273169827489674}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 29.37it/s, loss=0.4575]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.28it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 49.19it/s]


Validation: {'precision': 0.8040421792583302, 'recall': 0.5738476011270811, 'f1': 0.6697163720812899, 'auc': 0.8103868341882126, 'prauc': 0.8238356159373182}
Test:      {'precision': 0.8047538200305401, 'recall': 0.5945437441185496, 'f1': 0.683859327842563, 'auc': 0.8137495614690748, 'prauc': 0.8305591509270315}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 29.31it/s, loss=0.4431]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.07it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 49.13it/s]


Validation: {'precision': 0.7595916930631482, 'recall': 0.6767011602361973, 'f1': 0.7157545555449591, 'auc': 0.811548609752625, 'prauc': 0.8177946788640563}
Test:      {'precision': 0.7648489058674371, 'recall': 0.6904985888971763, 'f1': 0.7257745500535054, 'auc': 0.820271428621762, 'prauc': 0.8296053433690486}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 26.62it/s, loss=0.4123]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.88it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.25it/s]


Validation: {'precision': 0.7095652173892476, 'recall': 0.7676387582290134, 'f1': 0.7374604559182941, 'auc': 0.8114811920060436, 'prauc': 0.8192519930289534}
Test:      {'precision': 0.7117376294571008, 'recall': 0.7757917842559555, 'f1': 0.7423855914061434, 'auc': 0.813957942304682, 'prauc': 0.8282532581653597}


Epoch 009: 100%|██████████| 98/98 [00:03<00:00, 29.24it/s, loss=0.3500]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.48it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 49.04it/s]


Validation: {'precision': 0.7415658816017137, 'recall': 0.7306365631836607, 'f1': 0.7360606489230808, 'auc': 0.8124614621171042, 'prauc': 0.8196119091800256}
Test:      {'precision': 0.7389240506305731, 'recall': 0.7322044528042264, 'f1': 0.735548900337211, 'auc': 0.8215157239447196, 'prauc': 0.8330503238850564}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6759236300503945, 'recall': 0.8548134211324716, 'f1': 0.7549155309288386, 'auc': 0.8192474953704304, 'prauc': 0.8247749997897528}
Corresponding test performance:
{'precision': 0.6733990147766665, 'recall': 0.8573220445253769, 'f1': 0.7543109345100271, 'auc': 0.8201384474218408, 'prauc': 0.8309264761662789}


Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 29.37it/s, loss=0.6556]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.36it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 49.11it/s]


Validation: {'precision': 0.7217074784886773, 'recall': 0.6839134524907999, 'f1': 0.7023023617699896, 'auc': 0.7838651347666689, 'prauc': 0.7911542649444518}
Test:      {'precision': 0.7267670157044281, 'recall': 0.6964565694553263, 'f1': 0.7112890262249699, 'auc': 0.780195062582201, 'prauc': 0.7856526099465735}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 29.30it/s, loss=0.5914]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.24it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.87it/s]


Validation: {'precision': 0.6245898052928799, 'recall': 0.8952649733430691, 'f1': 0.73582473742473, 'auc': 0.8072556264932252, 'prauc': 0.8177872791512203}
Test:      {'precision': 0.6258578702664913, 'recall': 0.8864847914679007, 'f1': 0.7337139842459925, 'auc': 0.7995581722282706, 'prauc': 0.8123289464345484}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 28.49it/s, loss=0.5498]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.21it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 45.18it/s]


Validation: {'precision': 0.7346733668317097, 'recall': 0.6876763875801578, 'f1': 0.7103984400954803, 'auc': 0.7965364057338644, 'prauc': 0.8120940142888605}
Test:      {'precision': 0.7260769483698584, 'recall': 0.6923800564418552, 'f1': 0.7088282454018304, 'auc': 0.7930757665671572, 'prauc': 0.8071657964956187}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 29.39it/s, loss=0.5062]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.09it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 49.06it/s]


Validation: {'precision': 0.7229488703902409, 'recall': 0.7626215114432029, 'f1': 0.7422554505178578, 'auc': 0.8162002333196589, 'prauc': 0.8225857652072157}
Test:      {'precision': 0.7281268349949263, 'recall': 0.7776732518006345, 'f1': 0.7520849078158696, 'auc': 0.8189452909706167, 'prauc': 0.8267736786707521}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 29.02it/s, loss=0.4811]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.31it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 49.21it/s]


Validation: {'precision': 0.7555123216577319, 'recall': 0.7306365631836607, 'f1': 0.7428662471909663, 'auc': 0.8225150958984792, 'prauc': 0.8325817640764654}
Test:      {'precision': 0.7584650112842358, 'recall': 0.7375352775141502, 'f1': 0.7478537310876309, 'auc': 0.8235538797843913, 'prauc': 0.8318677286589263}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 29.21it/s, loss=0.4336]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 46.92it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 47.47it/s]


Validation: {'precision': 0.7436868686845212, 'recall': 0.7387895892106028, 'f1': 0.7412301350008688, 'auc': 0.8225400635736289, 'prauc': 0.8331693279929278}
Test:      {'precision': 0.7416744475544922, 'recall': 0.747256193161658, 'f1': 0.7444548528546644, 'auc': 0.8234368039815838, 'prauc': 0.8332993994456532}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 28.78it/s, loss=0.3966]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 49.51it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 49.40it/s]


Validation: {'precision': 0.7563850687598023, 'recall': 0.7243650047013974, 'f1': 0.7400288272921345, 'auc': 0.8227638683881, 'prauc': 0.8298237354045601}
Test:      {'precision': 0.7534898477133456, 'recall': 0.7447475697687528, 'f1': 0.7490931979626398, 'auc': 0.8266423958360074, 'prauc': 0.8336631870189529}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 29.41it/s, loss=0.3724]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 45.16it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 49.15it/s]


Validation: {'precision': 0.7210210210188558, 'recall': 0.7529005957956949, 'f1': 0.7366160404058163, 'auc': 0.8097656565157243, 'prauc': 0.8151284545504212}
Test:      {'precision': 0.7253373313321579, 'recall': 0.7585449984297318, 'f1': 0.741569584209305, 'auc': 0.8135206948846532, 'prauc': 0.8194030877802825}


Epoch 009: 100%|██████████| 98/98 [00:03<00:00, 29.28it/s, loss=0.3464]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.28it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.65it/s]


Validation: {'precision': 0.7386046511605004, 'recall': 0.7469426152375449, 'f1': 0.7427502288612653, 'auc': 0.8201197564240019, 'prauc': 0.8266756263643691}
Test:      {'precision': 0.7437170338171154, 'recall': 0.7516462840992423, 'f1': 0.7476606313047331, 'auc': 0.8244155899065154, 'prauc': 0.83265665821936}


Epoch 010: 100%|██████████| 98/98 [00:03<00:00, 28.69it/s, loss=0.3097]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 48.47it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 49.10it/s]


Validation: {'precision': 0.7497618291497309, 'recall': 0.7403574788311685, 'f1': 0.7450299729088612, 'auc': 0.8209445939658203, 'prauc': 0.8275478683388723}
Test:      {'precision': 0.7448494453225203, 'recall': 0.7369081216659238, 'f1': 0.7408574981503931, 'auc': 0.8240953176222238, 'prauc': 0.8315355793010908}


Epoch 011: 100%|██████████| 98/98 [00:03<00:00, 29.27it/s, loss=0.2644]
Running inference: 100%|██████████| 198/198 [00:04<00:00, 49.11it/s]
Running inference: 100%|██████████| 197/197 [00:04<00:00, 48.75it/s]

Validation: {'precision': 0.7046703296683938, 'recall': 0.804327375350253, 'f1': 0.7512080781943039, 'auc': 0.813847544863697, 'prauc': 0.8171968265078629}
Test:      {'precision': 0.7081627081607749, 'recall': 0.8134211351495346, 'f1': 0.7571511917525988, 'auc': 0.8201537991500678, 'prauc': 0.8255005493252348}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7436868686845212, 'recall': 0.7387895892106028, 'f1': 0.7412301350008688, 'auc': 0.8225400635736289, 'prauc': 0.8331693279929278}
Corresponding test performance:
{'precision': 0.7416744475544922, 'recall': 0.747256193161658, 'f1': 0.7444548528546644, 'auc': 0.8234368039815838, 'prauc': 0.8332993994456532}





In [19]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")


Final Metrics:
precision: 0.7416 ± 0.0494
recall: 0.7398 ± 0.0774
f1: 0.7353 ± 0.0177
auc: 0.8202 ± 0.0106
prauc: 0.8307 ± 0.0125


In [20]:
final_metrics = {"precision":[],"recall":[],"f1":[],"auc":[],"prauc":[]}
for i in range(5):
    model = HeteroGT(tokenizer, d_model=128, num_heads=4, num_layers=2, max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, use_hetero_graph=False).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

Epoch 001:   0%|          | 0/98 [00:00<?, ?it/s, loss=0.7584]

Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 74.60it/s, loss=0.6543]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 222.65it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 223.47it/s]


Validation: {'precision': 0.8118489583280479, 'recall': 0.3910316713691094, 'f1': 0.5278306834403983, 'auc': 0.7558063911019024, 'prauc': 0.7659564036686408}
Test:      {'precision': 0.8089053803289493, 'recall': 0.41015992474001206, 'f1': 0.5443195960320538, 'auc': 0.7618889068411832, 'prauc': 0.7712677424020319}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 73.76it/s, loss=0.5754]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 222.29it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 223.23it/s]


Validation: {'precision': 0.6325842696614998, 'recall': 0.8827218563785427, 'f1': 0.7370074568439654, 'auc': 0.8057807251356368, 'prauc': 0.8201910613932202}
Test:      {'precision': 0.628500451669764, 'recall': 0.8726873628069217, 'f1': 0.7307338798618984, 'auc': 0.8041602680160082, 'prauc': 0.8224039108805503}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 74.22it/s, loss=0.5337]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 222.37it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 222.58it/s]


Validation: {'precision': 0.6277765312976716, 'recall': 0.8773910316686191, 'f1': 0.7318859485753046, 'auc': 0.7949047154142188, 'prauc': 0.8042986251307959}
Test:      {'precision': 0.6322900247677724, 'recall': 0.8805268109097506, 'f1': 0.7360419348443613, 'auc': 0.7931022420066571, 'prauc': 0.8071968676445602}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 75.20it/s, loss=0.4902]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 219.57it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 222.80it/s]


Validation: {'precision': 0.7413127413103562, 'recall': 0.7224835371567185, 'f1': 0.7317770316826386, 'auc': 0.8098937602815831, 'prauc': 0.8196697972982976}
Test:      {'precision': 0.73652118099892, 'recall': 0.7196613358397, 'f1': 0.7279936508270686, 'auc': 0.8093876573111725, 'prauc': 0.8244505905541931}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 74.67it/s, loss=0.4633]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 219.65it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 217.30it/s]


Validation: {'precision': 0.8263888888847898, 'recall': 0.522420821572523, 'f1': 0.6401536936184314, 'auc': 0.8058480424086764, 'prauc': 0.8116171211626564}
Test:      {'precision': 0.833494675697805, 'recall': 0.5399811853228599, 'f1': 0.6553758277662844, 'auc': 0.8138648755981511, 'prauc': 0.8242762360519239}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 71.87it/s, loss=0.4329]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 213.40it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 213.33it/s]


Validation: {'precision': 0.733290897515171, 'recall': 0.7224835371567185, 'f1': 0.7278470965617104, 'auc': 0.8050630928629725, 'prauc': 0.8096659869268064}
Test:      {'precision': 0.7371553884688686, 'recall': 0.7378488554382633, 'f1': 0.7375019539382945, 'auc': 0.8117607828274058, 'prauc': 0.8210202216119991}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 73.58it/s, loss=0.4152]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 223.04it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 220.95it/s]


Validation: {'precision': 0.7617486338770064, 'recall': 0.6556914393206156, 'f1': 0.7047522700508954, 'auc': 0.805188081949034, 'prauc': 0.8125049709300304}
Test:      {'precision': 0.7758931793548326, 'recall': 0.674192536843292, 'f1': 0.7214765050892873, 'auc': 0.8135731927618364, 'prauc': 0.8225921694814217}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6325842696614998, 'recall': 0.8827218563785427, 'f1': 0.7370074568439654, 'auc': 0.8057807251356368, 'prauc': 0.8201910613932202}
Corresponding test performance:
{'precision': 0.628500451669764, 'recall': 0.8726873628069217, 'f1': 0.7307338798618984, 'auc': 0.8041602680160082, 'prauc': 0.8224039108805503}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 75.90it/s, loss=0.6868]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 220.00it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 222.79it/s]


Validation: {'precision': 0.7945417095736635, 'recall': 0.48385073690660446, 'f1': 0.6014422092863535, 'auc': 0.7815343495428304, 'prauc': 0.7887976614202195}
Test:      {'precision': 0.7826725403779605, 'recall': 0.5014111006569414, 'f1': 0.6112385273476407, 'auc': 0.7796190959392415, 'prauc': 0.7841237282655599}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 73.95it/s, loss=0.5820]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 221.02it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 223.82it/s]


Validation: {'precision': 0.6321347749544075, 'recall': 0.7883349012204819, 'f1': 0.7016466598645303, 'auc': 0.7600604911006063, 'prauc': 0.7799094290524086}
Test:      {'precision': 0.6306306306290524, 'recall': 0.7902163687651609, 'f1': 0.7014613729316871, 'auc': 0.7556651652173125, 'prauc': 0.7787847846552611}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 73.48it/s, loss=0.5448]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 221.04it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 220.82it/s]


Validation: {'precision': 0.7445036642213708, 'recall': 0.7008466603929105, 'f1': 0.7220158244320469, 'auc': 0.8048197961813824, 'prauc': 0.8159363209104846}
Test:      {'precision': 0.7445207719962561, 'recall': 0.71370335528155, 'f1': 0.7287864183108187, 'auc': 0.8036364469154855, 'prauc': 0.8152444695423164}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 75.36it/s, loss=0.5024]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 224.88it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 224.93it/s]


Validation: {'precision': 0.7913354530970138, 'recall': 0.6243336469092997, 'f1': 0.6979842194317264, 'auc': 0.8159527167493111, 'prauc': 0.8246127382073223}
Test:      {'precision': 0.7879137798276062, 'recall': 0.6418940106596366, 'f1': 0.7074477227151602, 'auc': 0.8134687003428217, 'prauc': 0.8226412354418582}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 74.96it/s, loss=0.4903]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 219.49it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 225.30it/s]


Validation: {'precision': 0.7338807785865759, 'recall': 0.7566635308850528, 'f1': 0.7450980342145538, 'auc': 0.8182505968881937, 'prauc': 0.8310175769588293}
Test:      {'precision': 0.731437125746313, 'recall': 0.7660708686084476, 'f1': 0.748353494770638, 'auc': 0.8195021812037466, 'prauc': 0.8302526018820823}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 74.22it/s, loss=0.4344]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 223.97it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 224.55it/s]


Validation: {'precision': 0.6751787538287151, 'recall': 0.8291000313551926, 'f1': 0.7442645974429356, 'auc': 0.8054897537584389, 'prauc': 0.8145985291775396}
Test:      {'precision': 0.676410777832546, 'recall': 0.8344308560651162, 'f1': 0.7471570917815047, 'auc': 0.8071177658755745, 'prauc': 0.814957659476986}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 74.50it/s, loss=0.4153]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 224.07it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 225.93it/s]


Validation: {'precision': 0.786773547091035, 'recall': 0.6155534650341312, 'f1': 0.690710762137555, 'auc': 0.808460856864488, 'prauc': 0.8183980593596518}
Test:      {'precision': 0.7946464242876763, 'recall': 0.6237064910610733, 'f1': 0.6988756099682724, 'auc': 0.8139536136206573, 'prauc': 0.8170431678627195}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 74.62it/s, loss=0.3964]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 224.61it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 224.06it/s]


Validation: {'precision': 0.6857798165120138, 'recall': 0.843838193788511, 'f1': 0.7566427618163379, 'auc': 0.8150445866412992, 'prauc': 0.8156754550351684}
Test:      {'precision': 0.6812358703826144, 'recall': 0.8504233301948874, 'f1': 0.7564853507074328, 'auc': 0.8187493425181969, 'prauc': 0.8199745872306908}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 74.68it/s, loss=0.3642]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 223.73it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 223.52it/s]


Validation: {'precision': 0.7269746646774159, 'recall': 0.7648165569119949, 'f1': 0.7454156429226997, 'auc': 0.8098841650583364, 'prauc': 0.8134170477287388}
Test:      {'precision': 0.7271105826375532, 'recall': 0.767011602380787, 'f1': 0.7465283026466408, 'auc': 0.8131593504356619, 'prauc': 0.8130077695699329}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 75.26it/s, loss=0.3236]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 221.26it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 218.88it/s]


Validation: {'precision': 0.6936542669565272, 'recall': 0.7952336155509714, 'f1': 0.7409788116756005, 'auc': 0.8052552485117608, 'prauc': 0.8038251964008714}
Test:      {'precision': 0.6917900403749885, 'recall': 0.8058952649708188, 'f1': 0.7444959394069357, 'auc': 0.803324831999243, 'prauc': 0.7969819436242922}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7338807785865759, 'recall': 0.7566635308850528, 'f1': 0.7450980342145538, 'auc': 0.8182505968881937, 'prauc': 0.8310175769588293}
Corresponding test performance:
{'precision': 0.731437125746313, 'recall': 0.7660708686084476, 'f1': 0.748353494770638, 'auc': 0.8195021812037466, 'prauc': 0.8302526018820823}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 74.41it/s, loss=0.6870]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 209.31it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 214.00it/s]


Validation: {'precision': 0.6999012833145775, 'recall': 0.6669802445886893, 'f1': 0.6830443109929999, 'auc': 0.7543075268045827, 'prauc': 0.7501676622706805}
Test:      {'precision': 0.6936879205360664, 'recall': 0.6788962057049893, 'f1': 0.6862123563296261, 'auc': 0.7534090651703513, 'prauc': 0.7495235371765588}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 75.87it/s, loss=0.5844]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 221.72it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 222.72it/s]


Validation: {'precision': 0.7659574468057485, 'recall': 0.6660395108163498, 'f1': 0.7125125746932041, 'auc': 0.8049072583995629, 'prauc': 0.8131075566908845}
Test:      {'precision': 0.7513264945144983, 'recall': 0.6660395108163498, 'f1': 0.7061170162923521, 'auc': 0.797073255930423, 'prauc': 0.8090805447101558}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 74.94it/s, loss=0.5508]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 222.98it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 222.84it/s]


Validation: {'precision': 0.7517309594436145, 'recall': 0.7149576669780027, 'f1': 0.7328833122621181, 'auc': 0.813102282367024, 'prauc': 0.8230359077989766}
Test:      {'precision': 0.7366201751516814, 'recall': 0.7121354656609842, 'f1': 0.724170913366466, 'auc': 0.810411491750082, 'prauc': 0.8250729560405805}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 74.89it/s, loss=0.5056]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 223.51it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 223.77it/s]


Validation: {'precision': 0.7948207171283075, 'recall': 0.6255879586057523, 'f1': 0.7001228236349343, 'auc': 0.8226726384121001, 'prauc': 0.8335160501704542}
Test:      {'precision': 0.7907884465230649, 'recall': 0.6353088742532602, 'f1': 0.7045731127756392, 'auc': 0.8173197191187404, 'prauc': 0.8315397047517701}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 75.15it/s, loss=0.4783]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 222.21it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 222.63it/s]


Validation: {'precision': 0.6438956842819205, 'recall': 0.8748824082757138, 'f1': 0.7418239780969295, 'auc': 0.8048708367406423, 'prauc': 0.8149899769858131}
Test:      {'precision': 0.6459584295597092, 'recall': 0.8770774537445059, 'f1': 0.7439819076015226, 'auc': 0.8042607337522091, 'prauc': 0.8142506049030374}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 75.03it/s, loss=0.4691]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 224.37it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 225.56it/s]


Validation: {'precision': 0.7300871656124734, 'recall': 0.7616807776708634, 'f1': 0.7455494118200895, 'auc': 0.8144335567965377, 'prauc': 0.8245645740588591}
Test:      {'precision': 0.7203065134078388, 'recall': 0.7663844465325608, 'f1': 0.7426314140241207, 'auc': 0.8090442819342372, 'prauc': 0.8207399460542688}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 77.09it/s, loss=0.4313]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 222.50it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 221.15it/s]


Validation: {'precision': 0.7045514338313482, 'recall': 0.83976168077504, 'f1': 0.7662374771534464, 'auc': 0.8285889726871719, 'prauc': 0.8396998838558178}
Test:      {'precision': 0.6948405496481855, 'recall': 0.8403888366232662, 'f1': 0.7607152945034681, 'auc': 0.823355817323494, 'prauc': 0.8357866323423395}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 75.79it/s, loss=0.4050]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 224.80it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 225.02it/s]


Validation: {'precision': 0.764461378591048, 'recall': 0.6920664785177421, 'f1': 0.7264647744701324, 'auc': 0.8194632623015534, 'prauc': 0.8227873390665678}
Test:      {'precision': 0.7681912681886064, 'recall': 0.6952022577588737, 'f1': 0.729876538219912, 'auc': 0.8211464267971715, 'prauc': 0.8243331403492542}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 74.50it/s, loss=0.3883]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 223.32it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 226.63it/s]


Validation: {'precision': 0.7811594202870248, 'recall': 0.6760740043879709, 'f1': 0.7248276971583772, 'auc': 0.8185298630977661, 'prauc': 0.8288286087358245}
Test:      {'precision': 0.776622445318119, 'recall': 0.6792097836291026, 'f1': 0.7246570709650939, 'auc': 0.8154834007551036, 'prauc': 0.8265372183864781}


Epoch 010: 100%|██████████| 98/98 [00:01<00:00, 76.08it/s, loss=0.3465]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 223.90it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 218.74it/s]


Validation: {'precision': 0.7744923405743694, 'recall': 0.6817184070220078, 'f1': 0.7251500950845867, 'auc': 0.8209886516139215, 'prauc': 0.8249979183355665}
Test:      {'precision': 0.7798696596640845, 'recall': 0.6754468485397446, 'f1': 0.723911942595146, 'auc': 0.8237968900921959, 'prauc': 0.8264528121440461}


Epoch 011: 100%|██████████| 98/98 [00:01<00:00, 75.48it/s, loss=0.3101]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 224.05it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 225.64it/s]


Validation: {'precision': 0.7431221020069765, 'recall': 0.7538413295680344, 'f1': 0.7484433324823597, 'auc': 0.8185588497145898, 'prauc': 0.8206062759660742}
Test:      {'precision': 0.7374336559452469, 'recall': 0.7406710567552817, 'f1': 0.7390488060114789, 'auc': 0.8200275626438596, 'prauc': 0.8230003052151633}


Epoch 012: 100%|██████████| 98/98 [00:01<00:00, 74.97it/s, loss=0.2869]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 221.71it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 220.74it/s]


Validation: {'precision': 0.6759213759197152, 'recall': 0.8626528692353006, 'f1': 0.7579556363444635, 'auc': 0.8126731598697823, 'prauc': 0.815388937228499}
Test:      {'precision': 0.6720116618059475, 'recall': 0.867356538096998, 'f1': 0.7572895227991833, 'auc': 0.8142977943341554, 'prauc': 0.8193069882398392}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7045514338313482, 'recall': 0.83976168077504, 'f1': 0.7662374771534464, 'auc': 0.8285889726871719, 'prauc': 0.8396998838558178}
Corresponding test performance:
{'precision': 0.6948405496481855, 'recall': 0.8403888366232662, 'f1': 0.7607152945034681, 'auc': 0.823355817323494, 'prauc': 0.8357866323423395}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 73.69it/s, loss=0.7084]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 213.91it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 210.52it/s]


Validation: {'precision': 0.6448505803886272, 'recall': 0.8187519598594584, 'f1': 0.721470014410954, 'auc': 0.7609450601630545, 'prauc': 0.7443702048706773}
Test:      {'precision': 0.6475673005664422, 'recall': 0.8222013170247031, 'f1': 0.7245095280887573, 'auc': 0.7582633823028297, 'prauc': 0.7389078821399594}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 72.55it/s, loss=0.5935]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 214.56it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 214.39it/s]


Validation: {'precision': 0.7035347776490778, 'recall': 0.7739103167112765, 'f1': 0.7370464337130408, 'auc': 0.8053012653939281, 'prauc': 0.8092839549066576}
Test:      {'precision': 0.703473595027666, 'recall': 0.7811226089658793, 'f1': 0.7402674541496654, 'auc': 0.801031636136861, 'prauc': 0.8038230258590394}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 76.12it/s, loss=0.5298]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 224.66it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 225.40it/s]


Validation: {'precision': 0.738683788119619, 'recall': 0.721542803384379, 'f1': 0.730012685353703, 'auc': 0.8082638784856908, 'prauc': 0.8180151114082946}
Test:      {'precision': 0.7380497131907647, 'recall': 0.7262464722460764, 'f1': 0.7321005165722166, 'auc': 0.8073080266385201, 'prauc': 0.8175260246726147}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 74.75it/s, loss=0.4966]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 223.46it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 223.89it/s]


Validation: {'precision': 0.7275440167092583, 'recall': 0.7645029789878818, 'f1': 0.745565744236262, 'auc': 0.8134792088592746, 'prauc': 0.8249269997050096}
Test:      {'precision': 0.7195301027879015, 'recall': 0.7682659140772397, 'f1': 0.7430997826888862, 'auc': 0.8134533989481298, 'prauc': 0.8251497896285864}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 65.13it/s, loss=0.4738]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 223.29it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 224.62it/s]


Validation: {'precision': 0.8190224570637648, 'recall': 0.5832549388504759, 'f1': 0.6813186764575273, 'auc': 0.8086395490586683, 'prauc': 0.8219827775043709}
Test:      {'precision': 0.8113695090404334, 'recall': 0.5907808090291917, 'f1': 0.6837234572878453, 'auc': 0.8083911539818608, 'prauc': 0.8222444701022528}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 74.60it/s, loss=0.4330]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 224.47it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 224.05it/s]


Validation: {'precision': 0.7096204766086802, 'recall': 0.7563499529609397, 'f1': 0.7322404321613275, 'auc': 0.798281229261633, 'prauc': 0.8047359490285735}
Test:      {'precision': 0.7149084568418631, 'recall': 0.7714016933183713, 'f1': 0.7420814429687859, 'auc': 0.8035226931259996, 'prauc': 0.8128829168985335}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 75.44it/s, loss=0.4024]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 225.46it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 221.88it/s]


Validation: {'precision': 0.7093802345038822, 'recall': 0.7968015051715371, 'f1': 0.7505538275356729, 'auc': 0.8145613089050001, 'prauc': 0.8228356490292968}
Test:      {'precision': 0.7044267253211316, 'recall': 0.8033866415779135, 'f1': 0.7506592390861415, 'auc': 0.81610411391083, 'prauc': 0.8262567289941591}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 74.28it/s, loss=0.3748]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 218.17it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 224.42it/s]


Validation: {'precision': 0.708564747147885, 'recall': 0.7601128880502975, 'f1': 0.7334341856242127, 'auc': 0.7973784744881098, 'prauc': 0.8046652185150485}
Test:      {'precision': 0.7062049062028681, 'recall': 0.7673251803049003, 'f1': 0.7354974401521689, 'auc': 0.7999996476652538, 'prauc': 0.808399242091816}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 74.36it/s, loss=0.3402]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 225.50it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 226.23it/s]


Validation: {'precision': 0.7422380336327871, 'recall': 0.7196613358397, 'f1': 0.7307753492418205, 'auc': 0.8065044862943538, 'prauc': 0.8135782970857952}
Test:      {'precision': 0.7341972187081726, 'recall': 0.7284415177148685, 'f1': 0.7313080384418175, 'auc': 0.8091164602236722, 'prauc': 0.8208913534582816}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7275440167092583, 'recall': 0.7645029789878818, 'f1': 0.745565744236262, 'auc': 0.8134792088592746, 'prauc': 0.8249269997050096}
Corresponding test performance:
{'precision': 0.7195301027879015, 'recall': 0.7682659140772397, 'f1': 0.7430997826888862, 'auc': 0.8134533989481298, 'prauc': 0.8251497896285864}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 75.19it/s, loss=0.6739]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 225.47it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 225.19it/s]


Validation: {'precision': 0.7825912555123041, 'recall': 0.6117905299447733, 'f1': 0.686730019711816, 'auc': 0.7912333117214745, 'prauc': 0.8070631389230827}
Test:      {'precision': 0.7622887864794076, 'recall': 0.6224521793646207, 'f1': 0.6853098517722549, 'auc': 0.7871697302172849, 'prauc': 0.802554270979667}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 75.30it/s, loss=0.5812]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 222.70it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 226.10it/s]


Validation: {'precision': 0.6981981981961015, 'recall': 0.7290686735630948, 'f1': 0.7132995808261977, 'auc': 0.7833183577519205, 'prauc': 0.7999703611876453}
Test:      {'precision': 0.6897670303724867, 'recall': 0.733458764500679, 'f1': 0.7109422442426729, 'auc': 0.776784764240238, 'prauc': 0.7942632897240067}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 74.24it/s, loss=0.5274]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 224.31it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 225.77it/s]


Validation: {'precision': 0.9147509578456442, 'recall': 0.2994669175280669, 'f1': 0.4512166275125645, 'auc': 0.8028229347738829, 'prauc': 0.8163338165146984}
Test:      {'precision': 0.9097022094052863, 'recall': 0.2969582941351616, 'f1': 0.44775413340303377, 'auc': 0.798892309891496, 'prauc': 0.8129924297175033}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 73.66it/s, loss=0.5082]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 222.78it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 220.94it/s]


Validation: {'precision': 0.7327044025134192, 'recall': 0.7306365631836607, 'f1': 0.7316690168221748, 'auc': 0.8060646131281342, 'prauc': 0.8198914571198107}
Test:      {'precision': 0.7264325323452667, 'recall': 0.739416745058829, 'f1': 0.7328671278652475, 'auc': 0.8088750102554578, 'prauc': 0.8196527921959437}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 75.18it/s, loss=0.4684]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 212.54it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 212.94it/s]


Validation: {'precision': 0.6946141032740294, 'recall': 0.784571966131124, 'f1': 0.7368576006708656, 'auc': 0.8039228688732867, 'prauc': 0.8157751586511478}
Test:      {'precision': 0.6892593604791221, 'recall': 0.7908435246133871, 'f1': 0.7365654155821492, 'auc': 0.802553067904469, 'prauc': 0.8125388197168355}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 74.21it/s, loss=0.4483]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 215.00it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 223.20it/s]


Validation: {'precision': 0.7111307420473759, 'recall': 0.7572906867332792, 'f1': 0.7334851886245809, 'auc': 0.801913347799514, 'prauc': 0.8181260117449737}
Test:      {'precision': 0.7030844623790629, 'recall': 0.7648165569119949, 'f1': 0.7326524431892797, 'auc': 0.8005742049692286, 'prauc': 0.8138197017435439}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 74.64it/s, loss=0.4144]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 225.79it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 224.71it/s]


Validation: {'precision': 0.7150355047832201, 'recall': 0.7262464722460764, 'f1': 0.7205973814324103, 'auc': 0.7939034463329117, 'prauc': 0.8080383259679054}
Test:      {'precision': 0.7092952612372135, 'recall': 0.7322044528042264, 'f1': 0.7205678086080501, 'auc': 0.7942108884523293, 'prauc': 0.8082859150046365}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 74.65it/s, loss=0.4092]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 225.41it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 223.96it/s]


Validation: {'precision': 0.7603423680429374, 'recall': 0.6685481342092551, 'f1': 0.7114967412221651, 'auc': 0.8040099291972999, 'prauc': 0.8146837112254438}
Test:      {'precision': 0.7536433032590089, 'recall': 0.6810912511737814, 'f1': 0.7155328561535681, 'auc': 0.8058019969326744, 'prauc': 0.8124746080000023}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 75.46it/s, loss=0.3759]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 224.33it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 223.40it/s]

Validation: {'precision': 0.7292782855318964, 'recall': 0.72561931639785, 'f1': 0.7274441949348717, 'auc': 0.7958944300382131, 'prauc': 0.8057866634544943}
Test:      {'precision': 0.7176724137908939, 'recall': 0.7309501411077738, 'f1': 0.7242504222158178, 'auc': 0.7951361194958391, 'prauc': 0.8034014652546613}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7327044025134192, 'recall': 0.7306365631836607, 'f1': 0.7316690168221748, 'auc': 0.8060646131281342, 'prauc': 0.8198914571198107}
Corresponding test performance:
{'precision': 0.7264325323452667, 'recall': 0.739416745058829, 'f1': 0.7328671278652475, 'auc': 0.8088750102554578, 'prauc': 0.8196527921959437}





In [21]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")


Final Metrics:
precision: 0.7001 ± 0.0380
recall: 0.7974 ± 0.0504
f1: 0.7432 ± 0.0109
auc: 0.8139 ± 0.0069
prauc: 0.8266 ± 0.0058
