# Compare DMS to natural sequence evolution
The basic approach is to find all pairs of parent-descendant Pango clades that differ by only a single spike amino-acid substitution, and then compare the differences in growth rates to changes in spike phenotypes measured by DMS.

This approach of comparing clade pairs is better than trying to compare all clades based on their phenotypes as it avoids phylogenetic correlations because it only utilizes the new mutation that has appeared in each parent / descendant clade pair rather than all mutations in clades (the latter approach is confounded by phylogeny due to clades sharing mutations by ancestry).

In [None]:
# this cell is tagged parameters for papermill parameterization
dms_summary_csv = None
growth_rates_csv = None
pango_consensus_seqs_json = None
growth_dms_csv = None

In [None]:
import collections
import functools
import itertools
import json
import math
import operator

import altair as alt

import numpy

import pandas as pd

import polyclonal.plot

import scipy.stats

import statsmodels.api

_ = alt.data_transformers.disable_max_rows()

## Read Pango clades and identify all pairs separated by a single spike mutation
First, read all Pango clades and get their new mutations relative to parents:

In [None]:
with open(pango_consensus_seqs_json) as f:
    pango_clades = json.load(f)

def build_records(c, recs):
    """Build records of Pango clade information."""
    if c in recs["clade"]:
        return
    recs["clade"].append(c)
    recs["date"].append(pango_clades[c]["designationDate"])
    recs["parent"].append(pango_clades[c]["parent"])
    recs["all_new_muts_from_ref"].append(
        [
            mut
            for field in ["aaSubstitutionsNew", "aaDeletionsNew"]
            for mut in pango_clades[c][field]
            if mut
        ]
    )
    recs["all_new_muts_reverted_from_ref"].append(
        [
            mut
            for field in ["aaSubstitutionsReverted", "aaDeletionsReverted"]
            for mut in pango_clades[c][field]
            if mut
        ]
    )
    for c_child in pango_clades[c]["children"]:
        build_records(c_child, recs)
        
records = collections.defaultdict(list)
for starting_clade in starting_clades:
    build_records(starting_clade, records)

pango_df = pd.DataFrame(records).query("clade not in @exclude_clades")

Now get all clade pairs that differ by just a single spike mutation between the parent and descendant clade.
Also note whether they have any additional mutations:

In [None]:
def consolidate_reversions(r):
    """If there are reversions, combine with new mutations to get actual changes."""
    reverted = r["all_new_muts_reverted_from_ref"]
    new = r["all_new_muts_from_ref"]
    if not reverted:
        return new
    # get as dicts with key: val of (gene, site): (wt, mutant)
    reverted_dict = {
        (m.split(":")[0], m.split(":")[1][1: -1]): (m.split(":")[1][0], m.split(":")[1][-1])
        for m in reverted
    }
    new_dict = {
        (m.split(":")[0], m.split(":")[1][1: -1]): (m.split(":")[1][0], m.split(":")[1][-1])
        for m in new
    }
    muts = []
    for (gene, site), (rev_wt, rev_mutant) in reverted_dict.items():
        if (gene, site) in new_dict:
            new_wt, new_mutant = new_dict[(gene, site)]
            muts.append(f"{gene}:{rev_mutant}{site}{new_mutant}")
            del new_dict[(gene, site)]
        else:
            muts.append(f"{gene}:{rev_mutant}{site}{rev_wt}")
    for (gene, site), (new_wt, new_mutant) in new_dict.items():
        muts.append(f"{gene}:{new_wt}{site}{new_mutant}")
    return muts

pango_pair_df = (
    pango_df
    .assign(
        all_new_muts=lambda x: x.apply(consolidate_reversions, axis=1),
        spike_new_muts=lambda x: x["all_new_muts"].map(lambda ms: [m.split(":")[1] for m in ms if m[0] == "S"]),
        n_new_spike_muts=lambda x: x["spike_new_muts"].map(len),
        n_new_all_muts=lambda x: x["all_new_muts"].map(len),
    )
    .query("n_new_spike_muts == 1")
    .assign(
        new_spike_mut=lambda x: x["spike_new_muts"].map(lambda ms: ms[0]),
        all_new_muts=lambda x: x["all_new_muts"].map(lambda ms: "; ".join(ms)),
        only_new_mut_is_spike=lambda x: x["n_new_spike_muts"] == x["n_new_all_muts"],
    )
    [["clade", "parent", "date", "new_spike_mut", "all_new_muts", "only_new_mut_is_spike"]]
    .reset_index(drop=True)
)

