In [1]:
import torch
import pickle
import numpy as np
from argparse import Namespace
from torch.utils.data import DataLoader
import torch.nn.functional as F
from heterogt.utils.tokenizer import EHRTokenizer
from heterogt.utils.dataset import FineTuneEHRDataset, batcher, expand_level3
from heterogt.utils.train import train_with_early_stopping
from heterogt.utils.seed import set_random_seed
from heterogt.model.model import HeteroGTFineTune

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
set_random_seed(123)

[INFO] Random seed set to 123


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
config = Namespace(
    dataset = "MIMIC-III",
    tasks = ["death", "readmission", "stay", "next_diag_6m", "next_diag_12m"], 
    task_index = 2,  # index of the task to train
    token_type = ["diag", "med", "lab", "pro"],
    special_tokens = ["[PAD]", "[CLS]"],
    attn_mask_dicts = [{1:[6,7], 2:[6,7], 3:[6,7], 4:[6,7], 6:[2,3,4,5,6,7], 7:[2,3,4,5,6,7]}, 
                       {1:[6,7], 2:[6,7], 3:[6,7], 4:[6,7], 6:[2,3,4,5,6,7], 7:[2,3,4,5,6,7]}],
    d_model = 64,
    num_heads = 4,
    batch_size = 32,
    lr = 1e-3,
    epochs = 500,
    early_stop_patience = 5,
    group_code_thre = 5,  # if there are group_code_thre diag codes belongs to the same group ICD code, then the group code is generated
    use_pretrained_model = True,
    pretrain_mask_rate = 0.7,
    pretrain_cls_ontology_weight = 5e-2,
    pretrain_visit_ontology_weight = 5e-2,
    pretrain_adm_type_weight = 5e-2,
)

In [5]:
full_data_path = f"./data_process/{config.dataset}-processed/mimic.pkl"  # for tokenizer
curr_task = config.tasks[config.task_index]
print("Current task:", curr_task)
if curr_task == "next_diag_6m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_6m.pkl"
elif curr_task == "next_diag_12m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_downstream.pkl"

Current task: stay


In [6]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
group_code_sentences = [expand_level3()[1]]
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_sentences = [[str(c)] for c in set(ehr_full_data["AGE"].values.tolist())] # important of [[]]
adm_type_sentences = ehr_full_data["ADMISSION_TYPE"].values.tolist()
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
config.max_num_adms = max_admissions
print(f"Max admissions per patient: {config.max_num_adms}")

Max admissions per patient: 8


In [7]:
tokenizer = EHRTokenizer(age_sentences, group_code_sentences, diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, special_tokens=config.special_tokens, adm_types_sentences=adm_type_sentences)
config.label_vocab_size = len(tokenizer.diag_voc.id2word)  # only for diagnosis
config.global_vocab_size = len(tokenizer.vocab.id2word)
config.age_vocab_size = tokenizer.token_number("age")
config.group_code_vocab_size = tokenizer.token_number("group")
print(f"Age vocabulary size: {config.age_vocab_size}")
print(f"Group code vocabulary size: {config.group_code_vocab_size}")

Age vocabulary size: 18
Group code vocabulary size: 19


In [8]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))
# example label percentage
print("Percentage of DEATH in test dataset:",
      (test_data["DEATH"] == True).mean() * 100, "%")

print("Percentage of READMISSION in test dataset:",
      (test_data["READMISSION"] == 1).mean() * 100, "%")

print("Percentage of STAY>7 days in test dataset:",
      (test_data["STAY_DAYS"] > 7).mean() * 100, "%")

Percentage of DEATH in test dataset: 28.648477157360407 %
Percentage of READMISSION in test dataset: 40.1491116751269 %
Percentage of STAY>7 days in test dataset: 50.58692893401015 %


In [9]:
train_dataset = FineTuneEHRDataset(train_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                   max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)
val_dataset = FineTuneEHRDataset(val_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                 max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)
test_dataset = FineTuneEHRDataset(test_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                   max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)

