# Escape sites related to antibody contact distance

This notebook plots the summed escape at each site stratified by distance to antibody residues for each antibody.

In [None]:
# Imports
import os
import warnings
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Plotting colors
tol_muted_adjusted = [
    "#000000",
    "#CC6677", 
    "#1f78b4", 
    "#DDCC77", 
    "#117733", 
    "#882255", 
    "#88CCEE",
    "#44AA99", 
    "#999933", 
    "#AA4499", 
    "#EE7733",
    "#CC3311",
    "#DDDDDD",
]

# Seaborn style settings
sns.set(rc={
    "figure.dpi":300, 
    "savefig.dpi":300,
    "svg.fonttype":"none",
})
sns.set_style("ticks")
sns.set_palette(tol_muted_adjusted)

# Suppress warnings
warnings.simplefilter("ignore")

In [None]:
# this cell is tagged as `parameters` for papermill parameterization
contacts_89F = None
contacts_377H = None
contacts_256A = None
contacts_2510C = None
contacts_121F = None
contacts_372D = None

filtered_escape_377H = None
filtered_escape_89F = None
filtered_escape_2510C = None
filtered_escape_121F = None
filtered_escape_256A = None
filtered_escape_372D = None

func_scores = None

min_times_seen = None
n_selections = None

out_dir = None
saved_image_path = None
func_distance_image_path = None

In [None]:
# # Uncomment for running interactive
# contacts_89F = "../data/antibody_contacts/antibody_contacts_89F.csv"
# contacts_377H = "../data/antibody_contacts/antibody_contacts_377H.csv"
# contacts_256A = "../data/antibody_contacts/antibody_contacts_256A.csv"
# contacts_2510C = "../data/antibody_contacts/antibody_contacts_2510C.csv"
# contacts_121F = "../data/antibody_contacts/antibody_contacts_121F.csv"
# contacts_372D = "../data/antibody_contacts/antibody_contacts_372D.csv"

# filtered_escape_377H = "../results/filtered_antibody_escape_CSVs/377H_filtered_mut_effect.csv"
# filtered_escape_89F = "../results/filtered_antibody_escape_CSVs/89F_filtered_mut_effect.csv"
# filtered_escape_2510C = "../results/filtered_antibody_escape_CSVs/2510C_filtered_mut_effect.csv"
# filtered_escape_121F = "../results/filtered_antibody_escape_CSVs/121F_filtered_mut_effect.csv"
# filtered_escape_256A = "../results/filtered_antibody_escape_CSVs/256A_filtered_mut_effect.csv"
# filtered_escape_372D = "../results/filtered_antibody_escape_CSVs/372D_filtered_mut_effect.csv"

# func_scores = "../results/func_effects/averages/293T_entry_func_effects.csv"

# min_times_seen = 2
# n_selections = 8

# out_dir = "../results/antibody_escape_profiles/"
# saved_image_path = "../results/antibody_escape_profiles/antibody_escape_by_distance.svg"
# func_distance_image_path = "../results/antibody_escape_profiles/func_effect_by_distance.svg"

In [None]:
contacts = [
    contacts_2510C,
    contacts_121F,
    contacts_377H,
    contacts_256A,
    contacts_372D,
    contacts_89F,
]

escape = [
    filtered_escape_2510C,
    filtered_escape_121F,
    filtered_escape_377H,
    filtered_escape_256A,
    filtered_escape_372D,
    filtered_escape_89F,
]

