In [1]:
import torch
import pickle
from argparse import Namespace
from torch.utils.data import DataLoader
from heterogt.utils.tokenizer import EHRTokenizer
from heterogt.utils.dataset import PreTrainEHRDataset, batcher, expand_level3
from heterogt.utils.seed import set_random_seed
from heterogt.model.model import HeteroGTPreTrain
from tqdm import tqdm

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
set_random_seed(123)

[INFO] Random seed set to 123


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
config = Namespace(
    dataset = "MIMIC-IV",
    tasks = ["death", "readmission", "stay", "next_diag_6m", "next_diag_12m"], 
    task_index = 2,  # index of the task to train
    token_type = ["diag", "med", "lab", "pro"],
    special_tokens = ["[PAD]", "[CLS]"],
    attn_mask_dicts = [{1:[6,7], 2:[6,7], 3:[6,7], 4:[6,7], 5:[6,7], 6:[2,3,4,5,6,7], 7:[2,3,4,5,6,7]}, 
                       {1:[6,7], 2:[6,7], 3:[6,7], 4:[6,7], 5:[6,7], 6:[2,3,4,5,6,7], 7:[2,3,4,5,6,7]}],
    d_model = 64,
    num_heads = 4,
    batch_size = 32,
    lr = 1e-3,
    epochs = 25,
    early_stop_patience = 5,
    group_code_thre = 5,  # if there are group_code_thre diag codes belongs to the same group ICD code, then the group code is generated
    pretrain_mask_rate = 0.7,
    cls_ontology_weight = 5e-2,
    visit_ontology_weight = 5e-2,
    adm_type_weight = 5e-2,
)

In [5]:
full_data_path = f"./data_process/{config.dataset}-processed/mimic.pkl"  # for tokenizer
pretrain_data_path = f"./data_process/{config.dataset}-processed/mimic_pretrain.pkl" # for pretraining

In [6]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
group_code_sentences = [expand_level3()[1]]
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_sentences = [[str(c)] for c in set(ehr_full_data["AGE"].values.tolist())] # important of [[]]
adm_type_sentences = ehr_full_data["ADMISSION_TYPE"].values.tolist()
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
config.max_num_adms = max_admissions
print(f"Max admissions per patient: {config.max_num_adms}")

Max admissions per patient: 8


In [7]:
tokenizer = EHRTokenizer(age_sentences, group_code_sentences, diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, special_tokens=config.special_tokens, adm_types_sentences=adm_type_sentences)
config.label_vocab_size = {"diag":tokenizer.token_number("diag"),     
                            "med":tokenizer.token_number("med"), 
                            "lab":tokenizer.token_number("lab"), 
                            "pro":tokenizer.token_number("pro")}
config.global_vocab_size = len(tokenizer.vocab.id2word)
config.age_vocab_size = tokenizer.token_number("age")
config.group_code_vocab_size = tokenizer.token_number("group")
print(f"Age vocabulary size: {config.age_vocab_size}")
print(f"Group code vocabulary size: {config.group_code_vocab_size}")

Age vocabulary size: 20
Group code vocabulary size: 19


In [8]:
# load pretrain data
ehr_pretrain_data = pickle.load(open(pretrain_data_path, 'rb'))
# load occurence data

In [9]:
pretrain_dataset = PreTrainEHRDataset(ehr_pretrain_data=ehr_pretrain_data, tokenizer=tokenizer, token_type=config.token_type,
                                      mask_rate=config.pretrain_mask_rate, group_code_thre=config.group_code_thre, max_num_adms=config.max_num_adms)
print("Number of pretrain samples:", len(pretrain_dataset))
pretrain_dataloader = DataLoader(pretrain_dataset, batch_size=config.batch_size, 
                                 collate_fn=batcher(tokenizer, n_token_type=len(config.token_type), is_pretrain=True), shuffle=True)

Number of pretrain samples: 42496


In [10]:
model = HeteroGTPreTrain(tokenizer=tokenizer, token_types=config.token_type, d_model=config.d_model, num_heads=config.num_heads, layer_types=['gnn', 'tf', 'gnn', 'tf'], max_num_adms=config.max_num_adms, 
                     device=device, label_vocab_size=config.label_vocab_size, attn_mask_dicts=config.attn_mask_dicts,
                     use_cls_cat=True).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)

