# Make the interactive escape calculator plot

## Import modules and read data
Import Python modules:

In [None]:
import itertools
import os
import re

import altair as alt

import numpy

import pandas as pd

import yaml

Disable max rows specifier for Altair:

In [None]:
_ = alt.data_transformers.disable_max_rows()

Read configuration:

In [None]:
with open("config.yaml") as f:
    config = yaml.safe_load(f)

Read the data.
We don't actually need the antibody names, so to make the data smaller we encode the antibodies with integers:

In [None]:
antibody_sources = pd.read_csv("results/antibody_sources.csv")

assert len(antibody_sources) == antibody_sources["antibody"].nunique()

encoding = (
    antibody_sources
    .rename_axis("encoding")
    .reset_index()
    .set_index("antibody")
    ["encoding"]
    .to_dict()
)

antibody_sources["antibody"] = antibody_sources["antibody"].map(encoding)

# convert the IC50s to weights of -log IC50 / 10, so antibodies with IC50 of 10
# have value of zero
antibody_ic50s = (
    pd.read_csv("results/antibody_IC50s.csv")
    .assign(
        antibody=lambda x: x["antibody"].map(encoding),
        ic50_weight=lambda x: -numpy.log(x["IC50"] / 10),
    )
    .drop(columns="IC50")
)
assert (antibody_ic50s["ic50_weight"] >= 0).all()
# make into lists so can be flattened
antibody_ic50s = (
    antibody_ic50s
    .groupby("antibody", as_index=False)
    .aggregate(list)
)

antibody_binding = (
    pd.read_csv("results/antibody_binding.csv")
    .assign(antibody=lambda x: x["antibody"].map(encoding))
    .groupby("antibody", as_index=False)
    .aggregate(list)
)

escape = (
    pd.read_csv("results/escape.csv")
    .assign(antibody=lambda x: x["antibody"].map(encoding))
)

# calculations assume max escape is 1 for each antibody
assert (escape.groupby("antibody")["escape"].max() == 1).all()

assert (
    set(escape["antibody"])
    == set(antibody_binding["antibody"])
    == set(antibody_ic50s["antibody"])
    == set(antibody_sources["antibody"])
)

Specify which sites to use:

In [None]:
sites = list(range(config["sites"]["start"], config["sites"]["end"] + 1))
assert escape["site"].isin(sites).all()

## Make an "escape calculator" plot

First make selections used on plot:

In [None]:
virus_selection = alt.selection_point(
    fields=["virus"],
    bind=alt.binding_select(
        options=sorted(antibody_ic50s.explode("virus")["virus"].unique()),
        name="virus",
   ),
)

mut_selection = alt.selection_point(
    name="mut",
    fields=["site"],
    value=[{"site": -1}],
    empty=True,
    toggle="true",
)
 
mut_escape_strength = alt.param(
    name="mut_escape_strength",
    bind=alt.binding_range(
        min=1,
        max=10,
        name="mutation escape strength",
    ),
    value=config["mutation_escape_strength"],
)

ic50_weight = alt.param(
    name="weight_by_neg_log_IC50",
    bind=alt.binding_radio(
        options=[1, 0],
        labels=["yes", "no"],
        name="weight by negative log IC50",
    ),
    value=int(config["weight_by_neg_log_IC50"]),
)

source_selection = alt.selection_point(
    fields=["source"],
    bind=alt.binding_select(
        options=antibody_sources["source"].unique(),
        name="antibody source",
    ),
)

study_selection = alt.selection_point(
    fields=["study"],
    bind=alt.binding_select(
        options=antibody_sources["study"].unique(),
        name="study",
    ),
)

binding_selection = alt.selection_point(
    fields=["binds"],
    bind=alt.binding_select(
        options=antibody_binding.explode("binds")["binds"].unique(),
        name="binds",
    ),
)

params = [
    mut_selection,
    mut_escape_strength,
    ic50_weight,
    virus_selection,
    source_selection,
    study_selection,
    binding_selection,
]

