# Make the interactive escape calculator plot

## Import modules and read data
Import Python modules:

In [1]:
import os

import altair as alt

import numpy

import pandas as pd

import yaml

Disable max rows specifier for Altair:

In [2]:
_ = alt.data_transformers.disable_max_rows()

Read configuration:

In [3]:
with open("config.yaml") as f:
    config = yaml.safe_load(f)

Read the data.
We don't actually need the antibody names, so to make the data smaller we encode the antibodies with integers:

In [4]:
antibody_sources = pd.read_csv("results/antibody_sources.csv")

assert len(antibody_sources) == antibody_sources["antibody"].nunique()

encoding = (
    antibody_sources
    .rename_axis("encoding")
    .reset_index()
    .set_index("antibody")
    ["encoding"]
    .to_dict()
)

antibody_sources["antibody"] = antibody_sources["antibody"].map(encoding)

# convert the IC50s to weights of -log IC50 / 10, so antibodies with IC50 of 10
# have value of zero
antibody_ic50s = (
    pd.read_csv("results/antibody_IC50s.csv")
    .assign(
        antibody=lambda x: x["antibody"].map(encoding),
        ic50_weight=lambda x: -numpy.log(x["IC50"] / 10),
    )
    .drop(columns="IC50")
)
assert (antibody_ic50s["ic50_weight"] >= 0).all()
# make into lists so can be flattened
antibody_ic50s = (
    antibody_ic50s
    .groupby("antibody", as_index=False)
    .aggregate(list)
)

antibody_binding = (
    pd.read_csv("results/antibody_binding.csv")
    .assign(antibody=lambda x: x["antibody"].map(encoding))
    .groupby("antibody", as_index=False)
    .aggregate(list)
)

escape = (
    pd.read_csv("results/escape.csv")
    .assign(antibody=lambda x: x["antibody"].map(encoding))
)

# calculations assume max escape is 1 for each antibody
assert (escape.groupby("antibody")["escape"].max() == 1).all()

assert (
    set(escape["antibody"])
    == set(antibody_binding["antibody"])
    == set(antibody_ic50s["antibody"])
    == set(antibody_sources["antibody"])
)

Specify which sites to use:

In [5]:
sites = list(range(config["sites"]["start"], config["sites"]["end"] + 1))
assert escape["site"].isin(sites).all()

## Make an "escape calculator" plot

First make selections used on plot:

In [6]:
virus_selection = alt.selection_point(
    fields=["virus"],
    bind=alt.binding_select(
        options=sorted(antibody_ic50s.explode("virus")["virus"].unique()),
        name="virus against which neutralization measured",
    ),
    value=[{"virus": config["init_virus"]}],
)

mut_selection = alt.selection_point(
    name="mut",
    fields=["site"],
    value=[{"site": -1}],
    empty=False,
    toggle="true",
)
 
mut_escape_strength = alt.param(
    name="mut_escape_strength",
    bind=alt.binding_range(
        min=1,
        max=10,
        name="mutation escape strength",
    ),
    value=config["init_mutation_escape_strength"],
)

ic50_weight = alt.param(
    name="weight_by_neg_log_IC50",
    bind=alt.binding_radio(
        options=[1, 0],
        labels=["yes", "no"],
        name="weight escape by negative log IC50",
    ),
    value=int(config["init_weight_by_neg_log_IC50"]),
)

# setup source selection after getting initial values from config
source_list = sorted(antibody_sources["source"].unique())
init_source_list = config["init_sources"]["sources"]
assert len(init_source_list) == len(set(init_source_list))
set(source_list).issuperset(init_source_list)
if config["init_sources"]["include_exclude"] == "exclude":
    init_source_list = [s for s in source_list if s not in init_source_list]
elif config["init_sources"]["include_exclude"] != "include":
    raise ValueError(f"invalid {config['init_sources']['include_exclude']}")
source_selection = alt.selection_point(
    fields=["source"],
    empty=True,
    toggle="true",
    value=[{"source": s} for s in init_source_list],
)

assert set(config["studies"]) == set(antibody_sources["study"])
study_selection = alt.selection_point(
    fields=["study"],
    bind=alt.binding_select(
        labels=[
            "any",
            *[config["studies"][s] for s in antibody_sources["study"].unique()],
        ],
        options=[None, *antibody_sources["study"].unique()],
        name="study",
    ),
    **(
        {"value": [{"study": config["init_study"]}]}
        if config["init_study"] != "any"
        else {}
    ),
)

