# Compare DMS to natural sequence evolution

In [1]:
# this cell is tagged parameters for papermill parameterization
dms_summary_csv = None
growth_rates_csv = None
pango_consensus_seqs_json = None
starting_clade = None
dms_clade = None
n_random = None
exclude_clades = None
pango_dms_phenotypes_csv = None
pango_randomized_dms_phenotypes_csv = None
pango_by_date_html = None
pango_affinity_vs_escape_html = None

In [2]:
# Parameters
starting_clade = "XBB"
dms_clade = "XBB.1.5"
dms_summary_csv = "results/summaries/summary.csv"
growth_rates_csv = "data/2023-09-18_Murrell_growth_estimates.csv"
pango_consensus_seqs_json = (
    "results/compare_natural/pango-consensus-sequences_summary.json"
)
pango_dms_phenotypes_csv = "results/compare_natural/pango_dms_phenotypes.csv"
pango_randomized_dms_phenotypes_csv = (
    "results/compare_natural/pango_randomized_dms_phenotypes.csv"
)
pango_by_date_html = "results/compare_natural/pango_dms_phenotypes_by_date.html"
pango_affinity_vs_escape_html = "results/compare_natural/pango_affinity_vs_escape.html"
n_random = 10
exclude_clades = []

import os
os.chdir("../")

In [3]:
import collections
import json
import re

import altair as alt

import numpy

import pandas as pd

import statsmodels.api

_ = alt.data_transformers.disable_max_rows()

## Read Pango clades and mutations

In [4]:
with open(pango_consensus_seqs_json) as f:
    pango_clades = json.load(f)

def n_child_clades(c):
    """Get number of children clades of a Pango clade."""
    direct_children = pango_clades[c]["children"]
    return len(direct_children) + sum([n_child_clades(c_child) for c_child in direct_children])

def build_records(c, recs):
    """Build records of Pango clade information."""
    recs["clade"].append(c)
    recs["n_child_clades"].append(n_child_clades(c))
    recs["date"].append(pango_clades[c]["designationDate"])
    recs["muts_from_ref"].append(
        [
            mut.split(":")[1]
            for field in ["aaSubstitutions", "aaDeletions"]
            for mut in pango_clades[c][field]
            if mut.startswith("S:")
        ]
    )
    for c_child in pango_clades[c]["children"]:
        build_records(c_child, recs)
        
records = collections.defaultdict(list)
build_records(starting_clade, records)

pango_df = pd.DataFrame(records).query("clade not in @exclude_clades")
starting_clade_mutations_from_ref = pango_df.set_index("clade").at[
    starting_clade, "muts_from_ref"
]
dms_clade_mutations_from_ref = pango_df.set_index("clade").at[
    dms_clade, "muts_from_ref"
]

def mutations_from(muts, from_muts):
    """Get mutations from another sequence."""
    new_muts = set(muts).symmetric_difference(from_muts)
    assert all(re.fullmatch("[A-Z\-]\d+[A-Z\-]", m) for m in new_muts)
    new_muts_d = collections.defaultdict(list)
    for m in new_muts:
        new_muts_d[int(m[1: -1])].append(m)
    new_muts_list = []
    for _, ms in sorted(new_muts_d.items()):
        if len(ms) == 1:
            m = ms[0]
            if m in muts:
                new_muts_list.append(m)
            else:
                assert m in from_muts
                new_muts_list.append(m[-1] + m[1: -1] + m[0])
        else:
            m, from_m = ms
            if m not in muts:
                from_m, m = m, from_m
            assert m in muts and from_m in from_muts
            new_muts_list.append(from_m[-1] + m[1: ])
    return new_muts_list

pango_df = (
    pango_df
    .assign(
        muts_from_start_clade=lambda x: x["muts_from_ref"].apply(
            mutations_from, args=(starting_clade_mutations_from_ref,),
        ).map(lambda ml: "; ".join(ml)),
        muts_from_dms_clade=lambda x: x["muts_from_ref"].apply(
            mutations_from, args=(dms_clade_mutations_from_ref,),
        ),
        date=lambda x: pd.to_datetime(x["date"]),
    )
    .drop(columns="muts_from_ref")
    .sort_values("date")
    .reset_index(drop=True)
)

