# Compare escape pre and post vaccination or infection

In [None]:
# this exposure is tagged `parameters` for `papermill` parameterization
site_numbering_map_csv = None
func_effects_csv = None
escape_csvs = None
max_effect_std = None
init_min_func_effect = None
init_min_times_seen = None
init_floor_at_zero = None
init_site_escape_stat = None
chart_html = None

Import Python modules:

In [1]:
import collections
import os
import re

import altair as alt

import pandas as pd

_ = alt.data_transformers.disable_max_rows()

## Read in the data

In [4]:
# Get sera with both pre and post exposure measures
exposure_types = {
    "pre": "pre exposure",
    "post": "post exposure",
}
sera_all = collections.defaultdict(dict)
for escape_csv in escape_csvs:
    if m := re.match(
        rf"sera_(?P<serum>[\dA-Za-z]+)_(?P<exposure>{'|'.join(exposure_types)})",
        os.path.basename(escape_csv),
    ):
        sera_all[m.group("serum")][m.group("exposure")] = escape_csv

sera_to_analyze = {}
escape = []
for serum, serum_d in sera_all.items():
    if set(exposure_types).issubset(serum_d):
        sera_to_analyze[serum] = serum_d
        print(f"Analyzing {serum}")
        escape += [
            pd.read_csv(serum_d[exposure_type]).assign(serum="serum " + serum, exposure_type=exposure_type)
            for exposure_type in exposure_types
        ]
    else:
        print(f"Skipping {serum} as not data for all exposure types")

escape = (
    pd.concat(escape, ignore_index=True)
    .rename(columns={"escape_median": "escape"})
    [["epitope", "serum", "exposure_type", "site", "wildtype", "mutant", "escape", "times_seen"]]
    .merge(
        pd.read_csv(site_numbering_map_csv)
        .rename(columns={"reference_site": "site"})
        [["site", "sequential_site"]],
        validate="many_to_one",
    )
    .merge(
        pd.read_csv(func_effects_csv).query("effect_std <= @max_effect_std")[["site", "mutant", "effect"]],
        validate="many_to_one",
    )
)

assert escape["epitope"].nunique() == 1, "code only works for one epitope"
escape = escape.drop(columns="epitope")

assert escape["escape"].notnull().all()

escape

Analyzing 527C
Skipping 503C as not data for all exposure types
Analyzing 183C
Analyzing 281C
Analyzing 087C
Analyzing 404C
Analyzing 507C


Unnamed: 0,serum,exposure_type,site,wildtype,mutant,escape,times_seen,sequential_site,effect
0,serum 527C,pre,2,F,L,0.060840,5.500,2,-0.17060
1,serum 527C,pre,2,F,S,0.091130,3.500,2,-0.11940
2,serum 527C,pre,3,V,A,0.049460,4.500,3,0.10420
3,serum 527C,pre,3,V,F,0.044430,4.000,3,-0.59310
4,serum 527C,pre,3,V,G,0.132400,2.000,3,-0.52040
...,...,...,...,...,...,...,...,...,...
86747,serum 507C,post,1252,S,F,-0.003998,5.667,1247,-0.54730
86748,serum 507C,post,1252,S,P,-0.022720,7.500,1247,-0.03440
86749,serum 507C,post,1252,S,Y,-0.017220,3.500,1247,0.12710
86750,serum 507C,post,1253,*,Q,0.003006,6.000,1248,0.03426


## Plot correlations

In [5]:
plotheight = 160

func_effects_slider = alt.param(
    value=init_min_func_effect,
    bind=alt.binding_range(
        name="minimum spike-mediated entry",
        min=escape["effect"].min(),
        max=0,
    ),
)

times_seen_slider = alt.param(
    value=init_min_times_seen,
    bind=alt.binding_range(
        name="minimum times seen",
        min=0,
        max=min(escape["times_seen"].max(), 15),
    ),
)

floor_at_zero = alt.param(
    value=init_floor_at_zero,
    bind=alt.binding_radio(
        options=[True, False],
        name="floor escape at zero",
    ),
)

mut_selection = alt.selection_point(fields=["mutation"], on="mouseover", empty=False)

# make base chart
chart_base = (
    alt.Chart(escape)
    .add_params(times_seen_slider, func_effects_slider, floor_at_zero, mut_selection)
    .transform_filter(alt.datum["times_seen"] >= times_seen_slider)
    .transform_filter(alt.datum["effect"] >= func_effects_slider)
    .transform_calculate(
        escape_floored=alt.expr.if_(
            floor_at_zero,
            alt.expr.max(0, alt.datum["escape"]),
            alt.datum["escape"],
        ),
    )
)

