In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from modules import CodeARmodel, VQVAE
import random
import hparams as hp
import numpy as np
import pandas as pd
from utils import *
from torch.utils.data import DataLoader
from tqdm import tqdm_notebook
import warnings
warnings.filterwarnings("ignore")

GAP_TIME = 6
WINDOW_SIZE = 24
ID_COLS = ['subject_id', 'hadm_id', 'icustay_id']
DATA_FILEPATH = "./Dataset/all_hourly_data.h5"

X = pd.read_hdf(DATA_FILEPATH, 'vitals_labs')
statics = pd.read_hdf(DATA_FILEPATH, 'patients')
Y = statics[statics.max_hours > WINDOW_SIZE + GAP_TIME][['mort_hosp', 'mort_icu', 'los_icu']]
Y['los_3'] = Y['los_icu'] > 3
Y['los_7'] = Y['los_icu'] > 7
Y.drop(columns=['los_icu'], inplace=True)
Y.astype(float)

df_X, df_Y = aggregate_data(X, Y)

train_frac, dev_frac, test_frac = 0.8, 0.1, 0.1
X_subj_idx, Y_subj_idx = [df.index.get_level_values('subject_id') for df in (df_X, df_Y)]
X_subjects = set(X_subj_idx)
assert X_subjects == set(Y_subj_idx), "Subject ID pools differ!"

np.random.seed(0)
subjects, N = np.random.permutation(list(X_subjects)), len(X_subjects)
N_train, N_dev, N_test = int(train_frac * N), int(dev_frac * N), int(test_frac * N)
train_subj = subjects[:N_train]
dev_subj   = subjects[N_train:N_train + N_dev]
test_subj  = subjects[N_train+N_dev:]

[(df_X_train, df_X_dev, df_X_test), (df_Y_train, df_Y_dev, df_Y_test)] = [
    [df[df.index.get_level_values('subject_id').isin(s)] for s in (train_subj, dev_subj, test_subj)] \
    for df in (df_X, df_Y)
]

idx = pd.IndexSlice
df_X_means = np.nanmean(df_X_train.loc[:, idx[:, ['mean']]].to_numpy(), axis=0)
df_X_stds = np.nanstd(df_X_train.loc[:, idx[:, ['mean']]].to_numpy(), axis=0)

if not os.path.exists(f"./synthetic_dataset/codear"):
    os.mkdir(f"./synthetic_dataset/codear")
    os.mkdir(f"./synthetic_dataset/codear/sequences")
    os.mkdir(f"./synthetic_dataset/codear/labels")

model = CodeARmodel(hp).cuda()
checkpoint_dict = torch.load(f"./training_log/codear/Gen_checkpoint_149000.pt", map_location='cpu')
model.load_state_dict(checkpoint_dict['state_dict'])
model.eval()

vqvae = VQVAE(hp).cuda()
vqvae.load_state_dict(torch.load(f"./training_log/codear/vqvae.pt", map_location='cpu'))
vqvae.eval()

with torch.no_grad():
    for k, subj in enumerate(tqdm_notebook(X_subjects)):
        label = torch.LongTensor(df_Y[df_Y.index.get_level_values('subject_id')==subj].astype(int).to_numpy())
        code = model.inference(label.cuda())
        preds = vqvae.decode(code)
        seq = df_X_stds*preds[0].detach().cpu().numpy()+df_X_means
        if k==0:
            print(seq.shape)
            print()

        if not os.path.exists(f"./Dataset/codes/codear"):
            os.mkdir(f"./Dataset/codes/codear")

        np.save(f"./Dataset/codes/codear/codes_{subj}.npy", code[0].detach().cpu().numpy())
        np.save(f"./synthetic_dataset/codear/sequences/{subj}.npy", seq)
        np.save(f"./synthetic_dataset/codear/labels/{subj}.npy", label[0]) # [4, ] (4-tasks)



  0%|          | 0/23944 [00:00<?, ?it/s]

(24, 104)

