### This script analyzes filtered mAb escape data

In [None]:
# this cell is tagged as parameters for `papermill` parameterization
binding_data = None
HENV103_filter = None
HENV117_filter = None
HENV26_filter = None
HENV32_filter = None
m102_filter = None
nAH1_filter = None

altair_config = None
nipah_config = None
escape_bubble_plot = None
bubble_1_mut_plot = None
mab_line_escape_plot = None
aggregate_mab_and_binding = None
aggregate_mab_and_niv_polymorphism = None
binding_vs_escape = None

mab_plot_top = None
mab_plot_all = None

In [None]:
if binding_data is None:
    print("this is being run manually")
else:
    print("papermill!")

In [None]:
import math
import os
import re

import altair as alt

import numpy as np

import pandas as pd

import scipy.stats

import Bio.SeqIO
import yaml
import matplotlib

matplotlib.rcParams["svg.fonttype"] = "none"

from Bio import PDB
import dmslogo
from dmslogo.colorschemes import CBPALETTE
from dmslogo.colorschemes import ValueToColorMap

In [None]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

if (
    os.getcwd()
    == "/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/"
):
    pass
    print("Already in correct directory")
else:
    os.chdir(
        "/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/"
    )
    print("Setup in correct directory")

### For running interactively

In [None]:
if binding_vs_escape is None:
    altair_config = "data/custom_analyses_data/theme.py"
    nipah_config = "nipah_config.yaml"

    binding_data = "results/filtered_data/E2_binding_filtered.csv"

    HENV103_filter = "results/filtered_data/HENV103_escape_filtered.csv"
    HENV117_filter = "results/filtered_data/HENV117_escape_filtered.csv"
    HENV26_filter = "results/filtered_data/HENV26_escape_filtered.csv"
    HENV32_filter = "results/filtered_data/HENV32_escape_filtered.csv"
    m102_filter = "results/filtered_data/m102_escape_filtered.csv"
    nAH1_filter = "results/filtered_data/nAH1_escape_filtered.csv"

    # escape_bubble_plot = 'results/images/escape_bubble_plot.html'
    # bubble_1_mut_plot = 'results/images/escape_bubble_1_mut_plot.html'
    # overlap_escape_plot = 'results/images/overlap_escape_plot.html'

    # m102_heat = 'results/images/m102_heatmap.html'
    # HENV26_heat = 'results/images/HENV26_heatmap.html'
    # HENV32_heat = 'results/images/HENV32_heatmap.html'
    # nAH1_heat = 'results/images/nAH1_heatmap.html'
    # HENV117_heat = 'results/images/HENV117_heatmap.html'
    # HENV103_heat = 'results/images/HENV103_heatmap.html'

In [None]:
if altair_config:
    with open(altair_config, "r") as file:
        exec(file.read())

with open(nipah_config) as f:
    config = yaml.safe_load(f)

# Make logo plots

### Filtering parameters

In [None]:
# Make a dataframe with all the mutants with low entry scores for masking later in script
func_scores_E3 = pd.read_csv(
    "../Nipah_Malaysia_RBP_DMS/results/func_effects/averages/CHO_EFNB3_low_func_effects.csv"
)
func_scores_E3_low_effect = func_scores_E3[
    (func_scores_E3["effect"] < config["min_func_effect_for_ab"])
    & (func_scores_E3["times_seen"] > config["func_times_seen_cutoff"])
    & (func_scores_E3["site"] != 603)
    & (func_scores_E3["mutant"] != "-")
    & (func_scores_E3["mutant"] != "*")
]
display(func_scores_E3_low_effect)

### Read in filtered antibody escape files and combine.

In [None]:
HENV103 = pd.read_csv(HENV103_filter)
HENV117 = pd.read_csv(HENV117_filter)
HENV26 = pd.read_csv(HENV26_filter)
HENV32 = pd.read_csv(HENV32_filter)
m102 = pd.read_csv(m102_filter)
nAH1 = pd.read_csv(nAH1_filter)

# Combine all the individual filtered antibody escape files
combined_df = pd.concat([HENV103, HENV117, HENV26, HENV32, m102, nAH1])
combined_df = combined_df[
    [
        "site",
        "wildtype",
        "mutant",
        "mutation",
        "effect",
        "escape_median",
        "escape_std",
        "times_seen_ab",
        "show_site",
        "ab",
    ]
]
display(combined_df)

# Make a separate dataframe that only has the top sites
filtered_df = combined_df.query("show_site == True")
filtered_df = filtered_df[filtered_df["escape_median"] >= config["min_escape_cutoff"]]
display(filtered_df)

In [None]:
def identify_escape_sites(df, ab):
    subset = df[(df["ab"] == ab)]
    unique_sites = list(subset["site"].unique())
    return unique_sites


abs = ["HENV-26", "HENV-103", "HENV-32", "HENV-117", "m102.4", "nAH1.3"]
sites_dict = {}  # Create an empty dictionary to store the results

for ab in abs:
    sites_dict[ab] = identify_escape_sites(filtered_df, ab)

display(sites_dict)  # need site dict for later

### Plot bubble chart showing mAb escape for individual mutants by functional score for both E2 or E3

In [None]:
order_ab = ["m102.4", "HENV-26", "HENV-117", "HENV-103", "HENV-32", "nAH1.3"]


def generate_chart(df):
    variant_selector = alt.selection_point(
        on="mouseover", empty=False, fields=["site"], value=1
    )
    chart = (
        alt.Chart(
            df,
            title=alt.Title(
                "Top Antibody Escape Mutations",
                subtitle="Hover over points to see escape at same site",
            ),
        )
        .mark_point(stroke="black")
        .encode(
            x=alt.X(
                "ab:O",
                sort=order_ab,
                title="Antibody",
                axis=alt.Axis(labelAngle=-90, grid=False),
            ),
            y=alt.Y(
                "effect:Q",
                title="Cell Entry of Top Escape",
                axis=alt.Axis(
                    grid=True, tickCount=4, values=[0.5, 0, -0.5, -1, -1.5, -2]
                ),
            ),
            size=alt.Size(
                "escape_median", legend=alt.Legend(title="Mean Escape By Mutation")
            ),
            xOffset="random:Q",
            tooltip=[
                "site",
                "wildtype",
                "mutant",
                "ab",
                "effect",
                "escape_median",
                "escape_std",
            ],
            color=alt.Color("ab").legend(None),
            opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.2)),
            strokeWidth=alt.condition(variant_selector, alt.value(2), alt.value(0)),
        )
        .transform_calculate(
            random="sqrt(-1*log(random()))*cos(2*PI*random())"
        )
        .properties(width=config["bubble_width"], height=config["bubble_height"])
        .add_params(variant_selector)
    )

    return chart


