# Compare DMS data to natural sequence data

This notebook analyzes natural sequences for antibody escape predicted by DMS and shows the resulting neutralization validations. Additional analysis is performed by correlating natural sequence diversity with DMS data.

In [None]:
# Imports
import os
import warnings
import dmslogo
import neutcurve
import numpy as np
import scipy as sp
import pandas as pd
import altair as alt
import seaborn as sns
import matplotlib.colors
import matplotlib.pyplot as plt
from Bio import SeqIO, AlignIO 

# Plotting colors
tol_muted_adjusted = [
    "#000000",
    "#CC6677", 
    "#1f78b4", 
    "#DDCC77", 
    "#117733", 
    "#882255", 
    "#88CCEE",
    "#44AA99", 
    "#999933", 
    "#AA4499", 
    "#EE7733",
    "#CC3311",
    "#DDDDDD",
]

# Create color palette
def color_gradient_hex(start, end, n):
    """Color function from polyclonal"""
    cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
            name="_", colors=[start, end], N=n
        )
    return [matplotlib.colors.rgb2hex(tup) for tup in cmap(list(range(0, n)))]

# Seaborn style settings
sns.set(rc={"figure.dpi":300, "savefig.dpi":300})
sns.set_style("ticks")
sns.set_palette(tol_muted_adjusted)

# Suppress warnings
warnings.simplefilter("ignore")

In [None]:
# this cell is tagged as `parameters` for papermill parameterization
filtered_escape_377H = None
filtered_escape_89F = None
filtered_escape_2510C = None
filtered_escape_121F = None
filtered_escape_256A = None
filtered_escape_372D = None

contacts_89F = None
contacts_377H = None
contacts_256A = None
contacts_2510C = None
contacts_121F = None
contacts_372D = None

func_scores = None
min_times_seen = None
n_selections = None

natural_sequence_variation = None
natural_GPC_sequence_alignment = None

fraction_infected_natural_isolates = None
out_dir = None
neuts_image_path = None
corr_image_path = None

out_dir_escape = None
escape_image_path = None

In [None]:
# # Uncomment for running interactive
# filtered_escape_377H = "../results/filtered_antibody_escape_CSVs/377H_filtered_mut_effect.csv"
# filtered_escape_89F = "../results/filtered_antibody_escape_CSVs/89F_filtered_mut_effect.csv"
# filtered_escape_2510C = "../results/filtered_antibody_escape_CSVs/2510C_filtered_mut_effect.csv"
# filtered_escape_121F = "../results/filtered_antibody_escape_CSVs/121F_filtered_mut_effect.csv"
# filtered_escape_256A = "../results/filtered_antibody_escape_CSVs/256A_filtered_mut_effect.csv"
# filtered_escape_372D = "../results/filtered_antibody_escape_CSVs/372D_filtered_mut_effect.csv"

# contacts_89F = "../data/antibody_contacts/antibody_contacts_89F.csv"
# contacts_377H = "../data/antibody_contacts/antibody_contacts_377H.csv"
# contacts_256A = "../data/antibody_contacts/antibody_contacts_256A.csv"
# contacts_2510C = "../data/antibody_contacts/antibody_contacts_2510C.csv"
# contacts_121F = "../data/antibody_contacts/antibody_contacts_121F.csv"
# contacts_372D = "../data/antibody_contacts/antibody_contacts_372D.csv"

# func_scores = "../results/func_effects/averages/293T_entry_func_effects.csv"
# min_times_seen = 2
# n_selections = 8

# natural_sequence_variation = "../non-pipeline_analyses/LASV_phylogeny_analysis/Results/GPC_protein_variation.csv"
# natural_GPC_sequence_alignment = "../non-pipeline_analyses/LASV_phylogeny_analysis/Results/LASV_GPC_protein_alignment.fasta"

# fraction_infected_natural_isolates = "../data/validation_frac_infected_natural_isolates.csv"
# out_dir = "../results/validation_plots/"
# neuts_image_path = "../results/validation_plots/validation_neut_curves_natural_isolates.svg"
# corr_image_path = "../results/validation_plots/natural_isolate_validation_correlation.svg"

# out_dir_escape = "../results/antibody_escape_profiles/"
# escape_image_path = "../results/antibody_escape_profiles/natural_isolate_escape_profiles.svg"

## Identify natural sequences that would potentially escape antibody neutralization

To identify any natural isolates that could potentially escape any of the mapped antibodies, we want to identify high confidence escape mutations. First, we are going to filter for the top 5% escape mutants and then further filter this list to the mutations present in the top 5% of **summed escape** sites. Escape scores are clipped at the lower end of 0 to focus on escape mutations rather than sensitizing mutations. Any sequence with these mutations will be flagged as potential escape isolates.

In [None]:
def determine_escape(percentile_escape, sequence, escape_file, strain, print_results=False):
    """
    Function that determines if a sequence contains any 
    escape mutations given a percentile cutoff.
    """

    antibody_name = escape_file.split("/")[-1].split("_")[0]
    
    # Load data as dataframe
    escape_df = pd.read_csv(escape_file)

    # Clip lower scores to 0
    escape_df["escape_median"] = escape_df["escape_median"].clip(lower=0)

    # Get muts for top escape
    cutoff = escape_df["escape_median"].quantile(percentile_escape)
    top_escape_muts = (
        tuple(zip(
            escape_df.loc[escape_df["escape_median"] >= cutoff]["mutation"].tolist(),
            escape_df.loc[escape_df["escape_median"] >= cutoff]["escape_median"].tolist(),
        ))
    )

    # Calculate summed escape sites and get sites for top escape
    escape_df = (
        escape_df
        .groupby("site")
        .aggregate({"escape_median" : "sum"})
        .reset_index()
    )
    cutoff = escape_df["escape_median"].quantile(percentile_escape)
    top_escape_sites = escape_df.loc[escape_df["escape_median"] >= cutoff]["site"].tolist()

    # Filter top escape muts bases on top escape sites
    top_escape_muts = [x for x in top_escape_muts if int(x[0][1:-1]) in top_escape_sites]
    
    # Initialize escape flag
    escape = 0
    
    # Iterate through list of escape mutations
    for escape_mut in top_escape_muts:
        site = int(escape_mut[0][1:-1])
        if sequence[site-1] == escape_mut[0][-1]:
            if print_results:
                print(f"{strain:<75} with \t {escape_mut[0][0]}{site}{escape_mut[0][-1]} \t DMS score: {escape_mut[1]}")
            escape += escape_mut[1]
   
    return escape

