In [1]:
from set_seed_utils import set_random_seed
import os
import random
import numpy as np
import pickle
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from token_utils_rep import EHRTokenizer
from dataset_utils_rep import HBERTFinetuneEHRDataset, batcher, UniqueIDSampler
from HEART_rep import HBERT_Finetune
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve, precision_recall_fscore_support

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
@torch.no_grad()
def evaluate(model, dataloader, device, task_type="binary"):
    model.eval()
    predicted_scores, gt_labels = [], []
    # gt_labels is ground truth labels, predicted_scores is the output logits
    for _, batch in enumerate(tqdm(dataloader, desc="Running inference")):                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          
        batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]
        labels = batch[-1]
        output_logits = model(*batch[:-1])
        predicted_scores.append(output_logits)
        gt_labels.append(labels)
    
    if task_type == "binary":
        # standard binary classification evaluation
        predicted_scores = torch.cat(predicted_scores, dim=0).view(-1)
        gt_labels = torch.cat(gt_labels, dim=0).view(-1).cpu().numpy()
        scores = predicted_scores.cpu().numpy()      
    #   The threshold should be 0 because:
	# 	Your model outputs logits, not probabilities. logit > 0 ≡ probability > 0.5
        predicted_labels = (predicted_scores > 0).float().cpu().numpy()

        precision = (predicted_labels * gt_labels).sum() / (predicted_labels.sum() + 1e-8)
        recall = (predicted_labels * gt_labels).sum() / (gt_labels.sum() + 1e-8)
        f1 = 2 * precision * recall / (precision + recall + 1e-8)
        roc_auc = roc_auc_score(gt_labels, scores)
        precision_curve, recall_curve, _ = precision_recall_curve(gt_labels, scores)
        pr_auc = auc(recall_curve, precision_curve)

        return {"precision":precision, "recall":recall, "f1":f1, "auc":roc_auc, "prauc":pr_auc}
    else:
        # —— Multi-label classification evaluation (per-class over the batch) —— #
        # predicted_scores: list[Tensor] -> [B, C] logits
        logits = torch.cat(predicted_scores, dim=0)           # [B, C]
        y_true_t = torch.cat(gt_labels, dim=0)                # [B, C]
        # 1) 连续分数用于 AUC/PR-AUC：sigmoid(logits)
        #    CPU + float16 无 sigmoid 实现，先升到 fp32
        logits_for_sigmoid = logits.float() if (logits.device.type == "cpu" and logits.dtype == torch.float16) else logits
        prob_t = torch.sigmoid(logits_for_sigmoid)            # [B, C] in [0,1]

        # 2) 阈值化用于 P/R/F1：logits > 0 等价于 prob >= 0.5
        y_pred_t = (logits > 0).to(torch.int32)               # [B, C]

        # 转 numpy
        y_true = y_true_t.cpu().numpy().astype(np.int32)      # [B, C]
        y_pred = y_pred_t.cpu().numpy().astype(np.int32)      # [B, C]
        scores = prob_t.cpu().numpy()                         # [B, C] float

        # 3) 按类计算 Precision/Recall/F1
        p_cls, r_cls, f1_cls, _ = precision_recall_fscore_support(
            y_true, y_pred, average=None, zero_division=0
        )

        # 4) 按类计算 AUC/PR-AUC（遇到单一类别则设为 NaN）
        C = y_true.shape[1]
        aucs, praucs = [], []
        for c in range(C):
            yt, ys = y_true[:, c], scores[:, c]
            if yt.max() == yt.min():           # 全 0 或全 1，曲线不定义
                aucs.append(np.nan)
                praucs.append(np.nan)
                continue
            aucs.append(roc_auc_score(yt, ys))
            prec_curve, rec_curve, _ = precision_recall_curve(yt, ys)
            praucs.append(auc(rec_curve, prec_curve))

        # 5) 取宏平均（忽略 NaN）
        ave_precision = float(np.mean(p_cls))
        ave_recall    = float(np.mean(r_cls))
        ave_f1        = float(np.mean(f1_cls))
        ave_auc       = float(np.nanmean(aucs))   if np.any(~np.isnan(aucs))   else np.nan
        ave_prauc     = float(np.nanmean(praucs)) if np.any(~np.isnan(praucs)) else np.nan

        ave_f1, ave_auc, ave_prauc, ave_recall, ave_precision = np.mean(ave_f1), np.mean(ave_auc), np.mean(ave_prauc), np.mean(ave_recall), np.mean(ave_precision)
        return {"recall":ave_recall, "precision":ave_precision, "f1":ave_f1, "auc":ave_auc, "prauc":ave_prauc}

In [4]:
args = {
    "seed": 0,
    "dataset": "MIMIC-IV", 
    "task": "next_diag_12m",  # options: death, stay, readmission, next_diag_6m, next_diag_12m
    "encoder": "hi_edge",  # options: hi_edge, hi_node, hi_edge_node
    "batch_size": 4,
    "eval_batch_size": 4,
    "pretrain_mask_rate": 0.7,
    "pretrain_anomaly_rate": 0.05,
    "pretrain_anomaly_loss_weight": 1,
    "pretrain_pos_weight": 1,
    "lr": 1e-4,
    "epochs": 50,
    "num_hidden_layers": 5,
    "num_attention_heads": 6,
    "attention_probs_dropout_prob": 0.2,
    "hidden_dropout_prob": 0.2,
    "edge_hidden_size": 32,
    "hidden_size": 288,  # must be divisible by num_attention_heads
    "intermediate_size": 288,
    "save_model": True,
    "gat": "dotattn",
    "gnn_n_heads": 1,
    "gnn_temp": 1,
    "diag_med_emb": "tree",  # simple, tree
    "early_stop_patience": 5,
}

In [5]:
exp_name = "Pretrain-HBERT" \
    + "-" + str(args["dataset"]) \
    + "-" + str(args["encoder"]) \
    + "-" + str(args["pretrain_mask_rate"]) \
    + "-" + str(args["pretrain_anomaly_rate"]) \
    + "-" + str(args["pretrain_anomaly_loss_weight"]) \
    + "-" + str(args["hidden_size"]) \
    + "-" + str(args["edge_hidden_size"]) \
    + "-" + str(args["num_hidden_layers"]) \
    + "-" + str(args["num_attention_heads"]) \
    + "-" + str(args["attention_probs_dropout_prob"]) \
    + "-" + str(args["hidden_dropout_prob"]) \
    + "-" + str(args["intermediate_size"]) \
    + "-" + str(args["gat"]) \
    + "-" + str(args["gnn_n_heads"]) \
    + "-" + str(args["gnn_temp"]) \
    + "-" + str(args["diag_med_emb"])
