# Summarize escape across all sera

In [1]:
import altair as alt

import pandas as pd

The next cell is tagged as `parameters` for `papermill` parameterization:

In [2]:
site_numbering_map_csv = None
func_effects_csv = None
sera = None
times_seen = 3
frac_models = 1
escape_stat = "escape_median"

In [3]:
# Parameters
sera = {
    "sera1-01": "results/antibody_escape/averages/sera1-01_mut_escape.csv",
    "sera1-02": "results/antibody_escape/averages/sera1-02_mut_escape.csv",
    "sera1-03": "results/antibody_escape/averages/sera1-03_mut_escape.csv",
    "sera1-04": "results/antibody_escape/averages/sera1-04_mut_escape.csv",
    "sera1-05": "results/antibody_escape/averages/sera1-05_mut_escape.csv",
    "sera2-01": "results/antibody_escape/averages/sera2-01_mut_escape.csv",
    "sera2-04": "results/antibody_escape/averages/sera2-04_mut_escape.csv",
    "sera2-05": "results/antibody_escape/averages/sera2-05_mut_escape.csv",
}
site_numbering_map_csv = "data/site_numbering_map.csv"
func_effects_csv = "results/func_effects/averages/293T_entry_func_effects.csv"

import os
os.chdir("../")

Read the escape data and add site numbering and functional effect data:

In [4]:
escape_tidy = (
    pd.concat([pd.read_csv(f).assign(serum=s) for s, f in sera.items()])
    .rename(columns={escape_stat: "escape"})
    .query("frac_models >= @frac_models")
    .query("times_seen >= @times_seen")
    [["epitope", "serum", "site", "wildtype", "mutant", "escape"]]
)

assert escape_tidy["epitope"].nunique() == 1, "averaging only works for one epitope"

escape = escape_tidy.pivot_table(
    index=["site", "wildtype", "mutant"],
    columns="serum",
    values="escape",
).reset_index()

site_numbering_map = (
    pd.read_csv(site_numbering_map_csv)
    .rename(columns={"reference_site": "site"})
    [["site", "sequential_site", "region"]]
)

func_effects = (
    pd.read_csv(func_effects_csv)
    .rename(columns={"effect": "functional effect"})
    [["site", "mutant", "functional effect"]]
)

escape

serum,site,wildtype,mutant,sera1-01,sera1-02,sera1-03,sera1-04,sera1-05,sera2-01,sera2-04,sera2-05
0,2,E,A,0.019930,,,,,,,
1,2,E,C,0.043570,-0.007879,0.05043,0.057150,-0.006042,0.07697,0.05997,0.039850
2,2,E,D,,0.288800,0.15600,-0.087910,0.026780,0.26090,0.01403,0.153000
3,2,E,F,-0.158300,0.127500,0.05859,-0.007849,0.046290,0.07214,0.02264,-0.013010
4,2,E,G,0.001388,0.091050,-0.04513,-0.020370,0.005345,-0.07698,0.02925,-0.292700
...,...,...,...,...,...,...,...,...,...,...,...
2805,567,I,C,0.286900,0.079450,-0.03988,-0.086600,-0.141700,-0.02005,-0.07882,-0.064730
2806,567,I,G,,-0.252700,-0.19740,-0.083310,,-0.19010,-0.25760,-0.662600
2807,567,I,L,,0.094700,-0.05316,-0.003195,-0.146300,-0.03702,-0.20260,-0.082610
2808,567,I,V,-0.028730,-0.127300,-0.07234,-0.065560,-0.008174,0.07227,-0.16330,-0.002466


Now make a plot for all the sera:

In [66]:
site_stats = ["mean", "sum", "max", "min"]

site_escape_selection = alt.selection_point(
    fields=["site escape statistic"],
    bind=alt.binding_select(
        options=site_stats,
        name="site escape statistic",
    ),
    value="mean",
)

site_selection = alt.selection_point(fields=["site"], on="mouseover", empty=False)

base = alt.Chart(escape)

site_escape_base = (
    base
    .encode(
        x=alt.X(
            "site:O",
            sort=alt.SortField("sequential_site"),
            scale=alt.Scale(nice=False),
            axis=alt.Axis(labelOverlap=True, grid=False),
        ),
        y=alt.Y(
            "escape:Q",
            title="site escape",
            scale=alt.Scale(nice=False, padding=5),
            axis=alt.Axis(grid=False),
        ),
        tooltip=[
            "site",
            alt.Tooltip("escape:Q", format=".2f"),
            "wildtype",
            "sequential_site:Q",
            "serum:N",
        ],
    )
)

site_escape_lines = site_escape_base.mark_line(size=0.75, color="gray")

site_escape_points = site_escape_base.encode(
    strokeWidth=alt.condition(site_selection, alt.value(2), alt.value(0)),
).mark_circle(filled=True, color="gray", stroke="red")

site_escape_width = 600

site_escape_chart = (
    (site_escape_lines + site_escape_points)
    .transform_fold(fold=list(sera), as_=["serum", "escape"])
    .transform_aggregate(
        **{stat: f"{stat}(escape)" for stat in site_stats},
        groupby=["site", "serum", "wildtype"],
    )
    .transform_fold(fold=site_stats, as_=["site escape statistic", "escape"])
    .transform_filter(site_escape_selection)
    .transform_lookup(
        lookup="site",
        from_=alt.LookupData(
            site_numbering_map,
            key="site",
            fields=["sequential_site", "region"],
        ),
    )
    .properties(height=65, width=site_escape_width)
    .facet(
        facet=alt.Facet(
            "serum:N",
            title="escape from individual sera",
            header=alt.Header(
                labelOrient="right",
                labelFontSize=11,
                labelFontWeight="bold",
                titlePadding=1,
                titleFontSize=13,
                titleFontWeight="bold",
            ),
        ),
        columns=1,
        spacing=1,
    )
    .add_params(site_escape_selection, site_selection)
)

site_escape_chart

mean_escape_chart = (
    (site_escape_lines + site_escape_points)
    .transform_fold(fold=list(sera), as_=["serum", "escape"])
    # average missing values as zero
    .transform_calculate(
        escape=alt.expr.if_(
            alt.expr.isValid(alt.datum["escape"]),
            alt.datum["escape"],
            0,
        ),
    )
    .transform_aggregate(
        **{stat: f"{stat}(escape)" for stat in site_stats},
        groupby=["site", "wildtype"],
    )
    .transform_calculate(serum="'mean of all sera'")
    .transform_fold(fold=site_stats, as_=["site escape statistic", "escape"])
    .transform_filter(site_escape_selection)
    .transform_lookup(
        lookup="site",
        from_=alt.LookupData(
            site_numbering_map,
            key="site",
            fields=["sequential_site", "region"],
        ),
    )
    .add_params(site_escape_selection, site_selection)
    .properties(
        height=100,
        width=site_escape_width,
        title=alt.TitleParams("mean escape across sera", fontSize=13, fontWeight="bold"),
    )
)

mean_escape_chart & site_escape_chart

In [60]:
alt.expr.if_

<function expr.if(*args)>