In [1]:
import bgen_reader
import pandas as pd
import numpy as np
from tqdm import tqdm
from rpy2.robjects.packages import importr
rbgen = importr("rbgen")
import rpy2.robjects as ro
import rpy2.rlike.container as rlc
from rpy2.robjects import pandas2ri
pandas2ri.activate()

# Load data

In [2]:
filedir = '/vol/bmd/meliao/data/haplotype/hap/'
chr_num = '16'
gwas_name = '20002_1081.gwas.imputed_v3.male'

In [3]:
# bgen = bgen_reader.read_bgen(
#     f'{filedir}/ukb_hap_chr{chr_num}_v2.bgen', 
#     samples_filepath='/vol/bmd/meliao/data/haplotype/link_files/ukb1952_v2_s487398.sample',
#     metafile_filepath=f'/vol/bmd/yanyul/UKB/ukb_hap_bgen_reader_metafile/ukb_hap_v2.chr{chr_num}.metafile'
# )

In [4]:
gwas = pd.read_csv(
    f'/vol/bmd/yanyul/UKB/neale_lab_gwas/{gwas_name}.tsv.bgz', 
    header = 0, 
    sep = '\t', 
    compression = 'gzip'
)

In [5]:
snp_map = pd.read_csv(
    '/vol/bmd/yanyul/UKB/haplotype_imputation/snp_map_for_neale_lab_gwas.with_sign.tsv.gz',
    sep='\t',
    header=0,
    compression='gzip',
    dtype={3:'str'}
)

# Workflow

1. **In GWAS**, extract variants that occur in my genotype.
2. **In GWAS**, perform LD clumping.
3. Calculate PRS.

## Extract variants

In [6]:
# # don't run
# snp_map = snp_map[ (snp_map['assigned_id'] != 'not_shown') & (snp_map['assigned_id'] != 'ambiguious') ]

In [7]:
# # don't run
# # extract variants from GWAS
# gwas_sub = gwas[ gwas['variant'].isin(snp_map['assigned_id']) ]

## LD clump by `plink`

## Naive PRS

In [8]:
variant_annot = snp_map

In [9]:
ld_clump = pd.read_csv(
    f'/vol/bmd/yanyul/UKB/haplotype_imputation/ld_clump/tmp_{gwas_name}/gwas_clump_x_chr{chr_num}.valid.snp',
    header=None
)

In [10]:
gwas_x_ld_clump = gwas[ gwas['variant'].isin(ld_clump[0] )]

In [11]:
variant_annot = pd.merge(
    variant_annot, 
    gwas_x_ld_clump, 
    left_on=['assigned_id'], 
    right_on='variant',
    how='inner'
)

In [12]:
variant_annot = variant_annot.sort_values(
    by=['pos']
).reset_index()

In [13]:
variant_annot['signed_beta'] = variant_annot[['assigned_sign', 'beta']].apply(lambda x: x.beta if x.assigned_sign == '+' else -x.beta, axis=1)

In [36]:
from rpy2.robjects import pandas2ri
pandas2ri.activate()
import gc

class UKBhapReader:
    '''
    Reader of BGEN for one chromosome of ukb_hap.
    '''
    def __init__(self, bgen_path, bgen_bgi_path, sample_path):
        self.rbgen = importr("rbgen")
        self.bgen_path = bgen_path
        self.bgi_path = bgen_bgi_path
        self.sample_path = sample_path
        
    def extract_variant_by_position(self, chrom, start, end, max_entries_per_sample=4):
        '''
        max_entries_per_sample: number of entries per sample in BGEN
        '''
        range_pd = pd.DataFrame(
            {
                'chromosome': [ chrom ],
                'start': [start],
                'end': [end]
            }
        )