print("Number of clade pairs differing by one spike mutation:")
display(
    pango_pair_df
    .groupby("only_new_mut_is_spike")
    .aggregate(n_clade_pairs=pd.NamedAgg("clade", "count"))
)

## Assign changes in DMS phenotypes to Pango clade pairs
For each clade pair, we compute the change in DMS phenotype:

In [None]:
# read the DMS data
dms_summary = pd.read_csv(dms_summary_csv).rename(
    columns={
        "spike mediated entry": "cell entry",
        "human sera escape": "sera escape",
    }
)

# specify DMS phenotypes of interest
phenotypes = [
    "sera escape",
    "ACE2 affinity",
    "cell entry",
]
assert set(phenotypes).issubset(dms_summary.columns)

# dict that maps site to wildtype in DMS
dms_wt = dms_summary.set_index("site")["wildtype"].to_dict()

# dict that maps site to region in DMS
site_to_region = dms_summary.set_index("site")["region"].to_dict()

# dict that maps (site, wt, mutant) to DMS phenotypes
dms_data_dict = (
    dms_summary
    .set_index(["site", "wildtype", "mutant"])
    [phenotypes]
    .to_dict(orient="index")
)

def mut_dms(m):
    """Get DMS phenotypes for a mutation, and also whether it is in the RBD."""
    null_d = {k: pd.NA for k in phenotypes}
    site = int(m[1: -1])
    if site not in dms_wt:
        d = null_d
        d["new_mut_is_RBD"] = pd.NA
    else:
        parent = m[0]
        mut = m[-1]
        wt = dms_wt[site]
        if parent == wt:
            try:
                d = dms_data_dict[(site, parent, mut)]
            except KeyError:
                d = null_d
        elif mut == wt:
            try:
                d = {k: -v for (k, v) in dms_data_dict[(site, mut, parent)].items()}
            except KeyError:
                d = null_d
        else:
            try:
                parent_d = dms_data_dict[(site, wt, parent)]
                mut_d = dms_data_dict[(site, wt, mut)]
                d = {p: mut_d[p] - parent_d[p] for p in phenotypes}
            except KeyError:
                d = null_d
        d["new_mut_is_RBD"] = (site_to_region[site] == "RBD")
    assert list(d) == phenotypes + ["new_mut_is_RBD"]
    return d

pango_pair_dms_df = (
    pango_pair_df
    # to add multiple columns: https://stackoverflow.com/a/46814360
    .apply(
        lambda cols: pd.concat([cols, pd.Series(mut_dms(cols["new_spike_mut"]))]),
        axis=1,
    )
    # remove any clade pairs for which we don't have DMS data for all phenotypes
    .query(" and ".join(f"`{p}`.notnull()" for p in phenotypes))
    .reset_index(drop=True)
)

print("Number of clade pairs with DMS data for all phenotypes:")
display(
    pango_pair_dms_df
    .groupby("only_new_mut_is_spike")
    .aggregate(n_clade_pairs=pd.NamedAgg("clade", "count"))
)

## Assign changes in growth rate to Pango pair clades
For each pair clade, we compute the change in growth rate.

In [None]:
growth_rates = pd.read_csv(growth_rates_csv).rename(
    columns={"pango": "clade", "seq_volume": "n_sequences", "R": "growth_rate"}
)

if (invalid_clades := set(growth_rates["clade"]) - set(pango_clades)):
    raise ValueError(f"Growth rates specified for {invalid_clades}")

pango_pair_dms_growth_df = (
    pango_pair_dms_df
    .merge(growth_rates, on="clade", validate="one_to_one")
    .merge(
        growth_rates.rename(
            columns={
                "clade": "parent",
                "growth_rate": "parent_growth_rate",
                "n_sequences": "parent_n_sequences",
            }
        ),
        on="parent",
        validate="many_to_one",
    )
    .assign(change_in_growth_rate=lambda x: x["growth_rate"] - x["parent_growth_rate"])
)

print(f"Saving data to {growth_dms_csv}")
pango_pair_dms_growth_df.to_csv(growth_dms_csv, float_format="%.5g", index=False)

