# Summarize results across assays
This notebook makes summarizes the results across assays.

In [1]:
import functools
import operator
import re

import altair as alt

import pandas as pd

import polyclonal.alphabets
from polyclonal.plot import color_gradient_hex

_ = alt.data_transformers.disable_max_rows()

## Get configuration parameteres

The next cell is tagged as `parameters` for `papermill` parameterization:

In [2]:
site_numbering_map_csv = None
chart_overlaid = None
chart_faceted = None
output_csv_file = None
config = None
input_csvs = None

In [3]:
# Parameters
config = {
    "min_times_seen": 3,
    "min_frac_models": 1,
    "alphabet": [
        "A",
        "C",
        "D",
        "E",
        "F",
        "G",
        "H",
        "I",
        "K",
        "L",
        "M",
        "N",
        "P",
        "Q",
        "R",
        "S",
        "T",
        "V",
        "W",
        "Y",
        "-",
    ],
    "init_floor_escape_at_zero": True,
    "init_site_escape_stat": "mean",
    "antibody_escape": {
        "monoclonal antibodies": {
            "stat": "escape_median",
            "negative_color": "#0072B2",
            "positive_color": "#E69F00",
            "max_at_least": 1,
            "min_at_least": -1,
            "antibody_list": {
                "REGN10933": "REGN10933", "S2M11": "S2M-11",
            },
        }
    },
    "func_effects": {
        "spike mediated entry": {
            "condition": "293T_ACE2_entry",
            "effect_type": "func_effects",
            "positive_color": "#009E73",
            "negative_color": "#CC79A7",
            "max_at_least": 1,
            "min_at_least": 0,
            "init_min_value": -3,
        }
    },
    "other_assays": {
        "receptor_affinity": {
            "mock receptor affinity": {
                "condition": "pretending_S2M11_is_receptor",
                "stat": "receptor affinity_median",
                "positive_color": "#FF715B",
                "negative_color": "#F3C13A",
                "max_at_least": 1,
                "min_at_least": 0,
                "init_min_value": -10,
            }
        }
    },
}
input_csvs = {
    "antibody_escape REGN10933": "results/antibody_escape/averages/REGN10933_mut_effect.csv",
    "antibody_escape S2M11": "results/antibody_escape/averages/S2M11_mut_effect.csv",
    "func_effects 293T_ACE2_entry": "results/func_effects/averages/293T_ACE2_entry_func_effects.csv",
    "receptor_affinity pretending_S2M11_is_receptor": "results/receptor_affinity/averages/pretending_S2M11_is_receptor_mut_effect.csv",
}
site_numbering_map_csv = "data/site_numbering_map.csv"
chart_faceted = "results/summaries/summary_faceted_nolegend.html"
chart_overlaid = "results/summaries/summary_overlaid_nolegend.html"
output_csv_file = "results/summaries/summary.csv"

import os
os.chdir("../test_example")

Get the `min_times_seen` and `min_frac_models` filters:

In [4]:
min_times_seen = config["min_times_seen"]
min_frac_models = config["min_frac_models"]
alphabet = polyclonal.alphabets.biochem_order_aas(config["alphabet"])

print(f"Using {min_times_seen=} and {min_frac_models=}")

Using min_times_seen=3 and min_frac_models=1


## Read the data

Read the site numbering map:

In [5]:
site_numbering_map = pd.read_csv(site_numbering_map_csv).rename(
    columns={"reference_site": "site"}
)
site_numbering_map = site_numbering_map[
    [c for c in site_numbering_map.columns if c.endswith("site")] + ["region"]
]

Read the escape data:

