# Compare ACE2 affinity across different experiments

First, get parameters specifying data to analyze:

In [None]:
# this cell is tagged as `parameters` for papermill parameterization
xbb_spike_affinity = None
xbb_spike_func_effects = None
xbb_spike_escape = None

starr_rbd_affinity = None

ba2_spike_affinity = None
ba2_spike_func_effects = None

xbb_rbd_affinity_monomeric = None
xbb_rbd_affinity_dimeric = None
xbb_rbd_func_effects = None

merged_affinity_csv = None

init_min_func_effect = None
clip_affinity_upper = None
clip_affinity_lower = None

affinity_corr_html = None
affinity_dist_html = None
affinity_entry_corr_html = None
affinity_escape_corr_html = None

In [None]:
import functools
import itertools
import math
import operator
import os

import altair as alt

import pandas as pd

_ = alt.data_transformers.disable_max_rows()

## Read the data sets

In [None]:
xbb_spike = pd.read_csv(xbb_spike_csv).assign(
    experiment="XBB.1.5 full spike lentivirus",
    ACE2_type="monomeric",
).drop(columns="sequential_site")

ba2_spike = pd.read_csv(ba2_spike_csv).assign(
    experiment="BA.2 full spike lentivirus",
    ACE2_type="monomeric",
).drop(columns="sequential_site")

xbb_rbd = (
    pd.read_csv(xbb_rbd_csv)
    .rename(
        columns={
            f"{ace2_type} ACE2 affinity": ace2_type
            for ace2_type in ["monomeric", "dimeric"]
        }
    )
    .melt(
        id_vars=["site", "wildtype", "mutant", "spike mediated entry", "region"],
        var_name="ACE2_type",
        value_name="ACE2 affinity",
        value_vars=["monomeric", "dimeric"],
    )
    .assign(experiment="XBB.1.5 RBD pseudovirus")
)

starr_rbd = (
    pd.read_csv(starr_rbd_affinity)
    .rename(columns={"position": "site", "delta_bind": "ACE2 affinity", "target": "experiment"})
    .query("experiment in ['Omicron_BA2', 'Omicron_XBB15']")
    .assign(
        experiment=lambda x: x["experiment"].map(
            {
                "Omicron_BA2": "BA.2 RBD yeast display",
                "Omicron_XBB15": "XBB.1.5 RBD yeast display",
            }
        )
    )
    .assign(ACE2_type="monomeric", region="RBD")
    [["site", "wildtype", "mutant", "ACE2 affinity", "region", "experiment", "ACE2_type"]]
)

merged_affinity_data = pd.concat([xbb_spike, ba2_spike, xbb_rbd, starr_rbd], ignore_index=True)

print(f"Writing the merged affinity data to {merged_affinity_csv}")
os.makedirs(os.path.dirname(merged_affinity_csv), exist_ok=True)
merged_affinity_data.to_csv(merged_affinity_csv, float_format="%.4f", index=False)

# make some tweaks to merged affinity for plotting
merged_affinity = (
    merged_affinity_data
    .assign(
        experiment=lambda x: x["experiment"].where(
            x["ACE2_type"] == "monomeric",
            x["experiment"] + " (dimeric)",
        ).str.replace(".", "_").str.replace("pseudovirus", "lentivirus"),
        mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"],
        region=lambda x: x["region"].map(lambda r: "RBD" if r == "RBD" else "not RBD"),
        **{
            "ACE2 affinity": lambda x: x["ACE2 affinity"].clip(
                lower=clip_affinity_lower, upper=clip_affinity_upper,
            )
        },
    )
    .drop(columns="ACE2_type")
)

merged_affinity.head()

## Correlate affinities across datasets

In [None]:
experiments = merged_affinity["experiment"].unique().tolist()

func_effects_slider = alt.param(
    value=init_min_func_effect,
    bind=alt.binding_range(
        name="minimum spike-mediated entry",
        min=merged_affinity["spike mediated entry"].min(),
        max=0,
    ),
)

common_mutations_only = alt.param(
    value=False,
    bind=alt.binding_radio(
        options=[True, False],
        name="show only mutations measured in all datasets",
    ),
)

region_selection = alt.selection_point(
    #value="all spike",
    fields=["region"],
    bind=alt.binding_radio(
        options=[None, "RBD", "not RBD"],
        labels=["all spike", "RBD", "not RBD"],
        name="show mutations in this region",
    ),
)

mut_selection = alt.selection_point(fields=["mutation"], on="mouseover", empty=False)

