In [None]:
import torch
import numpy as np
import os
import sys
import math
import logging
import json
from tqdm import tqdm, trange
from gctpyhealth.process_eicu_dataset import get_eicu_datasets
from gctpyhealth.utils import *
from gctpyhealth.gct import GCT

from tensorboardX import SummaryWriter
import torchsummary as summary


In [None]:
class Args:
    pass


args = Args()
args.learning_rate = 0.00022
args.max_steps = 10  ### for short run # 1000000
args.do_train = True
args.do_eval = True
args.do_test = True
args.warmup = 0.05  # default
args.intermediate_size = 256  # default
args.eps = 1e-8  # default
args.max_grad_norm = 1.0  # default
args.eval_batch_size = 32
args.logging_steps = 100  # default
args.num_train_epochs = 0  # default
args.seed = 42  # default

In [None]:
label_key = "expired"
fold = 0
data_dir = "data"
output_dir = "eicu_output/model_pyhealth"
batch_size = 32

In [None]:
# Store the log data
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO
)

logger = logging.getLogger(__name__)

if not os.path.exists(output_dir):
    os.makedirs(output_dir)
logging_dir = os.path.join(output_dir, 'logging')
if not os.path.exists(logging_dir):
    os.makedirs(logging_dir)
tb_writer = SummaryWriter(log_dir=logging_dir)

In [None]:
# loading the eICU dataset
from pyhealth.datasets import eICUDataset

print('Loading eICU dataset')
eicu_ds = eICUDataset(
    root='../../eicu_csv',
    tables=["admissionDx", "diagnosisString", "treatment"],
    refresh_cache=False,
    dev=True
)

eicu_ds.stat()
eicu_ds.info()

In [None]:
# fetch the datatset from caches
datasets, prior_guides = get_eicu_datasets(data_dir, fold=fold)
train_dataset, eval_dataset, test_dataset = datasets
train_priors, eval_priors, test_priors = prior_guides
train_priors_dataset = eICUPriorDataset(train_priors)
eval_priors_dataset = eICUPriorDataset(eval_priors)
test_priors_dataset = eICUPriorDataset(test_priors)

# prepare data loader
train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

train_priors_dataloader = DataLoader(train_priors_dataset,
                                     batch_size=batch_size, collate_fn=priors_collate_fn)
eval_priors_dataloader = DataLoader(eval_priors_dataset,
                                    batch_size=batch_size, collate_fn=priors_collate_fn)
test_priors_dataloader = DataLoader(test_priors_dataset,
                                    batch_size=batch_size, collate_fn=priors_collate_fn)

In [None]:
# check if cuda is available
n_gpu = torch.cuda.device_count()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if device.type == 'cuda':
    torch.cuda.set_device(device)
    logger.info('***** Using CUDA device *****')

In [None]:
# from pyhealth.models import Transformer
from gctpyhealth.gct import GCT

model = GCT(
    dataset=eicu_ds,
    feature_keys=['conditions_hash',
                  'conditions_mask',
                  'procedures_hash',
                  'procedures_mask'],
    label_key="label",
    mode="binary",
)
model = model.to(device)


In [None]:
# prepare optimizer, scheduler
num_update_steps_per_epoch = len(train_dataloader)
if args.max_steps > 0:
    max_steps = args.max_steps
    num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
        args.max_steps % num_update_steps_per_epoch > 0)
else:
    max_steps = int(num_update_steps_per_epoch * args.num_train_epochs)
    num_train_epochs = args.num_train_epochs
num_train_epochs = int(np.ceil(num_train_epochs))

args.eval_steps = num_update_steps_per_epoch // 2

optimizer = torch.optim.Adamax(model.parameters(), lr=args.learning_rate)
warmup_steps = max_steps // (1 / args.warmup)
scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, num_training_steps=max_steps)

logger.info('***** Running Training *****')
logger.info(' Num examples = {}'.format(len(train_dataloader.dataset)))
logger.info(' Num epochs = {}'.format(num_train_epochs))
logger.info(' Train batch size = {}'.format(batch_size))
logger.info(' Total optimization steps = {}'.format(max_steps))

