# Compare DMS data to natural sequence data

This notebook analyzes natural sequences for antibody escape predicted by DMS and shows the resulting neutralization validations. Additional analysis is performed by correlating natural sequence diversity with DMS data.

In [None]:
# Imports
import os
import warnings
import dmslogo
import neutcurve
import numpy as np
import scipy as sp
import pandas as pd
import altair as alt
import seaborn as sns
import matplotlib.colors
from Bio import SeqIO, AlignIO 
from matplotlib import pyplot as plt
from matplotlib import ticker as mticker

# Plotting colors
tol_muted_adjusted = [
    "#000000",
    "#CC6677", 
    "#1f78b4", 
    "#DDCC77", 
    "#117733", 
    "#882255", 
    "#88CCEE",
    "#44AA99", 
    "#999933", 
    "#AA4499", 
    "#EE7733",
    "#CC3311",
    "#DDDDDD",
]

# Create color palette
def color_gradient_hex(start, end, n):
    """Color function from polyclonal"""
    cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
            name="_", colors=[start, end], N=n
        )
    return [matplotlib.colors.rgb2hex(tup) for tup in cmap(list(range(0, n)))]

# Seaborn style settings
sns.set(rc={
    "figure.dpi":300, 
    "savefig.dpi":300,
    "svg.fonttype":"none",
})
sns.set_style("ticks")
sns.set_palette(tol_muted_adjusted)

# Suppress warnings
warnings.simplefilter("ignore")

# Allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

In [None]:
# this cell is tagged as `parameters` for papermill parameterization
filtered_escape_377H = None
filtered_escape_89F = None
filtered_escape_2510C = None
filtered_escape_121F = None
filtered_escape_256A = None
filtered_escape_372D = None

contacts_89F = None
contacts_377H = None
contacts_256A = None
contacts_2510C = None
contacts_121F = None
contacts_372D = None

func_scores = None
min_times_seen = None
n_selections = None

natural_sequence_variation = None
natural_GPC_sequence_alignment = None

fraction_infected_natural_isolates = None
out_dir = None
neuts_image_path = None
corr_image_path = None

out_dir_escape = None
escape_top10_image_path = None
escape_all_image_path = None

out_dir_natural = None
natural_escape = None
total_natural_escape = None

html_func_vs_natural = None
html_nat_mut_freqs_vs_escape = None
html_nat_mut_freqs_vs_escape_all_abs = None
html_natural_vs_escape = None
html_natural_vs_escape_all_abs = None
html_arevirumab_comparisons = None

In [None]:
# # Uncomment for running interactive
# filtered_escape_377H = "../results/filtered_antibody_escape_CSVs/377H_filtered_mut_effect.csv"
# filtered_escape_89F = "../results/filtered_antibody_escape_CSVs/89F_filtered_mut_effect.csv"
# filtered_escape_2510C = "../results/filtered_antibody_escape_CSVs/2510C_filtered_mut_effect.csv"
# filtered_escape_121F = "../results/filtered_antibody_escape_CSVs/121F_filtered_mut_effect.csv"
# filtered_escape_256A = "../results/filtered_antibody_escape_CSVs/256A_filtered_mut_effect.csv"
# filtered_escape_372D = "../results/filtered_antibody_escape_CSVs/372D_filtered_mut_effect.csv"

# contacts_89F = "../data/antibody_contacts/antibody_contacts_89F.csv"
# contacts_377H = "../data/antibody_contacts/antibody_contacts_377H.csv"
# contacts_256A = "../data/antibody_contacts/antibody_contacts_256A.csv"
# contacts_2510C = "../data/antibody_contacts/antibody_contacts_2510C.csv"
# contacts_121F = "../data/antibody_contacts/antibody_contacts_121F.csv"
# contacts_372D = "../data/antibody_contacts/antibody_contacts_372D.csv"

# func_scores = "../results/func_effects/averages/293T_entry_func_effects.csv"
# min_times_seen = 2
# n_selections = 8

# natural_sequence_variation = "../non-pipeline_analyses/LASV_phylogeny_analysis/Results/GPC_protein_variation.csv"
# natural_GPC_sequence_alignment = "../non-pipeline_analyses/LASV_phylogeny_analysis/Results/LASV_GPC_protein_alignment.fasta"

# fraction_infected_natural_isolates = "../data/validation_frac_infected_natural_isolates.csv"
# out_dir = "../results/validation_plots/"
# neuts_image_path = "../results/validation_plots/validation_neut_curves_natural_isolates.svg"
# corr_image_path = "../results/validation_plots/natural_isolate_validation_correlation.svg"

# out_dir_escape = "../results/antibody_escape_profiles/"
# escape_top10_image_path = "../results/antibody_escape_profiles/natural_isolate_top10_escape_profiles.svg"
# escape_all_image_path = "../results/antibody_escape_profiles/natural_isolate_all_escape_profiles.svg"

# out_dir_natural = "../results/natural_isolate_escape/"
# natural_escape = "../results/natural_isolate_escape/natural_isolate_escape.svg"
# total_natural_escape = "../results/natural_isolate_escape/total_natural_site_escape.svg"

# html_func_vs_natural = "../results/natural_isolate_escape/func_vs_natural.html"
# html_nat_mut_freqs_vs_escape = "../results/natural_isolate_escape/nat_mut_freqs_vs_escape.html"
# html_nat_mut_freqs_vs_escape_all_abs = "../results/natural_isolate_escape/nat_mut_freqs_vs_escape_all_abs.html"
# html_natural_vs_escape = "../results/natural_isolate_escape/natural_vs_escape.html"
# html_natural_vs_escape_all_abs = "../results/natural_isolate_escape/natural_vs_escape_all_abs.html"
# html_arevirumab_comparisons = "../results/antibody_escape_profiles/arevirumab_comparisons.html"

## Identify natural sequences that would potentially escape antibody neutralization

To identify any natural isolates that could potentially escape any of the mapped antibodies, we want to identify high confidence escape mutations. We are going to use the following filters to identiy high confidence escape mutations: 
1. First, we are going to filter for the top 10% escape mutants for each antibody selection
2. Second, we are going to further filter the list of mutations by keeping only mutations present in sites of  **strong escape** (e.g., summed escape sites 10 fold greater than the median of all summed sites)
3. Finally, we are going to search all high quality sequences for any of the mutations identified as high confidence escape mutations.

Other details of note are that escape scores are clipped at the lower end of 0 to focus on escape mutations rather than sensitizing mutations. Any sequence with these mutations will be flagged as potential escape isolates. This conservative method will filter many escape mutations but should result in higher confidence mutations.

In [None]:
def determine_escape(percentile_escape, sequence, escape_file, strain, print_results=False, strong_escape_site=False):
    """
    Function that determines if a sequence contains any 
    escape mutations given a percentile cutoff.
    """

    antibody_name = escape_file.split("/")[-1].split("_")[0]
    
    # Load data as dataframe
    escape_df = pd.read_csv(escape_file)

    # Filter escape df for low functional score mutations
    escape_df = escape_df.query("poor_cell_entry == False")

    # Clip lower scores to 0
    escape_df["escape_median"] = escape_df["escape_median"].clip(lower=0)

    # Get muts for top escape
    cutoff = escape_df["escape_median"].quantile(percentile_escape)
    top_escape_muts = (
        tuple(zip(
            escape_df.loc[escape_df["escape_median"] >= cutoff]["mutation"].tolist(),
            escape_df.loc[escape_df["escape_median"] >= cutoff]["escape_median"].tolist(),
        ))
    )

    # Calculate summed escape sites and get sites for top escape
    # either by calculating 10-fold median or top percentile
    escape_df = (
        escape_df
        .groupby("site")
        .aggregate({"escape_median" : "sum"})
        .reset_index()
    )
    if strong_escape_site == True:
        cutoff = escape_df["escape_median"].median() * 10
    else:
        cutoff = escape_df["escape_median"].quantile(percentile_escape)
    top_escape_sites = escape_df.loc[escape_df["escape_median"] >= cutoff]["site"].tolist()

    # Filter top escape muts bases on top escape sites
    top_escape_muts = [x for x in top_escape_muts if int(x[0][1:-1]) in top_escape_sites]
    
    # Initialize escape flag
    escape = 0
    
    # Iterate through list of escape mutations
    for escape_mut in top_escape_muts:
        site = int(escape_mut[0][1:-1])
        if sequence[site-1] == escape_mut[0][-1]:
            if print_results:
                print(f"{strain:<75} with \t {escape_mut[0][0]}{site}{escape_mut[0][-1]} \t DMS score: {escape_mut[1]}")
            escape += escape_mut[1]
   
    return escape

# Load alignment and metadata info
natural_seqs_df = pd.DataFrame(columns=["strain", "sequence"])

# Add alignment sequence to dataframe
for curr_fasta in AlignIO.read(natural_GPC_sequence_alignment, "fasta"):
    natural_seqs_df.loc[len(natural_seqs_df.index)] = [
        str(curr_fasta.id),
        str(curr_fasta.seq),
    ] 

# Antibody escape dataframes and percentile cutoffs to use
antibody_files = [
    (filtered_escape_2510C, 0.9),
    (filtered_escape_121F, 0.9), 
    (filtered_escape_377H, 0.9),
    (filtered_escape_256A, 0.9),
    (filtered_escape_372D, 0.9),
    (filtered_escape_89F, 0.9), 
]

for antibody, percentile in antibody_files:
    antibody_name = antibody.split("/")[-1].split("_")[0]
    print(f"{antibody_name} potentially escaped by:")
    natural_seqs_df.apply(lambda x: determine_escape(
        percentile, 
        x["sequence"], 
        antibody, 
        x["strain"],
        print_results=True,
        strong_escape_site=True,
    ), axis=1)
    print()

The following were chosen for validation:
- 8.9F
    - Natural isolate: Lassa_virus_H-sapiens-wt_NGA_2018_ISTH_1024_MH157037_2018-02-14
        - Corresponding single mutant: K126N
- 12.1F
    - Natural isolate: LM395-SLE-2009_KM822115_2009-XX-XX
        - Corresponding single mutant: N89D
- 25.10C
    - Natural isolate: GA391_OL774861_reverse_complement_1977-XX-XX 
        - Corresponding single mutant: E228D
- 37.7H
    - Natural isolate: GA391_OL774861_reverse_complement_1977-XX-XX 
        - Corresponding single mutant: H398K
    - Natural isolate: LASV_H-sapiens-tc_NGA_2016_IRR_007_MK107922_2016-01-18
        - Corresponding single mutant: D401E 
- 37.2D
    - Natural isolate: GA391_OL774861_reverse_complement_1977-XX-XX
        - Corresponding single mutant: H398K

*The D401E mutation present in the chosen natural isolates did not have high confidence in the DMS data for the 25.6A antibody selection so 25.6A will be omitted from validation neutralization assays. In addition, it has been noted that antibodies 37.7H and 25.6A are very similar so both antibodies would be expected to have similar results.

In [None]:
# Isolates chosen to validate for escape 
chosen_isolates = [
    "Josiah_NC_004296_reverse_complement_2018-08-13",
    "GA391_OL774861_reverse_complement_1977-XX-XX",
    "LASV_H-sapiens-tc_NGA_2016_IRR_007_MK107922_2016-01-18",
    "Lassa_virus_H-sapiens-wt_NGA_2018_ISTH_1024_MH157037_2018-02-14",
    "LM395-SLE-2009_KM822115_2009-XX-XX",
]

# Create subset of df for chosen isolates
validation_isolates = (
    natural_seqs_df.loc[natural_seqs_df["strain"].isin(chosen_isolates)]
    .reset_index(drop=True)
)

