**@author: James V. Talwar**<br>

# Generate VADEr Left and Right Shifted Patches:

**About:** This notebook generates SNP left and right feature patches as a f(x) of a defined radius (equivalent to patch size divided by 2) for enabling Shift Patch Tokenization (SPT) within VADEr. This notebook takes in employs both the genomic patch mappings and the location patch mappings generated in the `Generate_VADEr_Patches.ipynb`. *Shall we begin...* 




In [1]:
import pandas as pd
import os
from collections import defaultdict
import numpy as np
import joblib
import tqdm
import logging
logging.getLogger().setLevel(logging.INFO)

In [2]:
logger = logging.getLogger()
console = logging.StreamHandler()
logger.addHandler(console)

Define condition path (`write_path` defined in `Generate_VADEr_Patches.ipynb`):

In [3]:
conditionPath = "../../Data/Feature_Patches"

Define base paths for snp and location mappings:

In [4]:
snpSetSizes = ["250", "125", "500"] #KB of defined preprocessed patch sizes
pVals = ["5e-04", "5e-05", "5e-06", "5e-07", "5e-08"]

baseLocationMappingPath = os.path.join(conditionPath, "Patch_To_Chrom_Mapping/{}kb")
baseSNPMappingPath = os.path.join(conditionPath, "Patches_To_Features/{}kb")

Load in chromosome lengths for validation of patch correctness (SNPs at end of a patch):

In [6]:
chromosome_length_file = "../../Data/Reference_Genome_Build_Sizes/hg19.chrom.sizes" #<-- Update if not using hg19/grch37
chromosomeLengths = pd.read_csv(chromosome_length_file, sep = "\t", header = None)
chromosomeLengths.columns = ["CHR", "LENGTH"]
iteratables = ["chr" + str(i) for i in range(1,23)] + ["chrX", "chrY"]
chromLengthMap = {k:v for k,v in dict(zip(chromosomeLengths.CHR, chromosomeLengths.LENGTH)).items() if k in iteratables}
chromLengthMap

{'chr1': 249250621,
 'chr2': 243199373,
 'chr3': 198022430,
 'chr4': 191154276,
 'chr5': 180915260,
 'chr6': 171115067,
 'chr7': 159138663,
 'chr8': 146364022,
 'chr9': 141213431,
 'chr10': 135534747,
 'chr11': 135006516,
 'chr12': 133851895,
 'chr13': 115169878,
 'chr14': 107349540,
 'chr15': 102531392,
 'chr16': 90354753,
 'chr17': 81195210,
 'chr18': 78077248,
 'chr19': 59128983,
 'chr20': 63025520,
 'chr21': 48129895,
 'chr22': 51304566,
 'chrX': 155270560,
 'chrY': 59373566}

Define functions to generate left and right shifted patches:

In [7]:
def ExtractOverlaps(df):
    groupedDF = df.groupby("CHR")
    leftOverlaps = defaultdict(str) #mapping of a patch to the patch which if left shifted by radius would overlap (i.e. -1) --> need to get the second half of left patches
    rightOverlaps = defaultdict(str) #mapping of a patch to the patch which if right shifted by radius would overlap (i.e. + 1) --> need to get the first half of right patches 
    for chrom, chromDF in groupedDF:
        for i, row in chromDF.iterrows():
            omission = chromDF[chromDF.index != i] #omit current clump from search
            #check where left and right indices match (if any)
            leftMatch = omission[omission.Left_Overlap == row.ChromosomePatch]
            #rightMatch = omission[omission.Right_Overlap == row.ChromosomePatch] --> Equality condition - if a patch is a left then the converse is a right
            assert leftMatch.shape[0] <= 1, "Invalid mapping - only one patch pair can exist for an overlap"
            if leftMatch.shape[0] == 1:
                rightOverlaps[i] = leftMatch.index[0]
                leftOverlaps[leftMatch.index[0]] = i
            
    return leftOverlaps, rightOverlaps

In [8]:
'''
Inputs: 1) snps: list of snps 
        2) midpoint: int corresponding to the patch's midpoint
        3) first half: boolean corresponding whether want all SNPs in the first half of the patch (True), or second half of the patch
        4) radius: Integer corresponding to the patch radius 
'''
def ExtractShiftedSNPs(snps, midpoint, first_half, radius, chrom_end):
    shiftedSNPs = list()
    for snp in snps:
        assert "rs" not in snp, "SNP index deviates from expected CHR:POS:REF:ALT format"
        chrom, pos, ref, alt = snp.split(":")
        difference = midpoint - int(pos)
        if np.abs(difference) > radius:
            if np.abs(chrom_end - int(pos)) > radius: #check if is an end of SNP patch
                raise ValueError(f"Issue with patch size! {chrom} {pos} {difference}")  #"  {name}"
        
        if first_half:
            if (difference > 0):
                shiftedSNPs.append(snp)
                
        else:
            if difference <= 0:
                shiftedSNPs.append(snp)

    return shiftedSNPs

