# Plot yeast RBD DMS escape maps

## Import modules and read data
Import Python modules:

In [1]:
import itertools

import altair as alt

import numpy

import pandas as pd

import sklearn.manifold

Disable max rows specifier for Altair:

In [2]:
_ = alt.data_transformers.disable_max_rows()

Read the deep mutational scanning data, and reduce to site-level data, calculating the max, mean, and total site-based metrics:

In [3]:
dms_mut_data = pd.read_csv('./results/merged_data/yeast_RBD_DMS_data.csv')

# calculate site metrics and fill missing sites as 0
sites = list(range(dms_mut_data['site'].min(), dms_mut_data['site'].max() + 1))
dms_data = (
    dms_mut_data
    .groupby(['condition', 'condition_type', 'condition_subtype', 'study', 'site'],
             as_index=False, dropna=False)
    .aggregate(site_total_escape=pd.NamedAgg('mut_escape', 'sum'),
               site_max_escape=pd.NamedAgg('mut_escape', 'max'),
               site_mean_escape=pd.NamedAgg('mut_escape', 'mean')
               )
    )
assert dms_data.notnull().all().all()
dms_data = (pd.merge_ordered(dms_data,
                             pd.DataFrame({'site': sites}),
                             on='site',
                             left_by=['condition', 'study', 'condition_type', 'condition_subtype'],
                             )
            .fillna(0)
            )

# check no duplicated conditions
dup_conditions = (dms_data
                  .groupby('condition', as_index=False)
                  .aggregate(n_studies=pd.NamedAgg('study', 'nunique'))
                  .query('n_studies > 1')
                  )
if len(dup_conditions):
    raise ValueError('duplicate studies for some conditions:\n' + str(dup_conditions))

dms_data

Unnamed: 0,condition,condition_type,condition_subtype,study,site,site_total_escape,site_max_escape,site_mean_escape
0,C002,antibody,not clinical antibody,2021_Greaney_Rockefeller,331,0.028500,0.001785,0.001781
1,C002,antibody,not clinical antibody,2021_Greaney_Rockefeller,332,0.033839,0.001781,0.001781
2,C002,antibody,not clinical antibody,2021_Greaney_Rockefeller,333,0.032058,0.001781,0.001781
3,C002,antibody,not clinical antibody,2021_Greaney_Rockefeller,334,0.032058,0.001781,0.001781
4,C002,antibody,not clinical antibody,2021_Greaney_Rockefeller,335,0.033839,0.001781,0.001781
...,...,...,...,...,...,...,...,...
10045,subject K (day 29),serum,convalescent serum,2021_Greaney_HAARVI_sera,527,0.005310,0.002243,0.000312
10046,subject K (day 29),serum,convalescent serum,2021_Greaney_HAARVI_sera,528,0.009810,0.002504,0.000545
10047,subject K (day 29),serum,convalescent serum,2021_Greaney_HAARVI_sera,529,0.031831,0.009543,0.001768
10048,subject K (day 29),serum,convalescent serum,2021_Greaney_HAARVI_sera,530,0.024201,0.007600,0.001274


Make a tidy version of `dms_data` that is melted to have the two site metrics in one column:

In [4]:
tidy_cols = {'site_total_escape': 'sum of mutations at site',
             'site_max_escape': 'max of any mutation at site',
             'site_mean_escape': 'mean of mutations at site'}
dms_data_tidy = (
    dms_data
    .rename(columns=tidy_cols)
    .melt(value_vars=tidy_cols.values(),
          value_name='antibody escape',
          var_name='metric',
          id_vars=[c for c in dms_data.columns if c not in tidy_cols])
    )

dms_data_tidy

