In [1]:
import sys
sys.path.insert(0, '../../')
import torch
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from statsmodels.discrete.discrete_model import Probit
from models.imavae import IMAVAE

# To make this notebook's output stable across runs
np.random.seed(2020)
torch.manual_seed(2020)
torch.cuda.manual_seed(2020)
torch.cuda.manual_seed_all(2020)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False



# 1. Preprocess and run simulation on Jobs II Data

In [2]:
def simulate_jobs2_data(eta, alpha, n):
    df = pd.read_csv('../../data/jobs_ii_data.csv')

    # Preprocessing
    df = df.drop(['control','job_dich','job_disc'], axis=1)   # drop redundant features
    df['job_seek'] = df['job_seek'].apply(lambda x: 1 if x >= 3 else 0)
    continuous_cols = ['econ_hard','depress1','age','depress2']
    binary_cols = ['sex','nonwhite','work1','comply','treat']
    categorical_cols = ['occp','marital','educ','income']
    for col in continuous_cols:
        scaler = StandardScaler()
        df[col] = scaler.fit_transform(np.array(df[col]).reshape(-1,1)).squeeze()
    for col in binary_cols:
        if df[col].dtype == object:
            df[col] = np.array(pd.get_dummies(df[col]).iloc[:,0])
    for col in categorical_cols:
        df = pd.concat([df,pd.get_dummies(df[col])], axis=1)
        df = df.drop([col], axis=1)

    # Run simulation so that the ground truth of direct and indirect effects are zero
    T, M = np.array(df['treat'], dtype=np.float32), np.array(df['job_seek'], dtype=np.float32)
    cols_x = [col for col in df.columns if col not in ['treat','job_seek','depress2']]
    X = np.array(df[cols_x], dtype=np.float32)
    beta_pop = Probit(T,X).fit(maxiter=500).params
    TX = np.concatenate([T.reshape(-1,1),X], axis=1)
    coeffs_m = Probit(M,TX).fit(maxiter=500).params
    gamma_pop, omega_pop = coeffs_m[0], coeffs_m[1:]
    df = df[(df['treat'] == 0) & (df['job_seek'] == 0)]
    df_pseudo = df.sample(n=n,replace=True,random_state=42)
    df_pseudo['treat'] = df_pseudo.apply(
        lambda row: 1 if np.array(row[cols_x])@beta_pop + np.random.RandomState(42).rand() > 0 else 0, axis=1)
    df_pseudo['job_seek'] = df_pseudo.apply(
        lambda row: eta*(row['treat']*gamma_pop + np.array(row[cols_x])@omega_pop) + alpha + np.random.RandomState(42).rand(), axis=1)
    df_pseudo = df_pseudo.reset_index(drop=True)
    print("Treatment/control ratio: {:.4f}".format(np.sum(df_pseudo['treat'])/len(df_pseudo)))
    print("Mediation ratio: {:.4f}".format(np.sum(df_pseudo['job_seek'] > 3)/len(df_pseudo)))

    # Make sure both train and test splits have same ratio of treatment and control samples
    df_control, df_treatment = df_pseudo[df_pseudo['treat'] == 0], df_pseudo[df_pseudo['treat'] == 1]
    df_control_train = df_control.sample(frac=0.8,random_state=42)
    df_treatment_train = df_treatment.sample(frac=0.8,random_state=42)
    df_control_test = df_control.drop(df_control_train.index)
    df_treatment_test = df_treatment.drop(df_treatment_train.index)
    df_train = pd.concat(
        [df_control_train,df_treatment_train], axis=0).sample(frac=1,random_state=42).reset_index(drop=True)
    df_test = pd.concat(
        [df_control_test,df_treatment_test], axis=0).sample(frac=1,random_state=42).reset_index(drop=True)
    
    # Note that here X, T, Y, W corresponds to the variables in our paper's causal graph
    X_train, X_test = np.array(df_train['job_seek']).reshape(-1,1), np.array(df_test['job_seek']).reshape(-1,1)
    T_train, T_test = np.array(df_train['treat']).reshape(-1,1), np.array(df_test['treat']).reshape(-1,1)
    Y_train, Y_test = np.array(df_train['depress2']).reshape(-1,1), np.array(df_test['depress2']).reshape(-1,1)
    df_train = df_train.drop(['treat','job_seek','depress2'], axis=1)
    df_test = df_test.drop(['treat','job_seek','depress2'], axis=1)
    W_train, W_test = np.array(df_train).astype(np.float32), np.array(df_test).astype(np.float32)
    return X_train, T_train, W_train, Y_train, X_test, T_test, W_test, Y_test, df_pseudo

# 2. Estimate mediation effects with IMAVAE

In [3]:
eta, alpha, n = 1, 0.9, 500
X_train, T_train, W_train, Y_train, X_test, T_test, W_test, Y_test, df_pseudo = simulate_jobs2_data(eta, alpha, n)
imavae = IMAVAE(n_components=5, n_sup_networks=5, n_hidden_layers=2, hidden_dim=10, 
                optim_name='Adam', weight_decay=0.01, recon_weight=1., elbo_weight=0.1, sup_weight=1.)
WT_train = np.concatenate([T_train,W_train], axis=1)
_ = imavae.fit(
    X_train, WT_train, Y_train, X_val=X_train, aux_val=WT_train, y_val=Y_train, 
    lr=1e-5, n_epochs=500, pretrain=False, verbose=1
)
WT_test = np.concatenate([T_test,W_test], axis=1)
acme_c_mean, acme_c_std = imavae.acme_score(WT_test, treatment=False)
acme_t_mean, acme_t_std = imavae.acme_score(WT_test, treatment=True)
ade_c_mean, ade_c_std = imavae.ade_score(WT_test, treatment=False)
ade_t_mean, ade_t_std = imavae.ade_score(WT_test, treatment=True)
ate_mean, ate_std = imavae.ate_score(WT_test)
print("ACME (control) = {:.4f} +/- {:.4f}".format(acme_c_mean, acme_c_std))
print("ACME (treatment) = {:.4f} +/- {:.4f}".format(acme_t_mean, acme_t_std))
print("ADE (control) = {:.4f} +/- {:.4f}".format(ade_c_mean, ade_c_std))
print("ADE (treatment) = {:.4f} +/- {:.4f}".format(ade_t_mean, ade_t_std))
print("ATE = {:.4f} +/- {:.4f}".format(ate_mean, ate_std))

Optimization terminated successfully.
         Current function value: 0.378116
         Iterations 93
Optimization terminated successfully.
         Current function value: 0.212122
         Iterations 7
Treatment/control ratio: 0.6300
Mediation ratio: 0.1080
Beginning Training


Epoch: 499, Best Epoch: 499, Best Recon MSE: 0.00018003, Best Pred Metric [0.005


Saving the last epoch with training MSE: 0.00018003 and Pred Metric: [0.005642338, 0.005642338, 0.005642338, 0.005642338, 0.005642338]
Loaded the best model from Epoch: 499 with MSE: 0.00018003 and Pred Metric: [0.005642338, 0.005642338, 0.005642338, 0.005642338, 0.005642338]
ACME (control) = 0.0701 +/- 0.0111
ACME (treatment) = 0.0741 +/- 0.0118
ADE (control) = -0.2453 +/- 0.0000
ADE (treatment) = -0.2453 +/- 0.0000
ATE = -0.1750 +/- 0.0140
