# ephrin_binding.ipynb
Analyze receptor selection data using soluble bat Ephrin-B2 or -B3
- Written by Brendan Larsen

In [None]:
# this cell is tagged as parameters for `papermill` parameterization
# input configs
altair_config = None
nipah_config = None

# input files
binding_E2_file = None
binding_E3_file = None

# output images
entry_binding_combined_corr_plot = None
entry_binding_combined_corr_plot_agg = None
E2_E3_correlation = None
E2_E3_correlation_site = None
binding_by_site_plot = None
entry_binding_corr_heatmap = None
binding_corr_heatmap = None
binding_region_bubble_plot = None
combined_contact_ranked_bar_output = None

In [None]:
import math
import os
import re
import altair as alt
import numpy as np
import pandas as pd
import scipy.stats
import yaml

In [None]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

if (
    os.getcwd()
    == "/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/"
):
    pass
    print("Already in correct directory")
else:
    os.chdir(
        "/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/"
    )
    print("Setup in correct directory")

### hard paths for running in interactive mode

In [None]:
if nipah_config is None:
    ##hard paths in case don't want to run with snakemake
    print("loading hard paths")
    altair_config = "data/custom_analyses_data/theme.py"
    nipah_config = "nipah_config.yaml"

    # input files
    binding_E2_file = "results/filtered_data/binding/e2_binding_filtered.csv"
    binding_E3_file = "results/filtered_data/binding/e3_binding_filtered.csv"

### Run config files to setup altair theme and config variables

In [None]:
if altair_config:
    with open(altair_config, "r") as file:
        exec(file.read())

with open(nipah_config) as f:
    config = yaml.safe_load(f)

### Import the filtered binding and entry data for bEFNB2 and bEFNB3

In [None]:
# import binding data
df_E2_filter = pd.read_csv(binding_E2_file)
display(df_E2_filter.head(3))
df_E3_filter = pd.read_csv(binding_E3_file)
display(df_E3_filter.head(3))

### Merge data

In [None]:
## Merge the filtered EFNB2 and EFNB3 DataFrames for combined analysis.
df_binding_effect_merge = pd.merge(
    df_E2_filter,
    df_E3_filter,
    on=["site", "wildtype", "mutant"],
    suffixes=["_E2", "_E3"],
    how="outer",
)
display(df_binding_effect_merge.head(3))
## Add a 'selection' column to distinguish between EFNB2 and EFNB3 data.
df_E2_filter["selection"] = "bEFNB2"
df_E3_filter["selection"] = "bEFNB3"

## Concatenate the DataFrames for plotting or further analysis.
df_binding_effect_concat = pd.concat([df_E2_filter, df_E3_filter])
display(df_binding_effect_concat.head(3))

In [None]:
# What are the top 5 highest and lowest binding mutants for EFNB2 and EFNB3?
def find_highest_lowest(df, name):
    print(f"We are analyzing {name}\n")
    print(f"The total number of mutants was: {df.shape[0]}\n")
    tmp_df = df.sort_values(by="binding_mean")
    print("These are the lowest binding mutants detected:")
    display(tmp_df.head(5))

    tmp_df_high = df.sort_values(by="binding_mean", ascending=False)
    print("\nThese are the highest binding mutants detected:\n")
    display(tmp_df_high.head(5))

    # What about mutants with positive entry scores?
    tmp_df = df[df["effect"] > 0].sort_values(by="binding_mean")
    print("These are the lowest binding mutants detected with positive entry scores:")
    display(tmp_df.head(5))

    tmp_df_high = df[df["effect"] > 0].sort_values(by="binding_mean", ascending=False)
    print(
        "\nThese are the highest binding mutants detected with positive entry scores:\n"
    )
    display(tmp_df_high.head(5))

    mean_df = df.groupby('site')['binding_mean'].sum().reset_index()
    display(mean_df.sort_values(by='binding_mean',ascending=False).head(10))


find_highest_lowest(df_E2_filter, "bEFNB2")
find_highest_lowest(df_E3_filter, "bEFNB3")

### Find the top overlapping binders for both bEFNB2 and bEFNB3

