In [None]:
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
from itertools import combinations
import statsmodels.api as sm
import statsmodels.formula.api as smf
import pandas as pd

# Define Functions

In [None]:
class KeyGen:
    def __init__(self, seed=0):
        self.key = jax.random.PRNGKey(seed)

    def next(self):
        self.key, subkey = jax.random.split(self.key)
        return subkey

rng = KeyGen(42)

In [None]:
# -----------------------------------------------------------------------------
# Longitudinal data-generator
# -----------------------------------------------------------------------------
def simulate_longitudinal_data(n,
                              T,
                              eta_true,      # length 2  (propensity)
                              beta_true,     # length 3  (outcome)
                              error_func):
    """
    Simulate (y_{it}, z_{it}, w_{it}) for i=1..n, t=1..T.

    • Treatment model :     π_{it} = σ([1,w_{it}]·η)
    • Outcome model   :     y_{it} = [1,z_{it},w_{it}]·β + ε_{it}
    """
    # --- time-varying confounder --------------------------------------------
    w  = jax.random.normal(rng.next(), shape=(n, T))        # (n,T)
    Wz = jnp.concatenate([jnp.ones((n, 1)), w], axis=1) # (n,T+1)

    # --- treatment assignment -----------------------------------------------
    pi_true = jax.nn.sigmoid(Wz @ eta_true)  # (n,)
    z       = jax.random.bernoulli(rng.next(), p=pi_true).astype(jnp.float32)

    # --- outcome -------------------------------------------------------------
    z_T = jnp.broadcast_to(z[:, None], (n, T))
    eps = error_func(rng.next(), shape=(n, T))                  # (n,T)
    X   = jnp.concatenate([jnp.ones((n, T, 1)),             # 1
                            z_T[..., None],                 # z_{it}
                            w[:, :, None]], axis=2)                   # w_{it}   ⇒ (n,T,3)
    # assume same effect overtime
    y   = jnp.einsum('ntp,p->nt', X, beta_true) + eps      # (n,T)

    return y, z, w, Wz, z_T, X                   # w squeezed to (n,T)


In [None]:
# ------------------------------
# helper matrices for all pairs
# ------------------------------
# design for g(w_it,w_jt;γt) and g(w_jt,w_it;γt)
def make_Xg(a,b):
    return jnp.concatenate([jnp.ones_like(a), a, b], axis=2)  # [1, w_it, w_jt]

def data_pairwise(y, z, w):
    n = y.shape[0]
    T = y.shape[1]
    tri_u, tri_v = jnp.triu_indices(n, k=1)                   # i<j indices
    m            = tri_u.size                                 # #pairs

    wi, wj    = w[tri_u], w[tri_v]                 # (m,T,p)
    z_T = jnp.broadcast_to(z[:, None], (n, T))
    zi, zj    = z_T[tri_u], z_T[tri_v]
    yi, yj    = y[tri_u], y[tri_v]

    Wz  = jnp.concatenate([jnp.ones((n,1)), w], axis=1)
    Wz_T = jnp.broadcast_to(Wz[:, None, :], (n, T, 1+T*p))
    Wz_i, Wz_j = Wz_T[tri_u],          Wz_T[tri_v]                # (m,2, p+1)
    Xg_ij, Xg_ji = make_Xg(wi[...,None],wj[...,None]), make_Xg(wj[...,None],wi[...,None])     # (m,2p+1)

    return {
        'Xg_ij': Xg_ij,
        'Xg_ji': Xg_ji,
        'Wz_i': Wz_i,
        'Wz_j': Wz_j,
        'yi': yi,
        'yj': yj,
        'zi': zi,
        'zj': zj,
        'wi': wi,
        'wj': wj,
        'i': tri_u,
        'j': tri_v
    }

In [None]:

def safe_sigmoid(x):
    return jax.nn.sigmoid(jnp.clip(x, -10.0, 10.0))