# Get Josiah sequence for comparison
josiah_sequence = validation_isolates.loc[validation_isolates["strain"] == "Josiah_NC_004296_reverse_complement_2018-08-13"].at[0,"sequence"]

# Rename dictionary for virus names
rename_dict = {
    "WT" : "unmutated Josiah strain",
    "OL774861" : "GA391 strain",
    "KM822115" : "LM395 strain",
    "MK107922" : "IRR007 strain",
    "MH157037" : "ISTH1024 strain",
    "Josiah_NC_004296_reverse_complement_2018-08-13" : "unmutated Josiah strain",
    "GA391_OL774861_reverse_complement_1977-XX-XX" : "GA391 strain",
    "LASV_H-sapiens-tc_NGA_2016_IRR_007_MK107922_2016-01-18" : "IRR007 strain",
    "Lassa_virus_H-sapiens-wt_NGA_2018_ISTH_1024_MH157037_2018-02-14" : "ISTH1024 strain",
    "LM395-SLE-2009_KM822115_2009-XX-XX" : "LM395 strain",
    "N89D" : "N89D Josiah strain",
    "K126N" : "K126N Josiah strain",
    "E228D" : "E228D Josiah strain",
    "H398K" : "H398K Josiah strain",
    "D401E" : "D401E Josiah strain",
}

# Rename viruses and column name
validation_isolates["strain"] = (
    validation_isolates["strain"].replace(rename_dict)    
)

# Get counts of variants for each validation site
site_228_count = 0
site_89_count = 0
site_398_count = 0
site_401_count = 0
site_126_count = 0
site_228D_count = 0
site_89D_count = 0
site_398K_count = 0
site_401E_count = 0
site_126N_count = 0
all_position_variants = [0] * 491
all_position_variants_list =  [ [] for _ in range(491) ]
for seq in natural_seqs_df["sequence"].tolist():
    if seq[227] != "E":
        site_228_count += 1
        if seq[227] == "D":
            site_228D_count += 1
    if seq[88] != "N":
        site_89_count += 1
        if seq[88] == "D":
            site_89D_count += 1
    if seq[397] != "H":
        site_398_count += 1
        if seq[397] == "K":
            site_398K_count += 1
    if seq[400] != "D":
        site_401_count += 1
        if seq[400] == "E":
            site_401E_count += 1
    if seq[125] != "K":
        site_126_count += 1
        if seq[125] == "N":
            site_126N_count += 1
    for i in range(491):
        if seq[i] != josiah_sequence[i] and seq[i] != "-":
            all_position_variants[i] += 1
            all_position_variants_list[i].append(josiah_sequence[i] + str(i+1) + seq[i])

# Print counts of each variant for chosen validations
print(f"Total number of sequences analyzed: {len(natural_seqs_df['sequence'].tolist())}")
print(f"Total number of non-Josiah strain variants at position 228: {site_228_count} (E228D variants: {site_228D_count})")
print(f"Total number of non-Josiah strain variants at position 89: {site_89_count} (N89D variants: {site_89D_count})")
print(f"Total number of non-Josiah strain variants at position 398: {site_398_count} (H398K variants: {site_398K_count})")
print(f"Total number of non-Josiah strain variants at position 401: {site_401_count} (D401E variants: {site_401E_count})")
print(f"Total number of non-Josiah strain variants at position 126: {site_126_count} (K126N variants: {site_126N_count})")
print()

# Print counts of variants for top escape sites for Arevirumab-3
print("Variants at top escape sites for Arevirumab-3 antibodies:")
for i in [92, 111, 117, 119, 120, 121, 123, 124, 125, 127, 129, 133, 134, 135, 138, 147, 148, 149, 150, 153, 160, 248, 253, 254, 395]:
    print(f"Number variants at {i}: {all_position_variants[i-1]} {all_position_variants_list[i-1]}")

## Escape profiles for the sites that differ in the chosen isolates from the Josiah DMS strain

Create escape logo plots for all sites that differ between isolate and Josiah strain as well as top 10 total escape sites that differ between isolate and Josiah.

In [None]:
def create_line_and_logoplots(
    escape_file,  
    contacts_file,
    func_scores, 
    min_times_seen,  
    n_selections, 
    logo_plot,
    output_file = None,
    sites = None,
    name = None,
    amino_acids_to_color = None,
    only_top_10_sites=False,
):
    """
    Function that filters and writes an antibody escape csv.
    Also creates summed escape profiles and logoplots.
    """

    antibody_name = escape_file.split("/")[-1].split("_")[0]

    # Read data
    escape_df = pd.read_csv(escape_file)
    func_scores = pd.read_csv(func_scores)
    contacts_df = pd.read_csv(contacts_file)
    
    # Create mutation column to match antibody df
    func_scores["site"] = func_scores["site"].astype(str)
    func_scores["mutation"] = func_scores["wildtype"] + func_scores["site"] + func_scores["mutant"]

    # Clip lower scores to 0
    escape_df["escape_median"] = escape_df["escape_median"].clip(lower=0)

    # Summed escape to get top escape sites to show
    summed_df = (
        escape_df
        .groupby(["site", "wildtype"])
        .aggregate({
            "escape_median" : "sum",
        })
        .rename(columns={"escape_median" : "site_escape"})
        .reset_index()
    )
    
    # Top escape sites for each antibody combined to show
    if sites == None:
        sites = sorted(summed_df.nlargest(18, "site_escape")["site"].tolist())
    # Filter for top 10 sites if true
    if only_top_10_sites == True:
        summed_score_dict = dict(zip(summed_df["site"].tolist(), summed_df["site_escape"]))
        filtered_sites = (
            [x[0] for x in sorted([(x, summed_score_dict[x]) for x in sites], key = lambda x: x[1], reverse=True)[0:10]]
        )
        escape_df["show_site"] = escape_df.apply(lambda x: True if (x["site"] in sites and x["site"] in filtered_sites) else False, axis=1)
    else:
        escape_df["show_site"] = escape_df.apply(lambda x: True if x["site"] in sites else False, axis=1)
    
    # Shade contact sites in logo plot
    shade_sites = list(contacts_df.loc[contacts_df["distance"] == 4]["position"].unique())
    
    # **
    # # Uncomment to show antibody contacts
    # print(antibody_name)
    # print(f"Contact sites: {shade_sites}")
    # print()
    # **
    
    escape_df["shade_site"] = escape_df.apply(lambda x: "#DDCC77" if x["site"] in shade_sites else None, axis=1)
    escape_df["shade_alpha"] = 0.35
    escape_df = (
        escape_df.merge(
            summed_df,
            how="left",
            on=["site", "wildtype"],
            validate="many_to_one",
        )
    )

    # Drop extra columns
    escape_df = (
        escape_df.drop(
            escape_df.columns.difference([
                "site", 
                "wildtype",
                "mutant", 
                "escape_median", 
                "show_site", 
                "shade_site",
                "shade_alpha",
                "site_escape", 
            ]), axis=1)
    )

    # Fill in missing sites
    seen_list = [False]*491
    for index in range(len(escape_df.index)):
        site = escape_df.at[index, "site"] - 1 
        seen_list[site] = True
    for index, seen in enumerate(seen_list):
        if seen == False:
            site = index + 1
            # Add missing sites
            escape_df.loc[len(escape_df.index)] = [
                site,
                "X",
                "X",
                0,
                False,
                None,
                None,
                0,
            ]

    # Sort by site
    escape_df = (
        escape_df
        .sort_values(by="site")
        .astype({"mutant" : "str"})
        .reset_index(drop=True)
    )

    # Merge functional and escape dfs
    func_scores["site"] = func_scores["site"].astype("int")
    func_scores = func_scores.loc[func_scores["mutant"] != "*"] # remove stop codons
    func_scores["effect"] = func_scores["effect"].clip(upper=0, lower=-2) # clip scores 
    escape_df = (
        escape_df.merge(
            func_scores,
            how="left",
            on=["site", "wildtype", "mutant"],
            validate="one_to_one",
        )
    )
    escape_df["effect"] = escape_df["effect"].fillna(-2) # missing functional values are filled as -4 to make less visible
    
    # Color specified mutants or color by functional effect
    if amino_acids_to_color != None:
        
        escape_df["color"] = (
            escape_df.apply(
                lambda x: "#EE7733" if (x["site"], x["mutant"]) in amino_acids_to_color else "#000000", axis=1
            )  
        )
    else:
        # Add color column for logo plots
        func_color_map = dmslogo.colorschemes.ValueToColorMap(
            minvalue=func_scores["effect"].min(),
            maxvalue=func_scores["effect"].max(),
            cmap=matplotlib.colors.ListedColormap(color_gradient_hex("white", "#000000", n=20))
        )
        escape_df = (
            escape_df.assign(
                color=lambda x: x["effect"].map(func_color_map.val_to_color)
            )
        )

    # Add wildtype to each site for logo plot
    escape_df["wt_site"] = escape_df["wildtype"] + escape_df["site"].map(str)

    # Set ylim for each antibody
    fixed_ymin = None
    fixed_ymax = None
    if antibody_name == "2510C":
        fixed_ymin = -6.875
        fixed_ymax = 61.875
    elif antibody_name == "121F":
        fixed_ymin = -1.25
        fixed_ymax = 11.25
    elif antibody_name == "377H" and len(sites) == 33:
        fixed_ymin = -1.25
        fixed_ymax = 11.25
    elif antibody_name == "377H" and len(sites) == 31:
        fixed_ymin = -5 
        fixed_ymax = 45
    elif antibody_name == "256A":
        fixed_ymin = -5 
        fixed_ymax = 45
    elif antibody_name == "372D":
        fixed_ymin = -0.625
        fixed_ymax = 5.625
    elif antibody_name == "89F":
        fixed_ymin = -1.25
        fixed_ymax = 11.25
    else:
        print("Error! No ylims set!")

    

    _, logoplot = dmslogo.draw_logo(
        escape_df.query("show_site == True"),
        x_col="site",
        letter_col="mutant",
        letter_height_col="escape_median",
        ax=logo_plot,
        xtick_col="wt_site",
        color_col="color",
        shade_color_col="shade_site",
        shade_alpha_col="shade_alpha",
        draw_line_at_zero="never",
        fixed_ymin=fixed_ymin,
        fixed_ymax=fixed_ymax,
    )

    logoplot.set(ylabel=None, xlabel=None)
    x_labels = logoplot.get_xticklabels()
    logoplot.set_xticklabels(labels=x_labels, rotation=90, horizontalalignment="center", fontsize=6)
    # Change all spines
    for axis in ["top", "bottom", "left", "right"]:
        logoplot.spines[axis].set_linewidth(1)
    logoplot.tick_params(axis="both", length=2, width=1, pad=1)


    # Set antibody specific y axis ticks
    if antibody_name == "2510C":
        yticks = [0, 25, 50]
        logoplot.set_yticks(yticks)
        logoplot.set_yticklabels(labels=["0", "25", "50"], fontsize=6)
        logoplot.set_title(
            antibody_name[0:2] + "." + antibody_name[2:] + " escape for sites different between Josiah and " + name + "s", 
            fontsize=8,
            color="#44AA99",
        )
    if antibody_name == "121F":
        yticks = [0, 5, 10]
        logoplot.set_yticks(yticks)
        logoplot.set_yticklabels(labels=["0", "5", "10"], fontsize=6)
        logoplot.set_title(
            antibody_name[0:2] + "." + antibody_name[2:] + " escape for sites different between Josiah and " + name + "s", 
            fontsize=8,
            color="#999933",
        )
    if antibody_name == "377H" and len(sites) == 33:
        yticks = [0, 5, 10]
        logoplot.set_yticks(yticks)
        logoplot.set_yticklabels(labels=["0", "5", "10"], fontsize=6)
        logoplot.set_title(
            antibody_name[0:2] + "." + antibody_name[2:] + " escape for sites different between Josiah and " + name + "s", 
            fontsize=8,
            color="#AA4499",
        )
    if antibody_name == "377H" and len(sites) == 31:
        yticks = [0, 20, 40]
        logoplot.set_yticks(yticks)
        logoplot.set_yticklabels(labels=["0", "20", "40"], fontsize=6)
        logoplot.set_title(
            antibody_name[0:2] + "." + antibody_name[2:] + " escape for sites different between Josiah and " + name + "s", 
            fontsize=8,
            color="#AA4499",
        )
    if antibody_name == "256A":
        yticks = [0, 20, 40]
        logoplot.set_yticks(yticks)
        logoplot.set_yticklabels(labels=["0", "20", "40"], fontsize=6)
        logoplot.set_title(
            antibody_name[0:2] + "." + antibody_name[2:] + " escape for sites different between Josiah and " + name + "s", 
            fontsize=8,
            color="#AA4499",
        )
    if antibody_name == "372D":
        yticks = [0, 2.5, 5]
        logoplot.set_yticks(yticks)
        logoplot.set_yticklabels(labels=["0", "2.5", "5"], fontsize=6)
        logoplot.set_title(
            antibody_name[0:2] + "." + antibody_name[2:] + " escape for sites different between Josiah and " + name + "s", 
            fontsize=8,
            color="#AA4499",
        )
    if antibody_name == "89F":
        yticks = [0, 5, 10]
        logoplot.set_yticks(yticks)
        logoplot.set_yticklabels(labels=["0", "5", "10"], fontsize=6)
        logoplot.set_title(
            antibody_name[0] + "." + antibody_name[1:] + " escape for sites different between Josiah and " + name + "s",
            fontsize=8,
            color="#117733",
        )