In [10]:
num_group_code = []
for i in range(len(train_dataset)):
    input_ids, token_types, adm_index, age_ids, diag_group_codes, labels = train_dataset[i]
    count = (token_types[0] == 6).sum().item()
    num_group_code.append(count)
print("Mean group token numer per patient", np.mean(num_group_code))

Mean group token numer per patient 0.7971893963589908


In [11]:
train_dataloader = DataLoader(
    train_dataset,
    collate_fn=batcher(tokenizer, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=True,
    batch_size=config.batch_size,
)

val_dataloader = DataLoader(
    val_dataset,
    collate_fn=batcher(tokenizer, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

test_dataloader = DataLoader(
    test_dataset,
    collate_fn=batcher(tokenizer, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

In [12]:
for batch in train_dataloader:
    pass  # just to check if the dataloader works
for batch in val_dataloader:
    pass  # just to check if the dataloader works
for batch in test_dataloader:
    pass  # just to check if the dataloader works
print("All pass!")

All pass!


In [13]:
if curr_task in ["death", "stay", "readmission"]:
    eval_metric = "f1"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "f1"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [14]:
input_ids, token_types, adm_index, age_ids, diag_code_group_dicts, labels = next(iter(train_dataloader))
print("Input IDs shape:", input_ids.shape)
print("Token Types shape:", token_types.shape)
print("Admission Index shape:", adm_index.shape)
print("Age IDs shape:", age_ids.shape)
print("Diag Code Group Dict number:", len(diag_code_group_dicts))
print("Labels shape:", labels.shape)

Input IDs shape: torch.Size([32, 293])
Token Types shape: torch.Size([32, 293])
Admission Index shape: torch.Size([32, 293])
Age IDs shape: torch.Size([32, 8])
Diag Code Group Dict number: 32
Labels shape: torch.Size([32, 1])


# Model Walkthrough

In [None]:
# load pretrained model
if config.use_pretrained_model:
    pretrain_exp_name = (
    f"{config.dataset}-{config.pretrain_mask_rate}-{config.d_model}-{config.pretrain_cls_ontology_weight}-{config.pretrain_visit_ontology_weight}-{config.pretrain_adm_type_weight}"
)
    print(pretrain_exp_name)
    save_path = "./pretrained_models/" + pretrain_exp_name
    state_dict = torch.load(f"{save_path}/pretrained_model.pt", map_location="cpu")

MIMIC-III-0.7-64-0.05-0.05


In [None]:
final_metrics = []
for i in range(15):
    model = HeteroGTFineTune(tokenizer=tokenizer, token_types=config.token_type, d_model=config.d_model, num_heads=config.num_heads, layer_types=['gnn', 'tf', 'gnn', 'tf'], max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, label_vocab_size=config.label_vocab_size, attn_mask_dicts=config.attn_mask_dicts,
                     use_cls_cat=True).to(device)
    if config.use_pretrained_model:
        model.load_weight(state_dict)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    final_metrics.append(best_test_metric)



Epoch 001: 100%|██████████| 98/98 [00:06<00:00, 14.40it/s, loss=0.5842]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.80it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.77it/s]


Validation: {'precision': 0.6517064048605617, 'recall': 0.8742552524274875, 'f1': 0.7467523722303702, 'auc': 0.8079216153653785, 'prauc': 0.8117240811937516}
Test:      {'precision': 0.652869238004109, 'recall': 0.8704923173381296, 'f1': 0.7461362671064863, 'auc': 0.8041327859058048, 'prauc': 0.8079571539953866}


Epoch 002: 100%|██████████| 98/98 [00:06<00:00, 15.04it/s, loss=0.5131]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.86it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.64it/s]


Validation: {'precision': 0.7278812572737838, 'recall': 0.7842583882070108, 'f1': 0.7550188629291992, 'auc': 0.8304403484060727, 'prauc': 0.8381760565503973}
Test:      {'precision': 0.7281385281364268, 'recall': 0.7911571025375003, 'f1': 0.7583408426167831, 'auc': 0.8332821441280647, 'prauc': 0.8436798149769993}


Epoch 003: 100%|██████████| 98/98 [00:06<00:00, 15.88it/s, loss=0.4606]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 22.00it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.72it/s]


Validation: {'precision': 0.7973068745542264, 'recall': 0.7055503292546079, 'f1': 0.7486275112364106, 'auc': 0.8444473648753943, 'prauc': 0.8517508851704716}
Test:      {'precision': 0.7850335333540945, 'recall': 0.6973973032276658, 'f1': 0.738625036529468, 'auc': 0.8356130901418249, 'prauc': 0.8463979957381762}


Epoch 004: 100%|██████████| 98/98 [00:06<00:00, 14.91it/s, loss=0.4245]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.80it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.54it/s]


Validation: {'precision': 0.7724628312838641, 'recall': 0.7494512386304502, 'f1': 0.7607830604133323, 'auc': 0.8450142868352835, 'prauc': 0.8515017568206431}
Test:      {'precision': 0.7573780677205425, 'recall': 0.7645029789878818, 'f1': 0.7609238401912428, 'auc': 0.8425061671163968, 'prauc': 0.8521701190920243}


Epoch 005: 100%|██████████| 98/98 [00:06<00:00, 15.85it/s, loss=0.3809]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.22it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 22.17it/s]


