In [1]:
import torch
import pickle
import numpy as np
from argparse import Namespace
from torch.utils.data import DataLoader
import torch.nn.functional as F
from heterogt.utils.tokenizer import EHRTokenizer
from heterogt.utils.dataset import FineTuneEHRDataset, batcher
from heterogt.utils.train import train_with_early_stopping
from heterogt.utils.seed import set_random_seed

In [2]:
set_random_seed(123)

[INFO] Random seed set to 123


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
config = Namespace(
    dataset = "MIMIC-III",
    tasks = ["death", "readmission", "stay", "next_diag_6m", "next_diag_12m"], 
    task_index = 2,
    token_type = ["diag", "med", "lab", "pro"],
    special_tokens = ["[PAD]"],
    batch_size = 32,
    lr = 1e-3,
    epochs = 500,
    early_stop_patience = 5,   
)

In [5]:
full_data_path = f"./data_process/{config.dataset}-processed/mimic.pkl"  # for tokenizer
curr_task = config.tasks[config.task_index]
print("Current task:", curr_task)
if curr_task == "next_diag_6m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_6m.pkl"
elif curr_task == "next_diag_12m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_downstream.pkl"

Current task: stay


In [6]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_gender_sentences = ["[PAD]"] + [str(c) + "_" + gender for c in set(ehr_full_data["AGE"].values.tolist()) for gender in ["M", "F"]]
token_type_sentences = ["[PAD]"] + config.token_type
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
config.max_num_adms = max_admissions
print(f"Max admissions per patient: {config.max_num_adms}")

Max admissions per patient: 8


In [7]:
task_sentences = config.tasks
tokenizer = EHRTokenizer(token_type_sentences, age_gender_sentences, task_sentences, diag_sentences, 
                         med_sentences, lab_sentences, pro_sentences, special_tokens=config.special_tokens)
config.label_vocab_size = len(tokenizer.diag_voc.id2word)  # only for diagnosis
config.global_vocab_size = len(tokenizer.vocab.id2word)
config.age_gender_vocab_size = tokenizer.token_number("age_gender")
print(f"Age and gender vocabulary size: {config.age_gender_vocab_size}")

Age and gender vocabulary size: 37


In [8]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))
# example label percentage
print("Percentage of DEATH in test dataset:",
      (test_data["DEATH"] == True).mean() * 100, "%")

print("Percentage of READMISSION in test dataset:",
      (test_data["READMISSION"] == 1).mean() * 100, "%")

print("Percentage of STAY>7 days in test dataset:",
      (test_data["STAY_DAYS"] > 7).mean() * 100, "%")

Percentage of DEATH in test dataset: 28.648477157360407 %
Percentage of READMISSION in test dataset: 40.1491116751269 %
Percentage of STAY>7 days in test dataset: 50.58692893401015 %


In [9]:
train_dataset = FineTuneEHRDataset(train_data, tokenizer, token_type=config.token_type, task=curr_task)
val_dataset = FineTuneEHRDataset(val_data, tokenizer, token_type=config.token_type, task=curr_task)
test_dataset = FineTuneEHRDataset(test_data, tokenizer, token_type=config.token_type, task=curr_task)

In [10]:
train_dataloader = DataLoader(
    train_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=True,
    batch_size=config.batch_size,
)

