In [None]:
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
from itertools import combinations, product
import statsmodels.api as sm
import pandas as pd
from scipy.stats import mannwhitneyu

In [2]:
import sys
sys.path.append('/app/python_lib/src')
from robustinfer.drgu import DRGU

# Define Functions

In [3]:
class KeyGen:
    def __init__(self, seed=0):
        self.key = jax.random.PRNGKey(seed)

    def next(self):
        self.key, subkey = jax.random.split(self.key)
        return subkey

rng = KeyGen(42)

In [4]:
def simulation_linear_data(n, ate, error_func):
    eta_true = jnp.array([ 0.0, -0.6])
    w   = jax.random.normal(rng.next(), (n, 1))   # covariates
    Wt  = jnp.concatenate([jnp.ones((n,1)), w], axis=1)
    pi_true    = jax.nn.sigmoid(Wt @ eta_true)                # true propensity
    z          = jax.random.bernoulli(rng.next(), pi_true).astype(jnp.float32)

    # beta_true = jnp.array([0.0, 0.0, 1.0])
    error = error_func(rng.next(), (n,))
    beta_true = jnp.array([ 0.0, ate, 1.0])
    X = jnp.concatenate([jnp.ones((n,1)), z[:,None]*1.0,w], axis=1)
    y = X @ beta_true + error
    return y, z, w, Wt

# Analysis on one simulated data

In [5]:
class ErrorFunc:
    name = ''
    f = None

    def __init__(self, name, f):
            self.name = name
            self.f = f

lognormal_error_func = ErrorFunc('lognormal', lambda key, shape: jax.random.lognormal(key, shape=shape))
y, x, w, _ = simulation_linear_data(n=100, ate=2.0, error_func=lognormal_error_func.f)
    

In [6]:
y.shape, x.shape, w.shape

((100,), (100,), (100, 1))

In [7]:
df = pd.DataFrame({
    "y": y,
    "x": x,
    **{f"w{i+1}": w[:, i] for i in range(w.shape[1])}  # Dynamically create column names for w
})


In [8]:
df.head()

Unnamed: 0,y,x,w1
0,1.279513,0.0,0.605764
1,3.452183,1.0,0.799044
2,1.433406,1.0,-0.908927
3,1.577025,1.0,-0.635258
4,10.63337,0.0,-1.222659


In [9]:
# Example usage
model = DRGU(df, covariates=["w1"], treatment="x", response="y")
model.fit()

Step 0 gradient norm: 0.7455853819847107
converged after 6 iterations


In [10]:
model.coefficients

Array([ 0.78363335, -0.03352755, -0.965426  ,  2.208195  ,  1.3163865 ,
       -1.1250285 ], dtype=float32)

In [11]:
model.variance_matrix

Array([[ 3.0269942e-03,  9.2025424e-05,  6.7656621e-04,  2.2690054e-02,
         4.0019527e-03, -9.8145586e-03],
       [ 9.2025446e-05,  4.5719426e-02,  1.0030746e-02,  2.0557018e-03,
         2.9265408e-03, -6.2924484e-03],
       [ 6.7656656e-04,  1.0030745e-02,  8.2298696e-02, -6.8560508e-03,
         1.7015442e-02,  4.0472071e-03],
       [ 2.2690052e-02,  2.0557120e-03, -6.8560443e-03,  2.8485453e-01,
         1.1894996e-01, -1.2228911e-01],
       [ 4.0019518e-03,  2.9265499e-03,  1.7015440e-02,  1.1895007e-01,
         1.0504008e-01, -3.2049537e-02],
       [-9.8145576e-03, -6.2924516e-03,  4.0472108e-03, -1.2228926e-01,
        -3.2049544e-02,  1.4148051e-01]], dtype=float32)

In [12]:
model.summary()

Unnamed: 0,Names,Coefficient,Null_Hypothesis,Std_Error,Z_Score,P_Value
0,delta,0.783633,0.5,0.055018,5.155271,2.532643e-07
1,beta_0,-0.033528,0.0,0.213821,-0.156802,0.875401
2,beta_1,-0.965426,0.0,0.286877,-3.36529,0.0007646314
3,gamma_0,2.208195,0.0,0.533718,4.137384,3.512873e-05
4,gamma_1,1.316386,0.0,0.324099,4.061682,4.872047e-05
5,gamma_2,-1.125028,0.0,0.376139,-2.990992,0.002780731
