In [1]:
from set_seed_utils import set_random_seed
import os
import random
import numpy as np
import pickle
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from token_utils_rep import EHRTokenizer
from dataset_utils_rep import HBERTFinetuneEHRDataset, batcher, UniqueIDSampler
from HEART_rep import HBERT_Finetune
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve, precision_recall_fscore_support

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
@torch.no_grad()
def evaluate(model, dataloader, device, task_type="binary"):
    model.eval()
    predicted_scores, gt_labels = [], []
    # gt_labels is ground truth labels, predicted_scores is the output logits
    for _, batch in enumerate(tqdm(dataloader, desc="Running inference")):                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          
        batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]
        labels = batch[-1]
        output_logits = model(*batch[:-1])
        predicted_scores.append(output_logits)
        gt_labels.append(labels)
    
    if task_type == "binary":
        # standard binary classification evaluation
        predicted_scores = torch.cat(predicted_scores, dim=0).view(-1)
        gt_labels = torch.cat(gt_labels, dim=0).view(-1).cpu().numpy()
        scores = predicted_scores.cpu().numpy()      
    #   The threshold should be 0 because:
	# 	Your model outputs logits, not probabilities. logit > 0 ≡ probability > 0.5
        predicted_labels = (predicted_scores > 0).float().cpu().numpy()

        precision = (predicted_labels * gt_labels).sum() / (predicted_labels.sum() + 1e-8)
        recall = (predicted_labels * gt_labels).sum() / (gt_labels.sum() + 1e-8)
        f1 = 2 * precision * recall / (precision + recall + 1e-8)
        roc_auc = roc_auc_score(gt_labels, scores)
        precision_curve, recall_curve, _ = precision_recall_curve(gt_labels, scores)
        pr_auc = auc(recall_curve, precision_curve)

        return {"precision":precision, "recall":recall, "f1":f1, "auc":roc_auc, "prauc":pr_auc}
    else:
        # —— Multi-label classification evaluation (per-class over the batch) —— #
        # predicted_scores: list[Tensor] -> [B, C] logits
        logits = torch.cat(predicted_scores, dim=0)           # [B, C]
        y_true_t = torch.cat(gt_labels, dim=0)                # [B, C]
        # 1) 连续分数用于 AUC/PR-AUC：sigmoid(logits)
        #    CPU + float16 无 sigmoid 实现，先升到 fp32
        logits_for_sigmoid = logits.float() if (logits.device.type == "cpu" and logits.dtype == torch.float16) else logits
        prob_t = torch.sigmoid(logits_for_sigmoid)            # [B, C] in [0,1]

        # 2) 阈值化用于 P/R/F1：logits > 0 等价于 prob >= 0.5
        y_pred_t = (logits > 0).to(torch.int32)               # [B, C]

        # 转 numpy
        y_true = y_true_t.cpu().numpy().astype(np.int32)      # [B, C]
        y_pred = y_pred_t.cpu().numpy().astype(np.int32)      # [B, C]
        scores = prob_t.cpu().numpy()                         # [B, C] float

        # 3) 按类计算 Precision/Recall/F1
        p_cls, r_cls, f1_cls, _ = precision_recall_fscore_support(
            y_true, y_pred, average=None, zero_division=0
        )

        # 4) 按类计算 AUC/PR-AUC（遇到单一类别则设为 NaN）
        C = y_true.shape[1]
        aucs, praucs = [], []
        for c in range(C):
            yt, ys = y_true[:, c], scores[:, c]
            if yt.max() == yt.min():           # 全 0 或全 1，曲线不定义
                aucs.append(np.nan)
                praucs.append(np.nan)
                continue
            aucs.append(roc_auc_score(yt, ys))
            prec_curve, rec_curve, _ = precision_recall_curve(yt, ys)
            praucs.append(auc(rec_curve, prec_curve))

        # 5) 取宏平均（忽略 NaN）
        ave_precision = float(np.mean(p_cls))
        ave_recall    = float(np.mean(r_cls))
        ave_f1        = float(np.mean(f1_cls))
        ave_auc       = float(np.nanmean(aucs))   if np.any(~np.isnan(aucs))   else np.nan
        ave_prauc     = float(np.nanmean(praucs)) if np.any(~np.isnan(praucs)) else np.nan

        ave_f1, ave_auc, ave_prauc, ave_recall, ave_precision = np.mean(ave_f1), np.mean(ave_auc), np.mean(ave_prauc), np.mean(ave_recall), np.mean(ave_precision)
        return {"recall":ave_recall, "precision":ave_precision, "f1":ave_f1, "auc":ave_auc, "prauc":ave_prauc}

In [4]:
args = {
    "seed": 0,
    "dataset": "MIMIC-III", 
    "task": "next_diag_12m",  # options: death, stay, readmission, next_diag_6m, next_diag_12m
    "encoder": "hi_edge",  # options: hi_edge, hi_node, hi_edge_node
    "batch_size": 4,
    "eval_batch_size": 4,
    "pretrain_mask_rate": 0.7,
    "pretrain_anomaly_rate": 0.05,
    "pretrain_anomaly_loss_weight": 1,
    "pretrain_pos_weight": 1,
    "lr": 1e-4,
    "epochs": 50,
    "num_hidden_layers": 5,
    "num_attention_heads": 6,
    "attention_probs_dropout_prob": 0.2,
    "hidden_dropout_prob": 0.2,
    "edge_hidden_size": 32,
    "hidden_size": 288,  # must be divisible by num_attention_heads
    "intermediate_size": 288,
    "save_model": True,
    "gat": "dotattn",
    "gnn_n_heads": 1,
    "gnn_temp": 1,
    "diag_med_emb": "tree",  # simple, tree
    "early_stop_patience": 5,
}