val_dataloader = DataLoader(
    val_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

test_dataloader = DataLoader(
    test_dataset,
    collate_fn=batcher(tokenizer, config.task_index, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

In [11]:
if curr_task in ["death", "stay", "readmission"]:
    eval_metric = "prauc"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "prauc"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [12]:
input_ids, token_types, adm_index, age_gender_ids, task_index, labels = next(iter(train_dataloader))
print("Input IDs shape:", input_ids.shape)
print("Token Types shape:", token_types.shape)
print("Admission Index shape:", adm_index.shape)
print("Age/Sex IDs shape:", age_gender_ids.shape)
print("Task Index:", task_index)
print("Labels shape:", labels.shape)

Input IDs shape: torch.Size([32, 256])
Token Types shape: torch.Size([32, 256])
Admission Index shape: torch.Size([32, 256])
Age/Sex IDs shape: torch.Size([32, 7])
Task Index: 2
Labels shape: torch.Size([32, 1])


# Model Walkthrough

In [13]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData, Batch as HeteroBatch
from torch_geometric.nn import HeteroConv, GATConv
from heterogt.model.layer import TransformerEncoder

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [14]:
class DiseaseOccHetGNN(nn.Module):
    def __init__(self, d_model: int, heads: int = 4, dropout: float = 0.0):
        super().__init__()
        self.act = nn.GELU()

        # 第1层
        self.conv1 = HeteroConv({
            ('visit','contains','occ'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('occ','contained_by','visit'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('visit','next','visit'):       GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
        }, aggr='mean')

        # 第2层
        self.conv2 = HeteroConv({
            ('visit','contains','occ'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('occ','contained_by','visit'): GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
            ('visit','next','visit'):       GATConv(d_model, d_model, heads=heads, concat=False, add_self_loops=False),
        }, aggr='mean')
        self.lin = nn.Linear(d_model, d_model)

    def forward(self, hg):
        # x_dict: {'visit': [N_visit, d], 'occ': [N_occ, d]}
        x_dict = {'visit': hg['visit'].x, 'occ': hg['occ'].x}

        # 第1层：HeteroConv → Linear → GELU → Dropout
        x_dict = self.conv1(x_dict, hg.edge_index_dict)

        # 第2层：HeteroConv → Linear（末层通常不再加激活/随你需要）
        x_dict = self.conv2(x_dict, hg.edge_index_dict)
        x_dict = {k: self.lin(v) for k, v in x_dict.items()}

        return x_dict  # {'visit': [N_visit, d], 'occ': [N_occ, d]}

In [15]:
# multi-class classification task
class MultiPredictionHead(nn.Module):
    def __init__(self, hidden_size, label_size):
        super(MultiPredictionHead, self).__init__()
        self.cls = nn.Sequential(
                nn.Linear(hidden_size, hidden_size), 
                nn.ReLU(), 
                nn.Linear(hidden_size, label_size)
            )

    def forward(self, input):
        return self.cls(input)
    
class BinaryPredictionHead(nn.Module):
    def __init__(self, hidden_size):
        super(BinaryPredictionHead, self).__init__()
        self.cls = nn.Sequential(
                nn.Linear(hidden_size, hidden_size), 
                nn.ReLU(), 
                nn.Linear(hidden_size, 1)
            )
    def forward(self, input):
        return self.cls(input)

In [16]:
for i in range(len(train_dataset)):
    age_gender_ids = train_dataset[i][3]
    if len(age_gender_ids[0]) > 3:
        print(age_gender_ids)
        break
exp_i = i
id_seq = torch.concat([train_dataset[exp_i][0][0], torch.zeros(5, dtype=train_dataset[exp_i][0][0].dtype)], dim=0)
type_seq = torch.concat([train_dataset[exp_i][1][0], torch.zeros(5, dtype=train_dataset[exp_i][1][0].dtype)], dim=0)
visit_seq = torch.concat([train_dataset[exp_i][2][0], torch.zeros(5, dtype=train_dataset[exp_i][2][0].dtype)], dim=0)
age_sex = torch.concat([train_dataset[exp_i][3][0], torch.zeros(3, dtype=train_dataset[exp_i][3][0].dtype)], dim=0)

tensor([[33, 33, 33,  5]])


In [17]:
class HeteroGT(nn.Module):
    def __init__(self, tokenizer, d_model, num_heads, layer_types, max_num_adms, device, task, label_vocab_size):
        super(HeteroGT, self).__init__()
        self.device = device
        self.tokenizer = tokenizer
        self.max_num_adms = max_num_adms
        self.global_vocab_size = len(self.tokenizer.vocab.word2id)
        self.age_gender_vocab_size = len(self.tokenizer.age_gender_voc.word2id)
        self.n_type = len(self.tokenizer.token_type_voc.word2id)
        self.d_model = d_model
        self.num_attn_heads = num_heads
        self.layer_types = layer_types
        self.seq_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="all")[0] #0
        self.type_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="type")[0] #0
        self.adm_pad_id = 0
        self.age_gender_pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"], voc_type="age_gender")[0] #0
        self.node_type_id_dict = {'diag': 1, 'med': 2, 'lab': 3, 'pro': 4, 'visit': 5}
        self.graph_node_types = ['diag']
        
        # embedding layers
        self.token_emb = nn.Embedding(self.global_vocab_size, d_model, padding_idx=self.seq_pad_id) # already contains [PAD]
        self.type_emb = nn.Embedding(self.n_type + 1, d_model, padding_idx=self.type_pad_id) # n_type already have PAD, + 1 for visit
        self.adm_index_emb = nn.Embedding(self.max_num_adms + 1, d_model, padding_idx=self.adm_pad_id) # +1 for pad
        self.age_gender_emb = nn.Embedding(self.age_gender_vocab_size, d_model, padding_idx=self.age_gender_pad_id) # already contains [PAD]
        self.task_emb = nn.Embedding(5, d_model, padding_idx=None)  # 5 task in total, task embedding
        
        # stack together
        self.stack_layers = nn.ModuleList(self.make_gnn_layer() if layer_type == 'gnn' else self.make_tf_layer()
            for layer_type in self.layer_types
        )

        # prediction head
        if task in ["death", "stay", "readmission"]:
            self.cls_head = BinaryPredictionHead(self.d_model)
        else:
            self.cls_head = MultiPredictionHead(self.d_model, label_vocab_size)

    def make_tf_layer(self):
        assert self.d_model % self.num_attn_heads == 0, "Invalid model and attention head dimensions"
        layer_layer = nn.TransformerEncoderLayer(d_model=self.d_model, nhead=self.num_attn_heads, batch_first=True, norm_first=True)
        tf_wrapper = nn.TransformerEncoder(layer_layer, num_layers=1, enable_nested_tensor=False)
        return tf_wrapper

    def make_gnn_layer(self):
        return DiseaseOccHetGNN(d_model=self.d_model, heads=self.num_attn_heads)
    
    def forward(self, input_ids, token_types, adm_index, age_gender_ids, task_id):
        """Forward pass for the model.

        Args:
            input_ids (Tensor): Input token IDs. Shape of [B, L]
            token_types (Tensor): Token type IDs. Shape of [B, L]
            adm_index (Tensor): Admission index IDs. Shape of [B, L]
            age_gender_ids (Tensor): Age and gender IDs. Shape of [B, V]
            task_id (Tensor): Task ID. Shape of [1]

        Returns:
            Tensor: Output logits. Shape of [B, label_size]
        """
        B, L = input_ids.shape
        V = age_gender_ids.shape[1]
        num_visits = adm_index.max(dim=1).values
        
        task_id = torch.full((B,), task_id, dtype=torch.long, device=self.device) # [1] -> [B]
        # 基础表示
        task_id_embed = self.task_emb(task_id).unsqueeze(1) # [B, 1, d]
        seq_embed = self.token_emb(input_ids)  # [B, L, d]
        visit_embed = self.age_gender_emb(age_gender_ids) #[B, V, d], the virtual node representation in gnn

        # run through layers
        for i, layer_type in enumerate(self.layer_types):
            if layer_type == 'gnn':
                hg_batch = self.build_graph_batch(seq_embed, token_types, self.graph_node_types, visit_embed, adm_index) # num_visits is a 1d tensor of [B]
                gnn_out = self.stack_layers[i](hg_batch)['visit']  # extract virtual visit node representations
                visit_embed = self.process_gnn_out(gnn_out, num_visits, V) # [B, V, d]
            elif layer_type == 'tf':
                x, src_key_padding_mask, attn_mask = self.prepare_tf_input(task_id_embed, seq_embed, visit_embed, i, input_ids, 
                                                                            adm_index, token_types, num_visits)
                h = self.stack_layers[i](src=x, src_key_padding_mask=src_key_padding_mask, mask=attn_mask) # [B, 1+L+V, d]
                task_id_embed, seq_embed, visit_embed = self.process_tf_out(h, L, V) # # [B, 1, d], [B, L, d], [B, V, d]
            else:
                raise ValueError(f"Unknown layer type: {layer_type}")

        logits = self.cls_head(task_id_embed.squeeze())  # [B, label_size]
        return logits

    def build_graph_batch(self, seq_embed, token_types, graph_node_types, visit_embed, adm_index):
        """Build a batch of heterogeneous graphs from the input sequences.

        Args:
            seq_embed (Tensor): Sequence embeddings. Shape of [B, L, d]
            token_types (Tensor): Token type IDs. Shape of [B, L]
            graph_node_types: a list controls what types of tokens are connected to the virtual visit nodes. e.g. ['diag']
            visit_embed (Tensor): Visit embeddings. Shape of [B, V, d]
        Returns:
            A batch of heterogeneous graphs.
        """
        B, L = seq_embed.shape[0], seq_embed.shape[1]
        V = visit_embed.shape[1]
        graph_node_type_ids = [self.node_type_id_dict[t] for t in graph_node_types]
        graphs = [] # contains heterogeneous graphs for each patient
        for p in range(B):
            hg_p = self.build_patient_graph(seq_embed[p], token_types[p], visit_embed[p], adm_index[p], graph_node_type_ids)
            graphs.append(hg_p)
        hg_batch = HeteroBatch.from_data_list(graphs).to(self.device)
        return hg_batch

    def build_patient_graph(self, seq_embed_p, token_types_p, visit_embed_p, adm_index_p, graph_node_type_ids):
        """Build a heterogeneous graph for a single patient.

        Args:
            seq_embed_p (Tensor): Sequence embeddings for patient p. Shape [L, d]
            token_types_p (Tensor): Token type IDs for patient p. Shape [L]
            visit_embed_p (Tensor): Visit embeddings for patient p. Shape [V, d]
            graph_node_type_ids (list): List of graph node type IDs that the graph uses.
            adm_index_p (Tensor): Admission index for patient p. Shape [L]

        Returns:
            A heterogeneous graph for patient p.
        """
        hg = HeteroData()
        occ_mask = torch.isin(token_types_p, torch.tensor(graph_node_type_ids, device=token_types_p.device)) # [L], a mask for the token types needed in the graph
        occ_pos = torch.nonzero(occ_mask, as_tuple=False).view(-1) # [L], seq position index for the token types needed in the graph
        num_occ = occ_pos.numel() # int, number of occurrences of the token types needed in the graph
        
        # build visit virtual nodes
        nonpad = adm_index_p != self.adm_pad_id
        adm_index_used_p = adm_index_p[nonpad] # adm_index非pad部分
        adm_ids_unique, adm_lid_nonpad = torch.unique(adm_index_used_p, return_inverse=True)
        num_visit_p = adm_ids_unique.numel()  # int, number of visits for patient
        adm_lid_full = torch.full_like(token_types_p, fill_value=-1) # [L]
        adm_lid_full[nonpad] = adm_lid_nonpad
        hg['visit'].x = visit_embed_p[:num_visit_p, :]
        hg['visit'].num_nodes = num_visit_p
        
        # build medical code nodes
        gid_occ_embed = seq_embed_p[occ_pos, :]
        hg['occ'].x = gid_occ_embed
        hg['occ'].num_nodes = num_occ

        # build edges between occ nodes and virtual visit nodes
        occ_adm_lid = adm_lid_full[occ_pos]
        assert (occ_adm_lid != -1).all(), "occ_adm_lid contains -1"
        e_v2o = torch.stack([occ_adm_lid, torch.arange(num_occ, device=self.device)], dim=0)
        e_o2v = torch.stack([torch.arange(num_occ, device=self.device), occ_adm_lid], dim=0)
        hg['visit','contains','occ'].edge_index = e_v2o
        hg['occ','contained_by','visit'].edge_index = e_o2v
        
        # build forward edges between virtual visit nodes
        if num_visit_p > 1:
            src = torch.arange(0, num_visit_p - 1, device=self.device)
            dst = torch.arange(1, num_visit_p, device=self.device)
            e_next = torch.stack([src, dst], dim=0) # [2, num_visit_p-1]
        else:
            e_next = torch.empty(2, 0, dtype=torch.long, device=self.device)
        hg['visit','next','visit'].edge_index = e_next
        return hg

    def process_gnn_out(self, gnn_out, num_visits, V):
        """Process the output of the GNN layer.

        Args:
            gnn_out (Tensor): The output of the GNN layer. Shape [sum(num_visits), d]
            num_visits (Tensor): A tensor containing the number of visits for each patient.
            V (int): The maximum number of visits.

        Returns:
            Tensor: The processed visit embeddings. Shape [B, V, d]
        """
        B = len(num_visits)
        # 计算每个批次的累积偏移量
        cumsum = torch.cumsum(num_visits, dim=0)  # [B]
        offsets = torch.cat([torch.tensor([0], device=self.device), cumsum[:-1]])  # [B]

        # 创建索引以从 gnn_out 中提取所有批次的嵌入
        indices = torch.arange(sum(num_visits), device=self.device)  # [N]
        batch_indices = torch.repeat_interleave(torch.arange(B, device=self.device), num_visits)  # [N]
        visit_pos = indices - offsets[batch_indices]  # [N]，每个嵌入的相对位置

        # 创建目标张量 visit_emb_pad，初始化为零
        visit_emb_pad = torch.zeros(B, V, self.d_model, device=self.device, dtype=gnn_out.dtype)  # [B, V, d]

        # 创建掩码，选择有效位置 (visit_pos < V 且 visit_pos < num_visits)
        mask = (visit_pos < V) & (visit_pos < num_visits[batch_indices])  # [N]
        valid_indices = indices[mask]  # [N_valid]
        valid_batch_indices = batch_indices[mask]  # [N_valid]
        valid_visit_pos = visit_pos[mask]  # [N_valid]

        # 使用 scatter 将 gnn_out 的值分配到 visit_emb_pad
        visit_emb_pad[valid_batch_indices, valid_visit_pos] = gnn_out[valid_indices]
        return visit_emb_pad
    
    def prepare_tf_input(self, task_id_embed, seq_embed, visit_embed, layer_i, input_ids, adm_index, token_types, num_visits):
        """Prepare the input for the Transformer layer.
        Args:
            task_id_emb (Tensor): Task ID embeddings. Shape [B, 1, d]
            seq_embed (Tensor): Sequence embeddings. Shape [B, L, d]
            visit_embed (Tensor): Visit embeddings. Shape [B, V, d]
            layer_i (int): The current layer index.
            adm_index (tensor): The admission index. Shape [B, L]
            token_types (Tensor): Token types. Shape [B, L]

        Returns:
            Tuple[Tensor, Tensor, Tensor]: Processed inputs for the Transformer layer.
        """
        
        B, L, d = seq_embed.shape
        V = visit_embed.shape[1]

        # Part 1: prepare main seq embedding x
        # important: initiate new tensor to ensure safe autograd
        x = torch.empty(B, 1 + L + V, d, device=seq_embed.device, dtype=seq_embed.dtype)
        x[:, 0:1, :] = task_id_embed
        x[:, 1:1 + L, :] = seq_embed
        x[:, 1 + L:, :] = visit_embed
        
        # we already have token_types for main seq, just prepare token types for visit nodes
        # here it is out of the if branch because it is needed in the mask making
        arange_V = torch.arange(1, V + 1, device=self.device, dtype=torch.long)[None, :]  # [1, V]
        n_v = num_visits.view(B, 1)  # [B, 1]
        visit_index = torch.where(arange_V <= n_v, arange_V, torch.full((B, V), self.adm_pad_id, device=self.device, dtype=torch.long))  # [B, V]
        visit_type_id = torch.full((B, V), self.node_type_id_dict['visit'], dtype=torch.long, device=self.device)  # [B, V]
        visit_type_id_mask = (visit_index != self.adm_pad_id).long() # [B, V]
        visit_type_id = visit_type_id * visit_type_id_mask # [B, V]
        token_types = torch.cat([token_types, visit_type_id], dim=1)  # [B, L+V]

        # if it is the first time transformer going through, we need extra information of admission index and token types
        if (layer_i == 0) or (layer_i == 1 and self.layer_types[0] == 'gnn'):
            adm_index = torch.cat([adm_index, visit_index], dim=1)  # [B, L+V]
            # transform into embedding and add
            adm_index_embed = self.adm_index_emb(adm_index)
            token_type_embed = self.type_emb(token_types)
            x_non_task = x[:, 1:, :]
            x_non_task.add_(adm_index_embed).add_(token_type_embed)
            x[:, 1:, :] = x_non_task
        else:
            x = x
            
        # part 2: prepare mask (src_key_padding_mask and attn_mask)
        task_pad_mask = torch.zeros((B, 1), dtype=torch.bool, device=self.device) # [B, 1]
        seq_pad_mask = (input_ids == self.seq_pad_id) # [B, L], bool
        visit_pad_mask = (visit_index == self.adm_pad_id) # [B, V], bool
        src_key_padding_mask = torch.cat([task_pad_mask, seq_pad_mask, visit_pad_mask], dim=1)  # [B, 1+L+V]
        attn_mask = self.build_attn_mask(torch.cat([torch.full((B, 1), -1, device=self.device), token_types], dim=1), forbid_map=None, num_heads=self.num_attn_heads)
        assert attn_mask.dtype == src_key_padding_mask.dtype, f"attn_mask dtype ({attn_mask.dtype}) and src_key_padding_mask dtype ({src_key_padding_mask.dtype}) must match"
        return x, src_key_padding_mask, attn_mask
    
    def process_tf_out(self, h, L, V):
        return h[:, 0:1, :], h[:, 1:1 + L, :], h[:, 1 + L:, :]  # [B, 1, d], [B, L, d], [B, V, d]
        
    @staticmethod
    def build_attn_mask(token_types, forbid_map, num_heads):
        B, L = token_types.shape
        device = token_types.device
        
        if forbid_map == None:
            mask = torch.zeros((B, L, L), dtype=torch.bool, device=device)
        else:
            # 收集所有出现的 token 类型
            observed = torch.unique(token_types)
            for q_t, ks in forbid_map.items():
                observed = torch.unique(torch.cat([observed, torch.tensor([q_t] + list(ks), device=device)]))
            type_list = observed.sort().values
            t2i = {t.item(): i for i, t in enumerate(type_list)}  # Map token types to indices
            T = len(type_list)

            # 构造禁止矩阵 (T, T)，单向关系
            ban_table = torch.zeros((T, T), dtype=torch.bool, device=device)
            for q_t, ks in forbid_map.items():
                if q_t in t2i:
                    qi = t2i[q_t]
                    for k_t in ks:
                        if k_t in t2i:
                            ban_table[qi, t2i[k_t]] = True  # 只设置 q -> k 的禁止

            # 向量化映射 token_types 到类型索引
            mapping = torch.zeros_like(type_list, dtype=torch.long, device=device)
            for t, i in t2i.items():
                mapping[type_list == t] = i
            q_idx = mapping[torch.searchsorted(type_list, token_types.unsqueeze(-1))]
            k_idx = mapping[torch.searchsorted(type_list, token_types.unsqueeze(-2))]

            # 查询 ban_table 得到 (B, L, L)
            mask = ban_table[q_idx, k_idx].to(torch.bool)
        
        # 扩展到 num_heads
        mask = mask.unsqueeze(1).expand(B, num_heads, L, L)
        mask = mask.reshape(B * num_heads, L, L)
        return mask

In [18]:
final_metrics = {"precision":[],"recall":[],"f1":[],"auc":[],"prauc":[]}
for i in range(5):
    model = HeteroGT(tokenizer, d_model=128, num_heads=4, layer_types=['gnn', 'tf', 'tf'], max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, label_vocab_size=config.label_vocab_size).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 26.88it/s, loss=0.6880]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.43it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.15it/s]


Validation: {'precision': 0.7847478474736486, 'recall': 0.40012543116839094, 'f1': 0.5300103797406505, 'auc': 0.7576357128783671, 'prauc': 0.7472932178426162}
Test:      {'precision': 0.7750865051858415, 'recall': 0.4214487300080858, 'f1': 0.5460085268178382, 'auc': 0.7484954047999065, 'prauc': 0.7422041965813613}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 29.30it/s, loss=0.6012]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.03it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.28it/s]


