In [None]:
import pandas as pd
import altair as alt
import numpy as np

In [None]:
file = '/Users/ivan/Downloads/BARD1.snvscores.tsv' #path to SGE scores
ref_path = '../Data/SNV_filtering_inputs/20240809_BARD1_SNVlib_ref_seqs_intron_annotated.xlsx' #path to reference sequence
coord_file = '../Data/SNV_filtering_inputs/20250415_BARD1_filter_entry.xlsx'
#coords = [(214809,214809500),(214797050,214797156)] #genomic coordinates for exon to make map

In [None]:
def get_region_coords(file): #Reads input file to get coordinates for each SGE target
    df = pd.read_excel(file, sheet_name = 'targets')

    coords = []

    i = 0
    while i < len(df):
        target = df['target'][i]
        target_start = df['start'][i]
        target_end = df['end'][i]

        start_end = (target_start, target_end)
        full_tuple = (target, start_end)

        coords.append(full_tuple)

        i += 1
        
    return coords

In [None]:
def read_scores(file,region): #reads scores
    
    data = pd.read_csv(file, sep = '\t')
    data = data.rename(columns = {'simplified_consequence': 'Consequence', 'score': 'snv_score_minmax'})
    data['pos'] = data['pos'].astype(str)
    data['pos_id'] = data['pos'] + ':' + data['allele']
    
    data = data[['exon','target','pos', 'pos_id', 'Consequence', 'snv_score_minmax', 'amino_acid_change', 'functional_consequence']]
    data = data.loc[data['target'].isin([region])]
    
    return data

In [None]:
def get_reference(ref, coords,region): #pulls out reference sequence
    start, end = coords

    #Generates list of coordinates for each target
    list_coords = []
    for i in range(end, start + 1): #end and start flipped due to input having start coord > end coord on antisense strnad
        list_coords.append(i)

    ref = pd.read_excel(ref)
    ref = ref.loc[ref['target'].isin([region])]
    ref = ref[['target', 'Reference', 'pos']]
    x_coord = ref.loc[ref['pos'].isin(list_coords)]
    
    return x_coord
    

In [None]:
def reverse_complement(seq_string):
    reverse_seq = seq_string[::-1]
    reverse_comp_list = []
    for char in reverse_seq:
        if char == "A":
            reverse_comp_list.append("T")
        elif char == "G":
            reverse_comp_list.append("C")
        elif char == "C":
            reverse_comp_list.append("G")
        else:
            reverse_comp_list.append("A")
    reverse_compliment_str = "".join(reverse_comp_list)
    return reverse_compliment_str

In [None]:
def reverse_comp_ref(x_ref): #reverse complements reference for antisense gene

    ref_list = x_ref['Reference'].tolist()
    ref_string = ''.join(ref_list)
    ref_string = ref_string.upper()
    reversed = reverse_complement(ref_string)

    reversed_ref = []
    for char in reversed:
        reversed_ref.append(char)

    x_ref = x_ref[::-1].reset_index(drop = True)


    x_ref['Reference'] = reversed_ref

    x_ref_reversed = x_ref

    return x_ref_reversed

In [None]:
def row_enumerate_ref(ref_df): #enumerates each row for heat map and each column for the base pair number
    ref_df['Row'] = None

    bases = ['A', 'C', 'G', 'T']

    for base in bases:
        ref_df.loc[ref_df['Reference'] == base, 'Row'] = base

    bp_num = []
    i = 0
    while i < len(ref_df):
        bp_num.append(i + 1)

        i += 1
    ref_df['Column'] = bp_num

    return ref_df
    

In [None]:
def reverse_posid(string): #to reverse complement pos_id for antisense gene to get base change on coding strand
    split = string.split(':')
    reversed = reverse_complement(split[1])

    split[1] = reversed

    reversed_id = ':'.join(split)
    
    return reversed_id

In [None]:
def process_data(df): #groups consequence of SNVs, enumerates row for base change
    df = df.reset_index(drop = True)
  
    df.loc[df['Consequence'].str.contains('missense'), 'Consequence'] = 'Missense'
    df.loc[df['Consequence'] == 'synonymous_variant', 'Consequence'] = 'Synonymous'
    df.loc[df['Consequence'] == 'intron_variant', 'Consequence'] = 'Intron'
    df.loc[df['Consequence'] == 'stop_gained', 'Consequence'] = 'Stop Gained'
    df.loc[df['Consequence'] == 'stop_lost', 'Consequence'] = 'Stop Lost'
    df.loc[df['Consequence'].str.contains('splic'), 'Consequence'] = 'Splice'
    df.loc[df['Consequence'].str.contains('UTR'), 'Consequence'] = 'UTR Variant'
    df.loc[df['Consequence'] == 'start_lost', 'Consequence'] = 'Start Lost'
    

    df['rev_pos_id'] = df['pos_id'].transform(lambda x: reverse_posid(x))
    df['Row'] = df['rev_pos_id'].transform(lambda x: x[-1])

    df = df[::-1].reset_index(drop = True)

    return df