In [None]:
def get_site_differences(seq1, seq2):
    """
    Function that returns a list of all sites
    that are different between two sequences.
    """
    list_of_sites = []
    list_of_muts = []
    site = 1
    for s1, s2 in zip(seq1, seq2):
        if s1 != s2 and s1 != "-" and s2 != "-":
            list_of_sites.append(site)
            list_of_muts.append(s2)
        site += 1
    return list_of_sites, tuple(zip(list_of_sites, list_of_muts))

# Get list of sequence differences compared to Josiah
list_of_different_sites = []
for index in validation_isolates.index:
    curr_seq_name = validation_isolates.at[index, "strain"]
    curr_seq = validation_isolates.at[index, "sequence"]
    site_differences, mut_differences = get_site_differences(josiah_sequence, curr_seq)
    list_of_different_sites.append((curr_seq_name, site_differences, mut_differences))

# Antibody contact files and escape files
contacts_files = [
    contacts_2510C,
    contacts_121F,
    contacts_377H,
    # contacts_256A,
    contacts_377H, # duplicate 37.7H because validated with two strains
    contacts_372D,
    contacts_89F,
]

filtered_escape_files = [
    filtered_escape_2510C,
    filtered_escape_121F,
    filtered_escape_377H,
    # filtered_escape_256A,
    filtered_escape_377H, # duplicate 37.7H because validated with two strains
    filtered_escape_372D,
    filtered_escape_89F,
]

# Set figure size and subplots
fig, axes = plt.subplots(
    6, 
    1, 
    figsize=(3.5, 6), 
    # sharex="col"
)

# Iterate through list of antibody files
for i in range(len(antibody_files)):
    name = None
    sites = None
    AA_to_color = None
    if i == 0:
        name = list_of_different_sites[1][0]
        sites = list_of_different_sites[1][1]
        AA_to_color = list_of_different_sites[1][2]
    elif i == 1:
        name = list_of_different_sites[4][0]
        sites = list_of_different_sites[4][1]
        AA_to_color = list_of_different_sites[4][2]
    elif i == 2:
        name = list_of_different_sites[1][0]
        sites = list_of_different_sites[1][1]
        AA_to_color = list_of_different_sites[1][2]
    elif i == 3:
        name = list_of_different_sites[2][0]
        sites = list_of_different_sites[2][1]
        AA_to_color = list_of_different_sites[2][2]
    elif i == 4:
        name = list_of_different_sites[1][0]
        sites = list_of_different_sites[1][1]
        AA_to_color = list_of_different_sites[1][2]
    elif i == 5:
        name = list_of_different_sites[3][0]
        sites = list_of_different_sites[3][1]
        AA_to_color = list_of_different_sites[3][2]
    else:
        print("Error! Index out of range!")

    create_line_and_logoplots(
        filtered_escape_files[i], 
        contacts_files[i],
        func_scores, 
        min_times_seen, 
        n_selections, 
        axes[i],
        sites=sorted(sites),
        name=name,
        amino_acids_to_color=AA_to_color,
        only_top_10_sites=True,
    )

fig.subplots_adjust(
    left  = 0,  # the left side of the subplots of the figure
    right = 1,    # the right side of the subplots of the figure
    bottom = 0,   # the bottom of the subplots of the figure
    top = 1,      # the top of the subplots of the figure
    wspace = 0,   # the amount of width reserved for blank space between subplots
    hspace = 0.8,   # the amount of height reserved for white space between subplots
)

# Common Y axis labels
fig.text(-0.075, 0.5, "site escape", va="center", rotation="vertical", fontsize=8)
fig.text(0.5, -0.07, "site", ha="center", fontsize=8)

# Make output dir if doesn't exist
if not os.path.exists(out_dir_escape):
    os.mkdir(out_dir_escape)

# Save fig
plt.savefig(escape_top10_image_path)

In [None]:
# Set figure size and subplots
fig, axes = plt.subplots(
    6, 
    1, 
    figsize=(4.5, 6), 
    # sharex="col"
)

# Iterate through list of antibody files
for i in range(len(antibody_files)):
    name = None
    sites = None
    AA_to_color = None
    if i == 0:
        name = list_of_different_sites[1][0]
        sites = list_of_different_sites[1][1]
        AA_to_color = list_of_different_sites[1][2]
    elif i == 1:
        name = list_of_different_sites[4][0]
        sites = list_of_different_sites[4][1]
        AA_to_color = list_of_different_sites[4][2]
    elif i == 2:
        name = list_of_different_sites[1][0]
        sites = list_of_different_sites[1][1]
        AA_to_color = list_of_different_sites[1][2]
    elif i == 3:
        name = list_of_different_sites[2][0]
        sites = list_of_different_sites[2][1]
        AA_to_color = list_of_different_sites[2][2]
    elif i == 4:
        name = list_of_different_sites[1][0]
        sites = list_of_different_sites[1][1]
        AA_to_color = list_of_different_sites[1][2]
    elif i == 5:
        name = list_of_different_sites[3][0]
        sites = list_of_different_sites[3][1]
        AA_to_color = list_of_different_sites[3][2]
    else:
        print("Error! Index out of range!")

    create_line_and_logoplots(
        filtered_escape_files[i], 
        contacts_files[i],
        func_scores, 
        min_times_seen, 
        n_selections, 
        axes[i],
        sites=sorted(sites),
        name=name,
        amino_acids_to_color=AA_to_color,
        only_top_10_sites=False,
    )

fig.subplots_adjust(
    left  = 0,  # the left side of the subplots of the figure
    right = 1,    # the right side of the subplots of the figure
    bottom = 0,   # the bottom of the subplots of the figure
    top = 1,      # the top of the subplots of the figure
    wspace = 0,   # the amount of width reserved for blank space between subplots
    hspace = 0.8,   # the amount of height reserved for white space between subplots
)

# Common Y axis labels
fig.text(-0.075, 0.5, "site escape", va="center", rotation="vertical", fontsize=8)
fig.text(0.5, -0.07, "site", ha="center", fontsize=8)

# Make output dir if doesn't exist
if not os.path.exists(out_dir_escape):
    os.mkdir(out_dir_escape)

# Save fig
plt.savefig(escape_all_image_path)

## Pseudovirus neutralization validation assays for chosen natural isolates and corresponding single mutants

The chosen isolates were validated with tradition pseudovirus neutralization assays in addition to the corresponding single mutation identified as the main mutation for escape.

In [None]:
# Rename column name
validation_isolates = (
    validation_isolates.rename(columns={"strain" : "virus"})
)

# Read nuetralization data
frac_infected_natural_isolates = pd.read_csv(fraction_infected_natural_isolates)

# Rename viruses
frac_infected_natural_isolates["virus"] = (
    frac_infected_natural_isolates["virus"].replace(rename_dict)
)

# Fit hill curves using neutcurve
fits = neutcurve.curvefits.CurveFits(
    data=frac_infected_natural_isolates,
    fixbottom=0,
    fixtop=1,
)

In [None]:
# Plot neut data
fig, axes = fits.plotGrid(
    {
        (0, 0) : ("25.10C",
                  [
                      {
                          "serum" : "25.10C", 
                          "virus" : "unmutated Josiah strain",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[0],
                          "marker" : "o",
                          "label" : "unmutated\nJosiah"
                      },
                      {
                          "serum" : "25.10C", 
                          "virus" : "GA391 strain",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[12],
                          "marker" : "o",
                          "label" : "GA391"
                      },
                      {
                          "serum" : "25.10C", 
                          "virus" : "E228D Josiah strain",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[10],
                          "marker" : "o",
                          "label" : "E228D\nJosiah"
                      },
                  ]
                 ),
        (1, 0) : ("12.1F",
                  [
                      {
                          "serum" : "12.1F", 
                          "virus" : "unmutated Josiah strain",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[0],
                          "marker" : "o",
                          "label" : "unmutated\nJosiah"
                      },
                      {
                          "serum" : "12.1F", 
                          "virus" : "LM395 strain",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[12],
                          "marker" : "o",
                          "label" : "LM395"
                      },
                      {
                          "serum" : "12.1F", 
                          "virus" : "N89D Josiah strain",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[10],
                          "marker" : "o",
                          "label" : "*N89D\nJosiah"
                      },
                  ]
                 ),
        (2, 0) : ("37.7H",
                  [
                      {
                          "serum" : "37.7H", 
                          "virus" : "unmutated Josiah strain",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[0],
                          "marker" : "o",
                          "label" : "unmutated\nJosiah"
                      },
                      {
                          "serum" : "37.7H", 
                          "virus" : "GA391 strain",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[12],
                          "marker" : "o",
                          "label" : "GA391"
                      },
                      {
                          "serum" : "37.7H", 
                          "virus" : "H398K Josiah strain",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[10],
                          "marker" : "o",
                          "label" : "H398K\nJosiah"
                      },
                  ]
                 ),
        (3, 0) : ("37.7H",
                  [
                      {
                          "serum" : "37.7H", 
                          "virus" : "unmutated Josiah strain",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[0],
                          "marker" : "o",
                          "label" : "unmutated\nJosiah"
                      },
                      {
                          "serum" : "37.7H", 
                          "virus" : "IRR007 strain",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[12],
                          "marker" : "o",
                          "label" : "IRR007"
                      },
                      {
                          "serum" : "37.7H", 
                          "virus" : "D401E Josiah strain",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[10],
                          "marker" : "o",
                          "label" : "D401E\nJosiah"
                      },
                  ]
                 ),
        (4, 0) : ("37.2D",
                  [
                      {
                          "serum" : "37.2D", 
                          "virus" : "unmutated Josiah strain",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[0],
                          "marker" : "o",
                          "label" : "unmutated\nJosiah"
                      },
                      {
                          "serum" : "37.2D", 
                          "virus" : "GA391 strain",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[12],
                          "marker" : "o",
                          "label" : "GA391"
                      },
                      {
                          "serum" : "37.2D", 
                          "virus" : "H398K Josiah strain",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[10],
                          "marker" : "o",
                          "label" : "H398K\nJosiah"
                      },
                  ]
                 ),
        (5, 0) : ("8.9F",
                  [
                      {
                          "serum" : "8.9F", 
                          "virus" : "unmutated Josiah strain",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[0],
                          "marker" : "o",
                          "label" : "unmutated\nJosiah"
                      },
                      {
                          "serum" : "8.9F", 
                          "virus" : "ISTH1024 strain",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[12],
                          "marker" : "o",
                          "label" : "ISTH1024"
                      },
                      {
                          "serum" : "8.9F", 
                          "virus" : "K126N Josiah strain",
                          "replicate" : "average",
                          "color" : tol_muted_adjusted[10],
                          "marker" : "o",
                          "label" : "K126N\nJosiah"
                      },
                  ]
                 ),
    },
    sharex=True,
    sharey=False,
    xlabel="",
    ylabel="",
    attempt_shared_legend=False,
    despine=True,
)


