In [None]:
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
from itertools import combinations, product
import statsmodels.api as sm
import pandas as pd
from scipy.stats import mannwhitneyu

# Define Functions

In [3]:
class KeyGen:
    def __init__(self, seed=0):
        self.key = jax.random.PRNGKey(seed)

    def next(self):
        self.key, subkey = jax.random.split(self.key)
        return subkey

rng = KeyGen(42)

In [4]:
def simulation_linear_data(n, ate, error_func):
    eta_true = jnp.array([ 0.0, -0.6])
    w   = jax.random.normal(rng.next(), (n, 1))   # covariates
    Wt  = jnp.concatenate([jnp.ones((n,1)), w], axis=1)
    pi_true    = jax.nn.sigmoid(Wt @ eta_true)                # true propensity
    z          = jax.random.bernoulli(rng.next(), pi_true).astype(jnp.float32)

    # beta_true = jnp.array([0.0, 0.0, 1.0])
    error = error_func(rng.next(), (n,))
    beta_true = jnp.array([ 0.0, ate, 1.0])
    X = jnp.concatenate([jnp.ones((n,1)), z[:,None]*1.0,w], axis=1)
    y = X @ beta_true + error
    return y, z, w, Wt

In [5]:
def simulation_wrong_pi_data(n, ate, error_func):
    eta_true = jnp.array([ 0.0, -0.2, 0.6])
    w   = jax.random.normal(rng.next(), (n, 1))   # covariates
    w2   = jax.random.normal(rng.next(), (n, 1))   # covariates
    Wt  = jnp.concatenate([jnp.ones((n,1)), w, w2], axis=1)
    pi_true    = jax.nn.sigmoid(Wt @ eta_true)                # true propensity
    z          = jax.random.bernoulli(rng.next(), pi_true).astype(jnp.float32)

    # beta_true = jnp.array([0.0, 0.0, 1.0])
    error = error_func(rng.next(), (n,))
    beta_true = jnp.array([ 0.0, ate, 1.0])
    X = jnp.concatenate([jnp.ones((n,1)), z[:,None]*1.0,w], axis=1)
    y = X @ beta_true + error
    return y, z, w, Wt

In [6]:
def simulation_wrong_g_data(n, ate, error_func):
    eta_true = jnp.array([ 0.0, -0.6])
    w   = jax.random.normal(rng.next(), (n, 1))   # covariates
    Wt  = jnp.concatenate([jnp.ones((n,1)), w], axis=1)
    pi_true    = jax.nn.sigmoid(Wt @ eta_true)                # true propensity
    z          = jax.random.bernoulli(rng.next(), pi_true).astype(jnp.float32)

    # beta_true = jnp.array([0.0, 0.0, 1.0])
    error = error_func(rng.next(), (n,))
    beta_true = jnp.array([ 0.0, ate, 1.0, -0.1])
    X = jnp.concatenate([jnp.ones((n,1)), z[:,None]*1.0, w, jnp.square(w)], axis=1)
    y = X @ beta_true + error
    return y, z, w, Wt

In [7]:
# ------------------------------
# helper matrices for all pairs
# ------------------------------
# design for g(w_i,w_j;γ) and g(w_j,w_i;γ)
def make_Xg(a,b):
    return jnp.concatenate([jnp.ones_like(a), a, b], axis=1)  # [1, w_i, w_j]

def data_pairwise(y, z, w):
    n = y.size
    Wt  = jnp.concatenate([jnp.ones((n,1)), w], axis=1)

    tri_u, tri_v = jnp.triu_indices(n, k=1)                   # i<j indices
    m            = tri_u.size                                 # #pairs

    wi, wj    = w[tri_u],            w[tri_v]                 # (m,p)
    zi, zj    = z[tri_u],            z[tri_v]
    yi, yj    = y[tri_u],            y[tri_v]

    Wt_i, Wt_j = Wt[tri_u],          Wt[tri_v]                # (m,p+1)
    Xg_ij, Xg_ji = make_Xg(wi,wj), make_Xg(wj,wi)               # (m,2p+1)

    return {
        'Xg_ij': Xg_ij,
        'Xg_ji': Xg_ji,
        'Wt_i': Wt_i,
        'Wt_j': Wt_j,
        'yi': yi,
        'yj': yj,
        'zi': zi,
        'zj': zj,
        'wi': wi,
        'wj': wj,
        'i': tri_u,
        'j': tri_v
    }

