In [1]:
from set_seed_utils import set_random_seed
import os
import random
import numpy as np
import pickle
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from token_utils_rep import EHRTokenizer
from dataset_utils_rep import HBERTFinetuneEHRDataset, batcher, UniqueIDSampler
from HEART_rep import HBERT_Finetune
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve, precision_recall_fscore_support

ModuleNotFoundError: No module named 'typing_extensions'

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
@torch.no_grad()
def evaluate(model, dataloader, device, task_type="binary"):
    model.eval()
    predicted_scores, gt_labels = [], []
    # gt_labels is ground truth labels, predicted_scores is the output logits
    for _, batch in enumerate(tqdm(dataloader, desc="Running inference")):                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          
        batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]
        labels = batch[-1]
        output_logits = model(*batch[:-1])
        predicted_scores.append(output_logits)
        gt_labels.append(labels)
    
    if task_type == "binary":
        # standard binary classification evaluation
        predicted_scores = torch.cat(predicted_scores, dim=0).view(-1)
        gt_labels = torch.cat(gt_labels, dim=0).view(-1).cpu().numpy()
        scores = predicted_scores.cpu().numpy()      
    #   The threshold should be 0 because:
	# 	Your model outputs logits, not probabilities. logit > 0 ≡ probability > 0.5
        predicted_labels = (predicted_scores > 0).float().cpu().numpy()

        precision = (predicted_labels * gt_labels).sum() / (predicted_labels.sum() + 1e-8)
        recall = (predicted_labels * gt_labels).sum() / (gt_labels.sum() + 1e-8)
        f1 = 2 * precision * recall / (precision + recall + 1e-8)
        roc_auc = roc_auc_score(gt_labels, scores)
        precision_curve, recall_curve, _ = precision_recall_curve(gt_labels, scores)
        pr_auc = auc(recall_curve, precision_curve)

        return {"precision":precision, "recall":recall, "f1":f1, "auc":roc_auc, "prauc":pr_auc}
    else:
        # —— Multi-label classification evaluation (per-class over the batch) —— #
        # predicted_scores: list[Tensor] -> [B, C] logits
        logits = torch.cat(predicted_scores, dim=0)           # [B, C]
        y_true_t = torch.cat(gt_labels, dim=0)                # [B, C]
        # 1) 连续分数用于 AUC/PR-AUC：sigmoid(logits)
        #    CPU + float16 无 sigmoid 实现，先升到 fp32
        logits_for_sigmoid = logits.float() if (logits.device.type == "cpu" and logits.dtype == torch.float16) else logits
        prob_t = torch.sigmoid(logits_for_sigmoid)            # [B, C] in [0,1]

        # 2) 阈值化用于 P/R/F1：logits > 0 等价于 prob >= 0.5
        y_pred_t = (logits > 0).to(torch.int32)               # [B, C]

        # 转 numpy
        y_true = y_true_t.cpu().numpy().astype(np.int32)      # [B, C]
        y_pred = y_pred_t.cpu().numpy().astype(np.int32)      # [B, C]
        scores = prob_t.cpu().numpy()                         # [B, C] float

        # 3) 按类计算 Precision/Recall/F1
        p_cls, r_cls, f1_cls, _ = precision_recall_fscore_support(
            y_true, y_pred, average=None, zero_division=0
        )

        # 4) 按类计算 AUC/PR-AUC（遇到单一类别则设为 NaN）
        C = y_true.shape[1]
        aucs, praucs = [], []
        for c in range(C):
            yt, ys = y_true[:, c], scores[:, c]
            if yt.max() == yt.min():           # 全 0 或全 1，曲线不定义
                aucs.append(np.nan)
                praucs.append(np.nan)
                continue
            aucs.append(roc_auc_score(yt, ys))
            prec_curve, rec_curve, _ = precision_recall_curve(yt, ys)
            praucs.append(auc(rec_curve, prec_curve))

        # 5) 取宏平均（忽略 NaN）
        ave_precision = float(np.mean(p_cls))
        ave_recall    = float(np.mean(r_cls))
        ave_f1        = float(np.mean(f1_cls))
        ave_auc       = float(np.nanmean(aucs))   if np.any(~np.isnan(aucs))   else np.nan
        ave_prauc     = float(np.nanmean(praucs)) if np.any(~np.isnan(praucs)) else np.nan

        ave_f1, ave_auc, ave_prauc, ave_recall, ave_precision = np.mean(ave_f1), np.mean(ave_auc), np.mean(ave_prauc), np.mean(ave_recall), np.mean(ave_precision)
        return {"recall":ave_recall, "precision":ave_precision, "f1":ave_f1, "auc":ave_auc, "prauc":ave_prauc}