# Load alignment and metadata info
natural_seqs_df = pd.DataFrame(columns=["strain", "sequence"])

# Add alignment sequence to dataframe
for curr_fasta in AlignIO.read(natural_GPC_sequence_alignment, "fasta"):
    natural_seqs_df.loc[len(natural_seqs_df.index)] = [
        str(curr_fasta.id),
        str(curr_fasta.seq),
    ] 

# Antibody escape dataframes and percentile cutoffs to use
antibody_files = [
    (filtered_escape_2510C, 0.95),
    (filtered_escape_121F, 0.95), 
    (filtered_escape_377H, 0.95),
    (filtered_escape_256A, 0.95),
    (filtered_escape_372D, 0.95),
    (filtered_escape_89F, 0.95), 
]

for antibody, percentile in antibody_files:
    antibody_name = antibody.split("/")[-1].split("_")[0]
    print(f"{antibody_name} potentially escaped by:")
    natural_seqs_df.apply(lambda x: determine_escape(
        percentile, 
        x["sequence"], 
        antibody, 
        x["strain"],
        print_results=True,
    ), axis=1)
    print()

The following were chosen for validation:
- 8.9F
    - Natural isolate: Lassa_virus_H-sapiens-wt_NGA_2018_ISTH_1024_MH157037_2018-02-14
        - Corresponding single mutant: K126N (DMS score: 0.7906)
- 12.1F
    - Natural isolate: LM395-SLE-2009_KM822115_2009-XX-XX
        - Corresponding single mutant: N89D (DMS score: 2.515)
- 25.10C
    - Natural isolate: GA391_OL774861_reverse_complement_1977-XX-XX 
        - Corresponding single mutant: E228D (DMS score: 3.098)
- 37.7H
    - Natural isolate: GA391_OL774861_reverse_complement_1977-XX-XX 
        - Corresponding single mutant: H398K (DMS score: 2.949)
    - Natural isolate: LASV_H-sapiens-tc_NGA_2016_IRR_007_MK107922_2016-01-18
        - Corresponding single mutant: D401E (DMS score: 1.3)
- 37.2D
    - Natural isolate: GA391_OL774861_reverse_complement_1977-XX-XX
        - Corresponding single mutant: H398K (DMS score: 1.078)

*The D401E mutation present in the chosen natural isolates did not have high confidence in the DMS data for the 25.6A antibody selection so 25.6A will be omitted from validation neutralization assays.

In [None]:
# Isolates chosen to validate for escape 
chosen_isolates = [
    "Josiah_NC_004296_reverse_complement_2018-08-13",
    "GA391_OL774861_reverse_complement_1977-XX-XX",
    "LASV_H-sapiens-tc_NGA_2016_IRR_007_MK107922_2016-01-18",
    "Lassa_virus_H-sapiens-wt_NGA_2018_ISTH_1024_MH157037_2018-02-14",
    "LM395-SLE-2009_KM822115_2009-XX-XX",
]

# Create subset of df for chosen isolates
validation_isolates = (
    natural_seqs_df.loc[natural_seqs_df["strain"].isin(chosen_isolates)]
    .reset_index(drop=True)
)

# Get Josiah sequence for comparison
josiah_sequence = validation_isolates.loc[validation_isolates["strain"] == "Josiah_NC_004296_reverse_complement_2018-08-13"].at[0,"sequence"]

# Rename dictionary for virus names
rename_dict = {
    "WT" : "unmutated Josiah",
    "OL774861" : "GA391",
    "KM822115" : "LM395",
    "MK107922" : "IRR007",
    "MH157037" : "ISTH1024",
    "Josiah_NC_004296_reverse_complement_2018-08-13" : "unmutated Josiah",
    "GA391_OL774861_reverse_complement_1977-XX-XX" : "GA391",
    "LASV_H-sapiens-tc_NGA_2016_IRR_007_MK107922_2016-01-18" : "IRR007",
    "Lassa_virus_H-sapiens-wt_NGA_2018_ISTH_1024_MH157037_2018-02-14" : "ISTH1024",
    "LM395-SLE-2009_KM822115_2009-XX-XX" : "LM395",
    "N89D" : "N89D Josiah",
    "K126N" : "K126N Josiah",
    "E228D" : "E228D Josiah",
    "H398K" : "H398K Josiah",
    "D401E" : "D401E Josiah",
}

# Rename viruses and column name
validation_isolates["strain"] = (
    validation_isolates["strain"].replace(rename_dict)    
)

## Escape profiles for the sites that differ in the chosen isolates from the Josiah DMS strain

