# Example Pipeline for e401k

This notebook is a proof-of-concept for generating causal samples from external samples

In [1]:
import contextlib
import sys
import os
sys.path.append("../")  # go to parent dir
# sys.path.append("../data/analysis/")  # go to parent dir

import jax
import jax.random as jr
import jax.numpy as jnp
# jnp.set_printoptions(precision=2)
jax.config.update("jax_enable_x64", True)
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import rankdata

import rpy2.robjects as ro
from rpy2.robjects.packages import importr
from rpy2.robjects import pandas2ri
from rpy2.robjects.vectors import StrVector
from rpy2.robjects.packages import SignatureTranslatedAnonymousPackage

from frugal_flows.causal_flows import independent_continuous_marginal_flow, get_independent_quantiles, train_frugal_flow
from frugal_flows.sample_outcome import sample_outcome
from frugal_flows.sample_marginals import from_quantiles_to_marginal_cont, from_quantiles_to_marginal_discr
from frugal_flows.train_quantile_propensity_score import train_quantile_propensity_score
from frugal_flows.bijections import UnivariateNormalCDF
from frugal_flows.benchmarking import FrugalFlowModel
from frugal_flows.sample_outcome import sample_outcome
from frugal_flows.sample_marginals import from_quantiles_to_marginal_cont, from_quantiles_to_marginal_discr
from frugal_flows.train_quantile_propensity_score import train_quantile_propensity_score
import torch
from benchmarking import compare_datasets


import data.template_causl_simulations as causl_py
import data.analysis.validationMethods as valMethods
import wandb

# Activate automatic conversion of rpy2 objects to pandas objects
pandas2ri.activate()
base = importr('base')
utils = importr('utils')

# Import the R library causl
try:
    causl = importr('causl')
except Exception as e:
    package_names = ('causl')
    utils.install_packages(StrVector(package_names))

seed = 0
N = 2000
B = 50
sampling_size = 1000
keys, *subkeys = jr.split(jr.PRNGKey(seed), 20)

def clean_ate(value):
    if isinstance(value, (list, tuple, np.ndarray)):
        return np.mean(value)
    return value

In [2]:
marginal_hyperparam_dict = {
    'learning_rate': 5e-4,
    # 'learning_rate': 0.2,
    'RQS_knots': 8,
    'flow_layers': 3,
    'nn_depth': 5,    
    'nn_width': 10,
    'max_patience': 100,
    'max_epochs': 20000
}
hyperparam_dict = {
    'learning_rate': 0.00261635,
    'RQS_knots': 5,
    'flow_layers': 2,
    'nn_depth': 3,    
    'nn_width': 34,
    'max_patience': 100,
    'max_epochs': 20000
}
causal_margin_hyperparams_dict = {
    'RQS_knots': 4,
    'flow_layers': 8,
    'nn_depth': 10,    
    'nn_width': 50,
}
seed=7

# Load data
e401k = pd.read_csv('../data/filtered_401k_data.csv')

# Preprocess data
outcome_col = 'net_tfa'
treatment_col = 'e401'
standardised_outcome_col = f'{outcome_col}_standardised'
Y_control = e401k.loc[e401k[treatment_col]==0, outcome_col]
Y_control_mean = Y_control.mean()
Y_control_std = Y_control.std()
e401k[standardised_outcome_col] = (e401k[outcome_col] - Y_control_mean) / Y_control_std
e401k_filtered = e401k.loc[(e401k[standardised_outcome_col] > -2) & (e401k[standardised_outcome_col] < +3)]
X = jnp.array(e401k_filtered[treatment_col].values)[:, None]
Y = jnp.array(e401k_filtered[standardised_outcome_col].values)[:, None]
covariate_colnames = [col for col in e401k_filtered.columns if col not in [outcome_col,standardised_outcome_col, treatment_col]]
# ['age', 'inc', 'educ', 'fsize', 'marr', 'twoearn', 'db', 'pira', 'hown', 'p401']

cont_columns = ['age', 'inc']
disc_columns = ['educ', 'fsize', 'marr', 'twoearn', 'db', 'pira', 'hown', 'p401']
disc_columns = cont_columns + disc_columns
cont_columns = []

