In [1]:
import torch
import pickle
from argparse import Namespace
from torch.utils.data import DataLoader
from heterogt.utils.tokenizer import EHRTokenizer
from heterogt.utils.dataset import PreTrainEHRDataset, batcher, expand_level3
from heterogt.utils.seed import set_random_seed
from heterogt.model.model import HeteroGTPreTrain
from tqdm import tqdm

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
set_random_seed(123)

[INFO] Random seed set to 123


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
config = Namespace(
    dataset = "MIMIC-III",
    tasks = ["death", "readmission", "stay", "next_diag_6m", "next_diag_12m"], 
    task_index = 2,  # index of the task to train
    token_type = ["diag", "med", "lab", "pro"],
    special_tokens = ["[PAD]", "[CLS]"],
    attn_mask_dicts = [{1:[6,7], 2:[6,7], 3:[6,7], 4:[6,7], 6:[2,3,4,5,6,7], 7:[2,3,4,5,6,7]}, 
                       {1:[6,7], 2:[6,7], 3:[6,7], 4:[6,7], 6:[2,3,4,5,6,7], 7:[2,3,4,5,6,7]}],
    d_model = 64,
    num_heads = 4,
    batch_size = 32,
    lr = 1e-3,
    epochs = 25,
    early_stop_patience = 5,
    group_code_thre = 5,  # if there are group_code_thre diag codes belongs to the same group ICD code, then the group code is generated
    pretrain_mask_rate = 0.7,
    cls_ontology_weight = 5e-2,
    visit_ontology_weight = 5e-2,
)

In [5]:
full_data_path = f"./data_process/{config.dataset}-processed/mimic.pkl"  # for tokenizer
pretrain_data_path = f"./data_process/{config.dataset}-processed/mimic_pretrain.pkl" # for pretraining

In [6]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
group_code_sentences = [expand_level3()[1]]
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_sentences = [[str(c)] for c in set(ehr_full_data["AGE"].values.tolist())] # important of [[]]
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
config.max_num_adms = max_admissions
print(f"Max admissions per patient: {config.max_num_adms}")

Max admissions per patient: 8


In [7]:
tokenizer = EHRTokenizer(age_sentences, group_code_sentences, diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, special_tokens=config.special_tokens)
config.label_vocab_size = {"diag":tokenizer.token_number("diag"),     
                            "med":tokenizer.token_number("med"), 
                            "lab":tokenizer.token_number("lab"), 
                            "pro":tokenizer.token_number("pro")}
config.global_vocab_size = len(tokenizer.vocab.id2word)
config.age_vocab_size = tokenizer.token_number("age")
config.group_code_vocab_size = tokenizer.token_number("group")
print(f"Age vocabulary size: {config.age_vocab_size}")
print(f"Group code vocabulary size: {config.group_code_vocab_size}")

Age vocabulary size: 18
Group code vocabulary size: 19


In [8]:
# load pretrain data
ehr_pretrain_data = pickle.load(open(pretrain_data_path, 'rb'))
# load occurence data

In [9]:
pretrain_dataset = PreTrainEHRDataset(ehr_pretrain_data=ehr_pretrain_data, tokenizer=tokenizer, token_type=config.token_type,
                                      mask_rate=config.pretrain_mask_rate, group_code_thre=config.group_code_thre, max_num_adms=config.max_num_adms)
print("Number of pretrain samples:", len(pretrain_dataset))
pretrain_dataloader = DataLoader(pretrain_dataset, batch_size=config.batch_size, 
                                 collate_fn=batcher(tokenizer, n_token_type=len(config.token_type), is_pretrain=True), shuffle=True)

Number of pretrain samples: 23146


In [10]:
model = HeteroGTPreTrain(tokenizer=tokenizer, token_types=config.token_type, d_model=config.d_model, num_heads=config.num_heads, layer_types=['gnn', 'tf', 'gnn', 'tf'], max_num_adms=config.max_num_adms, 
                     device=device, label_vocab_size=config.label_vocab_size, attn_mask_dicts=config.attn_mask_dicts,
                     use_cls_cat=True).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)

