# Compare ACE2 affinity across different experiments

First, get parameters specifying data to analyze:

In [None]:
# this cell is tagged as `parameters` for papermill parameterization
xbb_spike_affinity = None
xbb_spike_func_effects = None
xbb_spike_escape = None

starr_rbd_affinity = None

ba2_spike_affinity = None
ba2_spike_func_effects = None

xbb_rbd_affinity_monomeric = None
xbb_rbd_affinity_dimeric = None
xbb_rbd_func_effects = None

merged_affinity_csv = None

In [103]:
import functools
import itertools
import math
import operator
import os

import altair as alt

import pandas as pd

_ = alt.data_transformers.disable_max_rows()

## Read the data sets

In [None]:
xbb_spike = pd.read_csv(xbb_spike_csv).assign(
    experiment="XBB.1.5 full spike pseudovirus",
    ACE2_type="monomeric",
)

ba2_spike = pd.read_csv(ba2_spike_csv).assign(
    experiment="BA.2 full spike pseudovirus",
    ACE2_type="monomeric",
)

xbb_rbd = (
    pd.read_csv(xbb_rbd_csv)
    .rename(
        columns={
            f"{ace2_type} ACE2 affinity": ace2_type
            for ace2_type in ["monomeric", "dimeric"]
        }
    )
    .melt(
        id_vars=["site", "wildtype", "mutant", "spike mediated entry", "sequential_site", "region"],
        var_name="ACE2_type",
        value_name="ACE2 affinity",
        value_vars=["monomeric", "dimeric"],
    )
    .assign(experiment="XBB.1.5 RBD pseudovirus")
)

starr_rbd = (
    pd.read_csv(starr_rbd_affinity)
    .rename(columns={"position": "site", "delta_bind": "ACE2 affinity", "target": "experiment"})
    .query("experiment in ['Omicron_BA2', 'Omicron_XBB15']")
    .assign(
        experiment=lambda x: x["experiment"].map(
            {
                "Omicron_BA2": "BA.2 RBD yeast display",
                "Omicron_XBB15": "XBB.1.5 RBD yeast display",
            }
        )
    )
    .assign(ACE2_type="monomeric", region="RBD")
    [["site", "wildtype", "mutant", "ACE2 affinity", "region", "experiment", "ACE2_type"]]
)

merged_affinity_data = pd.concat([xbb_spike, ba2_spike, xbb_rbd, starr_rbd], ignore_index=True)

print(f"Writing the merged affinity data to {merged_affinity_csv}")
os.makedirs(os.path.dirname(merged_affinity_csv), exist_ok=True)
merged_affinity_data.to_csv(merged_affinity_csv, float_format="%.4f", index=False)

In [43]:
merged_affinity = (
    pd.read_csv("../results/affinity_comparison/merged_affinities.csv")
    .assign(
        experiment=lambda x: x["experiment"].where(
            x["ACE2_type"] == "monomeric",
            x["experiment"] + " (dimeric ACE2)",
        ).str.replace(".", "_")
    )
    .drop(columns="ACE2_type")
)

merged_affinity.head()

Unnamed: 0,site,wildtype,mutant,human sera escape,spike mediated entry,ACE2 affinity,sequential_site,region,experiment
0,2,F,C,0.0111,0.101,0.0215,2.0,other,XBB_1_5 full spike pseudovirus
1,2,F,L,0.0188,0.0943,-0.2698,2.0,other,XBB_1_5 full spike pseudovirus
2,2,F,S,0.0317,0.0584,-0.0564,2.0,other,XBB_1_5 full spike pseudovirus
3,2,F,F,0.0,0.0,0.0,2.0,other,XBB_1_5 full spike pseudovirus
4,3,V,A,0.024,-0.0415,-0.0498,3.0,other,XBB_1_5 full spike pseudovirus


## Correlate affinities across datasets

In [116]:
experiments = merged_affinity["experiment"].unique().tolist()

func_effects_slider = alt.param(
    value=-3,
    bind=alt.binding_range(
        name="minimum spike-mediated entry",
        min=merged_affinity["spike mediated entry"].min(),
        max=0,
    ),
)

common_mutations_only = alt.param(
    value=False,
    bind=alt.binding_radio(
        options=[True, False],
        name="show only mutations measured in all datasets",
    ),
)

region_selection = alt.selection_point(
    #value="all spike",
    fields=["region"],
    bind=alt.binding_radio(
        options=[None, "RBD", "not RBD"],
        labels=["all spike", "RBD", "not RBD"],
        name="show mutations in this region",
    ),
)

mut_selection = alt.selection_point(fields=["mutation"], on="mouseover", empty=False)

corr_base = (
    alt.Chart(
        merged_affinity
        .assign(
            mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"],
            region=lambda x: x["region"].map(lambda r: "RBD" if r == "RBD" else "not RBD"),
        )
        [["mutation", "spike mediated entry", "ACE2 affinity", "region", "experiment"]]
    )
    .add_params(func_effects_slider, common_mutations_only, region_selection, mut_selection)
    .transform_filter(
        (alt.datum["spike mediated entry"] >= func_effects_slider)
        | alt.expr.isNaN(alt.datum["spike mediated entry"])
    )
    .transform_filter(region_selection)
    .transform_pivot(
        pivot="experiment",
        value="ACE2 affinity",
        groupby=["mutation", "region"],
        op="max",  # if default of "sum", end up with zero
    )
    .transform_filter(
         functools.reduce(
            operator.and_,
            [alt.expr.isFinite(alt.datum[c]) for c in experiments],
        ) | (common_mutations_only == False)
    )
    .mark_point(filled=True, stroke="black")
)

corr_charts = []
for expt1, expt2 in itertools.combinations(experiments, 2):
    corr_charts.append(
        corr_base.encode(
            x=alt.X(expt1, type="quantitative", title=expt1.replace("_", ".")),
            y=alt.Y(expt2, type="quantitative", title=expt2.replace("_", ".")),
            tooltip=[
                "mutation",
                *[
                    alt.Tooltip(expt, type="quantitative", format=".2f", title=expt.replace("_", "."))
                    for expt in [expt1, expt2]
                ],
                "region",
            ],
            color=alt.Color("region", scale=alt.Scale(domain=["RBD", "not RBD"])),
            size=alt.condition(mut_selection, alt.value(70), alt.value(30)),
            opacity=alt.condition(mut_selection, alt.value(1), alt.value(0.4)),
            strokeWidth=alt.condition(mut_selection, alt.value(2), alt.value(0)),
        )
        .properties(width=215, height=215)
    )
ncols = 4
nrows = int(math.ceil(len(corr_charts) / ncols))
corr_chart = alt.vconcat(
    *[
        alt.hconcat(
            *[
                corr_charts[ichart]
                for ichart in range(irow * ncols, min(len(corr_charts), (irow + 1) * ncols))
            ],
            spacing=10,
        )
        for irow in range(nrows)
    ],
    spacing=10,
)
corr_chart.configure_axis(grid=False)

In [105]:
help(alt.vconcat)

Help on function vconcat in module altair.vegalite.v5.api:

vconcat(*charts, **kwargs)
    Concatenate charts vertically



In [84]:
len(experiments)

6