# Mixed Data Frugal Flows

In this notebook we demonstrate the ability for Frugal Flows to identify Marginal Causal Effects when dealing with a mix of discrete and continous variables.

In [36]:
import sys
import os
sys.path.append("../") # go to parent dir

import jax
import jax.random as jr
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import numpy as np
from scipy.stats import rankdata
import scipy.stats as ss
import seaborn as sns
from sklearn.model_selection import KFold

# from data.create_sim_data import *
import data.template_causl_simulations as causl_py
from data.run_all_simulations import plot_simulation_results
from frugal_flows.causal_flows import independent_continuous_marginal_flow, get_independent_quantiles, train_frugal_flow
from frugal_flows.bijections import UnivariateNormalCDF

import rpy2.robjects as ro
from rpy2.robjects.packages import importr
from rpy2.robjects import pandas2ri
from rpy2.robjects.packages import SignatureTranslatedAnonymousPackage

# Activate automatic conversion of rpy2 objects to pandas objects
pandas2ri.activate()

# Import the R library causl
try:
    causl = importr('causl')
except Exception as e:
    package_names = ('causl')
    utils.install_packages(StrVector(package_names))

jax.config.update("jax_enable_x64", True)

hyperparams_dict = {
    'learning_rate': 3e-3,
    'RQS_knots': 8,
    'flow_layers': 5,
    'nn_width': 50,
    'nn_depth': 4,    
    'max_patience': 50,
    'max_epochs': 10000
}

SEED = 123
NUM_ITER = 20
NUM_SAMPLES = 20000
TRUE_PARAMS = {'ate': 1, 'const': 1, 'scale': 1}
CAUSAL_PARAMS = [1, 1]

In [37]:
results = {'dequantised': {'ate': [], 'const': []}, 'copula': {'ate': [], 'const': []},}
for i in range(5):
    Z_disc, Z_cont, X, Y = causl_py.generate_discrete_samples(N=NUM_SAMPLES, causal_params=CAUSAL_PARAMS, seed=i).values()
    ### QUANTISED
    uz_samples = causl_py.generate_uz_samples(seed=i, Z_cont=jnp.hstack([Z_disc, Z_cont]))['uz_samples']
    
    for d in range(2):
        uz_samples = uz_samples.at[:, d].set(uz_samples[:, d] + jr.normal(key=jr.PRNGKey(d*(i+1)), shape=(NUM_SAMPLES,))*0.05)
    
    ff, losses = train_frugal_flow(
        key=jr.PRNGKey(i+1000),
        y=Y,
        u_z=uz_samples,
        **hyperparams_dict,
        condition=X
    )
    causal_margin = ff.bijection.bijections[-1].bijection.bijections[0]
    results['dequantised']['ate'].append(causal_margin.ate)
    results['dequantised']['const'].append(causal_margin.const)

    #### ORIGINAL
    uz_samples = causl_py.generate_uz_samples(seed=i, Z_cont=Z_cont, Z_disc=Z_disc)['uz_samples']
        
    ff, losses = train_frugal_flow(
        key=jr.PRNGKey(i+1000),
        y=Y,
        u_z=uz_samples,
        **hyperparams_dict,
        condition=X
    )
    causal_margin = ff.bijection.bijections[-1].bijection.bijections[0]
    results['copula']['ate'].append(causal_margin.ate)
    results['copula']['const'].append(causal_margin.const)
    print(results['dequantised'])
    print(results['copula'])

  1%|▊                                                                                             | 82/10000 [01:39<3:21:16,  1.22s/it, train=-0.41128846990938156, val=-0.3239833555386186 (Max patience reached)]
  1%|▋                                                                                                | 73/10000 [01:25<3:13:26,  1.17s/it, train=1.3576480959655415, val=1.4213770076457195 (Max patience reached)]


{'ate': [Array(1.03476165, dtype=float64)], 'const': [Array(0.98425454, dtype=float64)]}
{'ate': [Array(1.03967894, dtype=float64)], 'const': [Array(0.97304505, dtype=float64)]}


  1%|▊                                                                                            | 81/10000 [01:37<3:19:23,  1.21s/it, train=-0.43198191293233595, val=-0.37127043399561294 (Max patience reached)]
  1%|▋                                                                                                | 77/10000 [01:31<3:15:49,  1.18s/it, train=1.3319012118971758, val=1.4075347249864625 (Max patience reached)]