pango_df

Unnamed: 0,clade,n_child_clades,date,muts_from_start_clade,muts_from_dms_clade
0,XBB,623,2022-09-17,,"[V252G, P486S]"
1,XBB.1,541,2022-10-03,G252V,[P486S]
2,XBB.1.1,0,2022-10-15,G252V,[P486S]
3,XBB.2,64,2022-10-15,D253G,"[V252G, D253G, P486S]"
4,XBB.3,5,2022-10-15,,"[V252G, P486S]"
...,...,...,...,...,...
619,GK.2.3,0,2023-09-17,G252V; K356T; L455F; F456L; S486P; V511I,"[K356T, L455F, F456L, V511I]"
620,JG.3,0,2023-09-17,Q52H; G252V; L455F; F456L; S486P; S704L,"[Q52H, L455F, F456L, S704L]"
621,GK.4,0,2023-09-17,G252V; L455F; F456L; A475V; S486P,"[L455F, F456L, A475V]"
622,XBB.1.5.106,0,2023-09-17,G252V; S486P; A623V,[A623V]


## Assign DMS phenotypes to Pango clades

First define function that assigns DMS phenotypes to mutations:

In [5]:
# read the DMS data
dms_summary = pd.read_csv(dms_summary_csv).rename(
    columns={
        "spike mediated entry": "cell entry",
        "human sera escape": "sera escape",
    }
)

# specify DMS phenotypes of interest
phenotypes = [
    "sera escape",
    "ACE2 affinity",
    "cell entry",
]
assert set(phenotypes).issubset(dms_summary.columns)

# dict that maps site to wildtype in DMS
dms_wt = dms_summary.set_index("site")["wildtype"].to_dict()

# dict that maps site to region in DMS
site_to_region = dms_summary.set_index("site")["region"].to_dict()

def mut_dms(m, dms_data):
    """Get DMS phenotypes for a mutation."""
    null_d = {k: pd.NA for k in phenotypes}
    if pd.isnull(m) or int(m[1: -1]) not in dms_wt:
        d = null_d
        d["is_RBD"] = pd.NA
    else:
        parent = m[0]
        site = int(m[1: -1])
        mut = m[-1]
        wt = dms_wt[site]
        if parent == wt:
            try:
                d = dms_data[(site, parent, mut)]
            except KeyError:
                d = null_d
        elif mut == wt:
            try:
                d = {k: -v for (k, v) in dms_data[(site, mut, parent)].items()}
            except KeyError:
                d = null_d
        else:
            try:
                parent_d = dms_data[(site, wt, parent)]
                mut_d = dms_data[(site, wt, mut)]
                d = {p: mut_d[p] - parent_d[p] for p in phenotypes}
            except KeyError:
                d = null_d
        d["is_RBD"] = (site_to_region[site] == "RBD")
    assert list(d) == phenotypes + ["is_RBD"]
    return d

Now assign phenotypes to pango clades.
We do this both using the actual DMS data and randomizing the DMS data among measured mutations:

