In [1]:
import torch
import pickle
import numpy as np
from argparse import Namespace
from torch.utils.data import DataLoader
import torch.nn.functional as F
from heterogt.utils.tokenizer import EHRTokenizer
from heterogt.utils.dataset import FineTuneEHRDataset, batcher, expand_level3
from heterogt.utils.train import train_with_early_stopping
from heterogt.utils.seed import set_random_seed
from heterogt.model.model import HeteroGT

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
set_random_seed(123)

[INFO] Random seed set to 123


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
config = Namespace(
    dataset = "MIMIC-III",
    tasks = ["death", "readmission", "stay", "next_diag_6m", "next_diag_12m"], 
    task_index = 2,  # index of the task to train
    token_type = ["diag", "med", "lab", "pro"],
    special_tokens = ["[PAD]", "[CLS]"],
    attn_mask_dicts = [{1:[6,7], 2:[6,7], 3:[6,7], 4:[6,7], 6:[2,3,4,5,6,7], 7:[2,3,4,5,6,7]}, 
                       {1:[6,7], 2:[6,7], 3:[6,7], 4:[6,7], 6:[2,3,4,5,6,7], 7:[2,3,4,5,6,7]}],
    d_model = 64,
    num_heads = 4,
    batch_size = 32,
    lr = 1e-3,
    epochs = 500,
    early_stop_patience = 5,
    group_code_thre = 5,  # if there are group_code_thre diag codes belongs to the same group ICD code, then the group code is generated
    use_pretrained_model = True,
    pretrain_mask_rate = 0.7,
)

In [5]:
full_data_path = f"./data_process/{config.dataset}-processed/mimic.pkl"  # for tokenizer
curr_task = config.tasks[config.task_index]
print("Current task:", curr_task)
if curr_task == "next_diag_6m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_6m.pkl"
elif curr_task == "next_diag_12m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_downstream.pkl"

Current task: stay


In [6]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
group_code_sentences = [expand_level3()[1]]
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_sentences = [[str(c)] for c in set(ehr_full_data["AGE"].values.tolist())] # important of [[]]
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
config.max_num_adms = max_admissions
print(f"Max admissions per patient: {config.max_num_adms}")

Max admissions per patient: 8


In [7]:
tokenizer = EHRTokenizer(age_sentences, group_code_sentences, diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, special_tokens=config.special_tokens)
config.label_vocab_size = len(tokenizer.diag_voc.id2word)  # only for diagnosis
config.global_vocab_size = len(tokenizer.vocab.id2word)
config.age_vocab_size = tokenizer.token_number("age")
config.group_code_vocab_size = tokenizer.token_number("group")
print(f"Age vocabulary size: {config.age_vocab_size}")
print(f"Group code vocabulary size: {config.group_code_vocab_size}")

Age vocabulary size: 18
Group code vocabulary size: 19


In [8]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))
# example label percentage
print("Percentage of DEATH in test dataset:",
      (test_data["DEATH"] == True).mean() * 100, "%")

print("Percentage of READMISSION in test dataset:",
      (test_data["READMISSION"] == 1).mean() * 100, "%")

print("Percentage of STAY>7 days in test dataset:",
      (test_data["STAY_DAYS"] > 7).mean() * 100, "%")

Percentage of DEATH in test dataset: 28.648477157360407 %
Percentage of READMISSION in test dataset: 40.1491116751269 %
Percentage of STAY>7 days in test dataset: 50.58692893401015 %


In [9]:
train_dataset = FineTuneEHRDataset(train_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                   max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)
val_dataset = FineTuneEHRDataset(val_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                 max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)
test_dataset = FineTuneEHRDataset(test_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                   max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)

In [10]:
num_group_code = []
for i in range(len(train_dataset)):
    input_ids, token_types, adm_index, age_ids, diag_group_codes, labels = train_dataset[i]
    count = (token_types[0] == 6).sum().item()
    num_group_code.append(count)
print("Mean group token numer per patient", np.mean(num_group_code))

Mean group token numer per patient 0.7971893963589908


