## Analyze selection data using soluble Ephrin-B2 or -B3

In [None]:
# this cell is tagged as parameters for `papermill` parameterization
#input configs
altair_config = None
nipah_config = None

#input files
entropy_file = None
func_scores_E2_file = None
binding_E2_file = None
func_scores_E3_file = None
binding_E3_file = None

#output files
filtered_E2_binding_data = None
filtered_E3_binding_data = None
filtered_E2_binding_low_effect = None
filtered_E3_binding_low_effect = None

#output images
entry_binding_combined_corr_plot = None
entry_binding_combined_corr_plot_agg = None
E2_E3_correlation = None
E2_E3_correlation_site = None
combined_E2_E3_site_corr = None
binding_by_site_plot = None
entry_binding_corr_heatmap = None
binding_corr_heatmap = None

In [None]:
import math
import os
import re
import altair as alt
import numpy as np
import pandas as pd
import scipy.stats
import yaml

In [None]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

if os.getcwd() == '/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/':
    pass
    print("Already in correct directory")
else:
    os.chdir("/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/")
    print("Setup in correct directory")

In [None]:
if nipah_config is None:
##hard paths in case don't want to run with snakemake
    print('loading hard paths')
    altair_config = "data/custom_analyses_data/theme.py"
    nipah_config = "nipah_config.yaml"
    entropy_file = 'results/entropy/entropy.csv'
    
    #input files
    func_scores_E2_file = "results/func_effects/averages/CHO_EFNB2_low_func_effects.csv"
    binding_E2_file = "results/receptor_affinity/averages/EFNB2_monomeric_mut_effect.csv"
    func_scores_E3_file = "results/func_effects/averages/CHO_EFNB3_low_func_effects.csv"
    binding_E3_file = "results/receptor_affinity/averages/EFNB3_dimeric_mut_effect.csv"

    filtered_E2_binding_data="results/filtered_data/E2_binding_filtered.csv"
    filtered_E3_binding_data="results/filtered_data/E3_binding_filtered.csv"
    filtered_E2_binding_low_effect="results/filtered_data/E2_binding_low_effect_filter.csv"
    filtered_E3_binding_low_effect="results/filtered_data/E3_binding_low_effect_filter.csv"

In [None]:
print(filtered_E2_binding_data)

### Run config files to setup altair theme and config variables

In [None]:
if altair_config:
    with open(altair_config, 'r') as file:
        exec(file.read())

with open(nipah_config) as f:
    config = yaml.safe_load(f)

### Make the E2/E3 dataframes, filter separately, then merge

In [None]:
#import binding and entry data
e2 = pd.read_csv(binding_E2_file)
e2_func = pd.read_csv(func_scores_E2_file)
e3 = pd.read_csv(binding_E3_file)
e3_func = pd.read_csv(func_scores_E3_file)