In [6]:
def get_pango_dms_df(dms_data_dict):
    """Given dict mapping mutations to DMS data, get data frame of values for Pango clades."""
    pango_dms_df = (
        pango_df
        # put one mutation in each column
        .explode("muts_from_dms_clade")
        .rename(columns={"muts_from_dms_clade": "mutation"})
        # to add multiple columns: https://stackoverflow.com/a/46814360
        .apply(
            lambda cols: pd.concat([cols, pd.Series(mut_dms(cols["mutation"], dms_data_dict))]),
            axis=1,
        )
        .melt(
            id_vars=["clade", "date", "n_child_clades", "muts_from_start_clade", "mutation", "is_RBD"],
            value_vars=phenotypes,
            var_name="DMS_phenotype",
            value_name="mutation_effect",
        )
        .assign(
            muts_from_dms_clade=lambda x: x.groupby(["clade", "DMS_phenotype"])["mutation"].transform(
                lambda ms: "; ".join([m for m in ms if not pd.isnull(m)])
            ),
            mutation_missing=lambda x: x["mutation"].where(
                x["mutation_effect"].isnull() & x["mutation"].notnull(),
                pd.NA,
            ),
            muts_from_dms_clade_missing_data=lambda x: (
                x.groupby(["clade", "DMS_phenotype"])["mutation_missing"]
                .transform(lambda ms: "; ".join([m for m in ms if not pd.isnull(m)]))
            ),
            mutation_effect=lambda x: x["mutation_effect"].fillna(0),
            is_RBD=lambda x: x["is_RBD"].fillna(False),
            mutation_effect_RBD=lambda x: x["mutation_effect"] * x["is_RBD"].astype(int),
            mutation_effect_nonRBD=lambda x: x["mutation_effect"] * (~x["is_RBD"]).astype(int),
        )
        .groupby(
            [
                "clade",
                "date",
                "n_child_clades",
                "muts_from_start_clade",
                "muts_from_dms_clade",
                "muts_from_dms_clade_missing_data",
                "DMS_phenotype",
            ],
            as_index=False,
        )
        .aggregate(
            phenotype=pd.NamedAgg("mutation_effect", "sum"),
            phenotype_RBD_only=pd.NamedAgg("mutation_effect_RBD", "sum"),
            phenotype_nonRBD_only=pd.NamedAgg("mutation_effect_nonRBD", "sum"),
        )
        .rename(
            columns={
                "muts_from_start_clade": f"muts_from_{starting_clade}",
                "muts_from_dms_clade": f"muts_from_{dms_clade}",
                "muts_from_dms_clade_missing_data": f"muts_from_{dms_clade}_missing_data",
            },
        )
        .sort_values(["date", "DMS_phenotype"])
        .reset_index(drop=True)
    )
    
    assert set(pango_df["clade"]) == set(pango_dms_df["clade"])
    assert numpy.allclose(
        pango_dms_df["phenotype"],
        pango_dms_df["phenotype_RBD_only"] + pango_dms_df["phenotype_nonRBD_only"]
    )

    return pango_dms_df

# First, get the actual DMS data mapped to phenotype
dms_data_dict_actual = (
    dms_summary
    .set_index(["site", "wildtype", "mutant"])
    [phenotypes]
    .to_dict(orient="index")
)
pango_dms_df = get_pango_dms_df(dms_data_dict_actual)
print(f"Saving Pango DMS phenotypes to {pango_dms_phenotypes_csv}")
pango_dms_df.to_csv(pango_dms_phenotypes_csv, float_format="%.4f", index=False)

# Now get the randomized DMS data mapped to phenotype
pango_dms_dfs_rand = []
for irandom in range(1, n_random + 1):
    # randomize the non-null DMS data for each phenotype
    dms_summary_rand = dms_summary.copy()
    for phenotype in phenotypes:
        dms_summary_rand = pd.concat(
            [
                dms_summary_rand[dms_summary_rand[phenotype].isnull()],
                (
                    dms_summary_rand[dms_summary_rand[phenotype].notnull()]
                    .set_index([c for c in dms_summary_rand.columns if c != phenotype])
                    .sample(frac=1, random_state=irandom)
                    .reset_index()
                )
            ]
        )
        assert dms_summary_rand.shape == dms_summary.shape
        dms_data_dict_rand = (
            dms_summary_rand
            .set_index(["site", "wildtype", "mutant"])
            [phenotypes]
            .to_dict(orient="index")
        )
        pango_dms_dfs_rand.append(get_pango_dms_df(dms_data_dict_rand).assign(randomize=irandom))
# all randomizations concatenated
pango_dms_df_rand = pd.concat(pango_dms_dfs_rand)
print(f"Saving randomized Pango DMS phenotypes to {pango_randomized_dms_phenotypes_csv}")
pango_dms_df_rand.to_csv(pango_randomized_dms_phenotypes_csv, float_format="%.4f", index=False)

Saving Pango DMS phenotypes to results/compare_natural/pango_dms_phenotypes.csv
Saving randomized Pango DMS phenotypes to results/compare_natural/pango_randomized_dms_phenotypes.csv


