In [1]:
import jax
jax.config.update('jax_enable_x64',True)
import jax.numpy as jnp
import pandas as pd
from jax.random import PRNGKey as pkey
from jax.scipy.special import expit
from jaxopt import LBFGS
import statsmodels.api as sm


In [2]:
beta_true = 1.
num_covs = 1
num_subjects_per_group = 1000
num_groups = 3
total_subjects = num_subjects_per_group*num_groups
num_exposures = num_groups -1 
intercept_val = 0.2
confounding_constant = 1.
assignment = jnp.vstack([jnp.arange(num_groups)]*num_subjects_per_group).T.flatten().reshape(-1,1)
covs = jnp.hstack([jax.random.normal(pkey(12),(total_subjects,num_covs))]) + confounding_constant*assignment
beta_covs = jax.random.normal(pkey(13),(num_covs,))
full_beta = jnp.hstack([beta_true,intercept_val,beta_covs])


#This is just for simulating
X_sim = jnp.hstack([assignment,jnp.ones((total_subjects,1)),covs])

Y_obs = jax.random.binomial(pkey(1),1,expit(X_sim@full_beta))

In [3]:
def neg_log_lik(z,p,n):
    loss = -1*n*(p*jnp.log(expit(z)) + (1-p)*jnp.log(1-expit(z)))
    return jnp.sum(loss)

@jax.jit
def loss(beta,X,Y):
    return neg_log_lik(X@beta,Y,1.)

In [4]:
dummies = pd.get_dummies(pd.Series(assignment[:,0]),prefix='treat',drop_first=True).astype('float64')
X_reg = jnp.hstack([jnp.ones((total_subjects,1)),covs,dummies.values])

df_reg = pd.DataFrame(X_reg,columns = ['intercept']+['covariate']+list(dummies.columns))

In [5]:
solver = LBFGS(loss,jit = True)
beta_init = jnp.zeros(len(X_reg[0]))
result = solver.run(beta_init,X_reg,Y_obs)
beta_reg = result.params
L = beta_reg[-num_exposures:]
cov_reg = jnp.linalg.inv(jax.hessian(loss,argnums=0)(result.params,X_reg,Y_obs))

cov_L = cov_reg[-num_exposures:,-num_exposures:]
V = jnp.diag(cov_reg)[-num_exposures:]

In [6]:
#It matches statsmodels
# model = sm.Logit(Y_obs, df_reg)
# results = model.fit()
# results.summary()

In [7]:
# Set up inverse GLM

df_summary = pd.DataFrame({
    "exposure":assignment[:,0],
    "Y":Y_obs,
    "cov":covs[:,0]
})

P = df_summary.groupby('exposure')['Y'].mean().values
N = df_summary.groupby("exposure").size().values

inv_GLM_design = jnp.hstack(
    [
        jnp.ones((num_groups,1)),#Intercept for everyone
        jnp.identity(num_groups)[:,1:]# Dummies for exposure
        ])
L_offset = jnp.hstack([jnp.zeros(1),L])

@jax.jit
def loss_offset(beta,X,Y,offset,N):
    return neg_log_lik(X@beta+offset,Y,N)

In [8]:
solver = LBFGS(loss_offset,jit = True,tol = 1e-5)
beta_init = jnp.zeros(len(inv_GLM_design[0]))
result = solver.run(beta_init,inv_GLM_design,P,L_offset,N)
inv_glm_sol = result.params

In [9]:
#Flip back over

CC_recover_beta = jnp.hstack([inv_glm_sol[0],L])
CC_recover_offset = jnp.hstack([jnp.zeros(1),inv_glm_sol[1:]])
H = jax.hessian(loss_offset,argnums = 0)(CC_recover_beta,inv_GLM_design,P,CC_recover_offset,N)

cov_iglm = jnp.linalg.inv(H)

In [10]:
sd_iglm = jnp.sqrt(jnp.diag(cov_iglm[-num_exposures:,-num_exposures:]))
CC_iglm = jnp.diag(jnp.sqrt(V)/sd_iglm)@cov_iglm[-num_exposures:,-num_exposures:]@jnp.diag(jnp.sqrt(V)/sd_iglm)
CC_iglm

Array([[0.01169197, 0.00515849],
       [0.00515849, 0.02400847]], dtype=float64)

In [11]:
cov_L

Array([[0.01169197, 0.00742085],
       [0.00742085, 0.02400847]], dtype=float64)

In [12]:
jnp.hstack([inv_glm_sol[0] - intercept_val,inv_glm_sol[1:]])

Array([0.00471194, 0.14893243, 0.29368091], dtype=float64)

In [13]:
(df_summary.groupby('exposure')['cov'].mean()*beta_covs).values

array([0.00890422, 0.18378993, 0.35276649])

In [14]:
CC_recover_beta

Array([0.20471194, 0.93548627, 1.91708523], dtype=float64)

In [42]:
pseudocov = inv_GLM_design[:,1:]@inv_glm_sol[1:]
pseudo_design = jnp.hstack([jnp.ones((num_groups,1))*inv_glm_sol[0]+pseudocov.reshape(-1,1),inv_GLM_design[:,1:]])
iglm_cc_beta = jnp.hstack([jnp.ones(1),L])


H = jax.hessian(loss_offset,argnums = 0)(iglm_cc_beta,pseudo_design,P,jnp.zeros(3),N)
alt_cov_iglm = jnp.linalg.inv(H)[-2:,-2:]
sd_iglm_alt = jnp.sqrt(jnp.diag(alt_cov_iglm))

In [43]:
CC_iglm = jnp.diag(jnp.sqrt(V)/sd_iglm_alt)@alt_cov_iglm@jnp.diag(jnp.sqrt(V)/sd_iglm_alt)
CC_iglm

Array([[0.01169197, 0.01101055],
       [0.01101055, 0.02400847]], dtype=float64)