In [None]:
def create_line_and_logoplots(
    escape_file,  
    contacts_file,
    func_scores, 
    min_times_seen,  
    n_selections, 
    line_plot,
    logo_plot,
    output_file = None,
    sites = None,
    name = None,
):
    """
    Function that filters and writes an antibody escape csv.
    Also creates summed escape profiles and logoplots.
    """

    antibody_name = escape_file.split("/")[-1].split("_")[0]

    # Read data
    escape_df = pd.read_csv(escape_file)
    func_scores = pd.read_csv(func_scores)
    contacts_df = pd.read_csv(contacts_file)
    
    # Create mutation column to match antibody df
    func_scores["site"] = func_scores["site"].astype(str)
    func_scores["mutation"] = func_scores["wildtype"] + func_scores["site"] + func_scores["mutant"]

    # Clip lower scores to 0
    escape_df["escape_median"] = escape_df["escape_median"].clip(lower=0)

    # Summed escape to get top 20 escape sites to show
    summed_df = (
        escape_df
        .groupby(["site", "wildtype"])
        .aggregate({
            "escape_median" : "sum",
        })
        .rename(columns={"escape_median" : "site_escape"})
        .reset_index()
    )
    # Top escape sites for each antibody combined to show
    if sites == None:
        sites = sorted(summed_df.nlargest(18, "site_escape")["site"].tolist())
    escape_df["show_site"] = escape_df.apply(lambda x: True if x["site"] in sites else False, axis=1)
    
    # Shade contact sites in logo plot
    shade_sites = list(contacts_df.loc[contacts_df["distance"] == 4]["position"].unique())
    
    # **
    # # Uncomment to show antibody contacts
    # print(antibody_name)
    # print(f"Contact sites: {shade_sites}")
    # print()
    # **
    
    escape_df["shade_site"] = escape_df.apply(lambda x: "#DDDDDD" if x["site"] in shade_sites else None, axis=1)
    escape_df["shade_alpha"] = 0.75
    escape_df = (
        escape_df.merge(
            summed_df,
            how="left",
            on=["site", "wildtype"],
            validate="many_to_one",
        )
    )

    # Drop extra columns
    escape_df = (
        escape_df.drop(
            escape_df.columns.difference([
                "site", 
                "wildtype",
                "mutant", 
                "escape_median", 
                "show_site", 
                "shade_site",
                "shade_alpha",
                "site_escape", 
            ]), axis=1)
    )

    # Fill in missing sites
    seen_list = [False]*491
    for index in range(len(escape_df.index)):
        site = escape_df.at[index, "site"] - 1 
        seen_list[site] = True
    for index, seen in enumerate(seen_list):
        if seen == False:
            site = index + 1
            # Add missing sites
            escape_df.loc[len(escape_df.index)] = [
                site,
                "X",
                "X",
                0,
                False,
                None,
                None,
                0,
            ]

    # Sort by site
    escape_df = (
        escape_df
        .sort_values(by="site")
        .astype({"mutant" : "str"})
        .reset_index(drop=True)
    )

    # Merge functional and escape dfs
    func_scores["site"] = func_scores["site"].astype("int")
    func_scores = func_scores.loc[func_scores["mutant"] != "*"] # remove stop codons
    func_scores["effect"] = func_scores["effect"].clip(upper=0, lower=-2) # clip scores 
    escape_df = (
        escape_df.merge(
            func_scores,
            how="left",
            on=["site", "wildtype", "mutant"],
            validate="one_to_one",
        )
    )
    escape_df["effect"] = escape_df["effect"].fillna(-2) # missing functional values are filled as -4 to make less visible
    
    # Add color column for logo plots
    func_color_map = dmslogo.colorschemes.ValueToColorMap(
        minvalue=func_scores["effect"].min(),
        maxvalue=func_scores["effect"].max(),
        cmap=matplotlib.colors.ListedColormap(color_gradient_hex("white", "#000000", n=20))
    )
    escape_df = (
        escape_df.assign(
            color=lambda x: x["effect"].map(func_color_map.val_to_color)
        )
    )

    # Add wildtype to each site for logo plot
    escape_df["wt_site"] = escape_df["wildtype"] + escape_df["site"].map(str)

    # Set ylim for each antibody
    fixed_ymin = None
    fixed_ymax = None
    if antibody_name == "2510C":
        fixed_ymin = -6.25
        fixed_ymax = 56.25
    elif antibody_name == "121F":
        fixed_ymin = -2.5
        fixed_ymax = 22.5
    elif antibody_name == "377H":
        fixed_ymin = -7.5 
        fixed_ymax = 67.5
    elif antibody_name == "256A":
        fixed_ymin = -5 
        fixed_ymax = 45
    elif antibody_name == "372D":
        fixed_ymin = -2.5 
        fixed_ymax = 22.5
    elif antibody_name == "89F":
        fixed_ymin = -5 
        fixed_ymax = 45
    else:
        print("Error! No ylims set!")

    # Plot escape profiles and logo plots
    _, lineplot = dmslogo.draw_line(
        escape_df,
        x_col="site",
        height_col="site_escape",
        show_col="show_site",
        ax=line_plot,
        show_color="#CC6677",
        linewidth=0.5,
        fixed_ymin=fixed_ymin,
        fixed_ymax=fixed_ymax,
    )
    lineplot.set(ylabel=None, xlabel=None)
    lineplot.set_xlim(1,491)
    xticks = [100, 200, 300, 400]
    lineplot.set_xticks(xticks)
    x_labels = [
        "100", 
        "200",
        "300",
        "400",
    ]
    lineplot.set_xticklabels(labels=x_labels, rotation=90, horizontalalignment="center", fontsize=6)
    # Change all spines
    for axis in ["top", "bottom", "left", "right"]:
        lineplot.spines[axis].set_linewidth(1)
    lineplot.tick_params(axis="both", length=4, width=1, pad=1)
    

    _, logoplot = dmslogo.draw_logo(
        escape_df.query("show_site == True"),
        x_col="site",
        letter_col="mutant",
        letter_height_col="escape_median",
        ax=logo_plot,
        xtick_col="wt_site",
        color_col="color",
        shade_color_col="shade_site",
        shade_alpha_col="shade_alpha",
        draw_line_at_zero="never",
        fixed_ymin=fixed_ymin,
        fixed_ymax=fixed_ymax,
    )

    logoplot.set(ylabel=None, xlabel=None)
    x_labels = logoplot.get_xticklabels()
    logoplot.set_xticklabels(labels=x_labels, rotation=90, horizontalalignment="center", fontsize=6)
    # Change all spines
    for axis in ["top", "bottom", "left", "right"]:
        logoplot.spines[axis].set_linewidth(1)
    logoplot.tick_params(axis="both", length=4, width=1, pad=1)


    # Set antibody specific y axis ticks
    if antibody_name == "2510C":
        yticks = [0, 25, 50]
        lineplot.set_yticks(yticks)
        lineplot.set_yticklabels(labels=["0", "25", "50"], fontsize=6)
        logoplot.set_yticks(yticks)
        logoplot.set_yticklabels(labels=["0", "25", "50"], fontsize=6)
        lineplot.set_title(
            antibody_name[0:2] + "." + antibody_name[2:], 
            fontsize=8,
            color="#44AA99",
        )
        logoplot.set_title(
            antibody_name[0:2] + "." + antibody_name[2:] + " escape profile for sites different between Josiah and " + name, 
            fontsize=8,
            color="#44AA99",
        )
    if antibody_name == "121F":
        yticks = [0, 10, 20]
        lineplot.set_yticks(yticks)
        lineplot.set_yticklabels(labels=["0", "10", "20"], fontsize=6)
        logoplot.set_yticks(yticks)
        logoplot.set_yticklabels(labels=["0", "10", "20"], fontsize=6)
        lineplot.set_title(
            antibody_name[0:2] + "." + antibody_name[2:], 
            fontsize=8,
            color="#999933",
        )
        logoplot.set_title(
            antibody_name[0:2] + "." + antibody_name[2:] + " escape profile for sites different between Josiah and " + name, 
            fontsize=8,
            color="#999933",
        )
    if antibody_name == "377H":
        yticks = [0, 30, 60]
        lineplot.set_yticks(yticks)
        lineplot.set_yticklabels(labels=["0", "30", "60"], fontsize=6)
        logoplot.set_yticks(yticks)
        logoplot.set_yticklabels(labels=["0", "30", "60"], fontsize=6)
        lineplot.set_title(
            antibody_name[0:2] + "." + antibody_name[2:], 
            fontsize=8,
            color="#AA4499",
        )
        logoplot.set_title(
            antibody_name[0:2] + "." + antibody_name[2:] + " escape profile for sites different between Josiah and " + name, 
            fontsize=8,
            color="#AA4499",
        )
    if antibody_name == "256A":
        yticks = [0, 20, 40]
        lineplot.set_yticks(yticks)
        lineplot.set_yticklabels(labels=["0", "20", "40"], fontsize=6)
        logoplot.set_yticks(yticks)
        logoplot.set_yticklabels(labels=["0", "20", "40"], fontsize=6)
        lineplot.set_title(
            antibody_name[0:2] + "." + antibody_name[2:], 
            fontsize=8,
            color="#AA4499",
        )
        logoplot.set_title(
            antibody_name[0:2] + "." + antibody_name[2:] + " escape profile for sites different between Josiah and " + name, 
            fontsize=8,
            color="#AA4499",
        )
    if antibody_name == "372D":
        yticks = [0, 10, 20]
        lineplot.set_yticks(yticks)
        lineplot.set_yticklabels(labels=["0", "10", "20"], fontsize=6)
        logoplot.set_yticks(yticks)
        logoplot.set_yticklabels(labels=["0", "10", "20"], fontsize=6)
        lineplot.set_title(
            antibody_name[0:2] + "." + antibody_name[2:], 
            fontsize=8,
            color="#AA4499",
        )
        logoplot.set_title(
            antibody_name[0:2] + "." + antibody_name[2:] + " escape profile for sites different between Josiah and " + name, 
            fontsize=8,
            color="#AA4499",
        )
    if antibody_name == "89F":
        yticks = [0, 20, 40]
        lineplot.set_yticks(yticks)
        lineplot.set_yticklabels(labels=["0", "20", "40"], fontsize=6)
        logoplot.set_yticks(yticks)
        logoplot.set_yticklabels(labels=["0", "20", "40"], fontsize=6)
        lineplot.set_title(
            antibody_name[0] + "." + antibody_name[1:],
            fontsize=8,
            color="#117733",
        )
        logoplot.set_title(
            antibody_name[0] + "." + antibody_name[1:] + " escape profile for sites different between Josiah and " + name,
            fontsize=8,
            color="#117733",
        )

