In [1]:
from set_seed_utils import set_random_seed
import os
import random
import numpy as np
import pickle
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from token_utils_rep import EHRTokenizer
from dataset_utils_rep import HBERTFinetuneEHRDataset, batcher, UniqueIDSampler
from HEART_rep import HBERT_Finetune
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve, precision_recall_fscore_support

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
@torch.no_grad()
def evaluate(model, dataloader, device, task_type="binary"):
    model.eval()
    predicted_scores, gt_labels = [], []
    # gt_labels is ground truth labels, predicted_scores is the output logits
    for _, batch in enumerate(tqdm(dataloader, desc="Running inference")):                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          
        batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]
        labels = batch[-1]
        output_logits = model(*batch[:-1])
        predicted_scores.append(output_logits)
        gt_labels.append(labels)
    
    if task_type == "binary":
        # standard binary classification evaluation
        predicted_scores = torch.cat(predicted_scores, dim=0).view(-1)
        gt_labels = torch.cat(gt_labels, dim=0).view(-1).cpu().numpy()
        scores = predicted_scores.cpu().numpy()      
    #   The threshold should be 0 because:
	# 	Your model outputs logits, not probabilities. logit > 0 ≡ probability > 0.5
        predicted_labels = (predicted_scores > 0).float().cpu().numpy()

        precision = (predicted_labels * gt_labels).sum() / (predicted_labels.sum() + 1e-8)
        recall = (predicted_labels * gt_labels).sum() / (gt_labels.sum() + 1e-8)
        f1 = 2 * precision * recall / (precision + recall + 1e-8)
        roc_auc = roc_auc_score(gt_labels, scores)
        precision_curve, recall_curve, _ = precision_recall_curve(gt_labels, scores)
        pr_auc = auc(recall_curve, precision_curve)

        return {"precision":precision, "recall":recall, "f1":f1, "auc":roc_auc, "prauc":pr_auc}
    else:
        # —— Multi-label classification evaluation (per-class over the batch) —— #
        # predicted_scores: list[Tensor] -> [B, C] logits
        logits = torch.cat(predicted_scores, dim=0)           # [B, C]
        y_true_t = torch.cat(gt_labels, dim=0)                # [B, C]
        # 1) 连续分数用于 AUC/PR-AUC：sigmoid(logits)
        #    CPU + float16 无 sigmoid 实现，先升到 fp32
        logits_for_sigmoid = logits.float() if (logits.device.type == "cpu" and logits.dtype == torch.float16) else logits
        prob_t = torch.sigmoid(logits_for_sigmoid)            # [B, C] in [0,1]

        # 2) 阈值化用于 P/R/F1：logits > 0 等价于 prob >= 0.5
        y_pred_t = (logits > 0).to(torch.int32)               # [B, C]

        # 转 numpy
        y_true = y_true_t.cpu().numpy().astype(np.int32)      # [B, C]
        y_pred = y_pred_t.cpu().numpy().astype(np.int32)      # [B, C]
        scores = prob_t.cpu().numpy()                         # [B, C] float

        # 3) 按类计算 Precision/Recall/F1
        p_cls, r_cls, f1_cls, _ = precision_recall_fscore_support(
            y_true, y_pred, average=None, zero_division=0
        )

        # 4) 按类计算 AUC/PR-AUC（遇到单一类别则设为 NaN）
        C = y_true.shape[1]
        aucs, praucs = [], []
        for c in range(C):
            yt, ys = y_true[:, c], scores[:, c]
            if yt.max() == yt.min():           # 全 0 或全 1，曲线不定义
                aucs.append(np.nan)
                praucs.append(np.nan)
                continue
            aucs.append(roc_auc_score(yt, ys))
            prec_curve, rec_curve, _ = precision_recall_curve(yt, ys)
            praucs.append(auc(rec_curve, prec_curve))

        # 5) 取宏平均（忽略 NaN）
        ave_precision = float(np.mean(p_cls))
        ave_recall    = float(np.mean(r_cls))
        ave_f1        = float(np.mean(f1_cls))
        ave_auc       = float(np.nanmean(aucs))   if np.any(~np.isnan(aucs))   else np.nan
        ave_prauc     = float(np.nanmean(praucs)) if np.any(~np.isnan(praucs)) else np.nan

        ave_f1, ave_auc, ave_prauc, ave_recall, ave_precision = np.mean(ave_f1), np.mean(ave_auc), np.mean(ave_prauc), np.mean(ave_recall), np.mean(ave_precision)
        return {"recall":ave_recall, "precision":ave_precision, "f1":ave_f1, "auc":ave_auc, "prauc":ave_prauc}

In [4]:
args = {
    "seed": 0,
    "dataset": "MIMIC-IV", 
    "task": "next_diag_6m",  # options: death, stay, readmission, next_diag_6m, next_diag_12m
    "encoder": "hi_edge",  # options: hi_edge, hi_node, hi_edge_node
    "batch_size": 4,
    "eval_batch_size": 4,
    "pretrain_mask_rate": 0.7,
    "pretrain_anomaly_rate": 0.05,
    "pretrain_anomaly_loss_weight": 1,
    "pretrain_pos_weight": 1,
    "lr": 1e-4,
    "epochs": 50,
    "num_hidden_layers": 5,
    "num_attention_heads": 6,
    "attention_probs_dropout_prob": 0.2,
    "hidden_dropout_prob": 0.2,
    "edge_hidden_size": 32,
    "hidden_size": 288,  # must be divisible by num_attention_heads
    "intermediate_size": 288,
    "save_model": True,
    "gat": "dotattn",
    "gnn_n_heads": 1,
    "gnn_temp": 1,
    "diag_med_emb": "tree",  # simple, tree
    "early_stop_patience": 5,
}

In [5]:
exp_name = "Pretrain-HBERT" \
    + "-" + str(args["dataset"]) \
    + "-" + str(args["encoder"]) \
    + "-" + str(args["pretrain_mask_rate"]) \
    + "-" + str(args["pretrain_anomaly_rate"]) \
    + "-" + str(args["pretrain_anomaly_loss_weight"]) \
    + "-" + str(args["hidden_size"]) \
    + "-" + str(args["edge_hidden_size"]) \
    + "-" + str(args["num_hidden_layers"]) \
    + "-" + str(args["num_attention_heads"]) \
    + "-" + str(args["attention_probs_dropout_prob"]) \
    + "-" + str(args["hidden_dropout_prob"]) \
    + "-" + str(args["intermediate_size"]) \
    + "-" + str(args["gat"]) \
    + "-" + str(args["gnn_n_heads"]) \
    + "-" + str(args["gnn_temp"]) \
    + "-" + str(args["diag_med_emb"])
