In [1]:
import torch
import pickle
import numpy as np
from argparse import Namespace
from torch.utils.data import DataLoader
import torch.nn.functional as F
from heterogt.utils.tokenizer import EHRTokenizer
from heterogt.utils.dataset import FineTuneEHRDataset, batcher, expand_level3
from heterogt.utils.train import train_with_early_stopping
from heterogt.utils.seed import set_random_seed
from heterogt.model.model import HeteroGT

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
set_random_seed(123)

[INFO] Random seed set to 123


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
config = Namespace(
    dataset = "MIMIC-III",
    tasks = ["death", "readmission", "stay", "next_diag_6m", "next_diag_12m"], 
    task_index = 2,  # index of the task to train
    token_type = ["diag", "med", "lab", "pro"],
    special_tokens = ["[PAD]", "[CLS]"],
    batch_size = 32,
    lr = 1e-3,
    epochs = 500,
    early_stop_patience = 5,
    group_code_thre = 5,  # if there are group_code_thre diag codes belongs to the same group ICD code, then the group code is generated
)

In [5]:
full_data_path = f"./data_process/{config.dataset}-processed/mimic.pkl"  # for tokenizer
curr_task = config.tasks[config.task_index]
print("Current task:", curr_task)
if curr_task == "next_diag_6m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_6m.pkl"
elif curr_task == "next_diag_12m":
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_nextdiag_12m.pkl"
else:
    finetune_data_path = f"./data_process/{config.dataset}-processed/mimic_downstream.pkl"

Current task: stay


In [6]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
group_code_sentences = [expand_level3()[1]]
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_sentences = [[str(c)] for c in set(ehr_full_data["AGE"].values.tolist())] # important of [[]]
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
config.max_num_adms = max_admissions
print(f"Max admissions per patient: {config.max_num_adms}")

Max admissions per patient: 8


In [7]:
tokenizer = EHRTokenizer(age_sentences, group_code_sentences, diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, special_tokens=config.special_tokens)
config.label_vocab_size = len(tokenizer.diag_voc.id2word)  # only for diagnosis
config.global_vocab_size = len(tokenizer.vocab.id2word)
config.age_vocab_size = tokenizer.token_number("age")
config.group_code_vocab_size = tokenizer.token_number("group")
print(f"Age vocabulary size: {config.age_vocab_size}")
print(f"Group code vocabulary size: {config.group_code_vocab_size}")

Age vocabulary size: 18
Group code vocabulary size: 19


In [8]:
train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))
# example label percentage
print("Percentage of DEATH in test dataset:",
      (test_data["DEATH"] == True).mean() * 100, "%")

print("Percentage of READMISSION in test dataset:",
      (test_data["READMISSION"] == 1).mean() * 100, "%")

print("Percentage of STAY>7 days in test dataset:",
      (test_data["STAY_DAYS"] > 7).mean() * 100, "%")

Percentage of DEATH in test dataset: 28.648477157360407 %
Percentage of READMISSION in test dataset: 40.1491116751269 %
Percentage of STAY>7 days in test dataset: 50.58692893401015 %


In [9]:
train_dataset = FineTuneEHRDataset(train_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                   max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)
val_dataset = FineTuneEHRDataset(val_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                 max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)
test_dataset = FineTuneEHRDataset(test_data, tokenizer, token_type=config.token_type, task=curr_task, 
                                   max_num_adms=config.max_num_adms, group_code_thre=config.group_code_thre)

In [10]:
num_group_code = []
for i in range(len(train_dataset)):
    input_ids, token_types, adm_index, age_ids, diag_group_codes, labels = train_dataset[i]
    count = (token_types[0] == 6).sum().item()
    num_group_code.append(count)
print("Mean group token numer per patient", np.mean(num_group_code))

Mean group token numer per patient 0.7971893963589908


