# ephrin_binding.ipynb
Analyze effect of RBP mutations on receptor binding from DMS selection data using soluble bat Ephrin-B2 or -B3
- Written by Brendan Larsen

In [None]:
# this cell is tagged as parameters for `papermill` parameterization
# input configs
altair_config = None
nipah_config = None

# input files
binding_E2_file = None
binding_E3_file = None

# output images
entry_binding_combined_corr_plot = None
E2_E3_correlation = None
E2_E3_correlation_site = None
binding_by_site_plot = None
entry_binding_corr_heatmap = None
binding_region_bubble_plot = None
combined_contact_ranked_bar_output = None

### Import libraries

In [None]:
import math
import os
import re
import altair as alt
import numpy as np
import pandas as pd
import scipy.stats
import yaml

### Set working directory

In [None]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

if (
    os.getcwd()
    == "/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/"
):
    pass
    print("Already in correct directory")
else:
    os.chdir(
        "/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/"
    )
    print("Setup in correct directory")

### Set hard paths for running in interactive mode

In [None]:
if nipah_config is None:
    ##hard paths in case don't want to run with snakemake
    print("loading hard paths")
    altair_config = "data/custom_analyses_data/theme.py"
    nipah_config = "nipah_config.yaml"

    # input files
    binding_E2_file = "results/filtered_data/binding/e2_binding_filtered.csv"
    binding_E3_file = "results/filtered_data/binding/e3_binding_filtered.csv"

### Run config files to setup altair theme and config variables

In [None]:
if altair_config:
    with open(altair_config, "r") as file:
        exec(file.read())

with open(nipah_config) as f:
    config = yaml.safe_load(f)

# Import and merge data

### Import the filtered binding and entry data for bEFNB2 and bEFNB3

In [None]:
# import binding data
df_E2_filter = pd.read_csv(binding_E2_file)
df_E3_filter = pd.read_csv(binding_E3_file)

### Merge data

In [None]:
## Merge the filtered EFNB2 and EFNB3 DataFrames for combined analysis.
df_binding_effect_merge = pd.merge(
    df_E2_filter,
    df_E3_filter,
    on=["site", "wildtype", "mutant"],
    suffixes=["_E2", "_E3"],
    how="outer",
)
display(df_binding_effect_merge.head(3))
## Add a 'selection' column to distinguish between EFNB2 and EFNB3 data.
df_E2_filter["selection"] = "bEFNB2"
df_E3_filter["selection"] = "bEFNB3"

## Concatenate the DataFrames for plotting or further analysis.
df_binding_effect_concat = pd.concat([df_E2_filter, df_E3_filter])

# Calculate stats

In [None]:
# What are the top 5 highest and lowest binding mutants for EFNB2 and EFNB3?
def find_highest_lowest(df, name):
    print(f"We are analyzing {name}\n")
    print(f"The total number of mutants was: {df.shape[0]}\n")
    tmp_df = df.sort_values(by="binding_mean")
    print("These are the lowest binding mutants detected:")
    display(tmp_df.head(5))

    tmp_df_high = df.sort_values(by="binding_mean", ascending=False)
    print("\nThese are the highest binding mutants detected:\n")
    display(tmp_df_high.head(5))

    # What about mutants with positive entry scores?
    tmp_df = df[df["effect"] > 0].sort_values(by="binding_mean")
    print("These are the lowest binding mutants detected with positive entry scores:")
    display(tmp_df.head(5))

    tmp_df_high = df[df["effect"] > 0].sort_values(by="binding_mean", ascending=False)
    print(
        "\nThese are the highest binding mutants detected with positive entry scores:\n"
    )
    display(tmp_df_high.head(5))

    mean_df = df.groupby('site')['binding_mean'].sum().reset_index()
    print('These are the sites with the highest summed binding score:\n')
    display(mean_df.sort_values(by='binding_mean',ascending=False).head(10))


find_highest_lowest(df_E2_filter, "bEFNB2")
find_highest_lowest(df_E3_filter, "bEFNB3")

