# About

Here I draft the core function for logistic regression using `pytorch` which solves logistic regression problem for multiple SNPs (typical use case in GWAS).

# Implementation

I first implement a naive convergence checker. 
And will modify the code later so that it checks and updates in per-SNP manner.

In [1]:
import torch 
import numpy as np
# import scipy.stats as stats

In [2]:
def batchIRLS(X, y, C, tol=1e-8, maxiter=100):
    '''
    Input
        X: Tensor(n, p)
        y: Tensor(n, 1)
        C: Tensor(n, k)
    Output:
        B_hat: Tensor(p, 1). Effect size estimates where mth entry is for logit(y) ~ X[:, p] + C.
        B_se: Tensor(p, 1). The corresponding SE's of the estimates.
    '''
    # get dimensions
    n = X.shape[0]
    p = X.shape[1]
    k = C.shape[1]
    
    # initialize
    Wc = torch.zeros(p, k)
    Wx = torch.zeros(p)
    Wcx = torch.Tensor(p, k + 1)
    Wcx[:, :k] = Wc
    Wcx[:, k] = Wx
    XSX = torch.Tensor(k + 1, k + 1, p)
    CX_RES = torch.Tensor(k + 1, p)
    
    diff = tol + 1
    niter = 0
    
    while diff > tol and niter < maxiter:
        
        # take a copy of current Wcx
        Wcx_old = Wcx.clone()
        
        # compute w^T x
        C_Wc = torch.einsum('nk,pk->np', C, Wc) 
        X_Wx = torch.einsum('np,p->np', X, Wx)  
        CX_Wcx = C_Wc + X_Wx 
        
        # compute mu and S
        Mu = logistic_func(CX_Wcx)
        S = Mu * (1 - Mu)  
        
        # compute X^T S X
        C_S_C = torch.einsum('nk,np,nj->kjp', C, S, C) 
        C_S_X = torch.einsum('nk,np,np->kp', C, S, X)  # .view(k, -1, p)  
        X_S_X = torch.einsum('np,np,np->p', X, S, X)  # .view(1, 1, p)  
        # combine blocks together
        XSX[:k, :k, :] = C_S_C
        XSX[:k, k, :] = C_S_X
        XSX[k, :k, :] = C_S_X
        XSX[k, k, :] = X_S_X
        
        # get LHS
        RES = mat_vec_add(-Mu, Y)  
        C_RES = torch.einsum('nk,np->kp', C, RES)  
        X_RES = torch.einsum('np,np->p', X, RES)  #.view(-1, p)  
        CX_RES[:k, :] = C_RES
        CX_RES[k, :] = X_RES
        
        XSX_Wcx = torch.einsum('jkp,pk->jp', XSX, Wcx)  
    
        LHS = XSX_Wcx + CX_RES  

        # solve and update
        Wcx, _ = torch.solve(
            torch.einsum('kp->pk', LHS).view(p, k + 1, -1), 
            torch.einsum('ijk->kij', XSX)
        )
        Wcx = Wcx[:, :, 0]
        Wc = Wcx[:, :k]
        Wx = Wcx[:, k]
        
        diff = naive_convergence_checker(Wcx, Wcx_old)
        niter += 1
        
    # compute SE
    ## need to update X^T S X first
    # compute w^T x
    C_Wc = torch.einsum('nk,pk->np', C, Wc) 
    X_Wx = torch.einsum('np,p->np', X, Wx)  
    CX_Wcx = C_Wc + X_Wx 
    # compute mu and S
    Mu = logistic_func(CX_Wcx)
    S = Mu * (1 - Mu) 
    # compute X^T S X
    C_S_C = torch.einsum('nk,np,nj->kjp', C, S, C) 
    C_S_X = torch.einsum('nk,np,np->kp', C, S, X)  # .view(k, -1, p)  
    X_S_X = torch.einsum('np,np,np->p', X, S, X)  # .view(1, 1, p)  
    ## combine blocks together
    XSX[:k, :k, :] = C_S_C
    XSX[:k, k, :] = C_S_X
    XSX[k, :k, :] = C_S_X
    XSX[k, k, :] = X_S_X
    ## finished updating XSX
    ONES = torch.eye(k + 1).view(k + 1, k + 1, -1).expand_as(XSX)  # Tensor(k + 1, k + 1)
    VAR, _ = torch.solve(
        torch.einsum('ijk->kij', ONES), 
        torch.einsum('ijk->kij', XSX)
    )  # Tensor(k + 1, k + 1, p)
    
    # return B_hat, B_se
    if diff > tol:
        print(f'Warning: not converged! diff = {diff} > tol = {tol}')
    return Wcx.T, torch.sqrt(torch.diagonal(VAR, dim1=1, dim2=2)).T 

def naive_convergence_checker(w_new, w_old):
    nom = torch.pow(w_new - w_old, 2)
    den = torch.pow(w_new, 2)
    return torch.sum(nom) / torch.sum(den)

def logistic_func(u):
    return 1 / ( 1 + torch.exp(-u) )

