In [1]:
import torch
import pickle
from argparse import Namespace
from torch.utils.data import DataLoader
import torch.nn.functional as F
from heterogt.utils.tokenizer import EHRTokenizer
from heterogt.utils.dataset import PreTrainEHRDataset, batcher, expand_level3
from heterogt.utils.seed import set_random_seed
from heterogt.model.model import HeteroGTPretrain
from tqdm import tqdm

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
set_random_seed(123)

[INFO] Random seed set to 123


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
config = Namespace(
    dataset = "MIMIC-III",
    tasks = ["death", "readmission", "stay", "next_diag_6m", "next_diag_12m"], 
    task_index = 2,  # index of the task to train
    token_type = ["diag", "med", "lab", "pro"],
    special_tokens = ["[PAD]", "[CLS]"],
    attn_mask_dicts = [{1:[6,7], 2:[6,7], 3:[6,7], 4:[6,7], 6:[2,3,4,5,6,7], 7:[2,3,4,5,6,7]}, 
                       {1:[6,7], 2:[6,7], 3:[6,7], 4:[6,7], 6:[2,3,4,5,6,7], 7:[2,3,4,5,6,7]}],
    d_model = 64,
    num_heads = 4,
    batch_size = 32,
    lr = 1e-3,
    epochs = 5,
    early_stop_patience = 5,
    group_code_thre = 5,  # if there are group_code_thre diag codes belongs to the same group ICD code, then the group code is generated
    pretrain_mask_rate = 0.7,
)

In [5]:
full_data_path = f"./data_process/{config.dataset}-processed/mimic.pkl"  # for tokenizer
pretrain_data_path = f"./data_process/{config.dataset}-processed/mimic_pretrain.pkl" # for pretraining

In [6]:
ehr_full_data = pickle.load(open(full_data_path, 'rb'))
group_code_sentences = [expand_level3()[1]]
diag_sentences = ehr_full_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_full_data["NDC"].values.tolist()
lab_sentences = ehr_full_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_full_data["PRO_CODE"].values.tolist()
age_sentences = [[str(c)] for c in set(ehr_full_data["AGE"].values.tolist())] # important of [[]]
max_admissions = ehr_full_data.groupby("SUBJECT_ID")["HADM_ID"].nunique().max()
config.max_num_adms = max_admissions
print(f"Max admissions per patient: {config.max_num_adms}")

Max admissions per patient: 8


In [7]:
tokenizer = EHRTokenizer(age_sentences, group_code_sentences, diag_sentences, med_sentences, lab_sentences, 
                         pro_sentences, special_tokens=config.special_tokens)
config.label_vocab_size = {"diag":len(tokenizer.diag_voc.id2word),     
                            "med":len(tokenizer.med_voc.id2word), 
                            "lab":len(tokenizer.lab_voc.id2word), 
                            "pro":len(tokenizer.pro_voc.id2word)}
config.global_vocab_size = len(tokenizer.vocab.id2word)
config.age_vocab_size = tokenizer.token_number("age")
config.group_code_vocab_size = tokenizer.token_number("group")
print(f"Age vocabulary size: {config.age_vocab_size}")
print(f"Group code vocabulary size: {config.group_code_vocab_size}")

Age vocabulary size: 18
Group code vocabulary size: 19


In [8]:
# load pretrain data
ehr_pretrain_data = pickle.load(open(pretrain_data_path, 'rb'))
# load occurence data

In [9]:
pretrain_dataset = PreTrainEHRDataset(ehr_pretrain_data=ehr_pretrain_data, tokenizer=tokenizer, token_type=config.token_type,
                                      mask_rate=config.pretrain_mask_rate, group_code_thre=config.group_code_thre, max_num_adms=config.max_num_adms)
print("Number of pretrain samples:", len(pretrain_dataset))
pretrain_dataloader = DataLoader(pretrain_dataset, batch_size=config.batch_size, 
                                 collate_fn=batcher(tokenizer, n_token_type=len(config.token_type), is_pretrain=True), shuffle=True)

Number of pretrain samples: 23146


In [10]:
for batch in pretrain_dataloader:
    pass
print("All pass!")

All pass!


In [None]:
model = HeteroGTPretrain(tokenizer=tokenizer, token_types=config.token_type, d_model=config.d_model, num_heads=config.num_heads, layer_types=['gnn', 'tf', 'gnn', 'tf'], max_num_adms=config.max_num_adms, 
                     device=device, label_vocab_size=config.label_vocab_size, attn_mask_dicts=config.attn_mask_dicts,
                     use_cls_cat=True).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)

In [12]:
for epoch in range(1, 1 + config.epochs):
    model.train()
    avg_loss, avg_loss_dict = 0., {loss_type: 0. for loss_type in config.token_type}
    step_iter = tqdm(pretrain_dataloader, desc=f"Epoch {epoch:03d}", unit="batch")
    
    for step, batch in enumerate(step_iter):
        batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]
        loss, loss_dict = model(*batch)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        avg_loss += loss.item()
        
        for loss_type in config.token_type:
            avg_loss_dict[loss_type] += loss_dict[loss_type]
    
    avg_loss /= (step + 1)
    avg_loss_dict = {loss_type: avg_loss_dict[loss_type] / (step + 1) for loss_type in config.token_type}
    loss_str = ", ".join([f"{t}: {avg_loss_dict[t]:.4f}" for t in config.token_type])
    print(f"[Epoch {epoch:03d}] Avg loss: {avg_loss:.4f} | {loss_str}")

Epoch 001: 100%|██████████| 724/724 [00:37<00:00, 19.24batch/s]


[Epoch 001] Avg loss: 0.0602 | diag: 0.0265, med: 0.1057, lab: 0.0832, pro: 0.0253


Epoch 002: 100%|██████████| 724/724 [00:37<00:00, 19.10batch/s]


[Epoch 002] Avg loss: 0.0493 | diag: 0.0176, med: 0.0911, lab: 0.0727, pro: 0.0159


Epoch 003: 100%|██████████| 724/724 [00:37<00:00, 19.24batch/s]


[Epoch 003] Avg loss: 0.0477 | diag: 0.0172, med: 0.0868, lab: 0.0717, pro: 0.0151


Epoch 004: 100%|██████████| 724/724 [00:37<00:00, 19.36batch/s]


[Epoch 004] Avg loss: 0.0468 | diag: 0.0169, med: 0.0845, lab: 0.0711, pro: 0.0147


Epoch 005: 100%|██████████| 724/724 [00:37<00:00, 19.16batch/s]

[Epoch 005] Avg loss: 0.0461 | diag: 0.0168, med: 0.0827, lab: 0.0706, pro: 0.0145





In [13]:
import os
exp_name = (
    f"{config.dataset}-{config.pretrain_mask_rate}-{config.d_model}"
)
save_path = "./pretrained_models/" + exp_name
if not os.path.exists(save_path):
    os.makedirs(save_path)
torch.save(model.cpu().state_dict(), f"{save_path}/pretrained_model.pt")
print("Save model:", exp_name)

Save model: MIMIC-III-0.7-64