def wrong_pi_data_pairwise(y, z, w, Wt):
    n = y.size
    tri_u, tri_v = jnp.triu_indices(n, k=1)                   # i<j indices
    m            = tri_u.size                                 # #pairs

    wi, wj    = w[tri_u],            w[tri_v]                 # (m,p)
    zi, zj    = z[tri_u],            z[tri_v]
    yi, yj    = y[tri_u],            y[tri_v]

    Wt_i, Wt_j = Wt[tri_u],          Wt[tri_v]                # (m,p+1)
    Xg_ij, Xg_ji = make_Xg(wi,wj), make_Xg(wj,wi)               # (m,2p+1)

    return {
        'Xg_ij': Xg_ij,
        'Xg_ji': Xg_ji,
        'Wt_i': Wt_i,
        'Wt_j': Wt_j,
        'yi': yi,
        'yj': yj,
        'zi': zi,
        'zj': zj,
        'wi': wi,
        'wj': wj,
        'i': tri_u,
        'j': tri_v
    }

In [8]:
from sklearn.linear_model import LogisticRegression
import numpy as np

def get_init(data, Wt, z):
    yi, yj = data['yi'], data['yj']
    zi, zj = data['zi'], data['zj']
    wi, wj = data['wi'], data['wj']
    Xg_ij, Xg_ji = data['Xg_ij'], data['Xg_ji']
    Wt_i, Wt_j = data['Wt_i'], data['Wt_j']

    I_ij = (yi >= yj).astype(jnp.float32)
    I_ji = 1. - I_ij
    h3   = zi*(1-zj)*I_ij + zj*(1-zi)*I_ji
    z_logistic = LogisticRegression(random_state=0, fit_intercept=False).fit(Wt, z)
    u_logistic = LogisticRegression(random_state=0, fit_intercept=False).fit((zi*(1-zj))[:,None]*Xg_ij + (zj*(1-zi))[:,None]*Xg_ji, h3)

    beta = jnp.array(z_logistic.coef_[0])
    gamma = jnp.array(u_logistic.coef_[0])
    u_ij = u_logistic.predict_proba(Xg_ij)[:,1]
    u_ji = u_logistic.predict_proba(Xg_ji)[:,1]
    delta_reg = 0.5 * np.mean(zi*(1-zj)*(I_ij - u_ij) + zj*(1-zi)*(I_ji - u_ji) + (u_ij + u_ji))
    bi = z_logistic.predict_proba(Wt_i)[:,1]
    bj = z_logistic.predict_proba(Wt_j)[:,1]
    delta_ipw = 0.5*jnp.mean(zi*(1-zj)/(bi*(1-bj))*I_ij + zj*(1-zi)/(bj*(1-bi))*I_ji)

    return beta, gamma, jnp.array([delta_reg]), jnp.array([delta_ipw])

In [9]:

def safe_sigmoid(x):
    return jax.nn.sigmoid(jnp.clip(x, -10.0, 10.0))

# @jax.jit
def compute_h_f(theta, V_inv, data):
    # All inputs
    Wt_i, Wt_j = data['Wt_i'], data['Wt_j']
    Xg_ij, Xg_ji = data['Xg_ij'], data['Xg_ji']
    yi, yj = data['yi'], data['yj']
    zi, zj = data['zi'], data['zj']
    m = zi.size

    delta, beta, gamma = theta["delta"], theta["beta"], theta["gamma"]

    V_inv = V_inv

    # predictions
    # p_delta = safe_sigmoid(delta)
    pi_i  = safe_sigmoid(jnp.sum(Wt_i * beta, axis=1))
    pi_j  = safe_sigmoid(jnp.sum(Wt_j * beta, axis=1))
    g_ij = safe_sigmoid(jnp.sum(Xg_ij * gamma, axis=1))
    g_ji = safe_sigmoid(jnp.sum(Xg_ji * gamma, axis=1))

    # indicators
    I_ij = (yi >= yj).astype(jnp.float32)
    I_ji = 1. - I_ij

    # h vector (3‑component) for all pairs
    num1 = zi*(1-zj)/(2*pi_i*(1-pi_j)) * (I_ij - g_ij)
    num2 = zj*(1-zi)/(2*pi_j*(1-pi_i)) * (I_ji - g_ji)
    h1   = num1 + num2 + 0.5*(g_ij + g_ji)
    h2   = 0.5*(zi + zj)
    h3   = 0.5*(zi*(1-zj)*I_ij + zj*(1-zi)*I_ji)
    h    = jnp.stack([h1,h2,h3], axis=1)                  # (m,3)

    # f vector
    f1   = jnp.full_like(h1, delta)
    f2   = 0.5*(pi_i + pi_j)
    f3   = 0.5*(pi_i*(1-pi_j)*g_ij + pi_j*(1-pi_i)*g_ji)
    f    = jnp.stack([f1,f2,f3], axis=1)
    return h, f