In [None]:
def find_good_binding_for_both(df):
    e2_good = df.sort_values(by='binding_mean_E2',ascending=False)
    e3_good = df.sort_values(by='binding_mean_E3',ascending=False)
    e2_good = e2_good.head(50)
    e3_good = e3_good.head(50)
    combo = pd.merge(e2_good,e3_good,on=['site','wildtype','mutant'],how='inner')
    display(combo)

find_good_binding_for_both(df_binding_effect_merge)

In [None]:
# Compare E2 and E3 binders
def find_highest_lowest(df):
    df["binding_diff"] = (df["binding_mean_E2"] - df["binding_mean_E3"]).abs()
    print(
        "These are the mutants with the biggest difference between EFNB2 and EFNB3:\n"
    )
    display(df.sort_values(by="binding_diff", ascending=False).head(10))

    # calculate aggregate differences
    agg_df = (
        df.groupby("site")[["binding_mean_E2", "binding_mean_E3", "binding_diff"]]
        .mean()
        .reset_index()
    )
    print("These are the sites with the biggest difference between EFNB2 and EFNB3:\n")
    display(agg_df.sort_values(by="binding_diff", ascending=False).head(5))


find_highest_lowest(df_binding_effect_merge)
# find_highest_lowest(df_E3_filter,'EFNB3')

### Make plots showing correlation between binding and entry for EFNB2 and EFNB3

In [None]:
def plot_corr_binding_entry_updated(df, flag):
    # Define interactive selectors for variant selection.
    # 'variant_selector' is for selecting individual variants based on 'site' and 'mutant'.
    variant_selector = alt.selection_point(
        on="mouseover", nearest=True, empty=False, fields=["site", "mutant"], value=0
    )
    # 'variant_selector_agg' is for aggregated selection based on 'site' only.
    variant_selector_agg = alt.selection_point(
        on="mouseover", nearest=True, empty=False, fields=["site"], value=0
    )

    # Initialize an empty list to store charts for each unique selection in 'df'.
    empty_chart = []
    
    # Loop through each unique cell selection in the DataFrame.
    for cell in list(df["selection"].unique()):
        # Filter DataFrame for the current selection.
        tmp_df = df[df["selection"] == cell]
        
        # Check if aggregation flag is True to aggregate data.
        if flag == True:
            # Aggregate data by 'site', summing 'binding_median' and 'effect', then reset index.
            agg_df = (
                tmp_df.groupby("site")[["binding_mean", "effect"]].mean().reset_index()
            )
            # Create a chart using aggregated data with specified encodings.
            chart = (
                alt.Chart(agg_df)
                .mark_point(stroke="black", filled=True)  # Use filled points with black stroke.
                .encode(
                    x=alt.X(
                        "effect",
                        title=f"Mean entry in CHO-{cell}",
                        axis=alt.Axis(grid=True),
                    ),
                    y=alt.Y(
                        "binding_mean",
                        title=f"Mean {cell} binding",
                        axis=alt.Axis(grid=True),
                    ),
                    # Dynamic opacity, size, strokeWidth, and color based on 'variant_selector_agg'.
                    opacity=alt.condition(
                        variant_selector_agg, alt.value(1), alt.value(0.2)
                    ),
                    size=alt.condition(
                        variant_selector_agg, alt.value(100), alt.value(50)
                    ),
                    strokeWidth=alt.condition(
                        variant_selector_agg, alt.value(1), alt.value(0)
                    ),
                    color=alt.condition(
                        variant_selector_agg, alt.value("orange"), alt.value("black")
                    ),
                    tooltip=["site", "binding_mean", "effect"],
                )
                .add_params(variant_selector_agg)
            )
            # Add the chart to the list.
            empty_chart.append(chart)

        else:
            # Create a chart for individual data points with specified encodings.
            chart = (
                alt.Chart(tmp_df)
                .mark_point(stroke="black", filled=True)
                .encode(
                    x=alt.X(
                        "effect", title=f"Entry in CHO-{cell}", axis=alt.Axis(grid=True)
                    ),
                    y=alt.Y(
                        "binding_mean",
                        title=f"{cell} binding",
                        axis=alt.Axis(grid=True),
                    ),
                    # Dynamic opacity, size, strokeWidth, and color based on 'variant_selector'.
                    opacity=alt.condition(
                        variant_selector, alt.value(1), alt.value(0.1)
                    ),
                    size=alt.condition(variant_selector, alt.value(50), alt.value(20)),
                    strokeWidth=alt.condition(
                        variant_selector, alt.value(1), alt.value(0)
                    ),
                    color=alt.condition(
                        variant_selector, alt.value("orange"), alt.value("black")
                    ),
                    tooltip=[
                        "site",
                        "wildtype",
                        "mutant",
                        "binding_mean",
                        "times_seen_binding",
                        "effect",
                    ],
                )
                .add_params(variant_selector)
            )
            # Add the chart to the list.
            empty_chart.append(chart)

    # Combine all charts horizontally with a title.
    combined_chart = alt.hconcat(
        *empty_chart, title=alt.Title("Correlation between binding and entry")
    )
    
    # Return the combined chart for display or further use.
    return combined_chart