In [None]:
# Functions
def plot_func_scores_vs_distance(contacts_file, escape_file, ax, i, func_scores_file):
    """
    This function creates a plot of 
    functional scores stratified by antibody
    distance.
    """

    antibody_name = contacts_file.split("/")[-1].split("_")[2][:-4]

    # Load data as dataframe
    contacts_df = pd.read_csv(contacts_file)
    escape_df = pd.read_csv(escape_file)
    func_scores = pd.read_csv(func_scores_file)

    # Filter contacts df
    contacts_df = (
        contacts_df
        .groupby(["position"])
        .aggregate({"distance" : "first"})
        .reset_index()
        .sort_values(by=["distance", "position"])
        .rename(columns={"position" : "site"})
        .reset_index(drop=True)
    )

    # Floor escape scores at 0
    escape_df["escape_median"] = escape_df["escape_median"].clip(lower=0)

    # Calculate site sums
    escape_df = (
        escape_df
        .groupby(["site"])
        .aggregate({"escape_median" : "sum"})
        .reset_index()
    )

    # Calculate site averages for functional scores but filter
    # for stop codons, min times seen, and min selections
    func_scores = func_scores.loc[
        (func_scores["mutant"] != "*")
        &
        (func_scores["times_seen"] >= min_times_seen)
        &
        (func_scores["n_selections"] >= n_selections)
    ]
    func_scores = (
        func_scores
        .groupby(["site"])
        .aggregate({"effect" : "mean"})
        .reset_index()
    )
    # Merge functional and escape dfs
    escape_df = (
        escape_df.merge(
            func_scores,
            how="left",
            on=["site"],
            validate="one_to_one",
        )
    )

    # Merge escape and contacts dataframes
    merged_df = (
        escape_df.merge(
            contacts_df,
            how="left",
            on="site",
            validate="one_to_one",
        )
        .fillna(100)
    )

    # Mark sites with strong escape
    cutoff = escape_df["escape_median"].median() * 10
    merged_df["strong escape"] = (
        merged_df
        .apply(
            lambda x: True if x["escape_median"] > cutoff else False, 
            axis=1
        )
    )

    # Re-map distance for plotting and sort
    merged_df["distance"] = merged_df["distance"].map({
        100 : 3,
        4 : 0,
        8 : 1.5,
    })
    merged_df = merged_df.sort_values(by="distance")

    # Add jitter to x values
    merged_df["jittered_x"] = merged_df["distance"] + np.random.normal(0,0.1,merged_df["distance"].shape)

    merged_df = merged_df.loc[
        (merged_df["distance"] == 0)
        |
        (merged_df["strong escape"] == True)
    ]

    # Plot escape vs functional score
    chart = sns.scatterplot(
        data=merged_df,
        x="jittered_x",
        y="effect",
        hue="strong escape",
        edgecolor=None,
        linewidth=0.5,
        palette={False : "#00000026", True : "#EE7733CC"},
        s=20,
        ax=ax,
    )
    if antibody_name == "2510C":
        chart.set_title(
            antibody_name[:2] + "." + antibody_name[2:], 
            fontsize=8, 
            color="#44AA99",
        )
    if antibody_name == "121F":
        chart.set_title(
            antibody_name[:2] + "." + antibody_name[2:], 
            fontsize=8, 
            color="#999933",
        )
    if antibody_name == "377H" or antibody_name == "256A" or antibody_name == "372D":
        chart.set_title(
            antibody_name[:2] + "." + antibody_name[2:], 
            fontsize=8, 
            color="#AA4499",
        )
    if antibody_name == "89F":
        chart.set_title(
            antibody_name[:2] + "." + antibody_name[2:], 
            fontsize=8, 
            color="#117733",
        )
    xticks = [0, 1.5, 3]
    chart.set_xticks(xticks)
    x_labels = [
        "contact", 
        "proximal", 
        "distal",
    ]
    chart.set_xticklabels(labels=x_labels, rotation=90, horizontalalignment="center", fontsize=8)
    chart.set_ylabel("site mean\neffect on cell entry", fontsize=8)
    chart.set_ylim(-5, 1)
    chart.set_xlim(-0.75, 3.75)
    yticks = [-4, -2, 0]
    chart.set_yticks(yticks)
    chart.set_yticklabels(labels=["-4", "-2", "0"], fontsize=8)
    chart.set(xlabel=None)
    # Make only one legend
    if i == 5:
        sns.move_legend(
            chart, 
            "upper left", 
            bbox_to_anchor=(1, 1),
            fontsize=8,
            markerscale=1,
            handletextpad=0.1,
            title="site of\nstrong\nescape",
            title_fontproperties = {
                "size" : 8, 
                # "weight" : "bold",
            },
            frameon=False,
            borderaxespad=0.1,
            reverse=True,
        )
        # Add edges to legend markers to match scatter plot
        for ha in chart.legend_.legendHandles:
            ha.set_edgecolor(None)
            ha.set_linewidths(0.5)
    else:
        ax.get_legend().remove()

    # Only keep the first y-axis
    if i == 0:
        # Change all spines
        for axis in ["top", "bottom", "left", "right"]:
            chart.spines[axis].set_linewidth(1)
        chart.tick_params(axis="both", length=4, width=1)
    else:
        # Change all spines
        for axis in ["top", "bottom", "left", "right"]:
            chart.spines[axis].set_linewidth(1)
        # Remove y-axis
        chart.spines["left"].set_linewidth(0)
        chart.set_yticks([])
        chart.set_yticklabels([])
        chart.tick_params(axis="both", length=4, width=1)
        chart.set_ylabel("")

    chart.grid(False)
    sns.despine()