#         breakpoint()
        cached_data = self.rbgen.bgen_load(
            self.bgen_path,
            index_filename=self.bgi_path,
            ranges=range_pd, 
            max_entries_per_sample=4
        )
        return cached_data
    
    @staticmethod
    def _get_varid(chrm, pos, a1, a2):
        return f'{chrm}:{pos}:{a1}:{a2}'
    
    @staticmethod
    def _check_ncol(mat, ncol):
        '''
        check if mat has expected number of columns
        '''
        if mat.shape[1] != ncol:
            raise ValueError(f'ukb_hap does not have {ncol} columns.')
    
    def _hap_to_count(self, hap):
        '''
        expect hap has two columns.
        '''
        self._check_ncol(hap, 2)
        colsum = np.sum(hap, axis=1)
        if (colsum != 1).sum() > 0:
            raise ValueError('some rows have colsum != 1 which is not allowed.')
        return hap[:,1]
    
    @staticmethod
    def _next_pos(curr_pos, max_pos, n_jump):
        next_pos = min(curr_pos + n_jump, max_pos)
        return next_pos
    
    def ukb_hap_to_haplo(self, ukb_hap):
        '''
        expect ukb_hap has 4 columns
        ukb hap encoding law:
            0,1,0,1 -> 1|1
            0,1,1,0 -> 1|0
            1,0,0,1 -> 0|1
            1,0,1,0 -> 0|0
        i.e. the 1st, 2nd columns encode haplotype 1 and 
        the 3rd, 4th columns encode haplotype 2.
        And 0,1->1; 1,0->0. 
        Other combinations are not allowed.
        '''
        self._check_ncol(ukb_hap, 4)
        return self._hap_to_count(ukb_hap[:, :2]), self._hap_to_count(ukb_hap[:, 2:])
    
    def retrieve_from_list(self, chrom, pos, non_effect_allele, effect_allele, 
                           n_var_cached=10, max_entries_per_sample=4):
        '''
        Retrieve generator of variants.
        '''
        niter = 0
        snp_list = pd.DataFrame({
            'chr': chrom,
            'pos': pos,
            'non_effect_allele': non_effect_allele,
            'effect_allele': effect_allele
        })
        # sort by position so that we can retrive by left to right.
        snp_list = snp_list.sort_values('pos').reset_index()
        # set desired variant id 
        set_snp_list = set(
            snp_list.apply(
                lambda x: self._get_varid(x.chr, x.pos, x.non_effect_allele, x.effect_allele),
                axis=1
            ).tolist()
        )
        nsnp = snp_list.shape[0]
        curr_pos_in_snp_list = 0
        next_pos_in_snp_list = 0
        while next_pos_in_snp_list < nsnp:
            # modified from 
            # https://github.com/liangyy/predixcan_prediction/blob/a85d52d89de9fe237a1217b5627c7e8d9f700f7e/bgen/bgen_dosage.py#L80
            next_pos_in_snp_list = self._next_pos(curr_pos_in_snp_list, nsnp, n_var_cached)
            if niter > 0:
                cached_data_struct = cached_data.__sexp__
                del cached_data
                del cached_data_struct
                gc.collect()
            cached_data = self.extract_variant_by_position(
                chrom=snp_list.chr[curr_pos_in_snp_list], 
                start=snp_list.pos[curr_pos_in_snp_list], 
                end=snp_list.pos[next_pos_in_snp_list - 1]
            )
            all_variants = pandas2ri.ri2py(cached_data[0])
            if all_variants.shape[0] == 0:
                return
            all_variants['my_var_id'] = all_variants.apply(
                lambda x: self._get_varid(x.chromosome, x.position, x.allele0, x.allele1),
                axis=1
            )
            all_probs = pandas2ri.ri2py(cached_data[4])
            niter += 1
            for row_idx, (rsid, row) in enumerate(all_variants.iterrows()):
                if row.my_var_id in set_snp_list:
                    dosage_row = row.rename({'chromosome': 'chr'})
                    h1, h2 = self.ukb_hap_to_haplo(all_probs[row_idx, :, :])
                    dosage_row['haplo_dosage_1'] = h1
                    dosage_row['haplo_dosage_2'] = h2
                    yield dosage_row

In [None]:
# class LDclumpPRSmodel:
#     def __init__(self, gwas_file, ld_clump, var_col, chr_col, pos_col, ea_col, nea_col, pval_col, beta_col):
#     def annot_with_snp_map(self, snp_map_df):
        
        

In [37]:
reader = UKBhapReader(
    bgen_path=f'{filedir}/ukb_hap_chr{chr_num}_v2.bgen', 
    bgen_bgi_path=f'/vol/bmd/meliao/data/haplotype/hap_bgi/ukb_hap_chr{chr_num}_v2.bgen.bgi', 
    sample_path=''
)

In [38]:
var_list = pd.concat(
    [
        pd.DataFrame({
            'pos': [10, 100],
            'allele_ids': ['A,T', 'G,A']
        }),
        variant_annot
    ]
)[:10]
var_generator = reader.retrieve_from_list(
    chrom=[ '' for i in range(var_list.shape[0]) ], 
    pos=var_list.pos.tolist(), 
    non_effect_allele=var_list.allele_ids.map(lambda x: x.split(',')[0]).tolist(), 
    effect_allele=var_list.allele_ids.map(lambda x: x.split(',')[1]).tolist(),
    n_var_cached=100
)