In [5]:
exp_name = "Pretrain-HBERT" \
    + "-" + str(args["dataset"]) \
    + "-" + str(args["encoder"]) \
    + "-" + str(args["pretrain_mask_rate"]) \
    + "-" + str(args["pretrain_anomaly_rate"]) \
    + "-" + str(args["pretrain_anomaly_loss_weight"]) \
    + "-" + str(args["hidden_size"]) \
    + "-" + str(args["edge_hidden_size"]) \
    + "-" + str(args["num_hidden_layers"]) \
    + "-" + str(args["num_attention_heads"]) \
    + "-" + str(args["attention_probs_dropout_prob"]) \
    + "-" + str(args["hidden_dropout_prob"]) \
    + "-" + str(args["intermediate_size"]) \
    + "-" + str(args["gat"]) \
    + "-" + str(args["gnn_n_heads"]) \
    + "-" + str(args["gnn_temp"]) \
    + "-" + str(args["diag_med_emb"])
print(exp_name)

Pretrain-HBERT-MIMIC-III-hi_edge-0.7-0.05-1-288-32-5-6-0.2-0.2-288-dotattn-1-1-tree


In [6]:
pretrained_weight_path = "./pretrained_models/" + exp_name + f"/pretrained_model.pt"
finetune_exp_name = f"Finetune-{args['task']}-" + exp_name
save_path = "./saved_model/" + finetune_exp_name
if args["save_model"] and not os.path.exists(save_path):
    os.makedirs(save_path)

In [7]:
args["predicted_token_type"] = ["diag", "med", "pro", "lab"]
args["special_tokens"] = ("[PAD]", "[CLS]", "[SEP]", "[MASK0]", "[MASK1]", "[MASK2]", "[MASK3]")
args["max_visit_size"] = 15

full_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic.pkl"

if args["task"] == "next_diag_6m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_6m.pkl"
elif args["task"] == "next_diag_12m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_downstream.pkl"

In [8]:
ehr_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_data["NDC"].values.tolist()
lab_sentences = ehr_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_data["PRO_CODE"].values.tolist()
gender_set = [["M"], ["F"]]
age_gender_set = [[str(c) + "_" + gender] for c in set(ehr_data["AGE"].values.tolist()) for gender in ["M", "F"]]
age_set = [[c] for c in set(ehr_data["AGE"].values.tolist())]    

In [9]:
tokenizer = EHRTokenizer(diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, gender_set, age_set, age_gender_set, special_tokens=args["special_tokens"])
tokenizer.build_tree()

In [10]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))

train_dataset = HBERTFinetuneEHRDataset(
    train_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

val_dataset = HBERTFinetuneEHRDataset(
    val_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

test_dataset = HBERTFinetuneEHRDataset(
    test_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

train_dataloader = DataLoader(
    train_dataset, 
    batch_sampler=UniqueIDSampler(train_dataset.get_ids(), batch_size=args["batch_size"]),
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False), 
)

val_dataloader = DataLoader(
    val_dataset, 
    batch_sampler=UniqueIDSampler(val_dataset.get_ids(), batch_size=args["batch_size"]),
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False), 
)

test_dataloader = DataLoader(
    test_dataset, 
    batch_sampler=UniqueIDSampler(test_dataset.get_ids(), batch_size=args["eval_batch_size"]),
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False),
)

In [11]:
# examine a batch
batch = next(iter(train_dataloader))  # 取第一个 batch
input_ids, input_types, edge_index, visit_positions, labeled_batch_idx, labels = batch

# 打印每个张量的形状
print("input_ids shape:", input_ids.shape)
print("input_types shape:", input_types.shape)
print("visit_positions shape:", visit_positions.shape)
print("labeled_batch_idx shape:", len(labeled_batch_idx)) # it is a list
print("labels shape:", labels.shape)

input_ids shape: torch.Size([8, 126])
input_types shape: torch.Size([8, 126])
visit_positions shape: torch.Size([8])
labeled_batch_idx shape: 4
labels shape: torch.Size([4, 18])


In [12]:
args["vocab_size"] = len(args["special_tokens"]) + \
                     len(tokenizer.diag_voc.id2word) + \
                     len(tokenizer.pro_voc.id2word) + \
                     len(tokenizer.med_voc.id2word) + \
                     len(tokenizer.lab_voc.id2word) + \
                     len(tokenizer.age_voc.id2word) + \
                     len(tokenizer.gender_voc.id2word) + \
                     len(tokenizer.age_gender_voc.id2word)
args["label_vocab_size"] = 18  # only for diagnosis

In [13]:
if args["task"] in ["death", "stay", "readmission"]:
    eval_metric = "f1"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "prauc"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [14]:
def train_with_early_stopping(model, 
                              train_dataloader, 
                              val_dataloader, 
                              test_dataloader,
                              optimizer, 
                              loss_fn, 
                              device, 
                              args,
                              task_type="binary", 
                              eval_metric="f1"):
    best_score = 0.
    best_val_metric = None
    best_test_metric = None
    epochs_no_improve = 0

    for epoch in range(1, 1 + args["epochs"]):
        model.train()
        ave_loss = 0.

        for step, batch in enumerate(tqdm(train_dataloader, desc="Training Batches")):
            batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]

            labels = batch[-1].float()
            output_logits = model(*batch[:-1])
            
            loss = loss_fn(output_logits.view(-1), labels.view(-1))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            ave_loss += loss.item()
            

        ave_loss /= (step + 1)

        # Evaluation
        val_metric = evaluate(model, val_dataloader, device, task_type=task_type)
        test_metric = evaluate(model, test_dataloader, device, task_type=task_type)

        # Logging
        print(f"Epoch: {epoch:03d}, Average Loss: {ave_loss:.4f}")
        print(f"Validation: {val_metric}")
        print(f"Test:       {test_metric}")

        # Check for improvement
        current_score = val_metric[eval_metric]
        if current_score > best_score:
            best_score = current_score
            best_val_metric = val_metric
            best_test_metric = test_metric
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        # Early stopping check
        if epochs_no_improve >= args["early_stop_patience"]:
            print(f"\nEarly stopping triggered after {epoch} epochs (no improvement for {args['early_stop_patience']} epochs).")
            break

    print("\nBest validation performance:")
    print(best_val_metric)
    print("Corresponding test performance:")
    print(best_test_metric)
    return best_test_metric

In [15]:
random.seed(42)
seeds = [random.randint(0, 2**32 - 1) for _ in range(5)]
print(seeds)

[2746317213, 1181241943, 958682846, 3163119785, 1812140441]


In [None]:
final_metrics = {"recall":[], "precision":[],"f1":[],"auc":[],"prauc":[]}

for seed in seeds:
    args["seed"] = seed
    set_random_seed(args["seed"])
    print(f"Training with seed: {args['seed']}")
    
    # Initialize model, optimizer, and loss function
    model = HBERT_Finetune(args, tokenizer)
    model.load_weight(torch.load(pretrained_weight_path, weights_only=True))
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args["lr"])
    
    best_test_metric = train_with_early_stopping(
        model, 
        train_dataloader, 
        val_dataloader, 
        test_dataloader,
        optimizer, 
        loss_fn, 
        device, 
        args,
        task_type=task_type)
    
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

[INFO] Random seed set to 2746317213
Training with seed: 2746317213


Training Batches: 100%|██████████| 471/471 [00:19<00:00, 24.45it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.45it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.51it/s]


Epoch: 001, Average Loss: 0.4276
Validation: {'recall': 0.26989317814628666, 'precision': 0.345813720440676, 'f1': 0.2898586923862811, 'auc': 0.705322820317638, 'prauc': 0.39683285800696155}
Test:       {'recall': 0.26546533317474275, 'precision': 0.34322389730509084, 'f1': 0.2849973093474003, 'auc': 0.7116808286391488, 'prauc': 0.3933291722614243}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.35it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 88.03it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.13it/s]


Epoch: 002, Average Loss: 0.3851
Validation: {'recall': 0.295006828829122, 'precision': 0.3887971433838162, 'f1': 0.3218708476714616, 'auc': 0.7231822063322164, 'prauc': 0.4265793498118323}
Test:       {'recall': 0.2897344723354182, 'precision': 0.39935456617571086, 'f1': 0.3126356422912713, 'auc': 0.7329780522491193, 'prauc': 0.4206781130657211}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.52it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 85.58it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.84it/s] 


Epoch: 003, Average Loss: 0.3663
Validation: {'recall': 0.35357143040648664, 'precision': 0.4649480366782543, 'f1': 0.37590821283414527, 'auc': 0.7415779071526492, 'prauc': 0.44987198374183124}
Test:       {'recall': 0.34787246566482016, 'precision': 0.4158977908359775, 'f1': 0.36670690350842494, 'auc': 0.7455871167163122, 'prauc': 0.4339038144091327}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.82it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 86.75it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.45it/s]


Epoch: 004, Average Loss: 0.3533
Validation: {'recall': 0.3312768267931646, 'precision': 0.42773350113403485, 'f1': 0.35714579766371646, 'auc': 0.74491524504738, 'prauc': 0.4503387560473915}
Test:       {'recall': 0.316599921265119, 'precision': 0.4214203807143826, 'f1': 0.3423422811534969, 'auc': 0.7488074453788721, 'prauc': 0.43304028181864723}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.04it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 85.85it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.43it/s]


Epoch: 005, Average Loss: 0.3451
Validation: {'recall': 0.3485143712796533, 'precision': 0.44415458096314847, 'f1': 0.3764183994648496, 'auc': 0.7449907865104535, 'prauc': 0.4522949602159444}
Test:       {'recall': 0.33704349970057396, 'precision': 0.43272890704459677, 'f1': 0.364857919765654, 'auc': 0.7491872446532997, 'prauc': 0.43165166152638934}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.44it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 88.10it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.95it/s]


Epoch: 006, Average Loss: 0.3352
Validation: {'recall': 0.3260259773196733, 'precision': 0.44690467054085026, 'f1': 0.3611258615255917, 'auc': 0.7438173133687122, 'prauc': 0.4457981202705501}
Test:       {'recall': 0.313074094846698, 'precision': 0.42651135051293115, 'f1': 0.3456291885468587, 'auc': 0.7453079191286038, 'prauc': 0.4319309937823921}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.24it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.67it/s] 
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.80it/s] 


Epoch: 007, Average Loss: 0.3255
Validation: {'recall': 0.3563386684462747, 'precision': 0.43610293673861594, 'f1': 0.381569620473295, 'auc': 0.7385098737553538, 'prauc': 0.44348497417510857}
Test:       {'recall': 0.349442273591669, 'precision': 0.4167147322559863, 'f1': 0.37220094281680144, 'auc': 0.7480472005170269, 'prauc': 0.43067110254489477}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.45it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.28it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.33it/s]


