In [None]:
import pandas as pd
import re
import altair as alt

In [None]:
table_s2 = '../Data/final_tables/variant_reclass_table.xlsx'

In [None]:
def read_data(path):
    df = pd.read_excel(path, sheet_name = 'Reclass_evidence')
    df = df.dropna(subset = ['cds_num'])
    df = df.loc[df['Variant Type'] == 'missense']
    df['aa_pos'] = df['p_variant'].transform(lambda x: int(x.split('.')[1][1:-1]))
    df['cds_pos'] = df['c_variant'].str.extract(r'(\d+)').astype(int)

    '''
    df.loc[(df['points_class_updated_points_all_vars'] == 'Benign') | (df['points_class_updated_points_all_vars'] == 'Likely Benign'), 'points_class_updated_points_all_vars'] = 'BLB'
    df.loc[(df['points_class_updated_points_all_vars'] == 'Pathogenic') | (df['points_class_updated_points_all_vars'] == 'Likely Pathogenic'), 'points_class_updated_points_all_vars'] = 'PLP'

    
    '''
    
    return df

In [None]:
def make_lollis(df, stick_base):

    palette = ['#1D7AAB', '#63A1C4' ,'#E6B1B8', '#A0A0A0']
    var_classes = ['Benign', 'Likely Benign', 'Likely Pathogenic', 'Uncertain']

    df['base_rule_value']= stick_base

    sticks = alt.Chart(df).mark_rule().encode(
        x = alt.X('aa_pos:Q',
                  scale = alt.Scale(
                      domain = [0, 777]
                  )
                 ),
        y = alt.Y('updated_points_all_vars_adj:Q'),
        y2 = alt.Y2('base_rule_value:Q'),
        color = alt.Color('points_class_updated_points_all_vars',
                          scale = alt.Scale(
                              domain = var_classes,
                              range = palette
                          )
                         )
    ).properties(
        width = 1750,
        height = 200
    )

    return sticks

In [None]:
def draw_plot(df):

    palette = ['#1D7AAB', '#63A1C4' ,'#E6B1B8', '#A0A0A0']
    var_classes = ['Benign', 'Likely Benign', 'Likely Pathogenic', 'Uncertain']

    
    df = df.loc[~df['updated_points_all_vars'].isin([0, 'indeterminate_func_class'])].copy()
    
    df['updated_points_all_vars'] = df['updated_points_all_vars'].astype(float)
    df.loc[(df['updated_points_all_vars'] <= 5) & (df['updated_points_all_vars'] > -2), 'points_class_updated_points_all_vars'] = 'Uncertain'
    
    df['updated_points_all_vars_adj'] = df['updated_points_all_vars'] + df['updated_points_all_vars'].apply(lambda v: 1.5 if v >= 0 else -1.5)

    
    points_plot = alt.Chart(df).mark_point(
        filled = True,
        strokeWidth = 1
    ).encode(
        x = alt.X('aa_pos:Q',
                  title = 'Amino Acid Position',
                  axis = alt.Axis(
                      titleFontSize = 20,
                      labelFontSize = 18
                  ),
                  scale = alt.Scale(
                      domain = [0, 777]
                  )
                 ),
        y = alt.Y('updated_points_all_vars_adj:Q',
                  title = '',
                  axis = alt.Axis(
                      labels = False,
                      ticks = False
                  ),
                  scale = alt.Scale(
                      domain = [-15, 17.5]
                  )
                 ),
        stroke = alt.Color('points_class_updated_points_all_vars:N',
                       scale = alt.Scale(domain = var_classes, range = palette),
                       legend = alt.Legend(title = 'Classification',
                                          titleFontSize = 18,
                                          labelFontSize = 16)
                      ),
        fill = alt.condition(
            alt.datum.functional_data_used == 'no_func_available',
            alt.value('white'),
            alt.Color('points_class_updated_points_all_vars', 
                      scale = alt.Scale(
                                  domain = var_classes,
                                  range = palette
                                    )# same as color encoding
                    )
        ),
        tooltip = ['updated_points_all_vars']
    ).properties(
        width = 1750,
        height = 200
    )

    path_stick_df = df.loc[df['points_class_updated_points_all_vars'].isin(['Likely Pathogenic'])].copy()
    ben_stick_df = df.loc[df['points_class_updated_points_all_vars'].isin(['Likely Benign', 'Benign'])].copy()

    
    path_sticks = make_lollis(path_stick_df, 1.5)
    ben_sticks = make_lollis(ben_stick_df, -1.5)

    lb_line = alt.Chart(pd.DataFrame({'updated_points_all_vars_adj': [-3.5]})).mark_rule(color = '#63A1C4', strokeDash = [8,8], strokeWidth = 1).encode(
        y = 'updated_points_all_vars_adj')

    b_line = alt.Chart(pd.DataFrame({'updated_points_all_vars_adj': [-8.5]})).mark_rule(color = '#1D7AAB', strokeDash = [8,8], strokeWidth = 1).encode(
        y = 'updated_points_all_vars_adj')

    plp_line = alt.Chart(pd.DataFrame({'updated_points_all_vars_adj': [7.5]})).mark_rule(color = '#E6B1B8', strokeDash = [8,8], strokeWidth = 1).encode(
        y = 'updated_points_all_vars_adj')

    points_plot = (points_plot + lb_line + b_line + plp_line).resolve_scale(
        x = 'shared'
    )

    '''
    points_plot = (points_plot + path_sticks + ben_sticks).resolve_scale(
        x = 'shared'
    )
    '''

    return points_plot