In [None]:
def get_site_differences(seq1, seq2):
    """
    Function that returns a list of all sites
    that are different between two sequences.
    """
    list_of_sites = []
    site = 1
    for s1, s2 in zip(seq1, seq2):
        if s1 != s2 and s1 != "-" and s2 != "-":
            list_of_sites.append(site)
        site += 1
    return list_of_sites

# Get list of sequence differences compared to Josiah
list_of_different_sites = []
for index in validation_isolates.index:
    curr_seq_name = validation_isolates.at[index, "strain"]
    curr_seq = validation_isolates.at[index, "sequence"]
    differences = get_site_differences(josiah_sequence, curr_seq)
    list_of_different_sites.append((curr_seq_name, differences))

# Antibody contact files and escape files
contacts_files = [
    contacts_2510C,
    contacts_121F,
    contacts_377H,
    # contacts_256A,
    contacts_377H, # duplicate 37.7H because validated with two strains
    contacts_372D,
    contacts_89F,
]

filtered_escape_files = [
    filtered_escape_2510C,
    filtered_escape_121F,
    filtered_escape_377H,
    # filtered_escape_256A,
    filtered_escape_377H, # duplicate 37.7H because validated with two strains
    filtered_escape_372D,
    filtered_escape_89F,
]

# Set figure size and subplots
fig, axes = plt.subplots(
    6, 
    2, 
    gridspec_kw={"width_ratios":[1,8]},
    figsize=(7, 6.25), 
    # sharex="col"
)
# Adjust spacing of subplots
fig.subplots_adjust(
    bottom=0, 
    top=1, 
    wspace=0.1, 
    hspace=0.9,
)

# Iterate through list of antibody files
for i in range(len(antibody_files)):
    name = None
    sites = None
    if i == 0:
        name = list_of_different_sites[1][0]
        sites = list_of_different_sites[1][1]
    elif i == 1:
        name = list_of_different_sites[4][0]
        sites = list_of_different_sites[4][1]
    elif i == 2:
        name = list_of_different_sites[1][0]
        sites = list_of_different_sites[1][1]
    elif i == 3:
        name = list_of_different_sites[2][0]
        sites = list_of_different_sites[2][1]
    elif i == 4:
        name = list_of_different_sites[1][0]
        sites = list_of_different_sites[1][1]
    elif i == 5:
        name = list_of_different_sites[3][0]
        sites = list_of_different_sites[3][1]
    else:
        print("Error! Index out of range!")

    create_line_and_logoplots(
        filtered_escape_files[i], 
        contacts_files[i],
        func_scores, 
        min_times_seen, 
        n_selections, 
        axes[i][0],
        axes[i][1],
        sites=sorted(sites),
        name=name,
    )

# Common X and Y axis labels
fig.text(0.5, -0.1, "site", ha="center", fontsize=8)
fig.text(-0.05, 0.5, "site escape", va="center", rotation="vertical", fontsize=8)

# Make output dir if doesn't exist
if not os.path.exists(out_dir_escape):
    os.mkdir(out_dir_escape)

# Save fig
plt.savefig(escape_image_path)

## Pseudovirus neutralization validation assays for chosen natural isolates and corresponding single mutants

The chosen isolates were validated with tradition pseudovirus neutralization assays in addition to the corresponding single mutation identified as the main one for escape.

In [None]:
# Rename column name
validation_isolates = (
    validation_isolates.rename(columns={"strain" : "virus"})
)

# Read nuetralization data
frac_infected_natural_isolates = pd.read_csv(fraction_infected_natural_isolates)

# Rename viruses
frac_infected_natural_isolates["virus"] = (
    frac_infected_natural_isolates["virus"].replace(rename_dict)
)

# Fit hill curves using neutcurve
fits = neutcurve.curvefits.CurveFits(
    data=frac_infected_natural_isolates,
    fixbottom=0,
    fixtop=1,
)