{'ate': [Array(1.03476165, dtype=float64), Array(1.03289475, dtype=float64)], 'const': [Array(0.98425454, dtype=float64), Array(1.00024772, dtype=float64)]}
{'ate': [Array(1.03967894, dtype=float64), Array(1.00896918, dtype=float64)], 'const': [Array(0.97304505, dtype=float64), Array(0.97387995, dtype=float64)]}


  1%|▊                                                                                              | 85/10000 [01:38<3:11:48,  1.16s/it, train=-0.448144863433613, val=-0.30107616924868064 (Max patience reached)]
  1%|▋                                                                                                 | 70/10000 [01:22<3:14:59,  1.18s/it, train=1.327935106881753, val=1.4442452896004767 (Max patience reached)]


{'ate': [Array(1.03476165, dtype=float64), Array(1.03289475, dtype=float64), Array(1.02363634, dtype=float64)], 'const': [Array(0.98425454, dtype=float64), Array(1.00024772, dtype=float64), Array(0.97006195, dtype=float64)]}
{'ate': [Array(1.03967894, dtype=float64), Array(1.00896918, dtype=float64), Array(1.07165726, dtype=float64)], 'const': [Array(0.97304505, dtype=float64), Array(0.97387995, dtype=float64), Array(0.97393777, dtype=float64)]}


  1%|▌                                                                                            | 67/10000 [01:17<3:12:10,  1.16s/it, train=-0.40931635079750145, val=-0.39728510125942956 (Max patience reached)]
  1%|▋                                                                                                | 67/10000 [01:44<4:18:51,  1.56s/it, train=1.3606759368008245, val=1.4085474897134342 (Max patience reached)]


{'ate': [Array(1.03476165, dtype=float64), Array(1.03289475, dtype=float64), Array(1.02363634, dtype=float64), Array(1.11135723, dtype=float64)], 'const': [Array(0.98425454, dtype=float64), Array(1.00024772, dtype=float64), Array(0.97006195, dtype=float64), Array(0.92754435, dtype=float64)]}
{'ate': [Array(1.03967894, dtype=float64), Array(1.00896918, dtype=float64), Array(1.07165726, dtype=float64), Array(1.13494542, dtype=float64)], 'const': [Array(0.97304505, dtype=float64), Array(0.97387995, dtype=float64), Array(0.97393777, dtype=float64), Array(0.91517848, dtype=float64)]}


  1%|▋                                                                                              | 69/10000 [01:23<3:20:35,  1.21s/it, train=-0.4213740631867528, val=-0.3353653037739957 (Max patience reached)]
  1%|▊                                                                                                 | 77/10000 [01:29<3:11:40,  1.16s/it, train=1.344677016803782, val=1.4272019121431216 (Max patience reached)]

{'ate': [Array(1.03476165, dtype=float64), Array(1.03289475, dtype=float64), Array(1.02363634, dtype=float64), Array(1.11135723, dtype=float64), Array(1.06021987, dtype=float64)], 'const': [Array(0.98425454, dtype=float64), Array(1.00024772, dtype=float64), Array(0.97006195, dtype=float64), Array(0.92754435, dtype=float64), Array(0.96298439, dtype=float64)]}
{'ate': [Array(1.03967894, dtype=float64), Array(1.00896918, dtype=float64), Array(1.07165726, dtype=float64), Array(1.13494542, dtype=float64), Array(1.01175074, dtype=float64)], 'const': [Array(0.97304505, dtype=float64), Array(0.97387995, dtype=float64), Array(0.97393777, dtype=float64), Array(0.91517848, dtype=float64), Array(0.98298227, dtype=float64)]}





In [38]:
results

{'dequantised': {'ate': [Array(1.03476165, dtype=float64),
   Array(1.03289475, dtype=float64),
   Array(1.02363634, dtype=float64),
   Array(1.11135723, dtype=float64),
   Array(1.06021987, dtype=float64)],
  'const': [Array(0.98425454, dtype=float64),
   Array(1.00024772, dtype=float64),
   Array(0.97006195, dtype=float64),
   Array(0.92754435, dtype=float64),
   Array(0.96298439, dtype=float64)]},
 'copula': {'ate': [Array(1.03967894, dtype=float64),
   Array(1.00896918, dtype=float64),
   Array(1.07165726, dtype=float64),
   Array(1.13494542, dtype=float64),
   Array(1.01175074, dtype=float64)],
  'const': [Array(0.97304505, dtype=float64),
   Array(0.97387995, dtype=float64),
   Array(0.97393777, dtype=float64),
   Array(0.91517848, dtype=float64),
   Array(0.98298227, dtype=float64)]}}