# Generate and display plots for non-aggregated data.
entry_binding_corr_plot = plot_corr_binding_entry_updated(df_binding_effect_concat, False)
entry_binding_corr_plot.display()

# Save the plot 
if entry_binding_combined_corr_plot is not None:
    entry_binding_corr_plot.save(entry_binding_combined_corr_plot)

# Repeat for aggregated data.
entry_binding_corr_plot_agg = plot_corr_binding_entry_updated(df_binding_effect_concat, True)
entry_binding_corr_plot_agg.display()

# Save the aggregated plot
if entry_binding_combined_corr_plot is not None:
    entry_binding_corr_plot_agg.save(entry_binding_combined_corr_plot_agg)

### Same plot as above, but slightly different format

In [None]:
def plot_entry_binding_corr_heatmap(df):
    empty_chart = []

    for cell in list(df["selection"].unique()):
        tmp_df = df[df["selection"] == cell]
        chart = (
            alt.Chart(tmp_df, title=f"{cell}")
            .mark_rect()
            .encode(
                x=alt.X(
                    "effect", title="Cell entry", axis=alt.Axis(values=[-2, -1, 0, 1])
                ).bin(maxbins=50),
                y=alt.Y(
                    "binding_mean",
                    title="Receptor binding",
                    axis=alt.Axis(values=[-4, -2, 0, 2]),
                ).bin(maxbins=50),
                color=alt.Color("count()", title="Count").scale(type='log'),
            )
        ).properties(width=200,height=200)
        empty_chart.append(chart)

    combined_chart = alt.hconcat(
        *empty_chart, 
    ).resolve_scale(y="shared", x="shared", color="shared")
    return combined_chart


entry_binding_corr_heat = plot_entry_binding_corr_heatmap(df_binding_effect_concat)
entry_binding_corr_heat.display()
if entry_binding_combined_corr_plot is not None:
    entry_binding_corr_heat.save(entry_binding_corr_heatmap)

### Calculate some stats on binding

In [None]:
def overall_stats(df, effect, name):
    # Now group sites and find sites where all mutants are deleterious
    filtered_df = df.groupby("site").filter(lambda group: (group[effect] < -0.25).all())
    # Which sites are these?
    unique = filtered_df["site"].unique()
    # Convert unique to a Pandas Series
    unique_series = pd.Series(unique)

    # Find the common elements that are also contact sites
    unique_contact_bool = unique_series.isin(config["contact_sites"])
    # Filter and get the common elements
    common_elements = unique_series[unique_contact_bool]

    print(f"The dataset we are analyzing is: {name}\n")

    # Print the common elements
    print(
        f"Here are the contact sites that only have negative binding scores: {list(common_elements)}\n"
    )

    print(f"There are {len(unique)} sites with all negative binding score mutants\n")
    print(
        f"These are the sites with all negative binding score mutants: {list(unique)}\n"
    )

    # Now find sites with low and high binding (median)
    median_df = (
        df.groupby("site")["binding_mean"]
        .max()
        .reset_index()
        .sort_values(by="binding_mean", ascending=False)
    )
    print("These are the sites with the highest binding mutants:\n")
    display(median_df.head(5))

    # Now calculate mutant number
    total_mutants = df.shape[0]

    mutants_above_cutoff_tolerated = df[df["effect"] > 0]
    mutants_above_cutoff_tolerated = mutants_above_cutoff_tolerated[
        ["site", "effect", "binding_mean", "wildtype", "mutant"]
    ]

    total_sites = df["site"].unique().shape[0]

    print(f"The total number of sites are: {total_sites}")