In [11]:
train_dataloader = DataLoader(
    train_dataset,
    collate_fn=batcher(tokenizer, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=True,
    batch_size=config.batch_size,
)

val_dataloader = DataLoader(
    val_dataset,
    collate_fn=batcher(tokenizer, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

test_dataloader = DataLoader(
    test_dataset,
    collate_fn=batcher(tokenizer, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

In [12]:
for batch in train_dataloader:
    pass  # just to check if the dataloader works
for batch in val_dataloader:
    pass  # just to check if the dataloader works
for batch in test_dataloader:
    pass  # just to check if the dataloader works
print("All pass!")

All pass!


In [13]:
if curr_task in ["death", "stay", "readmission"]:
    eval_metric = "f1"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "f1"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [14]:
input_ids, token_types, adm_index, age_ids, diag_code_group_dicts, labels = next(iter(train_dataloader))
print("Input IDs shape:", input_ids.shape)
print("Token Types shape:", token_types.shape)
print("Admission Index shape:", adm_index.shape)
print("Age IDs shape:", age_ids.shape)
print("Diag Code Group Dict number:", len(diag_code_group_dicts))
print("Labels shape:", labels.shape)

Input IDs shape: torch.Size([32, 293])
Token Types shape: torch.Size([32, 293])
Admission Index shape: torch.Size([32, 293])
Age IDs shape: torch.Size([32, 8])
Diag Code Group Dict number: 32
Labels shape: torch.Size([32, 1])


# Model Walkthrough

In [15]:
# load pretrained model
if config.use_pretrained_model:
    pretrain_exp_name = (
    f"{config.dataset}-{config.pretrain_mask_rate}-{config.d_model}"
)
    print(pretrain_exp_name)
    save_path = "./pretrained_models/" + pretrain_exp_name
    state_dict = torch.load(f"{save_path}/pretrained_model.pt", map_location="cpu")

MIMIC-III-0.7-64


In [None]:
final_metrics = []
for i in range(10):
    model = HeteroGT(tokenizer=tokenizer, token_types=config.token_type, d_model=config.d_model, num_heads=config.num_heads, layer_types=['gnn', 'tf', 'gnn', 'tf'], max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, label_vocab_size=config.label_vocab_size, attn_mask_dicts=attn_mask_dicts,
                     use_cls_cat=True).to(device)
    if config.use_pretrained_model:
        model.load_weight(state_dict)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    final_metrics.append(best_test_metric)



Epoch 001: 100%|██████████| 98/98 [00:06<00:00, 16.04it/s, loss=0.6085]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 23.11it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 23.04it/s]


Validation: {'precision': 0.6642874723094653, 'recall': 0.8463468171814164, 'f1': 0.744346382274101, 'auc': 0.7999988244595604, 'prauc': 0.8094198590348712}
Test:      {'precision': 0.647897535039517, 'recall': 0.8407024145473795, 'f1': 0.7318138343066659, 'auc': 0.7946282037924304, 'prauc': 0.8075749247729611}


Epoch 002: 100%|██████████| 98/98 [00:06<00:00, 16.14it/s, loss=0.5531]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 23.15it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 23.11it/s]


Validation: {'precision': 0.6949920085223895, 'recall': 0.818124804011232, 'f1': 0.7515483170819328, 'auc': 0.8156240677939195, 'prauc': 0.8245096525466373}
Test:      {'precision': 0.6898917920277385, 'recall': 0.8196926936317978, 'f1': 0.7492118035759486, 'auc': 0.8115886924706568, 'prauc': 0.8220086897453106}


Epoch 003: 100%|██████████| 98/98 [00:05<00:00, 17.03it/s, loss=0.4993]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 23.17it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 23.05it/s]


Validation: {'precision': 0.7821482602088421, 'recall': 0.648479147066013, 'f1': 0.7090690847034848, 'auc': 0.8208141290717279, 'prauc': 0.8342775500063317}
Test:      {'precision': 0.7778189910950378, 'recall': 0.6575729068652946, 'f1': 0.7126592983461761, 'auc': 0.815756359516335, 'prauc': 0.830082850211411}


Epoch 004: 100%|██████████| 98/98 [00:06<00:00, 16.15it/s, loss=0.4878]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 23.10it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 23.07it/s]


Validation: {'precision': 0.813736903372861, 'recall': 0.6575729068652946, 'f1': 0.7273673207561985, 'auc': 0.8404218421843994, 'prauc': 0.8493777252503707}
Test:      {'precision': 0.7993920972614005, 'recall': 0.6597679523340867, 'f1': 0.7228998404306879, 'auc': 0.8330054606852306, 'prauc': 0.8457134540843725}


Epoch 005: 100%|██████████| 98/98 [00:05<00:00, 16.96it/s, loss=0.4409]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 23.15it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 22.99it/s]