Validation: {'precision': 0.6229180546288632, 'recall': 0.8795860771374112, 'f1': 0.7293291683109395, 'auc': 0.799053368430751, 'prauc': 0.8118308200358381}
Test:      {'precision': 0.6175618453295798, 'recall': 0.8689244277175637, 'f1': 0.7219906152557334, 'auc': 0.7857528915357618, 'prauc': 0.8061668395729323}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 28.19it/s, loss=0.5382]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 53.99it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.26it/s]


Validation: {'precision': 0.8086220789653158, 'recall': 0.6293508936951102, 'f1': 0.7078116685014195, 'auc': 0.819469642371461, 'prauc': 0.8309523043381795}
Test:      {'precision': 0.785941223189889, 'recall': 0.6205707118199418, 'f1': 0.6935342512433141, 'auc': 0.8068569374963194, 'prauc': 0.8225610117208685}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 28.41it/s, loss=0.4998]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.13it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.40it/s]


Validation: {'precision': 0.7371834465696814, 'recall': 0.7485105048581107, 'f1': 0.7428037914815641, 'auc': 0.8220740672865282, 'prauc': 0.8331974153518603}
Test:      {'precision': 0.7277519379822396, 'recall': 0.7359673878935843, 'f1': 0.7318366024191416, 'auc': 0.8135439993114374, 'prauc': 0.8281647342925911}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 29.49it/s, loss=0.4672]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.63it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.10it/s]


