In [1]:
# Example of error

In [2]:
import itertools
import random
import sys
import os
sys.path.append("../") # go to parent dir

import jax
import jax.random as jr
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import numpy as np
from scipy.stats import rankdata
from jax.scipy.special import expit
import scipy.stats as ss
import seaborn as sns
from sklearn.model_selection import KFold

jnp.set_printoptions(precision=2)

# from data.create_sim_data import *
import data_processing_and_simulations.causl_sim_data_generation as causl_py
from data_processing_and_simulations.run_all_simulations import plot_simulation_results
from frugal_flows.causal_flows import independent_continuous_marginal_flow, get_independent_quantiles, train_frugal_flow, train_copula_flow
from frugal_flows.bijections import UnivariateNormalCDF
from frugal_flows.benchmarking import FrugalFlowModel

import rpy2.robjects as ro
from rpy2.robjects.packages import importr
from rpy2.robjects import pandas2ri
from rpy2.robjects.packages import SignatureTranslatedAnonymousPackage
import wandb

# Activate automatic conversion of rpy2 objects to pandas objects
pandas2ri.activate()

# Import the R library causl
try:
    causl = importr('causl')
except Exception as e:
    package_names = ('causl')
    utils.install_packages(StrVector(package_names))


hyperparams_dict = {
    'learning_rate': 5e-3,
    'RQS_knots': 8,
    'flow_layers': 5,
    'nn_width': 30,
    'nn_depth': 4,    
    'max_patience': 100,
    'max_epochs': 10000
}

jax.config.update("jax_enable_x64", True)

mixed_cont_rscript = """
library(causl)
forms <- list(list(Z1 ~ 1), X ~ Z1, Y ~ X, ~ 1)
fams <- list(1, 5, 1, 1)
pars <- list(Z1 = list(beta=0, phi=2),
             X = list(beta=c(0,2)),
             Y = list(beta=c(0,2), phi=1),
             cop = list(beta=matrix(c(0.8), nrow=1)))



set.seed(1234)
n <- 1e3

data_samples <- rfrugalParam(n, formulas = forms, family = fams, pars = pars)
# Convert multi-dimensional columns to separate one-dimensional columns
data_samples <- as.data.frame(data_samples)
if (any(sapply(data_samples, is.matrix))) {
    for (col_name in names(data_samples)) {
        if (is.matrix(data_samples[[col_name]])) {
            mat <- data_samples[[col_name]]
            for (i in seq_len(ncol(mat))) {
                data_samples[[col_name]] <- mat[, i]
            }
        }
    }
}
"""
rcode_compiled = SignatureTranslatedAnonymousPackage(mixed_cont_rscript, "powerpack")
df = rcode_compiled.data_samples

Y = jnp.array(df['Y'].values)[:, None]
X = jnp.array(df['X'].values)[:, None]
Z_cont = jnp.array(df['Z1'].values)[:, None]

R[write to console]: Inversion method selected: using pair-copula parameterization



## Gaussian Flow

In [3]:
gaussian_flow = FrugalFlowModel(Y=Y, X=X, Z_cont=Z_cont, Z_disc=None, confounding_copula=None)

In [4]:
gaussian_flow.train_benchmark_model(
    training_seed=jr.PRNGKey(0),
    marginal_hyperparam_dict=hyperparams_dict,
    frugal_hyperparam_dict=hyperparams_dict,
    causal_model='gaussian',
    causal_model_args={'ate': jnp.array([-7.]), 'const': 3., 'scale': 5},
    prop_flow_hyperparam_dict=hyperparams_dict
)

  2%|█▎                                                       | 239/10000 [00:12<08:24, 19.34it/s, train=1.7414076769779299, val=1.8218569844313255 (Max patience reached)]
  3%|█▉                                                       | 345/10000 [00:15<07:21, 21.87it/s, train=1.2258003139344593, val=1.4791259531597616 (Max patience reached)]
  1%|▋                                                    | 121/10000 [00:05<07:24, 22.21it/s, train=-0.43901179122748935, val=-0.05080161866032616 (Max patience reached)]