print(exp_name)

Pretrain-HBERT-MIMIC-IV-hi_edge-0.7-0.05-1-288-32-5-6-0.2-0.2-288-dotattn-1-1-tree


In [6]:
pretrained_weight_path = "./pretrained_models/" + exp_name + f"/pretrained_model.pt"
finetune_exp_name = f"Finetune-{args['task']}-" + exp_name
save_path = "./saved_model/" + finetune_exp_name
if args["save_model"] and not os.path.exists(save_path):
    os.makedirs(save_path)

In [7]:
args["predicted_token_type"] = ["diag", "med", "pro", "lab"]
args["special_tokens"] = ("[PAD]", "[CLS]", "[SEP]", "[MASK0]", "[MASK1]", "[MASK2]", "[MASK3]")
args["max_visit_size"] = 15

full_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic.pkl"

if args["task"] == "next_diag_6m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_6m.pkl"
elif args["task"] == "next_diag_12m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_downstream.pkl"

In [8]:
ehr_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_data["NDC"].values.tolist()
lab_sentences = ehr_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_data["PRO_CODE"].values.tolist()
gender_set = [["M"], ["F"]]
age_gender_set = [[str(c) + "_" + gender] for c in set(ehr_data["AGE"].values.tolist()) for gender in ["M", "F"]]
age_set = [[c] for c in set(ehr_data["AGE"].values.tolist())]    

In [9]:
tokenizer = EHRTokenizer(diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, gender_set, age_set, age_gender_set, special_tokens=args["special_tokens"])
tokenizer.build_tree()

In [10]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))

train_dataset = HBERTFinetuneEHRDataset(
    train_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

val_dataset = HBERTFinetuneEHRDataset(
    val_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

test_dataset = HBERTFinetuneEHRDataset(
    test_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

train_dataloader = DataLoader(
    train_dataset, 
    batch_sampler=UniqueIDSampler(train_dataset.get_ids(), batch_size=args["batch_size"]),
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False), 
)

val_dataloader = DataLoader(
    val_dataset, 
    batch_sampler=UniqueIDSampler(val_dataset.get_ids(), batch_size=args["batch_size"]),
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False), 
)

test_dataloader = DataLoader(
    test_dataset, 
    batch_sampler=UniqueIDSampler(test_dataset.get_ids(), batch_size=args["eval_batch_size"]),
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False),
)

In [11]:
# examine a batch
batch = next(iter(train_dataloader))  # 取第一个 batch
input_ids, input_types, edge_index, visit_positions, labeled_batch_idx, labels = batch

# 打印每个张量的形状
print("input_ids shape:", input_ids.shape)
print("input_types shape:", input_types.shape)
print("visit_positions shape:", visit_positions.shape)
print("labeled_batch_idx shape:", len(labeled_batch_idx)) # it is a list
print("labels shape:", labels.shape)

input_ids shape: torch.Size([13, 55])
input_types shape: torch.Size([13, 55])
visit_positions shape: torch.Size([13])
labeled_batch_idx shape: 4
labels shape: torch.Size([4, 18])


In [12]:
args["vocab_size"] = len(args["special_tokens"]) + \
                     len(tokenizer.diag_voc.id2word) + \
                     len(tokenizer.pro_voc.id2word) + \
                     len(tokenizer.med_voc.id2word) + \
                     len(tokenizer.lab_voc.id2word) + \
                     len(tokenizer.age_voc.id2word) + \
                     len(tokenizer.gender_voc.id2word) + \
                     len(tokenizer.age_gender_voc.id2word)
args["label_vocab_size"] = 18  # only for diagnosis

In [13]:
if args["task"] in ["death", "stay", "readmission"]:
    eval_metric = "f1"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "prauc"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [14]:
def train_with_early_stopping(model, 
                              train_dataloader, 
                              val_dataloader, 
                              test_dataloader,
                              optimizer, 
                              loss_fn, 
                              device, 
                              args,
                              task_type="binary", 
                              eval_metric="f1"):
    best_score = 0.
    best_val_metric = None
    best_test_metric = None
    epochs_no_improve = 0

    for epoch in range(1, 1 + args["epochs"]):
        model.train()
        ave_loss = 0.

        for step, batch in enumerate(tqdm(train_dataloader, desc="Training Batches")):
            batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]

            labels = batch[-1].float()
            output_logits = model(*batch[:-1])
            
            loss = loss_fn(output_logits.view(-1), labels.view(-1))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            ave_loss += loss.item()
            

        ave_loss /= (step + 1)

        # Evaluation
        val_metric = evaluate(model, val_dataloader, device, task_type=task_type)
        test_metric = evaluate(model, test_dataloader, device, task_type=task_type)

        # Logging
        print(f"Epoch: {epoch:03d}, Average Loss: {ave_loss:.4f}")
        print(f"Validation: {val_metric}")
        print(f"Test:       {test_metric}")

        # Check for improvement
        current_score = val_metric[eval_metric]
        if current_score > best_score:
            best_score = current_score
            best_val_metric = val_metric
            best_test_metric = test_metric
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        # Early stopping check
        if epochs_no_improve >= args["early_stop_patience"]:
            print(f"\nEarly stopping triggered after {epoch} epochs (no improvement for {args['early_stop_patience']} epochs).")
            break

    print("\nBest validation performance:")
    print(best_val_metric)
    print("Corresponding test performance:")
    print(best_test_metric)
    return best_test_metric

In [15]:
random.seed(42)
seeds = [random.randint(0, 2**32 - 1) for _ in range(5)]
print(seeds)

[2746317213, 1181241943, 958682846, 3163119785, 1812140441]


In [16]:
final_metrics = {"recall":[], "precision":[],"f1":[],"auc":[],"prauc":[]}

for seed in seeds:
    args["seed"] = seed
    set_random_seed(args["seed"])
    print(f"Training with seed: {args['seed']}")
    
    # Initialize model, optimizer, and loss function
    model = HBERT_Finetune(args, tokenizer)
    model.load_weight(torch.load(pretrained_weight_path, weights_only=True))
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args["lr"])
    
    best_test_metric = train_with_early_stopping(
        model, 
        train_dataloader, 
        val_dataloader, 
        test_dataloader,
        optimizer, 
        loss_fn, 
        device, 
        args,
        task_type=task_type)
    
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

[INFO] Random seed set to 2746317213
Training with seed: 2746317213


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 32.68it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 119.14it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.50it/s]


