In [1]:
from set_seed_utils import set_random_seed
import os
import random
import numpy as np
import pickle
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from token_utils_rep import EHRTokenizer
from dataset_utils_rep import HBERTFinetuneEHRDataset, batcher, UniqueIDSampler
from HEART_rep import HBERT_Finetune
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve, precision_recall_fscore_support

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
@torch.no_grad()
def evaluate(model, dataloader, device, task_type="binary"):
    model.eval()
    predicted_scores, gt_labels = [], []
    # gt_labels is ground truth labels, predicted_scores is the output logits
    for _, batch in enumerate(tqdm(dataloader, desc="Running inference")):                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          
        batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]
        labels = batch[-1]
        output_logits = model(*batch[:-1])
        predicted_scores.append(output_logits)
        gt_labels.append(labels)
    
    if task_type == "binary":
        # standard binary classification evaluation
        predicted_scores = torch.cat(predicted_scores, dim=0).view(-1)
        gt_labels = torch.cat(gt_labels, dim=0).view(-1).cpu().numpy()
        scores = predicted_scores.cpu().numpy()      
    #   The threshold should be 0 because:
	# 	Your model outputs logits, not probabilities. logit > 0 ≡ probability > 0.5
        predicted_labels = (predicted_scores > 0).float().cpu().numpy()

        precision = (predicted_labels * gt_labels).sum() / (predicted_labels.sum() + 1e-8)
        recall = (predicted_labels * gt_labels).sum() / (gt_labels.sum() + 1e-8)
        f1 = 2 * precision * recall / (precision + recall + 1e-8)
        roc_auc = roc_auc_score(gt_labels, scores)
        precision_curve, recall_curve, _ = precision_recall_curve(gt_labels, scores)
        pr_auc = auc(recall_curve, precision_curve)

        return {"precision":precision, "recall":recall, "f1":f1, "auc":roc_auc, "prauc":pr_auc}
    else:
        # —— Multi-label classification evaluation (per-class over the batch) —— #
        # predicted_scores: list[Tensor] -> [B, C] logits
        logits = torch.cat(predicted_scores, dim=0)           # [B, C]
        y_true_t = torch.cat(gt_labels, dim=0)                # [B, C]
        # 1) 连续分数用于 AUC/PR-AUC：sigmoid(logits)
        #    CPU + float16 无 sigmoid 实现，先升到 fp32
        logits_for_sigmoid = logits.float() if (logits.device.type == "cpu" and logits.dtype == torch.float16) else logits
        prob_t = torch.sigmoid(logits_for_sigmoid)            # [B, C] in [0,1]

        # 2) 阈值化用于 P/R/F1：logits > 0 等价于 prob >= 0.5
        y_pred_t = (logits > 0).to(torch.int32)               # [B, C]

        # 转 numpy
        y_true = y_true_t.cpu().numpy().astype(np.int32)      # [B, C]
        y_pred = y_pred_t.cpu().numpy().astype(np.int32)      # [B, C]
        scores = prob_t.cpu().numpy()                         # [B, C] float

        # 3) 按类计算 Precision/Recall/F1
        p_cls, r_cls, f1_cls, _ = precision_recall_fscore_support(
            y_true, y_pred, average=None, zero_division=0
        )

        # 4) 按类计算 AUC/PR-AUC（遇到单一类别则设为 NaN）
        C = y_true.shape[1]
        aucs, praucs = [], []
        for c in range(C):
            yt, ys = y_true[:, c], scores[:, c]
            if yt.max() == yt.min():           # 全 0 或全 1，曲线不定义
                aucs.append(np.nan)
                praucs.append(np.nan)
                continue
            aucs.append(roc_auc_score(yt, ys))
            prec_curve, rec_curve, _ = precision_recall_curve(yt, ys)
            praucs.append(auc(rec_curve, prec_curve))

        # 5) 取宏平均（忽略 NaN）
        ave_precision = float(np.mean(p_cls))
        ave_recall    = float(np.mean(r_cls))
        ave_f1        = float(np.mean(f1_cls))
        ave_auc       = float(np.nanmean(aucs))   if np.any(~np.isnan(aucs))   else np.nan
        ave_prauc     = float(np.nanmean(praucs)) if np.any(~np.isnan(praucs)) else np.nan

        ave_f1, ave_auc, ave_prauc, ave_recall, ave_precision = np.mean(ave_f1), np.mean(ave_auc), np.mean(ave_prauc), np.mean(ave_recall), np.mean(ave_precision)
        return {"recall":ave_recall, "precision":ave_precision, "f1":ave_f1, "auc":ave_auc, "prauc":ave_prauc}

In [4]:
args = {
    "seed": 0,
    "dataset": "MIMIC-III", 
    "task": "next_diag_6m",  # options: death, stay, readmission, next_diag_6m, next_diag_12m
    "encoder": "hi_edge",  # options: hi_edge, hi_node, hi_edge_node
    "batch_size": 4,
    "eval_batch_size": 4,
    "pretrain_mask_rate": 0.7,
    "pretrain_anomaly_rate": 0.05,
    "pretrain_anomaly_loss_weight": 1,
    "pretrain_pos_weight": 1,
    "lr": 1e-4,
    "epochs": 50,
    "num_hidden_layers": 5,
    "num_attention_heads": 6,
    "attention_probs_dropout_prob": 0.2,
    "hidden_dropout_prob": 0.2,
    "edge_hidden_size": 32,
    "hidden_size": 288,  # must be divisible by num_attention_heads
    "intermediate_size": 288,
    "save_model": True,
    "gat": "dotattn",
    "gnn_n_heads": 1,
    "gnn_temp": 1,
    "diag_med_emb": "tree",  # simple, tree
    "early_stop_patience": 5,
}

In [5]:
exp_name = "Pretrain-HBERT" \
    + "-" + str(args["dataset"]) \
    + "-" + str(args["encoder"]) \
    + "-" + str(args["pretrain_mask_rate"]) \
    + "-" + str(args["pretrain_anomaly_rate"]) \
    + "-" + str(args["pretrain_anomaly_loss_weight"]) \
    + "-" + str(args["hidden_size"]) \
    + "-" + str(args["edge_hidden_size"]) \
    + "-" + str(args["num_hidden_layers"]) \
    + "-" + str(args["num_attention_heads"]) \
    + "-" + str(args["attention_probs_dropout_prob"]) \
    + "-" + str(args["hidden_dropout_prob"]) \
    + "-" + str(args["intermediate_size"]) \
    + "-" + str(args["gat"]) \
    + "-" + str(args["gnn_n_heads"]) \
    + "-" + str(args["gnn_temp"]) \
    + "-" + str(args["diag_med_emb"])
print(exp_name)

Pretrain-HBERT-MIMIC-III-hi_edge-0.7-0.05-1-288-32-5-6-0.2-0.2-288-dotattn-1-1-tree


In [6]:
pretrained_weight_path = "./pretrained_models/" + exp_name + f"/pretrained_model.pt"
finetune_exp_name = f"Finetune-{args['task']}-" + exp_name
save_path = "./saved_model/" + finetune_exp_name
if args["save_model"] and not os.path.exists(save_path):
    os.makedirs(save_path)

In [7]:
args["predicted_token_type"] = ["diag", "med", "pro", "lab"]
args["special_tokens"] = ("[PAD]", "[CLS]", "[SEP]", "[MASK0]", "[MASK1]", "[MASK2]", "[MASK3]")
args["max_visit_size"] = 15

full_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic.pkl"

if args["task"] == "next_diag_6m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_6m.pkl"
elif args["task"] == "next_diag_12m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_downstream.pkl"