In [11]:
# 统一权重表：未指定的一律 1.0
loss_weights = {
    **{t: 1.0 for t in config.token_type},
    "cls_ontology": float(config.cls_ontology_weight),
    "visit_ontology": float(config.visit_ontology_weight),
    "adm_type": float(config.adm_type_weight)
}
loss_types = list(loss_weights.keys())
print(loss_weights)

{'diag': 1.0, 'med': 1.0, 'lab': 1.0, 'pro': 1.0, 'cls_ontology': 0.05, 'visit_ontology': 0.05, 'adm_type': 0.05}


In [12]:
for epoch in range(1, 1 + config.epochs):
    model.train()
    avg_total_loss = 0.0
    avg_loss_dict = {t: 0.0 for t in loss_types}         # 未加权分项
    avg_contrib_dict = {t: 0.0 for t in loss_types}      # 加权后贡献
    step_iter = tqdm(pretrain_dataloader, desc=f"Epoch {epoch:03d}", unit="batch")

    for step, batch in enumerate(step_iter):
        batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]
        loss_dict = model(*batch)   # {loss_type: tensor}

        # 总 loss（加权）
        total_loss = sum(loss_weights[t] * loss_dict[t] for t in loss_types)

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # 累积
        avg_total_loss += total_loss.item()
        for t in loss_types:
            avg_loss_dict[t] += loss_dict[t].item()                       # 原始 loss
            avg_contrib_dict[t] += (loss_weights[t] * loss_dict[t]).item()  # 加权后贡献

    steps = step + 1
    avg_total_loss /= steps
    avg_loss_dict = {t: v / steps for t, v in avg_loss_dict.items()}
    avg_contrib_dict = {t: v / steps for t, v in avg_contrib_dict.items()}

    # 日志字符串
    raw_loss_str = ", ".join(f"{t}: {avg_loss_dict[t]:.4f}" for t in loss_types)
    contrib_str  = ", ".join(f"{t}: {avg_contrib_dict[t]:.4f}" for t in loss_types)

    print(f"[Epoch {epoch:03d}] Weighted total loss: {avg_total_loss:.4f}\n"
          f"  Raw losses      | {raw_loss_str}\n"
          f"  Contributions   | {contrib_str}")

Epoch 001: 100%|██████████| 1328/1328 [01:16<00:00, 17.38batch/s]


[Epoch 001] Weighted total loss: 0.2012
  Raw losses      | diag: 0.0227, med: 0.0594, lab: 0.0453, pro: 0.0185, cls_ontology: 0.4284, visit_ontology: 0.4298, adm_type: 0.2459
  Contributions   | diag: 0.0227, med: 0.0594, lab: 0.0453, pro: 0.0185, cls_ontology: 0.0214, visit_ontology: 0.0215, adm_type: 0.0123


Epoch 002: 100%|██████████| 1328/1328 [01:18<00:00, 16.85batch/s]


[Epoch 002] Weighted total loss: 0.1672
  Raw losses      | diag: 0.0175, med: 0.0499, lab: 0.0379, pro: 0.0120, cls_ontology: 0.3995, visit_ontology: 0.4003, adm_type: 0.1978
  Contributions   | diag: 0.0175, med: 0.0499, lab: 0.0379, pro: 0.0120, cls_ontology: 0.0200, visit_ontology: 0.0200, adm_type: 0.0099


Epoch 003: 100%|██████████| 1328/1328 [01:16<00:00, 17.26batch/s]


[Epoch 003] Weighted total loss: 0.1605
  Raw losses      | diag: 0.0170, med: 0.0480, lab: 0.0366, pro: 0.0108, cls_ontology: 0.3873, visit_ontology: 0.3902, adm_type: 0.1824
  Contributions   | diag: 0.0170, med: 0.0480, lab: 0.0366, pro: 0.0108, cls_ontology: 0.0194, visit_ontology: 0.0195, adm_type: 0.0091


Epoch 004: 100%|██████████| 1328/1328 [01:18<00:00, 16.94batch/s]