Validation: {'precision': 0.7629513343775104, 'recall': 0.7619943555949765, 'f1': 0.7624725397105058, 'auc': 0.8426719471541322, 'prauc': 0.84886745296349}
Test:      {'precision': 0.7547692307669084, 'recall': 0.7692066478495792, 'f1': 0.7619195477236606, 'auc': 0.8345799439989088, 'prauc': 0.8403733440360198}


Epoch 006: 100%|██████████| 98/98 [00:06<00:00, 15.70it/s, loss=0.3399]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.22it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.66it/s]


Validation: {'precision': 0.7619194094101448, 'recall': 0.776732518028295, 'f1': 0.7692546533831677, 'auc': 0.8413482082402571, 'prauc': 0.8426929795993303}
Test:      {'precision': 0.7488556606629574, 'recall': 0.7695202257736924, 'f1': 0.7590473194650182, 'auc': 0.8375492702392403, 'prauc': 0.8392576472704316}


Epoch 007: 100%|██████████| 98/98 [00:05<00:00, 16.56it/s, loss=0.2944]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.08it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 22.05it/s]


Validation: {'precision': 0.78566308243446, 'recall': 0.6873628096560447, 'f1': 0.733232977123845, 'auc': 0.8280707803950801, 'prauc': 0.8351305182634023}
Test:      {'precision': 0.7756320224691867, 'recall': 0.6926936343659684, 'f1': 0.7318204356297284, 'auc': 0.8286232721126545, 'prauc': 0.8348180749069938}


Epoch 008: 100%|██████████| 98/98 [00:06<00:00, 15.41it/s, loss=0.2567]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.73it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.72it/s]


Validation: {'precision': 0.7387496224683215, 'recall': 0.767011602380787, 'f1': 0.7526153796148304, 'auc': 0.8324921688409643, 'prauc': 0.8348974918120147}
Test:      {'precision': 0.7355568790925088, 'recall': 0.7745374725595029, 'f1': 0.7545440609854471, 'auc': 0.830112691751894, 'prauc': 0.8330550674916124}


Epoch 009: 100%|██████████| 98/98 [00:06<00:00, 16.32it/s, loss=0.2282]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.04it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.51it/s]