In [8]:
ehr_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_data["NDC"].values.tolist()
lab_sentences = ehr_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_data["PRO_CODE"].values.tolist()
gender_set = [["M"], ["F"]]
age_gender_set = [[str(c) + "_" + gender] for c in set(ehr_data["AGE"].values.tolist()) for gender in ["M", "F"]]
age_set = [[c] for c in set(ehr_data["AGE"].values.tolist())]    

In [9]:
tokenizer = EHRTokenizer(diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, gender_set, age_set, age_gender_set, special_tokens=args["special_tokens"])
tokenizer.build_tree()

In [10]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))

train_dataset = HBERTFinetuneEHRDataset(
    train_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

val_dataset = HBERTFinetuneEHRDataset(
    val_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

test_dataset = HBERTFinetuneEHRDataset(
    test_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

train_dataloader = DataLoader(
    train_dataset, 
    batch_sampler=UniqueIDSampler(train_dataset.get_ids(), batch_size=args["batch_size"]),
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False), 
)

val_dataloader = DataLoader(
    val_dataset, 
    batch_sampler=UniqueIDSampler(val_dataset.get_ids(), batch_size=args["batch_size"]),
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False), 
)

test_dataloader = DataLoader(
    test_dataset, 
    batch_sampler=UniqueIDSampler(test_dataset.get_ids(), batch_size=args["eval_batch_size"]),
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False),
)

In [11]:
# examine a batch
batch = next(iter(train_dataloader))  # 取第一个 batch
input_ids, input_types, edge_index, visit_positions, labeled_batch_idx, labels = batch

# 打印每个张量的形状
print("input_ids shape:", input_ids.shape)
print("input_types shape:", input_types.shape)
print("visit_positions shape:", visit_positions.shape)
print("labeled_batch_idx shape:", len(labeled_batch_idx)) # it is a list
print("labels shape:", labels.shape)

input_ids shape: torch.Size([10, 89])
input_types shape: torch.Size([10, 89])
visit_positions shape: torch.Size([10])
labeled_batch_idx shape: 4
labels shape: torch.Size([4, 18])


In [12]:
args["vocab_size"] = len(args["special_tokens"]) + \
                     len(tokenizer.diag_voc.id2word) + \
                     len(tokenizer.pro_voc.id2word) + \
                     len(tokenizer.med_voc.id2word) + \
                     len(tokenizer.lab_voc.id2word) + \
                     len(tokenizer.age_voc.id2word) + \
                     len(tokenizer.gender_voc.id2word) + \
                     len(tokenizer.age_gender_voc.id2word)
args["label_vocab_size"] = 18  # only for diagnosis

In [13]:
if args["task"] in ["death", "stay", "readmission"]:
    eval_metric = "f1"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "prauc"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [14]:
def train_with_early_stopping(model, 
                              train_dataloader, 
                              val_dataloader, 
                              test_dataloader,
                              optimizer, 
                              loss_fn, 
                              device, 
                              args,
                              task_type="binary", 
                              eval_metric="f1"):
    best_score = 0.
    best_val_metric = None
    best_test_metric = None
    epochs_no_improve = 0

    for epoch in range(1, 1 + args["epochs"]):
        model.train()
        ave_loss = 0.

        for step, batch in enumerate(tqdm(train_dataloader, desc="Training Batches")):
            batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]

            labels = batch[-1].float()
            output_logits = model(*batch[:-1])
            
            loss = loss_fn(output_logits.view(-1), labels.view(-1))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            ave_loss += loss.item()
            

        ave_loss /= (step + 1)

        # Evaluation
        val_metric = evaluate(model, val_dataloader, device, task_type=task_type)
        test_metric = evaluate(model, test_dataloader, device, task_type=task_type)

        # Logging
        print(f"Epoch: {epoch:03d}, Average Loss: {ave_loss:.4f}")
        print(f"Validation: {val_metric}")
        print(f"Test:       {test_metric}")

        # Check for improvement
        current_score = val_metric[eval_metric]
        if current_score > best_score:
            best_score = current_score
            best_val_metric = val_metric
            best_test_metric = test_metric
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        # Early stopping check
        if epochs_no_improve >= args["early_stop_patience"]:
            print(f"\nEarly stopping triggered after {epoch} epochs (no improvement for {args['early_stop_patience']} epochs).")
            break

    print("\nBest validation performance:")
    print(best_val_metric)
    print("Corresponding test performance:")
    print(best_test_metric)
    return best_test_metric

In [15]:
random.seed(42)
seeds = [random.randint(0, 2**32 - 1) for _ in range(5)]
print(seeds)

[2746317213, 1181241943, 958682846, 3163119785, 1812140441]


In [16]:
final_metrics = {"recall":[], "precision":[],"f1":[],"auc":[],"prauc":[]}

for seed in seeds:
    args["seed"] = seed
    set_random_seed(args["seed"])
    print(f"Training with seed: {args['seed']}")
    
    # Initialize model, optimizer, and loss function
    model = HBERT_Finetune(args, tokenizer)
    model.load_weight(torch.load(pretrained_weight_path, weights_only=True))
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args["lr"])
    
    best_test_metric = train_with_early_stopping(
        model, 
        train_dataloader, 
        val_dataloader, 
        test_dataloader,
        optimizer, 
        loss_fn, 
        device, 
        args,
        task_type=task_type)
    
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

[INFO] Random seed set to 2746317213
Training with seed: 2746317213


Training Batches: 100%|██████████| 380/380 [00:15<00:00, 24.58it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 88.95it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.07it/s]


Epoch: 001, Average Loss: 0.4284
Validation: {'recall': 0.23114538060290338, 'precision': 0.34575355150539333, 'f1': 0.26218197097288964, 'auc': 0.6777470058327965, 'prauc': 0.38973807061213545}
Test:       {'recall': 0.23942123049673297, 'precision': 0.34023843788330627, 'f1': 0.2661724601238128, 'auc': 0.6872898617495933, 'prauc': 0.38235237294089985}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.99it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 90.54it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 85.24it/s]


Epoch: 002, Average Loss: 0.3876
Validation: {'recall': 0.2923791268234813, 'precision': 0.40832870023311474, 'f1': 0.3048856931975127, 'auc': 0.7040926867746792, 'prauc': 0.4094795119916842}
Test:       {'recall': 0.29076483632850736, 'precision': 0.3533181293115708, 'f1': 0.3030245167737972, 'auc': 0.7144482272245387, 'prauc': 0.40551119754366133}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.24it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 90.02it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.87it/s]


Epoch: 003, Average Loss: 0.3671
Validation: {'recall': 0.287638091555267, 'precision': 0.42229618047979817, 'f1': 0.310548637251089, 'auc': 0.7223621457127217, 'prauc': 0.43019685461765333}
Test:       {'recall': 0.2874259168293432, 'precision': 0.3923244447962306, 'f1': 0.3100942313246827, 'auc': 0.7279455284466639, 'prauc': 0.42476093690538697}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.28it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 88.55it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 85.99it/s]


Epoch: 004, Average Loss: 0.3544
Validation: {'recall': 0.30503688814345814, 'precision': 0.40240004130106605, 'f1': 0.3151473962598291, 'auc': 0.7238887471933437, 'prauc': 0.4338729404955928}
Test:       {'recall': 0.30333624614323007, 'precision': 0.39929212939405384, 'f1': 0.31666979558212893, 'auc': 0.7272747803805857, 'prauc': 0.4308735293278622}


Training Batches: 100%|██████████| 380/380 [00:15<00:00, 24.87it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 86.95it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.13it/s]