In [6]:
escape = {}
for antibody_set, antibody_set_d in config["antibody_escape"].items():
    assert len(antibody_set_d["antibody_list"]) == len(set(antibody_set_d["antibody_list"].values()))
    escape_dfs = []
    for antibody, antibody_name in antibody_set_d["antibody_list"].items():
        csv_file = input_csvs[f"antibody_escape {antibody}"]
        escape_dfs.append(
            pd.read_csv(csv_file).assign(antibody=antibody_name).rename(
                columns={antibody_set_d["stat"]: "escape"}
            )
        )
    escape[antibody_set] = (
        pd.concat(escape_dfs)
        .query("frac_models >= @min_frac_models")
        .query("times_seen >= @min_times_seen")
        .query("(mutant in @alphabet) and (wildtype in @alphabet)")
        .pivot_table(
            index=["epitope", "site", "wildtype", "mutant"],
            columns="antibody",
             values="escape",
        )
        .reset_index()
        .assign(site_mutant=lambda x: x["site"].astype(str) + x["mutant"])
    )
    assert escape[antibody_set]["epitope"].nunique() == 1, "can only have 1 epitope"
    escape[antibody_set] = escape[antibody_set].drop(columns="epitope")

Read other properties (functional effects and measurements from other assays):

In [7]:
other_props = {}

for name, name_d in config["func_effects"].items():
    csv_file = input_csvs[f"func_effects {name_d['condition']}"]
    other_props[name] = (
        pd.read_csv(csv_file)
        .rename(columns={"effect": name})
        .assign(frac_models=lambda x: x["n_selections"] / x["n_selections"].max())
        .query("times_seen >= @min_times_seen")
        .query("frac_models >= @min_frac_models")
        [["site", "wildtype", "mutant", name]]
    )

for assay, assay_d in config["other_assays"].items():
    for name, name_d in assay_d.items():
        assert name not in other_props, f"{name} multiply defined"
        csv_file = input_csvs[f"{assay} {name_d['condition']}"]
        other_props[name] = (
            pd.read_csv(csv_file)
            .rename(columns={name_d["stat"]: name})
            .query("times_seen >= @min_times_seen")
            .query("frac_models >= @min_frac_models")
            [["site", "wildtype", "mutant", name]]
        )

assert not set(other_props).intersection(escape), "multiply defined names"

# add wildtype effects of zero
site_wts = pd.concat([*escape.values(), *other_props.values()])[["site", "wildtype"]].drop_duplicates()
assert len(site_wts) == site_wts["site"].nunique()
for prop in other_props:
    if not re.fullmatch("[\w ]+", prop):
        raise ValueError("non-alphanumeric name for property: {prop}")
    other_props[prop] = (
        pd.concat(
            [
                other_props[prop],
                site_wts.assign(
                    mutant=lambda x: x["wildtype"],
                    **{prop: 0},
                )
            ],
            ignore_index=True,
        )
        .assign(site_mutant=lambda x: x["site"].astype(str) + x["mutant"])
        .merge(site_numbering_map, on="site", validate="many_to_one")
        .query("(mutant in @alphabet) and (wildtype in @alphabet)")
    )
    assert other_props[prop]["site_mutant"].nunique() == len(other_props[prop])

Get from the config the plot parameters for each plot (essentially, this "flattens" some aspects of `config` to make these easier to access below):

In [8]:
plot_params = {}

for name in config["antibody_escape"]:
    assert name not in plot_params
    plot_params[name] = config["antibody_escape"][name]
    
for name in config["func_effects"]:
    assert name not in plot_params
    plot_params[name] = config["func_effects"][name]

for assay, assay_d in config["other_assays"].items():
    for name in assay_d:
        assert name not in plot_params
        plot_params[name] = assay_d[name]

## Set up selections for interactive charts

In [12]:
site_escape_width = 800  # width of site escape chart
site_escape_overlaid_height = 130  # height of overlaid site escape plots
site_escape_faceted_height = 80  # height of faceted site escape plots
cell_size = 9  # heatmap cell size

floor_escape_at_zero = alt.param(
    value=config["init_floor_escape_at_zero"],
    name="floor_escape_at_zero",
    bind=alt.binding_radio(options=[True, False], name="floor escape at zero"),
)

site_stats = ["mean", "sum", "max", "min"]
assert config["init_site_escape_stat"] in site_stats
site_escape_selection = alt.selection_point(
    fields=["site escape statistic"],
    bind=alt.binding_select(
        options=site_stats,
        name="site escape statistic",
    ),
    value=config["init_site_escape_stat"],
)