print("Number of clade pairs with growth rates:")
display(
    pango_pair_dms_growth_df
    .assign(
        **{
            f"at least {min_sequences} sequences": lambda x: (
                x[["n_sequences", "parent_n_sequences"]].min(axis=1) >= min_sequences
            )
        }
    )
    .groupby([f"at least {min_sequences} sequences", "only_new_mut_is_spike"])
    .aggregate(n_clade_pairs=pd.NamedAgg("clade", "count"))
)

## Phenotype correlations
Plot correlations between changes in DMS phenotype between all clade pairs.
Do this both with and without stratifying by whether the mutation is in the RBD.

In [None]:
clade_selection = alt.selection_point(fields=["clade"], on="mouseover", empty=False)

n_sequences_min = int(
    10 * math.log10(
        pango_pair_dms_growth_df[["n_sequences", "parent_n_sequences"]].min(axis=None)
    )
) / 10
n_sequences_slider = alt.param(
    value=math.log10(min_sequences),
    bind=alt.binding_range(
        name="min log10 number sequences in both clades",
        min=n_sequences_min,
        max=math.log10(pango_pair_dms_growth_df["n_sequences"].max() / 10),
        step=0.1,
    ),
)

new_mut_is_rbd = alt.selection_point(
    fields=["new_mut_is_RBD"],
    value=None,
    bind=alt.binding_radio(
        options=[None, True, False],
        labels=["all spike", "RBD only", "non-RBD only"],
        name="mutations to include",
    ),
)

only_new_mut_is_spike = alt.selection_point(
    fields=["only_new_mut_is_spike"],
    value=None,
    bind=alt.binding_radio(
        options=[True, False, None],
        labels=["only a spike mutation", "spike & non-spike mutations", "either"],
        name="include clade pairs separated by",
    ),
)

toggle_muts = {
    mut: alt.selection_point(
        fields=[f"not_{mut}"],
        value=init_value,
        bind=alt.binding_radio(
            name=f"exclude clade pairs separated by {mut}",
            labels=["yes", "no"],
            options=[True, None],
        ),
    )
    for (mut, init_value) in muts_to_toggle.items()
}

phenotype_scatter_base = (
    alt.Chart(pango_pair_dms_growth_df)
    .transform_calculate(
        min_n_sequences=alt.expr.min(
            alt.datum["n_sequences"], alt.datum["parent_n_sequences"]
        )
    )
    .transform_filter(
        alt.expr.log(alt.datum["min_n_sequences"]) / math.log(10) >= n_sequences_slider
    )
    .transform_filter(new_mut_is_rbd)
    .transform_filter(only_new_mut_is_spike)
    .transform_calculate(
        **{f"not_{mut}": alt.datum["new_spike_mut"] != mut for mut in toggle_muts}
    )
    .transform_filter(
        functools.reduce(operator.and_, [toggle for toggle in toggle_muts.values()])
    )
    .add_params(
        clade_selection,
        n_sequences_slider,
        new_mut_is_rbd,
        only_new_mut_is_spike,
        *toggle_muts.values(),
    )
)

tooltips = [
    "clade",
    "parent",
    "date",
    "new_spike_mut",
    "all_new_muts",
    *[
        alt.Tooltip(p, format=".1f")
        for p in ["change_in_growth_rate", "growth_rate", "parent_growth_rate"]
    ],
    "n_sequences",
    "parent_n_sequences",
]

phenotype_scatter_charts = []
for pheno1, pheno2 in itertools.combinations(phenotypes, 2):
    phenotype_scatter_points = (
        phenotype_scatter_base
        .encode(
            alt.X(pheno1, scale=alt.Scale(nice=False, padding=5)),
            y=alt.Y(pheno2, scale=alt.Scale(nice=False, padding=5)),
            tooltip=[
                *[alt.Tooltip(p, format=".2f") for p in phenotypes],
                *tooltips,
            ],
            size=alt.condition(clade_selection, alt.value(70), alt.value(35)),
            strokeWidth=alt.condition(clade_selection, alt.value(2), alt.value(0.5)),
            stroke=alt.condition(clade_selection, alt.value("red"), alt.value("black")),
        )
        .mark_circle(fill="black", strokeOpacity=1, fillOpacity=0.35)
    )
    phenotype_scatter_r = (
        phenotype_scatter_base
        .transform_regression(pheno1, pheno2, params=True)
        .transform_calculate(
            r=alt.expr.if_(
                alt.datum["coef"][1] >= 0,
                alt.expr.sqrt(alt.datum["rSquared"]),
                -alt.expr.sqrt(alt.datum["rSquared"]),
            ),
            label='"r = " + format(datum.r, ".2f")',
        )
        .mark_text(align="left", color="purple", fontWeight=500, fontSize=11, opacity=1)
        .encode(x=alt.value(3), y=alt.value(7), text=alt.Text("label:N"))
        .properties(width=110, height=110)
    )
    phenotype_scatter_charts.append(phenotype_scatter_points + phenotype_scatter_r)