Validation: {'precision': 0.7236514522801559, 'recall': 0.8203198494800241, 'f1': 0.7689594306434423, 'auc': 0.8317328400484322, 'prauc': 0.8324386081781852}
Test:      {'precision': 0.7135802469116226, 'recall': 0.8156161806183267, 'f1': 0.7611940248707798, 'auc': 0.8251824213148429, 'prauc': 0.8242978633719427}


Epoch 010: 100%|██████████| 98/98 [00:06<00:00, 15.32it/s, loss=0.2300]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.53it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.49it/s]


Validation: {'precision': 0.7877758913385814, 'recall': 0.727500783942529, 'f1': 0.7564395124492014, 'auc': 0.8406225883210157, 'prauc': 0.8438752560465965}
Test:      {'precision': 0.7715531700745671, 'recall': 0.7212292254602658, 'f1': 0.7455429447601539, 'auc': 0.8300348761065198, 'prauc': 0.8337709455024602}


Epoch 011: 100%|██████████| 98/98 [00:06<00:00, 15.95it/s, loss=0.1813]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.75it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.75it/s]


Validation: {'precision': 0.6999743128674545, 'recall': 0.8544998432083585, 'f1': 0.7695566174702798, 'auc': 0.8382169503084989, 'prauc': 0.8419689627140141}
Test:      {'precision': 0.695155972607925, 'recall': 0.859517089994169, 'f1': 0.7686483405388663, 'auc': 0.8323278706347612, 'prauc': 0.8363461728400106}


Epoch 012: 100%|██████████| 98/98 [00:06<00:00, 14.97it/s, loss=0.1624]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.75it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.71it/s]


Validation: {'precision': 0.7510809141422141, 'recall': 0.7626215114432029, 'f1': 0.7568072145404905, 'auc': 0.8283645650314497, 'prauc': 0.8300311992839939}
Test:      {'precision': 0.7453473945386311, 'recall': 0.7535277516439213, 'f1': 0.7494152452706951, 'auc': 0.8242140040981564, 'prauc': 0.8278830658249661}


Epoch 013: 100%|██████████| 98/98 [00:06<00:00, 15.98it/s, loss=0.1418]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.74it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.65it/s]


Validation: {'precision': 0.7026372443468712, 'recall': 0.8187519598594584, 'f1': 0.7562635721449648, 'auc': 0.8208910415680142, 'prauc': 0.8234220706951638}
Test:      {'precision': 0.6942522618395576, 'recall': 0.818124804011232, 'f1': 0.7511155844944581, 'auc': 0.8171331327038621, 'prauc': 0.8171551469833455}


Epoch 014: 100%|██████████| 98/98 [00:06<00:00, 15.91it/s, loss=0.1359]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 20.05it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.50it/s]


Validation: {'precision': 0.6786582401539167, 'recall': 0.8755095641239401, 'f1': 0.7646172756477492, 'auc': 0.8289649949175459, 'prauc': 0.8359226808722187}
Test:      {'precision': 0.669856459328541, 'recall': 0.8780181875168453, 'f1': 0.7599402854941187, 'auc': 0.8231532248444317, 'prauc': 0.82781455950079}


Epoch 015: 100%|██████████| 98/98 [00:06<00:00, 15.39it/s, loss=0.1226]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 20.90it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.60it/s]


Validation: {'precision': 0.7156673114100147, 'recall': 0.812166823453082, 'f1': 0.7608695602351027, 'auc': 0.8266170789548221, 'prauc': 0.8321287697147818}
Test:      {'precision': 0.7107843137235546, 'recall': 0.8184383819353451, 'f1': 0.7608220326264094, 'auc': 0.8241726802657812, 'prauc': 0.8293722922623437}


Epoch 016: 100%|██████████| 98/98 [00:06<00:00, 15.68it/s, loss=0.1183]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.12it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.56it/s]