Epoch: 005, Average Loss: 0.3400
Validation: {'recall': 0.3396268838339989, 'precision': 0.45036371973647277, 'f1': 0.36244498647907525, 'auc': 0.730173580971416, 'prauc': 0.4393382370881178}
Test:       {'recall': 0.3390198648305669, 'precision': 0.43858454409034175, 'f1': 0.35756136350195566, 'auc': 0.7308937147881926, 'prauc': 0.43619676398671936}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.70it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 87.64it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 82.55it/s]


Epoch: 006, Average Loss: 0.3303
Validation: {'recall': 0.3138167354423889, 'precision': 0.44580269137857836, 'f1': 0.34448985541177163, 'auc': 0.7194917036115686, 'prauc': 0.43868109749269646}
Test:       {'recall': 0.30768752857903114, 'precision': 0.48388847271204516, 'f1': 0.33491940797384206, 'auc': 0.7287678410589011, 'prauc': 0.4401423940921577}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.35it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 88.72it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.89it/s]


Epoch: 007, Average Loss: 0.3218
Validation: {'recall': 0.3173340244941719, 'precision': 0.4540647599191831, 'f1': 0.348724908146082, 'auc': 0.7206567232811124, 'prauc': 0.4338653283468365}
Test:       {'recall': 0.324882321668813, 'precision': 0.42999998268060297, 'f1': 0.3511025958332296, 'auc': 0.728886971365785, 'prauc': 0.43826711773331195}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.20it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.26it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.98it/s]


Epoch: 008, Average Loss: 0.3118
Validation: {'recall': 0.3616933210533435, 'precision': 0.4370116785233031, 'f1': 0.3787826980323399, 'auc': 0.7253339161660031, 'prauc': 0.4399506799166433}
Test:       {'recall': 0.36123177090774344, 'precision': 0.41995065215632804, 'f1': 0.37666058816677783, 'auc': 0.7279796895169899, 'prauc': 0.43980906401201697}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.14it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.79it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.63it/s]


Epoch: 009, Average Loss: 0.3020
Validation: {'recall': 0.35715097587351996, 'precision': 0.4425431543969903, 'f1': 0.3821765418566149, 'auc': 0.7260226798011742, 'prauc': 0.43647079739248185}
Test:       {'recall': 0.3596972517432993, 'precision': 0.4511841210132281, 'f1': 0.3843873113383624, 'auc': 0.7318661934121269, 'prauc': 0.4362336160085961}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.02it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 90.01it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.46it/s]


Epoch: 010, Average Loss: 0.2923
Validation: {'recall': 0.3675380831574877, 'precision': 0.42473603193971704, 'f1': 0.37939997596215425, 'auc': 0.717458406715113, 'prauc': 0.430444528521608}
Test:       {'recall': 0.37373143592981584, 'precision': 0.43630630404604237, 'f1': 0.385519348033328, 'auc': 0.7241442990657428, 'prauc': 0.43019742605666983}


Training Batches: 100%|██████████| 380/380 [00:15<00:00, 25.05it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 86.09it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 82.44it/s]


Epoch: 011, Average Loss: 0.2843
Validation: {'recall': 0.3167358287240928, 'precision': 0.4549114552889426, 'f1': 0.3564562690893337, 'auc': 0.7167075745559246, 'prauc': 0.4301426335583466}
Test:       {'recall': 0.3215356059135549, 'precision': 0.46057142682273383, 'f1': 0.3619580728566577, 'auc': 0.7247227642192162, 'prauc': 0.42631277040198623}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.44it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.05it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.10it/s]


Epoch: 012, Average Loss: 0.2756
Validation: {'recall': 0.3560999765952515, 'precision': 0.41714487980055726, 'f1': 0.3744675748890352, 'auc': 0.7160084342833685, 'prauc': 0.429311312319535}
Test:       {'recall': 0.34999357035716244, 'precision': 0.4202558811086051, 'f1': 0.3688989157915165, 'auc': 0.7254565569843887, 'prauc': 0.4275947350471101}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.86it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.11it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.00it/s]


Epoch: 013, Average Loss: 0.2680
Validation: {'recall': 0.37578006097413635, 'precision': 0.4596911049572111, 'f1': 0.38453766235348635, 'auc': 0.7160660345310501, 'prauc': 0.42540952533366516}
Test:       {'recall': 0.3805883936375351, 'precision': 0.41933127701054435, 'f1': 0.3889947895082822, 'auc': 0.7235827279771906, 'prauc': 0.4225118729003307}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.91it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.96it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.27it/s]


Epoch: 014, Average Loss: 0.2618
Validation: {'recall': 0.3836737878842294, 'precision': 0.4737698917710779, 'f1': 0.3936005288448346, 'auc': 0.7204812317656608, 'prauc': 0.42864157504998635}
Test:       {'recall': 0.38029033069931895, 'precision': 0.4529387048100389, 'f1': 0.3916076567100867, 'auc': 0.7276741622453327, 'prauc': 0.43232704488491625}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.88it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 87.81it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.91it/s]


Epoch: 015, Average Loss: 0.2542
Validation: {'recall': 0.36973388327359535, 'precision': 0.504461399446223, 'f1': 0.38629943801174893, 'auc': 0.7225854268807239, 'prauc': 0.4298180743719617}
Test:       {'recall': 0.3585424725299995, 'precision': 0.42047429909730255, 'f1': 0.3771642583184589, 'auc': 0.7217795280615363, 'prauc': 0.42722418849931126}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.80it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 90.08it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.20it/s]


Epoch: 016, Average Loss: 0.2439
Validation: {'recall': 0.4010783002410451, 'precision': 0.4128248889514868, 'f1': 0.38804391897849905, 'auc': 0.716680824413219, 'prauc': 0.4221291928508815}
Test:       {'recall': 0.40246101341226737, 'precision': 0.45130126846531937, 'f1': 0.3955899546695243, 'auc': 0.7186296838304449, 'prauc': 0.4231828308685497}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.57it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 87.75it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 85.27it/s]


Epoch: 017, Average Loss: 0.2363
Validation: {'recall': 0.372635262080155, 'precision': 0.46295100861997557, 'f1': 0.37865969007103956, 'auc': 0.7119260952659903, 'prauc': 0.42211869334435087}
Test:       {'recall': 0.3730859580186685, 'precision': 0.42795478579111657, 'f1': 0.3777629021818767, 'auc': 0.7154708330846081, 'prauc': 0.41730234910521313}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.28it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 88.82it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.74it/s]


Epoch: 018, Average Loss: 0.2336
Validation: {'recall': 0.336885111666798, 'precision': 0.45039682644822965, 'f1': 0.37003878202188506, 'auc': 0.7119617085440706, 'prauc': 0.4227634838678747}
Test:       {'recall': 0.33351628523141597, 'precision': 0.4340852512638188, 'f1': 0.3653875709639647, 'auc': 0.7195178020027375, 'prauc': 0.42174113398548224}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.00it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.82it/s] 
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.63it/s]


Epoch: 019, Average Loss: 0.2257
Validation: {'recall': 0.36507768974519705, 'precision': 0.45825419784715415, 'f1': 0.38527985311231255, 'auc': 0.7067838075065497, 'prauc': 0.417102410802526}
Test:       {'recall': 0.3711657880397795, 'precision': 0.4151230118539479, 'f1': 0.38826448246238265, 'auc': 0.7192815932877082, 'prauc': 0.416528584555322}

Early stopping triggered after 19 epochs (no improvement for 5 epochs).

Best validation performance:
{'recall': 0.3836737878842294, 'precision': 0.4737698917710779, 'f1': 0.3936005288448346, 'auc': 0.7204812317656608, 'prauc': 0.42864157504998635}
Corresponding test performance:
{'recall': 0.38029033069931895, 'precision': 0.4529387048100389, 'f1': 0.3916076567100867, 'auc': 0.7276741622453327, 'prauc': 0.43232704488491625}
[INFO] Random seed set to 1181241943
Training with seed: 1181241943


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.17it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 90.52it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.13it/s]


