# Compare escape in spike and RBD deep mutational scans
Compare antibody-escape in spike and RBD deep mutational scanning libraries:

In [None]:
# this cell is tagged `parameters` for `papermill` parameterization
site_numbering_map_csv = None
func_effects_csv = None
spike_escape_csvs = None
rbd_escape_csvs = None
init_min_func_effect = None
init_min_times_seen = None
init_floor_at_zero = None
init_site_escape_stat = None
corr_chart_html = None
dist_chart_html = None

Import Python modules:

In [None]:
import collections
import os
import re

import altair as alt

import pandas as pd

_ = alt.data_transformers.disable_max_rows()

## Read in the data for sera measured in both spike and RBD

In [None]:
escape = []
for spike_escape_csv in spike_escape_csvs:
    if m := re.match(
        "sera_(?P<serum>[\dA-z]+)_mediumACE2",
        os.path.basename(spike_escape_csv),
    ):
        serum = m.group("serum")
        for rbd_escape_csv in rbd_escape_csvs:
            if os.path.basename(rbd_escape_csv).startswith(f"sera_{serum}_mediumACE2"):
                print(f"Serum {serum} has matches in both library types")
                escape += [
                    pd.read_csv(f).assign(serum="serum " + serum, library_type=libtype)
                    for (f, libtype) in [
                        (spike_escape_csv, "full spike"), (rbd_escape_csv, "RBD only")
                    ]
                ]

escape = (
    pd.concat(escape, ignore_index=True)
    .rename(columns={"escape_median": "escape"})
    [["epitope", "serum", "library_type", "site", "wildtype", "mutant", "escape", "times_seen"]]
    .merge(
        pd.read_csv(site_numbering_map_csv)
        .rename(columns={"reference_site": "site"})
        [["site", "sequential_site"]],
        validate="many_to_one",
    )
    .merge(
        pd.read_csv(func_effects_csv)[["site", "mutant", "effect"]],
        validate="many_to_one",
    )
)

assert escape["epitope"].nunique() == 1, "code only works for one epitope"
escape = escape.drop(columns="epitope")

assert escape["escape"].notnull().all()

escape

## Plot correlations

In [None]:
plotheight = 160

library_types = escape["library_type"].unique().tolist()

func_effects_slider = alt.param(
    value=init_min_func_effect,
    bind=alt.binding_range(
        name="minimum spike-mediated entry",
        min=escape["effect"].min(),
        max=0,
    ),
)

times_seen_slider = alt.param(
    value=init_min_times_seen,
    bind=alt.binding_range(
        name="minimum times seen",
        min=0,
        max=min(escape["times_seen"].max(), 15),
    ),
)

floor_at_zero = alt.param(
    value=init_floor_at_zero,
    bind=alt.binding_radio(
        options=[True, False],
        name="floor escape at zero",
    ),
)

mut_selection = alt.selection_point(fields=["mutation"], on="mouseover", empty=False)

# make base chart
chart_base = (
    alt.Chart(escape)
    .add_params(times_seen_slider, func_effects_slider, floor_at_zero, mut_selection)
    .transform_filter(alt.datum["times_seen"] >= times_seen_slider)
    .transform_filter(alt.datum["effect"] >= func_effects_slider)
    .transform_calculate(
        escape_floored=alt.expr.if_(
            floor_at_zero,
            alt.expr.max(0, alt.datum["escape"]),
            alt.datum["escape"],
        ),
    )
)