Validation: {'precision': 0.7762879322484958, 'recall': 0.6898714330489499, 'f1': 0.730532952013118, 'auc': 0.8306717389729534, 'prauc': 0.8364994618912593}
Test:      {'precision': 0.7769936485505399, 'recall': 0.6904985888971763, 'f1': 0.7311970728831141, 'auc': 0.8303852478448438, 'prauc': 0.8385976648394879}


Epoch 006: 100%|██████████| 98/98 [00:06<00:00, 16.14it/s, loss=0.4076]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 23.09it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 23.06it/s]


Validation: {'precision': 0.7780033840921219, 'recall': 0.7209156475361527, 'f1': 0.7483723908381501, 'auc': 0.83363184022617, 'prauc': 0.8414342905047226}
Test:      {'precision': 0.7695711368251547, 'recall': 0.7089996864198527, 'f1': 0.7380447150973809, 'auc': 0.8331497669305653, 'prauc': 0.8458033945092314}


Epoch 007: 100%|██████████| 98/98 [00:05<00:00, 16.95it/s, loss=0.3619]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 23.13it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 23.00it/s]


Validation: {'precision': 0.7720637027180755, 'recall': 0.7296958294113212, 'f1': 0.7502821165556474, 'auc': 0.8356290030542952, 'prauc': 0.8433847143084795}
Test:      {'precision': 0.776776776774185, 'recall': 0.7300094073354343, 'f1': 0.7526673082904533, 'auc': 0.8385338445207167, 'prauc': 0.8508987826563991}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.6949920085223895, 'recall': 0.818124804011232, 'f1': 0.7515483170819328, 'auc': 0.8156240677939195, 'prauc': 0.8245096525466373}
Corresponding test performance:
{'precision': 0.6898917920277385, 'recall': 0.8196926936317978, 'f1': 0.7492118035759486, 'auc': 0.8115886924706568, 'prauc': 0.8220086897453106}


Epoch 001: 100%|██████████| 98/98 [00:06<00:00, 16.15it/s, loss=0.6224]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.55it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 23.01it/s]


Validation: {'precision': 0.6823218997343474, 'recall': 0.8109125117566294, 'f1': 0.7410803790441262, 'auc': 0.8000816146580448, 'prauc': 0.8112115005505669}
Test:      {'precision': 0.6693569382955757, 'recall': 0.806208842894932, 'f1': 0.7314366948986802, 'auc': 0.7907554912628533, 'prauc': 0.8042334924191312}


Epoch 002: 100%|██████████| 98/98 [00:05<00:00, 16.87it/s, loss=0.5453]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 23.00it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 23.00it/s]


Validation: {'precision': 0.6655868427593681, 'recall': 0.8375666353062479, 'f1': 0.7417384010617254, 'auc': 0.7990688413561958, 'prauc': 0.8040598389655313}
Test:      {'precision': 0.6631944444427996, 'recall': 0.8385073690785874, 'f1': 0.7406176380518293, 'auc': 0.7898045901164064, 'prauc': 0.7950201369372583}


Epoch 003: 100%|██████████| 98/98 [00:06<00:00, 16.20it/s, loss=0.5152]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 23.13it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 23.09it/s]


Validation: {'precision': 0.8216003497996432, 'recall': 0.5892129194086259, 'f1': 0.6862673435626656, 'auc': 0.8253521672997002, 'prauc': 0.837791060642445}
Test:      {'precision': 0.821865766142584, 'recall': 0.6105362182483207, 'f1': 0.700611725944743, 'auc': 0.8249579840815162, 'prauc': 0.8394433784592801}