In [None]:
def merge_func_binding_dfs(func,binding,name):
    df_int = pd.merge(
        binding,
        func,
        on=['site','mutant','wildtype'],
        suffixes=['_binding','_cell_entry'],
        validate='one_to_one',
        how='outer'
    ).round(3)
    df = df_int.rename(columns={'Ephrin binding_mean':'binding_mean','Ephrin binding_std':'binding_std','Ephrin binding_median':'binding_median'})

    # Only save relevant columns
    df = df[['site','wildtype','mutant','binding_median','binding_std','times_seen_binding','effect','effect_std','times_seen_cell_entry','frac_models']]
    
    def filter_binding_data(df):
        df_filter = df[
            (df['mutant'] != '*') &
            (df['mutant'] != '-') &
            (df['site'] != 603) &
            # Filter cell entry parameters
            (df['effect'] >= config['min_func_effect_for_binding']) &
            (df['times_seen_cell_entry'] >= config['func_times_seen_cutoff']) &
            (df['effect_std'] <= config['func_std_cutoff']) &
            # Filter binding parameters
            (df['times_seen_binding'] >= config['min_times_seen_binding']) &
            (df['binding_std'] <= config['max_binding_std']) &
            (df['frac_models'] >= config['frac_models'])
        ]
        return df_filter

    df_filter = filter_binding_data(df)
    
    #For pulling out low effect mutants for heatmaps later. Find mutants below func effect cutoff, but still have ok times_seen and func_std.
    def store_filtered_info(df):
        df_low_filter = df[
            (df['mutant'] != '*') &
            (df['mutant'] != '-') &
            (df['site'] != 603) &
            (df['effect'] < config['min_func_effect_for_binding']) &
            (df['times_seen_cell_entry'] >= config['func_times_seen_cutoff']) &
            (df['effect_std'] <= config['func_std_cutoff']) 
        ]
        return df_low_filter
    
    df_low_effect_filter = store_filtered_info(df)
    
    if name == 'EFNB2':
        print(name)
        df_filter.to_csv(filtered_E2_binding_data,index=False)
        df_low_effect_filter.to_csv(filtered_E2_binding_low_effect,index=False)
    else:
        df_filter.to_csv(filtered_E3_binding_data,index=False)
        df_low_effect_filter.to_csv(filtered_E3_binding_low_effect,index=False)
    
    return df_filter,df_low_effect_filter

#Call filtering function
df_E2_filter,df_E2_filter_missing = merge_func_binding_dfs(e2_func,e2,'EFNB2')
df_E3_filter,df_E3_filter_missing = merge_func_binding_dfs(e3_func,e3,'EFNB3')

#Now that they are filtered, merge EFNB2 and EFNB3
df_binding_effect_merge = pd.merge(
    df_E2_filter,
    df_E3_filter,
    on=['site','wildtype','mutant'],
    suffixes=['_E2','_E3'],
    how='outer'
)

#display stats
display(df_binding_effect_merge.describe().round(3))

# Make a concat df of E2/E3 data for plotting later
df_E2_filter['selection'] = 'EFNB2'
df_E3_filter['selection'] = 'EFNB3'
df_binding_effect_concat = pd.concat([df_E2_filter,df_E3_filter])

### Make nice interactive plot for correlation between binding and entry for EFNB2 and EFNB3

In [None]:
def plot_corr_binding_entry_updated(df,flag):
    variant_selector = alt.selection_point(
        on="mouseover",
        empty=False,
        fields=["site","mutant"],
        value=0
    )  
    variant_selector_agg = alt.selection_point(
        on="mouseover",
        empty=False,
        fields=["site"],
        value=0
    )  
    slider = alt.binding_range(min=2, max=10, step=1, name="times seen")
    selector = alt.param(name="SelectorName", value=2, bind=slider)

    empty_chart = []
    
    for cell in list(df['selection'].unique()):
        tmp_df = df[df['selection'] == cell]
        if flag == True:
            agg_df = tmp_df.groupby('site')[['binding_median','effect']].sum().reset_index()
            chart = alt.Chart(agg_df).mark_point(stroke='black',filled=True).encode(
                x=alt.X('effect', title=f'Median {cell} Cell Entry', axis=alt.Axis(grid=True)),
                y=alt.Y('binding_median', title=f'Summed {cell} Binding', axis=alt.Axis(grid=True)),
                opacity=alt.condition(variant_selector_agg, alt.value(1), alt.value(0.2)),
                size=alt.condition(variant_selector_agg,alt.value(100),alt.value(50)),
                strokeWidth=alt.condition(variant_selector_agg,alt.value(1),alt.value(0)),
                color=alt.condition(variant_selector_agg,alt.value('orange'),alt.value('black')),
                tooltip=['site', 'binding_median','effect'],
            ).properties(
                width=200,
                height=200,
            ).add_params(variant_selector_agg)
            
            empty_chart.append(chart)
        
        
        else:
            chart = alt.Chart(tmp_df).mark_point(stroke='black',filled=True).encode(
                x=alt.X('effect', title=f'{cell} Cell Entry', axis=alt.Axis(grid=True)),
                y=alt.Y('binding_median', title=f'{cell} Binding', axis=alt.Axis(grid=True)),
                opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.1)),
                size=alt.condition(variant_selector,alt.value(50),alt.value(20)),
                strokeWidth=alt.condition(variant_selector,alt.value(1),alt.value(0)),
                color=alt.condition(variant_selector,alt.value('orange'),alt.value('black')),
                tooltip=['site', 'wildtype', 'mutant','binding_median','times_seen_binding','effect'],
            ).properties(
                width=200,
                height=200,
            ).add_params(variant_selector,selector).transform_filter(
                alt.datum.times_seen_binding >= selector
            )
            empty_chart.append(chart)
    
    combined_chart = alt.hconcat(*empty_chart,title=alt.Title('Correlation between binding and entry'))
    return combined_chart