# @jax.jit
def compute_f(theta, V_inv, data):
    _, f = compute_h_f(theta, V_inv, data)
    return f

def compute_u_ij(theta, V_inv, data):
    h, f = compute_h_f(theta, V_inv, data)
    D = jax.jacfwd(compute_f, argnums=0)(theta, V_inv, data)
    D_ij = jnp.concatenate([D['delta'], D['beta'], D['gamma']], axis=2)

    G_ij = jnp.transpose(D_ij, (0, 2, 1)) @ V_inv
    S_ij = h - f
    u_ij = jnp.einsum('npc,nc->np', G_ij, S_ij)

    return u_ij

@jax.jit
def U_n(theta, V_inv, data):
    return jnp.mean(compute_u_ij(theta, V_inv, data), axis=0)


In [10]:
def compute_h_f_fisher(theta, V_inv, data):
    # All inputs
    Wt_i, Wt_j = data['Wt_i'], data['Wt_j']
    Xg_ij, Xg_ji = data['Xg_ij'], data['Xg_ji']
    yi, yj = data['yi'], data['yj']
    zi, zj = data['zi'], data['zj']
    m = zi.size

    delta, beta, gamma = theta["delta"], theta["beta"], theta["gamma"]

    V_inv = V_inv

    # predictions
    # p_delta = safe_sigmoid(delta)
    pi_i  = safe_sigmoid(jnp.sum(Wt_i * beta, axis=1))
    pi_j  = safe_sigmoid(jnp.sum(Wt_j * beta, axis=1))
    g_ij = safe_sigmoid(jnp.sum(Xg_ij * gamma, axis=1))
    g_ji = safe_sigmoid(jnp.sum(Xg_ji * gamma, axis=1))

    # indicators
    I_ij = (yi >= yj).astype(jnp.float32)
    I_ji = 1. - I_ij

    # h vector (3‑component) for all pairs
    num1 = zi*(1-zj)/(2*pi_i*(1-pi_j)) * (I_ij - g_ij)
    num2 = zj*(1-zi)/(2*pi_j*(1-pi_i)) * (I_ji - g_ji)
    h1   = num1 + num2 + 0.5*(g_ij + g_ji)
    h2   = 0.5*(zi + zj)
    h3   = 0.5*(zi*(1-zj)*I_ij + zj*(1-zi)*I_ji)
    h    = jnp.stack([h1,h2,h3], axis=1)                  # (m,3)

    # f vector
    f1   = jnp.full_like(h1, delta)
    f2   = 0.5*(pi_i + pi_j)
    f3   = 0.5*(pi_i*(1-pi_j)*g_ij + pi_j*(1-pi_i)*g_ji)
    f    = jnp.stack([f1,f2,f3], axis=1)
    return h, f

def compute_h_fisher(theta, V_inv, data):
    h, _ = compute_h_f_fisher(theta, V_inv, data)
    return h

def compute_f_fisher(theta, V_inv, data):
    _, f = compute_h_f_fisher(theta, V_inv, data)
    return f

