# Summarize escape across all sera alongside functional effects

In [1]:
import functools
import operator
import re

import altair as alt

import pandas as pd

import polyclonal.alphabets
from polyclonal.plot import color_gradient_hex

_ = alt.data_transformers.disable_max_rows()

The next cell is tagged as `parameters` for `papermill` parameterization:

In [2]:
site_numbering_map_csv = None
func_effects_csv = None
sera = None
chart = None
csv_file = None
receptor_affinity_csv = None

In [3]:
# Parameters
sera = {
    "sera_493C_highACE2": "results/antibody_escape/averages/sera_493C_highACE2_mut_effect.csv",
    "sera_498C_highACE2": "results/antibody_escape/averages/sera_498C_highACE2_mut_effect.csv",
    "sera_500C_highACE2": "results/antibody_escape/averages/sera_500C_highACE2_mut_effect.csv",
    "sera_503C_highACE2": "results/antibody_escape/averages/sera_503C_highACE2_mut_effect.csv",
    "sera_493C_mediumACE2": "results/antibody_escape/averages/sera_493C_mediumACE2_mut_effect.csv",
    "sera_498C_mediumACE2": "results/antibody_escape/averages/sera_498C_mediumACE2_mut_effect.csv",
    "sera_500C_mediumACE2": "results/antibody_escape/averages/sera_500C_mediumACE2_mut_effect.csv",
    "sera_501C_mediumACE2": "results/antibody_escape/averages/sera_501C_mediumACE2_mut_effect.csv",
    "sera_503C_mediumACE2": "results/antibody_escape/averages/sera_503C_mediumACE2_mut_effect.csv",
    "sera_287C_mediumACE2": "results/antibody_escape/averages/sera_287C_mediumACE2_mut_effect.csv",
    "sera_288C_mediumACE2": "results/antibody_escape/averages/sera_288C_mediumACE2_mut_effect.csv",
    "sera_343C_mediumACE2": "results/antibody_escape/averages/sera_343C_mediumACE2_mut_effect.csv",
    "sera_497C_mediumACE2": "results/antibody_escape/averages/sera_497C_mediumACE2_mut_effect.csv",
    "sera_505C_mediumACE2": "results/antibody_escape/averages/sera_505C_mediumACE2_mut_effect.csv",
}
site_numbering_map_csv = "data/site_numbering_map.csv"
func_effects_csv = "results/func_effects/averages/293T_high_ACE2_entry_func_effects.csv"
receptor_affinity_csv = "results/receptor_affinity/averages/monomeric_ACE2_mut_effect.csv"
chart = "results/summaries/escape_summary_nolegend.html"
csv_file = "results/summaries/escape_summary.csv"

import os
os.chdir("../")

Some configuration for plot:

In [4]:
times_seen = 3  # only include mutations with times_seen >= this
frac_models = 1  # only include mutations in >= this fraction of models / selections
escape_stat = "escape_median"  # for each sera, use this escape value (mean or median)
init_site_escape_stat = "sum"  # default site escape stat to show
init_floor_escape_at_zero = False  # default on whether to floor escape at zero

# initial minimum values of other properties to show
other_props_init_min = {
    "functional effect": -3,
    "ACE2 affinity": -1,
}

# amino-acid characters to include
alphabet = polyclonal.AAS_WITHGAP

# for heatmap colors
escape_negative_color = "#0072B2"  # french blue
escape_positive_color = "#E69F00"  # orange
escape_max_at_least = 1
escape_min_at_least = -1

other_prop_heatmap_params = {
    "functional effect": {
        "positive_color": "#009E73",  # green
        "negative_color": "#CC79A7",  # wild orchid
        "max_at_least": 1,
        "min_at_least": 0,
    },
    "ACE2 affinity": {
        "positive_color": "#56B4E9",  # light blue
        "negative_color": "#D55E00",  # red
        "max_at_least": 1,
        "min_at_least": 0,
    }
}

Get just the sera to keep.
Here we keep just the high ACE2 sera:

In [5]:
sera_to_keep = {}
sera_to_discard = []
for serum in sera:
    if (m := re.fullmatch("sera_(?P<name>\w+)_mediumACE2", serum)):
        sera_to_keep[serum] = "serum " + m.group("name")
    else:
        if not re.fullmatch("sera_\w+_highACE2", serum):
            raise ValueError(f"unrecognized {serum=}")
        sera_to_discard.append(serum)

assert len(sera_to_keep) == len(set(sera_to_keep.values()))
print(f"{sera_to_keep.keys()=}\n{sera_to_discard=}")

sera_to_keep.keys()=dict_keys(['sera_493C_mediumACE2', 'sera_498C_mediumACE2', 'sera_500C_mediumACE2', 'sera_501C_mediumACE2', 'sera_503C_mediumACE2', 'sera_287C_mediumACE2', 'sera_288C_mediumACE2', 'sera_343C_mediumACE2', 'sera_497C_mediumACE2', 'sera_505C_mediumACE2'])
sera_to_discard=['sera_493C_highACE2', 'sera_498C_highACE2', 'sera_500C_highACE2', 'sera_503C_highACE2']


Read the escape data:

In [6]:
escape_tidy = (
    pd.concat(
        [
            pd.read_csv(f).assign(serum=sera_to_keep[s])
            for s, f in sera.items()
            if s in sera_to_keep
        ]
    )
    .rename(columns={escape_stat: "escape"})
    .query("frac_models >= @frac_models")
    .query("times_seen >= @times_seen")
    .query("(mutant in @alphabet) and (wildtype in @alphabet)")
    [["epitope", "serum", "site", "wildtype", "mutant", "escape"]]
)

assert escape_tidy["epitope"].nunique() == 1, "averaging only works for one epitope"

escape = (
    escape_tidy
    .pivot_table(
        index=["site", "wildtype", "mutant"],
        columns="serum",
        values="escape",
    ).reset_index()
    .assign(site_mutant=lambda x: x["site"].astype(str) + x["mutant"])
)

assert escape["site_mutant"].nunique() == len(escape)

Get just the sera to keep and rename them:

Read the site numbering map:

In [7]:
site_numbering_map = (
    pd.read_csv(site_numbering_map_csv)
    .rename(columns={"reference_site": "site"})
    [["site", "sequential_site", "region"]]
)

Read other properties:

In [8]:
other_props = {}

for prop, prop_csv, stat_col, n_reps_col in [
    ("functional effect", func_effects_csv, "effect", "n_selections"),
    ("ACE2 affinity", receptor_affinity_csv, "affinity_median", "n_models"),
]:
    print(f"Reading {prop=} from {stat_col=} in {prop_csv=}")
    other_props[prop] = (
        pd.read_csv(prop_csv)
        .rename(columns={stat_col: prop})
        .query("times_seen >= @times_seen")
        .assign(frac_models=lambda x: x[n_reps_col] / x[n_reps_col].max())
        .query("frac_models >= @frac_models")
        [["site", "wildtype", "mutant", prop]]
    )

# add wildtype effects of zero
site_wts = pd.concat([escape, *other_props.values()])[["site", "wildtype"]].drop_duplicates()
for prop in other_props:
    other_props[prop] = (
        pd.concat(
            [
                other_props[prop],
                site_wts.assign(
                    mutant=lambda x: x["wildtype"],
                    **{prop: 0},
                )
            ],
            ignore_index=True,
        )
        .assign(site_mutant=lambda x: x["site"].astype(str) + x["mutant"])
        .merge(site_numbering_map, on="site", validate="many_to_one")
        .query("(mutant in @alphabet) and (wildtype in @alphabet)")
    )
    assert other_props[prop]["site_mutant"].nunique() == len(other_props[prop])

Reading prop='functional effect' from stat_col='effect' in prop_csv='results/func_effects/averages/293T_high_ACE2_entry_func_effects.csv'
Reading prop='ACE2 affinity' from stat_col='affinity_median' in prop_csv='results/receptor_affinity/averages/monomeric_ACE2_mut_effect.csv'