of pandas will change to not sort by default.

To accept the future behavior, pass 'sort=False'.


  import sys


In [39]:
var_list

Unnamed: 0,AC,allele_ids,assigned_id,assigned_sign,beta,chrom,expected_case_minor_AC,id,index,low_confidence_variant,...,minor_allele,n_complete_samples,pos,pval,rsid,se,signed_beta,tstat,variant,ytx
0,,"A,T",,,,,,,,,...,,,10,,,,,,,
1,,"G,A",,,,,,,,,...,,,100,,,,,,,
0,16959.0,"G,A",16:85629:G:A,+,0.000551,16.0,297.87,16.0,0.0,False,...,A,166988.0,85629,0.593783,rs74676015,0.001033,0.000551,0.533363,16:85629:G:A,305.0
1,5827.0,"C,A",16:85667:C:A,+,-0.002301,16.0,102.346,16.0,1.0,False,...,A,166988.0,85667,0.183405,rs117923034,0.00173,-0.002301,-1.33035,16:85667:C:A,89.0
2,6876.0,"A,C",16:89659:A:C,+,-0.001565,16.0,120.771,16.0,2.0,False,...,C,166988.0,89659,0.326478,rs72763640,0.001595,-0.001565,-0.981236,16:89659:A:C,109.0
3,23116.0,"G,A",16:92391:G:A,+,-0.001029,16.0,406.013,16.0,3.0,False,...,A,166988.0,92391,0.249725,rs1088642,0.000894,-0.001029,-1.15102,16:92391:G:A,381.0
4,2887.0,"G,A",16:92688:G:A,+,0.003364,16.0,50.7077,16.0,4.0,False,...,A,166988.0,92688,0.168642,rs117695470,0.002444,0.003364,1.37659,16:92688:G:A,61.0
5,229394.0,"G,A",16:105325:G:A,+,-7.1e-05,16.0,1836.89,16.0,5.0,False,...,G,166988.0,105325,0.883987,rs2858042,0.000489,-7.1e-05,-0.145917,16:105325:G:A,4024.0
6,19.0,"C,T",16:106596:C:T,+,0.035728,16.0,0.333719,16.0,6.0,True,...,T,166988.0,106596,0.234488,rs80023530,0.030052,0.035728,1.18888,16:106596:C:T,1.0
7,12796.0,"A,C",16:119006:A:C,+,0.001227,16.0,224.751,16.0,7.0,False,...,C,166988.0,119006,0.299191,rs117078265,0.001182,0.001227,1.03817,16:119006:A:C,241.0


In [40]:
for i in var_generator:
    print(i)