Epoch: 001, Average Loss: 0.4352
Validation: {'recall': 0.22168400312371955, 'precision': 0.37707876478557434, 'f1': 0.25384465354307884, 'auc': 0.6884465802130874, 'prauc': 0.39079463629189415}
Test:       {'recall': 0.22594643718323804, 'precision': 0.3720689636479436, 'f1': 0.2576824590211382, 'auc': 0.6883319310748982, 'prauc': 0.3856484578021599}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.02it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 90.64it/s] 
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.97it/s]


Epoch: 002, Average Loss: 0.3885
Validation: {'recall': 0.28891040389025324, 'precision': 0.4053312787791996, 'f1': 0.3080907554692515, 'auc': 0.7037072992045345, 'prauc': 0.41552654791828403}
Test:       {'recall': 0.29368495307637493, 'precision': 0.39516816263644955, 'f1': 0.3089120183652173, 'auc': 0.7153100894248907, 'prauc': 0.40580311557491994}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.05it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.67it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 85.04it/s]


Epoch: 003, Average Loss: 0.3720
Validation: {'recall': 0.30418119376491104, 'precision': 0.3917527313742057, 'f1': 0.3214817913808148, 'auc': 0.7186361890168084, 'prauc': 0.42129485622424045}
Test:       {'recall': 0.3038847953151189, 'precision': 0.3675239746757995, 'f1': 0.3195146228822576, 'auc': 0.7266618329890027, 'prauc': 0.4160806041366986}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.85it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.77it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.52it/s]


Epoch: 004, Average Loss: 0.3568
Validation: {'recall': 0.2928378229145636, 'precision': 0.4356285830036686, 'f1': 0.3105711272176513, 'auc': 0.7182330461162876, 'prauc': 0.4259435895537362}
Test:       {'recall': 0.27917195955761265, 'precision': 0.4236639142519591, 'f1': 0.3023647063358992, 'auc': 0.7202076264981846, 'prauc': 0.4169212266313804}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.15it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.43it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 85.11it/s]


Epoch: 005, Average Loss: 0.3440
Validation: {'recall': 0.3680832887269289, 'precision': 0.4183472769361021, 'f1': 0.3849719058379726, 'auc': 0.731968125686428, 'prauc': 0.4442182407458499}
Test:       {'recall': 0.36042328344851726, 'precision': 0.4200239929470501, 'f1': 0.380180677536377, 'auc': 0.7372494507596582, 'prauc': 0.4401901196384688}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.87it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.14it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.90it/s]


Epoch: 006, Average Loss: 0.3311
Validation: {'recall': 0.307877038902636, 'precision': 0.48763263628346354, 'f1': 0.35729698351801253, 'auc': 0.7332143671740262, 'prauc': 0.4503331236476251}
Test:       {'recall': 0.2960256540076435, 'precision': 0.44172444099869324, 'f1': 0.34133301039852076, 'auc': 0.7358750520890887, 'prauc': 0.4466255017482887}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.96it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.59it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 85.48it/s]


Epoch: 007, Average Loss: 0.3197
Validation: {'recall': 0.32819850162226616, 'precision': 0.4513987079962241, 'f1': 0.3667019918036718, 'auc': 0.7307838959617786, 'prauc': 0.4452066667648054}
Test:       {'recall': 0.32018856030264586, 'precision': 0.4859760332094406, 'f1': 0.3581267006783755, 'auc': 0.731552776557123, 'prauc': 0.4451374341296066}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.99it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.55it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 82.65it/s]


Epoch: 008, Average Loss: 0.3091
Validation: {'recall': 0.35164813406341383, 'precision': 0.44949264317910603, 'f1': 0.37811681087632204, 'auc': 0.7341995952043778, 'prauc': 0.44441924406066635}
Test:       {'recall': 0.3474174477556626, 'precision': 0.4403929845725011, 'f1': 0.37441858645366105, 'auc': 0.7288528507963279, 'prauc': 0.44239217120168317}


Training Batches: 100%|██████████| 380/380 [00:15<00:00, 24.94it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 87.87it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.69it/s]


Epoch: 009, Average Loss: 0.3029
Validation: {'recall': 0.3785575782896682, 'precision': 0.41813066468960186, 'f1': 0.39351828406800965, 'auc': 0.7326787578054434, 'prauc': 0.4406527796174559}
Test:       {'recall': 0.3714748362663613, 'precision': 0.499717861352593, 'f1': 0.38903767627259833, 'auc': 0.7289876124481958, 'prauc': 0.4395819419747742}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.13it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 90.00it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.33it/s]


Epoch: 010, Average Loss: 0.2892
Validation: {'recall': 0.39993577554697335, 'precision': 0.3966908837494446, 'f1': 0.3902946365557567, 'auc': 0.723143363862198, 'prauc': 0.4320966905347589}
Test:       {'recall': 0.4012305771798758, 'precision': 0.4111720750813534, 'f1': 0.39269321609827273, 'auc': 0.7173127201224757, 'prauc': 0.43033794752071536}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.04it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.13it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.66it/s]


Epoch: 011, Average Loss: 0.2792
Validation: {'recall': 0.3532680078314761, 'precision': 0.43361712592018975, 'f1': 0.3724913263257718, 'auc': 0.7246777271032111, 'prauc': 0.4338543876103834}
Test:       {'recall': 0.3538900306114818, 'precision': 0.44133924813431363, 'f1': 0.37350424196273774, 'auc': 0.7226309387193162, 'prauc': 0.43406281689085857}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.15it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 90.01it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.67it/s]


Epoch: 012, Average Loss: 0.2719
Validation: {'recall': 0.34809313986686574, 'precision': 0.43310130293136473, 'f1': 0.3810643704679453, 'auc': 0.7242653390291696, 'prauc': 0.4366912201847485}
Test:       {'recall': 0.35273607500312426, 'precision': 0.4405234965211793, 'f1': 0.38799120207060206, 'auc': 0.7237959395370673, 'prauc': 0.43192425260996636}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.41it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.90it/s] 
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.59it/s]


Epoch: 013, Average Loss: 0.2647
Validation: {'recall': 0.4241333317650804, 'precision': 0.4092230111423209, 'f1': 0.4018270728706678, 'auc': 0.7213992993321974, 'prauc': 0.4323747643203064}
Test:       {'recall': 0.4277422817655751, 'precision': 0.43969412208464426, 'f1': 0.40779835799986386, 'auc': 0.7143280597358114, 'prauc': 0.4330436844605197}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.58it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 86.74it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 82.29it/s]


Epoch: 014, Average Loss: 0.2555
Validation: {'recall': 0.35465139561808506, 'precision': 0.4336041674491663, 'f1': 0.3795174577926257, 'auc': 0.7259634550022839, 'prauc': 0.4328742646791045}
Test:       {'recall': 0.35845376551822955, 'precision': 0.4430846919096023, 'f1': 0.385243833027977, 'auc': 0.7224619696810347, 'prauc': 0.43306295379438026}


Training Batches: 100%|██████████| 380/380 [00:15<00:00, 24.90it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.00it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.42it/s]


Epoch: 015, Average Loss: 0.2453
Validation: {'recall': 0.3673659119091902, 'precision': 0.42457972285587436, 'f1': 0.3851194676675821, 'auc': 0.7175635266718914, 'prauc': 0.4266286117283214}
Test:       {'recall': 0.37036086387447087, 'precision': 0.43928397849159306, 'f1': 0.3924343633055441, 'auc': 0.7194338839101042, 'prauc': 0.43200008038630644}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.73it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.34it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 85.12it/s]