Epoch: 008, Average Loss: 0.3141
Validation: {'recall': 0.3734369816029107, 'precision': 0.44723756023349875, 'f1': 0.39932827717620933, 'auc': 0.7371345371536449, 'prauc': 0.4448868164296111}
Test:       {'recall': 0.3557500386658983, 'precision': 0.42934574680314674, 'f1': 0.38220033483418586, 'auc': 0.737427058226523, 'prauc': 0.4287723139941431}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.44it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 86.71it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.80it/s]


Epoch: 009, Average Loss: 0.3088
Validation: {'recall': 0.3685778791863314, 'precision': 0.527173345161227, 'f1': 0.3929520158987889, 'auc': 0.7401006650755617, 'prauc': 0.4408852890434408}
Test:       {'recall': 0.3490252904697349, 'precision': 0.4410185778420589, 'f1': 0.3711953221210251, 'auc': 0.7415496121885412, 'prauc': 0.4269985919178576}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.26it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 88.23it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.47it/s]


Epoch: 010, Average Loss: 0.3007
Validation: {'recall': 0.3946200605125117, 'precision': 0.42079710575369345, 'f1': 0.39539350131199624, 'auc': 0.7365590178064578, 'prauc': 0.4433860614778082}
Test:       {'recall': 0.3763817151680511, 'precision': 0.4105250482625704, 'f1': 0.3813619216497778, 'auc': 0.7362263491545362, 'prauc': 0.42280484124693823}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.15it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.39it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.52it/s]


Epoch: 011, Average Loss: 0.2938
Validation: {'recall': 0.40273849589826066, 'precision': 0.4132000693737272, 'f1': 0.3983273728734938, 'auc': 0.7323464185884477, 'prauc': 0.436709946590424}
Test:       {'recall': 0.3804848527764018, 'precision': 0.39323512244783265, 'f1': 0.3798897801024387, 'auc': 0.7293700276089334, 'prauc': 0.41851505454654114}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.15it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 88.11it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.72it/s]


Epoch: 012, Average Loss: 0.2861
Validation: {'recall': 0.4070611038947186, 'precision': 0.40978383483321196, 'f1': 0.40033567812137627, 'auc': 0.7346578418290518, 'prauc': 0.43433149386166303}
Test:       {'recall': 0.38929442685026333, 'precision': 0.3901678277636542, 'f1': 0.3822952386018924, 'auc': 0.7306435929168871, 'prauc': 0.41633508517808826}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.37it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 86.55it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 86.79it/s]


Epoch: 013, Average Loss: 0.2760
Validation: {'recall': 0.40291458394488155, 'precision': 0.4678735825717611, 'f1': 0.40231571763312446, 'auc': 0.7310373863063524, 'prauc': 0.4305651650352422}
Test:       {'recall': 0.3821339676674911, 'precision': 0.387840065782902, 'f1': 0.3807315242428201, 'auc': 0.7239568885393158, 'prauc': 0.41086708648338544}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.06it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.85it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.31it/s]


Epoch: 014, Average Loss: 0.2669
Validation: {'recall': 0.40508123571834026, 'precision': 0.4250632776085535, 'f1': 0.4048361702414693, 'auc': 0.7302483734667103, 'prauc': 0.43254169327923686}
Test:       {'recall': 0.38388452840056647, 'precision': 0.41249156265928, 'f1': 0.3863972451711193, 'auc': 0.730742963806787, 'prauc': 0.4159418717640798}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.40it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.82it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.88it/s]


Epoch: 015, Average Loss: 0.2606
Validation: {'recall': 0.395506510947395, 'precision': 0.4098909399446538, 'f1': 0.395000327184614, 'auc': 0.7314973585048672, 'prauc': 0.4272636158149047}
Test:       {'recall': 0.37384934478844967, 'precision': 0.3841906244672092, 'f1': 0.3724273654027763, 'auc': 0.7240542977641568, 'prauc': 0.41119789780750315}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.46it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.29it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.79it/s] 


Epoch: 016, Average Loss: 0.2543
Validation: {'recall': 0.37724053896322013, 'precision': 0.43037253170851586, 'f1': 0.39578592174242616, 'auc': 0.7311374879015299, 'prauc': 0.43301445802048033}
Test:       {'recall': 0.36338801831843026, 'precision': 0.4155879830396942, 'f1': 0.3815256650676576, 'auc': 0.7248548315975573, 'prauc': 0.4120706847488652}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.36it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 88.15it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.42it/s]


Epoch: 017, Average Loss: 0.2458
Validation: {'recall': 0.3640250275019994, 'precision': 0.4277131728130288, 'f1': 0.3825673187587539, 'auc': 0.7288851612214646, 'prauc': 0.4227312314192038}
Test:       {'recall': 0.3419590964865076, 'precision': 0.4171152580072081, 'f1': 0.3620663170487137, 'auc': 0.720882219045587, 'prauc': 0.40673541281236303}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.50it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 88.13it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.85it/s]


Epoch: 018, Average Loss: 0.2374
Validation: {'recall': 0.36603159019430875, 'precision': 0.43639270025024435, 'f1': 0.3875014828406862, 'auc': 0.7272295247862264, 'prauc': 0.42474027039091516}
Test:       {'recall': 0.35018767127988953, 'precision': 0.4299338301557603, 'f1': 0.3696202019769285, 'auc': 0.7226767314008683, 'prauc': 0.40658552861176195}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.40it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.28it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.41it/s]