Unnamed: 0,condition,condition_type,condition_subtype,study,site,metric,antibody escape
0,C002,antibody,not clinical antibody,2021_Greaney_Rockefeller,331,sum of mutations at site,0.028500
1,C002,antibody,not clinical antibody,2021_Greaney_Rockefeller,332,sum of mutations at site,0.033839
2,C002,antibody,not clinical antibody,2021_Greaney_Rockefeller,333,sum of mutations at site,0.032058
3,C002,antibody,not clinical antibody,2021_Greaney_Rockefeller,334,sum of mutations at site,0.032058
4,C002,antibody,not clinical antibody,2021_Greaney_Rockefeller,335,sum of mutations at site,0.033839
...,...,...,...,...,...,...,...
30145,subject K (day 29),serum,convalescent serum,2021_Greaney_HAARVI_sera,527,mean of mutations at site,0.000312
30146,subject K (day 29),serum,convalescent serum,2021_Greaney_HAARVI_sera,528,mean of mutations at site,0.000545
30147,subject K (day 29),serum,convalescent serum,2021_Greaney_HAARVI_sera,529,mean of mutations at site,0.001768
30148,subject K (day 29),serum,convalescent serum,2021_Greaney_HAARVI_sera,530,mean of mutations at site,0.001274


## Perform multidimensional scaling
Steps:
 1. Calculate similarities betweeen escape maps for each antibody.
 2. Convert similarities to dissimilarities.
 3. Do multi-dimensional scaling on dissimilarities.


First, compute the dissimilarity between all pairs of escape profiles in a data frame.
We calculate similarity as the dot product of the escape profile site-level metric for each pair of conditions, normalizing each profile so it's dot product with itself is one.
Then we compute the dissimilarity as just one minux the similarity:

In [5]:
def escape_similarity(df):
    """Compute similarity between all pairs of conditions in `df`."""
    df = df[['condition', 'site', 'antibody escape']].drop_duplicates()
    assert not df.isnull().any().any()
    
    conditions = df['condition'].unique()
    similarities = []
    pivoted_df = (
        df
        .pivot_table(index='site',
                     columns='condition',
                     values='antibody escape',
                     fill_value=0)
        # for normalization: https://stackoverflow.com/a/58113206
        # to get norm: https://stackoverflow.com/a/47953601
        .transform(lambda x: x / numpy.linalg.norm(x, axis=0))
        )
    for cond1, cond2 in itertools.product(conditions, conditions):
        similarity = (
            pivoted_df
            [list({cond1, cond2})]
            .assign(similarity=lambda x: x[cond1] * x[cond2])
            ['similarity']
            )
        assert similarity.notnull().all()  # make sure no sites have null values
        similarities.append(similarity.sum())  # sum of similarities over sites
    return pd.DataFrame(numpy.array(similarities).reshape(len(conditions), len(conditions)),
                        columns=conditions, index=conditions)

similarities = (
    dms_data_tidy
    .groupby('metric')
    .apply(escape_similarity)
    )

dissimilarities = (1 - similarities).clip(lower=0)

dissimilarities.round(3)

Unnamed: 0_level_0,Unnamed: 1_level_0,C002,C105,C110,C121,C135,C144,COV-021,COV-047,COV-057,COV-072,...,subject G (day 18),subject G (day 94),subject H (day 152),subject H (day 61),subject I (day 102),subject I (day 26),subject J (day 121),subject J (day 15),subject K (day 103),subject K (day 29)
metric,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
max of any mutation at site,C002,0.000,0.731,0.736,0.192,0.987,0.132,0.352,0.374,0.558,0.537,...,0.692,0.646,0.446,0.558,0.243,0.278,0.590,0.750,0.621,0.786
max of any mutation at site,C105,0.731,0.000,0.991,0.814,0.985,0.696,0.715,0.669,0.770,0.668,...,0.855,0.706,0.762,0.771,0.669,0.648,0.677,0.792,0.659,0.759
max of any mutation at site,C110,0.736,0.991,0.000,0.722,0.480,0.764,0.739,0.623,0.596,0.716,...,0.392,0.697,0.406,0.511,0.520,0.573,0.542,0.824,0.890,0.882
max of any mutation at site,C121,0.192,0.814,0.722,0.000,0.972,0.084,0.326,0.314,0.522,0.537,...,0.671,0.675,0.527,0.581,0.220,0.301,0.623,0.772,0.629,0.791
max of any mutation at site,C135,0.987,0.985,0.480,0.972,0.000,0.985,0.823,0.731,0.679,0.756,...,0.534,0.693,0.640,0.642,0.710,0.702,0.619,0.842,0.922,0.896
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
sum of mutations at site,subject I (day 26),0.206,0.733,0.705,0.301,0.934,0.295,0.148,0.224,0.299,0.264,...,0.464,0.466,0.281,0.241,0.029,0.000,0.400,0.581,0.429,0.587
sum of mutations at site,subject J (day 121),0.607,0.693,0.546,0.717,0.784,0.659,0.362,0.477,0.572,0.305,...,0.386,0.039,0.276,0.219,0.396,0.400,0.000,0.290,0.400,0.392
sum of mutations at site,subject J (day 15),0.695,0.777,0.792,0.771,0.915,0.686,0.508,0.500,0.584,0.334,...,0.612,0.254,0.501,0.358,0.575,0.581,0.290,0.000,0.283,0.176
sum of mutations at site,subject K (day 103),0.509,0.671,0.870,0.626,0.928,0.423,0.373,0.168,0.248,0.070,...,0.608,0.418,0.472,0.316,0.425,0.429,0.400,0.283,0.000,0.235


