# Plot yeast RBD DMS escape calculator

## Import modules and read data
Import Python modules:

In [1]:
import itertools
import os

import altair as alt

import numpy

import pandas as pd

Disable max rows specifier for Altair:

In [2]:
_ = alt.data_transformers.disable_max_rows()

Read the deep mutational scanning data, reduce to site-level data:

In [3]:
dms_data = pd.read_csv('./processed_data/escape_data.csv', low_memory=False)

# get only site-level data for antibodies, and drop entries with 0 escape which we instead impute below
dms_data = (
    dms_data
    .rename(columns={"site_total_escape": "escape"})
    .query("condition_type == 'antibody'")
    .query("escape != 0")
    [['condition', 'eliciting_virus', 'known_to_neutralize', "IC50s", 'study', 'site', "escape"]]
    .drop_duplicates()
)

# for duplicated conditions, keep one with more known_to_neutralize details
print(f"Before de-duplicating we have {len(dms_data.groupby(['condition', 'study']))} conditions")
dms_data = (
    dms_data
    .assign(n_known_to_neutralize=lambda x: x["known_to_neutralize"].str.count(";") + 1)
    .sort_values("n_known_to_neutralize")
    .groupby("condition", as_index=False)
    .aggregate({"study": "last"})
    .merge(dms_data)
)
print(f"After de-duplicating we have {len(dms_data.groupby(['condition', 'study']))} conditions")

# get list of all sites
sites = list(range(dms_data['site'].min(), dms_data['site'].max() + 1))

# split out known_to_neutralize and eliciting virus and drop columns we won't use more
dms_data = (
    dms_data
    .assign(
        known_to_neutralize=lambda x: x["known_to_neutralize"].str.split(";").map(tuple),
        IC50s=lambda x: x["IC50s"].map(
            lambda s: tuple([pd.NA if i == "NA" else float(i) for i in s.split(";")])
        ),
        eliciting_virus=lambda x: x["eliciting_virus"].str.split(";").map(tuple),
    )
    .drop(columns="study")
)

# now use the IC50s to calculate negative log IC50 scaled so an IC50 of
# 10 is considered not neutralized as this was cutoff in Cao et al data.
# But first check to make sue largest IC50 is indeed ~10
all_ic50s = dms_data.explode("IC50s")["IC50s"]
max_ic50 = max(all_ic50s[pd.notnull(all_ic50s)])
ic50_ceil = 10
assert max_ic50 <= ic50_ceil and numpy.allclose(max_ic50, ic50_ceil, rtol=1e-3)

# now compute the negative log of the IC50s normalized by the ceiling
dms_data = (
    dms_data.assign(
        neg_log_IC50=lambda x: (
            x["IC50s"].map(
                lambda t: tuple([
                    0 if pd.isnull(i) else -numpy.log(float(i) / ic50_ceil) for i in t
                ])
            )
        )
    )
    .drop(columns="IC50s")
)

assert all(dms_data["known_to_neutralize"].map(len) == dms_data["neg_log_IC50"].map(len))
neg_log_ic50s = dms_data.explode("neg_log_IC50")["neg_log_IC50"].sort_values()
assert min(neg_log_ic50s) == 0 and all(neg_log_ic50s >= 0)

dms_data

Before de-duplicating we have 1736 conditions
After de-duplicating we have 1535 conditions


