# Summary:
This is implimentation of longitudinal DR U with compound symetric working correlation on $R(\alpha)$


# UGEE code

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
from itertools import combinations
import statsmodels.api as sm
import statsmodels.formula.api as smf
import pandas as pd

In [None]:
class KeyGen:
    def __init__(self, seed=0):
        self.key = jax.random.PRNGKey(seed)

    def next(self):
        self.key, subkey = jax.random.split(self.key)
        return subkey

rng = KeyGen(42)

In [None]:
# ------------------------------
# helper matrices for all pairs
# ------------------------------
# design for g(w_it,w_jt;γt) and g(w_jt,w_it;γt)
def make_Xg(a,b):
    return jnp.concatenate([jnp.ones_like(a), a, b], axis=2)  # [1, w_it, w_jt]

def data_pairwise(y, z, w):
    n = y.shape[0]
    T = y.shape[1]

    tri_u, tri_v = jnp.triu_indices(n, k=1)                   # i<j indices
    m            = tri_u.size                                 # #pairs

    wi, wj    = w[tri_u], w[tri_v]                 # (m,T,p)
    zi, zj    = z[tri_u], z[tri_v]
    yi, yj    = y[tri_u], y[tri_v]
    w_init    = w[:,0][:, None] # use w at first time slot

    Wz  = jnp.concatenate([jnp.ones((n,1)), w_init], axis=1)
    Wz_i, Wz_j = Wz[tri_u],          Wz[tri_v]                # (m,p+1)
    Xg_ij, Xg_ji = make_Xg(wi[...,None],wj[...,None]), make_Xg(wj[...,None],wi[...,None])     # (m,T, p+1)

    return {
        'Xg_ij': Xg_ij,
        'Xg_ji': Xg_ji,
        'Wz_i': Wz_i,
        'Wz_j': Wz_j,
        'yi': yi,
        'yj': yj,
        'zi': zi,
        'zj': zj,
        'wi': wi,
        'wj': wj,
        'i': tri_u,
        'j': tri_v,
        'w': w
    }

In [None]:
# inital values
from sklearn.linear_model import LogisticRegression

def get_init(data):
    I_ij = (data['yi'] >= data['yj']).astype(jnp.float32)[:,:,None]
    I_ji = 1. - I_ij
    zi = data['zi'][:,None,None]
    zj = data['zj'][:,None,None]
    h3 = zi*(1-zj)*I_ij + zj*(1-zi)*I_ji
    Xg = zi*(1-zj)* data['Xg_ij'] + zj*(1-zi)* data['Xg_ji']
    w = data['w']
    w_init = w[:,0][:, None]
    n = w_init.shape[0]

    #choose only none zero features?
    # h3 = h3[(data['zi'] + data['zj'] == 1)]
    # Xg = Xg[(data['zi'] + data['zj'] == 1)]
    h3_flat = h3.reshape(-1, h3.shape[-1])
    Xg_flat = Xg.reshape(-1, Xg.shape[-1])

    z_logistic = LogisticRegression(random_state=0, fit_intercept=False).fit(jnp.concatenate([jnp.ones((n, 1)), w_init], axis=1), z)
    g_logistic = LogisticRegression(random_state=0, fit_intercept=False).fit(Xg_flat, h3_flat.reshape(-1))
    return {
        "delta": jnp.array([0.5]),
        "beta": jnp.array(z_logistic.coef_[0]),
        "gamma": jnp.array(g_logistic.coef_[0]),
    }

In [None]:
def safe_sigmoid(x):
    return jax.nn.sigmoid(jnp.clip(x, -10.0, 10.0))