Validation: {'precision': 0.7273542600875117, 'recall': 0.762935089367316, 'f1': 0.744719921538679, 'auc': 0.8227207150018754, 'prauc': 0.8322147391901085}
Test:      {'precision': 0.7170305676835022, 'recall': 0.7723424270907108, 'f1': 0.7436594152945059, 'auc': 0.8146962345985673, 'prauc': 0.8266242768191439}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 29.64it/s, loss=0.4396]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.21it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 49.45it/s]


Validation: {'precision': 0.769479353677668, 'recall': 0.6719974913744998, 'f1': 0.7174422447693768, 'auc': 0.8180157399841191, 'prauc': 0.8289995342125205}
Test:      {'precision': 0.7602667602640918, 'recall': 0.6792097836291026, 'f1': 0.7174561063086086, 'auc': 0.8158293431423328, 'prauc': 0.8270217766088352}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 29.41it/s, loss=0.4263]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.34it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.40it/s]


Validation: {'precision': 0.6942826892517356, 'recall': 0.8225148949488162, 'f1': 0.752978321429441, 'auc': 0.8139062214121376, 'prauc': 0.8108411511064509}
Test:      {'precision': 0.6858187516304889, 'recall': 0.8234556287211557, 'f1': 0.7483613515512759, 'auc': 0.8076155645384138, 'prauc': 0.8100694856862716}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 29.50it/s, loss=0.3926]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.55it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.27it/s]


Validation: {'precision': 0.7467451952859679, 'recall': 0.7554092191886002, 'f1': 0.7510522163540213, 'auc': 0.8273005502232573, 'prauc': 0.8384697755492235}
Test:      {'precision': 0.7366800123167949, 'recall': 0.7500783944786765, 'f1': 0.7433188265705015, 'auc': 0.818357495946892, 'prauc': 0.8339322898948753}


Epoch 009: 100%|██████████| 98/98 [00:03<00:00, 29.32it/s, loss=0.3581]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.03it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.15it/s]


Validation: {'precision': 0.7243795620416806, 'recall': 0.7779868297247476, 'f1': 0.7502267866581646, 'auc': 0.8181696654502335, 'prauc': 0.8166397398605968}
Test:      {'precision': 0.7101910828004917, 'recall': 0.7692066478495792, 'f1': 0.7385217472261157, 'auc': 0.8144299701975138, 'prauc': 0.8218085319130408}


Epoch 010: 100%|██████████| 98/98 [00:03<00:00, 28.84it/s, loss=0.3294]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.26it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.21it/s]


Validation: {'precision': 0.7175507900656954, 'recall': 0.7974286610197635, 'f1': 0.7553839249091757, 'auc': 0.8243921426073226, 'prauc': 0.8292325250169372}
Test:      {'precision': 0.7108938547466176, 'recall': 0.7980558168679898, 'f1': 0.7519574481094532, 'auc': 0.817102731248619, 'prauc': 0.825935322851401}


Epoch 011: 100%|██████████| 98/98 [00:03<00:00, 29.24it/s, loss=0.2986]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.31it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.21it/s]


Validation: {'precision': 0.7002224694085089, 'recall': 0.7895892129169345, 'f1': 0.7422254924365845, 'auc': 0.8152734151328627, 'prauc': 0.8228484690689792}
Test:      {'precision': 0.7027932960874224, 'recall': 0.7889620570687083, 'f1': 0.7433889741842311, 'auc': 0.811737277066481, 'prauc': 0.8214172446763262}


Epoch 012: 100%|██████████| 98/98 [00:03<00:00, 29.47it/s, loss=0.2812]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.54it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.07it/s]