In [11]:
train_dataloader = DataLoader(
    train_dataset,
    collate_fn=batcher(tokenizer, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=True,
    batch_size=config.batch_size,
)

val_dataloader = DataLoader(
    val_dataset,
    collate_fn=batcher(tokenizer, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

test_dataloader = DataLoader(
    test_dataset,
    collate_fn=batcher(tokenizer, n_token_type=len(config.token_type), is_pretrain = False),
    shuffle=False,
    batch_size=config.batch_size,
)

In [12]:
for batch in train_dataloader:
    pass  # just to check if the dataloader works
for batch in val_dataloader:
    pass  # just to check if the dataloader works
for batch in test_dataloader:
    pass  # just to check if the dataloader works
print("All pass!")

All pass!


In [13]:
if curr_task in ["death", "stay", "readmission"]:
    eval_metric = "f1"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "f1"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

In [14]:
input_ids, token_types, adm_index, age_ids, diag_code_group_dicts, labels = next(iter(train_dataloader))
print("Input IDs shape:", input_ids.shape)
print("Token Types shape:", token_types.shape)
print("Admission Index shape:", adm_index.shape)
print("Age IDs shape:", age_ids.shape)
print("Diag Code Group Dict number:", len(diag_code_group_dicts))
print("Labels shape:", labels.shape)

Input IDs shape: torch.Size([32, 293])
Token Types shape: torch.Size([32, 293])
Admission Index shape: torch.Size([32, 293])
Age IDs shape: torch.Size([32, 8])
Diag Code Group Dict number: 32
Labels shape: torch.Size([32, 1])


# Model Walkthrough

In [15]:
attn_mask_dicts = [{1:[6,7], 2:[6,7], 3:[6,7], 4:[6,7], 6:[2,3,4,5,6,7], 7:[2,3,4,5,6,7]}, 
                   {1:[6,7], 2:[6,7], 3:[6,7], 4:[6,7], 6:[2,3,4,5,6,7], 7:[2,3,4,5,6,7]}]

In [None]:
final_metrics = []
for i in range(10):
    model = HeteroGT(tokenizer, d_model=64, num_heads=4, layer_types=['gnn', 'tf', 'gnn', 'tf'], max_num_adms=config.max_num_adms, 
                     device=device, task=curr_task, label_vocab_size=config.label_vocab_size, attn_mask_dicts=attn_mask_dicts,
                     use_cls_cat=True).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    best_test_metric = train_with_early_stopping(model, train_dataloader, val_dataloader, test_dataloader,
                                             optimizer, loss_fn, device, config.early_stop_patience, task_type, config.epochs, 
                                             val_long_seq_idx=None, test_long_seq_idx=None, eval_metric=eval_metric, return_model=False)
    final_metrics.append(best_test_metric)

Epoch 001: 100%|██████████| 98/98 [00:07<00:00, 13.88it/s, loss=0.6937]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.19it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.25it/s]


Validation: {'precision': 0.6182373472935903, 'recall': 0.8886798369366928, 'f1': 0.7291907837864662, 'auc': 0.7807708510983115, 'prauc': 0.7824556976905184}
Test:      {'precision': 0.6115810019505279, 'recall': 0.8842897459991086, 'f1': 0.7230769182412495, 'auc': 0.7717335423181714, 'prauc': 0.7736575545837455}


Epoch 002: 100%|██████████| 98/98 [00:06<00:00, 15.63it/s, loss=0.6070]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.96it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.79it/s]


Validation: {'precision': 0.6828229027944532, 'recall': 0.8040137974261399, 'f1': 0.738479257703903, 'auc': 0.8046067420358893, 'prauc': 0.8168990861575831}
Test:      {'precision': 0.6787148594359338, 'recall': 0.7949200376268583, 'f1': 0.7322356969353889, 'auc': 0.7979962722983853, 'prauc': 0.8132044772530707}


Epoch 003: 100%|██████████| 98/98 [00:06<00:00, 15.93it/s, loss=0.5582]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.39it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.65it/s]


Validation: {'precision': 0.8195187165738881, 'recall': 0.5766698024440995, 'f1': 0.6769740426363543, 'auc': 0.8108900056857977, 'prauc': 0.8198505626676416}
Test:      {'precision': 0.8082311733764964, 'recall': 0.5788648479128916, 'f1': 0.6745843181746299, 'auc': 0.8027187659022512, 'prauc': 0.816319156690563}


Epoch 004: 100%|██████████| 98/98 [00:06<00:00, 15.88it/s, loss=0.5411]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.74it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.83it/s]


Validation: {'precision': 0.8223483195075411, 'recall': 0.5907808090291917, 'f1': 0.6875912360076676, 'auc': 0.8159373442974082, 'prauc': 0.821034688942649}
Test:      {'precision': 0.8089792460787422, 'recall': 0.5989338350561338, 'f1': 0.6882882833970951, 'auc': 0.8105847397781399, 'prauc': 0.8234294421384032}


Epoch 005: 100%|██████████| 98/98 [00:06<00:00, 15.61it/s, loss=0.4940]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.94it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.42it/s]