Epoch: 016, Average Loss: 0.2401
Validation: {'recall': 0.3524409907512752, 'precision': 0.42214021877842367, 'f1': 0.37842585569928505, 'auc': 0.7221381312188861, 'prauc': 0.42986204836769565}
Test:       {'recall': 0.34726554462382747, 'precision': 0.4472550001855456, 'f1': 0.37900964892967465, 'auc': 0.7181950278788806, 'prauc': 0.4260842943290414}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.88it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 90.17it/s] 
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.81it/s]


Epoch: 017, Average Loss: 0.2313
Validation: {'recall': 0.37117681339360886, 'precision': 0.41255054037019584, 'f1': 0.38262070654871727, 'auc': 0.7177140895584525, 'prauc': 0.4214158426150457}
Test:       {'recall': 0.3717592655432336, 'precision': 0.42990807854824387, 'f1': 0.38886211193143905, 'auc': 0.7124926078603013, 'prauc': 0.42479608455488566}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.98it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.67it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.42it/s]


Epoch: 018, Average Loss: 0.2245
Validation: {'recall': 0.3815126148787405, 'precision': 0.41275596351086624, 'f1': 0.38843069843198613, 'auc': 0.7157475975811851, 'prauc': 0.4246306886820078}
Test:       {'recall': 0.3888929969299678, 'precision': 0.41991445293223506, 'f1': 0.3957173714101873, 'auc': 0.717870651515083, 'prauc': 0.4314076407924949}

Early stopping triggered after 18 epochs (no improvement for 5 epochs).

Best validation performance:
{'recall': 0.4241333317650804, 'precision': 0.4092230111423209, 'f1': 0.4018270728706678, 'auc': 0.7213992993321974, 'prauc': 0.4323747643203064}
Corresponding test performance:
{'recall': 0.4277422817655751, 'precision': 0.43969412208464426, 'f1': 0.40779835799986386, 'auc': 0.7143280597358114, 'prauc': 0.4330436844605197}
[INFO] Random seed set to 958682846
Training with seed: 958682846


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.83it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 88.43it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 85.40it/s]


Epoch: 001, Average Loss: 0.4296
Validation: {'recall': 0.23967510287254498, 'precision': 0.37306899671319194, 'f1': 0.26917375024896345, 'auc': 0.6919244679522254, 'prauc': 0.4037472353353413}
Test:       {'recall': 0.24219645117769273, 'precision': 0.36393270878962247, 'f1': 0.2704087102602043, 'auc': 0.6854780633424914, 'prauc': 0.3955456540572267}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.66it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 87.12it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 82.18it/s]


Epoch: 002, Average Loss: 0.3824
Validation: {'recall': 0.2926882243158516, 'precision': 0.39109371704093426, 'f1': 0.3197575521530686, 'auc': 0.7135288199404072, 'prauc': 0.43123672002298885}
Test:       {'recall': 0.29595588144666507, 'precision': 0.3880811173465894, 'f1': 0.3217462124303272, 'auc': 0.7125132648833667, 'prauc': 0.42281715041681933}


Training Batches: 100%|██████████| 380/380 [00:15<00:00, 24.68it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 86.45it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.14it/s]


Epoch: 003, Average Loss: 0.3659
Validation: {'recall': 0.27880189706131486, 'precision': 0.43518859030798995, 'f1': 0.31390502877291415, 'auc': 0.7195215989072603, 'prauc': 0.4265127001808003}
Test:       {'recall': 0.28216401305501626, 'precision': 0.3760214257124152, 'f1': 0.3140966217397282, 'auc': 0.7185626934404432, 'prauc': 0.4274848204839759}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.96it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 88.89it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.26it/s]


Epoch: 004, Average Loss: 0.3524
Validation: {'recall': 0.3506885680024173, 'precision': 0.38083125323096956, 'f1': 0.3595463727927219, 'auc': 0.7413820955822542, 'prauc': 0.43930784270522405}
Test:       {'recall': 0.3496708321675232, 'precision': 0.3813353819152572, 'f1': 0.35949168581649815, 'auc': 0.7330633840159657, 'prauc': 0.43491767873827486}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.39it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 90.87it/s] 
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.08it/s]


Epoch: 005, Average Loss: 0.3411
Validation: {'recall': 0.3336675498770467, 'precision': 0.42927239190640526, 'f1': 0.35453629276232923, 'auc': 0.7380780011623693, 'prauc': 0.44177797390225365}
Test:       {'recall': 0.33892511173486795, 'precision': 0.4440816783019768, 'f1': 0.3596497171873321, 'auc': 0.7266197049973858, 'prauc': 0.44073991381681255}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.18it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 90.09it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.66it/s]


Epoch: 006, Average Loss: 0.3316
Validation: {'recall': 0.330698110145782, 'precision': 0.4184852998503386, 'f1': 0.35891035568775354, 'auc': 0.7368152512003894, 'prauc': 0.4368587202929057}
Test:       {'recall': 0.3260999336763477, 'precision': 0.4254305927957899, 'f1': 0.35612027155816295, 'auc': 0.7249697408644993, 'prauc': 0.4352652091511892}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.13it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.50it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 82.88it/s]


Epoch: 007, Average Loss: 0.3210
Validation: {'recall': 0.33329158375402657, 'precision': 0.4618507746258405, 'f1': 0.3686413125446719, 'auc': 0.7360796546586038, 'prauc': 0.4411561660014764}
Test:       {'recall': 0.32865985298442457, 'precision': 0.4345514318103715, 'f1': 0.3630956874588552, 'auc': 0.7314676090631677, 'prauc': 0.4456841193776148}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.19it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.72it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 82.78it/s]


Epoch: 008, Average Loss: 0.3121
Validation: {'recall': 0.35591644339385514, 'precision': 0.4452603500883849, 'f1': 0.38515924970105625, 'auc': 0.7360543317434051, 'prauc': 0.4383437074469725}
Test:       {'recall': 0.35680749284936997, 'precision': 0.44610936599258444, 'f1': 0.38659774038556277, 'auc': 0.7306168299731708, 'prauc': 0.43920996577046995}


Training Batches: 100%|██████████| 380/380 [00:15<00:00, 24.68it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 88.81it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 81.98it/s]


Epoch: 009, Average Loss: 0.3028
Validation: {'recall': 0.3644490045381201, 'precision': 0.4204180862288999, 'f1': 0.37775808831898616, 'auc': 0.7398064733349735, 'prauc': 0.43293119917864925}
Test:       {'recall': 0.35307135086645913, 'precision': 0.41679868615260873, 'f1': 0.3683619836487247, 'auc': 0.7201521093102197, 'prauc': 0.4356552059107349}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.29it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 90.03it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.17it/s]


Epoch: 010, Average Loss: 0.2934
Validation: {'recall': 0.3921638239105138, 'precision': 0.4219972460778758, 'f1': 0.3962335658511365, 'auc': 0.7316027419724512, 'prauc': 0.4391381261474791}
Test:       {'recall': 0.3854281705600965, 'precision': 0.4670789507367801, 'f1': 0.3905318868999555, 'auc': 0.7229351967792206, 'prauc': 0.44242541124266865}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.93it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 88.49it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.70it/s]


Epoch: 011, Average Loss: 0.2845
Validation: {'recall': 0.3836378026130504, 'precision': 0.42262340776405904, 'f1': 0.3945035470352224, 'auc': 0.7334919979258271, 'prauc': 0.4404852535769189}
Test:       {'recall': 0.37556688554763795, 'precision': 0.40560337251296136, 'f1': 0.38777763159060863, 'auc': 0.728349259145225, 'prauc': 0.4427745703221845}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.21it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 90.44it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.31it/s]


