In [1]:
import os
import numpy as np
from sklearn.metrics import mean_squared_error
from tqdm import tqdm
from model_synthetic_package import LSTMModel
from train_synthetic_package import trainInitIPTW

# dataset meta data
n_X_features = 100
n_X_static_features = 5
n_X_t_types = 1
n_classes = 1

In [2]:
import torch

import torch.nn.functional as F

import torch.optim as optim
from torch.utils import data
import torch.nn as nn

In [7]:
def get_dim():
    return n_X_features, n_X_static_features, n_X_t_types, n_classes


class SyntheticDataset(data.Dataset):
    def __init__(self, list_IDs, obs_w, treatment):
        '''Initialization'''
        self.list_IDs = list_IDs
        self.obs_w = obs_w
        self.treatment = treatment


    def __len__(self):
        '''Denotes the total number of samples'''
        return len(self.list_IDs)

    def __getitem__(self, index):
        '''Generates one sample of data'''
        # Select sample
        ID = self.list_IDs[index]

        # Load labels
        label = np.load(data_dir + '{}.y.npy'.format(ID))

        # Load data
        X_demographic = np.load(data_dir + '{}.static.npy'.format(ID))
        X_all = np.load(data_dir + '{}.x.npy'.format(ID))
        X_treatment_res = np.load(data_dir + '{}.a.npy'.format(ID))

        X = torch.from_numpy(X_all.astype(np.float32))
        X_demo = torch.from_numpy(X_demographic.astype(np.float32))
        X_treatment = torch.from_numpy(X_treatment_res.astype(np.float32))
        y = torch.from_numpy(label.astype(np.float32))

        return X, X_demo, X_treatment, y

In [8]:
#Default Parameters
treatment_option = 'vaso'
observation_window = 12
epochs = 1
batch_size = 128
lr = .001
weight_decay = .00001
l1 = .00001
resume = ''.format(treatment_option)
cuda_device = 1

gamma_h=(.1,.3,.5,.7)
HIDDEN_SIZE = 32
CUDA = False

print('hi')
os.makedirs(r'model_checkpoints', exist_ok=True)

hi


In [11]:
for gamma in gamma_h:
    data_dir = '../data/data_synthetic/data_syn_{}/'.format(gamma)
    save_model = 'model_checkpoints/mimic-6-7-{}.pt'.format(gamma)
    train_test_split = np.loadtxt('../data/data_synthetic/data_syn_{}/train_test_split.csv'.format(gamma), delimiter=',',
                                  dtype=int)
    train_iids = np.where(train_test_split == 1)[0]
    val_iids = np.where(train_test_split == 2)[0]
    test_iids = np.where(train_test_split == 0)[0]
    train_dataset = SyntheticDataset(train_iids, 12, treatment_option)
    val_dataset = SyntheticDataset(val_iids, 12, treatment_option)
    test_dataset = SyntheticDataset(test_iids, 12, treatment_option)
    train_loader = torch.utils.data.DataLoader(train_dataset, 128, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, 128, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, 128, shuffle=True)
    n_X_features, n_X_static_features, n_X_fr_types, n_classes = get_dim()
    if ''.format(treatment_option):
        if os.path.isfile(''.format(treatment_option)):
            print("=> loading checkpoint '{}'".format(''.format(treatment_option)))

            model = torch.load(''.format(treatment_option))
            model = model.cuda()

            print("=> loaded checkpoint '{}'"
                  .format(''.format(treatment_option)))

        else:
            print("=> no checkpoint found at '{}'".format(''.format(treatment_option)))
    else:

        attn_model = 'concat2'
        n_Z_confounders = HIDDEN_SIZE

        model = LSTMModel(n_X_features, n_X_static_features, n_X_fr_types, n_Z_confounders,
                          attn_model, n_classes, 12,
                          128, hidden_size=HIDDEN_SIZE)

    adam_optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

    model = trainInitIPTW(train_loader, val_loader, test_loader,
                          model, epochs=epochs,
                          criterion=F.binary_cross_entropy_with_logits, optimizer=adam_optimizer,
                          l1_reg_coef=1e-5,
                          use_cuda=False,
                          save_model=save_model)

100%|█████████████████████████████████████████████████████████████████| 22/22 [00:08<00:00,  2.55it/s]

Epoch: 0, IPW train loss: 0.7129472060637041
Epoch: 0, Outcome train loss: 0.20435010980476032





Validation:
PEHE: 0.3938	ATE: 0.2704
RMSE: 0.4050

Best model. Saving...

Test:
PEHE: 0.4160	ATE: 0.2846
RMSE: 0.4162

0.4160161574692131
0.28457241445016035
0.41615012


100%|█████████████████████████████████████████████████████████████████| 22/22 [00:08<00:00,  2.56it/s]

Epoch: 0, IPW train loss: 0.6404329159043052
Epoch: 0, Outcome train loss: 0.15179188549518585





Validation:
PEHE: 0.4603	ATE: 0.3855
RMSE: 0.3669

Best model. Saving...

Test:
PEHE: 0.4546	ATE: 0.3839
RMSE: 0.3595

0.45458821275425715
0.3838738838667387
0.35952643


100%|█████████████████████████████████████████████████████████████████| 22/22 [00:08<00:00,  2.50it/s]

Epoch: 0, IPW train loss: 0.714750967242501
Epoch: 0, Outcome train loss: 0.4049650254574689





Validation:
PEHE: 0.7348	ATE: 0.4782
RMSE: 0.5129

Best model. Saving...

Test:
PEHE: 0.7189	ATE: 0.4647
RMSE: 0.5106

0.7189362853268815
0.4647024287973357
0.5106464


100%|█████████████████████████████████████████████████████████████████| 22/22 [00:08<00:00,  2.45it/s]

Epoch: 0, IPW train loss: 0.6845430162819949
Epoch: 0, Outcome train loss: 0.07627394575286996





Validation:
PEHE: 0.3816	ATE: 0.2690
RMSE: 0.1842

Best model. Saving...

Test:
PEHE: 0.3884	ATE: 0.2675
RMSE: 0.1776

0.3883965835348956
0.267469852318223
0.17756014