antibody_names = [
    "25.10C",
    "12.1F",
    "37.7H",
    "37.7H",
    "37.2D",
    "8.9F",
]

for index in range(6):
    if antibody_names[index] == "25.10C":
        axes[index,0].set_title(
            "antibody " + antibody_names[index], 
            # weight="bold",
            fontsize=8,
            color="#44AA99",
        )
    elif antibody_names[index] == "12.1F":
        axes[index,0].set_title(
            "antibody " + antibody_names[index], 
            # weight="bold",
            fontsize=8,
            color="#999933",
        )
    elif antibody_names[index] == "37.7H" or antibody_names[index] == "37.2D" :
        axes[index,0].set_title(
            "antibody " + antibody_names[index], 
            # weight="bold",
            fontsize=8,
            color="#AA4499",
        )
    elif antibody_names[index] == "8.9F":
        axes[index,0].set_title(
            "antibody " + antibody_names[index], 
            # weight="bold",
            fontsize=8,
            color="#117733",
        )
    axes[index,0].set_ylim(-0.1, 1.3)
    axes[index,0].set_yticks([0, 0.5, 1.0])
    axes[index,0].set_yticklabels(labels=[0, 0.5, 1.0], fontsize=6)
    axes[index,0].set_xlim(0.0005, 12.5)
    axes[index,0].set_xticks([0.001, 0.01, 0.1, 1, 10,])
    axes[index,0].set_xticklabels(labels=["$10^{-3}$", "$10^{-2}$", "$10^{-1}$", "$10^0$", "$10^1$",], fontsize=6)
    plt.setp(axes[index,0].collections, alpha=0.8, linewidths=0.5, colors="black") # for vertical error bar segment
    plt.setp(axes[index,0].lines, alpha=0.8, markeredgewidth=0.5, markeredgecolor="black", linewidth=1, markersize=4) # for the lines and markers
    sns.move_legend(
        axes[index,0], 
        loc="upper left",
        borderaxespad=0,
        frameon=False,
        bbox_to_anchor=(1, 1.3),
        fontsize=6,
        markerscale=1,
        handletextpad=0.1,
        title="Lassa GPC",
        title_fontproperties={"weight" : "bold", "size" : 6},
        alignment="center",
    )

    # Add edges to legend markers to match scatter plot
    for ha in axes[index,0].legend_.legendHandles:
        ha.set_markeredgecolor("black")
        ha.set_markeredgewidth(0.5)
        ha.set_linewidth(0)
        ha.set_markersize(5)
        
    # Change all spines
    for axis in ["top", "bottom", "left", "right"]:
        axes[index,0].spines[axis].set_linewidth(1)
    axes[index,0].tick_params(axis="both", length=2, width=1)

# Adjust spacing of subplots
fig.subplots_adjust(
    left  = 0,  # the left side of the subplots of the figure
    right = 1,    # the right side of the subplots of the figure
    bottom = 0,   # the bottom of the subplots of the figure
    top = 1,      # the top of the subplots of the figure
    wspace = 0,   # the amount of width reserved for blank space between subplots
    hspace = 0.8,   # the amount of height reserved for white space between subplots
)

# Common X and Y axis labels
fig.text(-0.4, 0.5, "fraction infectivity", va="center", rotation="vertical", fontsize=8)
fig.text(0.5, -0.05, "concentration (\u03BCg/mL)", ha="center", fontsize=8)

width = 0.9
height = 6
fig.set_size_inches(width, height)

# Make output dir if doesn't exist
if not os.path.exists(out_dir):
    os.mkdir(out_dir)

# Save fig
plt.savefig(neuts_image_path)

We further calculate **cumulative summed** escape scores across all mutations present in each isolate even if the effects are small. These total DMS escape scores for each isolate are then compared to the ic50 values. The same comparisons are made for the corresponding single mutants. 

In [None]:
# IC values to calculate
fitParams = fits.fitParams(ics=[50, 80, 99])[["serum", "virus", "ic50", "ic50_bound"]]

# Merge isolate sequence and neut dfs
validation_df = (
    fitParams.merge(
        validation_isolates,
        how="left",
        on="virus",
        validate="many_to_one"
    )
)

# Add sinlge mutant sequences
validation_df.at[2, "sequence"] = (
    josiah_sequence[:int(validation_df.at[2,"virus"].split()[0][1:-1])-1] + validation_df.at[2,"virus"].split()[0][-1] + josiah_sequence[int(validation_df.at[2,"virus"].split()[0][1:-1]):]
)

validation_df.at[5, "sequence"] = (
    josiah_sequence[:int(validation_df.at[5,"virus"].split()[0][1:-1])-1] + validation_df.at[5,"virus"].split()[0][-1] + josiah_sequence[int(validation_df.at[5,"virus"].split()[0][1:-1]):]
)

validation_df.at[8, "sequence"] = (
    josiah_sequence[:int(validation_df.at[8,"virus"].split()[0][1:-1])-1] + validation_df.at[8,"virus"].split()[0][-1] + josiah_sequence[int(validation_df.at[8,"virus"].split()[0][1:-1]):]
)

validation_df.at[10, "sequence"] = (
    josiah_sequence[:int(validation_df.at[10,"virus"].split()[0][1:-1])-1] + validation_df.at[10,"virus"].split()[0][-1] + josiah_sequence[int(validation_df.at[10,"virus"].split()[0][1:-1]):]
)

validation_df.at[13, "sequence"] = (
    josiah_sequence[:int(validation_df.at[13,"virus"].split()[0][1:-1])-1] + validation_df.at[13,"virus"].split()[0][-1] + josiah_sequence[int(validation_df.at[13,"virus"].split()[0][1:-1]):]
)

validation_df.at[16, "sequence"] = (
    josiah_sequence[:int(validation_df.at[16,"virus"].split()[0][1:-1])-1] + validation_df.at[16,"virus"].split()[0][-1] + josiah_sequence[int(validation_df.at[16,"virus"].split()[0][1:-1]):]
)

# Antibody conversion dict
antibody_file_dict = {
    "25.10C" : filtered_escape_2510C,
    "12.1F" : filtered_escape_121F,
    "37.7H" : filtered_escape_377H,
    "25.6A" : filtered_escape_256A,
    "37.2D" : filtered_escape_372D,
    "8.9F" : filtered_escape_89F,
}

# Calculate total escape for each sequence using an additive model
validation_df["total_escape"] = (
    validation_df.apply(lambda x: determine_escape(
        0, 
        x["sequence"], 
        antibody_file_dict[x["serum"]], 
        x["virus"],
    ), axis=1)
)

# Rename columns and reformat columns
validation_df = (
    validation_df.rename(columns={
        "ic50_bound" : "lower bound",
        "serum" : "antibody",
    })
)
validation_df["lower bound"] = (
    validation_df["lower bound"].replace({"interpolated" : False, "lower" : True})
)

# Mark strains as natural isolates or single mutants
validation_df["Lassa GPC"] = (
    validation_df.apply(lambda x: "unmutated\nJosiah" if x["virus"] == "unmutated Josiah strain" else 
        ("Josiah\nvariant" if x["virus"] in [
            "N89D Josiah strain",
            "K126N Josiah strain",
            "E228D Josiah strain",
            "H398K Josiah strain",
            "D401E Josiah strain",
        ] else "natural\nisolate"), axis=1)
)

In [None]:
# Set figure size and subplots
fig, axes = plt.subplots(
    1, 
    5, 
    figsize=(6.25, 1.25), 
    sharey=True,
    sharex=True,
)
color_dict = {
    "25.10C" : "#44AA99", 
    "12.1F" : "#999933", 
    "37.7H" : "#AA4499", 
    "37.2D" : "#AA4499", 
    "8.9F" : "#117733",
}