print(exp_name)

Pretrain-HBERT-MIMIC-IV-hi_edge-0.7-0.05-1-288-32-5-6-0.2-0.2-288-dotattn-1-1-tree


In [6]:
pretrained_weight_path = "./pretrained_models/" + exp_name + f"/pretrained_model.pt"
finetune_exp_name = f"Finetune-{args['task']}-" + exp_name
save_path = "./saved_model/" + finetune_exp_name
if args["save_model"] and not os.path.exists(save_path):
    os.makedirs(save_path)

In [7]:
args["predicted_token_type"] = ["diag", "med", "pro", "lab"]
args["special_tokens"] = ("[PAD]", "[CLS]", "[SEP]", "[MASK0]", "[MASK1]", "[MASK2]", "[MASK3]")
args["max_visit_size"] = 15

full_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic.pkl"

if args["task"] == "next_diag_6m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_6m.pkl"
elif args["task"] == "next_diag_12m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_downstream.pkl"

In [8]:
ehr_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_data["NDC"].values.tolist()
lab_sentences = ehr_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_data["PRO_CODE"].values.tolist()
gender_set = [["M"], ["F"]]
age_gender_set = [[str(c) + "_" + gender] for c in set(ehr_data["AGE"].values.tolist()) for gender in ["M", "F"]]
age_set = [[c] for c in set(ehr_data["AGE"].values.tolist())]    

In [9]:
tokenizer = EHRTokenizer(diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, gender_set, age_set, age_gender_set, special_tokens=args["special_tokens"])
tokenizer.build_tree()

In [10]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))

train_dataset = HBERTFinetuneEHRDataset(
    train_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

val_dataset = HBERTFinetuneEHRDataset(
    val_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

test_dataset = HBERTFinetuneEHRDataset(
    test_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

train_dataloader = DataLoader(
    train_dataset, 
    batch_sampler=UniqueIDSampler(train_dataset.get_ids(), batch_size=args["batch_size"]),
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False), 
)

val_dataloader = DataLoader(
    val_dataset, 
    batch_sampler=UniqueIDSampler(val_dataset.get_ids(), batch_size=args["batch_size"]),
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False), 
)

test_dataloader = DataLoader(
    test_dataset, 
    batch_sampler=UniqueIDSampler(test_dataset.get_ids(), batch_size=args["eval_batch_size"]),
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False),
)

In [11]:
# examine a batch
batch = next(iter(train_dataloader))  # 取第一个 batch
input_ids, input_types, edge_index, visit_positions, labeled_batch_idx, labels = batch

# 打印每个张量的形状
print("input_ids shape:", input_ids.shape)
print("input_types shape:", input_types.shape)
print("visit_positions shape:", visit_positions.shape)
print("labeled_batch_idx shape:", len(labeled_batch_idx)) # it is a list
print("labels shape:", labels.shape)

input_ids shape: torch.Size([6, 74])
input_types shape: torch.Size([6, 74])
visit_positions shape: torch.Size([6])
labeled_batch_idx shape: 4
labels shape: torch.Size([4, 18])


In [12]:
args["vocab_size"] = len(args["special_tokens"]) + \
                     len(tokenizer.diag_voc.id2word) + \
                     len(tokenizer.pro_voc.id2word) + \
                     len(tokenizer.med_voc.id2word) + \
                     len(tokenizer.lab_voc.id2word) + \
                     len(tokenizer.age_voc.id2word) + \
                     len(tokenizer.gender_voc.id2word) + \
                     len(tokenizer.age_gender_voc.id2word)
args["label_vocab_size"] = 18  # only for diagnosis

In [13]:
if args["task"] in ["death", "stay", "readmission"]:
    eval_metric = "f1"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "prauc"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [14]:
def train_with_early_stopping(model, 
                              train_dataloader, 
                              val_dataloader, 
                              test_dataloader,
                              optimizer, 
                              loss_fn, 
                              device, 
                              args,
                              task_type="binary", 
                              eval_metric="f1"):
    best_score = 0.
    best_val_metric = None
    best_test_metric = None
    epochs_no_improve = 0

    for epoch in range(1, 1 + args["epochs"]):
        model.train()
        ave_loss = 0.

        for step, batch in enumerate(tqdm(train_dataloader, desc="Training Batches")):
            batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]

            labels = batch[-1].float()
            output_logits = model(*batch[:-1])
            
            loss = loss_fn(output_logits.view(-1), labels.view(-1))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            ave_loss += loss.item()
            

        ave_loss /= (step + 1)

        # Evaluation
        val_metric = evaluate(model, val_dataloader, device, task_type=task_type)
        test_metric = evaluate(model, test_dataloader, device, task_type=task_type)

        # Logging
        print(f"Epoch: {epoch:03d}, Average Loss: {ave_loss:.4f}")
        print(f"Validation: {val_metric}")
        print(f"Test:       {test_metric}")

        # Check for improvement
        current_score = val_metric[eval_metric]
        if current_score > best_score:
            best_score = current_score
            best_val_metric = val_metric
            best_test_metric = test_metric
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        # Early stopping check
        if epochs_no_improve >= args["early_stop_patience"]:
            print(f"\nEarly stopping triggered after {epoch} epochs (no improvement for {args['early_stop_patience']} epochs).")
            break

    print("\nBest validation performance:")
    print(best_val_metric)
    print("Corresponding test performance:")
    print(best_test_metric)
    return best_test_metric

In [15]:
random.seed(42)
seeds = [random.randint(0, 2**32 - 1) for _ in range(5)]
print(seeds)

[2746317213, 1181241943, 958682846, 3163119785, 1812140441]


In [16]:
final_metrics = {"recall":[], "precision":[],"f1":[],"auc":[],"prauc":[]}

for seed in seeds:
    args["seed"] = seed
    set_random_seed(args["seed"])
    print(f"Training with seed: {args['seed']}")
    
    # Initialize model, optimizer, and loss function
    model = HBERT_Finetune(args, tokenizer)
    model.load_weight(torch.load(pretrained_weight_path, weights_only=True))
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args["lr"])
    
    best_test_metric = train_with_early_stopping(
        model, 
        train_dataloader, 
        val_dataloader, 
        test_dataloader,
        optimizer, 
        loss_fn, 
        device, 
        args,
        task_type=task_type)
    
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

