In [263]:
import numpy as np
import pandas as pd
from collections import namedtuple
import numba
from scipy.stats import pearsonr
import scipy.stats as ss
import math

In [267]:
Plink = namedtuple("Plink", ["iid", "snp", "chr", "bp", "a1", "a2", "n_samples", "n_snps", "bed"])

def read_plink(bfile):
    bim_file = f"{bfile}.bim"
    bim = pd.read_csv(bim_file, sep='\t', header=None, usecols=[0,1,3,4,5], names=["CHR","SNP","BP","A1","A2"])
    fam_file = f"{bfile}.fam"
    fam = pd.read_csv(fam_file, delim_whitespace=True, header=None, usecols=[1], names=["IID"])
    
    bed_file = f"{bfile}.bed"
    magic_bits = np.fromfile(bed_file,count=3,dtype=np.uint8) # read whole bed file at once
    if (magic_bits != [108,27,1]).any():
        # check magic bits
        # [108,27,1] are integers corresponding to bytes([0x6c, 0x1b, 0x01])
        raise ValueError(f"{bed_file} file is not a valid bed file!")
    n_snps = len(bim)
    n_samples = len(fam)
    n_cols = n_samples//4
    if 4*n_cols != n_samples:
        n_cols += 1
    bed = np.memmap(bed_file, dtype=np.uint8, offset=3, mode='r', shape=(n_snps,n_cols))
    return Plink(iid=fam.IID.values, snp=bim.SNP.values, chr=bim.CHR.values, bp=bim.BP.values,
                 a1=bim.A1.values, a2=bim.A2.values, n_samples=n_samples, n_snps=n_snps, bed=bed)

@numba.jit(nopython=True, nogil=True)
def get_byte_map():
    """
    Construct mapping between bytes 0..255 and 4-element arrays of a1 genotypes
    from plink bed file.
    Return 256 x 4 array A, where A[i] = [a1, a2, a3, a4], each ai from [2, -1, 1, 0].
    """
    genotype_codes = np.array([2, -1, 1, 0],dtype=np.int8)
    byte_map = np.empty((256,4), dtype=np.int8)
    for b in range(256):
        for i in range(4):
            byte_map[b,i] = genotype_codes[(b >> 2*i) & 3]
    return byte_map

# geno_idx = [n_nonmiss, n2, n1, {ii2}, {ii1}, {empty}]
# len(geno_idx) = 3 + n_samples
# len(ii2) = n2, {ii2} = [i20, i21, ...], i20 = index of the first occurance of 2 genotype
# {empty} = empty array to fill remaining space
# geno_idx = np.empty(3+n_sampels, dtype=np.int)
# ii1_tmp, ii2_tmp = np.empty(n_samples, dtype=np.int), to fill indices of 2 and 1 genotypes correspondingly
@numba.jit(nopython=True, nogil=True)
def get_geno_idx(i_geno, bed, n_samples, geno_idx, ii1_tmp, ii2_tmp, byte_map):
    i = 0 # current sample index
    n2 = 0
    n1 = 0
    n_nonmiss = 0
    for b in bed[i_geno]:
        for g in byte_map[b]:
            if g != -1:
                n_nonmiss += 1
                if g == 2:
                    ii2_tmp[n2] = i
                    n2 += 1
                elif g == 1:
                    ii1_tmp[n1] = i
                    n1 += 1
            i += 1
            if i == n_samples:
                break
    geno_idx[0] = n_nonmiss
    geno_idx[1] = n2
    geno_idx[2] = n1
    geno_idx[3:3+n2] = ii2_tmp[:n2]
    geno_idx[3+n2:3+n2+n1] = ii1_tmp[:n1]

    
@numba.jit(nopython=True, nogil=True)
def get_geno(geno, i, bed, byte_map=None):
    """
    Args:
        geno = np.empty(n_samples, dtype=np.int8) = array to fill
        i = int number
    """
    if byte_map is None:
        byte_map = get_byte_map()
    n_samples = len(geno)
    i_g = 0
    for b in bed[i]:
        for g in byte_map[b]:
            geno[i_g] = g
            i_g += 1
            if i_g == n_samples:
                break

