In [1]:
import torch
import pickle
from argparse import Namespace
from torch.utils.data import DataLoader
from heterogt.utils.tokenizer import EHRTokenizer
from heterogt.utils.dataset import PreTrainEHRDataset, batcher, expand_level3
from heterogt.utils.seed import set_random_seed
from heterogt.model.model import HeteroGTPreTrain
from tqdm import tqdm

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
set_random_seed(123)

[INFO] Random seed set to 123


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
config = Namespace(
    dataset = "MIMIC-III",
    tasks = ["death", "readmission", "stay", "next_diag_6m", "next_diag_12m"], 
    task_index = 2,  # index of the task to train
    token_type = ["diag", "med", "lab", "pro"],
    special_tokens = ["[PAD]", "[CLS]"],
    attn_mask_dicts = [{1:[6,7], 2:[6,7], 3:[6,7], 4:[6,7], 6:[2,3,4,5,6,7], 7:[2,3,4,5,6,7]}, 
                       {1:[6,7], 2:[6,7], 3:[6,7], 4:[6,7], 6:[2,3,4,5,6,7], 7:[2,3,4,5,6,7]}],
    d_model = 64,
    num_heads = 4,
    batch_size = 32,
    lr = 1e-3,
    epochs = 25,
    early_stop_patience = 5,
    group_code_thre = 5,  # if there are group_code_thre diag codes belongs to the same group ICD code, then the group code is generated
    pretrain_mask_rate = 0.7,
    cls_ontology_weight = 5e-2,
    visit_ontology_weight = 5e-2,
    adm_type_weight = 5e-2,
)

In [5]:
full_data_path = f"./data_process/{config.dataset}-processed/mimic.pkl"  # for tokenizer
pretrain_data_path = f"./data_process/{config.dataset}-processed/mimic_pretrain.pkl" # for pretraining

In [6]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
group_code_sentences = [expand_level3()[1]]
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_sentences = [[str(c)] for c in set(ehr_full_data["AGE"].values.tolist())] # important of [[]]
adm_type_sentences = ehr_full_data["ADMISSION_TYPE"].values.tolist()
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
config.max_num_adms = max_admissions
print(f"Max admissions per patient: {config.max_num_adms}")

Max admissions per patient: 8


In [7]:
tokenizer = EHRTokenizer(age_sentences, group_code_sentences, diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, special_tokens=config.special_tokens, adm_types_sentences=adm_type_sentences)
config.label_vocab_size = {"diag":tokenizer.token_number("diag"),     
                            "med":tokenizer.token_number("med"), 
                            "lab":tokenizer.token_number("lab"), 
                            "pro":tokenizer.token_number("pro")}
config.global_vocab_size = len(tokenizer.vocab.id2word)
config.age_vocab_size = tokenizer.token_number("age")
config.group_code_vocab_size = tokenizer.token_number("group")
print(f"Age vocabulary size: {config.age_vocab_size}")
print(f"Group code vocabulary size: {config.group_code_vocab_size}")

Age vocabulary size: 18
Group code vocabulary size: 19


In [8]:
# load pretrain data
ehr_pretrain_data = pickle.load(open(pretrain_data_path, 'rb'))
# load occurence data

In [9]:
pretrain_dataset = PreTrainEHRDataset(ehr_pretrain_data=ehr_pretrain_data, tokenizer=tokenizer, token_type=config.token_type,
                                      mask_rate=config.pretrain_mask_rate, group_code_thre=config.group_code_thre, max_num_adms=config.max_num_adms)
print("Number of pretrain samples:", len(pretrain_dataset))
pretrain_dataloader = DataLoader(pretrain_dataset, batch_size=config.batch_size, 
                                 collate_fn=batcher(tokenizer, n_token_type=len(config.token_type), is_pretrain=True), shuffle=True)

Number of pretrain samples: 23146


In [10]:
model = HeteroGTPreTrain(tokenizer=tokenizer, token_types=config.token_type, d_model=config.d_model, num_heads=config.num_heads, layer_types=['gnn', 'tf', 'gnn', 'tf'], max_num_adms=config.max_num_adms, 
                     device=device, label_vocab_size=config.label_vocab_size, attn_mask_dicts=config.attn_mask_dicts,
                     use_cls_cat=True).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)

