# Extracting data for genomic regions of interest

## Aim

To extract the summary statistics and genotype on specific genomic regions and calculate their LD matrix.

## Pre-requisites

Make sure you install the pre-requisited before running this notebook:

```
pip install pybgen
pip install bed-reader
pip install scipy
```

## Input and Output

### Input

- `--region-file`, including a list of regions
    - Each locus will be represented by one line in the region file with 3 columns chr, start, and end. e.g. `7 27723990 28723990`
- `--geno-path`, the path of a genotype inventory, which lists the path of all genotype file in `bgen` format or in `plink` format.
    - The list is a file with 2 columns: `chr genotype_file_chr.ext`. 
    - The first column is chromosome ID, the 2nd file is genotype for that chromosome.
    - When chromosome ID is 0, it implies that the genotype file contains all the genotypes.
- `--pheno-path`, the path of a phenotype.
    - The phenotype file should have a column with the name `IID`, which is used to represent the sample ID.
- `--bgen-sample-path`, the path of a file including the sample in the `bgen` files.
    - If the genotype file is in `bgen` format, you should provide this path.
- `--sumstats-path`, the path of the GWAS file, including all summary statistics (eg, $\hat{\beta}$, $SE(\hat{\beta})$ and p-values)
    - These summary statistics should contain at least these columns: `chrom, pos, ref, alt, snp_id, bhat, sbhat, p`
- `--unrelated-samples`, the file path of unrelated samples with a column named `IID`.   
- `--cwd`, the path of output directory

### Output
- `rg_stat`, the reginonal summary stats
    - The rowname is the variant ID.
    - It should contain at least the following columns: `CHR, BP, SNP, ALT, REF, BETA, SE, Z, P`.
- `rg_geno`,the regional genotypes
    - The rowname is the variant ID, which should match with the rowname of `rg_stat`.
    - The column name is the sample's IID, which is sorted by the sample in phenotype.
- `pld`, the regional approximate population LD calculated by unrelated individuals
- `sld`, the regional approximate sample LD calcualted by unrelated individuals in a phenotype.

## Workflow usage

Using our minimal working example data-set where we have already generated results for fastGWA,

### For `bgen` format input

```
sos run LMM.ipynb fastGWA \
    --cwd output \
    --bfile data/genotypes.bed \
    --sampleFile data/imputed_genotypes.sample \
    --genoFile data/imputed_genotypes_chr*.bgen \
    --phenoFile data/phenotypes.txt \
    --formatFile data/fastGWA_template.yml \
    --phenoCol BMI \
    --covarCol SEX \
    --qCovarCol AGE \
    --numThreads 1 \
    --bgenMinMAF 0.001 \
    --bgenMinINFO 0.1 \
    --parts 2 \
    --p-filter 1
```

```
sos run Region_Extraction.ipynb \
    --cwd candidate_loci \
    --region-file data/regions.txt \
    --pheno-path data/phenotypes.txt \
    --geno-path data/genotype_inventory.txt \
    --bgen-sample-path data/imputed_genotypes.sample \
    --sumstats-path output/phenotypes_BMI.fastGWA.snp_stats.gz \
    --unrelated-samples data/unrelated_samples.txt \
    --job-size 1
```

### For PLINK format input

```
sos run LMM.ipynb fastGWA \
    --cwd output \
    --bfile data/genotypes.bed \
    --genoFile data/genotypes21_22.bed \
    --phenoFile data/phenotypes.txt \
    --formatFile data/fastGWA_template.yml \
    --phenoCol BMI \
    --covarCol SEX \
    --qCovarCol AGE \
    --numThreads 1 \
    --bgenMinMAF 0.001 \
    --bgenMinINFO 0.1 \
    --parts 2 \
    --p-filter 1 
```