## Plot phenotypes of Pango clades
Plot phenotypes of Pango clades versus their designation dates:

In [7]:
region_cols = {
    "phenotype": "full spike",
    "phenotype_RBD_only": "RBD only",
    "phenotype_nonRBD_only": "non-RBD only",
}

pango_chart_df = (
    pango_dms_df
    .melt(
        id_vars=[c for c in pango_dms_df if c not in region_cols],
        value_vars=region_cols,
        var_name="spike_region",
        value_name="phenotype value",
    )
    .assign(
        spike_region=lambda x: x["spike_region"].map(region_cols),
        n_mutations=lambda x: x[f"muts_from_{starting_clade}"].map(
            lambda s: len([m for m in s.split(";") if m])
        )
    )
    .rename(columns={f"muts_from_{dms_clade}_missing_data": "muts_missing_data"})
)

if pango_chart_df["n_mutations"].max() > 12:
    raise ValueError(
        "check high number of mutations to ensure not bug in JSON like this one:\n"
        + "https://github.com/corneliusroemer/pango-sequences/issues/6\n\n"
        + str(pango_chart_df.query("n_mutations > 12"))
    )

# columns cannot have "." in them for Altair
col_renames = {c: c.replace(".", "_") for c in pango_chart_df.columns if "." in c}
col_renames_rev = {v: k for (k, v) in col_renames.items()}
pango_chart_df = pango_chart_df.drop(columns="n_mutations").rename(columns=col_renames)

clade_selection = alt.selection_point(fields=["clade"], on="mouseover", empty=False)

base_pango_chart = (
    alt.Chart(pango_chart_df)
    .encode(
        tooltip=[
            alt.Tooltip(c, title=col_renames_rev[c] if c in col_renames_rev else c)
            for c in pango_chart_df.columns
        ],
        opacity=alt.condition(clade_selection, alt.value(1), alt.value(0.35)),
        size=alt.condition(clade_selection, alt.value(60), alt.value(40)),
        strokeWidth=alt.condition(clade_selection, alt.value(2), alt.value(0)),
        color=alt.Color("DMS_phenotype", legend=None),
    )
    .mark_circle(stroke="black")
    .properties(width=300, height=125)
)

phenotype_pango_charts = []
for phenotype in phenotypes:
    first_row = (phenotype == phenotypes[0])
    last_row = (phenotype == phenotypes[-1])
    phenotype_pango_charts.append(
        base_pango_chart
        .transform_filter(alt.datum["DMS_phenotype"] == phenotype)
        .encode(
            x=alt.X(
                "date",
                title="designation date of clade" if last_row else None,
                axis=(
                    alt.Axis(titleFontSize=12, labelOverlap=True, format="%b-%Y", labelAngle=-90)
                    if last_row
                    else None
                ),
                scale=alt.Scale(nice=False, padding=3),
            ),
            y=alt.Y(
                "phenotype value",
                title=phenotype,
                axis=alt.Axis(titleFontSize=12),
                scale=alt.Scale(nice=False, padding=3),
            ),
            column=alt.Column(
                "spike_region",
                sort=list(region_cols),
                title=None,
                header=(
                    alt.Header(labelFontSize=12, labelFontStyle="bold", labelPadding=4)
                    if first_row
                    else None
                ),
                spacing=4,
            ),
        )
    )

pango_chart = (
    alt.vconcat(*phenotype_pango_charts, spacing=4)
    .configure_axis(grid=False)
    .add_params(clade_selection)
    .properties(        
        title=alt.TitleParams(
            f"DMS predicted phenotypes of all Pango clades descended from {starting_clade}",
            anchor="middle",
            fontSize=16,
            dy=-5,
        ),
    )
)

print(f"Saving chart to {pango_by_date_html}")
pango_chart.save(pango_by_date_html)

pango_chart

Saving chart to results/compare_natural/pango_dms_phenotypes_by_date.html


## Pango clade affinity versus escape scatter plot