escape_bubble = generate_chart(filtered_df)
escape_bubble.display()
if mab_line_escape_plot is not None:
    escape_bubble.save(escape_bubble_plot)

### Now summarize by number of mutations between wildtype and mutant codons

In [None]:
# Load in wt nucleotide sequence (which is different than the 'wt' sequence from Library as it was codon optimized)
niv_m_wt = str(
    Bio.SeqIO.read(
        "data/custom_analyses_data/alignments/wild_type_seq.fasta", "fasta"
    ).seq
)

codon_table = {
    "ATA": "I",
    "ATC": "I",
    "ATT": "I",
    "ATG": "M",
    "ACA": "T",
    "ACC": "T",
    "ACG": "T",
    "ACT": "T",
    "AAC": "N",
    "AAT": "N",
    "AAA": "K",
    "AAG": "K",
    "AGC": "S",
    "AGT": "S",
    "AGA": "R",
    "AGG": "R",
    "CTA": "L",
    "CTC": "L",
    "CTG": "L",
    "CTT": "L",
    "CCA": "P",
    "CCC": "P",
    "CCG": "P",
    "CCT": "P",
    "CAC": "H",
    "CAT": "H",
    "CAA": "Q",
    "CAG": "Q",
    "CGA": "R",
    "CGC": "R",
    "CGG": "R",
    "CGT": "R",
    "GTA": "V",
    "GTC": "V",
    "GTG": "V",
    "GTT": "V",
    "GCA": "A",
    "GCC": "A",
    "GCG": "A",
    "GCT": "A",
    "GAC": "D",
    "GAT": "D",
    "GAA": "E",
    "GAG": "E",
    "GGA": "G",
    "GGC": "G",
    "GGG": "G",
    "GGT": "G",
    "TCA": "S",
    "TCC": "S",
    "TCG": "S",
    "TCT": "S",
    "TTC": "F",
    "TTT": "F",
    "TTA": "L",
    "TTG": "L",
    "TAC": "Y",
    "TAT": "Y",
    "TAA": "*",
    "TAG": "*",
    "TGC": "C",
    "TGT": "C",
    "TGA": "*",
    "TGG": "W",
}


def find_closest_codon(wt_codon, mutant_aa):
    mutant_codons = [codon for codon, aa in codon_table.items() if aa == mutant_aa]
    min_mutations = 3  # Maximum mutations possible
    closest_codon = None
    for m_codon in mutant_codons:
        mutations = sum([1 for c1, c2 in zip(wt_codon, m_codon) if c1 != c2])
        if mutations < min_mutations:
            min_mutations = mutations
            closest_codon = m_codon
    return closest_codon, min_mutations


# Function to extract codon for a given site
def extract_codon(site):
    idx = (site - 1) * 3
    return niv_m_wt[idx : idx + 3]


def extract_codon_niv_b(site):
    idx = (site - 1) * 3
    return niv_m_wt[idx : idx + 3]


def apply_codon_to_df(df, extract_func):
    df["wt_codon"] = df["site"].apply(extract_func)
    df["closest_mutant_codon"] = df.apply(
        lambda row: find_closest_codon(row["wt_codon"], row["mutant"])[0], axis=1
    )
    df["min_mutations"] = df.apply(
        lambda row: find_closest_codon(row["wt_codon"], row["mutant"])[1], axis=1
    )
    return df


combined_df = apply_codon_to_df(combined_df, extract_codon)
filtered_df = apply_codon_to_df(filtered_df, extract_codon)

In [None]:
def generate_chart_all(df):
    variant_selector = alt.selection_point(
        on="mouseover", empty=False, fields=["site"], value=1
    )
    radio = alt.binding_radio(
        options=[1, 2, 3], labels=["1", "2", "3"], name="Min Mutations:"
    )
    mutation_selector = alt.param(name="MutationSelector", value=1, bind=radio)

    slider = alt.binding_range(min=0.2, max=1.6, step=0.1, name="median_escape")
    selector = alt.param(name="SelectorName", value=0.2, bind=slider)

    chart = (
        alt.Chart(
            df,
            title=alt.Title(
                "Antibody Escape Mutations",
                subtitle="Hover over points to see escape at same site",
            ),
        )
        .mark_point(filled=True, stroke="black")
        .encode(
            x=alt.X(
                "ab:O",
                sort=order_ab,
                title="Antibody",
                axis=alt.Axis(labelAngle=-90, grid=False),
            ),
            y=alt.Y(
                "effect:Q",
                title="Cell Entry of Top Escape",
                axis=alt.Axis(
                    grid=True, tickCount=4, values=[0.5, 0, -0.5, -1, -1.5, -2]
                ),
            ),
            size=alt.Size(
                "escape_median", legend=alt.Legend(title="Mean Escape By Mutation")
            ),
            xOffset="random:Q",
            tooltip=[
                "site",
                "wildtype",
                "mutant",
                "ab",
                "effect",
                "escape_median",
                "escape_std",
            ],
            color=alt.Color("ab").legend(None),
            opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.4)),
            strokeWidth=alt.condition(variant_selector, alt.value(2), alt.value(0)),
        )
        .transform_calculate(
            random="sqrt(-1*log(random()))*cos(2*PI*random())"
            # random='random'
        )
        #.properties(width=config["bubble_width"], height=config["bubble_height"])
        .properties(width=200,height=250)
        .add_params(variant_selector, mutation_selector, selector)
        .transform_filter(
            (alt.datum.min_mutations == mutation_selector)
            & (alt.datum.escape_median > selector)
        )
    )

    return chart


#all_escape = generate_chart_all(combined_df.query("escape_median >= 0.2"))
all_escape = generate_chart_all(filtered_df)

all_escape.display()

In [None]:
# Make combined figure
combined_bubble_plots = (escape_bubble | all_escape)
combined_bubble_plots.display()

In [None]:
def plot_escape_and_mutations_away(df):
    variant_selector = alt.selection_point(
        on="mouseover", empty=False, fields=["site"], value=1
    )
    radio = alt.binding_radio(options=[1, 2, 3], name="Min Mutations:")
    mutation_selector = alt.param(name="MutationSelector", value=1, bind=radio)

    chart = (
        alt.Chart(
            df,
            title=alt.Title(
                "Top Antibody Escape Mutations",
                subtitle="By # of nucleotide mutations away",
            ),
        )
        .mark_point(filled=True, stroke="black")
        .encode(
            x=alt.X(
                "ab:O",
                sort=order_ab,
                title=None,
                axis=alt.Axis(labelAngle=-90, grid=False),
            ),
            y=alt.Y(
                "effect:Q",
                title="Cell Entry of Escape Mutants",
                axis=alt.Axis(
                    grid=True, tickCount=4, values=[0.5, 0, -0.5, -1, -1.5, -2]
                ),
            ),  # 'Q' denotes a quantitative variable
            size=alt.Size("escape_median", legend=alt.Legend(title="Escape of Mutant")),
            xOffset="random:Q",
            tooltip=["ab", "effect", "escape_median", "site", "mutant"],
            opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.4)),
            strokeWidth=alt.condition(variant_selector, alt.value(2), alt.value(0)),
            color=alt.Color("ab:N").legend(None),
        )
        .transform_calculate(
            # random='random()'
            random="sqrt(-2*log(random()))*cos(2*PI*random())"
        )
        .properties(width=config["bubble_width"], height=config["bubble_height"])
        .add_params(variant_selector, mutation_selector)
        .transform_filter((alt.datum.min_mutations == mutation_selector))
    )
    return chart