overall_stats(df_E2_filter, "binding_mean", "EFNB2")
overall_stats(df_E3_filter, "binding_mean", "EFNB3")

### Find sites with opposite effects on binding

In [None]:
# find sites that are different
def find_biggest_differences(df):
    efnb2_good_efnb3_bad = df[
        (df["binding_mean_E2"] > 0.5) & (df["binding_mean_E3"] < -0.5)
    ].sort_values(by="binding_mean_E2", ascending=False)
    print('mutantions good for efnb2 binding, bad for efnb3 binding:\n')
    display(efnb2_good_efnb3_bad)

    efnb2_bad_efnb3_good = df[
        (df["binding_mean_E2"] < -0.5) & (df["binding_mean_E3"] > 0.5)
    ].sort_values(by="binding_mean_E3", ascending=False)
    print('mutantions bad for efnb2 binding, good for efnb3 binding:\n')
    display(efnb2_bad_efnb3_good)


find_biggest_differences(df_binding_effect_merge)

### Find correlations between EFNB2 and EFNB3 binding

In [None]:
def plot_entry_binding_corr(df):
    chart = (
        alt.Chart(df, title="Correlation Between Mutant Binding Scores")
        .mark_rect()
        .encode(
            x=alt.X(
                "binding_mean_E2",
                title="bEFNB2 binding",
                axis=alt.Axis(values=[-4,-2, 0, 2]),
            ).bin(maxbins=50),
            y=alt.Y(
                "binding_mean_E3",
                title="bEFNB3 binding",
                axis=alt.Axis(values=[-2, 0, 2]),
            ).bin(maxbins=50),
            color=alt.Color("count()", title="Count").scale(type='log'),
        )
    ).properties(width=200,height=200)
    return chart


entry_binding_corr_heatmap_1 = plot_entry_binding_corr(df_binding_effect_merge)
entry_binding_corr_heatmap_1.display()
if entry_binding_combined_corr_plot is not None:
    entry_binding_corr_heatmap_1.save(binding_corr_heatmap)

### Plot correlations with mutants we did BLI on

In [None]:
def plot_affinity_BLI_mutants(df):
    df = df.dropna().round(2).copy()
    df['site_mutant'] = df['site'].astype(str) + df['mutant'].astype(str)
    slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(
        df["binding_mean_E2"], df["binding_mean_E3"]
    )
    
    # make chart
    chart = alt.Chart(df,title=alt.Title("Correlation Between Mutant Binding Scores")
    ).mark_point(
        color="gray", 
        size=30, 
        opacity=0.4, 
        filled=True
    ).encode(
        x=alt.X("binding_mean_E2", title=("bEFNB2 binding"),axis=alt.Axis(tickCount=4)),
        y=alt.Y("binding_mean_E3", title=("bEFNB3 binding"),axis=alt.Axis(tickCount=4)),
        tooltip=[
            "site",
            "wildtype",
            "mutant",
            "binding_mean_E2",
            "binding_mean_E3",
            "effect_E2",
            "effect_E3",
        ],
    )
    
    highlight = chart.transform_filter(
        (alt.datum.site == 244) & (alt.datum.mutant == 'W') |
        (alt.datum.site == 305) & (alt.datum.mutant == 'W') | 
        (alt.datum.site == 492) & (alt.datum.mutant == 'L') |
        (alt.datum.site == 507) & (alt.datum.mutant == 'I') |
        (alt.datum.site == 530) & (alt.datum.mutant == 'F') |
        (alt.datum.site == 553) & (alt.datum.mutant == 'W') |
        (alt.datum.site == 555) & (alt.datum.mutant == 'Y') |
        (alt.datum.site == 559) & (alt.datum.mutant == 'R') |
        (alt.datum.site == 588) & (alt.datum.mutant == 'V') 
    ).mark_point(
        color="orange", size=40, opacity=1, filled=True, stroke='black', strokeWidth=1
    )

    text_on_point = chart.transform_filter(
        (alt.datum.site == 244) & (alt.datum.mutant == 'W') |
        (alt.datum.site == 305) & (alt.datum.mutant == 'W') | 
        (alt.datum.site == 492) & (alt.datum.mutant == 'L') |
        (alt.datum.site == 507) & (alt.datum.mutant == 'I') |
        (alt.datum.site == 530) & (alt.datum.mutant == 'F') |
        (alt.datum.site == 553) & (alt.datum.mutant == 'W') |
        (alt.datum.site == 555) & (alt.datum.mutant == 'Y') |
        (alt.datum.site == 559) & (alt.datum.mutant == 'R') |
        (alt.datum.site == 588) & (alt.datum.mutant == 'V') 
    ).mark_text(
        align='center',baseline='top',dy=-15,fontSize=14
    ).encode(text='site_mutant')
    
    min = int(df["binding_mean_E2"].min())
    max = int(df["binding_mean_E3"].max())
    text = (
        alt.Chart({"values": [{"x": min, "y": max, "text": f"r = {r_value:.2f}"}]})
        .mark_text(align="left", baseline="top", dx=-10, dy=-20)
        .encode(x=alt.X("x:Q"), y=alt.Y("y:Q"), text="text:N")
    )
    chart_and_text = chart + text + text_on_point + highlight
    return chart_and_text.properties(height=300,width=300)


