# Plot yeast RBD DMS escape maps

## Import modules and read data
Import Python modules:

In [1]:
import itertools
import os

import altair as alt

import numpy

import pandas as pd

import sklearn.manifold

Disable max rows specifier for Altair:

In [2]:
_ = alt.data_transformers.disable_max_rows()

Read the deep mutational scanning data, reduce to site-level data:

In [3]:
# only site level data for escape and drop entries with 0 escape which we impute below
dms_data = (
    pd.read_csv('./processed_data/escape_data.csv', low_memory=False)
    .assign(condition_alias=lambda x: x['condition_alias'].fillna(''))
    .rename(columns={"site_mean_escape": "escape"})
    .query("escape != 0")
    [['condition', 'condition_alias', 'condition_type', 'condition_subtype',
      'eliciting_virus', 'known_to_neutralize', 'study', 'lab', 'site', "escape"]]
    .drop_duplicates()
)

# get list of all sites
sites = list(range(dms_data['site'].min(), dms_data['site'].max() + 1))

# for duplicated conditions, add lab to name
dms_data = (
    dms_data
    .assign(
        n_studies=lambda x: x.groupby('condition')['study'].transform('nunique'),
        condition=lambda x: x['condition'].where(
            x['n_studies'] == 1,
            x['condition'] + ' (' + x['lab'] + ')'
        ),
    )
    .drop(columns='n_studies')
)

# for duplicated conditions within lab, keep one with more known_to_neutralize details
print(f"Before de-duplicating we have {len(dms_data.groupby(['condition', 'study']))} conditions")
dms_data = (
    dms_data
    .assign(n_known_to_neutralize=lambda x: x["known_to_neutralize"].str.count(";") + 1)
    .sort_values("n_known_to_neutralize")
    .groupby("condition", as_index=False)
    .aggregate({"study": "last"})
    .merge(dms_data)
)
print(f"After de-duplicating we have {len(dms_data.groupby(['condition', 'study']))} conditions")

assert len(dms_data) == len(dms_data.groupby(['condition', 'site']))

# split out known_to_neutralize and eliciting virus
dms_data = dms_data.assign(
    known_to_neutralize=lambda x: x["known_to_neutralize"].str.split(";").map(tuple),
    eliciting_virus=lambda x: x["eliciting_virus"].str.split(";").map(tuple),
)

dms_data

Before de-duplicating we have 1800 conditions
After de-duplicating we have 1622 conditions


Unnamed: 0,condition,study,condition_alias,condition_type,condition_subtype,eliciting_virus,known_to_neutralize,lab,site,escape
0,1-57 (Xie_XS),2022_Cao_Omicron,,antibody,class 3,"(SARS-CoV-2,)","(Wuhan-Hu-1,)",Xie_XS,338,0.059400
1,1-57 (Xie_XS),2022_Cao_Omicron,,antibody,class 3,"(SARS-CoV-2,)","(Wuhan-Hu-1,)",Xie_XS,370,0.042170
2,1-57 (Xie_XS),2022_Cao_Omicron,,antibody,class 3,"(SARS-CoV-2,)","(Wuhan-Hu-1,)",Xie_XS,396,0.028930
3,1-57 (Xie_XS),2022_Cao_Omicron,,antibody,class 3,"(SARS-CoV-2,)","(Wuhan-Hu-1,)",Xie_XS,444,0.048420
4,1-57 (Xie_XS),2022_Cao_Omicron,,antibody,class 3,"(SARS-CoV-2,)","(Wuhan-Hu-1,)",Xie_XS,445,0.148500
...,...,...,...,...,...,...,...,...,...,...
46786,subject K (day 29),2021_Greaney_HAARVI_sera,,serum,convalescent serum,"(SARS-CoV-2,)","(Wuhan-Hu-1,)",Bloom_JD,527,0.003514
46787,subject K (day 29),2021_Greaney_HAARVI_sera,,serum,convalescent serum,"(SARS-CoV-2,)","(Wuhan-Hu-1,)",Bloom_JD,528,0.006132
46788,subject K (day 29),2021_Greaney_HAARVI_sera,,serum,convalescent serum,"(SARS-CoV-2,)","(Wuhan-Hu-1,)",Bloom_JD,529,0.019900
46789,subject K (day 29),2021_Greaney_HAARVI_sera,,serum,convalescent serum,"(SARS-CoV-2,)","(Wuhan-Hu-1,)",Bloom_JD,530,0.014330


Get a data frame with just the conditions and their citations:

In [4]:
conditions_df = (
    dms_data
    [['condition_type', 'condition_subtype', 'condition', 'condition_alias',
      'eliciting_virus', 'study', 'lab', 'known_to_neutralize']]
    .sort_values(['condition_type', 'condition_subtype', 'condition'])
    .drop_duplicates()
    .reset_index(drop=True)
    )

conditions_df

Unnamed: 0,condition_type,condition_subtype,condition,condition_alias,eliciting_virus,study,lab,known_to_neutralize
0,antibody,class 1,15033 (Xie_XS),,"(SARS-CoV-2,)",2022_Cao_Omicron,Xie_XS,"(Wuhan-Hu-1,)"
1,antibody,class 1,2H2,,"(SARS-CoV-2,)",2022_Cao_BA2-4-5,Xie_XS,"(Wuhan-Hu-1, Omicron BA.12.1)"
2,antibody,class 1,B38 (Xie_XS),,"(SARS-CoV-2,)",2022_Cao_Omicron,Xie_XS,"(Wuhan-Hu-1,)"
3,antibody,class 1,BD-236,,"(SARS-CoV-2,)",2022_Cao_Omicron,Xie_XS,"(Wuhan-Hu-1,)"
4,antibody,class 1,BD-319,,"(SARS-CoV-2,)",2022_Cao_Omicron,Xie_XS,"(Wuhan-Hu-1,)"
...,...,...,...,...,...,...,...,...
1617,serum,convalescent serum,subject I (day 26),,"(SARS-CoV-2,)",2021_Greaney_HAARVI_sera,Bloom_JD,"(Wuhan-Hu-1,)"
1618,serum,convalescent serum,subject J (day 121),,"(SARS-CoV-2,)",2021_Greaney_HAARVI_sera,Bloom_JD,"(Wuhan-Hu-1,)"
1619,serum,convalescent serum,subject J (day 15),,"(SARS-CoV-2,)",2021_Greaney_HAARVI_sera,Bloom_JD,"(Wuhan-Hu-1,)"
1620,serum,convalescent serum,subject K (day 103),,"(SARS-CoV-2,)",2021_Greaney_HAARVI_sera,Bloom_JD,"(Wuhan-Hu-1,)"


## Perform multidimensional scaling
Steps:
 1. Calculate similarities betweeen escape maps for each antibody.
 2. Convert similarities to dissimilarities.
 3. Do multi-dimensional scaling on dissimilarities.


First, compute the dissimilarity between all pairs of escape profiles in a data frame.
We calculate similarity as the dot product of the escape profile site-level metric for each pair of conditions, normalizing each profile so it's dot product with itself is one.
Then we compute the dissimilarity as just one minux the similarity:

In [5]:
def escape_similarity(df):
    """Compute similarity between all pairs of conditions in `df`."""
    df = df[['condition', 'site', 'escape']].drop_duplicates()
    assert not df.isnull().any().any(), df
    
    pivoted_df = (
        df
        .pivot_table(index='site',
                     columns='condition',
                     values='escape',
                     fill_value=0)
        # for normalization: https://stackoverflow.com/a/58113206
        # to get norm: https://stackoverflow.com/a/47953601
        .transform(lambda x: x / numpy.linalg.norm(x, axis=0))
        )
    conditions = pivoted_df.columns.tolist()
    arr = pivoted_df.values.transpose()
    similarities = [x.dot(y).sum() for x in arr for y in arr]
    return pd.DataFrame(numpy.array(similarities).reshape(len(conditions), len(conditions)),
                        columns=conditions, index=conditions)

similarities = escape_similarity(dms_data)

assert similarities.notnull().any().any()

dissimilarities = (1 - similarities).clip(lower=0)

dissimilarities.round(3)