Now make the base plot.
To make things smaller, we add properties via transform lookups:

In [None]:
plot_base = (
    alt.Chart(escape)
    .transform_lookup(
        lookup="antibody",
        from_=alt.LookupData(
            data=antibody_ic50s,
            key="antibody",
            fields=["virus", "ic50_weight"],
        ),
    )
    .transform_flatten(["virus", "ic50_weight"])
    .transform_filter(virus_selection)
    .transform_aggregate(
        escape="mean(escape)",
        groupby=["antibody", "site", "ic50_weight"],
    )
    .transform_lookup(
        lookup="antibody",
        from_=alt.LookupData(
            data=antibody_sources,
            key="antibody",
            fields=["source", "study"],
        ),
    )
    .transform_filter(source_selection)
    .transform_filter(study_selection)
    .transform_aggregate(
        escape="mean(escape)",
        groupby=["antibody", "site", "ic50_weight"],
    )
    .transform_lookup(
        lookup="antibody",
        from_=alt.LookupData(
            data=antibody_binding,
            key="antibody",
            fields=["binds"],
        ),
    )
    .transform_flatten(["binds"])
    .transform_filter(binding_selection)
    .transform_aggregate(
        escape="mean(escape)",
        groupby=["antibody", "site", "ic50_weight"],
    )
    .transform_calculate(
        # based on here: https://github.com/altair-viz/altair/issues/2366#issuecomment-812621436
        # based on here: https://stackoverflow.com/a/60894451/4191652
        site_binding_retained="1 - if(indexof(mut.site, datum.site) >= 0, datum.escape, 0)",
        weight=alt.expr.if_(ic50_weight == 1, alt.datum["ic50_weight"], 1),
    )
)

Now build the bar plot:

In [None]:
frac_bound_bar = (
    plot_base
    .transform_aggregate(
        binding_retained="product(site_binding_retained)",
        groupby=["antibody", "weight"],
    )
    .transform_calculate(
        binding_retained_exp=(
            alt.datum["weight"] * alt.expr.pow(alt.datum["binding_retained"], mut_escape_strength)
        ),
    )
    .transform_aggregate(
        sum_binding_retained="sum(binding_retained_exp)",
        sum_weight="sum(weight)",
    )
    .transform_calculate(
        bound=alt.datum["sum_binding_retained"] / alt.datum["sum_weight"],
        escape=1 - alt.datum["bound"],
    )
    .transform_fold(
        ["bound", "escaped"],
        ["binding state", "fraction of antibodies"]
    )
    .encode(
        x=alt.X("fraction of antibodies:Q", axis=alt.Axis(grid=False)),
        y=alt.value(1),
        fill=alt.Color(
            "binding state:N",
            scale=alt.Scale(
                domain=["bound", "escaped"],
                range=["lightgray", "#56B4E9"],
                reverse=True,
            ),
        ),
        order=alt.Order("binding state:N"),
        tooltip=[
            "binding state:N",
            alt.Tooltip("fraction of antibodies:Q", format=".2g"),
        ]
    )
    .mark_bar(stroke="black", size=20)
    .add_params(*params)
    .properties(width=300, height=10)
)

frac_bound_bar

Now make the line plot:

In [None]:
escape_mut_base = (
    plot_base
    .encode(
        x=alt.X(
            "site:Q",
            axis=alt.Axis(grid=False),
            scale=alt.Scale(zero=False, nice=False),
        ),
        y=alt.Y(
            "mean_escape_value:Q",
            axis=alt.Axis(
                grid=False,
                title="escape (arbitrary units)",
                labels=False,
                ticks=False,
            ),
        ),
    )
    .transform_joinaggregate(
        binding_retained="product(site_binding_retained)",
        groupby=["antibody", "weight"],
    )
    .transform_calculate(
        escape_weighted=alt.datum["weight"] * alt.datum["escape"],
        escape_after_mut=(
            alt.expr.pow(alt.datum["binding_retained"], mut_escape_strength)
            * alt.datum["escape_weighted"]
        ),
    )
    # we don't actually have the correct denominator here, but it should
    # just affect relative scale of escape values
    .transform_joinaggregate(n_conditions="distinct(antibody)")
    .transform_aggregate(
        sum_mutated="sum(escape_after_mut)",
        sum_unmutated="sum(escape_weighted)",
        n_conditions="mean(n_conditions)",
        groupby=["site"],
    )
    .transform_calculate(
        mutated=alt.datum["sum_mutated"] / alt.datum["n_conditions"],
        unmutated=alt.datum["sum_unmutated"] / alt.datum["n_conditions"],
    )
    .transform_fold(
        ["unmutated", "mutated"],
        ["escape_type", "mean_escape_value"],
    )
    .transform_impute(
        impute="mean_escape_value",
        key="site",
        value=0,
        groupby=["escape_type"],
        keyvals=sites,
    )
    .transform_calculate(
        color_val=(
            'if((indexof(mut.site, datum.site) >= 0) '
            + '& (datum.escape_type == "mutated"), '
            + '"mutated site", datum.escape_type)'
        ),
    )
    .properties(width=800, height=225)
    )

mut_escape_color_scale = alt.Scale(
    domain=["unmutated", "mutated", "mutated site"],
    range=["#999999", "#56B4E9", "#D55E00"],
)

mut_escape_point_size_scale = alt.Scale(
    domain=["unmutated", "mutated", "mutated site"],
    range=[30, 60, 100],
)

mut_escape_opacity_scale = alt.Scale(
    domain=["unmutated", "mutated", "mutated site"],
    range=[0.5, 0.7, 1],
)

escape_mut_lines = (
    escape_mut_base
    .encode(
        color=alt.Color("escape_type:N", scale=mut_escape_color_scale),
        opacity=alt.Opacity(
            "escape_type:N",
            scale=mut_escape_opacity_scale,
            legend=None,
        ),
    )
    .mark_line()
)

escape_mut_points = (
    escape_mut_base
    .encode(
        color=alt.Color(
            "color_val:N",
            scale=mut_escape_color_scale,
            legend=alt.Legend(
                title=None,
                labelExpr=(
                    'if(datum.value == "unmutated", '
                    + '   "escape when no mutations", '
                    + '   if(datum.value == "mutated", '
                    + '      "escape with mutations", '
                    + '      "mutated site"))'
                ),
            ),
        ),
        opacity=alt.Opacity(
            "color_val:N",
            scale=mut_escape_opacity_scale,
            legend=None,
        ),
        size=alt.Size("color_val:N", scale=mut_escape_point_size_scale),
        tooltip=[
            "site:O",
            alt.Tooltip("mutated:Q", format=".2g"),
            alt.Tooltip("unmutated:Q", format=".2g"),
        ],
    )
    .mark_point(filled=True)
    .add_params(*params)
)

escape_chart = (
    ((escape_mut_lines + escape_mut_points) & frac_bound_bar)
    .configure_view(strokeOpacity=0)
    .configure_legend(
        orient="bottom",
        labelFontSize=12,
        title=None,
    )
    .resolve_legend("independent")
)

escape_chart

escape_calc_chartfile = 'docs/_includes/escape_calc_chart.html'
os.makedirs(os.path.dirname(escape_calc_chartfile), exist_ok=True)
print(f"Saving chart to {escape_calc_chartfile}")
escape_chart.save(escape_calc_chartfile)

In [None]:
(
    alt.Chart(pd.DataFrame({"x": [1, 2]}))
    .transform_lookup(
        lookup="x",
        from_=alt.LookupData(
            data=pd.DataFrame({"x": [1, 1, 2, 2], "y": [1, 2, 3, 4]}),
            key="x",
            fields=["y"],
        ),
    )
    .encode(x="x", y="y:Q")
    .mark_point()
)