In [None]:
# Plot neut data
fig, axes = fits.plotGrid(
    {
        (0, 0) : ("25.10C",
                  [
                      {
                          "serum" : "25.10C", 
                          "virus" : "unmutated Josiah",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[0],
                          "marker" : "o",
                          "label" : "unmutated Josiah"
                      },
                      {
                          "serum" : "25.10C", 
                          "virus" : "GA391",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[1],
                          "marker" : "o",
                          "label" : "GA391"
                      },
                      {
                          "serum" : "25.10C", 
                          "virus" : "E228D Josiah",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[2],
                          "marker" : "o",
                          "label" : "E228D Josiah"
                      },
                  ]
                 ),
        (1, 0) : ("12.1F",
                  [
                      {
                          "serum" : "12.1F", 
                          "virus" : "unmutated Josiah",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[0],
                          "marker" : "o",
                          "label" : "unmutated Josiah"
                      },
                      {
                          "serum" : "12.1F", 
                          "virus" : "LM395",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[3],
                          "marker" : "o",
                          "label" : "LM395"
                      },
                      {
                          "serum" : "12.1F", 
                          "virus" : "N89D Josiah",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[4],
                          "marker" : "o",
                          "label" : "N89D Josiah"
                      },
                  ]
                 ),
        (2, 0) : ("37.7H",
                  [
                      {
                          "serum" : "37.7H", 
                          "virus" : "unmutated Josiah",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[0],
                          "marker" : "o",
                          "label" : "unmutated Josiah"
                      },
                      {
                          "serum" : "37.7H", 
                          "virus" : "GA391",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[1],
                          "marker" : "o",
                          "label" : "GA391"
                      },
                      {
                          "serum" : "37.7H", 
                          "virus" : "H398K Josiah",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[7],
                          "marker" : "o",
                          "label" : "H398K Josiah"
                      },
                  ]
                 ),
        (3, 0) : ("37.7H",
                  [
                      {
                          "serum" : "37.7H", 
                          "virus" : "unmutated Josiah",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[0],
                          "marker" : "o",
                          "label" : "unmutated Josiah"
                      },
                      {
                          "serum" : "37.7H", 
                          "virus" : "IRR007",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[6],
                          "marker" : "o",
                          "label" : "IRR007"
                      },
                      {
                          "serum" : "37.7H", 
                          "virus" : "D401E Josiah",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[5],
                          "marker" : "o",
                          "label" : "D401E Josiah"
                      },
                  ]
                 ),
        (4, 0) : ("37.2D",
                  [
                      {
                          "serum" : "37.2D", 
                          "virus" : "unmutated Josiah",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[0],
                          "marker" : "o",
                          "label" : "unmutated Josiah"
                      },
                      {
                          "serum" : "37.2D", 
                          "virus" : "GA391",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[1],
                          "marker" : "o",
                          "label" : "GA391"
                      },
                      {
                          "serum" : "37.2D", 
                          "virus" : "H398K Josiah",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[7],
                          "marker" : "o",
                          "label" : "H398K Josiah"
                      },
                  ]
                 ),
        (5, 0) : ("8.9F",
                  [
                      {
                          "serum" : "8.9F", 
                          "virus" : "unmutated Josiah",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[0],
                          "marker" : "o",
                          "label" : "unmutated Josiah"
                      },
                      {
                          "serum" : "8.9F", 
                          "virus" : "ISTH1024",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[8],
                          "marker" : "o",
                          "label" : "ISTH1024"
                      },
                      {
                          "serum" : "8.9F", 
                          "virus" : "K126N Josiah",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[10],
                          "marker" : "o",
                          "label" : "K126N Josiah"
                      },
                  ]
                 ),
    },
    sharex=True,
    sharey=False,
    xlabel="",
    ylabel="",
    attempt_shared_legend=False,
)


antibody_names = [
    "25.10C",
    "12.1F",
    "37.7H",
    "37.7H",
    "37.2D",
    "8.9F",
]

for index in range(6):
    if antibody_names[index] == "25.10C":
        axes[index,0].set_title(
            antibody_names[index], 
            weight="bold",
            fontsize=8,
            color="#44AA99",
        )
    elif antibody_names[index] == "12.1F":
        axes[index,0].set_title(
            antibody_names[index], 
            weight="bold",
            fontsize=8,
            color="#999933",
        )
    elif antibody_names[index] == "37.7H" or antibody_names[index] == "37.2D" :
        axes[index,0].set_title(
            antibody_names[index], 
            weight="bold",
            fontsize=8,
            color="#AA4499",
        )
    elif antibody_names[index] == "8.9F":
        axes[index,0].set_title(
            antibody_names[index], 
            weight="bold",
            fontsize=8,
            color="#117733",
        )
    axes[index,0].set_ylim(-0.1, 1.3)
    axes[index,0].set_yticks([0, 0.5, 1.0])
    axes[index,0].set_yticklabels(labels=[0, 0.5, 1.0], fontsize=8)
    axes[index,0].set_xlim(0.0005, 12.5)
    axes[index,0].set_xticks([0.001, 0.01, 0.1, 1, 10,])
    axes[index,0].set_xticklabels(labels=["$10^{-3}$", "$10^{-2}$", "$10^{-1}$", "$10^0$", "$10^1$",], fontsize=8)
    plt.setp(axes[index,0].collections, alpha=0.8, linewidths=0.5, colors="black") # for vertical error bar segment
    plt.setp(axes[index,0].lines, alpha=0.8, markeredgewidth=0.5, markeredgecolor="black", linewidth=1) # for the lines and markers
    sns.move_legend(
        axes[index,0], 
        loc="upper left",
        borderaxespad=0,
        frameon=False,
        bbox_to_anchor=(1, 1),
        fontsize=8,
        markerscale=1,
        handletextpad=0.1,
        title="Lassa GPC",
        title_fontproperties={"weight" : "bold", "size" : 8},
        alignment="left"
    )

    # Add edges to legend markers to match scatter plot
    for ha in axes[index,0].legend_.legendHandles:
        ha.set_markeredgecolor("black")
        ha.set_markeredgewidth(0.5)
        ha.set_linewidth(0)
        
    # Change all spines
    for axis in ["top", "bottom", "left", "right"]:
        axes[index,0].spines[axis].set_linewidth(1)
    axes[index,0].tick_params(axis="both", length=4, width=1)


# Common X and Y axis labels
fig.text(0.5, -0.1, "concentration (\u03BCg/mL)", ha="center", fontsize=8)
fig.text(-0.05, 0.5, "fraction infectivity", va="center", rotation="vertical", fontsize=8)

width = 2
height = 7
fig.set_size_inches(width, height)

# Make output dir if doesn't exist
if not os.path.exists(out_dir):
    os.mkdir(out_dir)

