# Example Pipeline for Lalonde

This notebook is a proof-of-concept for generating causal samples from external samples

In [1]:
import contextlib
import sys
import os
sys.path.append("../")  # go to parent dir
# sys.path.append("../data/analysis/")  # go to parent dir

import jax
import jax.random as jr
import jax.numpy as jnp
# jnp.set_printoptions(precision=2)
jax.config.update("jax_enable_x64", True)
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import rankdata

import rpy2.robjects as ro
from rpy2.robjects.packages import importr
from rpy2.robjects import pandas2ri
from rpy2.robjects.vectors import StrVector
from rpy2.robjects.packages import SignatureTranslatedAnonymousPackage

from frugal_flows.causal_flows import independent_continuous_marginal_flow, get_independent_quantiles, train_frugal_flow
from frugal_flows.sample_outcome import sample_outcome
from frugal_flows.sample_marginals import from_quantiles_to_marginal_cont, from_quantiles_to_marginal_discr
from frugal_flows.train_quantile_propensity_score import train_quantile_propensity_score
from frugal_flows.bijections import UnivariateNormalCDF
from frugal_flows.benchmarking import FrugalFlowModel
from benchmarking import compare_datasets
from frugal_flows.sample_outcome import sample_outcome
from frugal_flows.sample_marginals import from_quantiles_to_marginal_cont, from_quantiles_to_marginal_discr
from frugal_flows.train_quantile_propensity_score import train_quantile_propensity_score


import data.template_causl_simulations as causl_py
import data.analysis.validationMethods as valMethods
import wandb

# Activate automatic conversion of rpy2 objects to pandas objects
pandas2ri.activate()
base = importr('base')
utils = importr('utils')

# Import the R library causl
try:
    causl = importr('causl')
except Exception as e:
    package_names = ('causl')
    utils.install_packages(StrVector(package_names))

seed = 0
N = 2000
B = 50
sampling_size = 1000
keys, *subkeys = jr.split(jr.PRNGKey(seed), 20)

def clean_ate(value):
    if isinstance(value, (list, tuple, np.ndarray)):
        return np.mean(value)
    return value

lalonde = pd.read_csv('../data/filtered_lalonde_dataset.csv')
lalonde = lalonde

outcome_col = 're78'
treatment_col = 'treatment'
standardised_outcome_col = f'{outcome_col}_standardised'
Y_control = lalonde.loc[lalonde[treatment_col]==0, outcome_col]
Y_control_mean = Y_control.mean()
Y_control_std = Y_control.std()



def rescale_outcome(x, mean, std):
    return x * std + mean
    
cont_columns = ['age']
disc_columns = ['education', 'black', 'hispanic', 'married', 'nodegree']

lalonde[standardised_outcome_col] = (lalonde[outcome_col] - Y_control_mean) / Y_control_std
for col in cont_columns:
    mean = lalonde[col].mean()
    std = lalonde[col].std()
    lalonde[col] = (lalonde[col] - mean) / std

X = jnp.array(lalonde[treatment_col].values)[:, None]
Y = jnp.array(lalonde[standardised_outcome_col].values)[:, None]

covariate_colnames = ['age', 'education', 'black', 'hispanic', 'married', 'nodegree']
# Z_disc = jnp.array(lalonde[['black', 'hispanic', 'married', 'nodegree']].values)
Z_disc = jnp.array(lalonde[disc_columns].values)
Z_cont = jnp.array(lalonde[cont_columns].values)
# Z_disc = jnp.array(lalonde[covariate_colnames].values)

lalonde_rescaled = lalonde[
    [standardised_outcome_col, treatment_col] + covariate_colnames
]

In [None]:
marginal_hyperparam_dict = {
    'learning_rate': 5e-4,
    # 'learning_rate': 0.2,
    'RQS_knots': 8,
    'flow_layers': 3,
    'nn_depth': 5,    
    'nn_width': 10,
    'max_patience': 100,
    'max_epochs': 20000
}

hyperparam_dict = {
    'learning_rate': 0.006335,
    # 'learning_rate': 0.2,
    'RQS_knots': 4,
    'flow_layers': 9,
    'nn_depth': 10,    
    'nn_width': 50,
    'max_patience': 100,
    'max_epochs': 20000
}
causal_margin_hyperparams_dict = {
    'learning_rate': 0.005,
    'RQS_knots': 8,
    'flow_layers': 10,
    'nn_depth': 20,    
    'nn_width': 50,
    'max_epochs': 20000,
    'max_patience': 200,
}
seed=1