In [11]:
# 统一权重表：未指定的一律 1.0
loss_weights = {
    **{t: 1.0 for t in config.token_type},
    "cls_ontology": float(config.cls_ontology_weight),
    "visit_ontology": float(config.visit_ontology_weight),
    "adm_type": float(config.adm_type_weight)
}
loss_types = list(loss_weights.keys())
print(loss_weights)

{'diag': 1.0, 'med': 1.0, 'lab': 1.0, 'pro': 1.0, 'cls_ontology': 0.05, 'visit_ontology': 0.05, 'adm_type': 0.05}


In [12]:
for epoch in range(1, 1 + config.epochs):
    model.train()
    avg_total_loss = 0.0
    avg_loss_dict = {t: 0.0 for t in loss_types}         # 未加权分项
    avg_contrib_dict = {t: 0.0 for t in loss_types}      # 加权后贡献
    step_iter = tqdm(pretrain_dataloader, desc=f"Epoch {epoch:03d}", unit="batch")

    for step, batch in enumerate(step_iter):
        batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]
        loss_dict = model(*batch)   # {loss_type: tensor}

        # 总 loss（加权）
        total_loss = sum(loss_weights[t] * loss_dict[t] for t in loss_types)

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # 累积
        avg_total_loss += total_loss.item()
        for t in loss_types:
            avg_loss_dict[t] += loss_dict[t].item()                       # 原始 loss
            avg_contrib_dict[t] += (loss_weights[t] * loss_dict[t]).item()  # 加权后贡献

    steps = step + 1
    avg_total_loss /= steps
    avg_loss_dict = {t: v / steps for t, v in avg_loss_dict.items()}
    avg_contrib_dict = {t: v / steps for t, v in avg_contrib_dict.items()}

    # 日志字符串
    raw_loss_str = ", ".join(f"{t}: {avg_loss_dict[t]:.4f}" for t in loss_types)
    contrib_str  = ", ".join(f"{t}: {avg_contrib_dict[t]:.4f}" for t in loss_types)

    print(f"[Epoch {epoch:03d}] Weighted total loss: {avg_total_loss:.4f}\n"
          f"  Raw losses      | {raw_loss_str}\n"
          f"  Contributions   | {contrib_str}")

Epoch 001:   0%|          | 0/724 [00:00<?, ?batch/s]

Epoch 001: 100%|██████████| 724/724 [00:51<00:00, 14.13batch/s]


[Epoch 001] Weighted total loss: 0.2961
  Raw losses      | diag: 0.0264, med: 0.1056, lab: 0.0834, pro: 0.0253, cls_ontology: 0.4266, visit_ontology: 0.4245, adm_type: 0.2573
  Contributions   | diag: 0.0264, med: 0.1056, lab: 0.0834, pro: 0.0253, cls_ontology: 0.0213, visit_ontology: 0.0212, adm_type: 0.0129


Epoch 002: 100%|██████████| 724/724 [00:46<00:00, 15.44batch/s]


[Epoch 002] Weighted total loss: 0.2463
  Raw losses      | diag: 0.0175, med: 0.0921, lab: 0.0726, pro: 0.0160, cls_ontology: 0.3856, visit_ontology: 0.3866, adm_type: 0.1909
  Contributions   | diag: 0.0175, med: 0.0921, lab: 0.0726, pro: 0.0160, cls_ontology: 0.0193, visit_ontology: 0.0193, adm_type: 0.0095


Epoch 003: 100%|██████████| 724/724 [00:48<00:00, 14.96batch/s]


[Epoch 003] Weighted total loss: 0.2391
  Raw losses      | diag: 0.0169, med: 0.0894, lab: 0.0714, pro: 0.0150, cls_ontology: 0.3777, visit_ontology: 0.3806, adm_type: 0.1689
  Contributions   | diag: 0.0169, med: 0.0894, lab: 0.0714, pro: 0.0150, cls_ontology: 0.0189, visit_ontology: 0.0190, adm_type: 0.0084


Epoch 004: 100%|██████████| 724/724 [00:51<00:00, 14.18batch/s]