# @jax.jit
def compute_h_f(theta, V_inv, data):
    # All inputs
    Wz_i, Wz_j = data['Wz_i'], data['Wz_j']
    Xg_ij, Xg_ji = data['Xg_ij'], data['Xg_ji']
    yi, yj = data['yi'], data['yj']
    zi, zj = data['zi'], data['zj']
    m = zi.shape[0]

    delta, beta, gamma = theta["delta"], theta["beta"], theta["gamma"]

    V_inv = V_inv

    # predictions
    pi_i  = safe_sigmoid(jnp.sum(Wz_i * beta, axis=1))
    pi_j  = safe_sigmoid(jnp.sum(Wz_j * beta, axis=1))
    g_ij = safe_sigmoid(jnp.einsum('ntp,p->nt', Xg_ij, gamma)) # equivalent jnp.sum(Xg_ij * gamma, axis=2)
    g_ji = safe_sigmoid(jnp.einsum('ntp,p->nt', Xg_ji, gamma))
    # indicators
    # tie handling
    I_ij = (yi > yj).astype(jnp.float32) + 0.5*(yi == yj).astype(jnp.float32)
    I_ji = 1. - I_ij
    # h vector (3‑component) for all pairs
    zi = jnp.reshape(zi, (-1, 1))
    zj = jnp.reshape(zj, (-1, 1))
    pi_i = jnp.reshape(pi_i, (-1, 1))
    pi_j = jnp.reshape(pi_j, (-1, 1))
    num1 = zi*(1-zj)/(2*pi_i*(1-pi_j)) * (I_ij - g_ij)
    num2 = zj*(1-zi)/(2*pi_j*(1-pi_i)) * (I_ji - g_ji)
    h1   = num1 + num2 + 0.5*(g_ij + g_ji)
    h2   = 0.5*(zi + zj)
    h3   = 0.5*(zi*(1-zj)*I_ij + zj*(1-zi)*I_ji)
    h    = jnp.concatenate([h1, h2, h3], axis=1)                # (m,1+2*T)
    # f vector
    f1   = jnp.full_like(h1, delta)
    f2   = 0.5*(pi_i + pi_j)
    f3   = 0.5*(pi_i*(1-pi_j)*g_ij + pi_j*(1-pi_i)*g_ji)
    f    = jnp.concatenate([f1,f2,f3], axis=1)
    return h, f

# @jax.jit
def compute_f(theta, V_inv, data):
    _, f = compute_h_f(theta, V_inv, data)
    return f

def compute_u_ij(theta, V_inv, data):
    h, f = compute_h_f(theta, V_inv, data)
    D = jax.jacfwd(compute_f, argnums=0)(theta, V_inv, data)
    D_ij = jnp.concatenate([D['delta'], D['beta'], D['gamma']], axis=2)

    G_ij = jnp.transpose(D_ij, (0, 2, 1)) @ V_inv
    S_ij = h - f
    u_ij = jnp.einsum('npc,nc->np', G_ij, S_ij)

    return u_ij

@jax.jit
def U_n(theta, V_inv, data):
    return jnp.mean(compute_u_ij(theta, V_inv, data), axis=0)

In [None]:
def compute_h_f_fisher(theta, V_inv, data):
    return compute_h_f(theta, V_inv, data)

def compute_h_fisher(theta, V_inv, data):
    h, _ = compute_h_f_fisher(theta, V_inv, data)
    return h

def compute_f_fisher(theta, V_inv, data):
    _, f = compute_h_f_fisher(theta, V_inv, data)
    return f

@jax.jit
def compute_B_u_ij(theta, V_inv, data):
    h, f = compute_h_f_fisher(theta, V_inv, data)
    D = jax.jacfwd(compute_f_fisher, argnums=0)(theta, V_inv, data)
    D_ij = jnp.concatenate([D['delta'], D['beta'], D['gamma']], axis=2)

    M0 = jax.jacfwd(compute_h_fisher, argnums=0)(theta, V_inv, data)
    M0_ij = jnp.concatenate([M0['delta'], M0['beta'], M0['gamma']], axis=2)
    M_ij = D_ij - M0_ij

    G_ij = jnp.transpose(D_ij, (0, 2, 1)) @ V_inv
    B_ij = jnp.einsum('npq,nqc->npc', G_ij, M_ij)
    B = jnp.mean(B_ij, axis=0)

    S_ij = h - f
    u_ij = jnp.einsum('npc,nc->np', G_ij, S_ij)

    return B, u_ij

def compute_B_U(theta, V_inv, data):
    B, u_ij = compute_B_u_ij(theta, V_inv, data)
    U = jnp.mean(u_ij, axis=0)
    return B, U