Epoch: 019, Average Loss: 0.2373
Validation: {'recall': 0.39737080150625637, 'precision': 0.3992120521940873, 'f1': 0.39110043922672305, 'auc': 0.7261919498088825, 'prauc': 0.42027657352520553}
Test:       {'recall': 0.387522500992038, 'precision': 0.41698756262271175, 'f1': 0.3861241295814361, 'auc': 0.7151976713335165, 'prauc': 0.4032273382450024}

Early stopping triggered after 19 epochs (no improvement for 5 epochs).

Best validation performance:
{'recall': 0.40508123571834026, 'precision': 0.4250632776085535, 'f1': 0.4048361702414693, 'auc': 0.7302483734667103, 'prauc': 0.43254169327923686}
Corresponding test performance:
{'recall': 0.38388452840056647, 'precision': 0.41249156265928, 'f1': 0.3863972451711193, 'auc': 0.730742963806787, 'prauc': 0.4159418717640798}
[INFO] Random seed set to 1181241943
Training with seed: 1181241943


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.23it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.76it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.58it/s]


Epoch: 001, Average Loss: 0.4320
Validation: {'recall': 0.24866094657629276, 'precision': 0.40934800341225114, 'f1': 0.2767772773262691, 'auc': 0.7035824236041311, 'prauc': 0.40154788152641063}
Test:       {'recall': 0.24161474740721284, 'precision': 0.3624335041060932, 'f1': 0.2718201475461422, 'auc': 0.7130302779231562, 'prauc': 0.40000172976058074}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.29it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.20it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.91it/s]


Epoch: 002, Average Loss: 0.3842
Validation: {'recall': 0.2993109160827122, 'precision': 0.3763912317713021, 'f1': 0.31980339894166154, 'auc': 0.7161030505965569, 'prauc': 0.41863034898379076}
Test:       {'recall': 0.29438581335564584, 'precision': 0.37822017532031044, 'f1': 0.31472607876159064, 'auc': 0.7274969530704007, 'prauc': 0.4209552588161632}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.58it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 88.12it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.27it/s]


Epoch: 003, Average Loss: 0.3666
Validation: {'recall': 0.3119754873809347, 'precision': 0.43174867750510004, 'f1': 0.3444826973792375, 'auc': 0.7398064922198958, 'prauc': 0.4433012669675602}
Test:       {'recall': 0.30587088659768075, 'precision': 0.42276689071233975, 'f1': 0.33494365594179554, 'auc': 0.7442215285373383, 'prauc': 0.4340655891877348}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.41it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 88.07it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.70it/s]


Epoch: 004, Average Loss: 0.3523
Validation: {'recall': 0.3758353016081317, 'precision': 0.4527309150049263, 'f1': 0.39149621465229356, 'auc': 0.7417142002208201, 'prauc': 0.45327044883265466}
Test:       {'recall': 0.36946608041394075, 'precision': 0.42514585051387965, 'f1': 0.38011304912697524, 'auc': 0.7494041616351886, 'prauc': 0.4419667990025218}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.17it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.77it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.30it/s]


Epoch: 005, Average Loss: 0.3414
Validation: {'recall': 0.3842814943033359, 'precision': 0.42333774439246796, 'f1': 0.38891737385989467, 'auc': 0.7402323861125661, 'prauc': 0.44514722853272093}
Test:       {'recall': 0.3800024638654566, 'precision': 0.4156057083958547, 'f1': 0.38385994662275835, 'auc': 0.7506950604499373, 'prauc': 0.43793594629387234}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.32it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.25it/s] 
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.48it/s] 


Epoch: 006, Average Loss: 0.3327
Validation: {'recall': 0.3752809565535212, 'precision': 0.45264554889000386, 'f1': 0.3916554134621053, 'auc': 0.7378692447570486, 'prauc': 0.4515303964487835}
Test:       {'recall': 0.35649296163160676, 'precision': 0.41821138231910937, 'f1': 0.37281039154505113, 'auc': 0.7379876615614829, 'prauc': 0.43184381204105754}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.47it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.16it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 88.24it/s]


Epoch: 007, Average Loss: 0.3225
Validation: {'recall': 0.36304119317247996, 'precision': 0.41876736103264306, 'f1': 0.372197015208359, 'auc': 0.7385521116902205, 'prauc': 0.44685572236307336}
Test:       {'recall': 0.3425250001239763, 'precision': 0.44158404380311134, 'f1': 0.35023795604502, 'auc': 0.7418013113386004, 'prauc': 0.4244464301669161}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.76it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.07it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.65it/s]


Epoch: 008, Average Loss: 0.3166
Validation: {'recall': 0.36911258881300574, 'precision': 0.48956812852656273, 'f1': 0.3921298226077229, 'auc': 0.7351780640730006, 'prauc': 0.448372045734215}
Test:       {'recall': 0.35744655148745974, 'precision': 0.4136219545995768, 'f1': 0.3780248795134595, 'auc': 0.730714663039511, 'prauc': 0.42579884494861664}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.24it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.82it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.55it/s] 


Epoch: 009, Average Loss: 0.3055
Validation: {'recall': 0.3877753425720482, 'precision': 0.41395469280937724, 'f1': 0.3966525831387539, 'auc': 0.7311527889576874, 'prauc': 0.4466622934179613}
Test:       {'recall': 0.36953859596763644, 'precision': 0.39759673735699286, 'f1': 0.3797031505230422, 'auc': 0.7295882800413237, 'prauc': 0.43036334593477066}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.10it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 86.57it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.52it/s] 


