# Plot yeast RBD DMS escape maps

## Import modules and read data
Import Python modules:

In [None]:
import itertools
import os

import altair as alt

import numpy

import pandas as pd

import sklearn.manifold

Disable max rows specifier for Altair:

In [None]:
_ = alt.data_transformers.disable_max_rows()

Read the deep mutational scanning data, reduce to site-level data:

In [None]:
dms_data = (pd.read_csv('./processed_data/escape_data.csv', low_memory=False)
            .assign(condition_alias=lambda x: x['condition_alias'].fillna(''))
            .rename(columns={'eliciting_virus': 'virus',
                             'neutralizes_Omicron': 'Omicron'})
            )

metric_cols = {'site_total_escape': 'sum of mutations at site',
               'site_mean_escape': 'mean of mutations at site'}

# get only site-level data
dms_data = (
        dms_data
        [['condition', 'condition_alias', 'condition_type', 'condition_subtype',
          'virus', 'Omicron', 'study', 'lab', 'site'] + list(metric_cols)]
        .drop_duplicates()
        )

# fill missing sites as 0
sites = list(range(dms_data['site'].min(), dms_data['site'].max() + 1))
assert dms_data.notnull().all().all()
dms_data = (pd.merge_ordered(dms_data,
                             pd.DataFrame({'site': sites}),
                             on='site',
                             left_by=['condition', 'condition_alias', 'condition_type', 'condition_subtype',
                                      'virus', 'study', 'lab'],
                             )
            .fillna(0)
            )

# for duplicated conditions, add study to name
dms_data = (
    dms_data
    .assign(n_studies=lambda x: x.groupby('condition')['study'].transform('nunique'),
            condition=lambda x: x['condition'].where(
                            x['n_studies'] == 1,
                            x['condition'] + ' (' + x['lab'] + ')'
                            )
            )
    .drop(columns='n_studies')
    )

# check no duplicated conditions
dup_conditions = (dms_data
                  .groupby('condition', as_index=False)
                  .aggregate(n_studies=pd.NamedAgg('study', 'nunique'))
                  .query('n_studies > 1')
                  )
if len(dup_conditions):
    raise ValueError('duplicate studies for some conditions:\n' + str(dup_conditions))
    
assert len(dms_data) == len(dms_data.groupby(['condition', 'site']))

dms_data

Make a tidy version of `dms_data` that is melted to have the two site metrics in one column, and gets rid of some columns we don't need for escape line plots:

In [None]:
dms_data_tidy = (
    dms_data
    .rename(columns=metric_cols)
    .melt(value_vars=metric_cols.values(),
          value_name='escape',
          var_name='metric',
          id_vars=[c for c in dms_data.columns if c not in metric_cols])
    .drop(columns=['condition_type', 'condition_alias', 'study'])
    )

dms_data_tidy

## Perform multidimensional scaling
Steps:
 1. Calculate similarities betweeen escape maps for each antibody.
 2. Convert similarities to dissimilarities.
 3. Do multi-dimensional scaling on dissimilarities.


First, compute the dissimilarity between all pairs of escape profiles in a data frame.
We calculate similarity as the dot product of the escape profile site-level metric for each pair of conditions, normalizing each profile so it's dot product with itself is one.
Then we compute the dissimilarity as just one minux the similarity:

In [None]:
def escape_similarity(df):
    """Compute similarity between all pairs of conditions in `df`."""
    df = df[['condition', 'site', 'escape']].drop_duplicates()
    assert not df.isnull().any().any(), df
    
    conditions = df['condition'].unique()
    similarities = []
    pivoted_df = (
        df
        .pivot_table(index='site',
                     columns='condition',
                     values='escape',
                     fill_value=0)
        # for normalization: https://stackoverflow.com/a/58113206
        # to get norm: https://stackoverflow.com/a/47953601
        .transform(lambda x: x / numpy.linalg.norm(x, axis=0))
        )
    for cond1, cond2 in itertools.product(conditions, conditions):
        similarity = (
            pivoted_df
            [list({cond1, cond2})]
            .assign(similarity=lambda x: x[cond1] * x[cond2])
            ['similarity']
            )
        assert similarity.notnull().all(), similarity  # make sure no sites have null values
        similarities.append(similarity.sum())  # sum of similarities over sites
    return pd.DataFrame(numpy.array(similarities).reshape(len(conditions), len(conditions)),
                        columns=conditions, index=conditions)