# Save fig
plt.savefig(neuts_image_path)

We further calculate **cumulative summed** escape scores across all mutations present in each isolate even if the effects are small. However, this simple additive estimate probably does not correctly reflect the true impact of the combined mutations present in each isolate. These total DMS escape scores are then compared to log2 ic50 fold change values. 

In [None]:
# IC values to calculate
fitParams = fits.fitParams(ics=[50, 80, 99])[["serum", "virus", "ic50", "ic50_bound"]]

# Create dictionary for antibody ic50s
wt_ic50s = {
    "8.9F" : fitParams.query("serum == '8.9F' & virus == 'unmutated Josiah'").reset_index(drop=True).at[0,"ic50"],
    "12.1F" : fitParams.query("serum == '12.1F' & virus == 'unmutated Josiah'").reset_index(drop=True).at[0,"ic50"],
    "25.10C" : fitParams.query("serum == '25.10C' & virus == 'unmutated Josiah'").reset_index(drop=True).at[0,"ic50"],
    "37.7H" : fitParams.query("serum == '37.7H' & virus == 'unmutated Josiah'").reset_index(drop=True).at[0,"ic50"],
    "37.2D" : fitParams.query("serum == '37.2D' & virus == 'unmutated Josiah'").reset_index(drop=True).at[0,"ic50"],
}

# Calculate log2 fold change of IC50 values
fitParams["ic50 fold change"] = (
    fitParams.apply(lambda x: x["ic50"]/wt_ic50s[x["serum"]], axis=1)
)
fitParams["log2 ic50 fold change"] = np.log2(fitParams["ic50 fold change"])


# Merge isolate sequence and neut dfs
validation_df = (
    fitParams.merge(
        validation_isolates,
        how="left",
        on="virus",
        validate="many_to_one"
    )
)

# Add sinlge mutant sequences
validation_df.at[2, "sequence"] = (
    josiah_sequence[:int(validation_df.at[2,"virus"].split()[0][1:-1])-1] + validation_df.at[2,"virus"].split()[0][-1] + josiah_sequence[int(validation_df.at[2,"virus"].split()[0][1:-1]):]
)

validation_df.at[5, "sequence"] = (
    josiah_sequence[:int(validation_df.at[5,"virus"].split()[0][1:-1])-1] + validation_df.at[5,"virus"].split()[0][-1] + josiah_sequence[int(validation_df.at[5,"virus"].split()[0][1:-1]):]
)

validation_df.at[8, "sequence"] = (
    josiah_sequence[:int(validation_df.at[8,"virus"].split()[0][1:-1])-1] + validation_df.at[8,"virus"].split()[0][-1] + josiah_sequence[int(validation_df.at[8,"virus"].split()[0][1:-1]):]
)

validation_df.at[10, "sequence"] = (
    josiah_sequence[:int(validation_df.at[10,"virus"].split()[0][1:-1])-1] + validation_df.at[10,"virus"].split()[0][-1] + josiah_sequence[int(validation_df.at[10,"virus"].split()[0][1:-1]):]
)

validation_df.at[13, "sequence"] = (
    josiah_sequence[:int(validation_df.at[13,"virus"].split()[0][1:-1])-1] + validation_df.at[13,"virus"].split()[0][-1] + josiah_sequence[int(validation_df.at[13,"virus"].split()[0][1:-1]):]
)

validation_df.at[16, "sequence"] = (
    josiah_sequence[:int(validation_df.at[16,"virus"].split()[0][1:-1])-1] + validation_df.at[16,"virus"].split()[0][-1] + josiah_sequence[int(validation_df.at[16,"virus"].split()[0][1:-1]):]
)

# Antibody conversion dict
antibody_file_dict = {
    "25.10C" : filtered_escape_2510C,
    "12.1F" : filtered_escape_121F,
    "37.7H" : filtered_escape_377H,
    "25.6A" : filtered_escape_256A,
    "37.2D" : filtered_escape_372D,
    "8.9F" : filtered_escape_89F,
}

# Calculate total escape for each sequence using an additive model
validation_df["total_escape"] = (
    validation_df.apply(lambda x: determine_escape(
        0, 
        x["sequence"], 
        antibody_file_dict[x["serum"]], 
        x["virus"],
    ), axis=1)
)

# Rename columns and reformat columns
validation_df = (
    validation_df.rename(columns={
        "ic50_bound" : "lower bound",
        "serum" : "antibody",
        "virus" : "Lassa GPC"
    })
)
validation_df["lower bound"] = (
    validation_df["lower bound"].replace({"interpolated" : False, "lower" : True})
)

# Calculate correlation between predicted and measured
r, p = sp.stats.pearsonr(
    x=validation_df["log2 ic50 fold change"], 
    y=validation_df["total_escape"]
)
print(f"R={r}")
print(f"R^2={r**2}")

# Plot predicted vs measured
fig, ax = plt.subplots(figsize=(2.25, 2.25))
corr_chart = sns.scatterplot(
    data=validation_df,
    x="log2 ic50 fold change", 
    y="total_escape",
    hue="Lassa GPC",
    palette={
        "unmutated Josiah" : tol_muted_adjusted[0],
        "GA391" : tol_muted_adjusted[1],
        "LM395" : tol_muted_adjusted[3],
        "IRR007" : tol_muted_adjusted[6],
        "ISTH1024" : tol_muted_adjusted[8],
        "N89D Josiah" : tol_muted_adjusted[4],
        "K126N Josiah" : tol_muted_adjusted[10],
        "E228D Josiah": tol_muted_adjusted[2],
        "H398K Josiah" : tol_muted_adjusted[7],
        "D401E Josiah" : tol_muted_adjusted[5],
    },
    style="antibody",
    markers=["o", "s", "D", "v", "^", "P"],
    alpha=0.8,
    edgecolor="black",
    ax=ax,
    s=40,
)
# Change all spines
for axis in ["top", "bottom", "left", "right"]:
    corr_chart.spines[axis].set_linewidth(1)
