# Logistic Sampling

We present an example showing how logistic outcomes can be sampled from exactly.

In [1]:
import itertools
import random
import sys
import os
sys.path.append("../") # go to parent dir

import jax
import jax.random as jr
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import numpy as np
from scipy.stats import rankdata
from jax.scipy.special import expit
import scipy.stats as ss
import seaborn as sns
from sklearn.model_selection import KFold

jnp.set_printoptions(precision=2)

# from data.create_sim_data import *
import data_processing_and_simulations.causl_sim_data_generation as causl_py
from data_processing_and_simulations.run_all_simulations import plot_simulation_results
from frugal_flows.causal_flows import independent_continuous_marginal_flow, get_independent_quantiles, train_frugal_flow, train_copula_flow
from frugal_flows.bijections import UnivariateNormalCDF
from frugal_flows.benchmarking import FrugalFlowModel

import rpy2.robjects as ro
from rpy2.robjects.packages import importr
from rpy2.robjects import pandas2ri
from rpy2.robjects.packages import SignatureTranslatedAnonymousPackage
import wandb

# Activate automatic conversion of rpy2 objects to pandas objects
pandas2ri.activate()

# Import the R library causl
try:
    causl = importr('causl')
except Exception as e:
    package_names = ('causl')
    utils.install_packages(StrVector(package_names))


hyperparams_dict = {
    'learning_rate': 5e-3,
    'RQS_knots': 8,
    'flow_layers': 5,
    'nn_width': 30,
    'nn_depth': 4,    
    'max_patience': 100,
    'max_epochs': 10000
}

jax.config.update("jax_enable_x64", True)

In [2]:
%%time
mixed_cont_rscript = f"""
library(causl)
forms <- list(list(Z1 ~ 1), X ~ Z1, Y ~ X, ~ 1)
fams <- list(1, 5, 1, 1)
pars <- list(Z1 = list(beta=0, phi=2),
             X = list(beta=c(0,2)),
             Y = list(beta=c(0,2), phi=1),
             cop = list(beta=matrix(c(0.8), nrow=1)))



set.seed(1234)
n <- 1e3

data_samples <- rfrugalParam(n, formulas = forms, family = fams, pars = pars)
"""
rcode_compiled = SignatureTranslatedAnonymousPackage(mixed_cont_rscript, "powerpack")
df = rcode_compiled.data_samples

R[write to console]: Inversion method selected: using pair-copula parameterization



CPU times: user 10.3 ms, sys: 3.21 ms, total: 13.5 ms
Wall time: 13 ms


In [3]:
logistic_flow = FrugalFlowModel(
    Y=jnp.array(df['Y'].values)[:, None], 
    X=jnp.array(df['X'].values)[:, None],
    Z_cont=jnp.array(df['Z1'].values)[:, None],
)

In [4]:
logistic_flow.train_benchmark_model(
    training_seed=jr.PRNGKey(0), 
    marginal_hyperparam_dict=hyperparams_dict, 
    frugal_hyperparam_dict=hyperparams_dict, 
    prop_flow_hyperparam_dict=hyperparams_dict,
    causal_model='gaussian', 
    causal_model_args={'ate': jnp.array([0.]), 'const': 0., 'scale': 1.}
)

  2%| | 239/10000 [00:13<08:56, 18.21it/s, train=1.7414300
  2%| | 225/10000 [00:10<07:39, 21.28it/s, train=1.2491213
  1%| | 130/10000 [00:06<07:53, 20.83it/s, train=-0.411746


In [5]:
synthetic_samples = logistic_flow.generate_samples(
    key=jr.PRNGKey(1),
    sampling_size=(1000),
    copula_param=0,
    outcome_causal_model='logistic_regression',
    outcome_causal_args={'ate': jnp.array([2.]), 'const': -1.},
    with_confounding=True
)

Y shape: (1000, 1)
X shape: (1000, 1)
Z shape: (1000, 1)


In [6]:
Y0, Y1 = synthetic_samples.groupby('X')['Y'].mean().values
print(Y1/Y0)

5.226337448559671


In [7]:
expit(1)/expit(-1)

Array(2.72, dtype=float64, weak_type=True)

In [8]:
synthetic_samples

Unnamed: 0,Y,X,Z_1
0,1.0,1.0,0.535241
1,1.0,1.0,2.026317
2,1.0,1.0,0.962257
3,1.0,1.0,-1.259690
4,0.0,0.0,-1.058330
...,...,...,...
995,0.0,1.0,-1.060602
996,0.0,0.0,-2.467215
997,0.0,0.0,-0.571012
998,1.0,0.0,0.731236


In [9]:
import pandas as pd
import rpy2.robjects as ro
from rpy2.robjects import pandas2ri
from rpy2.robjects.packages import importr

# Activate the pandas2ri conversion
pandas2ri.activate()

# Import necessary R libraries
base = importr('base')
stats = importr('stats')
survey = importr('survey')

r_df = pandas2ri.py2rpy(synthetic_samples)
ro.globalenv['dat'] = r_df

In [10]:
# Define the R code as a string
r_code = """
library(survey)

glmX <- glm(X ~ Z_1, family=binomial, data=dat)
glmX_coefficients <- summary(glmX)$coefficients

ps <- predict(glmX, type="response")
wt <- dat$X/ps + (1-dat$X)/(1-ps)
glmY <- svyglm(Y ~ X, family=quasibinomial(), design = svydesign(~1, weights=wt, data=dat))
glmY_coefficients <- summary(glmY)$coefficients

glmY_OR <- glm(Y ~ X, family=binomial, data=dat)
glmY_OR_coefficients <- summary(glmY_OR)$coefficients

list(glmX_coefficients = glmX_coefficients, glmY_coefficients = glmY_coefficients, glmY_OR_coefficients = glmY_OR_coefficients)
"""

# Execute the R code
result = ro.r(r_code)

In [11]:
result.rx2('glmY_OR_coefficients')

array([[-1.66e+00,  1.21e-01, -1.37e+01,  8.08e-43],
       [ 3.27e+00,  1.71e-01,  1.91e+01,  2.21e-81]])

First columns are the means, second columns are the std errors. True values are -1 and +2. Weighted GLM gets the right values!

In [12]:
result.rx2('glmY_coefficients')[:, :2]

array([[-0.84,  0.22],
       [ 1.64,  0.32]])

In [13]:
result.rx2('glmY_OR_coefficients')[:, :2]

array([[-1.66,  0.12],
       [ 3.27,  0.17]])