entry_binding_corr_plot = plot_corr_binding_entry_updated(df_binding_effect_concat,False)
entry_binding_corr_plot.display()
entry_binding_corr_plot.save(entry_binding_combined_corr_plot)

entry_binding_corr_plot_agg = plot_corr_binding_entry_updated(df_binding_effect_concat,True)
entry_binding_corr_plot_agg.display()
entry_binding_corr_plot_agg.save(entry_binding_combined_corr_plot_agg)

In [None]:
def plot_entry_binding_corr_heatmap(df):
    empty_chart = []
    
    for cell in list(df['selection'].unique()):
        tmp_df = df[df['selection'] == cell]
        chart = alt.Chart(tmp_df,title=f'{cell}').mark_rect().encode(
            x=alt.X('effect',title='Cell Entry',axis=alt.Axis(values=[-2,-1,0,1])).bin(maxbins=60),
            y=alt.Y('binding_median',title='Binding',axis=alt.Axis(values=[-4,-2,0,2])).bin(maxbins=60),
            color=alt.Color('count()',title='Count').scale(scheme='greenblue'),
            #tooltip=['effect','binding_median']
        )
        empty_chart.append(chart)
    
    combined_chart = alt.hconcat(*empty_chart,title=alt.Title('Correlation between binding and entry')).resolve_scale(y='shared',x='shared',color='shared')
    return combined_chart

entry_binding_corr_heat = plot_entry_binding_corr_heatmap(df_binding_effect_concat)
entry_binding_corr_heat.display()
entry_binding_corr_heat.save(entry_binding_corr_heatmap)