binding_selection = alt.selection_point(
    fields=["binds"],
    bind=alt.binding_select(
        labels=["any", *antibody_binding.explode("binds")["binds"].unique()],
        options=[None, *antibody_binding.explode("binds")["binds"].unique()],
        name="antibody known to bind",
    ),
    **(
        {"value": [{"binds": config["init_binds"]}]}
        if config["init_binds"] != "any"
        else {}
    ),
)

params = [
    mut_selection,
    mut_escape_strength,
    ic50_weight,
    binding_selection,
    study_selection,
    source_selection,
    virus_selection,
]

Now make the base plot.
To make things smaller, we add properties via transform lookups:

In [7]:
plot_base_no_source_selection = (
    alt.Chart(escape)
    .transform_lookup(
        lookup="antibody",
        from_=alt.LookupData(
            data=antibody_ic50s,
            key="antibody",
            fields=["virus", "ic50_weight"],
        ),
    )
    .transform_flatten(["virus", "ic50_weight"])
    .transform_filter(virus_selection)
    .transform_lookup(
        lookup="antibody",
        from_=alt.LookupData(
            data=antibody_binding,
            key="antibody",
            fields=["binds"],
        ),
    )
    .transform_flatten(["binds"])
    .transform_filter(binding_selection)
    .transform_lookup(
        lookup="antibody",
        from_=alt.LookupData(
            data=antibody_sources,
            key="antibody",
            fields=["source", "study"],
        ),
    )
    .transform_filter(study_selection)
    .transform_calculate(
        # based on here: https://github.com/altair-viz/altair/issues/2366#issuecomment-812621436
        # based on here: https://stackoverflow.com/a/60894451/4191652
        site_binding_retained="1 - if(indexof(mut.site, datum.site) >= 0, datum.escape, 0)",
        weight=alt.expr.if_(ic50_weight == 1, alt.datum["ic50_weight"], 1),
    )
)

plot_base = plot_base_no_source_selection.transform_filter(source_selection)

Make plot showing the sources, both to display antibodies of each type from each source and to use as clickable legend:

In [8]:
source_selection_barplot = (
    plot_base_no_source_selection
    .transform_aggregate(
        mean_ic50_weight="mean(ic50_weight)",
        groupby=["antibody", "source"],
    )
    .transform_calculate(
        neutralizing=alt.expr.if_(
            alt.datum["mean_ic50_weight"] > 0,
            "neutralizes virus",
            "does not neutralize virus",
        )
    )
    .transform_aggregate(
        n_antibodies="distinct(antibody)",
        groupby=["source", "neutralizing"],
    )
    .encode(
        x=alt.X(
            "n_antibodies:Q",
            axis=alt.Axis(grid=False),
            title="number of antibodies",
        ),
        y=alt.Y(
            "source:N",
            scale=alt.Scale(domain=source_list),
            title=None,
        ),
        opacity=alt.condition(source_selection, alt.value(1), alt.value(0.2)),
        color=alt.Color(
            "neutralizing:N",
            scale=alt.Scale(
                domain=["neutralizes virus", "does not neutralize virus"],
                range=["#CC79A7", "#009E73"],
            ),
            legend=alt.Legend(orient="right", labelFontSize=12, title=None, offset=5),
        ),
        order=alt.Order("neutralizing:N", sort="descending"),
        tooltip=["source:N", "n_antibodies:Q", "neutralizing:N"],
    )
    .mark_bar()
    .properties(
        height=alt.Step(10),
        width=225,
        title="click bars to select antibody sources",
    )
    .add_params(*params)
)

Now build the bar plot:

In [9]:
frac_bound_bar = (
    plot_base
    .transform_aggregate(
        binding_retained="product(site_binding_retained)",
        groupby=["antibody", "weight"],
    )
    .transform_calculate(
        binding_retained_exp=(
            alt.datum["weight"] * alt.expr.pow(alt.datum["binding_retained"], mut_escape_strength)
        ),
    )
    .transform_aggregate(
        sum_binding_retained="sum(binding_retained_exp)",
        sum_weight="sum(weight)",
    )
    .transform_calculate(
        bound=alt.datum["sum_binding_retained"] / alt.datum["sum_weight"],
        escaped=1 - alt.datum["bound"],
    )
    .transform_fold(
        ["bound", "escaped"],
        ["binding state", "fraction of antibodies"]
    )
    .encode(
        x=alt.X("fraction of antibodies:Q", axis=alt.Axis(grid=False)),
        y=alt.value(1),
        fill=alt.Color(
            "binding state:N",
            scale=alt.Scale(
                domain=["bound", "escaped"],
                range=["lightgray", "#56B4E9"],
                reverse=True,
            ),
            legend=alt.Legend(orient="bottom", labelFontSize=12, title=None, offset=6),
        ),
        order=alt.Order("binding state:N"),
        tooltip=[
            "binding state:N",
            alt.Tooltip("fraction of antibodies:Q", format=".2g"),
        ]
    )
    .mark_bar(stroke="black", size=20)
    .add_params(*params)
    .properties(width=275, height=10, title="total neutralization or binding remaining")
)