In [8]:
pango_scatter_df = (
    pango_dms_df
    .pivot_table(
        index=[
            c
            for c in pango_dms_df
            if c not in {"DMS_phenotype", "phenotype", "phenotype_RBD_only", "phenotype_nonRBD_only"}
        ],
        values="phenotype",
        columns="DMS_phenotype",
    )
    .reset_index()
    .rename(columns={f"muts_from_{dms_clade}_missing_data": "muts_missing_data"})
    .rename(columns=col_renames)
)

pango_scatter_df

pango_scatter_chart = (
    alt.Chart(pango_scatter_df)
    .encode(
        x=alt.X(
            "ACE2 affinity",
            axis=alt.Axis(titleFontSize=12),
            scale=alt.Scale(nice=False, padding=5),
        ),
        y=alt.Y(
            "sera escape",
            axis=alt.Axis(titleFontSize=12),
            scale=alt.Scale(nice=False, padding=5),
        ),
        tooltip=[
            alt.Tooltip(c, title=col_renames_rev[c] if c in col_renames_rev else c)
            for c in pango_scatter_df.columns
        ],
        opacity=alt.condition(clade_selection, alt.value(1), alt.value(0.35)),
        size=alt.condition(clade_selection, alt.value(100), alt.value(55)),
        strokeWidth=alt.condition(clade_selection, alt.value(2), alt.value(0)),
    )
    .mark_circle(stroke="red", color="black")
    .add_params(clade_selection)
    .configure_axis(grid=False)
    .properties(        
        title=alt.TitleParams(
            [
                "DMS predicted ACE2 affinity vs serum escape",
                f"for Pango clades descended from {starting_clade}"
            ],
            anchor="middle",
            fontSize=14,
            dy=-5,
        ),
    )
    .properties(width=300, height=300)
)

print(f"Saving chart to {pango_affinity_vs_escape_html}")
pango_scatter_chart.save(pango_affinity_vs_escape_html)

pango_scatter_chart

Saving chart to results/compare_natural/pango_affinity_vs_escape.html


## Correlate with clade growth

In [9]:
growth_rates = pd.read_csv(growth_rates_csv).rename(columns={"pango": "clade"})

if (invalid_clades := set(growth_rates["clade"]) - set(pango_clades)):
    raise ValueError(f"Growth rates specified for {invalid_clades}")

pango_dms_growth_df = pango_dms_df.merge(growth_rates, on="clade", validate="many_to_one")

print(
    f"{growth_rates['clade'].nunique()} clades have growth rates estimates.\n"
    f"{pango_dms_df['clade'].nunique()} clades have DMS estimates.\n"
    f"{pango_dms_growth_df['clade'].nunique()} clades have growth and DMS estimates"
)

print("Simple correlations:")
display(
    pango_dms_growth_df
    .groupby("DMS_phenotype")
    [["R", "phenotype", "phenotype_RBD_only", "phenotype_nonRBD_only"]]
    .corr()
    [["R"]]
)

966 clades have growth rates estimates.
624 clades have DMS estimates.
276 clades have growth and DMS estimates
Simple correlations:


Unnamed: 0_level_0,Unnamed: 1_level_0,R
DMS_phenotype,Unnamed: 1_level_1,Unnamed: 2_level_1
ACE2 affinity,R,1.0
ACE2 affinity,phenotype,0.030709
ACE2 affinity,phenotype_RBD_only,-0.021114
ACE2 affinity,phenotype_nonRBD_only,0.078582
cell entry,R,1.0
cell entry,phenotype,0.460764
cell entry,phenotype_RBD_only,0.524726
cell entry,phenotype_nonRBD_only,0.22195
sera escape,R,1.0
sera escape,phenotype,0.59465


Now perform OLS:

In [13]:
# pivot DMS data to get phenotypes
ols_vars = (
    pango_dms_growth_df
    .rename(
        columns={
            "phenotype": "full spike",
            "phenotype_RBD_only": "RBD",
            "phenotype_nonRBD_only": "non RBD",
        }
    )
    .assign(
        # group muts missing data from all phenotypes
        **{
            f"muts_from_{dms_clade}_missing_data": lambda x: (
                x.groupby("clade")
                [f"muts_from_{dms_clade}_missing_data"]
                .transform(
                    lambda s: "; ".join(dict.fromkeys([m for ms in s.str.split("; ") for m in ms if m]))
                )
            ),
        }
    )
    .pivot_table(
        index=[
            "clade",
            "R",
            "date",
            "muts_from_XBB",
            "muts_from_XBB.1.5",
            "muts_from_XBB.1.5_missing_data",
            "seq_volume",
        ],
        columns="DMS_phenotype",
        values=["full spike", "RBD", "non RBD"],
    )
    .reset_index()
)
# flatten column names
ols_vars.columns = [" ".join(c).strip() for c in ols_vars.columns.values]

# https://www.einblick.ai/python-code-examples/ordinary-least-squares-regression-statsmodels/
for name, exog_vars in [
    ("full spike", [f"full spike {c}" for c in phenotypes]),
    (
        "RBD and non-RBD separately",
        [f"{d} {c}" for d in ["RBD", "non RBD"] for c in phenotypes],
    ),
]:
    print(f"Fitting for {name}:\n")
    ols_model = statsmodels.api.OLS(
        endog=ols_vars[["R"]],
        exog=statsmodels.api.add_constant(ols_vars[exog_vars]),
    )
    res_ols = ols_model.fit()
    display(res_ols.summary())

Fitting for full spike:



0,1,2,3
Dep. Variable:,R,R-squared:,0.403
Model:,OLS,Adj. R-squared:,0.396
Method:,Least Squares,F-statistic:,61.18
Date:,"Tue, 19 Sep 2023",Prob (F-statistic):,2.9299999999999997e-30
Time:,11:19:11,Log-Likelihood:,-1023.7
No. Observations:,276,AIC:,2055.0
Df Residuals:,272,BIC:,2070.0
Df Model:,3,,
Covariance Type:,nonrobust,,

0,1,2,3,4,5,6
,coef,std err,t,P>|t|,[0.025,0.975]
const,61.9687,0.714,86.774,0.000,60.563,63.375
full spike sera escape,44.8749,4.922,9.117,0.000,35.184,54.565
full spike ACE2 affinity,-1.8259,2.112,-0.865,0.388,-5.983,2.331
full spike cell entry,23.2526,4.931,4.716,0.000,13.546,32.959

0,1,2,3
Omnibus:,7.725,Durbin-Watson:,0.798
Prob(Omnibus):,0.021,Jarque-Bera (JB):,7.974
Skew:,0.324,Prob(JB):,0.0186
Kurtosis:,3.523,Cond. No.,10.2


Fitting for RBD and non-RBD separately:



0,1,2,3
Dep. Variable:,R,R-squared:,0.613
Model:,OLS,Adj. R-squared:,0.604
Method:,Least Squares,F-statistic:,70.93
Date:,"Tue, 19 Sep 2023",Prob (F-statistic):,1.37e-52
Time:,11:19:11,Log-Likelihood:,-963.97
No. Observations:,276,AIC:,1942.0
Df Residuals:,269,BIC:,1967.0
Df Model:,6,,
Covariance Type:,nonrobust,,

0,1,2,3,4,5,6
,coef,std err,t,P>|t|,[0.025,0.975]
const,60.5230,0.600,100.859,0.000,59.342,61.704
RBD sera escape,91.4703,6.562,13.939,0.000,78.551,104.390
RBD ACE2 affinity,5.4262,2.387,2.273,0.024,0.726,10.126
RBD cell entry,2.1937,9.015,0.243,0.808,-15.555,19.942
non RBD sera escape,-3.9303,5.895,-0.667,0.506,-15.536,7.675
non RBD ACE2 affinity,-9.9285,2.945,-3.371,0.001,-15.727,-4.130
non RBD cell entry,7.5757,5.343,1.418,0.157,-2.944,18.095

0,1,2,3
Omnibus:,22.341,Durbin-Watson:,0.996
Prob(Omnibus):,0.0,Jarque-Bera (JB):,26.663
Skew:,0.64,Prob(JB):,1.62e-06
Kurtosis:,3.826,Cond. No.,22.0