corr_base = (
    alt.Chart(
        merged_affinity[["mutation", "spike mediated entry", "ACE2 affinity", "region", "experiment"]]
    )
    .add_params(func_effects_slider, common_mutations_only, region_selection, mut_selection)
    .transform_filter(
        (alt.datum["spike mediated entry"] >= func_effects_slider)
        | alt.expr.isNaN(alt.datum["spike mediated entry"])
    )
    .transform_filter(region_selection)
    .transform_pivot(
        pivot="experiment",
        value="ACE2 affinity",
        groupby=["mutation", "region"],
        op="max",  # if default of "sum", end up with zero
    )
    .transform_filter(
         functools.reduce(
            operator.and_,
            [alt.expr.isFinite(alt.datum[c]) for c in experiments],
        ) | (common_mutations_only == False)
    )
)

corr_charts = []
for expt1, expt2 in itertools.combinations(experiments, 2):
    corr_scatter = (
        corr_base.encode(
            x=alt.X(
                expt1,
                type="quantitative",
                title=expt1.replace("_", "."),
                scale=alt.Scale(nice=False, padding=4),
            ),
            y=alt.Y(
                expt2,
                type="quantitative",
                title=expt2.replace("_", "."),
                scale=alt.Scale(nice=False, padding=4),
            ),
            tooltip=[
                "mutation",
                *[
                    alt.Tooltip(expt, type="quantitative", format=".2f", title=expt.replace("_", "."))
                    for expt in [expt1, expt2]
                ],
                "region",
            ],
            size=alt.condition(mut_selection, alt.value(70), alt.value(35)),
            opacity=alt.condition(mut_selection, alt.value(1), alt.value(0.25)),
            strokeWidth=alt.condition(mut_selection, alt.value(2), alt.value(0)),
        )
        .mark_point(filled=True, stroke="red", color="black")
        .properties(width=170, height=170)
    )
    corr_r = (
        corr_base
        .transform_regression(expt1, expt2, params=True)
        .transform_calculate(
            r=alt.expr.if_(
                alt.datum["coef"][1] > 0,
                alt.expr.sqrt(alt.datum["rSquared"]),
                -alt.expr.sqrt(alt.datum["rSquared"]),
            ),
            r_text="r = " + alt.expr.format(alt.datum["r"], ".2f"),
        )
        .encode(
            text="r_text:N",
            x=alt.value(5),
            y=alt.value(10),
        )
        .mark_text(size=14, align="left", color="blue")
    )
    corr_charts.append(corr_scatter + corr_r)
ncols = 5
nrows = int(math.ceil(len(corr_charts) / ncols))
corr_chart = alt.vconcat(
    *[
        alt.hconcat(
            *[
                corr_charts[ichart]
                for ichart in range(irow * ncols, min(len(corr_charts), (irow + 1) * ncols))
            ],
            spacing=7,
        )
        for irow in range(nrows)
    ],
    spacing=7,
).configure_axis(grid=False).properties(
    title=alt.TitleParams(
        "Correlations among ACE2 affinities measured in different experiments",
        anchor="middle",
        fontSize=16,
    ),
)

print(f"Saving to {affinity_corr_html}")
corr_chart.save(affinity_corr_html)

corr_chart

## Distribution of effects on ACE2 affinity for RBD and non-RBD mutations

In [None]:
# get data frame to plot
dist_df = (
    merged_affinity
    .query("experiment in ['XBB_1_5 full spike lentivirus', 'BA_2 full spike lentivirus']")
    [["mutation", "spike mediated entry", "ACE2 affinity", "region", "experiment"]]
)

dist_chart = (
    alt.Chart(dist_df)
    .add_params(func_effects_slider)
    .transform_filter(alt.datum["spike mediated entry"] >= func_effects_slider)
    .encode(
        x=alt.X("ACE2 affinity", bin=alt.BinParams(step=0.25), title="ACE2 affinity"),
        y=alt.Y("count()", title="number of mutations"),
        color=alt.Color("region", legend=None),
        row=alt.Row(
            "region",
            title=None,
            header=alt.Header(
                labelFontSize=12,
                labelFontStyle="bold",
                orient="right",
                labelPadding=3,
            ),
            spacing=8,
        ),
        column=alt.Column(
            "experiment",
            title=None,
            header=alt.Header(
                labelFontSize=12,
                labelFontStyle="bold",
                labelExpr="replace(replace(datum.label, '_', '.'), '_', '.')",
                labelPadding=3,
            ),
        ),
    )
    .mark_bar()
    .configure_axis(grid=False)
    .resolve_scale(y="independent")
    .properties(
        width=225,
        height=115,
        title=alt.TitleParams(
            "Effects of RBD and non-RBD mutations on ACE2 affinity",
            anchor="middle",
            fontSize=16,
        ),
    )
)

print(f"Saving to {affinity_dist_html}")
dist_chart.save(affinity_dist_html)

dist_chart

## Correlation of ACE2 affinity and pseudovirus entry

