# Mixed Data Frugal Flows

In this notebook we demonstrate the ability for Frugal Flows to identify Marginal Causal Effects when dealing with a mix of discrete and continous variables.

In [1]:
import sys
import os
sys.path.append("../") # go to parent dir

import jax
import jax.random as jr
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import scipy.stats as ss
import seaborn as sns
from sklearn.model_selection import KFold

from data.create_sim_data import simulate_data
from frugal_flows.causal_flows import get_independent_quantiles, train_frugal_flow
from frugal_flows.bijections import UnivariateNormalCDF

jax.config.update("jax_enable_x64", True)

## Checking for the Causal Effect

Generating Normalised Data

In [2]:
def transform_var(Z, col_idx, inv_cdf):
    return inv_cdf(
            jax.scipy.special.ndtr(Z[:, col_idx])
        )

In [3]:
N = 10000

# marginal_Z = {
#     'Z1': ss.norm(loc=0, scale=1),
#     'Z2': ss.norm(loc=0, scale=1),
#     'Z3': ss.norm(loc=3, scale=5),
#     'Z4': ss.norm(loc=-1, scale=2),
#     # 'Z5': ss.norm(loc=0, scale=1)
# }
# corr_matrix = np.array([
#     [1, 0.8, 0.6, 0.2, 0.1],
#     [0.8, 1, 0.4, 0.2, 0.1],
#     [0.6, 0.4, 1, 0.1, 0.1],
#     [0.2, 0.2, 0.1, 1, 0.1],
#     [0.1, 0.1, 0.1, 0.1, 1]
# ])
# treatment_type = "D"
# outcome_type = "C"
# prop_score_weights = [1, 1, 1, 1]  # Check propscore weights are of same dim as Z
# causal_params = [1, 1]
# data_xdyc = simulate_data(N, corr_matrix, marginal_Z, prop_score_weights, "D", causal_params, "C")
# df_Z = scipy.stats.zscore(data_xdyc[['Z1', 'Z2', 'Z3', 'Z4']].values)

# Y = jnp.array(data_xdyc[['Y']].values)
# X = jnp.array(data_xdyc[['X']].values)
# Z = jnp.array(df_Z)

keys = jr.split(jr.PRNGKey(0), 3)


Z = jr.multivariate_normal(
    keys[1], 
    jnp.array([0.,0.,0.,0.,0.]), 
    jnp.array([
        [1, 0.8, 0.6, 0.2, 0.1],
        [0.8, 1, 0.4, 0.2, 0.1],
        [0.6, 0.4, 1, 0.1, 0.1],
        [0.2, 0.2, 0.1, 1, 0.1],
        [0.1, 0.1, 0.1, 0.1, 1]
    ]), 
    shape=(N,)
)


p = 1 / (
    1 + jnp.exp(-jnp.sum((Z) * 1 * jnp.ones(shape=(Z.shape[0], Z.shape[1])), axis=1))
)
X = jr.bernoulli(key=jr.PRNGKey(1), p=p).astype(int)[:, None]
Y = (jax.random.normal(keys[2], shape=(N,1)) + X + jnp.expand_dims(Z.sum(1), axis=1))

poisson_icdf = lambda x: scipy.stats.poisson.ppf(x, mu=5)
gamma_icdf = lambda x: scipy.stats.gamma.ppf(x, a=4)
bernoulli_icdf = lambda x: scipy.stats.bernoulli.ppf(x, p=0.3)

icdf_transforms = [poisson_icdf]#, bernoulli_icdf]#, gamma_icdf]
for i, icdf in enumerate(icdf_transforms):
    Z = Z.at[:, i].set(transform_var(Z, i, icdf))

data_xdyc = pd.DataFrame(
    jnp.concat([Y, X, Z], axis=1),
    columns=['Y', 'X'] + [f"Z{i+1}" for i in range(Z.shape[1])]
)

res = get_independent_quantiles(
    key = jr.PRNGKey(3),
    z_discr=jnp.expand_dims(Z[:, 0],axis=1), #impose discrete
    # z_discr=Z[:, 0],
    z_cont=Z[:, 1:],
    RQS_knots=8,
    flow_layers=4,
    nn_width=10,
    nn_depth=4,
    max_epochs = 1000,
    max_patience=100,
    learning_rate=5e-3,
    return_z_cont_flow = True,
 )

 16%|████████████████                                                                                    | 160/1000 [00:47<04:06,  3.40it/s, train=5.663252613317186, val=5.658494163536159 (Max patience reached)]