In [11]:
# 统一权重表：未指定的一律 1.0
loss_weights = {
    **{t: 1.0 for t in config.token_type},
    "cls_ontology": float(config.cls_ontology_weight),
    "visit_ontology": float(config.visit_ontology_weight),
}
loss_types = list(loss_weights.keys())
print(loss_weights)

{'diag': 1.0, 'med': 1.0, 'lab': 1.0, 'pro': 1.0, 'cls_ontology': 0.05, 'visit_ontology': 0.05}


In [12]:
for epoch in range(1, 1 + config.epochs):
    model.train()
    avg_total_loss = 0.0
    avg_loss_dict = {t: 0.0 for t in loss_types}         # 未加权分项
    avg_contrib_dict = {t: 0.0 for t in loss_types}      # 加权后贡献
    step_iter = tqdm(pretrain_dataloader, desc=f"Epoch {epoch:03d}", unit="batch")

    for step, batch in enumerate(step_iter):
        batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]
        loss_dict = model(*batch)   # {loss_type: tensor}

        # 总 loss（加权）
        total_loss = sum(loss_weights[t] * loss_dict[t] for t in loss_types)

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # 累积
        avg_total_loss += total_loss.item()
        for t in loss_types:
            avg_loss_dict[t] += loss_dict[t].item()                       # 原始 loss
            avg_contrib_dict[t] += (loss_weights[t] * loss_dict[t]).item()  # 加权后贡献

    steps = step + 1
    avg_total_loss /= steps
    avg_loss_dict = {t: v / steps for t, v in avg_loss_dict.items()}
    avg_contrib_dict = {t: v / steps for t, v in avg_contrib_dict.items()}

    # 日志字符串
    raw_loss_str = ", ".join(f"{t}: {avg_loss_dict[t]:.4f}" for t in loss_types)
    contrib_str  = ", ".join(f"{t}: {avg_contrib_dict[t]:.4f}" for t in loss_types)

    print(f"[Epoch {epoch:03d}] Weighted total loss: {avg_total_loss:.4f}\n"
          f"  Raw losses      | {raw_loss_str}\n"
          f"  Contributions   | {contrib_str}")

Epoch 001:   0%|          | 0/724 [00:00<?, ?batch/s]

Epoch 001: 100%|██████████| 724/724 [00:45<00:00, 15.76batch/s]


[Epoch 001] Weighted total loss: 0.2842
  Raw losses      | diag: 0.0264, med: 0.1061, lab: 0.0836, pro: 0.0254, cls_ontology: 0.4257, visit_ontology: 0.4275
  Contributions   | diag: 0.0264, med: 0.1061, lab: 0.0836, pro: 0.0254, cls_ontology: 0.0213, visit_ontology: 0.0214


Epoch 002: 100%|██████████| 724/724 [00:50<00:00, 14.30batch/s]


[Epoch 002] Weighted total loss: 0.2363
  Raw losses      | diag: 0.0174, med: 0.0920, lab: 0.0728, pro: 0.0158, cls_ontology: 0.3820, visit_ontology: 0.3853
  Contributions   | diag: 0.0174, med: 0.0920, lab: 0.0728, pro: 0.0158, cls_ontology: 0.0191, visit_ontology: 0.0193


Epoch 003: 100%|██████████| 724/724 [00:48<00:00, 15.02batch/s]


[Epoch 003] Weighted total loss: 0.2300
  Raw losses      | diag: 0.0170, med: 0.0885, lab: 0.0718, pro: 0.0150, cls_ontology: 0.3750, visit_ontology: 0.3773
  Contributions   | diag: 0.0170, med: 0.0885, lab: 0.0718, pro: 0.0150, cls_ontology: 0.0187, visit_ontology: 0.0189


Epoch 004: 100%|██████████| 724/724 [00:48<00:00, 14.79batch/s]


[Epoch 004] Weighted total loss: 0.2266
  Raw losses      | diag: 0.0168, med: 0.0869, lab: 0.0712, pro: 0.0146, cls_ontology: 0.3714, visit_ontology: 0.3725
  Contributions   | diag: 0.0168, med: 0.0869, lab: 0.0712, pro: 0.0146, cls_ontology: 0.0186, visit_ontology: 0.0186


Epoch 005: 100%|██████████| 724/724 [00:51<00:00, 14.15batch/s]