Epoch: 001, Average Loss: 0.3153
Validation: {'recall': 0.3107983166217798, 'precision': 0.38024995317433546, 'f1': 0.32417999309935475, 'auc': 0.7475655567815238, 'prauc': 0.39840748929128034}
Test:       {'recall': 0.314453173496191, 'precision': 0.388531159524259, 'f1': 0.3316489330255231, 'auc': 0.7489721851843888, 'prauc': 0.40485847061435526}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 33.02it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.84it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.56it/s]


Epoch: 002, Average Loss: 0.2823
Validation: {'recall': 0.3435766682354864, 'precision': 0.4229683209890889, 'f1': 0.35784876160519796, 'auc': 0.7641411715419534, 'prauc': 0.42450161547428883}
Test:       {'recall': 0.3499294838407045, 'precision': 0.42497994040347953, 'f1': 0.3656377561227064, 'auc': 0.7603995104898126, 'prauc': 0.4324305564084084}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 32.90it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.11it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.54it/s]


Epoch: 003, Average Loss: 0.2721
Validation: {'recall': 0.36053057911370595, 'precision': 0.41590350045152946, 'f1': 0.3767361170971377, 'auc': 0.7753998308439219, 'prauc': 0.4307427742547531}
Test:       {'recall': 0.36678315917740867, 'precision': 0.4327940946686339, 'f1': 0.3846384759350725, 'auc': 0.7687292428592754, 'prauc': 0.43963814174884575}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.10it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.29it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.06it/s]


Epoch: 004, Average Loss: 0.2654
Validation: {'recall': 0.3688053004144903, 'precision': 0.4147413124069892, 'f1': 0.3837727023085764, 'auc': 0.7735801134247372, 'prauc': 0.4324149225904602}
Test:       {'recall': 0.37558412627102855, 'precision': 0.42942829754091805, 'f1': 0.3939263314270211, 'auc': 0.7670757702347227, 'prauc': 0.44057355869255715}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 32.98it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 119.48it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.89it/s]


Epoch: 005, Average Loss: 0.2581
Validation: {'recall': 0.34647813992445425, 'precision': 0.42330806943558386, 'f1': 0.35859953004610035, 'auc': 0.7669910827726577, 'prauc': 0.4301068945079291}
Test:       {'recall': 0.3510129400932442, 'precision': 0.4344784712644706, 'f1': 0.3676547497367012, 'auc': 0.7649344110608027, 'prauc': 0.44148007587394056}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 32.87it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.20it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 117.95it/s]


Epoch: 006, Average Loss: 0.2526
Validation: {'recall': 0.3664328387662868, 'precision': 0.42741240224490196, 'f1': 0.38066787230397753, 'auc': 0.7740028779155258, 'prauc': 0.4382382749467568}
Test:       {'recall': 0.37247006089720436, 'precision': 0.43987054297832867, 'f1': 0.3895806988177147, 'auc': 0.7731731183787927, 'prauc': 0.44430724292480267}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 32.74it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.82it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.60it/s]


Epoch: 007, Average Loss: 0.2469
Validation: {'recall': 0.3651476629519075, 'precision': 0.5187381795128991, 'f1': 0.3857668094583328, 'auc': 0.7772209361745345, 'prauc': 0.4388130293069752}
Test:       {'recall': 0.370412651373075, 'precision': 0.45601840702961244, 'f1': 0.3934391290998803, 'auc': 0.7759381770641881, 'prauc': 0.44473890130623267}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.15it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 119.29it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 117.70it/s]


Epoch: 008, Average Loss: 0.2407
Validation: {'recall': 0.3795909168180261, 'precision': 0.446100274234484, 'f1': 0.38707644764428584, 'auc': 0.7667174494745337, 'prauc': 0.43093696569197504}
Test:       {'recall': 0.38616625740703, 'precision': 0.5223646710992795, 'f1': 0.39869015987944234, 'auc': 0.7646024058312169, 'prauc': 0.4379884430411721}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.38it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.97it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 119.76it/s]


Epoch: 009, Average Loss: 0.2361
Validation: {'recall': 0.36967507996908733, 'precision': 0.43687979361083773, 'f1': 0.3827825996875991, 'auc': 0.7687883678419601, 'prauc': 0.43457155522134544}
Test:       {'recall': 0.3760139852816523, 'precision': 0.479973343953589, 'f1': 0.39319413285142507, 'auc': 0.7678841355340142, 'prauc': 0.4412061747593829}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 32.75it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 119.12it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.50it/s]


Epoch: 010, Average Loss: 0.2306
Validation: {'recall': 0.37149164993634387, 'precision': 0.4554689521991736, 'f1': 0.38991818121744326, 'auc': 0.7713887238991863, 'prauc': 0.4343049888624957}
Test:       {'recall': 0.3790413032324566, 'precision': 0.5171100089019132, 'f1': 0.4021297744458965, 'auc': 0.7708963930866, 'prauc': 0.43930988855152797}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.26it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.67it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 117.86it/s]


Epoch: 011, Average Loss: 0.2235
Validation: {'recall': 0.3660067841383659, 'precision': 0.48733241270503125, 'f1': 0.38515115191572985, 'auc': 0.7674375228423949, 'prauc': 0.4350184465855127}
Test:       {'recall': 0.3771503301636223, 'precision': 0.5288408926004139, 'f1': 0.39956516730050473, 'auc': 0.7714133311033166, 'prauc': 0.4430214665901975}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.10it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 121.62it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 119.51it/s]


Epoch: 012, Average Loss: 0.2193
Validation: {'recall': 0.389153391723821, 'precision': 0.43848020100933804, 'f1': 0.390581765736691, 'auc': 0.769229386120454, 'prauc': 0.4319568291098246}
Test:       {'recall': 0.39792416282357157, 'precision': 0.46391979818516355, 'f1': 0.40618757906242664, 'auc': 0.7715707323144284, 'prauc': 0.43716246092007194}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.34it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.54it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.27it/s]


Epoch: 013, Average Loss: 0.2123
Validation: {'recall': 0.3706547756049716, 'precision': 0.45189957078645815, 'f1': 0.3910060171915102, 'auc': 0.7639757924995607, 'prauc': 0.42627251984509623}
Test:       {'recall': 0.38152108687614367, 'precision': 0.4606364332455642, 'f1': 0.40408251078642593, 'auc': 0.7668686470430078, 'prauc': 0.42997446748514817}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.14it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.55it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 119.43it/s]


Epoch: 014, Average Loss: 0.2070
Validation: {'recall': 0.3731595778747899, 'precision': 0.44513620361801226, 'f1': 0.38911752778727976, 'auc': 0.7615266914912749, 'prauc': 0.4253857993487167}
Test:       {'recall': 0.37883853345897733, 'precision': 0.468272822991547, 'f1': 0.4010891064074056, 'auc': 0.7642339339368689, 'prauc': 0.43157038145014776}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.13it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 118.80it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 116.85it/s]