similarities = (
    dms_data_tidy
    .groupby(['metric'])
    .apply(escape_similarity)
    )

dissimilarities = (1 - similarities).clip(lower=0)

dissimilarities.round(3)

Now do the multidimensional scaling [as described here](https://scikit-learn.org/stable/auto_examples/manifold/plot_mds.html#sphx-glr-auto-examples-manifold-plot-mds-py) to get the x and y coordinates for each antibody / serum.
For each metric, we do this for several random number seeds (different seeds will given different MDS layouts):

In [None]:
mds_coords = []
seeds = [1, 2]
for seed, (metric, mat) in itertools.product(seeds, dissimilarities.groupby('metric')):
    # use multidimensional scaling to get locations of antibodies
    mds = sklearn.manifold.MDS(n_components=2,
                               metric=True,
                               max_iter=3000,
                               eps=1e-6,
                               random_state=seed,
                               dissimilarity='precomputed',
                               n_jobs=1)
    locs = mds.fit_transform(mat)
    mds_coords.append(pd.DataFrame(locs, columns=['x', 'y'])
                      .assign(metric=metric,
                              seed=seed,
                              condition=mat.columns,
                              xmin=lambda df: df['x'].min(),
                              ymin=lambda df: df['y'].min(),
                              x=lambda df: df['x'] - df['xmin'],
                              y=lambda df: df['y'] - df['ymin'],
                              )
                      )
mds_coords = (
    pd.concat(mds_coords,
              ignore_index=True)
    .merge(dms_data
           [['condition', 'condition_type', 'condition_subtype', 'study', 'lab',
             'virus', 'Omicron']]
           .drop_duplicates(),
           on='condition',
           how='left',
           validate='many_to_one')
    .drop(columns=['xmin', 'ymin'])
    )
mds_coords

Get a data frame with just the conditions and their citations:

In [None]:
conditions_df = (
    dms_data
    [['condition_type', 'condition_subtype', 'condition', 'condition_alias',
      'virus', 'study', 'lab', 'Omicron']]
    .sort_values(['condition_type', 'condition_subtype', 'condition'])
    .drop_duplicates()
    .reset_index(drop=True)
    )

## Read information on studies and merge into conditions data frame

In [None]:
studies = pd.read_csv('processed_data/studies.csv')

studies

In [None]:
conditions_df = (
    conditions_df
    .drop(columns=['citation','url'], errors='ignore')
    .merge(studies, how='left', on='study', validate='many_to_one')
    )

Add `dms-view` links:

In [None]:
dms_view_base_url = 'https://dms-view.github.io/?data-url=https%3A%2F%2Fmedia.githubusercontent.com%2Fmedia%2Fjbloomlab%2FSARS2_RBD_Ab_escape_maps%2Fmain%2Fprocessed_data%2Fescape_data.csv&condition={condition}&site_metric=site_total_escape&mutation_metric=mut_escape&selected_sites=&protein-data-color=&protein-other-color=pink&markdown-url=https%3A%2F%2Fraw.githubusercontent.com%2Fjbloomlab%2FSARS2_RBD_Ab_escape_maps%2Fmain%2Fdms-view%2Fmanifest.md&pdb-url=https%3A%2F%2Fraw.githubusercontent.com%2Fjbloomlab%2FSARS2_RBD_Ab_escape_maps%2Fmain%2Fdms-view%2F6M0J.pdb'

conditions_df['dms_view_url'] = conditions_df['condition'].map(lambda c: dms_view_base_url.format(condition=c))

conditions_df.head()

## Make interactive plots
First make plot to select condition(s) to show:

In [None]:
condition_subtypes = (conditions_df
                      ['condition_subtype']
                      .unique()
                      .tolist()
                      )

# define colors from here: https://vega.github.io/vega/docs/schemes/
# similar to Greaney et al antibody class papers
condition_subtype_colors = {'class 1': '#E52794',
                            'class 2': '#6A0DAD',
                            'class 3': '#66CCEE',
                            'class 4': '#E69F00',
                            # greens from https://www.rapidtables.com/web/color/green-color.html
                            'convalescent serum': '#006400', 
                            'Moderna vaccine serum': '#98FB98',
                            'B.1.351 convalescent plasma': '#808000',
                            }
if not set(condition_subtypes).issubset(condition_subtype_colors):
    raise ValueError('missing colors for some condition subtypes')
select_condition_subtype = alt.selection_multi(fields=['condition_subtype'],
                                               # initialize to show antibodies but not sera
                                               init=[{'condition_subtype': subtype} for subtype in
                                                     conditions_df.query('condition_type == "antibody"')
                                                     ['condition_subtype'].unique()],
                                               resolve='union',
                                               empty='none',
                                               )
condition_subtype_color = alt.condition(select_condition_subtype,
                                   alt.Color('condition_subtype:N',
                                             legend=None,
                                             scale=alt.Scale(domain=condition_subtypes,
                                                             range=[condition_subtype_colors[c]
                                                                    for c in condition_subtypes]),
                                                             ),
                                   alt.value('white'),
                                   )

circle_size = 110

legend_condition_type = (
    alt.Chart(conditions_df[['condition_type', 'condition_subtype']].drop_duplicates())
    .mark_circle(size=0.7 * circle_size,
                 stroke='black',
                 strokeWidth=1)
    .encode(x=alt.X('condition_type:N',
                    axis=alt.Axis(title=['',
                                         'On each subplot, you can:',
                                         ' - click to select one item',
                                         ' - shift-click to select additional items',
                                         ' - double-click to clear selected items',
                                         ' - mouseover to see antibody/serum name',
                                         ],
                                  titleAlign='left',
                                  titleFontSize=14,
                                  titleFontWeight='normal',
                                  titleFontStyle='italic',
                                  labelFontSize=12),
                    ),
            y=alt.Y('condition_subtype:N',
                    sort=condition_subtypes,
                    axis=alt.Axis(title=None,
                                  labelFontSize=12,
                                  orient='right'),
                    ),
            color=condition_subtype_color,
            )
    .add_selection(select_condition_subtype)
    .properties(title={'text': 'choose antibody/serum types to display',
                       'align': 'left',
                       'anchor': 'start'})
    )

(legend_condition_type).configure_view(strokeOpacity=0)

In [None]:
eliciting_viruses = sorted(dms_data_tidy['virus'].unique(),
                           reverse=True)
eliciting_virus_dropdown = alt.binding_select(
            options=[None] + eliciting_viruses,
            labels=['all'] + eliciting_viruses,
            )
eliciting_virus_selection = alt.selection_single(
                                    fields=['virus'],
                                    bind=eliciting_virus_dropdown,
                                    name='eliciting',
                                    )

labs = sorted(dms_data_tidy['lab'].unique())
lab_dropdown = alt.binding_select(
            options=[None] + labs,
            labels=['all'] + labs,
            )
lab_selection = alt.selection_single(fields=['lab'],
                                     bind=lab_dropdown,
                                     name='source',
                                     init={'lab': 'Bloom_JD'},
                                     )

omicron_selection = alt.selection_single(
        fields=['Omicron'],
        bind=alt.binding_select(options=[None, True, False],
                                labels=['either', 'yes', 'no']
                                ),
        name='neutralizes',
        )

highlight_condition = (
    alt.selection(type='multi',
                  on='click',
                  fields=['condition'],
                  nearest=False,
                  empty='none',
                  toggle=True,
                  resolve='union',
                  )
    )

cell_height = 17  # size of cells in heat map

# build zoom bar to zoom in condition legend
legend_condition_zoom_brush = alt.selection_interval(
                encodings=['y'],
                mark=alt.BrushConfig(stroke='black', strokeWidth=2),
                )
legend_condition_zoom_bar = (
    alt.Chart(conditions_df)
    .mark_rect()
    .encode(y=alt.Y('condition:N',
                    title='antibody / sera zoom bar',
                    sort=conditions_df['condition'].unique(),
                    axis=alt.Axis(ticks=False,
                                  labels=False,
                                  titleFontSize=12)
                    ),
            color=condition_subtype_color,
            )
    .add_selection(legend_condition_zoom_brush)
    .transform_filter(select_condition_subtype)
    .transform_filter(omicron_selection)
    .transform_filter(eliciting_virus_selection)
    .transform_filter(lab_selection)
    .properties(height=175,
                width=15)
    )

condition_base = (
    alt.Chart(conditions_df)
    .add_selection(select_condition_subtype,
                   highlight_condition,
                   omicron_selection,
                   eliciting_virus_selection,
                   lab_selection,
                   legend_condition_zoom_brush)
    .transform_filter(select_condition_subtype)
    .transform_filter(legend_condition_zoom_brush)
    .transform_filter(omicron_selection)
    .transform_filter(eliciting_virus_selection)
    .transform_filter(lab_selection)
    .properties(height={'step': cell_height},
                width=cell_height,
                )
    )

legend_condition_heatmap = (
    condition_base
    .encode(y=alt.Y('condition:N',
                    sort=conditions_df['condition'].unique(),
                    title=None,
                    axis=alt.Axis(orient='right',
                                  labelFontSize=11,
                                  ),
                    ),
            color=condition_subtype_color,
            strokeWidth=alt.condition(~highlight_condition,
                                      alt.value(0.5),
                                      alt.value(3)),
            stroke=alt.condition(~highlight_condition,
                                 alt.value('black'),
                                 alt.value('black')),
            )
    .mark_rect()
    )

condition_citations = (
    condition_base
    .encode(y=alt.Y('condition:N',
                    sort=conditions_df['condition'].unique(),
                    title=None,
                    axis=None,
                    ),
            text='citation:N',
            href='url:N'
            )
    .mark_text(align='left',
               fontSize=11,
               fontStyle='normal',
               color='darkblue',
               )
    )

condition_dms_view = (
    condition_base
    .encode(y=alt.Y('condition:N',
                    sort=conditions_df['condition'].unique(),
                    title=None,
                    axis=None,
                    ),
            href='dms_view_url:N'
            )
    .mark_text(text='dms-view',
               align='left',
               fontSize=11,
               fontStyle='normal',
               color='darkblue',
               )
    )

condition_alias = (
    condition_base
    .encode(y=alt.Y('condition:N',
                    sort=conditions_df['condition'].unique(),
                    title=None,
                    axis=None,
                    ),
            text='condition_alias:N',
            )
    .mark_text(text='dms-view',
               align='left',
               fontSize=11,
               fontStyle='normal',
               )
    )

legend_condition = (
    (legend_condition_zoom_bar | alt.hconcat(legend_condition_heatmap,
                                             condition_citations,
                                             condition_dms_view,
                                             condition_alias,
                                             spacing=2)
     )
    .properties(title={'text': ['select antibody/serum by by clicking box; shift-click',
                                'citation or dms-view text to open that information']})
    )

legend_condition.configure_view(strokeOpacity=0)

Next make MDS plot:

In [None]:
# build selections to select metric, normalization, and random seed
metric_select_binding = alt.binding_select(options=mds_coords['metric'].unique())
metric_selection = alt.selection_single(name='escape',
                                        fields=['metric'],
                                        bind=metric_select_binding,
                                        init={'metric': 'sum of mutations at site'})

seed_select_binding = alt.binding_select(options=mds_coords['seed'].unique())
seed_selection = alt.selection_single(name='multidimensional scaling random',
                                      fields=['seed'],
                                      bind=seed_select_binding,
                                      init={'seed': 2},
                                      )

# size, but scaled so a unit on x and y mean the same; note
# padding added here so sizes correct
size = 180
pad = 0.04
x_extent = mds_coords['x'].max() - mds_coords['x'].min()
y_extent = mds_coords['y'].max() - mds_coords['y'].min()
y_min = mds_coords['y'].min() - pad * y_extent
y_max = mds_coords['y'].max() + pad * y_extent
x_min = mds_coords['x'].min() - pad * x_extent
x_max = mds_coords['x'].max() + pad * x_extent

mds_plot = (
    alt.Chart(mds_coords)
    .encode(x=alt.X('x:Q',
                    scale=alt.Scale(padding=0,
                                    nice=False,
                                    domain=(x_min, x_max),
                                    ),
                    axis=alt.Axis(labels=False,
                                  title=None,
                                  ticks=False,
                                  grid=False,
                                  ),
                    ),
            y=alt.Y('y:Q',
                    scale=alt.Scale(padding=0,
                                    nice=False,
                                    domain=(y_min, y_max),
                                    ),
                    axis=alt.Axis(labels=False,
                                  title=None,
                                  ticks=False,
                                  grid=False,
                                  ),
                    ),
            opacity=alt.condition(~highlight_condition, alt.value(0.5), alt.value(1)),
            stroke=alt.condition(~highlight_condition, alt.value(None), alt.value('black')),
            color=condition_subtype_color,
            tooltip=['condition'])
    .mark_circle(size=circle_size)
    .properties(width=size * x_extent,
                height=size * y_extent,
                title={'text': 'multidimensional scaling of antibodies/sera',
                       'subtitle': ['antibodies/sera with escape mutations at similar',
                                    'sites are positioned nearby in the plot below'],
                       'anchor': 'start',
                       'align': 'left',
                       }
                )
    .add_selection(seed_selection,
                   metric_selection,
                   highlight_condition,
                   select_condition_subtype,
                   omicron_selection,
                   eliciting_virus_selection,
                   lab_selection,
                   )
    .transform_filter(eliciting_virus_selection)
    .transform_filter(omicron_selection)
    .transform_filter(lab_selection)
    .transform_filter(metric_selection)
    .transform_filter(seed_selection)
    .transform_filter(select_condition_subtype)
    )

# box around MDS plot: https://stackoverflow.com/a/62862229/4191652
dummy_lines = {}
for key, x, y in [('top', (x_min, x_max), (y_max, y_max)),
                  ('right', (x_max, x_max), (y_min, y_max)),
                  ]:
    dummy_lines[key] = (
        alt.Chart(pd.DataFrame({'x': x,
                                'y': y})
                  )
        .mark_line(color='black',
                   strokeWidth=0.5)
        .encode(x=alt.X('x:Q',
                        scale=alt.Scale(padding=0,
                                        nice=False,
                                        domain=(x_min, x_max),
                                        ),
                        axis=alt.Axis(labels=False,
                                      title=None,
                                      ticks=False,
                                      grid=False,
                                      ),
                        ),
                y=alt.Y('y:Q',
                        scale=alt.Scale(padding=0,
                                        nice=False,
                                        domain=(y_min, y_max),
                                        ),
                        axis=alt.Axis(labels=False,
                                      title=None,
                                      ticks=False,
                                      grid=False,
                                      ),
                        )
                )
        )
mds_plot = mds_plot + dummy_lines['top'] + dummy_lines['right']

# show the plot with legend
(legend_condition_type | mds_plot).configure_view(stroke='black').configure_view(strokeOpacity=0)

Next make line plots:

In [None]:
width = 800

# build zoom bar to zoom in on sites
zoom_brush = alt.selection_interval(
                encodings=['x'],
                mark=alt.BrushConfig(stroke='black', strokeWidth=2))
zoom_bar = (
    alt.Chart(dms_data_tidy[['site']].drop_duplicates())
    .mark_rect(color='lightgray')
    .encode(x=alt.X('site:O',
                    title=None,
                    ),
            )
    .add_selection(zoom_brush)
    .properties(width=width,
                height=15,
                title='site zoom bar')
    )

# build base for escape plots
escape_base = (
    alt.Chart(dms_data_tidy.assign(mean_over='all displayed types'))
    .encode(x=alt.X('site:O',
                    axis=alt.Axis(grid=False),
                    ),
            )
    .transform_filter(eliciting_virus_selection)
    .transform_filter(omicron_selection)
    .transform_filter(lab_selection)
    .transform_filter(metric_selection)
    .transform_filter(select_condition_subtype)
    .transform_filter(zoom_brush)
    .properties(width=width,
                height=200,
                )
    )

# the escape line plot
escape_lines = (
    escape_base
    .encode(size=alt.condition(~highlight_condition, alt.value(0.9), alt.value(1.5)),
            opacity=alt.condition(~highlight_condition, alt.value(0.4), alt.value(1)),
            )
    .add_selection(omicron_selection,
                   eliciting_virus_selection,
                   lab_selection,
                   metric_selection,
                   select_condition_subtype,
                   zoom_brush,
                   )
    .mark_line()
    )

# escape point plot
escape_points = (
    escape_base
    .encode(fill=condition_subtype_color,
            tooltip=['condition:N', 'site:O'],
            )
    .mark_point(size=40)
    .transform_filter(highlight_condition)
    # needs to be add_selection within chart: https://github.com/altair-viz/altair/issues/2368#issuecomment-742377146
    .add_selection(highlight_condition)
    )

# combine point and line plots
escape_lines_points = (
    (escape_lines + escape_points)
    .encode(detail='condition:N',  # https://github.com/altair-viz/altair/issues/985
            color=condition_subtype_color,
            y=alt.Y('escape:Q',
                    axis=alt.Axis(grid=False),
                    ),
            )
    .properties(title={'text': 'escape from individual antibodies/sera'})
    )

# checkbox to specify if mean for only selected antibodies or all antibody/serum types
mean_radio = alt.binding_radio(options=['all displayed types',
                                        'just selected antibodies/sera'])
mean_selection = alt.selection_single(fields=['mean_over'],
                                      bind=mean_radio,
                                      name='calculate',
                                      init={'mean_over': 'all displayed types'})
# plot of mean values
escape_mean = (
    escape_base
    .mark_line(color='darkgray',
               point={'color': 'darkgray',
                      'size': 60},
               )
    .encode(tooltip=['site:O',
                     alt.Tooltip('mean(escape):Q',
                                 format='.2g',
                                 title='escape'),
                     ],
            y=alt.Y('mean(escape):Q',
                    axis=alt.Axis(grid=False,
                                  title='escape',
                                  ),
                    ),
            )
    .transform_filter(highlight_condition | (select_condition_subtype & mean_selection))
    .add_selection(highlight_condition,
                   mean_selection,
                   )
    .properties(title={'text': 'mean escape over selected antibodies/sera or ' +
                               'antibody/serum types (choose with radio button below)'
                       })
    )

# combine zoom bar, lines, and points
escape_plot = (zoom_bar & (escape_lines_points & escape_mean).resolve_scale(x='shared'))

escape_plot

Now combine the antibody MDS and escape plots:

In [None]:
chart = (
    (((legend_condition_type | mds_plot) & escape_plot) | legend_condition)
    .configure(padding={'left': 5,
                        'right': 60,
                        'top': 5,
                        'bottom': 5})
    .configure_view(strokeOpacity=0)
    )

chartfile = 'docs/_includes/chart.html'
os.makedirs(os.path.dirname(chartfile), exist_ok=True)
print(f"Saving chart to {chartfile}")
chart.save(chartfile)

chart

## Make an "escape calculator" plot
This is **only** for monoclonal antibodies, so get just those data:

In [None]:
# get just antibody subtypes
condition_subtypes = (
    dms_data
    .query('condition_type == "antibody"')
    ['condition_subtype']
    .unique()
    .tolist()
    )
print(f"Including following condition subtypes: {condition_subtypes}")

escape_calc_data = (
    dms_data_tidy
    .query('condition_subtype in @condition_subtypes')
    .drop(columns='condition_subtype')
    .drop_duplicates()
    )

escape_calc_data

Now make bar plot with antibody fraction bound:

In [None]:
mut_selection = alt.selection_multi(name='mut',
                                    fields=['site'],
                                    init=[{'site': -1}],
                                    empty='none',
                                    )

mut_escape_strength_slider = alt.binding_range(min=1, max=10,
                                               name='mutation_escape_strength')
mut_escape_strength_selection = alt.selection_single(name='mut_escape_strength',
                                                     fields=['mutation_escape_strength'],
                                                     bind=mut_escape_strength_slider,
                                                     init={'mutation_escape_strength': 2})

eliciting_virus_selection2 = alt.selection_single(
                                    fields=['virus'],
                                    bind=eliciting_virus_dropdown,
                                    name='eliciting',
                                    init={'virus': 'SARS-CoV-2'}
                                    )

lab_selection2 = alt.selection_single(fields=['lab'],
                                      bind=lab_dropdown,
                                      name='source',
                                      )

frac_bound_bar = (
    alt.Chart(escape_calc_data)
    .transform_filter(eliciting_virus_selection2)
    .transform_filter(lab_selection2)
    # get maximum escape across any site for this condition
    .transform_joinaggregate(
        condition_escape_max='max(escape)',
        groupby=['condition', 'metric'],
        )
    .transform_calculate(
        # based on here: https://github.com/altair-viz/altair/issues/2366#issuecomment-812621436
        # based on here: https://stackoverflow.com/a/60894451/4191652
        site_binding_retained='(datum.condition_escape_max - '
                              ' if(indexof(mut.site, datum.site) >= 0, datum.escape, 0)) / '
                              'datum.condition_escape_max',
        )
    .transform_aggregate(
        binding_retained='product(site_binding_retained)',
        groupby=['condition', 'metric']
        )
    .transform_calculate(
        binding_retained_exp='pow(datum.binding_retained, mut_escape_strength.mutation_escape_strength)'
        )
    .transform_aggregate(
        mean_binding_retained='mean(binding_retained_exp)',
        groupby=['metric'],
        )
    .transform_calculate(
        bound='datum.mean_binding_retained',
        escaped='1 - datum.bound',
        )
    .transform_fold(
        ['bound', 'escaped'],
        ['binding state', 'fraction of antibodies']
        )
    .encode(x=alt.X('fraction of antibodies:Q',
                    axis=alt.Axis(grid=False),
                    ),
            y=alt.value(1),
            fill=alt.Color('binding state:N',
                            scale=alt.Scale(
                                domain=['bound', 'escaped'],
                                range=['lightgray', '#56B4E9'],
                                reverse=True,
                                ),
                            ),
            order=alt.Order('binding state:N'),
            tooltip=['binding state:N',
                     alt.Tooltip('fraction of antibodies:Q',
                                 format='.2g'),]
            )
    .mark_bar(stroke='black',
              size=20)
    .transform_filter(metric_selection)
    .add_selection(metric_selection,
                   mut_selection,
                   mut_escape_strength_selection,
                   omicron_selection,
                   eliciting_virus_selection2,
                   lab_selection2,
                   )
    .properties(width=300,
                height=10)
    )

frac_bound_bar

Now make the line plot:

In [None]:
escape_mut_base = (
    alt.Chart(escape_calc_data)
    .encode(x=alt.X('site:O',
                    axis=alt.Axis(grid=False),
                    ),
            y=alt.Y('mean_escape_value:Q',
                    axis=alt.Axis(grid=False,
                                  title='escape',
                                  ),
                    ),
            )
    .transform_filter(metric_selection)
    .transform_filter(lab_selection2)
    .transform_filter(omicron_selection)
    .transform_filter(zoom_brush)
    .transform_filter(eliciting_virus_selection2)
    # get maximum escape across any site for this condition
    .transform_joinaggregate(
        condition_escape_max='max(escape)',
        groupby=['condition', 'metric'],
        )
    .transform_calculate(
        # based on here: https://github.com/altair-viz/altair/issues/2366#issuecomment-812621436
        # based on here: https://stackoverflow.com/a/60894451/4191652
        site_binding_retained='(datum.condition_escape_max - '
                              ' if(indexof(mut.site, datum.site) >= 0, datum.escape, 0)) / '
                              'datum.condition_escape_max',
        )
    .transform_joinaggregate(
        binding_retained='product(site_binding_retained)',
        groupby=['condition', 'metric']
        )
    .transform_calculate(
        escape_after_mut='pow(datum.binding_retained, mut_escape_strength.mutation_escape_strength) * datum.escape'
        )
    .transform_aggregate(
        mutated='mean(escape_after_mut):Q',
        unmutated='mean(escape):Q',
        groupby=['site', 'metric'],
        )
    .transform_fold(['unmutated', 'mutated'],
                    ['escape_type', 'mean_escape_value'])
    .transform_calculate(
        color_val='if((indexof(mut.site, datum.site) >= 0) & (datum.escape_type == "mutated"), '
                  '"mutated site", datum.escape_type)'
        )
    .properties(width=width,
                height=225,
                )
    )

mut_escape_color_scale = alt.Scale(
        domain=['unmutated', 'mutated', 'mutated site'],
        range=['#999999', '#56B4E9', '#D55E00']
        )
mut_escape_point_size_scale = alt.Scale(
        domain=['unmutated', 'mutated', 'mutated site'],
        range=[30, 60, 100],
        )
mut_escape_opacity_scale = alt.Scale(
        domain=['unmutated', 'mutated', 'mutated site'],
        range=[0.5, 0.7, 1],
        )

escape_mut_lines = (
    escape_mut_base
    .encode(color=alt.Color('escape_type:N',
                            scale=mut_escape_color_scale,
                            ),
            opacity=alt.Opacity('escape_type:N',
                                scale=mut_escape_opacity_scale,
                                legend=None,
                                ),
            )
    .mark_line()
    )

escape_mut_points = (
    escape_mut_base
    .encode(color=alt.Color(
                    'color_val:N',
                    scale=mut_escape_color_scale,
                    legend=alt.Legend(
                            title=None,
                            labelExpr='if(datum.value == "unmutated", '
                                      '   "escape when no mutations", '
                                      '   if(datum.value == "mutated", '
                                      '      "escape with mutations", '
                                      '      "mutated site"))'
                            ),
                    ),
            opacity=alt.Opacity('color_val:N',
                                scale=mut_escape_opacity_scale,
                                legend=None,
                                ),
            size=alt.Size('color_val:N',
                          scale=mut_escape_point_size_scale,
                          ),
            tooltip=['site:O',
                     alt.Tooltip('mutated:Q',
                                 format='.2g'),
                     alt.Tooltip('unmutated:Q',
                                 format='.2g'),
                     ],
            )
    .mark_point(filled=True)
    .add_selection(metric_selection,
                   mut_selection,
                   zoom_brush,
                   mut_escape_strength_selection,
                   omicron_selection,
                   eliciting_virus_selection2,
                   lab_selection2,
                   )
    )

escape_chart = (
    (zoom_bar & (escape_mut_lines + escape_mut_points) & frac_bound_bar)
    .configure_view(strokeOpacity=0)
    .configure_legend(orient='bottom',
                      labelFontSize=12,
                      title=None)
    .resolve_legend('independent')
    )

escape_calc_chartfile = 'docs/_includes/escape_calc_chart.html'
os.makedirs(os.path.dirname(escape_calc_chartfile), exist_ok=True)
print(f"Saving chart to {escape_calc_chartfile}")
escape_chart.save(escape_calc_chartfile)

escape_chart

Write the escape calculator data to a file:

In [None]:
escape_calc_data_file = 'processed_data/escape_calculator_data.csv'
os.makedirs(os.path.dirname(escape_calc_data_file), exist_ok=True)

print(f"Writing escape calculator data to {escape_calc_data_file}")

escape_calc_data.to_csv(escape_calc_data_file, index=False)

## Mini-example escape calculator
Now we make an "example" escape calculator just for demonstration purposes of how it works.

First, get example data for a few antibodies:

In [None]:
example_abs = ['LY-CoV016', 'LY-CoV555', 'REGN10987']

ab_colors = {'LY-CoV016': '#E52794',
             'LY-CoV555': '#6A0DAD',
             'REGN10987': '#66CCEE',
             }

example_escape_calc_data = (
    escape_calc_data
    .query('condition.str.contains("(Bloom_JD)")')
    .assign(condition=lambda x: x['condition'].str.split().str[0])
    .query('condition in @example_abs')
    .query('virus == "SARS-CoV-2"')
    .query('metric == "sum of mutations at site"')
    .reset_index(drop=True)
    .drop(columns=['metric', 'virus'])
    .rename(columns={'condition': 'antibody'})
    )

example_escape_calc_data

Now make the plot:

In [None]:
antibody_selection = alt.selection_multi(
        name='ab',
        fields=['antibody'],
        bind='legend',
        toggle='true',
        init=[{'antibody': ab} for ab in example_abs],
        )

example_base = (
    alt.Chart(example_escape_calc_data)
    .properties(width=650,
                height=200,
                )
    )

example_ab_escape_base = (
    example_base
    .encode(x=alt.X('site:O',
                    axis=alt.Axis(grid=False),
                    ),
            y=alt.Y('escape:Q',
                    axis=alt.Axis(grid=False,
                                  title='escape',
                                  ),
                    ),
            color=alt.Color('antibody:N',
                            legend=alt.Legend(orient='top',
                                              title='antibody (thick black line is mean)',
                                              ),
                            scale=alt.Scale(domain=example_abs,
                                            range=[ab_colors[ab] for ab in ab_colors],
                                            ),
                            ),
            opacity=alt.condition(antibody_selection, alt.value(0.7), alt.value(0)),
            tooltip=['antibody',
                     'site',
                     alt.Tooltip('escape', format='.2f'),
                     ],
            )
    )

example_ab_escape_lines = (
    example_ab_escape_base
    .mark_line(size=1)
    )

example_ab_escape_points = (
    example_ab_escape_base
    .mark_point(filled=True, size=20)
    .add_selection(antibody_selection)
    )

example_mean_escape_base = (
    example_base
    .transform_calculate(escape2='if(indexof(ab.antibody, datum.antibody) >= 0, datum.escape, 0)')
    .encode(x=alt.X('site:O',
                    axis=alt.Axis(grid=False),
                    ),
            y=alt.Y('mean(escape2):Q',
                    axis=alt.Axis(grid=False,
                                  title='escape',
                                  ),
                    ),
            color=alt.value('black'),
            )
    )

example_mean_escape_lines = (
    example_mean_escape_base
    .mark_line(size=2.5)
    )

example_mean_escape_points = (
    example_mean_escape_base
    .mark_point(filled=True, size=40, opacity=1)
    )

example_chart = (example_ab_escape_points + example_ab_escape_lines +
                 example_mean_escape_lines + example_mean_escape_points)


example_chart.save('docs/_includes/mini_example_escape_calc.html')

example_chart