```
sos run Region_Extraction.ipynb \
    --cwd candidate_loci \
    --region-file data/regions_plink.txt \
    --pheno-path data/phenotypes.txt \
    --geno-path data/genotype_inventory_plink.txt \
    --bgen-sample-path data/imputed_genotypes.sample \
    --sumstats-path output/phenotypes_BMI.fastGWA.snp_stats.gz \
    --unrelated-samples data/unrelated_samples.txt \
    --job-size 1
```

## Workflow codes

In [2]:
[global]
# Work directory where output will be saved to
parameter: cwd = path
# Region specifications
parameter: region_file = path
# Genotype file inventory
parameter: geno_path = path
# Phenotype path
parameter: pheno_path = path
# Sample file path, for bgen format
parameter: bgen_sample_path = path('.')
# Path to summary stats file
parameter: sumstats_path = path
# Path to summary stats format configuration
parameter: format_config_path = path('.')
# Path to samples of unrelated individuals
parameter: unrelated_samples = path
# Number of tasks to run in each job on cluster
parameter: job_size = int
# Specify the container to use
parameter: container_lmm = 'statisticalgenetics/lmm:2.9'
fail_if(not region_file.is_file(), msg = 'Cannot find regions to extract. Please specify them using ``--region-file`` option.')
# Load all regions of interest. Each item in the list will be a region: (chr, start, end)
regions = list(set([tuple(x.strip().split()) for x in open(region_file).readlines() if x.strip()]))

## Some utility functions

