# Packages and functions

In [1]:
data_dir = '/pollard/data/projects/kgjoni/'

import sys
sys.path.insert(0, f"{data_dir}Akita/akita_imaging/")
import pred_utils_1mb_KG as pred

sys.path.insert(0, f'{data_dir}Akita_variant_scoring/scripts/')
import OLD_utils as utils
import numpy as np
import pandas as pd
import pysam
import os

half_patch_size = 2**19
MB = 2*half_patch_size
pixel_size = 2048
bins = 448



Model: "functional_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
sequence (InputLayer)           [(None, 1048576, 4)] 0                                            
__________________________________________________________________________________________________
stochastic_reverse_complement ( ((None, 1048576, 4), 0           sequence[0][0]                   
__________________________________________________________________________________________________
stochastic_shift (StochasticShi (None, 1048576, 4)   0           stochastic_reverse_complement[0][
__________________________________________________________________________________________________
re_lu (ReLU)                    (None, 1048576, 4)   0           stochastic_shift[0][0]           
_______________________________________________________________________________________

Model: "functional_7"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
sequence (InputLayer)           [(None, 1048576, 4)] 0                                            
__________________________________________________________________________________________________
stochastic_reverse_complement_1 ((None, 1048576, 4), 0           sequence[0][0]                   
__________________________________________________________________________________________________
stochastic_shift_1 (StochasticS (None, 1048576, 4)   0           stochastic_reverse_complement_1[0
__________________________________________________________________________________________________
re_lu_42 (ReLU)                 (None, 1048576, 4)   0           stochastic_shift_1[0][0]         
_______________________________________________________________________________________

In [2]:

import io

def read_vcf(path):
    
    '''
    Read  vcf files into dataframe.
    Source: https://gist.github.com/dceoy/99d976a2c01e7f0ba1c813778f9db744.
    
    '''
    
    with open(path, 'r') as f:
        lines = [l for l in f if not l.startswith('##')]
        
    return pd.read_csv(
        io.StringIO(''.join(lines)),
        dtype={'#CHROM': str, 'POS': int, 'ID': str, 'REF': str, 'ALT': str,
               'QUAL': str, 'FILTER': str, 'INFO': str},
        sep='\t'
    ).rename(columns={'#CHROM': 'CHROM'})




def read_vcf_gz(path):
    
    '''
    Read  gzipped vcf files into dataframe.
    Adapted from: https://gist.github.com/dceoy/99d976a2c01e7f0ba1c813778f9db744.
    
    '''
    
    with io.TextIOWrapper(gzip.open(path,'r')) as f:
        lines =[l for l in f if not l.startswith('##')]

    return pd.read_csv(
            io.StringIO(''.join(lines)),
            dtype={'#CHROM': str, 'POS': int, 'ID': str, 'REF': str, 'ALT': str,
                   'QUAL': str, 'FILTER': str, 'INFO': str},
            sep='\t'
        ).rename(columns={'#CHROM':'CHROM'})



def read_input(in_file):

    '''
    Read and reformat variant dataset. Accepted formats are .vcf .vcf.gz from 4.1 version, 
    .bed file with the following columns: [CHROM, POS, REF, ALT, END, SVTYPE, SVLEN], 
    and .tsv from ANNOVAR annotSV.
    
    '''
    
    
    if 'vcf' in in_file:
        
        # For gzipped files
        if in_file.endswith('.gz'):
            variants = read_vcf_gz(in_file)
        else:
            variants = read_vcf(in_file)
            
        # Read SVs
        if any(['SVTYPE' in x for x in variants.INFO]):
            
            variants['END'] = variants.INFO.str.split('END=').str[1].str.split(';').str[0] # this SVLEN (END-POS) would be 0 for SNPs
            variants.loc[~pd.isnull(variants.END), 'END'] = variants.loc[~pd.isnull(variants.END), 'END'].astype('int')
            variants['SVTYPE'] = variants.INFO.str.split('SVTYPE=').str[1].str.split(';').str[0]
            variants['SVLEN'] = variants.INFO.str.split('SVLEN=').str[1].str.split(';').str[0]
            variants = variants[['CHROM', 'POS', 'END', 'REF', 'ALT', 'SVTYPE', 'SVLEN']]

        # Read simple variants 
        else:
            
            variants = variants[['CHROM', 'POS', 'REF', 'ALT']]       
            
 
        
    elif 'bed' in in_file:
        
        colnames = ['CHROM', 'POS', 'REF', 'ALT', 'END', 'SVTYPE', 'SVLEN']
        ncols = len(pd.read_csv(in_file, sep = '\t', nrows = 0, low_memory=False).columns)

        variants = pd.read_csv(in_file, sep = '\t', names = colnames[:ncols], low_memory=False)
        
        
    elif 'tsv' in in_file:
        
        with (gzip.open if in_file.endswith(".gz") else open)(in_file, "rt", encoding="utf-8") as variants:
            variants = (pd.read_csv(in_file, sep = '\t', low_memory=False)
                        .rename(columns = {'SV_chrom':'CHROM', 
                                           'SV_start':'POS',
                                           'SV_end':'END', 
                                           'SV_type':'SVTYPE',
                                           'SV_length':'SVLEN'})
                       [['CHROM', 'POS', 'END', 'REF', 'ALT', 'SVTYPE', 'SVLEN']])
            variants['CHROM'] = ['chr' + str(x) for x in variants['CHROM']]
            variants.loc[~pd.isnull(variants.END), 'END'] = variants.loc[~pd.isnull(variants.END), 'END'].astype('int')

            
    else:
        raise ValueError('Input file type not accepted. Make sure it has the right extension.')
        
        
    variants.reset_index(inplace = True, drop = True)
    
    
    return variants


In [3]:
# Get inputs from all arguments
file_format = 'vcf'
var_type = 'SV' 
fasta_path = '/pollard/data/vertebrate_genomes/human/hg38/hg38/hg38.fa'
chrom_lengths_path = 'data/chrom_lengths_hg38'
centromere_coords_path = 'data/centromere_coords_hg38'
scores_to_use = ['mse', 'corr']
revcomp = False
no_revcomp = True
svlen_limit = 700000
seq_len = 1048576
get_seq = True
get_scores = True
shift_by = [0]
out_file = 'score_var_results'
out_dir = 'temp'
augment = False
get_maps = False
get_tracks = False
var_set_size = 1000

sys.path.insert(0, '../scripts')


# Get example input file for each scenario
input_file_types_dir = '/pollard/data/projects/kgjoni/CBTN_collab/input_file_types/'

if file_format == 'df':
    if var_type == 'simple':
        in_file = f'{input_file_types_dir}d2acb167-24d7-4747-a7f8-de98205ad91a.consensus_somatic.norm.annot.public.maf'
    elif var_type == 'SV':
#         in_file = f'{input_file_types_dir}test.tsv'
        in_file = f'{input_file_types_dir}909d602a-480c-4d70-ac31-05192f76eb08.manta.PASS.annotated.tsv.gz'
elif file_format == 'vcf':
    if var_type == 'simple':
        in_file = f'{input_file_types_dir}e024fa3c-f459-44f5-aa0b-fc08eb5aea07.consensus_somatic.public.vcf'
    elif var_type == 'SV':
        in_file = f'{input_file_types_dir}30e05b01-b2be-477c-93eb-8e3625acb55a.somaticSV.vcf.gz'
         
# in_file = 'test/input/subset.somaticSV.vcf'        
in_file = 'test/input/run_test_SV.bed'




# # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Adjust inputs from arguments


# Handle argument dependencies

if seq_len != 1048576:
    get_scores = False
    if svlen_limit > 0.66*seq_len:
        svlen_limit = 0.66*seq_len
        
if not revcomp and not no_revcomp:
    raise ValueError('Either revcomp and/or no_revcomp must be True.')
if not get_seq and not get_scores:
    raise ValueError('Either get_seq and/or get_scores must be True.')
    

# Adjust shift input: Remove shifts that are outside of allowed range
max_shift = 0.4*seq_len
shift_by = [x for x in shift_by if x > -max_shift and x < max_shift]


# Adjust input for taking the reverse compliment
revcomp_decision = []

if no_revcomp:
    revcomp_decision.append(False)
if revcomp:
    revcomp_decision.append(True)
revcomp_decision_i = revcomp_decision

    
# Adjust input for taking the average score from augmented sequences
if augment:
    shift_by = [-1,0,1]
    revcomp_decision = [True, False]
#     scores_to_use = [x for x in scores_to_use if x in ['mse', 'corr']]


# Create dictionaries to save sequences, maps, and disruption score tracks, if specified
if get_seq:
    sequences = {}
    
if get_maps:
    variant_maps = {}
    if get_scores == False:
        get_scores = True
        print('Must get scores to get maps: get_scores was set to True.')
    
if get_tracks:
    variant_tracks = {}
    if get_scores == False:
        get_scores = True
        print('Must get scores to get maps: get_scores was set to True.')


    






# # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Get necessary files if they are not there

import os
from pathlib import Path

chrom_lengths_path = 'data/chrom_lengths_hg38'
if not Path(chrom_lengths_path).is_file():
    os.system('wget -P ./data/ https://raw.githubusercontent.com/ketringjoni/Akita_variant_scoring/main/data/chrom_lengths_hg38')
    print('Chromosome lengths file downloaded as data/chrom_lengths_hg38.')

centromere_coords_path = 'data/centromere_coords_hg38'
if not Path(centromere_coords_path).is_file():
    os.system('wget -P ./data/ https://raw.githubusercontent.com/ketringjoni/Akita_variant_scoring/main/data/centromere_coords_hg38')
    print('Centromere coordinates file downloaded as data/centromere_coords_hg38.')

if fasta_path == 'data/hg38.fa' and not Path(fasta_path).is_file():
    os.system('wget -P ./data/ https://hgdownload.soe.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz')
    os.system('gunzip data/hg38.fa.gz')
    print('Fasta file downloaded as data/hg38.fa.')

    
if not os.path.exists(out_dir):
    os.mkdir(out_dir)
out_file = os.path.join(out_dir, out_file)
 
    
    

    
# # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Read in (and adjust) data 
    

    

import pandas as pd
import pysam

chrom_lengths = pd.read_table(chrom_lengths_path, header = None, names = ['CHROM', 'chrom_max'])
centromere_coords = pd.read_table(centromere_coords_path, sep = '\t')
fasta_open = pysam.Fastafile(fasta_path)


# Assign necessary values to variables across module

# Module 1: reading utilities
import reading_utils
reading_utils.var_set_size = var_set_size


# Module 2: get_seq utilities
import get_seq_utils
get_seq_utils.fasta_open = fasta_open
get_seq_utils.chrom_lengths = chrom_lengths
get_seq_utils.centromere_coords = centromere_coords

get_seq_utils.svlen_limit = svlen_limit
get_seq_utils.seq_length = seq_len
get_seq_utils.half_patch_size = round(seq_len/2)


# Module 2: get_scores utilities
if get_scores:
    import get_scores_utils
    get_scores_utils.chrom_lengths = chrom_lengths
    get_scores_utils.centromere_coords = centromere_coords

    

    
    
    




# Another way to do the filtering out:
# filtered_out = [
#     # Exclude mitochondrial variants
#     list(variants[variants.CHROM == 'chrM'].var_index),
#     # Exclude variants larger than limit
#     [y for x,y in zip(variants.SVLEN, variants.var_index) if not pd.isnull(x) and abs(int(x)) > svlen_limit]]


# add this to print in for loop for testing: str(i+1)+'/'+str(len(variants))



# if yoou want to make gzip an argument:

# parser.add_argument('--gz_input',
#                     help = 'Use a gzipped input file', 
#                     action = 'store_const',
#                     const = True, 
#                     default = 'No',
#                     required = False)






# Variants that should bring up each error:
# centromeric variant: chr1:122845747-123875879
# N composition > 5%: chr2:89274001-89283500
# Variant larger than prediction window: chr1:30651501-35842000
# Insertions or other types that are not supported

# Other
# Less than prediction window but too large to get score (nan): chr1:122845747-123875879
# BND variants: 7: t[p[; 10: [p[t; 11: t]p]; 9: ]p]t
# Different variant types:
    # chrom_start chr2:400000-450000
    # chrom_centro_left chr10:37800000-37900000
    # chrom_centro_right  chr14:17600000-17700000
    # chrom_end chr1:248746420-248766420
    
    
# chr1	8144964	8145340	CTTT	C	DEL	-3
# chr2	54802149	NaN	C	[chr19:3989901[C	BND	NaN
# chr10	132139208	132139430	G	<DUP:TANDEM>	DUP	222
# chr4	186172375	186177065	C	<DEL>	DEL	-4690
# chr7	66286935	66286945	T	<INS>	INS	NaN




Model: "functional_11"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
sequence (InputLayer)           [(None, 1048576, 4)] 0                                            
__________________________________________________________________________________________________
stochastic_reverse_complement_2 ((None, 1048576, 4), 0           sequence[0][0]                   
__________________________________________________________________________________________________
stochastic_shift_2 (StochasticS (None, 1048576, 4)   0           stochastic_reverse_complement_2[0
__________________________________________________________________________________________________
re_lu_84 (ReLU)                 (None, 1048576, 4)   0           stochastic_shift_2[0][0]         
______________________________________________________________________________________

# Scoring pipeline

In [305]:
i = 42
shift = 0
rev_comp = False

variant = variants.iloc[i]

CHR = variant.CHROM
POS = variant.POS
REF = variant.REF
ALT = variant.ALT

if 'END' in variants.columns:
    END = variant.END
    SVTYPE = variant.SVTYPE
else:
    END = np.nan
    SVTYPE = np.nan
    
sequences_i = get_sequences_SV(CHR, POS, REF, ALT, END, SVTYPE, shift, revcomp)

In [None]:
i = 547

get_seq = False
get_scores = True

augment = True
shift_by = [0]
revcomp_decision = [False]
scores_to_use = ['corr', 'mse']

get_tracks = False
get_maps = False

sequences = {}
variant_maps = {}
variant_tracks = {}
variant_scores = pd.DataFrame({'var_index':variants.index})


for i in variants.index:

    variant = variants.iloc[i]

    var_index = variant.var_index
    CHR = variant.CHROM
    POS = variant.POS
    REF = variant.REF
    ALT = variant.ALT


    if 'SVTYPE' in variants.columns:
        END = variant.END
        SVTYPE = variant.SVTYPE
        SVLEN = variant.SVLEN
    else:
        END = np.nan
        SVTYPE = np.nan
        SVLEN = 0

    for shift in shift_by:

        if augment: # if getting augmented score, take reverse complement only with 0 shift
            if shift != 0 & True in revcomp_decision:
                revcomp_decision_i = [False]
            else:
                revcomp_decision_i = revcomp_decision

        for revcomp in revcomp_decision_i:

            try:

                if revcomp:
                    revcomp_annot = '_revcomp'
                else:
                    revcomp_annot = ''

                sequences_i = get_seq_utils.get_sequences_SV(CHR, POS, REF, ALT, END, SVTYPE, shift, revcomp)


                if get_seq:

                    # Get relative position of variant in sequence
                    var_rel_pos = str(sequences_i[-1]).replace(', ', '_')

                    for ii in range(len(sequences_i[:-1][:3])): 
                        sequences[f'{var_index}_{shift}{revcomp_annot}_{ii}_{var_rel_pos}'] = sequences_i[:-1][ii]

                if get_scores:

                    scores = get_scores_utils.get_scores(POS, SVTYPE, SVLEN, 
                                                         sequences_i, scores_to_use, 
                                                         shift, revcomp, 
                                                         get_tracks, get_maps)

                    if get_tracks:
                        for track in [x for x in scores.keys() if 'track' in x]:
                            variant_tracks[f'{var_index}_{track}_{shift}{revcomp_annot}'] = scores[track]
                            del scores[track]

                    if get_maps:
                        variant_maps[f'{var_index}_{shift}{revcomp_annot}'] = scores['maps']
                        del scores['maps']

                    for score in scores:
                        variant_scores.loc[variant_scores.var_index == var_index, 
                                           f'{score}_{shift}{revcomp_annot}'] = scores[score]


                print(str(var_index) + ' (' + str(shift) + f' shift{revcomp_annot})')

            except Exception as e: 

                print(str(var_index) + ' (' + str(shift) + f' shift{revcomp_annot})' + ': Error:', e)

                pass




In [44]:
# Check output

out_file = 'temp/score_var_results'

pd.read_csv(f'{out_file}_scores', sep = '\t')

# pd.read_csv(f'{out_file}_filtered_out', sep = '\t')

# np.load(f'{out_file}_tracks.npy', allow_pickle="TRUE").item()

# np.load(f'{out_file}_maps.npy', allow_pickle="TRUE").item()

# fasta_open2 = pysam.Fastafile(f'{out_file}_sequences.fa')
# fasta_open2.references
# fasta_open2.fetch(fasta_open2.references[0], 0, 100).upper()

Unnamed: 0,var_index,mse_0,corr_0
0,0,0.016132,0.973913
1,1,0.007965,0.984667
2,2,0.005242,0.995228
3,3,0.009628,0.958351
4,8,0.010467,0.976499
5,9,0.000834,0.979311
6,10,0.001901,0.950111
7,11,0.000189,0.968983
8,12,0.003523,0.94152
9,13,,


# Tests

## Edge cases

In [93]:
from itertools import product

In [95]:
# Generate test set for simple variants

# Features to test:
SVTYPE2 = ['SNP', 'del', 'ins']
var_pos_coord = {'chrom_mid':['chr1', 244397848],
                  'chrom_start':['chr2', 400000],
                  'chrom_centro_left':['chr10', 37900000], #37900000
                  'chrom_centro_right':['chr3', 94400000], #14: 17600000
                  'chrom_end':['chr1', 248766420],
                  'centromere':['chr11', 52000000]}

# Create variants dataframe
run_test_simple = pd.DataFrame(tuple(product(var_pos_coord.items(), SVTYPE2)), 
                         columns = ['var_pos_coord', 'SVTYPE2'])

# Split up var_pos_coord
run_test_simple[['var_position', 'coord']] = pd.DataFrame(run_test_simple['var_pos_coord'].tolist(), 
                                                          index=run_test_simple.index)
run_test_simple[['CHROM', 'POS']] = pd.DataFrame(run_test_simple['coord'].tolist(), index=run_test_simple.index)
run_test_simple.drop(['var_pos_coord', 'coord'], axis = 1, inplace = True)


# Get REF alleles and remaining ALT alleles
nt = ['A', 'T', 'C', 'G']
for i in run_test_simple.index:
    
    variant = run_test_simple.loc[i]

    CHROM = variant.CHROM
    POS = variant.POS
    SVTYPE2 = variant.SVTYPE2
    
    if SVTYPE2 == 'del':
        REF = fasta_open.fetch(CHROM, POS - 1, POS - 1 + 50).upper()
        run_test_simple.loc[i,'REF'] = REF
        run_test_simple.loc[i,'ALT'] = REF[0]
    elif SVTYPE2 == 'ins':
        ALT = fasta_open.fetch(CHROM, POS - 1, POS - 1 + 50).upper()
        run_test_simple.loc[i,'ALT'] = ALT
        run_test_simple.loc[i,'REF'] = ALT[0]
    elif SVTYPE2 == 'SNP':
        REF = fasta_open.fetch(CHROM, POS - 1, POS).upper()
        run_test_simple.loc[i,'REF'] = REF
        run_test_simple.loc[i,'ALT'] = [x for x in nt if x != REF][0]
        
run_test_simple[['CHROM', 'POS', 'POS', 'REF', 'ALT']].to_csv('test_data/test_set_edge_cases/test_set_edge_simple.bed', 
                                                              sep = '\t', index = False, header = False)

In [96]:
# Generate test set for SVs

# Features to test:
SVTYPE2 = ['del', 'DEL', 'DUP', 'INV']
SVLEN = [500,5000]
var_pos_coord = {'chrom_mid':['chr1', 244397848],
                  'chrom_start':['chr2', 400000],
                  'chrom_centro_left':['chr10', 37600000],
                  'chrom_centro_right':['chr3', 94400000],
                  'chrom_end':['chr1', 248766420],
                  'centromere':['chr11', 52000000]}

# Create variants dataframe
run_test_SV = pd.DataFrame(tuple(product(var_pos_coord.items(), SVLEN, SVTYPE2)), 
                         columns = ['var_pos_coord', 'SVLEN', 'SVTYPE2'])

# Split up var_pos_coord
run_test_SV[['var_position', 'coord']] = pd.DataFrame(run_test_SV['var_pos_coord'].tolist(), index=run_test_SV.index)
run_test_SV[['CHROM', 'POS']] = pd.DataFrame(run_test_SV['coord'].tolist(), index=run_test_SV.index)
run_test_SV.drop(['var_pos_coord', 'coord'], axis = 1, inplace = True)

# Adust SVTYPE and ALT alleles
run_test_SV[['SVTYPE']] = run_test_SV[['SVTYPE2']]
run_test_SV.loc[run_test_SV.SVTYPE2 == 'del', 'SVTYPE'] = 'DEL'
run_test_SV.loc[run_test_SV.SVTYPE2 == 'DUP', 'ALT'] = '<DUP:TANDEM>'
run_test_SV.loc[run_test_SV.SVTYPE2 == 'INV', 'ALT'] = '<INV>'
run_test_SV.loc[run_test_SV.SVTYPE2 == 'DEL', 'ALT'] = '<DEL>'

# Remove del alleles that are 5000bp long
run_test_SV['END'] = run_test_SV.POS+ run_test_SV.SVLEN
run_test_SV = run_test_SV[~((run_test_SV.SVLEN == 5000) & (run_test_SV.SVTYPE2 == 'del'))]


# Get REF alleles and remaining ALT alleles
run_test_SV.reset_index(inplace = True, drop = True)

for i in run_test_SV.index:
    
    variant = run_test_SV.loc[i]

    CHROM = variant.CHROM
    POS = variant.POS
    SVLEN = variant.SVLEN
    SVTYPE2 = variant.SVTYPE2
    
    if SVTYPE2 == 'del':
        REF = fasta_open.fetch(CHROM, POS - 1, POS - 1 + SVLEN).upper()
        run_test_SV.loc[i,'REF'] = REF
        run_test_SV.loc[i,'ALT'] = REF[0]
    else:
        run_test_SV.loc[i,'REF'] = fasta_open.fetch(CHROM, POS - 1, POS).upper()
        

In [97]:
# Generate test set for BNDs

# Features to test:

# Create all possible pairs of var_positions with coordinates
var_pos_coord_1 = {'chrom_mid':['chr1', 244397848],
                  'chrom_start':['chr2', 400000],
                  'chrom_centro_left':['chr10', 37600000], 
                  'chrom_centro_right':['chr11', 56200000],
                  'chrom_end':['chr1', 248766420],
                  'centromere':['chr11', 52000000]}
var_pos_coord_2 = {'chrom_mid':['chr6', 73541678],
                  'chrom_start':['chr10', 400000],
                  'chrom_centro_left':['chr5', 45700000],
                  'chrom_centro_right':['chr19', 28400000],
                  'chrom_end':['chr20', 64044167],
                  'centromere':['chr15', 17500000]}
var_pos_coord_pairs = list(product(var_pos_coord_1.items(), var_pos_coord_2.items()))

ALT_type = ['t[p[', 't]p]', ']p]t', '[p[t']
ALT_allele = ['SNP', 'ins']

# Create variants dataframe
run_test_BND = pd.DataFrame(tuple(product(var_pos_coord_pairs, ALT_type, ALT_allele)), 
                         columns = ['var_pos_coord', 'ALT_type', 'ALT_allele'])

# Add remaining columns
run_test_BND['END'] = np.nan
run_test_BND['SVLEN'] = np.nan
run_test_BND['SVTYPE'] = 'BND'

# Split up var_pos_coord
run_test_BND[['var_pos_coord1', 'var_pos_coord2']] = pd.DataFrame(run_test_BND['var_pos_coord'].tolist(), index=run_test_BND.index)
run_test_BND[['var_position1', 'coord1']] = pd.DataFrame(run_test_BND['var_pos_coord1'].tolist(), index=run_test_BND.index)
run_test_BND[['var_position2', 'coord2']] = pd.DataFrame(run_test_BND['var_pos_coord2'].tolist(), index=run_test_BND.index)
run_test_BND[['CHROM', 'POS']] = pd.DataFrame(run_test_BND['coord1'].tolist(), index=run_test_BND.index)
run_test_BND[['CHR2', 'POS2']] = pd.DataFrame(run_test_BND['coord2'].tolist(), index=run_test_BND.index)
run_test_BND.drop(['var_pos_coord', 'var_pos_coord1', 'var_pos_coord2', 'coord1', 'coord2'], axis = 1, inplace = True)

# Get REF alleles and remaining ALT alleles
for i in run_test_BND.index:
    
    variant = run_test_BND.loc[i]

    CHROM = variant.CHROM
    POS = variant.POS
    CHR2 = variant.CHR2
    POS2 = variant.POS2
    ALT_type = variant.ALT_type
    ALT_allele = variant.ALT_allele
    
    REF = fasta_open.fetch(CHROM, POS-1, POS).upper()
    run_test_BND.loc[i,'REF'] = REF
        
    if ALT_allele == 'SNP':
        ALT = REF
    elif ALT_allele == 'ins':
        ALT = fasta_open.fetch(CHROM, POS-1, POS-1 + 3).upper()
        
        
    if ALT_type == 't[p[':
        ALT = f'{ALT}[{CHR2}:{POS2}['      
    elif ALT_type == 't]p]':
        ALT = f'{ALT}]{CHR2}:{POS2}]'
    elif ALT_type == ']p]t':
        ALT = f']{CHR2}:{POS2}]{ALT}'
    elif ALT_type == '[p[t':
        ALT = f'[{CHR2}:{POS2}[{ALT}' 
            
    run_test_BND.loc[i,'ALT'] = ALT




In [98]:
run_test_SV = pd.concat([run_test_SV[['CHROM', 'POS', 'END', 'REF', 'ALT', 'SVTYPE', 'SVLEN']],
                        run_test_BND[['CHROM', 'POS', 'END', 'REF', 'ALT', 'SVTYPE', 'SVLEN']]],
                        axis = 0)
run_test_SV.to_csv('test_data/test_set_edge_cases/test_set_edge_SV.bed', sep = '\t', index = False, header = False)

In [None]:
# test this out - temp code

variants = run_test_BND

run_test_BND[run_test_BND.var_position1 == 'chrom_start']

i = 48
for i in run_test_BND[run_test_BND.var_position1 == 'chrom_start'].index.drop([58, 59, 62, 63, 64, 65, 68, 69, 74, 75, 78, 79, 80, 81, 84, 85]):
    shift = 0
    revcomp = False

    variant = variants.loc[i]

    # var_index = variant.var_index
    CHR = variant.CHROM
    POS = variant.POS
    REF = variant.REF
    ALT = variant.ALT

    if 'END' in variants.columns:
        END = variant.END
        SVTYPE = variant.SVTYPE
    else:
        END = np.nan
        SVTYPE = np.nan

    sequences_i = get_seq_utils.get_sequences_SV(CHR, POS, REF, ALT, END, SVTYPE, shift, revcomp)

    print(i, fasta_open.fetch(CHR, 0, MB).upper() == sequences_i[0])

### Run test

In [4]:
test_path = 'test_data/test_set_edge_cases/test_set_edge_'

run_test_simple = pd.read_csv(f'{test_path}simple.bed', sep ='\t')
run_test_SV = pd.read_csv(f'{test_path}SV.bed', sep ='\t')

len(run_test_simple), len(run_test_SV)

(17, 329)

In [5]:
variants = reading_utils.read_input(f'{test_path}SV.bed', 0)

variants['var_index'] = variants.index
variant_scores = pd.DataFrame({'var_index':variants.var_index})

In [8]:
# Test all variants

get_seq = False
get_scores = True

shift_by = [-10000,0,10000]
revcomp_decision = [True, False]
scores_to_use = ['corr', 'mse']

for i in [76]:#variants.index:

    variant = variants.iloc[i]

    var_index = variant.var_index
    CHR = variant.CHROM
    POS = variant.POS
    REF = variant.REF
    ALT = variant.ALT


    if 'SVTYPE' in variants.columns:
        END = variant.END
        SVTYPE = variant.SVTYPE
        SVLEN = variant.SVLEN
    else:
        END = np.nan
        SVTYPE = np.nan
        SVLEN = 0

    for shift in shift_by:

        if augment: # if getting augmented score, take reverse complement only with 0 shift
            if shift != 0 & True in revcomp_decision:
                revcomp_decision_i = [False]
            else:
                revcomp_decision_i = revcomp_decision

        for revcomp in revcomp_decision_i:

            try:

                if revcomp:
                    revcomp_annot = '_revcomp'
                else:
                    revcomp_annot = ''

                sequences_i = get_seq_utils.get_sequences_SV(CHR, POS, REF, ALT, END, SVTYPE, shift, revcomp)


                if get_seq:

                    # Get relative position of variant in sequence
                    var_rel_pos = str(sequences_i[-1]).replace(', ', '_')

                    for ii in range(len(sequences_i[:-1][:3])): 
                        sequences[f'{var_index}_{shift}{revcomp_annot}_{ii}_{var_rel_pos}'] = sequences_i[:-1][ii]

                if get_scores:

                    scores = get_scores_utils.get_scores(POS, SVTYPE, SVLEN, 
                                                         sequences_i, scores_to_use, 
                                                         shift, revcomp, 
                                                         get_tracks, get_maps)

                    if get_tracks:
                        for track in [x for x in scores.keys() if 'track' in x]:
                            variant_tracks[f'{var_index}_{track}_{shift}{revcomp_annot}'] = scores[track]
                            del scores[track]

                    if get_maps:
                        variant_maps[f'{var_index}_{shift}{revcomp_annot}'] = scores['maps']
                        del scores['maps']

                    for score in scores:
                        variant_scores.loc[variant_scores.var_index == var_index, 
                                           f'{score}_{shift}{revcomp_annot}'] = scores[score]


                print(str(var_index) + ' (' + str(shift) + f' shift{revcomp_annot})')

            except Exception as e: 

                print(str(var_index) + ' (' + str(shift) + f' shift{revcomp_annot})' + ': Error:', e)

                pass




76 (-10000 shift)
76 (0 shift)
76 (10000 shift)


In [None]:
# 76 seq length is off?, 100, 104-107, 110, 116, 117, 120-123, 126, 127

In [6]:
# Test one variant at a time

i = 42

get_seq = False
get_scores = True

shift_by = [-10000,0,10000]
revcomp_decision = [True, False]
scores_to_use = ['corr', 'mse']

variant = variants.loc[i]

var_index = variant.var_index
CHR = variant.CHROM
POS = variant.POS
REF = variant.REF
ALT = variant.ALT

if 'END' in variants.columns:
    END = variant.END
    SVTYPE = variant.SVTYPE
else:
    END = np.nan
    SVTYPE = np.nan

for shift in shift_by:

    for revcomp in revcomp_decision:


        if revcomp:
            revcomp_annot = '_revcomp'
        else:
            revcomp_annot = ''

        sequences_i = utils.get_sequences_SV(CHR, POS, REF, ALT, END, SVTYPE, shift, revcomp)

        if get_seq:

            for ii in range(len(sequences_i)): # note that for BNDs, the last entry in seq is BND_rel_pos
                sequences[f'{var_index}_{shift}{revcomp_annot}_{ii}'] = sequences_i[ii]

        if get_scores:

            scores = utils.get_scores(CHR, POS, REF, ALT, sequences_i, SVTYPE, scores_to_use, shift, revcomp)



In [None]:
# Run test sets

# python scripts/score_var.py --in test/input/run_test_simple.bed --fa /pollard/data/vertebrate_genomes/human/hg38/hg38/hg38.fa \
# --shift_by -10000 0 10000 --revcomp --get_seq --dir test/output --file run_test_simple  

# python score_var.py --in test/input/run_test_SV.bed --fa /pollard/data/vertebrate_genomes/human/hg38/hg38/hg38.fa \
# --shift_by -10000 0 10000 --revcomp --dir test/output --file run_test_SV  

# Test sequence length

from Bio.Seq import Seq
fasta_path_edited = 'test/output/run_test_simple_sequences.fa'
fasta_open_edited = pysam.Fastafile(fasta_path_edited)

len(fasta_open_edited.fetch('0_-10000_0', 0, MB+2).upper())

## Sequences

Only the test sequences below were generated in this way, the rest were generated and annotated on SnapGene

In [None]:
# Test 4

CHR = 'chr11'
POS = 56200000


# left side
fasta_open.fetch(CHR, 55800000, POS - 1).upper()

# deletion
fasta_open.fetch(CHR, POS-1, POS-1 + 100000).upper()

# right side
fasta_open.fetch(CHR, POS-1 + 100000, 55800000+MB).upper()

# ALT right side
fasta_open.fetch(CHR, 55800000+MB, 55800000+MB+100000).upper()

In [None]:
# Test 7: t[p[, chrom_start, chrom_centro_right

# you are getting one extra bp on the left, yoou need the code to shift it right by 1
CHR = 'chr2'
POS = 400000

CHR2 = 'chr19'
POS2 = 28400000 # closer


# right side
fasta_open.fetch(CHR2, POS2 - 1, 28100000 + MB).upper()

# right side REF (left end of it)
fasta_open.fetch(CHR2, 28100000, POS2 - 1).upper()



# left side
left_side = POS2 - 1 - 28100000
fasta_open.fetch(CHR, POS - left_side, POS).upper()

# left side REF (right end of it)
fasta_open.fetch(CHR, POS, POS + MB - left_side).upper()

In [None]:
# Test 8: ]p]t, chrom_centro_left, chrom_centro_left

# you are getting one extra bp on the left, yoou need the code to shift it right by 1
CHR = 'chr10'
POS = 37900000 # closer

CHR2 = 'chr5'
POS2 = 45700000


# right side
fasta_open.fetch(CHR, POS - 1, 38000000).upper()
# not including centromere start (0-based)

# right side REF (left end of it)
fasta_open.fetch(CHR, 38000000 - MB, POS - 1).upper()



# left side
fasta_open.fetch(CHR2, POS2 - (MB-distance), POS2).upper()


# left side REF (right end of it)
distance = 38000000 - (POS - 1)
fasta_open.fetch(CHR2, POS2, POS2 + distance).upper()

In [None]:
# Test 9: [p[t, chrom_centro_left, chrom_centro_left

# you are getting one extra bp on the left, yoou need the code to shift it right by 1
CHR = 'chr1'
POS = 248766420 # closer

CHR2 = 'chr10'
POS2 = 400000


# right side
fasta_open.fetch(CHR, POS - 1, 248956422).upper()
# not including centromere start (0-based)

# right side REF (left end of it)
fasta_open.fetch(CHR, 248956422 - MB, POS - 1).upper()


# left side
distance = 248956422 - (POS - 1)
str(Seq(fasta_open.fetch(CHR2, POS2 - 1, POS2 - 1 + (MB - distance)).upper()).reverse_complement())

# left side REF (right end of it)
str(Seq(fasta_open.fetch(CHR2, POS2 - 1 - distance, POS2 - 1).upper()).reverse_complement())

In [7]:
variants = pd.read_csv('test_data/test_set_sequences/test_set_sequences.txt', sep ='\t')
variants

Unnamed: 0,test_index,ID,CHROM,POS,END,REF,ALT,SVTYPE,SVLEN,reason_for_test
0,,test_0,chr1,8145000,,GAGGTCAGCACCATCCTGGTTAACAAGGTGAAGCCCCATCTCTACT...,G,,,chrom_mid
1,,test_1,chr1,8145000,8145500.0,G,<DEL>,DEL,-500.0,chrom_mid
2,11.0,test_2,chr2,400000,450000.0,C,<DEL>,DEL,-50000.0,chrom_start
3,,test_3,chr10,37800000,37900000.0,G,<DEL>,DEL,-100000.0,chrom_centro_left
4,,test_4,chr11,56200000,56300000.0,A,<DEL>,DEL,-100000.0,chrom_centro_right
5,,test_5,chr1,248746420,248766420.0,T,<DEL>,DEL,-20000.0,chrom_end
6,2.0,test_6,chr1,244397848,,A,A]chr6:73541678],BND,,chrom_mid_chrom_mid
7,72.0,test_7,chr2,400000,,C,C[chr19:28400000[,BND,,chrom_start_chrom_centro_right
8,116.0,test_8,chr10,37900000,,T,]chr5:45700000]T,BND,,chrom_centro_left_chrom_centro_left
9,206.0,test_9,chr1,248766420,,A,[chr10:400000[A,BND,,chrom_end_chrom_start


### Run test

In [7]:
n = 10
shift = 0
revcomp = False

for i in [7]:#range(0,9):

    variant = variants.iloc[i]

    var_index = variant.ID
    CHR = variant.CHROM
    POS = variant.POS
    REF = variant.REF
    ALT = variant.ALT

    if 'END' in variants.columns:
        END = variant.END
        SVTYPE = variant.SVTYPE
    else:
        END = np.nan
        SVTYPE = np.nan

    sequences_i = get_seq_utils.get_sequences_SV(CHR, POS, REF, ALT, END, SVTYPE, shift, revcomp)

    if var_index == 'test_0':
        var_index = 'test_1'

    print(i, SVTYPE)

    if SVTYPE != 'BND':

        # non_BND

        custom_REF = pd.read_csv(f'test/sequences/test{var_index[-1]}_REF.txt', header = None).loc[0].values[0]
        print(len(custom_REF) == MB, sequences_i[0]== custom_REF)

        print(sequences_i[0][:n], custom_REF[:n])


        custom_ALT = pd.read_csv(f'test/sequences/test{var_index[-1]}_ALT.txt', header = None).loc[0].values[0]
        print(len(custom_ALT) == MB, sequences_i[1] == custom_ALT)

        print(sequences_i[1][:n], custom_ALT[:n])

    elif SVTYPE == 'BND':

        # BND

        custom_REF_L = pd.read_csv(f'test/sequences/test{var_index[-1]}_REF_L.txt', header = None).loc[0].values[0]
        print(len(custom_REF_L) == MB, sequences_i[0]== custom_REF_L)

        print(sequences_i[0][:n], custom_REF_L[:n])


        custom_REF_R = pd.read_csv(f'test/sequences/test{var_index[-1]}_REF_R.txt', header = None).loc[0].values[0]
        print(len(custom_REF_R) == MB, sequences_i[1]== custom_REF_R)

        print(sequences_i[1][:n], custom_REF_R[:n])


        custom_ALT = pd.read_csv(f'test/sequences/test{var_index[-1]}_ALT.txt', header = None).loc[0].values[0]
        print(len(custom_ALT) == MB, sequences_i[2] == custom_ALT)

        print(sequences_i[2][:n], custom_ALT[:n])

7 BND
True True
TCCAAATATG TCCAAATATG
True True
TAGAACTGGC TAGAACTGGC
True True
TCCAAATATG TCCAAATATG


# Time and memory

In [None]:
# To measure time and memory- paste the following in the corresponding lines in SuPreMo.py 


# Before everything

import tracemalloc
tracemalloc.start()


# Before import sys

import time
start_time = time.time()
first_size, first_peak = tracemalloc.get_traced_memory()


# After closing log file

end_time = time.time()
print(end_time - start_time)
second_size, second_peak = tracemalloc.get_traced_memory()


# At the end of the script

third_size, third_peak = tracemalloc.get_traced_memory()
print(f"first_size={first_size}, first_peak={first_peak}")
print(f"second_size={second_size}, second_peak={second_peak}")
print(f"third_size={third_size}, third_peak={third_peak}")

# Extra code

In [None]:
# To combine tracks files from different sets (currently they accumulate in memory)

tracks = np.load(f'{out_file}_tracks_0.npy', allow_pickle="TRUE").item()

for var_set in [1]:
    tracks_n = np.load(f'{out_file}_tracks_{var_set}.npy', allow_pickle="TRUE").item()

    tracks.update(tracks_n)
    

# Old version of some code

In [None]:
# before changing var_rel_pos


def get_sequences(CHR, POS, REF, ALT, shift, revcomp: bool):
  
    '''
    Get reference and alternate sequence for prediction from REF and ALT alleles by incorporating ALT into the reference genome.
    Requires ALT allele to be a sequence and not a symbolic allele.
    Use positive sign for a right shift and negative for a left shift.
    revcomp: Take the reverse compliment of the resulting sequence.
    
    '''

    # Get reference sequence
    
    REF_len = len(REF)

    REF_half_left = math.ceil((seq_length - REF_len)/2) - shift # if the REF allele is odd, shift right
    REF_half_right = math.floor((seq_length - REF_len)/2) + shift

    
    # Annotate whether variant position with respect to chromosome arms ends
    if len(REF) <= len(ALT):
        var_position = get_variant_position(CHR, POS, REF_len, REF_half_left, REF_half_right)
  
    elif len(REF) > len(ALT):       
        ALT_len = len(ALT)
        ALT_half_left = math.ceil((seq_length - ALT_len)/2) - shift
        ALT_half_right = math.floor((seq_length - ALT_len)/2) + shift   
        var_position = get_variant_position(CHR, POS, ALT_len, ALT_half_left, ALT_half_right)
    

    # Get last coordinate of chromosome
    chrom_max = int(chrom_lengths[chrom_lengths.CHROM == CHR[3:]]['chrom_max'])
    
    # Get centromere coordinates
    centro_start = int(centromere_coords[centromere_coords.chrom == CHR]['start'])
    centro_stop = int(centromere_coords[centromere_coords.chrom == CHR]['end'])
    
    
    # Get start and end of reference sequence
    if var_position == "chrom_mid":
        REF_start = POS - REF_half_left
        REF_stop = REF_start + seq_length 
    elif var_position == "centromere":
        raise ValueError('Centromeric variant.')
    else:
        REF_start = adjust_seq_ends(centro_start, centro_stop, chrom_max, var_position, 0, shift)
        REF_stop = adjust_seq_ends(centro_start, centro_stop, chrom_max, var_position, seq_length, shift)
        print("Warning: Variant not centered; too close to chromosome arm ends.")
        
        
    # Get reference sequence
    REF_seq = fasta_open.fetch(CHR, REF_start, REF_stop).upper()


    # Error if N composition is more than 5% of sequence
    if Counter(REF_seq)['N']/seq_length*100 > 5:
        raise ValueError('N composition greater than 5%.')



    # Error if reference sequence does not match given REF

    if var_position == "chrom_mid":
        var_rel_pos_REF = REF_half_left - 1 # subtract 1 bc POS is not included in variant

    elif var_position == "chrom_start": 
        var_rel_pos_REF = POS - abs(shift) - 1

    elif var_position == "chrom_centro_right": 
        var_rel_pos_REF = POS - centro_stop - abs(shift) - 1

    elif var_position in ["chrom_end", "chrom_centro_left"]: 
        var_rel_pos_REF = -(REF_stop - POS) - 1


    if REF_seq[var_rel_pos_REF : var_rel_pos_REF + REF_len] != REF:
        raise ValueError('Reference allele does not match hg38.')
            
            
            
    # Error if reference sequence is not the right length      
    if len(REF_seq) != seq_length:
        raise ValueError('Reference sequence generated is not the right length.')





    # For SNPs, MNPs, Insertions: 
    if len(REF) <= len(ALT):

        # Create alternate sequence: change REF sequence at position from REF to ALT

        ALT_seq = REF_seq

        ALT_seq = ALT_seq[:var_rel_pos_REF] + ALT + ALT_seq[var_rel_pos_REF + REF_len:]


        var_rel_pos_ALT = var_rel_pos_REF
        
        # Chop off ends of alternate sequence if it's longer 
        if len(ALT_seq) > len(REF_seq):
            to_remove = (len(ALT_seq) - len(REF_seq))/2

            if to_remove == 0.5:
                ALT_seq = ALT_seq[1:]
                var_rel_pos_ALT = var_rel_pos_REF - 1
                
            else:
                ALT_seq = ALT_seq[math.ceil(to_remove) : -math.floor(to_remove)]
                var_rel_pos_ALT = var_rel_pos_REF - math.ceil(to_remove)
                
            


    # For Deletions
    elif len(REF) > len(ALT):


        del_len = len(REF) - len(ALT)
        
        to_add_left = math.ceil(del_len/2)
        to_add_right = math.floor(del_len/2) 

        # Get start and end of reference sequence
        if var_position == "chrom_mid":
            ALT_start = REF_start - to_add_left
            ALT_stop = REF_stop + to_add_right

        elif var_position in ["chrom_start", "chrom_centro_right"]: 
            ALT_start = adjust_seq_ends(centro_start, centro_stop, chrom_max, var_position, 0, shift)
            ALT_stop = adjust_seq_ends(centro_start, centro_stop, chrom_max, var_position, seq_length + del_len, shift)
            
        elif var_position in ["chrom_centro_left", "chrom_end"]: 
            ALT_start = adjust_seq_ends(centro_start, centro_stop, chrom_max, var_position, 0 - del_len, shift)
            ALT_stop = adjust_seq_ends(centro_start, centro_stop, chrom_max, var_position, seq_length, shift)
        
        
        
        # Get alternate sequence
        ALT_seq = fasta_open.fetch(CHR, ALT_start, ALT_stop).upper()
        
        
        
        # Error if alternate sequence does not match REF at POS

        if var_position == "chrom_mid":
            var_rel_pos_ALT = REF_half_left + to_add_left - 1

        elif var_position == "chrom_start": 
            var_rel_pos_ALT = POS - abs(shift) - 1
            
        elif var_position == "chrom_centro_right": 
            var_rel_pos_ALT = POS - centro_stop - abs(shift) - 1
            
        elif var_position in ["chrom_end", "chrom_centro_left"]: 
            var_rel_pos_ALT = -(REF_stop - POS) - 1
                
                
        if ALT_seq[var_rel_pos_ALT : var_rel_pos_ALT + REF_len] != REF:
            raise ValueError('Sequence for the alternate allele does not match hg38 at REF position.')


    
        # Change alternate sequence to match ALT at POS
        ALT_seq = ALT_seq[:var_rel_pos_ALT] + ALT + ALT_seq[var_rel_pos_ALT + REF_len:]

            
    if len(ALT_seq) != seq_length:
        raise ValueError('Alternate sequence generated is not the right length.')
         
            
    # Take reverse compliment of sequence
    if revcomp:
        REF_seq, ALT_seq = [str(Seq(x).reverse_complement()) for x in [REF_seq, ALT_seq]]

        
    return REF_seq, ALT_seq, [var_rel_pos_REF, var_rel_pos_ALT]



In [None]:
# Changed the following to keep the middle of the ALT maps and to have nan there instead of 0s in REF map

def assemple_BND_maps(vector_repr_L, vector_repr_R, BND_rel_pos_map, matrix_len = target_length_cropped, num_diags = 2):
    
    '''
    This applies to BND predcitions.
    
    Get predicted matrix from Akita predictions. 
    Output is a 448x448 array with the contact frequency at each 2048 bp bin corresponding to a 917,504 bp sequence (32 bins are cropped on each end from the prediction).
    
    '''
    
    z = np.zeros((matrix_len,matrix_len))

    # make df out of matrix indices and filter to get top left and bottom right parts
    indices = np.triu_indices(matrix_len, num_diags)
    indices = pd.DataFrame(np.column_stack(indices), columns = ['rows', 'cols'])

    indices_L = tuple(indices.query('cols < @BND_rel_pos_map').T.apply(np.array, axis=1))
    indices_R = tuple(indices.query('rows >= @BND_rel_pos_map').T.apply(np.array, axis=1))
    
    z[indices_L] = vector_repr_L
    z[indices_R] = vector_repr_R
    
    for i in range(-num_diags+1,num_diags):
        set_diag(z, np.nan, i)
        
    return z + z.T

def get_masked_BND_maps(matrices, rel_pos_map):

    # Get REF and ALT vectors, excluding diagonal 
    indexes_left = np.triu_indices(rel_pos_map, 2)
    indexes_right = np.triu_indices(target_length_cropped - rel_pos_map, 2)

    REF_L = get_left_BND_map(matrices[0], rel_pos_map)[indexes_left]
    REF_R = get_right_BND_map(matrices[1], rel_pos_map)[indexes_right]
    ALT_L = get_left_BND_map(matrices[2], rel_pos_map)[indexes_left]
    ALT_R = get_right_BND_map(matrices[2], rel_pos_map)[indexes_right]

    return (assemple_BND_maps(REF_L, REF_R, rel_pos_map),
            assemple_BND_maps(ALT_L, ALT_R, rel_pos_map))

In [None]:
# This code did not work well - the subplots would not always align based on the size of the gene subplot


# import itertools

bin_size = 2048
target_length_cropped = 448

map_length = bin_size * target_length_cropped


def pcolormesh_45deg(ax, mat, lines, linewidth, *args, **kwargs):
    #https://stackoverflow.com/questions/12848581/is-there-a-way-to-rotate-a-matplotlib-plot-by-45-degrees
    n = mat.shape[0]
    # create rotation/scaling matrix
    t = np.array([[1,0.5],[-1,0.5]])
    # create coordinate matrix and transform it
    A = np.dot(np.array([(i[1],i[0]) for i in itertools.product(range(n,-1,-1),range(0,n+1,1))]),t)
    # plot
    im = ax.pcolormesh(A[:,1].reshape(n+1,n+1),A[:,0].reshape(n+1,n+1),np.flipud(mat),*args, **kwargs)
    im.set_rasterized(True)
    if lines is not None:
        for line in lines:
            ax.plot([line,line + (448-line)/2], [0,448-line], color = 'gray', linewidth=linewidth, linestyle = 'dashed')
            ax.plot([line,line/2], [0,line], color = 'gray', linewidth=linewidth, linestyle = 'dashed')
    ax.set_ylim(0,n)
    ax.set_ylim(0,target_length_cropped)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['left'].set_visible(False)
    _ = ax.set_xticks([])
    _ = ax.set_yticks([])
    ax.plot([0, n/2], [0, n], 'k-',linewidth=1)
    ax.plot([n/2, n], [n, 0], 'k-',linewidth=1)
    
    ax.set_aspect(.5)
    return im


def plot_maps(mat1, mat2, lines, disruption_track, genes_in_map, size):
    
    gene_track = list(genes_in_map[['Start', 'width']].to_records(index = False))
    linestyle = 'dashed'
    linewidth = size/10*1.5
    color = 'gray'
    
    nrows = 8.5
    fig, (ax1,ax2,ax3,ax4,ax5) = plt.subplots(nrows=5,ncols=1,figsize=(size/2,size), 
                                               sharex = True,
                                               gridspec_kw={'height_ratios': [2/nrows, 2/nrows, 2/nrows, 
                                                                              1/nrows, 1.5/nrows]})

    pcolormesh_45deg(ax1, mat1, lines, linewidth, cmap= 'RdBu_r', vmax=2, vmin=-2)
    pcolormesh_45deg(ax2, mat2, lines, linewidth, cmap= 'RdBu_r', vmax=2, vmin=-2)
    pcolormesh_45deg(ax3, mat1 - mat2, lines, linewidth, cmap= 'PRGn', vmax=1, vmin=-1)
    
    ax4.plot(list(range(target_length_cropped)), disruption_track, color = 'black', linewidth = linewidth)
#     plt.ylabel(scoring_method, rotation = 90)
    plt.xlim([0,target_length_cropped])
    ax4.set_xticks([])
    ax4.set_yticks([])
    if lines is not None:
        for line in lines:
            ax4.axvline(x=line, color=color, linestyle=linestyle, linewidth = linewidth)

    
    ax5.broken_barh(gene_track[::3], (14.5, 0.5), facecolors='tab:blue')
    ax5.broken_barh(gene_track[1:][::3], (9.5, 0.5), facecolors='tab:blue')
    ax5.broken_barh(gene_track[2:][::3], (4.5, 0.5), facecolors='tab:blue')
    
    bar_locations = ([14.5, 9.5, 4.5]*math.ceil(len(genes_in_map)/3))[:len(genes_in_map)]
    for i in range(len(genes_in_map)):
        
        bar_location = bar_locations[i]
        
        gene = genes_in_map.loc[i,'Gene']
        location = genes_in_map.loc[i,'Start']
        
        ax5.annotate(gene, (location,bar_location), # annotate gene at the start
                     rotation = 45, 
                     ha = "right", va = "top", # set horizontal and vertical alignment
                     annotation_clip = False, # keep annotation if outside of window
                     fontsize = size/10*14) 
            
    if lines is not None:
        for line in lines:
            plt.axvline(x=line, color=color, linestyle=linestyle, linewidth = linewidth)
    plt.ylabel('Genes',rotation=90)
    plt.ylim([0,15])
    plt.xlim([0,448])
    plt.axis('off')
    
    fig.tight_layout()
    plt.show()
    

plot_maps(REF_pred, ALT_pred, lines, disruption_track, genes_in_map, 10)

In [None]:
def plot_maps_genes_tracks(REF_pred, ALT_pred, genes_in_map, lines, disruption_track, scoring_method):
    
    '''
    Plot the reference and alternate predicted contact frequency maps with lines at the beginning and end of the variant. For duplications, there will be 3 lines marking the two regions that are duplicates. 
    
    Plot genes that match the regions in the reference map.
    
    Plot disruption score tracks.
    
    '''

    gene_track = list(genes_in_map[['Start', 'width']].to_records(index = False))
    
    vmin = -2
    vmax = 2
    linestyle = 'dashed'
    linewidth = 0.3
    color = 'black'

    plt.rcParams['font.size']= 10
    plot_width = 3
    plot_width1D = 1
    fig, gs = gridspec_inches([plot_width], [plot_width, .15, plot_width, .15, 
                                             plot_width1D, .15, plot_width1D*2])


    plt.subplot(gs[0,0])
    plt.matshow(REF_pred, fignum=False, cmap= 'RdBu_r', vmax=vmax, vmin=vmin)
    plt.ylabel('Reference matrix',rotation=90)
    if lines is not None:
        for line in lines:
            plt.axvline(x=line, color=color, linestyle=linestyle, linewidth = linewidth)
            plt.axhline(y=line, color=color, linestyle=linestyle, linewidth = linewidth)
    plt.gca().yaxis.tick_right()
    plt.xticks([])

    plt.subplot(gs[2,0])
    plt.matshow(ALT_pred, fignum=False, cmap= 'RdBu_r', vmax=vmax, vmin=vmin)
    plt.ylabel('Alternate matrix',rotation=90)
    if lines is not None:
        for line in lines:
            plt.axvline(x=line, color=color, linestyle=linestyle, linewidth = linewidth)
            plt.axhline(y=line, color=color, linestyle=linestyle, linewidth = linewidth)
    plt.gca().yaxis.tick_right()
    plt.xticks([])
    
    plt.subplot(gs[4,0])
    plt.plot(list(range(target_length_cropped)), disruption_track)
    plt.ylabel(scoring_method, rotation = 90)
    plt.xlim([0,target_length_cropped])
    plt.gca().yaxis.tick_right()
    plt.xticks([])
    

    ax1 = fig.add_subplot(gs[6,0])
            
    ax1.broken_barh(gene_track[::3], (14.5, 0.5), facecolors='tab:blue')
    ax1.broken_barh(gene_track[1:][::3], (9.5, 0.5), facecolors='tab:blue')
    ax1.broken_barh(gene_track[2:][::3], (4.5, 0.5), facecolors='tab:blue')
    
    bar_locations = ([14.5, 9.5, 4.5]*math.ceil(len(genes_in_map)/3))[:len(genes_in_map)]
    for i in range(len(genes_in_map)):
        
        bar_location = bar_locations[i]
        
        gene = genes_in_map.loc[i,'Gene']
        location = genes_in_map.loc[i,'Start']
        
        ax1.annotate(gene, (location,bar_location), # annotate gene at the start
                     rotation = 45, 
                     ha = "right", va = "top", # set horizontal and vertical alignment
                     annotation_clip = False, # keep annotation if outside of window
                     fontsize = 8) 
            
    if lines is not None:
        for line in lines:
            plt.axvline(x=line, color=color, linestyle=linestyle, linewidth = linewidth)
    plt.ylabel('Genes',rotation=90)
    plt.ylim([0,15])
    plt.xlim([0,448])
    plt.axis('off')

    plt.show()
    
    
    
    
def plot_maps_genes(REF_pred, ALT_pred, genes_in_map, lines):

        
    '''
    Plot the reference and alternate predicted contact frequency maps with lines at the beginning and end of the variant. For duplications, there will be 3 lines marking the two regions that are duplicates. 
    
    Plot genes that match the regions in the reference map.
    
    '''
    
    
    gene_track = list(genes_in_map[['Start', 'width']].to_records(index = False))
    
    vmin = -2
    vmax = 2
    linestyle = 'dashed'
    linewidth = 0.3
    color = 'black'

    plt.rcParams['font.size']= 10
    plot_width = 3
    plot_width1D = 1
    fig, gs = gridspec_inches([plot_width], [plot_width, .15, plot_width, .15, plot_width1D])


    plt.subplot(gs[0,0])
    plt.matshow(REF_pred, fignum=False, cmap= 'RdBu_r', vmax=vmax, vmin=vmin)
    plt.ylabel('Reference matrix',rotation=90)
    if lines is not None:
        for line in lines:
            plt.axvline(x=line, color=color, linestyle=linestyle, linewidth = linewidth)
            plt.axhline(y=line, color=color, linestyle=linestyle, linewidth = linewidth)
    plt.gca().yaxis.tick_right()
    plt.xticks([])

    plt.subplot(gs[2,0])
    plt.matshow(ALT_pred, fignum=False, cmap= 'RdBu_r', vmax=vmax, vmin=vmin)
    plt.ylabel('Alternate matrix',rotation=90)
    if lines is not None:
        for line in lines:
            plt.axvline(x=line, color=color, linestyle=linestyle, linewidth = linewidth)
            plt.axhline(y=line, color=color, linestyle=linestyle, linewidth = linewidth)
    plt.gca().yaxis.tick_right()
    plt.xticks([])

    ax1 = fig.add_subplot(gs[4,0])
            
    ax1.broken_barh(gene_track[::3], (14.5, 0.5), facecolors='tab:blue')
    ax1.broken_barh(gene_track[1:][::3], (9.5, 0.5), facecolors='tab:blue')
    ax1.broken_barh(gene_track[2:][::3], (4.5, 0.5), facecolors='tab:blue')
    
    bar_locations = ([14.5, 9.5, 4.5]*math.ceil(len(genes_in_map)/3))[:len(genes_in_map)]
    for i in range(len(genes_in_map)):
        
        bar_location = bar_locations[i]
        
        gene = genes_in_map.loc[i,'Gene']
        location = genes_in_map.loc[i,'Start']
        
        ax1.annotate(gene, (location,bar_location), # annotate gene at the start
                     rotation = 45, 
                     ha = "right", va = "top", # set horizontal and vertical alignment
                     annotation_clip = False, # keep annotation if outside of window
                     fontsize = 8) 
            
    if lines is not None:
        for line in lines:
            plt.axvline(x=line, color=color, linestyle=linestyle, linewidth = linewidth)
    plt.ylabel('Genes',rotation=90)
    plt.ylim([0,15])
    plt.xlim([0,448])
    plt.axis('off')

    plt.show()

In [None]:
# Removed the following from get_scores_utils.py after changing masking function

get_variant_position(CHR, POS, var_len, half_left, half_right) # same as in get_seq



def get_variant_type(REF, ALT):
    
    '''
    Get the type of variant from a REF and ALT allele of sequences.
    One of 4 variant types:
    1. Deletion
    2. Insertion
    3. SNP: single nucleotide variant
    4. MNP: multiple nucleotide variant
    
    '''

    # Annotate variant as one of the 6 categories below based on REF and ALT allele
    
    if len(REF) > len(ALT):
        variant_type = "Deletion"
    elif len(REF) < len(ALT):
        variant_type = "Insertion"
        
    elif len(REF) == 1 and len(ALT) ==1:
        variant_type = "SNP"
    elif len(REF) == len(ALT) and len(REF) != 1:
        variant_type = "MNP"
    
    return variant_type










In [None]:
# Plotting function for when there might be more than 2 maps

def plot_maps(maps, genes_in_map):

    vmin = -2
    vmax = 2

    plt.rcParams['font.size']= 10
    plot_width = 3
    plot_width1D = 1
    if len(maps) == 2:
        fig, gs = gridspec_inches([plot_width], [plot_width, .15, plot_width, .15, plot_width1D])
#     if len(maps) == 3:
#         fig, gs = gridspec_inches([plot_width], [plot_width, .15, plot_width, .15, plot_width, .15, plot_width1D])


    for i in range(len(maps)):
        
        plt.subplot(gs[i*2,0])
        plt.matshow(maps[i], fignum=False, cmap= 'RdBu_r', vmax=vmax, vmin=vmin)
        plt.gca().yaxis.tick_right()
        plt.xticks([])

   
    gene_track = list(genes_in_map[['Start', 'width']].to_records(index = False))

    ax1 = fig.add_subplot(gs[len(maps)*2+2,0])

    ax1.broken_barh(gene_track[::2], (9.5, 0.5), facecolors='tab:blue')
    ax1.broken_barh(gene_track[1:][::2], (4.5, 0.5), facecolors='tab:blue')

    for i in range(len(genes_in_map)):

        if (i % 2) == 0:
            bar_location = 9.5
        else:
            bar_location = 4.5

        gene = genes_in_map.loc[i,'Gene']
        location = genes_in_map.loc[i,'Start']

        ax1.annotate(gene, (location,bar_location), # annotate gene at the start
                     rotation = 45, 
                     ha = "right", va = "top", # set horizontal and vertical alignment
                     annotation_clip = False, # keep annotation if outside of window
                     fontsize = 6) 

    plt.ylabel('Genes',rotation=90)
    plt.ylim([0,10])
    plt.xlim([0,target_length_cropped])
    plt.axis('off')

    plt.show()

In [133]:
def mask_matrices(CHR, POS, REF, ALT, REF_pred, ALT_pred, shift):

    '''
    This applied to non-BND predicted matrices.
    
    Mask reference and alternate predicted matrices based on the type of variant.
    
    '''
    
    variant_type = get_variant_type(REF, ALT)
    
    # Get last coordinate of chromosome
    chrom_max = int(chrom_lengths[chrom_lengths.CHROM == CHR[3:]]['chrom_max']) 

    # Get centromere coordinate
    centro_start = int(centromere_coords[centromere_coords.chrom == CHR]['start'])
    centro_stop = int(centromere_coords[centromere_coords.chrom == CHR]['end'])

    
    # Insertions: Mask REF, add nans if necessary, and mirror nans to ALT
    if variant_type in ["Insertion", "Deletion"]:
        
        if variant_type == "Deletion":
            # this works the same exact way but the commands are swapped
            REF, ALT = ALT, REF
            REF_pred, ALT_pred = ALT_pred, REF_pred

        # Get REF allele sections
        REF_len = len(REF)

        REF_half_left = math.ceil((seq_length - REF_len)/2) - shift # if the REF allele is odd, shift right
        REF_half_right = math.floor((seq_length - REF_len)/2) + shift

        
        # Get ALT allele sections
        ALT_len = len(ALT)
        
        ALT_half_left = math.ceil((seq_length - ALT_len)/2) - shift
        ALT_half_right = math.floor((seq_length - ALT_len)/2) + shift
        
        
        # Annotate whether variant is close to beginning or end of chromosome
        var_position = get_variant_position(CHR, POS, REF_len, REF_half_left, REF_half_right)


        # Get start and end bins of REF and ALT alleles
        if var_position == "chrom_mid":
            
            var_start = get_bin(REF_half_left - 1)
            var_end = get_bin(REF_half_left - 1 + REF_len)
            
            var_start_ALT = get_bin(ALT_half_left - 1)
            var_end_ALT = get_bin(ALT_half_left - 1 + ALT_len)

        elif var_position == "chrom_start": 
            
            var_start = get_bin(POS - 1 - abs(shift))
            var_end = get_bin(POS - 1 + REF_len - abs(shift))
            
            var_start_ALT = var_start
            var_end_ALT = get_bin(POS - 1 + ALT_len - abs(shift))

        elif var_position == "chrom_centro_left": 
            
            var_start = get_bin(POS - (centro_start - seq_length - abs(shift)) - 1)
            var_end = get_bin(POS - (centro_start - seq_length - abs(shift)) - 1 + REF_len)
            
            var_start_ALT = get_bin(POS - (centro_start - seq_length - abs(shift)) - 1 - ALT_len)
            var_end_ALT = var_end

        elif var_position == "chrom_centro_right": 
            
            var_start = get_bin(POS - centro_stop - 1 - abs(shift))
            var_end = get_bin(POS - centro_stop - 1 + REF_len - abs(shift))
            
            var_start_ALT = var_start
            var_end_ALT = get_bin(POS - centro_stop - 1 + ALT_len - abs(shift))

        elif var_position == "chrom_end": 
            
            var_start = get_bin(POS - (chrom_max - seq_length - abs(shift)) - 1)
            var_end = get_bin(POS - (chrom_max - seq_length - abs(shift)) - 1 + REF_len)
            
            var_start_ALT = get_bin(POS - (chrom_max - seq_length - abs(shift)) - 1 - ALT_len)
            var_end_ALT = var_end


        # Mask REF map: make variant bin(s) nan and add empty bins at the variant if applicable
        
        REF_pred_masked = REF_pred.copy()

        REF_pred_masked[var_start:var_end + 1, :] = np.nan
        REF_pred_masked[:, var_start:var_end + 1] = np.nan
  
        
        # If the ALT allele falls on more bins than the REF allele, adjust ALT allele 
            # (add nan(s) to var and remove outside bin(s))
            # Otherwise don't mask
        
        if var_end_ALT - var_start_ALT > var_end - var_start:
            
        
            # Insert the rest of the nans corresponding to the ALT allele
            to_add = (var_end_ALT - var_start_ALT) - (var_end - var_start)

            for j in range(var_start, var_start + to_add): # range only includes the first variable 
                REF_pred_masked = np.insert(REF_pred_masked, j, np.nan, axis = 0)
                REF_pred_masked = np.insert(REF_pred_masked, j, np.nan, axis = 1)

            # Chop off the outside of the REF matrix 
            to_remove = len(REF_pred_masked) - 448

            if var_position == "chrom_mid":
                # remove less on the left bc that's where you put one less part of the variant with odd number of bp
                REF_pred_masked = REF_pred_masked[math.floor(to_remove/2) : -math.ceil(to_remove/2), 
                                                  math.floor(to_remove/2) : -math.ceil(to_remove/2)]
                
            elif var_position in ["chrom_start", "chrom_centro_right"]: 
                # Remove all from the right
                REF_pred_masked = REF_pred_masked[: -to_remove, 
                                                  : -to_remove]

            elif var_position in ["chrom_end", "chrom_centro_left"]: 
                # Remove all from the left
                REF_pred_masked = REF_pred_masked[to_remove :, 
                                                  to_remove :]

            assert len(REF_pred_masked) == 448, 'Masked reference matrix is not the right size.'
            
            
        

        # Mask ALT map: make all nans in REF_pred also nan in ALT_pred
        
        REF_pred_novalues = REF_pred_masked.copy()

        REF_pred_novalues[np.invert(np.isnan(REF_pred_novalues))] = 0

        ALT_pred_masked = ALT_pred + REF_pred_novalues

        assert len(ALT_pred_masked) == 448, 'Masked alternate matrix is not the right size.'
        
        if variant_type == "Deletion":
            # Swap back
            REF_pred_masked, ALT_pred_masked = ALT_pred_masked, REF_pred_masked
        
    
    # SNPs or MNPs: Mask REF and mirror nans to ALT
    elif variant_type in ['SNP', 'MNP']:
        
        # Get REF allele sections
        REF_len = len(REF)

        REF_half_left = math.ceil((seq_length - REF_len)/2)  - shift # if the REF allele is odd, shift right
        REF_half_right = math.floor((seq_length - REF_len)/2) + shift


        # Annotate whether variant is close to beginning or end of chromosome
        var_position = get_variant_position(CHR, POS, REF_len, REF_half_left, REF_half_right)


        # Get start and end bins of REF and ALT alleles
        if var_position == "chrom_mid":
            
            var_start = get_bin(REF_half_left - 1)
            var_end = get_bin(REF_half_left - 1 + REF_len)

        elif var_position == "chrom_start": 
            
            var_start = get_bin(POS - 1 + abs(shift))
            var_end = get_bin(POS - 1 + REF_len + abs(shift))

        elif var_position == "chrom_centro_left": 
            
            var_start = get_bin(POS - (centro_start - seq_length - abs(shift)) - 1)
            var_end = get_bin(POS - (centro_start - seq_length - abs(shift)) - 1 + REF_len)

        elif var_position == "chrom_centro_right": 
            
            var_start = get_bin(POS - centro_stop - 1 + abs(shift))
            var_end = get_bin(POS - centro_stop - 1 + REF_len + abs(shift))

        elif var_position == "chrom_end": 
            
            var_start = get_bin(POS - (chrom_max - seq_length - abs(shift)) - 1)
            var_end = get_bin(POS - (chrom_max - seq_length - abs(shift)) - 1 + REF_len)
            
            
        # Mask REF map: make variant bin(s) nan 
        
        REF_pred_masked = REF_pred.copy()

        REF_pred_masked[var_start:var_end + 1, :] = np.nan
        REF_pred_masked[:, var_start:var_end + 1] = np.nan

        
        # Mask ALT map: make all nans in REF_pred also nan in ALT_pred
        
        REF_pred_novalues = REF_pred_masked.copy()

        REF_pred_novalues[np.invert(np.isnan(REF_pred_novalues))] = 0

        ALT_pred_masked = ALT_pred + REF_pred_novalues

        assert len(ALT_pred_masked) == 448, 'Masked alternate matrix is not the right size.'
        
    
    
    return REF_pred_masked, ALT_pred_masked

In [None]:
def get_scores(CHR, POS, REF, ALT, sequences, SVTYPE, scores, shift, revcomp, get_tracks, get_maps): 
    
    '''
    Get disruption scores, disruption tracks, and/or predicted maps from variants and the sequences generated from them.
    
    '''
    
    
    var_rel_pos = sequences[-1]
    
    # Error if variant position is too close to end of prediction window
    if any([int(x) <= bin_size*32 or int(x) >= seq_length - bin_size*32 for x in var_rel_pos]):
        raise ValueError('Variant outside prediction window after cropping')
    
    
    # Make prediction
    sequences = [x for x in sequences if type(x) == str]
    matrices = [mat_from_vector(vector) for vector in [vector_from_seq(seq) for seq in sequences]]
    
    if revcomp:
        matrices = [np.flipud(np.fliplr(x)) for x in matrices]
    
                
    # Get processed matrices to score
                     
    if len(REF) > bin_size/2 or len(ALT) > bin_size/2:

        # mask matrices
        matrices = mask_matrices(CHR, POS, REF, ALT, matrices[0], matrices[1], shift)
    
    if SVTYPE == "BND":  
        
        BND_rel_pos_map = round(var_rel_pos[0]/bin_size - 32)
        
        # Get REF and ALT vectors, excluding diagonal 
        indexes_left = np.triu_indices(BND_rel_pos_map, 2)
        indexes_right = np.triu_indices(target_length_cropped - BND_rel_pos_map, 2)

        REF_L = get_left_BND_map(matrices[0], BND_rel_pos_map)[indexes_left]
        REF_R = get_right_BND_map(matrices[1], BND_rel_pos_map)[indexes_right]
        ALT_L = get_left_BND_map(matrices[2], BND_rel_pos_map)[indexes_left]
        ALT_R = get_right_BND_map(matrices[2], BND_rel_pos_map)[indexes_right]

        matrices = (assemple_BND_mats(REF_L, REF_R, BND_rel_pos_map),
                    assemple_BND_mats(ALT_L, ALT_R, BND_rel_pos_map))


    # Get disruption score and correlation for this variant
    scores_results = {}
    
    if get_maps:
        scores_results['maps'] = [matrices[0], matrices[1]]
        
        
    for score in scores:
        scores_results[score] = getattr(scoring_map_methods(matrices[0], matrices[1]), 
                                        score)()
        
        if get_tracks and score in ['corr', 'mse']:
            scores_results[f'{score}_track'] = getattr(scoring_map_methods(matrices[0], matrices[1]), 
                                                       f'{score}_track')()
            

    return scores_results

In [None]:
def adjust_seq_ends_BND(CHR, position, adjust, shift):
           
    '''
    Get start (adjust = 1) or end (adjust = MB) of sequence for prediction based on variant position \
    with respect to chromosome arm ends (defined in get_variant_position function).
    Different from adjust_seq_ends because it does not require centro_start, centro_stop, and chrom_max as input.
    
    '''
    
    if position == 'chrom_start':
        seq_pos = adjust + 1 + abs(shift) # 1 is added so the position is never 0. coordinates are 1-based

    elif position == 'chrom_centro_right':
        seq_pos = int(centromere_coords[centromere_coords.CHROM == CHR]['centro_stop']) + adjust + abs(shift) 

    elif position == 'chrom_end':
        seq_pos = int(chrom_lengths[chrom_lengths.CHROM == CHR[3:]]['chrom_max']) - MB + adjust - abs(shift) 

    elif position == 'chrom_centro_left':
        seq_pos = int(centromere_coords[centromere_coords.CHROM == CHR]['centro_start']) - MB + adjust - abs(shift) 
        
    return seq_pos

In [134]:
def get_sequences(CHR, POS, REF, ALT, shift, revcomp: bool):
  
    '''
    Get reference and alternate sequence for prediction from REF and ALT alleles by incorporating ALT into the reference genome.
    Requires ALT allele to be a sequence and not a symbolic allele.
    Use positive sign for a right shift and negative for a left shift.
    revcomp: Take the reverse compliment of the resulting sequence.
    
    '''

    # Get reference sequence
    
    REF_len = len(REF)

    REF_half_left = math.ceil((MB - REF_len)/2) - shift # if the REF allele is odd, shift right
    REF_half_right = math.floor((MB - REF_len)/2) + shift

    
    # Annotate whether variant position with respect to chromosome arms ends
    if len(REF) <= len(ALT):
        var_position = get_variant_position(CHR, POS, REF_len, REF_half_left, REF_half_right)
  
    elif len(REF) > len(ALT):       
        ALT_len = len(ALT)
        ALT_half_left = math.ceil((MB - ALT_len)/2) - shift
        ALT_half_right = math.floor((MB - ALT_len)/2) + shift   
        var_position = get_variant_position(CHR, POS, ALT_len, ALT_half_left, ALT_half_right)
    

    # Get last coordinate of chromosome
    chrom_max = int(chrom_lengths[chrom_lengths.CHROM == CHR[3:]]['chrom_max'])
    
    # Get centromere coordinates
    centro_start = int(centromere_coords[centromere_coords.CHROM == CHR]['centro_start']) + 1
    centro_stop = int(centromere_coords[centromere_coords.CHROM == CHR]['centro_stop']) + 1
    
    
    # Get start and end of reference sequence
    if var_position == "chrom_mid":
        REF_start = POS - REF_half_left
        REF_stop = REF_start + MB 
    elif var_position == "centromere":
        raise ValueError('Centromeric variant')
    else:
        REF_start = adjust_seq_ends(centro_start, centro_stop, chrom_max, var_position, 0, shift)
        REF_stop = adjust_seq_ends(centro_start, centro_stop, chrom_max, var_position, MB, shift)
        print("Warning: Variant not centered; too close to chromosome arm ends.")
        
        
    # Get reference sequence
    REF_seq = fasta_open.fetch(CHR, REF_start - 1, REF_stop - 1).upper()


    # Error if N composition is more than 5% of sequence
    if Counter(REF_seq)['N']/MB*100 > 5:
        raise ValueError('N composition greater than 5%')



    # Error if reference sequence does not match given REF

    if var_position == "chrom_mid":
        var_rel_pos_REF = REF_half_left

    elif var_position == "chrom_start": 
        var_rel_pos_REF = POS - abs(shift) - 1

    elif var_position == "chrom_centro_right": 
        var_rel_pos_REF = POS - centro_stop - abs(shift)

    elif var_position in ["chrom_end", "chrom_centro_left"]: 
        var_rel_pos_REF = -(REF_stop - POS)


    if REF_seq[var_rel_pos_REF : var_rel_pos_REF + REF_len] != REF:
        raise ValueError('Reference allele does not match hg38.')
            
            
            
    # Error if reference sequence is not the right length      
    if len(REF_seq) != MB:
        raise ValueError('Reference sequence generated is not the right length.')





    # For SNPs, MNPs, Insertions: 
    if len(REF) <= len(ALT):

        # Create alternate sequence: change REF sequence at position from REF to ALT

        ALT_seq = REF_seq

        ALT_seq = ALT_seq[:var_rel_pos_REF] + ALT + ALT_seq[var_rel_pos_REF + REF_len:]


        var_rel_pos_ALT = var_rel_pos_REF
        
        # Chop off ends of alternate sequence if it's longer 
        if len(ALT_seq) > len(REF_seq):
            to_remove = (len(ALT_seq) - len(REF_seq))/2

            if to_remove == 0.5:
                ALT_seq = ALT_seq[1:]
                var_rel_pos_ALT = var_rel_pos_REF - 1
                
            else:
                ALT_seq = ALT_seq[math.ceil(to_remove) : -math.floor(to_remove)]
                var_rel_pos_ALT = var_rel_pos_REF - math.ceil(to_remove)
                
            


    # For Deletions
    elif len(REF) > len(ALT):


        del_len = len(REF) - len(ALT)
        
        to_add_left = math.ceil(del_len/2)
        to_add_right = math.floor(del_len/2) 

        # Get start and end of reference sequence
        if var_position == "chrom_mid":
            ALT_start = REF_start - to_add_left
            ALT_stop = REF_stop + to_add_right

        elif var_position in ["chrom_start", "chrom_centro_right"]: 
            ALT_start = adjust_seq_ends(centro_start, centro_stop, chrom_max, var_position, 0, shift)
            ALT_stop = adjust_seq_ends(centro_start, centro_stop, chrom_max, var_position, MB + del_len, shift)
            
        elif var_position in ["chrom_centro_left", "chrom_end"]: 
            ALT_start = adjust_seq_ends(centro_start, centro_stop, chrom_max, var_position, 0 - del_len, shift)
            ALT_stop = adjust_seq_ends(centro_start, centro_stop, chrom_max, var_position, MB, shift)
        
        
        
        # Get alternate sequence
        ALT_seq = fasta_open.fetch(CHR, ALT_start - 1, ALT_stop - 1).upper()
        
        
        
        # Error if alternate sequence does not match REF at POS

        if var_position == "chrom_mid":
            var_rel_pos_ALT = REF_half_left + to_add_left

        elif var_position == "chrom_start": 
            var_rel_pos_ALT = POS - abs(shift) - 1
            
        elif var_position == "chrom_centro_right": 
            var_rel_pos_ALT = POS - centro_stop - abs(shift)
            
        elif var_position in ["chrom_end", "chrom_centro_left"]: 
            var_rel_pos_ALT = -(REF_stop - POS)
                
                
        if ALT_seq[var_rel_pos_ALT : var_rel_pos_ALT + REF_len] != REF:
            raise ValueError('Sequence for the alternate allele does not match hg38 at REF position.')


    
        # Change alternate sequence to match ALT at POS

        if var_position == "chrom_mid":
            ALT_seq = ALT_seq[:var_rel_pos_ALT - 1] + ALT + ALT_seq[var_rel_pos_ALT + REF_len - 1:] 

        elif var_position == "chrom_start": 
            ALT_seq = ALT_seq[:var_rel_pos_ALT + 1] + ALT + ALT_seq[var_rel_pos_ALT + 1 + REF_len:]
            
        elif var_position == "chrom_centro_right": 
            ALT_seq = ALT_seq[:var_rel_pos_ALT] + ALT + ALT_seq[var_rel_pos_ALT + REF_len:]
            
        elif var_position in ["chrom_end", "chrom_centro_left"]: 
            ALT_seq = ALT_seq[:var_rel_pos_ALT] + ALT + ALT_seq[var_rel_pos_ALT + REF_len:]

            
    if len(ALT_seq) != MB:
        raise ValueError('Alternate sequence generated is not the right length.')
         
            
    # Take reverse compliment of sequence
    if revcomp:
        REF_seq, ALT_seq = [str(Seq(x).reverse_complement()) for x in [REF_seq, ALT_seq]]

        
    return REF_seq, ALT_seq, [var_rel_pos_REF, var_rel_pos_ALT]



In [None]:
# removed this completely 

def from_upper_triu(vector_repr, matrix_len, num_diags):
    
    '''
    Get a matrix from a vector representatin of the upper triangle.
    
    '''
    
    z = np.zeros((matrix_len,matrix_len))
    triu_tup = np.triu_indices(matrix_len,num_diags)
    
    z[triu_tup] = vector_repr
    
    for i in range(-num_diags+1,num_diags):
        set_diag(z, np.nan, i)
        
    return z + z.T

In [None]:
parser.add_argument('--chrom',
                    dest = 'chrom_lengths', 
                    help = 'File with lengths of chromosomes in hg38. Columns: chromosome (ex: 1), length; no header.', 
                    type = str,
                    default = 'data/chrom_lengths_hg38',
                    required = False)

parser.add_argument('--centro',
                    dest = 'centromere_coords', 
                    help = 'Centromere coordinates for hg38. Columns: chromosome (ex: chr1), start, end; no header.', 
                    type = str,
                    default = 'data/centromere_coords_hg38',
                    required = False)


parser.add_argument('--revcomp',
                    dest = 'revcomp', 
                    help = '''Make predictions with the reverse compliment of the sequence.
''', 
                    action='store_true',
                    required = False)

parser.add_argument('--no_revcomp',
                    dest = 'no_revcomp', 
                    help = '''Make predictions without taking the reverse compliment of the sequence. If specified, --revcomp must also be specified.
''', 
                    action='store_false',
                    required = False)

if not revcomp and not no_revcomp:
    raise ValueError('Either revcomp and/or no_revcomp must be True.')

In [None]:
print("Warning: Variant not centered; too close to start of chromosome.")
print("Warning: Variant not centered; too close to right end of centromere.")
print("Warning: Variant not centered; too close to end of chromosome.")
print("Warning: Variant not centered; too close to left end of centromere.")

In [None]:
def get_MSE_from_vector(vector1, vector2):

    '''
    Get MSE between two vectors.
    
    '''
    
    # A vecotr of sum of squared differences for each row of matrix 
    sub_vec = [x - y for x,y in zip(vector1, vector2)]
    
    # Get number of bins that are not nan (not including ones that overlap deletion)
    non_nan_values = np.count_nonzero(np.invert(np.isnan(sub_vec)))
    
    MSE = np.nansum([x**2 for x in sub_vec])/non_nan_values

    return MSE

In [None]:
# from read_input

elif 'maf' in in_file:

    # These are formatted like SVs

    variants = (pd.read_csv(in_file, skiprows = 1, sep = '\t')
                .rename(columns = {'Chromosome':'CHROM', 
                                   'Start_Position':'POS',
                                   'End_Position':'END',
                                   'Reference_Allele':'REF',
                                   'Tumor_Seq_Allele2':'ALT'})
               [['CHROM', 'POS', 'END', 'REF', 'ALT']])
    variants['SVLEN'] = (variants.END - variants.POS).astype('int') # this SVLEN (END-POS) would be 0 for SNPs

    # Might need to do this if there are homozygous variants where Tumor_Seq_Allele1 != Reference_Allele
    # would need to only do this for those variants and concat with the rest
    # variants = pd.melt(variants, 
    #                    id_vars = ['CHROM', 'POS', 'REF'], 
    #                    value_vars = ['ALT1', 'ALT2'], 
    #                    var_name = 'allele', 
    #                    value_name='ALT')

    for i in variants[variants.ALT == '-'].index:
        variants.loc[i,'SVTYPE'] = 'DEL'
    for i in variants[variants.REF == '-'].index:
        variants.loc[i,'SVTYPE'] = 'INS'

In [None]:
if SVTYPE == "DEL":

    # Get REF and ALT allele sequences first
    REF = fasta_open.fetch(CHR, POS - 1, END).upper()
    ALT = REF[0]


elif SVTYPE == "DUP":

    # Insert duplicated sequence before POS
    ALT = fasta_open.fetch(CHR, POS - 1, END).upper()
    REF = ALT[0]


elif SVTYPE == "INV":

    REF = fasta_open.fetch(CHR, POS - 1, END).upper()
    ALT = REF[0] + str(Seq(REF[1:]).reverse_complement())

In [None]:
def get_scores(REF_pred, ALT_pred, scores): # change this to score on maps instead of vectors

    scores_results = {}
    for score in scores:
        
        if score == 'corr':
            correlation, corr_pval = spearmanr(REF_pred, ALT_pred, nan_policy='omit')
            if corr_pval >= 0.05:
                correlation = 1
            scores_results[score] = correlation
            
        elif score == 'mse':
            mse = get_MSE_from_vector(REF_pred, ALT_pred)
            
        scores_results[score] = mse
        
    return scores_results

In [None]:
def get_sequences(CHR, POS, REF, ALT, shift):
  
    # Get reference and alternate sequence from REF and ALT allele using reference genome 
    # use positive sign for a right shift and negative for a left shift

    # Get reference sequence
    
    REF_len = len(REF)

    REF_half_left = math.ceil((MB - REF_len)/2) - shift # if the REF allele is odd, shift right
    REF_half_right = math.floor((MB - REF_len)/2) + shift

    
    # Annotate whether variant is near end of chromosome arms
    if len(REF) <= len(ALT): # For SNPs, MNPs, Insertions
        var_position = get_variant_position(CHR, POS, REF_len, REF_half_left, REF_half_right)
  
    elif len(REF) > len(ALT): # For Deletions        
        ALT_len = len(ALT)
        ALT_half_left = math.ceil((MB - ALT_len)/2) - shift
        ALT_half_right = math.floor((MB - ALT_len)/2) + shift   
        var_position = get_variant_position(CHR, POS, ALT_len, ALT_half_left, ALT_half_right)
    

    # Get last coordinate of chromosome
    chrom_max = int(chrom_lengths[chrom_lengths.CHROM == CHR[3:]]['chrom_max']) 
    
    # Get centromere coordinate
    centro_start = int(centromere_coords[centromere_coords.CHROM == CHR]['centro_start'])
    centro_stop = int(centromere_coords[centromere_coords.CHROM == CHR]['centro_stop'])
    
    
    # Get start and end of reference sequence

    if var_position == "chrom_mid":
        REF_start = POS - REF_half_left
        REF_stop = REF_start + MB 

    elif var_position == "chrom_start": 
        REF_start = 0 + abs(shift)
        REF_stop = MB + abs(shift)
        print("Warning: Variant not centered; too close to start of chromosome.")

    elif var_position == "chrom_centro_left": 
        REF_start = centro_start - MB - abs(shift)
        REF_stop = centro_start - abs(shift)
        print("Warning: Variant not centered; too close to left end of centromere.")
        
    elif var_position == "chrom_centro_right": 
        REF_start = centro_stop + abs(shift)
        REF_stop = centro_stop + MB + abs(shift)
        print("Warning: Variant not centered; too close to right end of centromere.")

    elif var_position == "chrom_end": 
        REF_start = chrom_max - MB - abs(shift)
        REF_stop = chrom_max - abs(shift)
        print("Warning: Variant not centered; too close to end of chromosome.")
        
    elif var_position == "centromere":
        raise ValueError('Centromeric variant')

        
        
    # Get reference sequence
    REF_seq = fasta_open.fetch(CHR, REF_start, REF_stop).upper()


    # Error if Ns are more than 5% of sequence
    if Counter(REF_seq)['N']/MB*100 > 5:
        raise ValueError('N composition greater than 5%')



    # Make sure that reference sequence matches given REF

    if var_position == "chrom_mid":
        if REF_seq[(REF_half_left - 1) : (REF_half_left - 1 + REF_len)] != REF:
            raise ValueError('Reference allele does not match hg38.')

    elif var_position == "chrom_start": 
        if REF_seq[(POS - abs(shift) - 1) : (POS - abs(shift) - 1 + REF_len)] != REF:
            raise ValueError('Reference allele does not match hg38.')

    elif var_position == "chrom_centro_right": 
        POS_adj = POS - centro_stop - abs(shift)
        if REF_seq[(POS_adj - 1) : (POS_adj - 1 + REF_len)] != REF:
            raise ValueError('Reference allele does not match hg38.')

    elif var_position in ["chrom_end", "chrom_centro_left"]: 
        if REF_seq[-(REF_stop - POS + 1) : -(REF_stop - POS + 1 - REF_len)] != REF:
            raise ValueError('Reference allele does not match hg38.')


    if len(REF_seq) != MB:
            raise ValueError('Reference sequence generated is not the right length.')





    # For SNPs, MNPs, Insertions: 
    if len(REF) <= len(ALT):

        # Create alternate sequence: change REF sequence at position from REF to ALT

        ALT_seq = REF_seq

        if var_position == "chrom_mid":
            ALT_seq = ALT_seq[:(REF_half_left - 1)] + ALT + ALT_seq[(REF_half_left - 1 + REF_len):]

        elif var_position == "chrom_start": 
            ALT_seq = ALT_seq[:(POS - abs(shift) - 1)] + ALT + ALT_seq[(POS - abs(shift) - 1 + REF_len):]
            
        elif var_position == "chrom_centro_right": 
            POS_adj = POS - centro_stop - abs(shift)
            ALT_seq = ALT_seq[:(POS_adj - 1)] + ALT + ALT_seq[(POS_adj - 1 + REF_len):]

        elif var_position in ["chrom_end", "chrom_centro_left"]: 
            ALT_seq = ALT_seq[:-(REF_stop - POS + 1)] + ALT + ALT_seq[-(REF_stop - POS + 1 - REF_len):]
            
            


        # Chop off ends of alternate sequence if it's longer 
        if len(ALT_seq) > len(REF_seq):
            to_remove = (len(ALT_seq) - len(REF_seq))/2

            if to_remove == 0.5:
                ALT_seq = ALT_seq[1:]
            else:
                ALT_seq = ALT_seq[math.ceil(to_remove) : -math.floor(to_remove)]


    # For Deletions
    elif len(REF) > len(ALT):


        del_len = len(REF) - len(ALT)
        
        to_add_left = math.ceil(del_len/2)
        to_add_right = math.floor(del_len/2) 

        # Get start and end of reference sequence
        if var_position == "chrom_mid":
            ALT_start = REF_start - to_add_left
            ALT_stop = REF_stop + to_add_right

        if var_position == "chrom_start": 
            ALT_start = 0 + abs(shift)
            ALT_stop = MB + del_len + abs(shift)
            
        if var_position == "chrom_centro_left": 
            ALT_start = centro_start - MB - del_len - abs(shift)
            ALT_stop = centro_start - abs(shift)

        if var_position == "chrom_centro_right": 
            ALT_start = centro_stop + abs(shift)
            ALT_stop = centro_stop + MB + del_len + abs(shift)

        if var_position == "chrom_end": 
            ALT_start = chrom_max - MB - del_len - abs(shift)
            ALT_stop = chrom_max - abs(shift)
            
            
        # Get alternate sequence
        ALT_seq = fasta_open.fetch(CHR, ALT_start, ALT_stop).upper()
        
        
        
        # Make sure that alternate sequence matches REF at POS

        if var_position == "chrom_mid":
            if ALT_seq[(REF_half_left - 1 + to_add_left) : (REF_half_left - 1 + to_add_left + REF_len)] != REF:
                raise ValueError('Sequence for the alternate allele does not match hg38 at REF position.')

        elif var_position == "chrom_start": 
            if ALT_seq[(POS - abs(shift) - 1) : (POS - abs(shift) - 1 + REF_len)] != REF:
                raise ValueError('Sequence for the alternate allele does not match hg38 at REF position.')
            
        elif var_position == "chrom_centro_right": 
            POS_adj = POS - centro_stop
            if ALT_seq[(POS_adj - abs(shift) - 1) : (POS_adj - abs(shift) - 1 + REF_len)] != REF:
                raise ValueError('Sequence for the alternate allele does not match hg38 at REF position.')
            
        elif var_position in ["chrom_end", "chrom_centro_left"]: 
            if ALT_seq[-(REF_stop - POS + 1) : -(REF_stop - POS - REF_len + 1)] != REF:
                raise ValueError('Sequence for the alternate allele does not match hg38 at REF position.')


    
        # Change alternate sequence to match ALT at POS

        if var_position == "chrom_mid":
            # [:N] does not include N but [N:] includes N
            ALT_seq = ALT_seq[:(REF_half_left - 1 + to_add_left)] + ALT + ALT_seq[(REF_half_left - 1 + to_add_left + REF_len):] 

        elif var_position == "chrom_start": 
            ALT_seq = ALT_seq[:(POS - abs(shift) - 1)] + ALT + ALT_seq[(POS - abs(shift) - 1 + REF_len):]
            
        elif var_position == "chrom_centro_right": 
            POS_adj = POS - centro_stop
            ALT_seq = ALT_seq[:(POS_adj - abs(shift) - 1)] + ALT + ALT_seq[(POS_adj - abs(shift) - 1 + REF_len):]
            
        elif var_position in ["chrom_end", "chrom_centro_left"]: 
            ALT_seq = ALT_seq[:-(REF_stop - POS + 1)] + ALT + ALT_seq[-(REF_stop - POS - REF_len + 1):]

            
    if len(ALT_seq) != MB:
        raise ValueError('Alternate sequence generated is not the right length.')
         
        
    return REF_seq, ALT_seq



In [None]:
def get_scores_BND(REF_pred_L, REF_pred_R, ALT_pred):
    
    # Get REF and ALT vectors, excluding diagonal 
    indexes = np.triu_indices(bins/2, 2)
    
    REF_UL = upper_left(REF_pred_L)[indexes]
    REF_LR = lower_right(REF_pred_R)[indexes]
    ALT_UL = upper_left(ALT_pred)[indexes]
    ALT_LR = lower_right(ALT_pred)[indexes]
    
    seq_REF = np.append(REF_UL, REF_LR)
    seq_ALT = np.append(ALT_UL, ALT_LR)
    
    # Get disruption score 
    disruption_score = get_DS_from_vector(seq_REF, seq_ALT)
    
    # Get spearman correlation
    correlation, corr_pval = spearmanr(seq_REF, seq_ALT)
    
    if corr_pval >= 0.05:
        correlation = 0
    
    return disruption_score, correlation, corr_pval

In [None]:
# import itertools
# position_pairs = np.array(list(itertools.product(variant_positions, variant_positions)))

# position_pairs = [x+'_'+y for x,y in zip(np.array(position_pairs)[:,0], np.array(position_pairs)[:,1])]
# dict(zip(position_pairs,len(position_pairs)*['test']))



get_BND_ALT_sense_by_pos = {'chrom_mid_chrom_mid': [ignore, 
                                                    ignore, 
                                                    ignore], # M M
                         'chrom_mid_chrom_start': [adjust_right, 
                                                   adjust_right_antisense_left, 
                                                   adjust_right_antisense_right], # M S
                         'chrom_mid_chrom_centro_left': [adjust_right, 
                                                         adjust_right_antisense_left, 
                                                         adjust_right_antisense_right], # M S
                         'chrom_mid_chrom_centro_right': [adjust_right, 
                                                          adjust_right_antisense_left, 
                                                          adjust_right_antisense_right], # M S
                         'chrom_mid_chrom_end': [adjust_right, 
                                                 adjust_right_antisense_left, 
                                                 adjust_right_antisense_right], # M S
                         'chrom_start_chrom_mid': [adjust_left, 
                                                   adjust_left_antisense_left, 
                                                   adjust_left_antisense_right], # S M
                         'chrom_start_chrom_start': [adjust_both_start, 
                                                     raise_error, 
                                                     raise_error], # S S
                         'chrom_start_chrom_centro_left': [raise_error,
                                                           adjust_start_end_antisense_left,
                                                           adjust_start_end_antisense_right], # S E
                         'chrom_start_chrom_centro_right': [adjust_both_start, 
                                                            raise_error, 
                                                            raise_error], # S S
                         'chrom_start_chrom_end': [raise_error,
                                                   adjust_start_end_antisense_left,
                                                   adjust_start_end_antisense_right], # S E
                         'chrom_centro_left_chrom_mid': [adjust_left, 
                                                         adjust_left_antisense_left, 
                                                         adjust_left_antisense_right], # E M
                         'chrom_centro_left_chrom_start': [raise_error,
                                                           adjust_end_start_antisense_left,
                                                           adjust_end_start_antisense_right], # E S
                         'chrom_centro_left_chrom_centro_left': [adjust_both_end, 
                                                                 raise_error, 
                                                                 raise_error], # E E
                         'chrom_centro_left_chrom_centro_right': [raise_error,
                                                                  adjust_end_start_antisense_left,
                                                                  adjust_end_start_antisense_right], # E S
                         'chrom_centro_left_chrom_end': [adjust_both_end, 
                                                         raise_error, 
                                                         raise_error], # E E
                         'chrom_centro_right_chrom_mid': [adjust_left, 
                                                          adjust_left_antisense_left, 
                                                          adjust_left_antisense_right], # S M
                         'chrom_centro_right_chrom_start': [adjust_both_start, 
                                                            raise_error, 
                                                            raise_error], # S S
                         'chrom_centro_right_chrom_centro_left': [raise_error,
                                                                  adjust_start_end_antisense_left,
                                                                  adjust_start_end_antisense_right], # S E
                         'chrom_centro_right_chrom_centro_right': [adjust_both_start, 
                                                                   raise_error, 
                                                                   raise_error], # S S
                         'chrom_centro_right_chrom_end': [raise_error,
                                                          adjust_start_end_antisense_left,
                                                          adjust_start_end_antisense_right], # S E
                         'chrom_end_chrom_mid': [adjust_left, 
                                                 adjust_left_antisense_left, 
                                                 adjust_left_antisense_right], # E M
                         'chrom_end_chrom_start': [raise_error, 
                                                   adjust_end_start_antisense_left,
                                                   adjust_end_start_antisense_right], # E S
                         'chrom_end_chrom_centro_left': [adjust_both_end, 
                                                         raise_error, 
                                                         raise_error], # E E
                         'chrom_end_chrom_centro_right': [raise_error,
                                                          adjust_end_start_antisense_left,
                                                          adjust_end_start_antisense_right], # E S
                         'chrom_end_chrom_end': [adjust_both_end, 
                                                 raise_error, 
                                                 raise_error]} # E E



In [None]:
parser.add_argument('--format',
                    dest = 'file_format', 
                    help = 'Format for input file.', 
                    type = str,
                    choices = ['vcf', 'df'],
                    default = 'vcf',
                    required = False)

parser.add_argument('--type',
                    dest = 'var_type', 
                    help = 'Variant type: simple or SV.', 
                    type = str,
                    choices = ['simple', 'SV'],
                    default = 'SV',
                    required = False)

In [None]:
# it cannot take in boolean values

parser.add_argument('--revcomp',
                    dest = 'revcomp', 
                    nargs = '+', 
                    help = 'Make predictions for the reverse compliment sequence.', 
                    choices = [True, False],
                    default = [False],
                    required = False)

In [None]:
def read_input(in_file, file_format, var_type):

    # Read and reformat variant dataset
    
    if file_format == 'df':
        if var_type == 'simple':
            variants = (pd.read_csv(in_file, skiprows = 1, sep = '\t')
                        .rename(columns = {'Chromosome':'CHROM', 
                                           'Start_Position':'POS',
                                           'End_Position':'END',
                                           'Reference_Allele':'REF',
                                           'Tumor_Seq_Allele2':'ALT'})
                       [['CHROM', 'POS', 'END', 'REF', 'ALT']])
            variants['SVLEN'] = (variants.END - variants.POS).astype('int') # this SVLEN (END-POS) would be 0 for SNPs

            # Might need to do this if there are homozygous variants where Tumor_Seq_Allele1 != Reference_Allele
            # would need to only do this for those variants and concat with the rest
            # variants = pd.melt(variants, 
            #                    id_vars = ['CHROM', 'POS', 'REF'], 
            #                    value_vars = ['ALT1', 'ALT2'], 
            #                    var_name = 'allele', 
            #                    value_name='ALT')

            for i in variants[variants.ALT == '-'].index:
                variants.loc[i,'SVTYPE'] = 'DEL'
            for i in variants[variants.REF == '-'].index:
                variants.loc[i,'SVTYPE'] = 'INS'

            var_type = 'SV' # These are formatted like SVs

        elif var_type == 'SV':
            variants = (pd.read_csv(in_file, sep = '\t', low_memory=False)
                        .rename(columns = {'SV_chrom':'CHROM', 
                                           'SV_start':'POS',
                                           'SV_end':'END', 
                                           'SV_type':'SVTYPE',
                                           'SV_length':'SVLEN'})
                       [['CHROM', 'POS', 'END', 'REF', 'ALT', 'SVTYPE', 'SVLEN']])
            variants['CHROM'] = ['chr' + str(x) for x in variants['CHROM']]
            variants.loc[~pd.isnull(variants.END), 'END'] = variants.loc[~pd.isnull(variants.END), 'END'].astype('int')


    elif file_format == 'vcf':
        if in_file.endswith('.gz'):
            variants = read_vcf_gz(in_file)
        else:
            variants = read_vcf(in_file)
            
        if var_type == 'simple':
            variants = variants[['CHROM', 'POS', 'REF', 'ALT']]

        elif var_type == 'SV':
            variants['END'] = variants.INFO.str.split('END=').str[1].str.split(';').str[0] # this SVLEN (END-POS) would be 0 for SNPs
            variants.loc[~pd.isnull(variants.END), 'END'] = variants.loc[~pd.isnull(variants.END), 'END'].astype('int')
            variants['SVTYPE'] = variants.INFO.str.split('SVTYPE=').str[1].str.split(';').str[0]
            variants['SVLEN'] = variants.INFO.str.split('SVLEN=').str[1].str.split(';').str[0]
            variants = variants[['CHROM', 'POS', 'END', 'REF', 'ALT', 'SVTYPE', 'SVLEN']]
            
    variants.reset_index(inplace = True, drop = True)
    
    return variants
     


In [None]:
def get_sequences_BND(CHR, POS, ALT, shift):

    if '[' in ALT:

        if ALT[0] in nt:

            # t[p[

            CHR2 = ALT.split(':')[0].split('[')[1]
            POS2 = int(ALT.split('[')[1].split(':')[1])
            ALT_t = ALT.split('[')[0]

            ALT_left = fasta_open.fetch(CHR, POS - half_patch_size + shift, POS).upper() # don't inlcude POS

            ALT_right = fasta_open.fetch(CHR2, POS2 + 1, POS2 + 1 + half_patch_size + shift).upper() 

            REF_for_left = fasta_open.fetch(CHR, POS - half_patch_size + shift, POS + half_patch_size + shift).upper()
            REF_for_right = fasta_open.fetch(CHR2, POS2 - half_patch_size + shift, POS2 + half_patch_size + shift).upper() 


        elif ALT[0] not in nt:

            #  [p[t

            CHR2 = ALT.split(':')[0].split('[')[1]
            POS2 = int(ALT.split('[')[1].split(':')[1])
            ALT_t = ALT.split('[')[2]

            ALT_left_revcomp = fasta_open.fetch(CHR2, POS2 + 1, POS2 + 1 + half_patch_size - shift).upper() # don't include POS2
            ALT_left = str(Seq(ALT_left_revcomp).reverse_complement())

            ALT_right = fasta_open.fetch(CHR, POS + 1, POS + 1 + half_patch_size + shift).upper()

            REF_for_left_revcomp = fasta_open.fetch(CHR2, POS2 - half_patch_size - shift, POS2 + half_patch_size - shift).upper() 
            REF_for_left = str(Seq(REF_for_left_revcomp).reverse_complement())
            REF_for_right = fasta_open.fetch(CHR, POS - half_patch_size + shift, POS + half_patch_size + shift).upper() 


    elif ']' in ALT:

        if ALT[0] in nt:

            # t]p]

            CHR2 = ALT.split(':')[0].split(']')[1]
            POS2 = int(ALT.split(']')[1].split(':')[1])
            ALT_t = ALT.split(']')[0]

            ALT_left = fasta_open.fetch(CHR, POS - half_patch_size + shift, POS).upper() # don't include POS

            ALT_right_revcomp = fasta_open.fetch(CHR2, POS2 - half_patch_size - shift, POS2).upper() # don't include POS2
            ALT_right = str(Seq(ALT_right_revcomp).reverse_complement())

            REF_for_left = fasta_open.fetch(CHR, POS - half_patch_size + shift, POS + half_patch_size + shift).upper()
            REF_for_right_revcomp = fasta_open.fetch(CHR2, POS2 - half_patch_size - shift, POS2 + half_patch_size - shift).upper()
            REF_for_right = str(Seq(REF_for_right_revcomp).reverse_complement())



        elif ALT[0] not in nt:

            # ]p]t

            CHR2 = ALT.split(':')[0].split(']')[1]
            POS2 = int(ALT.split(']')[1].split(':')[1])
            ALT_t = ALT.split(']')[2]

            ALT_left = fasta_open.fetch(CHR2, POS2 - half_patch_size + shift, POS2).upper() # don't include POS2

            ALT_right = fasta_open.fetch(CHR, POS + 1, POS + 1 + half_patch_size + shift).upper()

            REF_for_left = fasta_open.fetch(CHR2, POS2 - half_patch_size + shift, POS2 + half_patch_size + shift).upper() 
            REF_for_right = fasta_open.fetch(CHR, POS - half_patch_size + shift, POS + half_patch_size + shift).upper() 


    ALT_seq = ALT_left + ALT_t + ALT_right

    # chop off the sides if longer than ~1MB
    if len(ALT_seq) > MB:
        to_remove = (len(ALT_seq) - MB)/2

        if to_remove == 0.5:
            ALT_seq = ALT_seq[1:]
        else:
            ALT_seq = ALT_seq[math.ceil(to_remove) : -math.floor(to_remove)]

        
    return REF_for_left, REF_for_right, ALT_seq