Validation: {'precision': 0.7980732177233215, 'recall': 0.6494198808383524, 'f1': 0.7161134113711425, 'auc': 0.8228280207445713, 'prauc': 0.827766555808411}
Test:      {'precision': 0.784125766868159, 'recall': 0.6412668548114103, 'f1': 0.7055373419513613, 'auc': 0.8131946342438165, 'prauc': 0.8223019330661526}


Epoch 006: 100%|██████████| 98/98 [00:06<00:00, 16.05it/s, loss=0.4506]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.01it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.83it/s]


Validation: {'precision': 0.7232884560277191, 'recall': 0.7917842583857266, 'f1': 0.755988018960052, 'auc': 0.8225153973191046, 'prauc': 0.8247692616973521}
Test:      {'precision': 0.7188847369913226, 'recall': 0.7842583882070108, 'f1': 0.7501499650132063, 'auc': 0.8152602721936915, 'prauc': 0.8176290976106607}


Epoch 007: 100%|██████████| 98/98 [00:06<00:00, 15.96it/s, loss=0.4260]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.23it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.37it/s]


Validation: {'precision': 0.7920651788849024, 'recall': 0.7011602383170237, 'f1': 0.7438456370652912, 'auc': 0.8352364529262869, 'prauc': 0.8370042955094231}
Test:      {'precision': 0.7746773630946122, 'recall': 0.6964565694553263, 'f1': 0.7334874454740645, 'auc': 0.8225070429199088, 'prauc': 0.8304401361340079}


Epoch 008: 100%|██████████| 98/98 [00:06<00:00, 16.30it/s, loss=0.3905]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.86it/s]
Running inference: 100%|██████████| 197/197 [00:08<00:00, 21.93it/s]


Validation: {'precision': 0.7384848484826106, 'recall': 0.7641894010637686, 'f1': 0.7511172703882681, 'auc': 0.8202993026432881, 'prauc': 0.8187441901823803}
Test:      {'precision': 0.7263189448419475, 'recall': 0.7597993101261844, 'f1': 0.7426819873374262, 'auc': 0.8120598646933908, 'prauc': 0.8146962436197458}


Epoch 009: 100%|██████████| 98/98 [00:06<00:00, 16.13it/s, loss=0.3774]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.96it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.08it/s]


Validation: {'precision': 0.7509079903124973, 'recall': 0.7779868297247476, 'f1': 0.7642076031926538, 'auc': 0.8335723096526237, 'prauc': 0.8366192888043034}
Test:      {'precision': 0.7360024081856231, 'recall': 0.7666980244566739, 'f1': 0.7510367021108223, 'auc': 0.8229557160524215, 'prauc': 0.8321821314783768}


Epoch 010: 100%|██████████| 98/98 [00:06<00:00, 15.99it/s, loss=0.3403]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.86it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.86it/s]


Validation: {'precision': 0.7354497354475736, 'recall': 0.784571966131124, 'f1': 0.7592171092496182, 'auc': 0.8220537716310744, 'prauc': 0.8168955324389202}
Test:      {'precision': 0.7221095334664674, 'recall': 0.7814361868899924, 'f1': 0.7506024046440781, 'auc': 0.8156597191288071, 'prauc': 0.8136602337787939}


Epoch 011: 100%|██████████| 98/98 [00:06<00:00, 15.89it/s, loss=0.3102]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.15it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.42it/s]


Validation: {'precision': 0.7499225286620703, 'recall': 0.7588585763538449, 'f1': 0.7543640847733851, 'auc': 0.8265832696079896, 'prauc': 0.8235547766108247}
Test:      {'precision': 0.7423088638417841, 'recall': 0.7641894010637686, 'f1': 0.7530902298565768, 'auc': 0.8219437603278122, 'prauc': 0.8247757466506267}


Epoch 012: 100%|██████████| 98/98 [00:06<00:00, 15.62it/s, loss=0.2753]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.81it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.75it/s]


Validation: {'precision': 0.836444007854438, 'recall': 0.5340232047647099, 'f1': 0.6518660239491149, 'auc': 0.8217621974126255, 'prauc': 0.8203683172264054}
Test:      {'precision': 0.8373435996109849, 'recall': 0.5456255879568968, 'f1': 0.6607176713164107, 'auc': 0.816913879824658, 'prauc': 0.8204478420416552}


Epoch 013: 100%|██████████| 98/98 [00:06<00:00, 15.82it/s, loss=0.2655]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.30it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.19it/s]


Validation: {'precision': 0.7136715391209124, 'recall': 0.780809031041766, 'f1': 0.7457322501740747, 'auc': 0.803292598345261, 'prauc': 0.7824360326023453}
Test:      {'precision': 0.7064846416362159, 'recall': 0.7789275634970871, 'f1': 0.7409395923251186, 'auc': 0.7991573662877054, 'prauc': 0.7880339026180059}


