# Escape at key sites: logo plots and binding / escape correlations
Make logo plots of serum escape at key sites, and look at relationship between escape and other phenotypes like ACE2 binding.

First get input files / parameters from `papermill` and import Python modules:

In [1]:
# this cell is tagged as `parameters` for papermill parameterization
dms_csv = None
per_antibody_csv = None
pango_consensus_seqs_json = None

In [10]:
# Parameters
pango_consensus_seqs_json = "https://raw.githubusercontent.com/corneliusroemer/pango-sequences/c64ef05e53debaa9cc65dd56d6eb83e31517179c/data/pango-consensus-sequences_summary.json"
dms_csv = "results/summaries/summary.csv"
per_antibody_csv="results/summaries/per_antibody_escape.csv"

key_sites_by_rank = {
    "total_abs_escape": {
        "any antibody": 10,
        "average of antibodies": 10,
    },
    "total_positive_escape": {
        "any antibody": 10,
        "average of antibodies": 10,
    },
}

key_sites_manual = []

import os
os.chdir("../")

In [3]:
import altair as alt

import dmslogo

import matplotlib.pyplot as plt

import numpy

import pandas as pd

_ = alt.data_transformers.disable_max_rows()

## Read input data
Keep only mutations with all phenotypes measured:

In [4]:
# read averages for all DMS measurements
dms_df = (
    pd.read_csv(dms_csv)
    .rename(
        columns={"human sera escape": "sera escape", "spike mediated entry": "cell entry"}
    )
    .query("`sera escape`.notnull() and `cell entry`.notnull() and `sera escape`.notnull()")
)

# read per antibody values, merge with averages to create escape_df
per_antibody_df = pd.read_csv(per_antibody_csv)

assert per_antibody_df["antibody_set"].nunique() == 1, "code expects 1 antibody_set"

if (
    (intersection := set(dms_df.columns).intersection(per_antibody_df.columns))
    != {"site", "wildtype", "mutant"}
):
    raise ValueError(f"unexpected {intersection=}")

assert "average" not in per_antibody_df["antibody"]

escape_df = (
    pd.concat(
        [
            dms_df[["site", "wildtype", "mutant", "sera escape"]].rename(
                columns={"sera escape": "escape"}
            ).assign(antibody="average"),
            per_antibody_df.drop(columns="antibody_set"),
        ],
        ignore_index=True,
    )
    .merge(dms_df.drop(columns="sera escape"), validate="many_to_one")
    .assign(wildtype_site=lambda x: x["wildtype"] + x["site"].astype(str))
)

## Determine key sites to plot
Get key sites with most site escape, and plot their site escape values in interactive chart:

In [13]:
# get total magnitude of escape at each site, both for averages
# and across all individual antibodies
site_escape_df = (
    escape_df
    .assign(
        is_average=lambda x: numpy.where(
            x["antibody"] == "average", "average of antibodies", "any antibody"),
    )
    .groupby(["is_average", "antibody", "site", "sequential_site"], as_index=False)
    .aggregate(
        total_abs_escape=pd.NamedAgg("escape", lambda s: s.abs().sum()),
        total_positive_escape=pd.NamedAgg("escape", lambda s: s.clip(lower=0).sum()),
        total_negative_escape=pd.NamedAgg("escape", lambda s: s.clip(upper=0).abs().sum()),
    )
    .groupby(["is_average", "site", "sequential_site"], as_index=False)
    .aggregate(
        {
            "total_abs_escape": "max",
            "total_positive_escape": "max",
            "total_negative_escape": "max",
        }
    )
    .melt(
        id_vars=["is_average", "site", "sequential_site"],
        var_name="site metric",
        value_name="site escape",
    )
    .assign(
        rank=lambda x: (
            x.groupby(["is_average", "site metric"])
            ["site escape"]
            .rank(ascending=False, method="min")
            .astype(int)
        )
    )
)

# get key sites
print(f"Keeping the following manually specified sites: {key_sites_manual}")
key_sites = set(key_sites_manual)
for site_metric, site_metric_d in key_sites_by_rank.items():
    for is_average, rank in site_metric_d.items():
        new_sites = set(
            site_escape_df
            .query("`site metric` == @site_metric")
            .query("is_average == @is_average")
            .query("rank <= @rank")
            ["site"]
        )
        print(f"Adding sites with {site_metric} / {is_average} rank <= {rank}: {new_sites}")
        key_sites = key_sites.union(new_sites)
print(f"Overall keeping the following {len(key_sites)} sites: {key_sites}")

site_escape_df["key_site"] = site_escape_df["site"].isin(key_sites)

# plot sites being kept
site_selection = alt.selection_point(fields=["site"], on="mouseover", empty=False)

site_metric_selection = alt.selection_point(
    fields=["site metric"],
    value="total_positive_escape",
    bind=alt.binding_select(
        name="site metric",
        options=site_escape_df["site metric"].unique(),
    ),
)

site_escape_chart = (
    alt.Chart(site_escape_df)
    .add_params(site_selection, site_metric_selection)
    .transform_filter(site_metric_selection)
    .encode(
        alt.X("site", sort=alt.SortField("sequential_site"), scale=alt.Scale(nice=False, zero=False)),
        alt.Y("site escape"),
        alt.Color("key_site"),
        alt.Row("is_average", title=None),
        tooltip=[alt.Tooltip(c, format=".2f") if site_escape_df[c].dtype == float else c for c in site_escape_df.columns],
        strokeWidth=alt.condition(site_selection, alt.value(2), alt.value(0)),
        opacity=alt.condition(site_selection, alt.value(1), alt.value(0.35)),
        size=alt.condition(site_selection, alt.value(70), alt.value(30)),
    )
    .mark_circle(stroke="black")
    .configure_axis(grid=False)
    .resolve_scale(y="independent")
    .properties(
        width=600,
        height=150,
        title="Escape at each site for average of antibodies or max for any antibody",
    )
)

site_escape_chart

Keeping the following manually specified sites: []
Adding sites with total_abs_escape / any antibody rank <= 10: {385, 417, 420, 357, 486, 456, 368, 371, 405, 473}
Adding sites with total_abs_escape / average of antibodies rank <= 10: {385, 417, 357, 486, 371, 373, 405, 375, 376, 473}
Adding sites with total_positive_escape / any antibody rank <= 10: {450, 483, 420, 357, 485, 456, 371, 473, 475, 447}
Adding sites with total_positive_escape / average of antibodies rank <= 10: {450, 420, 357, 455, 456, 371, 440, 473, 475, 447}
Overall keeping the following 20 sites: {385, 450, 455, 456, 405, 473, 475, 417, 483, 420, 357, 486, 485, 440, 368, 371, 373, 375, 376, 447}


In [None]:
(
    site_escape_df
    .assign(
        rank=lambda x: (
            x.groupby(["is_average", "site metric"])
            ["site escape"]
            .rank(ascending=False, method="min")
        )
    )
    .sort_values("rank")
    .query("`site metric` == 'mag_positive_escape'")
    .assign(int_rank = lambda x: x["rank"].astype(int))
    .query("rank != int_rank")
)