# plot correlations
assert len(library_types) == 2, "current code only works for 2 library types"
corr_base = (
    chart_base
    .transform_calculate(
        mutation=alt.datum["wildtype"] + alt.expr.format(alt.datum["site"], "d") + alt.datum["mutant"],
    )
    .transform_pivot(
        groupby=["serum", "mutation", "site"],
        op="max",  # "sum" converts NaN to zero
        pivot="library_type",
        value="escape_floored",
    )
)
corr_scatter = (
    corr_base
    .encode(
        x=alt.X(
            library_types[0],
            type="quantitative",
            scale=alt.Scale(nice=False, padding=3),
        ),
        y=alt.Y(
            library_types[1],
            type="quantitative",
            scale=alt.Scale(nice=False, padding=3),
        ),
        tooltip=[
            "serum",
            "mutation:N",
            *[
                alt.Tooltip(lib_type, type="quantitative", format=".3f")
                for lib_type in library_types
             ],
        ],
        size=alt.condition(mut_selection, alt.value(70), alt.value(35)),
        opacity=alt.condition(mut_selection, alt.value(1), alt.value(0.25)),
        strokeWidth=alt.condition(mut_selection, alt.value(2), alt.value(0)),
    )
    .mark_circle(color="black", stroke="red")
)
corr_r = (
    corr_base
    .transform_regression(library_types[0], library_types[1], params=True)
    .transform_calculate(
        r=alt.expr.if_(
            alt.datum["coef"][1] > 0,
            alt.expr.sqrt(alt.datum["rSquared"]),
            -alt.expr.sqrt(alt.datum["rSquared"]),
        ),
        r_text="r = " + alt.expr.format(alt.datum["r"], ".2f"),
    )
    .encode(
        text="r_text:N",
        x=alt.value(5),
        y=alt.value(10),
    )
    .mark_text(size=14, align="left", color="blue")
)
corr = (
    (corr_scatter + corr_r)
    .properties(width=plotheight, height=plotheight)
    .facet(
        alt.Facet(
            "serum",
            title=None,
            header=alt.Header(
                labelFontSize=12,
                labelFontStyle="bold",
                labelPadding=3,
            ),
        ),
        columns=1,
    )
    .resolve_scale(x="independent", y="independent")
)

## Escape line plots

In [None]:
site_stats = ["mean", "sum", "max", "min"]
assert init_site_escape_stat in site_stats
site_escape_stat_selection = alt.selection_point(
    fields=["site escape statistic"],
    bind=alt.binding_select(
        options=site_stats,
        name="site escape statistic",
    ),
    value=init_site_escape_stat,
)

site_selection = alt.selection_point(fields=["site"], on="mouseover", empty=False)

site_brush = alt.selection_interval(
    encodings=["x"],
    mark=alt.BrushConfig(stroke="black", strokeWidth=2, fillOpacity=0),
    empty=True,
)

site_escape_width = 700

site_zoom_bar = (
    alt.Chart(escape[["site", "sequential_site"]].drop_duplicates())
    .encode(
        x=alt.X(
            "site:N",
            sort=alt.SortField("sequential_site"),
            scale=alt.Scale(nice=False, zero=False),
            axis=alt.Axis(labelOverlap=True),
        ),
        tooltip=["site"],
    )
    .mark_rect(color="gray")
    .properties(
        width=site_escape_width,
        height=11,
        title=alt.TitleParams(
            "site zoom bar",
            fontSize=11,
            fontWeight="bold",
            orient="top",
        ),
    )
)

escape_base = (
    chart_base
    .add_params(site_escape_stat_selection, site_selection)
    .transform_aggregate(
        **{stat: f"{stat}(escape_floored)" for stat in site_stats},
        groupby=["site", "sequential_site", "library_type", "serum", "wildtype"],
    )
    .transform_fold(fold=site_stats, as_=["site escape statistic", "site escape"])
    .transform_filter(site_escape_stat_selection)
    .encode(
        x=alt.X(
            "site:N",
            sort=alt.SortField("sequential_site"),
            scale=alt.Scale(nice=False, zero=False),
            axis=alt.Axis(labelOverlap=True),
        ),
        y=alt.Y("site escape:Q"),
        tooltip=["serum", "site", "wildtype", "library_type:N", "site escape:Q"],
        color=alt.Color(
            "library_type:N",
            scale=alt.Scale(range=["#E69F00", "#009E73", "#F0E442", "#0072B2"]),
            title=["library type"],
        ),
    )
)

escape_lines = escape_base.mark_line(size=0.5, opacity=0.65)
escape_points = escape_base.encode(
    size=alt.condition(site_selection, alt.value(50), alt.value(10)),
    strokeWidth=alt.condition(site_selection, alt.value(2), alt.value(0)),
    opacity=alt.condition(site_selection, alt.value(1), alt.value(0.65)),
).mark_circle(stroke="black")

