# Compare mutation effects on ACE2 binding vs sera escape at key sites
This notebook compares how different mutations affect ACE2 binding versus escape at key sites.

In [1]:
# this cell is tagged as parameters for `papermill` parameterization
dms_csv = None
logoplot_subdir = None

In [2]:
# Parameters
dms_csv = "results/summaries/summary.csv"
logoplot_subdir = "results/binding_vs_escape/logoplots"


import os
os.chdir("../")

In [17]:
import os

import altair as alt

import dmslogo

import matplotlib
import matplotlib.pyplot as plt

import numpy

import pandas as pd

## Read input data

In [39]:
min_cell_entry = -1.5  # only keep mutations with cell entry at least this good
min_mutations_at_site = 7  # only keep sites with at least this many mutations

dms_df = (
    pd.read_csv(dms_csv)
    .rename(columns={"human sera escape": "sera escape", "spike mediated entry": "cell entry"})
    .dropna(subset=["sera escape", "cell entry", "ACE2 binding"])
    .query("`cell entry` >= @min_cell_entry")
    .query("mutant not in ['*', '-']")  # exclude stop and gap
    .assign(
        mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"],
        n_mutations_at_site=lambda x: x.groupby("site")["mutant"].transform("count"),
    )
    .query("n_mutations_at_site >= @min_mutations_at_site")
    .reset_index(drop=True)
)

dms_df

Unnamed: 0,site,wildtype,mutant,sera escape,cell entry,ACE2 binding,sequential_site,region,mutation,n_mutations_at_site
0,3,V,A,0.024020,-0.041540,-0.049770,3,other,V3A,7
1,3,V,F,0.112700,-0.105700,0.113900,3,other,V3F,7
2,3,V,G,-0.122400,-0.002075,-0.189000,3,other,V3G,7
3,3,V,I,-0.113900,-0.120400,-0.138400,3,other,V3I,7
4,3,V,L,-0.025990,0.033060,-0.108000,3,other,V3L,7
...,...,...,...,...,...,...,...,...,...,...
3255,1252,S,F,-0.006158,0.100700,-0.012280,1248,other,S1252F,7
3256,1252,S,P,-0.004025,0.077520,0.053890,1248,other,S1252P,7
3257,1252,S,T,0.001764,0.003655,-0.114200,1248,other,S1252T,7
3258,1252,S,Y,-0.006188,0.058880,0.000546,1248,other,S1252Y,7


## Calculate correlation between ACE2 binding and escape for each site

In [50]:
# compute correlations
correlation_df = (
    dms_df
    .groupby("site")
    [["sera escape", "ACE2 binding"]]
    .corr()
    .reset_index()
    .query("level_1 == 'sera escape'")
    .rename(columns={"ACE2 binding": "correlation"})
    [["site", "correlation"]]
    .dropna(subset="correlation")
    .reset_index(drop=True)
)

# add correlations to DMS data frame
dms_df = dms_df.merge(correlation_df, validate="many_to_one")

dms_df

Unnamed: 0,site,wildtype,mutant,sera escape,cell entry,ACE2 binding,sequential_site,region,mutation,n_mutations_at_site,correlation
0,3,V,A,0.024020,-0.041540,-0.049770,3,other,V3A,7,0.884780
1,3,V,F,0.112700,-0.105700,0.113900,3,other,V3F,7,0.884780
2,3,V,G,-0.122400,-0.002075,-0.189000,3,other,V3G,7,0.884780
3,3,V,I,-0.113900,-0.120400,-0.138400,3,other,V3I,7,0.884780
4,3,V,L,-0.025990,0.033060,-0.108000,3,other,V3L,7,0.884780
...,...,...,...,...,...,...,...,...,...,...,...
3255,1252,S,F,-0.006158,0.100700,-0.012280,1248,other,S1252F,7,-0.288667
3256,1252,S,P,-0.004025,0.077520,0.053890,1248,other,S1252P,7,-0.288667
3257,1252,S,T,0.001764,0.003655,-0.114200,1248,other,S1252T,7,-0.288667
3258,1252,S,Y,-0.006188,0.058880,0.000546,1248,other,S1252Y,7,-0.288667