Epoch: 015, Average Loss: 0.2022
Validation: {'recall': 0.3895464220473808, 'precision': 0.4490455534576136, 'f1': 0.39746714915601444, 'auc': 0.7652227557223089, 'prauc': 0.4247125376975229}
Test:       {'recall': 0.3957872265780997, 'precision': 0.44979899629578196, 'f1': 0.40846129247016294, 'auc': 0.7627784301673326, 'prauc': 0.4292292781184471}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 33.07it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 121.00it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 119.13it/s]


Epoch: 016, Average Loss: 0.1972
Validation: {'recall': 0.3677186462545036, 'precision': 0.4686891185937706, 'f1': 0.39014995197430014, 'auc': 0.7640660751169461, 'prauc': 0.4274532908093553}
Test:       {'recall': 0.37151111892015687, 'precision': 0.46501976076518, 'f1': 0.39726595272274534, 'auc': 0.7622667826429503, 'prauc': 0.43207848086121975}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.25it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 118.88it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 117.88it/s]


Epoch: 017, Average Loss: 0.1908
Validation: {'recall': 0.39530134134858724, 'precision': 0.4295593935432103, 'f1': 0.39693534166785577, 'auc': 0.7632356145929527, 'prauc': 0.4184848628306622}
Test:       {'recall': 0.39875417588053275, 'precision': 0.4347909232760941, 'f1': 0.40472760356641113, 'auc': 0.7587279821790502, 'prauc': 0.4224518957673492}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.20it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 121.22it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 119.50it/s]


Epoch: 018, Average Loss: 0.1857
Validation: {'recall': 0.3836917522017553, 'precision': 0.44602302777597647, 'f1': 0.39476169157883684, 'auc': 0.7644056177260408, 'prauc': 0.41913961210429895}
Test:       {'recall': 0.3925496755552237, 'precision': 0.450728038217234, 'f1': 0.40675746694257175, 'auc': 0.7577127690147245, 'prauc': 0.42350557905042746}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 32.99it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.86it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 119.33it/s]


Epoch: 019, Average Loss: 0.1824
Validation: {'recall': 0.36487036223103037, 'precision': 0.4430385615132933, 'f1': 0.38141547249510704, 'auc': 0.7561488208325452, 'prauc': 0.4163644722621431}
Test:       {'recall': 0.3742638681433499, 'precision': 0.448697459838155, 'f1': 0.3937466077500002, 'auc': 0.7554246389669479, 'prauc': 0.42168531346969595}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.20it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.42it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 119.69it/s]


Epoch: 020, Average Loss: 0.1765
Validation: {'recall': 0.3653592468473174, 'precision': 0.4429182879512206, 'f1': 0.38881224599504305, 'auc': 0.7600574076734593, 'prauc': 0.41901562946218274}
Test:       {'recall': 0.3748360329387892, 'precision': 0.4701245971037685, 'f1': 0.401867181969871, 'auc': 0.757281744524247, 'prauc': 0.4249848735539532}

Early stopping triggered after 20 epochs (no improvement for 5 epochs).

Best validation performance:
{'recall': 0.3895464220473808, 'precision': 0.4490455534576136, 'f1': 0.39746714915601444, 'auc': 0.7652227557223089, 'prauc': 0.4247125376975229}
Corresponding test performance:
{'recall': 0.3957872265780997, 'precision': 0.44979899629578196, 'f1': 0.40846129247016294, 'auc': 0.7627784301673326, 'prauc': 0.4292292781184471}
[INFO] Random seed set to 1181241943
Training with seed: 1181241943


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 33.00it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.79it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 119.19it/s]


Epoch: 001, Average Loss: 0.3123
Validation: {'recall': 0.30254250619390805, 'precision': 0.36827751219284877, 'f1': 0.31823114794046226, 'auc': 0.739404669413395, 'prauc': 0.3953740145665008}
Test:       {'recall': 0.30537042073242365, 'precision': 0.43053608118012743, 'f1': 0.32475056926453116, 'auc': 0.7435354206678914, 'prauc': 0.4030639640804748}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.32it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.06it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 117.55it/s]


Epoch: 002, Average Loss: 0.2836
Validation: {'recall': 0.33365911968243056, 'precision': 0.44047153823263546, 'f1': 0.3533943539055057, 'auc': 0.7715358999834155, 'prauc': 0.42867463378691106}
Test:       {'recall': 0.34324076905241147, 'precision': 0.4428120904001098, 'f1': 0.36619078365039043, 'auc': 0.776908885074006, 'prauc': 0.4325855111408085}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.11it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.90it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 119.16it/s]


Epoch: 003, Average Loss: 0.2730
Validation: {'recall': 0.3716887238930452, 'precision': 0.41471854539901315, 'f1': 0.3864122613850172, 'auc': 0.7702669250275717, 'prauc': 0.42990760558395946}
Test:       {'recall': 0.37757506485645165, 'precision': 0.42857308118752396, 'f1': 0.396200650039618, 'auc': 0.7720607220291871, 'prauc': 0.44386703821078416}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.30it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 121.36it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 117.47it/s]


Epoch: 004, Average Loss: 0.2658
Validation: {'recall': 0.3517761426715449, 'precision': 0.4374364071576812, 'f1': 0.361403825053047, 'auc': 0.7787527891811286, 'prauc': 0.43858037781963555}
Test:       {'recall': 0.362207802418578, 'precision': 0.5101191846447743, 'f1': 0.37596883679296306, 'auc': 0.7780402420327506, 'prauc': 0.44499642894077823}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.28it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 121.78it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 120.22it/s]


Epoch: 005, Average Loss: 0.2597
Validation: {'recall': 0.35243326561045973, 'precision': 0.4251116114548521, 'f1': 0.3688684597251016, 'auc': 0.7745072512200604, 'prauc': 0.4370310128062254}
Test:       {'recall': 0.3631996069518688, 'precision': 0.5276466781850917, 'f1': 0.38336106605121895, 'auc': 0.7772673804208172, 'prauc': 0.4450804951947245}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.24it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 119.14it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.89it/s]


Epoch: 006, Average Loss: 0.2539
Validation: {'recall': 0.36702640773891876, 'precision': 0.4573746191153124, 'f1': 0.3853387179659515, 'auc': 0.7773627826393792, 'prauc': 0.4369807315557662}
Test:       {'recall': 0.3788904652151236, 'precision': 0.47754873607004394, 'f1': 0.40183322676895883, 'auc': 0.7761372540115951, 'prauc': 0.4448238928198472}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.25it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 119.90it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.34it/s]


