# Compare escape in high and medium ACE2 cells
Compare antibody-escape in high- and medium-ACE2 expressing cells:

In [None]:
# this cell is tagged `parameters` for `papermill` parameterization
site_numbering_map_csv = None
func_effects_csv = None
escape_csvs = None
init_min_func_effect = None
init_min_times_seen = None
init_floor_at_zero = None
init_site_escape_stat = None
chart_html = None

In [None]:
# Parameters
init_min_func_effect = -2
init_min_times_seen = 3
init_floor_at_zero = False
init_site_escape_stat = "mean"
escape_csvs = [
    "results/antibody_escape/averages/sera_493C_highACE2_mut_effect.csv",
    "results/antibody_escape/averages/sera_498C_highACE2_mut_effect.csv",
    "results/antibody_escape/averages/sera_500C_highACE2_mut_effect.csv",
    "results/antibody_escape/averages/sera_503C_highACE2_mut_effect.csv",
    "results/antibody_escape/averages/sera_493C_mediumACE2_mut_effect.csv",
    "results/antibody_escape/averages/sera_498C_mediumACE2_mut_effect.csv",
    "results/antibody_escape/averages/sera_500C_mediumACE2_mut_effect.csv",
    "results/antibody_escape/averages/sera_501C_mediumACE2_mut_effect.csv",
    "results/antibody_escape/averages/sera_503C_mediumACE2_mut_effect.csv",
    "results/antibody_escape/averages/sera_287C_mediumACE2_mut_effect.csv",
    "results/antibody_escape/averages/sera_288C_mediumACE2_mut_effect.csv",
    "results/antibody_escape/averages/sera_343C_mediumACE2_mut_effect.csv",
    "results/antibody_escape/averages/sera_497C_mediumACE2_mut_effect.csv",
    "results/antibody_escape/averages/sera_505C_mediumACE2_mut_effect.csv",
]
site_numbering_map_csv = "data/site_numbering_map.csv"
func_effects_csv = "results/func_effects/averages/293T_high_ACE2_entry_func_effects.csv"

import os
os.chdir("../")

Import Python modules:

In [None]:
import collections
import os
import re

import altair as alt

import pandas as pd

_ = alt.data_transformers.disable_max_rows()

## Read in the data

In [None]:
# Get sera shared in both high and medium cells
cell_types = {
    "high": "high ACE2 cells",
    "medium": "medium ACE2 cells",
}
sera_all = collections.defaultdict(dict)
for escape_csv in escape_csvs:
    if m := re.match(
        f"sera_(?P<serum>[\dA-z]+)_(?P<cell>{'|'.join(cell_types)})ACE2",
        os.path.basename(escape_csv),
    ):
        sera_all[m.group("serum")][m.group("cell")] = escape_csv

sera_to_analyze = {}
escape = []
for serum, serum_d in sera_all.items():
    if set(cell_types).issubset(serum_d):
        sera_to_analyze[serum] = serum_d
        print(f"Analyzing {serum}")
        escape += [
            pd.read_csv(serum_d[cell_type]).assign(serum="serum " + serum, cell_type=cell_type)
            for cell_type in cell_types
        ]
    else:
        print(f"Skipping {serum} as not data for all cell types")

escape = (
    pd.concat(escape, ignore_index=True)
    .rename(columns={"escape_median": "escape"})
    [["epitope", "serum", "cell_type", "site", "wildtype", "mutant", "escape", "times_seen"]]
    .merge(
        pd.read_csv(site_numbering_map_csv)
        .rename(columns={"reference_site": "site"})
        [["site", "sequential_site"]],
        validate="many_to_one",
    )
    .merge(
        pd.read_csv(func_effects_csv)[["site", "mutant", "effect"]],
        validate="many_to_one",
    )
)

assert escape["epitope"].nunique() == 1, "code only works for one epitope"
escape = escape.drop(columns="epitope")

assert escape["escape"].notnull().all()

escape

## Plot correlations

In [None]:
plotheight = 160

func_effects_slider = alt.param(
    value=init_min_func_effect,
    bind=alt.binding_range(
        name="minimum spike-mediated entry",
        min=escape["effect"].min(),
        max=0,
    ),
)

times_seen_slider = alt.param(
    value=init_min_times_seen,
    bind=alt.binding_range(
        name="minimum times seen",
        min=0,
        max=min(escape["times_seen"].max(), 15),
    ),
)

floor_at_zero = alt.param(
    value=init_floor_at_zero,
    bind=alt.binding_radio(
        options=[True, False],
        name="floor escape at zero",
    ),
)

mut_selection = alt.selection_point(fields=["mutation"], on="mouseover", empty=False)

# make base chart
chart_base = (
    alt.Chart(escape)
    .add_params(times_seen_slider, func_effects_slider, floor_at_zero, mut_selection)
    .transform_filter(alt.datum["times_seen"] >= times_seen_slider)
    .transform_filter(alt.datum["effect"] >= func_effects_slider)
    .transform_calculate(
        escape_floored=alt.expr.if_(
            floor_at_zero,
            alt.expr.max(0, alt.datum["escape"]),
            alt.datum["escape"],
        ),
    )
)