In [5]:
gaussian_samples = gaussian_flow.generate_samples(
    key=jr.PRNGKey(1),
    sampling_size=(1000),
    copula_param=0,
    outcome_causal_model='causal_cdf',
    outcome_causal_args={'ate': jnp.array([2.]), 'const': -1., 'scale': 5.},
    with_confounding=True
)

Y shape: (1000, 1)
X shape: (1000, 1)
Z shape: (1000, 1)


## Location Translation

In [14]:
loc_translation_flow.frugal_flow.shape

(2,)

In [18]:
loc_translation_flow.frugal_flow.bijection.bijections[1].bijection.bijection.bijections[0]

AttributeError: 'MaskedAutoregressiveFirstUniform' object has no attribute 'bijections'

In [6]:
loc_translation_flow = FrugalFlowModel(Y=Y, X=X, Z_cont=Z_cont, Z_disc=None, confounding_copula=None)

In [7]:
loc_translation_flow.train_benchmark_model(
    training_seed=jr.PRNGKey(0), 
    marginal_hyperparam_dict=hyperparams_dict, 
    frugal_hyperparam_dict=hyperparams_dict, 
    prop_flow_hyperparam_dict=hyperparams_dict,
    causal_model='location_translation', 
    causal_model_args={'ate': 0., **hyperparams_dict}
)

  2%|█▎                                                       | 239/10000 [00:07<05:15, 30.99it/s, train=1.7414076769779299, val=1.8218569844313255 (Max patience reached)]
  2%|▉                                                         | 155/10000 [00:19<20:47,  7.89it/s, train=1.267292955951178, val=1.7333332475228291 (Max patience reached)]
  1%|▋                                                    | 121/10000 [00:05<07:00, 23.48it/s, train=-0.43901179122748935, val=-0.05080161866032616 (Max patience reached)]


In [8]:
synthetic_samples = loc_translation_flow.generate_samples(
    key=jr.PRNGKey(1),
    sampling_size=(1000),
    copula_param=0,
    outcome_causal_model='location_translation',
    outcome_causal_args={'ate': 2.},
    with_confounding=True
)



Y shape: (1000, 1)
X shape: (1000, 1)
Z shape: (1000, 1)


## Logistic Flow

In [9]:
logistic_flow = FrugalFlowModel(
    Y=jnp.array(df['Y'].values)[:, None], 
    X=jnp.array(df['X'].values)[:, None],
    Z_cont=jnp.array(df['Z1'].values)[:, None],
)

In [10]:
logistic_flow.train_benchmark_model(
    training_seed=jr.PRNGKey(0), 
    marginal_hyperparam_dict=hyperparams_dict, 
    frugal_hyperparam_dict=hyperparams_dict, 
    prop_flow_hyperparam_dict=hyperparams_dict,
    causal_model='gaussian', 
    causal_model_args={'ate': jnp.array([0.]), 'const': 0., 'scale': 1.}
)

  2%|█▎                                                       | 239/10000 [00:07<05:17, 30.78it/s, train=1.7414076769779299, val=1.8218569844313255 (Max patience reached)]
  2%|█▏                                                       | 211/10000 [00:08<06:31, 25.02it/s, train=1.2475766119435763, val=1.4532597539683303 (Max patience reached)]
  1%|▋                                                    | 121/10000 [00:05<07:08, 23.08it/s, train=-0.43901179122748935, val=-0.05080161866032616 (Max patience reached)]


In [11]:
synthetic_samples = logistic_flow.generate_samples(
    key=jr.PRNGKey(1),
    sampling_size=(1000),
    copula_param=0,
    outcome_causal_model='logistic_regression',
    outcome_causal_args={'ate': jnp.array([2.]), 'const': -1.},
    with_confounding=True
)

Y shape: (1000, 1)
X shape: (1000, 1)
Z shape: (1000, 1)