site_selection = alt.selection_point(fields=["site"], on="mouseover", empty=False)

site_brush = alt.selection_interval(
    encodings=["x"],
    mark=alt.BrushConfig(stroke="black", strokeWidth=2, fillOpacity=0),
    empty=True,
)

other_prop_sliders = {}
for prop, prop_df in other_props.items():
    other_prop_sliders[prop] = alt.param(
        value=max(prop_df[prop].min(), plot_params[prop]["init_min_value"]),
        name=prop.replace(" ", "_") + "_slider",
        bind=alt.binding_range(
            name=f"minimum mutation {prop}",
            min=prop_df[prop].min(),
            max=0,
        ),
    )

# region zoom bar
region_bar = (
    alt.Chart(site_numbering_map)
    .encode(
        x=alt.X(
            "site:N",
            sort=alt.SortField("sequential_site:Q"),
            axis=None,
        ),
        color=alt.Color(
            "region",
            scale=alt.Scale(domain=site_numbering_map["region"].unique()),
        ),
        tooltip=site_numbering_map.columns.tolist(),
    )
    .mark_rect()
    .properties(
        width=site_escape_width,
        height=11,
        title=alt.TitleParams(
            "site zoom bar", fontSize=11, fontWeight="bold", orient="top",
        ),
    )
)

## Antibody escape site summary plots
Now make site summary escape plots for each antibody set.
Do this with both the sera faceted and overlaid.

In [34]:
site_escape_charts = {"faceted": {}, "overlaid": {}}