epochs_trained = 0
global_step = 0
tr_loss = torch.tensor(0.0).to(device)
logging_loss_scalar = 0.0
model.zero_grad()

In [None]:

train_pbar = trange(epochs_trained, num_train_epochs, desc='Epoch')
for epoch in range(epochs_trained, num_train_epochs):
    epoch_pbar = tqdm(train_dataloader, desc='Iteration')
    for data, priors_data in zip(train_dataloader, train_priors_dataloader):
        model.train()
        data, priors_data = prepare_data(data, priors_data, device)

        # [loss, logits, all_hidden_states, all_attentions]
        outputs = model(data, priors_data)
        loss = outputs[0]

        if n_gpu > 1:
            loss = loss.mean()
        loss.backward()

        tr_loss += loss.detach()
        optimizer.step()
        scheduler.step()
        model.zero_grad()
        global_step += 1

        if (args.logging_steps > 0 and global_step % args.logging_steps == 0):
            logs = {}
            tr_loss_scalar = tr_loss.item()
            logs['loss'] = (tr_loss_scalar - logging_loss_scalar) / args.logging_steps
            logs['learning_rate'] = scheduler.get_last_lr()[0]
            logging_loss_scalar = tr_loss_scalar
            if tb_writer:
                for k, v in logs.items():
                    if isinstance(v, (int, float)):
                        tb_writer.add_scalar(k, v, global_step)
                tb_writer.flush()
            output = {**logs, **{"step": global_step}}
            print(output)

        if (args.eval_steps > 0 and global_step % args.eval_steps == 0):
            metrics = prediction_loop(device, label_key, model, eval_dataloader, eval_priors_dataloader)
            logger.info('**** Checkpoint Eval Results ****')
            for key, value in metrics.items():
                logger.info('{} = {}'.format(key, value))
                tb_writer.add_scalar(key, value, global_step)

        epoch_pbar.update(1)
        if global_step >= max_steps:
            break
    epoch_pbar.close()
    train_pbar.update(1)
    if global_step >= max_steps:
        break


In [None]:

train_pbar.close()
if tb_writer:
    tb_writer.close()

logging.info('\n\nTraining completed')

eval_results = {}
logger.info('*** Evaluate ***')
logger.info(' Num examples = {}'.format(len(eval_dataloader.dataset)))
eval_result = prediction_loop(device, label_key, model, eval_dataloader, eval_priors_dataloader)
output_eval_file = os.path.join(output_dir, 'eval_results.txt')
with open(output_eval_file, 'w') as writer:
    logger.info('*** Eval Results ***')
    for key, value in eval_result.items():
        logger.info("{} = {}".format(key, value))
        writer.write('{} = {}'.format(key, value))
eval_results.update(eval_result)

logging.info('*** Test ***')
# predict
test_result = prediction_loop(device, label_key, model, test_dataloader, test_priors_dataloader, description='Testing')
output_test_file = os.path.join(output_dir, 'test_results.txt')
with open(output_test_file, 'w') as writer:
    logger.info('**** Test results ****')
    for key, value in test_result.items():
        logger.info('{} = {}'.format(key, value))
        writer.write('{} = {}'.format(key, value))
eval_results.update(test_result)

> The training runs without an error

#### PyHealth Example Code

If we implement all model code in PyHealth-compatible way, below codes should be run without an error

In [None]:
# pyhealth_test = GraphConvolutionalTransformer(
#         dataset=dataset,
#         label_key="label",
#         mode="binary",
#     )

In [None]:
# from pyhealth.datasets import get_dataloader
# train_loader = get_dataloader(dataset, batch_size=2, shuffle=True)
# data_batch = next(iter(train_loader))
#
# ret = model(**data_batch)
# print(ret) ## the output should be in this format
# #{
#    'loss': tensor(0.8872, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
#    'y_prob': tensor([[0.5008], [0.6614]], grad_fn=<SigmoidBackward0>),
#    'y_true': tensor([[1.], [0.]]),
#    'logit': tensor([[0.0033], [0.6695]], grad_fn=<AddmmBackward0>)
#}