@numba.jit(nopython=True, nogil=True)
def get_t_stat(idx2, idx1, pheno_mat, n_pheno, pheno_mean_arr, pheno_std_arr,
               geno_mean, geno_std, n_nonmiss, t_stat):
    for pheno_i in range(n_pheno):
        p2 = 2.0*pheno_mat[pheno_i][idx2].sum()
        p1 = pheno_mat[pheno_i][idx1].sum()
        pg_mean = (p2 + p1)/n_nonmiss
        pg_r = (pg_mean - pheno_mean_arr[pheno_i]*geno_mean)/(pheno_std_arr[pheno_i]*geno_std)
        pg_t = pg_r*math.sqrt((n_nonmiss - 2)/(1 - pg_r*pg_r))
        t_stat[pheno_i] = pg_t
        
        
@numba.jit(nopython=True, parallel=True, nogil=True)
def gen_corr(pheno_mat, pheno_mean, pheno_std, inv_C0reg, bed, n_snps, n_samples,
             mosttest_stat, mosttest_stat_shuf, minp_stat, minp_stat_shuf):
    # pheno_mean, pheno_std are passed as arguments since numba does not support kwargs for these functions.
    byte_map = get_byte_map()
    n_pheno = pheno_mat.shape[0]
    for geno_i in numba.prange(n_snps):
        t_stat = np.zeros(n_pheno, dtype=np.float32) # this array can be preallocated by thread
        t_stat_shuf = np.empty(n_pheno, dtype=np.float32) # this array can be preallocated by thread
        geno_idx = np.empty(3+n_samples, dtype=np.int32) # this array can be preallocated by thread
        ii1_tmp = np.empty(n_samples, dtype=np.int32) # this array can be preallocated by thread
        ii2_tmp = np.empty(n_samples, dtype=np.int32) # this array can be preallocated by thread
        get_geno_idx(geno_i, bed, n_samples, geno_idx, ii1_tmp, ii2_tmp, byte_map)
        geno_mean = (geno_idx[1]*2 + geno_idx[2])/geno_idx[0]
        geno_std = math.sqrt((geno_idx[1]*4 + geno_idx[2])/geno_idx[0] - geno_mean*geno_mean)
        n_nonmiss = geno_idx[0]
        n2 = geno_idx[1]
        n1 = geno_idx[2]
        # for original genotypes
        get_t_stat(geno_idx[3:3+n2], geno_idx[3+n2:3+n2+n1], pheno_mat, n_pheno, pheno_mean, pheno_std,
                   geno_mean, geno_std, n_nonmiss, t_stat)
        mosttest_stat[geno_i] = t_stat @ inv_C0reg @ t_stat 
        x = -np.max(np.abs(t_stat))
        minp_stat[geno_i] = 1.0 + math.erf(x/math.sqrt(2.0)) # 2*norm.cdf(x)

        # for shuffled genotypes
        geno_idx_shuf = np.random.choice(n_samples,n1+n2,replace=False)
        get_t_stat(geno_idx_shuf[:n2], geno_idx_shuf[n2:], pheno_mat, n_pheno, pheno_mean, pheno_std,
                   geno_mean, geno_std, n_nonmiss, t_stat)
        mosttest_stat_shuf[geno_i] = t_stat @ inv_C0reg @ t_stat
        x = -np.max(np.abs(t_stat))
        minp_stat_shuf[geno_i] = 1.0 + math.erf(x/math.sqrt(2.0)) # 2*norm.cdf(x)
        
    

In [152]:
pheno = "pheno.txt"            # full or relative path to the phenotype file
bfile = "chr21"                # full or relative path to plink bfile prefix
out = "results"                # prefix for the output files
pheno_df = pd.read_csv(pheno, sep='\t')
print(pheno_df.shape)
pheno_df.head(3)

(10000, 10)


Unnamed: 0,trait1,trait2,trait3,trait4,trait5,trait6,trait7,trait8,trait9,trait10
0,1.27144,0.573959,0.282021,-0.524421,1.5639,1.99793,-1.59026,-0.025822,0.157737,0.620283
1,-2.38677,0.430421,-0.903973,-0.308664,0.751253,-0.233879,1.71025,1.13925,0.178941,-0.79577
2,-1.58732,0.414166,1.52453,1.29193,-0.579384,0.095391,1.1751,-0.523504,-1.48036,-0.072877