[Epoch 004] Weighted total loss: 0.2353
  Raw losses      | diag: 0.0167, med: 0.0879, lab: 0.0709, pro: 0.0145, cls_ontology: 0.3716, visit_ontology: 0.3753, adm_type: 0.1588
  Contributions   | diag: 0.0167, med: 0.0879, lab: 0.0709, pro: 0.0145, cls_ontology: 0.0186, visit_ontology: 0.0188, adm_type: 0.0079


Epoch 005: 100%|██████████| 724/724 [00:47<00:00, 15.40batch/s]


[Epoch 005] Weighted total loss: 0.2323
  Raw losses      | diag: 0.0165, med: 0.0865, lab: 0.0704, pro: 0.0142, cls_ontology: 0.3677, visit_ontology: 0.3724, adm_type: 0.1532
  Contributions   | diag: 0.0165, med: 0.0865, lab: 0.0704, pro: 0.0142, cls_ontology: 0.0184, visit_ontology: 0.0186, adm_type: 0.0077


Epoch 006: 100%|██████████| 724/724 [00:49<00:00, 14.69batch/s]


[Epoch 006] Weighted total loss: 0.2298
  Raw losses      | diag: 0.0163, med: 0.0853, lab: 0.0700, pro: 0.0140, cls_ontology: 0.3648, visit_ontology: 0.3697, adm_type: 0.1493
  Contributions   | diag: 0.0163, med: 0.0853, lab: 0.0700, pro: 0.0140, cls_ontology: 0.0182, visit_ontology: 0.0185, adm_type: 0.0075


Epoch 007: 100%|██████████| 724/724 [00:48<00:00, 14.84batch/s]


[Epoch 007] Weighted total loss: 0.2272
  Raw losses      | diag: 0.0161, med: 0.0839, lab: 0.0697, pro: 0.0137, cls_ontology: 0.3611, visit_ontology: 0.3666, adm_type: 0.1473
  Contributions   | diag: 0.0161, med: 0.0839, lab: 0.0697, pro: 0.0137, cls_ontology: 0.0181, visit_ontology: 0.0183, adm_type: 0.0074


Epoch 008: 100%|██████████| 724/724 [00:44<00:00, 16.25batch/s]


[Epoch 008] Weighted total loss: 0.2249
  Raw losses      | diag: 0.0160, med: 0.0824, lab: 0.0694, pro: 0.0135, cls_ontology: 0.3604, visit_ontology: 0.3658, adm_type: 0.1451
  Contributions   | diag: 0.0160, med: 0.0824, lab: 0.0694, pro: 0.0135, cls_ontology: 0.0180, visit_ontology: 0.0183, adm_type: 0.0073


Epoch 009: 100%|██████████| 724/724 [00:51<00:00, 14.07batch/s]


[Epoch 009] Weighted total loss: 0.2231
  Raw losses      | diag: 0.0159, med: 0.0814, lab: 0.0691, pro: 0.0134, cls_ontology: 0.3584, visit_ontology: 0.3642, adm_type: 0.1426
  Contributions   | diag: 0.0159, med: 0.0814, lab: 0.0691, pro: 0.0134, cls_ontology: 0.0179, visit_ontology: 0.0182, adm_type: 0.0071


Epoch 010: 100%|██████████| 724/724 [00:48<00:00, 14.86batch/s]


[Epoch 010] Weighted total loss: 0.2214
  Raw losses      | diag: 0.0158, med: 0.0804, lab: 0.0689, pro: 0.0132, cls_ontology: 0.3572, visit_ontology: 0.3625, adm_type: 0.1417
  Contributions   | diag: 0.0158, med: 0.0804, lab: 0.0689, pro: 0.0132, cls_ontology: 0.0179, visit_ontology: 0.0181, adm_type: 0.0071


Epoch 011: 100%|██████████| 724/724 [00:47<00:00, 15.37batch/s]


[Epoch 011] Weighted total loss: 0.2201
  Raw losses      | diag: 0.0157, med: 0.0799, lab: 0.0686, pro: 0.0130, cls_ontology: 0.3555, visit_ontology: 0.3607, adm_type: 0.1400
  Contributions   | diag: 0.0157, med: 0.0799, lab: 0.0686, pro: 0.0130, cls_ontology: 0.0178, visit_ontology: 0.0180, adm_type: 0.0070


