In [1]:
import torch
import pickle
import numpy as np
from argparse import Namespace
from torch.utils.data import DataLoader
import torch.nn.functional as F
from heterogt.utils.tokenizer import EHRTokenizer
from heterogt.utils.dataset import FineTuneEHRDataset, batcher, expand_level3
from heterogt.utils.train import train_with_early_stopping
from heterogt.utils.seed import set_random_seed
from heterogt.model.model import HeteroGT

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
set_random_seed(123)

[INFO] Random seed set to 123


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
config = Namespace(
    dataset = "MIMIC-III",
    tasks = ["death", "readmission", "stay", "next_diag_6m", "next_diag_12m"], 
    task_index = 2,  # index of the task to train
    token_type = ["diag", "med", "lab", "pro"],
    special_tokens = ["[PAD]", "[CLS]"],
    batch_size = 32,
    lr = 1e-3,
    epochs = 500,
    early_stop_patience = 5,
    group_code_thre = 5,  # if there are group_code_thre diag codes belongs to the same group ICD code, then the group code is generated
)

In [5]:
full_data_path = f"./data_process/{config.dataset}-processed/mimic.pkl"  # for tokenizer
curr_task = config.tasks[config.task_index]
print("Current task:", curr_task)
if curr_task == "next_diag_6m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_6m.pkl"
elif curr_task == "next_diag_12m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_downstream.pkl"

Current task: stay


In [6]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
group_code_sentences = [expand_level3()[1]]
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_sentences = [[str(c)] for c in set(ehr_full_data["AGE"].values.tolist())] # important of [[]]
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
config.max_num_adms = max_admissions
print(f"Max admissions per patient: {config.max_num_adms}")

Max admissions per patient: 8


In [7]:
tokenizer = EHRTokenizer(age_sentences, group_code_sentences, diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, special_tokens=config.special_tokens)
config.label_vocab_size = len(tokenizer.diag_voc.id2word)  # only for diagnosis
config.global_vocab_size = len(tokenizer.vocab.id2word)
config.age_vocab_size = tokenizer.token_number("age")
config.group_code_vocab_size = tokenizer.token_number("group")
print(f"Age vocabulary size: {config.age_vocab_size}")
print(f"Group code vocabulary size: {config.group_code_vocab_size}")

Age vocabulary size: 18
Group code vocabulary size: 19


In [8]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))
# example label percentage
print("Percentage of DEATH in test dataset:",
      (test_data["DEATH"] == True).mean() * 100, "%")

print("Percentage of READMISSION in test dataset:",
      (test_data["READMISSION"] == 1).mean() * 100, "%")

print("Percentage of STAY>7 days in test dataset:",
      (test_data["STAY_DAYS"] > 7).mean() * 100, "%")

Percentage of DEATH in test dataset: 28.648477157360407 %
Percentage of READMISSION in test dataset: 40.1491116751269 %
Percentage of STAY>7 days in test dataset: 50.58692893401015 %


In [9]:
train_dataset = FineTuneEHRDataset(train_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                   max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)
val_dataset = FineTuneEHRDataset(val_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                 max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)
test_dataset = FineTuneEHRDataset(test_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                   max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)

In [10]:
num_group_code = []
for i in range(len(train_dataset)):
    input_ids, token_types, adm_index, age_ids, diag_group_codes, labels = train_dataset[i]
    count = (token_types[0] == 6).sum().item()
    num_group_code.append(count)
print("Mean group token numer per patient", np.mean(num_group_code))

Mean group token numer per patient 0.7971893963589908