Epoch: 012, Average Loss: 0.2800
Validation: {'recall': 0.3704709299249027, 'precision': 0.43752314907169915, 'f1': 0.39410705854132, 'auc': 0.733617253366997, 'prauc': 0.44340157205577824}
Test:       {'recall': 0.36321542770776244, 'precision': 0.45636891055825957, 'f1': 0.38838055512849834, 'auc': 0.7334966349976031, 'prauc': 0.4471489360111646}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.13it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.56it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.07it/s]


Epoch: 013, Average Loss: 0.2697
Validation: {'recall': 0.38490283684925114, 'precision': 0.42759155357641915, 'f1': 0.3957864397496555, 'auc': 0.7279695867875904, 'prauc': 0.42966646322818075}
Test:       {'recall': 0.3779649322794411, 'precision': 0.43352564782842595, 'f1': 0.39056045430546116, 'auc': 0.7293772821968849, 'prauc': 0.43805846411069393}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.15it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 88.38it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.46it/s]


Epoch: 014, Average Loss: 0.2611
Validation: {'recall': 0.34746468007456166, 'precision': 0.40416858704667863, 'f1': 0.3678813974635397, 'auc': 0.7272106939480146, 'prauc': 0.4313850739476911}
Test:       {'recall': 0.35562240404665, 'precision': 0.4396350625099884, 'f1': 0.37865591116065145, 'auc': 0.7259596623950791, 'prauc': 0.4390603228015069}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.49it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.18it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.95it/s]


Epoch: 015, Average Loss: 0.2553
Validation: {'recall': 0.37893909380982815, 'precision': 0.4369682879896606, 'f1': 0.39333258695746637, 'auc': 0.72331898070999, 'prauc': 0.428099872996316}
Test:       {'recall': 0.3751811260683041, 'precision': 0.4255858390199572, 'f1': 0.3926980118016835, 'auc': 0.7248854486508622, 'prauc': 0.43678179715045806}

Early stopping triggered after 15 epochs (no improvement for 5 epochs).

Best validation performance:
{'recall': 0.3921638239105138, 'precision': 0.4219972460778758, 'f1': 0.3962335658511365, 'auc': 0.7316027419724512, 'prauc': 0.4391381261474791}
Corresponding test performance:
{'recall': 0.3854281705600965, 'precision': 0.4670789507367801, 'f1': 0.3905318868999555, 'auc': 0.7229351967792206, 'prauc': 0.44242541124266865}
[INFO] Random seed set to 3163119785
Training with seed: 3163119785


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.03it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 88.71it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.03it/s]


Epoch: 001, Average Loss: 0.4359
Validation: {'recall': 0.20199063961104155, 'precision': 0.3597654475153329, 'f1': 0.22637643505208127, 'auc': 0.6982014326098199, 'prauc': 0.39570170999970267}
Test:       {'recall': 0.20449576441276912, 'precision': 0.3617861856023239, 'f1': 0.22907524747405017, 'auc': 0.7061678906259007, 'prauc': 0.3935216992939736}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.97it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 88.00it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.80it/s]


Epoch: 002, Average Loss: 0.3826
Validation: {'recall': 0.30683179142490863, 'precision': 0.40169215680018017, 'f1': 0.3321017597776989, 'auc': 0.7253150856900324, 'prauc': 0.430300334417487}
Test:       {'recall': 0.3092156276939385, 'precision': 0.40066609683170074, 'f1': 0.3330050804142383, 'auc': 0.7274223894408091, 'prauc': 0.4202860987526912}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.28it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.54it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.09it/s]


Epoch: 003, Average Loss: 0.3618
Validation: {'recall': 0.3192530372427512, 'precision': 0.4157893743488537, 'f1': 0.35007606478479275, 'auc': 0.7147347219476604, 'prauc': 0.4265599545767117}
Test:       {'recall': 0.3216958632676046, 'precision': 0.4135755663291925, 'f1': 0.35135795785472584, 'auc': 0.7283337380406061, 'prauc': 0.42300564839126115}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.79it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 90.03it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.99it/s]


Epoch: 004, Average Loss: 0.3484
Validation: {'recall': 0.2679770335455859, 'precision': 0.4170179251181796, 'f1': 0.30132047776927984, 'auc': 0.7267753078775526, 'prauc': 0.4298674119650877}
Test:       {'recall': 0.2672112470727391, 'precision': 0.4134720803127936, 'f1': 0.2972200709926798, 'auc': 0.7232054736779502, 'prauc': 0.42059254663282997}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.56it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 86.99it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 81.03it/s]


Epoch: 005, Average Loss: 0.3393
Validation: {'recall': 0.335548560872398, 'precision': 0.41523591691805994, 'f1': 0.35222313419566237, 'auc': 0.7297896526483455, 'prauc': 0.4296717932857969}
Test:       {'recall': 0.3398973998807963, 'precision': 0.418050376151333, 'f1': 0.3544648889467057, 'auc': 0.7286785791982541, 'prauc': 0.42673754815867515}


Training Batches: 100%|██████████| 380/380 [00:15<00:00, 24.42it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 86.73it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.06it/s]


Epoch: 006, Average Loss: 0.3285
Validation: {'recall': 0.3242038355846393, 'precision': 0.40432592951229096, 'f1': 0.3428555296736199, 'auc': 0.7257408058007648, 'prauc': 0.4280823982097602}
Test:       {'recall': 0.32916856870550504, 'precision': 0.4039894202759871, 'f1': 0.3489015407948048, 'auc': 0.7184452695417122, 'prauc': 0.4259947848387908}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.00it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 88.87it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 85.00it/s]


Epoch: 007, Average Loss: 0.3184
Validation: {'recall': 0.3534444634372703, 'precision': 0.40781821749470454, 'f1': 0.3693238931275513, 'auc': 0.721883158065831, 'prauc': 0.43052509579150117}
Test:       {'recall': 0.36326984961760617, 'precision': 0.4162715328780522, 'f1': 0.37828061615524117, 'auc': 0.7298321660528438, 'prauc': 0.431194849071547}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.20it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.14it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.79it/s]


Epoch: 008, Average Loss: 0.3077
Validation: {'recall': 0.372105893918952, 'precision': 0.4190547628810171, 'f1': 0.3778116348765353, 'auc': 0.7121434558089806, 'prauc': 0.42773298322830505}
Test:       {'recall': 0.3703626989132079, 'precision': 0.4036302872373372, 'f1': 0.38123031377591476, 'auc': 0.7233750357275777, 'prauc': 0.4316825664289017}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.91it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.64it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 85.54it/s]


Epoch: 009, Average Loss: 0.3004
Validation: {'recall': 0.3682225055448371, 'precision': 0.4179059373692863, 'f1': 0.3833615630956964, 'auc': 0.7271630742365692, 'prauc': 0.4358111396165386}
Test:       {'recall': 0.3639489640301037, 'precision': 0.4268237906248835, 'f1': 0.3822431532425681, 'auc': 0.7277779507078027, 'prauc': 0.4348666010246191}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.05it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 88.48it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 85.04it/s]


Epoch: 010, Average Loss: 0.2924
Validation: {'recall': 0.38162935224696654, 'precision': 0.43584324455567824, 'f1': 0.39520398449878674, 'auc': 0.7253998685706644, 'prauc': 0.4329957587244385}
Test:       {'recall': 0.3840915667751847, 'precision': 0.44333290335682196, 'f1': 0.39833176118211605, 'auc': 0.7231734055856731, 'prauc': 0.4323095750501424}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.10it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 87.77it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.05it/s]


