In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.distributed as dist

import h5py
import tables
import numpy as np
import pickle
from functools import reduce

In [3]:
# device = torch.device('cpu')
device = torch.device('cuda')
# device = torch.device('cuda') # 4 GPUs :D 

In [4]:
import os
os.chdir('/home/aisinai/work/repos/nis_patient_encoding/')

In [5]:
from data.data_loader import NISDatabase
from data.cohort_builder import CohortBuilder

from utils.experiments import *
from utils.feature_utils import *
from utils.code_mappings import *

from model.autoencoder.autoencoder import AutoEncoder
from model.autoencoder.loss import CustomLoss
from trainer.trainer import Trainer

In [6]:
import matplotlib.pyplot as plt
import seaborn as sns

In [7]:
DATA_FOLDER = 'data/raw/'
FIGURE_FOLDER = 'figures/dm/'

# Playground

#### Autoencoder 4.0

In [8]:
case_db = NISDatabase(DATA_FOLDER + 'NIS_2012_2014_proto_emb_v2.h5', 'TRAIN', num_workers=4)
control_db = NISDatabase(DATA_FOLDER + 'NIS_2012_2014_proto_emb_v2.h5', 'TRAIN', num_workers=4)

In [9]:
DATA_FOLDER = 'data/raw/'
INPUT_FEATURES = {
    'AGE' : {'type': 'one_hot', 'rep_func': None},
    'FEMALE' : {'type': 'one-hot', 'rep_func': None},
    'HCUP_ED' : {'type': 'one-hot', 'rep_func': None},
    'TRAN_IN' : {'type': 'one-hot', 'rep_func': None},
    'ELECTIVE' : {'type': 'one-hot', 'rep_func': None},
    'ZIPINC_QRTL' : {'type': 'one-hot', 'rep_func': None},
    'DXn' : {'type': 'embedding', 'rep_func': None},
    'ECODEn' : {'type': 'embedding', 'rep_func': None},
    'PRn' : {'type': 'embedding', 'rep_func': None},
    'CHRONn' : {'type': 'embedding', 'rep_func': None},
}

DEFAULT_BUILD = {
    'encoding' : {
        'total_layers' : 1,
        'scale' : 4,
        'activation' : 'leaky_relu',
    },

    'latent' : {'dimensions' : 64},

    'decoding' : {
        'scale' : 4,
        'activation' : 'leaky_relu',
        'total_layers' : 1,
        'output_dims' : None
    }
}

EMBEDDING_DICTIONARY = {

     'CHRONn' : {
        'header_prefix' : b'CHRON',
        'num_classes' : 12583,
        'dimensions' : 256,
    },   

    'DXn' : {
        'header_prefix' : b'DX',
        'num_classes' : 12583,
        'dimensions' : 256,
    },

    'PRn' : {
        'header_prefix' : b'PR',
        'num_classes' : 4445,
        'dimensions' : 64,
    },

    'ECODEn' : {
        'header_prefix' : b'ECODE',
        'num_classes' : 1186,
        'dimensions' : 32,
    },

    'CHRONBn' : {
        'header_prefix' : b'CHRONB',
        'num_classes' : 19,
        'dimensions' : 16,
    }
}

ONE_HOT_LIST = [b'ELECTIVE', b'FEMALE', b'HCUP_ED', b'TRAN_IN', b'ZIPINC_QRTL', b'AGE']
ONE_HOTS = { feature.decode('utf-8') : {} for feature in ONE_HOT_LIST }
ONE_HOTS['ELECTIVE']['num_classes'] = 2 # 0,1
ONE_HOTS['FEMALE']['num_classes'] = 2 # 0,1
ONE_HOTS['TRAN_IN']['num_classes'] = 3 # 0-2
ONE_HOTS['HCUP_ED']['num_classes'] = 5 # 0-4
ONE_HOTS['ZIPINC_QRTL']['num_classes'] = 5 # 1-4
ONE_HOTS['AGE']['num_classes'] = 13 # 0-12 (must be allowed to guess in between as well)

CONTINUOUS = {}

FEATURE_REPRESENTATIONS = {}
FEATURE_REPRESENTATIONS['embedding'] = EMBEDDING_DICTIONARY
FEATURE_REPRESENTATIONS['one_hots'] = ONE_HOTS
FEATURE_REPRESENTATIONS['continuous'] = CONTINUOUS

DEFAULT_BUILD['features'] = FEATURE_REPRESENTATIONS

find_nlike_features(case_db.headers, FEATURE_REPRESENTATIONS['embedding'])
FEATURE_REPRESENTATIONS['one_hots'] = create_onehot_info(case_db, FEATURE_REPRESENTATIONS['one_hots'], FEATURE_REPRESENTATIONS['embedding'])
DEFAULT_BUILD['features'] = FEATURE_REPRESENTATIONS
calc_output_dims(DEFAULT_BUILD)

In [10]:
# Create our autoencoder
ae = AutoEncoder(DEFAULT_BUILD)
ae.load_state('/home/aisinai/work/repos/nis_patient_encoding/experiments/train_ae_020/1587303827.pth', device=device)

ae = nn.DataParallel(ae, device_ids=[0, 1, 2, 3])
ae = ae.to(device)

In [11]:
ae.device_ids

[0, 1, 2, 3]

In [12]:
case_db.set_batch_size(4000)

In [13]:
loss = CustomLoss().to(device)
optimizer = torch.optim.Adam(ae.parameters(), lr=1e-4)

In [15]:
for i, data in enumerate(case_db.iterator):
    input = data.to(device)
    recon, gt = ae(input)
    print("Outside: input size", input.size(),
          "output_size", recon['CHRONn'].size())
    
    l = loss(recon, gt)
    optimizer.zero_grad()
    l.backward()
    optimizer.step()
    
    if i == 10:
        break

Outside: input size torch.Size([4000, 115]) output_size torch.Size([4000, 12583])
Outside: input size torch.Size([4000, 115]) output_size torch.Size([4000, 12583])
Outside: input size torch.Size([4000, 115]) output_size torch.Size([4000, 12583])
Outside: input size torch.Size([4000, 115]) output_size torch.Size([4000, 12583])
Outside: input size torch.Size([4000, 115]) output_size torch.Size([4000, 12583])
Outside: input size torch.Size([4000, 115]) output_size torch.Size([4000, 12583])
Outside: input size torch.Size([4000, 115]) output_size torch.Size([4000, 12583])
Outside: input size torch.Size([4000, 115]) output_size torch.Size([4000, 12583])
Outside: input size torch.Size([4000, 115]) output_size torch.Size([4000, 12583])
Outside: input size torch.Size([4000, 115]) output_size torch.Size([4000, 12583])
Outside: input size torch.Size([4000, 115]) output_size torch.Size([4000, 12583])