E2_E3_corr = plot_affinity_BLI_mutants(df_binding_effect_merge)
E2_E3_corr.display()
if entry_binding_combined_corr_plot is not None:
    E2_E3_corr.save(E2_E3_correlation)

In [None]:
def plot_affinity_neut(df):
    df = df.dropna().round(2).copy()
    df['site_mutant'] = df['site'].astype(str) + df['mutant'].astype(str)
    slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(
        df["binding_mean_E2"], df["binding_mean_E3"]
    )
    
    # make chart
    chart = alt.Chart(df,title=alt.Title("Correlation Between Mutant Binding Scores")
    ).mark_point(
        color="gray", 
        size=30, 
        opacity=0.4, 
        filled=True
    ).encode(
        x=alt.X("binding_mean_E2", title=("bEFNB2 binding"),axis=alt.Axis(tickCount=4)),
        y=alt.Y("binding_mean_E3", title=("bEFNB3 binding"),axis=alt.Axis(tickCount=4)),
        tooltip=[
            "site",
            "wildtype",
            "mutant",
            "binding_mean_E2",
            "binding_mean_E3",
            "effect_E2",
            "effect_E3",
        ],
    )
    
    highlight = chart.transform_filter(
        (alt.datum.site == 333) & (alt.datum.mutant == 'Q') |
        (alt.datum.site == 492) & (alt.datum.mutant == 'R') | 
        (alt.datum.site == 507) & (alt.datum.mutant == 'I') |
        (alt.datum.site == 530) & (alt.datum.mutant == 'F') |
        (alt.datum.site == 553) & (alt.datum.mutant == 'W') |
        (alt.datum.site == 555) & (alt.datum.mutant == 'K') 

    ).mark_point(
        color="blue", size=40, opacity=1, filled=True, stroke='black', strokeWidth=1
    )

    text_on_point = chart.transform_filter(
        (alt.datum.site == 333) & (alt.datum.mutant == 'Q') |
        (alt.datum.site == 492) & (alt.datum.mutant == 'R') | 
        (alt.datum.site == 507) & (alt.datum.mutant == 'I') |
        (alt.datum.site == 530) & (alt.datum.mutant == 'F') |
        (alt.datum.site == 553) & (alt.datum.mutant == 'W') |
        (alt.datum.site == 555) & (alt.datum.mutant == 'K') 
    ).mark_text(
        align='center',baseline='top',dy=-15,fontSize=14
    ).encode(text='site_mutant')
    
    min = int(df["binding_mean_E2"].min())
    max = int(df["binding_mean_E3"].max())
    text = (
        alt.Chart({"values": [{"x": min, "y": max, "text": f"r = {r_value:.2f}"}]})
        .mark_text(align="left", baseline="top", dx=-10, dy=-20)
        .encode(x=alt.X("x:Q"), y=alt.Y("y:Q"), text="text:N")
    )
    chart_and_text = chart + text + text_on_point + highlight
    return chart_and_text.properties(height=300,width=300)


