In [None]:
import pandas as pd
import altair as alt
from scipy.stats import spearmanr
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

In [None]:
thermompnn = [('F9 (AF)', '../Data/ThermoMPNN/vampseq/FIX_AF.csv'),
              ('G6PD (7UAG)', '../Data/ThermoMPNN/vampseq/G6PD_7UAG.csv'),
              ('TSC2 (7DL2)', '../Data/ThermoMPNN/vampseq/TSC2_7DL2.csv')
             ]

vampseq = [('F9 (AF)','../Data/filtered_ppj_data/VAMPseq/F9_ab102.csv'),
           ('G6PD (7UAG)','../Data/filtered_ppj_data/VAMPseq/G6PD_scores_consequence.csv'),
           ('TSC2 (7DL2)', '../Data/filtered_ppj_data/VAMPseq/TSC2_Lib1_scores_consequences.csv')
          ]

alt.data_transformers.disable_max_rows()

In [None]:
def read_vampseq(files):

    vampseq_data = {}
    for gene in files:
        gene_name, path = gene
        data = pd.read_csv(path)
        data = data.loc[~(data['type'].isin(['Deletion']))]
        data = data.loc[data['type'].isin(['Missense'])]

        if 'F9' in gene_name:
            ab = data['antibody_nonnum'][2]
            if ab == 'ab001':
                data = data.rename(columns = {'variant': 'Mutation','wt_aa': 'WT', 'var_aa': 'Mut'})
                data = data.drop(columns = ['antibody_nonnum', 'antibody_label2', 'Unnamed: 8'])
            else:
                data['Mutation'] = data['wt_aa'] + data['position'].astype(str) + data['var_aa']
                data = data.rename(columns = {'variant': 'Mutation','wt_aa': 'WT', 'var_aa': 'Mut'})
                data = data.drop(columns = ['antibody_nonnum', 'antibody_label2'])
        else:
            data.loc[data['Mut'] == 'Stop', 'Mut'] = '*'
            data = data.dropna(subset = ['position'])
            data['position'] = data['position'].astype(int)
            data['Mutation'] = data['WT'] + data['position'].astype(str) + data['Mut']


        vampseq_data[gene_name] = data

    return vampseq_data

In [None]:
def read_thermo(files):
    offset_dict = {'F9 (AF)': 1,
                   'G6PD (7UAG)': 27,
                   'TSC2 (7DL2)': 127
                  } #ThermoMPNN always starts with residue 0


    thermo_data = {}
    for gene in files:
        gene_name, path = gene
        df = pd.read_csv(path)
        offset = offset_dict[gene_name]
        df['pos'] = df['pos'] + offset
        df['Mutation'] = df['wtAA'] + df['pos'].astype(str) + df['mutAA']
        df = df[['Mutation', 'ddG (kcal/mol)']]
        
        thermo_data[gene_name] = df

    return thermo_data

In [None]:
def merge(vampseq, thermo):
    genes = list(vampseq.keys())

    merged_dfs = []
    for gene in genes:
        vampseq_df = vampseq[gene]
        thermo_df = thermo[gene]
        df = pd.merge(vampseq_df, thermo_df, on = 'Mutation', how = 'inner')
        df['gene'] = gene
        merged_dfs.append(df)

        r, p = spearmanr(df['average_score'], df['ddG (kcal/mol)'])
        print('Analysis for: ', gene, '\n',
              'Spearman r: ', str(r), '\n',
              'P-value: ', str(p)
             )


    df = pd.concat(merged_dfs)

    r, p = spearmanr(df['average_score'], df['ddG (kcal/mol)'])
    print('All the Data: ',
          'Spearman r: ', str(r), '\n',
          'P-value: ', str(p)
         )
    return df

In [None]:
def scatter(df):

    plot = alt.Chart(df).mark_circle().encode(
        x = 'average_score:Q',
        y = 'ddG (kcal/mol):Q',
        color = 'gene',
        tooltip = ['Mutation']
    ).properties(
        width = 600,
        height = 400
    ).facet('gene').interactive()

    plot.display()

In [None]:
def histo_heatmap(df):

    plot = alt.Chart(df).mark_rect().encode(
        x = alt.X('average_score:Q', 
                  bin = alt.Bin(maxbins = 50),
                 title = 'VAMP-seq Score'
                 ),
        y = alt.Y('ddG (kcal/mol):Q',
                  title = 'ddG (kcal/mol)'
                  , bin = alt.Bin(maxbins = 25)),
        color = alt.Color('count():Q', scale = alt.Scale(scheme = 'greenblue'),
                          legend = alt.Legend(title = '# of Vars.')
                         )
    ).properties(
        title = 'ThermoMPNN vs. VAMP-seq'
    )

    plot.display()

In [None]:
def density(df):

    sns.set_style('whitegrid')
    plt.figure(figsize = (12,5))

    plt.subplot(1, 2, 1)
    sns.kdeplot(
        data=df, 
        x='average_score', 
        y='ddG (kcal/mol)',
        fill=True,  # Creates filled density regions
        cmap='BuPu',  # Color scheme
        levels=10,  # Number of contour levels
        thresh=0.05,  # Don't show very low density areas
        bw_adjust=0.5  # Bandwidth adjustment (lower = more detail)
    )
    plt.title('2D Density Plot')

    # 2. Contour Plot
    plt.subplot(1, 2, 2)
    sns.kdeplot(
        data=df,
        x='average_score',
        y='ddG (kcal/mol)', 
        fill=False,  # Just contour lines, no fill
        cmap='plasma',
        levels=10,
        linewidths=1.5,
        bw_adjust=0.5
    )


    plt.tight_layout()
    plt.show()

In [None]:
def main():
    vampseq_dict = read_vampseq(vampseq)
    thermo_data = read_thermo(thermompnn)
    final_df = merge(vampseq_dict, thermo_data)
    #scatter(final_df)
    histo_heatmap(final_df)
    density(final_df)

In [None]:
main()