[INFO] Random seed set to 2746317213
Training with seed: 2746317213


Training Batches: 100%|██████████| 1713/1713 [00:53<00:00, 31.84it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 120.72it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.74it/s]


Epoch: 001, Average Loss: 0.3188
Validation: {'recall': 0.30400487725334835, 'precision': 0.36527294486685385, 'f1': 0.3153289428481232, 'auc': 0.7439815824070245, 'prauc': 0.4011660662817574}
Test:       {'recall': 0.3088630376980107, 'precision': 0.3630025407440793, 'f1': 0.31823511801658033, 'auc': 0.7425757417279394, 'prauc': 0.40171050321507784}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.50it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 123.20it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.37it/s]


Epoch: 002, Average Loss: 0.2899
Validation: {'recall': 0.34400978246743436, 'precision': 0.42931828298240426, 'f1': 0.36360487735100133, 'auc': 0.7683573443290065, 'prauc': 0.4298140658195271}
Test:       {'recall': 0.3471777078086956, 'precision': 0.4254150335884723, 'f1': 0.36366718271681847, 'auc': 0.7647166311532175, 'prauc': 0.42803426282084955}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.42it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.18it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 116.94it/s]


Epoch: 003, Average Loss: 0.2807
Validation: {'recall': 0.3533889153077711, 'precision': 0.43796004320410464, 'f1': 0.37480750059609697, 'auc': 0.7696451616050156, 'prauc': 0.4329722037790962}
Test:       {'recall': 0.35378258233034643, 'precision': 0.43255367422202173, 'f1': 0.37292455279138625, 'auc': 0.7660939623137896, 'prauc': 0.43034738861335975}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.41it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.29it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.33it/s]


Epoch: 004, Average Loss: 0.2747
Validation: {'recall': 0.3583150904209103, 'precision': 0.4284209229278367, 'f1': 0.38076954866714247, 'auc': 0.7775809556858219, 'prauc': 0.4378140203815382}
Test:       {'recall': 0.36078116363742363, 'precision': 0.4283781912902353, 'f1': 0.379594940754301, 'auc': 0.7742232531135368, 'prauc': 0.4368254067039848}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.46it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.64it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.94it/s]


Epoch: 005, Average Loss: 0.2694
Validation: {'recall': 0.3655177111637975, 'precision': 0.42875559421977033, 'f1': 0.3840017392246293, 'auc': 0.7727061810063354, 'prauc': 0.44018260423449523}
Test:       {'recall': 0.36267107299479384, 'precision': 0.425123175907119, 'f1': 0.37918923570954177, 'auc': 0.7719402544117828, 'prauc': 0.43696623372260557}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.54it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.64it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 116.18it/s]


Epoch: 006, Average Loss: 0.2635
Validation: {'recall': 0.3595110058621229, 'precision': 0.4325965072701603, 'f1': 0.38269685040886503, 'auc': 0.7664617089058912, 'prauc': 0.43566889286045846}
Test:       {'recall': 0.3660974901024001, 'precision': 0.43283160500318296, 'f1': 0.38516854828782726, 'auc': 0.774917027587193, 'prauc': 0.4373423919275247}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.48it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.75it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.30it/s]


Epoch: 007, Average Loss: 0.2577
Validation: {'recall': 0.3486039694151043, 'precision': 0.4535403749089561, 'f1': 0.37647256266493734, 'auc': 0.772645042337585, 'prauc': 0.4358035186781322}
Test:       {'recall': 0.3481312241380632, 'precision': 0.456159362814091, 'f1': 0.37353152710414655, 'auc': 0.7681041198158717, 'prauc': 0.43427918347263356}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.42it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.84it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 116.71it/s]


Epoch: 008, Average Loss: 0.2521
Validation: {'recall': 0.3625984331377543, 'precision': 0.44651779023192617, 'f1': 0.3876412717911968, 'auc': 0.7737926196610897, 'prauc': 0.4383500833599568}
Test:       {'recall': 0.36187558891419064, 'precision': 0.45016625583260644, 'f1': 0.3846676858010159, 'auc': 0.7720256718150552, 'prauc': 0.4369278692307562}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.41it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.01it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.81it/s]


Epoch: 009, Average Loss: 0.2475
Validation: {'recall': 0.3819782939617672, 'precision': 0.4417913719798543, 'f1': 0.3950099644270712, 'auc': 0.773338891272222, 'prauc': 0.43419721694068447}
Test:       {'recall': 0.3803735855070387, 'precision': 0.41741386323244295, 'f1': 0.3902599324123891, 'auc': 0.773925660206874, 'prauc': 0.434371920150807}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.48it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.91it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.14it/s]


Epoch: 010, Average Loss: 0.2429
Validation: {'recall': 0.3768281392111917, 'precision': 0.4921310408446118, 'f1': 0.39161035299225067, 'auc': 0.7715950206508976, 'prauc': 0.4310648756998166}
Test:       {'recall': 0.3793248709323762, 'precision': 0.43245157413508534, 'f1': 0.39142971388128583, 'auc': 0.7692874181668168, 'prauc': 0.4328848990666558}


Training Batches: 100%|██████████| 1713/1713 [00:53<00:00, 32.31it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.24it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.86it/s]


Epoch: 011, Average Loss: 0.2374
Validation: {'recall': 0.35560660427052426, 'precision': 0.4545822596152249, 'f1': 0.3755275525375346, 'auc': 0.768923681828383, 'prauc': 0.43182016160464665}
Test:       {'recall': 0.3563162801697566, 'precision': 0.49595982644292724, 'f1': 0.3758537728856418, 'auc': 0.7672970885022051, 'prauc': 0.4319159583325982}


Training Batches: 100%|██████████| 1713/1713 [00:53<00:00, 32.25it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 120.86it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.39it/s]


Epoch: 012, Average Loss: 0.2319
Validation: {'recall': 0.37560250648477006, 'precision': 0.5102775906631031, 'f1': 0.3946031000230466, 'auc': 0.7700938739322467, 'prauc': 0.4321513908521503}
Test:       {'recall': 0.37836271664574117, 'precision': 0.5155488821932891, 'f1': 0.3928286719856084, 'auc': 0.7686447974188177, 'prauc': 0.43382507930620207}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.53it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.90it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 116.52it/s]