bubble_plot_1_mut_away = plot_escape_and_mutations_away(filtered_df)
bubble_plot_1_mut_away.display()
if mab_line_escape_plot is not None:
    bubble_plot_1_mut_away.save(bubble_1_mut_plot)

In [None]:
def find_overlapping_escape(df):
    slider = alt.binding_range(
        min=config["min_func_effect_for_ab"], max=0, step=0.25, name="effect"
    )
    selector = alt.param(name="SelectorName", value=-4, bind=slider)

    radio = alt.binding_radio(options=[1, 2, 3], name="Min Mutations:")
    mutation_selector = alt.param(name="MutationSelector", value=1, bind=radio)

    df_filtered = df
    # Group by 'site' and 'mutant', count the unique 'ab' values for each group
    grouped = df_filtered.groupby(["site", "mutant"])["ab"].nunique().reset_index()

    # Filter groups where the count of unique 'ab' values is at least 2
    result = grouped[grouped["ab"] >= 2]

    # Merge the result with the original dataframe to get the full rows
    df_result = pd.merge(df, result[["site", "mutant"]], on=["site", "mutant"])
    df_result["mutation_number"] = (
        df_result["mutation"].str.extract("(\d+)").astype(int)
    )
    base = (
        (
            alt.Chart(df_result, title=alt.Title("Shared antibody escape mutations"))
            .mark_rect()
            .encode(
                x=alt.X(
                    "mutation:O",
                    title="Site",
                    sort=alt.EncodingSortField(field="mutation_number"),
                    axis=alt.Axis(labelAngle=-90, grid=False),
                ),
                y=alt.Y(
                    "ab:O", title="Mutant", sort=order_ab, axis=alt.Axis(grid=False)
                ),  # Apply custom sort order here
                color="escape_median",
                tooltip=[
                    "site",
                    "wildtype",
                    "mutant",
                    "escape_median",
                    "min_mutations",
                ],
            )
        )
        .properties(height=200, width=400)
        .add_params(selector, mutation_selector)
        .transform_filter(
            (alt.datum.effect >= selector)
            & (alt.datum.min_mutations == mutation_selector)
        )
    )
    return base


overlap_escape = find_overlapping_escape(filtered_df)
overlap_escape.display()
if mab_line_escape_plot is not None:
    overlap_escape.save(overlap_escape_plot)

### Line plots of escape

In [None]:
def plot_line_escape(df):
    variant_selector = alt.selection_point(
        on="mouseover", empty=False, fields=["site"], value=0
    )
    # Group by 'site' and 'mutant', count the unique 'ab' values for each group
    summed = df.groupby(["site", "ab"])["escape_median"].sum().reset_index()
    empty_chart = []
    ab_list = ["m102.4", "HENV-26", "HENV-117", "HENV-103", "HENV-32", "nAH1.3"]
    for idx, ab in enumerate(ab_list):
        tmp_df = summed[summed["ab"] == ab]
        # color = '#1f4e79'
        if ab in ["m102.4", "HENV-26", "HENV-117"]:
            color = "#1f4e79"
        if ab in ["HENV-103", "HENV-32"]:
            color = "#ff7f0e"
        if ab in ["nAH1.3"]:
            color = "#2ca02c"

        # Conditionally set the x-axis labels and title for the last plot
        is_last_plot = idx == len(ab_list) - 1
        x_axis = alt.Axis(
            values=[100, 200, 300, 400, 500, 600],
            tickCount=6,
            labelAngle=-90,
            grid=True,
            labelExpr="datum.value % 100 === 0 ? datum.value : ''",
            title="Site" if is_last_plot else None,
            labels=is_last_plot,
        )  # Only show labels for the last plot
        base = (
            alt.Chart(tmp_df)
            .mark_line(size=1, color=color)
            .encode(
                x=alt.X("site:O", axis=x_axis),
                y=alt.Y("escape_median", title=f"{ab}", axis=alt.Axis(grid=True)),
            )
            .properties(
                width=config["large_line_width"], height=config["large_line_height"]
            )
        )
        point = (
            base.mark_point(color="black", size=10, filled=True)
            .encode(
                x=alt.X("site:O", axis=x_axis),
                y=alt.Y(
                    "escape_median",
                    title=f"{ab}",
                    axis=alt.Axis(
                        grid=True,
                    ),
                ),
                size=alt.condition(variant_selector, alt.value(100), alt.value(15)),
                color=alt.condition(
                    variant_selector, alt.value("black"), alt.value(color)
                ),
                tooltip=["site", "escape_median"],
            )
            .add_params(variant_selector)
        )
        chart = base + point
        empty_chart.append(chart)

    # Use configure_concat to adjust spacing between vertically concatenated plots
    combined_chart = (
        alt.vconcat(*empty_chart, spacing=1)
        .resolve_scale(y="independent", x="shared", color="independent")
        .properties(
            title=alt.Title(
                "Summed Antibody Escape by Site", subtitle="Colored by epitope"
            )
        )
    )

    return combined_chart


tmp_line = plot_line_escape(combined_df)
tmp_line.display()
if mab_line_escape_plot is not None:
    tmp_line.save(mab_line_escape_plot)

### Now calculate atomic distances between escape sites and closest amino acid in heavy and light chains