[Epoch 004] Weighted total loss: 0.1564
  Raw losses      | diag: 0.0166, med: 0.0467, lab: 0.0360, pro: 0.0102, cls_ontology: 0.3802, visit_ontology: 0.3817, adm_type: 0.1766
  Contributions   | diag: 0.0166, med: 0.0467, lab: 0.0360, pro: 0.0102, cls_ontology: 0.0190, visit_ontology: 0.0191, adm_type: 0.0088


Epoch 005: 100%|██████████| 1328/1328 [01:16<00:00, 17.31batch/s]


[Epoch 005] Weighted total loss: 0.1538
  Raw losses      | diag: 0.0164, med: 0.0459, lab: 0.0356, pro: 0.0097, cls_ontology: 0.3767, visit_ontology: 0.3772, adm_type: 0.1725
  Contributions   | diag: 0.0164, med: 0.0459, lab: 0.0356, pro: 0.0097, cls_ontology: 0.0188, visit_ontology: 0.0189, adm_type: 0.0086


Epoch 006: 100%|██████████| 1328/1328 [01:17<00:00, 17.09batch/s]


[Epoch 006] Weighted total loss: 0.1518
  Raw losses      | diag: 0.0162, med: 0.0453, lab: 0.0352, pro: 0.0093, cls_ontology: 0.3730, visit_ontology: 0.3734, adm_type: 0.1698
  Contributions   | diag: 0.0162, med: 0.0453, lab: 0.0352, pro: 0.0093, cls_ontology: 0.0187, visit_ontology: 0.0187, adm_type: 0.0085


Epoch 007: 100%|██████████| 1328/1328 [01:18<00:00, 16.91batch/s]


[Epoch 007] Weighted total loss: 0.1501
  Raw losses      | diag: 0.0160, med: 0.0448, lab: 0.0349, pro: 0.0090, cls_ontology: 0.3700, visit_ontology: 0.3698, adm_type: 0.1675
  Contributions   | diag: 0.0160, med: 0.0448, lab: 0.0349, pro: 0.0090, cls_ontology: 0.0185, visit_ontology: 0.0185, adm_type: 0.0084


Epoch 008: 100%|██████████| 1328/1328 [01:16<00:00, 17.26batch/s]


[Epoch 008] Weighted total loss: 0.1487
  Raw losses      | diag: 0.0159, med: 0.0443, lab: 0.0347, pro: 0.0088, cls_ontology: 0.3676, visit_ontology: 0.3673, adm_type: 0.1655
  Contributions   | diag: 0.0159, med: 0.0443, lab: 0.0347, pro: 0.0088, cls_ontology: 0.0184, visit_ontology: 0.0184, adm_type: 0.0083


Epoch 009: 100%|██████████| 1328/1328 [01:19<00:00, 16.76batch/s]


[Epoch 009] Weighted total loss: 0.1476
  Raw losses      | diag: 0.0158, med: 0.0440, lab: 0.0345, pro: 0.0086, cls_ontology: 0.3664, visit_ontology: 0.3654, adm_type: 0.1646
  Contributions   | diag: 0.0158, med: 0.0440, lab: 0.0345, pro: 0.0086, cls_ontology: 0.0183, visit_ontology: 0.0183, adm_type: 0.0082


Epoch 010: 100%|██████████| 1328/1328 [01:14<00:00, 17.73batch/s]


[Epoch 010] Weighted total loss: 0.1465
  Raw losses      | diag: 0.0157, med: 0.0436, lab: 0.0343, pro: 0.0085, cls_ontology: 0.3639, visit_ontology: 0.3628, adm_type: 0.1627
  Contributions   | diag: 0.0157, med: 0.0436, lab: 0.0343, pro: 0.0085, cls_ontology: 0.0182, visit_ontology: 0.0181, adm_type: 0.0081


Epoch 011: 100%|██████████| 1328/1328 [01:17<00:00, 17.05batch/s]


[Epoch 011] Weighted total loss: 0.1458
  Raw losses      | diag: 0.0156, med: 0.0435, lab: 0.0341, pro: 0.0084, cls_ontology: 0.3626, visit_ontology: 0.3612, adm_type: 0.1622
  Contributions   | diag: 0.0156, med: 0.0435, lab: 0.0341, pro: 0.0084, cls_ontology: 0.0181, visit_ontology: 0.0181, adm_type: 0.0081