for col in cont_columns:
    mean = e401k_filtered[col].mean()
    std = e401k_filtered[col].std()
    e401k_filtered[col] = (e401k_filtered[col] - mean) / std

Z_cont = jnp.array(e401k_filtered[cont_columns].values).astype(float)
Z_cont = None
# Z_disc = jnp.array(e401k_filtered[disc_columns].values)
Z_disc = jnp.array(e401k_filtered[disc_columns].values)
e401k_rescaled = e401k_filtered[
    [standardised_outcome_col, treatment_col] + covariate_colnames
]

true_ATE = 1000
benchmark_flow = FrugalFlowModel(Y=Y, X=X, Z_disc=Z_disc, Z_cont=Z_cont, confounding_copula=None)
e401k_for_frugal_flow = e401k_rescaled.copy()

In [None]:
benchmark_flow.train_benchmark_model(
    training_seed=jr.PRNGKey(seed),
    marginal_hyperparam_dict=marginal_hyperparam_dict,
    frugal_hyperparam_dict=hyperparam_dict,
    causal_model='location_translation',
    causal_model_args={'ate': 0, **causal_margin_hyperparams_dict},
    prop_flow_hyperparam_dict=causal_margin_hyperparams_dict
)

  1%|▋                                                                                                                    | 121/20000 [02:25<6:38:42,  1.20s/it, train=-1.8997623618224506, val=-1.4560157590190013]

In [None]:
benchmark_flow

In [None]:
def rescale_outcome(x, mean, std):
    return x * std + mean

### Unconfounded Data

In [None]:
sim_data_df = benchmark_flow.generate_samples(
    key=jr.PRNGKey(10*seed),
    sampling_size=6000,
    copula_param=0.,
    outcome_causal_model='location_translation',
    outcome_causal_args={'ate': true_ATE / Y_control_std},
    with_confounding=True
)
sim_data_df.columns = e401k_rescaled.columns

In [None]:
import seaborn as sns

sns.set(font_scale=1.25)
fig,ax = plt.subplots(ncols=2,figsize=(18,6))
sns.heatmap(e401k_rescaled[[standardised_outcome_col] + covariate_colnames].corr(),ax=ax[0],square=True)
ax[0].set_title('Observed e401k Data')
sns.heatmap(sim_data_df[[standardised_outcome_col] + covariate_colnames].corr(),ax=ax[1],square=True)
ax[1].set_title('Generated e401k Data')
# fig.savefig('Lalonde_NSW_0_1000_0.png')

In [None]:
# alphas = [0.1, 0.5, 1.0, 2.0]
# k = 3
# compare_datasets(e401k_rescaled.sample(2000).values, sim_data_df.sample(2000).values, alphas=alphas, k=3, n_permutations=2000)

# With Credence

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import pytorch_lightning as pl
import tqdm
sns.set()
import os
# os.chdir('/Users/harshparikh/Documents/GitHub/credence-to-causal-estimation/credence-v2/src/')
os.chdir('./credence-to-causal-estimation/credence-v2/src/')
import credence
import autoencoder
import importlib
importlib.reload(autoencoder)
importlib.reload(credence)
# os.chdir('/Users/harshparikh/Documents/GitHub/credence-to-causal-estimation/notebooks/')
os.chdir('../../../')

In [None]:
# Load data
e401k = pd.read_csv('../data/filtered_401k_data.csv')

# Preprocess data
outcome_col = 'net_tfa'
treatment_col = 'e401'
standardised_outcome_col = f'{outcome_col}_standardised'
Y_control = e401k.loc[e401k[treatment_col]==0, outcome_col]
Y_control_mean = Y_control.mean()
Y_control_std = Y_control.std()
e401k[standardised_outcome_col] = (e401k[outcome_col] - Y_control_mean) / Y_control_std
e401k_filtered = e401k.loc[(e401k[standardised_outcome_col] > -2) & (e401k[standardised_outcome_col] < +3)]
X = jnp.array(e401k_filtered[treatment_col].values)[:, None]
Y = jnp.array(e401k_filtered[standardised_outcome_col].values)[:, None]
covariate_colnames = [col for col in e401k_filtered.columns if col not in [outcome_col,standardised_outcome_col, treatment_col]]
# ['age', 'inc', 'educ', 'fsize', 'marr', 'twoearn', 'db', 'pira', 'hown', 'p401']