In [None]:
def calculate_min_distances(pdb_path, source_chain_id, target_chain_ids, name):
    # Initialize the PDB parser and load the structure
    parser = PDB.PDBParser(QUIET=True)
    structure = parser.get_structure("structure_id", pdb_path)

    source_chain = structure[0][source_chain_id]
    target_chains = [structure[0][chain_id] for chain_id in target_chain_ids]

    data = []

    for residueA in source_chain:
        if residueA.resname in ["HOH", "WAT", "IPA", "NAG"]:
            continue

        min_distance = float("inf")
        closest_residueB = None
        closest_chain_id = None
        residues_within_4 = 0

        for target_chain in target_chains:
            for residueB in target_chain:
                if residueB.resname in ["HOH", "WAT", "IPA"]:
                    continue

                # Check for residues within 4 angstroms
                is_within_4 = False
                for atomA in residueA:
                    for atomB in residueB:
                        distance = atomA - atomB
                        if distance < min_distance:
                            min_distance = distance
                            closest_residueB = residueB
                            closest_chain_id = target_chain.get_id()
                        if distance < 4:
                            is_within_4 = True
                if is_within_4:
                    residues_within_4 += 1

        data.append(
            {
                "wildtype": residueA.resname,
                "site": residueA.id[1],
                "chain": closest_chain_id,
                "residue": closest_residueB.id[1],
                "residue_name": closest_residueB.resname,
                "distance": min_distance,
                "residues_within_4": residues_within_4,
                "ab": name,
            }
        )

    # Convert data to pandas DataFrame
    df = pd.DataFrame(data)
    return df


def check_file(input_path, source_chain, target_chain, name, output_path):

    file_path = output_path

    if not os.path.exists(file_path):
        print(f"File {name} does not exist, running calculation")
        output_df = calculate_min_distances(
            input_path, source_chain, target_chain, name
        )
        print(f"done calculating for {file_path}")
        output_df.to_csv(output_path, index=False)
        return output_df
    else:
        print("File already exists,loading from disk")
        output_df = pd.read_csv(output_path)
        return output_df


pdb_path_26 = "data/custom_analyses_data/crystal_structures/6vy5.pdb"
source_chain_26 = "A"
target_chains_26 = ["H", "L"]
output_path_26 = "results/distances/df_HENV26_atomic_distances.csv"

pdb_path_32 = "data/custom_analyses_data/crystal_structures/6vy4.pdb"
source_chain_32 = "A"
target_chains_32 = ["H", "L"]
output_path_32 = "results/distances/df_HENV32_atomic_distances.csv"

pdb_path_nah = "data/custom_analyses_data/crystal_structures/7txz.pdb"
source_chain_nah = "A"
target_chains_nah = ["F", "E"]
output_path_nah = "results/distances/df_nAH_atomic_distances.csv"

pdb_path_m102 = "data/custom_analyses_data/crystal_structures/6cmg.pdb"
source_chain_m102 = "A"
target_chains_m102 = ["B", "C"]
output_path_m102 = "results/distances/df_m102_atomic_distances.csv"


df_HENV26 = check_file(
    pdb_path_26, source_chain_26, target_chains_26, "HENV-26", output_path_26
)
df_HENV32 = check_file(
    pdb_path_32, source_chain_32, target_chains_32, "HENV-32", output_path_32
)
df_nah = check_file(
    pdb_path_nah, source_chain_nah, target_chains_nah, "nAH1.3", output_path_nah
)
df_nah["chain"].replace(
    {"E": "H", "F": "L"}, inplace=True
)  # Fix naming so consistent heavy and light chain naming
df_m102 = check_file(
    pdb_path_m102, source_chain_m102, target_chains_m102, "m102.4", output_path_m102
)
df_m102["chain"].replace(
    {"C": "H", "B": "L"}, inplace=True
)  # Fix naming so consistent heavy and light chain naming

In [None]:
def find_close_mab_sites(df, name):
    unique_sites = df.query("distance <= 4")["site"].unique()
    mab_site_list = list(unique_sites)
    print(f"Close sites for mAb {name} are: {mab_site_list}")
    return mab_site_list


### First find RBP sites that are close to mAb residues
nah_close = find_close_mab_sites(df_nah, "nAH1.3")
HENV26_close = find_close_mab_sites(df_HENV26, "HENV-26")
HENV32_close = find_close_mab_sites(df_HENV32, "HENV-32")
m102_close = find_close_mab_sites(df_m102, "m102.4")

### Now combined the close residues AND the top escape sites identified previously
nah_combined_sites = sites_dict["nAH1.3"] + nah_close
HENV26_combined_sites = sites_dict["HENV-26"] + HENV26_close
HENV32_combined_sites = sites_dict["HENV-32"] + HENV32_close
m102_combined_sites = sites_dict["m102.4"] + m102_close

In [None]:
def make_distance(df):
    subset_df = df[["site", "distance"]].copy()
    subset_df["mutant"] = "distance"
    subset_df["wildtype"] = ""
    subset_df["effect"] = "escape_median"
    subset_df.rename(columns={"distance": "value"}, inplace=True)
    return subset_df


distance_nah_df = make_distance(df_nah)
distance_26_df = make_distance(df_HENV26)
distance_32_df = make_distance(df_HENV32)
distance_m102_df = make_distance(df_m102)

display(distance_m102_df)

### Prepare dataframes for heatmaps

In [None]:
def make_empty_df_with_distance(ab, distance_file):
    # print(ab)
    sites = range(71, 603)
    amino_acids = [
        "A",
        "C",
        "D",
        "E",
        "F",
        "G",
        "H",
        "I",
        "K",
        "L",
        "M",
        "N",
        "P",
        "Q",
        "R",
        "S",
        "T",
        "V",
        "W",
        "Y",
    ]
    # Create the combination of each site with each amino acid
    data = [{"site": site, "mutant": aa} for site in sites for aa in amino_acids]
    # Create the DataFrame
    empty_df = pd.DataFrame(data)
    all_sites_df = pd.merge(
        empty_df, combined_df.query(f'ab == "{ab}"'), on=["site", "mutant"], how="left"
    )
    df_melted = all_sites_df.melt(
        id_vars=["site", "mutant", "wildtype"],
        value_vars=["escape_median"],
        var_name="effect",
        value_name="value",
    )

    df_filtered = func_scores_E3_low_effect.melt(
        id_vars=["site", "mutant", "wildtype"],
        value_vars=["effect"],
        var_name="effect",
        value_name="value",
    )

    df_test = pd.concat([df_melted, df_filtered, distance_file], ignore_index=True)
    df_test["ab"] = ab
    return df_test


empty_df_m102 = make_empty_df_with_distance("m102.4", distance_m102_df)
empty_df_HENV26 = make_empty_df_with_distance("HENV-26", distance_26_df)
empty_df_HENV32 = make_empty_df_with_distance("HENV-32", distance_32_df)
empty_df_nah = make_empty_df_with_distance("nAH1.3", distance_nah_df)