def mat_vec_add(mat, vec):
    nrow = vec.shape[0]
    return mat + vec.view(nrow, 1).expand_as(mat)

# Simulate data

Follow the same procedure in `../analysis/logistic_solver_single.Rmd` with the same paramter settings. 
But here we simulate multiple SNPs.

In [3]:
nsnp = 10
N = 1000
K = 13
maf = 0.1
sigma_c2 = 1
sigma2 = 4
intercept = -3

# G = torch.Tensor(N, nsnp)
# Y = torch.Tensor(N)
# C = torch.Tensor(N, K)

# simulate data begins here
# index of the place to fill in true signal
idx_true = 1
# effect sizes
beta = torch.empty(K + 1).normal_(mean=0, std=np.sqrt(sigma2))  # np.random.normal(loc=0, scale=np.sqrt(sigma2), size=(K + 1,))
G = torch.empty(N, nsnp).bernoulli_(maf)  # np.random.binomial(1, maf, size=(N, nsnp))
covar = torch.empty(N, K).normal_(mean=0, std=np.sqrt(sigma_c2))  # np.random.normal(loc=0, scale=np.sqrt(sigma_c2), size=(N, K))
X = torch.cat((torch.ones((N, 1)), G[:, idx_true, np.newaxis], covar), axis=1)
w = torch.cat((torch.Tensor([intercept]), beta))
mu = logistic_func(np.matmul(X, w))
Y = torch.empty(N).bernoulli_(mu)  # rbinom(N, 1, prob = mu)
C_w_inter = torch.cat((torch.ones((N, 1)), covar), axis=1)

# Run

In [4]:
bhat, b_se = batchIRLS(G, Y, C_w_inter, tol=1e-8)  # tol=1e-8 ~~ epsilon=1e-8 in glm.control

# Sanity check (compare with R `glm`)

In [5]:
from rpy2.robjects.packages import importr
from rpy2.robjects import pandas2ri
stats = importr("stats")
base = importr("base")
import pandas as pd
pandas2ri.activate()

  from pandas.core.index import Index as PandasIndex


In [6]:
def logistic_regression_glm(X, y, C):
    Nx = X.shape[1]
    Nc = C.shape[1]
    bhat = np.empty((Nc + 1 + 1, Nx))
    bse = np.empty((Nc + 1 + 1, Nx))
    covar_df = pd.DataFrame({ f'covar{i}': C[:, i] for i in range(C.shape[1]) })
    for xi in range(Nx):
        string = 'cbind(y, 1 - y) ~ 1 + ' + ' + '.join(covar_df.columns.tolist()) + ' + xx'
        data_df = pd.concat((
            covar_df, 
            pd.DataFrame({'xx': X[:, xi]}),
            pd.DataFrame({'y': y})
        ), axis=1)
        fit = stats.glm(formula=string, data=pandas2ri.py2rpy(data_df), family='binomial')
        est = base.summary(fit).rx2('coefficients')[:, :2]
        bhat[:, xi] = est[:, 0]
        bse[:, xi] = est[:, 1]
    return bhat, bse

In [7]:
bhat0, b_se0 = logistic_regression_glm(G, Y, covar)

In [8]:
np.testing.assert_array_almost_equal(bhat, bhat0, decimal=5)
np.testing.assert_array_almost_equal(b_se, b_se0, decimal=5)

# Time used

In [9]:
%timeit batchIRLS(G, Y, C_w_inter, tol=1e-8)
%timeit logistic_regression_glm(G, Y, covar)

20.8 ms ± 746 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
486 ms ± 16.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Conclusion

The `batchIRLS` implementation is consistent with `glm` upto 4 decimal but runs faster.

# Implementing the masking scheme

As first attempt, we apply mask inside the iteration.

In [10]:
def batchIRLS_w_mask(X, y, C, tol=1e-8, maxiter=100):
    '''
    Input
        X: Tensor(n, p)
        y: Tensor(n, 1)
        C: Tensor(n, k)
    Output:
        B_hat: Tensor(p, 1). Effect size estimates where mth entry is for logit(y) ~ X[:, p] + C.
        B_se: Tensor(p, 1). The corresponding SE's of the estimates.
    '''
    if tol > 1 or tol < 0:
        raise ValueError('The tol is relative tolerence so we expect tol in (0, 1).')
    # get dimensions
    n = X.shape[0]
    p = X.shape[1]
    k = C.shape[1]
    
    # initialize
    Wc = torch.zeros(p, k)
    Wx = torch.zeros(p)
    Wcx = torch.Tensor(p, k + 1)
    Wcx[:, :k] = Wc
    Wcx[:, k] = Wx
    XSX = torch.Tensor(k + 1, k + 1, p)
    CX_RES = torch.Tensor(k + 1, p)
    mask = torch.zeros(p) == 0
    
    max_diff = tol + 1
    diffs = torch.ones(p) 
    niter = 0
    
    while max_diff > tol and niter < maxiter:
        
        # generate mask
        mask = diffs > tol