sns.despine()
corr_chart.tick_params(axis="both", length=4, width=1)
corr_chart.set_xlabel(
    "log2 fold change IC$_{50}$ measured\nby pseuodvirus neutralization assay", 
    # weight="bold",
    fontsize=8,
)
corr_chart.set_ylabel(
    "escape score predicted\nby DMS (arbitrary units)", 
    # weight="bold",
    fontsize=8,
)
corr_chart.set_xlim(-1, 12.5)
corr_chart.set_ylim(-0.5, 5.5)
corr_chart.set_xticks([0, 2, 4, 6, 8, 10, 12])
corr_chart.set_xticklabels(corr_chart.get_xticks(), size=8)
corr_chart.set_yticks([0, 1, 2, 3, 4, 5])
corr_chart.set_yticklabels(corr_chart.get_yticks(), size=8)
sns.move_legend(
    corr_chart, 
    "upper left", 
    bbox_to_anchor=(1, 1),
    fontsize=8,
    markerscale=1,
    handletextpad=0.1,
    frameon=False,
    borderaxespad=0.75,
)
corr_chart.get_legend().get_texts()[0].set_weight("bold")
corr_chart.get_legend().get_texts()[11].set_weight("bold")
corr_chart.text(
    0, 
    5, 
    f"r={r:.2f}", 
    horizontalalignment="left",  
    weight="bold",
    fontsize=8,
)

# Add edges to legend markers to match scatter plot
for ha in corr_chart.legend_.legendHandles:
    ha.set_edgecolor("black")
    ha.set_linewidths(0.5)
        
# Set square ratio
corr_chart.set_box_aspect(1)

# Make output dir if doesn't exist
if not os.path.exists(out_dir):
    os.mkdir(out_dir)

# Save fig
plt.savefig(corr_image_path)

## Correlation of antibody escape to functional scores

Next, we are going to compare the DMS data to natural sequence diversity. In this case, natural sequence diversity is reflected by calculating effective amino acids at each site based on all high quality available Lassa GPC sequences. The formula for calculating effective amino acids is described in *Biophysical Models of Protein Evolution: Understanding the Patterns of Evolutionary Sequence Divergence*

In [None]:
# Load data as dataframe
functional_scores = pd.read_csv(func_scores)

# Filter for minimum selections, times seen and no stop codons
functional_scores = (
    functional_scores.query(
        "n_selections >= @n_selections and times_seen >= @min_times_seen and mutant != '*'"
    )
    .drop(columns=["mutant", "times_seen"])
    .groupby(["site", "wildtype"])
    .aggregate({
        "effect" : "mean"
    })
    .reset_index()
)

# Load data as dataframe
natural_variation = pd.read_csv(natural_sequence_variation)

# Drop individual amino acid counts
natural_variation = natural_variation[["site", "entropy", "n_effective"]]

# Merge functional and natural dataframes
merged_df = (
    functional_scores.merge(
        natural_variation,
        how="left",
        on=["site"],
        validate="one_to_one",
    )
)

# Add escape to dataframe for each antibody
for antibody_file,_ in antibody_files:

    antibody_name = antibody_file.split("/")[-1].split("_")[0]

    # Load data as dataframe
    escape_df = pd.read_csv(antibody_file)

    # Clip lower scores to 0
    escape_df["escape_median"] = escape_df["escape_median"].clip(lower=0)

    # Calculate summed escape sites and get sites for top escape
    escape_df = (
        escape_df
        .groupby("site")
        .aggregate({"escape_median" : "sum"})
        .reset_index()
    )

    # Rename escape column to include antibody name
    escape_df = escape_df.rename(columns={"escape_median" : "escape_" + antibody_name})

    # Merge dataframes
    merged_df = (
        merged_df.merge(
            escape_df,
            how="left",
            on=["site"],
            validate="one_to_one",
        )
    )

# Total summed escape per site for all antibodies
merged_df["total_escape"] = (
    merged_df[[
        "escape_2510C",
        "escape_121F",
        "escape_377H",
        "escape_256A",
        "escape_372D",
        "escape_89F"
    ]].median(axis=1)
)

# Sites of mutations for chosen validations
validation_sites = [89, 126, 228, 398, 401]
merged_df["site of validation"] = (
    merged_df["site"].apply(lambda x: True if x in validation_sites else False)
)

First, we look at the correlation of functional effects and antibody escape stratified by the different antibodies. As expected, the mutations that lead to escape tend to be more functionally tolerated. Furthermore, the mutations present in the natural isolates also tend to be functionally tolerated except for the notable exception of N89D which is quite deleterious. 


In [None]:
# Making two lists for values and colors 
dom = [True, False] 
rng = ["#EE7733FF", "#00000026"] 

subplots = []
for index,antibody in enumerate(antibody_files):
    antibody_name = antibody[0].split("/")[-1].split("_")[0]

    curr_subplot = alt.Chart(merged_df, title=antibody_name).mark_point(
        filled=True, 
        color="black", 
        size=75
    ).encode(
        alt.X(
            "escape_" + antibody_name,
            axis=alt.Axis(
                title="total site escape", 
                domainWidth=1,
                domainColor="black",
                tickColor="black",
            ),
            scale=alt.Scale(domainMin=-1)
        ),
        alt.Y(
            "effect",
            axis=alt.Axis(
                title="effect on cell entry", 
                values=[-4,-3,-2,-1,0,1],
                domainWidth=1,
                domainColor="black",
                tickColor="black",
            ),
            scale=alt.Scale(domain=[-4.5,1])
        ),
        tooltip=[
            "site",
            "wildtype",
            "effect",
            "escape_" + antibody_name,
        ],
        color=alt.Color(
            "site of validation", 
            scale=alt.Scale(domain=dom, range=rng),
        ), 
    ).properties(
        width=150,
        height=150,
    )
    
    subplots.append(curr_subplot)