Epoch 012: 100%|██████████| 1328/1328 [01:17<00:00, 17.21batch/s]


[Epoch 012] Weighted total loss: 0.1451
  Raw losses      | diag: 0.0155, med: 0.0433, lab: 0.0340, pro: 0.0082, cls_ontology: 0.3619, visit_ontology: 0.3603, adm_type: 0.1608
  Contributions   | diag: 0.0155, med: 0.0433, lab: 0.0340, pro: 0.0082, cls_ontology: 0.0181, visit_ontology: 0.0180, adm_type: 0.0080


Epoch 013: 100%|██████████| 1328/1328 [01:18<00:00, 16.97batch/s]


[Epoch 013] Weighted total loss: 0.1444
  Raw losses      | diag: 0.0154, med: 0.0430, lab: 0.0338, pro: 0.0082, cls_ontology: 0.3606, visit_ontology: 0.3591, adm_type: 0.1597
  Contributions   | diag: 0.0154, med: 0.0430, lab: 0.0338, pro: 0.0082, cls_ontology: 0.0180, visit_ontology: 0.0180, adm_type: 0.0080


Epoch 014: 100%|██████████| 1328/1328 [01:17<00:00, 17.20batch/s]


[Epoch 014] Weighted total loss: 0.1439
  Raw losses      | diag: 0.0154, med: 0.0429, lab: 0.0337, pro: 0.0081, cls_ontology: 0.3597, visit_ontology: 0.3578, adm_type: 0.1590
  Contributions   | diag: 0.0154, med: 0.0429, lab: 0.0337, pro: 0.0081, cls_ontology: 0.0180, visit_ontology: 0.0179, adm_type: 0.0080


Epoch 015: 100%|██████████| 1328/1328 [01:15<00:00, 17.51batch/s]


[Epoch 015] Weighted total loss: 0.1432
  Raw losses      | diag: 0.0153, med: 0.0427, lab: 0.0335, pro: 0.0080, cls_ontology: 0.3587, visit_ontology: 0.3566, adm_type: 0.1580
  Contributions   | diag: 0.0153, med: 0.0427, lab: 0.0335, pro: 0.0080, cls_ontology: 0.0179, visit_ontology: 0.0178, adm_type: 0.0079


Epoch 016: 100%|██████████| 1328/1328 [01:18<00:00, 16.86batch/s]


[Epoch 016] Weighted total loss: 0.1428
  Raw losses      | diag: 0.0153, med: 0.0426, lab: 0.0334, pro: 0.0080, cls_ontology: 0.3576, visit_ontology: 0.3556, adm_type: 0.1573
  Contributions   | diag: 0.0153, med: 0.0426, lab: 0.0334, pro: 0.0080, cls_ontology: 0.0179, visit_ontology: 0.0178, adm_type: 0.0079


Epoch 017: 100%|██████████| 1328/1328 [01:16<00:00, 17.46batch/s]


[Epoch 017] Weighted total loss: 0.1423
  Raw losses      | diag: 0.0152, med: 0.0424, lab: 0.0333, pro: 0.0079, cls_ontology: 0.3570, visit_ontology: 0.3552, adm_type: 0.1564
  Contributions   | diag: 0.0152, med: 0.0424, lab: 0.0333, pro: 0.0079, cls_ontology: 0.0179, visit_ontology: 0.0178, adm_type: 0.0078


Epoch 018: 100%|██████████| 1328/1328 [01:18<00:00, 16.88batch/s]


[Epoch 018] Weighted total loss: 0.1419
  Raw losses      | diag: 0.0152, med: 0.0422, lab: 0.0332, pro: 0.0079, cls_ontology: 0.3565, visit_ontology: 0.3549, adm_type: 0.1554
  Contributions   | diag: 0.0152, med: 0.0422, lab: 0.0332, pro: 0.0079, cls_ontology: 0.0178, visit_ontology: 0.0177, adm_type: 0.0078


Epoch 019: 100%|██████████| 1328/1328 [01:14<00:00, 17.86batch/s]