In [None]:
def overall_stats(df, effect, name):
    # Now group sites and find sites where all mutants are deleterious
    filtered_df = df.groupby("site").filter(lambda group: (group[effect] < -0.25).all())
    # Which sites are these?
    unique = filtered_df["site"].unique()
    # Convert unique to a Pandas Series
    unique_series = pd.Series(unique)

    # Find the common elements that are also contact sites
    unique_contact_bool = unique_series.isin(config["contact_sites"])
    # Filter and get the common elements
    common_elements = unique_series[unique_contact_bool]

    print(f"The dataset we are analyzing is: {name}\n")

    # Print the common elements
    print(
        f"Here are the contact sites that only have negative binding scores: {list(common_elements)}\n"
    )

    print(f"There are {len(unique)} sites with all negative binding score mutants\n")
    print(
        f"These are the sites with all negative binding score mutants: {list(unique)}\n"
    )

    # Now find sites with low and high binding (median)
    median_df = (
        df.groupby("site")["binding_mean"]
        .max()
        .reset_index()
        .sort_values(by="binding_mean", ascending=False)
    )
    print("These are the sites with the highest binding mutants:\n")
    display(median_df.head(5))

    # Now calculate mutant number
    total_mutants = df.shape[0]

    mutants_above_cutoff_tolerated = df[df["effect"] > 0]
    mutants_above_cutoff_tolerated = mutants_above_cutoff_tolerated[
        ["site", "effect", "binding_mean", "wildtype", "mutant"]
    ]

    total_sites = df["site"].unique().shape[0]

    print(f"The total number of sites are: {total_sites}")


overall_stats(df_E2_filter, "binding_mean", "EFNB2")
overall_stats(df_E3_filter, "binding_mean", "EFNB3")

### Find sites with opposite effects on binding

In [None]:
# find sites that are different
def find_biggest_differences(df):
    efnb2_good_efnb3_bad = df[
        (df["binding_mean_E2"] > 0.5) & (df["binding_mean_E3"] < -0.5)
    ].sort_values(by="binding_mean_E2", ascending=False)
    print('mutantions good for efnb2 binding, bad for efnb3 binding:\n')
    display(efnb2_good_efnb3_bad)

    efnb2_bad_efnb3_good = df[
        (df["binding_mean_E2"] < -0.5) & (df["binding_mean_E3"] > 0.5)
    ].sort_values(by="binding_mean_E3", ascending=False)
    print('mutantions bad for efnb2 binding, good for efnb3 binding:\n')
    display(efnb2_bad_efnb3_good)


find_biggest_differences(df_binding_effect_merge)

### Find the top overlapping binders for both bEFNB2 and bEFNB3

In [None]:
def find_good_binding_for_both(df):
    e2_good = df.sort_values(by='binding_mean_E2',ascending=False)
    e3_good = df.sort_values(by='binding_mean_E3',ascending=False)
    e2_good = e2_good.head(50)
    e3_good = e3_good.head(50)
    combo = pd.merge(e2_good,e3_good,on=['site','wildtype','mutant'],how='inner')
    display(combo)

find_good_binding_for_both(df_binding_effect_merge)

### Find sites with the largest absolute difference in binding

In [None]:
# Compare E2 and E3 binders
def find_highest_lowest(df):
    df["binding_diff"] = (df["binding_mean_E2"] - df["binding_mean_E3"]).abs()
    print(
        "These are the mutants with the biggest difference between EFNB2 and EFNB3:\n"
    )
    display(df.sort_values(by="binding_diff", ascending=False).head(10))

    # calculate aggregate differences
    agg_df = (
        df.groupby("site")[["binding_mean_E2", "binding_mean_E3", "binding_diff"]]
        .mean()
        .reset_index()
    )
    print("These are the sites with the biggest difference between EFNB2 and EFNB3:\n")
    display(agg_df.sort_values(by="binding_diff", ascending=False).head(5))


find_highest_lowest(df_binding_effect_merge)

# Make plots

### Make plots showing correlation between binding and entry for EFNB2 and EFNB3