phenotype_scatter_chart = (
    alt.hconcat(*phenotype_scatter_charts, spacing=9)
    .configure_axis(grid=False)
    .properties(        
        title=alt.TitleParams(
            "Changes in DMS phenotypes for mutations separating clade pairs",
            anchor="middle",
            fontSize=14,
            dy=-3,
        ),
    )    
)

phenotype_scatter_chart

## Plot correlations between growth and each phenotype
Correlations of growth rate versus change in each DMS phenotype for each clade pair:

In [None]:
pango_pair_dms_growth_df_tidy = pango_pair_dms_growth_df.melt(
    id_vars=[c for c in pango_pair_dms_growth_df.columns if c not in phenotypes],
    value_vars=phenotypes,
    var_name="phenotype_name",
    value_name="phenotype",
)

growth_phenotype_chart_size = 110

growth_phenotype_corr_base = (
    alt.Chart(pango_pair_dms_growth_df_tidy)
    .transform_calculate(
        min_n_sequences=alt.expr.min(
            alt.datum["n_sequences"], alt.datum["parent_n_sequences"]
        )
    )
    .transform_filter(
        alt.expr.log(alt.datum["min_n_sequences"]) / math.log(10) >= n_sequences_slider
    )
    .transform_filter(new_mut_is_rbd)
    .transform_filter(only_new_mut_is_spike)
    .transform_calculate(
        **{f"not_{mut}": alt.datum["new_spike_mut"] != mut for mut in toggle_muts}
    )
    .transform_filter(
        functools.reduce(operator.and_, [toggle for toggle in toggle_muts.values()])
    )
    .add_params(
        clade_selection,
        n_sequences_slider,
        new_mut_is_rbd,
        only_new_mut_is_spike,
        *toggle_muts.values(),
    )
    .properties(width=growth_phenotype_chart_size, height=growth_phenotype_chart_size)
)

growth_phenotype_scatter = (
    growth_phenotype_corr_base
    .encode(
        alt.X(
            "change_in_growth_rate",
            title="change in growth",
            scale=alt.Scale(nice=False, padding=5),
        ),
        alt.Y("phenotype:Q", title=None, scale=alt.Scale(nice=False, padding=5)),
        alt.Fill("phenotype_name:N", sort=phenotypes, legend=None),
        size=alt.condition(clade_selection, alt.value(70), alt.value(35)),
        strokeWidth=alt.condition(clade_selection, alt.value(2), alt.value(0.5)),
        tooltip=["phenotype_name:O", alt.Tooltip("phenotype:Q", format=".2f"), *tooltips],
    )
    .mark_circle(stroke="black", strokeOpacity=1, fillOpacity=0.45)
)

growth_phenotype_r = (
    growth_phenotype_corr_base
    .transform_regression("change_in_growth_rate", "phenotype", params=True)
    .transform_calculate(
        r=alt.expr.if_(
            alt.datum["coef"][1] >= 0,
            alt.expr.sqrt(alt.datum["rSquared"]),
            -alt.expr.sqrt(alt.datum["rSquared"]),
        ),
        label='"r = " + format(datum.r, ".2f")',
    )
    .mark_text(align="right", color="black", fontWeight=500, fontSize=11, opacity=1)
    .encode(
        x=alt.value(growth_phenotype_chart_size - 3),
        y=alt.value(growth_phenotype_chart_size - 5),
        text=alt.Text("label:N"),
    )
)