Now make a site summary escape plot for all the sera.
Make plots with both the sera faceted, and the individual sera all overlaid:

In [29]:
floor_escape_at_zero = alt.param(
    value=init_floor_escape_at_zero,
    name="floor_escape_at_zero",
    bind=alt.binding_radio(options=[True, False], name="floor escape at zero"),
)

site_stats = ["mean", "sum", "max", "min"]
site_escape_selection = alt.selection_point(
    fields=["site escape statistic"],
    bind=alt.binding_select(
        options=site_stats,
        name="site escape statistic",
    ),
    value=init_site_escape_stat,
)

site_selection = alt.selection_point(fields=["site"], on="mouseover", empty=False)

other_prop_sliders = {}
for prop, prop_df in other_props.items():
    other_prop_sliders[prop] = alt.param(
        value=other_props_init_min[prop],
        name=prop.replace(" ", "_") + "_slider",
        bind=alt.binding_range(
            name=f"minimum mutation {prop}",
            min=prop_df[prop].min(),
            max=0,
        ),
    )

site_brush = alt.selection_interval(
    encodings=["x"],
    mark=alt.BrushConfig(stroke="black", strokeWidth=2, fillOpacity=0),
    empty=True,
)

site_escape_width = 800  # width of site escape chart
site_escape_overlaid_height = 140  # height of overlaid site escape plots
site_escape_faceted_height = 80  # height of faceted site escape plots

site_escape_base = (
    alt.Chart(escape)
    .encode(
        y=alt.Y(
            "escape:Q",
            scale=alt.Scale(nice=False, padding=4),
            axis=alt.Axis(grid=False),
        ),
        tooltip=[
            "site",
            alt.Tooltip("escape:Q", format=".2f"),
            "wildtype",
            "sequential_site:Q",
            "serum:N",
            "region:N",
        ],
    )
    .transform_filter(site_brush)
)

site_escape_lines = site_escape_base.mark_line(size=0.75)

site_escape_points = site_escape_base.encode(
    strokeWidth=alt.condition(site_selection, alt.value(1.5), alt.value(0)),
    size=alt.condition(site_selection, alt.value(45), alt.value(15)),
).mark_circle(filled=True, stroke="red")

site_escape_lines_and_points = (
    (site_escape_lines + site_escape_points)
    .transform_fold(fold=list(sera_to_keep.values()), as_=["serum", "escape_orig"])
    # floor escape at zero if selected
    .transform_calculate(
        escape=alt.expr.if_(
            floor_escape_at_zero,
            alt.expr.max(alt.datum["escape_orig"], 0),
            alt.datum["escape_orig"],
        )
    )
)

# filter on other properties
for prop, prop_df in other_props.items():
    slider = other_prop_sliders[prop]
    site_escape_lines_and_points = (
        site_escape_lines_and_points
        .transform_lookup(
            lookup="site_mutant",
            from_=alt.LookupData(
                prop_df,
                key="site_mutant",
                fields=[prop],
            ),
        )
        .transform_filter(alt.datum[prop] >= slider)
    )

# compute site statistics
site_escape_lines_and_points = (
    site_escape_lines_and_points
    # compute site statistics from mutation statistics
    .transform_aggregate(
        **{stat: f"{stat}(escape)" for stat in site_stats},
        groupby=["site", "serum", "wildtype"],
    )
    # filter on site statistic of interest
    .transform_fold(fold=site_stats, as_=["site escape statistic", "escape"])
    .transform_filter(site_escape_selection)
    # get sequential sites and regions
    .transform_lookup(
        lookup="site",
        from_=alt.LookupData(
            site_numbering_map,
            key="site",
            fields=["sequential_site", "region"],
        ),
    )
)