In [None]:
def column_enumerate_data(df,refdf): #enumerates column for basepair that was changed
    column_dict = {} #dictionary to store the column number for each genomic coordinate

    i = 0
    while i < len(refdf): #makes the dictionary
        coord = refdf['pos'][i]
        col = refdf['Column'][i]

        column_dict[coord] = col

        i += 1

    df['Column'] = np.nan #empty column to hold column values

    j = 0
    while j < len(df): #assigns the column values
        id = df['pos_id'][j]
        split = id.split(':')
        coord = int(split[0])
        col = column_dict[coord]

        df.loc[df['pos_id'] == id, 'Column'] = col

        j += 1

    return df

In [None]:
def heatmap(data, letters,region):
    
    # Filter out the cells that will display letters from the heatmap dataset
    heatmap_data = data.merge(letters, on=['Row', 'Column'], how='left', indicator=True)
    heatmap_data = heatmap_data[heatmap_data['_merge'] == 'left_only'].drop(columns=['Reference', '_merge'])

    # Define the rectangle size and spacing
    rect_size = 15
    spacing = 7.5

    total_width = (rect_size + spacing) * len(letters) - spacing
    total_height = (rect_size + spacing) * len(data['Row'].unique()) - spacing

    target = region.split('_')
    title_s = 'Exon ' + target[1]

    # Create the background heatmap with borders
    background = alt.Chart(heatmap_data).mark_rect(
        width=rect_size,
        height=rect_size,
        strokeWidth=2
    ).encode(
        x=alt.X('pos_x:O', sort = 'descending', title='Basepair', axis=alt.Axis(labelAngle=270)),
        y=alt.Y('Row:N', title='SNV'),
        color=alt.condition(
        'datum.is_wt == true',
        alt.value('#000000'),
        alt.condition(
            alt.datum.score <= -0.5,
            alt.value('#ff0000'),
            alt.condition(
                alt.datum.score >= 0,
                alt.value('#0000ff'),
                alt.Color('snv_score_minmax:Q', title = 'SGE Score',
                          scale = alt.Scale(
                              domain = [-0.5, 0],
                              range = ['#ff0000', '#a6a6a6', '#0000ff']
                          )
                         )
            )
        )
    ),
        stroke=alt.Stroke('Consequence:N', title='Consequence', legend = alt.Legend(symbolFillColor = 'white')),
        tooltip = [alt.Tooltip('amino_acid_change', title = 'Amino Acid Change: '),
                   alt.Tooltip('Consequence', title = 'Consequence: '),
                   alt.Tooltip('snv_score_minmax', title = 'SGE Score: '),
                   alt.Tooltip('functional_consequence', title = 'Functional Consequence: '),
                   alt.Tooltip('pos_id', title = 'Position ID: ')
                  ]
                   
    ).properties(
        width= total_width,
        height= total_height
    )
    
    # Create the text overlay
    letters['pos'] = letters['pos'].astype(str) #Sets data type of pos column to string
    letters = letters.rename(columns = {'pos': 'pos_x'}) #Renames to have same x encoding as heat map

    #Builds text overlay
    text = alt.Chart(letters).mark_text(
        align='center',
        baseline='middle',
        fontSize=14
    ).encode(
        x=alt.X('pos_x:O', sort = 'descending', title='Basepair'),
        y=alt.Y('Row:N', title='SNV'),
        text='Reference:N',
        color=alt.value('black')
    )
    
    # Combine the background and text
    heatmap = alt.layer(
        background, text
    ).properties(
        title = title_s
    )

    # Display the chart
    heatmap.show()

    return heatmap

In [None]:
def main():
    all_maps = []
    coords = get_region_coords(coord_file)
    
    for elem in coords:
        region, ref_coords = elem
        data = read_scores(file,region)
        ref = get_reference(ref_path, ref_coords, region)
        ref_reversed = reverse_comp_ref(ref)
        ref_enumerated = row_enumerate_ref(ref_reversed)
        sge_row_enumerated = process_data(data)
        sge_ready = column_enumerate_data(sge_row_enumerated, ref_enumerated)
        map = heatmap(sge_ready, ref_enumerated,region)
        all_maps.append(map)

    #i = 1
    #map_1 = all_maps[0]
    #while i < len(all_maps):
        #if i == 1:
            #joined = alt.vconcat(map_1,all_maps[1])

            #i += 1
        #elif 1 < i < len(all_maps):
            #joined = alt.vconcat(joined,all_maps[i])

            #i += 1
        #else:
            #joined = joined.configure_view(
                    #stroke = None
            #)

            #i += 1

        #i += 1
    
    #joined.show()

In [None]:
main()