# plot correlations
assert len(cell_types) == 2, "current code only works for 2 cell types"
corr_base = (
    chart_base
    .transform_calculate(
        mutation=alt.datum["wildtype"] + alt.expr.format(alt.datum["site"], "d") + alt.datum["mutant"],
    )
    .transform_pivot(
        groupby=["serum", "mutation", "site"],
        op="max",  # "sum" converts NaN to zero
        pivot="cell_type",
        value="escape_floored",
    )
)
corr_scatter = (
    corr_base
    .encode(
        x=alt.X(
            list(cell_types)[0],
            type="quantitative",
            title=list(cell_types.values())[0],
            scale=alt.Scale(nice=False, padding=3),
        ),
        y=alt.Y(
            list(cell_types)[1],
            type="quantitative",
            title=list(cell_types.values())[1],
            scale=alt.Scale(nice=False, padding=3),
        ),
        tooltip=[
            "serum",
            "mutation:N",
            *[
                alt.Tooltip(cell_type, type="quantitative", title=cell_types[cell_type], format=".3f")
                for cell_type in cell_types
             ],
        ],
        size=alt.condition(mut_selection, alt.value(70), alt.value(35)),
        opacity=alt.condition(mut_selection, alt.value(1), alt.value(0.25)),
        strokeWidth=alt.condition(mut_selection, alt.value(2), alt.value(0)),
    )
    .mark_circle(color="black", stroke="red")
)
corr_r = (
    corr_base
    .transform_regression(list(cell_types)[0], list(cell_types)[1], params=True)
    .transform_calculate(
        r=alt.expr.if_(
            alt.datum["coef"][1] > 0,
            alt.expr.sqrt(alt.datum["rSquared"]),
            -alt.expr.sqrt(alt.datum["rSquared"]),
        ),
        r_text="r = " + alt.expr.format(alt.datum["r"], ".2f"),
    )
    .encode(
        text="r_text:N",
        x=alt.value(5),
        y=alt.value(10),
    )
    .mark_text(size=14, align="left", color="blue")
)
corr = (
    (corr_scatter + corr_r)
    .properties(width=plotheight, height=plotheight)
    .facet(
        alt.Facet(
            "serum",
            title=None,
            header=alt.Header(
                labelFontSize=12,
                labelFontStyle="bold",
                labelPadding=3,
            ),
        ),
        columns=1,
    )
    .resolve_scale(x="independent", y="independent")
)

## Escape line plots

In [None]:
site_stats = ["mean", "sum", "max", "min"]
assert init_site_escape_stat in site_stats
site_escape_stat_selection = alt.selection_point(
    fields=["site escape statistic"],
    bind=alt.binding_select(
        options=site_stats,
        name="site escape statistic",
    ),
    value=init_site_escape_stat,
)

site_selection = alt.selection_point(fields=["site"], on="mouseover", empty=False)

site_brush = alt.selection_interval(
    encodings=["x"],
    mark=alt.BrushConfig(stroke="black", strokeWidth=2, fillOpacity=0),
    empty=True,
)

site_escape_width = 700

site_zoom_bar = (
    alt.Chart(escape[["site", "sequential_site"]].drop_duplicates())
    .encode(
        x=alt.X(
            "site:N",
            sort=alt.SortField("sequential_site"),
            scale=alt.Scale(nice=False, zero=False),
            axis=alt.Axis(labelOverlap=True),
        ),
        tooltip=["site"],
    )
    .mark_rect(color="gray")
    .properties(
        width=site_escape_width,
        height=11,
        title=alt.TitleParams(
            "site zoom bar",
            fontSize=11,
            fontWeight="bold",
            orient="top",
        ),
    )
)

escape_base = (
    chart_base
    .add_params(site_escape_stat_selection, site_selection)
    .transform_aggregate(
        **{stat: f"{stat}(escape_floored)" for stat in site_stats},
        groupby=["site", "sequential_site", "cell_type", "serum", "wildtype"],
    )
    .transform_fold(fold=site_stats, as_=["site escape statistic", "site escape"])
    .transform_filter(site_escape_stat_selection)
    .encode(
        x=alt.X(
            "site:N",
            sort=alt.SortField("sequential_site"),
            scale=alt.Scale(nice=False, zero=False),
            axis=alt.Axis(labelOverlap=True),
        ),
        y=alt.Y("site escape:Q"),
        tooltip=["serum", "site", "wildtype", "cell_type:N", "site escape:Q"],
        color=alt.Color(
            "cell_type:N",
            scale=alt.Scale(range=["#E69F00", "#009E73", "#F0E442", "#0072B2"]),
            title=["target 293T", "cell ACE2", "expression"],
        ),
    )
)

escape_lines = escape_base.mark_line(size=0.5, opacity=0.65)
escape_points = escape_base.encode(
    size=alt.condition(site_selection, alt.value(50), alt.value(10)),
    strokeWidth=alt.condition(site_selection, alt.value(2), alt.value(0)),
    opacity=alt.condition(site_selection, alt.value(1), alt.value(0.65)),
).mark_circle(stroke="black")

escape_chart = (
    (escape_lines + escape_points)
    .properties(width=site_escape_width, height=plotheight)
    .facet(
        alt.Facet(
            "serum",
            title=None,
            header=alt.Header(
                labelFontSize=12,
                labelFontStyle="bold",
                labelPadding=3,
            ),
        ),
        columns=1,
    )
    .resolve_scale(y="independent", x="independent")
)

chart = (
    alt.vconcat(
        site_zoom_bar,
        alt.hconcat(corr, escape_chart).transform_filter(site_brush),
        center=True,
    )
    .add_params(site_brush)
    .configure_axis(grid=False)
    .properties(
        title=alt.TitleParams(
            "Comparison of serum antibody escape in 293T cells expressing high versus medium levels of ACE2",
            anchor="middle",
            fontSize=16,
        ),
    )
)

print(f"Saving to {chart_html}")
chart.save(chart_html)

chart