Epoch: 007, Average Loss: 0.2481
Validation: {'recall': 0.3631050957596733, 'precision': 0.4743226312208957, 'f1': 0.37722353927879715, 'auc': 0.7767007994565921, 'prauc': 0.43685491850609026}
Test:       {'recall': 0.36620594364494674, 'precision': 0.4355638975281512, 'f1': 0.38405508363397467, 'auc': 0.7716747128092639, 'prauc': 0.44211193829054063}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 33.09it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 119.52it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 119.06it/s]


Epoch: 008, Average Loss: 0.2417
Validation: {'recall': 0.36093288272646895, 'precision': 0.5165493864724322, 'f1': 0.38425510763913573, 'auc': 0.7785622074177038, 'prauc': 0.44271354475193636}
Test:       {'recall': 0.36912380388943344, 'precision': 0.4995863292670235, 'f1': 0.3955343114323976, 'auc': 0.7775307013406011, 'prauc': 0.44886375608848605}

Early stopping triggered after 8 epochs (no improvement for 5 epochs).

Best validation performance:
{'recall': 0.3716887238930452, 'precision': 0.41471854539901315, 'f1': 0.3864122613850172, 'auc': 0.7702669250275717, 'prauc': 0.42990760558395946}
Corresponding test performance:
{'recall': 0.37757506485645165, 'precision': 0.42857308118752396, 'f1': 0.396200650039618, 'auc': 0.7720607220291871, 'prauc': 0.44386703821078416}
[INFO] Random seed set to 958682846
Training with seed: 958682846


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 33.06it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.49it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 119.72it/s]


Epoch: 001, Average Loss: 0.3121
Validation: {'recall': 0.3171324613873552, 'precision': 0.4190058931376719, 'f1': 0.33366128662505734, 'auc': 0.751980282672505, 'prauc': 0.40909515106226885}
Test:       {'recall': 0.3235236294894534, 'precision': 0.4234887284435266, 'f1': 0.3435728663772153, 'auc': 0.7519683394495777, 'prauc': 0.4175941019152105}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 32.94it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 119.88it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.44it/s]


Epoch: 002, Average Loss: 0.2812
Validation: {'recall': 0.3560737639819665, 'precision': 0.40898311779611185, 'f1': 0.3638317295669944, 'auc': 0.7645168875524286, 'prauc': 0.426070031615296}
Test:       {'recall': 0.36455599914628295, 'precision': 0.4199875450872499, 'f1': 0.37646015371164915, 'auc': 0.7659305416514905, 'prauc': 0.43373515163931864}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 33.01it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.30it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.31it/s]


Epoch: 003, Average Loss: 0.2723
Validation: {'recall': 0.3560054714312549, 'precision': 0.42202850138093556, 'f1': 0.3690031283775433, 'auc': 0.7728583732873463, 'prauc': 0.4321518689040884}
Test:       {'recall': 0.3643238660599865, 'precision': 0.435560789407, 'f1': 0.3794604286606018, 'auc': 0.7734762340905958, 'prauc': 0.43748328708028955}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 32.95it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.18it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.43it/s]


Epoch: 004, Average Loss: 0.2648
Validation: {'recall': 0.3764536075970779, 'precision': 0.4753201416959631, 'f1': 0.38163636595748557, 'auc': 0.7794796603895112, 'prauc': 0.4387285920598118}
Test:       {'recall': 0.3840224533444467, 'precision': 0.42674332975786633, 'f1': 0.3891844534738254, 'auc': 0.7731702058649818, 'prauc': 0.443072142318843}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.18it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.78it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.92it/s]


Epoch: 005, Average Loss: 0.2592
Validation: {'recall': 0.3475810130949795, 'precision': 0.4468566943037098, 'f1': 0.36483185256560063, 'auc': 0.7738019844460706, 'prauc': 0.4367668984504056}
Test:       {'recall': 0.3550020151746344, 'precision': 0.46764512479767045, 'f1': 0.3753036296836744, 'auc': 0.768535837923655, 'prauc': 0.44273726876549374}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.11it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.23it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.77it/s]


Epoch: 006, Average Loss: 0.2530
Validation: {'recall': 0.3545518794407875, 'precision': 0.459320594724057, 'f1': 0.3789069272901542, 'auc': 0.7794768340979663, 'prauc': 0.43953508688345005}
Test:       {'recall': 0.35920542608135203, 'precision': 0.453382044853274, 'f1': 0.38570588149038626, 'auc': 0.7772198164778457, 'prauc': 0.4454894421712642}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.34it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 119.17it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.87it/s]


Epoch: 007, Average Loss: 0.2473
Validation: {'recall': 0.36569899323671395, 'precision': 0.45582540711858893, 'f1': 0.37605468307279183, 'auc': 0.7796047296218587, 'prauc': 0.43741796706762615}
Test:       {'recall': 0.37178669824491917, 'precision': 0.47868145241051707, 'f1': 0.3863448503024831, 'auc': 0.7702646968064937, 'prauc': 0.4425528761438364}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 32.96it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 119.67it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.47it/s]


Epoch: 008, Average Loss: 0.2426
Validation: {'recall': 0.3695372486760389, 'precision': 0.4563643739756747, 'f1': 0.38864689228402016, 'auc': 0.7731747908550693, 'prauc': 0.43573674798521206}
Test:       {'recall': 0.37520469658261274, 'precision': 0.5418831459553578, 'f1': 0.3976968401104308, 'auc': 0.7677686345449957, 'prauc': 0.43735483553831406}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.11it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 119.89it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.83it/s]


Epoch: 009, Average Loss: 0.2370
Validation: {'recall': 0.3576370721836164, 'precision': 0.44293350370633816, 'f1': 0.36970842277401483, 'auc': 0.769865053575705, 'prauc': 0.4313557984501611}
Test:       {'recall': 0.36321466313870254, 'precision': 0.4330970697166905, 'f1': 0.3783818394598351, 'auc': 0.7602481238937514, 'prauc': 0.43289635371492763}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 32.99it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.45it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.42it/s]


Epoch: 010, Average Loss: 0.2320
Validation: {'recall': 0.3655282454829726, 'precision': 0.5159036795741452, 'f1': 0.3878215416460627, 'auc': 0.7774724210229665, 'prauc': 0.43542867343824654}
Test:       {'recall': 0.36856235940091203, 'precision': 0.46520040001809515, 'f1': 0.3922724045275264, 'auc': 0.768103787277655, 'prauc': 0.43460724753670066}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 32.90it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.73it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.08it/s]