Validation: {'precision': 0.6992933786948462, 'recall': 0.837880213230361, 'f1': 0.7623395100170683, 'auc': 0.8231364745180509, 'prauc': 0.8280727319946368}
Test:      {'precision': 0.6883720930214771, 'recall': 0.8353715898374557, 'f1': 0.7547811255161346, 'auc': 0.8172676239098386, 'prauc': 0.8237102966665484}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6999743128674545, 'recall': 0.8544998432083585, 'f1': 0.7695566174702798, 'auc': 0.8382169503084989, 'prauc': 0.8419689627140141}
Corresponding test performance:
{'precision': 0.695155972607925, 'recall': 0.859517089994169, 'f1': 0.7686483405388663, 'auc': 0.8323278706347612, 'prauc': 0.8363461728400106}


Epoch 001: 100%|██████████| 98/98 [00:06<00:00, 15.79it/s, loss=0.5955]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.96it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.82it/s]


Validation: {'precision': 0.6818650180301858, 'recall': 0.830040765127532, 'f1': 0.7486918349553978, 'auc': 0.8047694589369155, 'prauc': 0.8050971936873919}
Test:      {'precision': 0.6784434203771673, 'recall': 0.8309814988998715, 'f1': 0.7470049281004016, 'auc': 0.8012937228544954, 'prauc': 0.8037019928895356}


Epoch 002: 100%|██████████| 98/98 [00:06<00:00, 15.88it/s, loss=0.5103]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.31it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.46it/s]


Validation: {'precision': 0.6769342711601336, 'recall': 0.875195986199827, 'f1': 0.7634026209000625, 'auc': 0.8339832464387906, 'prauc': 0.839498247539988}
Test:      {'precision': 0.6768815886230035, 'recall': 0.8657886484764322, 'f1': 0.759768844825224, 'auc': 0.8284497220833855, 'prauc': 0.8374176753315108}


Epoch 003: 100%|██████████| 98/98 [00:06<00:00, 15.25it/s, loss=0.4628]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.89it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.83it/s]


Validation: {'precision': 0.7137139807878503, 'recall': 0.8388209470027005, 'f1': 0.7712267501116915, 'auc': 0.8437406842187916, 'prauc': 0.8520408845216904}
Test:      {'precision': 0.7047015319579908, 'recall': 0.8366259015339084, 'f1': 0.7650179161813894, 'auc': 0.8350638506060409, 'prauc': 0.8454124275242547}


Epoch 004: 100%|██████████| 98/98 [00:06<00:00, 15.55it/s, loss=0.4149]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 20.96it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.63it/s]


Validation: {'precision': 0.7821612349887405, 'recall': 0.7149576669780027, 'f1': 0.7470511090312184, 'auc': 0.8384291504288863, 'prauc': 0.844922650407526}
Test:      {'precision': 0.7765456989221219, 'recall': 0.7246785826255105, 'f1': 0.7497161345006977, 'auc': 0.8386738724155618, 'prauc': 0.8464062052256919}


Epoch 005: 100%|██████████| 98/98 [00:06<00:00, 16.12it/s, loss=0.3672]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.59it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.51it/s]


Validation: {'precision': 0.6883279180187745, 'recall': 0.8635936030076401, 'f1': 0.7660639728085109, 'auc': 0.8320067309235156, 'prauc': 0.8326177119584728}
Test:      {'precision': 0.6838198911413067, 'recall': 0.8667293822487716, 'f1': 0.7644862348683208, 'auc': 0.8264673861342184, 'prauc': 0.8300133587875385}


Epoch 006: 100%|██████████| 98/98 [00:06<00:00, 15.94it/s, loss=0.3379]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.33it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 22.01it/s]


Validation: {'precision': 0.6828912144057365, 'recall': 0.8798996550615243, 'f1': 0.7689777972149227, 'auc': 0.8391149828255552, 'prauc': 0.845782295424889}
Test:      {'precision': 0.6771768129986001, 'recall': 0.8755095641239401, 'f1': 0.7636761438759915, 'auc': 0.8339024546155097, 'prauc': 0.840479523553211}