Plot functional scores stratified by distance to antibody while highlighting strong escape sites (i.e., 10 fold greater than median of all summed escape sites) for all antibodies. Sites that are not strong escape are only showed for contact sites because proximal and distal sites have too many not strong escape sites. 

In [None]:
fig, axes = plt.subplots(1, 6, figsize=(6,2))
for i in range(len(contacts)):
    plot_func_scores_vs_distance(contacts[i], escape[i], axes[i], i, func_scores)

# Make output dir if doesn't exist
if not os.path.exists(out_dir):
    os.mkdir(out_dir)

# Save fig
plt.savefig(func_distance_image_path)

In [None]:
# Functions
def plot_escape_vs_contact_distance(contacts_file, escape_file, ax, i):
    """
    This function creates a plot of 
    site escape stratified by contact.
    """

    antibody_name = contacts_file.split("/")[-1].split("_")[2][:-4]

    # Load data as dataframe
    contacts_df = pd.read_csv(contacts_file)
    escape_df = pd.read_csv(escape_file)

    # Filter contacts df
    contacts_df = (
        contacts_df
        .groupby(["position"])
        .aggregate({"distance" : "first"})
        .reset_index()
        .sort_values(by=["distance", "position"])
        .rename(columns={"position" : "site"})
        .reset_index(drop=True)
    )

    # Floor escape scores at 0
    escape_df["escape_median"] = escape_df["escape_median"].clip(lower=0)

    # Calculate site sums
    escape_df = (
        escape_df
        .groupby(["site"])
        .aggregate({"escape_median" : "sum"})
        .reset_index()
    )

    # Merge escape and contacts dataframes
    merged_df = (
        escape_df.merge(
            contacts_df,
            how="left",
            on="site",
            validate="one_to_one",
        )
        .fillna(100)
    )

    # Mark sites with strong escape
    cutoff = escape_df["escape_median"].median() * 10
    merged_df["strong escape"] = (
        merged_df
        .apply(
            lambda x: True if x["escape_median"] > cutoff else False, 
            axis=1
        )
    )

    # Re-map distance for plotting and sort
    merged_df["distance"] = merged_df["distance"].map({
        100 : 3,
        4 : 0,
        8 : 1.5,
    })
    merged_df = merged_df.sort_values(by="distance")

    # Add jitter to x values
    np.random.seed(0)
    merged_df["jittered_x"] = merged_df["distance"] + np.random.normal(0,0.1,merged_df["distance"].shape)
    
    
    # plt.figure(figsize=(1,2))
    chart = sns.scatterplot(
        data=merged_df,
        x="jittered_x",
        y="escape_median",
        hue="strong escape",
        edgecolor=None,
        linewidth=0.5,
        palette={False : "#00000026", True : "#EE7733CC"},
        s=20,
        ax=ax,
    )
    if antibody_name == "2510C":
        chart.set_title(
            antibody_name[:2] + "." + antibody_name[2:], 
            fontsize=8, 
            color="#44AA99",
        )
    if antibody_name == "121F":
        chart.set_title(
            antibody_name[:2] + "." + antibody_name[2:], 
            fontsize=8, 
            color="#999933",
        )
    if antibody_name == "377H" or antibody_name == "256A" or antibody_name == "372D":
        chart.set_title(
            antibody_name[:2] + "." + antibody_name[2:], 
            fontsize=8, 
            color="#AA4499",
        )
    if antibody_name == "89F":
        chart.set_title(
            antibody_name[:2] + "." + antibody_name[2:], 
            fontsize=8, 
            color="#117733",
        )
    xticks = [0, 1.5, 3]
    chart.set_xticks(xticks)
    x_labels = [
        "contact", 
        "proximal", 
        "distal",
    ]
    chart.set_xticklabels(labels=x_labels, rotation=90, horizontalalignment="center", fontsize=8)
    chart.set_ylabel("site escape", fontsize=8)
    chart.set_ylim(-2.5, 95)
    chart.set_xlim(-0.75, 3.75)
    yticks = [0, 20, 40, 60, 80]
    chart.set_yticks(yticks)
    chart.set_yticklabels(labels=["0", "20", "40", "60", "80"], fontsize=8)
    chart.set(xlabel=None)
    # Make only one legend
    if i == 5:
        sns.move_legend(
            chart, 
            "upper left", 
            bbox_to_anchor=(1, 1),
            fontsize=8,
            markerscale=1,
            handletextpad=0.1,
            title="site of\nstrong\nescape",
            title_fontproperties = {
                "size" : 8, 
                # "weight" : "bold",
            },
            frameon=False,
            borderaxespad=0.1,
            reverse=True,
        )
        # Add edges to legend markers to match scatter plot
        for ha in chart.legend_.legendHandles:
            ha.set_edgecolor(None)
            ha.set_linewidths(0.5)
    else:
        ax.get_legend().remove()

    # Only keep the first y-axis
    if i == 0:
        # Change all spines
        for axis in ["top", "bottom", "left", "right"]:
            chart.spines[axis].set_linewidth(1)
        chart.tick_params(axis="both", length=4, width=1)
    else:
        # Change all spines
        for axis in ["top", "bottom", "left", "right"]:
            chart.spines[axis].set_linewidth(1)
        # Remove y-axis
        chart.spines["left"].set_linewidth(0)
        chart.set_yticks([])
        chart.set_yticklabels([])
        chart.tick_params(axis="both", length=4, width=1)
        chart.set_ylabel("")

    chart.grid(False)
    sns.despine()
    
    # Plot cutoff line
    chart.axhline(
        y = cutoff, 
        color = "#000000", 
        linestyle = "--",
        alpha=0.5,
        linewidth=1,
    )

    # Calculate counts for each distance and add to top of plot
    print(antibody_name)
    strong_escape = merged_df.loc[(merged_df["distance"] == 0) & (merged_df["strong escape"] == True)].shape[0]
    total = merged_df.loc[(merged_df["distance"] == 0)].shape[0]
    print(f"Contacts with strong escape: {(strong_escape/total)*100:.1f}%")
    chart.text(
        0,
        91,
        f"{strong_escape}",
        fontsize=7,
        horizontalalignment="center",
        color="#EE7733",
    )
    chart.text(
        0,
        85,
        f"{total-strong_escape}",
        fontsize=7,
        horizontalalignment="center",
        color="#000000"
    )
    
    strong_escape = merged_df.loc[(merged_df["distance"] == 1.5) & (merged_df["strong escape"] == True)].shape[0]
    total = merged_df.loc[(merged_df["distance"] == 1.5)].shape[0]
    print(f"Proximal with strong escape: {(strong_escape/total)*100:.1f}%")
    chart.text(
        1.5,
        91,
        f"{strong_escape}",
        fontsize=7,
        horizontalalignment="center",
        color="#EE7733",
    )
    chart.text(
        1.5,
        85,
        f"{total-strong_escape}",
        fontsize=7,
        horizontalalignment="center",
        color="#000000"
    )

    strong_escape = merged_df.loc[(merged_df["distance"] == 3) & (merged_df["strong escape"] == True)].shape[0]
    total = (
        merged_df.loc[(merged_df["distance"] == 3)].shape[0] 
        + 
        (
            491 # size of LASV protein 
            - 
            merged_df.loc[(merged_df["distance"] == 1.5)].shape[0] 
            - 
            merged_df.loc[(merged_df["distance"] == 0)].shape[0] 
            - 
            merged_df.loc[(merged_df["distance"] == 3)].shape[0]
        )
    )
    print(f"Distal with strong escape: {(strong_escape/total)*100:.1f}%")
    # Calculate stats for summed escape across all sites
    print(f"Cumulative escape across all sites: {escape_df['escape_median'].sum()}")
    print(f"Median escape across all sites: {escape_df['escape_median'].median()}")
    print(f"Mean escape across all sites: {escape_df['escape_median'].mean()}")
    print()
    chart.text(
        3,
        91,
        f"{strong_escape}",
        fontsize=7,
        horizontalalignment="center",
        color="#EE7733",
    )
    chart.text(
        3,
        85,
        f"{total-strong_escape}",
        fontsize=7,
        horizontalalignment="center",
        color="#000000"
    )

    # Label points on each scatter plot
    for i in range(0, merged_df.shape[0]):
        x_pos = merged_df.at[i, "jittered_x"]
        y_pos = merged_df.at[i, "escape_median"]
        name = merged_df.at[i, "site"]
        if antibody_name == "89F":
            if name == 119:
                chart.text(
                    x_pos+0.05,
                    y_pos+2,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#EE7733",
                )
            if name == 125:
                chart.text(
                    x_pos+0.2,
                    y_pos+1,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#EE7733",
                )
            if name == 129:
                chart.text(
                    x_pos+0.05,
                    y_pos+2,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#EE7733",
                )
            if name == 138:
                chart.text(
                    x_pos+0.05,
                    y_pos+2,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#EE7733",
                )
            if name == 150:
                chart.text(
                    x_pos+0.05,
                    y_pos+2,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#EE7733",
                )
        if antibody_name == "377H":
            if name == 398:
                chart.text(
                    x_pos+0.25,
                    y_pos,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#EE7733",
                )
            if name == 401:
                chart.text(
                    x_pos+0.05,
                    y_pos+2,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#EE7733",
                )
            if name == 402:
                chart.text(
                    x_pos+0.05,
                    y_pos+2,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#EE7733",
                )
            if name == 404:
                chart.text(
                    x_pos+0.05,
                    y_pos+2,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#EE7733",
                )
        if antibody_name == "256A":
            if name == 401:
                chart.text(
                    x_pos+0.05,
                    y_pos+2,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#EE7733",
                )
            if name == 404:
                chart.text(
                    x_pos+0.05,
                    y_pos+2,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#EE7733",
                )       
        if antibody_name == "2510C":
            if name == 76:
                chart.text(
                    x_pos+0.05,
                    y_pos+2,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#EE7733",
                )
            if name == 99:
                chart.text(
                    x_pos+0.05,
                    y_pos+2,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#EE7733",
                )
            if name == 101:
                chart.text(
                    x_pos-0.75,
                    y_pos+2,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#EE7733",
                )
            if name == 228:
                chart.text(
                    x_pos+0.05,
                    y_pos+2,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#EE7733",
                )
        if antibody_name == "121F":
            if name == 89:
                chart.text(
                    x_pos+0.05,
                    y_pos+2,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#EE7733",
                )
            if name == 92:
                chart.text(
                    x_pos+0.05,
                    y_pos+2,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#EE7733",
                )
            if name == 111:
                chart.text(
                    x_pos-1.25,
                    y_pos+1,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#EE7733",
                )
            if name == 127:
                chart.text(
                    x_pos+0.1,
                    y_pos+2,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#EE7733",
                )
            if name == 160:
                chart.text(
                    x_pos-1.25,
                    y_pos+1,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#EE7733",
                )
        if antibody_name == "372D":
            if name == 149:
                chart.text(
                    x_pos+0.05,
                    y_pos+2,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#EE7733",
                )
            if name == 395:
                chart.text(
                    x_pos+0.05,
                    y_pos+2,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#EE7733",
                )
            if name == 397:
                chart.text(
                    x_pos+0.05,
                    y_pos+2,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#EE7733",
                )
            if name == 398:
                chart.text(
                    x_pos-1.25,
                    y_pos+2,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#EE7733",
                )

Plot summed site escape scores stratified by distance to antibody while highlighting strong escape sites (i.e., site score 10 fold greater than median of all sites) for all antibodies.

In [None]:
fig, axes = plt.subplots(1, 6, figsize=(6,2))
for i in range(len(contacts)):
    plot_escape_vs_contact_distance(contacts[i], escape[i], axes[i], i)

# Make output dir if doesn't exist
if not os.path.exists(out_dir):
    os.mkdir(out_dir)

# Save fig
plt.savefig(saved_image_path)