Epoch: 011, Average Loss: 0.2266
Validation: {'recall': 0.35974233711595666, 'precision': 0.4740332938136668, 'f1': 0.3810054011191585, 'auc': 0.7719285176208202, 'prauc': 0.42875487429859294}
Test:       {'recall': 0.37018261459554036, 'precision': 0.4656258811218237, 'f1': 0.39148762746615606, 'auc': 0.7666114616000351, 'prauc': 0.4326258613652352}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.22it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.17it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.51it/s]


Epoch: 012, Average Loss: 0.2197
Validation: {'recall': 0.3631831250354717, 'precision': 0.4683263783615163, 'f1': 0.38709253597951837, 'auc': 0.7730637718522839, 'prauc': 0.43333125623953234}
Test:       {'recall': 0.36216121788601313, 'precision': 0.474361952219323, 'f1': 0.38685331486002916, 'auc': 0.7728021355192297, 'prauc': 0.4334036004041045}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.20it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.20it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.90it/s]


Epoch: 013, Average Loss: 0.2135
Validation: {'recall': 0.3622926405388338, 'precision': 0.481672169890789, 'f1': 0.38035137717896295, 'auc': 0.7690065614369133, 'prauc': 0.429538748603529}
Test:       {'recall': 0.3672665776083314, 'precision': 0.44717338599154427, 'f1': 0.3866071279585185, 'auc': 0.7691853153427486, 'prauc': 0.42871809928562055}

Early stopping triggered after 13 epochs (no improvement for 5 epochs).

Best validation performance:
{'recall': 0.3695372486760389, 'precision': 0.4563643739756747, 'f1': 0.38864689228402016, 'auc': 0.7731747908550693, 'prauc': 0.43573674798521206}
Corresponding test performance:
{'recall': 0.37520469658261274, 'precision': 0.5418831459553578, 'f1': 0.3976968401104308, 'auc': 0.7677686345449957, 'prauc': 0.43735483553831406}
[INFO] Random seed set to 3163119785
Training with seed: 3163119785


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.26it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 119.10it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.81it/s]


Epoch: 001, Average Loss: 0.3113
Validation: {'recall': 0.30080436044721254, 'precision': 0.3675801847878004, 'f1': 0.3273202190040707, 'auc': 0.7420261322346433, 'prauc': 0.40442970338019424}
Test:       {'recall': 0.3067806726333806, 'precision': 0.3764413660933295, 'f1': 0.33554279485062233, 'auc': 0.7482153803947481, 'prauc': 0.40610247315487213}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 33.04it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.80it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 119.88it/s]


Epoch: 002, Average Loss: 0.2822
Validation: {'recall': 0.3136548574007386, 'precision': 0.4371867153499418, 'f1': 0.33005105104553145, 'auc': 0.7594504866596864, 'prauc': 0.4253499273200369}
Test:       {'recall': 0.31738293870797246, 'precision': 0.44532090185751255, 'f1': 0.33685088482541836, 'auc': 0.7604251180750666, 'prauc': 0.43239243974592}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 33.00it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.59it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 120.06it/s]


Epoch: 003, Average Loss: 0.2722
Validation: {'recall': 0.31898154069042217, 'precision': 0.44037459807316975, 'f1': 0.34284957463306126, 'auc': 0.7705321939095406, 'prauc': 0.43576782110987616}
Test:       {'recall': 0.3224916006366181, 'precision': 0.4491260983207672, 'f1': 0.3493716953298172, 'auc': 0.7740643545903824, 'prauc': 0.4381306390971104}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 33.08it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 119.14it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 120.22it/s]


Epoch: 004, Average Loss: 0.2648
Validation: {'recall': 0.3350889718434076, 'precision': 0.43239405946034454, 'f1': 0.35197073560473463, 'auc': 0.7715002551613869, 'prauc': 0.4407964995055431}
Test:       {'recall': 0.3448705288746869, 'precision': 0.4530003092600554, 'f1': 0.36449569989681396, 'auc': 0.7725628736972707, 'prauc': 0.44708431872948495}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 32.87it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.81it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.16it/s]


Epoch: 005, Average Loss: 0.2593
Validation: {'recall': 0.36837687888581816, 'precision': 0.47267747094708984, 'f1': 0.38543056293627714, 'auc': 0.7749699976609968, 'prauc': 0.44287881407636204}
Test:       {'recall': 0.3759554021033185, 'precision': 0.42661293678062084, 'f1': 0.39442017855786515, 'auc': 0.7761417166027218, 'prauc': 0.4496522157378192}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 32.96it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.41it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.16it/s]


Epoch: 006, Average Loss: 0.2535
Validation: {'recall': 0.3681114592817492, 'precision': 0.44802466795045287, 'f1': 0.3840932846976826, 'auc': 0.7738689792940634, 'prauc': 0.4371012626511279}
Test:       {'recall': 0.3741841986025618, 'precision': 0.4633425185744431, 'f1': 0.3938522839716126, 'auc': 0.7778334902442641, 'prauc': 0.4459843666389563}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 32.85it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.12it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.67it/s]


Epoch: 007, Average Loss: 0.2478
Validation: {'recall': 0.3653609243035756, 'precision': 0.45637359605078903, 'f1': 0.3759796710386574, 'auc': 0.7732893053427303, 'prauc': 0.4353125585377763}
Test:       {'recall': 0.3731044384378195, 'precision': 0.4508015962506543, 'f1': 0.3874435825664632, 'auc': 0.7750983481615408, 'prauc': 0.4419384962803974}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.21it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.26it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 117.86it/s]


Epoch: 008, Average Loss: 0.2424
Validation: {'recall': 0.35553041901974775, 'precision': 0.4812249769581672, 'f1': 0.37903131567228476, 'auc': 0.7774739130332279, 'prauc': 0.43698528983006746}
Test:       {'recall': 0.3667504626276749, 'precision': 0.47333368008056126, 'f1': 0.39429296013431386, 'auc': 0.7818764026231855, 'prauc': 0.4501654060203162}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 32.96it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.44it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 117.80it/s]


Epoch: 009, Average Loss: 0.2364
Validation: {'recall': 0.3588602050155014, 'precision': 0.4696968518459458, 'f1': 0.37864109122327394, 'auc': 0.778033065170793, 'prauc': 0.4356765714376271}
Test:       {'recall': 0.37230761765001325, 'precision': 0.4804817844531459, 'f1': 0.39695598338734245, 'auc': 0.776267570899123, 'prauc': 0.44192018547766565}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 32.82it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 119.22it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.30it/s]