Epoch 007: 100%|██████████| 98/98 [00:06<00:00, 16.21it/s, loss=0.2978]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.08it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.80it/s]


Validation: {'precision': 0.73234200743285, 'recall': 0.8030730636538004, 'f1': 0.7660783677274337, 'auc': 0.8320524463850574, 'prauc': 0.8388377166774185}
Test:      {'precision': 0.722713023651403, 'recall': 0.7952336155509714, 'f1': 0.7572409624621225, 'auc': 0.831485891258424, 'prauc': 0.839452907188735}


Epoch 008: 100%|██████████| 98/98 [00:06<00:00, 16.31it/s, loss=0.2579]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.45it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 22.10it/s]


Validation: {'precision': 0.7551659361278799, 'recall': 0.7563499529609397, 'f1': 0.7557574758060324, 'auc': 0.8296823760063556, 'prauc': 0.8334341644852257}
Test:      {'precision': 0.7504719949630256, 'recall': 0.7478833490098844, 'f1': 0.7491754308387172, 'auc': 0.8269309076596064, 'prauc': 0.8326703087043212}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7137139807878503, 'recall': 0.8388209470027005, 'f1': 0.7712267501116915, 'auc': 0.8437406842187916, 'prauc': 0.8520408845216904}
Corresponding test performance:
{'precision': 0.7047015319579908, 'recall': 0.8366259015339084, 'f1': 0.7650179161813894, 'auc': 0.8350638506060409, 'prauc': 0.8454124275242547}


Epoch 001: 100%|██████████| 98/98 [00:05<00:00, 16.37it/s, loss=0.5856]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.06it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.83it/s]


Validation: {'precision': 0.768904593636859, 'recall': 0.6823455628702341, 'f1': 0.7230436899813261, 'auc': 0.8209340442439261, 'prauc': 0.8203439181918517}
Test:      {'precision': 0.7591420534431816, 'recall': 0.6770147381603104, 'f1': 0.7157301458510411, 'auc': 0.8142933649830602, 'prauc': 0.8202340545607971}


Epoch 002: 100%|██████████| 98/98 [00:06<00:00, 15.81it/s, loss=0.5124]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.14it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.80it/s]


Validation: {'precision': 0.7234744365015847, 'recall': 0.8253370962658346, 'f1': 0.7710560957956978, 'auc': 0.8403032331682453, 'prauc': 0.8485476227608946}
Test:      {'precision': 0.706713780916807, 'recall': 0.8153026026942136, 'f1': 0.7571345320063563, 'auc': 0.8313030295251485, 'prauc': 0.8436109378043297}


Epoch 003: 100%|██████████| 98/98 [00:06<00:00, 15.76it/s, loss=0.4661]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.85it/s]
Running inference:   0%|          | 0/197 [00:00<?, ?it/s]

In [None]:
def topk_avg_performance_formatted(performances, k=5):
    metrics = ["f1", "auc", "prauc"]
    scores = {m: np.array([p[m] for p in performances]) for m in metrics}

    # 计算排名（值越大排名越靠前）
    ranks = {m: (-scores[m]).argsort().argsort() + 1 for m in metrics}
    avg_ranks = np.mean(np.stack([ranks[m] for m in metrics], axis=1), axis=1)

    # 选 top-k
    topk_idx = np.argsort(avg_ranks)[:k]
    final_avg = {m: np.mean([performances[i][m] for i in topk_idx]) for m in performances[0].keys()}
    final_std = {m: np.std([performances[i][m] for i in topk_idx], ddof=0) for m in performances[0].keys()}

    # 打印结果
    print("Final Metrics:")
    for m in performances[0].keys():
        print(f"{m}: {final_avg[m]:.4f}±{final_std[m]:.4f}")

In [None]:
topk_avg_performance_formatted(final_metrics, 5)