Epoch: 013, Average Loss: 0.2276
Validation: {'recall': 0.3681449950754701, 'precision': 0.5108106037379126, 'f1': 0.3867357268283681, 'auc': 0.768053934680401, 'prauc': 0.43087644745172027}
Test:       {'recall': 0.3691594912197652, 'precision': 0.4608749923724251, 'f1': 0.38609248036856436, 'auc': 0.7680573273403577, 'prauc': 0.43060876093155337}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.42it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.19it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.81it/s]


Epoch: 014, Average Loss: 0.2210
Validation: {'recall': 0.3750785034360704, 'precision': 0.515648571924332, 'f1': 0.39336838281016395, 'auc': 0.7636251729619703, 'prauc': 0.4267980333529472}
Test:       {'recall': 0.3749092895825606, 'precision': 0.5020161138457868, 'f1': 0.38964306428725615, 'auc': 0.7668147327631079, 'prauc': 0.4307901720358765}

Early stopping triggered after 14 epochs (no improvement for 5 epochs).

Best validation performance:
{'recall': 0.3819782939617672, 'precision': 0.4417913719798543, 'f1': 0.3950099644270712, 'auc': 0.773338891272222, 'prauc': 0.43419721694068447}
Corresponding test performance:
{'recall': 0.3803735855070387, 'precision': 0.41741386323244295, 'f1': 0.3902599324123891, 'auc': 0.773925660206874, 'prauc': 0.434371920150807}
[INFO] Random seed set to 1181241943
Training with seed: 1181241943


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.54it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 123.33it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.36it/s]


Epoch: 001, Average Loss: 0.3159
Validation: {'recall': 0.32735027149288326, 'precision': 0.40035424433854316, 'f1': 0.3453917835564502, 'auc': 0.7640963725885292, 'prauc': 0.4249476193531671}
Test:       {'recall': 0.3317619786615912, 'precision': 0.3996618400319702, 'f1': 0.34810567583486335, 'auc': 0.7544490159976847, 'prauc': 0.4241343984013025}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.39it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.65it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.10it/s]


Epoch: 002, Average Loss: 0.2886
Validation: {'recall': 0.3689579605034323, 'precision': 0.42722305506664154, 'f1': 0.3890011998259381, 'auc': 0.7707377269530485, 'prauc': 0.43336442454080704}
Test:       {'recall': 0.37327909464515435, 'precision': 0.423394167057387, 'f1': 0.38969325240633723, 'auc': 0.7718521913829857, 'prauc': 0.4342380547914469}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.55it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.39it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.18it/s]


Epoch: 003, Average Loss: 0.2800
Validation: {'recall': 0.3497571413840749, 'precision': 0.43815067232668664, 'f1': 0.3741585283309403, 'auc': 0.7709139233238009, 'prauc': 0.4354340399091867}
Test:       {'recall': 0.3542981324070589, 'precision': 0.4406743878712084, 'f1': 0.37841670324971755, 'auc': 0.7726564855759733, 'prauc': 0.43924454749750746}


Training Batches: 100%|██████████| 1713/1713 [00:53<00:00, 32.28it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 123.16it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.90it/s]


Epoch: 004, Average Loss: 0.2745
Validation: {'recall': 0.35923377681905727, 'precision': 0.44489330349675416, 'f1': 0.3805561843456681, 'auc': 0.7776025262664644, 'prauc': 0.44027563941086784}
Test:       {'recall': 0.35860839467592454, 'precision': 0.44353931849311884, 'f1': 0.3793068289796657, 'auc': 0.7774193115002802, 'prauc': 0.4388702145874249}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.48it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.42it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.69it/s]


Epoch: 005, Average Loss: 0.2697
Validation: {'recall': 0.35500939613127913, 'precision': 0.44013738411282355, 'f1': 0.3735294371894293, 'auc': 0.777102679528344, 'prauc': 0.43980613989790945}
Test:       {'recall': 0.35722511983112815, 'precision': 0.44881861476245966, 'f1': 0.3752318985779991, 'auc': 0.7763172521714822, 'prauc': 0.44303552679605646}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.55it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.38it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 116.78it/s]


Epoch: 006, Average Loss: 0.2635
Validation: {'recall': 0.35458178052853806, 'precision': 0.4344720835124389, 'f1': 0.38187893155784697, 'auc': 0.7772086507596314, 'prauc': 0.4421121579017226}
Test:       {'recall': 0.35634378193758054, 'precision': 0.4353830080408731, 'f1': 0.3823989887858088, 'auc': 0.7775837615733635, 'prauc': 0.4428403255471563}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.59it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.33it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 116.44it/s]


Epoch: 007, Average Loss: 0.2588
Validation: {'recall': 0.3636679892537201, 'precision': 0.4962711487905405, 'f1': 0.3814336629441798, 'auc': 0.775914464300008, 'prauc': 0.44165614080236026}
Test:       {'recall': 0.36721476265509484, 'precision': 0.49050759180612186, 'f1': 0.3834672008990802, 'auc': 0.7752110489264017, 'prauc': 0.44228250578734996}

Early stopping triggered after 7 epochs (no improvement for 5 epochs).

Best validation performance:
{'recall': 0.3689579605034323, 'precision': 0.42722305506664154, 'f1': 0.3890011998259381, 'auc': 0.7707377269530485, 'prauc': 0.43336442454080704}
Corresponding test performance:
{'recall': 0.37327909464515435, 'precision': 0.423394167057387, 'f1': 0.38969325240633723, 'auc': 0.7718521913829857, 'prauc': 0.4342380547914469}
[INFO] Random seed set to 958682846
Training with seed: 958682846


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.44it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 118.02it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.72it/s]


Epoch: 001, Average Loss: 0.3155
Validation: {'recall': 0.32440179491062765, 'precision': 0.4113919081684192, 'f1': 0.3406541310364875, 'auc': 0.7578182215314236, 'prauc': 0.4151894535728151}
Test:       {'recall': 0.3275673574574307, 'precision': 0.42194896439734547, 'f1': 0.3421771055440178, 'auc': 0.7574101388064709, 'prauc': 0.4159008315753333}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.46it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.23it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.74it/s]


Epoch: 002, Average Loss: 0.2890
Validation: {'recall': 0.33507445242813677, 'precision': 0.4482924019253132, 'f1': 0.35882447978196663, 'auc': 0.7727735401326501, 'prauc': 0.43440358072161606}
Test:       {'recall': 0.34023263101277984, 'precision': 0.4498532762530016, 'f1': 0.3613098554720157, 'auc': 0.7721234957597756, 'prauc': 0.4314426567992499}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.46it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.07it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.38it/s]