def make_empty_df(ab):
    sites = range(71, 603)
    amino_acids = [
        "A",
        "C",
        "D",
        "E",
        "F",
        "G",
        "H",
        "I",
        "K",
        "L",
        "M",
        "N",
        "P",
        "Q",
        "R",
        "S",
        "T",
        "V",
        "W",
        "Y",
    ]
    # Create the combination of each site with each amino acid
    data = [{"site": site, "mutant": aa} for site in sites for aa in amino_acids]
    # Create the DataFrame
    empty_df = pd.DataFrame(data)
    all_sites_df = pd.merge(
        empty_df, combined_df.query(f'ab == "{ab}"'), on=["site", "mutant"], how="left"
    )
    df_melted = all_sites_df.melt(
        id_vars=["site", "mutant", "wildtype"],
        value_vars=["escape_median"],
        var_name="effect",
        value_name="value",
    )

    df_filtered = func_scores_E3_low_effect.melt(
        id_vars=["site", "mutant", "wildtype"],
        value_vars=["effect"],
        var_name="effect",
        value_name="value",
    )

    df_test = pd.concat([df_melted, df_filtered], ignore_index=True)
    df_test["ab"] = ab
    return df_test


empty_df_HENV117 = make_empty_df("HENV-117")
empty_df_HENV103 = make_empty_df("HENV-103")

combined_ab = pd.concat(
    [
        empty_df_m102,
        empty_df_HENV26,
        empty_df_HENV32,
        empty_df_nah,
        empty_df_HENV117,
        empty_df_HENV103,
    ]
)
display(combined_ab)

In [None]:
def plot_distance_only(df, trigger):
    custom_order = [
        "distance",
        "R",
        "K",
        "H",
        "D",
        "E",
        "Q",
        "N",
        "S",
        "T",
        "Y",
        "W",
        "F",
        "A",
        "I",
        "L",
        "M",
        "V",
        "G",
        "P",
        "C",
    ]
    all_residues = range(71, 603)
    final_df = df
    final_df = final_df.sort_values(
        "site"
    )  # Sort the dataframe by 'site' to ensure that duplicates are detected correctly.
    sort_order = {
        mutant: i for i, mutant in enumerate(custom_order)
    }  # Create a dictionary that maps each mutant to its sort rank based on the custom order
    final_df["mutant_rank"] = final_df["mutant"].map(
        sort_order
    )  # Map the 'mutant' column to these ranks

    final_df = final_df.sort_values(
        "mutant_rank"
    )  # Now sort the dataframe by this rank
    final_df = final_df.drop(
        columns=["mutant_rank"]
    )  # Drop the 'mutant_rank' column as it is no longer needed after sorting
    sites = sorted(final_df["site"].unique(), key=lambda x: float(x))
    ab_list = ["m102.4", "HENV-26", "HENV-117", "HENV-103", "HENV-32", "nAH1.3"]
    empty_chart = []  # setup collection for charts
    for idx, ab in enumerate(ab_list):
        tmp_df = final_df[final_df["ab"] == ab]
        if ab == "m102.4":
            site_subset = m102_combined_sites
            # legend_conditional = alt.Legend(title='Distance to mAb')
        if ab == "HENV-26":
            site_subset = HENV26_combined_sites
            # legend_conditional = alt.Legend(title='Distance to mAb')
        if ab == "HENV-32":
            site_subset = HENV32_combined_sites
            # legend_conditional = alt.Legend(title='Distance to mAb')
        if ab == "HENV-103":
            site_subset = sites_dict["HENV-103"]
            # legend_conditional = alt.Legend(title=None)
        if ab == "HENV-117":
            site_subset = sites_dict["HENV-117"]
            # legend_conditional = alt.Legend(title=None)
        if ab == "nAH1.3":
            site_subset = nah_combined_sites
            # legend_conditional = alt.Legend(title='Distance to mAb')

        # select which sites you will show
        if trigger == True:
            tmp_df = tmp_df[tmp_df["site"].isin(site_subset)]
            x_axis = alt.Axis(
                labelAngle=-90,
                # labelExpr="datum.value % 10 === 0 ? datum.value : ''",
                title="Site",
            )
        else:
            tmp_df = tmp_df[tmp_df["site"].isin(all_residues)]

            # Conditionally set the x-axis labels and title for the last plot
            is_last_plot = idx == len(ab_list) - 1
            x_axis = alt.Axis(
                labelAngle=-90,
                labelExpr="datum.value % 10 === 0 ? datum.value : ''",
                title="Site" if is_last_plot else None,
                labels=True,
            )  # Only show labels for the last plot

        # Prepare the color scales separately for distance and effects
        # Filter out 'distance' values before creating the effect heatmap
        effect_df = tmp_df[
            (tmp_df["mutant"] != "distance") & (tmp_df["effect"] != "effect")
        ]
        max_color = effect_df["value"].max()
        min_color = effect_df["value"].min()

        # Adjust color scheme for abs with little sensitizing mutations
        if min_color > -1:
            min_color = min_color - 1

        # Prepare the color scale for effects, Altair will automatically determine the domain
        color_scale_escape = alt.Scale(
            scheme="redblue", domainMid=0, domain=[min_color, max_color]
        )
        color_scale_entropy = alt.Scale(scheme="greens", domain=[0, 15], reverse=True)

        strokewidth_size = 0.25

        unique_wildtypes_df = tmp_df.drop_duplicates(subset=["site", "wildtype"])

        # The chart for the heatmap
        base = (
            alt.Chart(tmp_df, title=f"{ab}")
            .encode(
                x=alt.X("site:O", title="Site", sort=sites, axis=x_axis),
                y=alt.Y(
                    "mutant",
                    title="Amino Acid",
                    sort=alt.EncodingSortField(field="sort_order", order="ascending"),
                    axis=alt.Axis(grid=False),
                ),  # Apply custom sort order here
                tooltip=["site", "wildtype", "mutant", "value"],
            )
            .properties(width=alt.Step(10), height=alt.Step(11))
        )
        # Heatmap for distance
        chart_empty = (
            base.mark_rect(color="#e6e7e8")
            .encode()
            .transform_filter(alt.datum.effect == "escape_median")
        )
        # Heatmap for effect
        chart_effect = (
            base.mark_rect(stroke="black", strokeWidth=strokewidth_size)
            .encode(
                color=alt.condition(
                    'datum.mutant != "distance"',
                    alt.Color(
                        "value:Q",
                        scale=color_scale_escape,
                        legend=alt.Legend(title=f"{ab} Escape"),
                    ),
                    alt.value("transparent"),
                ),
            )
            .transform_filter(alt.datum.effect == "escape_median")
        )

        # Heatmap for distance
        if ab in ["m102.4", "HENV-26", "HENV-32", "nAH1.3"]:
            chart_distance = (
                base.mark_rect()
                .encode(
                    color=alt.condition(
                        'datum.mutant == "distance"',
                        alt.Color(
                            "value:Q",
                            scale=color_scale_entropy,
                            legend=alt.Legend(title="Distance to mAb"),
                        ),
                        alt.value("transparent"),
                    )
                )
                .transform_filter(alt.datum.effect == "escape_median")
            )
        else:
            chart_distance = (
                base.mark_rect(color="transparent")
                .encode(
                    # color=alt.Color('white'),
                    # alt.Color('value:Q', scale=color_scale_entropy,legend=alt.Legend(title='Distance to mAb')),
                    # alt.value('transparent'))
                )
                .transform_filter(alt.datum.effect == "escape_median")
            )
        # Heatmap for distance
        chart_filtered = (
            base.mark_rect(
                color="#939598", stroke="black", strokeWidth=strokewidth_size
            )
            .encode()
            .transform_filter(alt.datum.effect == "effect")
        )

        # The layer for the wildtype boxes
        wildtype_layer_box = (
            alt.Chart(unique_wildtypes_df)
            .mark_rect(color="white", stroke="black", strokeWidth=strokewidth_size)
            .encode(
                x=alt.X("site:O", sort=sites),
                y=alt.Y(
                    "wildtype",
                    sort=alt.EncodingSortField(field="sort_order", order="ascending"),
                ),
                opacity=alt.value(1),
            )
            .transform_filter(
                (alt.datum.wildtype != "")
                & (alt.datum.wildtype != None)
                & (alt.datum.value != None)
            )
        )
        # The layer for the wildtype amino acids
        wildtype_layer = (
            alt.Chart(unique_wildtypes_df)
            .mark_text(color="black", text="X", size=8)
            .encode(
                x=alt.X("site:O", sort=sites),
                y=alt.Y(
                    "wildtype",
                    sort=alt.EncodingSortField(field="sort_order", order="ascending"),
                ),
                opacity=alt.value(1),
            )
            .transform_filter(
                (alt.datum.wildtype != "")
                & (alt.datum.wildtype != None)
                & (alt.datum.value != None)
            )
        )

        # Combine the heatmap layer with the wildtype layer
        chart = alt.layer(
            chart_empty,
            chart_effect,
            chart_distance,
            chart_filtered,
            wildtype_layer_box,
            wildtype_layer,
        ).resolve_scale(color="independent")
        empty_chart.append(chart)
    combined_chart = (
        alt.vconcat(*empty_chart, spacing=1)
        .resolve_scale(y="shared", x="independent", color="independent")
        .configure_title(
            anchor="start",  # Aligns the title to the left ('middle' for center, 'end' for right)
            offset=10,  # Adjusts the distance of the title from the chart
            orient="top",  # Positions the title at the top; use 'bottom' to position at the bottom
        )
    )
    return combined_chart


