# Plot yeast RBD DMS escape calculator

## Import modules and read data
Import Python modules:

In [1]:
import itertools
import os

import altair as alt

import numpy

import pandas as pd

Disable max rows specifier for Altair:

In [2]:
_ = alt.data_transformers.disable_max_rows()

Read the deep mutational scanning data, reduce to site-level data:

In [3]:
dms_data = (pd.read_csv('./processed_data/escape_data.csv', low_memory=False)
            .assign(condition_alias=lambda x: x['condition_alias'].fillna(''))
            .rename(columns={'eliciting_virus': 'virus'})
            )

# get only site-level data for antibodies, and drop entries with 0 escape which we instead impute below
dms_data = (
        dms_data
        .query("condition_type == 'antibody'")
        [['condition', 'virus', 'known_to_neutralize', 'study', 'lab', 'site', "site_mean_escape"]]
        .rename(columns={"site_mean_escape": "escape"})
        .drop_duplicates()
        .query("escape != 0")
        )

# for duplicated conditions, keep one with more known_to_neutralize details
print(f"Before de-duplicating we have {len(dms_data.groupby(['condition', 'study']))} conditions")
dms_data = (
    dms_data
    .assign(n_known_to_neutralize=lambda x: x["known_to_neutralize"].str.count(";") + 1)
    .sort_values("n_known_to_neutralize")
    .groupby("condition", as_index=False)
    .aggregate({"study": "last"})
    .merge(dms_data)
)
print(f"After de-duplicating we have {len(dms_data.groupby(['condition', 'study']))} conditions")

# get list of all sites
sites = list(range(dms_data['site'].min(), dms_data['site'].max() + 1))

# split out known_to_neutralize and drop columns we won't use more
dms_data = (
    dms_data
    .assign(known_to_neutralize=lambda x: x["known_to_neutralize"].str.split(";").map(tuple))
    .drop(columns=["lab", "study"])
)

dms_data

Before de-duplicating we have 1736 conditions
After de-duplicating we have 1535 conditions


Unnamed: 0,condition,virus,known_to_neutralize,site,escape
0,1-57,SARS-CoV-2,"(Wuhan-Hu-1,)",338,0.06084
1,1-57,SARS-CoV-2,"(Wuhan-Hu-1,)",359,0.01637
2,1-57,SARS-CoV-2,"(Wuhan-Hu-1,)",370,0.03328
3,1-57,SARS-CoV-2,"(Wuhan-Hu-1,)",394,0.01316
4,1-57,SARS-CoV-2,"(Wuhan-Hu-1,)",396,0.02269
...,...,...,...,...,...
33508,XGv-422,SARS-CoV-2,"(Wuhan-Hu-1, SARS-CoV-1, Omicron BA.1, Omicron...",465,0.22970
33509,XGv-422,SARS-CoV-2,"(Wuhan-Hu-1, SARS-CoV-1, Omicron BA.1, Omicron...",466,0.26460
33510,XGv-422,SARS-CoV-2,"(Wuhan-Hu-1, SARS-CoV-1, Omicron BA.1, Omicron...",468,0.09359
33511,XGv-422,SARS-CoV-2,"(Wuhan-Hu-1, SARS-CoV-1, Omicron BA.1, Omicron...",484,0.04930


## Make an "escape calculator" plot
Encode condition level data and then transform_lookup when making plot to make data set size smaller:

In [6]:
encoded_cols = [c for c in dms_data if c not in {"site", "escape"}]
lookup_cols = [c for c in encoded_cols if c != "condition"]

encoding = (
    dms_data
    [encoded_cols]
    .drop_duplicates()
    .reset_index(drop=True)
    .assign(encoding=lambda x: x.index)
)

dms_data_encoded = (
    dms_data
    .merge(encoding)
    [["encoding", "site", "escape"]]
)
assert len(dms_data_encoded) == len(dms_data_encoded.drop_duplicates())

encoding = encoding.drop(columns="condition")

display(encoding)
display(dms_data_encoded)

Unnamed: 0,virus,known_to_neutralize,encoding
0,SARS-CoV-2,"(Wuhan-Hu-1,)",0
1,SARS-CoV-2,"(Wuhan-Hu-1,)",1
2,SARS-CoV-2,"(Wuhan-Hu-1,)",2
3,SARS-CoV-2,"(Wuhan-Hu-1, Omicron BA.12.1)",3
4,SARS-CoV-2,"(Wuhan-Hu-1,)",4
...,...,...,...
1530,SARS-CoV-2,"(Wuhan-Hu-1, SARS-CoV-1)",1530
1531,SARS-CoV-2,"(Wuhan-Hu-1, SARS-CoV-1, Omicron BA.1, Omicron...",1531
1532,SARS-CoV-2,"(Wuhan-Hu-1, Omicron BA.1, Omicron BA.2, Omicr...",1532
1533,SARS-CoV-2,"(Wuhan-Hu-1, SARS-CoV-1)",1533