@jax.jit
def compute_B_u_ij(theta, V_inv, data):
    h, f = compute_h_f_fisher(theta, V_inv, data)
    D = jax.jacfwd(compute_f_fisher, argnums=0)(theta, V_inv, data)
    D_ij = jnp.concatenate([D['delta'], D['beta'], D['gamma']], axis=2)

    M0 = jax.jacfwd(compute_h_fisher, argnums=0)(theta, V_inv, data)
    M0_ij = jnp.concatenate([M0['delta'], M0['beta'], M0['gamma']], axis=2)
    M_ij = D_ij - M0_ij

    G_ij = jnp.transpose(D_ij, (0, 2, 1)) @ V_inv
    B_ij = jnp.einsum('npq,nqc->npc', G_ij, M_ij)
    B = jnp.mean(B_ij, axis=0)

    S_ij = h - f
    u_ij = jnp.einsum('npc,nc->np', G_ij, S_ij)

    return B, u_ij

def compute_B_U(theta, V_inv, data):
    B, u_ij = compute_B_u_ij(theta, V_inv, data)
    U = jnp.mean(u_ij, axis=0)
    return B, U


def compute_B_U_Sig(theta, V_inv, data, compute_sig=False):
    B, u_ij = compute_B_u_ij(theta, V_inv, data)
    U = jnp.mean(u_ij, axis=0)

    n = jnp.maximum(jnp.max(data['i']), jnp.max(data['j'])) + 1
    d = u_ij.shape[1]
    u_i = jnp.zeros((n,d)).at[data['i']].add(u_ij).at[data['j']].add(u_ij)/n
    sig_i = jnp.einsum('np,nq->npq', u_i, u_i)

    # Sig_ij = jnp.einsum('np,nq->npq', u_ij, u_ij)
    Sig = jnp.mean(sig_i, axis=0)

    return B, U, Sig

In [11]:

def compute_u(theta, V_inv, data):
    # All inputs
    Wt_i, Wt_j = data['Wt_i'], data['Wt_j']
    Xg_ij, Xg_ji = data['Xg_ij'], data['Xg_ji']
    yi, yj = data['yi'], data['yj']
    zi, zj = data['zi'], data['zj']
    m = zi.size

    delta, beta, gamma = theta["delta"], theta["beta"], theta["gamma"]

    V_inv = V_inv

    # predictions
    # p_delta = safe_sigmoid(delta)
    pi_i  = safe_sigmoid(jnp.sum(Wt_i * beta, axis=1))
    pi_j  = safe_sigmoid(jnp.sum(Wt_j * beta, axis=1))
    g_ij = safe_sigmoid(jnp.sum(Xg_ij * gamma, axis=1))
    g_ji = safe_sigmoid(jnp.sum(Xg_ji * gamma, axis=1))

    # indicators
    I_ij = (yi >= yj).astype(jnp.float32)
    I_ji = 1. - I_ij

    # h vector (3‑component) for all pairs
    num1 = zi*(1-zj)/(2*pi_i*(1-pi_j)) * (I_ij - g_ij)
    num2 = zj*(1-zi)/(2*pi_j*(1-pi_i)) * (I_ji - g_ji)
    h1   = num1 + num2 + 0.5*(g_ij + g_ji)
    h2   = 0.5*(zi + zj)
    h3   = 0.5*(zi*(1-zj)*I_ij + zj*(1-zi)*I_ji)
    h    = jnp.stack([h1,h2,h3], axis=1)                  # (m,3)

    # f vector
    f1   = jnp.full_like(h1, delta)
    f2   = 0.5*(pi_i + pi_j)
    f3   = 0.5*(pi_i*(1-pi_j)*g_ij + pi_j*(1-pi_i)*g_ji)
    f    = jnp.stack([f1,f2,f3], axis=1)

    # derivatives (D_ij rows) -------------------------
    # D_ij is (3, 3p+3)
    # helpers
    # pi and g both have logit link:
    grad_pi_i  = (pi_i*(1-pi_i))[:,None] * Wt_i              # (m,p+1)
    grad_pi_j  = (pi_j*(1-pi_j))[:,None] * Wt_j
    grad_g_ij = (g_ij*(1-g_ij))[:,None] * Xg_ij              # (m,2p+1)
    grad_g_ji = (g_ji*(1-g_ji))[:,None] * Xg_ji
    A = (1-pi_j)*g_ij - pi_j*g_ji
    B = (1-pi_i)*g_ji - pi_i*g_ij
    C = pi_i*(1-pi_j)
    D = pi_j*(1-pi_i)

    # row 1: derivative only w.r.t delta
    row1 = jnp.concatenate([jnp.ones((m,1)),
                            jnp.zeros((m,len(beta)+len(gamma)))], axis=1)

    # row 2: derivative w.r.t beta
    row2 = jnp.concatenate([jnp.zeros((m,1)),
                            0.5*(grad_pi_i + grad_pi_j),
                            jnp.zeros((m,len(gamma)))], axis=1)

    # row 3: derivative w.r.t beta and gamma
    row3beta = grad_pi_i * A[:,None] + grad_pi_j * B[:,None]   # (m,p+1)
    row3gamma = grad_g_ij * C[:,None] + grad_g_ji * D[:,None] # (m,2p+1)
    row3  = jnp.concatenate([jnp.zeros((m,1)), 0.5*row3beta, 0.5*row3gamma], axis=1)

    # D_ij^T (h-f)  --------------> contribution to U_n(theta)
    # NOTE: confirmed that this is correct
    diff            = h - f                                     # (m,3)
    D_ij            = jnp.stack([row1, row2, row3], axis=2)
    G_ij = D_ij @ V_inv
    S_ij = h - f
    u_ij = jnp.einsum('npc,nc->np', G_ij, S_ij)

    return u_ij

