In [None]:
import pandas as pd
import altair as alt

In [None]:
rad51d_data = '../Data/filtered_ppj_data/SGE/RAD51D.tsv'
xrcc2_data = '../Data/filtered_ppj_data/SGE/XRCC2.tsv'

'''
domains = {'RAD51D': [(13,61), (101, 310)],
           'XRCC2': [(42, 280)]
          }
'''

genes = ['RAD51D', 'XRCC2']

rad_annotation_df =  pd.DataFrame([
        {'start': 13, 'end': 61, 'label': 'N-Terminal Domain', 'color': '#B9DBF4'},
        {'start': 101, 'end': 310, 'label': 'RecA', 'color': '#C8DBC8'},
        {'start': 1, 'end': 13, 'label': '', 'color': 'grey'},
        {'start': 61, 'end': 101, 'label': '', 'color': 'grey'},
        {'start': 310, 'end': 328, 'label': '', 'color': 'grey'}
        # Add more annotations as needed
    ])

x_annotation_df =  pd.DataFrame([
        {'start': 1, 'end': 42, 'label': 'N-Terminal Domain', 'color': 'grey'},
        {'start': 42, 'end': 280, 'label': 'RecA', 'color': '#C8DBC8'},
        # Add more annotations as needed
    ])

annotation_dfs = {'RAD51D': rad_annotation_df,
                  'XRCC2': x_annotation_df
                 }

In [None]:
def read_data(rad, xrcc):
    rad_df = pd.read_csv(rad, sep = '\t')
    x_df = pd.read_csv(xrcc, sep = '\t')

    dfs = {'RAD51D': rad_df,
            'XRCC2': x_df
          }

    genes = dfs.keys()

    for gene in genes:
        df = dfs[gene]

        df = df.loc[~(df['amino_acid_change'].str.contains('-'))]
    
        df = df.rename(columns = {'amino_acid_change': 'AAsub'})
    
        df['og_AA'] = df['AAsub'].transform(lambda x: x[0]) #Makes column with the original amino acid
        df['AA_change'] = df['AAsub'].transform(lambda x: x[-1]) #makes column with amino acid change
        df['AApos'] = df['AAsub'].transform(lambda x: int(x[1: len(x)-1])) #makes column with residue position
    
        mis_df = df.loc[~(df['consequence'].isin(['stop_gained']))]
        df['AApos'] = df['AApos'].astype(int)

        df = df[['AApos', 'og_AA', 'AA_change', 'score', 'AAsub']]

        min_df = mis_df.groupby('AApos')['score'].min().reset_index()
        min_df['og_AA'] = 'Mis. Min.'
        min_df['AA_change'] = 'Mis. Min.'
    
        mean_df = mis_df.groupby('AApos')['score'].mean().reset_index()
        mean_df['og_AA'] = 'Mis. Mean'
        mean_df['AA_change'] = 'Mis. Mean'
    
        df = pd.concat([df, min_df, mean_df])

        dfs[gene] = df


    return dfs

In [None]:
def heatmap(dfs, annotations, gene):

    df = dfs[gene]
    #domain = domains[gene]
    max_width = 4
    annotation_df = annotations[gene]

    rect_colors = annotation_df['color'].tolist()
    domains = annotation_df['label'].tolist()
    

    
    # Calculate center positions for text
    annotation_df['center'] = (annotation_df['start'] + annotation_df['end']) / 2

    
    # Create domain rectangles
    annotation_rect = alt.Chart(annotation_df).mark_rect(height=25, 
                                                           stroke = 'black',
                                                          strokeWidth = 2 ).encode(
        x=alt.X('start:Q',
                axis = None,
                scale=alt.Scale(domain=[0, 515])),
        x2='end:Q',
        color=alt.Color('label:N', 
                        scale = alt.Scale(domain = domains,
                                          range = rect_colors
                                         ),
                        legend= None),
        tooltip=['label', 'start', 'end']
    ).properties(
        width=1750,
        height=20
    )

    #Domain text labels
    annotation_text = alt.Chart(annotation_df).mark_text(
        color='black',
        fontSize=20,
        fontWeight='bold',
        baseline='middle',
        dy = -10 # This helps with vertical centering
    ).encode(
        x=alt.X('center:Q', 
                scale=alt.Scale(domain=[0,515]),
                axis=None
        ), # Position text in the middle of the 50px height
        text='label:N'
    )

    annotations = alt.layer(annotation_rect, annotation_text).properties(
        width=1200,
        height=20
    )

    prot_length = max(list(set(df['AApos'])))
    adj_width = max_width * prot_length
    order = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y', 'Stop', 'Min.', 'Mean']
    heatmap = alt.Chart(df).mark_rect().encode(
        x = alt.X('AApos:Q',
                  title = 'Amino Acid Position',
                  axis = alt.Axis(
                      labelFontSize = 16,
                      titleFontSize = 20,
                      values = list(range(0, prot_length, 25))
                      
                  ),
                  scale = alt.Scale(domain = [1,prot_length + 1]),
                  bin = alt.Bin(maxbins = prot_length + 1, minstep = 1)
                 ), 
        y = alt.Y('AA_change',
                      title = 'Amino Acid Substitution',
                      axis = alt.Axis(
                          labelFontSize = 16,
                          titleFontSize = 20
                      ),
                     sort = order),
            color = alt.Color('score', 
                              title = 'Functional Score',
                              scale = alt.Scale(
                                  domain = [-0.2, 0],
                                  clamp = True,
                                  scheme = 'bluepurple',
                                  reverse = True
                              ),
                              legend = alt.Legend(
                                  titleFontSize = 20,
                                  labelFontSize = 16
                              )
                             )
        ).properties(
            height = 800, 
            width = adj_width
        )


    heatmapmap = alt.vconcat(annotations, heatmap, spacing = -5).configure_view(
        stroke = None
    ).properties(
        title = alt.TitleParams(gene,
                                anchor = 'middle', 
                                align = 'center',
                                fontSize = 24
                               )
    )
    return heatmap

In [None]:
def main():
    data = read_data(rad51d_data, xrcc2_data)

    base_maps = []
    for gene in genes:
        map = heatmap(data, annotation_dfs, gene)
        base_maps.append(map)

    spacer = alt.Chart(pd.DataFrame({'x': [0]})).mark_point(opacity=0).encode(
        x=alt.X('x:Q', axis=None)
    ).properties(width=100, height=0)

    rad_map = (base_maps[0])
    
    xrcc_map = spacer | (base_maps[1]) 

    stacked_maps = (rad_map & xrcc_map).configure_view(stroke = None)
    stacked_maps.display()

In [None]:
main()