Epoch: 010, Average Loss: 0.2968
Validation: {'recall': 0.35832473102776935, 'precision': 0.4269533670640283, 'f1': 0.3836481912005332, 'auc': 0.7318217353899421, 'prauc': 0.44438310186553975}
Test:       {'recall': 0.3389841644938663, 'precision': 0.41390818214131253, 'f1': 0.3654926153851451, 'auc': 0.7288857717737196, 'prauc': 0.42151931657873914}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.38it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.59it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.06it/s]


Epoch: 011, Average Loss: 0.2894
Validation: {'recall': 0.3863652559555185, 'precision': 0.4349001992848749, 'f1': 0.4010874476121916, 'auc': 0.7293122269680686, 'prauc': 0.43909209313126735}
Test:       {'recall': 0.3750631651196199, 'precision': 0.4418839037308082, 'f1': 0.38921594994469044, 'auc': 0.7295439213834074, 'prauc': 0.42419836357604657}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.44it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 85.85it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.89it/s]


Epoch: 012, Average Loss: 0.2783
Validation: {'recall': 0.3816962312789038, 'precision': 0.43555426074479087, 'f1': 0.3964095694855627, 'auc': 0.7339386826360776, 'prauc': 0.44056548748752145}
Test:       {'recall': 0.36684754644531575, 'precision': 0.44611100631345024, 'f1': 0.383859085407913, 'auc': 0.7289991231142293, 'prauc': 0.426620406742947}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.00it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.31it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.68it/s]


Epoch: 013, Average Loss: 0.2733
Validation: {'recall': 0.39331299465366915, 'precision': 0.4431244064459695, 'f1': 0.3979351639156715, 'auc': 0.729098694101316, 'prauc': 0.43567151065359294}
Test:       {'recall': 0.380345084609381, 'precision': 0.41465801185408807, 'f1': 0.38405601346352525, 'auc': 0.723334910941626, 'prauc': 0.41692980182037304}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.32it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.19it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.46it/s] 


Epoch: 014, Average Loss: 0.2630
Validation: {'recall': 0.3924446757323166, 'precision': 0.47914025961915246, 'f1': 0.4014379092105337, 'auc': 0.7291139711813062, 'prauc': 0.438111609590786}
Test:       {'recall': 0.37588113610267027, 'precision': 0.45454660494095855, 'f1': 0.38992172040044265, 'auc': 0.7228209959989189, 'prauc': 0.42381398524628444}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.38it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 88.17it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.65it/s]


Epoch: 015, Average Loss: 0.2548
Validation: {'recall': 0.38500848200408755, 'precision': 0.44363559390150675, 'f1': 0.39859815465807535, 'auc': 0.7194183283190262, 'prauc': 0.4322516778261686}
Test:       {'recall': 0.3636628356948714, 'precision': 0.44488330112737784, 'f1': 0.38469542120694467, 'auc': 0.7151661499545546, 'prauc': 0.4133321361646243}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.02it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 86.73it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.55it/s]


Epoch: 016, Average Loss: 0.2451
Validation: {'recall': 0.3801589705170132, 'precision': 0.4353105199907048, 'f1': 0.3895850522504884, 'auc': 0.7231948245338337, 'prauc': 0.43089484167281084}
Test:       {'recall': 0.3643619824615361, 'precision': 0.4637596541452229, 'f1': 0.37391866837433646, 'auc': 0.7162477322595195, 'prauc': 0.41407803295202056}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.76it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.17it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.28it/s]


Epoch: 017, Average Loss: 0.2396
Validation: {'recall': 0.3989882083498406, 'precision': 0.443485080578828, 'f1': 0.40119351466887065, 'auc': 0.7211643573322459, 'prauc': 0.42993283591891673}
Test:       {'recall': 0.38223438839470325, 'precision': 0.42030152074369814, 'f1': 0.383381332971132, 'auc': 0.7116813629851424, 'prauc': 0.4110959301184354}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.22it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.24it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.09it/s]


Epoch: 018, Average Loss: 0.2298
Validation: {'recall': 0.38583473617803904, 'precision': 0.43077148045360025, 'f1': 0.39842117290748313, 'auc': 0.7184234300989976, 'prauc': 0.4245981376728222}
Test:       {'recall': 0.37189511601181263, 'precision': 0.4305240509031529, 'f1': 0.38500280699591205, 'auc': 0.7113570769919627, 'prauc': 0.41178390996552594}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.15it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.06it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.18it/s]


Epoch: 019, Average Loss: 0.2232
Validation: {'recall': 0.39483725275056236, 'precision': 0.423464559160488, 'f1': 0.39657661131800953, 'auc': 0.7090104268595292, 'prauc': 0.4175503462879995}
Test:       {'recall': 0.38599261415008557, 'precision': 0.4208429117314004, 'f1': 0.3880072146085593, 'auc': 0.7068861242384262, 'prauc': 0.4050343962849271}

Early stopping triggered after 19 epochs (no improvement for 5 epochs).

Best validation performance:
{'recall': 0.3924446757323166, 'precision': 0.47914025961915246, 'f1': 0.4014379092105337, 'auc': 0.7291139711813062, 'prauc': 0.438111609590786}
Corresponding test performance:
{'recall': 0.37588113610267027, 'precision': 0.45454660494095855, 'f1': 0.38992172040044265, 'auc': 0.7228209959989189, 'prauc': 0.42381398524628444}
[INFO] Random seed set to 958682846
Training with seed: 958682846


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.69it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 86.84it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.19it/s] 


Epoch: 001, Average Loss: 0.4256
Validation: {'recall': 0.30960399497043156, 'precision': 0.4159952195955146, 'f1': 0.32370786933455614, 'auc': 0.7096929860079244, 'prauc': 0.422917472921039}
Test:       {'recall': 0.30177695245919833, 'precision': 0.39987835908899866, 'f1': 0.31516109505832135, 'auc': 0.7153007455754741, 'prauc': 0.4114563544086954}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.56it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 86.53it/s] 
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.66it/s]