In [None]:
affinity_entry_df = (
    merged_affinity
    .query("`ACE2 affinity`.notnull() and `spike mediated entry`.notnull()")
    [["mutation", "spike mediated entry", "ACE2 affinity", "region", "experiment"]]
)

affinity_entry_corr_base = (
    alt.Chart(affinity_entry_df)
    .add_params(func_effects_slider, region_selection, mut_selection)
    .transform_filter(alt.datum["spike mediated entry"] >= func_effects_slider)
    .transform_filter(region_selection)
)

affinity_entry_corr_chart = (
    (
        (
            affinity_entry_corr_base
            .encode(
                x=alt.X("ACE2 affinity", scale=alt.Scale(nice=False, padding=3)),
                y=alt.Y("spike mediated entry", scale=alt.Scale(nice=False, padding=3)),
                tooltip=affinity_entry_df.columns.tolist(),
                size=alt.condition(mut_selection, alt.value(70), alt.value(35)),
                opacity=alt.condition(mut_selection, alt.value(1), alt.value(0.25)),
                strokeWidth=alt.condition(mut_selection, alt.value(2), alt.value(0)),
            )
            .mark_circle(color="black", stroke="red")
        )
        + (
            affinity_entry_corr_base
            .transform_regression("ACE2 affinity", "spike mediated entry", params=True)
            .transform_calculate(
                r=alt.expr.if_(
                    alt.datum["coef"][1] > 0,
                    alt.expr.sqrt(alt.datum["rSquared"]),
                    -alt.expr.sqrt(alt.datum["rSquared"]),
                ),
                r_text="r = " + alt.expr.format(alt.datum["r"], ".2f"),
            )
            .encode(
                text="r_text:N",
                x=alt.value(5),
                y=alt.value(170),
            )
            .mark_text(size=14, align="left", color="blue")
        )
    )
    .properties(
        width=180,
        height=180,
    )
    .facet(
        facet=alt.Facet(
            "experiment",
            sort=experiments,
            title=None,
            header=alt.Header(
                labelFontSize=12,
                labelFontStyle="bold",
                labelExpr="replace(replace(datum.label, '_', '.'), '_', '.')",
                labelPadding=3,
            ),
        ),
    )
    .configure_axis(grid=False)
    .resolve_scale(y="independent", x="independent")
    .properties(
        title=alt.TitleParams(
            "Correlation of spike-mediated entry and ACE2 affinity",
            anchor="middle",
            fontSize=16,
        ),
    )
)

print(f"Saving to {affinity_entry_corr_html}")
affinity_entry_corr_chart.save(affinity_entry_corr_html)

affinity_entry_corr_chart

## Correlation of ACE2 affinity with serum escape

In [None]:
affinity_escape_df = (
    merged_affinity
    .query("experiment == 'XBB_1_5 full spike lentivirus'")
    .query("`human sera escape`.notnull() and `ACE2 affinity`.notnull()")
    [["mutant", "human sera escape", "spike mediated entry", "ACE2 affinity", "region"]]
)

affinity_escape_corr_base = (
    alt.Chart(affinity_escape_df)
    .add_params(func_effects_slider)
    .transform_filter(alt.datum["spike mediated entry"] >= func_effects_slider)
)

affinity_escape_corr_chart = (
    (
        (
            affinity_escape_corr_base
            .encode(
                x=alt.X("ACE2 affinity", scale=alt.Scale(nice=False, padding=3)),
                y=alt.Y("human sera escape", scale=alt.Scale(nice=False, padding=3)),
                tooltip=affinity_escape_df.columns.tolist(),
            )
            .mark_circle(color="black", opacity=0.25, size=35)
        )
        + (
            affinity_escape_corr_base
            .transform_regression("ACE2 affinity", "human sera escape", params=True)
            .transform_calculate(
                r=alt.expr.if_(
                    alt.datum["coef"][1] > 0,
                    alt.expr.sqrt(alt.datum["rSquared"]),
                    -alt.expr.sqrt(alt.datum["rSquared"]),
                ),
                r_text="r = " + alt.expr.format(alt.datum["r"], ".2f"),
            )
            .encode(
                text="r_text:N",
                x=alt.value(5),
                y=alt.value(170),
            )
            .mark_text(size=14, align="left", color="blue")
        )
    )
    .properties(
        width=180,
        height=180,
    )
    .facet(
        facet=alt.Facet(
            "region",
            title=None,
            header=alt.Header(
                labelFontSize=12,
                labelFontStyle="bold",
                labelPadding=3,
            ),
        ),
    )
    .configure_axis(grid=False)
    .properties(
        title=alt.TitleParams(
            "Correlation of ACE2 affinity and sera escape",
            anchor="middle",
            fontSize=16,
        ),
    )
)

print(f"Saving to {affinity_escape_corr_html}")
affinity_escape_corr_chart.save(affinity_escape_corr_html)

affinity_escape_corr_chart