In [28]:
causal_margin = ff.bijection.bijections[-1].bijection.bijections[0]

In [31]:
causal_margin.ate

Array(1.15434103, dtype=float64)

## For a single discrete parameter

Generating Normalised Data

In [7]:
Z, X, Y = generate_gaussian_samples(2000, 1).values()

In [None]:
def transform_var(Z, col_idx, inv_cdf):
    return inv_cdf(
            jax.scipy.special.ndtr(Z[:, col_idx])
        )

In [None]:
N = 20000

keys = jr.split(jr.PRNGKey(0), 3)

corr_matrix = jnp.array([
    [1, 0.8, 0.6, 0.2, 0.1],
    [0.8, 1, 0.4, 0.2, 0.1],
    [0.6, 0.4, 1, 0.1, 0.1],
    [0.2, 0.2, 0.1, 1, 0.1],
    [0.1, 0.1, 0.1, 0.1, 1]
])
Z = jr.multivariate_normal(
    keys[1], 
    jnp.array([0.,0.,0.,0.,0.]), 
    corr_matrix, 
    shape=(N,)
)


p = 1 / (
    1 + jnp.exp(-jnp.sum((Z) * 0.5 + jnp.ones(shape=(Z.shape[0], Z.shape[1])), axis=1))
)
X = jr.bernoulli(key=jr.PRNGKey(1), p=p).astype(int)[:, None]
Y = (jax.random.normal(keys[2], shape=(N,1)) + X + jnp.expand_dims(Z.sum(1), axis=1))

poisson_icdf = lambda x: scipy.stats.poisson.ppf(x, mu=5)
gamma_icdf = lambda x: scipy.stats.gamma.ppf(x, a=4)
bernoulli_icdf = lambda x: scipy.stats.bernoulli.ppf(x, p=0.3)

icdf_transforms = [poisson_icdf]#, bernoulli_icdf]#, gamma_icdf]
for i, icdf in enumerate(icdf_transforms):
    Z = Z.at[:, i].set(transform_var(Z, i, icdf))

data_xdyc = pd.DataFrame(
    jnp.concat([Y, X, Z], axis=1),
    columns=['Y', 'X'] + [f"Z{i+1}" for i in range(Z.shape[1])]
)

res = get_independent_quantiles(
    key = jr.PRNGKey(3),
    z_discr=jnp.expand_dims(Z[:, 0],axis=1).astype(int), #impose discrete
    # z_discr=Z[:, 0],
    z_cont=Z[:, 1:],
    RQS_knots=8,
    flow_layers=8,
    nn_width=20,
    nn_depth=8,
    max_epochs = 1000,
    max_patience=100,
    batch_size=200,
    learning_rate=5e-3,
    return_z_cont_flow = True,
 )
# u_z = res['u_z_cont']


In [None]:
u_z = jnp.concat([res['u_z_discr'], res['u_z_cont']], axis=1)

In [None]:
col_names = [f"U_Z{i}" for i in range(Z.shape[1])]
plot_data = pd.DataFrame(u_z, columns=col_names)
plot_data.head()

The correlation matrices have very close to the same entries but the rows and columns are permuted.

In [None]:
u_z[:,0].min()

In [None]:
u_z[:,0]==1

In [None]:
print("True Corr")
display(corr_matrix)
print("Flow Corr")
display(jnp.corrcoef(jax.scipy.special.ndtri(u_z[u_z[:,0]!=1,]).T))
sns.jointplot(x='U_Z1', y='U_Z2', data=plot_data, kind="scatter");

## Check conditional effect

In [None]:
data_xdyc.head()

In [None]:
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import OneHotEncoder

df_reg = data_xdyc[['Y', 'X']]
encoder = OneHotEncoder(drop='first')
X_encoded = encoder.fit_transform(data_xdyc[['X']])
model = LinearRegression()
model.fit(X_encoded, df_reg['Y'])