In [None]:
def plot_corr_binding_entry_updated(df, flag):
    df = df.copy().round(2)
    # Define interactive selectors for variant selection.
    variant_selector = alt.selection_point(
        on="mouseover", nearest=True, empty=False, fields=["site", "mutant"], value=0
    )
    # 'variant_selector_agg' is for aggregated selection based on 'site' only.
    variant_selector_agg = alt.selection_point(
        on="mouseover", nearest=True, empty=False, fields=["site"], value=0
    )

    # Initialize an empty list to store charts for each unique selection in 'df'.
    empty_chart = []
    
    # Loop through each unique cell selection in the DataFrame.
    for cell in list(df["selection"].unique()):
        # Filter DataFrame for the current selection.
        tmp_df = df[df["selection"] == cell]
        
        # Check if aggregation flag is True to aggregate data.
        if flag == True:
            # Aggregate data by 'site', summing 'binding_median' and 'effect', then reset index.
            agg_df = (
                tmp_df.groupby("site")[["binding_mean", "effect"]].mean().reset_index().round(2)
            )
            # Create a chart using aggregated data with specified encodings.
            chart = (
                alt.Chart(agg_df)
                .mark_point(stroke="black", filled=True)  # Use filled points with black stroke.
                .encode(
                    x=alt.X(
                        "effect",
                        title=f"Mean entry in CHO-{cell}",
                        axis=alt.Axis(grid=False),
                    ),
                    y=alt.Y(
                        "binding_mean",
                        title=f"Mean {cell} binding",
                        axis=alt.Axis(grid=False),
                    ),
                    # Dynamic opacity, size, strokeWidth, and color based on 'variant_selector_agg'.
                    opacity=alt.condition(
                        variant_selector_agg, alt.value(1), alt.value(0.2)
                    ),
                    size=alt.condition(
                        variant_selector_agg, alt.value(100), alt.value(50)
                    ),
                    strokeWidth=alt.condition(
                        variant_selector_agg, alt.value(1), alt.value(0)
                    ),
                    color=alt.condition(
                        variant_selector_agg, alt.value("orange"), alt.value("gray")
                    ),
                    tooltip=["site", "binding_mean", "effect"],
                )
                .add_params(variant_selector_agg)
            )
            # Add the chart to the list.
            empty_chart.append(chart)

        else:
            # Create a chart for individual data points with specified encodings.
            chart = (
                alt.Chart(tmp_df)
                .mark_point(stroke="black", filled=True)
                .encode(
                    x=alt.X(
                        "effect", title=f"Entry in CHO-{cell}", axis=alt.Axis(grid=False)
                    ),
                    y=alt.Y(
                        "binding_mean",
                        title=f"{cell} binding",
                        axis=alt.Axis(grid=False),
                    ),
                    # Dynamic opacity, size, strokeWidth, and color based on 'variant_selector'.
                    opacity=alt.condition(
                        variant_selector, alt.value(1), alt.value(0.1)
                    ),
                    size=alt.condition(variant_selector, alt.value(50), alt.value(20)),
                    strokeWidth=alt.condition(
                        variant_selector, alt.value(1), alt.value(0)
                    ),
                    color=alt.condition(
                        variant_selector, alt.value("orange"), alt.value("gray")
                    ),
                    tooltip=[
                        "site",
                        "wildtype",
                        "mutant",
                        "binding_mean",
                        "times_seen_binding",
                        "effect",
                    ],
                )
                .add_params(variant_selector)
            )
            # Add the chart to the list.
            empty_chart.append(chart)

    # Combine all charts horizontally with a title.
    combined_chart = alt.hconcat(
        *empty_chart, title=alt.Title("Correlation between binding and entry")
    )
    
    # Return the combined chart for display or further use.
    return combined_chart

# Generate and display plots for non-aggregated data.
entry_binding_corr_plot = plot_corr_binding_entry_updated(df_binding_effect_concat, False)
entry_binding_corr_plot.display()

# Save the plot 
if entry_binding_combined_corr_plot is not None:
    entry_binding_corr_plot.save(entry_binding_combined_corr_plot)

# Repeat for aggregated data.
entry_binding_corr_plot_agg = plot_corr_binding_entry_updated(df_binding_effect_concat, True)
entry_binding_corr_plot_agg.display()

### Same plot as above, but slightly different format

In [None]:
def plot_entry_binding_corr_heatmap(df):
    empty_chart = []
    for cell in list(df["selection"].unique()):
        tmp_df = df[df["selection"] == cell]
        chart = (
            alt.Chart(tmp_df, title=f"{cell}")
            .mark_rect()
            .encode(
                x=alt.X(
                    "effect", title="Cell entry", axis=alt.Axis(values=[-2, -1, 0, 1])
                ).bin(maxbins=50),
                y=alt.Y(
                    "binding_mean",
                    title="Receptor binding",
                    axis=alt.Axis(values=[-4, -2, 0, 2]),
                ).bin(maxbins=50),
                color=alt.Color("count()", title="Count").scale(type='log'),
            )
        ).properties(width=200,height=200)
        empty_chart.append(chart)

    combined_chart = alt.hconcat(
        *empty_chart, 
    ).resolve_scale(y="shared", x="shared", color="shared")
    return combined_chart


entry_binding_corr_heat = plot_entry_binding_corr_heatmap(df_binding_effect_concat)
entry_binding_corr_heat.display()
if entry_binding_combined_corr_plot is not None:
    entry_binding_corr_heat.save(entry_binding_corr_heatmap)

# Find correlations between bEFNB2 and bEFNB3 binding

In [None]:
def plot_entry_binding_corr(df):
    chart = (
        alt.Chart(df, title="Correlation Between Mutant Binding Scores")
        .mark_rect()
        .encode(
            x=alt.X(
                "binding_mean_E2",
                title="bEFNB2 binding",
                axis=alt.Axis(values=[-4,-2, 0, 2]),
            ).bin(maxbins=50),
            y=alt.Y(
                "binding_mean_E3",
                title="bEFNB3 binding",
                axis=alt.Axis(values=[-2, 0, 2]),
            ).bin(maxbins=50),
            color=alt.Color("count()", title="Count").scale(type='log'),
        )
    ).properties(width=200,height=200)
    return chart