## Plot sites with high inverse correlation between ACE2 binding and escape
Plot sites with high inverse correlation of binding and escape; note the slider at the bottom can control which sites are shown:

In [62]:
# first make base chart

facet_size = 100

cell_entry_slider = alt.param(
    value=min_cell_entry,
    bind=alt.binding_range(
        name="minimum cell entry",
        min=dms_df["cell entry"].min(),
        max=0,
    ),
)

binding_escape_corr_base = (
    alt.Chart(dms_df)
    .add_params(cell_entry_slider)
    .transform_filter(alt.datum["cell entry"] >= cell_entry_slider)
)

binding_escape_corr_chart = (
    (
        (
            binding_escape_corr_base
            .encode(
                x=alt.X("ACE2 binding", scale=alt.Scale(nice=False, padding=10)),
                y=alt.Y("sera escape", scale=alt.Scale(nice=False, padding=10)),
                tooltip=[
                    "site",
                    "mutation",
                    alt.Tooltip("ACE2 binding", format=".2f"),
                    alt.Tooltip("sera escape", format=".2f"),
                    alt.Tooltip("cell entry", format=".2f"),
                ],
            )           
            .mark_circle(color="black", opacity=0.3, size=60)
        )
        + (
            binding_escape_corr_base
            .transform_regression("ACE2 binding", "sera escape", params=True)
            .transform_calculate(
                r=alt.expr.if_(
                    alt.datum["coef"][1] > 0,
                    alt.expr.sqrt(alt.datum["rSquared"]),
                    -alt.expr.sqrt(alt.datum["rSquared"]),
                ),
                r_text="r = " + alt.expr.format(alt.datum["r"], ".2f"),
            )
            .encode(
                text="r_text:N",
                x=alt.value(3),
                y=alt.value(facet_size - 6),
            )
            .mark_text(size=12, align="left", color="blue")
        )
    )
    .properties(width=facet_size, height=facet_size)
    .facet(
        facet=alt.Facet(
            "site",
            title=None,
            header=alt.Header(
                labelFontSize=14,
                labelFontStyle="italic",
                labelPadding=0,
                labelExpr="'site ' + datum.label",
            )
        ),
        spacing=8,
        columns=8,
    )
    .configure_axis(grid=False)
)

# now make chart filtered for strongly negative correlations
max_corr_slider = alt.param(
    value=-0.82,
    bind=alt.binding_range(
        name="only show sites with correlation r less than this",
        min=-1,
        max=1,
        step=0.01,
    ),
)

binding_escape_neg_corr_chart = (
    binding_escape_corr_chart
    .properties(
        title=alt.TitleParams(
            "Correlation of ACE2 binding and escape filtered by extent of negative correlation",
            anchor="middle",
            fontSize=16,
            dy=-5,
        ),
        autosize=alt.AutoSizeParams(resize=True),
    )
    .add_params(max_corr_slider)
    .transform_filter(alt.datum["correlation"] <= max_corr_slider)
)

binding_escape_neg_corr_chart

## We now plot the same correlation for sites of strong escape
We manually specify some sites of strong escape:

In [63]:
escape_sites = [357, 420, 440, 444, 452, 456, 473]

binding_escape_high_escape_corr_chart = (
    binding_escape_corr_chart
    .properties(
        title=alt.TitleParams(
            "Correlation of ACE2 binding and escape for sites of strong escape",
            anchor="middle",
            fontSize=16,
            dy=-5,
        ),
        autosize=alt.AutoSizeParams(resize=True),
    )
    .transform_filter(alt.FieldOneOfPredicate("site", escape_sites))
)

binding_escape_high_escape_corr_chart