@jax.jit
def U_n_dir(theta, V_inv, data):
    return jnp.mean(compute_u(theta, V_inv, data), axis=0)


In [12]:
def compute_delta(theta, V_inv, data, lamb=0.0, option="fisher"):
    if option == 'dir':
        U  = U_n_dir(theta, V_inv, data)
        J_dict  = jax.jacfwd(U_n_dir, argnums=0)(theta, V_inv, data)
        J = jnp.concatenate([J_dict['delta'], J_dict['beta'], J_dict['gamma']], axis=1)
    elif option == "jax":
        U  = U_n(theta, V_inv, data)
        J_dict  = jax.jacfwd(U_n, argnums=0)(theta, V_inv, data)
        J = jnp.concatenate([J_dict['delta'], J_dict['beta'], J_dict['gamma']], axis=1)
    elif option == "fisher":
        B, U = compute_B_U(theta, V_inv, data)
        J = -B
    else:
        raise ValueError(f"Unknown option {option}")
    step = jnp.linalg.solve(J+ lamb * jnp.eye(J.shape[0]), -U)
    return step, J

def update_theta(theta, step):
    start = 0
    for k,v in theta.items():
        theta[k] += step[start:start+v.size]
        start += v.size
    return theta


In [20]:
def solve_ugee(data, theta_init, max_iter=100, tol=1e-6, lamb=0.0, option="fisher", verbose=True):
    V_inv = jnp.eye(3)
    theta = {k: v.copy() for k, v in theta_init.items()}
    for i in range(max_iter):
        step, J = compute_delta(theta, V_inv, data, lamb, option)
        # jax.debug.print("Step {i}: {x}", i=i, x=step)
        if i % 10 == 0 and verbose:
            jax.debug.print("Step {i} gradient norm: {x}", i=i, x=jnp.linalg.norm(step))
        theta = update_theta(theta, step)
        if jnp.linalg.norm(step) < tol:
            if verbose:
                print(f"converged after {i} iterations")
            break
    if i == max_iter-1 and verbose:
      print(f"did not converge, norm step = {jnp.linalg.norm(step)}")
    B, U, Sig = compute_B_U_Sig(theta, V_inv, data, compute_sig=True)
    B_inv = jnp.linalg.inv(B)
    Var = 4 * B_inv.T @ Sig @ B_inv
    # Var = 4 * B_inv @ Sig @ B_inv.T
    return theta, J, Var


# Simulation