Epoch: 003, Average Loss: 0.2808
Validation: {'recall': 0.3508682331943837, 'precision': 0.43440748150299524, 'f1': 0.37322172401724496, 'auc': 0.7671046711146751, 'prauc': 0.43757071957155835}
Test:       {'recall': 0.35360698367699706, 'precision': 0.4333521509254482, 'f1': 0.37414581326956964, 'auc': 0.7703326736625057, 'prauc': 0.4361593356105439}


Training Batches: 100%|██████████| 1713/1713 [00:53<00:00, 32.17it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.93it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.10it/s]


Epoch: 004, Average Loss: 0.2737
Validation: {'recall': 0.3574461299823215, 'precision': 0.43335231702899724, 'f1': 0.37895804466259075, 'auc': 0.7724109376735488, 'prauc': 0.43965896742440963}
Test:       {'recall': 0.3626224148477665, 'precision': 0.48404811788834334, 'f1': 0.38090253937510943, 'auc': 0.7756415426380382, 'prauc': 0.43961523270287745}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.43it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 120.89it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.30it/s]


Epoch: 005, Average Loss: 0.2681
Validation: {'recall': 0.37460528825724165, 'precision': 0.4164152808935696, 'f1': 0.384419748828825, 'auc': 0.7763993973802675, 'prauc': 0.4416076767619892}
Test:       {'recall': 0.3795238330194267, 'precision': 0.4120967713246136, 'f1': 0.3848109563232687, 'auc': 0.7777165744219784, 'prauc': 0.4434065096052515}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.56it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 120.53it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.15it/s]


Epoch: 006, Average Loss: 0.2619
Validation: {'recall': 0.36733723168570126, 'precision': 0.45768605681104596, 'f1': 0.3907748510615082, 'auc': 0.7693723925428722, 'prauc': 0.43972825781006575}
Test:       {'recall': 0.37206549684652945, 'precision': 0.4444316028024072, 'f1': 0.39033293359225335, 'auc': 0.7739847054395372, 'prauc': 0.4419293823108361}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.46it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.93it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.07it/s]


Epoch: 007, Average Loss: 0.2573
Validation: {'recall': 0.3906412833372142, 'precision': 0.4300378629919487, 'f1': 0.396817483166321, 'auc': 0.769628003834016, 'prauc': 0.4376061917632526}
Test:       {'recall': 0.39677361037564607, 'precision': 0.45558139646509743, 'f1': 0.399790111337718, 'auc': 0.7698023160504386, 'prauc': 0.4443057175750769}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.47it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 120.16it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.38it/s]


Epoch: 008, Average Loss: 0.2523
Validation: {'recall': 0.37068846346251905, 'precision': 0.47126085888973773, 'f1': 0.3894471655678479, 'auc': 0.777599190373905, 'prauc': 0.4411238861361733}
Test:       {'recall': 0.37235789104029204, 'precision': 0.48020396057449033, 'f1': 0.3889684927419831, 'auc': 0.7750787646768992, 'prauc': 0.44201399520541734}


Training Batches: 100%|██████████| 1713/1713 [00:53<00:00, 32.27it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.78it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 116.97it/s]


Epoch: 009, Average Loss: 0.2472
Validation: {'recall': 0.36041467965828416, 'precision': 0.44078059790541446, 'f1': 0.3835489160987193, 'auc': 0.7745686738032421, 'prauc': 0.43975122411377837}
Test:       {'recall': 0.3656437079606548, 'precision': 0.5070351219319982, 'f1': 0.38601742697659847, 'auc': 0.7712361995034374, 'prauc': 0.44078605110716457}


Training Batches: 100%|██████████| 1713/1713 [00:53<00:00, 32.30it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.15it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.57it/s]


Epoch: 010, Average Loss: 0.2413
Validation: {'recall': 0.37466719797606346, 'precision': 0.4570169612361469, 'f1': 0.39653363941437925, 'auc': 0.7748636102672918, 'prauc': 0.4369401332556483}
Test:       {'recall': 0.38045497661953587, 'precision': 0.49096257339353727, 'f1': 0.3991055705614903, 'auc': 0.7722347047182044, 'prauc': 0.43812454824872105}


Training Batches: 100%|██████████| 1713/1713 [00:53<00:00, 32.27it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.05it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.00it/s]


Epoch: 011, Average Loss: 0.2368
Validation: {'recall': 0.3702879334804947, 'precision': 0.49440796836988476, 'f1': 0.39029086200079044, 'auc': 0.7744214589205879, 'prauc': 0.4397026692955117}
Test:       {'recall': 0.37174613487052105, 'precision': 0.47112962265117514, 'f1': 0.3888818000795413, 'auc': 0.7717083129358825, 'prauc': 0.43933979018452013}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.38it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.14it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.34it/s]


Epoch: 012, Average Loss: 0.2310
Validation: {'recall': 0.3667227319069471, 'precision': 0.4875936180143263, 'f1': 0.3937533854008288, 'auc': 0.765925396394256, 'prauc': 0.4351731583317113}
Test:       {'recall': 0.37009834072687076, 'precision': 0.49530825240498544, 'f1': 0.39278265309548044, 'auc': 0.7662159508620756, 'prauc': 0.4332994901659315}

Early stopping triggered after 12 epochs (no improvement for 5 epochs).

Best validation performance:
{'recall': 0.3906412833372142, 'precision': 0.4300378629919487, 'f1': 0.396817483166321, 'auc': 0.769628003834016, 'prauc': 0.4376061917632526}
Corresponding test performance:
{'recall': 0.39677361037564607, 'precision': 0.45558139646509743, 'f1': 0.399790111337718, 'auc': 0.7698023160504386, 'prauc': 0.4443057175750769}
[INFO] Random seed set to 3163119785
Training with seed: 3163119785


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.45it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.79it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.72it/s]


Epoch: 001, Average Loss: 0.3162
Validation: {'recall': 0.31606766058219393, 'precision': 0.35970115780550216, 'f1': 0.32675496302578444, 'auc': 0.7458416432309648, 'prauc': 0.40145020493809147}
Test:       {'recall': 0.32107716012364895, 'precision': 0.3627289164720363, 'f1': 0.3305476757226382, 'auc': 0.7440811758128132, 'prauc': 0.40510303471006837}


Training Batches: 100%|██████████| 1713/1713 [00:53<00:00, 32.24it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.87it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 116.58it/s]


