In [None]:
import pandas as pd
import altair as alt
from scipy.stats import spearmanr

In [None]:
bard1_scores = '../Data/filtered_ppj_data/SGE/BARD1.xlsx'
thermompnn = [('RING (1JM7)', '../Data/ThermoMPNN/sge/BARD1_RING.csv'),
              ('ARD (3C5R)', '../Data/ThermoMPNN/sge/BARD1_ARD.csv'),
              ('BRCT (3FA2)', '../Data/ThermoMPNN/sge/BARD1_BRCT.csv')
             ]
alt.data_transformers.disable_max_rows()

In [None]:
def read_sge(scores):
    df = pd.read_excel(scores)

    df = df[['amino_acid_change', 'score', 'consequence']]
    df = df.rename(columns = {'amino_acid_change': 'Mutation'})

    return df

In [None]:
def read_thermo(files):
    offset_dict = {'RING (1JM7)': 26,
                   'ARD (3C5R)': 425,
                   'BRCT (3FA2)': 568
                  } #ThermoMPNN always starts with residue 0

    domains = list(offset_dict.keys())
    thermodata = []

    for domain, path in files:
        df = pd.read_csv(path)
        if domain == 'ARD (3C5R)':
            df = df.loc[~(df['pos'].isin([0,1,2]))]
            df['pos'] = df['pos'] - 2
            df['Mutation'] = df['wtAA'] + df['pos'].astype(str) + df['mutAA']
            df = df[['Mutation','ddG (kcal/mol)']]
            thermodata.append(df)
        else:
            offset = offset_dict[domain]
            df['pos'] = df['pos'] + offset
            df['Mutation'] = df['wtAA'] + df['pos'].astype(str) + df['mutAA']
            df = df[['Mutation', 'ddG (kcal/mol)']]
            thermodata.append(df)

    df = pd.concat(thermodata)

    return df

In [None]:
def scatter_plot(final_df):
    

    r, p = spearmanr(final_df['score'], final_df['ddG (kcal/mol)'])

    print("Spearman rho: ", str(r), '\n',
          'P-value: ', str(p)
         )
    
    plot = alt.Chart(final_df).mark_circle().encode(
        x = 'score:Q',
        y = 'ddG (kcal/mol)',
        color = 'consequence:N',
        tooltip = ['Mutation']
    ).properties(
        width = 600,
        height = 400,
        title = 'ThermoMPNN vs. BARD1 SGE Scores'
    ).interactive()
    
    plot.display()

In [None]:
def histo_heatmap(df):

    plot = alt.Chart(df).mark_rect().encode(
        x = alt.X('score:Q', 
                  title = 'SGE Score',
                  bin = alt.Bin(maxbins = 50)),
        y = alt.Y('ddG (kcal/mol):Q', 
                  bin = alt.Bin(maxbins = 25),
                 title = 'ddG (kcal/mol)'),
        color = alt.Color('count():Q', 
                          scale = alt.Scale(scheme = 'greenblue'),
                          legend = alt.Legend(title = '# of Vars.')
                         )
    ).properties(
        title = 'ThermoMPNN vs. BARD1 SGE Scores'
    )

    plot.display()

In [None]:
def main():
    sge_df = read_sge(bard1_scores)
    thermo_df = read_thermo(thermompnn)

    final_df = pd.merge(sge_df, thermo_df, on = 'Mutation', how = 'inner')
    
    scatter_plot(final_df)
    histo_heatmap(final_df)

In [None]:
main()