for antibody_set, escape_df in escape.items():

    antibodies = list(config["antibody_escape"][antibody_set]["antibody_list"].values())

    site_escape_base = (
        alt.Chart(escape_df)
        .encode(
            y=alt.Y(
                "escape:Q",
                scale=alt.Scale(nice=False, padding=4),
                axis=alt.Axis(grid=False),
            ),
            tooltip=[
                "site",
                alt.Tooltip("escape:Q", format=".2f"),
                "antibody:N",
                *[f"{c}:N" for c in site_numbering_map.columns if c != "site"],
            ],
        )
        .transform_filter(site_brush)
    )

    site_escape_lines = site_escape_base.mark_line(size=0.75)

    site_escape_points = site_escape_base.encode(
        strokeWidth=alt.condition(site_selection, alt.value(1.5), alt.value(0)),
        size=alt.condition(site_selection, alt.value(45), alt.value(15)),
    ).mark_circle(filled=True, stroke="red")
    
    site_escape_lines_and_points = (
        (site_escape_lines + site_escape_points)
        .transform_fold(fold=antibodies, as_=["antibody", "escape_orig"])
        # floor escape at zero if selected
        .transform_calculate(
            escape=alt.expr.if_(
                floor_escape_at_zero,
                alt.expr.max(alt.datum["escape_orig"], 0),
                alt.datum["escape_orig"],
            )
        )
    )

    # filter on other properties
    for prop, prop_df in other_props.items():
        # https://github.com/altair-viz/altair/issues/2600
        slider = other_prop_sliders[prop]
        site_escape_lines_and_points = (
            site_escape_lines_and_points
            .transform_lookup(
                lookup="site_mutant",
                from_=alt.LookupData(
                    prop_df,
                    key="site_mutant",
                    fields=[prop],
                ),
            )
            .transform_filter(alt.datum[prop] >= slider)
        )

    # compute site statistics
    site_escape_lines_and_points = (
        site_escape_lines_and_points
        # compute site statistics from mutation statistics
        .transform_aggregate(
            **{stat: f"{stat}(escape)" for stat in site_stats},
            groupby=["site", "antibody", "wildtype"],
        )
        # filter on site statistic of interest
        .transform_fold(fold=site_stats, as_=["site escape statistic", "escape"])
        .transform_filter(site_escape_selection)
        # get sequential sites and regions
        .transform_lookup(
            lookup="site",
            from_=alt.LookupData(
                site_numbering_map,
                key="site",
                fields=[c for c in site_numbering_map.columns if c != "site"],
            ),
        )
    )

    if len(antibody_set) < 14:   
        individual_title = f"individual {antibody_set}"
        mean_title = f"mean {antibody_set}"
    else:
        individual_title = ["individual", antibody_set]
        mean_title = ["mean", antibody_set]
    if len(antibodies) == 1:
        mean_title = antibody_set

    site_escape_faceted = (
        site_escape_lines_and_points
        .encode(
            x=alt.X(
                "site:N",
                sort=alt.SortField("sequential_site:Q"),
                axis=alt.Axis(labelOverlap=True, grid=False, ticks=False),
            ),
            color=alt.value("gray"),
        )
        .properties(height=site_escape_faceted_height, width=site_escape_width)
        .facet(
            facet=alt.Facet(
                "antibody:N",
                title=individual_title,
                header=alt.Header(
                    labelOrient="right",
                    labelFontSize=10,
                    labelPadding=3,
                    titleOrient="right",
                    titlePadding=3,
                ),
            ),
            columns=1,
            spacing=0,
        )
        .resolve_scale(y="independent")
    )

    site_escape_overlaid = (
        site_escape_lines_and_points
        .encode(
            x=alt.X(
                "site:N",
                sort=alt.SortField("sequential_site:Q"),
                axis=alt.Axis(labelOverlap=True, grid=False, ticks=False),
            ),
            opacity=alt.value(0.4),
            color=alt.value("gray" if len(antibodies) > 1 else "black"),
            detail="serum:N",
        )
        .properties(
            height=site_escape_overlaid_height,
            width=site_escape_width,
            title=alt.TitleParams(
                individual_title, fontSize=11, fontWeight="bold", orient="right",
            ),
        )
    )

    site_mean_escape = (
        site_escape_lines_and_points
        # average missing values as zero
        .transform_calculate(
            escape=alt.expr.if_(
                alt.expr.isValid(alt.datum["escape"]),
                alt.datum["escape"],
                0,
            ),
        )
        # take mean over sera
        .transform_aggregate(
            escape="mean(escape)",
            groupby=["wildtype", *site_numbering_map.columns],
        )
        .transform_calculate(
            antibody="'mean escape'" if len(antibodies) > 1 else f"'{antibodies[0]}'"
        )
        .encode(
            x=alt.X(
                "site:N",
                sort=alt.SortField("sequential_site:Q"),
                axis=None,
            ),
            color=alt.value("black")
        )
        .properties(
            title=alt.TitleParams(
                mean_title, fontSize=11, fontWeight="bold", orient="right",
            ),
        )
    )

    for chart_type, height, site_chart in [
        ("faceted", site_escape_faceted_height, site_escape_faceted),
        ("overlaid", site_escape_overlaid_height, site_escape_overlaid),
    ]:
        site_escape_charts[chart_type][antibody_set] = alt.vconcat(
            region_bar.add_params(site_brush), 
            (
                alt.vconcat(
                    *(
                        [
                            site_mean_escape.properties(
                                height=height, width=site_escape_width,
                            ),
                            site_chart,
                        ]
                        if len(antibodies) > 1
                        else [
                            site_mean_escape.properties(
                                height=height, width=site_escape_width,
                            )
                        ]
                    ),
                    spacing=3,
                )
                .add_params(
                    site_escape_selection,
                    site_selection,
                    *other_prop_sliders.values(),
                    floor_escape_at_zero,
                )
            ),
            spacing=0,
        )

## Heatmaps
First, create a data frame that has the average escape across sera (averaging mutations missing for a serum as zero for that serum) and other properties:

In [16]:
heatmap_data = None
heatmap_data_cols = []
for antibody_set, escape_df in escape.items():
    antibody_list = list(config["antibody_escape"][antibody_set]["antibody_list"].values())
    df = (
        pd.concat(
            [
                escape_df,
                # add wildtype with zero escape
                (
                    escape_df
                    [["site", "wildtype"]]
                    .drop_duplicates()
                    .assign(mutant=lambda x: x["wildtype"])
                ),
            ],
        )
        .fillna(0)
        .assign(escape=lambda x: x[antibody_list].mean(axis=1))
        .drop(columns=antibody_list + ["site_mutant"])
        .rename(columns={"escape": f"{antibody_set} escape"})
    )
    heatmap_data_cols.append(f"{antibody_set} escape")
    if heatmap_data is not None:
        heatmap_data = heatmap_data.merge(df, how="outer", validate="one_to_one")
    else:
        heatmap_data = df