cont_columns = ['age', 'inc', 'educ']
disc_columns = ['fsize', 'marr', 'twoearn', 'db', 'pira', 'hown', 'p401']
# disc_columns = cont_columns + disc_columns
# cont_columns = []

for col in cont_columns:
    mean = e401k_filtered[col].mean()
    std = e401k_filtered[col].std()
    e401k_filtered[col] = (e401k_filtered[col] - mean) / std

Z_cont = jnp.array(e401k_filtered[cont_columns].values).astype(float)
Z_cont = None
# Z_disc = jnp.array(e401k_filtered[disc_columns].values)
Z_disc = jnp.array(e401k_filtered[disc_columns].values)
e401k_rescaled = e401k_filtered[
    [standardised_outcome_col, treatment_col] + covariate_colnames
]

true_ATE = 1000
benchmark_flow = FrugalFlowModel(Y=Y, X=X, Z_disc=Z_disc, Z_cont=Z_cont, confounding_copula=None)

In [None]:
cont_columns = ['age', 'inc', 'educ']
disc_columns = ['fsize', 'marr', 'twoearn', 'db', 'pira', 'hown', 'p401']

In [None]:
%%time
v1000 = credence.Credence(
    data=e401k_rescaled, # dataframe 
    post_treatment_var=[standardised_outcome_col], # list of post treatment variables
    treatment_var=[treatment_col], # list of treatment variable(s)
    categorical_var=[treatment_col] + disc_columns, # list of variables which are categorical
    numerical_var= [standardised_outcome_col] + cont_columns # list of variables which are numerical
)
gen = v1000.fit(effect_rigidity=1000,bias_rigidity=1000,kld_rigidity=0.01,max_epochs=250);
# v.trainer_treat.save_checkpoint("e401k_treat_0.ckpt");
# v.trainer_pre.save_checkpoint("e401k_pre_0.ckpt");
# v.trainer_post.save_checkpoint("e401k_post_0.ckpt");

In [None]:
sim_data_credence1000, sim_data_credence_prime1000 = v1000.sample(data=e401k_rescaled)
sim_data_credence1000[standardised_outcome_col] = sim_data_credence1000['Y0']

In [None]:
%%time
v5000 = credence.Credence(
    data=e401k_rescaled, # dataframe 
    post_treatment_var=[standardised_outcome_col], # list of post treatment variables
    treatment_var=[treatment_col], # list of treatment variable(s)
    categorical_var=[treatment_col] + disc_columns, # list of variables which are categorical
    numerical_var= [standardised_outcome_col] + cont_columns # list of variables which are numerical
)
gen = v5000.fit(effect_rigidity=1000,bias_rigidity=5000,kld_rigidity=0.01,max_epochs=250);
# v.trainer_treat.save_checkpoint("e401k_treat_0.ckpt");
# v.trainer_pre.save_checkpoint("e401k_pre_0.ckpt");
# v.trainer_post.save_checkpoint("e401k_post_0.ckpt");

In [None]:
sim_data_credence5000, sim_data_credence_prime = v5000.sample(data=e401k_rescaled)
sim_data_credence5000[standardised_outcome_col] = sim_data_credence5000['Y0']

In [None]:
# sns.set(font_scale=1.25)
# fig,ax = plt.subplots(ncols=2,figsize=(18,6))
# sns.heatmap(e401k_rescaled[[standardised_outcome_col] + covariate_colnames].corr(),ax=ax[0],square=True)
# ax[0].set_title('Observed NSW Data')
# sns.heatmap(sim_data_credence[[standardised_outcome_col] + covariate_colnames].corr(),ax=ax[1],square=True)
# ax[1].set_title('Generated NSW Data')
# fig.savefig('Lalonde_NSW_0_1000_0.png')

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

sns.set(font_scale=1.25)