TypeError: Indexer must have integer or boolean type, got indexer with type float64 at position 0, indexer value Traced<ShapedArray(float64[])>with<DynamicJaxprTrace(level=3/0)>

In [None]:
col_names = [f"U_Z{i}" for i in range(Z.shape[1])]
plot_data = pd.DataFrame(res['u_z_cont'], columns=col_names)
plot_data.head()

The correlation matrices have very close to the same entries but the rows and columns are permuted.

In [None]:
print("True Corr")
display(jnp.corrcoef(Z.T))
print("Flow Corr")
display(jnp.corrcoef(jax.scipy.special.ndtri(plot_data.values).T))
sns.jointplot(x='U_Z1', y='U_Z2', data=plot_data, kind="scatter");

## Check conditional effect

In [None]:
data_xdyc.head()

In [None]:
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import OneHotEncoder

df_reg = data_xdyc[['Y', 'X']]
encoder = OneHotEncoder(drop='first')
X_encoded = encoder.fit_transform(data_xdyc[['X']])
model = LinearRegression()
model.fit(X_encoded, df_reg['Y'])

print(f"Intercept: {model.intercept_}")
print(f"Coefficients: {model.coef_}")

The data is indeed confounded.

### Training the Frugal FLow

In [None]:
def return_fits(frugal_flow):
    causal_margin = frugal_flow.bijection.bijections[-1].bijection.bijections[0]
    return {
        'ate': causal_margin.ate,
        'const': causal_margin.const,
        'scale': causal_margin.scale
    }

In [None]:
# frugal_flow, losses =  train_frugal_flow(
#     key=jr.PRNGKey(1),
#     y=jnp.array(Y),
#     u_z=res['u_z_cont'],
#     RQS_knots=4,
#     nn_depth=4,
#     nn_width=10,
#     flow_layers=1,
#     show_progress=True,
#     learning_rate=1e-2,
#     max_epochs=20000,
#     max_patience=100,
#     batch_size=100,
#     condition=jnp.array(X),
#  )

frugal_flow, losses = train_frugal_flow(key=jr.PRNGKey(1),
    y=Y,
    u_z=res['u_z_cont'],
    # u_z=res['u_z_cont'],
    learning_rate=5e-3,
    RQS_knots=8,
    nn_depth=4,
    nn_width=10,
    flow_layers=8,
    max_patience=70,
    max_epochs=10000,
    condition=X
 )

In [None]:
return_fits(frugal_flow)

## Diagnostics

In [None]:
keys = jr.split(jr.PRNGKey(0), 3)
frugal_flow_samples_0 = frugal_flow.sample(keys[0], condition=jnp.zeros((5000,1))) #
frugal_flow_samples_1 = frugal_flow.sample(keys[1], condition=jnp.ones((5000,1))) #

### No correlation between $Y$ and $Z_1$ or $Z_2$

In [None]:
frugal_flow_samples_0 = frugal_flow_samples_0.at[:, 1:].set(jax.scipy.special.ndtri(frugal_flow_samples_0[:, 1:]))
frugal_flow_samples_1 = frugal_flow_samples_1.at[:, 1:].set(jax.scipy.special.ndtri(frugal_flow_samples_1[:, 1:]))

True Correlation Matrix:

Flow outputs:

In [None]:
jnp.corrcoef(frugal_flow_samples_0.T)

In [None]:
jnp.corrcoef(frugal_flow_samples_1.T)

In [None]:
plt.scatter(*frugal_flow_samples_0[:,:2].T, label="Y-Z1", s=3)
plt.scatter(*frugal_flow_samples_0[:,[0,2]].T, label="Y-Z2", s=3)

# plt.scatter(*jax.scipy.special.ndtr(z).T, label="target", s=2)
plt.xlabel('u_y')
plt.ylabel('u_z1')
plt.legend()
plt.show()

In [None]:
plt.scatter(*frugal_flow_samples_0[:,1:].T, label="zero", s=2)
plt.scatter(*jax.scipy.special.ndtri(res['u_z_cont']).T, label="one", s=3)
plt.xlabel('u_z1')
plt.ylabel('u_z2')
plt.legend()
plt.show()