print(f"Intercept: {model.intercept_}")
print(f"Coefficients: {model.coef_}")

The data is **really** confounded!

### Training the Frugal FLow

In [None]:
def return_fits(frugal_flow):
    causal_margin = frugal_flow.bijection.bijections[-1].bijection.bijections[0]
    return {
        'ate': causal_margin.ate,
        'const': causal_margin.const,
        'scale': causal_margin.scale
    }

In [None]:
frugal_flow, losses = train_frugal_flow(key=jr.PRNGKey(1),
    y=Y,
    u_z=u_z,
    # u_z=res['u_z_cont'],
    RQS_knots=10,
    flow_layers=8,
    nn_width=10,
    nn_depth=8,
    max_epochs = 1000,
    max_patience=100,
    batch_size=200,
    learning_rate=5e-3,
    condition=X
 )

In [None]:
return_fits(frugal_flow)

Nice model fit!

## Diagnostics

In [None]:
keys = jr.split(jr.PRNGKey(0), 3)
frugal_flow_samples_0 = frugal_flow.sample(keys[0], condition=jnp.zeros((5000,1))) #
frugal_flow_samples_1 = frugal_flow.sample(keys[1], condition=jnp.ones((5000,1))) #

### No correlation between $Y$ and $Z_1$ or $Z_2$

In [None]:
frugal_flow_samples_0 = frugal_flow_samples_0.at[:, 1:].set(jax.scipy.special.ndtri(frugal_flow_samples_0[:, 1:]))
frugal_flow_samples_1 = frugal_flow_samples_1.at[:, 1:].set(jax.scipy.special.ndtri(frugal_flow_samples_1[:, 1:]))

True Correlation Matrix:

Flow outputs:

In [None]:
jnp.corrcoef(frugal_flow_samples_0.T)

In [None]:
jnp.corrcoef(frugal_flow_samples_1.T)

In [None]:
plt.scatter(*frugal_flow_samples_0[:,:2].T, label="Y-Z1", s=3)
plt.scatter(*frugal_flow_samples_0[:,[0,2]].T, label="Y-Z2", s=3)

# plt.scatter(*jax.scipy.special.ndtr(z).T, label="target", s=2)
plt.xlabel('u_y')
plt.ylabel('u_z1')
plt.legend()
plt.show()

## For a multiple discrete parameter

Generating Normalised Data

In [None]:
def transform_var(Z, col_idx, inv_cdf):
    return inv_cdf(
            jax.scipy.special.ndtr(Z[:, col_idx])
        )

In [None]:
N = 5000

# marginal_Z = {
#     'Z1': ss.norm(loc=0, scale=1),
#     'Z2': ss.norm(loc=0, scale=1),
#     'Z3': ss.norm(loc=3, scale=5),
#     'Z4': ss.norm(loc=-1, scale=2),
#     # 'Z5': ss.norm(loc=0, scale=1)
# }
# corr_matrix = np.array([
#     [1, 0.8, 0.6, 0.2, 0.1],
#     [0.8, 1, 0.4, 0.2, 0.1],
#     [0.6, 0.4, 1, 0.1, 0.1],
#     [0.2, 0.2, 0.1, 1, 0.1],
#     [0.1, 0.1, 0.1, 0.1, 1]
# ])
# treatment_type = "D"
# outcome_type = "C"
# prop_score_weights = [1, 1, 1, 1]  # Check propscore weights are of same dim as Z
# causal_params = [1, 1]
# data_xdyc = simulate_data(N, corr_matrix, marginal_Z, prop_score_weights, "D", causal_params, "C")
# df_Z = scipy.stats.zscore(data_xdyc[['Z1', 'Z2', 'Z3', 'Z4']].values)

# Y = jnp.array(data_xdyc[['Y']].values)
# X = jnp.array(data_xdyc[['X']].values)
# Z = jnp.array(df_Z)

keys = jr.split(jr.PRNGKey(0), 3)


corr_matrix = jnp.array([
    [1, 0.8, 0.6, 0.2, 0.1],
    [0.8, 1, 0.4, 0.2, 0.1],
    [0.6, 0.4, 1, 0.1, 0.1],
    [0.2, 0.2, 0.1, 1, 0.1],
    [0.1, 0.1, 0.1, 0.1, 1]
])
Z = jr.multivariate_normal(
    keys[1], 
    jnp.array([0.,0.,0.,0.,0.]), 
    corr_matrix, 
    shape=(N,)
)