Epoch 014: 100%|██████████| 98/98 [00:06<00:00, 15.35it/s, loss=0.2454]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.89it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.86it/s]


Validation: {'precision': 0.7065277015888466, 'recall': 0.8077767325154978, 'f1': 0.7537673688314125, 'auc': 0.8158211466462585, 'prauc': 0.8077874876075928}
Test:      {'precision': 0.7003522080717953, 'recall': 0.8105989338325161, 'f1': 0.7514534833965283, 'auc': 0.8116454686983297, 'prauc': 0.8098451794687016}

Early stopping triggered (no improvement for 5 epochs).

Best validation performance:
{'precision': 0.7509079903124973, 'recall': 0.7779868297247476, 'f1': 0.7642076031926538, 'auc': 0.8335723096526237, 'prauc': 0.8366192888043034}
Corresponding test performance:
{'precision': 0.7360024081856231, 'recall': 0.7666980244566739, 'f1': 0.7510367021108223, 'auc': 0.8229557160524215, 'prauc': 0.8321821314783768}


Epoch 001: 100%|██████████| 98/98 [00:06<00:00, 15.39it/s, loss=0.7032]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.01it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.36it/s]


Validation: {'precision': 0.8037552998134238, 'recall': 0.41611790529816206, 'f1': 0.5483471029406352, 'auc': 0.7669860318667914, 'prauc': 0.7696888149509129}
Test:      {'precision': 0.7948717948669423, 'recall': 0.4082784571953331, 'f1': 0.5394655020397823, 'auc': 0.7602900620964823, 'prauc': 0.7563828272213458}


Epoch 002: 100%|██████████| 98/98 [00:06<00:00, 15.81it/s, loss=0.6256]
Running inference: 100%|██████████| 198/198 [00:09<00:00, 21.79it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.85it/s]


Validation: {'precision': 0.7276154571136425, 'recall': 0.7262464722460764, 'f1': 0.7269303151483821, 'auc': 0.8015810315598448, 'prauc': 0.8139979695950348}
Test:      {'precision': 0.7276178424526175, 'recall': 0.7212292254602658, 'f1': 0.7244094438167133, 'auc': 0.7966955530825012, 'prauc': 0.8083947367054606}


Epoch 003: 100%|██████████| 98/98 [00:06<00:00, 15.79it/s, loss=0.5600]
Running inference: 100%|██████████| 198/198 [00:08<00:00, 22.26it/s]
Running inference: 100%|██████████| 197/197 [00:09<00:00, 21.39it/s]


Validation: {'precision': 0.8014557217921494, 'recall': 0.6215114455922812, 'f1': 0.7001059646995257, 'auc': 0.8171718627061202, 'prauc': 0.8289747019881051}
Test:      {'precision': 0.7898724082903116, 'recall': 0.6211978676681681, 'f1': 0.6954537426554526, 'auc': 0.8085342018888162, 'prauc': 0.8244293927122426}


Epoch 004: 100%|██████████| 98/98 [00:06<00:00, 15.84it/s, loss=0.5230]
Running inference:  70%|██████▉   | 138/198 [00:06<00:02, 21.71it/s]

In [None]:
def topk_avg_performance_formatted(performances, k=5):
    metrics = ["f1", "auc", "prauc"]
    scores = {m: np.array([p[m] for p in performances]) for m in metrics}

    # 计算排名（值越大排名越靠前）
    ranks = {m: (-scores[m]).argsort().argsort() + 1 for m in metrics}
    avg_ranks = np.mean(np.stack([ranks[m] for m in metrics], axis=1), axis=1)

    # 选 top-k
    topk_idx = np.argsort(avg_ranks)[:k]
    final_avg = {m: np.mean([performances[i][m] for i in topk_idx]) for m in performances[0].keys()}
    final_std = {m: np.std([performances[i][m] for i in topk_idx], ddof=0) for m in performances[0].keys()}

    # 打印结果
    print("Final Metrics:")
    for m in performances[0].keys():
        print(f"{m}: {final_avg[m]:.4f}±{final_std[m]:.4f}")

In [None]:
topk_avg_performance_formatted(final_metrics, 5)

Final Metrics:
precision: 0.7169±0.0089
recall: 0.8033±0.0202
f1: 0.7574±0.0046
auc: 0.8245±0.0014
prauc: 0.8311±0.0042