Epoch: 002, Average Loss: 0.2909
Validation: {'recall': 0.31907910522504424, 'precision': 0.4311020962017784, 'f1': 0.34113142822220144, 'auc': 0.7609468425463981, 'prauc': 0.4284044921128164}
Test:       {'recall': 0.32602517095855793, 'precision': 0.4334906367612138, 'f1': 0.34555248948244965, 'auc': 0.7631902360906071, 'prauc': 0.4306852435397058}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.47it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 120.99it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.35it/s]


Epoch: 003, Average Loss: 0.2805
Validation: {'recall': 0.3116740353046658, 'precision': 0.45173160902951576, 'f1': 0.34583800247529684, 'auc': 0.7712970990094087, 'prauc': 0.4403805896140567}
Test:       {'recall': 0.3146473729281103, 'precision': 0.4465395934381164, 'f1': 0.3470047627528537, 'auc': 0.7718677886145601, 'prauc': 0.44019329990497763}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.42it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.37it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.91it/s]


Epoch: 004, Average Loss: 0.2742
Validation: {'recall': 0.3662052682361436, 'precision': 0.48365236707614734, 'f1': 0.38033852058618206, 'auc': 0.774444650178322, 'prauc': 0.4370662781157996}
Test:       {'recall': 0.3607065700843213, 'precision': 0.5024360798559754, 'f1': 0.37295526462453255, 'auc': 0.7746482776359361, 'prauc': 0.43662006282421983}


Training Batches: 100%|██████████| 1713/1713 [00:53<00:00, 32.25it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.62it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.14it/s]


Epoch: 005, Average Loss: 0.2689
Validation: {'recall': 0.3609075013565847, 'precision': 0.4819664103864829, 'f1': 0.38237921736817904, 'auc': 0.7783221945537648, 'prauc': 0.4457499794073091}
Test:       {'recall': 0.36180991931186707, 'precision': 0.5117936287686011, 'f1': 0.3800769242696088, 'auc': 0.7785043422242381, 'prauc': 0.4459915441979394}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.48it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.88it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 116.66it/s]


Epoch: 006, Average Loss: 0.2635
Validation: {'recall': 0.36112232828486107, 'precision': 0.48739686907534874, 'f1': 0.3840178287055006, 'auc': 0.7743730743858398, 'prauc': 0.44044117033966246}
Test:       {'recall': 0.36212685117659216, 'precision': 0.43148281785583487, 'f1': 0.382374854442068, 'auc': 0.7728595234343901, 'prauc': 0.44014435580966577}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.44it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.21it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.23it/s]


Epoch: 007, Average Loss: 0.2591
Validation: {'recall': 0.36195116550034434, 'precision': 0.4817080755961823, 'f1': 0.38050952765190377, 'auc': 0.771909024007104, 'prauc': 0.440722301886908}
Test:       {'recall': 0.36098882309407426, 'precision': 0.41937925881204674, 'f1': 0.37610456596500286, 'auc': 0.7750185674537096, 'prauc': 0.4417297577785429}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.37it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 120.97it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.22it/s]


Epoch: 008, Average Loss: 0.2533
Validation: {'recall': 0.33721177096263794, 'precision': 0.46946216549113523, 'f1': 0.3662054507654972, 'auc': 0.7712995636134786, 'prauc': 0.44149352519091783}
Test:       {'recall': 0.33851689343956154, 'precision': 0.4855594868702324, 'f1': 0.36336857636480635, 'auc': 0.7731670335705441, 'prauc': 0.43918817974859903}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.45it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.86it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.79it/s]


Epoch: 009, Average Loss: 0.2485
Validation: {'recall': 0.3464320351910893, 'precision': 0.4727249086299286, 'f1': 0.37761972243357167, 'auc': 0.7704328437569207, 'prauc': 0.44030880989770727}
Test:       {'recall': 0.34901583453994633, 'precision': 0.4885945013226357, 'f1': 0.3784771733059633, 'auc': 0.7727557213150819, 'prauc': 0.44236161071759794}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.35it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.52it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.60it/s]


Epoch: 010, Average Loss: 0.2444
Validation: {'recall': 0.3412416417315607, 'precision': 0.49995164850298374, 'f1': 0.3691351790547484, 'auc': 0.7744923719283864, 'prauc': 0.4354974451573902}
Test:       {'recall': 0.3460789127895752, 'precision': 0.49223533541289793, 'f1': 0.37073926263342727, 'auc': 0.7698571627388612, 'prauc': 0.44154736890331386}


Training Batches: 100%|██████████| 1713/1713 [00:53<00:00, 32.32it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.22it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.57it/s]


Epoch: 011, Average Loss: 0.2376
Validation: {'recall': 0.3773874831983278, 'precision': 0.4510006188793468, 'f1': 0.3934077559232718, 'auc': 0.7749478174156929, 'prauc': 0.4377574169040749}
Test:       {'recall': 0.38143245021397987, 'precision': 0.45954096896147434, 'f1': 0.39356062004087666, 'auc': 0.7717690436358898, 'prauc': 0.4383422619623756}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.44it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.23it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.47it/s]


Epoch: 012, Average Loss: 0.2323
Validation: {'recall': 0.374554534295082, 'precision': 0.45466199400931523, 'f1': 0.395079909218084, 'auc': 0.7714182590586911, 'prauc': 0.4370259381119634}
Test:       {'recall': 0.3760871593738057, 'precision': 0.48897723847307567, 'f1': 0.3965700592898324, 'auc': 0.7677737007109431, 'prauc': 0.43591063552618603}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.41it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.68it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.45it/s]


Epoch: 013, Average Loss: 0.2269
Validation: {'recall': 0.3715024905980674, 'precision': 0.4777766025517926, 'f1': 0.39036266111436024, 'auc': 0.7708021449749103, 'prauc': 0.43222425313373297}
Test:       {'recall': 0.3732385877350427, 'precision': 0.4503515585529393, 'f1': 0.3883158398378834, 'auc': 0.7672288340713934, 'prauc': 0.43471784361255245}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.40it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 120.77it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.86it/s]


Epoch: 014, Average Loss: 0.2221
Validation: {'recall': 0.3759682060294019, 'precision': 0.5225036918545949, 'f1': 0.39647058525763496, 'auc': 0.7674076692439562, 'prauc': 0.4307785429695088}
Test:       {'recall': 0.37888299914422885, 'precision': 0.5055172013732983, 'f1': 0.39650340063781886, 'auc': 0.7680376408685401, 'prauc': 0.4341260953504286}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.52it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.69it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 116.72it/s]