Validation: {'precision': 0.7240875912387618, 'recall': 0.7776732518006345, 'f1': 0.7499243977860761, 'auc': 0.8230905581094254, 'prauc': 0.82789080155528}
Test:      {'precision': 0.7218001168885979, 'recall': 0.7745374725595029, 'f1': 0.7472394444064613, 'auc': 0.8198286948463996, 'prauc': 0.8281532990532633}


Epoch 013: 100%|██████████| 98/98 [00:03<00:00, 29.47it/s, loss=0.2500]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.87it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 53.98it/s]


Validation: {'precision': 0.7208297320636105, 'recall': 0.784571966131124, 'f1': 0.7513513463580596, 'auc': 0.8206978309470364, 'prauc': 0.8272762835968857}
Test:      {'precision': 0.7173850984849632, 'recall': 0.7880213232963688, 'f1': 0.7510460201133664, 'auc': 0.818809793093937, 'prauc': 0.8249279785355246}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7467451952859679, 'recall': 0.7554092191886002, 'f1': 0.7510522163540213, 'auc': 0.8273005502232573, 'prauc': 0.8384697755492235}
Corresponding test performance:
{'precision': 0.7366800123167949, 'recall': 0.7500783944786765, 'f1': 0.7433188265705015, 'auc': 0.818357495946892, 'prauc': 0.8339322898948753}


Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 29.60it/s, loss=0.6798]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.33it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.75it/s]


Validation: {'precision': 0.6244254760327765, 'recall': 0.8946378174948428, 'f1': 0.7354988350635047, 'auc': 0.7951436414967383, 'prauc': 0.7993171647605224}
Test:      {'precision': 0.6228260869551678, 'recall': 0.8984007525842007, 'f1': 0.7356528389162051, 'auc': 0.7901989030309345, 'prauc': 0.7958369461223991}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 29.41it/s, loss=0.5909]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.85it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 50.80it/s]


Validation: {'precision': 0.7260356818847111, 'recall': 0.7529005957956949, 'f1': 0.7392241329304085, 'auc': 0.8083963528506203, 'prauc': 0.8180017595582374}
Test:      {'precision': 0.7120588235273175, 'recall': 0.7591721542779581, 'f1': 0.7348611271929105, 'auc': 0.8013528144247859, 'prauc': 0.8148386476247935}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 29.20it/s, loss=0.5448]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.22it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.28it/s]


Validation: {'precision': 0.715200931855809, 'recall': 0.7701473816219186, 'f1': 0.7416578539808972, 'auc': 0.8151022082175501, 'prauc': 0.8266463266439693}
Test:      {'precision': 0.7037246049641543, 'recall': 0.7820633427382188, 'f1': 0.7408287489104072, 'auc': 0.8102990466325103, 'prauc': 0.8225680259707819}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 28.73it/s, loss=0.5117]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.07it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.02it/s]


Validation: {'precision': 0.8450990350388772, 'recall': 0.5217936657242966, 'f1': 0.6452113174951345, 'auc': 0.8190868381770121, 'prauc': 0.8301264246819064}
Test:      {'precision': 0.8307615729177165, 'recall': 0.5233615553448625, 'f1': 0.6421700606649708, 'auc': 0.8183394262077659, 'prauc': 0.8273960329834842}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 29.13it/s, loss=0.4750]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.19it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.83it/s]


Validation: {'precision': 0.6970478647156864, 'recall': 0.7626215114432029, 'f1': 0.728361779973468, 'auc': 0.7898139722325292, 'prauc': 0.8011154564962826}
Test:      {'precision': 0.6952006735877205, 'recall': 0.776732518028295, 'f1': 0.7337085258188548, 'auc': 0.788255072236173, 'prauc': 0.7979900555109354}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 29.27it/s, loss=0.4502]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.58it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.25it/s]


Validation: {'precision': 0.7691470054418543, 'recall': 0.6644716211957841, 'f1': 0.7129878819690753, 'auc': 0.8135417034023054, 'prauc': 0.823977445081627}
Test:      {'precision': 0.7621097954762752, 'recall': 0.6660395108163498, 'f1': 0.7108433685142227, 'auc': 0.8124190951339049, 'prauc': 0.8194065919039866}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 26.86it/s, loss=0.4100]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 53.46it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 53.79it/s]


Validation: {'precision': 0.7336039974992705, 'recall': 0.7365945437418107, 'f1': 0.7350962240698532, 'auc': 0.8085975511181751, 'prauc': 0.8128054775698808}
Test:      {'precision': 0.7309923664099818, 'recall': 0.7507055503269028, 'f1': 0.7407178167807715, 'auc': 0.8082796148679222, 'prauc': 0.8138800675214192}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 29.73it/s, loss=0.4092]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.58it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.16it/s]


Validation: {'precision': 0.7603903144172014, 'recall': 0.6597679523340867, 'f1': 0.706514434243682, 'auc': 0.8034765151636176, 'prauc': 0.7903004180397537}
Test:      {'precision': 0.7614321162681268, 'recall': 0.6735653809950657, 'f1': 0.7148086472626239, 'auc': 0.8057739611535842, 'prauc': 0.7976825590528691}


Epoch 009: 100%|██████████| 98/98 [00:03<00:00, 29.37it/s, loss=0.3711]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.40it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.28it/s]


Validation: {'precision': 0.7247842903876085, 'recall': 0.7638758231396555, 'f1': 0.7438167888943066, 'auc': 0.8126289012745973, 'prauc': 0.8056518781882118}
Test:      {'precision': 0.7215777262160047, 'recall': 0.7801818751935398, 'f1': 0.7497363216589636, 'auc': 0.8136557397595163, 'prauc': 0.8098307769675127}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.8450990350388772, 'recall': 0.5217936657242966, 'f1': 0.6452113174951345, 'auc': 0.8190868381770121, 'prauc': 0.8301264246819064}
Corresponding test performance:
{'precision': 0.8307615729177165, 'recall': 0.5233615553448625, 'f1': 0.6421700606649708, 'auc': 0.8183394262077659, 'prauc': 0.8273960329834842}


Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 29.53it/s, loss=0.7024]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.25it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.91it/s]


Validation: {'precision': 0.8049435787167925, 'recall': 0.4697397303215123, 'f1': 0.5932673220760907, 'auc': 0.7701479844655847, 'prauc': 0.7793701858604176}
Test:      {'precision': 0.8005194805153221, 'recall': 0.4832235810583781, 'f1': 0.6026593617481477, 'auc': 0.7645217533988978, 'prauc': 0.7748726529710364}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 28.31it/s, loss=0.5844]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 50.03it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.40it/s]


Validation: {'precision': 0.7281616688372615, 'recall': 0.7005330824687973, 'f1': 0.7140802251418281, 'auc': 0.7940856551010568, 'prauc': 0.8028136869670128}
Test:      {'precision': 0.7170178281986482, 'recall': 0.6936343681383078, 'f1': 0.705132286997851, 'auc': 0.7857631595769365, 'prauc': 0.8017811050552603}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 29.30it/s, loss=0.5471]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.17it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.57it/s]


Validation: {'precision': 0.7607207207179794, 'recall': 0.6619629978028788, 'f1': 0.7079141465978428, 'auc': 0.8025638135094514, 'prauc': 0.8143737539638263}
Test:      {'precision': 0.755788712008843, 'recall': 0.6550642834723893, 'f1': 0.7018310045981309, 'auc': 0.7927267035007477, 'prauc': 0.8074676341344083}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 29.51it/s, loss=0.5316]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.48it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.24it/s]