Epoch: 010, Average Loss: 0.2304
Validation: {'recall': 0.3471739275272503, 'precision': 0.45597908342893323, 'f1': 0.366681298204879, 'auc': 0.7731447602495854, 'prauc': 0.4322538093312534}
Test:       {'recall': 0.3600487999618335, 'precision': 0.46183023238266974, 'f1': 0.38308659670037365, 'auc': 0.7713878890634023, 'prauc': 0.43660242418869316}

Early stopping triggered after 10 epochs (no improvement for 5 epochs).

Best validation performance:
{'recall': 0.36837687888581816, 'precision': 0.47267747094708984, 'f1': 0.38543056293627714, 'auc': 0.7749699976609968, 'prauc': 0.44287881407636204}
Corresponding test performance:
{'recall': 0.3759554021033185, 'precision': 0.42661293678062084, 'f1': 0.39442017855786515, 'auc': 0.7761417166027218, 'prauc': 0.4496522157378192}
[INFO] Random seed set to 1812140441
Training with seed: 1812140441


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 33.05it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.37it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.67it/s]


Epoch: 001, Average Loss: 0.3124
Validation: {'recall': 0.3008606428982772, 'precision': 0.4288077904420562, 'f1': 0.31564079484727375, 'auc': 0.7500869387660581, 'prauc': 0.4033941454543553}
Test:       {'recall': 0.30481367050104247, 'precision': 0.37671282407793094, 'f1': 0.32282775857813867, 'auc': 0.7491084195651625, 'prauc': 0.4032556214043679}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 32.99it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.54it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.53it/s]


Epoch: 002, Average Loss: 0.2844
Validation: {'recall': 0.3429747095107861, 'precision': 0.43400896773804415, 'f1': 0.36639804084460076, 'auc': 0.7702179202664424, 'prauc': 0.4300482013973623}
Test:       {'recall': 0.3458612337554751, 'precision': 0.43871949844026387, 'f1': 0.37256852003856183, 'auc': 0.7735690451740735, 'prauc': 0.43661354873183034}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 33.09it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.74it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 119.01it/s]


Epoch: 003, Average Loss: 0.2735
Validation: {'recall': 0.3660939562036283, 'precision': 0.41754528989026213, 'f1': 0.3816936594755378, 'auc': 0.7663945066938499, 'prauc': 0.42862166071668845}
Test:       {'recall': 0.37488073255958754, 'precision': 0.4299491412411468, 'f1': 0.3928976834844307, 'auc': 0.7654521391342823, 'prauc': 0.43854083292022716}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 32.88it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.43it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.79it/s]


Epoch: 004, Average Loss: 0.2664
Validation: {'recall': 0.36127031140426713, 'precision': 0.4213475444678633, 'f1': 0.3711188416341764, 'auc': 0.7706324237842387, 'prauc': 0.43429582078839324}
Test:       {'recall': 0.370095322163819, 'precision': 0.42767735297662873, 'f1': 0.3818892856527666, 'auc': 0.7712701228918919, 'prauc': 0.4393867211727298}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 33.00it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 119.22it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 117.77it/s]


Epoch: 005, Average Loss: 0.2596
Validation: {'recall': 0.35769083230105575, 'precision': 0.484891189976197, 'f1': 0.37366115612204637, 'auc': 0.774484980020751, 'prauc': 0.4347238544052857}
Test:       {'recall': 0.36580909968645603, 'precision': 0.45481163374325423, 'f1': 0.383742167339552, 'auc': 0.7753042379460979, 'prauc': 0.4462617796425365}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.25it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.98it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.91it/s]


Epoch: 006, Average Loss: 0.2536
Validation: {'recall': 0.34706698071862013, 'precision': 0.4758970453832892, 'f1': 0.3731644223639731, 'auc': 0.775699806999706, 'prauc': 0.44161970378615434}
Test:       {'recall': 0.35126065381904376, 'precision': 0.4894603953680746, 'f1': 0.3788952980436378, 'auc': 0.7754652968470186, 'prauc': 0.447850214204292}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.11it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 119.88it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 119.59it/s]


Epoch: 007, Average Loss: 0.2482
Validation: {'recall': 0.34850889242229033, 'precision': 0.4800597155894006, 'f1': 0.3591218831928109, 'auc': 0.7731578236581432, 'prauc': 0.43887517861943387}
Test:       {'recall': 0.35197213020863366, 'precision': 0.44061864309618604, 'f1': 0.36800003508101603, 'auc': 0.7742105133219586, 'prauc': 0.4446625206715653}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.17it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.78it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.72it/s]


Epoch: 008, Average Loss: 0.2421
Validation: {'recall': 0.3637203035538064, 'precision': 0.4902845416024502, 'f1': 0.3834492164503367, 'auc': 0.7750074608768741, 'prauc': 0.4409221964043846}
Test:       {'recall': 0.365141688898628, 'precision': 0.446057315433477, 'f1': 0.3867878986640451, 'auc': 0.7785747752359486, 'prauc': 0.4477618878771021}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 33.03it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.27it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.56it/s]


Epoch: 009, Average Loss: 0.2366
Validation: {'recall': 0.3718671064451216, 'precision': 0.5180455235945742, 'f1': 0.3869971838534312, 'auc': 0.776102927548839, 'prauc': 0.43438440341558837}
Test:       {'recall': 0.376720542195493, 'precision': 0.46927333771944585, 'f1': 0.3941748054897613, 'auc': 0.7717468854340456, 'prauc': 0.44338836108707763}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.22it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.49it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 117.99it/s]


Epoch: 010, Average Loss: 0.2311
Validation: {'recall': 0.3583206100063407, 'precision': 0.48646225245992025, 'f1': 0.37796750169444054, 'auc': 0.7774295051570496, 'prauc': 0.4350762663893474}
Test:       {'recall': 0.36761414316443874, 'precision': 0.47149194465061167, 'f1': 0.3912027929072333, 'auc': 0.7764506535056469, 'prauc': 0.44313650813506483}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 32.96it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.65it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 120.39it/s]


Epoch: 011, Average Loss: 0.2240
Validation: {'recall': 0.3480340706410589, 'precision': 0.47623407804586015, 'f1': 0.37248963559523923, 'auc': 0.7752153691397299, 'prauc': 0.4360613023928815}
Test:       {'recall': 0.3526651870481333, 'precision': 0.5819604485845019, 'f1': 0.384879567288438, 'auc': 0.775858563002636, 'prauc': 0.4417778129877748}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 32.94it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 118.53it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 119.39it/s]