Epoch: 015, Average Loss: 0.2175
Validation: {'recall': 0.36581849348624007, 'precision': 0.4493617144229135, 'f1': 0.3870584085562996, 'auc': 0.7697401866458344, 'prauc': 0.4317931074407208}
Test:       {'recall': 0.37102085056162637, 'precision': 0.4758891890540596, 'f1': 0.38904794730002645, 'auc': 0.7663720746948735, 'prauc': 0.427378007260756}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.43it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.92it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 116.01it/s]


Epoch: 016, Average Loss: 0.2114
Validation: {'recall': 0.3867931697812348, 'precision': 0.4596597262672828, 'f1': 0.4001145885016164, 'auc': 0.7660960693038505, 'prauc': 0.4279589700979775}
Test:       {'recall': 0.3899699879118781, 'precision': 0.43923315800662244, 'f1': 0.39970774360995365, 'auc': 0.7620414606370725, 'prauc': 0.4223139363524872}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.38it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.52it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 119.06it/s]


Epoch: 017, Average Loss: 0.2063
Validation: {'recall': 0.3702177838055803, 'precision': 0.4718256782976192, 'f1': 0.3945362623364692, 'auc': 0.769181447170287, 'prauc': 0.42911287726250674}
Test:       {'recall': 0.3658810618294501, 'precision': 0.450829795499872, 'f1': 0.38814920877324927, 'auc': 0.7657737405807233, 'prauc': 0.4240739179000423}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.35it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 120.73it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 116.92it/s]


Epoch: 018, Average Loss: 0.2039
Validation: {'recall': 0.3585253850351547, 'precision': 0.4855809234218771, 'f1': 0.38493491801451785, 'auc': 0.7674842963399873, 'prauc': 0.42572771650447305}
Test:       {'recall': 0.36091289345696925, 'precision': 0.43779796914306696, 'f1': 0.3817780593049364, 'auc': 0.7640571841513376, 'prauc': 0.4217991033503192}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.48it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.87it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.06it/s]


Epoch: 019, Average Loss: 0.1986
Validation: {'recall': 0.37500872529864443, 'precision': 0.45867355830501855, 'f1': 0.39413332519585786, 'auc': 0.7620316006713267, 'prauc': 0.42119342024404743}
Test:       {'recall': 0.3728177236257777, 'precision': 0.45449927751026675, 'f1': 0.39143228273229824, 'auc': 0.7618705624530104, 'prauc': 0.41952628381663104}


Training Batches: 100%|██████████| 1713/1713 [00:53<00:00, 32.29it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.27it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.81it/s]


Epoch: 020, Average Loss: 0.1942
Validation: {'recall': 0.38695881059807963, 'precision': 0.43519889382811267, 'f1': 0.3983397325819449, 'auc': 0.7644810944486551, 'prauc': 0.4233600986690995}
Test:       {'recall': 0.38800216893667905, 'precision': 0.4581540583535988, 'f1': 0.39905858191464505, 'auc': 0.7620746919699717, 'prauc': 0.41847154257984465}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.48it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.48it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.19it/s]


Epoch: 021, Average Loss: 0.1893
Validation: {'recall': 0.37775130859847383, 'precision': 0.45528081870892456, 'f1': 0.3953224375368869, 'auc': 0.7687157856368331, 'prauc': 0.41902460507012385}
Test:       {'recall': 0.3770230546564747, 'precision': 0.43946859140225236, 'f1': 0.3915021454688565, 'auc': 0.7565005125636836, 'prauc': 0.4147433392508717}

Early stopping triggered after 21 epochs (no improvement for 5 epochs).

Best validation performance:
{'recall': 0.3867931697812348, 'precision': 0.4596597262672828, 'f1': 0.4001145885016164, 'auc': 0.7660960693038505, 'prauc': 0.4279589700979775}
Corresponding test performance:
{'recall': 0.3899699879118781, 'precision': 0.43923315800662244, 'f1': 0.39970774360995365, 'auc': 0.7620414606370725, 'prauc': 0.4223139363524872}
[INFO] Random seed set to 1812140441
Training with seed: 1812140441


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.45it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.46it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.13it/s]


Epoch: 001, Average Loss: 0.3174
Validation: {'recall': 0.3104016759786675, 'precision': 0.3807279535463104, 'f1': 0.3262166718906436, 'auc': 0.7429162915102437, 'prauc': 0.4046052596442727}
Test:       {'recall': 0.3153611061012885, 'precision': 0.42101477736778786, 'f1': 0.3298550805542058, 'auc': 0.739972235777861, 'prauc': 0.40291270462851236}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.51it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.25it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.10it/s]


Epoch: 002, Average Loss: 0.2899
Validation: {'recall': 0.3639259403375217, 'precision': 0.42512215805802933, 'f1': 0.38113426734472167, 'auc': 0.774316528693895, 'prauc': 0.431761720471411}
Test:       {'recall': 0.3663289948787292, 'precision': 0.42124591718734206, 'f1': 0.3794891930042001, 'auc': 0.7714979054670736, 'prauc': 0.4295898406186494}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.60it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.31it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.44it/s]


Epoch: 003, Average Loss: 0.2803
Validation: {'recall': 0.3445067999595121, 'precision': 0.44599998728006934, 'f1': 0.3707775561995387, 'auc': 0.7706212851257298, 'prauc': 0.43398598810688616}
Test:       {'recall': 0.345780982941679, 'precision': 0.44595201580669375, 'f1': 0.3709759920708738, 'auc': 0.7720482373400723, 'prauc': 0.43555723631330523}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.35it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 123.76it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 115.81it/s]


Epoch: 004, Average Loss: 0.2734
Validation: {'recall': 0.3743897761812983, 'precision': 0.4256643876009124, 'f1': 0.38737442311573417, 'auc': 0.772954043899603, 'prauc': 0.43635939773462634}
Test:       {'recall': 0.3752289537370118, 'precision': 0.4781148175883972, 'f1': 0.3852173107156538, 'auc': 0.7750201926062731, 'prauc': 0.4364094942336081}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.33it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.73it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.83it/s]


Epoch: 005, Average Loss: 0.2690
Validation: {'recall': 0.359969258125211, 'precision': 0.43254825319900836, 'f1': 0.38213124169506585, 'auc': 0.783510703602932, 'prauc': 0.4441173343443618}
Test:       {'recall': 0.35995312735444274, 'precision': 0.4837862676822784, 'f1': 0.3799985237116486, 'auc': 0.7821646047985353, 'prauc': 0.44187928075995725}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.57it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 120.25it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 116.53it/s]


