In [1]:
%load_ext autoreload
%autoreload 2
# activate line execution
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"



In [2]:
import os
import pandas as pd
from tqdm import tqdm
import numpy as np
import yaml
import pickle
import json
import torch

In [3]:
import plotly.express as px
import plotly.graph_objects as go


In [4]:
from synehrgy.Dataset import MyDataset,MyDatasetRaw,ClinicalDataset
from synehrgy.utils import *
from synehrgy.config import HydraConfig
from synehrgy.models import SynEHRgy

  from .autonotebook import tqdm as notebook_tqdm


# Configs

In [5]:
from hydra import initialize, compose, core
from omegaconf import DictConfig, OmegaConf



In [None]:

core.global_hydra.GlobalHydra.instance().clear()

initialize(version_base=None, config_path="configs")
cfg = compose(config_name="configTrain.yaml")

In [8]:
config = HydraConfig(cfg)


# Data Loading

In [9]:
# load processed data
train_dataset = ClinicalDataset(config.dataset_folder, split='train')
eval_dataset = ClinicalDataset(config.dataset_folder, split='val')


# An example of the data
train_dataset[0]

Loaded train dataset with 23468 patients
Loaded val dataset with 5155 patients


{'sid': 10003,
 'hadm_id': [144039],
 'covariates': [[41.10715693176053, 1]],
 'codes': [[['icd_8602',
    'icd_80708',
    'icd_99811',
    'icd_4019',
    'icd_E8798',
    'icd_4582',
    'icd_82523',
    'icd_2899',
    'icd_82525'],
   ['proc_3399', 'proc_3409', 'proc_3404', 'proc_3893']]],
 'ts': [    Capillary refill rate Glascow coma scale eye opening  \
  0                     NaN                  1 No Response   
  1                     NaN                            NaN   
  2                     NaN                            NaN   
  3                     NaN                            NaN   
  4                     NaN                            NaN   
  5                     NaN                            NaN   
  6                     NaN                            NaN   
  7                     NaN                    3 To speech   
  8                     NaN                            NaN   
  9                     NaN                            NaN   
  10            

In [10]:
# discretize the data
train_dataset.discretize()
eval_dataset.discretize()

# An example of the data
train_dataset[0]

Discretized data already exists. Loading...
Discretized data already exists. Loading...


{'covars': [([41, 42], [5, 1])],
 'codes': [[892,
   889,
   66,
   0,
   68,
   496,
   2832,
   1455,
   1283,
   4241,
   3715,
   3629,
   3599]],
 'ts': [[([1, 2, 3, 4, 17, 20, 25, 29, 36, 37],
    [0, 0, 0, 0, 9, 5, 8, 9, 4, 8],
    [0]),
   ([8, 11, 13, 14, 15, 16, 19, 21, 22, 26, 29, 31, 32, 33, 34, 35, 36, 39],
    [6, 6, 0, 5, 3, 0, 4, 8, 9, 0, 9, 5, 3, 6, 4, 9, 4, 3],
    [0]),
   ([30, 40], [4, 9], [0]),
   ([17, 20, 25, 29, 36, 37, 38], [9, 6, 8, 9, 4, 8, 8], [1]),
   ([17, 20, 25, 29, 36, 37], [9, 6, 8, 9, 4, 8], [1]),
   ([17, 20, 25, 29, 36, 37, 38], [9, 7, 9, 9, 4, 8, 9], [1]),
   ([17, 20, 25, 29, 36, 37], [9, 7, 8, 9, 4, 8], [1]),
   ([1, 2, 3, 4, 17, 20, 25, 29, 36, 37, 38],
    [1, 0, 1, 0, 9, 6, 7, 9, 4, 8, 9],
    [1]),
   ([17, 20, 25, 29, 36, 37], [9, 8, 7, 9, 4, 7], [0]),
   ([17, 20, 21, 25, 29, 36, 37], [9, 7, 8, 9, 9, 2, 8], [0]),
   ([40], [9], [0]),
   ([1, 2, 3, 4, 17, 20, 25, 29, 36, 37, 38],
    [2, 1, 2, 0, 9, 7, 8, 9, 5, 7, 9],
    [1]),
   ([29, 30,

In [11]:
# tokenize the data
train_dataset.tokenize(n_ctx=config.n_ctx)
eval_dataset.tokenize(n_ctx=config.n_ctx)

# An example of the data
train_dataset[0]

[info] Tokenization setting: n_ctx=1024, label_shuffle=False, truncate=True, split=False, ignore_ts=False, ts_shuffle=False


Tokenizing Dataset: 100%|██████████| 23468/23468 [00:12<00:00, 1920.12it/s]


[info] Dataset size: 15.15M tokens
[info] Truncated 14.47M tokens
[info] Tokenization setting: n_ctx=1024, label_shuffle=False, truncate=True, split=False, ignore_ts=False, ts_shuffle=False


Tokenizing Dataset: 100%|██████████| 5155/5155 [00:02<00:00, 1927.47it/s]


[info] Dataset size: 3.34M tokens
[info] Truncated 2.98M tokens


{'input_ids': tensor([5104, 5050, 5061,  ..., 5111, 5111, 5111]),
 'labels': tensor([5104, 5050, 5061,  ..., 5111, 5111, 5111]),
 'attention_mask': tensor([1, 1, 1,  ..., 0, 0, 0])}

# Training

In [None]:
# if the model name is already in the saved_models folder, increment the version number
# v = 1
# while os.path.exists(f"{PATH_SAVE_MODEL}/{RUN_NAME}[v{v}]_model"):
#     v += 1
# RUN_NAME = RUN_NAME + f"[v{v}]"

In [19]:
# setup wandb

import wandb
from dotenv import load_dotenv
load_dotenv()
wandb.login(key=os.getenv("WANDB_KEY"))


wandb_config = {k: v for k, v in vars(config).items() if k != "w_class"}
wandb.init(project=cfg.wandb.project, name = cfg.run_name,config=wandb_config)


True



True

0,1
train/epoch,▁▂▃▃▄▅▆▆▇█
train/global_step,▁▁▂▂▃▃▃▃▄▄▅▅▆▆▆▆▇▇██
train/grad_norm,█▅▃▂▁▁▁▁▁▁
train/learning_rate,█▇▆▆▅▄▃▃▂▁
train/loss,█▅▃▂▂▁▁▁▁▁
train/perplexity,█▂▁▁▁▁▁▁▁▁

0,1
train/epoch,0.54496
train/global_step,100.0
train/grad_norm,0.47241
train/learning_rate,0.00029
train/loss,2.1831
train/perplexity,8.87377


In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SEED = 4
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)



<torch._C.Generator at 0x7f5f236c2370>

In [30]:
model = SynEHRgy(config).to(device)




In [31]:
cfg.train.epochs = 0.1

In [32]:
model.fit(cfg, train_dataset, eval_dataset)

Epoch,Training Loss,Validation Loss,Perplexity
0,6.9687,5.320583,204.503045


# Generation

In [7]:
# loading generation config

core.global_hydra.GlobalHydra.instance().clear()

initialize(version_base=None, config_path="configs")
gen_cfg = compose(config_name="configGenerate.yaml")

gen_cfg

hydra.initialize()

{'model': 'TEST', 'fix_covars': False, 'bin_type': 'uniform', 'n_samples': 6000, 'n_resample': 1000, 'batch_size': 256, 'generation': {'top_k': 50, 'top_p': 1.0, 'temperature': 1.0, 'repetition_penalty': 1.0, 'do_sample': True}}

In [8]:
device='cuda'

In [None]:
# loading the model

RUN_NAME = "synehrgy-mimic-v2"

config_path = f"./saved_models/{RUN_NAME}_config.yaml"
model_path = f"./saved_models/{RUN_NAME}"



model = SynEHRgy.load_model(config_path, model_path).to(device)

config = HydraConfig(OmegaConf.load(f"{config_path}"))


metadata = pickle.load(open(config.dataset_folder+"/metadata2.pkl", "rb"))

Checkpoints found:  checkpoint-3303


In [10]:
gen_cfg.fix_covars=False
gen_cfg.n_samples=1000

In [12]:
# generate synthetic data
synthetic_data_tokenized = model.generate_synthetic_dataset(gen_cfg)

 25%|██▌       | 1/4 [00:20<01:01, 20.55s/it]

[info] Generated 256 synthetic patients


 50%|█████     | 2/4 [00:30<00:29, 14.59s/it]

[info] Generated 512 synthetic patients


 75%|███████▌  | 3/4 [00:41<00:12, 12.70s/it]

[info] Generated 768 synthetic patients


 75%|███████▌  | 3/4 [00:51<00:17, 17.28s/it]


In [14]:
# address to save the synthetic data
syn_folder = "./data/synthetic"


synthetic_dataset = ClinicalDataset(syn_folder, split='synthetic', data=synthetic_data_tokenized, metadata=metadata)



[info] Loaded synthetic dataset. Please note that the dataset is already tokenized and discretized
[info] Loaded synthetic dataset with 1000 patients


In [15]:
synthetic_dataset.detokenize()


synthetic_dataset[0]

Detokenizing: 100%|██████████| 1000/1000 [00:00<00:00, 1068.91it/s]

full: 716, truncated: 284
no ihm: 0 / 1000





{'covars': [([41, 42], [14, 1])],
 'codes': [[42, 0, 70, 288, 25]],
 'ts': [[([17, 25, 37], [9, 9, 9], [2]),
   ([20, 29, 36, 38], [3, 5, 4, 4], [0]),
   ([1, 2, 4], [2, 1, 1], [0]),
   ([19, 20, 29, 36], [3, 3, 5, 3], [0]),
   ([17, 25, 37], [8, 6, 8], [0]),
   ([20, 29, 36], [2, 5, 4], [1]),
   ([20, 29, 36], [1, 8, 3], [1]),
   ([20, 29, 36], [2, 8, 6], [1]),
   ([17, 25, 37], [9, 9, 9], [0]),
   ([20, 29, 36], [1, 8, 3], [0]),
   ([1, 2, 4, 20, 29, 36, 38], [2, 1, 1, 1, 7, 3, 2], [1]),
   ([17, 25, 37], [7, 4, 7], [0]),
   ([20, 29, 36], [2, 9, 4], [1]),
   ([17, 25, 37], [8, 5, 8], [0]),
   ([20, 29, 36], [0, 7, 2], [1]),
   ([17, 25, 37], [7, 3, 7], [0]),
   ([20, 29, 36], [1, 9, 4], [1]),
   ([17, 25, 37], [8, 5, 7], [0]),
   ([1, 2, 4, 19, 20, 29, 36], [2, 1, 1, 5, 1, 9, 5], [1]),
   ([17, 25, 37], [8, 6, 9], [0]),
   ([20, 29, 36, 38], [2, 7, 3, 2], [1]),
   ([17, 25, 37], [8, 6, 8], [0]),
   ([17, 20, 25, 29, 36, 37], [7, 2, 3, 7, 4, 6], [1]),
   ([5, 17, 20, 25, 29, 36, 37],

In [16]:
synthetic_dataset.save(name=RUN_NAME)

# Evaluation