# @jax.jit
def compute_h_f(theta, V_inv, data):
    # All inputs
    Wz_i, Wz_j = data['Wz_i'], data['Wz_j']
    Xg_ij, Xg_ji = data['Xg_ij'], data['Xg_ji']
    yi, yj = data['yi'], data['yj']
    zi, zj = data['zi'], data['zj']
    m = zi.shape[0]

    delta, beta, gamma = theta["delta"], theta["beta"], theta["gamma"]

    V_inv = V_inv

    # predictions
    # p_delta = safe_sigmoid(delta)
    pi_i  = safe_sigmoid(jnp.einsum('ntp,p->nt', Wz_i, beta))
    pi_j  = safe_sigmoid(jnp.einsum('ntp,p->nt', Wz_j, beta))
    g_ij = safe_sigmoid(jnp.einsum('ntp,p->nt', Xg_ij, gamma)) # equivalent jnp.sum(Xg_ij * gamma, axis=2)
    g_ji = safe_sigmoid(jnp.einsum('ntp,p->nt', Xg_ji, gamma))
    # indicators
    I_ij = (yi >= yj).astype(jnp.float32)
    I_ji = 1. - I_ij
    # h vector (3‑component) for all pairs
    num1 = zi*(1-zj)/(2*pi_i*(1-pi_j)) * (I_ij - g_ij)
    num2 = zj*(1-zi)/(2*pi_j*(1-pi_i)) * (I_ji - g_ji)
    h1   = num1 + num2 + 0.5*(g_ij + g_ji)
    h2   = 0.5*(zi + zj)[:, 0] # take only the first column
    h3   = 0.5*(zi*(1-zj)*I_ij + zj*(1-zi)*I_ji)
    h    = jnp.concatenate([h1, h2[:, None], h3], axis=1)                # (m,1+2*T)
    # f vector
    f1   = jnp.full_like(h1, delta)
    f2   = 0.5*(pi_i + pi_j)[:, 0] # take only the first column
    f3   = 0.5*(pi_i*(1-pi_j)*g_ij + pi_j*(1-pi_i)*g_ji)
    f    = jnp.concatenate([f1,f2[:, None],f3], axis=1)                  # (m,1+2*T)
    return h, f

# @jax.jit
def compute_f(theta, V_inv, data):
    _, f = compute_h_f(theta, V_inv, data)
    return f

def compute_u_ij(theta, V_inv, data):
    h, f = compute_h_f(theta, V_inv, data)
    D = jax.jacfwd(compute_f, argnums=0)(theta, V_inv, data)
    D_ij = jnp.concatenate([D['delta'], D['beta'], D['gamma']], axis=2)

    G_ij = jnp.transpose(D_ij, (0, 2, 1)) @ V_inv
    S_ij = h - f
    u_ij = jnp.einsum('npc,nc->np', G_ij, S_ij)

    return u_ij

@jax.jit
def U_n(theta, V_inv, data):
    return jnp.mean(compute_u_ij(theta, V_inv, data), axis=0)


In [None]:
def compute_h_f_fisher(theta, V_inv, data):
    return compute_h_f(theta, V_inv, data)

def compute_h_fisher(theta, V_inv, data):
    h, _ = compute_h_f_fisher(theta, V_inv, data)
    return h

def compute_f_fisher(theta, V_inv, data):
    _, f = compute_h_f_fisher(theta, V_inv, data)
    return f

@jax.jit
def compute_B_u_ij(theta, V_inv, data):
    h, f = compute_h_f_fisher(theta, V_inv, data)
    D = jax.jacfwd(compute_f_fisher, argnums=0)(theta, V_inv, data)
    D_ij = jnp.concatenate([D['delta'], D['beta'], D['gamma']], axis=2)

    M0 = jax.jacfwd(compute_h_fisher, argnums=0)(theta, V_inv, data)
    M0_ij = jnp.concatenate([M0['delta'], M0['beta'], M0['gamma']], axis=2)
    M_ij = D_ij - M0_ij

    G_ij = jnp.transpose(D_ij, (0, 2, 1)) @ V_inv
    B_ij = jnp.einsum('npq,nqc->npc', G_ij, M_ij)
    B = jnp.mean(B_ij, axis=0)

    S_ij = h - f
    u_ij = jnp.einsum('npc,nc->np', G_ij, S_ij)

    return B, u_ij

def compute_B_U(theta, V_inv, data):
    B, u_ij = compute_B_u_ij(theta, V_inv, data)
    U = jnp.mean(u_ij, axis=0)
    return B, U

def compute_B_U_Sig(theta, V_inv, data, compute_sig=False):
    B, u_ij = compute_B_u_ij(theta, V_inv, data)
    U = jnp.mean(u_ij, axis=0)

    n = jnp.maximum(jnp.max(data['i']), jnp.max(data['j'])) + 1
    d = u_ij.shape[1]
    u_i = jnp.zeros((n,d)).at[data['i']].add(u_ij).at[data['j']].add(u_ij)/n
    sig_i = jnp.einsum('np,nq->npq', u_i, u_i)

    # Sig_ij = jnp.einsum('np,nq->npq', u_ij, u_ij)
    Sig = jnp.mean(sig_i, axis=0)

    return B, U, Sig