site_escape_faceted = (
    site_escape_lines_and_points
    .encode(
        x=alt.X(
            "site:N",
            sort=alt.SortField("sequential_site:Q"),
            axis=alt.Axis(labelOverlap=True, grid=False, ticks=False),
        ),
        color=alt.value("gray"),
    )
    .properties(height=site_escape_faceted_height, width=site_escape_width)
    .facet(
        facet=alt.Facet(
            "serum:N",
            title="individual sera",
            header=alt.Header(
                labelOrient="right",
                labelFontSize=10,
                labelPadding=3,
                titleOrient="right",
                titlePadding=3,
            ),
        ),
        columns=1,
        spacing=0,
    )
    .resolve_scale(y="independent")
)

site_escape_overlaid = (
    site_escape_lines_and_points
    .encode(
        x=alt.X(
            "site:N",
            sort=alt.SortField("sequential_site:Q"),
            axis=alt.Axis(labelOverlap=True, grid=False, ticks=False),
        ),
        opacity=alt.value(0.4),
        color=alt.value("gray"),
        detail="serum:N",
    )
    .properties(
        height=site_escape_overlaid_height,
        width=site_escape_width,
        title=alt.TitleParams(
            "individual sera", fontSize=11, fontWeight="bold", orient="right",
        ),
    )
)

site_mean_escape = (
    site_escape_lines_and_points
    # average missing values as zero
    .transform_calculate(
        escape=alt.expr.if_(
            alt.expr.isValid(alt.datum["escape"]),
            alt.datum["escape"],
            0,
        ),
    )
    # take mean over sera
    .transform_aggregate(
        escape="mean(escape)",
        groupby=["site", "wildtype", "sequential_site", "region"],
    )
    .transform_calculate(serum="'mean of all sera'")
    .encode(
        x=alt.X(
            "site:N",
            sort=alt.SortField("sequential_site:Q"),
            axis=None,
        ),
        color=alt.value("black")
    )
    .properties(
        title=alt.TitleParams(
            "mean of sera", fontSize=11, fontWeight="bold", orient="right",
        ),
    )
)

region_bar = (
    alt.Chart(site_numbering_map)
    .encode(
        x=alt.X(
            "site:N",
            sort=alt.SortField("sequential_site:Q"),
            axis=None,
        ),
        color=alt.Color(
            "region",
            scale=alt.Scale(domain=site_numbering_map["region"].unique()),
        ),
        tooltip=["site", "region", "sequential_site"],
    )
    .mark_rect()
    .properties(
        width=site_escape_width,
        height=11,
        title=alt.TitleParams(
            "site zoom bar", fontSize=11, fontWeight="bold", orient="top",
        ),
    )
)       

site_charts = {}
for chart_type, height, site_chart in [
    ("faceted", site_escape_faceted_height, site_escape_faceted),
    ("overlaid", site_escape_overlaid_height, site_escape_overlaid),
]:
    site_charts[chart_type] = (
        alt.vconcat(
            region_bar.add_params(site_brush), 
            (
                alt.vconcat(
                    site_mean_escape.properties(height=height, width=site_escape_width),
                    site_chart,
                    spacing=3,
                )
                .add_params(
                    site_escape_selection,
                    site_selection,
                    *other_prop_sliders.values(),
                    floor_escape_at_zero,
                )
            ),
            spacing=0,
        )
    )

display(site_charts["overlaid"])

Now prepare to plot the heatmaps.
First, create a data frame that has the average escape across sera (averaging mutations missing for a serum as zero for that serum) and other properties:

In [None]:
heatmap_data = (
    pd.concat(
        [
            escape,
            # add wildtype with zero escape
            (
                escape
                [["site", "wildtype"]]
                .drop_duplicates()
                .assign(mutant=lambda x: x["wildtype"])
            ),
        ],
        ignore_index=True,
    )
    .fillna(0)
    .assign(escape=lambda x: x[list(sera)].mean(axis=1))
    .drop(columns=list(sera) + ["site_mutant"])
)

for prop, prop_df in other_props.items():
    heatmap_data = heatmap_data.merge(prop_df, validate="one_to_one", how="outer")