Validation: {'precision': 0.7857421114110411, 'recall': 0.6324866729362418, 'f1': 0.7008339074951448, 'auc': 0.8152428209393694, 'prauc': 0.8278874840802076}
Test:      {'precision': 0.7822517591838067, 'recall': 0.6274694261504313, 'f1': 0.6963633150508929, 'auc': 0.807064311661223, 'prauc': 0.8188244917330683}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 29.11it/s, loss=0.5105]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.67it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 53.95it/s]


Validation: {'precision': 0.7446054750378596, 'recall': 0.7249921605496238, 'f1': 0.7346679327170182, 'auc': 0.8114571788295415, 'prauc': 0.8164130171635103}
Test:      {'precision': 0.7452830188655001, 'recall': 0.7184070241432474, 'f1': 0.7315982705861293, 'auc': 0.8092617228061751, 'prauc': 0.8183398592522444}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 29.26it/s, loss=0.4539]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.54it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 51.14it/s]


Validation: {'precision': 0.7982697601226966, 'recall': 0.636563185949713, 'f1': 0.7083042518649438, 'auc': 0.8197754335960817, 'prauc': 0.8259843989416653}
Test:      {'precision': 0.7869429241564232, 'recall': 0.6312323612397892, 'f1': 0.7005394069248995, 'auc': 0.8145119635263072, 'prauc': 0.8246861683555032}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 29.04it/s, loss=0.4444]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.70it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.43it/s]


Validation: {'precision': 0.677132262050242, 'recall': 0.8588899341459427, 'f1': 0.7572573906996556, 'auc': 0.8205707821533671, 'prauc': 0.8253567796176086}
Test:      {'precision': 0.6779535343532428, 'recall': 0.8601442458423953, 'f1': 0.7582584608593507, 'auc': 0.8198934237726294, 'prauc': 0.8239119944886921}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 29.69it/s, loss=0.4075]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.64it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.61it/s]


Validation: {'precision': 0.8091143594118267, 'recall': 0.5901536531809654, 'f1': 0.6825022616657429, 'auc': 0.8135108077881865, 'prauc': 0.8223922334096729}
Test:      {'precision': 0.808721934366111, 'recall': 0.587331451863947, 'f1': 0.6804722930342609, 'auc': 0.8146653801415077, 'prauc': 0.8229188765759832}


Epoch 009: 100%|██████████| 98/98 [00:03<00:00, 28.51it/s, loss=0.3765]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.00it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.25it/s]


Validation: {'precision': 0.7538247566037767, 'recall': 0.6798369394773288, 'f1': 0.7149216767916682, 'auc': 0.8038719790243396, 'prauc': 0.8136311918771916}
Test:      {'precision': 0.7583304103796621, 'recall': 0.6779554719326499, 'f1': 0.7158940347483868, 'auc': 0.8064063516894704, 'prauc': 0.8180315227380424}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7857421114110411, 'recall': 0.6324866729362418, 'f1': 0.7008339074951448, 'auc': 0.8152428209393694, 'prauc': 0.8278874840802076}
Corresponding test performance:
{'precision': 0.7822517591838067, 'recall': 0.6274694261504313, 'f1': 0.6963633150508929, 'auc': 0.807064311661223, 'prauc': 0.8188244917330683}


Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 29.47it/s, loss=0.6554]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.29it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.76it/s]


Validation: {'precision': 0.6140237324690097, 'recall': 0.8924427720260507, 'f1': 0.7275051076433626, 'auc': 0.7798004273943523, 'prauc': 0.7902028798795148}
Test:      {'precision': 0.6125244618381985, 'recall': 0.8833490122267691, 'f1': 0.7234206423114945, 'auc': 0.7724288497730208, 'prauc': 0.7852842346875026}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 26.95it/s, loss=0.6009]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.04it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.98it/s]


Validation: {'precision': 0.8223562152095438, 'recall': 0.555973659452631, 'f1': 0.6634237557081281, 'auc': 0.8056821605910818, 'prauc': 0.8095879061169902}
Test:      {'precision': 0.8053571428535475, 'recall': 0.565694575100139, 'f1': 0.6645791073256851, 'auc': 0.7985673565884331, 'prauc': 0.8095100433383737}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 29.22it/s, loss=0.5445]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 53.70it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 52.84it/s]


Validation: {'precision': 0.7380880121373654, 'recall': 0.7626215114432029, 'f1': 0.750154220785574, 'auc': 0.8272560404442176, 'prauc': 0.8380177781785757}
Test:      {'precision': 0.7247542448593246, 'recall': 0.762935089367316, 'f1': 0.7433547154409856, 'auc': 0.8185663297843158, 'prauc': 0.8343791000055634}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 29.17it/s, loss=0.5105]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.80it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.76it/s]


Validation: {'precision': 0.7492063492039708, 'recall': 0.7400439009070554, 'f1': 0.7445969345782155, 'auc': 0.8221043098226251, 'prauc': 0.8310685440304701}
Test:      {'precision': 0.7383265434010081, 'recall': 0.7387895892106028, 'f1': 0.7385579887280929, 'auc': 0.8192686839340892, 'prauc': 0.831973325105668}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 28.98it/s, loss=0.4789]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.31it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.39it/s]


Validation: {'precision': 0.7651322943067592, 'recall': 0.6619629978028788, 'f1': 0.7098184213855472, 'auc': 0.7960207755170895, 'prauc': 0.8118264563843299}
Test:      {'precision': 0.7524254401697721, 'recall': 0.6566321730929551, 'f1': 0.7012726005130581, 'auc': 0.7902074597319135, 'prauc': 0.8064837114083486}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 29.34it/s, loss=0.4538]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.11it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.80it/s]


Validation: {'precision': 0.6856492027317498, 'recall': 0.8494825964225479, 'f1': 0.7588235244665879, 'auc': 0.8230300730372319, 'prauc': 0.8282846975806512}
Test:      {'precision': 0.6779448621536894, 'recall': 0.8482282847260952, 'f1': 0.7535868455964324, 'auc': 0.8189701557369913, 'prauc': 0.828456689491726}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 28.59it/s, loss=0.4304]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.83it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.17it/s]


Validation: {'precision': 0.7456113335363548, 'recall': 0.7591721542779581, 'f1': 0.7523306351472293, 'auc': 0.8213326730212163, 'prauc': 0.8304379374913531}
Test:      {'precision': 0.736987140230444, 'recall': 0.7547820633403739, 'f1': 0.7457784613035906, 'auc': 0.8180460320312551, 'prauc': 0.8265759754998133}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 29.33it/s, loss=0.4090]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.43it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.76it/s]


Validation: {'precision': 0.7765011119318884, 'recall': 0.6569457510170682, 'f1': 0.7117377222278787, 'auc': 0.8088675235251264, 'prauc': 0.8104992602740348}
Test:      {'precision': 0.7717231222357448, 'recall': 0.6572593289411814, 'f1': 0.7099068536240885, 'auc': 0.815067394086917, 'prauc': 0.8176923736894548}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7380880121373654, 'recall': 0.7626215114432029, 'f1': 0.750154220785574, 'auc': 0.8272560404442176, 'prauc': 0.8380177781785757}
Corresponding test performance:
{'precision': 0.7247542448593246, 'recall': 0.762935089367316, 'f1': 0.7433547154409856, 'auc': 0.8185663297843158, 'prauc': 0.8343791000055634}