In [4]:
args = {
    "seed": 0,
    "dataset": "MIMIC-III", 
    "task": "next_diag_6m",  # options: death, stay, readmission, next_diag_6m, next_diag_12m
    "encoder": "hi_edge",  # options: hi_edge, hi_node, hi_edge_node
    "batch_size": 4,
    "eval_batch_size": 4,
    "pretrain_mask_rate": 0.7,
    "pretrain_anomaly_rate": 0.05,
    "pretrain_anomaly_loss_weight": 1,
    "pretrain_pos_weight": 1,
    "lr": 1e-4,
    "epochs": 50,
    "num_hidden_layers": 5,
    "num_attention_heads": 6,
    "attention_probs_dropout_prob": 0.2,
    "hidden_dropout_prob": 0.2,
    "edge_hidden_size": 32,
    "hidden_size": 288,  # must be divisible by num_attention_heads
    "intermediate_size": 288,
    "save_model": True,
    "gat": "dotattn",
    "gnn_n_heads": 1,
    "gnn_temp": 1,
    "diag_med_emb": "tree",  # simple, tree
    "early_stop_patience": 5,
}

In [5]:
exp_name = "Pretrain-HBERT" \
    + "-" + str(args["dataset"]) \
    + "-" + str(args["encoder"]) \
    + "-" + str(args["pretrain_mask_rate"]) \
    + "-" + str(args["pretrain_anomaly_rate"]) \
    + "-" + str(args["pretrain_anomaly_loss_weight"]) \
    + "-" + str(args["hidden_size"]) \
    + "-" + str(args["edge_hidden_size"]) \
    + "-" + str(args["num_hidden_layers"]) \
    + "-" + str(args["num_attention_heads"]) \
    + "-" + str(args["attention_probs_dropout_prob"]) \
    + "-" + str(args["hidden_dropout_prob"]) \
    + "-" + str(args["intermediate_size"]) \
    + "-" + str(args["gat"]) \
    + "-" + str(args["gnn_n_heads"]) \
    + "-" + str(args["gnn_temp"]) \
    + "-" + str(args["diag_med_emb"])
print(exp_name)

Pretrain-HBERT-MIMIC-IV-hi_edge-0.7-0.05-1-288-32-5-6-0.2-0.2-288-dotattn-1-1-tree


In [6]:
pretrained_weight_path = "./pretrained_models/" + exp_name + f"/pretrained_model.pt"
finetune_exp_name = f"Finetune-{args['task']}-" + exp_name
save_path = "./saved_model/" + finetune_exp_name
if args["save_model"] and not os.path.exists(save_path):
    os.makedirs(save_path)

In [7]:
args["predicted_token_type"] = ["diag", "med", "pro", "lab"]
args["special_tokens"] = ("[PAD]", "[CLS]", "[SEP]", "[MASK0]", "[MASK1]", "[MASK2]", "[MASK3]")
args["max_visit_size"] = 15

full_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic.pkl"

if args["task"] == "next_diag_6m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_6m.pkl"
elif args["task"] == "next_diag_12m":
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_downstream.pkl"