In [None]:
def overall_stats(df,effect,name):
    #Find quantiles
    quantile_2 = df['binding_median'].quantile(.02)
    quantile_98 = df['binding_median'].quantile(.98)
    print(f'The 2% quantile for {name} is: {quantile_2}')
    print(f'The 98% quantile for {name} is: {quantile_98}')

    #Now group sites and find intolerant sites 
    filtered_df = df.groupby('site').filter(lambda group: (group[effect] <-0.25).all())
    unique = filtered_df['site'].unique()
    # Convert unique to a Pandas Series
    unique_series = pd.Series(unique)
    #print(unique_series)
    # Find the common elements
    unique_contact_bool = unique_series.isin(config['contact_sites'])
    #print(unique_contact_bool)
    # Filter and get the common elements
    common_elements = unique_series[unique_contact_bool]

    # Print the common elements
    print(f'Here are the contact sites that are conserved: {common_elements}')
    
    print(f'There are {len(unique)} sites with all negative binding score mutants for {name}')
    print(f'These are the sites for {name} with all negative binding score mutants: {list(unique)}')

    #Now find sites with low and high binding (median)
    median_df = df.groupby('site')['binding_median'].median().reset_index().sort_values(by='binding_median')
    print(f'For {name}, these are the sites with lowest median binding scores: {median_df.head(5)}')
    median_df = df.groupby('site')['binding_median'].median().reset_index().sort_values(by='binding_median',ascending=False)
    print(f'For {name}, these are the sites with highest median binding scores: {median_df.head(5)}')
    
    #Now calculate mutant number
    total_mutants = df.shape[0]
    upper_cutoff = df[df[effect] > 1].sort_values(by='binding_median',ascending=False)
    median_upper = upper_cutoff['effect'].median()
    print(f'The median entry score for top binders was: {median_upper}')
    
    mutants_above_cutoff_tolerated = upper_cutoff[upper_cutoff['effect'] > 0]
    mutants_above_cutoff_tolerated = mutants_above_cutoff_tolerated[['site','effect','binding_median','wildtype','mutant']]
    print(f'The mutants with positive entry scores and good binding are: {mutants_above_cutoff_tolerated.head(5)}')
    
    lower_cutoff = df[df[effect] < -1]
    
    print(f'For {name}, there are a total of : {total_mutants} binding mutants')
    print(f'For {name}, there are {upper_cutoff.shape[0]} mutants above cutoff, and {mutants_above_cutoff_tolerated.shape[0]} that have good entry scores')
    print(f'For {name}, there are {lower_cutoff.shape[0]} mutants below cutoff')
 
    total_sites = df['site'].unique().shape[0]
    
    print(f'The total number of sites are: {total_sites}')


overall_stats(df_E2_filter,'binding_median','E2')
overall_stats(df_E3_filter,'binding_median','E3')

### Find sites with opposite effects on binding

In [None]:
#find sites that are different
def find_biggest_differences(df):
    
    df = df[df['site'].isin(config['contact_sites'])]
    efnb2_good_efnb3_bad = df[
        (df['binding_median_E2'] > 0.1) &
        (df['binding_median_E3'] < -0.1)
    ].sort_values(by='binding_median_E2',ascending=False)
    display(efnb2_good_efnb3_bad)
    efnb2_bad_efnb3_good = df[
        (df['binding_median_E2'] < -0.1) &
        (df['binding_median_E3'] > 0.1)
    ].sort_values(by='binding_median_E3',ascending=False)
    display(efnb2_bad_efnb3_good)


find_biggest_differences(df_binding_effect_merge)

### Find correlations between EFNB2 and EFNB3 binding

In [None]:
def plot_entry_binding_corr(df):
    chart = alt.Chart(df,title='Correlation Between Ephrin Binding Scores').mark_rect().encode(
        x=alt.X('binding_median_E2',title='EFNB2 binding',axis=alt.Axis(values=[-5,0,2])).bin(maxbins=40),
        y=alt.Y('binding_median_E3',title='EFNB3 binding',axis=alt.Axis(values=[-2,0,2])).bin(maxbins=40),
        color=alt.Color('count()',title='Count').scale(scheme='greenblue'),
    ).properties(
        height=200,
        width=200,
    )
    return chart

entry_binding_corr_heatmap_1 = plot_entry_binding_corr(df_binding_effect_merge)
entry_binding_corr_heatmap_1.display()
entry_binding_corr_heatmap_1.save(binding_corr_heatmap)

In [None]:
def plot_affinity_solid(df):
    slider = alt.binding_range(min=1, max=20, step=1, name="times_seen")
    selector = alt.param(name="SelectorName", value=1, bind=slider)
    df = df.dropna()
    # calculate r value
    slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(df['binding_median_E2'], df['binding_median_E3'])
    r_value = float(r_value)
    # make chart
    chart = alt.Chart(df,title=alt.Title('Correlation between Mutant Binding Scores',subtitle=f'r={r_value:.2f}')).mark_point(color='black',size=30, opacity=0.2,filled=True).encode(
        x=alt.X('binding_median_E2', title=('EFNB2 Binding')),
        y=alt.Y('binding_median_E3', title=('EFNB3 Binding')),
        tooltip=['site', 'wildtype','mutant','binding_median_E2','binding_median_E3','effect_E2','effect_E3'],
    ).properties(
        width=200, 
        height=200
    ).add_params(selector).transform_filter(
            (alt.datum.times_seen_binding_E2 >= selector)
    )
    min = int(df['binding_median_E2'].min())
    max = int(df['binding_median_E3'].max())
    text = alt.Chart({'values':[{'x': min, 'y': max, 'text': f'r = {r_value:.2f}'}]}).mark_text(
        align='left', baseline='top', dx=-10, dy=-20).encode(
            x=alt.X('x:Q'),
            y=alt.Y('y:Q'),
            text='text:N'
        )
    chart_and_text = chart
    return chart_and_text