Epoch: 012, Average Loss: 0.2195
Validation: {'recall': 0.36192896417867043, 'precision': 0.46816810192553215, 'f1': 0.38505987229256716, 'auc': 0.773569422351776, 'prauc': 0.4323712034720533}
Test:       {'recall': 0.3691588275930926, 'precision': 0.529304406323264, 'f1': 0.3973289576641958, 'auc': 0.7742772593058898, 'prauc': 0.44164773404592983}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 32.99it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.53it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.45it/s]


Epoch: 013, Average Loss: 0.2137
Validation: {'recall': 0.3730810868809338, 'precision': 0.44813390921557505, 'f1': 0.3935517721946289, 'auc': 0.7728273346263278, 'prauc': 0.43266516297033175}
Test:       {'recall': 0.376226590326062, 'precision': 0.5219919740550573, 'f1': 0.40108092890071917, 'auc': 0.7685074328098135, 'prauc': 0.43637178692218903}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 33.03it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.92it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.00it/s]


Epoch: 014, Average Loss: 0.2081
Validation: {'recall': 0.36750310885721754, 'precision': 0.4773439621370091, 'f1': 0.3867375113916378, 'auc': 0.7758327244680568, 'prauc': 0.429295713218856}
Test:       {'recall': 0.37398408269567035, 'precision': 0.4986432072667583, 'f1': 0.39991250143819446, 'auc': 0.7691789645283978, 'prauc': 0.43477621478125894}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.22it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 120.73it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 119.86it/s]


Epoch: 015, Average Loss: 0.2009
Validation: {'recall': 0.3688737853334872, 'precision': 0.4449432575764923, 'f1': 0.39008074513518287, 'auc': 0.7749086049090962, 'prauc': 0.42446757368516524}
Test:       {'recall': 0.376977476091045, 'precision': 0.4668560718513129, 'f1': 0.403279223775456, 'auc': 0.7662084260766184, 'prauc': 0.43200882368520244}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 32.95it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 121.42it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 117.27it/s]


Epoch: 016, Average Loss: 0.1964
Validation: {'recall': 0.3833231860337848, 'precision': 0.4512778343448859, 'f1': 0.3983319802815396, 'auc': 0.7694094321080028, 'prauc': 0.4217448260526136}
Test:       {'recall': 0.38752118976417277, 'precision': 0.4615659356192304, 'f1': 0.4048043710944446, 'auc': 0.7664240303100122, 'prauc': 0.4304230745938008}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.35it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 122.14it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 119.48it/s]


Epoch: 017, Average Loss: 0.1889
Validation: {'recall': 0.38441970770866274, 'precision': 0.4801646593276043, 'f1': 0.3976625267877712, 'auc': 0.7710178927786392, 'prauc': 0.4254993006161156}
Test:       {'recall': 0.391093723714611, 'precision': 0.4886016237431463, 'f1': 0.4089891155269585, 'auc': 0.7685590579917303, 'prauc': 0.4332963954389919}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.14it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 119.37it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.81it/s]


Epoch: 018, Average Loss: 0.1851
Validation: {'recall': 0.3752479642649009, 'precision': 0.4403031672925997, 'f1': 0.3879129707770794, 'auc': 0.7659513762260696, 'prauc': 0.41952917571826387}
Test:       {'recall': 0.38257302072297317, 'precision': 0.477245249064882, 'f1': 0.40068576855590154, 'auc': 0.764411042600761, 'prauc': 0.4336306318360951}


Training Batches: 100%|██████████| 1390/1390 [00:42<00:00, 32.98it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 118.91it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.49it/s]


Epoch: 019, Average Loss: 0.1803
Validation: {'recall': 0.38649972654177156, 'precision': 0.45323484031024397, 'f1': 0.39519707841667, 'auc': 0.7615603456524702, 'prauc': 0.4160379488636293}
Test:       {'recall': 0.39183273529999396, 'precision': 0.4511093519155355, 'f1': 0.4054890530663659, 'auc': 0.7627734680986068, 'prauc': 0.42666082571533825}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.12it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 121.25it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 118.20it/s]


Epoch: 020, Average Loss: 0.1752
Validation: {'recall': 0.37720000725767, 'precision': 0.4500054539420451, 'f1': 0.39520775332929997, 'auc': 0.7622599034140293, 'prauc': 0.4183739119873169}
Test:       {'recall': 0.38555870730165437, 'precision': 0.4666768766372748, 'f1': 0.40663755227571463, 'auc': 0.7629678572886539, 'prauc': 0.4263905641081964}


Training Batches: 100%|██████████| 1390/1390 [00:41<00:00, 33.22it/s]
Running inference: 100%|██████████| 1033/1033 [00:08<00:00, 121.31it/s]
Running inference: 100%|██████████| 1050/1050 [00:08<00:00, 119.66it/s]

Epoch: 021, Average Loss: 0.1691
Validation: {'recall': 0.3683727690140683, 'precision': 0.4618052251151085, 'f1': 0.3868954476778692, 'auc': 0.7624223225654586, 'prauc': 0.4169395852905235}
Test:       {'recall': 0.37632631852341736, 'precision': 0.47329026094479076, 'f1': 0.3970671887855921, 'auc': 0.7616584715113267, 'prauc': 0.42591130499219304}

Early stopping triggered after 21 epochs (no improvement for 5 epochs).

Best validation performance:
{'recall': 0.3833231860337848, 'precision': 0.4512778343448859, 'f1': 0.3983319802815396, 'auc': 0.7694094321080028, 'prauc': 0.4217448260526136}
Corresponding test performance:
{'recall': 0.38752118976417277, 'precision': 0.4615659356192304, 'f1': 0.4048043710944446, 'auc': 0.7664240303100122, 'prauc': 0.4304230745938008}





In [17]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")


Final Metrics:
recall: 0.3824 ± 0.0080
precision: 0.4617 ± 0.0422
f1: 0.4003 ± 0.0054
auc: 0.7690 ± 0.0046
prauc: 0.4381 ± 0.0078


In [18]:
# Assuming final_metrics and args are defined
# Create the results directory if it doesn't exist
os.makedirs(f"./results/{args['dataset']}", exist_ok=True)

# Define the file name
file_name = f"{args['dataset']}-{args['task']}.txt"
file_path = os.path.join(f"./results/{args['dataset']}", file_name)

# Save the metrics to a text file
with open(file_path, 'w') as f:
    f.write("Final Metrics:\n")
    for key in final_metrics.keys():
        mean_value = np.mean(final_metrics[key])
        std_value = np.std(final_metrics[key])
        f.write(f"{key}: {mean_value:.4f} ± {std_value:.4f}\n")