for index,antibody in enumerate(["25.10C", "12.1F", "37.7H", "37.2D", "8.9F"]):

    curr_antibody = (
        validation_df.loc[validation_df["antibody"] == antibody]
        .reset_index(drop=True)
    )
        

    # Calculate correlation between predicted and measured
    r, p = sp.stats.pearsonr(
        x=np.log10(curr_antibody["ic50"]), # log10 because plotting on log scale
        y=curr_antibody["total_escape"]
    )
    print(f"R={r}")
    print(f"R^2={r**2}")
    
    corr_chart = sns.scatterplot(
        data=curr_antibody,
        x="ic50", 
        y="total_escape",
        hue="Lassa GPC",
        palette={
            "unmutated\nJosiah" : tol_muted_adjusted[0],
            "natural\nisolate" : tol_muted_adjusted[12],
            "Josiah\nvariant" : tol_muted_adjusted[10],
        },
        style="lower bound",
        markers=["o", "s"],
        alpha=0.8,
        edgecolor="black",
        ax=axes[index],
        s=25,
    )
    # Change all spines
    for axis in ["top", "bottom", "left", "right"]:
        corr_chart.spines[axis].set_linewidth(1)
    sns.despine()
    corr_chart.tick_params(axis="both", length=3, width=1)
    corr_chart.set(xlabel=None)
    corr_chart.set_title(
        antibody,
        # weight="bold",
        fontsize=8,
        color=color_dict[antibody]
    )
    corr_chart.set_ylabel(
        "escape score measured\nby DMS (arbitrary units)", 
        # weight="bold",
        fontsize=8,
    )
    corr_chart.set_xscale("log")
    corr_chart.xaxis.set_minor_locator(mticker.NullLocator()) # no minor ticks
    corr_chart.set_xlim(0.005, 500)
    corr_chart.set_ylim(-0.5, 8.5)
    corr_chart.set_xticks([0.01, 1, 100])
    corr_chart.set_xticklabels(["$10^{-2}$", "$10^{0}$", "$10^{2}$"], size=6)
    corr_chart.set_yticks([0, 2, 4, 6, 8])
    corr_chart.set_yticklabels(corr_chart.get_yticks(), size=6)
    if index == 2:
        sns.move_legend(
            corr_chart, 
            "upper left", 
            bbox_to_anchor=(3.5,1),
            fontsize=6,
            markerscale=1,
            handletextpad=0.1,
            frameon=False,
            borderaxespad=0.75,
            # title="Lassa GPC",
            # title_fontproperties = {
            #     "size" : 6, 
            #     "weight" : "bold",
            # },
        )
        # Add edges to legend markers to match scatter plot
        for ha in corr_chart.legend_.legendHandles:
            ha.set_edgecolor("black")
            ha.set_linewidths(0.5)
        corr_chart.get_legend().get_texts()[0].set_weight("bold")
        corr_chart.get_legend().get_texts()[4].set_weight("bold")
    else:
        corr_chart.get_legend().remove()

    # Add correlation to chart
    corr_chart.text(
        0.01, 
        7.5,
        f"r={r:.2f}", 
        horizontalalignment="left",  
        weight="bold",
        fontsize=6,
    )
    
    # Label points
    for i in range(0, curr_antibody.shape[0]):
        x_pos = curr_antibody.at[i, "ic50"]
        y_pos = curr_antibody.at[i, "total_escape"]
        name = curr_antibody.at[i, "virus"].split(" ")[0]
        if antibody == "12.1F" and name == "N89D":
            corr_chart.text(
                x_pos+15, 
                y_pos, 
                f"*{name}", 
                horizontalalignment="left",  
                # weight="bold",
                fontsize=6,
                style="italic",
            )
        elif antibody == "12.1F" and name == "LM395":
            corr_chart.text(
                x_pos+15, 
                y_pos, 
                f"{name}", 
                horizontalalignment="left",  
                # weight="bold",
                fontsize=6,
            )
        elif antibody == "37.7H" and name == "D401E":
            corr_chart.text(
                x_pos+0.1, 
                y_pos, 
                f"{name}", 
                horizontalalignment="left",  
                # weight="bold",
                fontsize=6,
            )
        elif antibody == "37.7H" and name == "IRR007":
            corr_chart.text(
                x_pos+15, 
                y_pos+0.25, 
                f"{name}", 
                horizontalalignment="left",  
                # weight="bold",
                fontsize=6,
            )
        elif antibody == "37.7H" and name == "GA391":
            corr_chart.text(
                x_pos+15, 
                y_pos-0.5, 
                f"{name}", 
                horizontalalignment="left",  
                # weight="bold",
                fontsize=6,
            )
        elif antibody == "37.2D" and name == "GA391":
            corr_chart.text(
                x_pos+0.5, 
                y_pos, 
                f"{name}", 
                horizontalalignment="left",  
                # weight="bold",
                fontsize=6,
            )
        elif antibody == "37.2D" and name == "H398K":
            corr_chart.text(
                x_pos+0.75, 
                y_pos-0.25, 
                f"{name}", 
                horizontalalignment="left",  
                # weight="bold",
                fontsize=6,
            )
        else:
            if name != "unmutated":
                corr_chart.text(
                    x_pos, 
                    y_pos+0.5, 
                    f"{name}", 
                    horizontalalignment="left",  
                    # weight="bold",
                    fontsize=6,
                )
            
    # Set square ratio
    corr_chart.set_box_aspect(1)

# Common X axis labels
fig.text(0.5, -0.1, "IC50 (\u03BCg/mL) measured by pseuodvirus neutralization", ha="center", rotation="horizontal", fontsize=8)

# Make output dir if doesn't exist
if not os.path.exists(out_dir):
    os.mkdir(out_dir)

# Save fig
fig.savefig(corr_image_path)

## Further analysis

First create a dataframe for amino acid level measurements and mutational frequencies with respect to the Josiah strain calculated from natural sequence alignments. In addition, create a site level dataframe with site level measurements, cumulative mutational frequencies, and effective number of amino acids calculated for each position. The formula for calculating effective amino acids is described in *Biophysical Models of Protein Evolution: Understanding the Patterns of Evolutionary Sequence Divergence*.

In [None]:
def get_natural_sequence_counts(site, amino_acid, natural_seqs_df):
    """
    Function that counts occurences of an amino acid at a site
    across a dataframe of sequences.
    """
    count = 0
    for seq in natural_seqs_df["sequence"].tolist():
        if seq[site-1] == amino_acid:
            count += 1
    return count
            

# Load data as dataframe
functional_scores = pd.read_csv(func_scores)

# Filter for minimum selections, times seen and no stop codons
functional_scores = (
    functional_scores.query(
        "n_selections >= @n_selections and times_seen >= @min_times_seen and mutant != '*'"
    )
    .drop(columns=["times_seen", "effect_std", "n_selections"])
    .reset_index(drop=True)
)

# Create a wildtype df to fill in missing Josiah/wildtype measurements as 0
josiah_df = pd.DataFrame(
    zip(list(range(1,492)), josiah_sequence, josiah_sequence, [0]*491, [True]*491),
    columns=["site", "wildtype", "mutant", "effect", "Josiah reference"],
)
AA_level_df = (
    pd.concat([functional_scores, josiah_df], ignore_index=True)
    .sort_values(by="site")
    .reset_index(drop=True)
    .fillna(False)
)

# Get natural sequence counts of each mutant and calculate mutation frequencies
# compared to the Josiah reference
AA_level_df["natural_counts"] = (
    AA_level_df.apply(lambda x: get_natural_sequence_counts(x["site"], x["mutant"], natural_seqs_df), axis=1)
)
number_sequences = AA_level_df["natural_counts"].max() # 572 sequences
AA_level_df["mutation_frequency"] = (
    AA_level_df.apply(lambda x: x["natural_counts"]/number_sequences if x["Josiah reference"] == False else None, axis=1)
)


# Add escape to dataframe for each antibody
for antibody_file,_ in antibody_files:

    antibody_name = antibody_file.split("/")[-1].split("_")[0]

    # Load data as dataframe
    escape_df = pd.read_csv(antibody_file)

    # Clip lower scores to 0
    escape_df["escape_median"] = escape_df["escape_median"].clip(lower=0)

    # Rename escape column to include antibody name
    escape_df = escape_df.rename(columns={"escape_median" : "escape_" + antibody_name})

    # Merge dataframes
    AA_level_df = (
        AA_level_df.merge(
            escape_df[["site", "wildtype", "mutant", "escape_" + antibody_name]],
            how="left",
            on=["site", "wildtype", "mutant"],
            validate="one_to_one",
        )
    )

# Fill in missing wildtype escape values
AA_level_df["escape_2510C"] = AA_level_df.apply(lambda x: 0 if x["wildtype"] == x["mutant"] else x["escape_2510C"], axis=1)
AA_level_df["escape_121F"] = AA_level_df.apply(lambda x: 0 if x["wildtype"] == x["mutant"] else x["escape_121F"], axis=1)
AA_level_df["escape_377H"] = AA_level_df.apply(lambda x: 0 if x["wildtype"] == x["mutant"] else x["escape_377H"], axis=1)
AA_level_df["escape_256A"] = AA_level_df.apply(lambda x: 0 if x["wildtype"] == x["mutant"] else x["escape_256A"], axis=1)
AA_level_df["escape_372D"] = AA_level_df.apply(lambda x: 0 if x["wildtype"] == x["mutant"] else x["escape_372D"], axis=1)
AA_level_df["escape_89F"] = AA_level_df.apply(lambda x: 0 if x["wildtype"] == x["mutant"] else x["escape_89F"], axis=1)

# Calculate total escape
AA_level_df["cocktail_total_escape"] = (
    AA_level_df[[
        "escape_121F",
        "escape_372D",
        "escape_89F"
    ]].sum(axis=1, skipna=False)
)

# Create a site level dataframe
site_level_df = (
    AA_level_df.groupby(["site", "wildtype"])
    .aggregate({
        "effect" : "mean",
        "escape_2510C" : "sum",
        "escape_121F" : "sum",
        "escape_377H" : "sum",
        "escape_256A" : "sum",
        "escape_372D" : "sum",
        "escape_89F" : "sum",
        "cocktail_total_escape" : "sum",
        "natural_counts" : "sum",
        "mutation_frequency" : "sum",
    })
    .reset_index()
)

# Calculate total escape averages per site
# for Arevirumab-3
site_level_df["cocktail_avg_site_escape"] = (
    site_level_df[[
        "escape_121F",
        "escape_372D",
        "escape_89F"
    ]].mean(axis=1, skipna=False)
)

# Add site level escape to AA level df
AA_level_df = (
    AA_level_df.merge(
        site_level_df[[
            "site", 
            "wildtype", 
            "escape_2510C",
            "escape_121F",
            "escape_377H",
            "escape_256A",
            "escape_372D",
            "escape_89F",
            "cocktail_total_escape",
            # "avg_site_escape",
            "cocktail_avg_site_escape",
        ]].rename(columns={
            "escape_2510C" : "site_escape_2510C",
            "escape_121F" : "site_escape_121F",
            "escape_377H" : "site_escape_377H",
            "escape_256A" : "site_escape_256A",
            "escape_372D" : "site_escape_372D",
            "escape_89F" : "site_escape_89F",
            "cocktail_total_escape" : "cocktail_total_site_escape",
        }),
        how="left",
        on=["site", "wildtype"],
        validate="many_to_one",
    )
)

# Load data as dataframe
natural_variation = pd.read_csv(natural_sequence_variation)

# Drop individual amino acid counts
natural_variation = natural_variation[["site", "entropy", "n_effective"]]

# Merge functional and natural dataframes
site_level_df = (
    site_level_df.merge(
        natural_variation,
        how="left",
        on=["site"],
        validate="one_to_one",
    )
)

# Mark mutations if in top percentile of summed escape sites
for antibody in ["2510C", "121F", "377H", "256A", "372D", "89F"]:

    # Calculate cutoff 
    cutoff = AA_level_df["site_escape_"+antibody].median() * 10

    # Mark sites that are greater than or equal to the top 5% summed escape site
    AA_level_df["top_site_escape_for_" + antibody] = AA_level_df["site_escape_"+antibody].apply(lambda x: True if x >= cutoff else False)

# Mark validation mutations
for antibody in ["2510C", "121F", "377H", "256A", "372D", "89F"]:
    validation_list = []
    if antibody == "2510C":
        validation_list = ["E228D"]
    if antibody == "121F":
        validation_list = ["N89D"]
    if antibody == "377H":
        validation_list = ["H398K", "D401E"]
    if antibody == "256A":
        validation_list = []
    if antibody == "372D":
        validation_list = ["H398K"]
    if antibody == "89F":
        validation_list = ["K126N"]
    AA_level_df["validation_for_" + antibody] = AA_level_df.apply(lambda x: True if x["wildtype"]+str(x["site"])+x["mutant"] in validation_list else False, axis=1)

## Comparisons of functional effects and natural sequence diversity

First, look at correlations between natural sequence diversity metrics and functional scores. 

In [None]:
# Making two lists for values and colors 
dom = [True, False] 
rng = ["#EE7733FF", "#00000026"] 

# Calculate statistics
r, p = sp.stats.pearsonr(
    site_level_df[["effect","mutation_frequency"]].dropna()["effect"],
    site_level_df[["effect","mutation_frequency"]].dropna()["mutation_frequency"],
)
print(f"r correlation of average site effect on cell entry and site mutational frequency: {r:.2f}")

natural_vs_effect_mut_freq = alt.Chart(
    site_level_df, 
).mark_point(
    filled=True, 
    color="black", 
    size=75,
    opacity=0.15,
).encode(
    alt.Y(
        "effect",
        axis=alt.Axis(
            title=["site mean", "effect on cell entry"],
            values=[-4,-3,-2,-1,0,1],
            domainWidth=1,
            domainColor="black",
            tickColor="black",
        ),
        scale=alt.Scale(domain=[-4.1,1.1])
    ),
    alt.X(
        "mutation_frequency",
        axis=alt.Axis(
            title=["mutation frequencies", "in natural sequences"],
            values=[0, 0.001, 0.01, 0.1, 0.5, 1],
            domainWidth=1,
            domainColor="black",
            tickColor="black",
            format="0.3",
        ),
        scale=alt.Scale(type="symlog", constant=0.0005, domain=[-0.0001,1.1])
    ),
    tooltip=[
        "site",
        "wildtype",
        "mutation_frequency",
        "effect",
    ],
).properties(
    width=350,
    height=350,
)