Epoch: 006, Average Loss: 0.2631
Validation: {'recall': 0.34743718804001167, 'precision': 0.43241560209372765, 'f1': 0.37287485102818546, 'auc': 0.7732512284559259, 'prauc': 0.4379279504691695}
Test:       {'recall': 0.35527306314110113, 'precision': 0.46232605519924785, 'f1': 0.37868081344349713, 'auc': 0.7759355293247886, 'prauc': 0.44126187579804327}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.41it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.06it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 116.65it/s]


Epoch: 007, Average Loss: 0.2585
Validation: {'recall': 0.3459686402037569, 'precision': 0.4368447518931846, 'f1': 0.37543430389651067, 'auc': 0.7715264197227006, 'prauc': 0.4356794989779515}
Test:       {'recall': 0.34480368415235696, 'precision': 0.4602446915621313, 'f1': 0.3723553315115294, 'auc': 0.7715923715818828, 'prauc': 0.4335558837181869}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.38it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.81it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.38it/s]


Epoch: 008, Average Loss: 0.2519
Validation: {'recall': 0.3679566163841435, 'precision': 0.4896152887209988, 'f1': 0.3917749472777246, 'auc': 0.770619006789371, 'prauc': 0.43866797779607825}
Test:       {'recall': 0.3691096300433363, 'precision': 0.4882112182443899, 'f1': 0.3915495243803545, 'auc': 0.7688996555208243, 'prauc': 0.43582382900420336}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.33it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.13it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 116.57it/s]


Epoch: 009, Average Loss: 0.2459
Validation: {'recall': 0.371093726952541, 'precision': 0.4749026568646534, 'f1': 0.3858003111786775, 'auc': 0.7698845677681783, 'prauc': 0.43747889571413023}
Test:       {'recall': 0.3724507119449229, 'precision': 0.4665915052575026, 'f1': 0.3849643638658172, 'auc': 0.7720948047452502, 'prauc': 0.4348636446140717}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.44it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.53it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.89it/s]


Epoch: 010, Average Loss: 0.2422
Validation: {'recall': 0.3824972087027636, 'precision': 0.4461359412507186, 'f1': 0.398098193502882, 'auc': 0.7673362917819405, 'prauc': 0.4348561830561144}
Test:       {'recall': 0.38051427417147843, 'precision': 0.48831206555903583, 'f1': 0.39468806339327006, 'auc': 0.7706130464391662, 'prauc': 0.43265630276496836}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.41it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.69it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.80it/s]


Epoch: 011, Average Loss: 0.2365
Validation: {'recall': 0.36142047430622537, 'precision': 0.46937210302739996, 'f1': 0.3889385129755871, 'auc': 0.7681093293092598, 'prauc': 0.43242809137184257}
Test:       {'recall': 0.36469903187731567, 'precision': 0.487271523402831, 'f1': 0.39034221641228933, 'auc': 0.7676966052953937, 'prauc': 0.4281720642495634}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.49it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 120.42it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.50it/s]


Epoch: 012, Average Loss: 0.2305
Validation: {'recall': 0.3692120375177594, 'precision': 0.47108853825240704, 'f1': 0.38972316441250965, 'auc': 0.7721428285757603, 'prauc': 0.43317341292797656}
Test:       {'recall': 0.3681941875659166, 'precision': 0.4675322129337646, 'f1': 0.38742609083952234, 'auc': 0.7675095416510432, 'prauc': 0.4314966631935205}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.51it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 121.75it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 115.36it/s]


Epoch: 013, Average Loss: 0.2258
Validation: {'recall': 0.3671359131967505, 'precision': 0.5120230161052057, 'f1': 0.3883745376049259, 'auc': 0.7636809548107979, 'prauc': 0.42838400541887234}
Test:       {'recall': 0.36539794002759357, 'precision': 0.4556015166257714, 'f1': 0.38232862924641003, 'auc': 0.7608847905850649, 'prauc': 0.4262153833652067}


Training Batches: 100%|██████████| 1713/1713 [00:52<00:00, 32.54it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 122.84it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 118.94it/s]


Epoch: 014, Average Loss: 0.2200
Validation: {'recall': 0.37440964734777005, 'precision': 0.46759380975384357, 'f1': 0.39090069754976486, 'auc': 0.7641722422253392, 'prauc': 0.42921521970076526}
Test:       {'recall': 0.37380705086149246, 'precision': 0.454376801669109, 'f1': 0.3870828418829024, 'auc': 0.7649455731087325, 'prauc': 0.4269029890577475}


Training Batches: 100%|██████████| 1713/1713 [00:53<00:00, 32.29it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 123.12it/s]
Running inference: 100%|██████████| 1266/1266 [00:10<00:00, 117.83it/s]

Epoch: 015, Average Loss: 0.2168
Validation: {'recall': 0.36542051106749174, 'precision': 0.5085585311029457, 'f1': 0.3891892433974781, 'auc': 0.765421303675298, 'prauc': 0.4285495364477936}
Test:       {'recall': 0.3644083724654385, 'precision': 0.47699942487442176, 'f1': 0.38669939068146575, 'auc': 0.7648033316942922, 'prauc': 0.42545424387692266}

Early stopping triggered after 15 epochs (no improvement for 5 epochs).

Best validation performance:
{'recall': 0.3824972087027636, 'precision': 0.4461359412507186, 'f1': 0.398098193502882, 'auc': 0.7673362917819405, 'prauc': 0.4348561830561144}
Corresponding test performance:
{'recall': 0.38051427417147843, 'precision': 0.48831206555903583, 'f1': 0.39468806339327006, 'auc': 0.7706130464391662, 'prauc': 0.43265630276496836}





In [17]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")


Final Metrics:
recall: 0.3842 ± 0.0082
precision: 0.4448 ± 0.0255
f1: 0.3948 ± 0.0044
auc: 0.7696 ± 0.0040
prauc: 0.4336 ± 0.0070


In [18]:
# Assuming final_metrics and args are defined
# Create the results directory if it doesn't exist
os.makedirs(f"./results/{args['dataset']}", exist_ok=True)

# Define the file name
file_name = f"{args['dataset']}-{args['task']}.txt"
file_path = os.path.join(f"./results/{args['dataset']}", file_name)

# Save the metrics to a text file
with open(file_path, 'w') as f:
    f.write("Final Metrics:\n")
    for key in final_metrics.keys():
        mean_value = np.mean(final_metrics[key])
        std_value = np.std(final_metrics[key])
        f.write(f"{key}: {mean_value:.4f} ± {std_value:.4f}\n")