p = 1 / (
    1 + jnp.exp(-jnp.sum((Z) * 0.5 * jnp.ones(shape=(Z.shape[0], Z.shape[1])), axis=1))
)
X = jr.bernoulli(key=jr.PRNGKey(1), p=p).astype(int)[:, None]
Y = (jax.random.normal(keys[2], shape=(N,1)) + X + jnp.expand_dims(Z.sum(1), axis=1))

poisson_icdf = lambda x: scipy.stats.poisson.ppf(x, mu=5)
gamma_icdf = lambda x: scipy.stats.gamma.ppf(x, a=4)
bernoulli_icdf = lambda x: scipy.stats.bernoulli.ppf(x, p=0.3)

icdf_transforms = [poisson_icdf, bernoulli_icdf]#, gamma_icdf]
for i, icdf in enumerate(icdf_transforms):
    Z = Z.at[:, i].set(transform_var(Z, i, icdf))

data_xdyc = pd.DataFrame(
    jnp.concat([Y, X, Z], axis=1),
    columns=['Y', 'X'] + [f"Z{i+1}" for i in range(Z.shape[1])]
)

res = get_independent_quantiles(
    key = jr.PRNGKey(3),
    z_discr=Z[:, [0, 1]].astype(int), #impose discrete
    # z_discr=Z[:, 0],
    z_cont=Z[:, 2:],
    RQS_knots=10,
    flow_layers=8,
    nn_width=10,
    nn_depth=8,
    max_epochs = 1000,
    max_patience=100,
    batch_size=200,
    learning_rate=5e-3,
    return_z_cont_flow = True,
 )

In [None]:
u_z = jnp.concat([res['u_z_discr'], res['u_z_cont']], axis=1)

In [None]:
col_names = [f"U_Z{i}" for i in range(Z.shape[1])]
plot_data = pd.DataFrame(u_z, columns=col_names)
display(data_xdyc.head())
display(plot_data.head())

The correlation matrices have very close to the same entries but the rows and columns are permuted.

In [None]:
print("True Corr")
display(corr_matrix)
print("Flow Corr")
display(jnp.corrcoef(jax.scipy.special.ndtri(u_z[u_z[:,0]!=1,]).T))
sns.jointplot(x='U_Z1', y='U_Z2', data=plot_data, kind="scatter");

## Check conditional effect

In [None]:
data_xdyc.head()

In [None]:
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import OneHotEncoder

df_reg = data_xdyc[['Y', 'X']]
encoder = OneHotEncoder(drop='first')
X_encoded = encoder.fit_transform(data_xdyc[['X']])
model = LinearRegression()
model.fit(X_encoded, df_reg['Y'])

print(f"Intercept: {model.intercept_}")
print(f"Coefficients: {model.coef_}")

The data is **really** confounded!

### Training the Frugal FLow

In [None]:
def return_fits(frugal_flow):
    causal_margin = frugal_flow.bijection.bijections[-1].bijection.bijections[0]
    return {
        'ate': causal_margin.ate,
        'const': causal_margin.const,
        'scale': causal_margin.scale
    }

In [None]:
frugal_flow, losses = train_frugal_flow(key=jr.PRNGKey(1),
    y=Y,
    u_z=u_z,
    # u_z=res['u_z_cont'],
    learning_rate=5e-3,
    RQS_knots=10,
    flow_layers=8,
    nn_width=10,
    nn_depth=8,
    batch_size=200,
    max_patience=100,
    max_epochs=10000,
    condition=X
 )

In [None]:
return_fits(frugal_flow)

Nice model fit!

## Diagnostics

In [None]:
keys = jr.split(jr.PRNGKey(0), 3)
frugal_flow_samples_0 = frugal_flow.sample(keys[0], condition=jnp.zeros((5000,1))) #
frugal_flow_samples_1 = frugal_flow.sample(keys[1], condition=jnp.ones((5000,1))) #

### No correlation between $Y$ and $Z_1$ or $Z_2$