entry_binding_corr_heatmap_1 = plot_entry_binding_corr(df_binding_effect_merge)
entry_binding_corr_heatmap_1.display()

# Plot correlations of binding mutants in scatterplots, and color different subsets

### First, find mutations that are outliers in the correlation between bEFNB2 and bEFNB3 binding

In [None]:
def find_outliers(df):
    df = df.dropna().copy()
    #Calculate the best fit line
    slope, intercept = np.polyfit(df['binding_mean_E2'], df['binding_mean_E3'], 1)
    
    #compute residuals
    df['predicted_y'] = slope * df['binding_mean_E2'] + intercept
    df['residuals'] = df['binding_mean_E3'] - df['predicted_y']
    
    #identify outliers
    # calculate the mean and standard deviation of the residuals
    mean_residual = np.mean(df['residuals'])
    std_residual = np.std(df['residuals'])

    outlier_threshold = 4.5 #4.5 std deviations
    df['is_outlier'] = abs(df['residuals']) > (std_residual * outlier_threshold)
    
    # Filter the DataFrame to only include outliers
    outliers = df[df['is_outlier']]
    print(f'Here are the outlier mutations outside of a {outlier_threshold} standard deviation:\n')
    display(outliers)
    outliers_list = list(outliers['site'].unique())
    print(f' Here are the sites: \n {outliers_list}')
    return df,outliers_list

residuals_df,outliers_list = find_outliers(df_binding_effect_merge)

### Find mutations in the top quantile of binding for bEFNB2 and bEFNB3 both

In [None]:
def find_top_for_both(df):
    quantile_threshold = 0.99
    # Calculate the quantiles for both variables
    x_quantile = df['binding_mean_E2'].quantile(quantile_threshold)
    y_quantile = df['binding_mean_E3'].quantile(quantile_threshold)
    
    # Filter points that are above the quantile threshold
    df['meets_threshold'] = (df['binding_mean_E2'] >= x_quantile) & (df['binding_mean_E3'] >= y_quantile)
    subset = df.query('meets_threshold == True')
    top_mutants = list(subset['site'].unique())
    return df, top_mutants

top_residuals_df,top_mutants_list = find_top_for_both(residuals_df)
print(f' The sites with the top binding mutations are : \n {top_mutants_list}')
cleaned_df = top_residuals_df.query('meets_threshold == True')[['wildtype','site','mutant','binding_mean_E2','binding_mean_E3']].round(2)
print(f'Here are the specific mutations:\n')
display(cleaned_df)

## Plot correlations between individual mutational effects on binding, and color different subsets of sites

In [None]:
def plot_affinity_BLI_mutants(df, highlight_conditions, color, subset=None):
    df = df.round(2).copy()
    df['site_mutant'] = df['site'].astype(str) + df['mutant'].astype(str)

    if subset is not None:
        df = df[df['site'].isin(subset)]

    slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(
        df["binding_mean_E2"], df["binding_mean_E3"]
    )
    
    # make correlation chart
    chart = alt.Chart(df #,title=alt.Title("Correlation Between Mutant Binding Scores",subtitle=f'r = {r_value:.2f}'
    ).mark_point(
        color="gray", 
        size=30, 
        opacity=0.4, 
        filled=True
    ).encode(
        x=alt.X("binding_mean_E2", title=("bEFNB2 binding"),axis=alt.Axis(tickCount=4)),
        y=alt.Y("binding_mean_E3", title=("bEFNB3 binding"),axis=alt.Axis(tickCount=4)),
        tooltip=[
            "site",
            "wildtype",
            "mutant",
            "binding_mean_E2",
            "binding_mean_E3",
            "effect_E2",
            "effect_E3",
            "binding_std_E2",
            "binding_std_E3",
            "times_seen_binding_E2",
            "times_seen_binding_E3"
        ],
    )

    #make colored circles for specific data
    highlight = chart.transform_filter(highlight_conditions).mark_point(
        color=color, size=60, opacity=1, filled=True, stroke='black', strokeWidth=1
    )
    #write text near the orange circles
    text_on_point = chart.transform_filter(highlight_conditions).mark_text(
            align='center',baseline='top',dy=-20,fontSize=16
    ).encode(text='site_mutant')
    
    min = int(df["binding_mean_E2"].min())
    max = int(df["binding_mean_E3"].max())
    text = (
        alt.Chart({"values": [{"x": min, "y": max, "text": f"r = {r_value:.2f}"}]})
        .mark_text(align="left", baseline="top", dx=-10, dy=-20)
        .encode(x=alt.X("x:Q"), y=alt.Y("y:Q"), text="text:N")
    )
    # Vertical line at x=0
    vline = alt.Chart(pd.DataFrame({'x': [0]})).mark_rule(color='gray',opacity=1,strokeDash=[2,4]).encode(x='x:Q')
    # Horizontal line at y=0
    hline = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(color='gray',opacity=1,strokeDash=[2,4]).encode(y='y:Q')
    chart_and_text = chart + text_on_point + highlight + vline + hline
    return chart_and_text.properties(height=200,width=200)