# plot correlations
assert len(exposure_types) == 2, "current code only works for 2 exposure types"
corr_base = (
    chart_base
    .transform_calculate(
        mutation=alt.datum["wildtype"] + alt.expr.format(alt.datum["site"], "d") + alt.datum["mutant"],
    )
    .transform_pivot(
        groupby=["serum", "mutation", "site"],
        op="max",  # "sum" converts NaN to zero
        pivot="exposure_type",
        value="escape_floored",
    )
)
corr_scatter = (
    corr_base
    .encode(
        x=alt.X(
            list(exposure_types)[0],
            type="quantitative",
            title=list(exposure_types.values())[0],
            scale=alt.Scale(nice=False, padding=3),
        ),
        y=alt.Y(
            list(exposure_types)[1],
            type="quantitative",
            title=list(exposure_types.values())[1],
            scale=alt.Scale(nice=False, padding=3),
        ),
        tooltip=[
            "serum",
            "mutation:N",
            *[
                alt.Tooltip(exposure_type, type="quantitative", title=exposure_types[exposure_type], format=".3f")
                for exposure_type in exposure_types
             ],
        ],
        size=alt.condition(mut_selection, alt.value(70), alt.value(35)),
        opacity=alt.condition(mut_selection, alt.value(1), alt.value(0.25)),
        strokeWidth=alt.condition(mut_selection, alt.value(2), alt.value(0)),
    )
    .mark_circle(color="black", stroke="red")
)
corr_r = (
    corr_base
    .transform_regression(list(exposure_types)[0], list(exposure_types)[1], params=True)
    .transform_calculate(
        r=alt.expr.if_(
            alt.datum["coef"][1] > 0,
            alt.expr.sqrt(alt.datum["rSquared"]),
            -alt.expr.sqrt(alt.datum["rSquared"]),
        ),
        r_text="r = " + alt.expr.format(alt.datum["r"], ".2f"),
    )
    .encode(
        text="r_text:N",
        x=alt.value(5),
        y=alt.value(10),
    )
    .mark_text(size=14, align="left", color="blue")
)
corr = (
    (corr_scatter + corr_r)
    .properties(width=plotheight, height=plotheight)
    .facet(
        alt.Facet(
            "serum",
            title=None,
            header=alt.Header(
                labelFontSize=12,
                labelFontStyle="bold",
                labelPadding=3,
            ),
        ),
        columns=1,
    )
    .resolve_scale(x="independent", y="independent")
)

## Escape line plots

In [None]:
site_stats = ["mean", "sum", "max", "min"]
assert init_site_escape_stat in site_stats
site_escape_stat_selection = alt.selection_point(
    fields=["site escape statistic"],
    bind=alt.binding_select(
        options=site_stats,
        name="site escape statistic",
    ),
    value=init_site_escape_stat,
)

site_selection = alt.selection_point(fields=["site"], on="mouseover", empty=False)

site_brush = alt.selection_interval(
    encodings=["x"],
    mark=alt.BrushConfig(stroke="black", strokeWidth=2, fillOpacity=0),
    empty=True,
)

site_escape_width = 700

site_zoom_bar = (
    alt.Chart(escape[["site", "sequential_site"]].drop_duplicates())
    .encode(
        x=alt.X(
            "site:N",
            sort=alt.SortField("sequential_site"),
            scale=alt.Scale(nice=False, zero=False),
            axis=alt.Axis(labelOverlap=True),
        ),
        tooltip=["site"],
    )
    .mark_rect(color="gray")
    .properties(
        width=site_escape_width,
        height=11,
        title=alt.TitleParams(
            "site zoom bar",
            fontSize=11,
            fontWeight="bold",
            orient="top",
        ),
    )
)

escape_base = (
    chart_base
    .add_params(site_escape_stat_selection, site_selection)
    .transform_aggregate(
        **{stat: f"{stat}(escape_floored)" for stat in site_stats},
        groupby=["site", "sequential_site", "exposure_type", "serum", "wildtype"],
    )
    .transform_fold(fold=site_stats, as_=["site escape statistic", "site escape"])
    .transform_filter(site_escape_stat_selection)
    .encode(
        x=alt.X(
            "site:N",
            sort=alt.SortField("sequential_site"),
            scale=alt.Scale(nice=False, zero=False),
            axis=alt.Axis(labelOverlap=True),
        ),
        y=alt.Y("site escape:Q"),
        tooltip=["serum", "site", "wildtype", "exposure_type:N", "site escape:Q"],
        color=alt.Color(
            "exposure_type:N",
            scale=alt.Scale(range=["#f0c808", "#124e78", "#F0E442", "#0072B2"]),
            title=["serum","pre or post", "exposure"],
        ),
    )
)

escape_lines = escape_base.mark_line(size=0.5, opacity=0.65)
escape_points = escape_base.encode(
    size=alt.condition(site_selection, alt.value(50), alt.value(10)),
    strokeWidth=alt.condition(site_selection, alt.value(2), alt.value(0)),
    opacity=alt.condition(site_selection, alt.value(1), alt.value(0.65)),
).mark_circle(stroke="black")

escape_chart = (
    (escape_lines + escape_points)
    .properties(width=site_escape_width, height=plotheight)
    .facet(
        alt.Facet(
            "serum",
            title=None,
            header=alt.Header(
                labelFontSize=12,
                labelFontStyle="bold",
                labelPadding=3,
            ),
        ),
        columns=1,
    )
    .resolve_scale(y="independent", x="independent")
)

chart = (
    alt.vconcat(
        site_zoom_bar,
        alt.hconcat(corr, escape_chart).transform_filter(site_brush),
        center=True,
    )
    .add_params(site_brush)
    .configure_axis(grid=False)
    .properties(
        title=alt.TitleParams(
            "Comparison of serum escape pre and post exposure",
            anchor="middle",
            fontSize=16,
        ),
    )
)

print(f"Saving to {chart_html}")
chart.save(chart_html)

chart