In [8]:
ehr_data = pickle.load(open(full_data_path, 'rb'))
diag_sentences = ehr_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_data["NDC"].values.tolist()
lab_sentences = ehr_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_data["PRO_CODE"].values.tolist()
gender_set = [["M"], ["F"]]
age_gender_set = [[str(c) + "_" + gender] for c in set(ehr_data["AGE"].values.tolist()) for gender in ["M", "F"]]
age_set = [[c] for c in set(ehr_data["AGE"].values.tolist())]    

In [9]:
tokenizer = EHRTokenizer(diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, gender_set, age_set, age_gender_set, special_tokens=args["special_tokens"])
tokenizer.build_tree()

In [10]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))

train_dataset = HBERTFinetuneEHRDataset(
    train_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

val_dataset = HBERTFinetuneEHRDataset(
    val_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

test_dataset = HBERTFinetuneEHRDataset(
    test_data, tokenizer, 
    token_type=args["predicted_token_type"], 
    task=args["task"]
)

train_dataloader = DataLoader(
    train_dataset, 
    batch_sampler=UniqueIDSampler(train_dataset.get_ids(), batch_size=args["batch_size"]),
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False), 
)

val_dataloader = DataLoader(
    val_dataset, 
    batch_sampler=UniqueIDSampler(val_dataset.get_ids(), batch_size=args["batch_size"]),
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False), 
)

test_dataloader = DataLoader(
    test_dataset, 
    batch_sampler=UniqueIDSampler(test_dataset.get_ids(), batch_size=args["eval_batch_size"]),
    collate_fn=batcher(pad_id=tokenizer.vocab.word2id["[PAD]"], is_train=False),
)

In [11]:
# examine a batch
batch = next(iter(train_dataloader))  # 取第一个 batch
input_ids, input_types, edge_index, visit_positions, labeled_batch_idx, labels = batch

# 打印每个张量的形状
print("input_ids shape:", input_ids.shape)
print("input_types shape:", input_types.shape)
print("visit_positions shape:", visit_positions.shape)
print("labeled_batch_idx shape:", len(labeled_batch_idx)) # it is a list
print("labels shape:", labels.shape)

input_ids shape: torch.Size([13, 73])
input_types shape: torch.Size([13, 73])
visit_positions shape: torch.Size([13])
labeled_batch_idx shape: 4
labels shape: torch.Size([4, 18])


In [12]:
args["vocab_size"] = len(args["special_tokens"]) + \
                     len(tokenizer.diag_voc.id2word) + \
                     len(tokenizer.pro_voc.id2word) + \
                     len(tokenizer.med_voc.id2word) + \
                     len(tokenizer.lab_voc.id2word) + \
                     len(tokenizer.age_voc.id2word) + \
                     len(tokenizer.gender_voc.id2word) + \
                     len(tokenizer.age_gender_voc.id2word)
args["label_vocab_size"] = 18  # only for diagnosis

In [13]:
if args["task"] in ["death", "stay", "readmission"]:
    eval_metric = "f1"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "prauc"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [14]:
def train_with_early_stopping(model, 
                              train_dataloader, 
                              val_dataloader, 
                              test_dataloader,
                              optimizer, 
                              loss_fn, 
                              device, 
                              args,
                              task_type="binary", 
                              eval_metric="f1"):
    best_score = 0.
    best_val_metric = None
    best_test_metric = None
    epochs_no_improve = 0

    for epoch in range(1, 1 + args["epochs"]):
        model.train()
        ave_loss = 0.

        for step, batch in enumerate(tqdm(train_dataloader, desc="Training Batches")):
            batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]

            labels = batch[-1].float()
            output_logits = model(*batch[:-1])
            
            loss = loss_fn(output_logits.view(-1), labels.view(-1))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            ave_loss += loss.item()
            

        ave_loss /= (step + 1)

        # Evaluation
        val_metric = evaluate(model, val_dataloader, device, task_type=task_type)
        test_metric = evaluate(model, test_dataloader, device, task_type=task_type)

        # Logging
        print(f"Epoch: {epoch:03d}, Average Loss: {ave_loss:.4f}")
        print(f"Validation: {val_metric}")
        print(f"Test:       {test_metric}")

        # Check for improvement
        current_score = val_metric[eval_metric]
        if current_score > best_score:
            best_score = current_score
            best_val_metric = val_metric
            best_test_metric = test_metric
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        # Early stopping check
        if epochs_no_improve >= args["early_stop_patience"]:
            print(f"\nEarly stopping triggered after {epoch} epochs (no improvement for {args['early_stop_patience']} epochs).")
            break

    print("\nBest validation performance:")
    print(best_val_metric)
    print("Corresponding test performance:")
    print(best_test_metric)
    return best_test_metric