for prop, prop_df in other_props.items():
    heatmap_data_cols.append(prop)
    heatmap_data = heatmap_data.merge(prop_df, validate="one_to_one", how="outer")

heatmap_data = (
    heatmap_data
    .drop(columns=[c for c in site_numbering_map.columns if c != "site"])
    .merge(site_numbering_map, validate="many_to_one")
    .drop(columns="site_mutant")
)

for antibody_set in escape:
    col = f"{antibody_set} escape"
    heatmap_data[col] = heatmap_data[col].where(
        heatmap_data["wildtype"] != heatmap_data["mutant"],
        0
    )

print(f"Writing summary data to {output_csv_file}")
heatmap_data.to_csv(output_csv_file, index=False, float_format="%.4g")

heatmap_data

Writing summary data to results/summaries/summary.csv


Unnamed: 0,site,wildtype,mutant,monoclonal antibodies escape,spike mediated entry,mock receptor affinity,sequential_site,dummy_site,region
0,5,L,C,-0.006785,0.5983,0.010750,5,5,other
1,5,L,M,-0.007225,0.3505,0.013080,5,5,other
2,5,L,V,-0.080340,-0.5474,0.009824,5,5,other
3,5,L,L,0.000000,0.0000,0.000000,5,5,other
4,5,L,R,,-4.3940,,5,5,other
...,...,...,...,...,...,...,...,...,...
2236,1247,C,F,,-0.8762,,1245,1247,other
2237,1247,C,Y,,-3.0870,,1245,1247,other
2238,1247,C,C,0.000000,0.0000,0.000000,1245,1247,other
2239,1248,C,Y,,0.6028,,1246,1248,other


Make heatmaps:

In [41]:
last_heatmap = list(plot_params.values())[-1]

heatmap_base = (
    alt.Chart(heatmap_data)
    .transform_calculate(
        **{
            f"{antibody_set} escape_floored": alt.expr.if_(
                floor_escape_at_zero,
                alt.expr.max(alt.datum[f"{antibody_set} escape"], 0),
                alt.datum[f"{antibody_set} escape"],
            )
            for antibody_set in escape
        },
        # convert null values to NaN so they show as NaN in tooltips rather than as 0.0
        **{
            col: alt.expr.if_(
                alt.expr.isFinite(alt.datum[col]),
                alt.datum[col],
                alt.expr.NaN,
            )
            for col in heatmap_data_cols
        }
    )
    .encode(
        x=alt.X(
            "site:N",
            sort=alt.SortField("sequential_site"),
            axis=alt.Axis(labelFontSize=9, ticks=False),
        ),
        y=alt.Y(
            "mutant:N",
            title="amino acid",
            sort=alphabet,
            axis=alt.Axis(labelFontSize=9, ticks=False),
        ),
        strokeWidth=alt.condition(site_selection, alt.value(2), alt.value(1)),
    )
    .properties(width=alt.Step(cell_size), height=alt.Step(cell_size))
    .add_params(*other_prop_sliders.values(), floor_escape_at_zero)
)

# mark X for wildtype
heatmap_wildtype = (
    heatmap_base
    .transform_filter(alt.datum["wildtype"] == alt.datum["mutant"])
    .mark_text(text="x", color="black")
)

# gray background for missing values
heatmap_bg = (
    heatmap_base
    .transform_impute(
        impute="_stat_dummy",
        key="mutant",
        keyvals=alphabet,
        groupby=["site"],
        value=None,
    )
    .mark_rect(color="#E0E0E0", opacity=0.8)
)