In [None]:
frugal_flow_samples_0 = frugal_flow_samples_0.at[:, 1:].set(jax.scipy.special.ndtri(frugal_flow_samples_0[:, 1:]))
frugal_flow_samples_1 = frugal_flow_samples_1.at[:, 1:].set(jax.scipy.special.ndtri(frugal_flow_samples_1[:, 1:]))

True Correlation Matrix:

Flow outputs:

In [None]:
jnp.corrcoef(frugal_flow_samples_0.T)

In [None]:
jnp.corrcoef(frugal_flow_samples_1.T)

In [None]:
plt.hist(frugal_flow_samples_0[:,0], alpha=0.3);
plt.hist(frugal_flow_samples_0[:,1], alpha=0.3);
plt.hist(frugal_flow_samples_0[:,2], alpha=0.3);

In [None]:
plt.scatter(*frugal_flow_samples_0[:,:2].T, label="Y-Z1", s=3)
plt.scatter(*frugal_flow_samples_0[:,[0,2]].T, label="Y-Z2", s=3)

# plt.scatter(*jax.scipy.special.ndtr(z).T, label="target", s=2)
plt.xlabel('u_y')
plt.ylabel('u_z1')
plt.legend()
plt.show()

## For a multiple discrete parameter -- V2

Generating Normalised Data

In [None]:
def transform_var(Z, col_idx, inv_cdf):
    return inv_cdf(
            jax.scipy.special.ndtr(Z[:, col_idx])
        )

In [None]:
N = 10000

# marginal_Z = {
#     'Z1': ss.norm(loc=0, scale=1),
#     'Z2': ss.norm(loc=0, scale=1),
#     'Z3': ss.norm(loc=3, scale=5),
#     'Z4': ss.norm(loc=-1, scale=2),
#     # 'Z5': ss.norm(loc=0, scale=1)
# }
# corr_matrix = np.array([
#     [1, 0.8, 0.6, 0.2, 0.1],
#     [0.8, 1, 0.4, 0.2, 0.1],
#     [0.6, 0.4, 1, 0.1, 0.1],
#     [0.2, 0.2, 0.1, 1, 0.1],
#     [0.1, 0.1, 0.1, 0.1, 1]
# ])
# treatment_type = "D"
# outcome_type = "C"
# prop_score_weights = [1, 1, 1, 1]  # Check propscore weights are of same dim as Z
# causal_params = [1, 1]
# data_xdyc = simulate_data(N, corr_matrix, marginal_Z, prop_score_weights, "D", causal_params, "C")
# df_Z = scipy.stats.zscore(data_xdyc[['Z1', 'Z2', 'Z3', 'Z4']].values)

# Y = jnp.array(data_xdyc[['Y']].values)
# X = jnp.array(data_xdyc[['X']].values)
# Z = jnp.array(df_Z)

keys = jr.split(jr.PRNGKey(0), 3)


corr_matrix = jnp.array([
    [1, 0.8, 0.6, 0.2, 0.1],
    [0.8, 1, 0.4, 0.2, 0.1],
    [0.6, 0.4, 1, 0.1, 0.1],
    [0.2, 0.2, 0.1, 1, 0.1],
    [0.1, 0.1, 0.1, 0.1, 1]
])
Z = jr.multivariate_normal(
    keys[1], 
    jnp.array([0.,0.,0.,0.,0.]), 
    corr_matrix, 
    shape=(N,)
)
Z = Z.at[:, 0].set(jr.bernoulli(key=jr.PRNGKey(9999), p=0.5, shape=(N,)))

p = 1 / (
    1 + jnp.exp(-jnp.sum((Z) * 0.5 * jnp.ones(shape=(Z.shape[0], Z.shape[1])), axis=1))
)
X = jr.bernoulli(key=jr.PRNGKey(1), p=p).astype(int)[:, None]
Y = (jax.random.normal(keys[2], shape=(N,1)) + X + jnp.expand_dims(Z.sum(1), axis=1) - 0.5)

# poisson_icdf = lambda x: scipy.stats.poisson.ppf(x, mu=5)
# gamma_icdf = lambda x: scipy.stats.gamma.ppf(x, a=4)
# bernoulli_icdf = lambda x: scipy.stats.bernoulli.ppf(x, p=0.3)

# icdf_transforms = [poisson_icdf]#, bernoulli_icdf]#, gamma_icdf]
# for i, icdf in enumerate(icdf_transforms):
#     Z = Z.at[:, i].set(transform_var(Z, i, icdf))