### Sites selected for BLI validation

In [None]:
# these are the data we want to show in orange circles
highlight_conditions = (
        (alt.datum.site == 244) & (alt.datum.mutant == 'W') |
        (alt.datum.site == 305) & (alt.datum.mutant == 'W') | 
        (alt.datum.site == 492) & (alt.datum.mutant == 'L') |
        (alt.datum.site == 507) & (alt.datum.mutant == 'I') |
        (alt.datum.site == 530) & (alt.datum.mutant == 'F') |
        (alt.datum.site == 553) & (alt.datum.mutant == 'W') |
        (alt.datum.site == 555) & (alt.datum.mutant == 'Y') |
        (alt.datum.site == 559) & (alt.datum.mutant == 'R') |
        (alt.datum.site == 588) & (alt.datum.mutant == 'V') 
)

E2_E3_corr_BLI_mutants = plot_affinity_BLI_mutants(df_binding_effect_merge,highlight_conditions,'#e49444')
E2_E3_corr_BLI_mutants.display()

### Sites selected for neutralization validations

In [None]:
highlight_conditions_neut = (
        (alt.datum.site == 333) & (alt.datum.mutant == 'Q') |
        (alt.datum.site == 492) & (alt.datum.mutant == 'R') | 
        (alt.datum.site == 507) & (alt.datum.mutant == 'I') |
        (alt.datum.site == 530) & (alt.datum.mutant == 'F') |
        (alt.datum.site == 553) & (alt.datum.mutant == 'W') |
        (alt.datum.site == 555) & (alt.datum.mutant == 'K') 
)
neut_muts = plot_affinity_BLI_mutants(df_binding_effect_merge,highlight_conditions_neut,'#5778a4')
neut_muts.display()

### Find outliers

In [None]:
highlight_conditions_top_sites = (
        (alt.datum.site == 580) & (alt.datum.mutant == 'S') |
        (alt.datum.site == 211) & (alt.datum.mutant == 'F') | 
        (alt.datum.site == 553) & (alt.datum.mutant == 'W') |
        (alt.datum.site == 589) & (alt.datum.mutant == 'G') |
        (alt.datum.site == 306) & (alt.datum.mutant == 'R') |
        (alt.datum.site == 492) & (alt.datum.mutant == 'L') |
        (alt.datum.site == 492) & (alt.datum.mutant == 'R') |
        (alt.datum.site == 588) & (alt.datum.mutant == 'P') |
        (alt.datum.site == 530) & (alt.datum.mutant == 'F') |
        (alt.datum.site == 546) & (alt.datum.mutant == 'H') |
        (alt.datum.site == 555) & (alt.datum.mutant == 'K') 


)

E2_E3_corr = plot_affinity_BLI_mutants(residuals_df,highlight_conditions_top_sites,'#af7aa1')
E2_E3_corr.display()
if entry_binding_combined_corr_plot is not None:
    E2_E3_corr.save(E2_E3_correlation)

### Find top mutations for both bEFNB2 and bEFNB3

In [None]:
highlight_conditions = (
        (alt.datum.meets_threshold == True)
)
E2_E3_corr = plot_affinity_BLI_mutants(top_residuals_df.query('meets_threshold == True'),highlight_conditions,'#d1615d')
E2_E3_corr.display()

# Plot correlations of mutants for individual sites of interest with letters of each mutant plotted