escape_chart = (
    (escape_lines + escape_points)
    .properties(width=site_escape_width, height=plotheight)
    .facet(
        alt.Facet(
            "serum",
            title=None,
            header=alt.Header(
                labelFontSize=12,
                labelFontStyle="bold",
                labelPadding=3,
            ),
        ),
        columns=1,
    )
    .resolve_scale(y="independent", x="independent")
)

chart = (
    alt.vconcat(
        site_zoom_bar,
        alt.hconcat(corr, escape_chart).transform_filter(site_brush),
        center=True,
    )
    .add_params(site_brush)
    .configure_axis(grid=False)
    .properties(
        title=alt.TitleParams(
            "Comparison of serum antibody escape in pseudovirus mutant libraries of full spike or RBD only",
            anchor="middle",
            fontSize=16,
        ),
    )
)

print(f"Saving to {corr_chart_html}")
chart.save(corr_chart_html)

chart

## Distribution of escape inside and outside RBD in full spike DMS

In [None]:
escape_spike = []
for spike_escape_csv in spike_escape_csvs:
    if m := re.match(
        "sera_(?P<serum>[\dA-z]+)_mediumACE2",
        os.path.basename(spike_escape_csv),
    ):
        serum = m.group("serum")
        escape_spike.append(
            pd.read_csv(spike_escape_csv).assign(serum="serum " + serum)
        )

escape_spike = (
    pd.concat(escape_spike, ignore_index=True)
    .rename(columns={"escape_median": "escape"})
    [["epitope", "serum", "site", "mutant", "escape", "times_seen"]]
    .merge(
        pd.read_csv(site_numbering_map_csv)
        .rename(columns={"reference_site": "site"})
        [["site", "region"]],
        validate="many_to_one",
    )
    .merge(
        pd.read_csv(func_effects_csv)[["site", "mutant", "effect"]],
        validate="many_to_one",
    )
    .assign(region=lambda x: x["region"].where(x["region"] == "RBD", "not RBD"))
)

assert escape_spike["epitope"].nunique() == 1, "code only works for one epitope"
escape_spike = escape_spike.drop(columns="epitope")

assert escape_spike["escape"].notnull().all()

In [None]:
escape_hist = (
    alt.Chart(escape_spike)
    .add_params(times_seen_slider, func_effects_slider, floor_at_zero)
    .transform_filter(alt.datum["times_seen"] >= times_seen_slider)
    .transform_filter(alt.datum["effect"] >= func_effects_slider)
    .transform_joinaggregate(total="count(*)", groupby=["region"])
    .transform_calculate(
        escape_floored=alt.expr.if_(
            floor_at_zero,
            alt.expr.max(0, alt.datum["escape"]),
            alt.datum["escape"],
        ),
        pct=100 / alt.datum["total"],
    )   
    .encode(
        x=alt.X(
            "escape_floored:Q",
            title="escape",
            axis=alt.Axis(labelOverlap=True, titleFontSize=12),
            bin=alt.BinParams(step=0.2),
            scale=alt.Scale(nice=False),
        ),
        y=alt.Y("sum(pct):Q", title=None, scale=alt.Scale(nice=False, padding=2)),
        color=alt.Color(
            "region",
            legend=None,
            scale=alt.Scale(range=["#E69F00", "#009E73"], domain=["RBD", "not RBD"]),
        ),
        column=alt.Column(
            "serum",
            title=None,
            header=alt.Header(
                labelFontSize=12,
                labelFontStyle="bold",
                labelPadding=3,
            ),
            spacing=2,
        ),
        row=alt.Row(
            "region",
            title=None,
            header=alt.Header(
                labelFontSize=12,
                labelFontStyle="bold",
                labelExpr="'% ' + datum.value + ' mutations'",
                orient="left",
                labelPadding=3,
            ),
            spacing=2,
        ),
    )
    .mark_bar()
    .configure_axis(grid=False)
    .properties(
        width=100,
        height=125,
        title=alt.TitleParams(
            "Serum antibody escape for RBD and non-RBD mutations",
            anchor="middle",
            fontSize=16,
        ),
    )
)

print(f"Saving to {dist_chart_html}")
escape_hist.save(dist_chart_html)

escape_hist