# ACE2 binding effects of non-RBD mutations in natural sequences
Look at ACE2 binding effects of mutations in non-RBD sequences.

In [None]:
# this cell is tagged as `parameters` for `papermill` parameterization
dms_summary_csv = None
pango_consensus_seqs_json = None

In [None]:
import collections
import json

import altair as alt

import numpy

import pandas as pd

Get spike mutations relative to reference and new spike mutations relative to parent in Pango clades descended from starting clades:

In [None]:
starting_clades = ["XBB", "BA.2", "BA.5"]

with open(pango_consensus_seqs_json) as f:
    pango_clades = json.load(f)

def build_records(c, recs):
    """Build records of Pango clade information."""
    if c in recs["clade"]:
        return
    recs["clade"].append(c)
    recs["date"].append(pango_clades[c]["designationDate"])
    recs["new_spike_muts"].append(
        [
            mut.split(":")[1]
            for field in ["aaSubstitutionsNew", "aaDeletionsNew"]
            for mut in pango_clades[c][field]
            if mut.startswith("S:")
        ]
    )
    recs["spike_muts"].append(
        [
            mut.split(":")[1]
            for field in ["aaSubstitutions", "aaDeletions"]
            for mut in pango_clades[c][field]
            if mut.startswith("S:")
        ]
    )
    for c_child in pango_clades[c]["children"]:
        build_records(c_child, recs)
        
pango_dfs = []
for starting_clade in starting_clades:
    records = collections.defaultdict(list)
    build_records(starting_clade, records)
    pango_dfs.append(
        pd.DataFrame(records)
        .query("clade != @starting_clade")
        .assign(parent_clade=starting_clade)
    )

pango_df = pd.concat(pango_dfs)

Get the counts of how many times each mutation newly occurs in a clade:

In [None]:
new_mut_counts = (
    pango_df
    .explode("new_spike_muts")
    .query("new_spike_muts.notnull()")
    .rename(columns={"new_spike_muts": "mutation"})
    .groupby(["parent_clade", "mutation"], as_index=False)
    .aggregate(
        n_clades=pd.NamedAgg("clade", "count"),
        clades=pd.NamedAgg("clade", "unique"),
    )
    .assign(
        site=lambda x: x["mutation"].str[1: -1].astype(int),
        clades=lambda x: x["clades"].map(lambda s: "; ".join(s)),
    )
)

Add DMS phenotypes:

In [None]:
xbb15_dms = pd.read_csv(xbb15_dms_csv).rename(
    columns={
        "spike mediated entry": "cell entry",
        "human sera escape": "sera escape",
    }
)

ba2_dms = pd.read_csv(ba2_dms_csv).rename(
    columns={
        "spike mediated entry": "cell entry",
        "human sera escape": "sera escape",
        "ACE2 affinity": "ACE2 binding",
    }
)

# specify DMS phenotypes of interest
phenotypes = [
#    "sera escape",
    "ACE2 binding",
    "cell entry",
]
assert set(phenotypes).issubset(xbb15_dms.columns)
assert set(phenotypes).issubset(ba2_dms.columns)

def mut_dms(m, dms_data, site_to_region, dms_wt):
    """Get DMS phenotypes for a mutation."""
    null_d = {k: pd.NA for k in phenotypes}
    if pd.isnull(m) or int(m[1: -1]) not in dms_wt:
        d = null_d
        d["is_RBD"] = pd.NA
    else:
        parent = m[0]
        site = int(m[1: -1])
        mut = m[-1]
        wt = dms_wt[site]
        if parent == wt:
            try:
                d = dms_data[(site, parent, mut)]
            except KeyError:
                d = null_d
        elif mut == wt:
            try:
                d = {k: -v for (k, v) in dms_data[(site, mut, parent)].items()}
            except KeyError:
                d = null_d
        else:
            try:
                parent_d = dms_data[(site, wt, parent)]
                mut_d = dms_data[(site, wt, mut)]
                d = {p: mut_d[p] - parent_d[p] for p in phenotypes}
            except KeyError:
                d = null_d
        d["is_RBD"] = (site_to_region[site] == "RBD")
    assert list(d) == phenotypes + ["is_RBD"]
    return d

site_to_region = {}

for strain, dms_df in [("XBB.1.5", xbb15_dms), ("BA.2", ba2_dms)]:

    # dict that maps site to wildtype in DMS
    dms_wt = dms_df.set_index("site")["wildtype"].to_dict()
    
    # dict that maps site to region in DMS
    site_to_region = dms_df.set_index("site")["region"].to_dict()
    
    dms_data = (
        dms_df
        .set_index(["site", "wildtype", "mutant"])
        [phenotypes]
        .to_dict(orient="index")
    )

    site_to_region.update(dms_df.set_index("site")["region"].to_dict())

    for phenotype in phenotypes:
        new_mut_counts[f"{strain} {phenotype}"] = new_mut_counts["mutation"].map(
            lambda m: mut_dms(m, dms_data, site_to_region, dms_wt)[phenotype]
        )

new_mut_counts["region"] = new_mut_counts["site"].map(site_to_region)

Look at non-RBD mutations observed at least 3 times in XBB-descended clades that substantially increase ACE2 binding:

In [None]:
pd.set_option("display.max_colwidth", 1000)

display(
    new_mut_counts
    .sort_values("XBB.1.5 ACE2 binding", ascending=False)
    .query("parent_clade == 'XBB'")
    .query("region != 'RBD'")
    .query("n_clades >= 3")
    .query("`XBB.1.5 ACE2 binding` >= 0.1")
    .reset_index(drop=True)
)

Same for BA.2 and BA.5 descended clades:

In [None]:
pd.set_option("display.max_colwidth", 1000)

display(
    new_mut_counts
    .sort_values("BA.2 ACE2 binding", ascending=False)
    .query("parent_clade in ['BA.5', 'BA.2']")
    .query("region != 'RBD'")
    .query("n_clades >= 3")
    .query("`BA.2 ACE2 binding` >= 0.1")
    .reset_index(drop=True)
)

Also look at how well libraries cover mutations of interest by seeing what fraction of all new mutations in descendant clades are covered well enough to make viral entry estimates for both the BA.2 and XBB libraries and those parent clades:

In [None]:
has_measurement_df = (
    new_mut_counts
    .query("region.notnull()")  # excludes cytoplasmic tail
    .assign(
        has_entry_measurement=lambda x: numpy.where(
            x["parent_clade"] == "BA.2",
            x["BA.2 cell entry"].notnull(),
            x["XBB.1.5 cell entry"].notnull(),
        ),
        n_clades=lambda x: x["n_clades"].clip(upper=6),
    )
    [["parent_clade", "mutation", "n_clades", "has_entry_measurement"]]
)

print("Mutations that appear in any Pango clades that have measurements")
display(
    has_measurement_df
    .assign(in_multiple_clades=lambda x: x["n_clades"] > 1)
    .groupby(["parent_clade", "in_multiple_clades"])
    .aggregate(
        n_mutations=pd.NamedAgg("mutation", "count"),
        has_entry_measurement=pd.NamedAgg("has_entry_measurement", "sum"),
    )
)

has_measurement_chart = (
    alt.Chart(has_measurement_df)
    .encode(
        x=alt.X(
            "n_clades",
            bin=alt.BinParams(step=1),
        ),
        y=alt.Y("count()"),
        color="has_entry_measurement",
        column="parent_clade",
        tooltip=["n_clades", "count()"],
    )
    .mark_bar()
)

has_measurement_chart