In [None]:
def plot_affinity_individual_mutants(df,mutant):
    df = df.round(2).copy()
    df = df[df['site'] == mutant]
    if df.empty:
        print('nothing here')
        pass
    else:
        wildtype_site = (df['wildtype'].astype(str) + df['site'].astype(str)).unique()[0]

        chart = alt.Chart(df #,title=alt.Title(f"Site {wildtype_site}"
        ).mark_text(
            size=15,
            opacity=1,
        ).encode(
            x=alt.X("binding_mean_E2", title=("bEFNB2 binding"),axis=alt.Axis(tickCount=3),scale=alt.Scale(domain=[-5,2])),
            y=alt.Y("binding_mean_E3", title=("bEFNB3 binding"),axis=alt.Axis(tickCount=3),scale=alt.Scale(domain=[-2,2])),
            text=alt.Text('mutant'),
            color=alt.Color(
                'mutant_type_E2',
                legend = None, 
                title='Amino acid type',
                scale=alt.Scale(
                    domain=['Aromatic', 'Hydrophilic', 'Hydrophobic','Negative', 'Positive', 'Special'],
                    range=["#4e79a7","#f28e2c","#e15759","#76b7b2","#59a14f","#edc949"])),
            tooltip=[
                "site",
                "wildtype",
                "mutant",
                "binding_mean_E2",
                "binding_mean_E3",
                "effect_E2",
                "effect_E3",
            ],
        )
        # Vertical line at x=0
        vline = alt.Chart(pd.DataFrame({'x': [0]})).mark_rule(color='gray',opacity=1,strokeDash=[2,4]).encode(x='x:Q')
        # Horizontal line at y=0
        hline = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(color='gray',opacity=1,strokeDash=[2,4]).encode(y='y:Q')
        final_chart = vline + hline + chart
        return final_chart.properties(height=200,width=200)

### Plot amino acid letter correlation plots for the top sites

In [None]:
#indiv_graph = plot_affinity_individual_mutants(df_binding_effect_merge,492)
indiv_graph = plot_affinity_individual_mutants(df_binding_effect_merge,553)
#indiv_graph = plot_affinity_individual_mutants(df_binding_effect_merge,530)

indiv_graph.display()

In [None]:
empty_charts = []
top_mutants_list = [580,211,553,589,306,492,530]
for site in top_mutants_list:
    indiv_graph = plot_affinity_individual_mutants(df_binding_effect_merge,site)
    empty_charts.append(indiv_graph)

top_sites = alt.vconcat(*empty_charts).resolve_scale(x='shared',y='shared',color='shared')
top_sites.display()

### Plot amino acid letter correlation plots for the outlier sites

In [None]:
empty_charts = []
for site in outliers_list:
    if site < 180:
        pass
    else:
        indiv_graph = plot_affinity_individual_mutants(df_binding_effect_merge,site)
        empty_charts.append(indiv_graph)

outlier_sites = alt.vconcat(*empty_charts).resolve_scale(x='shared',y='shared',color='shared')
outlier_sites.display()

### Now make for all contact sites

In [None]:
empty_charts = []
for site in config['contact_sites']:
    tmp_df = df_binding_effect_merge[df_binding_effect_merge['site'] == site]
    non_nan_in_effect_E2 = tmp_df['binding_mean_E2'].notnull().any()
    non_nan_in_effect_E3 = tmp_df['binding_mean_E3'].notnull().any()

    if non_nan_in_effect_E2 and non_nan_in_effect_E3:        
        contact_plots = plot_affinity_individual_mutants(tmp_df,site)
        empty_charts.append(contact_plots)
    else:
        pass

all_contact_plots = alt.vconcat(*empty_charts).resolve_scale(x='shared', y='shared')
all_contact_plots.display()

### Make correlation amino acid letter plots for 580-590 loop

In [None]:
for site in list(range(580,590)):
    tmp_df = df_binding_effect_merge[df_binding_effect_merge['site'] == site]
    non_nan_in_effect_E2 = tmp_df['binding_mean_E2'].notnull().any()
    non_nan_in_effect_E3 = tmp_df['binding_mean_E3'].notnull().any()

    if non_nan_in_effect_E2 and non_nan_in_effect_E3:        
        test_plots = plot_affinity_individual_mutants(tmp_df,site)
        test_plots.display()
    else:
        pass

# Plot correlations between each site for mean value

In [None]:
def plot_affinity_solid_mean(df):
    df = df.dropna()
    means = (
        df.groupby("site")
        .agg(
            {
                "effect_E2": "mean",
                "effect_E3": "mean",
                "binding_mean_E2": "mean",
                "binding_mean_E3": "mean",
                "wildtype": "first",
            }
        )
        .reset_index().round(2)
    )
    slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(
        means["binding_mean_E2"], means["binding_mean_E3"]
    )
    r_value = float(r_value)
    chart = (
        alt.Chart(
            means,
            title=alt.Title(
                "Correlation between Aggregate Mutant Binding Scores",
                subtitle=f"r={r_value:.2f}",
            ),
        )
        .mark_point(size=50, opacity=0.5,filled=True,color='gray')
        .encode(
            x=alt.X(
                "binding_mean_E2",
                title=("Mean bEFNB2 binding"),
                axis=alt.Axis(tickCount=3),
            ),
            y=alt.Y(
                "binding_mean_E3",
                title=("Mean bEFNB3 binding"),
                axis=alt.Axis(tickCount=3),
            ),
            tooltip=[
                "site",
                "wildtype",
                "binding_mean_E2",
                "binding_mean_E3",
                "effect_E2",
                "effect_E3",
            ],
        )
    )
    text = (
        alt.Chart({"values": [{"x": -3.5, "y": 0.5, "text": f"r = {r_value:.2f}"}]})
        .mark_text(align="left", baseline="top", dx=0, dy=0)
        .encode(x=alt.X("x:Q"), y=alt.Y("y:Q"), text="text:N")
    )
    chart_and_text = chart.properties(width=200,height=200)
    return chart_and_text