heatmap_data = (
    heatmap_data
    .drop(columns=["sequential_site", "region"])
    .merge(site_numbering_map, validate="many_to_one")
    .assign(
        escape=lambda x: x["escape"].where(
            x["wildtype"] != x["mutant"],
            0,
        ),
    )
    .drop(columns="site_mutant")
)

heatmap_data

Write these data to a CSV:

In [None]:
print(f"Writing summary data to {csv_file}")

(
    heatmap_data
    .merge(
        heatmap_data
        .query("wildtype != mutant")
        .groupby("site", as_index=False)
        .aggregate(mean_site_escape=pd.NamedAgg("escape", "mean")),
        how="outer",
        validate="many_to_one",
    )
    .to_csv(csv_file, index=False, float_format="%.4g")
)

Make heatmaps:

In [None]:
cell_size = 9  # heatmap cell size

sorted_alphabet = polyclonal.alphabets.biochem_order_aas(alphabet)

heatmap_base = (
    alt.Chart(heatmap_data)
    # convert null values to NaN so they show as NaN in tooltips rather than as 0.0
    .transform_calculate(
        escape_floored=alt.expr.if_(
            floor_escape_at_zero,
            alt.expr.max(alt.datum["escape"], 0),
            alt.datum["escape"],
        ),
        **{
            col: alt.expr.if_(
                alt.expr.isFinite(alt.datum[col]),
                alt.datum[col],
                alt.expr.NaN,
            )
            for col in ["escape", "functional effect", "affinity_median"]
        }
    )
    .encode(
        x=alt.X(
            "site:N",
            sort=alt.SortField("sequential_site"),
            axis=alt.Axis(labelFontSize=9, ticks=False),
        ),
        y=alt.Y(
            "mutant:N",
            title="amino acid",
            sort=sorted_alphabet,
            axis=alt.Axis(labelFontSize=9, ticks=False),
        ),
    )
    .properties(width=alt.Step(cell_size), height=alt.Step(cell_size))
    .add_params(*other_prop_sliders.values(), floor_escape_at_zero)
)

# mark X for wildtype
heatmap_wildtype = (
    heatmap_base
    .transform_filter(alt.datum["wildtype"] == alt.datum["mutant"])
    .mark_text(text="x", color="black")
)

# gray background for missing values
heatmap_bg = (
    heatmap_base
    .transform_impute(
        impute="_stat_dummy",
        key="mutant",
        keyvals=sorted_alphabet,
        groupby=["site"],
        value=None,
    )
    .mark_rect(color="#E0E0E0", opacity=0.8)
)

tooltips = [
    "site",
    "mutant",
    alt.Tooltip("escape", format=".2f"),
    *[alt.Tooltip(prop, format=".2f") for prop in other_props],
    "wildtype",
    "sequential_site",
    "region",
]

legend=alt.Legend(
    orient="left",
    titleOrient="left",
    gradientLength=100,
    gradientThickness=10,
    gradientStrokeColor="black",
    gradientStrokeWidth=0.5,
)

# heatmap for escape
escape_heatmap = (
    heatmap_base
    .transform_filter(
        functools.reduce(
            operator.or_,
            [(alt.datum[prop] >= slider) for prop, slider in other_prop_sliders.items()],
        ) | (alt.datum["wildtype"] == alt.datum["mutant"])
    )
    .encode(
        # turn off x-labels for this heatmap since it is stacked
        x=alt.X(
            "site:N",
            sort=alt.SortField("sequential_site"),
            title=None,
            axis=alt.Axis(ticks=False, labels=False),
        ),
        color=alt.Color(
            "escape_floored:Q",
            title="escape",
            legend=legend,
            scale=alt.Scale(
                zero=True,
                nice=False,
                type="linear",
                domainMid=0,
                domainMax=max(escape_max_at_least, heatmap_data["escape"].max()),
                domainMin=alt.ExprRef(
                    f"if(floor_escape_at_zero, 0, {escape_min_at_least})"
                ),
                range=(
                    color_gradient_hex(escape_negative_color, "white", n=20)
                    + color_gradient_hex("white", escape_positive_color, n=20)[1:]
                ),
            ),
        ),
        tooltip=tooltips,
    )
    .mark_rect(stroke="black")
)