Now do the multidimensional scaling [as described here](https://scikit-learn.org/stable/auto_examples/manifold/plot_mds.html#sphx-glr-auto-examples-manifold-plot-mds-py) to get the x and y coordinates for each antibody / serum.
For each metric, we do this for three different random number seeds (different seeds will given different MDS layouts):

In [6]:
mds_coords = []
for seed, (metric, mat) in itertools.product([1, 2, 3], dissimilarities.groupby('metric')):
    # use multidimensional scaling to get locations of antibodies
    mds = sklearn.manifold.MDS(n_components=2,
                               metric=True,
                               max_iter=3000,
                               eps=1e-6,
                               random_state=seed,
                               dissimilarity='precomputed',
                               n_jobs=1)
    locs = mds.fit_transform(mat)
    mds_coords.append(pd.DataFrame(locs, columns=['x', 'y'])
                      .assign(metric=metric,
                              seed=seed,
                              condition=mat.columns,
                              xmin=lambda df: df['x'].min(),
                              ymin=lambda df: df['y'].min(),
                              x=lambda df: df['x'] - df['xmin'],
                              y=lambda df: df['y'] - df['ymin'],
                              )
                      )
mds_coords = (
    pd.concat(mds_coords,
              ignore_index=True)
    .merge(dms_data_tidy
           [['condition', 'condition_type', 'condition_subtype', 'study']]
           .drop_duplicates(),
           on='condition',
           how='left',
           validate='many_to_one')
    )
mds_coords

Unnamed: 0,x,y,metric,seed,condition,xmin,ymin,condition_type,condition_subtype,study
0,0.887426,0.283506,max of any mutation at site,1,C002,-0.683226,-0.637242,antibody,not clinical antibody,2021_Greaney_Rockefeller
1,1.135609,1.079556,max of any mutation at site,1,C105,-0.683226,-0.637242,antibody,not clinical antibody,2021_Greaney_Rockefeller
2,0.210789,0.251851,max of any mutation at site,1,C110,-0.683226,-0.637242,antibody,not clinical antibody,2021_Greaney_Rockefeller
3,0.916591,0.277475,max of any mutation at site,1,C121,-0.683226,-0.637242,antibody,not clinical antibody,2021_Greaney_Rockefeller
4,0.000000,0.461736,max of any mutation at site,1,C135,-0.683226,-0.637242,antibody,not clinical antibody,2021_Greaney_Rockefeller
...,...,...,...,...,...,...,...,...,...,...
445,0.416048,0.773541,sum of mutations at site,3,subject I (day 26),-0.609636,-0.696282,serum,convalescent serum,2021_Greaney_HAARVI_sera
446,0.676483,0.529526,sum of mutations at site,3,subject J (day 121),-0.609636,-0.696282,serum,convalescent serum,2021_Greaney_HAARVI_sera
447,0.838871,0.515844,sum of mutations at site,3,subject J (day 15),-0.609636,-0.696282,serum,convalescent serum,2021_Greaney_HAARVI_sera
448,0.745220,0.743969,sum of mutations at site,3,subject K (day 103),-0.609636,-0.696282,serum,convalescent serum,2021_Greaney_HAARVI_sera


## Make interactive plots
First make MDS dot plot:

In [20]:
# build drop down menu to select metric and random seed
metric_select_binding = alt.binding_select(options=mds_coords['metric'].unique())
metric_selection = alt.selection_single(name='escape',
                                        fields=['metric'],
                                        bind=metric_select_binding,
                                        init={'metric': 'sum of mutations at site'})
seed_select_binding = alt.binding_select(options=mds_coords['seed'].unique())
seed_selection = alt.selection_single(name='random',
                                      fields=['seed'],
                                      bind=seed_select_binding,
                                      init={'seed': 1},
                                      )

highlight_antibody = (
    alt.selection(type='single',
                  on='mouseover',
                  fields=['condition'],
                  nearest=False)
    )

condition_subtypes = (mds_coords
                      .sort_values(['condition_type', 'condition_subtype'])
                      ['condition_subtype']
                      .unique()
                      .tolist()
                      )
# define colors from here: https://vega.github.io/vega/docs/schemes/
condition_subtype_colors = {'clinical antibody': '#0072B2',
                            'not clinical antibody': '#56B4E9',
                            'convalescent serum': '#FD5602',
                            'Moderna serum': '#FFAF42',
                            }
if not set(condition_subtypes).issubset(condition_subtype_colors):
    raise ValueError('missing colors for some condition subtypes')
select_condition_type = alt.selection_multi(fields=['condition_type', 'condition_subtype'])
condition_type_color=alt.condition(select_condition_type,
                                   alt.Color('condition_subtype:N',
                                             legend=None,
                                             scale=alt.Scale(domain=condition_subtypes,
                                                             range=[condition_subtype_colors[c] for c in condition_subtypes]),
                                                             ),
                                   alt.value('white'),
                                   )

# size, but scaled so a unit on x and y mean the same; note
# padding added here so sizes correct
size = 180
pad = 0.04
x_extent = mds_coords['x'].max() - mds_coords['x'].min()
y_extent = mds_coords['y'].max() - mds_coords['y'].min()
y_min = mds_coords['y'].min() - pad * y_extent
y_max = mds_coords['y'].max() + pad * y_extent
x_min = mds_coords['x'].min() - pad * x_extent
x_max = mds_coords['x'].max() + pad * x_extent
circle_size = 110

mds_plot = (
    alt.Chart(mds_coords)
    .encode(x=alt.X('x:Q',
                    scale=alt.Scale(padding=0,
                                    nice=False,
                                    domain=(x_min, x_max),
                                    ),
                    axis=alt.Axis(labels=False),
                    ),
            y=alt.Y('y:Q',
                    scale=alt.Scale(padding=0,
                                    nice=False,
                                    domain=(y_min, y_max),
                                    ),
                    axis=alt.Axis(labels=False),
                    ),
            opacity=alt.condition(~highlight_antibody, alt.value(0.75), alt.value(1)),
            stroke=alt.condition(~highlight_antibody, alt.value(None), alt.value('black')),
            color=condition_type_color,
            tooltip=['condition'])
    .mark_circle(size=circle_size)
    .properties(width=size * x_extent,
                height=size * y_extent)
    .add_selection(seed_selection,
                   metric_selection,
                   highlight_antibody,
                   select_condition_type,
                   )
    .transform_filter(metric_selection)
    .transform_filter(seed_selection)
    .transform_filter(select_condition_type)
    )

legend_condition_type = (
    alt.Chart(mds_coords)
    .mark_circle(size=0.7 * circle_size,
                 stroke='black',
                 strokeWidth=1)
    .encode(x=alt.X('condition_type:N'),
            y=alt.Y('condition_subtype:N',
                    sort=condition_subtypes,
                    ),
            color=condition_type_color,
            )
    .add_selection(select_condition_type)
    .properties(title='shift-click to select multiple categories')
    )

mds_plot_w_legend = (
    (legend_condition_type | mds_plot)
    .configure_axis(grid=False,
                    ticks=False,
                    title=None,
                    labelFontSize=12,
                    )
    .configure_view(stroke='black')
    .configure_title(fontSize=10,
                     fontWeight='normal',
                     fontStyle='italic',
                     anchor='end')
    )

mds_plot_w_legend

Next make line plots:

In [None]:
width = 800

# build zoom bar to zoom in on sites
zoom_brush = alt.selection_interval(
                encodings=['x'],
                mark=alt.BrushConfig(stroke='black', strokeWidth=2))
zoom_bar = (
    alt.Chart(dms_data_tidy)
    .mark_rect(color='lightgray')
    .encode(x='site:O')
    .add_selection(zoom_brush)
    .properties(width=width,
                height=15,
                title='zoom bar')
    )

# build drop down menu to select y-axis on escape plot
# https://github.com/altair-viz/altair/issues/965
y_axis_select_binding = alt.binding_select(options=dms_data_tidy['metric'].unique())
y_axis_selection = alt.selection_single(name='y-axis antibody escape',
                                        fields=['metric'],
                                        bind=y_axis_select_binding)

# build drop down menu to selection condition types on escape plot
condition_types = dms_data_tidy['condition_type'].unique().tolist()
condition_subtypes = dms_data_tidy['condition_subtype'].unique().tolist()
# include option for "all" as here: https://stackoverflow.com/a/62557828
# set initial value as here: https://github.com/altair-viz/altair/issues/1121#issuecomment-496017452
condition_type_binding = alt.binding_select(options=condition_types + [None],
                                            labels=condition_types + ['all'])
condition_type_selection = alt.selection_single(name='data to show',
                                                fields=['condition_type'],
                                                bind=condition_type_binding,
                                                init={'condition_type': 'antibody'},
                                                )

# selector to highlight specific conditions on escape plots
highlight_escape = (  # https://altair-viz.github.io/gallery/multiline_highlight.html
    alt.selection(type='single',
                  on='click',
                  fields=['condition'],
                  nearest=True)
    )

# the escape plots
escape_plot_base = (
    alt.Chart(dms_data_tidy)
    .encode(x='site:O',
            y='antibody escape:Q',
            detail='condition:N',  # https://github.com/altair-viz/altair/issues/985
            )
    )

escape_plot = (
    (escape_plot_base.mark_line(interpolate='step')
                     .encode(size=alt.condition(~highlight_escape, alt.value(1), alt.value(2)),
                             color=alt.condition(~highlight_escape, alt.value('gray'), alt.value('black')),
                             opacity=alt.condition(~highlight_escape, alt.value(0.4), alt.value(1))
                             ) +
     # don't understand need for this dummy mark circle plot (points never show, opacity = 0),
     # but used in example: https://altair-viz.github.io/gallery/multiline_highlight.html
     (escape_plot_base
      .mark_circle()
      .encode(opacity=alt.value(0))
      .add_selection(highlight_escape)
      )
     )
    .interactive(bind_y=False)  # https://github.com/altair-viz/altair/issues/1512#issuecomment-691720690
    .add_selection(y_axis_selection,
                   condition_type_selection,
                   )
    .transform_filter(y_axis_selection)
    .transform_filter(condition_type_selection)
    .transform_filter(zoom_brush)
    .properties(width=width,
                height=200) 
    )

# combine all the elements into a chart
chart = (
    alt.vconcat(zoom_bar, escape_plot)
    .configure_axis(grid=False)
    .configure_view(strokeWidth=0)
    )

# show the chart
chart

Not exactly what I want: https://stackoverflow.com/questions/61364509/altair-cant-create-combination-of-selections

Double selections: https://stackoverflow.com/questions/59982370/altair-double-dropdown-menu

Currently dealing with multiple selections; see this bug: https://github.com/altair-viz/altair/issues/1759

In [None]:
dms_data_tidy['condition_subtype'].unique().tolist()

In [None]:
condition_types

In [None]:
from vega_datasets import data

source = data.unemployment_across_industries.url

selection = alt.selection_multi(fields=['series'], bind='legend')

alt.Chart(source).mark_area().encode(
    alt.X('yearmonth(date):T', axis=alt.Axis(domain=False, format='%Y', tickSize=0)),
    alt.Y('sum(count):Q', stack='center', axis=None),
    alt.Color('series:N', scale=alt.Scale(scheme='category20b')),
    opacity=alt.condition(selection, alt.value(1), alt.value(0.2))
).add_selection(
    selection
)