E2_E3_site_corr = plot_affinity_solid_mean(df_binding_effect_merge)
E2_E3_site_corr.display()
if entry_binding_combined_corr_plot is not None:
    E2_E3_site_corr.save(E2_E3_correlation_site)

# Make plot showing binding by site

In [None]:
def plot_affinity_by_site(df):
     # define ranges of different RBP regions
    barrel_ranges = {
        "Stalk": list(range(70, 148)),
        "Neck": list(range(148, 166)),
        "Linker": list(range(166, 178)),
        "Head": list(range(178, 603)),
    }
    
    custom_order = ["Stalk", "Neck", "Linker", "Head"] #custom order for color legend
    
    agg_means = [] #store aggregation 
    
    # For each barrel, filter the dataframe to the sites belonging to that barrel and then store the means
    for barrel, sites in barrel_ranges.items():
        subset = df[df["site"].isin(sites)]
        for _, row in subset.iterrows():
            agg_means.append(
                {"region": barrel, "binding_mean": row["binding_mean"], "site": row["site"],"selection": row["selection"]}
            )
        agg_means_df = pd.DataFrame(agg_means).round(2)
    
    # Setup interactivity
    variant_selector = alt.selection_point(
        on="mouseover", nearest=True, empty=False, fields=["site"], value=0
    )
    #make chart
    chart = alt.Chart(agg_means_df).mark_bar(stroke='black',size=1.5,binSpacing=0,color='black').encode(
        alt.X("site")
            .title("Site")
            .axis(tickCount=5,labelAngle=-90,grid=True),
            #.scale(domain=[70, 602]),
        
        alt.Y('mean(binding_mean)')
            .axis(tickCount=3)
            .title('Mean binding'),
        alt.Row('selection',title=None),
        tooltip=['site'],
        strokeWidth=alt.condition(variant_selector, alt.value(1), alt.value(0)),
        color=alt.Color('region',sort=custom_order,title='Region'),
        #strokeColor=alt.condition(variant_selector, alt.value('orange'), alt.value('black')),


    ).properties(height=150, width=800).add_params(variant_selector).resolve_scale(y='independent')
    

    return chart


binding_by_site = plot_affinity_by_site(df_binding_effect_concat)
binding_by_site.display()
if entry_binding_combined_corr_plot is not None:
    binding_by_site.save(binding_by_site_plot)

# Make bar chart showing max binding score for each contact residue

In [None]:
def plot_affinity_by_contact_site(df, sites_to_show):
    # Filter the DataFrame based on the sites to show
    contact_df = df[df["site"].isin(sites_to_show)]
    
    # Define a selection for highlighting bars on hover
    variant_selector = alt.selection_point(on="mouseover", nearest=True, empty=False, fields=["site"])
    
    # Create the chart
    chart = alt.Chart(contact_df).mark_bar(stroke='black').encode(
        y=alt.Y("site:N", title="Site",sort='-x'),
        x=alt.X("max(binding_mean):Q", title="Max binding mutant at site"),
        color=alt.condition(variant_selector, alt.value("orange"), alt.value("black")),
        strokeWidth=alt.condition(variant_selector, alt.value(1), alt.value(0)),
        column=alt.Column('selection',title=None)
    ).add_params(variant_selector).properties(width=200, height=alt.Step(12)).resolve_scale(x='independent',y='shared').configure_header(
        labelFontSize=20,  
        labelAngle=0,
        labelAlign='center',
        labelAnchor='middle',
        labelFont='Helvetica Light',
        labelFontStyle='bold',
        labelPadding=3
    )
    
    return chart

In [None]:
contact_binding_by_site = plot_affinity_by_contact_site(df_binding_effect_concat, config["contact_sites"])
contact_binding_by_site.display()
if entry_binding_combined_corr_plot is not None:
    contact_binding_by_site.save(combined_contact_ranked_bar_output)

# Make bubble plots for binding in different areas of receptor pocket