true_ATE = 1000
benchmark_flow = FrugalFlowModel(Y=Y, X=X, Z_disc=Z_disc, Z_cont=Z_cont, confounding_copula=None)
benchmark_flow.train_benchmark_model(
    training_seed=jr.PRNGKey(seed),
    marginal_hyperparam_dict=marginal_hyperparam_dict,
    frugal_hyperparam_dict=hyperparam_dict,
    causal_model='location_translation',
    causal_model_args={'ate': 0, **causal_margin_hyperparams_dict},
    prop_flow_hyperparam_dict=hyperparam_dict
)

  4%|████▊                                                                                                                    | 797/20000 [00:36<28:16, 11.32it/s, train=0.8558722663016056, val=1.2966533465543875]

### Unconfounded Data

In [None]:
sim_data_df = benchmark_flow.generate_samples(
    key=jr.PRNGKey(10*seed),
    sampling_size=600,
    copula_param=0.,
    outcome_causal_model='location_translation',
    outcome_causal_args={'ate': true_ATE / Y_control_std},
    with_confounding=False
)
sim_data_df.columns = lalonde_rescaled.columns

In [None]:
sim_data_df['treatment'].value_counts()

In [None]:
lalonde_rescaled['treatment'].value_counts()

In [None]:
plotted_col = standardised_outcome_col#'education'
sim_data_df[plotted_col].hist(density=True, alpha=0.5, label='Y from sim_data_df', bins=20)
lalonde_rescaled[plotted_col].hist(density=True, alpha=0.5, label='Standardized Outcome from lalonde_rescaled', bins=20)
plt.legend()
plt.xlabel('Value')
plt.ylabel('Density')
plt.show()

In [None]:
lalonde_rescaled.head()

In [None]:
alphas = [0.1]#, 0.25, 0.5, 0.75, 1.0, 1.5, 2.0]
k = 3
# cols = ['age', 'black', 'education']
cols = ["age", "education", "black", "hispanic", "married", "nodegree"]
# cols = ["treatment"]
lalonde_control = lalonde_rescaled.loc[lalonde_rescaled['treatment']==0]
sim_data_df_control = sim_data_df.loc[sim_data_df['treatment']==0]
compare_datasets(lalonde_rescaled.loc[:, cols].sample(300).values, sim_data_df_control.loc[:, cols].sample(300).values, alphas=alphas, k=3, n_permutations=1000)

In [None]:
lalonde_control[cols + [standardised_outcome_col]].corr()

In [None]:
sim_data_df[cols + [standardised_outcome_col]].corr()

In [None]:
import seaborn as sns

sns.set(font_scale=1.25)
fig,ax = plt.subplots(ncols=2,figsize=(18,6))
sns.heatmap(lalonde_control[cols + [standardised_outcome_col]].corr(),ax=ax[0],square=True)
ax[0].set_title('Observed NSW Data')
sns.heatmap(sim_data_df[cols + [standardised_outcome_col]].corr(),ax=ax[1],square=True)
ax[1].set_title('Generated NSW Data')
# fig.savefig('Lalonde_NSW_0_1000_0.png')

# With Credence

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import pytorch_lightning as pl
import tqdm
sns.set()
import os
# os.chdir('/Users/harshparikh/Documents/GitHub/credence-to-causal-estimation/credence-v2/src/')
os.chdir('./credence-to-causal-estimation/credence-v2/src/')
import credence
import autoencoder
import importlib
importlib.reload(autoencoder)
importlib.reload(credence)
# os.chdir('/Users/harshparikh/Documents/GitHub/credence-to-causal-estimation/notebooks/')
os.chdir('../../../')

In [None]:
v5000 = credence.Credence(
    data=lalonde_rescaled, # dataframe 
    post_treatment_var=['re78_standardised'], # list of post treatment variables
    treatment_var=['treatment'], # list of treatment variable(s)
    categorical_var=['treatment'] + disc_columns, # list of variables which are categorical
    numerical_var= ['re78_standardised'] + cont_columns # list of variables which are numerical
)

In [None]:
%%time
gen = v5000.fit(effect_rigidity=0,bias_rigidity=5000,kld_rigidity=0.01,max_epochs=500)
v5000.trainer_treat.save_checkpoint("nsw_treat_5000.ckpt");
v5000.trainer_pre.save_checkpoint("nsw_pre_5000.ckpt");
v5000.trainer_post.save_checkpoint("nsw_post_5000.ckpt");