In [15]:
random.seed(42)
seeds = [random.randint(0, 2**32 - 1) for _ in range(5)]
print(seeds)

[2746317213, 1181241943, 958682846, 3163119785, 1812140441]


In [None]:
final_metrics = {"recall":[], "precision":[],"f1":[],"auc":[],"prauc":[]}

for seed in seeds:
    args["seed"] = seed
    set_random_seed(args["seed"])
    print(f"Training with seed: {args['seed']}")
    
    # Initialize model, optimizer, and loss function
    model = HBERT_Finetune(args, tokenizer)
    model.load_weight(torch.load(pretrained_weight_path, weights_only=True))
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args["lr"])
    
    best_test_metric = train_with_early_stopping(
        model, 
        train_dataloader, 
        val_dataloader, 
        test_dataloader,
        optimizer, 
        loss_fn, 
        device, 
        args,
        task_type=task_type)
    
    for key in final_metrics.keys():
        final_metrics[key].append(best_test_metric[key])

[INFO] Random seed set to 2746317213
Training with seed: 2746317213


Training Batches: 100%|██████████| 1713/1713 [00:55<00:00, 30.95it/s]
Running inference: 100%|██████████| 1262/1262 [00:11<00:00, 114.67it/s]
Running inference: 100%|██████████| 1266/1266 [00:11<00:00, 111.56it/s]


Epoch: 001, Average Loss: 0.3197
Validation: {'recall': 0.3007773923330686, 'precision': 0.3647787645354272, 'f1': 0.31544565347330966, 'auc': 0.7379152556979457, 'prauc': 0.39525755178521377}
Test:       {'recall': 0.30522303924297467, 'precision': 0.3602773560554896, 'f1': 0.31732646321009594, 'auc': 0.7384882780284925, 'prauc': 0.39626206574251693}


Training Batches: 100%|██████████| 1713/1713 [00:54<00:00, 31.22it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 118.17it/s]
Running inference: 100%|██████████| 1266/1266 [00:11<00:00, 112.24it/s]


Epoch: 002, Average Loss: 0.2909
Validation: {'recall': 0.34462728008663607, 'precision': 0.42177981710830026, 'f1': 0.3640686838751352, 'auc': 0.7642106694954057, 'prauc': 0.42730383828134477}
Test:       {'recall': 0.3456495380746115, 'precision': 0.41852354687854687, 'f1': 0.36178888532328696, 'auc': 0.7620042701046204, 'prauc': 0.4261701762971364}


Training Batches: 100%|██████████| 1713/1713 [00:54<00:00, 31.19it/s]
Running inference: 100%|██████████| 1262/1262 [00:11<00:00, 113.74it/s]
Running inference: 100%|██████████| 1266/1266 [00:11<00:00, 110.46it/s]


Epoch: 003, Average Loss: 0.2815
Validation: {'recall': 0.3558329504502366, 'precision': 0.43454657183462114, 'f1': 0.3769541027718112, 'auc': 0.7647197259836899, 'prauc': 0.42844325800134797}
Test:       {'recall': 0.3574212052604057, 'precision': 0.4317759474113656, 'f1': 0.3763816507957056, 'auc': 0.7657669417546761, 'prauc': 0.43146566065169045}