In [153]:
plink = read_plink(bfile)

In [168]:
byte_map = get_byte_map()
i_geno = 0
geno = np.empty(n_samples, dtype=np.int8)
get_geno(geno, i_geno, plink.bed, byte_map)
for i in range(10):
    print(f"{i}:\t{pearsonr(geno, pheno_df.values.T[i])}")

0:	(-0.026900267111844935, 0.007141408429096841)
1:	(-0.023174981513127416, 0.020475440309356184)
2:	(-0.02378678765013903, 0.017372963362809832)
3:	(-0.00259158690856959, 0.795537461997438)
4:	(0.025485991420215723, 0.010812697430029918)
5:	(-0.01108464920168129, 0.26770619924493866)
6:	(0.0021757589681088477, 0.8277805592597748)
7:	(-0.006787655271854427, 0.49733525809786233)
8:	(-0.00860538922174816, 0.3895425335762583)
9:	(0.008513740209739561, 0.39461215514420367)


In [269]:
mosttest_stat = np.empty(plink.n_snps, dtype=np.float32)
mosttest_stat_shuf = np.empty(plink.n_snps, dtype=np.float32)
minp_stat = np.empty(plink.n_snps, dtype=np.float32)
minp_stat_shuf = np.empty(plink.n_snps, dtype=np.float32)
pheno_mat = pheno_df.values.T
pheno_mean = np.mean(pheno_mat, axis=1, dtype=np.float32)
pheno_std = np.std(pheno_mat, axis=1, dtype=np.float32, ddof=0)
pheno_corr_mat = np.corrcoef(pheno_mat, rowvar=True)
inv_C0reg = np.linalg.inv(pheno_corr_mat)

pheno_mat = pheno_mat.astype(np.float32)
inv_C0reg = inv_C0reg.astype(np.float32)

%time gen_corr(pheno_mat, pheno_mean, pheno_std, inv_C0reg, plink.bed, plink.n_snps, plink.n_samples, mosttest_stat, mosttest_stat_shuf, minp_stat, minp_stat_shuf)

mosttest_stat, mosttest_stat_shuf, minp_stat, minp_stat_shuf

CPU times: user 49.2 s, sys: 109 ms, total: 49.3 s
Wall time: 13 s


(array([27.674358 ,  5.3650484, 27.669147 , ...,  9.494259 , 19.795343 ,
        32.702    ], dtype=float32),
 array([13.661152, 12.481163, 10.021618, ...,  4.498459,  6.136265,
        11.794331], dtype=float32),
 array([0.0071296 , 0.1387241 , 0.00672145, ..., 0.08371501, 0.00463058,
        0.00056898], dtype=float32),
 array([0.08895145, 0.05900057, 0.02563081, ..., 0.14381345, 0.18546563,
        0.01343986], dtype=float32))

In [192]:
mm = np.corrcoef(pheno_mat, rowvar=True)
mm_inv = np.linalg.inv(mm)

In [194]:
np.allclose(mm @ mm_inv, np.eye(mm.shape[0]))

True

In [199]:
t_stat @ mm_inv

array([-2.54562806, -3.33233232, -1.32856087, -0.02196962,  2.59537891,
       -0.98409948,  0.66497589, -0.11338269, -1.58169842,  0.75735126])

In [201]:
mm_inv @ t_stat

array([-2.54562806, -3.33233232, -1.32856087, -0.02196962,  2.59537891,
       -0.98409948,  0.66497589, -0.11338269, -1.58169842,  0.75735126])

In [208]:
a = np.array([[1,2],[3.0,4.0]])
b = np.array([2,5.0])
a, b

(array([[1., 2.],
        [3., 4.]]),
 array([2., 5.]))

In [210]:
@numba.njit()
def ff(a, b):
    return b@a@b

In [251]:
c = np.array([3.,2])

In [256]:
((3.0 - 1)/(c*c)).shape

(2,)