In [None]:
sim_data_credence_bias5000, sim_data_credence_prime5000 = v5000.sample(data=lalonde_rescaled)
sim_data_credence_bias5000.loc[0, 'education'] = 2
sim_data_credence_bias5000.loc[0, 'hispanic'] = 1
sim_data_credence_bias5000[standardised_outcome_col] = sim_data_credence_bias5000['Y0']

In [None]:
%%time
v1000 = credence.Credence(
    data=lalonde_rescaled, # dataframe 
    post_treatment_var=['re78_standardised'], # list of post treatment variables
    treatment_var=['treatment'], # list of treatment variable(s)
    categorical_var=['treatment'] + disc_columns, # list of variables which are categorical
    numerical_var= ['re78_standardised'] + cont_columns # list of variables which are numerical
)

gen = v1000.fit(effect_rigidity=0,bias_rigidity=1000,kld_rigidity=0.01,max_epochs=500)
v1000.trainer_treat.save_checkpoint("nsw_treat_1000.ckpt");
v1000.trainer_pre.save_checkpoint("nsw_pre_1000.ckpt");
v1000.trainer_post.save_checkpoint("nsw_post_1000.ckpt");

In [None]:
sim_data_credence_bias1000, sim_data_credence_prime1000 = v1000.sample(data=lalonde_rescaled)
sim_data_credence_bias1000.loc[0, 'hispanic'] = 1
sim_data_credence_bias1000[standardised_outcome_col] = sim_data_credence_bias1000['Y0']

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

sns.set(font_scale=1.25)

# Create the subplots
fig, ax = plt.subplots(ncols=4, figsize=(22, 6), gridspec_kw={'width_ratios': [1, 1, 1, 1]})

# Define the data for each heatmap
data_list = [
    lalonde_rescaled[cols + [standardised_outcome_col]].corr(),
    sim_data_df[cols + [standardised_outcome_col]].corr(),
    sim_data_credence_bias5000[cols + [standardised_outcome_col]].corr(),
    sim_data_credence_bias1000[cols + [standardised_outcome_col]].corr()
]

# Titles for each subplot
titles = [
    'Observed NSW Data',
    'Frugal Flows',
    'CREDENCE (bias = 5000) -- Original',
    'CREDENCE (bias = 1000)'
]

# Find the common vmin and vmax for consistent color scaling
vmin = min(d.min().min() for d in data_list)
vmax = max(d.max().max() for d in data_list)

# Create a colorbar axis
cbar_ax = fig.add_axes([0.92, 0.3, 0.02, 0.4])

# Plot each heatmap
for i, data in enumerate(data_list):
    sns.heatmap(data, ax=ax[i], square=True, vmin=vmin, vmax=vmax, cbar=(i == 3), cbar_ax=cbar_ax if i == 3 else None,
                yticklabels=(i == 0))
    ax[i].set_title(titles[i])
    if i != 0:
        ax[i].set_yticklabels([])

# Adjust layout
plt.subplots_adjust(wspace=0.1)
# fig.tight_layout(rect=[0, 0, 0.9, 1])

# Save the figure
fig.savefig('Lalonde_NSW_0_1000_0.png')
plt.show()


In [None]:
alphas = [0.1]#, 0.25, 0.5, 0.75, 1.0, 1.5, 2.0]
k = 3
# cols = ['age', 'black', 'education']
cols = [standardised_outcome_col] + ["age", "education", "black", "hispanic", "married", "nodegree"]
n = 250
metric_dict = {}
metric_dict['Frugal Flows'] = compare_datasets(lalonde_rescaled.loc[:, cols].sample(n).values, sim_data_df.loc[:, cols].sample(n).values, alphas=alphas, k=3, n_permutations=1000)
metric_dict['CREDENCE (Bias = 1000)'] = compare_datasets(lalonde_rescaled.loc[:, cols].sample(n).values, sim_data_credence_bias1000.loc[:, cols].sample(n).values, alphas=alphas, k=3, n_permutations=1000)
metric_dict['CREDENCE (Bias = 5000)'] = compare_datasets(lalonde_rescaled.loc[:, cols].sample(n).values, sim_data_credence_bias5000.loc[:, cols].sample(n).values, alphas=alphas, k=3, n_permutations=1000)

In [None]:
pd.DataFrame.from_dict(metric_dict, orient='columns')