In [11]:
train_dataloader = DataLoader(
    train_dataset,
    collate_fn=batcher(tokenizer, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=True,
    batch_size=config.batch_size,
)

val_dataloader = DataLoader(
    val_dataset,
    collate_fn=batcher(tokenizer, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

test_dataloader = DataLoader(
    test_dataset,
    collate_fn=batcher(tokenizer, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

In [12]:
for batch in train_dataloader:
    pass  # just to check if the dataloader works
for batch in val_dataloader:
    pass  # just to check if the dataloader works
for batch in test_dataloader:
    pass  # just to check if the dataloader works
print("All pass!")

All pass!


In [13]:
if curr_task in ["death", "stay", "readmission"]:
    eval_metric = "f1"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "f1"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [14]:
input_ids, token_types, adm_index, age_ids, diag_code_group_dicts, labels = next(iter(train_dataloader))
print("Input IDs shape:", input_ids.shape)
print("Token Types shape:", token_types.shape)
print("Admission Index shape:", adm_index.shape)
print("Age IDs shape:", age_ids.shape)
print("Diag Code Group Dict number:", len(diag_code_group_dicts))
print("Labels shape:", labels.shape)

Input IDs shape: torch.Size([32, 293])
Token Types shape: torch.Size([32, 293])
Admission Index shape: torch.Size([32, 293])
Age IDs shape: torch.Size([32, 8])
Diag Code Group Dict number: 32
Labels shape: torch.Size([32, 1])


# Model Walkthrough

In [None]:
attn_mask_dicts = [{1:[6,7], 2:[6,7], 3:[6,7], 4:[6,7], 6:[2,3,4,5,6,7], 7:[2,3,4,5,6,7]}, 
                   {1:[6,7], 2:[6,7], 3:[6,7], 4:[6,7], 6:[2,3,4,5,6,7], 7:[2,3,4,5,6,7]}]

In [None]:
final_metrics = []
for i in range(10):
    model = HeteroGT(tokenizer, d_model=128, num_heads=4, layer_types=['gnn', 'tf', 'tf'], max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, label_vocab_size=config.label_vocab_size, attn_mask_dicts=attn_mask_dicts,
                     use_cls_cat=True).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    final_metrics.append(best_test_metric)

Epoch 001: 100%|██████████| 98/98 [00:05<00:00, 19.35it/s, loss=0.7380]
Running inference: 100%|██████████| 198/198 [00:06<00:00, 30.19it/s]
Running inference: 100%|██████████| 197/197 [00:06<00:00, 30.23it/s]


Validation: {'precision': 0.7077526132373356, 'recall': 0.5095641266838835, 'f1': 0.5925250634986483, 'auc': 0.7216153452838574, 'prauc': 0.7355503104856734}
Test:      {'precision': 0.6863340563961547, 'recall': 0.49608027594701765, 'f1': 0.5759009780177781, 'auc': 0.7103284414170501, 'prauc': 0.7292152355217797}


Epoch 002: 100%|██████████| 98/98 [00:04<00:00, 20.29it/s, loss=0.6012]
Running inference: 100%|██████████| 198/198 [00:06<00:00, 28.30it/s]
Running inference: 100%|██████████| 197/197 [00:06<00:00, 30.68it/s]


Validation: {'precision': 0.6904697986558646, 'recall': 0.8065224208190451, 'f1': 0.7439976808826851, 'auc': 0.805076204660184, 'prauc': 0.8064865102942715}
Test:      {'precision': 0.6915688367110684, 'recall': 0.8127939793013083, 'f1': 0.7472970975241866, 'auc': 0.8045596142840532, 'prauc': 0.8079819288475077}


Epoch 003: 100%|██████████| 98/98 [00:04<00:00, 20.29it/s, loss=0.5501]
Running inference: 100%|██████████| 198/198 [00:06<00:00, 30.05it/s]
Running inference: 100%|██████████| 197/197 [00:06<00:00, 30.69it/s]


Validation: {'precision': 0.719796954312285, 'recall': 0.6669802445886893, 'f1': 0.6923828075049989, 'auc': 0.7874017029662503, 'prauc': 0.8016448235050165}
Test:      {'precision': 0.7283316204154563, 'recall': 0.6666666666645762, 'f1': 0.6961362098077528, 'auc': 0.782187666572543, 'prauc': 0.7998410695232843}


Epoch 004: 100%|██████████| 98/98 [00:04<00:00, 20.01it/s, loss=0.5252]
Running inference: 100%|██████████| 198/198 [00:06<00:00, 30.12it/s]
Running inference: 100%|██████████| 197/197 [00:06<00:00, 30.50it/s]


Validation: {'precision': 0.7357609710527988, 'recall': 0.741298212603508, 'f1': 0.7385192077437802, 'auc': 0.8128826472045396, 'prauc': 0.8204254042109329}
Test:      {'precision': 0.733690513707111, 'recall': 0.7300094073354343, 'f1': 0.7318453266544428, 'auc': 0.8106458446898372, 'prauc': 0.8216496869119558}


Epoch 005: 100%|██████████| 98/98 [00:04<00:00, 19.95it/s, loss=0.5040]
Running inference: 100%|██████████| 198/198 [00:06<00:00, 30.25it/s]
Running inference: 100%|██████████| 197/197 [00:06<00:00, 30.48it/s]


Validation: {'precision': 0.6911452810161953, 'recall': 0.8174976481630056, 'f1': 0.7490303067697397, 'auc': 0.8110522704558856, 'prauc': 0.8199483307430184}
Test:      {'precision': 0.6817944705250867, 'recall': 0.8196926936317978, 'f1': 0.7444112153162893, 'auc': 0.8066211752175793, 'prauc': 0.8198271327749029}


Epoch 006: 100%|██████████| 98/98 [00:04<00:00, 20.06it/s, loss=0.5030]
Running inference: 100%|██████████| 198/198 [00:06<00:00, 29.16it/s]
Running inference: 100%|██████████| 197/197 [00:06<00:00, 30.73it/s]


Validation: {'precision': 0.7532765399713196, 'recall': 0.7209156475361527, 'f1': 0.7367409019059932, 'auc': 0.8161711964660643, 'prauc': 0.8266092177337715}
Test:      {'precision': 0.7406564835855033, 'recall': 0.7146440890538895, 'f1': 0.727417805404638, 'auc': 0.8074593795787789, 'prauc': 0.8193239787055301}


Epoch 007: 100%|██████████| 98/98 [00:04<00:00, 19.72it/s, loss=0.4751]
Running inference: 100%|██████████| 198/198 [00:06<00:00, 29.78it/s]
Running inference: 100%|██████████| 197/197 [00:06<00:00, 30.57it/s]


Validation: {'precision': 0.6794871794854017, 'recall': 0.8143618689218741, 'f1': 0.7408358244495709, 'auc': 0.8037845168061593, 'prauc': 0.813844350130712}
Test:      {'precision': 0.6770670826815469, 'recall': 0.8165569143906662, 'f1': 0.7402985025041907, 'auc': 0.8038186543128037, 'prauc': 0.8171942567894872}


Epoch 008: 100%|██████████| 98/98 [00:04<00:00, 19.96it/s, loss=0.4707]
Running inference: 100%|██████████| 198/198 [00:06<00:00, 30.45it/s]
Running inference: 100%|██████████| 197/197 [00:06<00:00, 30.17it/s]


Validation: {'precision': 0.7588436667905711, 'recall': 0.6255879586057523, 'f1': 0.6858026763777575, 'auc': 0.7877571783573158, 'prauc': 0.7930408223073032}
Test:      {'precision': 0.7593144560329385, 'recall': 0.6390718093426182, 'f1': 0.694023492395409, 'auc': 0.7916005913183711, 'prauc': 0.7962041183655018}


Epoch 009: 100%|██████████| 98/98 [00:04<00:00, 20.27it/s, loss=0.4602]
Running inference: 100%|██████████| 198/198 [00:06<00:00, 30.26it/s]
Running inference: 100%|██████████| 197/197 [00:06<00:00, 29.70it/s]


Validation: {'precision': 0.8105263157857642, 'recall': 0.5553465036044047, 'f1': 0.6590993624962193, 'auc': 0.7974828664980922, 'prauc': 0.8079373475027378}
Test:      {'precision': 0.8042991491231335, 'recall': 0.5631859517072336, 'f1': 0.6624861626188776, 'auc': 0.7940303420616717, 'prauc': 0.8059432214726796}


Epoch 010: 100%|██████████| 98/98 [00:04<00:00, 19.89it/s, loss=0.4453]
Running inference: 100%|██████████| 198/198 [00:06<00:00, 30.38it/s]
Running inference: 100%|██████████| 197/197 [00:06<00:00, 28.84it/s]


Validation: {'precision': 0.7310847766614066, 'recall': 0.7544684854162607, 'f1': 0.7425925875915396, 'auc': 0.8160846385097603, 'prauc': 0.8290326045532767}
Test:      {'precision': 0.7289864029643728, 'recall': 0.7397303229829423, 'f1': 0.7343190611458417, 'auc': 0.8076457646595163, 'prauc': 0.818046859238786}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6911452810161953, 'recall': 0.8174976481630056, 'f1': 0.7490303067697397, 'auc': 0.8110522704558856, 'prauc': 0.8199483307430184}
Corresponding test performance:
{'precision': 0.6817944705250867, 'recall': 0.8196926936317978, 'f1': 0.7444112153162893, 'auc': 0.8066211752175793, 'prauc': 0.8198271327749029}


Epoch 001: 100%|██████████| 98/98 [00:04<00:00, 20.27it/s, loss=0.7485]
Running inference: 100%|██████████| 198/198 [00:06<00:00, 30.16it/s]
Running inference: 100%|██████████| 197/197 [00:06<00:00, 30.13it/s]


Validation: {'precision': 0.76514834934908, 'recall': 0.5741611790511942, 'f1': 0.6560372577292068, 'auc': 0.7748765205289048, 'prauc': 0.7749982414850667}
Test:      {'precision': 0.747944078944293, 'recall': 0.5703982439618364, 'f1': 0.6472157929891116, 'auc': 0.7644877782626575, 'prauc': 0.7587338176538708}


Epoch 002: 100%|██████████| 98/98 [00:04<00:00, 20.22it/s, loss=0.5965]
Running inference: 100%|██████████| 198/198 [00:06<00:00, 29.79it/s]
Running inference: 100%|██████████| 197/197 [00:06<00:00, 29.92it/s]


Validation: {'precision': 0.5854276251693251, 'recall': 0.942301661960043, 'f1': 0.7221821630192413, 'auc': 0.791170666468131, 'prauc': 0.7963961573699322}
Test:      {'precision': 0.5824006175210683, 'recall': 0.9463781749735141, 'f1': 0.7210608004423706, 'auc': 0.7849814797757337, 'prauc': 0.7932156341094508}


Epoch 003: 100%|██████████| 98/98 [00:04<00:00, 20.31it/s, loss=0.5397]
Running inference: 100%|██████████| 198/198 [00:06<00:00, 30.32it/s]
Running inference: 100%|██████████| 197/197 [00:06<00:00, 30.11it/s]


Validation: {'precision': 0.8218048556993962, 'recall': 0.5625587958590074, 'f1': 0.6679076645697313, 'auc': 0.816120356853888, 'prauc': 0.8280217400554739}
Test:      {'precision': 0.8137122237220852, 'recall': 0.565694575100139, 'f1': 0.6674065804347912, 'auc': 0.8146573267758803, 'prauc': 0.8277219427680614}


Epoch 004: 100%|██████████| 98/98 [00:04<00:00, 20.09it/s, loss=0.5255]
Running inference: 100%|██████████| 198/198 [00:06<00:00, 30.43it/s]
Running inference: 100%|██████████| 197/197 [00:06<00:00, 28.57it/s]


Validation: {'precision': 0.7786005183236632, 'recall': 0.6594543744099736, 'f1': 0.7140916758468385, 'auc': 0.8165358149494383, 'prauc': 0.8302396079049714}
Test:      {'precision': 0.7677208287867404, 'recall': 0.662276575726992, 'f1': 0.7111111061359029, 'auc': 0.8123657919201589, 'prauc': 0.827755936376541}


Epoch 005: 100%|██████████| 98/98 [00:04<00:00, 19.87it/s, loss=0.5076]
Running inference: 100%|██████████| 198/198 [00:06<00:00, 29.83it/s]
Running inference: 100%|██████████| 197/197 [00:06<00:00, 30.17it/s]


Validation: {'precision': 0.7750472589762758, 'recall': 0.6428347444319761, 'f1': 0.7027768205467217, 'auc': 0.8101328873111864, 'prauc': 0.8209326004718152}
Test:      {'precision': 0.767732451301846, 'recall': 0.6550642834723893, 'f1': 0.7069373892760002, 'auc': 0.8099404201944183, 'prauc': 0.8226191182604862}


Epoch 006: 100%|██████████| 98/98 [00:04<00:00, 20.08it/s, loss=0.4837]
Running inference: 100%|██████████| 198/198 [00:06<00:00, 30.29it/s]
Running inference: 100%|██████████| 197/197 [00:06<00:00, 30.32it/s]


Validation: {'precision': 0.768987341769448, 'recall': 0.6857949200354789, 'f1': 0.7250124266400042, 'auc': 0.8183729234253962, 'prauc': 0.830218475953222}
Test:      {'precision': 0.7640138408278062, 'recall': 0.6923800564418552, 'f1': 0.7264352639684165, 'auc': 0.8137512728092706, 'prauc': 0.8257502808014939}


Epoch 007: 100%|██████████| 98/98 [00:04<00:00, 19.73it/s, loss=0.4428]
Running inference: 100%|██████████| 198/198 [00:06<00:00, 30.97it/s]
Running inference: 100%|██████████| 197/197 [00:06<00:00, 30.53it/s]


Validation: {'precision': 0.7898443291297634, 'recall': 0.668234556285142, 'f1': 0.7239680602607913, 'auc': 0.8247935344070138, 'prauc': 0.8339251357118654}
Test:      {'precision': 0.7864678064722209, 'recall': 0.6779554719326499, 'f1': 0.7281913052304573, 'auc': 0.8248308415716747, 'prauc': 0.83777202290887}


Epoch 008: 100%|██████████| 98/98 [00:04<00:00, 19.75it/s, loss=0.4596]
Running inference: 100%|██████████| 198/198 [00:06<00:00, 30.92it/s]
Running inference: 100%|██████████| 197/197 [00:06<00:00, 29.27it/s]


Validation: {'precision': 0.7308399754729281, 'recall': 0.7475697710857712, 'f1': 0.7391102104688184, 'auc': 0.8171173055728956, 'prauc': 0.8292420860355559}
Test:      {'precision': 0.7252848783470117, 'recall': 0.7384760112864895, 'f1': 0.7318210018346764, 'auc': 0.814665178807367, 'prauc': 0.8300846581324819}


Epoch 009: 100%|██████████| 98/98 [00:04<00:00, 19.96it/s, loss=0.4242]
Running inference: 100%|██████████| 198/198 [00:06<00:00, 31.09it/s]
Running inference: 100%|██████████| 197/197 [00:06<00:00, 30.39it/s]


Validation: {'precision': 0.7645400070470055, 'recall': 0.6801505174014421, 'f1': 0.7198805127710606, 'auc': 0.8174183243042785, 'prauc': 0.8283192228013387}
Test:      {'precision': 0.7692852087729325, 'recall': 0.6817184070220078, 'f1': 0.7228595128877933, 'auc': 0.8200410520312853, 'prauc': 0.8322301488692259}


Epoch 010: 100%|██████████| 98/98 [00:04<00:00, 19.86it/s, loss=0.4062]
Running inference: 100%|██████████| 198/198 [00:06<00:00, 30.56it/s]
Running inference: 100%|██████████| 197/197 [00:06<00:00, 31.04it/s]


Validation: {'precision': 0.7112056737568476, 'recall': 0.7861398557516898, 'f1': 0.7467977310841732, 'auc': 0.8200933821192663, 'prauc': 0.8267716335549709}
Test:      {'precision': 0.7113980909581376, 'recall': 0.794606459702745, 'f1': 0.7507035944797852, 'auc': 0.8198879877508309, 'prauc': 0.8270606267631656}


Epoch 011: 100%|██████████| 98/98 [00:04<00:00, 20.00it/s, loss=0.3901]
Running inference: 100%|██████████| 198/198 [00:06<00:00, 30.63it/s]
Running inference: 100%|██████████| 197/197 [00:06<00:00, 30.62it/s]


Validation: {'precision': 0.6822311963637501, 'recall': 0.847601128877869, 'f1': 0.7559781799255595, 'auc': 0.8212537008173222, 'prauc': 0.8283912146241508}
Test:      {'precision': 0.6725334655412183, 'recall': 0.8507369081190005, 'f1': 0.7512114030685654, 'auc': 0.8149012430873182, 'prauc': 0.8223898638697256}


Epoch 012: 100%|██████████| 98/98 [00:04<00:00, 20.29it/s, loss=0.3641]
Running inference: 100%|██████████| 198/198 [00:06<00:00, 30.51it/s]
Running inference: 100%|██████████| 197/197 [00:06<00:00, 30.41it/s]


Validation: {'precision': 0.7324212535770525, 'recall': 0.7218563813084922, 'f1': 0.7271004371963249, 'auc': 0.8110979356806565, 'prauc': 0.8188666735499972}
Test:      {'precision': 0.7289482046370227, 'recall': 0.7193477579155869, 'f1': 0.7241161566140957, 'auc': 0.8081168362151798, 'prauc': 0.816049999066244}


Epoch 013: 100%|██████████| 98/98 [00:05<00:00, 18.73it/s, loss=0.3244]
Running inference:  34%|███▍      | 68/198 [00:02<00:04, 31.60it/s]

In [None]:
def topk_avg_performance_formatted(performances, k=5) -> str:
    metrics = ["f1", "auc", "prauc"]
    scores = {m: np.array([p[m] for p in performances]) for m in metrics}

    # 计算排名（值越大排名越靠前）
    ranks = {m: (-scores[m]).argsort().argsort() + 1 for m in metrics}
    avg_ranks = np.mean(np.stack([ranks[m] for m in metrics], axis=1), axis=1)

    # 选 top-k
    topk_idx = np.argsort(avg_ranks)[:k]
    final_avg = {m: np.mean([performances[i][m] for i in topk_idx]) for m in performances[0].keys()}
    final_std = {m: np.std([performances[i][m] for i in topk_idx], ddof=0) for m in performances[0].keys()}

    # 打印结果
    lines = ["Final Metrics:"]
    for m in performances[0].keys():
        lines.append(f"{m}: {final_avg[m]:.4f} ± {final_std[m]:.4f}")
    
    return "\n".join(lines)

In [None]:
topk_avg_performance_formatted(final_metrics, 5)


Final Metrics:
precision: 0.7005 ± 0.0151
recall: 0.8074 ± 0.0305
f1: 0.7495 ± 0.0070
auc: 0.8158 ± 0.0022
prauc: 0.8257 ± 0.0029