for facet_by_rbd in [False, True]:

    print(f"\n\n{facet_by_rbd=}:\n")
    
    growth_phenotype_chart = (
        (growth_phenotype_scatter + growth_phenotype_r)
        .facet(
            column=alt.Column(
                "phenotype_name:O",
                sort=phenotypes,
                header=alt.Header(
                    orient="left",
                    labelFontStyle="bold",
                    labelFontSize=11,
                    labelPadding=2,
                    title=None,
                ),
            ),
            row=(
                alt.Row(
                    "new_mut_is_RBD",
                    sort="descending",
                    header=alt.Header(
                        orient="right",
                        labelFontSize=12,
                        labelFontStyle="bold",
                        title=None,
                        labelExpr="if(datum.value, 'RBD mutation', 'non-RBD mutation')",
                    ),
                )
                if facet_by_rbd
                else alt.Row()
            ),
            spacing=8,
        )
        .properties(
            title=alt.TitleParams(
                "Changes in growth rate versus DMS phenotypes for clade pairs",
                anchor="middle",
                fontSize=14,
                dy=-3,
            ),
        )
        .resolve_scale(y="independent")
        .configure_axis(grid=False)
    )

    display(growth_phenotype_chart)

## Multiple least squares regression of growth versus DMS phenotypes

In [None]:
def powerset(iterable):
    # https://stackoverflow.com/a/1482316
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    return itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(len(s)+1))

def ols_unique_var_explained(var_endog, vars, df, full_r2):
    """Get unique variance explained by fitting model after removing each variable.

    https://blog.minitab.com/en/adventures-in-statistics-2/how-to-identify-the-most-important-predictor-variables-in-regression-models
    
    """
    unique_var = {}
    for vremove in vars:
        vremove_ols_model = statsmodels.api.OLS(
            endog=df[[var_endog]],
            exog=statsmodels.api.add_constant(df[[v for v in vars if v != vremove]].astype(float)),
        )
        vremove_res_ols = vremove_ols_model.fit()
        unique_var[vremove] = full_r2 - vremove_res_ols.rsquared
    return unique_var

for only_new_mut_is_spike in [False]:
    ols_df1 = (
        pango_pair_dms_growth_df.query("only_new_mut_is_spike")
        if only_new_mut_is_spike
        else pango_pair_dms_growth_df
    ).copy()
    for muts_excluded in powerset(muts_to_toggle):
        ols_df = (
            ols_df1[
                functools.reduce(
                    operator.and_,
                    [ols_df1["new_spike_mut"] != m for m in muts_excluded],
                )
            ]
            if muts_excluded
            else ols_df1
        ).query("n_sequences >= @min_sequences").query("parent_n_sequences >= @min_sequences")
        n = len(ols_df)
        print(f"\n{min_sequences=}, {only_new_mut_is_spike=}, {muts_excluded=}:")

        # https://www.einblick.ai/python-code-examples/ordinary-least-squares-regression-statsmodels/
        ols_model = statsmodels.api.OLS(
            endog=ols_df[["change_in_growth_rate"]],
            exog=statsmodels.api.add_constant(ols_df[phenotypes].astype(float)),
        )
        res_ols = ols_model.fit()
        ols_df = ols_df.assign(predicted_change_in_growth_rate=res_ols.predict())
        r2 = res_ols.rsquared
        r = math.sqrt(r2)
        unique_var = ols_unique_var_explained("change_in_growth_rate", phenotypes, ols_df, r2)

        subtitle = (
            [
            ]
            + [
                # https://stackoverflow.com/a/53966201
                f"{p}: {unique_var[p] * 100:.0f}% of variance (coef {res_ols.params[p]:.0f} \u00B1 {res_ols.bse[p]:.0f})"
                for p in phenotypes
            ]
        )

        ols_chart = (
            alt.Chart(ols_df)
            .add_params(clade_selection)
            .encode(
                alt.X(
                    "predicted_change_in_growth_rate",
                    title="predicted change in growth",
                    scale=alt.Scale(nice=False, padding=4),
                ),
                alt.Y(
                    "change_in_growth_rate",
                    title="actual change in growth",
                    scale=alt.Scale(nice=False, padding=4),
                ),
                size=alt.condition(clade_selection, alt.value(90), alt.value(55)),
                strokeWidth=alt.condition(clade_selection, alt.value(2), alt.value(0.5)),
                stroke=alt.condition(clade_selection, alt.value("red"), alt.value("black")),
                tooltip=tooltips,
            )
            .mark_circle(strokeOpacity=1, fillOpacity=0.35, fill="black", stroke="black")
            .properties(
                width=165,
                height=165,
                title=alt.TitleParams(
                    f"OLS regression r = {r:.2f} (n = {n})",
                    subtitle=subtitle,
                    fontSize=12,
                    subtitleFontSize=9,
                ),
            )
            .configure_axis(grid=False)
        )

        display(ols_chart)