In [1]:
[default_1 (export utils script)]
depends: Py_Module('xxhash'), Py_Module('pandas'), Py_Module('dask'), Py_Module('scipy')
parameter: scan_window = 500000
output: f'{cwd:a}/utils.py'
report:container=container_lmm, expand = '${ }', output=f'{cwd:a}/utils.py'

    import numpy as np
    import pandas as pd
    import dask.dataframe as dd
    from xxhash import xxh32 as xxh

    def shorten_id(x):
        return x if len(x) < 30 else f"{x.split('_')[0]}_{xxh(x).hexdigest()}"

    def read_sumstat(file, config_file):
        sumstats = pd.read_csv(file, compression='gzip', header=0, sep='\t', quotechar='"')
        if config_file is not None:
            import yaml
            config = yaml.safe_load(open(config_file, 'r'))
            try:
                sumstats = sumstats.loc[:,list(config.values())]
            except:
                raise ValueError(f'According to {config_file}, input summary statistics should have the following columns: {list(config.values())}.')
            sumstats.columns = list(config.keys())
        sumstats.SNP = sumstats.SNP.apply(shorten_id)
        sumstats.CHR = sumstats.CHR.astype(int)
        sumstats.POS = sumstats.POS.astype(int)
        return sumstats

    def regional_stats(sumstats, region):
        ss = sumstats[(sumstats.CHR == region[0]) & (sumstats.POS >= region[1]) & (sumstats.POS <= region[2])].copy()
        ss['Z'] = pd.Series(p2z(ss.P,ss.BETA))
        return ss

    from scipy.stats import norm
    def p2z(pval,beta,twoside=True):
        if twoside:
            pval = pval/2
        z=np.abs(norm.ppf(pval))
        ind=beta<0
        z[ind]=-z[ind]
        return z

    def plink_slice(p,region):
        bim = pd.DataFrame(dict(chrom = p.chromosome, pos = p.bp_position, a1 = p.allele_2, a0 = p.allele_1))
        fam = pd.DataFrame(dict(fid=p.fid,iid=p.iid))
        # see https://github.com/fastlmm/bed-reader
        variants = (p.chromosome==str(region[0])) & (p.bp_position >= region[1]) & (p.bp_position <= region[2])
        bim = bim[variants]
        bed = p.read(index=np.s_[:,variants]).T
        return bim, fam, bed

    def bgen_region(region,geno,dtype='float16'):
        snps,genos=[],[]
        i=0
        for t,g in geno[0].iter_variants_in_region('0'+str(region[0]) if region[0]<10 else str(region[0]),region[1],region[2]):
            snps.append([int(t.chrom),t.name,0.0,t.pos,t.a1,t.a2,i])
            genos.append(g.astype(dtype))
            i+=1
        return(pd.DataFrame(snps,columns=['chrom','snp','cm','pos','a0','a1','i']),np.array(genos))
    
    def check_unique(idx, variable):
        if idx.duplicated().any():
            raise ValueError(f"{variable} index has duplicated elements!")

    def extract_region(org_region, input_sumstats_path, input_format_config, geno_file, input_pheno_path, input_unrelated_samples, output_sumstats, output_genotype, output_pld, output_sld, output_general):
        import os
        
        # Load the file of summary statistics and standardize it.
        gwas = read_sumstat(input_sumstats_path, input_format_config)
        # Load phenotype file
        pheno = pd.read_csv(input_pheno_path, header=0, delim_whitespace=True, quotechar='"')
        # Load unrelated sample file
        unr = pd.read_csv(input_unrelated_samples, header=0, delim_whitespace=True, quotechar='"')
        
        if geno_file.endswith('.bed'):
            plink = True
            from bed_reader import open_bed
            geno = open_bed(geno_file)
        elif geno_file.endswith('.bgen'):
            plink = False
            from pybgen import PyBGEN
            bgen = PyBGEN(geno_file)
            sample_file = geno_file.replace('.bgen', '.sample')
            if not os.path.isfile(sample_file):
                if not os.path.isfile(${bgen_sample_path:r}):
                    raise ValueError(f"Cannot find the matching sample file ``{sample_file}`` for ``{geno_file}``.\nYou can specify path to sample file for all BGEN files using ``--bgen-sample-path``.")
                else:
                    sample_file = ${bgen_sample_path:r}
            bgen_fam = pd.read_csv(sample_file, header=0, delim_whitespace=True, quotechar='"',skiprows=1)
            bgen_fam.columns = ['fid','iid','missing','sex']
            geno = [bgen,bgen_fam]
        else:
            raise ValueError('Plesae provide the genotype files with PLINK binary format or BGEN format')

        # extraction starts here
        import gc
        import time
        t = time.localtime()
        # Extract the summary stat
        print(f'{time.strftime("%H:%M:%S", t)}: Extracting summary statistics ...')
        
        # chose the method of incrementing the regions by a certain amount and then doing all checking calculations to decrease
        # the time it takes for execution and to decrease the likelihood of reaching the memory capacity
        
        
        region_inc = ${scan_window} # size of the incrementer we will be doing
        curr_region_lbound = org_region[1] # the current left bound for the regions
        curr_region_rbound = org_region[1] + region_inc # the current right bound for the regions
        
        # WILL NEED THESE FOR LD MATRIX AND ONWARD
        iid_ph = []
        rg_stat_SNP = []
        phenoIID = []
        batch_id = 0

        pop_ld_ind = []
        sample_ld_ind = []

        while curr_region_lbound <= org_region[2]: # since we want to increment, we want to make sure our left bound is less than max right
            # check to see if right bound works
            if curr_region_rbound < org_region[2]:
                sub_region = (org_region[0], curr_region_lbound, curr_region_rbound)
            else:
                sub_region = (org_region[0], curr_region_lbound, org_region[2])
                
            # increment for the next iteration
            curr_region_lbound += region_inc + 1
            curr_region_rbound += region_inc + 1
            
            # call and do checks on rg_stat
            rg_stat = regional_stats(gwas, region) # only calling on a fraction of the region
            rg_stat.index = rg_stat.CHR.astype(str) + '_' + rg_stat.POS.astype(str) + '_' + rg_stat.REF.astype(str) + '_' + rg_stat.ALT.astype(str)
            print(f'The regional summary statistics of {sub_region[0]}_{sub_region[1]}_{sub_region[2]} has {len(rg_stat.index)} variants')
            check_unique(rg_stat.index, "Summary statistics")
            
            # geno, pheno, unr, and plink are defined prior to the while loop
            print(f'{time.strftime("%H:%M:%S", t)}: Extracting genotypes in {"plink" if plink else "bgen"} format ...')
            if plink:
                rg_bim,rg_fam,rg_bed = plink_slice(geno,sub_region)
            else:
                rg_bim,rg_bed=bgen_region(sub_region,geno,dtype='float16')
                rg_fam = geno[1]
            print(f'{time.strftime("%H:%M:%S", t)}: Checking SNP and sample IDs ...')
            # FIXME: why do we have duplicates? Let's see in practice how many duplicates are reported. I hope none.
            rg_bim.index = rg_bim.chrom.astype(str) + '_' + rg_bim.pos.astype(str) + '_' + rg_bim.a1.astype(str) + '_' + rg_bim.a0.astype(str)
            check_unique(rg_bim.index, 'SNPs in reference genotype')
            rg_fam.index = rg_fam.iid
            check_unique(rg_fam.index, 'FAM info')
            rg_bed = pd.DataFrame(rg_bed,index=rg_bim.index,columns=rg_fam.index)
            exclude_idx = rg_bed.index.duplicated(keep='first')

            exc = []
            i = 0
            for each in exclude_idx:
                if each == True:
                    exc.append(i)
                i += 1
            rg_bed.drop(exc, inplace=True)
            
            print(f'The regional genotype file of {sub_region[0]}_{sub_region[1]}_{sub_region[2]} has {len(rg_bed.index)} variants')
            if not list(rg_stat.index)==list(rg_bed.index):
               # overlapping variants
                com_row_idx = rg_bed.index.intersection(rg_stat.index)
                if len(com_row_idx) == 0:
                    print("Variants ID between sub-region summary statistics and reference genotype are non-overlapping. This sub-region is skipped.")
                    continue
                print(f'The regional genotype file ({len(rg_bed.index)} variants) and the regional summary statistics ({len(rg_stat.index)} variants) do not match with each other. The overlapping variants ({len(com_row_idx)} variants) will be selected.')
                rg_stat = rg_stat.loc[com_row_idx,:]
                rg_bed = rg_bed.loc[com_row_idx,:]
                
            temp_iid_unr = rg_fam.index.intersection(pd.Index(unr.IID)) # iid_unr

            pheno.index = pheno.IID
            check_unique(pheno.index, "Phenotype")
            temp_iid_ph = pheno.index.intersection(rg_fam.index) # iid_ph

            # mean imputation for missing genotypes
            rg_bed.fillna( rg_bed.mean(), inplace = True )

            temp_three_intersec = pd.Index(temp_iid_unr).intersection(temp_iid_ph)
                

            batch_id += 1
            rg_stat.to_pickle(f'{output_sumstats + ".batch_" + str(batch_id) + ".pickle"}')
            rg_bed.loc[:,temp_iid_ph].to_pickle(f'{output_genotype + ".batch_" + str(batch_id) + ".pickle"}')

            pop_ld = rg_bed.loc[:,temp_iid_unr]
            pop_ld_ind.extend(pop_ld.index.to_list())
            pop_ld = pop_ld.to_numpy(dtype='float32')
            np.save(f'{output_general + "pre_pop_ld.batch_" + str(batch_id) + ".npy"}', pop_ld, allow_pickle=True)

            sample_ld = rg_bed.loc[:,temp_three_intersec]
            sample_ld_ind.extend(sample_ld.index.to_list())
            sample_ld = sample_ld.to_numpy(dtype='float32')
            np.save(f'{output_general + "pre_sample_ld.batch_" + str(batch_id) + ".npy"}', sample_ld, allow_pickle=True)
    
            for each in temp_iid_ph: #order based on pheno
                iid_ph.append(each)
            for each in rg_stat.SNP:
                rg_stat_SNP.append(each)
            for each in pheno.IID:
                phenoIID.append(each)
                
            gc.collect()

        # genotypes in the sample of a specific phenotype with ordering match
        if not iid_ph == phenoIID:
            print('Warning: Some samples with phenotype do not have genotypes')    

        # merge data into CSV files
        print(f'{time.strftime("%H:%M:%S", t)}: Merging data batches ...')
        if batch_id == 0:
            raise ValueError("Region data extraction failed because variants ID between region summary statistics and reference genotype are completely non-overlapping.")
        rg_stat = pd.concat([pd.read_pickle(f'{output_sumstats + ".batch_" + str(b+1) + ".pickle"}') for b in range(batch_id)])
        rg_stat.to_csv(output_sumstats, sep = "\t", header = True, index = True)
        rg_bed = pd.concat([pd.read_pickle(f'{output_genotype + ".batch_" + str(b+1) + ".pickle"}') for b in range(batch_id)])
        rg_bed.to_csv(output_genotype, sep = "\t", header = True, index = True)

        import pickle
        
        ld = dict()
        
        ld["rg_bed"] = np.concatenate([np.load(f'{output_general + "pre_pop_ld.batch_" + str(b+1) + ".npy"}', allow_pickle=True) for b in range(batch_id)], axis=0)
        ld["index"] = pop_ld_ind
    
        with open(f'{output_general + ".pre_pop_ld.pickle"}', 'wb') as handle:
            pickle.dump(ld, handle)

        ld["rg_bed"] = np.concatenate([np.load(f'{output_general + "pre_sample_ld.batch_" + str(b+1) + ".npy"}', allow_pickle=True) for b in range(batch_id)], axis=0)
        ld["index"] = sample_ld_ind

        with open(f'{output_general + ".pre_sample_ld.pickle"}', 'wb') as handle:
            pickle.dump(ld, handle)


        import os
        for b in range(batch_id):
            os.remove(f'{output_sumstats + ".batch_" + str(b+1) + ".pickle"}')
            os.remove(f'{output_genotype + ".batch_" + str(b+1) + ".pickle"}')
            os.remove(f'{output_general + "pre_pop_ld.batch_" + str(b+1) + ".npy"}')
            os.remove(f'{output_general + "pre_sample_ld.batch_" + str(b+1) + ".npy"}')


    def get_ld(geno_file, output_file):
        import pickle
        with open(geno_file, 'rb') as handle:
            b = pickle.load(handle)
        

        index = b["index"]
        x = b["rg_bed"]

        batch_size = 5
        curr = 0

        stdev_percol = []
        means = []

        # for each row in x, compute the stdev and append it
        for _, i in enumerate(x):
            mean = i.mean()
            i = i - mean
            i = np.dot(i,i)
            stdev_percol.append(np.sqrt(i))
            means.append(mean)
        
        mylis = [[0 if i != j else 1 for i in range(len(x))] for j in range(len(x))] 
        # first's row information
        for i in range(len(x)):
            row = []
            f = x[i] - means[i]

            # second's row information
            for j in range(i+1, len(x)):
                s = x[j] - means[j]
                a = np.dot(f,s)

                val = a / (stdev_percol[i] * stdev_percol[j])

                mylis[i][j] = val
                mylis[j][i] = val

        corr = pd.DataFrame(mylis, columns=index)
        corr.to_csv(output_file, sep = "\t", header = True, index = False, mode='w')

        import os
        os.remove(f'{geno_file}')