[Epoch 005] Weighted total loss: 0.2233
  Raw losses      | diag: 0.0166, med: 0.0849, lab: 0.0707, pro: 0.0143, cls_ontology: 0.3683, visit_ontology: 0.3695
  Contributions   | diag: 0.0166, med: 0.0849, lab: 0.0707, pro: 0.0143, cls_ontology: 0.0184, visit_ontology: 0.0185


Epoch 006: 100%|██████████| 724/724 [00:51<00:00, 14.08batch/s]


[Epoch 006] Weighted total loss: 0.2205
  Raw losses      | diag: 0.0164, med: 0.0833, lab: 0.0703, pro: 0.0140, cls_ontology: 0.3651, visit_ontology: 0.3670
  Contributions   | diag: 0.0164, med: 0.0833, lab: 0.0703, pro: 0.0140, cls_ontology: 0.0183, visit_ontology: 0.0183


Epoch 007: 100%|██████████| 724/724 [00:47<00:00, 15.30batch/s]


[Epoch 007] Weighted total loss: 0.2181
  Raw losses      | diag: 0.0162, med: 0.0818, lab: 0.0699, pro: 0.0138, cls_ontology: 0.3623, visit_ontology: 0.3637
  Contributions   | diag: 0.0162, med: 0.0818, lab: 0.0699, pro: 0.0138, cls_ontology: 0.0181, visit_ontology: 0.0182


Epoch 008: 100%|██████████| 724/724 [00:52<00:00, 13.78batch/s]


[Epoch 008] Weighted total loss: 0.2162
  Raw losses      | diag: 0.0161, med: 0.0808, lab: 0.0696, pro: 0.0136, cls_ontology: 0.3602, visit_ontology: 0.3619
  Contributions   | diag: 0.0161, med: 0.0808, lab: 0.0696, pro: 0.0136, cls_ontology: 0.0180, visit_ontology: 0.0181


Epoch 009: 100%|██████████| 724/724 [00:48<00:00, 14.94batch/s]


[Epoch 009] Weighted total loss: 0.2149
  Raw losses      | diag: 0.0160, med: 0.0803, lab: 0.0693, pro: 0.0134, cls_ontology: 0.3585, visit_ontology: 0.3602
  Contributions   | diag: 0.0160, med: 0.0803, lab: 0.0693, pro: 0.0134, cls_ontology: 0.0179, visit_ontology: 0.0180


Epoch 010: 100%|██████████| 724/724 [00:51<00:00, 14.04batch/s]


[Epoch 010] Weighted total loss: 0.2134
  Raw losses      | diag: 0.0159, med: 0.0794, lab: 0.0691, pro: 0.0132, cls_ontology: 0.3569, visit_ontology: 0.3584
  Contributions   | diag: 0.0159, med: 0.0794, lab: 0.0691, pro: 0.0132, cls_ontology: 0.0178, visit_ontology: 0.0179


Epoch 011: 100%|██████████| 724/724 [00:49<00:00, 14.64batch/s]


[Epoch 011] Weighted total loss: 0.2123
  Raw losses      | diag: 0.0158, med: 0.0788, lab: 0.0689, pro: 0.0131, cls_ontology: 0.3558, visit_ontology: 0.3581
  Contributions   | diag: 0.0158, med: 0.0788, lab: 0.0689, pro: 0.0131, cls_ontology: 0.0178, visit_ontology: 0.0179


Epoch 012: 100%|██████████| 724/724 [00:47<00:00, 15.10batch/s]


[Epoch 012] Weighted total loss: 0.2112
  Raw losses      | diag: 0.0157, med: 0.0783, lab: 0.0687, pro: 0.0129, cls_ontology: 0.3545, visit_ontology: 0.3564
  Contributions   | diag: 0.0157, med: 0.0783, lab: 0.0687, pro: 0.0129, cls_ontology: 0.0177, visit_ontology: 0.0178


Epoch 013: 100%|██████████| 724/724 [00:51<00:00, 14.02batch/s]


[Epoch 013] Weighted total loss: 0.2104
  Raw losses      | diag: 0.0157, med: 0.0778, lab: 0.0686, pro: 0.0128, cls_ontology: 0.3541, visit_ontology: 0.3560
  Contributions   | diag: 0.0157, med: 0.0778, lab: 0.0686, pro: 0.0128, cls_ontology: 0.0177, visit_ontology: 0.0178