Epoch 001: 100%|██████████| 98/98 [00:03<00:00, 29.54it/s, loss=0.7053]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.41it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.34it/s]


Validation: {'precision': 0.7378249725348304, 'recall': 0.6318595170880155, 'f1': 0.6807432382708701, 'auc': 0.7630537988594043, 'prauc': 0.7551447234630357}
Test:      {'precision': 0.7338078291788833, 'recall': 0.646597679521334, 'f1': 0.6874479030023292, 'auc': 0.7626327861574725, 'prauc': 0.7598205039388326}


Epoch 002: 100%|██████████| 98/98 [00:03<00:00, 29.31it/s, loss=0.6142]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.23it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 49.25it/s]


Validation: {'precision': 0.7161003493149696, 'recall': 0.7071182188751737, 'f1': 0.7115809353576887, 'auc': 0.7821583404744903, 'prauc': 0.793305579226219}
Test:      {'precision': 0.7151326053019809, 'recall': 0.7187206020673605, 'f1': 0.7169221094801161, 'auc': 0.7804996811370547, 'prauc': 0.7886449981608827}


Epoch 003: 100%|██████████| 98/98 [00:03<00:00, 29.18it/s, loss=0.5750]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.51it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.31it/s]


Validation: {'precision': 0.7262055837540412, 'recall': 0.717779868295021, 'f1': 0.7219681388237881, 'auc': 0.7957878276103103, 'prauc': 0.8058590842073038}
Test:      {'precision': 0.7179646292252002, 'recall': 0.72561931639785, 'f1': 0.7217716731014451, 'auc': 0.7909055858647327, 'prauc': 0.8016864141681233}


Epoch 004: 100%|██████████| 98/98 [00:03<00:00, 29.21it/s, loss=0.5460]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.21it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.11it/s]


Validation: {'precision': 0.7114443084433555, 'recall': 0.7290686735630948, 'f1': 0.7201486708541784, 'auc': 0.7868781855764403, 'prauc': 0.8000051594918791}
Test:      {'precision': 0.7025487256350749, 'recall': 0.7347130761971317, 'f1': 0.7182709943871814, 'auc': 0.7799297545183156, 'prauc': 0.793188474694297}


Epoch 005: 100%|██████████| 98/98 [00:03<00:00, 29.28it/s, loss=0.5175]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.12it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.91it/s]


Validation: {'precision': 0.8391198043968747, 'recall': 0.538099717778181, 'f1': 0.6557126433066702, 'auc': 0.811036697056899, 'prauc': 0.8227082092801152}
Test:      {'precision': 0.8221492257117685, 'recall': 0.5493885230462546, 'f1': 0.6586466117366281, 'auc': 0.8080935821219308, 'prauc': 0.8185104856063388}


Epoch 006: 100%|██████████| 98/98 [00:03<00:00, 28.51it/s, loss=0.4780]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.30it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.00it/s]


Validation: {'precision': 0.8246844319737042, 'recall': 0.5531514581356126, 'f1': 0.6621621573538636, 'auc': 0.8148528328866781, 'prauc': 0.8238387845037353}
Test:      {'precision': 0.8149466192134567, 'recall': 0.5744747569753074, 'f1': 0.6739010435195572, 'auc': 0.8100455669493901, 'prauc': 0.8190399273625092}


Epoch 007: 100%|██████████| 98/98 [00:03<00:00, 26.99it/s, loss=0.4529]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.11it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 54.54it/s]


Validation: {'precision': 0.6759776536295684, 'recall': 0.8347444339892294, 'f1': 0.7470183758585147, 'auc': 0.8132402827767551, 'prauc': 0.8226888071037662}
Test:      {'precision': 0.6774357669786888, 'recall': 0.8350580119133426, 'f1': 0.7480337029173696, 'auc': 0.8103052879908715, 'prauc': 0.8186974796773705}


Epoch 008: 100%|██████████| 98/98 [00:03<00:00, 29.25it/s, loss=0.4272]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.42it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.13it/s]


Validation: {'precision': 0.8148308759721821, 'recall': 0.5892129194086259, 'f1': 0.6838944446261648, 'auc': 0.8165856498261959, 'prauc': 0.8265468115077327}
Test:      {'precision': 0.8023748939745446, 'recall': 0.593289432422097, 'f1': 0.6821705377454156, 'auc': 0.8127246700259267, 'prauc': 0.8191824060698862}


Epoch 009: 100%|██████████| 98/98 [00:03<00:00, 29.64it/s, loss=0.3962]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 54.92it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 52.95it/s]


Validation: {'precision': 0.7239632795165434, 'recall': 0.7171527124467948, 'f1': 0.7205418979594044, 'auc': 0.8023421186393591, 'prauc': 0.8141866779144795}
Test:      {'precision': 0.7279296261365759, 'recall': 0.7265600501701895, 'f1': 0.7272441883436103, 'auc': 0.7973156622358055, 'prauc': 0.8084897261902291}


Epoch 010: 100%|██████████| 98/98 [00:03<00:00, 29.37it/s, loss=0.3809]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.46it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.27it/s]


Validation: {'precision': 0.6709824828619676, 'recall': 0.8287864534310794, 'f1': 0.7415824866357659, 'auc': 0.8051672336891, 'prauc': 0.8134261833667953}
Test:      {'precision': 0.6710626426561728, 'recall': 0.8297271872034189, 'f1': 0.7420078469887448, 'auc': 0.799052018198593, 'prauc': 0.809269318415105}


Epoch 011: 100%|██████████| 98/98 [00:03<00:00, 29.47it/s, loss=0.3450]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 51.18it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.47it/s]


Validation: {'precision': 0.6450953678459467, 'recall': 0.8908748824054848, 'f1': 0.7483208169356158, 'auc': 0.8103873867926926, 'prauc': 0.814562758625387}
Test:      {'precision': 0.6458429030761464, 'recall': 0.8817811226062033, 'f1': 0.7455919346638933, 'auc': 0.8050140757731105, 'prauc': 0.8128192927789353}


Epoch 012: 100%|██████████| 98/98 [00:03<00:00, 29.34it/s, loss=0.3370]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 55.52it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.09it/s]


Validation: {'precision': 0.7125625183856648, 'recall': 0.7594857322020713, 'f1': 0.7352762548692722, 'auc': 0.7998026498690981, 'prauc': 0.8071779136430122}
Test:      {'precision': 0.7100088053985618, 'recall': 0.7585449984297318, 'f1': 0.7334748282354997, 'auc': 0.7943219745644513, 'prauc': 0.8007045145144362}


Epoch 013: 100%|██████████| 98/98 [00:03<00:00, 28.57it/s, loss=0.3232]
Running inference: 100%|██████████| 198/198 [00:03<00:00, 53.97it/s]
Running inference: 100%|██████████| 197/197 [00:03<00:00, 55.21it/s]

Validation: {'precision': 0.7357377049156205, 'recall': 0.7036688617099289, 'f1': 0.7193460440464974, 'auc': 0.7993667454077815, 'prauc': 0.8055203376902877}
Test:      {'precision': 0.7302782324035015, 'recall': 0.6995923486964579, 'f1': 0.7146060167809236, 'auc': 0.7920776525647201, 'prauc': 0.8021489926271532}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.8148308759721821, 'recall': 0.5892129194086259, 'f1': 0.6838944446261648, 'auc': 0.8165856498261959, 'prauc': 0.8265468115077327}
Corresponding test performance:
{'precision': 0.8023748939745446, 'recall': 0.593289432422097, 'f1': 0.6821705377454156, 'auc': 0.8127246700259267, 'prauc': 0.8191824060698862}