mab_plot = plot_distance_only(combined_ab, True)
mab_plot.display()
if mab_line_escape_plot is not None:
    mab_plot.save(mab_plot_top)

### Make full antibody escape heatmaps

In [None]:
mab_all = plot_distance_only(combined_ab, False)
mab_all.display()
if mab_line_escape_plot is not None:
    mab_all.save(mab_plot_all)

### Now make heatmaps of antibody escape versus Ephrin Binding

First prepare data:

In [None]:
bind_df = pd.read_csv(binding_data)
binding_df = bind_df.groupby("site")["binding_median"].median().reset_index()


def make_empty_binding():
    sites = range(71, 603)
    data = [{"site": site} for site in sites]
    empty_df = pd.DataFrame(data)
    empty_df = pd.merge(empty_df, binding_df, on="site", how="left")
    empty_df = empty_df.rename(columns={"binding_median": "value"})
    empty_df["effect"] = "escape_median"
    empty_df["ab"] = "Ephrin-B2 binding"
    return empty_df


binding_empty = make_empty_binding()

escape_df = combined_df.groupby(["ab", "site"])["escape_median"].median().reset_index()


def make_empty_df(ab):
    sites = range(71, 603)
    data = [{"site": site} for site in sites]

    # Create the DataFrame
    empty_df = pd.DataFrame(data)

    all_sites_df = pd.merge(
        empty_df, escape_df.query(f'ab == "{ab}"'), on=["site"], how="left"
    )

    df_melted = all_sites_df.melt(
        id_vars=["site"],
        value_vars=["escape_median"],
        var_name="effect",
        value_name="value",
    )

    df_test = pd.concat([df_melted], ignore_index=True)
    df_test["ab"] = ab
    return df_test


ab_list = ["m102.4", "HENV-26", "HENV-117", "HENV-103", "HENV-32", "nAH1.3"]
# ab_list = ['HENV-32']

empty = []
for ab in ab_list:
    tmp_df = make_empty_df(ab)
    empty.append(tmp_df)
all_empties_df = pd.concat(empty, ignore_index=True)
all_empties_df = pd.concat([all_empties_df, binding_empty])
display(all_empties_df)

In [None]:
def make_heatmap_with_binding(df):
    # Define the custom sort order directly in the encoding
    sort_order = [
        "NiV Polymorphism",
        "Ephrin-B2 binding",
        "m102.4",
        "HENV-26",
        "HENV-117",
        "HENV-103",
        "HENV-32",
        "nAH1.3",
    ]
    full_ranges = [
        list(range(start, end))
        for start, end in [(71, 181), (181, 291), (291, 401), (401, 511), (511, 603)]
    ]

    # container to hold the charts
    charts = []
    color_scale_effect = alt.Scale(scheme="redblue", domainMid=0)
    color_scale_binding = alt.Scale(scheme="redblue", domainMid=0)

    for idx, subset in enumerate(full_ranges):
        subset_df = df[df["site"].isin(subset)]  # for the wrapping of sites
        is_last_plot = idx == len(full_ranges) - 1
        x_axis = alt.Axis(
            labelAngle=-90,
            labelExpr="datum.value % 10 === 0 ? datum.value : ''",
            title="Site" if is_last_plot else None,
            labels=True,
        )  # Only show labels for the last plot

        effect_legend = (
            alt.Legend(title="Antibody Escape") if is_last_plot else None
        )  # ,direction='horizontal',gradientLength=50,titleAnchor='middle',tickCount=3,labelAlign='center')
        binding_legend = (
            alt.Legend(title="Henipavirus Entropy") if is_last_plot else None
        )  # ,direction='horizontal',gradientLength=50,titleAnchor='middle',labelAlign='center')
        print(is_last_plot)
        print(effect_legend)
        base = (
            alt.Chart(subset_df)
            .encode(
                x=alt.X("site:O", title="Site", axis=x_axis),
                y=alt.Y(
                    "ab", title=None, sort=sort_order, axis=alt.Axis(grid=False)
                ),  # Correctly apply custom sort order
                tooltip=["site", "value"],
            )
            .properties(width=alt.Step(10), height=alt.Step(11))
        )

        # Define the chart for empty cells
        chart_empty = base.mark_rect(color="#e6e7e8").transform_filter(
            alt.datum.effect == "escape_median"
        )

        # Define the chart for cells with effect
        chart_effect = (
            base.mark_rect(stroke="black", strokeWidth=0.25)
            .encode(
                color=alt.condition(
                    'datum.effect == "escape_median"',
                    alt.Color(
                        "value:Q", scale=color_scale_effect, legend=effect_legend
                    ),  # Define a color scale
                    alt.value("transparent"),
                )
            )
            .transform_filter(alt.datum.effect == "escape_median")
        )

        chart_binding = (
            base.mark_rect(strokeWidth=1.1)
            .encode(
                stroke=alt.value("value"),
                color=alt.condition(
                    'datum.effect == "escape_median"',
                    alt.Color(
                        "value:Q", scale=color_scale_binding, legend=binding_legend
                    ),
                    alt.value("transparent"),
                ),
            )
            .transform_filter(alt.datum.ab == "Ephrin-B2 binding")
        )

        chart_poly = (
            base.mark_rect(color="black")
            .encode()
            .transform_filter(alt.datum.ab == "NiV Polymorphism")
        )
        # Layer the charts using `layer` instead of `+`
        chart = alt.layer(
            chart_empty, chart_effect, chart_binding, chart_poly
        )  # .resolve_scale(color='shared')
        charts.append(chart)
    combined_chart = alt.vconcat(
        *charts, spacing=5, title="Heatmap of median mAb escape and Ephrin-B2 binding"
    )  

    return combined_chart