# Create the subplots
fig, ax = plt.subplots(ncols=3, figsize=(22, 6), gridspec_kw={'width_ratios': [1, 1, 1]})

# Define the data for each heatmap
data_list = [
    e401k_rescaled[covariate_colnames + [standardised_outcome_col]].corr(),
    sim_data_df[covariate_colnames + [standardised_outcome_col]].corr(),
    sim_data_credence[covariate_colnames + [standardised_outcome_col]].corr(),
]

# Titles for each subplot
titles = [
    'Observed NSW Data',
    'Frugal Flows',
    'CREDENCE (bias = 5000)',
]

# Find the common vmin and vmax for consistent color scaling
vmin = min(d.min().min() for d in data_list)
vmax = max(d.max().max() for d in data_list)

# Create a colorbar axis
cbar_ax = fig.add_axes([0.92, 0.3, 0.02, 0.4])

# Plot each heatmap
for i, data in enumerate(data_list):
    sns.heatmap(data, ax=ax[i], square=True, vmin=vmin, vmax=vmax, cbar=(i == 2), cbar_ax=cbar_ax if i == 2 else None,
                yticklabels=(i == 0))
    ax[i].set_title(titles[i])
    if i != 0:
        ax[i].set_yticklabels([])

# Adjust layout
plt.subplots_adjust(wspace=0.1)

# Save the figure
fig.savefig('e401k_0_1000_0.png')
plt.show()


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

sns.set(font_scale=1.25)

# Create the subplots
fig, ax = plt.subplots(ncols=4, figsize=(22, 6), gridspec_kw={'width_ratios': [1, 1, 1, 1]})

# Define the data for each heatmap
data_list = [
    e401k_rescaled[covariate_colnames + [standardised_outcome_col]].corr(),
    sim_data_df[covariate_colnames + [standardised_outcome_col]].corr(),
    sim_data_credence5000[covariate_colnames + [standardised_outcome_col]].corr(),
    sim_data_credence1000[covariate_colnames + [standardised_outcome_col]].corr()
]

# Titles for each subplot
titles = [
    'Observed NSW Data',
    'Frugal Flows',
    'CREDENCE (bias = 5000) -- Original',
    'CREDENCE (bias = 1000)'
]

# Find the common vmin and vmax for consistent color scaling
vmin = min(d.min().min() for d in data_list)
vmax = max(d.max().max() for d in data_list)

# Create a colorbar axis
cbar_ax = fig.add_axes([0.92, 0.3, 0.02, 0.4])

# Plot each heatmap
for i, data in enumerate(data_list):
    sns.heatmap(data, ax=ax[i], square=True, vmin=vmin, vmax=vmax, cbar=(i == 3), cbar_ax=cbar_ax if i == 3 else None,
                yticklabels=(i == 0))
    ax[i].set_title(titles[i])
    if i != 0:
        ax[i].set_yticklabels([])

# Adjust layout
plt.subplots_adjust(wspace=0.1)
# fig.tight_layout(rect=[0, 0, 0.9, 1])

# Save the figure
fig.savefig('e401k_0_1000_0.png')
plt.show()


In [None]:
alphas = [0.1]#, 0.25, 0.5, 0.75, 1.0, 1.5, 2.0]
k = 3
# cols = ['age', 'black', 'education']
cols = covariate_colnames + [standardised_outcome_col]
n = 600
metric_dict = {}
metric_dict['Frugal Flows'] = compare_datasets(e401k_for_frugal_flow.loc[:, cols].sample(n).values, sim_data_df.loc[:, cols].sample(n).values, alphas=alphas, k=3, n_permutations=1000)
metric_dict['CREDENCE (Bias = 1000)'] = compare_datasets(e401k_rescaled.loc[:, cols].sample(n).values, sim_data_credence1000.loc[:, cols].sample(n).values, alphas=alphas, k=3, n_permutations=1000)
metric_dict['CREDENCE (Bias = 5000)'] = compare_datasets(e401k_rescaled.loc[:, cols].sample(n).values, sim_data_credence5000.loc[:, cols].sample(n).values, alphas=alphas, k=3, n_permutations=1000)

In [None]:
pd.DataFrame.from_dict(metric_dict, orient='columns')