E2_E3_corr = plot_affinity_solid(df_binding_effect_merge)
E2_E3_corr.display()
E2_E3_corr.save(E2_E3_correlation)

### Plot correlations between summary statistics for each site

In [None]:
def plot_affinity_solid_mean(df):
    df = df.dropna()
    means = df.groupby('site').agg({
            'effect_E2': 'median',
            'effect_E3': 'median',
            'binding_median_E2': 'median',
            'binding_median_E3': 'median',
            'wildtype': 'first'
        }).reset_index()
    slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(means['binding_median_E2'], means['binding_median_E3'])
    r_value = float(r_value)
    chart = alt.Chart(means,title=alt.Title('Correlation between Aggregate Mutant Binding Scores',subtitle=f'r={r_value:.2f}')).mark_point(size=50, color='black', opacity=0.3,filled=True).encode(
            x=alt.X('binding_median_E2', title=('Median Ephrin-B2 Binding'), axis=alt.Axis(tickCount=3)),
            y=alt.Y('binding_median_E3', title=('Median Ephrin-B3 Binding'), axis=alt.Axis(tickCount=3)),
            tooltip=['site', 'wildtype','binding_median_E2','binding_median_E3','effect_E2','effect_E3'],
        ).properties(
            width=200, 
            height=200
    )
    text = alt.Chart({'values':[{'x': -3.5, 'y': 0.5, 'text': f'r = {r_value:.2f}'}]}).mark_text(
        align='left', baseline='top', dx=0, dy=-10).encode(
            x=alt.X('x:Q'),
            y=alt.Y('y:Q'),
            text='text:N'
        )
    chart_and_text = chart #+ text
    return chart_and_text#.display()

E2_E3_site_corr = plot_affinity_solid_mean(df_binding_effect_merge)
E2_E3_site_corr.display()
E2_E3_site_corr.save(E2_E3_correlation_site)

(E2_E3_site_corr | E2_E3_corr).save(combined_E2_E3_site_corr)

### Make plot showing binding by site (median)

In [None]:
def plot_affinity_by_site_median(df):
    variant_selector = alt.selection_point(
        on="mouseover",
        empty=False,
        fields=["site"],
        value=0
    )  
    empty_charts = []
    for selection in ['binding_median_E2','binding_median_E3']:
        if selection == 'binding_median_E2':
            name = 'EFNB2 Binding'
        else:
            name = 'EFNB3 Binding'
        mean = df.groupby('site')[selection].sum().reset_index()
        chart = alt.Chart(mean).mark_point(size=60, color='black', stroke='black',filled=True).encode(
            x=alt.X('site', title=('Site'), axis=alt.Axis(grid=True, tickCount=4),scale=alt.Scale(domain=[70,602])),
            y=alt.Y(selection, title=(name), axis=alt.Axis(grid=True, tickCount=4)),
            tooltip=['site'],
            color=alt.condition(variant_selector, alt.value('orange'), alt.value('black')),
            opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.4)),
            strokeWidth=alt.condition(variant_selector,alt.value(2),alt.value(0))
        ).properties(
            width=500, 
            height=100
        ).add_params(variant_selector)
        empty_charts.append(chart)
    combined_chart = alt.vconcat(*empty_charts, spacing=1,title='Summed Binding by Site')
    return combined_chart