# Calculate statistics
r, p = sp.stats.pearsonr(
    site_level_df[["effect","n_effective"]].dropna()["effect"],
    site_level_df[["effect","n_effective"]].dropna()["n_effective"],
)
print(f"r correlation of average site effect on cell entry and site effective amino acids: {r:.2f}")

natural_vs_effect_neff = alt.Chart(
    site_level_df, 
).mark_point(
    filled=True, 
    color="black", 
    size=75,
    opacity=0.15,
).encode(
    alt.Y(
        "effect",
        axis=alt.Axis(
            title=["site mean", "effect on cell entry"],
            values=[-4,-3,-2,-1,0,1],
            domainWidth=1,
            domainColor="black",
            tickColor="black",
        ),
        scale=alt.Scale(domain=[-4.1,1.1])
    ),
    alt.X(
        "n_effective",
        axis=alt.Axis(
            title=["effective amino acids", "in natural sequences"], 
            values=[1,2,3,4],
            domainWidth=1,
            domainColor="black",
            tickColor="black",
        ),
        scale=alt.Scale(domain=[0.9,4.1])
    ),
    tooltip=[
        "site",
        "wildtype",
        "n_effective",
        "effect",
    ],
).properties(
    width=350,
    height=350,
)

combined_plot = alt.hconcat(
    natural_vs_effect_mut_freq,
    natural_vs_effect_neff,
    spacing=5,
    title="Correlations of site natural sequence diversity and effect on cell entry"
).configure_axis(
    grid=False,
    labelFontSize=16,
    titleFontSize=16,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_title(
    fontSize=24,
)

# Make output dir if doesn't exist
if not os.path.exists(out_dir_natural):
    os.mkdir(out_dir_natural)

combined_plot.save(html_func_vs_natural)

combined_plot

## Compare DMS data to natural sequences

Next, we are going to compare DMS antibody escape scores to natural sequence variation. Natural sequence variation is measured by two methods:
1. Mutation frequency relative to Josiah strain for both single amino acid mutations as well as cumulative site mutation frequencies.
2. Number of effective amino acids per site as described in *Biophysical Models of Protein Evolution: Understanding the Patterns of Evolutionary Sequence Divergence*.

These diversity metrics will be compared to antibody escape as measured for individual amino acid mutations, site summed escape, and site summed escape across all mapped antibodies.

In [None]:
# Making two lists for values and colors 
dom = [True, False] 
rng = ["#EE7733FF", "#00000026"] 

subplots = []
for index,antibody in enumerate(antibody_files):
    antibody_name = antibody[0].split("/")[-1].split("_")[0]

    curr_subplot = alt.Chart(AA_level_df, title=antibody_name).mark_point(
        filled=True, 
        color="black", 
        size=75,
        opacity=0.15,
    ).encode(
        alt.X(
            "mutation_frequency",
            axis=alt.Axis(
                title="mutation frequency", 
                values=[0, 0.001, 0.01, 0.1, 0.5, 1],
                domainWidth=1,
                domainColor="black",
                tickColor="black",
                format="0.3",
            ),
            scale=alt.Scale(type="symlog", constant=0.0005, domain=[-0.0001,1.1])
        ),
        alt.Y(
            "escape_" + antibody_name,
            axis=alt.Axis(
                title="escape", 
                values=[0,1,2,3,4,5,6],
                domainWidth=1,
                domainColor="black",
                tickColor="black",
            ),
            scale=alt.Scale(domain=[0,6.1])
        ),
        tooltip=[
            "site",
            "wildtype",
            "mutant",
            "mutation_frequency",
            "effect",
            "escape_" + antibody_name,
        ], 
    ).properties(
        width=250,
        height=150,
    )
    
    subplots.append(curr_subplot)

natural_vs_antibody = alt.hconcat(
    subplots[0],
    subplots[1],
    subplots[2],
    subplots[3],
    subplots[4],
    subplots[5],
    spacing=5,
    title=["Mutational frequencies vs antibody escape", "for individual amino-acid mutations"],
).configure_axis(
    grid=False,
    labelFontSize=16,
    titleFontSize=16,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_title(
    fontSize=24,
)

# Make output dir if doesn't exist
if not os.path.exists(out_dir_natural):
    os.mkdir(out_dir_natural)

natural_vs_antibody.save(html_nat_mut_freqs_vs_escape)

natural_vs_antibody

In [None]:
r, p = sp.stats.pearsonr(
    AA_level_df[["cocktail_total_escape","mutation_frequency"]].replace([np.inf, -np.inf], np.nan).dropna()["cocktail_total_escape"],
    AA_level_df[["cocktail_total_escape","mutation_frequency"]].replace([np.inf, -np.inf], np.nan).dropna()["mutation_frequency"],
)
print(f"r correlation of cocktail total escape and natural sequence diversity: {r:.2f}")

natural_vs_antibody = alt.Chart(
    AA_level_df,
    title=["Mutational frequencies vs antibody escape across all Arevirumab-3 antibodies", "for individual amino-acid mutations"]
).mark_point(
    filled=True, 
    color="black", 
    size=75,
    opacity=0.15,
).encode(
    alt.Y(
        "cocktail_total_escape",
        axis=alt.Axis(
            title=["escape summed", "across all Arevirumab-3 antibodies"], 
            # values=[0,1,2,3],
            domainWidth=1,
            domainColor="black",
            tickColor="black",
        ),
        # scale=alt.Scale(domain=[-0.1, 3.1])
    ),
    alt.X(
        "mutation_frequency",
        axis=alt.Axis(
            title="mutation frequency", 
            values=[0, 0.001, 0.01, 0.1, 0.5, 1],
            domainWidth=1,
            domainColor="black",
            tickColor="black",
            format="0.3",
        ),
        scale=alt.Scale(type="symlog", constant=0.0005, domain=[-0.0001,1.1])
    ),
    tooltip=[
        "site",
        "wildtype",
        "mutant",
        "mutation_frequency",
        "escape_121F",
        "escape_89F",
        "escape_372D",
        "cocktail_total_escape",
    ], 
).properties(
    width=300,
    height=300,
).configure_axis(
    grid=False,
    labelFontSize=16,
    titleFontSize=16,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_title(
    fontSize=24,
)

# Make output dir if doesn't exist
if not os.path.exists(out_dir_natural):
    os.mkdir(out_dir_natural)

natural_vs_antibody.save(html_nat_mut_freqs_vs_escape_all_abs)

natural_vs_antibody

In [None]:
# Making two lists for values and colors 
dom = [True, False] 
rng = ["#EE7733FF", "#00000026"] 

subplots = []
for index,antibody in enumerate(antibody_files):
    antibody_name = antibody[0].split("/")[-1].split("_")[0]

    curr_subplot = alt.Chart(site_level_df, title=antibody_name).mark_point(
        filled=True, 
        color="black", 
        size=75,
        opacity=0.15,
    ).encode(
        alt.Y(
            "escape_" + antibody_name,
            axis=alt.Axis(
                title="site escape", 
                # values=[0,1,2,3,4,5,6],
                domainWidth=1,
                domainColor="black",
                tickColor="black",
            ),
            # scale=alt.Scale(domain=[-0.1,6.5])
        ),
        alt.X(
            "mutation_frequency",
            axis=alt.Axis(
                title=["mutation frequencies", "in natural sequences"], 
                values=[0, 0.001, 0.01, 0.1, 0.5, 1],
                domainWidth=1,
                domainColor="black",
                tickColor="black",
                format="0.3",
            ),
            scale=alt.Scale(type="symlog", constant=0.0005, domain=[-0.0001,1.1])
        ),
        tooltip=[
            "site",
            "wildtype",
            "mutation_frequency",
            "escape_" + antibody_name,
            "effect",
        ],
    ).properties(
        width=250,
        height=150,
    )
    
    subplots.append(curr_subplot)

natural_vs_antibody_mut_freq = alt.hconcat(
    subplots[0],
    subplots[1],
    subplots[2],
    subplots[3],
    subplots[4],
    subplots[5],
    spacing=5,
    title=["Mutational frequencies vs antibody escape", "for summed site escape"],
)

subplots = []
for index,antibody in enumerate(antibody_files):
    antibody_name = antibody[0].split("/")[-1].split("_")[0]

    curr_subplot = alt.Chart(site_level_df, title=antibody_name).mark_point(
        filled=True, 
        color="black", 
        size=75,
        opacity=0.15,
    ).encode(
        alt.Y(
            "escape_" + antibody_name,
            axis=alt.Axis(
                title="site escape", 
                domainWidth=1,
                domainColor="black",
                tickColor="black",
            ),
            # scale=alt.Scale(domainMin=-1)
        ),
        alt.X(
            "n_effective",
            axis=alt.Axis(
                title=["effective amino acids", "in natural sequences"], 
                values=[1,2,3,4],
                domainWidth=1,
                domainColor="black",
                tickColor="black",
            ),
            scale=alt.Scale(domain=[0.9,4])
        ),
        tooltip=[
            "site",
            "wildtype",
            "n_effective",
            "escape_" + antibody_name,
            "effect",
        ],
    ).properties(
        width=250,
        height=150,
    )
    
    subplots.append(curr_subplot)

natural_vs_antibody_neff = alt.hconcat(
    subplots[0],
    subplots[1],
    subplots[2],
    subplots[3],
    subplots[4],
    subplots[5],
    spacing=5,
    title=["Effective amino acids vs antibody escape", "for summed site escape"]
)

combined_plot = alt.vconcat(
    natural_vs_antibody_mut_freq,
    natural_vs_antibody_neff,
    spacing=5,
).configure_axis(
    grid=False,
    labelFontSize=16,
    titleFontSize=16,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_axis(
    grid=False,
    labelFontSize=16,
    titleFontSize=16,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_title(
    fontSize=24,
)

# Make output dir if doesn't exist
if not os.path.exists(out_dir_natural):
    os.mkdir(out_dir_natural)

combined_plot.save(html_natural_vs_escape)

combined_plot

In [None]:
# Making two lists for values and colors 
dom = [True, False] 
rng = ["#EE7733FF", "#00000026"] 

# Calculate statistics
r, p = sp.stats.pearsonr(
    site_level_df[["cocktail_total_escape","mutation_frequency"]].dropna()["cocktail_total_escape"],
    site_level_df[["cocktail_total_escape","mutation_frequency"]].dropna()["mutation_frequency"],
)
print(f"r correlation of site escape summed across all Arevirumab-3 antibodies and site mutational frequency: {r:.2f}")

natural_vs_antibody_mut_freq = alt.Chart(
    site_level_df, 
).mark_point(
    filled=True, 
    color="black", 
    size=75,
    opacity=0.15,
).encode(
    alt.Y(
        "cocktail_total_escape",
        axis=alt.Axis(
            title=["site escape", "summed across all Arevirumab-3 antibodies"],  
            # values=[0,1,2,3],
            domainWidth=1,
            domainColor="black",
            tickColor="black",
        ),
        # scale=alt.Scale(domain=[-0.1, 3.1])
    ),
    alt.X(
        "mutation_frequency",
        axis=alt.Axis(
            title=["mutation frequencies", "in natural sequences"],
            values=[0, 0.001, 0.01, 0.1, 0.5, 1],
            domainWidth=1,
            domainColor="black",
            tickColor="black",
            format="0.3",
        ),
        scale=alt.Scale(type="symlog", constant=0.0005, domain=[-0.0001,1.1])
    ),
    tooltip=[
        "site",
        "wildtype",
        "mutation_frequency",
        "escape_121F",
        "escape_89F",
        "escape_372D",
        "cocktail_total_escape",
    ],
).properties(
    width=300,
    height=300,
)

# Calculate statistics
r, p = sp.stats.pearsonr(
    site_level_df[["cocktail_avg_site_escape","mutation_frequency"]].dropna()["cocktail_avg_site_escape"],
    site_level_df[["cocktail_avg_site_escape","mutation_frequency"]].dropna()["mutation_frequency"],
)
print(f"r correlation of site escape averaged across all Arevirumab-3 antibodies and site mutational frequency: {r:.2f}")

natural_vs_antibody_avg_mut_freq = alt.Chart(
    site_level_df, 
).mark_point(
    filled=True, 
    color="black", 
    size=75,
    opacity=0.15,
).encode(
    alt.Y(
        "cocktail_avg_site_escape",
        axis=alt.Axis(
            title=["site escape", "averaged across all Arevirumab-3 antibodies"],  
            # values=[0,1,2,3],
            domainWidth=1,
            domainColor="black",
            tickColor="black",
        ),
        # scale=alt.Scale(domain=[-0.1, 3.1])
    ),
    alt.X(
        "mutation_frequency",
        axis=alt.Axis(
            title=["mutation frequencies", "in natural sequences"],
            values=[0, 0.001, 0.01, 0.1, 0.5, 1],
            domainWidth=1,
            domainColor="black",
            tickColor="black",
            format="0.3",
        ),
        scale=alt.Scale(type="symlog", constant=0.0005, domain=[-0.0001,1.1])
    ),
    tooltip=[
        "site",
        "wildtype",
        "mutation_frequency",
        "escape_121F",
        "escape_89F",
        "escape_372D",
        "cocktail_avg_site_escape",
    ],
).properties(
    width=300,
    height=300,
)


# Calculate statistics
r, p = sp.stats.pearsonr(
    site_level_df[["cocktail_total_escape","n_effective"]].dropna()["cocktail_total_escape"],
    site_level_df[["cocktail_total_escape","n_effective"]].dropna()["n_effective"],
)
print(f"r correlation of site escape summed across all Arevirumab-3 antibodies and site effective amino acids: {r:.2f}")

natural_vs_antibody_neff = alt.Chart(
    site_level_df, 
).mark_point(
    filled=True, 
    color="black", 
    size=75,
    opacity=0.15,
).encode(
    alt.Y(
        "cocktail_total_escape",
        axis=alt.Axis(
            title=["site escape", "summed across all Arevirumab-3 antibodies"],  
            # values=[0,1,2,3],
            domainWidth=1,
            domainColor="black",
            tickColor="black",
        ),
        # scale=alt.Scale(domain=[-0.1, 3.1])
    ),
    alt.X(
        "n_effective",
        axis=alt.Axis(
            title=["effective amino acids", "in natural sequences"], 
            values=[1,2,3,4],
            domainWidth=1,
            domainColor="black",
            tickColor="black",
        ),
        scale=alt.Scale(domain=[1,4])
    ),
    tooltip=[
        "site",
        "wildtype",
        "n_effective",
        "escape_121F",
        "escape_89F",
        "escape_372D",
        "cocktail_total_escape",
    ], 
).properties(
    width=300,
    height=300,
)

# Calculate statistics
r, p = sp.stats.pearsonr(
    site_level_df[["cocktail_avg_site_escape","n_effective"]].dropna()["cocktail_avg_site_escape"],
    site_level_df[["cocktail_avg_site_escape","n_effective"]].dropna()["n_effective"],
)
print(f"r correlation of site escape averaged across all Arevirumab-3 antibodies and site effective amino acids: {r:.2f}")

natural_vs_antibody_avg_neff = alt.Chart(
    site_level_df, 
).mark_point(
    filled=True, 
    color="black", 
    size=75,
    opacity=0.15,
).encode(
    alt.Y(
        "cocktail_avg_site_escape",
        axis=alt.Axis(
            title=["site escape", "averaged across all Arevirumab-3antibodies"],  
            # values=[0,1,2,3],
            domainWidth=1,
            domainColor="black",
            tickColor="black",
        ),
        # scale=alt.Scale(domain=[-0.1, 3.1])
    ),
    alt.X(
        "n_effective",
        axis=alt.Axis(
            title=["effective amino acids", "in natural sequences"], 
            values=[1,2,3,4],
            domainWidth=1,
            domainColor="black",
            tickColor="black",
        ),
        scale=alt.Scale(domain=[1,4])
    ),
    tooltip=[
        "site",
        "wildtype",
        "n_effective",
        "escape_121F",
        "escape_89F",
        "escape_372D",
        "cocktail_avg_site_escape",
    ], 
).properties(
    width=300,
    height=300,
)

combined_plot = alt.hconcat(
    natural_vs_antibody_mut_freq,
    natural_vs_antibody_avg_mut_freq,
    natural_vs_antibody_neff,
    natural_vs_antibody_avg_neff,
    spacing=5,
    title=["Site natural sequence diversity vs site antibody escape", "across all Arevirumab-3 antibodies"]
).configure_axis(
    grid=False,
    labelFontSize=16,
    titleFontSize=16,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_title(
    fontSize=24,
)

# Make output dir if doesn't exist
if not os.path.exists(out_dir_natural):
    os.mkdir(out_dir_natural)

combined_plot.save(html_natural_vs_escape_all_abs)

combined_plot

Compare escape across the three antibodies in Arevirumab-3

In [None]:
subplots = []
for i,antibody in enumerate(["escape_372D", "escape_89F", "escape_121F"]):

    # Get names of antibodies to compare
    antibody_name = antibody.split("_")[1]
    other_antibody_name = ["escape_372D", "escape_89F", "escape_121F"][(i+1)%3].split("_")[1]
    if antibody_name == "372D":
        antibody_name = "37.2D"
    elif antibody_name == "89F":
        antibody_name = "8.9F"
    elif antibody_name == "121F":
        antibody_name = "12.1F"
    else:
        print("ERROR!")
    if other_antibody_name == "372D":
        other_antibody_name = "37.2D"
    elif other_antibody_name == "89F":
        other_antibody_name = "8.9F"
    elif other_antibody_name == "121F":
        other_antibody_name = "12.1F"
    else:
        print("ERROR!")

    curr_subplot = alt.Chart(
        AA_level_df, 
        title=f"{antibody_name} vs {other_antibody_name}",
    ).mark_point(
        filled=True, 
        color="black", 
        size=75,
        opacity=0.15,
    ).encode(
        alt.Y(
            antibody,
            axis=alt.Axis(
                title=f"{antibody_name} escape",
                values=[0,1,2,3,4,5,6],
                domainWidth=1,
                domainColor="black",
                tickColor="black",
            ),
            scale=alt.Scale(domain=[-0.2, 6])
        ),
        alt.X(
            ["escape_372D", "escape_89F", "escape_121F"][(i+1)%3],
            axis=alt.Axis(
                title=f"{other_antibody_name} escape", 
                values=[0,1,2,3,4,5,6],
                domainWidth=1,
                domainColor="black",
                tickColor="black",
            ),
            scale=alt.Scale(domain=[-0.2, 6])
        ),
        tooltip=[
            "site",
            "wildtype",
            "mutant",
            "effect",
            "escape_372D", 
            "escape_89F", 
            "escape_121F"
        ],
    ).properties(
        width=250,
        height=250,
    )
    
    subplots.append(curr_subplot)

combined_plot = alt.hconcat(
    subplots[0],
    subplots[1],
    subplots[2],
    spacing=5,
).configure_axis(
    grid=False,
    labelFontSize=16,
    titleFontSize=16,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_axis(
    grid=False,
    labelFontSize=16,
    titleFontSize=16,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_title(
    fontSize=24,
).configure_view(
    stroke=None
)

combined_plot.save(html_arevirumab_comparisons)

combined_plot

Recreate some of the **same** plots below but formatted for a manuscript or in static form.

In [None]:
natural_vs_effect_neff = alt.Chart(
    site_level_df, 
).mark_point(
    filled=True, 
    color="black", 
    opacity=0.15,
    size=10,
).encode(
    alt.Y(
        "effect",
        axis=alt.Axis(
            title=["site mean", "effect on cell entry"],
            values=[-4,-3,-2,-1,0,1],
            domainWidth=1,
            domainColor="black",
            tickColor="black",
            labelFontSize=8,
            titleFontSize=8,
            labelFontWeight="normal",
            titleFontWeight="normal",
        ),
        scale=alt.Scale(domain=[-4.1,1.1])
    ),
    alt.X(
        "n_effective",
        axis=alt.Axis(
            title=["effective amino acids", "in natural sequences"], 
            values=[1,2,3,4],
            domainWidth=1,
            domainColor="black",
            tickColor="black",
            labelFontSize=8,
            titleFontSize=8,
            labelFontWeight="normal",
            titleFontWeight="normal",
        ),
        scale=alt.Scale(domain=[0.9,4.1])
    ),
    tooltip=[
        "site",
        "wildtype",
        "n_effective",
        "effect",
    ],
).properties(
    width=125,
    height=125
).configure_axis(
    grid=False,
).configure_view(
    stroke=None
)

natural_vs_effect_neff

In [None]:
subplots = []
for i,antibody in enumerate(["escape_372D", "escape_89F", "escape_121F"]):

    # Get names of antibodies to compare
    antibody_name = antibody.split("_")[1]
    other_antibody_name = ["escape_372D", "escape_89F", "escape_121F"][(i+1)%3].split("_")[1]
    if antibody_name == "372D":
        antibody_name = "37.2D"
    elif antibody_name == "89F":
        antibody_name = "8.9F"
    elif antibody_name == "121F":
        antibody_name = "12.1F"
    else:
        print("ERROR!")
    if other_antibody_name == "372D":
        other_antibody_name = "37.2D"
    elif other_antibody_name == "89F":
        other_antibody_name = "8.9F"
    elif other_antibody_name == "121F":
        other_antibody_name = "12.1F"
    else:
        print("ERROR!")

    curr_subplot = alt.Chart(
        AA_level_df, 
        title=f"{antibody_name} vs {other_antibody_name}",
    ).mark_point(
        filled=True, 
        color="black", 
        size=15,
        opacity=0.15,
    ).encode(
        alt.Y(
            antibody,
            axis=alt.Axis(
                title=f"{antibody_name} escape",
                values=[0,1,2,3,4,5,6],
                domainWidth=1,
                domainColor="black",
                tickColor="black",
            ),
            scale=alt.Scale(domain=[-0.2, 6])
        ),
        alt.X(
            ["escape_372D", "escape_89F", "escape_121F"][(i+1)%3],
            axis=alt.Axis(
                title=f"{other_antibody_name} escape", 
                values=[0,1,2,3,4,5,6],
                domainWidth=1,
                domainColor="black",
                tickColor="black",
            ),
            scale=alt.Scale(domain=[-0.2, 6])
        ),
        tooltip=[
            "site",
            "wildtype",
            "mutant",
            "effect",
            "escape_372D", 
            "escape_89F", 
            "escape_121F"
        ],
    ).properties(
        width=110,
        height=110,
    )
    
    subplots.append(curr_subplot)

combined_plot = alt.hconcat(
    subplots[0],
    subplots[1],
    subplots[2],
    spacing=5,
).configure_axis(
    grid=False,
    labelFontSize=8,
    titleFontSize=8,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_axis(
    grid=False,
    labelFontSize=8,
    titleFontSize=8,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_title(
    fontSize=8,
    fontWeight="bold"
).configure_view(
    stroke=None
)

combined_plot

In [None]:
# Mark sites to label and color by antibody class with largest escape
site_level_df["cocktail_site_max"] = (
    site_level_df[[
        "escape_121F",
        "escape_372D",
        "escape_89F",
    ]].max(skipna = False, axis=1)
)
site_level_df["site_epitope"] = (
    site_level_df.apply(lambda x: "none" if (x["cocktail_site_max"] == 0) else
                        ("GPC-A" if (x["escape_2510C"] == x["cocktail_site_max"]) else 
                        ("GPC-B" if (x["escape_377H"] == x["cocktail_site_max"] or x["escape_256A"] == x["cocktail_site_max"] or x["escape_372D"] == x["cocktail_site_max"]) else 
                         ("GPC-C" if (x["escape_89F"] == x["cocktail_site_max"]) else 
                          ("GP1-A" if (x["escape_121F"] == x["cocktail_site_max"]) else "error")))), axis=1)
)
# Color only top escape sites with nonzero mutation frequencies
cutoff = site_level_df["n_effective"].quantile(0.75)
print(f"{cutoff} is the 75% percentile for effective amino acids")
top_escape_sites = site_level_df.query("n_effective >= @cutoff").nlargest(10, "cocktail_avg_site_escape", keep="all")["site"].tolist() 
print(f"The following sites are the top 10 escape mutations among the top 25% percentile of effective amino acids{top_escape_sites}")
site_level_df["site_label"] = (
    site_level_df.apply(lambda x: x["site_epitope"] if (x["site"] in top_escape_sites) else "none", axis=1)
)

# Replace zero frequency with small value for easier plotting on log scale
site_level_df["mutation_frequency"] = site_level_df["mutation_frequency"].replace(0, 0.0003)
AA_level_df["mutation_frequency"] = AA_level_df["mutation_frequency"].replace(0, 0.0002)

In [None]:
# Set figure size and subplots
fig, axes = plt.subplots(
    1, 
    6, 
    figsize=(7.7, 1), 
)

# Adjust spacing of subplots
fig.subplots_adjust(
    bottom=0, 
    top=1, 
    wspace=0.2, 
    hspace=0,
)

for index,antibody in enumerate(["2510C", "121F", "377H", "256A", "372D", "89F"]):
    chart = sns.scatterplot(
        data=AA_level_df,
        x="mutation_frequency",
        y="escape_"+antibody,
        edgecolor=None,
        linewidth=0.5,
        s=15,
        alpha=0.15,
        ax=axes[index],
    )
    chart.set_xscale("log")
    chart.xaxis.set_minor_locator(mticker.NullLocator()) # no minor ticks
    xticks = [0.0002, 0.001, 0.01, 0.1]
    chart.set_xticks(xticks)
    x_labels=["0", " $10^{-3}$", " $10^{-2}$", " $10^{-1}$"]
    chart.set_xticklabels(labels=x_labels, horizontalalignment="center", fontsize=6)

    chart.set_xlim(0.000125,1.1)
    chart.set(xlabel=None)

    if index == 0:
        chart.set_ylabel(ylabel="escape", fontsize=8)
        chart.set_ylim(-0.25, 6.25)
        yticks = [0, 1, 2, 3, 4, 5, 6]
        chart.set_yticks(yticks)
        chart.set_yticklabels(labels=["0", "1", "2", "3", "4", "5", "6"], fontsize=6)
        chart.set_title(
            "25.10C", 
            fontsize=8,
            color="#44AA99",
            horizontalalignment="center",
        )
    if index == 1:
        chart.set(ylabel=None)
        chart.set_ylim(-0.25, 6.25)
        yticks = [0, 1, 2, 3, 4, 5, 6]
        chart.set_yticks(yticks)
        chart.set_yticklabels(labels=["0", "1", "2", "3", "4", "5", "6"], fontsize=6)
        chart.set_title(
            "12.1F", 
            fontsize=8,
            color="#999933",
            horizontalalignment="center",
        )
    if index == 2:
        chart.set(ylabel=None)
        chart.set_ylim(-0.29, 7.29)
        yticks = [0, 1, 2, 3, 4, 5, 6, 7]
        chart.set_yticks(yticks)
        chart.set_yticklabels(labels=["0", "1", "2", "3", "4", "5", "6", "7"], fontsize=6)
        chart.set_title(
            "37.7H", 
            fontsize=8,
            color="#AA4499",
            horizontalalignment="center",
        )
    if index == 3:
        chart.set(ylabel=None)
        chart.set_ylim(-0.29, 7.29)
        yticks = [0, 1, 2, 3, 4, 5, 6, 7]
        chart.set_yticks(yticks)
        chart.set_yticklabels(labels=["0", "1", "2", "3", "4", "5", "6", "7"], fontsize=6)
        chart.set_title(
            "25.6A", 
            fontsize=8,
            color="#AA4499",
            horizontalalignment="center",
        )
    if index == 4:
        chart.set(ylabel=None)
        chart.set_ylim(-0.125, 3.125)
        yticks = [0, 1, 2, 3]
        chart.set_yticks(yticks)
        chart.set_yticklabels(labels=["0", "1", "2", "3"], fontsize=6)
        chart.set_title(
            "37.2D", 
            fontsize=8,
            color="#AA4499",
            horizontalalignment="center",
        )
    if index == 5:
        chart.set(ylabel=None)
        chart.set_ylim(-0.25, 6.25)
        yticks = [0, 1, 2, 3, 4, 5, 6]
        chart.set_yticks(yticks)
        chart.set_yticklabels(labels=["0", "1", "2", "3", "4", "5", "6"], fontsize=6)
        chart.set_title(
            "8.9F", 
            fontsize=8,
            color="#117733",
            horizontalalignment="center",
        )

    # Change all spines
    for axis in ["top", "bottom", "left", "right"]:
        chart.spines[axis].set_linewidth(1)
    chart.tick_params(axis="both", length=3, width=1)
    chart.grid(False)
    sns.despine()

    # Label points on each scatter plot
    for i in range(0, AA_level_df.shape[0]):
        x_pos = AA_level_df.at[i, "mutation_frequency"]
        y_pos = AA_level_df.at[i, "escape_"+antibody]
        name = AA_level_df.at[i, "wildtype"] + str(AA_level_df.at[i, "site"]) + AA_level_df.at[i, "mutant"]
        if antibody == "2510C":
            if name == "E228D":
                chart.text(
                    x_pos,
                    y_pos+0.3,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#000000",
                )
        if antibody == "121F":
            if name == "N89D":
                chart.text(
                    x_pos,
                    y_pos+0.3,
                    f"*{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#000000",
                    style="italic",
                )
        if antibody == "377H":
            if name == "H398K":
                chart.text(
                    x_pos,
                    y_pos+0.3,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#000000",
                )
            if name == "D401E":
                chart.text(
                    x_pos,
                    y_pos+0.3,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#000000",
                )
        if antibody == "372D":
            if name == "H398K":
                chart.text(
                    x_pos,
                    y_pos+0.15,
                    f"{name}",
                    fontsize=6,
                    horizontalalignment="left",
                    color="#000000",
                )
        if antibody == "89F":
            if name == "K126N":
                # Annotate text with arrow
                chart.annotate(
                    text=f"{name}",
                    xy=(x_pos, y_pos),
                    xytext=(x_pos+0.0075, y_pos+0.3),
                    arrowprops=dict(
                        arrowstyle="-|>",
                        lw=0.5,
                        relpos=(0.05,0.5),
                        color="#000000",
                    ),
                    bbox=dict(pad=-5, facecolor="none", edgecolor="none"),
                    fontsize=6,
                    horizontalalignment="left",
                    color="#000000",
                )

# Common X axis labels
fig.text(0.5, -0.3, "frequency of amino-acid mutation in natural sequences", ha="center", rotation="horizontal", fontsize=8)

# Save fig
fig.savefig(natural_escape)

In [None]:
# Set figure size and subplots
fig, ax = plt.subplots(figsize=(4, 2))

chart = sns.scatterplot(
    data=site_level_df,
    y="n_effective",
    x="cocktail_avg_site_escape",
    palette={
        "none" : "#00000026", 
        "GPC-A" : "#44AA99CC",
        "GPC-B" : "#AA4499CC",
        "GPC-C" : "#117733CC",
        "GP1-A" : "#999933CC",
    },
    hue="site_label",
    edgecolor=None,
    linewidth=0.5,
    s=15,
    ax=ax,
    legend=False,
)

chart.set_ylim(0.9, 4.1)
yticks = [1, 2, 3, 4]
chart.set_yticks(yticks)
y_labels=["1", "2", "3", "4"]
chart.set_yticklabels(labels=y_labels, fontsize=8)

chart.set_xlim(-0.5, 20.05)
xticks = [0, 5, 10, 15, 20]
chart.set_xticks(xticks)
chart.set_xticklabels(labels=["0", "5", "10", "15", "20"], fontsize=8)

# Change all spines
for axis in ["top", "bottom", "left", "right"]:
    chart.spines[axis].set_linewidth(1)
chart.tick_params(axis="both", length=3, width=1)
chart.grid(False)
sns.despine()


chart.set_ylabel("effective amino acids\nin natural sequences", fontsize=8)
chart.set_xlabel("site escape averaged across\nArevirumab-3 (12.1F, 8.9F, and 37.2D)", fontsize=8)

# Label points on each scatter plot
for i in range(0, site_level_df.shape[0]):
    y_pos = site_level_df.at[i, "n_effective"]
    x_pos = site_level_df.at[i, "cocktail_avg_site_escape"]
    label = site_level_df.at[i, "site_label"]
    name = site_level_df.at[i, "site"]
    color = None
    if label == "GPC-B":
        color="#AA4499CC"
    if label == "GP1-A":
        color="#999933CC"
    if label == "GPC-A":
        color="#44AA99CC"
    if label == "GPC-C":
        color="#117733CC"

    if label != "none":
        if name in [152]:
            chart.text(
                x_pos,
                y_pos+0.1,
                f"{name}",
                fontsize=6,
                horizontalalignment="right",
                color=color,
            )
        elif name in [126]:
            chart.text(
                x_pos,
                y_pos+0.1,
                f"{name} ",
                fontsize=6,
                horizontalalignment="right",
                color=color,
            )
        elif name in [116]:
            chart.text(
                x_pos+0.25,
                y_pos,
                f"{name}",
                fontsize=6,
                horizontalalignment="left",
                color=color,
            )
        elif name in [111]:
            chart.text(
                x_pos,
                y_pos+0.1,
                f"  {name}",
                fontsize=6,
                horizontalalignment="center",
                color=color,
            )
        elif name in [248]:
            chart.text(
                x_pos+0.15,
                y_pos+0.1,
                f"{name}",
                fontsize=6,
                horizontalalignment="left",
                color=color,
            )
        elif name in [127]:
            chart.text(
                x_pos,
                y_pos+0.1,
                f" {name}",
                fontsize=6,
                horizontalalignment="center",
                color=color,
            )
        else:
            chart.text(
                x_pos,
                y_pos+0.1,
                f"{name}",
                fontsize=6,
                horizontalalignment="left",
                color=color,
            )

# Save fig
fig.savefig(total_natural_escape)