In [None]:
def make_bubble_binding_region(df):
    barrel_ranges = {
        "Stalk": list(range(70, 148)),
        "Neck": list(range(148, 166)),
        "Linker": list(range(166, 178)),
        "Head": list(range(178, 603)),
        "Receptor Contact": config["contact_sites"],
    }
    custom_order = ["Stalk", "Neck", "Linker", "Head", "Receptor Contact"]
    agg_means = []

    # For each barrel, filter the site_means dataframe to the sites belonging to that barrel and then store the means
    for barrel, sites in barrel_ranges.items():
        subset = df[df["site"].isin(sites)]
        for _, row in subset.iterrows():
            agg_means.append(
                {
                    "region": barrel,
                    "binding_mean": row["binding_mean"],
                    "site": row["site"],
                    "mutant": row["mutant"],
                    "selection": row["selection"],
                }
            )
        agg_means_df = pd.DataFrame(agg_means)

    variant_selector = alt.selection_point(
        on="mouseover", empty=False, nearest=True, fields=["site", "mutant"], value=1
    )

    chart = alt.Chart(agg_means_df).mark_point(stroke="black",filled=True).encode(
        x=alt.X(
            "region:O",
            sort=custom_order,
            title="RBP Region",
            axis=alt.Axis(labelAngle=-90),
        ),
        y=alt.Y(
            "binding_mean",
            title="Binding",
            axis=alt.Axis(tickCount=4),
        ),
        xOffset="random:Q",
        tooltip=["region", "binding_mean", "site", "mutant"],
        color=alt.condition(
            variant_selector, alt.value("orange"), alt.value("black")
        ),
        opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.1)),
        strokeWidth=alt.condition(variant_selector, alt.value(2), alt.value(0)),
        size=alt.condition(variant_selector, alt.value(70), alt.value(20)),
        column=alt.Column('selection',title=None)
    ).transform_calculate(random="sqrt(-1*log(random()))*cos(2*PI*random())"
    ).properties(
        width=300, 
        height=300
    ).add_params(variant_selector
    ).resolve_scale(y='independent'
    ).configure_header(
        labelFontSize=20,  
        labelAngle=0,
        labelAlign='center',
        labelAnchor='middle',
        labelFont='Helvetica Light',
        labelFontStyle='bold',
        labelPadding=1
    )
    
    return chart


entry_region_bubble = make_bubble_binding_region(df_binding_effect_concat)
entry_region_bubble.display()
if entry_binding_combined_corr_plot is not None:
    entry_region_bubble.save(binding_region_bubble_plot)

# Plot effects of selected mutants on cell entry

In [None]:
def plot_effects_of_mutants(df):
    tmp_df = df.copy()
    tmp_df = tmp_df[
        #((tmp_df['site'] == 580) & (tmp_df['mutant'] == 'S')) |
        ((tmp_df['site'] == 598) & (tmp_df['mutant'] == 'G')) |
        #((tmp_df['site'] == 492) & (tmp_df['mutant'] == 'L')) |
        ((tmp_df['site'] == 553) & (tmp_df['mutant'] == 'W')) |
        #((tmp_df['site'] == 588) & (tmp_df['mutant'] == 'P')) |
        #((tmp_df['site'] == 492) & (tmp_df['mutant'] == 'R')) |
        #((tmp_df['site'] == 530) & (tmp_df['mutant'] == 'F')) |
        #((tmp_df['site'] == 492) & (tmp_df['mutant'] == 'K')) |
        #((tmp_df['site'] == 239) & (tmp_df['mutant'] == 'H')) |
        ((tmp_df['site'] == 211) & (tmp_df['mutant'] == 'F')) |
        ((tmp_df['site'] == 546) & (tmp_df['mutant'] == 'H')) |
        ((tmp_df['site'] == 143) & (tmp_df['mutant'] == 'Q')) |
        ((tmp_df['site'] == 331) & (tmp_df['mutant'] == 'E')) |
        ((tmp_df['site'] == 566) & (tmp_df['mutant'] == 'C')) 

        #((tmp_df['site'] == 589) & (tmp_df['mutant'] == 'G')) 
    ]
    tmp_df['label'] = (tmp_df['wildtype'].astype(str) + tmp_df['site'].astype(str) + tmp_df['mutant'].astype(str))

    tmp_df = tmp_df.sort_values(by='site')
    binding_chart = alt.Chart(tmp_df).mark_bar().encode(
        alt.X("label:N",title='Mutant',sort=None),
        alt.Y("binding_mean:Q",title='Binding score',axis=alt.Axis(tickCount=4)),
        xOffset="selection:N",
        color=alt.Color("selection:N",title='Receptor selection'),
        #row='selection'
    )
    binding_chart.resolve_scale(y='independent',x='shared').display()
plot_effects_of_mutants(df_binding_effect_concat)