def compute_B_U_Sig(theta, V_inv, data, compute_sig=False):
    B, u_ij = compute_B_u_ij(theta, V_inv, data)
    U = jnp.mean(u_ij, axis=0)

    n = jnp.maximum(jnp.max(data['i']), jnp.max(data['j'])) + 1
    d = u_ij.shape[1]
    u_i = jnp.zeros((n,d)).at[data['i']].add(u_ij).at[data['j']].add(u_ij)/n
    sig_i = jnp.einsum('np,nq->npq', u_i, u_i)

    # Sig_ij = jnp.einsum('np,nq->npq', u_ij, u_ij)
    Sig = jnp.mean(sig_i, axis=0)

    return B, U, Sig

In [None]:
def compute_delta(theta, V_inv, data, lamb=0.0, option="fisher"):
    if option == 'dir':
        raise NotImplementedError
        # U  = U_n_dir(theta, V_inv, data)
        # J_dict  = jax.jacfwd(U_n_dir, argnums=0)(theta, V_inv, data)
        # J = jnp.concatenate([J_dict['delta'], J_dict['beta'], J_dict['gamma']], axis=1)
    elif option == "jax":
        U  = U_n(theta, V_inv, data)
        J_dict  = jax.jacfwd(U_n, argnums=0)(theta, V_inv, data)
        J = jnp.concatenate([J_dict['delta'], J_dict['beta'], J_dict['gamma']], axis=1)
    elif option == "fisher":
        B, U = compute_B_U(theta, V_inv, data)
        J = -B
    else:
        raise ValueError(f"Unknown option {option}")
    step = jnp.linalg.solve(J+ lamb * jnp.eye(J.shape[0]), -U)
    return step, J

def update_theta(theta, step):
    start = 0
    for k,v in theta.items():
        theta[k] += step[start:start+v.size]
        start += v.size
    return theta


In [None]:
def average_pairwise_corr(X_sub):
    T = X_sub.shape[1]
    corr_matrix = np.corrcoef(X_sub, rowvar=False)  # shape (T, T)
    upper_tri = np.triu(corr_matrix, k=1)  # zero out diagonal and lower triangle
    count = T * (T - 1) / 2
    return upper_tri.sum() / count

def compound_symmetric_corr(T, rho):
    I = np.eye(T)
    J = np.ones((T, T))
    return (1 - rho) * I + rho * J

def get_V_inv(A_vec, rho_h1, rho_h3):
    T = (len(A_vec) - 1) // 2

    # 1. Build R
    R1 = compound_symmetric_corr(T, rho_h1)         # First T x T block
    R_mid = np.array([[1.0]])                       # Scalar 1 for center
    R3 = compound_symmetric_corr(T, rho_h3)         # Last T x T block

    # Block diagonal R
    R = np.block([
        [R1,               np.zeros((T, 1)), np.zeros((T, T))],
        [np.zeros((1, T)), R_mid,           np.zeros((1, T))],
        [np.zeros((T, T)), np.zeros((T, 1)), R3]
    ])

    # 2. Construct D = diag(sqrt(A_vec))
    A_sqrt = np.sqrt(A_vec)
    D = np.diag(A_sqrt)
    D_inv = np.diag(1 / A_sqrt)

    # 3. V_inv = D^{-1} @ R^{-1} @ D^{-1}
    R_inv = np.linalg.inv(R)
    V_inv = D_inv @ R_inv @ D_inv
    V = D @ R @ D

    return V_inv, V

def calculate_V_inv(theta, V_inv, data):
    h, f = compute_h_f_fisher(theta, V_inv, data)
    error = np.array(h - f)
    A = np.var(error, axis=0, ddof=1)
    T = (len(A) - 1) // 2

    rho_h1 = average_pairwise_corr(error[:,:T])
    rho_h3 = average_pairwise_corr(error[:,-T:])

    V_inv, V = get_V_inv(A, rho_h1, rho_h3)

    return jnp.array(V_inv), jnp.array(V)