tooltips = [
    "site",
    "mutant",
    *[alt.Tooltip(c, format=".2f") for c in heatmap_data_cols],
    "wildtype",
    *[c for c in site_numbering_map.columns if c != "site"],
]

legend=alt.Legend(
    orient="left",
    offset=10,
    titleOrient="left",
    gradientLength=100,
    gradientThickness=10,
    gradientStrokeColor="black",
    gradientStrokeWidth=0.5,
)

# heatmaps for escape
escape_heatmaps = {}
escape_heatmap_base = heatmap_base.transform_filter(
    functools.reduce(
        operator.or_,
        [(alt.datum[prop] >= slider) for prop, slider in other_prop_sliders.items()],
    ) | (alt.datum["wildtype"] == alt.datum["mutant"])
)
for antibody_set in escape:
    domainMin = min(
        plot_params[antibody_set]["min_at_least"],
        heatmap_data[f"{antibody_set} escape"].min(),
    )
    escape_heatmaps[antibody_set] = (
        escape_heatmap_base
        .encode(
            # turn off x-labels for this heatmap since it is stacked
            x=alt.X(
                "site:N",
                sort=alt.SortField("sequential_site"),
                title=None,
                axis=(
                    alt.Axis()
                    if last_heatmap == antibody_set
                    else alt.Axis(ticks=False, labels=False)
                ),
            ),
            color=alt.Color(
                f"{antibody_set} escape_floored:Q",
                title=f"{antibody_set} escape" if len(antibody_set) < 14 else [antibody_set, "escape"],
                legend=legend,
                scale=alt.Scale(
                    zero=True,
                    nice=False,
                    type="linear",
                    domainMid=0,
                    domainMax=max(
                        plot_params[antibody_set]["max_at_least"],
                        heatmap_data[f"{antibody_set} escape"].max()
                    ),
                    domainMin=alt.ExprRef(
                        f"if(floor_escape_at_zero, 0, {domainMin})"
                    ),
                    range=(
                        color_gradient_hex(plot_params[antibody_set]["negative_color"], "white", n=20)
                        + color_gradient_hex("white", plot_params[antibody_set]["positive_color"], n=20)[1:]
                    ),
                ),
            ),
            tooltip=tooltips,
        )
        .mark_rect(stroke="black")
    )

# heatmap for other property (eg, functional effect) filtered escape
escape_filtered_heatmap = (
    heatmap_base
    .transform_filter(
        functools.reduce(
            operator.or_,
            [alt.datum[prop] < slider for prop, slider in other_prop_sliders.items()],
        )
        & (alt.datum["wildtype"] != alt.datum["mutant"])
    )
    .transform_calculate(filtered="''")
    .encode(
        tooltip=tooltips,
        color=alt.Color(
            "filtered:N",
            title="deleterious",
            scale=alt.Scale(range=["silver"]),
            legend=None,
        ),
    )
    .mark_rect(stroke="black")
)

heatmap = alt.vconcat(
    *[
        heatmap_bg
        + escape_heatmap
        + escape_filtered_heatmap
        + heatmap_wildtype
        for escape_heatmap in escape_heatmaps.values()
    ],
    spacing=1,
).resolve_scale(color="independent").add_params(site_selection)

heatmap

In [40]:
plot_params

{'monoclonal antibodies': {'stat': 'escape_median',
  'negative_color': '#0072B2',
  'positive_color': '#E69F00',
  'max_at_least': 1,
  'min_at_least': -1,
  'antibody_list': {'REGN10933': 'REGN10933', 'S2M11': 'S2M-11'}},
 'spike mediated entry': {'condition': '293T_ACE2_entry',
  'effect_type': 'func_effects',
  'positive_color': '#009E73',
  'negative_color': '#CC79A7',
  'max_at_least': 1,
  'min_at_least': 0,
  'init_min_value': -3},
 'mock receptor affinity': {'condition': 'pretending_S2M11_is_receptor',
  'stat': 'receptor affinity_median',
  'positive_color': '#FF715B',
  'negative_color': '#F3C13A',
  'max_at_least': 1,
  'min_at_least': 0,
  'init_min_value': -10}}