E2_E3_corr = plot_affinity_neut(df_binding_effect_merge)
E2_E3_corr.display()
#if entry_binding_combined_corr_plot is not None:
#    E2_E3_corr.save(E2_E3_correlation)

In [None]:
def plot_affinity_highB3(df):
    df = df.dropna().round(2).copy()
    df = df[df['site'].isin(config['contact_sites'])]
    df['site_mutant'] = df['site'].astype(str) + df['mutant'].astype(str)
    slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(
        df["binding_mean_E2"], df["binding_mean_E3"]
    )
    
    # make chart
    chart = alt.Chart(df,title=alt.Title("Correlation Between Mutant Binding Scores")
    ).mark_point(
        color="gray", 
        size=30, 
        opacity=0.4, 
        filled=True
    ).encode(
        x=alt.X("binding_mean_E2", title=("bEFNB2 binding"),axis=alt.Axis(tickCount=4)),
        y=alt.Y("binding_mean_E3", title=("bEFNB3 binding"),axis=alt.Axis(tickCount=4)),
        tooltip=[
            "site",
            "wildtype",
            "mutant",
            "binding_mean_E2",
            "binding_mean_E3",
            "effect_E2",
            "effect_E3",
        ],
    )
    
    highlight = chart.transform_filter(
        (alt.datum.site == 580) & (alt.datum.mutant == 'S') |
        (alt.datum.site == 458) & (alt.datum.mutant == 'Y') | 
        (alt.datum.site == 492) & (alt.datum.mutant == 'L') |
        (alt.datum.site == 491) & (alt.datum.mutant == 'A') |
        (alt.datum.site == 588) & (alt.datum.mutant == 'P') |
        (alt.datum.site == 492) & (alt.datum.mutant == 'R') |
        (alt.datum.site == 530) & (alt.datum.mutant == 'F') |
        (alt.datum.site == 492) & (alt.datum.mutant == 'K') |
        (alt.datum.site == 239) & (alt.datum.mutant == 'H') 



    ).mark_point(
        color="blue", size=40, opacity=1, filled=True, stroke='black', strokeWidth=1
    )

    text_on_point = chart.transform_filter(
        (alt.datum.site == 580) & (alt.datum.mutant == 'S') |
        (alt.datum.site == 458) & (alt.datum.mutant == 'Y') | 
        (alt.datum.site == 492) & (alt.datum.mutant == 'L') |
        (alt.datum.site == 491) & (alt.datum.mutant == 'A') |
        (alt.datum.site == 588) & (alt.datum.mutant == 'P') |
        (alt.datum.site == 492) & (alt.datum.mutant == 'R') |
        (alt.datum.site == 530) & (alt.datum.mutant == 'F') |
        (alt.datum.site == 492) & (alt.datum.mutant == 'K') |
        (alt.datum.site == 239) & (alt.datum.mutant == 'H') 
    ).mark_text(
        align='center',baseline='top',dy=-15,fontSize=14
    ).encode(text='site_mutant')
    
    min = int(df["binding_mean_E2"].min())
    max = int(df["binding_mean_E3"].max())
    text = (
        alt.Chart({"values": [{"x": min, "y": max, "text": f"r = {r_value:.2f}"}]})
        .mark_text(align="left", baseline="top", dx=-10, dy=-20)
        .encode(x=alt.X("x:Q"), y=alt.Y("y:Q"), text="text:N")
    )
    chart_and_text = chart + text + text_on_point + highlight
    return chart_and_text.properties(height=300,width=300)


E2_E3_corr = plot_affinity_highB3(df_binding_effect_merge)
E2_E3_corr.display()
#if entry_binding_combined_corr_plot is not None:
#    E2_E3_corr.save(E2_E3_correlation)

### Plot correlations of mutants for individual sites of interest with letters of each mutant plotted