Epoch 004: 100%|██████████| 98/98 [00:05<00:00, 16.98it/s, loss=0.4696]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 23.16it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 23.03it/s]


Validation: {'precision': 0.6930614406761307, 'recall': 0.8206334274041373, 'f1': 0.7514716389673122, 'auc': 0.8189625021689726, 'prauc': 0.8307285138298998}
Test:      {'precision': 0.6904512067138237, 'recall': 0.8253370962658346, 'f1': 0.7518925818107639, 'auc': 0.8196312363879246, 'prauc': 0.8312859553342554}


Epoch 005: 100%|██████████| 98/98 [00:05<00:00, 16.50it/s, loss=0.4475]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.37it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 23.06it/s]


Validation: {'precision': 0.7157636809128348, 'recall': 0.8243963624934951, 'f1': 0.7662489020468327, 'auc': 0.8326542829007395, 'prauc': 0.8401677119849806}
Test:      {'precision': 0.7007240547044765, 'recall': 0.8193791157076846, 'f1': 0.7554206368322705, 'auc': 0.827707252106081, 'prauc': 0.8405723101906195}


Epoch 006: 100%|██████████| 98/98 [00:05<00:00, 16.96it/s, loss=0.3978]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 23.13it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 22.99it/s]


Validation: {'precision': 0.7004405286325461, 'recall': 0.847601128877869, 'f1': 0.7670261017399431, 'auc': 0.8314291587681903, 'prauc': 0.8373889015668229}
Test:      {'precision': 0.6881409413602732, 'recall': 0.8206334274041373, 'f1': 0.7485697890867592, 'auc': 0.8238585486727802, 'prauc': 0.8360179499001839}


Epoch 007: 100%|██████████| 98/98 [00:05<00:00, 16.98it/s, loss=0.3808]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.36it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 23.08it/s]


Validation: {'precision': 0.7201959716910175, 'recall': 0.8297271872034189, 'f1': 0.7710913544865147, 'auc': 0.8382288061864374, 'prauc': 0.8438743479141599}
Test:      {'precision': 0.7145981410587354, 'recall': 0.8196926936317978, 'f1': 0.7635460735957874, 'auc': 0.8352034758326048, 'prauc': 0.8464026818815882}


Epoch 008: 100%|██████████| 98/98 [00:05<00:00, 16.90it/s, loss=0.3411]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 23.11it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 23.00it/s]


Validation: {'precision': 0.7438042131327639, 'recall': 0.7529005957956949, 'f1': 0.7483247573478603, 'auc': 0.8280450591683663, 'prauc': 0.839134001537688}
Test:      {'precision': 0.7416487894552816, 'recall': 0.7588585763538449, 'f1': 0.7501549856988905, 'auc': 0.8284085492516158, 'prauc': 0.8414861399584697}


Epoch 009:  88%|████████▊ | 86/98 [00:05<00:00, 16.78it/s, loss=0.3142]

In [None]:
def topk_avg_performance_formatted(performances, k=5):
    metrics = ["f1", "auc", "prauc"]
    scores = {m: np.array([p[m] for p in performances]) for m in metrics}

    # 计算排名（值越大排名越靠前）
    ranks = {m: (-scores[m]).argsort().argsort() + 1 for m in metrics}
    avg_ranks = np.mean(np.stack([ranks[m] for m in metrics], axis=1), axis=1)

    # 选 top-k
    topk_idx = np.argsort(avg_ranks)[:k]
    final_avg = {m: np.mean([performances[i][m] for i in topk_idx]) for m in performances[0].keys()}
    final_std = {m: np.std([performances[i][m] for i in topk_idx], ddof=0) for m in performances[0].keys()}

    # 打印结果
    print("Final Metrics:")
    for m in performances[0].keys():
        print(f"{m}: {final_avg[m]:.4f}±{final_std[m]:.4f}")

In [None]:
topk_avg_performance_formatted(final_metrics, 5)