In [None]:
def compute_delta(theta, V_inv, data, lamb=0.0, option="fisher"):
    if option == 'dir':
        raise NotImplementedError
        # U  = U_n_dir(theta, V_inv, data)
        # J_dict  = jax.jacfwd(U_n_dir, argnums=0)(theta, V_inv, data)
        # J = jnp.concatenate([J_dict['delta'], J_dict['beta'], J_dict['gamma']], axis=1)
    elif option == "jax":
        U  = U_n(theta, V_inv, data)
        J_dict  = jax.jacfwd(U_n, argnums=0)(theta, V_inv, data)
        J = jnp.concatenate([J_dict['delta'], J_dict['beta'], J_dict['gamma']], axis=1)
    elif option == "fisher":
        B, U = compute_B_U(theta, V_inv, data)
        J = -B
    else:
        raise ValueError(f"Unknown option {option}")
    step = jnp.linalg.solve(J+ lamb * jnp.eye(J.shape[0]), -U)
    return step, J

def update_theta(theta, step):
    start = 0
    for k,v in theta.items():
        theta[k] += step[start:start+v.size]
        start += v.size
    return theta


In [None]:
def solve_longitudinal_ugee(data, theta_init, max_iter=100, tol=1e-6, lamb=0.0, option="fisher", verbose=True):
    T = data['yi'].shape[1]
    V_inv = jnp.eye(1+2*T)
    theta = {k: v.copy() for k, v in theta_init.items()}
    for i in range(max_iter):
        step, J = compute_delta(theta, V_inv, data, lamb, option)
        # jax.debug.print("Step {i}: {x}", i=i, x=step)
        if i % 10 == 0 and verbose:
            jax.debug.print("Step {i} gradient norm: {x}", i=i, x=jnp.linalg.norm(step))
        theta = update_theta(theta, step)
        if jnp.linalg.norm(step) < tol:
            if verbose:
                print(f"converged after {i} iterations")
            break
    if i == max_iter-1 and verbose:
      print(f"did not converge, norm step = {jnp.linalg.norm(step)}")
    B, U, Sig = compute_B_U_Sig(theta, V_inv, data, compute_sig=True)
    B_inv = jnp.linalg.inv(B)
    Var = 4 * B_inv.T @ Sig @ B_inv
    return theta, J, Var


# Analysis

In [None]:
from logging import error
n = 100
T = 2
p = 1
eta_true = jnp.array([0.0, 0.5, 0.2])      # length T+1  (propensity)
beta_true = jnp.array([0.0, 0.5, -0.6])     # length ?  (outcome)
error_func = lambda key, shape: jax.random.normal(key, shape)
y, z, w, Wz, z_T, X = simulate_longitudinal_data(n, T, eta_true, beta_true, error_func)

data = data_pairwise(y, z, w)
V_inv = jnp.eye(1+2*T)

In [None]:
theta_init = {
    "delta": jnp.array([0.5]),
    "beta": jax.random.normal(rng.next(), (1+T*p, )),
    "gamma": jax.random.normal(rng.next(), (1+2*p, )),
}

In [None]:
# compute_B_u_ij(theta_init, V_inv, data)[1].shape
# compute_B_U(theta_init, V_inv, data)[1].shape
# compute_delta(theta_init, V_inv, data, lamb=0.0, option="fisher")
T = data['yi'].shape[1]
V_inv = jnp.eye(1+2*T)
theta = theta_init
step, J = compute_delta(theta, V_inv, data, lamb=0.0, option='fisher')
theta = update_theta(theta, step)
B, U, Sig = compute_B_U_Sig(theta, V_inv, data, compute_sig=True)
B_inv = jnp.linalg.inv(B)
Var = 4 * B_inv.T @ Sig @ B_inv

In [None]:
jnp.diag(Var)

In [None]:
theta, J, Var = solve_longitudinal_ugee(data, theta_init, option = "fisher", verbose=True)

# Simulation

In [None]:
def make_longitudinal_df(y, treatment, w):
    """
    Parameters
    ----------
    y          : array (n, T)     outcome for n subjects at T time-points
    treatment  : array (n,)       0/1 or categorical indicator per subject
    w          : array (n, T)     covariate for n subjects at T time-points, only works because there's one convariate

    Returns
    -------
    df         : pandas DataFrame with columns
                 ['y', 'treatment', 'time', 'id']  and n*T rows
    """
    n, T = y.shape
    assert y.shape == w.shape
    df = pd.DataFrame({
        "y"        : y.reshape(-1),                 # (n*T,)
        "treatment": np.repeat(treatment, T),       # (n*T,)
        "confounder": w.reshape(-1),
        "time"     : np.tile(np.arange(T), n),      # (n*T,)
        "id"       : np.repeat(np.arange(n), T)     # (n*T,)
    })
    return df

def fit_gee(y, z, w):
  fam        = sm.families.Gaussian()               # use Binomial() for logistic GEE
  cov_struct = sm.cov_struct.Exchangeable()
  df = make_longitudinal_df(y, z, w)
  gee_mod  = smf.gee("y ~ 1 + treatment + confounder",                # mean model μ = β₀ + β₁*treatment
                    groups="id",                    # clustering / subject id
                    data=df,
                    family=fam,
                    cov_struct=cov_struct)

  gee_res  = gee_mod.fit()
  return gee_res