## Extract data

This step runs in parallel for all loci listed in the region file (via `for_each`).

In [1]:
[default_2 (extract genotypes)]
depends: Py_Module('bed_reader'), Py_Module('pybgen'), f'{cwd:a}/utils.py'
input: geno_path, pheno_path, sumstats_path, unrelated_samples, for_each = 'regions'
output: sumstats = f'{cwd:a}/{_regions[0]}_{_regions[1]}_{_regions[2]}/{sumstats_path:bn}_{_regions[0]}_{_regions[1]}_{_regions[2]}.sumstats.gz',
        genotype = f'{cwd:a}/{_regions[0]}_{_regions[1]}_{_regions[2]}/{sumstats_path:bn}_{_regions[0]}_{_regions[1]}_{_regions[2]}.genotype.gz',
        pld = f'{cwd:a}/{_regions[0]}_{_regions[1]}_{_regions[2]}/{sumstats_path:bn}_{_regions[0]}_{_regions[1]}_{_regions[2]}.pre_pop_ld.pickle',
        sld = f'{cwd:a}/{_regions[0]}_{_regions[1]}_{_regions[2]}/{sumstats_path:bn}_{_regions[0]}_{_regions[1]}_{_regions[2]}.pre_sample_ld.pickle'
task: trunk_workers = 1, trunk_size = job_size, walltime = '4h', mem = '60G', cores = 1, tags = f'{step_name}_{_output[0]:bn}'
python: container=container_lmm, expand = '${ }', input = f'{cwd:a}/utils.py', stderr = f'{_output[0]:n}.stderr', stdout = f'{_output[0]:n}.stdout'
    

    import os
    # output path files that we will need in our final version
    output_sumstats = ${_output['sumstats']:r}
    output_genotype = ${_output['genotype']:r}
    output_pld = ${_output['pld']:r}
    output_sld = ${_output['sld']:r}

    # this general path is used to create other temporary files that we need to calculate the ld matrices later on
    cwd = os.getcwd()
    output_general = '${cwd}/${_regions[0]}_${_regions[1]}_${_regions[2]}/${sumstats_path:bn}_${_regions[0]}_${_regions[1]}_${_regions[2]}'

    input_sample_path = ${bgen_sample_path:r}
    input_geno_path = ${_input[0]:r}
    input_pheno_path = ${_input[1]:r}
    input_sumstats_path = ${_input[2]:r}
    input_unrelated_samples = ${_input[3]:r}
    input_format_config = ${format_config_path:r} if ${format_config_path.is_file()} else None

    
    # Load genotype file for the region of interest
    geno_inventory = dict([x.strip().split() for x in open(${_input[0]:r}).readlines() if x.strip()])
    chrom = "${_regions[0]}"
    if chrom.startswith('chr'):
        chrom = chrom[3:]
    if chrom not in geno_inventory:
        geno_file = geno_inventory['0']
    else:
        geno_file = geno_inventory[chrom]


    if not os.path.isfile(geno_file):
        # relative path
        if not os.path.isfile('${_input[0]:ad}/' + geno_file):
            raise ValueError(f"Cannot find genotype file {geno_file}")
        else:
            geno_file = '${_input[0]:ad}/' + geno_file


    region = (int(chrom), ${_regions[1]}, ${_regions[2]})
    rg_info = extract_region(region, input_sumstats_path, input_format_config, geno_file, input_pheno_path, input_unrelated_samples,
                                output_sumstats, output_genotype, output_pld, output_sld, output_general)

In [None]:
[default_3 (compute LD sld)]
output: sld = f"{_input['sld']:nn}.sample_ld.gz"
task: trunk_workers = 1, trunk_size = job_size, walltime = '24h', mem = '64G', cores = 4, tags = f'{step_name}_{_output[0]:bn}'
python: container=container_lmm, expand = '${ }', input = f'{cwd:a}/utils.py', stderr = f'{_output[0]:n}.stderr', stdout = f'{_output[0]:n}.stdout'
    get_ld(${_input["sld"]:r}, ${_output["sld"]:r})
        

In [None]:
[default_4 (compute LD pld)]
input: output_from('default_2')
output: pld = f"{_input['pld']:nn}.population_ld.gz"
task: trunk_workers = 1, trunk_size = job_size, walltime = '24h', mem = '64G', cores = 4, tags = f'{step_name}_{_output[0]:bn}'
python: container=container_lmm, expand = '${ }', input = f'{cwd:a}/utils.py', stderr = f'{_output[0]:n}.stderr', stdout = f'{_output[0]:n}.stdout'
    get_ld(${_input["pld"]:r}, ${_output["pld"]:r})
        