Unnamed: 0,encoding,site,escape
0,0,338,0.06084
1,0,359,0.01637
2,0,370,0.03328
3,0,394,0.01316
4,0,396,0.02269
...,...,...,...
33508,1534,465,0.22970
33509,1534,466,0.26460
33510,1534,468,0.09359
33511,1534,484,0.04930


Now make bar plot with antibody fraction bound:

In [7]:
known_to_neutralize_options = (
    sorted(encoding.explode("known_to_neutralize")["known_to_neutralize"].unique())
)
known_to_neutralize_selection = alt.selection_point(
    fields=['known_to_neutralize'],
    bind=alt.binding_select(
        options=[None, *known_to_neutralize_options],
        labels=["all", *map(str, known_to_neutralize_options)],
        name="known_to_neutralize",
   ),
)

mut_selection = alt.selection_point(name='mut',
                                    fields=['site'],
                                    value=[{'site': -1}],
                                    empty=True,
                                    )

mut_escape_strength_slider = alt.binding_range(min=1, max=10,
                                               name='mutation_escape_strength')
mut_escape_strength_selection = alt.selection_point(name='mut_escape_strength',
                                                    fields=['mutation_escape_strength'],
                                                    bind=mut_escape_strength_slider,
                                                    value=[{'mutation_escape_strength': 2}])

eliciting_viruses = encoding["virus"].unique().tolist()
eliciting_virus_dropdown = alt.binding_select(
            options=[None] + eliciting_viruses,
            labels=['all'] + eliciting_viruses,
            )
eliciting_virus_selection = alt.selection_point(
                                    fields=['virus'],
                                    bind=eliciting_virus_dropdown,
                                    name='eliciting',
                                    value=[{'virus': 'SARS-CoV-2'}]
                                    )

frac_bound_bar = (
    alt.Chart(dms_data_encoded)
    .transform_impute(
        impute="escape",
        key="site",
        value=0,
        groupby=["encoding"],
        keyvals=sites,
    )
    .transform_lookup(
        lookup="encoding",
        from_=alt.LookupData(data=encoding, key="encoding", fields=lookup_cols)
    )
    .transform_flatten(["known_to_neutralize"])
    .transform_filter(eliciting_virus_selection)
    .transform_filter(known_to_neutralize_selection)
    # get maximum escape across any site for this condition
    .transform_joinaggregate(
        condition_escape_max='max(escape)',
        groupby=['encoding'],
        )
    .transform_calculate(
        # based on here: https://github.com/altair-viz/altair/issues/2366#issuecomment-812621436
        # based on here: https://stackoverflow.com/a/60894451/4191652
        site_binding_retained='(datum.condition_escape_max - '
                              ' if(indexof(mut.site, datum.site) >= 0, datum.escape, 0)) / '
                              'datum.condition_escape_max',
        )
    .transform_aggregate(
        binding_retained='product(site_binding_retained)',
        groupby=['encoding'],
        )
    .transform_calculate(
        binding_retained_exp='pow(datum.binding_retained, mut_escape_strength.mutation_escape_strength)'
        )
    .transform_aggregate(
        mean_binding_retained='mean(binding_retained_exp)',
        )
    .transform_calculate(
        bound='datum.mean_binding_retained',
        escaped='1 - datum.bound',
        )
    .transform_fold(
        ['bound', 'escaped'],
        ['binding state', 'fraction of antibodies']
        )
    .encode(x=alt.X('fraction of antibodies:Q',
                    axis=alt.Axis(grid=False),
                    ),
            y=alt.value(1),
            fill=alt.Color('binding state:N',
                            scale=alt.Scale(
                                domain=['bound', 'escaped'],
                                range=['lightgray', '#56B4E9'],
                                reverse=True,
                                ),
                            ),
            order=alt.Order('binding state:N'),
            tooltip=['binding state:N',
                     alt.Tooltip('fraction of antibodies:Q',
                                 format='.2g'),]
            )
    .mark_bar(stroke='black',
              size=20)
    .add_parameter(
        mut_selection,
        mut_escape_strength_selection,
        known_to_neutralize_selection,
        eliciting_virus_selection,
    )
    .properties(width=300, height=10)
    )

frac_bound_bar

Now make the line plot:

In [8]:
escape_mut_base = (
    alt.Chart(dms_data_encoded)
    .transform_impute(
        impute="escape",
        key="site",
        value=0,
        groupby=["encoding"],
        keyvals=sites,
    )
    .transform_lookup(
        lookup="encoding",
        from_=alt.LookupData(data=encoding, key="encoding", fields=lookup_cols)
    )
    .transform_flatten(["known_to_neutralize"])
    .encode(x=alt.X('site:Q',
                    axis=alt.Axis(grid=False),
                    scale=alt.Scale(zero=False, nice=False),
                    ),
            y=alt.Y('mean_escape_value:Q',
                    axis=alt.Axis(grid=False, title='escape'),
                    ),
            )
    .transform_filter(known_to_neutralize_selection)
    .transform_filter(eliciting_virus_selection)
    # get maximum escape across any site for this condition
    .transform_joinaggregate(
        condition_escape_max='max(escape)',
        groupby=['encoding'],
        )
    .transform_calculate(
        # based on here: https://github.com/altair-viz/altair/issues/2366#issuecomment-812621436
        # based on here: https://stackoverflow.com/a/60894451/4191652
        site_binding_retained='(datum.condition_escape_max - '
                              ' if(indexof(mut.site, datum.site) >= 0, datum.escape, 0)) / '
                              'datum.condition_escape_max',
        )
    .transform_joinaggregate(
        binding_retained='product(site_binding_retained)',
        groupby=['encoding'],
        )
    .transform_calculate(
        escape_after_mut='pow(datum.binding_retained, mut_escape_strength.mutation_escape_strength) * datum.escape'
        )
    .transform_aggregate(
        mutated='mean(escape_after_mut):Q',
        unmutated='mean(escape):Q',
        groupby=['site'],
        )
    .transform_fold(['unmutated', 'mutated'],
                    ['escape_type', 'mean_escape_value'])
    .transform_calculate(
        color_val='if((indexof(mut.site, datum.site) >= 0) & (datum.escape_type == "mutated"), '
                  '"mutated site", datum.escape_type)'
        )
    .properties(width=800, height=225)
    )

mut_escape_color_scale = alt.Scale(
        domain=['unmutated', 'mutated', 'mutated site'],
        range=['#999999', '#56B4E9', '#D55E00']
        )
mut_escape_point_size_scale = alt.Scale(
        domain=['unmutated', 'mutated', 'mutated site'],
        range=[30, 60, 100],
        )
mut_escape_opacity_scale = alt.Scale(
        domain=['unmutated', 'mutated', 'mutated site'],
        range=[0.5, 0.7, 1],
        )

escape_mut_lines = (
    escape_mut_base
    .encode(color=alt.Color('escape_type:N',
                            scale=mut_escape_color_scale,
                            ),
            opacity=alt.Opacity('escape_type:N',
                                scale=mut_escape_opacity_scale,
                                legend=None,
                                ),
            )
    .mark_line()
    )

escape_mut_points = (
    escape_mut_base
    .encode(color=alt.Color(
                    'color_val:N',
                    scale=mut_escape_color_scale,
                    legend=alt.Legend(
                            title=None,
                            labelExpr='if(datum.value == "unmutated", '
                                      '   "escape when no mutations", '
                                      '   if(datum.value == "mutated", '
                                      '      "escape with mutations", '
                                      '      "mutated site"))'
                            ),
                    ),
            opacity=alt.Opacity('color_val:N',
                                scale=mut_escape_opacity_scale,
                                legend=None,
                                ),
            size=alt.Size('color_val:N',
                          scale=mut_escape_point_size_scale,
                          ),
            tooltip=['site:O',
                     alt.Tooltip('mutated:Q',
                                 format='.2g'),
                     alt.Tooltip('unmutated:Q',
                                 format='.2g'),
                     ],
            )
    .mark_point(filled=True)
    .add_parameter(mut_selection,
                   mut_escape_strength_selection,
                   known_to_neutralize_selection,
                   eliciting_virus_selection,
                   )
    )

escape_chart = (
    ((escape_mut_lines + escape_mut_points) & frac_bound_bar)
    .configure_view(strokeOpacity=0)
    .configure_legend(orient='bottom',
                      labelFontSize=12,
                      title=None)
    .resolve_legend('independent')
    )

escape_calc_chartfile = 'docs/_includes/_temp_escape_calc_chart.html'
os.makedirs(os.path.dirname(escape_calc_chartfile), exist_ok=True)
print(f"Saving chart to {escape_calc_chartfile}")
escape_chart.save(escape_calc_chartfile)

escape_chart

Saving chart to docs/_includes/_temp_escape_calc_chart.html


Write the escape calculator data to a file:

In [9]:
escape_calc_data_file = 'processed_data/escape_calculator_data.csv'
os.makedirs(os.path.dirname(escape_calc_data_file), exist_ok=True)

print(f"Writing escape calculator data to {escape_calc_data_file}")

(
    dms_data
    .assign(known_to_neutralize=lambda x: x["known_to_neutralize"].map(lambda t: ";".join(t)))
    .to_csv(escape_calc_data_file, index=False)
)

Writing escape calculator data to processed_data/escape_calculator_data.csv