Unnamed: 0,1-57 (Xie_XS),15033 (Xie_XS),2-15 (Xie_XS),2H2,3C1,ADG-2,B38 (Xie_XS),BD-236,BD-254,BD-319,...,subject G (day 18),subject G (day 94),subject H (day 152),subject H (day 61),subject I (day 102),subject I (day 26),subject J (day 121),subject J (day 15),subject K (day 103),subject K (day 29)
1-57 (Xie_XS),0.000,1.000,1.000,1.000,0.999,0.934,1.000,1.000,1.000,1.000,...,0.377,0.811,0.653,0.750,0.714,0.737,0.708,0.915,0.955,0.971
15033 (Xie_XS),1.000,0.000,0.910,0.945,1.000,0.990,0.736,0.886,0.854,0.247,...,0.965,0.785,0.779,0.814,0.582,0.549,0.789,0.937,0.933,0.969
2-15 (Xie_XS),1.000,0.910,0.000,0.819,1.000,0.928,1.000,0.984,0.454,0.863,...,0.835,0.833,0.646,0.750,0.427,0.495,0.838,0.929,0.901,0.959
2H2,1.000,0.945,0.819,0.000,0.971,0.985,0.761,0.523,0.630,0.845,...,0.766,0.758,0.716,0.708,0.468,0.479,0.726,0.794,0.572,0.893
3C1,0.999,1.000,1.000,0.971,0.000,0.293,1.000,1.000,1.000,1.000,...,0.977,0.887,0.973,0.971,0.984,0.980,0.900,0.956,0.965,0.958
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
subject I (day 26),0.737,0.549,0.495,0.479,0.980,0.916,0.626,0.548,0.500,0.506,...,0.427,0.485,0.296,0.336,0.028,0.000,0.408,0.592,0.491,0.687
subject J (day 121),0.708,0.789,0.838,0.726,0.900,0.825,0.686,0.704,0.840,0.770,...,0.436,0.041,0.292,0.270,0.425,0.408,0.000,0.460,0.492,0.622
subject J (day 15),0.915,0.937,0.929,0.794,0.956,0.948,0.876,0.845,0.932,0.936,...,0.473,0.476,0.400,0.206,0.643,0.592,0.460,0.000,0.101,0.083
subject K (day 103),0.955,0.933,0.901,0.572,0.965,0.952,0.780,0.698,0.905,0.936,...,0.479,0.540,0.385,0.203,0.542,0.491,0.492,0.101,0.000,0.098