In [None]:
def plot_affinity_individual_mutants(df,mutant):
    df = df.dropna().round(2).copy()
    df = df[df['site'] == mutant]
    
    chart = alt.Chart(df,title=alt.Title(f"Correlation Between Mutant Binding Scores at Site {mutant}")
    ).mark_text(
        size=15
    ).encode(
        x=alt.X("binding_mean_E2", title=("bEFNB2 binding"),axis=alt.Axis(tickCount=4)),
        y=alt.Y("binding_mean_E3", title=("bEFNB3 binding"),axis=alt.Axis(tickCount=4)),
        text=alt.Color('mutant'),
        tooltip=[
            "site",
            "wildtype",
            "mutant",
            "binding_mean_E2",
            "binding_mean_E3",
            "effect_E2",
            "effect_E3",
        ],
    )
    
    return chart.properties(height=300,width=300)

In [None]:
indiv_mutant_492 = plot_affinity_individual_mutants(df_binding_effect_merge,492)
indiv_mutant_492.display()

In [None]:
indiv_mutant_491 = plot_affinity_individual_mutants(df_binding_effect_merge,491)
indiv_mutant_491.display()

In [None]:
indiv_mutant_580 = plot_affinity_individual_mutants(df_binding_effect_merge,580)
indiv_mutant_580.display()

In [None]:
indiv_mutant_589 = plot_affinity_individual_mutants(df_binding_effect_merge,589)
indiv_mutant_589.display()

In [None]:
indiv_mutant_559 = plot_affinity_individual_mutants(df_binding_effect_merge,559)
indiv_mutant_559.display()

In [None]:
indiv_mutant_239 = plot_affinity_individual_mutants(df_binding_effect_merge,239)
indiv_mutant_239.display()

### Plot correlations between summary statistics for each site

In [None]:
def plot_affinity_solid_mean(df):
    df = df.dropna()
    means = (
        df.groupby("site")
        .agg(
            {
                "effect_E2": "mean",
                "effect_E3": "mean",
                "binding_mean_E2": "mean",
                "binding_mean_E3": "mean",
                "wildtype": "first",
            }
        )
        .reset_index().round(2)
    )
    slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(
        means["binding_mean_E2"], means["binding_mean_E3"]
    )
    r_value = float(r_value)
    chart = (
        alt.Chart(
            means,
            title=alt.Title(
                "Correlation between Aggregate Mutant Binding Scores",
                subtitle=f"r={r_value:.2f}",
            ),
        )
        .mark_point(size=50, opacity=0.3,filled=True,color='black')
        .encode(
            x=alt.X(
                "binding_mean_E2",
                title=("Avg EFNB2 Binding"),
                axis=alt.Axis(tickCount=3),
            ),
            y=alt.Y(
                "binding_mean_E3",
                title=("Avg EFNB3 Binding"),
                axis=alt.Axis(tickCount=3),
            ),
            tooltip=[
                "site",
                "wildtype",
                "binding_mean_E2",
                "binding_mean_E3",
                "effect_E2",
                "effect_E3",
            ],
        )
    )
    text = (
        alt.Chart({"values": [{"x": -3.5, "y": 0.5, "text": f"r = {r_value:.2f}"}]})
        .mark_text(align="left", baseline="top", dx=0, dy=0)
        .encode(x=alt.X("x:Q"), y=alt.Y("y:Q"), text="text:N")
    )
    chart_and_text = chart
    return chart_and_text


E2_E3_site_corr = plot_affinity_solid_mean(df_binding_effect_merge)
E2_E3_site_corr.display()
if entry_binding_combined_corr_plot is not None:
    E2_E3_site_corr.save(E2_E3_correlation_site)

### Make plot showing binding by site

In [None]:
def plot_affinity_by_site(df):
    variant_selector = alt.selection_point(
        on="mouseover", nearest=True, empty=False, fields=["site"], value=0
    )
    
    chart = alt.Chart(df).mark_bar(stroke="black",size=1).encode(
        alt.X("site")
            .title("Site")
            .axis(tickCount=4)
            .scale(domain=[70, 602]),
        
        alt.Y('mean(binding_mean)')
            .axis(tickCount=3)
            .title('Mean binding'),
        alt.Row('selection',title=None),
        tooltip=['site'],
        color=alt.condition(
            variant_selector, alt.value("orange"), alt.value("black")
        ),
        opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.5)),
        strokeWidth=alt.condition(variant_selector, alt.value(1), alt.value(0)),
    ).properties(height=100, width=600).add_params(variant_selector).resolve_scale(y='independent')
    

    return chart