#         print('N active = ', mask.sum(), 'max_diff = ', max_diff, 'min_diff = ', diffs.min(), 'tol = ', tol)
        
        # take a copy of current Wcx
        Wcx_old = Wcx.clone()
        
        # compute w^T x
        C_Wc = torch.einsum('nk,pk->np', C, Wc[mask, :])
        X_Wx = torch.einsum('np,p->np', X[:, mask], Wx[mask])  
        CX_Wcx = C_Wc + X_Wx 
        
        # compute mu and S
        Mu = logistic_func(CX_Wcx)
        S = Mu * (1 - Mu)  
        
        # compute X^T S X
        # NOTE: view is dangerous. Needs more cares!
        C_S_C = torch.einsum('nk,np,nj->kjp', C, S, C) 
        C_S_X = torch.einsum('nk,np,np->kp', C, S, X[:, mask])  # .view(k, -1, p)  
        X_S_X = torch.einsum('np,np,np->p', X[:, mask], S, X[:, mask])  # .view(1, 1, p)  
        # combine blocks together
        XSX[:k, :k, mask] = C_S_C
        XSX[:k, k, mask] = C_S_X
        XSX[k, :k, mask] = C_S_X
        XSX[k, k, mask] = X_S_X
        
        # get LHS
        RES = mat_vec_add(-Mu, Y)  
        C_RES = torch.einsum('nk,np->kp', C, RES)  
        X_RES = torch.einsum('np,np->p', X[:, mask], RES)  #.view(-1, p)  
        CX_RES[:k, mask] = C_RES
        CX_RES[k, mask] = X_RES
        
        XSX_Wcx = torch.einsum('jkp,pk->jp', XSX, Wcx)  
    
        LHS = XSX_Wcx + CX_RES  

        # solve and update
        tmp_, _ = torch.solve(
            torch.einsum('kp->pk', LHS).view(p, k + 1, -1)[mask, :, :], 
            torch.einsum('ijk->kij', XSX)[mask, :, :]
        )
        Wcx[mask, :] = tmp_[:, :, 0]
        Wc[mask, :] = Wcx[mask, :k]
        Wx[mask] = Wcx[mask, k]
        
        max_diff, diffs = snp_level_convergence_checker(Wcx, Wcx_old)
        niter += 1
        
    # compute SE
    ## compute w^T x
    C_Wc = torch.einsum('nk,pk->np', C, Wc)
    X_Wx = torch.einsum('np,p->np', X, Wx)  
    CX_Wcx = C_Wc + X_Wx 
    ## compute mu and S
    Mu = logistic_func(CX_Wcx)
    S = Mu * (1 - Mu)  
    ## need to update X^T S X first
    C_S_C = torch.einsum('nk,np,nj->kjp', C, S, C) 
    C_S_X = torch.einsum('nk,np,np->kp', C, S, X)  # .view(k, -1, p)  
    X_S_X = torch.einsum('np,np,np->p', X, S, X)  # .view(1, 1, p)  
    ## combine blocks together
    XSX[:k, :k, :] = C_S_C
    XSX[:k, k, :] = C_S_X
    XSX[k, :k, :] = C_S_X
    XSX[k, k, :] = X_S_X
    ## finished updating XSX
    ONES = torch.eye(k + 1).view(k + 1, k + 1, -1).expand_as(XSX)  # Tensor(k + 1, k + 1)
    VAR, _ = torch.solve(
        torch.einsum('ijk->kij', ONES), 
        torch.einsum('ijk->kij', XSX)
    )  # Tensor(k + 1, k + 1, p)
    
    # return B_hat, B_se
    if max_diff > tol:
        print(f'Warning: not converged! max_diff = {max_diff} > tol = {tol}')
    return Wcx.T, torch.sqrt(torch.diagonal(VAR, dim1=1, dim2=2)).T 

def divide_fill_zero(a, b):
    o = torch.div(a, b)
    o[b == 0] = 0
    return o

def snp_level_convergence_checker(w_new, w_old):
    nom = torch.pow(w_new - w_old, 2)
    den = torch.pow(w_new, 2)
    diff = divide_fill_zero(torch.sum(nom, axis=1), torch.sum(den, axis=1))
    return diff.max(), diff

In [11]:
%timeit bhat, b_se = batchIRLS(G, Y, C_w_inter, tol=1e-8)

21.4 ms ± 922 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [12]:
%timeit bhat1, b_se1 = batchIRLS_w_mask(G, Y, C_w_inter, tol=1e-8)

29.3 ms ± 1.1 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [13]:
bhat1, b_se1 = batchIRLS_w_mask(G, Y, C_w_inter, tol=1e-8)
bhat, b_se = batchIRLS(G, Y, C_w_inter, tol=1e-8)

In [14]:
np.testing.assert_array_almost_equal(bhat1, bhat0, decimal=5)
np.testing.assert_array_almost_equal(b_se1, b_se0, decimal=5)

In [15]:
np.testing.assert_array_almost_equal(bhat1, bhat, decimal=5)
np.testing.assert_array_almost_equal(b_se1, b_se, decimal=5)