Training Batches: 100%|██████████| 1713/1713 [00:55<00:00, 31.02it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 115.34it/s]
Running inference: 100%|██████████| 1266/1266 [00:11<00:00, 110.34it/s]


Epoch: 004, Average Loss: 0.2750
Validation: {'recall': 0.35205846633456384, 'precision': 0.426714887118372, 'f1': 0.37646728338881424, 'auc': 0.7760638242644452, 'prauc': 0.4388323267688841}
Test:       {'recall': 0.3555217877085291, 'precision': 0.42857926348916614, 'f1': 0.3772825971604029, 'auc': 0.7742477496121296, 'prauc': 0.43794055194227727}


Training Batches: 100%|██████████| 1713/1713 [00:55<00:00, 30.84it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 115.42it/s]
Running inference: 100%|██████████| 1266/1266 [00:11<00:00, 111.52it/s]


Epoch: 005, Average Loss: 0.2702
Validation: {'recall': 0.36041870019201994, 'precision': 0.43167200791479554, 'f1': 0.3802747570594262, 'auc': 0.7679324183564604, 'prauc': 0.43825098796504436}
Test:       {'recall': 0.3562345159209734, 'precision': 0.42559682484177014, 'f1': 0.37440983524979754, 'auc': 0.7720378632118867, 'prauc': 0.4356890849898309}


Training Batches: 100%|██████████| 1713/1713 [00:55<00:00, 31.04it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 116.15it/s]
Running inference: 100%|██████████| 1266/1266 [00:11<00:00, 111.44it/s]


Epoch: 006, Average Loss: 0.2635
Validation: {'recall': 0.35781693056505026, 'precision': 0.43150604685184696, 'f1': 0.3810295591315545, 'auc': 0.7653517136286047, 'prauc': 0.4327044978536449}
Test:       {'recall': 0.361496659025029, 'precision': 0.4315416516169055, 'f1': 0.38238085516873777, 'auc': 0.7704510990885173, 'prauc': 0.4381393685292376}


Training Batches: 100%|██████████| 1713/1713 [00:54<00:00, 31.15it/s]
Running inference: 100%|██████████| 1262/1262 [00:11<00:00, 114.69it/s]
Running inference: 100%|██████████| 1266/1266 [00:11<00:00, 110.84it/s]


Epoch: 007, Average Loss: 0.2588
Validation: {'recall': 0.35271138335247054, 'precision': 0.42019761362582364, 'f1': 0.3767298563274455, 'auc': 0.7671503288858815, 'prauc': 0.4327465074553387}
Test:       {'recall': 0.35348601095054516, 'precision': 0.41668360311590835, 'f1': 0.37440129232861313, 'auc': 0.7682753853961732, 'prauc': 0.4302674802184633}


Training Batches: 100%|██████████| 1713/1713 [00:55<00:00, 30.91it/s]
Running inference: 100%|██████████| 1262/1262 [00:11<00:00, 114.48it/s]
Running inference: 100%|██████████| 1266/1266 [00:11<00:00, 112.78it/s]


Epoch: 008, Average Loss: 0.2533
Validation: {'recall': 0.35805268184817346, 'precision': 0.43135345637115013, 'f1': 0.385709356906013, 'auc': 0.7699442689857897, 'prauc': 0.439661821177309}
Test:       {'recall': 0.35847844716614446, 'precision': 0.45787639873732516, 'f1': 0.3836960949581462, 'auc': 0.7727912168690695, 'prauc': 0.43846818297941736}


Training Batches: 100%|██████████| 1713/1713 [00:55<00:00, 30.97it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 115.90it/s]
Running inference: 100%|██████████| 1266/1266 [00:11<00:00, 111.91it/s]


Epoch: 009, Average Loss: 0.2481
Validation: {'recall': 0.38699107437366975, 'precision': 0.40816316738166614, 'f1': 0.3913316115379248, 'auc': 0.7681913714568831, 'prauc': 0.43365787321627497}
Test:       {'recall': 0.3882809267732091, 'precision': 0.4685259184119012, 'f1': 0.39247000519484604, 'auc': 0.7725057718138274, 'prauc': 0.4333600073937445}