In [None]:
def draw_bard1_cartoon():


    y_max = 0
    rect_x = [0, 26, 122,425, 545, 568
             ]
    rect_x2 = [26, 122, 425, 545, 568, 777
              ]
    
    rect_y1 = [y_max  - 1.5, y_max  - 1.5, y_max - 1.5, y_max  - 1.5, y_max  - 1.5, y_max - 1.5]
    
    rect_y2 = [y_max + 1.5, y_max + 1.5, y_max + 1.5, y_max + 1.5, y_max + 1.5, y_max + 1.5]
    
    rect_colors = ['#B0B0B0','#B9DBF4', '#B0B0B0', '#C8DBC8', '#B0B0B0', '#F6BF93']
    rect_text = ['', 'RING', '', 'ARD', '', 'BRCT']

    
    rect_data = pd.DataFrame({'x': rect_x,
                             'x2': rect_x2,
                              'y': rect_y1,
                              'y2': rect_y2,
                             'color': rect_colors,
                             'label': rect_text
                             }
                            )

    rectangles = alt.Chart(rect_data).mark_rect(
        opacity=1,
        stroke = 'black', 
        strokeWidth = 1
    ).encode(
        x= 'x:Q',
        x2='x2:Q',
        y='y:Q',
        y2='y2:Q',
        color = alt.Color('color:N', scale = None)
    )

    text = alt.Chart(rect_data).mark_text(
        align='center',
        baseline='middle',
        fontWeight = 'bold',
        fontSize=18,
        angle=0,  # or 90/-90 for vertical text
        color='black',
        limit=1000 # truncate long text
    ).encode(
        x=alt.X('x_center:Q'),
        y=alt.Y('y_center:Q'),
        text='label:N'
    ).transform_calculate(
        x_center='(datum.x + datum.x2) / 2',
        y_center='(datum.y + datum.y2) / 2'
    )

    rectangles = rectangles + text
    return rectangles

In [None]:
def adj_labels(lolli, rect):

    plot = (lolli + rect)

    tick_data = pd.DataFrame({
        'x': [0] * 6,  # or your min x position
        'y_plot': [
            -15 - 1.5,   # -15
            -10 - 1.5,   # -10
            -5 - 1.5,    # -5
            0,           # 0 (at the cartoon)
            5 + 1.5,     # 5
            10 + 1.5     # 10
        ],
        'label': ['-15', '-10', '-5', '0', '5', '10']
    })

    axis_labels = alt.Chart(tick_data).mark_text(
            align='right',
            dx=-5,
            fontSize=18
        ).encode(
            x='x:Q',
            y='y_plot:Q',
            text='label:N'
        )

    axis_ticks = alt.Chart(tick_data).mark_tick(
        orient='horizontal',
        size=5
    ).encode(
        x='x:Q',
        y='y_plot:Q'
    )


    plot = (plot + axis_labels + axis_ticks).encode(
        x = alt.X(
            axis = alt.Axis(
                values = list(range(0, 780, 50))
            ),
            scale = alt.Scale(domain = [0, 778]
                             )
        ),
        y = alt.Y(
            scale = alt.Scale(
                domain = [-18, 10]
            )
        )
    ).configure_axis(
        grid = False
    ).configure_view(
        stroke = None
    )

    return plot


In [None]:
def main():
    df = read_data(table_s2)
    lollipop_plot = draw_plot(df)
    rect_plot = draw_bard1_cartoon()
    final_plot = adj_labels(lollipop_plot, rect_plot)

    final_plot.display()
    #final_plot.save('/Users/ivan/Desktop/BARD1_draft_figs/fig_4_reclass_plot.svg')
    

In [None]:
main()