In [None]:
def solve_longitudinal_ugee(data, theta_init, max_iter=100, tol=1e-6, lamb=0.0, option="fisher", verbose=True):
    T = data['yi'].shape[1]
    V_inv = jnp.eye(1+2*T)
    theta = {k: v.copy() for k, v in theta_init.items()}
    for i in range(max_iter):
        step, J = compute_delta(theta, V_inv, data, lamb, option)
        # jax.debug.print("Step {i}: {x}", i=i, x=step)
        if i % 10 == 0 and verbose:
            jax.debug.print("Step {i} gradient norm: {x}", i=i, x=jnp.linalg.norm(step))
        theta = update_theta(theta, step)
        if jnp.linalg.norm(step) < tol:
            if verbose:
                print(f"converged after {i} iterations")
            break
    if i == max_iter-1 and verbose:
      print(f"did not converge, norm step = {jnp.linalg.norm(step)}")
    B, U, Sig = compute_B_U_Sig(theta, V_inv, data, compute_sig=True)
    B_inv = jnp.linalg.inv(B)
    Var = 4 * B_inv.T @ Sig @ B_inv
    return theta, J, Var

def solve_longitudinal_ugee_inner(data, theta_init, V_inv, max_iter=10, tol=1e-6, lamb=0.0, option="fisher", verbose=True):
    T = data['yi'].shape[1]
    theta = {k: v.copy() for k, v in theta_init.items()}

    step, J = compute_delta(theta, V_inv, data, lamb, option)
    theta = update_theta(theta, step)

    B, U, Sig = compute_B_U_Sig(theta, V_inv, data, compute_sig=True)

    for i in range(max_iter):
        step, J = compute_delta(theta, V_inv, data, lamb, option)
        # jax.debug.print("Step {i}: {x}", i=i, x=step)
        if i % 10 == 0 and verbose:
            jax.debug.print("Step {i} gradient norm: {x}", i=i, x=jnp.linalg.norm(step))
        theta = update_theta(theta, step)
        if jnp.linalg.norm(step) < tol:
            if verbose:
                print(f"converged after {i} iterations")
            return theta, True
        if i == max_iter-1 and verbose:
            print(f"did not converge, norm step = {jnp.linalg.norm(step)}")
    return theta, False

def solve_longitudinal_ugee_with_cov(data, theta_init, max_iter=100, tol=1e-6, lamb=0.0, option="fisher", verbose=True):
    T = data['yi'].shape[1]
    V_inv = jnp.eye(1+2*T)
    theta = {k: v.copy() for k, v in theta_init.items()}

    theta, _ = solve_longitudinal_ugee_inner(data, theta_init, V_inv, max_iter=max_iter, option = option, verbose=False)

    V_inv, _ = calculate_V_inv(theta, V_inv, data)
    for i in range(max_iter):
        step, J = compute_delta(theta, V_inv, data, lamb, option)
        # jax.debug.print("Step {i}: {x}", i=i, x=step)
        if i % 10 == 0 and verbose:
            jax.debug.print("Step {i} gradient norm: {x}", i=i, x=jnp.linalg.norm(step))
        theta = update_theta(theta, step)
        if jnp.linalg.norm(step) < tol:
            if verbose:
                print(f"converged after {i} iterations")
            break
    if i == max_iter-1 and verbose:
      print(f"did not converge, norm step = {jnp.linalg.norm(step)}")
    B, U, Sig = compute_B_U_Sig(theta, V_inv, data, compute_sig=True)
    B_inv = jnp.linalg.inv(B)
    Var = 4 * B_inv.T @ Sig @ B_inv
    return theta, J, Var

In [None]:
def make_longitudinal_matrix(df, id_col, value_col, time_col):
    df = df.sort_values(by=[id_col, time_col])
    # Step 2: Pivot the dataframe to get ids as rows and sequence numbers as columns
    matrix_df = df.pivot(index=[id_col], columns=time_col, values=value_col)

    # Optional: Replace NaN values if needed
    #matrix_df = matrix_df.fillna(0)  # or use another appropriate value
    matrix_df = matrix_df.apply(lambda row: row.fillna(row.mean()), axis=1)

    # remove columns where everything is zero
    # matrix_df = matrix_df.loc[:, (matrix_df != 0).any(axis=0)]

    return matrix_df.reset_index().sort_values(by=id_col)