def GenerateShiftedPatches(snp_mapping, radius, patch_mapping, left_overlaps, right_overlaps, chrom_ends):
    leftShiftedPatches = defaultdict(list) #Defines the full left shifted patch - including SNPs from current patch
    rightShiftedPatches = defaultdict(list)
    
    for k,v in left_overlaps.items():
        currentChromosome = str(patch_mapping.loc[k, "CHR"])
        if currentChromosome == "23":
            currentChromosome = "X"
            
        chrom_end = chrom_ends["chr" + currentChromosome]
        currentPatchMidpoint = radius + 2*radius*patch_mapping.loc[k, "ChromosomePatch"]
        leftPatchMidpoint = radius + 2*radius*patch_mapping.loc[v, "ChromosomePatch"]
        
        currentPatchSNPs = snp_mapping[k]
        leftPatchSNPs = snp_mapping[v]
        
        relevantCurrentPatchSNPs = ExtractShiftedSNPs(snps = currentPatchSNPs,
                                                      midpoint = currentPatchMidpoint,
                                                      first_half = True, 
                                                      radius = radius, 
                                                      chrom_end = chrom_end)
        relevantLeftPatchSNPs = ExtractShiftedSNPs(snps = leftPatchSNPs,
                                                   midpoint = leftPatchMidpoint,
                                                   first_half = False, 
                                                   radius = radius, 
                                                   chrom_end = chrom_end)
        
        left_shifted_patch =  relevantLeftPatchSNPs + relevantCurrentPatchSNPs
        #Don't generate shifted patches if the patch is empty (i.e., both chromosome SNP patches fall outside of the desired midpoint) or if all the SNPs in shifted patch come from current patch
        if (len(left_shifted_patch) == 0) or (len(set(left_shifted_patch).difference(set(currentPatchSNPs))) == 0):
            continue 
        
        leftShiftedPatches[k] = left_shifted_patch
    
    for k,v in right_overlaps.items():
        currentChromosome = str(patch_mapping.loc[k, "CHR"])
        if currentChromosome == "23":
            currentChromosome = "X"
        
        chrom_end = chrom_ends["chr" + currentChromosome]
        
        
        currentPatchMidpoint = radius + 2*radius*patch_mapping.loc[k, "ChromosomePatch"]
        rightPatchMidpoint = radius + 2*radius*patch_mapping.loc[v, "ChromosomePatch"]
        
        currentPatchSNPs = snp_mapping[k]
        rightPatchSNPs = snp_mapping[v]
        
        relevantCurrentPatchSNPs = ExtractShiftedSNPs(snps = currentPatchSNPs,
                                                      midpoint = currentPatchMidpoint,
                                                      first_half = False, 
                                                      radius = radius, 
                                                      chrom_end = chrom_end)
        relevantRightPatchSNPs = ExtractShiftedSNPs(snps = rightPatchSNPs,
                                                    midpoint = rightPatchMidpoint,
                                                    first_half = True, 
                                                    radius = radius, 
                                                    chrom_end = chrom_end)
        
        right_shifted_patch =  relevantCurrentPatchSNPs + relevantRightPatchSNPs
        
        if (len(right_shifted_patch) == 0) or (len(set(right_shifted_patch).difference(set(currentPatchSNPs))) == 0):
            continue 
        
        rightShiftedPatches[k] = right_shifted_patch
            
    
    return leftShiftedPatches, rightShiftedPatches

Generate left and right patches for each SNP set and save in patch directory:

In [9]:
for windowSize in snpSetSizes:
    radius  = int(windowSize) * 1e3/2
    for p_val in tqdm.tqdm(pVals):
        #Load mapping of patch to SNPs
        snpMapping = joblib.load(os.path.join(baseSNPMappingPath, "{}_{}kb_DistanceClumps.joblib").format(p_val, windowSize))
        
        #Load positional patch mapping
        patchMapping = pd.read_csv(os.path.join(baseLocationMappingPath, "{}_{}kb_PositionalMaps.tsv").format(p_val, windowSize), sep = "\t", index_col = 0)
        patchMapping["Left_Overlap"] = patchMapping["ChromosomePatch"] - 1
        patchMapping["Right_Overlap"] = patchMapping["ChromosomePatch"] + 1
        
        #For given p val-window size set - map the overlapping left and right patches (i.e. consecutive patches) 
        left, right = ExtractOverlaps(patchMapping)
        
        #generate the shifted patches: 
        l,r = GenerateShiftedPatches(snp_mapping = snpMapping, 
                                     radius = radius, 
                                     patch_mapping = patchMapping,
                                     left_overlaps = left,
                                     right_overlaps = right, 
                                     chrom_ends = chromLengthMap)
        
        #save shifted patches:
        joblib.dump(l, os.path.join(baseSNPMappingPath, "{}_{}kb_Left_Shifted_Patches.joblib").format(p_val, windowSize))
        joblib.dump(r, os.path.join(baseSNPMappingPath, "{}_{}kb_Right_Shifted_Patches.joblib").format(p_val, windowSize))

100%|██████████| 5/5 [00:03<00:00,  1.41it/s]
100%|██████████| 5/5 [00:04<00:00,  1.15it/s]
100%|██████████| 5/5 [00:02<00:00,  1.74it/s]