In [21]:
def run_simulation(n, error_func, sim_data_func, ate=0.0, n_sim=100, verbose=False, max_iter=100):
  eta_true = jnp.array([ 0.0, -0.6])
  beta_true = jnp.array([0.0, ate, 1.0])
  z_vec = []
  pval_vec = []
  pval_u_vec = []
  for i in range(n_sim):
    y, x, w, _ = sim_data_func(n, ate, error_func.f)
    data = data_pairwise(y, x, w)
    Wt = np.asarray(jnp.concatenate([jnp.ones((n,1)), w], axis=1))
    beta, gamma, delta_reg, delta_ipw = get_init(data, Wt, x)
    theta_init_best = {
        "delta": delta_ipw,
        "beta": beta,
        "gamma": gamma,
    }
    theta, J, Var = solve_ugee(data, theta_init_best, max_iter = max_iter, option = "fisher", verbose = verbose)

    se = np.sqrt(Var[0][0]/n)
    z = abs(theta['delta']-0.5)/se
    z_vec.append(z)

    ols_mod = sm.OLS(np.asarray(y), np.asarray(jnp.concatenate([jnp.ones((n,1)), x[:,None], w], axis=1)), hasconst=True)
    res = ols_mod.fit()
    pval_vec.append(res.pvalues[1])

    stat, pval_u = mannwhitneyu(y[x==0], y[x>0], alternative='two-sided', method='auto')
    pval_u_vec.append(pval_u)

  z_array = jnp.concat(z_vec)
  return [{'n': n, 'ate': ate, 'errfunc': error_func.name, 'siglevel': 0.05, 'ugee': np.mean(jax.numpy.abs(z_array) > 1.96), 'ols': (jnp.asarray(pval_vec) < 0.05).mean(), 'u': (jnp.asarray(pval_u_vec) < 0.05).mean()},
          {'n': n, 'ate': ate, 'errfunc': error_func.name, 'siglevel': 0.01, 'ugee': np.mean(jax.numpy.abs(z_array) > 2.575829), 'ols': (jnp.asarray(pval_vec) < 0.01).mean(), 'u': (jnp.asarray(pval_u_vec) < 0.01).mean()}
           ]

In [22]:
def run_wrong_pi_simulation(n, error_func, ate=0.0, n_sim=100, verbose=False, max_iter=100):
  eta_true = jnp.array([ 0.0, -0.2, 0.6])
  beta_true = jnp.array([0.0, ate, 1.0])
  z_vec = []
  z_mis_vec = []
  pval_vec = []
  pval_u_vec = []
  for i in range(n_sim):
    y, x, w, Wt = simulation_wrong_pi_data(n, ate, error_func.f)

    # correctly specified DR UGEE
    data = wrong_pi_data_pairwise(y, x, w, Wt)
    theta_init_best = {
        "delta": jnp.array([0.5]),
        "beta": jax.random.normal(rng.next(), shape=(3,)),
        "gamma": jax.random.normal(rng.next(), shape=(3,)),
    }
    theta, J, Var = solve_ugee(data, theta_init_best, max_iter = max_iter, option = "fisher", verbose = verbose)
    se = np.sqrt(Var[0][0]/n)
    z = abs(theta['delta']-0.5)/se
    z_vec.append(z)

    # misspecified propensity model DR UGEE
    mis_data = data_pairwise(y, x, w)
    theta_init_mis = {
        "delta": jnp.array([0.5]),
        "beta": jax.random.normal(rng.next(), shape=(2,)),
        "gamma": jax.random.normal(rng.next(), shape=(3,)),
    }
    theta, J, Var = solve_ugee(mis_data, theta_init_mis, max_iter = max_iter, option = "fisher", verbose = verbose)
    se = np.sqrt(Var[0][0]/n)
    z = abs(theta['delta']-0.5)/se
    z_mis_vec.append(z)

    # correctly specified Linear regression
    ols_mod = sm.OLS(np.asarray(y), np.asarray(jnp.concatenate([jnp.ones((n,1)), x[:,None], w], axis=1)), hasconst=True)
    res = ols_mod.fit()
    pval_vec.append(res.pvalues[1])

    # vanilla U
    stat, pval_u = mannwhitneyu(y[x==0], y[x>0], alternative='two-sided', method='auto')
    pval_u_vec.append(pval_u)

  z_array = jnp.concat(z_vec)
  z_mis_array = jnp.concat(z_mis_vec)
  return [{
            'n': n, 'ate': ate, 'errfunc': error_func.name, 'siglevel': 0.05,
            'ugee': np.mean(jax.numpy.abs(z_array) > 1.96), 'misugee': np.mean(jax.numpy.abs(z_mis_array) > 1.96),
            'ols': (jnp.asarray(pval_vec) < 0.05).mean(), 'u': (jnp.asarray(pval_u_vec) < 0.05).mean()},
          {
            'n': n, 'ate': ate, 'errfunc': error_func.name, 'siglevel': 0.01,
            'ugee': np.mean(jax.numpy.abs(z_array) > 2.575829), 'misugee': np.mean(jax.numpy.abs(z_mis_array) > 2.575829),
            'ols': (jnp.asarray(pval_vec) < 0.01).mean(), 'u': (jnp.asarray(pval_u_vec) < 0.01).mean()}
           ]