[Epoch 019] Weighted total loss: 0.1416
  Raw losses      | diag: 0.0151, med: 0.0422, lab: 0.0332, pro: 0.0079, cls_ontology: 0.3559, visit_ontology: 0.3541, adm_type: 0.1551
  Contributions   | diag: 0.0151, med: 0.0422, lab: 0.0332, pro: 0.0079, cls_ontology: 0.0178, visit_ontology: 0.0177, adm_type: 0.0078


Epoch 020: 100%|██████████| 1328/1328 [01:18<00:00, 16.96batch/s]


[Epoch 020] Weighted total loss: 0.1412
  Raw losses      | diag: 0.0151, med: 0.0420, lab: 0.0331, pro: 0.0078, cls_ontology: 0.3554, visit_ontology: 0.3533, adm_type: 0.1545
  Contributions   | diag: 0.0151, med: 0.0420, lab: 0.0331, pro: 0.0078, cls_ontology: 0.0178, visit_ontology: 0.0177, adm_type: 0.0077


Epoch 021: 100%|██████████| 1328/1328 [01:17<00:00, 17.19batch/s]


[Epoch 021] Weighted total loss: 0.1409
  Raw losses      | diag: 0.0151, med: 0.0419, lab: 0.0330, pro: 0.0078, cls_ontology: 0.3546, visit_ontology: 0.3529, adm_type: 0.1544
  Contributions   | diag: 0.0151, med: 0.0419, lab: 0.0330, pro: 0.0078, cls_ontology: 0.0177, visit_ontology: 0.0176, adm_type: 0.0077


Epoch 022: 100%|██████████| 1328/1328 [01:18<00:00, 16.91batch/s]


[Epoch 022] Weighted total loss: 0.1405
  Raw losses      | diag: 0.0151, med: 0.0418, lab: 0.0329, pro: 0.0078, cls_ontology: 0.3539, visit_ontology: 0.3517, adm_type: 0.1536
  Contributions   | diag: 0.0151, med: 0.0418, lab: 0.0329, pro: 0.0078, cls_ontology: 0.0177, visit_ontology: 0.0176, adm_type: 0.0077


Epoch 023: 100%|██████████| 1328/1328 [01:16<00:00, 17.35batch/s]


[Epoch 023] Weighted total loss: 0.1402
  Raw losses      | diag: 0.0150, med: 0.0416, lab: 0.0329, pro: 0.0077, cls_ontology: 0.3536, visit_ontology: 0.3514, adm_type: 0.1537
  Contributions   | diag: 0.0150, med: 0.0416, lab: 0.0329, pro: 0.0077, cls_ontology: 0.0177, visit_ontology: 0.0176, adm_type: 0.0077


Epoch 024: 100%|██████████| 1328/1328 [01:17<00:00, 17.22batch/s]


[Epoch 024] Weighted total loss: 0.1399
  Raw losses      | diag: 0.0150, med: 0.0415, lab: 0.0328, pro: 0.0077, cls_ontology: 0.3533, visit_ontology: 0.3511, adm_type: 0.1526
  Contributions   | diag: 0.0150, med: 0.0415, lab: 0.0328, pro: 0.0077, cls_ontology: 0.0177, visit_ontology: 0.0176, adm_type: 0.0076


Epoch 025: 100%|██████████| 1328/1328 [01:18<00:00, 17.02batch/s]

[Epoch 025] Weighted total loss: 0.1398
  Raw losses      | diag: 0.0150, med: 0.0416, lab: 0.0328, pro: 0.0077, cls_ontology: 0.3527, visit_ontology: 0.3504, adm_type: 0.1523
  Contributions   | diag: 0.0150, med: 0.0416, lab: 0.0328, pro: 0.0077, cls_ontology: 0.0176, visit_ontology: 0.0175, adm_type: 0.0076





In [13]:
import os
exp_name = (
    f"{config.dataset}-{config.pretrain_mask_rate}-{config.d_model}-{config.cls_ontology_weight}-{config.visit_ontology_weight}-{config.adm_type_weight}"
)
save_path = "./pretrained_models/" + exp_name
if not os.path.exists(save_path):
    os.makedirs(save_path)
torch.save(model.cpu().state_dict(), f"{save_path}/pretrained_model.pt")
print("Save model:", exp_name)

Save model: MIMIC-IV-0.7-64-0.05-0.05-0.05