chr                                                                   
position                                                         85629
rsid                                                        rs74676015
number_of_alleles                                                    2
allele0                                                              G
allele1                                                              A
my_var_id                                                   :85629:G:A
haplo_dosage_1       [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
haplo_dosage_2       [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
Name: 0, dtype: object
chr                                                                   
position                                                         85667
rsid                                                       rs117923034
number_of_alleles                                                    2
allele0                                               

In [24]:
# # don't run: old approach
# # extract a chunk of variants
# var_sub = variant_annot[['pos', 'assigned_id']][:1]

# start = var_sub.pos[0]
# end = var_sub.pos[var_sub.shape[0] - 1]
# range_pd = pd.DataFrame(
#     {
#         'chromosome': [''],
#         'start': [start],
#         'end': [end]
#     }
# )
# bgen = rbgen.bgen_load(
#     f'{filedir}/ukb_hap_chr{chr_num}_v2.bgen',
#     index_filename=f'/vol/bmd/meliao/data/haplotype/hap_bgi/ukb_hap_chr{chr_num}_v2.bgen.bgi',
#     ranges=range_pd, # range_r,
#     max_entries_per_sample=4
# )

In [113]:
def get_effect_size(beta, sign):
    if sign == '+':
        return beta
    elif sign == '-':
        return -beta
    else:
        raise ValueError(f'sign = {sign} is not allowed.')
def _check_ncol(mat, ncol):
    '''
    check if mat has expected number of columns
    '''
    if mat.shape[1] != ncol:
        raise ValueError(f'ukb_hap does not have {ncol} columns.')
def ukb_hap_to_haplo(ukb_hap):
    '''
    expect ukb_hap has 4 columns
    ukb hap encoding law:
        0,1,0,1 -> 1|1
        0,1,1,0 -> 1|0
        1,0,0,1 -> 0|1
        1,0,1,0 -> 0|0
    i.e. the 1st, 2nd columns encode haplotype 1 and 
    the 3rd, 4th columns encode haplotype 2.
    And 0,1->1; 1,0->0. 
    Other combinations are not allowed.
    '''
    _check_ncol(ukb_hap, 4)
    out = {
        'h1': _hap_to_count(ukb_hap[:, :2]),
        'h2': _hap_to_count(ukb_hap[:, 2:])
    }
    return out
def _hap_to_count(hap):
    '''
    expect hap has two columns.
    '''
    _check_ncol(hap, 2)
    colsum = np.sum(hap, axis=1)
    if (colsum != 1).sum() > 0:
        raise ValueError('some rows have colsum != 1 which is not allowed.')
    return hap[:,1]

In [122]:
prs_thresholds = [5e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 0.05, 0.1, 0.5, 1]
geno = bgen["genotype"][0].compute()

# init out matrix and dynamic prs matrix 
outmat = {
    'h1': np.empty(
        (bgen['samples'].shape[0], len(prs_thresholds))
    ),
    'h2': np.empty(
        (bgen['samples'].shape[0], len(prs_thresholds))
    )
}

prs_collector = {
    'h1': np.zeros((bgen['samples'].shape[0])),
    'h2': np.zeros((bgen['samples'].shape[0]))
}

prs_thresholds.sort()
prs_index = 0  # go through prs cutoffs from stringent to loss

for row_i in tqdm(range(10)):  # variant_annot.shape[0]
#     print(row_i)
    pval = variant_annot.pval[row_i]; print(pval)
    variant_idx = variant_annot.idx[row_i]
    effect_size = get_effect_size(
        variant_annot.beta[row_i],
        variant_annot.assigned_sign[row_i]
    )
    if pval > prs_thresholds[prs_index]:
        for h_index in outmat.keys():
            outmat[h_index][:, prs_index] = prs_collector[h_index]
        prs_index += 1
        if prs_index >= len(prs_thresholds):
            break
    ukb_hap = bgen["genotype"][variant_idx].compute()
    haplo = ukb_hap_to_haplo(ukb_hap['probs'])
    for h_index in haplo.keys():
        prs_collector[h_index] = prs_collector[h_index] + haplo[h_index] * effect_size

# if we run out of variants before going through all p-value cutoffs 
# remove these not used cutoffs
if prs_index != outmat['h1'].shape[1]:
    prs_thresholds = prs_thresholds[:prs_index]
    for h_index in outmat.keys():
            outmat[h_index] = outmat[h_index][:, :prs_index]

  0%|          | 0/10 [00:00<?, ?it/s]

3.42455e-07


 10%|█         | 1/10 [00:00<00:04,  1.93it/s]

3.66874e-05


 20%|██        | 2/10 [00:00<00:03,  2.02it/s]

4.18463e-05


 30%|███       | 3/10 [00:01<00:03,  1.99it/s]

5.7707399999999996e-05


 40%|████      | 4/10 [00:01<00:02,  2.07it/s]

9.34941e-05


 50%|█████     | 5/10 [00:02<00:02,  1.88it/s]

9.67742e-05


 60%|██████    | 6/10 [00:03<00:02,  1.91it/s]

0.00011214700000000001


 70%|███████   | 7/10 [00:03<00:01,  1.88it/s]

0.000143731


 80%|████████  | 8/10 [00:04<00:00,  2.03it/s]

0.00017396799999999999


 90%|█████████ | 9/10 [00:04<00:00,  2.03it/s]

0.000194697


100%|██████████| 10/10 [00:04<00:00,  2.03it/s]


In [133]:
outmat['h2'][:,1].sum()

107.194394

In [120]:
variant_annot.pval

0       3.424550e-07
1       3.668740e-05
2       4.184630e-05
3       5.770740e-05
4       9.349410e-05
            ...     
9388    9.996810e-01
9389    9.997860e-01
9390    9.998070e-01
9391    9.998190e-01
9392    9.998830e-01
Name: pval, Length: 9393, dtype: float64

In [71]:
variant_annot.pval[10]

0.498989

In [69]:
bgen['samples'].shape

(487409,)