Epoch: 011, Average Loss: 0.2789
Validation: {'recall': 0.34326864533092605, 'precision': 0.41101051278217055, 'f1': 0.3538548722968559, 'auc': 0.7187154468677832, 'prauc': 0.4240876886073611}
Test:       {'recall': 0.34101589828690154, 'precision': 0.42524757570598576, 'f1': 0.3564194458976924, 'auc': 0.718758593188872, 'prauc': 0.4257736722982422}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.39it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.81it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 85.14it/s]


Epoch: 012, Average Loss: 0.2741
Validation: {'recall': 0.36441405070687405, 'precision': 0.4304622464991817, 'f1': 0.3781323469679367, 'auc': 0.7207292207534154, 'prauc': 0.42474779997274104}
Test:       {'recall': 0.3675266786045378, 'precision': 0.40139610542005694, 'f1': 0.3751988476279937, 'auc': 0.7180613934049164, 'prauc': 0.426918019346385}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.00it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 88.84it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.32it/s]


Epoch: 013, Average Loss: 0.2676
Validation: {'recall': 0.38079532295643925, 'precision': 0.4302409362715802, 'f1': 0.38725602025630584, 'auc': 0.7154205806174884, 'prauc': 0.4229180369000509}
Test:       {'recall': 0.37410148467859977, 'precision': 0.4193712666565912, 'f1': 0.38368814585533506, 'auc': 0.7190643205011292, 'prauc': 0.42429833572715503}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.22it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 88.80it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 85.35it/s]


Epoch: 014, Average Loss: 0.2602
Validation: {'recall': 0.3886287051316756, 'precision': 0.40404380306633514, 'f1': 0.38649863984064564, 'auc': 0.7151332390906389, 'prauc': 0.41951543123283125}
Test:       {'recall': 0.4024753129816599, 'precision': 0.41292221937086837, 'f1': 0.39715109220791195, 'auc': 0.7199765841889186, 'prauc': 0.4242926138415099}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.18it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 88.58it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.96it/s]


Epoch: 015, Average Loss: 0.2512
Validation: {'recall': 0.35506265704735246, 'precision': 0.4296043825714046, 'f1': 0.38047264070281916, 'auc': 0.7158407441432106, 'prauc': 0.42594778732487054}
Test:       {'recall': 0.3519607955118179, 'precision': 0.43373531176118374, 'f1': 0.37764085978996226, 'auc': 0.7214495584039189, 'prauc': 0.4286338475332516}

Early stopping triggered after 15 epochs (no improvement for 5 epochs).

Best validation performance:
{'recall': 0.38162935224696654, 'precision': 0.43584324455567824, 'f1': 0.39520398449878674, 'auc': 0.7253998685706644, 'prauc': 0.4329957587244385}
Corresponding test performance:
{'recall': 0.3840915667751847, 'precision': 0.44333290335682196, 'f1': 0.39833176118211605, 'auc': 0.7231734055856731, 'prauc': 0.4323095750501424}
[INFO] Random seed set to 1812140441
Training with seed: 1812140441


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.31it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.04it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.35it/s]


Epoch: 001, Average Loss: 0.4271
Validation: {'recall': 0.23384360281433858, 'precision': 0.332041277209937, 'f1': 0.26075524185690596, 'auc': 0.6856607730385282, 'prauc': 0.4043709470715211}
Test:       {'recall': 0.23886613085264902, 'precision': 0.33582467716586695, 'f1': 0.2643974762565651, 'auc': 0.6934341062458507, 'prauc': 0.39996559885197186}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.28it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 88.39it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.69it/s]


Epoch: 002, Average Loss: 0.3814
Validation: {'recall': 0.28388649439270397, 'precision': 0.4028812892113949, 'f1': 0.3110700523426313, 'auc': 0.7240675639717066, 'prauc': 0.4373826351897861}
Test:       {'recall': 0.2832267450771646, 'precision': 0.40488729273881835, 'f1': 0.3105955304228971, 'auc': 0.7274572297748011, 'prauc': 0.42922334594004213}


Training Batches: 100%|██████████| 380/380 [00:15<00:00, 25.18it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 90.56it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 82.68it/s]


Epoch: 003, Average Loss: 0.3631
Validation: {'recall': 0.2974440225789085, 'precision': 0.4432455854529766, 'f1': 0.3290305241130524, 'auc': 0.7409159731004644, 'prauc': 0.44988276131579863}
Test:       {'recall': 0.3058123257034326, 'precision': 0.455170129955968, 'f1': 0.3377861177636856, 'auc': 0.7352354132889373, 'prauc': 0.44070921394052076}


Training Batches: 100%|██████████| 380/380 [00:15<00:00, 25.19it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 88.62it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.69it/s]


Epoch: 004, Average Loss: 0.3504
Validation: {'recall': 0.36793224406343955, 'precision': 0.41130783529592285, 'f1': 0.3669815397915235, 'auc': 0.7444597787629657, 'prauc': 0.4512618899714812}
Test:       {'recall': 0.3743124592791203, 'precision': 0.426864893590041, 'f1': 0.37569734062657645, 'auc': 0.7251527086417999, 'prauc': 0.44725453508917745}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.73it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 88.45it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.66it/s]


Epoch: 005, Average Loss: 0.3371
Validation: {'recall': 0.32454030134170003, 'precision': 0.4293529446056388, 'f1': 0.3481073570887928, 'auc': 0.737217995689078, 'prauc': 0.4435145797084531}
Test:       {'recall': 0.3207400005693252, 'precision': 0.42712768296075854, 'f1': 0.34670746443649414, 'auc': 0.7303621669847463, 'prauc': 0.4441498833297093}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.00it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.22it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.64it/s]


Epoch: 006, Average Loss: 0.3284
Validation: {'recall': 0.3637422024272424, 'precision': 0.42056174697700005, 'f1': 0.38572335070463626, 'auc': 0.7298781839263105, 'prauc': 0.44610682231571264}
Test:       {'recall': 0.36301898444664965, 'precision': 0.41763677687562356, 'f1': 0.3840664118784052, 'auc': 0.7347237629499338, 'prauc': 0.4476740864724891}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.18it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 90.00it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 86.06it/s]


Epoch: 007, Average Loss: 0.3161
Validation: {'recall': 0.33679239045124676, 'precision': 0.4459410742474318, 'f1': 0.3704826493973351, 'auc': 0.7287630710230165, 'prauc': 0.44648573346747167}
Test:       {'recall': 0.34309804442547265, 'precision': 0.4349848928173865, 'f1': 0.3692345045181998, 'auc': 0.7337566123845154, 'prauc': 0.446154014597387}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.13it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 88.32it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.20it/s]


Epoch: 008, Average Loss: 0.3062
Validation: {'recall': 0.37098524613795414, 'precision': 0.4214086637612375, 'f1': 0.38839137537628293, 'auc': 0.7264437663199516, 'prauc': 0.4402008914220015}
Test:       {'recall': 0.38358227355897667, 'precision': 0.4582718243132596, 'f1': 0.39642688802134096, 'auc': 0.730624836178453, 'prauc': 0.44053053560699346}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.77it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 88.47it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.53it/s]


Epoch: 009, Average Loss: 0.2967
Validation: {'recall': 0.3778782288866706, 'precision': 0.4248490770101758, 'f1': 0.3858904172208736, 'auc': 0.7298593991818095, 'prauc': 0.44114936469985455}
Test:       {'recall': 0.38041981057244795, 'precision': 0.4469359355434947, 'f1': 0.38723371358950165, 'auc': 0.7295331106317361, 'prauc': 0.4401882238823841}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.16it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 88.83it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.43it/s]