In [23]:
class ErrorFunc:
    name = ''
    f = None

    def __init__(self, name, f):
            self.name = name
            self.f = f

## Power and Type I error simulation


In [24]:
normal_error_func = ErrorFunc('normal', lambda key, shape: jax.random.normal(key, shape))
lognormal_error_func = ErrorFunc('lognormal', lambda key, shape: jax.random.lognormal(key, shape=shape))
cauchy_error_func = ErrorFunc('cauchy', lambda key, shape: jax.random.lognormal(key, shape=shape))
laplace_error_func = ErrorFunc('laplace', lambda key, shape: jax.random.laplace(key, shape=shape))

In [25]:
n_sim = 200
res = []
n_list = [50, 200]
ate_list = [0.0, 0.5]
error_func_list = [normal_error_func, lognormal_error_func, cauchy_error_func]
for n, ate, ef in product(n_list, ate_list, error_func_list):
    print(f"***************** n={n}, ate={ate}, errfunc={ef.name}")
    res.extend(run_wrong_pi_simulation(n=n, error_func=ef, ate=ate, n_sim=n_sim, verbose=False))

***************** n=50, ate=0.0, errfunc=normal
***************** n=50, ate=0.0, errfunc=lognormal
***************** n=50, ate=0.0, errfunc=cauchy
***************** n=50, ate=0.5, errfunc=normal
***************** n=50, ate=0.5, errfunc=lognormal
***************** n=50, ate=0.5, errfunc=cauchy
***************** n=200, ate=0.0, errfunc=normal
***************** n=200, ate=0.0, errfunc=lognormal
***************** n=200, ate=0.0, errfunc=cauchy
***************** n=200, ate=0.5, errfunc=normal
***************** n=200, ate=0.5, errfunc=lognormal
***************** n=200, ate=0.5, errfunc=cauchy


In [26]:
pd.DataFrame(res)

Unnamed: 0,n,ate,errfunc,siglevel,ugee,misugee,ols,u
0,50,0.0,normal,0.05,0.035,0.02,0.049999997,0.065
1,50,0.0,normal,0.01,0.01,0.01,0.005,0.024999999
2,50,0.0,lognormal,0.05,0.035,0.015,0.049999997,0.07
3,50,0.0,lognormal,0.01,0.01,0.0,0.005,0.015
4,50,0.0,cauchy,0.05,0.07,0.035,0.02,0.07
5,50,0.0,cauchy,0.01,0.005,0.005,0.0,0.02
6,50,0.5,normal,0.05,0.26,0.21499999,0.35999998,0.065
7,50,0.5,normal,0.01,0.13499999,0.105,0.145,0.015
8,50,0.5,lognormal,0.05,0.29,0.21,0.225,0.12
9,50,0.5,lognormal,0.01,0.16499999,0.105,0.07,0.024999999


In [19]:
pd.DataFrame(res)

Unnamed: 0,n,ate,errfunc,siglevel,ugee,misugee,ols,u
0,50,0.0,normal,0.05,0.08,0.06,0.055,0.08
1,50,0.0,normal,0.01,0.01,0.005,0.005,0.01
2,50,0.0,lognormal,0.05,0.07,0.06,0.049999997,0.099999994
3,50,0.0,lognormal,0.01,0.03,0.03,0.0,0.035
4,50,0.0,cauchy,0.05,0.049999997,0.049999997,0.03,0.044999998
5,50,0.0,cauchy,0.01,0.02,0.01,0.01,0.01
6,50,0.5,normal,0.05,0.31,0.24499999,0.35999998,0.099999994
7,50,0.5,normal,0.01,0.155,0.125,0.13499999,0.035
8,50,0.5,lognormal,0.05,0.285,0.285,0.19,0.11
9,50,0.5,lognormal,0.01,0.155,0.14999999,0.074999996,0.03