binding_by_site = plot_affinity_by_site_median(df_binding_effect_merge)
binding_by_site.display()
binding_by_site.save(binding_by_site_plot)

### Make bubble plots for binding in different areas of receptor pocket

In [None]:
def make_boxplot_binding_region(df,title):# Create a box plot using Altair for aggregated means
    barrel_ranges = {
    'Hydrophobic': config['hydrophobic'],
    'Salt Bridges': config['salt_bridges'],
    'Hydrogen Bonds': config['h_bond_total'],
    'Contact': config['contact_sites'],
    'Overall': list(range(71,602)),
    }
    
    mean_df = df.groupby('site')[['binding_median']].median().reset_index()
    custom_order = ['Hydrophobic','Salt Bridges','Hydrogen Bonds','Contact','Overall']
    agg_means = []
    
    # For each barrel, filter the site_means dataframe to the sites belonging to that barrel and then store the means
    for barrel, sites in barrel_ranges.items():
        subset = mean_df[mean_df['site'].isin(sites)]
        for _, row in subset.iterrows():
            agg_means.append({'barrel': barrel, 'effect': row['binding_median'],'site':row['site']})
        agg_means_df = pd.DataFrame(agg_means)
    chart = alt.Chart(agg_means_df).mark_point(filled=True,size=70,opacity=0.4,color='black').encode(
                x=alt.X('barrel:O', sort=custom_order,title=None,axis=alt.Axis(labelAngle=-90)),
                y=alt.Y('effect',title=f'Median {title} Binding',axis=alt.Axis(grid=True,tickCount=4)),
                xOffset='random:Q',
                #color = alt.Color('barrel').legend(None),
                tooltip=['barrel', 'effect','site'],
            ).transform_calculate(
                random="sqrt(-1*log(random()))*cos(2*PI*random())"
        
            ).properties(
                height=alt.Step(20),
                width=alt.Step(25)
            )
    
    return chart.display()

make_boxplot_binding_region(df_E2_filter,'EFNB2')
make_boxplot_binding_region(df_E3_filter,'EFNB3')

### Plot histogram

In [None]:
def effect_histogram(df):
    colors = {'E2': '#1f4e79', 'E3': '#ff7f0e'}
    all_charts = []
    for effect in ['E2','E3']:
        func_effect_container = []
        for func_effect in [-2]:# Melt the dataframe for the specific columns
            df = df[
                (df[f'effect_{effect}'] > func_effect)
            ]
            color = colors[effect]    
            df_melted = df.melt(value_vars=['binding_median_E2', 'binding_median_E3'], var_name='Effect', value_name='Value')
        
            # Histogram for 'effect_E2'
            histogram = alt.Chart(df_melted[df_melted['Effect'] == f'binding_median_{effect}']).mark_bar(opacity=1,color=color).encode(
                x=alt.X('Value', bin=alt.Bin(step=0.1), title=f'Binding {effect}',axis=alt.Axis(tickCount=4,values=[3,1,0,-1,-5]),scale=alt.Scale(domain=[-6,3])),
                y=alt.Y('count()',stack=None,title='Count'),
                #color=alt.Color('red'),
                tooltip=['Effect', 'count()']
            ).properties(
                width=150, 
                height=alt.Step(10)
            )
            func_effect_container.append(histogram)
        combined_effect_chart = alt.hconcat(*func_effect_container).resolve_scale(y='shared', x='shared', color='independent')
        all_charts.append(combined_effect_chart)
    final_combined_chart = alt.vconcat(*all_charts).resolve_scale(y='independent', x='independent', color='independent')
    return final_combined_chart.display()
    
effect_histogram(df_binding_effect_merge)
#effect_histogram(df_affinity_filter_merge,'#ff7f0e','binding_mean_E3')