Now make the line plot:

In [10]:
escape_mut_base = (
    plot_base
    .encode(
        x=alt.X(
            "site:Q",
            axis=alt.Axis(grid=False),
            scale=alt.Scale(zero=False, nice=False),
        ),
        y=alt.Y(
            "mean_escape_value:Q",
            axis=alt.Axis(
                grid=False,
                title="escape (arbitrary units)",
                labels=False,
                ticks=False,
            ),
        ),
    )
    .transform_joinaggregate(
        binding_retained="product(site_binding_retained)",
        groupby=["antibody", "weight"],
    )
    .transform_calculate(
        escape_weighted=alt.datum["weight"] * alt.datum["escape"],
        escape_after_mut=(
            alt.expr.pow(alt.datum["binding_retained"], mut_escape_strength)
            * alt.datum["escape_weighted"]
        ),
    )
    # we don't actually have the correct denominator here, but it should
    # just affect relative scale of escape values
    .transform_joinaggregate(n_antibodies="distinct(antibody)")
    .transform_aggregate(
        sum_mutated="sum(escape_after_mut)",
        sum_unmutated="sum(escape_weighted)",
        n_antibodies="mean(n_antibodies)",
        groupby=["site"],
    )
    .transform_calculate(
        mutated=alt.datum["sum_mutated"] / alt.datum["n_antibodies"],
        unmutated=alt.datum["sum_unmutated"] / alt.datum["n_antibodies"],
    )
    .transform_fold(
        ["unmutated", "mutated"],
        ["escape_type", "mean_escape_value"],
    )
    .transform_impute(
        impute="mean_escape_value",
        key="site",
        value=0,
        groupby=["escape_type"],
        keyvals=sites,
    )
    .transform_calculate(
        color_val=(
            'if((indexof(mut.site, datum.site) >= 0) '
            + '& (datum.escape_type == "mutated"), '
            + '"mutated site", datum.escape_type)'
        ),
    )
    .properties(
        width=800,
        height=225,
        title="escape at each site (click sites to mutate)",
    )
    )

mut_escape_color_scale = alt.Scale(
    domain=["unmutated", "mutated", "mutated site"],
    range=["#999999", "#56B4E9", "#E69F00"],
)

mut_escape_point_size_scale = alt.Scale(
    domain=["unmutated", "mutated", "mutated site"],
    range=[30, 60, 100],
)

mut_escape_opacity_scale = alt.Scale(
    domain=["unmutated", "mutated", "mutated site"],
    range=[0.5, 0.7, 1],
)

escape_mut_lines = (
    escape_mut_base
    .encode(
        color=alt.Color("escape_type:N", scale=mut_escape_color_scale),
        opacity=alt.Opacity(
            "escape_type:N",
            scale=mut_escape_opacity_scale,
            legend=None,
        ),
    )
    .mark_line()
)

escape_mut_points = (
    escape_mut_base
    .encode(
        color=alt.Color(
            "color_val:N",
            scale=mut_escape_color_scale,
            legend=alt.Legend(
                title=None,
                labelExpr=(
                    'if(datum.value == "unmutated", '
                    + '   "escape when no mutations", '
                    + '   if(datum.value == "mutated", '
                    + '      "escape with mutations", '
                    + '      "mutated site"))'
                ),
                orient="bottom",
                labelFontSize=12,
                offset=10,
            ),
        ),
        opacity=alt.Opacity(
            "color_val:N",
            scale=mut_escape_opacity_scale,
            legend=None,
        ),
        size=alt.Size("color_val:N", scale=mut_escape_point_size_scale),
        tooltip=[
            "site:O",
            alt.Tooltip("mutated:Q", format=".2g"),
            alt.Tooltip("unmutated:Q", format=".2g"),
        ],
    )
    .mark_point(filled=True)
    .add_params(*params)
)

escape_chart = (
    (
        (escape_mut_lines + escape_mut_points)
        & alt.hconcat(
            frac_bound_bar,
            source_selection_barplot,
            center=True,
        )
    )
    .configure_view(strokeOpacity=0)
    .resolve_scale(color="independent")
    .resolve_legend("independent")
)

escape_chart

Save interactive chart to file:

In [11]:
escape_chart.save("results/escape_chart.html")