binding_by_site = plot_affinity_by_site(df_binding_effect_concat)
binding_by_site.display()
if entry_binding_combined_corr_plot is not None:
    binding_by_site.save(binding_by_site_plot)

In [None]:
def plot_affinity_by_contact_site(df, sites_to_show):
    # Filter the DataFrame based on the sites to show
    contact_df = df[df["site"].isin(sites_to_show)]
    
    # Define a selection for highlighting bars on hover
    variant_selector = alt.selection_point(on="mouseover", nearest=True, empty=False, fields=["site"])
    
    # Create the chart
    chart = alt.Chart(contact_df).mark_bar(stroke='black').encode(
        y=alt.Y("site:N", title="Site",sort='-x'),
        x=alt.X("max(binding_mean):Q", title="Max binding mutant at site"),
        color=alt.condition(variant_selector, alt.value("orange"), alt.value("black")),
        strokeWidth=alt.condition(variant_selector, alt.value(2), alt.value(0)),
        column=alt.Column('selection',title=None)
    ).add_params(variant_selector).properties(width=200, height=alt.Step(12)).resolve_scale(x='independent').configure_header(
        labelFontSize=20,  
        labelAngle=0,
        labelAlign='center',
        labelAnchor='middle',
        labelFont='Helvetica Light',
        labelFontStyle='bold',
        labelPadding=3
    )
    
    return chart


In [None]:
contact_binding_by_site = plot_affinity_by_contact_site(df_binding_effect_concat, config["contact_sites"])
contact_binding_by_site.display()
if entry_binding_combined_corr_plot is not None:
    contact_binding_by_site.save(combined_contact_ranked_bar_output)

### Make bubble plots for binding in different areas of receptor pocket

In [None]:
def make_bubble_binding_region(df):
    barrel_ranges = {
        "Stalk": list(range(96, 147)),
        "Neck": list(range(148, 165)),
        "Linker": list(range(166, 177)),
        "Head": list(range(178, 602)),
        "Receptor Contact": config["contact_sites"],
    }
    custom_order = ["Stalk", "Neck", "Linker", "Head", "Receptor Contact"]
    agg_means = []

    # For each barrel, filter the site_means dataframe to the sites belonging to that barrel and then store the means
    for barrel, sites in barrel_ranges.items():
        subset = df[df["site"].isin(sites)]
        for _, row in subset.iterrows():
            agg_means.append(
                {
                    "region": barrel,
                    "binding_mean": row["binding_mean"],
                    "site": row["site"],
                    "mutant": row["mutant"],
                    "selection": row["selection"],
                }
            )
        agg_means_df = pd.DataFrame(agg_means)

    variant_selector = alt.selection_point(
        on="mouseover", empty=False, nearest=True, fields=["site", "mutant"], value=1
    )

    chart = alt.Chart(agg_means_df).mark_point(stroke="black",filled=True).encode(
        x=alt.X(
            "region:O",
            sort=custom_order,
            title="RBP Region",
            axis=alt.Axis(labelAngle=-90),
        ),
        y=alt.Y(
            "binding_mean",
            title="Binding",
            axis=alt.Axis(tickCount=4),
        ),
        xOffset="random:Q",
        tooltip=["region", "binding_mean", "site", "mutant"],
        color=alt.condition(
            variant_selector, alt.value("orange"), alt.value("black")
        ),
        opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.2)),
        strokeWidth=alt.condition(variant_selector, alt.value(2), alt.value(0)),
        size=alt.condition(variant_selector, alt.value(50), alt.value(15)),
        column=alt.Column('selection',title=None)
    ).transform_calculate(random="sqrt(-1*log(random()))*cos(2*PI*random())"
    ).properties(
        width=200, 
        height=200
    ).add_params(variant_selector
    ).resolve_scale(y='independent'
    ).configure_header(
        labelFontSize=20,  
        labelAngle=0,
        labelAlign='center',
        labelAnchor='middle',
        labelFont='Helvetica Light',
        labelFontStyle='bold',
        labelPadding=3
    )
    
    return chart


entry_region_bubble = make_bubble_binding_region(df_binding_effect_concat)
entry_region_bubble.display()
if entry_binding_combined_corr_plot is not None:
    entry_region_bubble.save(binding_region_bubble_plot)