Now do the multidimensional scaling [as described here](https://scikit-learn.org/stable/auto_examples/manifold/plot_mds.html#sphx-glr-auto-examples-manifold-plot-mds-py) to get the x and y coordinates for each antibody / serum.
For each metric, we do this for several random number seeds (different seeds will given different MDS layouts):

In [6]:
mds_coords = []
seeds = [1, 2]
for seed in seeds:
    # use multidimensional scaling to get locations of antibodies
    mds = sklearn.manifold.MDS(n_components=2,
                               metric=True,
                               max_iter=3000,
                               eps=1e-6,
                               random_state=seed,
                               dissimilarity='precomputed',
                               n_jobs=1)
    locs = mds.fit_transform(dissimilarities)
    mds_coords.append(pd.DataFrame(locs, columns=['x', 'y'])
                      .assign(seed=seed,
                              condition=dissimilarities.columns,
                              xmin=lambda df: df['x'].min(),
                              ymin=lambda df: df['y'].min(),
                              x=lambda df: df['x'] - df['xmin'],
                              y=lambda df: df['y'] - df['ymin'],
                              )
                      )
mds_coords = (
    pd.concat(mds_coords,
              ignore_index=True)
    .merge(conditions_df,
           on='condition',
           how='left',
           validate="many_to_one",
           )
    .drop(columns=['xmin', 'ymin', "study", "condition_alias", "condition_type"])
    )

mds_coords.round(3)

Unnamed: 0,x,y,seed,condition,condition_subtype,eliciting_virus,lab,known_to_neutralize
0,0.870,1.517,1,1-57 (Xie_XS),class 3,"(SARS-CoV-2,)",Xie_XS,"(Wuhan-Hu-1,)"
1,0.168,1.134,1,15033 (Xie_XS),class 1,"(SARS-CoV-2,)",Xie_XS,"(Wuhan-Hu-1,)"
2,0.597,1.372,1,2-15 (Xie_XS),class 2,"(SARS-CoV-2,)",Xie_XS,"(Wuhan-Hu-1,)"
3,0.514,0.958,1,2H2,class 1,"(SARS-CoV-2,)",Xie_XS,"(Wuhan-Hu-1, Omicron BA.12.1)"
4,0.216,0.330,1,3C1,class 4,"(SARS-CoV-2,)",Xie_XS,"(Wuhan-Hu-1,)"
...,...,...,...,...,...,...,...,...
3239,1.191,0.779,2,subject I (day 26),convalescent serum,"(SARS-CoV-2,)",Bloom_JD,"(Wuhan-Hu-1,)"
3240,0.865,0.784,2,subject J (day 121),convalescent serum,"(SARS-CoV-2,)",Bloom_JD,"(Wuhan-Hu-1,)"
3241,0.635,0.887,2,subject J (day 15),convalescent serum,"(SARS-CoV-2,)",Bloom_JD,"(Wuhan-Hu-1,)"
3242,0.823,0.878,2,subject K (day 103),convalescent serum,"(SARS-CoV-2,)",Bloom_JD,"(Wuhan-Hu-1,)"


## Read information on studies and merge into conditions data frame

In [7]:
studies = pd.read_csv('processed_data/studies.csv')

studies

Unnamed: 0,study,citation,url
0,2021_Dong_AZ,Dong et al. Nat Micro (2021),https://www.nature.com/articles/s41564-021-009...
1,2021_Greaney_Crowe_Abs,Greaney et al. Cell Host Microbe (2021a),https://www.sciencedirect.com/science/article/...
2,2021_Greaney_HAARVI_sera,Greaney et al. Cell Host Microbe (2021b),https://www.sciencedirect.com/science/article/...
3,2021_Greaney_COV2-2955,Greaney et al. NA (2021),https://github.com/jbloomlab/SARS-CoV-2-RBD_MA...
4,2021_Greaney_Rockefeller,Greaney et al. Nat Comm (2021),https://www.nature.com/articles/s41467-021-244...
5,2021_Greaney_Moderna,Greaney et al. Sci Transl Med (2021),https://stm.sciencemag.org/content/13/600/eabi...
6,2021_Starr_LY-CoV555,Starr et al. Cell Reports Medicine (2021),https://doi.org/10.1016/j.xcrm.2021.100255
7,2021_Starr_Vir,Starr et al. Nature (2021),https://www.nature.com/articles/s41586-021-038...
8,2021_Starr_REGN,Starr et al. Science (2021),https://science.sciencemag.org/content/early/2...
9,2021_Tortorici_S2X259,Tortorici et al. Nature (2021),https://www.nature.com/articles/s41586-021-038...


In [8]:
conditions_df = (
    conditions_df
    .drop(columns=['citation', 'url'], errors='ignore')
    .merge(studies, how='left', on='study', validate='many_to_one')
    )

conditions_df.head()

Unnamed: 0,condition_type,condition_subtype,condition,condition_alias,eliciting_virus,study,lab,known_to_neutralize,citation,url
0,antibody,class 1,15033 (Xie_XS),,"(SARS-CoV-2,)",2022_Cao_Omicron,Xie_XS,"(Wuhan-Hu-1,)",Cao et al. Nature (2022),https://www.nature.com/articles/s41586-021-043...
1,antibody,class 1,2H2,,"(SARS-CoV-2,)",2022_Cao_BA2-4-5,Xie_XS,"(Wuhan-Hu-1, Omicron BA.12.1)",Cao et al. bioRxiv (2022),https://www.biorxiv.org/content/10.1101/2022.0...
2,antibody,class 1,B38 (Xie_XS),,"(SARS-CoV-2,)",2022_Cao_Omicron,Xie_XS,"(Wuhan-Hu-1,)",Cao et al. Nature (2022),https://www.nature.com/articles/s41586-021-043...
3,antibody,class 1,BD-236,,"(SARS-CoV-2,)",2022_Cao_Omicron,Xie_XS,"(Wuhan-Hu-1,)",Cao et al. Nature (2022),https://www.nature.com/articles/s41586-021-043...
4,antibody,class 1,BD-319,,"(SARS-CoV-2,)",2022_Cao_Omicron,Xie_XS,"(Wuhan-Hu-1,)",Cao et al. Nature (2022),https://www.nature.com/articles/s41586-021-043...


## Make interactive plots
First make plot to select condition(s) to show:

In [9]:
condition_subtypes = (conditions_df
                      ['condition_subtype']
                      .unique()
                      .tolist()
                      )

# define colors from here: https://vega.github.io/vega/docs/schemes/
# similar to Greaney et al antibody class papers
condition_subtype_colors = {'class 1': '#E52794',
                            'class 2': '#6A0DAD',
                            'class 3': '#66CCEE',
                            'class 4': '#E69F00',
                            # greens from https://www.rapidtables.com/web/color/green-color.html
                            'convalescent serum': '#006400', 
                            'Moderna vaccine serum': '#98FB98',
                            'B.1.351 convalescent plasma': '#808000',
                            }
if not set(condition_subtypes).issubset(condition_subtype_colors):
    raise ValueError('missing colors for some condition subtypes')
select_condition_subtype = alt.selection_point(fields=['condition_subtype'],
                                               # initialize to show antibodies but not sera
                                               value=[{'condition_subtype': subtype} for subtype in
                                                      conditions_df.query('condition_type == "antibody"')
                                                      ['condition_subtype'].unique()],
                                               resolve='union',
                                               empty=True,
                                               )
condition_subtype_color = alt.condition(select_condition_subtype,
                                   alt.Color('condition_subtype:N',
                                             legend=None,
                                             scale=alt.Scale(domain=condition_subtypes,
                                                             range=[condition_subtype_colors[c]
                                                                    for c in condition_subtypes]),
                                                             ),
                                   alt.value('white'),
                                   )

circle_size = 110

legend_condition_type = (
    alt.Chart(conditions_df[['condition_type', 'condition_subtype']].drop_duplicates())
    .mark_circle(size=0.7 * circle_size,
                 stroke='black',
                 strokeWidth=1)
    .encode(x=alt.X('condition_type:N',
                    axis=alt.Axis(title=['',
                                         'On each subplot, you can:',
                                         ' - click to select one item',
                                         ' - shift-click to select additional items',
                                         ' - double-click to clear selected items',
                                         ' - mouseover to see antibody/serum name',
                                         ],
                                  titleAlign='left',
                                  titleFontSize=14,
                                  titleFontWeight='normal',
                                  titleFontStyle='italic',
                                  labelFontSize=12),
                    ),
            y=alt.Y('condition_subtype:N',
                    sort=condition_subtypes,
                    axis=alt.Axis(title=None,
                                  labelFontSize=12,
                                  orient='right'),
                    ),
            color=condition_subtype_color,
            )
    .add_parameter(select_condition_subtype)
    .properties(title={'text': 'choose antibody/serum types to display',
                       'align': 'left',
                       'anchor': 'start'})
    )

(legend_condition_type).configure_view(strokeOpacity=0)

Encode the conditions as integers and then lookup details.
Needed to avoid some unclear problem when sorting:

In [10]:
encoded_conditions_df = (
    conditions_df
    .drop(columns="condition_type")
    .reset_index(drop=True)
    .assign(encoding=lambda x: x.index)
)

condition_encodings = encoded_conditions_df[["encoding"]]
assert len(condition_encodings) == condition_encodings["encoding"].nunique()

Make plot:

In [29]:
eliciting_viruses = sorted(dms_data.explode("eliciting_virus")["eliciting_virus"].unique())
eliciting_virus_selection = alt.selection_point(
    fields=['eliciting_virus'],
    bind=alt.binding_select(
        options=[None] + eliciting_viruses,
        labels=['all'] + eliciting_viruses,
        name="eliciting virus",
    ),
    value=[{'eliciting_virus': 'SARS-CoV-2'}]
)

labs = sorted(dms_data['lab'].unique())
lab_selection = alt.selection_point(
    fields=['lab'],
    bind=alt.binding_select(
        options=[None] + labs,
        labels=['all'] + labs,
        name="lab",
    ),
    value=[{"lab": "Bloom_JD"}]
)

known_to_neutralize_options = (
    sorted(dms_data.explode("known_to_neutralize")["known_to_neutralize"].unique())
)
known_to_neutralize_selection = alt.selection_point(
    fields=['known_to_neutralize'],
    bind=alt.binding_select(
        options=[None, *known_to_neutralize_options],
        labels=["all", *known_to_neutralize_options],
        name="known to neutralize",
   ),
)

highlight_condition = alt.selection_point(
    on='click',
    fields=['condition'],
    nearest=False,
    empty=False,
    toggle=True,
    resolve='union',
    value=[{"condition": ""}]
)

cell_height = 17  # size of cells in heat map

conditions_data = (
    alt.Chart(condition_encodings)
    .transform_lookup(
        lookup="encoding",
        from_=alt.LookupData(
            data=encoded_conditions_df,
            key="encoding",
            fields=["known_to_neutralize"],
        )
    )
    .transform_flatten(["known_to_neutralize"])
    .transform_filter(known_to_neutralize_selection)
    .transform_aggregate(encoding="mean(encoding)", groupby=["encoding"])
    .transform_lookup(
        lookup="encoding",
        from_=alt.LookupData(
            data=encoded_conditions_df,
            key="encoding",
            fields=["eliciting_virus"],
        )
    )
    .transform_flatten(["eliciting_virus"])
    .transform_filter(eliciting_virus_selection)
    .transform_aggregate(encoding="mean(encoding)", groupby=["encoding"])
    .transform_lookup(
        lookup="encoding",
        from_=alt.LookupData(
            data=encoded_conditions_df,
            key="encoding",
            fields=[
                c for c in encoded_conditions_df.columns
                if c not in {"encoding", "known_to_neutralize", "eliciting_virus"}
            ],
        )
    )
    .transform_filter(select_condition_subtype)
    .transform_filter(lab_selection)
)

# build zoom bar to zoom in condition legend
legend_condition_zoom_brush = alt.selection_interval(
                encodings=['y'],
                mark=alt.BrushConfig(stroke='black', strokeWidth=2),
                )
legend_condition_zoom_bar = (
    conditions_data
    .mark_rect()
    .encode(y=alt.Y("condition:N",
                    title='antibody / sera zoom bar',
                    axis=alt.Axis(ticks=False,
                                  labels=False,
                                  titleFontSize=12),
                    scale=alt.Scale(nice=False, zero=False),
                    sort=alt.EncodingSortField("encoding"),
                    ),
            color=condition_subtype_color,
            )
    .add_parameter(legend_condition_zoom_brush)
    .properties(height=175, width=15)
    )

condition_base = (
    conditions_data
    .add_parameter(select_condition_subtype,
                   highlight_condition,
                   known_to_neutralize_selection,
                   eliciting_virus_selection,
                   lab_selection,
                   legend_condition_zoom_brush)
    .transform_filter(legend_condition_zoom_brush)
    .properties(height={'step': cell_height},
                width=cell_height,
                )
    )

legend_condition_heatmap = (
    condition_base
    .encode(y=alt.Y('condition:N',
                    sort=alt.EncodingSortField("encoding"),
                    title=None,
                    axis=alt.Axis(orient='right',
                                  labelFontSize=11,
                                  ),
                    ),
            color=condition_subtype_color,
            strokeWidth=alt.condition(~highlight_condition,
                                      alt.value(0.5),
                                      alt.value(3)),
            stroke=alt.condition(~highlight_condition,
                                 alt.value('black'),
                                 alt.value('black')),
            )
    .mark_rect()
    )

condition_citations = (
    condition_base
    .encode(y=alt.Y('condition:N',
                    sort=alt.EncodingSortField("encoding"),
                    title=None,
                    axis=None,
                    ),
            text='citation:N',
            href='url:N'
            )
    .mark_text(align='left',
               fontSize=11,
               fontStyle='normal',
               color='darkblue',
               )
    )

condition_alias = (
    condition_base
    .encode(y=alt.Y('condition:N',
                    sort=alt.EncodingSortField("encoding"),
                    title=None,
                    axis=None,
                    ),
            text='condition_alias:N',
            )
    .mark_text(text='dms-view',
               align='left',
               fontSize=11,
               fontStyle='normal',
               )
    )

legend_condition = (
    (legend_condition_zoom_bar | alt.hconcat(legend_condition_heatmap,
                                             condition_citations,
                                             condition_alias,
                                             spacing=2)
     )
    .properties(title={'text': ['select antibody/serum by by clicking box; shift-click',
                                'citation or dms-view text to open that information']})
    )

(legend_condition_type | legend_condition).configure_view(strokeOpacity=0)

Next make MDS plot:

In [30]:
# first add conditions encoding
mds_coords = (
    mds_coords
    .drop(columns="encoding", errors="ignore")
    .merge(
        encoded_conditions_df[["condition", "encoding"]],
        validate="many_to_one",
        how="outer",
    )
)
assert mds_coords.notnull().all().all()

In [None]:
seed_select_binding = alt.binding_select(options=mds_coords['seed'].unique())
seed_selection = alt.selection_point(name='multidimensional scaling random',
                                     fields=['seed'],
                                     bind=seed_select_binding,
                                     value=[{'seed': 2}],
                                     )

# size, but scaled so a unit on x and y mean the same; note
# padding added here so sizes correct
size = 180
pad = 0.04
x_extent = mds_coords['x'].max() - mds_coords['x'].min()
y_extent = mds_coords['y'].max() - mds_coords['y'].min()
y_min = mds_coords['y'].min() - pad * y_extent
y_max = mds_coords['y'].max() + pad * y_extent
x_min = mds_coords['x'].min() - pad * x_extent
x_max = mds_coords['x'].max() + pad * x_extent

mds_plot = (
    alt.Chart(condition_encodings)
    .transform_lookup(
        lookup="encoding",
        from_=alt.LookupData(
            data=mds_coords,
            key="encoding",
            fields=["known_to_neutralize"],
        )
    )
    .transform_flatten(["known_to_neutralize"])
    .transform_filter(known_to_neutralize_selection)
    .transform_aggregate(encoding="mean(encoding)", groupby=["encoding"])
    .transform_lookup(
        lookup="encoding",
        from_=alt.LookupData(
            data=mds_coords,
            key="encoding",
            fields=["eliciting_virus"],
        )
    )
    .transform_flatten(["eliciting_virus"])
    .transform_filter(eliciting_virus_selection)
    .transform_aggregate(encoding="mean(encoding)", groupby=["encoding"])
    .transform_lookup(
        lookup="encoding",
        from_=alt.LookupData(
            data=mds_coords,
            key="encoding",
            fields=[
                c for c in mds_coords.columns
                if c not in {"encoding", "known_to_neutralize", "eliciting_virus"}
            ],
        )
    )
    .transform_filter(lab_selection)
    .transform_filter(seed_selection)
    .transform_filter(select_condition_subtype)
    .encode(x=alt.X('x:Q',
                    scale=alt.Scale(padding=0,
                                    nice=False,
                                    domain=(x_min, x_max),
                                    ),
                    axis=alt.Axis(labels=False,
                                  title=None,
                                  ticks=False,
                                  grid=False,
                                  ),
                    ),
            y=alt.Y('y:Q',
                    scale=alt.Scale(padding=0,
                                    nice=False,
                                    domain=(y_min, y_max),
                                    ),
                    axis=alt.Axis(labels=False,
                                  title=None,
                                  ticks=False,
                                  grid=False,
                                  ),
                    ),
            opacity=alt.condition(~highlight_condition, alt.value(0.3), alt.value(1)),
            stroke=alt.condition(~highlight_condition, alt.value(None), alt.value('black')),
            color=condition_subtype_color,
            tooltip=['condition:N'])
    .mark_circle(size=circle_size)
    .properties(width=size * x_extent,
                height=size * y_extent,
                title={'text': 'multidimensional scaling of antibodies/sera',
                       'subtitle': ['antibodies/sera with escape mutations at similar',
                                    'sites are positioned nearby in the plot below'],
                       'anchor': 'start',
                       'align': 'left',
                       }
                )
    .add_parameter(seed_selection,
                   highlight_condition,
                   select_condition_subtype,
                   known_to_neutralize_selection,
                   eliciting_virus_selection,
                   lab_selection,
                   )
    )

# box around MDS plot: https://stackoverflow.com/a/62862229/4191652
dummy_lines = {}
for key, x, y in [('top', (x_min, x_max), (y_max, y_max)),
                  ('right', (x_max, x_max), (y_min, y_max)),
                  ]:
    dummy_lines[key] = (
        alt.Chart(pd.DataFrame({'x': x,
                                'y': y})
                  )
        .mark_line(color='black',
                   strokeWidth=0.5)
        .encode(x=alt.X('x:Q',
                        scale=alt.Scale(padding=0,
                                        nice=False,
                                        domain=(x_min, x_max),
                                        ),
                        axis=alt.Axis(labels=False,
                                      title=None,
                                      ticks=False,
                                      grid=False,
                                      ),
                        ),
                y=alt.Y('y:Q',
                        scale=alt.Scale(padding=0,
                                        nice=False,
                                        domain=(y_min, y_max),
                                        ),
                        axis=alt.Axis(labels=False,
                                      title=None,
                                      ticks=False,
                                      grid=False,
                                      ),
                        )
                )
        )
mds_plot = mds_plot + dummy_lines['top'] + dummy_lines['right']

# show the plot with legend
(
    (legend_condition_type | mds_plot | legend_condition)
    .configure_view(stroke='black')
    .configure_view(strokeOpacity=0)
)

Next make line plots.
First, encode everything other than the actual site / escape values as in integer that we can lookup transform to the condition (antibody/sera) level values.
This dramatically shrinks size of the data:

In [None]:
encoded_cols = [col for col in dms_data_tidy.columns if col not in {"site", "escape"}]

encoding = (
    dms_data_tidy[encoded_cols]
    .drop_duplicates()
    .assign(encoding=lambda x: x.reset_index().index)
)

dms_data_tidy_encoded = dms_data_tidy.merge(encoding)[["encoding", "site", "escape"]]

display(encoding.head())
display(dms_data_tidy_encoded.head())

Now make plot:

In [None]:
width = 800

# build zoom bar to zoom in on sites
zoom_brush = alt.selection_interval(
                encodings=['x'],
                mark=alt.BrushConfig(stroke='black', strokeWidth=2))
zoom_bar = (
    alt.Chart(dms_data_tidy_encoded[['site']].drop_duplicates())
    .mark_rect(color='lightgray')
    .encode(x=alt.X('site:Q',
                    title=None,
                    scale=alt.Scale(nice=False, zero=False),
                    ),
            )
    .add_parameter(zoom_brush)
    .properties(width=width,
                height=15,
                title='site zoom bar')
    )

# build base for escape plots
escape_base = (
    alt.Chart(dms_data_tidy_encoded)
    .transform_lookup(
        lookup="encoding",
        from_=alt.LookupData(data=encoding, key="encoding", fields=encoded_cols)
    )
    .transform_flatten(["known_to_neutralize"])
    .transform_calculate(mean_over="1")
    .encode(x=alt.X('site:Q',
                    axis=alt.Axis(grid=False),
                    scale=alt.Scale(nice=False, zero=False)
                    ),
            )
    .transform_filter(eliciting_virus_selection)
    .transform_filter(known_to_neutralize_selection)
    .transform_filter(lab_selection)
    .transform_filter(metric_selection)
    .transform_filter(select_condition_subtype)
    .transform_filter(zoom_brush)
    .properties(width=width,
                height=200,
                )
    )

# the escape line plot
escape_lines = (
    escape_base
    .encode(size=alt.condition(~highlight_condition, alt.value(0.9), alt.value(1.5)),
            opacity=alt.condition(~highlight_condition, alt.value(0.4), alt.value(1)),
            )
    .add_parameter(known_to_neutralize_selection,
                   eliciting_virus_selection,
                   lab_selection,
                   metric_selection,
                   select_condition_subtype,
                   zoom_brush,
                   )
    .mark_line()
    )

# escape point plot
escape_points = (
    escape_base
    .encode(fill=condition_subtype_color,
            tooltip=['condition:N', 'site:Q'],
            )
    .mark_point(size=40)
    .transform_filter(highlight_condition)
    # needs to be add_parameter within chart: https://github.com/altair-viz/altair/issues/2368#issuecomment-742377146
    .add_parameter(highlight_condition)
    )

# combine point and line plots
escape_lines_points = (
    (escape_lines + escape_points)
    .encode(detail='condition:N',  # https://github.com/altair-viz/altair/issues/985
            color=condition_subtype_color,
            y=alt.Y('escape:Q',
                    axis=alt.Axis(grid=False),
                    ),
            )
    .properties(title={'text': 'escape from individual antibodies/sera'})
    )

# checkbox to specify if mean for only selected antibodies or all antibody/serum types
mean_radio = alt.binding_radio(
    options=[1, 0],
    labels=["all displayed types", "just selected antibodies/sera"],
)
mean_selection = alt.selection_point(fields=['mean_over'],
                                     bind=mean_radio,
                                     name='calculate',
                                     value=[{'mean_over': 1}])
# plot of mean values
escape_mean = (
    escape_base
    .mark_line(color='darkgray',
               point={'color': 'darkgray',
                      'size': 60},
               )
    .encode(tooltip=['site:Q',
                     alt.Tooltip('mean(escape):Q',
                                 format='.2g',
                                 title='escape'),
                     ],
            y=alt.Y('mean(escape):Q',
                    axis=alt.Axis(grid=False,
                                  title='escape',
                                  ),
                    ),
            )
    .transform_filter(highlight_condition | (select_condition_subtype & mean_selection))
    .add_parameter(highlight_condition,
                   mean_selection,
                   )
    .properties(title={'text': 'mean escape over selected antibodies/sera or ' +
                               'antibody/serum types (choose with radio button below)'
                       })
    )

# combine zoom bar, lines, and points
escape_plot = (zoom_bar & (escape_lines_points & escape_mean).resolve_scale(x='shared'))

escape_plot

Now combine the antibody MDS and escape plots:

In [None]:
chart = (
    (((legend_condition_type | mds_plot) & escape_plot) | legend_condition)
    .configure(padding={'left': 5,
                        'right': 60,
                        'top': 5,
                        'bottom': 5})
    .configure_view(strokeOpacity=0)
    )

chartfile = 'docs/_includes/chart.html'
os.makedirs(os.path.dirname(chartfile), exist_ok=True)
print(f"Saving chart to {chartfile}")
chart.save(chartfile)

chart