Epoch: 002, Average Loss: 0.3802
Validation: {'recall': 0.34030208913548, 'precision': 0.4191187132539382, 'f1': 0.3624735606827698, 'auc': 0.7330202087513523, 'prauc': 0.43967679389807685}
Test:       {'recall': 0.33862777258673993, 'precision': 0.4231283095435426, 'f1': 0.36354658160218145, 'auc': 0.7401613408912194, 'prauc': 0.4372668350468156}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.47it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.88it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.55it/s]


Epoch: 003, Average Loss: 0.3622
Validation: {'recall': 0.32580311091906405, 'precision': 0.42895903151931924, 'f1': 0.3535910239336179, 'auc': 0.7447213625853673, 'prauc': 0.44797893575867065}
Test:       {'recall': 0.31323156652828804, 'precision': 0.4087036155947624, 'f1': 0.3410205256476452, 'auc': 0.7441282941645265, 'prauc': 0.43723095712446663}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.48it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 86.81it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.75it/s]


Epoch: 004, Average Loss: 0.3500
Validation: {'recall': 0.39314613201995596, 'precision': 0.4120710444871665, 'f1': 0.3926505765703537, 'auc': 0.7444932655135622, 'prauc': 0.4466446530382952}
Test:       {'recall': 0.37895908333657935, 'precision': 0.4007386458335647, 'f1': 0.3804176612441808, 'auc': 0.7365953778013731, 'prauc': 0.43209659792710187}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.02it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.29it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.36it/s]


Epoch: 005, Average Loss: 0.3404
Validation: {'recall': 0.3344718593128543, 'precision': 0.46219526962173024, 'f1': 0.3652498020186432, 'auc': 0.7387537814058693, 'prauc': 0.44032696713670627}
Test:       {'recall': 0.3202392078487214, 'precision': 0.4317658767835935, 'f1': 0.34802883620941993, 'auc': 0.7356384711864034, 'prauc': 0.4315086179057614}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.39it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.13it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.63it/s] 


Epoch: 006, Average Loss: 0.3339
Validation: {'recall': 0.3772579584991303, 'precision': 0.4064269937468929, 'f1': 0.3840755730046011, 'auc': 0.7399138325695036, 'prauc': 0.4434717766299724}
Test:       {'recall': 0.3684187193899544, 'precision': 0.39354644561059293, 'f1': 0.37471367542693396, 'auc': 0.7413797125557191, 'prauc': 0.4291271021325422}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.41it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.50it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.38it/s]


Epoch: 007, Average Loss: 0.3246
Validation: {'recall': 0.3704351188997152, 'precision': 0.41611023252536694, 'f1': 0.37978312676411385, 'auc': 0.742977406678344, 'prauc': 0.44592336314124786}
Test:       {'recall': 0.36121559664336916, 'precision': 0.403661183786596, 'f1': 0.3706577900299459, 'auc': 0.7307841232716855, 'prauc': 0.4272036475884257}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.49it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.59it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.26it/s]


Epoch: 008, Average Loss: 0.3155
Validation: {'recall': 0.3634891845119672, 'precision': 0.4223749777619212, 'f1': 0.37807465717405964, 'auc': 0.7412080918550407, 'prauc': 0.4476768434792537}
Test:       {'recall': 0.349695659689636, 'precision': 0.4241366262012635, 'f1': 0.36654831815836736, 'auc': 0.7278788803052441, 'prauc': 0.42469978288750787}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.24it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.51it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.04it/s]


Epoch: 009, Average Loss: 0.3080
Validation: {'recall': 0.3667602651994251, 'precision': 0.44591462533473514, 'f1': 0.3918491832884586, 'auc': 0.7371252304576452, 'prauc': 0.44700132146009786}
Test:       {'recall': 0.34859150662964283, 'precision': 0.4292604330761935, 'f1': 0.37336923966998126, 'auc': 0.7251953261219032, 'prauc': 0.42495281715686867}

Early stopping triggered after 9 epochs (no improvement for 5 epochs).

Best validation performance:
{'recall': 0.39314613201995596, 'precision': 0.4120710444871665, 'f1': 0.3926505765703537, 'auc': 0.7444932655135622, 'prauc': 0.4466446530382952}
Corresponding test performance:
{'recall': 0.37895908333657935, 'precision': 0.4007386458335647, 'f1': 0.3804176612441808, 'auc': 0.7365953778013731, 'prauc': 0.43209659792710187}
[INFO] Random seed set to 3163119785
Training with seed: 3163119785


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.55it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.34it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.70it/s]


Epoch: 001, Average Loss: 0.4293
Validation: {'recall': 0.26305164974441314, 'precision': 0.3751089563218069, 'f1': 0.28345588856249176, 'auc': 0.7088769301509484, 'prauc': 0.4134457617972622}
Test:       {'recall': 0.25559197789115007, 'precision': 0.3577857901236528, 'f1': 0.2775054817162571, 'auc': 0.7119259353241429, 'prauc': 0.4071122360912201}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.51it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.59it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.63it/s]


Epoch: 002, Average Loss: 0.3848
Validation: {'recall': 0.32816414999271226, 'precision': 0.4132292292563778, 'f1': 0.3339707390926645, 'auc': 0.7255167526743427, 'prauc': 0.42676979826445416}
Test:       {'recall': 0.3296897597220673, 'precision': 0.36760980974679947, 'f1': 0.3340632128127313, 'auc': 0.7352165782730601, 'prauc': 0.41475824259879196}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.19it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.70it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.03it/s]


