In [None]:
import pandas as pd
import altair as alt
from scipy.stats import spearmanr
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
sge_genes = {'BARD1': '../Data/filtered_ppj_data/SGE/BARD1.xlsx',
             'RAD51D': '../Data/filtered_ppj_data/SGE/RAD51D.xlsx',
             'XRCC2': '../Data/filtered_ppj_data/SGE/XRCC2.xlsx',
             'PALB2': '../Data/filtered_ppj_data/SGE/PALB2.xlsx'
            }


thermompnn = {'BARD1': [('../Data/ThermoMPNN/sge/BARD1_RING.csv', 26), ('../Data/ThermoMPNN/sge/BARD1_ARD.csv', 425), ('../Data/ThermoMPNN/sge/BARD1_BRCT.csv', 568)],
             'RAD51D': [('../Data/ThermoMPNN/sge/RAD51D_All.csv', 2)],
             'XRCC2': [('../Data/ThermoMPNN/sge/XRCC2_All.csv', 20)],
              'PALB2': [('../Data/ThermoMPNN/sge/PALB2_WD40.csv', 854)]
             }

alt.data_transformers.disable_max_rows()

genes = list(sge_genes.keys())
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Helvetica Neue', 'Helvetica', 'Arial']

In [None]:
def read_sge(scores):

    sge_dfs = {}
    for gene in genes: 
        
        path = scores[gene]
        df = pd.read_excel(path)
    
        df = df[['amino_acid_change', 'score', 'consequence']]
        df = df.rename(columns = {'amino_acid_change': 'Mutation'})

        sge_dfs[gene] = df
    
    return sge_dfs

In [None]:
def read_thermo(files):

    thermo_dfs = {}

    for gene in genes:
        thermo_files = files[gene]
        gene_thermo_files = []
        for elem in thermo_files:
            path, offset = elem

            str_path = str(path)

            df = pd.read_csv(path)
            if 'BARD1_ARD' in str_path:
                df = df.loc[~(df['pos'].isin([0,1,2]))]
                df['pos'] = df['pos'] - 2
                df['Mutation'] = df['wtAA'] + df['pos'].astype(str) + df['mutAA']
                df = df[['Mutation','ddG (kcal/mol)']]
            else:
                df['pos'] = df['pos'] + offset
                df['Mutation'] = df['wtAA'] + df['pos'].astype(str) + df['mutAA']
                df = df[['Mutation', 'ddG (kcal/mol)']]
            gene_thermo_files.append(df)

        gene_df = pd.concat(gene_thermo_files)
        thermo_dfs[gene] = gene_df

    return thermo_dfs


In [None]:
def scatter_plot(final_df):
    

    r, p = spearmanr(final_df['score'], final_df['ddG (kcal/mol)'])

    print("Spearman rho: ", str(r), '\n',
          'P-value: ', str(p)
         )
    
    plot = alt.Chart(final_df).mark_circle().encode(
        x = 'score:Q',
        y = 'ddG (kcal/mol)',
        color = 'consequence:N',
        tooltip = ['Mutation']
    ).properties(
        width = 600,
        height = 400,
        title = 'ThermoMPNN vs. BARD1 SGE Scores'
    ).interactive()
    
    plot.display()

In [None]:
def histo_heatmap(df):

    plot = alt.Chart(df).mark_rect().encode(
        x = alt.X('score:Q', 
                  title = 'SGE Score',
                  bin = alt.Bin(maxbins = 50)),
        y = alt.Y('ddG (kcal/mol):Q', 
                  bin = alt.Bin(maxbins = 25),
                 title = 'ddG (kcal/mol)'),
        color = alt.Color('count():Q', 
                          scale = alt.Scale(scheme = 'greenblue'),
                          legend = alt.Legend(title = '# of Vars.')
                         )
    ).properties(
        title = 'ThermoMPNN vs. BARD1 SGE Scores'
    )

    plot.display()

In [None]:
def density(df):

    sns.set_style('white')
    fig, ax = plt.subplots(figsize=(7, 7))  # Create figure and axes together
    sns.kdeplot(
        data=df, 
        x='score', 
        y='ddG (kcal/mol)',
        fill=True,
        cmap='viridis',
        levels=10,
        thresh=0.05,
        bw_adjust=0.5,
        ax=ax
        )

    ax.set_ylim([-1.5, 4])
    ax.set_xlim([-0.4, 0.1])
    ax.set_xlabel('SGE Score', weight = 'bold',fontsize=20)
    ax.set_ylabel('Predicted ddG (kcal/mol)', fontsize=20, weight = 'bold')
    ax.tick_params(axis = 'both', labelsize = 18)
    ax.set_title('2D Density Plot', fontsize=14, weight = 'bold', pad=15)


    plt.tight_layout()
    plt.show()
    #fig.savefig('/Users/ivan/Desktop/pillar_project_figs/20250925_SGE_ThermoMPNN.png', dpi = 300)

In [None]:
def main():
    sge_dfs = read_sge(sge_genes)
    thermo_dfs = read_thermo(thermompnn)

    all_dfs = []
    for gene in genes:
        sge_df = sge_dfs[gene]
        sge_df = sge_df.loc[~sge_df['consequence'].isin(['synonymous_variant'])]
        
        thermo_df = thermo_dfs[gene]

        gene_df = pd.merge(sge_df, thermo_df, on = 'Mutation', how = 'inner')
        all_dfs.append(gene_df)

    final_df = pd.concat(all_dfs)
    print(final_df)
    scatter_plot(final_df)
    #histo_heatmap(final_df)
    density(final_df)

In [None]:
main()