func_vs_antibody = alt.hconcat(
    subplots[0],
    subplots[1],
    subplots[2],
    subplots[3],
    subplots[4],
    subplots[5],
    spacing=5,
    title="Correlations of functional effects and antibody escape",
).configure_axis(
    grid=False,
    labelFontSize=16,
    titleFontSize=16,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_title(
    fontSize=24,
)

func_vs_antibody

## Correlation of antibody escape to natural sequence diversity

Next, we look at the correlation of natural sequence diversity and antibody escape stratified by the different antibodies. The antibody 12.1F (and maybe 37.2D to a lesser extent) tends to be weakly escaped by mutations at sites with higher natural diversity. Of the chosen validations, site 126 tends to be most diverse but sites 228 and 398 also have increased diversity.

In [None]:
# Making two lists for values and colors 
dom = [True, False] 
rng = ["#EE7733FF", "#00000026"] 

subplots = []
for index,antibody in enumerate(antibody_files):
    antibody_name = antibody[0].split("/")[-1].split("_")[0]

    curr_subplot = alt.Chart(merged_df, title=antibody_name).mark_point(
        filled=True, 
        color="black", 
        size=75
    ).encode(
        alt.X(
            "escape_" + antibody_name,
            axis=alt.Axis(
                title="total site escape", 
                domainWidth=1,
                domainColor="black",
                tickColor="black",
            ),
            scale=alt.Scale(domainMin=-1)
        ),
        alt.Y(
            "n_effective",
            axis=alt.Axis(
                title=["effective amino acids", "in natural sequences"], 
                values=[1,2,3,4],
                domainWidth=1,
                domainColor="black",
                tickColor="black",
            ),
            scale=alt.Scale(domain=[0.9,4])
        ),
        tooltip=[
            "site",
            "wildtype",
            "n_effective",
            "escape_" + antibody_name,
        ],
        color=alt.Color(
            "site of validation", 
            scale=alt.Scale(domain=dom, range=rng),
        ), 
    ).properties(
        width=150,
        height=150,
    )
    
    subplots.append(curr_subplot)

func_vs_antibody = alt.hconcat(
    subplots[0],
    subplots[1],
    subplots[2],
    subplots[3],
    subplots[4],
    subplots[5],
    spacing=5,
    title="Correlations of natural sequence diversity and antibody escape",
).configure_axis(
    grid=False,
    labelFontSize=16,
    titleFontSize=16,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_title(
    fontSize=24,
)

func_vs_antibody

## Correlation of summary metrics across all mapped antibodies

Finally, we are going to look at summary correlations for antibody escape, functional effects on cell entry, and natural sequence diversity. The total antibody escape was calculated by summing escape at every site across all antibodies mapped.

In [None]:
# Making two lists for values and colors 
dom = [True, False] 
rng = ["#EE7733FF", "#00000026"] 

# Calculate statistics
r, p = sp.stats.pearsonr(
    merged_df[["total_escape","effect"]].dropna()["total_escape"],
    merged_df[["total_escape","effect"]].dropna()["effect"],
)
print(f"r correlation of total escape and effect on cell entry: {r:.2f}")

effect_vs_antibody = alt.Chart(merged_df, title=f"r = {r:.2f}").mark_point(
    filled=True, 
    color="black", 
    size=75,
).encode(
    alt.X(
        "total_escape",
        axis=alt.Axis(
            title=["median site escape across", "all mapped antibodies"],  
            values=[0,1,2,3,4,5],
            domainWidth=1,
            domainColor="black",
            tickColor="black",
        ),
        scale=alt.Scale(domain=[-0.1, 5.1])
    ),
    alt.Y(
        "effect",
        axis=alt.Axis(
            title="effect on cell entry", 
            values=[-4,-3,-2,-1,0,1],
            domainWidth=1,
            domainColor="black",
            tickColor="black",
        ),
        scale=alt.Scale(domain=[-4.5,1])
    ),
    tooltip=[
        "site",
        "wildtype",
        "effect",
        "escape_2510C",
        "escape_121F",
        "escape_89F",
        "escape_377H",
        "escape_372D",
        "escape_256A",
        "total_escape",
    ],
    color=alt.Color(
        "site of validation", 
        scale=alt.Scale(domain=dom, range=rng),
    ), 
).properties(
    width=300,
    height=300,
)

r, p = sp.stats.pearsonr(
    merged_df[["total_escape","n_effective"]].dropna()["total_escape"],
    merged_df[["total_escape","n_effective"]].dropna()["n_effective"],
)
print(f"r correlation of total escape and natural sequence diversity: {r:.2f}")

natural_vs_antibody = alt.Chart(merged_df, title=f"r = {r:.2f}").mark_point(
    filled=True, 
    color="black", 
    size=75,
).encode(
    alt.Y(
        "total_escape",
        axis=alt.Axis(
            title=["median site escape across", "all mapped antibodies"], 
            values=[0,1,2,3,4,5],
            domainWidth=1,
            domainColor="black",
            tickColor="black",
        ),
        scale=alt.Scale(domain=[-0.1, 5.1])
    ),
    alt.X(
        "n_effective",
        axis=alt.Axis(
            title=["effective amino acids", "in natural sequences"], 
            values=[1,2,3,4],
            domainWidth=1,
            domainColor="black",
            tickColor="black",
        ),
        scale=alt.Scale(domain=[0.9,4])
    ),
    tooltip=[
        "site",
        "wildtype",
        "n_effective",
        "escape_2510C",
        "escape_121F",
        "escape_89F",
        "escape_377H",
        "escape_372D",
        "escape_256A",
        "total_escape",
    ],
    color=alt.Color(
        "site of validation", 
        scale=alt.Scale(domain=dom, range=rng),
    ), 
).properties(
    width=300,
    height=300,
)

r, p = sp.stats.pearsonr(
    merged_df[["effect","n_effective"]].dropna()["effect"],
    merged_df[["effect","n_effective"]].dropna()["n_effective"],
)
print(f"r correlation of natural sequence diversity and effect on cell entry: {r:.2f}")

natural_vs_func = alt.Chart(merged_df, title=f"r = {r:.2f}").mark_point(
    filled=True, 
    color="black", 
    size=75,
).encode(
    alt.X(
        "n_effective",
        axis=alt.Axis(
            title=["effective amino acids", "in natural sequences"], 
            values=[1,2,3,4],
            domainWidth=1,
            domainColor="black",
            tickColor="black",
        ),
        scale=alt.Scale(domain=[0.9,4])
    ),
    alt.Y(
        "effect",
        axis=alt.Axis(
            title="effect on cell entry", 
            values=[-4,-3,-2,-1,0,1],
            domainWidth=1,
            domainColor="black",
            tickColor="black",
        ),
        scale=alt.Scale(domain=[-4.5,1])
    ),
    tooltip=[
        "site",
        "wildtype",
        "effect",
        "n_effective",
        "escape_2510C",
        "escape_121F",
        "escape_89F",
        "escape_377H",
        "escape_372D",
        "escape_256A",
        "total_escape",
    ],
    color=alt.Color(
        "site of validation", 
        scale=alt.Scale(domain=dom, range=rng),
    ), 
).properties(
    width=300,
    height=300,
)

summary_corr_plot = alt.hconcat(
    effect_vs_antibody,
    natural_vs_antibody,
    natural_vs_func,
    spacing=5,
    title="Correlations of summary statistics",
).configure_axis(
    grid=False,
    labelFontSize=16,
    titleFontSize=16,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_title(
    fontSize=24,
).configure_legend(
    labelFontSize=16,
    titleFontSize=16,
)

summary_corr_plot