Epoch 012: 100%|██████████| 724/724 [00:47<00:00, 15.11batch/s]


[Epoch 012] Weighted total loss: 0.2192
  Raw losses      | diag: 0.0157, med: 0.0794, lab: 0.0686, pro: 0.0129, cls_ontology: 0.3544, visit_ontology: 0.3595, adm_type: 0.1393
  Contributions   | diag: 0.0157, med: 0.0794, lab: 0.0686, pro: 0.0129, cls_ontology: 0.0177, visit_ontology: 0.0180, adm_type: 0.0070


Epoch 013: 100%|██████████| 724/724 [00:47<00:00, 15.32batch/s]


[Epoch 013] Weighted total loss: 0.2181
  Raw losses      | diag: 0.0156, med: 0.0790, lab: 0.0684, pro: 0.0128, cls_ontology: 0.3541, visit_ontology: 0.3587, adm_type: 0.1365
  Contributions   | diag: 0.0156, med: 0.0790, lab: 0.0684, pro: 0.0128, cls_ontology: 0.0177, visit_ontology: 0.0179, adm_type: 0.0068


Epoch 014: 100%|██████████| 724/724 [00:47<00:00, 15.20batch/s]


[Epoch 014] Weighted total loss: 0.2172
  Raw losses      | diag: 0.0155, med: 0.0785, lab: 0.0682, pro: 0.0127, cls_ontology: 0.3525, visit_ontology: 0.3572, adm_type: 0.1361
  Contributions   | diag: 0.0155, med: 0.0785, lab: 0.0682, pro: 0.0127, cls_ontology: 0.0176, visit_ontology: 0.0179, adm_type: 0.0068


Epoch 015: 100%|██████████| 724/724 [00:47<00:00, 15.11batch/s]


[Epoch 015] Weighted total loss: 0.2163
  Raw losses      | diag: 0.0155, med: 0.0780, lab: 0.0681, pro: 0.0125, cls_ontology: 0.3527, visit_ontology: 0.3557, adm_type: 0.1365
  Contributions   | diag: 0.0155, med: 0.0780, lab: 0.0681, pro: 0.0125, cls_ontology: 0.0176, visit_ontology: 0.0178, adm_type: 0.0068


Epoch 016: 100%|██████████| 724/724 [00:46<00:00, 15.42batch/s]


[Epoch 016] Weighted total loss: 0.2155
  Raw losses      | diag: 0.0154, med: 0.0776, lab: 0.0680, pro: 0.0125, cls_ontology: 0.3516, visit_ontology: 0.3557, adm_type: 0.1338
  Contributions   | diag: 0.0154, med: 0.0776, lab: 0.0680, pro: 0.0125, cls_ontology: 0.0176, visit_ontology: 0.0178, adm_type: 0.0067


Epoch 017: 100%|██████████| 724/724 [00:46<00:00, 15.70batch/s]


[Epoch 017] Weighted total loss: 0.2146
  Raw losses      | diag: 0.0153, med: 0.0772, lab: 0.0678, pro: 0.0124, cls_ontology: 0.3508, visit_ontology: 0.3535, adm_type: 0.1327
  Contributions   | diag: 0.0153, med: 0.0772, lab: 0.0678, pro: 0.0124, cls_ontology: 0.0175, visit_ontology: 0.0177, adm_type: 0.0066


Epoch 018: 100%|██████████| 724/724 [00:43<00:00, 16.63batch/s]


[Epoch 018] Weighted total loss: 0.2140
  Raw losses      | diag: 0.0153, med: 0.0769, lab: 0.0677, pro: 0.0123, cls_ontology: 0.3507, visit_ontology: 0.3536, adm_type: 0.1315
  Contributions   | diag: 0.0153, med: 0.0769, lab: 0.0677, pro: 0.0123, cls_ontology: 0.0175, visit_ontology: 0.0177, adm_type: 0.0066


Epoch 019: 100%|██████████| 724/724 [00:46<00:00, 15.53batch/s]