Training Batches: 100%|██████████| 1713/1713 [00:55<00:00, 30.86it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 115.78it/s]
Running inference: 100%|██████████| 1266/1266 [00:11<00:00, 111.55it/s]


Epoch: 010, Average Loss: 0.2423
Validation: {'recall': 0.38109675661980513, 'precision': 0.4790821379827234, 'f1': 0.39446188063947557, 'auc': 0.7657170625623232, 'prauc': 0.4343072929419598}
Test:       {'recall': 0.3835344790605882, 'precision': 0.48146741575583857, 'f1': 0.3930246352338243, 'auc': 0.7654567320226501, 'prauc': 0.4326165998060909}


Training Batches: 100%|██████████| 1713/1713 [00:54<00:00, 31.29it/s]
Running inference: 100%|██████████| 1262/1262 [00:11<00:00, 112.75it/s]
Running inference: 100%|██████████| 1266/1266 [00:11<00:00, 111.74it/s]


Epoch: 011, Average Loss: 0.2383
Validation: {'recall': 0.35912179539858424, 'precision': 0.4612548052993682, 'f1': 0.3809270807490437, 'auc': 0.7651646314656546, 'prauc': 0.43329227497228456}
Test:       {'recall': 0.36001610787494376, 'precision': 0.4738797946967091, 'f1': 0.3784891328011383, 'auc': 0.7653308650373867, 'prauc': 0.4317126943620484}


Training Batches: 100%|██████████| 1713/1713 [00:55<00:00, 30.85it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 115.46it/s]
Running inference: 100%|██████████| 1266/1266 [00:11<00:00, 111.76it/s]


Epoch: 012, Average Loss: 0.2325
Validation: {'recall': 0.37515792557443, 'precision': 0.5442195441965838, 'f1': 0.3953206498139783, 'auc': 0.7702987274128172, 'prauc': 0.43503351640208243}
Test:       {'recall': 0.3731121880115903, 'precision': 0.47901539618953554, 'f1': 0.3893433894590357, 'auc': 0.7673773658225783, 'prauc': 0.431668716969454}


Training Batches: 100%|██████████| 1713/1713 [00:54<00:00, 31.23it/s]
Running inference: 100%|██████████| 1262/1262 [00:10<00:00, 116.50it/s]
Running inference: 100%|██████████| 1266/1266 [00:11<00:00, 111.44it/s]


Epoch: 013, Average Loss: 0.2267
Validation: {'recall': 0.3686892741168608, 'precision': 0.46958043760074347, 'f1': 0.39034121281230993, 'auc': 0.7667904508434823, 'prauc': 0.4307884639652753}
Test:       {'recall': 0.3692808442925225, 'precision': 0.4660526714797797, 'f1': 0.38839417240864205, 'auc': 0.7662438603683294, 'prauc': 0.4291117166106179}


Training Batches:  21%|██        | 359/1713 [00:11<00:38, 35.02it/s]

In [None]:
# print the mean and std of the final metrics
print("\nFinal Metrics:")
for key in final_metrics.keys():
    mean_value = np.mean(final_metrics[key])
    std_value = np.std(final_metrics[key])
    print(f"{key}: {mean_value:.4f} ± {std_value:.4f}")

In [None]:
# Assuming final_metrics and args are defined
# Create the results directory if it doesn't exist
os.makedirs(f"./results/{args['dataset']}", exist_ok=True)

# Define the file name
file_name = f"{args['dataset']}-{args['task']}.txt"
file_path = os.path.join(f"./results/{args['dataset']}", file_name)

# Save the metrics to a text file
with open(file_path, 'w') as f:
    f.write("Final Metrics:\n")
    for key in final_metrics.keys():
        mean_value = np.mean(final_metrics[key])
        std_value = np.std(final_metrics[key])
        f.write(f"{key}: {mean_value:.4f} ± {std_value:.4f}\n")