Epoch: 010, Average Loss: 0.2906
Validation: {'recall': 0.35061008115932973, 'precision': 0.4386976192153569, 'f1': 0.3759577224868809, 'auc': 0.72324287170904, 'prauc': 0.43777688378124424}
Test:       {'recall': 0.3566199147593529, 'precision': 0.44386050785857556, 'f1': 0.3816459697920353, 'auc': 0.7218296008051558, 'prauc': 0.4375088360714414}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.52it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 88.11it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.55it/s]


Epoch: 011, Average Loss: 0.2823
Validation: {'recall': 0.38545103804114655, 'precision': 0.41867149110219826, 'f1': 0.3956056423486355, 'auc': 0.7199924673990441, 'prauc': 0.43246556375850803}
Test:       {'recall': 0.3959546227488132, 'precision': 0.42248494864616315, 'f1': 0.40346925525974675, 'auc': 0.7236181074439276, 'prauc': 0.4411779170150318}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.95it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 88.36it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 85.29it/s]


Epoch: 012, Average Loss: 0.2746
Validation: {'recall': 0.38490545602733944, 'precision': 0.4223449549927475, 'f1': 0.3888458450685147, 'auc': 0.7211226499760199, 'prauc': 0.42863204384314507}
Test:       {'recall': 0.38496169571050975, 'precision': 0.42282866610974035, 'f1': 0.39306430464039954, 'auc': 0.7206921236159067, 'prauc': 0.4340160376047528}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.18it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 90.25it/s] 
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.01it/s]


Epoch: 013, Average Loss: 0.2651
Validation: {'recall': 0.36883197886392904, 'precision': 0.4472865772864977, 'f1': 0.39164790386919557, 'auc': 0.7187364509388066, 'prauc': 0.42959818974566155}
Test:       {'recall': 0.3757351824063535, 'precision': 0.43112305187999045, 'f1': 0.3972971700331858, 'auc': 0.7208400626121928, 'prauc': 0.4423897133934378}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.08it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 87.93it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 84.96it/s]


Epoch: 014, Average Loss: 0.2549
Validation: {'recall': 0.39759947223860714, 'precision': 0.40882644549779484, 'f1': 0.3972103083865821, 'auc': 0.7097861656618483, 'prauc': 0.4247173089711626}
Test:       {'recall': 0.40465978272048, 'precision': 0.4208534876609396, 'f1': 0.4074856204481106, 'auc': 0.7221631632116036, 'prauc': 0.43775318356000054}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.26it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.48it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 85.15it/s]


Epoch: 015, Average Loss: 0.2479
Validation: {'recall': 0.3796366062662843, 'precision': 0.4226052595496501, 'f1': 0.38420440653874555, 'auc': 0.711521609427917, 'prauc': 0.42522383184745344}
Test:       {'recall': 0.3861714782738759, 'precision': 0.41942565950024646, 'f1': 0.39567156762706684, 'auc': 0.7172385122212859, 'prauc': 0.4331274794812248}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.08it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 90.70it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.24it/s]


Epoch: 016, Average Loss: 0.2434
Validation: {'recall': 0.4075740451667787, 'precision': 0.41014831225735177, 'f1': 0.4007362639743379, 'auc': 0.7087210958100117, 'prauc': 0.4218310683695448}
Test:       {'recall': 0.4183709938285431, 'precision': 0.398851903930544, 'f1': 0.4025342504565379, 'auc': 0.7168860779228332, 'prauc': 0.42983381638633567}


Training Batches: 100%|██████████| 380/380 [00:15<00:00, 24.89it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 88.33it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.61it/s]


Epoch: 017, Average Loss: 0.2319
Validation: {'recall': 0.39055308098354685, 'precision': 0.40397357378956217, 'f1': 0.3847598234125124, 'auc': 0.7113487534882049, 'prauc': 0.41686238661141545}
Test:       {'recall': 0.39786750090148615, 'precision': 0.43484375840477835, 'f1': 0.39561065164423725, 'auc': 0.715403912761899, 'prauc': 0.42623982435330343}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.97it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 90.18it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 85.09it/s]


Epoch: 018, Average Loss: 0.2270
Validation: {'recall': 0.38436384147176894, 'precision': 0.416235569384142, 'f1': 0.39028464786547234, 'auc': 0.7099476647235049, 'prauc': 0.4223166761595395}
Test:       {'recall': 0.3881847721674066, 'precision': 0.4182108821756942, 'f1': 0.3948873791021787, 'auc': 0.7222656484641765, 'prauc': 0.43376112750069246}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 25.98it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 90.14it/s]
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.77it/s]


Epoch: 019, Average Loss: 0.2169
Validation: {'recall': 0.3829304054558744, 'precision': 0.4218510707855134, 'f1': 0.38985040157520084, 'auc': 0.7029814919131617, 'prauc': 0.41746421244631526}
Test:       {'recall': 0.3852276089691994, 'precision': 0.4395277805829206, 'f1': 0.3940052568062642, 'auc': 0.7187584874432917, 'prauc': 0.4298804869028679}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.30it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 90.05it/s] 
Running inference: 100%|██████████| 289/289 [00:03<00:00, 83.60it/s]


Epoch: 020, Average Loss: 0.2105
Validation: {'recall': 0.37756705991773104, 'precision': 0.42119789269435554, 'f1': 0.3919411394397875, 'auc': 0.7070526879816909, 'prauc': 0.4207865565561912}
Test:       {'recall': 0.371764181856147, 'precision': 0.41639846275465375, 'f1': 0.38615461152576797, 'auc': 0.711645534190279, 'prauc': 0.42598205940442424}


Training Batches: 100%|██████████| 380/380 [00:14<00:00, 26.23it/s]
Running inference: 100%|██████████| 281/281 [00:03<00:00, 89.76it/s] 
Running inference: 100%|██████████| 289/289 [00:03<00:00, 82.99it/s]

Epoch: 021, Average Loss: 0.2028
Validation: {'recall': 0.41350864312206675, 'precision': 0.40596395497488463, 'f1': 0.3989199207475267, 'auc': 0.7051007618477592, 'prauc': 0.42010103593606724}
Test:       {'recall': 0.41805043022548916, 'precision': 0.4122756890009231, 'f1': 0.404057096257003, 'auc': 0.7125744001495973, 'prauc': 0.4245997025549339}

Early stopping triggered after 21 epochs (no improvement for 5 epochs).

Best validation performance:
{'recall': 0.4075740451667787, 'precision': 0.41014831225735177, 'f1': 0.4007362639743379, 'auc': 0.7087210958100117, 'prauc': 0.4218310683695448}
Corresponding test performance:
{'recall': 0.4183709938285431, 'precision': 0.398851903930544, 'f1': 0.4025342504565379, 'auc': 0.7168860779228332, 'prauc': 0.42983381638633567}





In [17]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")


Final Metrics:
recall: 0.3992 ± 0.0198
precision: 0.4404 ± 0.0228
f1: 0.3982 ± 0.0065
auc: 0.7210 ± 0.0048
prauc: 0.4340 ± 0.0044


In [18]:
# Assuming final_metrics and args are defined
# Create the results directory if it doesn't exist
os.makedirs(f"./results/{args['dataset']}", exist_ok=True)

# Define the file name
file_name = f"{args['dataset']}-{args['task']}.txt"
file_path = os.path.join(f"./results/{args['dataset']}", file_name)

# Save the metrics to a text file
with open(file_path, 'w') as f:
    f.write("Final Metrics:\n")
    for key in final_metrics.keys():
        mean_value = np.mean(final_metrics[key])
        std_value = np.std(final_metrics[key])
        f.write(f"{key}: {mean_value:.4f} ± {std_value:.4f}\n")