[Epoch 019] Weighted total loss: 0.2134
  Raw losses      | diag: 0.0153, med: 0.0766, lab: 0.0677, pro: 0.0123, cls_ontology: 0.3496, visit_ontology: 0.3524, adm_type: 0.1299
  Contributions   | diag: 0.0153, med: 0.0766, lab: 0.0677, pro: 0.0123, cls_ontology: 0.0175, visit_ontology: 0.0176, adm_type: 0.0065


Epoch 020: 100%|██████████| 724/724 [00:49<00:00, 14.75batch/s]


[Epoch 020] Weighted total loss: 0.2127
  Raw losses      | diag: 0.0152, med: 0.0762, lab: 0.0676, pro: 0.0122, cls_ontology: 0.3488, visit_ontology: 0.3510, adm_type: 0.1305
  Contributions   | diag: 0.0152, med: 0.0762, lab: 0.0676, pro: 0.0122, cls_ontology: 0.0174, visit_ontology: 0.0176, adm_type: 0.0065


Epoch 021: 100%|██████████| 724/724 [00:49<00:00, 14.61batch/s]


[Epoch 021] Weighted total loss: 0.2122
  Raw losses      | diag: 0.0151, med: 0.0760, lab: 0.0674, pro: 0.0122, cls_ontology: 0.3478, visit_ontology: 0.3508, adm_type: 0.1310
  Contributions   | diag: 0.0151, med: 0.0760, lab: 0.0674, pro: 0.0122, cls_ontology: 0.0174, visit_ontology: 0.0175, adm_type: 0.0066


Epoch 022: 100%|██████████| 724/724 [00:48<00:00, 15.02batch/s]


[Epoch 022] Weighted total loss: 0.2116
  Raw losses      | diag: 0.0151, med: 0.0757, lab: 0.0674, pro: 0.0121, cls_ontology: 0.3474, visit_ontology: 0.3498, adm_type: 0.1296
  Contributions   | diag: 0.0151, med: 0.0757, lab: 0.0674, pro: 0.0121, cls_ontology: 0.0174, visit_ontology: 0.0175, adm_type: 0.0065


Epoch 023: 100%|██████████| 724/724 [00:46<00:00, 15.62batch/s]


[Epoch 023] Weighted total loss: 0.2112
  Raw losses      | diag: 0.0151, med: 0.0757, lab: 0.0673, pro: 0.0120, cls_ontology: 0.3467, visit_ontology: 0.3485, adm_type: 0.1287
  Contributions   | diag: 0.0151, med: 0.0757, lab: 0.0673, pro: 0.0120, cls_ontology: 0.0173, visit_ontology: 0.0174, adm_type: 0.0064


Epoch 024: 100%|██████████| 724/724 [00:46<00:00, 15.61batch/s]


[Epoch 024] Weighted total loss: 0.2106
  Raw losses      | diag: 0.0150, med: 0.0754, lab: 0.0672, pro: 0.0119, cls_ontology: 0.3459, visit_ontology: 0.3480, adm_type: 0.1272
  Contributions   | diag: 0.0150, med: 0.0754, lab: 0.0672, pro: 0.0119, cls_ontology: 0.0173, visit_ontology: 0.0174, adm_type: 0.0064


Epoch 025: 100%|██████████| 724/724 [00:50<00:00, 14.37batch/s]

[Epoch 025] Weighted total loss: 0.2104
  Raw losses      | diag: 0.0150, med: 0.0752, lab: 0.0671, pro: 0.0119, cls_ontology: 0.3463, visit_ontology: 0.3484, adm_type: 0.1290
  Contributions   | diag: 0.0150, med: 0.0752, lab: 0.0671, pro: 0.0119, cls_ontology: 0.0173, visit_ontology: 0.0174, adm_type: 0.0065





In [13]:
import os
exp_name = (
    f"{config.dataset}-{config.pretrain_mask_rate}-{config.d_model}-{config.cls_ontology_weight}-{config.visit_ontology_weight}-{config.adm_type_weight}"
)
save_path = "./pretrained_models/" + exp_name
if not os.path.exists(save_path):
    os.makedirs(save_path)
torch.save(model.cpu().state_dict(), f"{save_path}/pretrained_model.pt")
print("Save model:", exp_name)

Save model: MIMIC-III-0.7-64-0.05-0.05-0.05