# heatmap for other property (eg, functional effect) filtered escape
escape_filtered_heatmap = (
    heatmap_base
    .transform_filter(
        functools.reduce(
            operator.or_,
            [alt.datum[prop] < slider for prop, slider in other_prop_sliders.items()],
        )
        & (alt.datum["wildtype"] != alt.datum["mutant"])
    )
    .transform_calculate(filtered="''")
    .encode(
        tooltip=tooltips,
        color=alt.Color(
            "filtered:N",
            title=["functionally", "deleterious"],
            scale=alt.Scale(range=["silver"]),
            legend=None,
        ),
    )
    .mark_rect(stroke="black")
)

# heatmaps for other properties
other_prop_heatmaps = {}
other_prop_filtered_heatmaps = {}
for prop in other_props:
    last_heatmap = (prop == list(other_props)[-1])
    params = other_prop_heatmap_params[prop]
    slider_name = prop.replace(" ", "_") + "_slider"  # name given when defining slider
    other_prop_heatmaps[prop] = (
        heatmap_base
        .transform_filter(
            functools.reduce(
                operator.or_,
                [
                    (alt.datum[other_prop] >= slider)
                    for other_prop, slider in other_prop_sliders.items()
                    if other_prop != prop
                ],
            )
            | (alt.datum["wildtype"] == alt.datum["mutant"])
        )
        .encode(
            # turn off x-labels for all but last heatmap since stacked
            x=alt.X(
                "site:N",
                sort=alt.SortField("sequential_site"),
                title=None,
                axis=alt.Axis() if last_heatmap else alt.Axis(ticks=False, labels=False),
            ),
            color=alt.Color(
                prop,
                legend=legend,
                scale=alt.Scale(
                    zero=True,
                    nice=False,
                    type="linear",
                    clamp=True,
                    domainMid=0,
                    domainMax=max(params["max_at_least"], heatmap_data[prop].max()),
                    domainMin=alt.ExprRef(f"min({slider_name}, {params['min_at_least']})"),
                    range=(
                        color_gradient_hex(params["negative_color"], "white", n=20)
                        + color_gradient_hex("white", params["positive_color"], n=20)[1:]
                    ),
                ),
            ),
            tooltip=tooltips,
        )
        .mark_rect(stroke="black")
    )
    other_prop_filtered_heatmaps[prop] = (
        heatmap_base
        .transform_filter(
            functools.reduce(
                operator.or_,
                [
                    alt.datum[other_prop] < slider
                    for other_prop, slider in other_prop_sliders.items()
                    if other_prop != prop
                ],
            )
            & (alt.datum["wildtype"] != alt.datum["mutant"])
        )
        .transform_calculate(filtered="''")
        .encode(
            tooltip=tooltips,
            color=alt.Color(
                "filtered:N",
                title=["functionally", "deleterious"],
                scale=alt.Scale(range=["silver"]),
                legend=None,
            ),
        )
        .mark_rect(stroke="black")
    )

heatmap = alt.vconcat(
    heatmap_bg + escape_heatmap + escape_filtered_heatmap + heatmap_wildtype,
    *[
        (
            heatmap_bg
            + other_prop_heatmaps[other_prop]
            + other_prop_filtered_heatmaps[other_prop]
            + heatmap_wildtype
        )
        for other_prop in other_props
     ],
    spacing=1,
).resolve_scale(color="independent")

heatmap

Make merged chart with everything:

In [None]:
merged_chart = alt.vconcat(
    site_chart,
    heatmap.transform_filter(site_brush),
    spacing=5,
).configure_legend(orient="left")

print(f"Saving to {chart}")
merged_chart.save(chart)

merged_chart