data_xdyc = pd.DataFrame(
    jnp.concat([Y, X, Z], axis=1),
    columns=['Y', 'X'] + [f"Z{i+1}" for i in range(Z.shape[1])]
)

res = get_independent_quantiles(
    key = jr.PRNGKey(3),
    z_discr=Z[:, 0].astype(int)[:, None], #impose discrete
    # z_discr=Z[:, 0],
    z_cont=Z[:, 1:],
    RQS_knots=10,
    flow_layers=8,
    nn_width=10,
    nn_depth=8,
    max_epochs = 1000,
    max_patience=100,
    batch_size=200,
    learning_rate=5e-3,
    return_z_cont_flow = True,
 )

In [None]:
u_z = jnp.concat([res['u_z_discr'], res['u_z_cont']], axis=1)

In [None]:
col_names = [f"U_Z{i}" for i in range(Z.shape[1])]
plot_data = pd.DataFrame(u_z, columns=col_names)
display(data_xdyc.head())
display(plot_data.head())

The correlation matrices have very close to the same entries but the rows and columns are permuted.

In [None]:
print("True Corr")
display(corr_matrix)
print("Flow Corr")
display(jnp.corrcoef(jax.scipy.special.ndtri(u_z[u_z[:,0]!=1,]).T))
sns.jointplot(x='U_Z1', y='U_Z2', data=plot_data, kind="scatter");

## Check conditional effect

In [None]:
data_xdyc.head()

In [None]:
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import OneHotEncoder

df_reg = data_xdyc[['Y', 'X']]
encoder = OneHotEncoder(drop='first')
X_encoded = encoder.fit_transform(data_xdyc[['X']])
model = LinearRegression()
model.fit(X_encoded, df_reg['Y'])

print(f"Intercept: {model.intercept_}")
print(f"Coefficients: {model.coef_}")

The data is **really** confounded!

### Training the Frugal FLow

In [None]:
def return_fits(frugal_flow):
    causal_margin = frugal_flow.bijection.bijections[-1].bijection.bijections[0]
    return {
        'ate': causal_margin.ate,
        'const': causal_margin.const,
        'scale': causal_margin.scale
    }

In [None]:
frugal_flow, losses = train_frugal_flow(key=jr.PRNGKey(1),
    y=Y,
    u_z=u_z,
    # u_z=res['u_z_cont'],
    learning_rate=1e-3,
    RQS_knots=10,
    flow_layers=8,
    nn_width=10,
    nn_depth=8,
    batch_size=200,
    max_patience=50,
    max_epochs=10000,
    condition=X
 )

In [None]:
return_fits(frugal_flow)

Now I get a good fit!

## Diagnostics

In [None]:
keys = jr.split(jr.PRNGKey(0), 3)
frugal_flow_samples_0 = frugal_flow.sample(keys[0], condition=jnp.zeros((5000,1))) #
frugal_flow_samples_1 = frugal_flow.sample(keys[1], condition=jnp.ones((5000,1))) #

### No correlation between $Y$ and $Z_1$ or $Z_2$

In [None]:
frugal_flow_samples_0 = frugal_flow_samples_0.at[:, 1:].set(jax.scipy.special.ndtri(frugal_flow_samples_0[:, 1:]))
frugal_flow_samples_1 = frugal_flow_samples_1.at[:, 1:].set(jax.scipy.special.ndtri(frugal_flow_samples_1[:, 1:]))

True Correlation Matrix:

Flow outputs:

In [None]:
jnp.corrcoef(frugal_flow_samples_0.T)

In [None]:
jnp.corrcoef(frugal_flow_samples_1.T)

In [None]:
plt.hist(frugal_flow_samples_0[:,0], alpha=0.3);
plt.hist(frugal_flow_samples_0[:,1], alpha=0.3);
plt.hist(frugal_flow_samples_0[:,2], alpha=0.3);

In [None]:
plt.scatter(*frugal_flow_samples_0[:,:2].T, label="Y-Z1", s=3)
plt.scatter(*frugal_flow_samples_0[:,[0,2]].T, label="Y-Z2", s=3)

# plt.scatter(*jax.scipy.special.ndtr(z).T, label="target", s=2)
plt.xlabel('u_y')
plt.ylabel('u_z1')
plt.legend()
plt.show()