Epoch 014: 100%|██████████| 724/724 [00:50<00:00, 14.28batch/s]


[Epoch 014] Weighted total loss: 0.2097
  Raw losses      | diag: 0.0156, med: 0.0776, lab: 0.0683, pro: 0.0128, cls_ontology: 0.3542, visit_ontology: 0.3556
  Contributions   | diag: 0.0156, med: 0.0776, lab: 0.0683, pro: 0.0128, cls_ontology: 0.0177, visit_ontology: 0.0178


Epoch 015: 100%|██████████| 724/724 [00:53<00:00, 13.62batch/s]


[Epoch 015] Weighted total loss: 0.2090
  Raw losses      | diag: 0.0155, med: 0.0773, lab: 0.0682, pro: 0.0126, cls_ontology: 0.3520, visit_ontology: 0.3535
  Contributions   | diag: 0.0155, med: 0.0773, lab: 0.0682, pro: 0.0126, cls_ontology: 0.0176, visit_ontology: 0.0177


Epoch 016: 100%|██████████| 724/724 [00:47<00:00, 15.18batch/s]


[Epoch 016] Weighted total loss: 0.2081
  Raw losses      | diag: 0.0154, med: 0.0769, lab: 0.0681, pro: 0.0125, cls_ontology: 0.3504, visit_ontology: 0.3520
  Contributions   | diag: 0.0154, med: 0.0769, lab: 0.0681, pro: 0.0125, cls_ontology: 0.0175, visit_ontology: 0.0176


Epoch 017: 100%|██████████| 724/724 [00:54<00:00, 13.25batch/s]


[Epoch 017] Weighted total loss: 0.2075
  Raw losses      | diag: 0.0154, med: 0.0766, lab: 0.0680, pro: 0.0124, cls_ontology: 0.3504, visit_ontology: 0.3520
  Contributions   | diag: 0.0154, med: 0.0766, lab: 0.0680, pro: 0.0124, cls_ontology: 0.0175, visit_ontology: 0.0176


Epoch 018: 100%|██████████| 724/724 [00:48<00:00, 14.96batch/s]


[Epoch 018] Weighted total loss: 0.2067
  Raw losses      | diag: 0.0153, med: 0.0762, lab: 0.0678, pro: 0.0123, cls_ontology: 0.3493, visit_ontology: 0.3502
  Contributions   | diag: 0.0153, med: 0.0762, lab: 0.0678, pro: 0.0123, cls_ontology: 0.0175, visit_ontology: 0.0175


Epoch 019: 100%|██████████| 724/724 [00:50<00:00, 14.22batch/s]


[Epoch 019] Weighted total loss: 0.2064
  Raw losses      | diag: 0.0153, med: 0.0761, lab: 0.0678, pro: 0.0123, cls_ontology: 0.3484, visit_ontology: 0.3499
  Contributions   | diag: 0.0153, med: 0.0761, lab: 0.0678, pro: 0.0123, cls_ontology: 0.0174, visit_ontology: 0.0175


Epoch 020: 100%|██████████| 724/724 [00:52<00:00, 13.79batch/s]

[Epoch 020] Weighted total loss: 0.2056
  Raw losses      | diag: 0.0152, med: 0.0759, lab: 0.0676, pro: 0.0122, cls_ontology: 0.3470, visit_ontology: 0.3486
  Contributions   | diag: 0.0152, med: 0.0759, lab: 0.0676, pro: 0.0122, cls_ontology: 0.0174, visit_ontology: 0.0174





In [13]:
import os
exp_name = (
    f"{config.dataset}-{config.pretrain_mask_rate}-{config.d_model}-{config.cls_ontology_weight}-{config.visit_ontology_weight}"
)
save_path = "./pretrained_models/" + exp_name
if not os.path.exists(save_path):
    os.makedirs(save_path)
torch.save(model.cpu().state_dict(), f"{save_path}/pretrained_model.pt")
print("Save model:", exp_name)

Save model: MIMIC-III-0.7-64-0.05-0.05