In [None]:
def fit_ols(y, z, w):
  n = y.shape[0]
  T = y.shape[1]
  ols_mod = sm.OLS(np.asarray(y[:, -1]), np.asarray(jnp.concatenate([jnp.ones((n,1)), z[:,None], w[:,-1][:,None]], axis=1)), hasconst=True)
  ols_res = ols_mod.fit()
  return ols_res
# print(ols_res.summary())

In [None]:
class ErrorFunc:
    name = ''
    f = None

    def __init__(self, name, f):
            self.name = name
            self.f = f

In [None]:
def run_simulation(n, error_func: ErrorFunc, ate=0.0, n_sim=100, verbose=False, max_iter=100):
  eta_true = jnp.array([ 0.0, -0.3, -0.6])
  beta_true = jnp.array([0.0, ate, 1.0])
  T = 2
  p = 1
  z_vec = []
  pval_vec = []
  pval_ols_vec = []
  for i in range(n_sim):
    jax.clear_caches
    y, x, w, _, _, _ = simulate_longitudinal_data(n, T, eta_true, beta_true, error_func.f)
    data = data_pairwise(y, x, w)
    theta_init = {
        "delta": jnp.array([0.5]),
        "beta": jax.random.normal(rng.next(), (1+T*p, )),
        "gamma": jax.random.normal(rng.next(), (1+2*p, )),
    }
    theta, J, Var = solve_longitudinal_ugee(data, theta_init, max_iter = max_iter, option = "fisher", verbose = verbose)

    se = np.sqrt(Var[0][0]/n)
    z = abs(theta['delta']-0.5)/se
    z_vec.append(z)

    gee_res = fit_gee(y, x, w)
    pval = gee_res.pvalues["treatment"]
    pval_vec.append(pval)

    ols_res = fit_ols(y, x, w)
    pval_ols_vec.append(ols_res.pvalues[1])

  z_array = jnp.concat(z_vec)
  return [{'n': n, 'ate': ate, 'errfunc': error_func.name, 'siglevel': 0.05, 'ugee': np.mean(jax.numpy.abs(z_array) > 1.96), 'gee': (jnp.asarray(pval_vec) < 0.05).mean(), 'ols': (jnp.asarray(pval_ols_vec) < 0.05).mean()},
          {'n': n, 'ate': ate, 'errfunc': error_func.name, 'siglevel': 0.01, 'ugee': np.mean(jax.numpy.abs(z_array) > 2.575829), 'gee': (jnp.asarray(pval_vec) < 0.01).mean(), 'ols': (jnp.asarray(pval_ols_vec) < 0.01).mean()}
           ]

## Type I error

In [None]:
normal_error_func = ErrorFunc('normal', lambda key, shape: jax.random.normal(key, shape))
lognormal_error_func = ErrorFunc('lognormal', lambda key, shape: jax.random.lognormal(key, shape=shape))
cauchy_error_func = ErrorFunc('cauchy', lambda key, shape: jax.random.lognormal(key, shape=shape))
laplace_error_func = ErrorFunc('laplace', lambda key, shape: jax.random.laplace(key, shape=shape))

In [None]:
p = 1
n_sim = 1000

In [None]:
run_simulation(n=50, error_func=normal_error_func, ate=0.0, n_sim=n_sim, verbose=False)

In [None]:
run_simulation(n=200, error_func=normal_error_func, ate=0.0, n_sim=n_sim, verbose=False)

In [None]:
run_simulation(n=50, error_func=lognormal_error_func, ate=0.0, n_sim=n_sim, verbose=False)

In [None]:
run_simulation(n=200, error_func=lognormal_error_func, ate=0.0, n_sim=n_sim, verbose=False)

In [None]:
run_simulation(n=50, error_func=cauchy_error_func, ate=0.0, n_sim=n_sim, verbose=False)

In [None]:
run_simulation(n=200, error_func=cauchy_error_func, ate=0.0, n_sim=n_sim, verbose=False)

# Power

In [None]:
run_simulation(n=50, error_func=normal_error_func, ate=0.5, n_sim=n_sim, verbose=False)

In [None]:
run_simulation(n=200, error_func=normal_error_func, ate=0.5, n_sim=n_sim, verbose=False)

In [None]:
run_simulation(n=50, error_func=lognormal_error_func, ate=0.5, n_sim=n_sim, verbose=False)

In [None]:
run_simulation(n=200, error_func=lognormal_error_func, ate=0.5, n_sim=n_sim, verbose=False)

In [None]:
run_simulation(n=50, error_func=cauchy_error_func, ate=0.5, n_sim=n_sim, verbose=False)

In [None]:
run_simulation(n=200, error_func=cauchy_error_func, ate=0.5, n_sim=n_sim, verbose=False)