# Assuming `all_empties_df` is your DataFrame and already defined
chart = make_heatmap_with_binding(all_empties_df)
chart.display()
if mab_line_escape_plot is not None:
    chart.save(aggregate_mab_and_binding)

### Now show heatmap with nipah polymorphisms

In [None]:
def make_contact():
    df = pd.DataFrame({"site": niv_poly, "contact": [0.0] * len(niv_poly)})
    df = df[["site", "contact"]]
    # df['mutant'] = 'contact'
    df["ab"] = "NiV Polymorphism"
    df["effect"] = "median_escape"
    df.rename(columns={"contact": "value"}, inplace=True)
    return df


niv_poly = [
    82,
    89,
    135,
    172,
    228,
    236,
    274,
    288,
    299,
    325,
    328,
    329,
    335,
    339,
    344,
    376,
    384,
    385,
    386,
    421,
    423,
    424,
    426,
    427,
    470,
    478,
    481,
    498,
    502,
    545,
]
contact_df = make_contact()

bind_df = pd.read_csv("results/filtered_data/E2_binding_filtered.csv")
binding_df = bind_df.groupby("site")["binding_median"].max().reset_index()


def make_empty_binding():
    sites = range(71, 603)
    data = [{"site": site} for site in sites]
    empty_df = pd.DataFrame(data)
    empty_df = pd.merge(empty_df, binding_df, on="site", how="left")
    empty_df = empty_df.rename(columns={"binding_median": "value"})
    empty_df["effect"] = "escape_median"
    empty_df["ab"] = "Ephrin-B2 binding"
    return empty_df


binding_empty = make_empty_binding()

escape_df = combined_df.groupby(["ab", "site"])["escape_median"].max().reset_index()


def make_empty_df(ab):
    sites = range(71, 603)
    data = [{"site": site} for site in sites]

    # Create the DataFrame
    empty_df = pd.DataFrame(data)

    all_sites_df = pd.merge(
        empty_df, escape_df.query(f'ab == "{ab}"'), on=["site"], how="left"
    )

    df_melted = all_sites_df.melt(
        id_vars=["site"],
        value_vars=["escape_median"],
        var_name="effect",
        value_name="value",
    )

    df_test = pd.concat([df_melted], ignore_index=True)
    df_test["ab"] = ab
    return df_test


ab_list = ["m102.4", "HENV-26", "HENV-117", "HENV-103", "HENV-32", "nAH1.3"]

empty = []
for ab in ab_list:
    tmp_df = make_empty_df(ab)
    empty.append(tmp_df)
all_empties_df = pd.concat(empty, ignore_index=True)
all_empties_df = pd.concat([all_empties_df, contact_df])
display(all_empties_df)

In [None]:
def make_heatmap_with_polymorphisms(df):
    # Define the custom sort order directly in the encoding
    sort_order = [
        "NiV Polymorphism",
        "m102.4",
        "HENV-26",
        "HENV-117",
        "HENV-103",
        "HENV-32",
        "nAH1.3",
    ]
    # full_ranges = [list(range(start, end)) for start, end in [(71, 204), (204, 337), (337, 470), (470, 603)]]
    full_ranges = [
        list(range(start, end))
        for start, end in [(71, 181), (181, 291), (291, 401), (401, 511), (511, 603)]
    ]

    # container to hold the charts
    charts = []
    color_scale_effect = alt.Scale(scheme="redblue", domainMid=0, domain=[0, 2])
    color_scale_binding = alt.Scale(scheme="redblue", domainMid=0, domain=[-5, 2])

    # Flags for showing the legend only the first time
    effect_legend_added = True
    binding_legend_added = True
    for idx, subset in enumerate(full_ranges):
        subset_df = df[df["site"].isin(subset)]  # for the wrapping of sites
        is_last_plot = idx == len(full_ranges) - 1
        x_axis = alt.Axis(
            labelAngle=-90,
            labelExpr="datum.value % 10 === 0 ? datum.value : ''",
            title="Site" if is_last_plot else None,
            labels=True,
        )  # Only show labels for the last plot

        base = (
            alt.Chart(subset_df)
            .encode(
                x=alt.X("site:O", title="Site", axis=x_axis),
                y=alt.Y(
                    "ab", title=None, sort=sort_order, axis=alt.Axis(grid=False)
                ),  # Correctly apply custom sort order
                tooltip=["site", alt.Tooltip("value", format=".2f")],
            )
            .properties(width=alt.Step(10), height=alt.Step(11))
        )

        # Define the chart for empty cells
        chart_empty = base.mark_rect(color="#e6e7e8").transform_filter(
            alt.datum.effect == "escape_median"
        )
        if not effect_legend_added:
            # Define the chart for cells with effect
            chart_effect = (
                base.mark_rect(stroke="black", strokeWidth=0.25)
                .encode(
                    color=alt.condition(
                        'datum.effect == "escape_median"',
                        alt.Color(
                            "value:Q", scale=color_scale_effect
                        ),  # Define a color scale
                        alt.value("transparent"),
                    )
                )
                .transform_filter(alt.datum.effect == "escape_median")
            )
            effect_legend_added = True
        else:
            # Define the chart for cells with effect
            chart_effect = (
                base.mark_rect(stroke="black", strokeWidth=0.25)
                .encode(
                    color=alt.condition(
                        'datum.effect == "escape_median"',
                        alt.Color(
                            "value:Q", scale=color_scale_effect, legend=None
                        ),  # Define a color scale
                        alt.value("transparent"),
                    )
                )
                .transform_filter(alt.datum.effect == "escape_median")
            )
        if not binding_legend_added:
            chart_binding = (
                base.mark_rect(strokeWidth=1.1)
                .encode(
                    stroke=alt.value("value"),
                    color=alt.condition(
                        'datum.effect == "escape_median"',
                        alt.Color("value:Q", scale=color_scale_binding),
                        alt.value("transparent"),
                    ),
                )
                .transform_filter(alt.datum.ab == "Ephrin-B2 binding")
            )
            binding_legend_added = True
        else:
            chart_binding = (
                base.mark_rect(strokeWidth=1.1)
                .encode(
                    stroke=alt.value("value"),
                    color=alt.condition(
                        'datum.effect == "escape_median"',
                        alt.Color("value:Q", scale=color_scale_binding, legend=None),
                        alt.value("transparent"),
                    ),
                )
                .transform_filter(alt.datum.ab == "Ephrin-B2 binding")
            )

        chart_poly = (
            base.mark_rect(color="black")
            .encode()
            .transform_filter(alt.datum.ab == "NiV Polymorphism")
        )
        # Layer the charts using `layer` instead of `+`
        chart = alt.layer(chart_empty, chart_effect, chart_poly).resolve_scale(
            color="independent"
        )
        charts.append(chart)
    combined_chart = alt.vconcat(
        *charts, spacing=5, title="Heatmap of max mAb escape and Nipah Polymorphisms"
    ).resolve_scale(y="shared", x="independent", color="shared")

    return combined_chart