Epoch: 003, Average Loss: 0.3681
Validation: {'recall': 0.33799786253783143, 'precision': 0.4172028829422867, 'f1': 0.346996611117995, 'auc': 0.7319341403781418, 'prauc': 0.4329973727753873}
Test:       {'recall': 0.33127047471088866, 'precision': 0.40243609848060097, 'f1': 0.34110149780842824, 'auc': 0.7385741315596283, 'prauc': 0.42570238840242985}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.16it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.68it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.39it/s]


Epoch: 004, Average Loss: 0.3561
Validation: {'recall': 0.3817743271933691, 'precision': 0.41638045164307674, 'f1': 0.36870957527474646, 'auc': 0.7350615098049885, 'prauc': 0.4421172700399689}
Test:       {'recall': 0.37598229167188973, 'precision': 0.3914147886956456, 'f1': 0.363594842388359, 'auc': 0.7386492881014379, 'prauc': 0.43038635191733937}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.44it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.64it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.55it/s]


Epoch: 005, Average Loss: 0.3453
Validation: {'recall': 0.3461652577550773, 'precision': 0.4293283418612565, 'f1': 0.3702861182463607, 'auc': 0.741539814475812, 'prauc': 0.44709834346436556}
Test:       {'recall': 0.32921525723117323, 'precision': 0.4163371711180298, 'f1': 0.3564994052370701, 'auc': 0.7415662784542129, 'prauc': 0.4331092956279315}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.23it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 86.99it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.53it/s]


Epoch: 006, Average Loss: 0.3359
Validation: {'recall': 0.3578365616628143, 'precision': 0.4228422595143012, 'f1': 0.37609051827284623, 'auc': 0.7364779144422475, 'prauc': 0.44597168815751737}
Test:       {'recall': 0.34801800946138456, 'precision': 0.4116744507318587, 'f1': 0.3676943317387302, 'auc': 0.7438351836444967, 'prauc': 0.4325676920609046}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.28it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 86.57it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.11it/s]


Epoch: 007, Average Loss: 0.3274
Validation: {'recall': 0.39136093446537157, 'precision': 0.4051723667732108, 'f1': 0.39572962552332913, 'auc': 0.738540324325724, 'prauc': 0.43762057794608}
Test:       {'recall': 0.3882786363479206, 'precision': 0.4035881457563498, 'f1': 0.3937678856938388, 'auc': 0.7428193139717819, 'prauc': 0.43249307086600314}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.19it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 89.24it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.12it/s]


Epoch: 008, Average Loss: 0.3178
Validation: {'recall': 0.33978000009469017, 'precision': 0.4325339541075495, 'f1': 0.373898794020152, 'auc': 0.7378617366031411, 'prauc': 0.4450693431151643}
Test:       {'recall': 0.3237685043371379, 'precision': 0.41996535275316654, 'f1': 0.3599216303723891, 'auc': 0.7422347714344952, 'prauc': 0.4299659373562543}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.51it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 88.05it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.04it/s] 


Epoch: 009, Average Loss: 0.3109
Validation: {'recall': 0.379342235728508, 'precision': 0.4208459545726988, 'f1': 0.38931670345442015, 'auc': 0.7359490666717399, 'prauc': 0.4391305519205812}
Test:       {'recall': 0.36638723993140654, 'precision': 0.41601700075591447, 'f1': 0.37959182371034544, 'auc': 0.7410183803206469, 'prauc': 0.42928044044711605}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.52it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 86.78it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 88.59it/s] 


Epoch: 010, Average Loss: 0.3015
Validation: {'recall': 0.3791522526871503, 'precision': 0.40444593267160045, 'f1': 0.38708810643887614, 'auc': 0.734629427279343, 'prauc': 0.4376544129405095}
Test:       {'recall': 0.3728873036055943, 'precision': 0.4142398437038094, 'f1': 0.38287006127755124, 'auc': 0.7406600657112353, 'prauc': 0.4282757220269783}


Training Batches: 100%|██████████| 471/471 [00:18<00:00, 25.24it/s]
Running inference: 100%|██████████| 353/353 [00:04<00:00, 87.87it/s]
Running inference: 100%|██████████| 353/353 [00:03<00:00, 90.07it/s]


Epoch: 011, Average Loss: 0.2951
Validation: {'recall': 0.3500001514531627, 'precision': 0.4223372998676191, 'f1': 0.36757331657596953, 'auc': 0.7278943706043954, 'prauc': 0.4302380748739163}
Test:       {'recall': 0.3382528628201607, 'precision': 0.3929062584765063, 'f1': 0.35594749491859906, 'auc': 0.7343770650502126, 'prauc': 0.42165025960277014}


Training Batches:  48%|████▊     | 225/471 [00:08<00:09, 24.87it/s]

In [None]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")

In [None]:
# Assuming final_metrics and args are defined
# Create the results directory if it doesn't exist
os.makedirs(f"./results/{args['dataset']}", exist_ok=True)

# Define the file name
file_name = f"{args['dataset']}-{args['task']}.txt"
file_path = os.path.join(f"./results/{args['dataset']}", file_name)

# Save the metrics to a text file
with open(file_path, 'w') as f:
    f.write("Final Metrics:\n")
    for key in final_metrics.keys():
        mean_value = np.mean(final_metrics[key])
        std_value = np.std(final_metrics[key])
        f.write(f"{key}: {mean_value:.4f} ± {std_value:.4f}\n")