# Plot yeast RBD DMS escape calculator

## Import modules and read data
Import Python modules:

In [1]:
import itertools
import os
import re

import altair as alt

import numpy

import pandas as pd

Disable max rows specifier for Altair:

In [2]:
_ = alt.data_transformers.disable_max_rows()

Read the deep mutational scanning data, reduce data suitable for plotting.
Also write the mutation- and site-level data to a file.
Note that these data are from [a single study by Cao et al](https://www.biorxiv.org/content/10.1101/2022.09.15.507787v1), and we classify each antibody as neutralizing or not against each virus:

In [3]:
escape_data = (
    pd.read_csv("data/2022_Cao_convergent/use_res_clean.csv")
    .rename(columns={"antibody": "condition"})
    .query("mut_escape != 0")
)

ic50_ceil = 10  # max IC50
antibody_data = (
    pd.read_csv("data/2022_Cao_convergent/antibody_info.csv")
    .rename(columns={"Antibody  Name": "condition"})
    .melt(
        id_vars=["condition", "source"],
        value_vars=["D614G", "BA.1", "BA.2", "BA.2.75", "BA.5"],
        var_name="target_virus",
        value_name="IC50",
    )
    .query("IC50 != '--'")  # neut data not available for this antibody
    .assign(
        IC50=lambda x: x["IC50"].map(
            lambda ic: ic50_ceil if ic == f">{ic50_ceil}" else float(ic)
        ),
        neg_log_IC50=lambda x: -numpy.log(x["IC50"] / ic50_ceil),
    )
)
# check to make sue largest IC50 is indeed <= the ceiling
all_ic50s = antibody_data["IC50"]
assert pd.notnull(all_ic50s).all()
max_ic50 = max(all_ic50s[pd.notnull(all_ic50s)])
assert max_ic50 <= ic50_ceil and numpy.allclose(max_ic50, ic50_ceil, rtol=1e-3)

# Now merge antibody and escape data
n_missing_antibody = len(set(escape_data["condition"]) - set(antibody_data["condition"]))
n_missing_escape = len(set(antibody_data["condition"]) - set(escape_data["condition"]))
n_both = len(set(antibody_data["condition"]).union(escape_data["condition"]))
print(f"There are {n_missing_antibody} antibodies with escape but not antibody data")
print(f"There are {n_missing_escape} antibodies with antibody but not escape data")
print(f"There are {n_both} antibodies with both antibody and escape data")

# first, merge mutation level
dms_data_all_mut = (
    antibody_data
    .merge(escape_data, on="condition")
)
all_mut_file = "processed_data/escape_data_mutation.csv"
print(f"Writing mutation-level data to {all_mut_file}")
dms_data_all_mut.to_csv(all_mut_file, index=False, float_format="%.4g")

# now collapse to site level
assert dms_data_all_mut["condition"].nunique() == len(dms_data_all_mut.groupby(["condition", "source"]))
dms_data_all_site = (
    dms_data_all_mut
    .groupby(
        ["condition", "site", "target_virus", "IC50", "neg_log_IC50", "group", "source"],
        as_index=False,
    )
    .aggregate(escape=pd.NamedAgg("mut_escape", "sum"))
)
all_site_file = "processed_data/escape_data_site.csv"
print(f"Writing site-level data to {all_mut_file}")
dms_data_all_site.to_csv(all_site_file, index=False, float_format="%.4g")

# now get DMS data used by calculator
def eliciting_virus(source):
    """Assign eliciting virus tuple from source."""
    if source == "SARS convalescents":
        return ("SARS-CoV-1 then SARS-CoV-2",)
    elif source.startswith("WT"):
        return ("SARS-CoV-2", "pre-Omicron SARS-CoV-2")
    else:
        m = re.fullmatch("(?P<omicron>BA\.\d) convalescents", source)
        assert m, source
        omicron = m.group("omicron")
        return ("SARS-CoV-2", f"pre-Omicron SARS-CoV-2 then Omicron {omicron}") 
    
dms_data = (
    dms_data_all_site
    .query("neg_log_IC50 > 0")  # only keep neutralizing ones for this target virus
    .assign(
        known_to_neutralize=lambda x: x["target_virus"].map(
            {
                "D614G": "Wuhan-Hu-1",
                "BA.1": "Omicron BA.1",
                "BA.2": "Omicron BA.2",
                "BA.2.75": "Omicron BA.2.75",
                "BA.5": "Omicron BA.5",
            }
        )
    )
    .groupby(["condition", "site", "escape", "source"], as_index=False)
    .aggregate(
        neg_log_IC50=pd.NamedAgg("neg_log_IC50", lambda s: tuple(round(s, 4) for s in [max(s), *s])),
        known_to_neutralize=pd.NamedAgg("known_to_neutralize", lambda s: tuple(["any", *s])),
    )
    .assign(eliciting_virus=lambda x: x["source"].map(eliciting_virus))
    .drop(columns="source")
)

print(f"\nCounts for different eliciting viruses:")
display(dms_data.groupby("eliciting_virus").aggregate(n=pd.NamedAgg("condition", "nunique")))

There are 4 antibodies with escape but not antibody data
There are 0 antibodies with antibody but not escape data
There are 3051 antibodies with both antibody and escape data
Writing mutation-level data to processed_data/escape_data_mutation.csv
Writing site-level data to processed_data/escape_data_mutation.csv

Counts for different eliciting viruses:


Unnamed: 0_level_0,n
eliciting_virus,Unnamed: 1_level_1
"(SARS-CoV-1 then SARS-CoV-2,)",449
"(SARS-CoV-2, pre-Omicron SARS-CoV-2)",534
"(SARS-CoV-2, pre-Omicron SARS-CoV-2 then Omicron BA.1)",504
"(SARS-CoV-2, pre-Omicron SARS-CoV-2 then Omicron BA.2)",435
"(SARS-CoV-2, pre-Omicron SARS-CoV-2 then Omicron BA.5)",189


Specify which sites to use:

In [4]:
sites = list(range(331, 531 + 1))
assert dms_data["site"].isin(sites).all()

## Make an "escape calculator" plot
Encode condition level data and then transform_lookup when making plot to make data set size smaller:

In [5]:
encoding = (
    dms_data
    [["eliciting_virus", "known_to_neutralize", "neg_log_IC50", "condition"]]
    .drop_duplicates()
    .reset_index(drop=True)
    .assign(encoding=lambda x: x.index)
)

dms_data_encoded = (
    dms_data
    .merge(encoding)
    [["encoding", "site", "escape"]]
)
assert len(dms_data_encoded) == len(dms_data_encoded.drop_duplicates())

encoding = encoding.drop(columns="condition")

display(encoding)
display(dms_data_encoded)

Unnamed: 0,eliciting_virus,known_to_neutralize,neg_log_IC50,encoding
0,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(any, Wuhan-Hu-1)","(7.3385, 7.3385)",0
1,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(any, Wuhan-Hu-1)","(8.6797, 8.6797)",1
2,"(SARS-CoV-1 then SARS-CoV-2,)","(any, Omicron BA.1, Wuhan-Hu-1)","(6.8401, 2.3238, 6.8401)",2
3,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(any, Wuhan-Hu-1)","(1.5141, 1.5141)",3
4,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(any, Wuhan-Hu-1)","(1.6451, 1.6451)",4
...,...,...,...,...
2106,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(any, Wuhan-Hu-1)","(1.5465, 1.5465)",2106
2107,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(any, Omicron BA.1, Omicron BA.2, Omicron BA.2...","(2.7318, 1.6094, 2.6409, 2.6437, 2.7318, 2.1804)",2107
2108,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(any, Omicron BA.2, Omicron BA.5, Wuhan-Hu-1)","(6.0533, 3.3932, 3.4082, 6.0533)",2108
2109,"(SARS-CoV-2, pre-Omicron SARS-CoV-2)","(any, Wuhan-Hu-1)","(5.4284, 5.4284)",2109


Unnamed: 0,encoding,site,escape
0,0,338,0.521154
1,0,359,0.140217
2,0,370,0.285110
3,0,394,0.112732
4,0,396,0.194379
...,...,...,...
43539,2110,465,0.769640
43540,2110,466,0.886558
43541,2110,468,0.313522
43542,2110,484,0.165163


Now make bar plot with antibody fraction bound:

In [6]:
known_to_neutralize_options = (
    sorted(encoding.explode("known_to_neutralize")["known_to_neutralize"].unique())
)
known_to_neutralize_selection = alt.selection_point(
    fields=['known_to_neutralize'],
    bind=alt.binding_select(
        options=known_to_neutralize_options,
        labels=known_to_neutralize_options,
        name="known to neutralize",
   ),
    value=[{'known_to_neutralize': 'Omicron BA.2'}]
)

mut_selection = alt.selection_point(name='mut',
                                    fields=['site'],
                                    value=[{'site': -1}],
                                    empty=True,
                                    toggle="true",
                                    )

mut_escape_strength_slider = alt.binding_range(min=1, max=10,
                                               name='mutation_escape_strength')
mut_escape_strength_selection = alt.selection_point(name='mut_escape_strength',
                                                    fields=['mutation_escape_strength'],
                                                    bind=mut_escape_strength_slider,
                                                    value=[{'mutation_escape_strength': 2}])

ic50_weight = alt.selection_point(
    name="weight_by_log_IC50",
    bind=alt.binding_radio(
        options=[1, 0],
        labels=["yes", "no"],
        name="weight by log IC50",
    ),
    fields=["choice"],
    value=[{"choice": 1}],
)

eliciting_viruses = sorted(encoding.explode("eliciting_virus")["eliciting_virus"].unique())
eliciting_virus_selection = alt.selection_point(
    fields=['eliciting_virus'],
    bind=alt.binding_select(
        options=eliciting_viruses,
        labels=eliciting_viruses,
        name="eliciting virus",
    ),
    value=[{'eliciting_virus': 'SARS-CoV-2'}]
)

plot_base = (
    alt.Chart(dms_data_encoded)
    .transform_lookup(
        lookup="encoding",
        from_=alt.LookupData(
            data=encoding,
            key="encoding",
            fields=["known_to_neutralize", "neg_log_IC50"],
        )
    )
    .transform_flatten(["known_to_neutralize", "neg_log_IC50"])
    .transform_filter(known_to_neutralize_selection)
    .transform_aggregate(
        escape="mean(escape)",
        groupby=["encoding", "site", "neg_log_IC50"],
    )
    .transform_lookup(
        lookup="encoding",
        from_=alt.LookupData(data=encoding, key="encoding", fields=["eliciting_virus"])
    )
    .transform_flatten(["eliciting_virus"])
    .transform_filter(eliciting_virus_selection)
    .transform_aggregate(
        escape="mean(escape)",
        groupby=["encoding", "site", "neg_log_IC50"],
    )
    .transform_joinaggregate(
        # get maximum escape across any site for this condition
        condition_escape_max='max(escape)',
        groupby=['encoding'],
    )
    .transform_calculate(
        # based on here: https://github.com/altair-viz/altair/issues/2366#issuecomment-812621436
        # based on here: https://stackoverflow.com/a/60894451/4191652
        site_binding_retained='(datum.condition_escape_max - '
                              ' if(indexof(mut.site, datum.site) >= 0, datum.escape, 0)) / '
                              'datum.condition_escape_max',
        encoding_weight="if(weight_by_log_IC50.choice == 1, datum.neg_log_IC50, 1)",
    )
)

frac_bound_bar = (
    plot_base
    .transform_aggregate(
        binding_retained='product(site_binding_retained)',
        groupby=['encoding', "encoding_weight"],
    )
    .transform_calculate(
        binding_retained_exp='datum.encoding_weight * pow(datum.binding_retained, mut_escape_strength.mutation_escape_strength)'
    )
    .transform_aggregate(
        sum_binding_retained='sum(binding_retained_exp)',
        sum_encoding_weight="sum(encoding_weight)",
    )
    .transform_calculate(
        bound='datum.sum_binding_retained / datum.sum_encoding_weight',
        escaped='1 - datum.bound',
    )
    .transform_fold(
        ['bound', 'escaped'],
        ['binding state', 'fraction of antibodies']
    )
    .encode(x=alt.X('fraction of antibodies:Q',
                    axis=alt.Axis(grid=False),
                    ),
            y=alt.value(1),
            fill=alt.Color('binding state:N',
                            scale=alt.Scale(
                                domain=['bound', 'escaped'],
                                range=['lightgray', '#56B4E9'],
                                reverse=True,
                                ),
                            ),
            order=alt.Order('binding state:N'),
            tooltip=['binding state:N',
                     alt.Tooltip('fraction of antibodies:Q',
                                 format='.2g'),]
            )
    .mark_bar(stroke='black',
              size=20)
    .add_parameter(
        mut_selection,
        mut_escape_strength_selection,
        ic50_weight,
        known_to_neutralize_selection,
        eliciting_virus_selection,
    )
    .properties(width=300, height=10)
    )

frac_bound_bar

Now make the line plot:

In [7]:
escape_mut_base = (
    plot_base
    .encode(
        x=alt.X(
            'site:Q',
            axis=alt.Axis(grid=False),
            scale=alt.Scale(zero=False, nice=False),
        ),
        y=alt.Y(
            'mean_escape_value:Q',
            axis=alt.Axis(grid=False, title='escape (arbitrary units)', labels=False, ticks=False),
        ),
    )
    .transform_joinaggregate(
        binding_retained='product(site_binding_retained)',
        groupby=['encoding', "encoding_weight"],
    )
    .transform_calculate(
        escape_weighted="datum.encoding_weight * datum.escape",
        escape_after_mut='pow(datum.binding_retained, mut_escape_strength.mutation_escape_strength) * datum.escape_weighted'
    )
    # we don't actually have the correct denominator here, but it should
    # just affect relative scale of escape values
    .transform_joinaggregate(n_conditions="distinct(encoding)")
    .transform_aggregate(
        sum_mutated='sum(escape_after_mut)',
        sum_unmutated='sum(escape_weighted)',
        n_conditions="mean(n_conditions)",
        groupby=['site'],
    )
    .transform_calculate(
        mutated="datum.sum_mutated / datum.n_conditions",
        unmutated="datum.sum_unmutated / datum.n_conditions",
    )
    .transform_fold(['unmutated', 'mutated'],
                    ['escape_type', 'mean_escape_value'])
    .transform_impute(
        impute="mean_escape_value",
        key="site",
        value=0,
        groupby=["escape_type"],
        keyvals=sites,
    )
    .transform_calculate(
        color_val='if((indexof(mut.site, datum.site) >= 0) & (datum.escape_type == "mutated"), '
                  '"mutated site", datum.escape_type)'
        )
    .properties(width=800, height=225)
    )

mut_escape_color_scale = alt.Scale(
        domain=['unmutated', 'mutated', 'mutated site'],
        range=['#999999', '#56B4E9', '#D55E00']
        )
mut_escape_point_size_scale = alt.Scale(
        domain=['unmutated', 'mutated', 'mutated site'],
        range=[30, 60, 100],
        )
mut_escape_opacity_scale = alt.Scale(
        domain=['unmutated', 'mutated', 'mutated site'],
        range=[0.5, 0.7, 1],
        )

escape_mut_lines = (
    escape_mut_base
    .encode(color=alt.Color('escape_type:N',
                            scale=mut_escape_color_scale,
                            ),
            opacity=alt.Opacity('escape_type:N',
                                scale=mut_escape_opacity_scale,
                                legend=None,
                                ),
            )
    .mark_line()
    )

escape_mut_points = (
    escape_mut_base
    .encode(color=alt.Color(
                    'color_val:N',
                    scale=mut_escape_color_scale,
                    legend=alt.Legend(
                            title=None,
                            labelExpr='if(datum.value == "unmutated", '
                                      '   "escape when no mutations", '
                                      '   if(datum.value == "mutated", '
                                      '      "escape with mutations", '
                                      '      "mutated site"))'
                            ),
                    ),
            opacity=alt.Opacity('color_val:N',
                                scale=mut_escape_opacity_scale,
                                legend=None,
                                ),
            size=alt.Size('color_val:N',
                          scale=mut_escape_point_size_scale,
                          ),
            tooltip=['site:O',
                     alt.Tooltip('mutated:Q',
                                 format='.2g'),
                     alt.Tooltip('unmutated:Q',
                                 format='.2g'),
                     ],
            )
    .mark_point(filled=True)
    .add_parameter(
        mut_selection,
        mut_escape_strength_selection,
        ic50_weight,
        known_to_neutralize_selection,
        eliciting_virus_selection,
    )
)

escape_chart = (
    ((escape_mut_lines + escape_mut_points) & frac_bound_bar)
    .configure_view(strokeOpacity=0)
    .configure_legend(orient='bottom',
                      labelFontSize=12,
                      title=None)
    .resolve_legend('independent')
    )

escape_calc_chartfile = 'docs/_includes/escape_calc_chart.html'
os.makedirs(os.path.dirname(escape_calc_chartfile), exist_ok=True)
print(f"Saving chart to {escape_calc_chartfile}")
escape_chart.save(escape_calc_chartfile)

escape_chart

Saving chart to docs/_includes/escape_calc_chart.html


Write the escape calculator data to a file:

In [8]:
escape_calc_data_file = 'processed_data/escape_calculator_data.csv'
os.makedirs(os.path.dirname(escape_calc_data_file), exist_ok=True)

print(f"Writing escape calculator data to {escape_calc_data_file}")

(
    dms_data
    .assign(
        known_to_neutralize=lambda x: x["known_to_neutralize"].map(lambda t: ";".join(t)),
        neg_log_IC50=lambda x: x["neg_log_IC50"].map(lambda t: ";".join(f"{i:.5g}" for i in t)),
        eliciting_virus=lambda x: x["eliciting_virus"].map(lambda t: ";".join(t)),
    )
    .to_csv(escape_calc_data_file, index=False)
)

Writing escape calculator data to processed_data/escape_calculator_data.csv