Unnamed: 0,condition,eliciting_virus,known_to_neutralize,site,escape,neg_log_IC50
0,1-57,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(Wuhan-Hu-1, any)",338,0.05792,"(7.338538195074591, 7.338538195074591)"
1,1-57,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(Wuhan-Hu-1, any)",359,0.01558,"(7.338538195074591, 7.338538195074591)"
2,1-57,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(Wuhan-Hu-1, any)",370,0.03169,"(7.338538195074591, 7.338538195074591)"
3,1-57,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(Wuhan-Hu-1, any)",394,0.01253,"(7.338538195074591, 7.338538195074591)"
4,1-57,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(Wuhan-Hu-1, any)",396,0.02160,"(7.338538195074591, 7.338538195074591)"
...,...,...,...,...,...,...
32850,XGv-422,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(Wuhan-Hu-1, SARS-CoV-1, Omicron BA.1, Omicron...",465,0.12270,"(3.233989462678249, 0.5344354894051243, 3.2754..."
32851,XGv-422,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(Wuhan-Hu-1, SARS-CoV-1, Omicron BA.1, Omicron...",466,0.14140,"(3.233989462678249, 0.5344354894051243, 3.2754..."
32852,XGv-422,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(Wuhan-Hu-1, SARS-CoV-1, Omicron BA.1, Omicron...",468,0.05000,"(3.233989462678249, 0.5344354894051243, 3.2754..."
32853,XGv-422,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(Wuhan-Hu-1, SARS-CoV-1, Omicron BA.1, Omicron...",484,0.02634,"(3.233989462678249, 0.5344354894051243, 3.2754..."


## Make an "escape calculator" plot
Encode condition level data and then transform_lookup when making plot to make data set size smaller:

In [4]:
encoding = (
    dms_data
    [["eliciting_virus", "known_to_neutralize", "condition"]]
    .drop_duplicates()
    .reset_index(drop=True)
    .assign(encoding=lambda x: x.index)
)

dms_data_encoded = (
    dms_data
    .merge(encoding)
    [["encoding", "site", "escape"]]
)
assert len(dms_data_encoded) == len(dms_data_encoded.drop_duplicates())

encoding = encoding.drop(columns="condition")

display(encoding)
display(dms_data_encoded)

Unnamed: 0,eliciting_virus,known_to_neutralize,encoding
0,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(Wuhan-Hu-1, any)",0
1,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(Wuhan-Hu-1, any)",1
2,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(Wuhan-Hu-1, any)",2
3,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(Wuhan-Hu-1, Omicron BA.2.12.1, any)",3
4,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(Wuhan-Hu-1, any)",4
...,...,...,...
1530,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(Wuhan-Hu-1, SARS-CoV-1, any)",1530
1531,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(Wuhan-Hu-1, SARS-CoV-1, Omicron BA.1, Omicron...",1531
1532,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(Wuhan-Hu-1, Omicron BA.1, Omicron BA.2, Omicr...",1532
1533,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(Wuhan-Hu-1, SARS-CoV-1, any)",1533


Unnamed: 0,encoding,site,escape
0,0,338,0.05792
1,0,359,0.01558
2,0,370,0.03169
3,0,394,0.01253
4,0,396,0.02160
...,...,...,...
32850,1534,465,0.12270
32851,1534,466,0.14140
32852,1534,468,0.05000
32853,1534,484,0.02634


Now make bar plot with antibody fraction bound:

In [5]:
known_to_neutralize_options = (
    sorted(encoding.explode("known_to_neutralize")["known_to_neutralize"].unique())
)
known_to_neutralize_selection = alt.selection_point(
    fields=['known_to_neutralize'],
    bind=alt.binding_select(
        options=known_to_neutralize_options,
        labels=known_to_neutralize_options,
        name="known to neutralize",
   ),
    value=[{'known_to_neutralize': 'any'}]
)

mut_selection = alt.selection_point(name='mut',
                                    fields=['site'],
                                    value=[{'site': -1}],
                                    empty=True,
                                    )

mut_escape_strength_slider = alt.binding_range(min=1, max=10,
                                               name='mutation_escape_strength')
mut_escape_strength_selection = alt.selection_point(name='mut_escape_strength',
                                                    fields=['mutation_escape_strength'],
                                                    bind=mut_escape_strength_slider,
                                                    value=[{'mutation_escape_strength': 2}])

eliciting_viruses = sorted(encoding.explode("eliciting_virus")["eliciting_virus"].unique())
eliciting_virus_selection = alt.selection_point(
    fields=['eliciting_virus'],
    bind=alt.binding_select(
        options=eliciting_viruses,
        labels=eliciting_viruses,
        name="eliciting virus",
    ),
    value=[{'eliciting_virus': 'SARS-CoV-2'}]
)

plot_base = (
    alt.Chart(dms_data_encoded)
    .transform_lookup(
        lookup="encoding",
        from_=alt.LookupData(data=encoding, key="encoding", fields=["known_to_neutralize"])
    )
    .transform_flatten(["known_to_neutralize"])
    .transform_filter(known_to_neutralize_selection)
    .transform_aggregate(
        escape="mean(escape)",
        groupby=["encoding", "site"],
    )
    .transform_lookup(
        lookup="encoding",
        from_=alt.LookupData(data=encoding, key="encoding", fields=["eliciting_virus"])
    )
    .transform_flatten(["eliciting_virus"])
    .transform_filter(eliciting_virus_selection)
    .transform_aggregate(
        escape="mean(escape)",
        groupby=["encoding", "site"],
    )
    .transform_joinaggregate(
        # get maximum escape across any site for this condition
        condition_escape_max='max(escape)',
        groupby=['encoding'],
    )
    .transform_calculate(
        # based on here: https://github.com/altair-viz/altair/issues/2366#issuecomment-812621436
        # based on here: https://stackoverflow.com/a/60894451/4191652
        site_binding_retained='(datum.condition_escape_max - '
                              ' if(indexof(mut.site, datum.site) >= 0, datum.escape, 0)) / '
                              'datum.condition_escape_max',
    )
)

frac_bound_bar = (
    plot_base
    .transform_aggregate(
        binding_retained='product(site_binding_retained)',
        groupby=['encoding'],
        )
    .transform_calculate(
        binding_retained_exp='pow(datum.binding_retained, mut_escape_strength.mutation_escape_strength)'
        )
    .transform_aggregate(
        mean_binding_retained='mean(binding_retained_exp)',
        )
    .transform_calculate(
        bound='datum.mean_binding_retained',
        escaped='1 - datum.bound',
        )
    .transform_fold(
        ['bound', 'escaped'],
        ['binding state', 'fraction of antibodies']
        )
    .encode(x=alt.X('fraction of antibodies:Q',
                    axis=alt.Axis(grid=False),
                    ),
            y=alt.value(1),
            fill=alt.Color('binding state:N',
                            scale=alt.Scale(
                                domain=['bound', 'escaped'],
                                range=['lightgray', '#56B4E9'],
                                reverse=True,
                                ),
                            ),
            order=alt.Order('binding state:N'),
            tooltip=['binding state:N',
                     alt.Tooltip('fraction of antibodies:Q',
                                 format='.2g'),]
            )
    .mark_bar(stroke='black',
              size=20)
    .add_parameter(
        mut_selection,
        mut_escape_strength_selection,
        known_to_neutralize_selection,
        eliciting_virus_selection,
    )
    .properties(width=300, height=10)
    )

frac_bound_bar

Now make the line plot:

In [6]:
escape_mut_base = (
    plot_base
    .encode(x=alt.X('site:Q',
                    axis=alt.Axis(grid=False),
                    scale=alt.Scale(zero=False, nice=False),
                    ),
            y=alt.Y('mean_escape_value:Q',
                    axis=alt.Axis(grid=False, title='escape'),
                    ),
            )
    .transform_joinaggregate(
        binding_retained='product(site_binding_retained)',
        groupby=['encoding'],
        )
    .transform_calculate(
        escape_after_mut='pow(datum.binding_retained, mut_escape_strength.mutation_escape_strength) * datum.escape'
        )
    .transform_joinaggregate(n_conditions="distinct(encoding)")
    .transform_aggregate(
        sum_mutated='sum(escape_after_mut)',
        sum_unmutated='sum(escape)',
        n_conditions="mean(n_conditions)",
        groupby=['site'],
        )
    .transform_calculate(
        mutated="datum.sum_mutated / datum.n_conditions",
        unmutated="datum.sum_unmutated / datum.n_conditions",
    )
    .transform_fold(['unmutated', 'mutated'],
                    ['escape_type', 'mean_escape_value'])
    .transform_impute(
        impute="mean_escape_value",
        key="site",
        value=0,
        groupby=["escape_type"],
        keyvals=sites,
    )
    .transform_calculate(
        color_val='if((indexof(mut.site, datum.site) >= 0) & (datum.escape_type == "mutated"), '
                  '"mutated site", datum.escape_type)'
        )
    .properties(width=800, height=225)
    )

mut_escape_color_scale = alt.Scale(
        domain=['unmutated', 'mutated', 'mutated site'],
        range=['#999999', '#56B4E9', '#D55E00']
        )
mut_escape_point_size_scale = alt.Scale(
        domain=['unmutated', 'mutated', 'mutated site'],
        range=[30, 60, 100],
        )
mut_escape_opacity_scale = alt.Scale(
        domain=['unmutated', 'mutated', 'mutated site'],
        range=[0.5, 0.7, 1],
        )

escape_mut_lines = (
    escape_mut_base
    .encode(color=alt.Color('escape_type:N',
                            scale=mut_escape_color_scale,
                            ),
            opacity=alt.Opacity('escape_type:N',
                                scale=mut_escape_opacity_scale,
                                legend=None,
                                ),
            )
    .mark_line()
    )

escape_mut_points = (
    escape_mut_base
    .encode(color=alt.Color(
                    'color_val:N',
                    scale=mut_escape_color_scale,
                    legend=alt.Legend(
                            title=None,
                            labelExpr='if(datum.value == "unmutated", '
                                      '   "escape when no mutations", '
                                      '   if(datum.value == "mutated", '
                                      '      "escape with mutations", '
                                      '      "mutated site"))'
                            ),
                    ),
            opacity=alt.Opacity('color_val:N',
                                scale=mut_escape_opacity_scale,
                                legend=None,
                                ),
            size=alt.Size('color_val:N',
                          scale=mut_escape_point_size_scale,
                          ),
            tooltip=['site:O',
                     alt.Tooltip('mutated:Q',
                                 format='.2g'),
                     alt.Tooltip('unmutated:Q',
                                 format='.2g'),
                     ],
            )
    .mark_point(filled=True)
    .add_parameter(mut_selection,
                   mut_escape_strength_selection,
                   known_to_neutralize_selection,
                   eliciting_virus_selection,
                   )
    )

escape_chart = (
    ((escape_mut_lines + escape_mut_points) & frac_bound_bar)
    .configure_view(strokeOpacity=0)
    .configure_legend(orient='bottom',
                      labelFontSize=12,
                      title=None)
    .resolve_legend('independent')
    )

escape_calc_chartfile = 'docs/_includes/escape_calc_chart.html'
os.makedirs(os.path.dirname(escape_calc_chartfile), exist_ok=True)
print(f"Saving chart to {escape_calc_chartfile}")
escape_chart.save(escape_calc_chartfile)

escape_chart

Saving chart to docs/_includes/escape_calc_chart.html


Write the escape calculator data to a file:

In [7]:
escape_calc_data_file = 'processed_data/escape_calculator_data.csv'
os.makedirs(os.path.dirname(escape_calc_data_file), exist_ok=True)

print(f"Writing escape calculator data to {escape_calc_data_file}")

(
    dms_data
    .assign(
        known_to_neutralize=lambda x: x["known_to_neutralize"].map(lambda t: ";".join(t)),
        neg_log_IC50=lambda x: x["neg_log_IC50"].map(lambda t: ";".join(f"{i:.5g}" for i in t)),
        eliciting_virus=lambda x: x["eliciting_virus"].map(lambda t: ";".join(t)),
    )
    .to_csv(escape_calc_data_file, index=False)
)

Writing escape calculator data to processed_data/escape_calculator_data.csv