# Assuming `all_empties_df` is your DataFrame and already defined
chart = make_heatmap_with_polymorphisms(all_empties_df)
chart.display()
if mab_line_escape_plot is not None:
    chart.save(aggregate_mab_and_niv_polymorphism)

### Make plots comparing escape with binding to see if escape sites do so by increasing binding

In [None]:
new_merged_df = pd.merge(
    combined_df,
    bind_df[["site", "wildtype", "mutant", "binding_median"]],
    on=["site", "wildtype", "mutant"],
    how="left",
)
new_merged_df = new_merged_df.drop(
    columns=[
        "mutation",
        "escape_std",
        "times_seen_ab",
        "show_site",
        "wt_codon",
        "closest_mutant_codon",
        "min_mutations",
    ]
)
new_merged_df = new_merged_df.round(2)

ab_list1 = ["m102.4", "HENV-26", "HENV-117"]
ab_list2 = ["HENV-103", "HENV-32"]
ab_list3 = ["nAH1.3"]


def plot_escape_vs_binding(df):
    variant_selector = alt.selection_point(
        on="mouseover", empty=False, nearest=True, fields=["site"], value=1
    )
    empty_chart1 = []
    for ab in ab_list1:
        tmp_df = df[df["ab"] == ab]
        base = (
            alt.Chart(tmp_df, title=alt.Title(f"{ab}", anchor="middle"))
            .mark_point(
                filled=True, size=15, color="#1f4e79", opacity=0.15, stroke="black"
            )
            .encode(
                alt.X(
                    "binding_median",
                    title="EFNB2 Binding",
                    axis=alt.Axis(grid=True, tickCount=3),
                ),
                alt.Y(
                    "escape_median",
                    title="Antibody Escape",
                    axis=alt.Axis(grid=True, tickCount=3),
                ),
                tooltip=[
                    "site",
                    "wildtype",
                    "mutant",
                    "escape_median",
                    "binding_median",
                ],
                opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.2)),
                size=alt.condition(variant_selector, alt.value(50), alt.value(20)),
                strokeWidth=alt.condition(variant_selector, alt.value(2), alt.value(0)),
            )
        )
        empty_chart1.append(base)
    combined_chart1 = alt.hconcat(*empty_chart1, spacing=5).resolve_scale(
        x="shared", y="shared"
    )
    empty_chart2 = []
    for ab in ab_list2:
        tmp_df = df[df["ab"] == ab]
        base = (
            alt.Chart(tmp_df, title=alt.Title(f"{ab}", anchor="middle"))
            .mark_point(
                filled=True, size=15, color="#ff7f0e", opacity=0.15, stroke="black"
            )
            .encode(
                alt.X(
                    "binding_median",
                    title="EFNB2 Binding",
                    axis=alt.Axis(grid=True, tickCount=3),
                ),
                alt.Y(
                    "escape_median",
                    title="Antibody Escape",
                    axis=alt.Axis(grid=True, tickCount=3),
                ),
                tooltip=[
                    "site",
                    "wildtype",
                    "mutant",
                    "escape_median",
                    "binding_median",
                ],
                opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.2)),
                size=alt.condition(variant_selector, alt.value(50), alt.value(20)),
                strokeWidth=alt.condition(variant_selector, alt.value(2), alt.value(0)),
            )
        )
        empty_chart2.append(base)
    combined_chart2 = alt.hconcat(*empty_chart2, spacing=5).resolve_scale(
        x="shared", y="shared"
    )

    empty_chart3 = []
    for ab in ab_list3:
        tmp_df = df[df["ab"] == ab]
        base3 = (
            alt.Chart(tmp_df, title=alt.Title(f"{ab}", anchor="middle"))
            .mark_point(
                filled=True, size=15, color="#2ca02c", opacity=0.15, stroke="black"
            )
            .encode(
                alt.X(
                    "binding_median",
                    title="EFNB2 Binding",
                    axis=alt.Axis(grid=True, tickCount=3),
                ),
                alt.Y(
                    "escape_median",
                    title="Antibody Escape",
                    axis=alt.Axis(grid=True, tickCount=3),
                ),
                tooltip=[
                    "site",
                    "wildtype",
                    "mutant",
                    "escape_median",
                    "binding_median",
                ],
                opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.2)),
                size=alt.condition(variant_selector, alt.value(50), alt.value(20)),
                strokeWidth=alt.condition(variant_selector, alt.value(2), alt.value(0)),
            )
        )

    combined_chart_total = alt.vconcat(
        combined_chart1,
        combined_chart2,
        base3,
        title=alt.Title(
            "Antibody Escape versus Binding",
            subtitle="Colored by Epitope. Hover over points to see the same sites",
        ),
    ).add_params(
        variant_selector
    )  
    return combined_chart_total


tmp_img_test = plot_escape_vs_binding(new_merged_df)
tmp_img_test.display()
if mab_line_escape_plot is not None:
    tmp_img_test.save(binding_vs_escape)