In [19]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")


Final Metrics:
precision: 0.7754 ± 0.0398
recall: 0.6514 ± 0.0922
f1: 0.7015 ± 0.0385
auc: 0.8150 ± 0.0045
prauc: 0.8267 ± 0.0068


In [20]:
final_metrics = {"precision":[],"recall":[],"f1":[],"auc":[],"prauc":[]}
for i in range(5):
    model = HeteroGT(tokenizer, d_model=128, num_heads=4, layer_types=['tf', 'tf'], max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, label_vocab_size=config.label_vocab_size).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 68.06it/s, loss=0.6820]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 215.75it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 220.65it/s]


Validation: {'precision': 0.7018856629730563, 'recall': 0.735340232045358, 'f1': 0.718223578461459, 'auc': 0.7839953987136774, 'prauc': 0.7931937139320092}
Test:      {'precision': 0.6984789740510037, 'recall': 0.7343994982730185, 'f1': 0.7159889891923322, 'auc': 0.7794720716830075, 'prauc': 0.7945419483358044}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 66.59it/s, loss=0.6037]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 217.67it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 220.35it/s]


Validation: {'precision': 0.6993233303892341, 'recall': 0.745374725616979, 'f1': 0.7216150526835212, 'auc': 0.7909441488680299, 'prauc': 0.7980914381199895}
Test:      {'precision': 0.693349028838454, 'recall': 0.7387895892106028, 'f1': 0.7153484085446866, 'auc': 0.7842883366628967, 'prauc': 0.7975866923870534}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 66.44it/s, loss=0.5491]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 219.19it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 219.28it/s]


Validation: {'precision': 0.8498312710863339, 'recall': 0.47381624333498334, 'f1': 0.6084155379820778, 'auc': 0.8111334028409296, 'prauc': 0.8205979507758119}
Test:      {'precision': 0.8483491885794721, 'recall': 0.47538413295554915, 'f1': 0.609324754236917, 'auc': 0.800414345661526, 'prauc': 0.8140339092218297}


Epoch 004: 100%|██████████| 98/98 [00:01<00:00, 67.94it/s, loss=0.5196]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 213.14it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 214.74it/s]


Validation: {'precision': 0.7795562022525299, 'recall': 0.6719974913744998, 'f1': 0.7217918441324659, 'auc': 0.8177016094555248, 'prauc': 0.8266051469213356}
Test:      {'precision': 0.7634137558488894, 'recall': 0.6647851991198972, 'f1': 0.7106939273043993, 'auc': 0.8051442886285975, 'prauc': 0.8203849949722504}


Epoch 005: 100%|██████████| 98/98 [00:01<00:00, 66.27it/s, loss=0.5023]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 217.75it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 219.63it/s]


Validation: {'precision': 0.7337223587201053, 'recall': 0.749137660706337, 'f1': 0.7413498786289614, 'auc': 0.8165209448652444, 'prauc': 0.8178715041731424}
Test:      {'precision': 0.7231950844831853, 'recall': 0.7381624333623764, 'f1': 0.7306021054886357, 'auc': 0.8055908477526328, 'prauc': 0.8135660699942082}


Epoch 006: 100%|██████████| 98/98 [00:01<00:00, 66.24it/s, loss=0.4693]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 217.10it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 219.35it/s]


Validation: {'precision': 0.7039343334256895, 'recall': 0.7798682972694266, 'f1': 0.7399583407413305, 'auc': 0.8061090224336319, 'prauc': 0.8165319736217244}
Test:      {'precision': 0.7021037868142999, 'recall': 0.7848855440552371, 'f1': 0.741190400698821, 'auc': 0.8024165130235505, 'prauc': 0.8137837891336622}


Epoch 007: 100%|██████████| 98/98 [00:01<00:00, 66.60it/s, loss=0.4547]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 215.53it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 217.04it/s]


Validation: {'precision': 0.6383456956605117, 'recall': 0.8905613044813717, 'f1': 0.743650165335743, 'auc': 0.8029474717290059, 'prauc': 0.8075680856879033}
Test:      {'precision': 0.6346499102319689, 'recall': 0.8867983693920138, 'f1': 0.739829949353839, 'auc': 0.8002715997557817, 'prauc': 0.8084463456074463}


Epoch 008: 100%|██████████| 98/98 [00:01<00:00, 66.25it/s, loss=0.4289]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 215.55it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 217.88it/s]


Validation: {'precision': 0.6648390941581291, 'recall': 0.8745688303516006, 'f1': 0.7554171131839358, 'auc': 0.8148377116186298, 'prauc': 0.8192064442708508}
Test:      {'precision': 0.6575931232075989, 'recall': 0.8635936030076401, 'f1': 0.7466449727228541, 'auc': 0.8105552946600649, 'prauc': 0.817349398606951}


Epoch 009: 100%|██████████| 98/98 [00:01<00:00, 65.02it/s, loss=0.3819]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 208.76it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 209.92it/s]


Validation: {'precision': 0.7322932917293844, 'recall': 0.7359673878935843, 'f1': 0.7341257378816888, 'auc': 0.8107184471130888, 'prauc': 0.8177859452920151}
Test:      {'precision': 0.7350374064814993, 'recall': 0.739416745058829, 'f1': 0.7372205671409314, 'auc': 0.8110307955668236, 'prauc': 0.8186492306309172}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7795562022525299, 'recall': 0.6719974913744998, 'f1': 0.7217918441324659, 'auc': 0.8177016094555248, 'prauc': 0.8266051469213356}
Corresponding test performance:
{'precision': 0.7634137558488894, 'recall': 0.6647851991198972, 'f1': 0.7106939273043993, 'auc': 0.8051442886285975, 'prauc': 0.8203849949722504}


Epoch 001: 100%|██████████| 98/98 [00:01<00:00, 65.72it/s, loss=0.6881]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 210.61it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 218.49it/s]


Validation: {'precision': 0.6672842773936811, 'recall': 0.790529946689274, 'f1': 0.7236974257785992, 'auc': 0.7812312208670686, 'prauc': 0.7945532108231721}
Test:      {'precision': 0.6566422594125088, 'recall': 0.7873941674481424, 'f1': 0.7160986689302955, 'auc': 0.7693681681663544, 'prauc': 0.7822712555957418}


Epoch 002: 100%|██████████| 98/98 [00:01<00:00, 66.33it/s, loss=0.5963]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 217.25it/s]
Running inference: 100%|██████████| 197/197 [00:00<00:00, 217.74it/s]


Validation: {'precision': 0.840367965363418, 'recall': 0.486986516147736, 'f1': 0.6166368823878776, 'auc': 0.7880090152899633, 'prauc': 0.8077791749882255}
Test:      {'precision': 0.8338727076546177, 'recall': 0.4847914706789439, 'f1': 0.6131271022287856, 'auc': 0.7792260413630925, 'prauc': 0.7993256654900069}


Epoch 003: 100%|██████████| 98/98 [00:01<00:00, 67.25it/s, loss=0.5473]
Running inference: 100%|██████████| 198/198 [00:00<00:00, 216.41it/s]
Running inference:  89%|████████▉ | 176/197 [00:00<00:00, 218.52it/s]


KeyboardInterrupt: 

In [None]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")