# Compare DMS to natural sequence evolution
The basic approach is to find all pairs of parent-descendant Pango clades that differ by only a single spike amino-acid substitution, and then compare the differences in growth rates to changes in spike phenotypes measured by DMS.

This approach of comparing clade pairs is better than trying to compare all clades based on their phenotypes as it avoids phylogenetic correlations because it only utilizes the new mutation that has appeared in each parent / descendant clade pair rather than all mutations in clades (the latter approach is confounded by phylogeny due to clades sharing mutations by ancestry).

In [1]:
# this cell is tagged parameters for papermill parameterization
dms_summary_csv = None
growth_rates_csv = None
pango_consensus_seqs_json = None
starting_clades = None
dms_clade = None
exclude_clades = None
muts_to_toggle = None

In [2]:
# Parameters
starting_clades = ["XBB"]
dms_clade = "XBB.1.5"
dms_summary_csv = "results/summaries/summary.csv"
growth_rates_csv = "MultinomialLogisticGrowth/model_fits/rates.csv"
pango_consensus_seqs_json = (
    "results/compare_natural/pango-consensus-sequences_summary.json"
)
exclude_clades = []
muts_to_toggle = {"L455F": True}

import os
os.chdir("../")


In [3]:
import collections
import functools
import itertools
import json
import math
import operator

import altair as alt

import numpy

import pandas as pd

import polyclonal.plot

import scipy.stats

import statsmodels.api

_ = alt.data_transformers.disable_max_rows()

## Read Pango clades and identify all pairs separated by a single spike mutation
First, read all Pango clades and get their new mutations relative to parents:

In [4]:
with open(pango_consensus_seqs_json) as f:
    pango_clades = json.load(f)

def build_records(c, recs):
    """Build records of Pango clade information."""
    if c in recs["clade"]:
        return
    recs["clade"].append(c)
    recs["date"].append(pango_clades[c]["designationDate"])
    recs["parent"].append(pango_clades[c]["parent"])
    recs["all_new_muts_from_ref"].append(
        [
            mut
            for field in ["aaSubstitutionsNew", "aaDeletionsNew"]
            for mut in pango_clades[c][field]
            if mut
        ]
    )
    recs["all_new_muts_reverted_from_ref"].append(
        [
            mut
            for field in ["aaSubstitutionsReverted", "aaDeletionsReverted"]
            for mut in pango_clades[c][field]
            if mut
        ]
    )
    for c_child in pango_clades[c]["children"]:
        build_records(c_child, recs)
        
records = collections.defaultdict(list)
for starting_clade in starting_clades:
    build_records(starting_clade, records)

pango_df = pd.DataFrame(records).query("clade not in @exclude_clades")

Now get all clade pairs that differ by just a single spike mutation between the parent and descendant clade.
Also note whether they have any additional mutations:

In [5]:
def consolidate_reversions(r):
    """If there are reversions, combine with new mutations to get actual changes."""
    reverted = r["all_new_muts_reverted_from_ref"]
    new = r["all_new_muts_from_ref"]
    if not reverted:
        return new
    # get as dicts with key: val of (gene, site): (wt, mutant)
    reverted_dict = {
        (m.split(":")[0], m.split(":")[1][1: -1]): (m.split(":")[1][0], m.split(":")[1][-1])
        for m in reverted
    }
    new_dict = {
        (m.split(":")[0], m.split(":")[1][1: -1]): (m.split(":")[1][0], m.split(":")[1][-1])
        for m in new
    }
    muts = []
    for (gene, site), (rev_wt, rev_mutant) in reverted_dict.items():
        if (gene, site) in new_dict:
            new_wt, new_mutant = new_dict[(gene, site)]
            muts.append(f"{gene}:{rev_mutant}{site}{new_mutant}")
            del new_dict[(gene, site)]
        else:
            muts.append(f"{gene}:{rev_mutant}{site}{rev_wt}")
    for (gene, site), (new_wt, new_mutant) in new_dict.items():
        muts.append(f"{gene}:{new_wt}{site}{new_mutant}")
    return muts

pango_pair_df = (
    pango_df
    .assign(
        all_new_muts=lambda x: x.apply(consolidate_reversions, axis=1),
        spike_new_muts=lambda x: x["all_new_muts"].map(lambda ms: [m.split(":")[1] for m in ms if m[0] == "S"]),
        n_new_spike_muts=lambda x: x["spike_new_muts"].map(len),
        n_new_all_muts=lambda x: x["all_new_muts"].map(len),
    )
    .query("n_new_spike_muts == 1")
    .assign(
        new_spike_mut=lambda x: x["spike_new_muts"].map(lambda ms: ms[0]),
        all_new_muts=lambda x: x["all_new_muts"].map(lambda ms: "; ".join(ms)),
        only_new_mut_is_spike=lambda x: x["n_new_spike_muts"] == x["n_new_all_muts"],
    )
    [["clade", "parent", "date", "new_spike_mut", "all_new_muts", "only_new_mut_is_spike"]]
    .reset_index(drop=True)
)

print("Number of clade pairs differing by one spike mutation:")
display(
    pango_pair_df
    .groupby("only_new_mut_is_spike")
    .aggregate(n_clade_pairs=pd.NamedAgg("clade", "count"))
)

Number of clade pairs differing by one spike mutation:


Unnamed: 0_level_0,n_clade_pairs
only_new_mut_is_spike,Unnamed: 1_level_1
False,179
True,186


## Assign changes in DMS phenotypes to Pango clade pairs
For each clade pair, we compute the change in DMS phenotype:

In [6]:
# read the DMS data
dms_summary = pd.read_csv(dms_summary_csv).rename(
    columns={
        "spike mediated entry": "cell entry",
        "human sera escape": "sera escape",
    }
)

# specify DMS phenotypes of interest
phenotypes = [
    "sera escape",
    "ACE2 affinity",
    "cell entry",
]
assert set(phenotypes).issubset(dms_summary.columns)

# dict that maps site to wildtype in DMS
dms_wt = dms_summary.set_index("site")["wildtype"].to_dict()

# dict that maps site to region in DMS
site_to_region = dms_summary.set_index("site")["region"].to_dict()

# dict that maps (site, wt, mutant) to DMS phenotypes
dms_data_dict = (
    dms_summary
    .set_index(["site", "wildtype", "mutant"])
    [phenotypes]
    .to_dict(orient="index")
)

def mut_dms(m):
    """Get DMS phenotypes for a mutation, and also whether it is in the RBD."""
    null_d = {k: pd.NA for k in phenotypes}
    site = int(m[1: -1])
    if site not in dms_wt:
        d = null_d
        d["new_mut_is_RBD"] = pd.NA
    else:
        parent = m[0]
        mut = m[-1]
        wt = dms_wt[site]
        if parent == wt:
            try:
                d = dms_data_dict[(site, parent, mut)]
            except KeyError:
                d = null_d
        elif mut == wt:
            try:
                d = {k: -v for (k, v) in dms_data_dict[(site, mut, parent)].items()}
            except KeyError:
                d = null_d
        else:
            try:
                parent_d = dms_data_dict[(site, wt, parent)]
                mut_d = dms_data_dict[(site, wt, mut)]
                d = {p: mut_d[p] - parent_d[p] for p in phenotypes}
            except KeyError:
                d = null_d
        d["new_mut_is_RBD"] = (site_to_region[site] == "RBD")
    assert list(d) == phenotypes + ["new_mut_is_RBD"]
    return d

pango_pair_dms_df = (
    pango_pair_df
    # to add multiple columns: https://stackoverflow.com/a/46814360
    .apply(
        lambda cols: pd.concat([cols, pd.Series(mut_dms(cols["new_spike_mut"]))]),
        axis=1,
    )
    # remove any clade pairs for which we don't have DMS data for all phenotypes
    .query(" and ".join(f"`{p}`.notnull()" for p in phenotypes))
    .reset_index(drop=True)
)

print("Number of clade pairs with DMS data for all phenotypes:")
display(
    pango_pair_dms_df
    .groupby("only_new_mut_is_spike")
    .aggregate(n_clade_pairs=pd.NamedAgg("clade", "count"))
)

Number of clade pairs with DMS data for all phenotypes:


Unnamed: 0_level_0,n_clade_pairs
only_new_mut_is_spike,Unnamed: 1_level_1
False,169
True,183


## Assign changes in growth rate to Pango pair clades
For each pair clade, we compute the change in growth rate.

In [7]:
growth_rates = pd.read_csv(growth_rates_csv).rename(
    columns={"pango": "clade", "seq_volume": "n_sequences", "R": "growth_rate"}
)

if (invalid_clades := set(growth_rates["clade"]) - set(pango_clades)):
    raise ValueError(f"Growth rates specified for {invalid_clades}")

pango_pair_dms_growth_df = (
    pango_pair_dms_df
    .merge(growth_rates, on="clade", validate="one_to_one")
    .merge(
        growth_rates.rename(
            columns={
                "clade": "parent",
                "growth_rate": "parent_growth_rate",
                "n_sequences": "parent_n_sequences",
            }
        ),
        on="parent",
        validate="many_to_one",
    )
    .assign(change_in_growth_rate=lambda x: x["growth_rate"] - x["parent_growth_rate"])
)

print("Number of clade pairs with growth rates:")
display(
    pango_pair_dms_growth_df
    .groupby("only_new_mut_is_spike")
    .aggregate(n_clade_pairs=pd.NamedAgg("clade", "count"))
)

Number of clade pairs with growth rates:


Unnamed: 0_level_0,n_clade_pairs
only_new_mut_is_spike,Unnamed: 1_level_1
False,45
True,70


## Phenotype correlations
Plot correlations between changes in DMS phenotype between all clade pairs.
Do this both with and without stratifying by whether the mutation is in the RBD.

In [11]:
clade_selection = alt.selection_point(fields=["clade"], on="mouseover", empty=False)

n_sequences_init = int(
    10 * math.log10(
        pango_pair_dms_growth_df[["n_sequences", "parent_n_sequences"]].min(axis=None)
    )
) / 10
n_sequences_slider = alt.param(
    value=n_sequences_init,
    bind=alt.binding_range(
        name="min log10 number sequences in both clades",
        min=n_sequences_init,
        max=math.log10(pango_pair_dms_growth_df["n_sequences"].max() / 10),
        step=0.1,
    ),
)

new_mut_is_rbd = alt.selection_point(
    fields=["new_mut_is_RBD"],
    value=None,
    bind=alt.binding_radio(
        options=[None, True, False],
        labels=["all spike", "RBD only", "non-RBD only"],
        name="mutations to include",
    ),
)

only_new_mut_is_spike = alt.selection_point(
    fields=["only_new_mut_is_spike"],
    value=None,
    bind=alt.binding_radio(
        options=[True, False, None],
        labels=["only a spike mutation", "spike & non-spike mutations", "either"],
        name="include clade pairs separated by",
    ),
)

toggle_muts = {
    mut: alt.selection_point(
        fields=[f"not_{mut}"],
        value=init_value,
        bind=alt.binding_radio(
            name=f"exclude clade pairs separated by {mut}",
            labels=["yes", "no"],
            options=[True, None],
        ),
    )
    for (mut, init_value) in muts_to_toggle.items()
}

phenotype_scatter_base = (
    alt.Chart(pango_pair_dms_growth_df)
    .transform_calculate(
        min_n_sequences=alt.expr.min(
            alt.datum["n_sequences"], alt.datum["parent_n_sequences"]
        )
    )
    .transform_filter(
        alt.expr.log(alt.datum["min_n_sequences"]) / math.log(10) >= n_sequences_slider
    )
    .transform_filter(new_mut_is_rbd)
    .transform_filter(only_new_mut_is_spike)
    .transform_calculate(
        **{f"not_{mut}": alt.datum["new_spike_mut"] != mut for mut in toggle_muts}
    )
    .transform_filter(
        functools.reduce(operator.and_, [toggle for toggle in toggle_muts.values()])
    )
    .add_params(
        clade_selection,
        n_sequences_slider,
        new_mut_is_rbd,
        only_new_mut_is_spike,
        *toggle_muts.values(),
    )
)

phenotype_scatter_charts = []
for pheno1, pheno2 in itertools.combinations(phenotypes, 2):
    phenotype_scatter_points = (
        phenotype_scatter_base
        .encode(
            alt.X(pheno1, scale=alt.Scale(nice=False, padding=5)),
            y=alt.Y(pheno2, scale=alt.Scale(nice=False, padding=5)),
            tooltip=[
                alt.Tooltip(
                    c,
                    format=".2f" if pango_pair_dms_growth_df[c].dtype == float else {},
                )
                for c in pango_pair_dms_growth_df.columns
            ],
            size=alt.condition(clade_selection, alt.value(60), alt.value(35)),
            strokeWidth=alt.condition(clade_selection, alt.value(2), alt.value(0)),
        )
        .mark_circle(stroke="red", color="black", strokeOpacity=1, fillOpacity=0.35)
    )
    phenotype_scatter_r = (
        phenotype_scatter_base
        .transform_regression(pheno1, pheno2, params=True)
        .transform_calculate(
            r=alt.expr.if_(
                alt.datum["coef"][1] >= 0,
                alt.expr.sqrt(alt.datum["rSquared"]),
                -alt.expr.sqrt(alt.datum["rSquared"]),
            ),
            label='"r = " + format(datum.r, ".2f")',
        )
        .mark_text(align="left", color="orange", fontWeight=500, fontSize=11, opacity=1)
        .encode(x=alt.value(3), y=alt.value(7), text=alt.Text("label:N"))
        .properties(width=110, height=110)
    )
    phenotype_scatter_charts.append(phenotype_scatter_points + phenotype_scatter_r)

phenotype_scatter_chart = (
    alt.hconcat(*phenotype_scatter_charts, spacing=9)
    .configure_axis(grid=False)
    .properties(        
        title=alt.TitleParams(
            "Changes in DMS phenotypes for mutations separating clade pairs",
            anchor="middle",
            fontSize=12,
            dy=-3,
        ),
    )    
)

phenotype_scatter_chart

## All mutation phenotype correlation
Now make plots like in prior section, but for mutations rather than clades.
You can select all mutations or just those found in Pango clades:

In [None]:
# get all mutations in Pango clades that are mutations or reversions
# relative to DMS measured mutations
pango_muts_or_rev = []
for m in set(pango_df.explode("muts_from_dms_clade")["muts_from_dms_clade"]):
    if not pd.isnull(m):
        pango_muts_or_rev.append(m)
        pango_muts_or_rev.append(f"{m[-1]}{m[1: -1]}{m[0]}")

In [None]:
mut_scatter_df = (
    dms_summary
    .assign(mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"])
    .drop(columns=["wildtype", "site", "mutant", "sequential_site"])
)

mut_selection = alt.selection_point(fields=["mutation"], on="mouseover", empty=False)

only_muts_in_pango = alt.param(
    value=False,
    bind=alt.binding_radio(
        options=[True, False],
        name=f"only mutations in Pango clades descended from {', '.join(starting_clades)}",
    ),
)

mut_scatter_base = (
    alt.Chart(mut_scatter_df)
    .add_params(mut_selection, only_muts_in_pango)
    .transform_calculate(keep_all=alt.expr.if_(only_muts_in_pango, False, True))
    .transform_filter(
        {
            "or": [
                alt.FieldOneOfPredicate(field="mutation", oneOf=pango_muts_or_rev),
                alt.datum["keep_all"],
            ]
        }
    )
    .properties(width=110, height=110)
)

mut_scatter_charts = []
for pheno1, pheno2 in itertools.combinations(phenotypes, 2):
    pheno_mut_scatter_chart = (
        mut_scatter_base
        .encode(
            x=alt.X(
                pheno1,
                scale=alt.Scale(nice=False, padding=5),
            ),
            y=alt.Y(
                pheno2,
                scale=alt.Scale(nice=False, padding=5),
            ),
            tooltip=[
                alt.Tooltip(
                    c,
                    format=".2f" if mut_scatter_df[c].dtype == float else {},
                )
                for c in mut_scatter_df.columns
            ],
            opacity=alt.condition(mut_selection, alt.value(1), alt.value(0.2)),
            size=alt.condition(mut_selection, alt.value(60), alt.value(30)),
            strokeWidth=alt.condition(mut_selection, alt.value(2), alt.value(0)),
        )
        .mark_circle(stroke="red", color="black")
    )
    pheno_mut_r = (
        mut_scatter_base
        .transform_regression(pheno1, pheno2, params=True)
        .transform_calculate(
            r=alt.expr.if_(
                alt.datum["coef"][1] >= 0,
                alt.expr.sqrt(alt.datum["rSquared"]),
                -alt.expr.sqrt(alt.datum["rSquared"]),
            ),
            label='"r = " + format(datum.r, ".2f")',
        )
        .mark_text(align="left", color="orange", fontWeight=500, fontSize=11, opacity=1)
        .encode(
            x=alt.value(3),
            y=alt.value(105),
            text=alt.Text("label:N"),
        )
    )
    mut_scatter_charts.append(pheno_mut_scatter_chart + pheno_mut_r)

mut_scatter_chart = (
    alt.hconcat(*mut_scatter_charts, spacing=9)
    .configure_axis(grid=False)
    .properties(        
        title=alt.TitleParams(
            "Correlations in DMS phenotypes for mutations",
            anchor="middle",
            fontSize=12,
            dy=-3,
        ),
    )    
)

mut_scatter_chart

## Correlate with clade growth

In [None]:
growth_rates = pd.read_csv(growth_rates_csv).rename(
    columns={"pango": "clade", "seq_volume": "number sequences"}
)

if (invalid_clades := set(growth_rates["clade"]) - set(pango_clades)):
    raise ValueError(f"Growth rates specified for {invalid_clades}")

pango_dms_growth_df = pango_dms_df.merge(growth_rates, on="clade", validate="many_to_one")

pango_dms_growth_df_rand = pango_dms_df_rand.merge(growth_rates, on="clade", validate="many_to_one")

print(
    f"{growth_rates['clade'].nunique()} clades have growth rates estimates.\n"
    f"{pango_dms_df['clade'].nunique()} clades have DMS estimates.\n"
    f"{pango_dms_growth_df['clade'].nunique()} clades have growth and DMS estimates"
)

print("Simple correlations:")
display(
    pango_dms_growth_df
    .groupby("DMS_phenotype")
    [["R", "phenotype", "phenotype_RBD_only", "phenotype_nonRBD_only"]]
    .corr()
    [["R"]]
)

Plot number of sequences versus date, with sizes proportional to log of number of sequences in clade:

In [None]:
(
    alt.Chart(pango_dms_growth_df)
    .encode(
        x="date",
        y="R",
        size=alt.Size("number sequences", scale=alt.Scale(type="log")),
        tooltip=pango_dms_growth_df.columns.tolist(),
    )
    .mark_circle(opacity=0.25, color="black")
)

Now perform OLS, weighting clades by log number of sequences:

In [None]:
# pivot DMS data to get phenotypes
def pivot_for_ols_vars(df):
    ols_vars = (
        df
        .rename(
            columns={
                "phenotype": "full spike",
                "phenotype_RBD_only": "RBD",
                "phenotype_nonRBD_only": "non RBD",
            }
        )
        .assign(
            # group muts missing data from all phenotypes
            muts_from_DMS_clade_missing_data=lambda x: (
                x.groupby("clade")
                [f"muts_from_{dms_clade}_missing_data"]
                .transform(
                    lambda s: "; ".join(dict.fromkeys([m for ms in s.str.split("; ") for m in ms if m]))
                )
            ),
        )
        .rename(columns={f"muts_from_{dms_clade}": "muts_from_DMS_clade"})
        .pivot_table(
            index=[
                "clade",
                "R",
                "date",
                "muts_from_DMS_clade",
                "muts_from_DMS_clade_missing_data",
                "number sequences",
            ],
            columns="DMS_phenotype",
            values=["full spike", "RBD", "non RBD"],
        )
    )
    # flatten column names
    assert all(len(c) == 2 for c in ols_vars.columns.values)
    ols_vars.columns = [f"{pheno} ({domain})" for domain, pheno in ols_vars.columns.values]
    return ols_vars.reset_index()


def ols_unique_var_explained(var_endog, var_ols, vars, weight_var, full_r2):
    """Get unique variance explained by fitting model after removing each variable.

    https://blog.minitab.com/en/adventures-in-statistics-2/how-to-identify-the-most-important-predictor-variables-in-regression-models
    
    """
    unique_var = {}
    for vremove in vars:
        vremove_ols_model = statsmodels.api.WLS(
            endog=var_endog,
            exog=statsmodels.api.add_constant(ols_vars[[v for v in vars if v != vremove]]),
            weights=weights,
        )
        vremove_res_ols = vremove_ols_model.fit()
        unique_var[vremove] = full_r2 - vremove_res_ols.rsquared
    return unique_var

ols_vars = pivot_for_ols_vars(pango_dms_growth_df)

# https://www.einblick.ai/python-code-examples/ordinary-least-squares-regression-statsmodels/
for name, exog_vars, regression_chartfile, corr_chartfile in [
    (
        "full spike",
        [f"{c} (full spike)" for c in phenotypes],
        pango_dms_vs_growth_regression_html,
        pango_dms_vs_growth_corr_html
    ),
    (
        "separate RBD and non-RBD",
        [f"{c} ({d})" for d in ["RBD", "non RBD"] for c in phenotypes],
        pango_dms_vs_growth_regression_by_domain_html,
        pango_dms_vs_growth_corr_by_domain_html,
    ),
]:
    print(f"\n\nFitting for {name}:")
    weights = numpy.log(ols_vars["number sequences"])**2  # weight by log n sequences, so pass log**2
    ols_model = statsmodels.api.WLS(
        endog=ols_vars[["R"]],
        exog=statsmodels.api.add_constant(ols_vars[exog_vars]),
        weights=weights,
    )
    res_ols = ols_model.fit()
    unique_var = ols_unique_var_explained(ols_vars[["R"]], ols_vars, exog_vars, weights, full_r2=res_ols.rsquared)

    fitted_df = ols_vars.assign(DMS_predicted_growth=res_ols.predict())

    plot_size=160
    
    clade_selection = alt.selection_point(fields=["clade"], on="mouseover", empty=False)

    n_sequences_init = int(10 * math.log10(fitted_df["number sequences"].min())) / 10
    n_sequences_slider = alt.param(
        value=n_sequences_init,
        bind=alt.binding_range(
            name="minimum log10 number sequences in clade",
            min=n_sequences_init,
            max=math.log10(fitted_df["number sequences"].max() / 10),
        ),
    )

    # date slider: https://stackoverflow.com/a/67941109
    select_date = alt.selection_interval(encodings=["x"])
    date_slider = (
        alt.Chart(fitted_df[["clade", "date"]].drop_duplicates())
        .mark_bar(color="black")
        .encode(
            x=alt.X(
                "date",
                title="zoom bar to select clades by designation date",
                axis=alt.Axis(format="%b-%Y"),
            ),
            y=alt.Y("count()", title=["number", "clades"]),
        )
        .properties(width=1.5 * plot_size, height=45)
        .add_params(select_date)
    )
    
    base_growth_chart = (
        alt.Chart(fitted_df)
        .transform_filter(
            alt.expr.log(alt.datum["number sequences"]) / math.log(10) >= n_sequences_slider
        )
        .transform_filter(select_date)
        .encode(
            size=alt.Size(
                "number sequences",
                scale=alt.Scale(
                    type="log",
                    nice=False,
                    range=[15, 200],
                ),
                legend=alt.Legend(symbolStrokeWidth=0, symbolFillColor="gray"),
            ),
            strokeWidth=alt.condition(clade_selection, alt.value(2), alt.value(0.5)),
            strokeOpacity=alt.condition(clade_selection, alt.value(1), alt.value(0.5)),
            tooltip=[
                "clade",
                alt.Tooltip("R", title="growth rate", format=".1f"),
                alt.Tooltip("DMS_predicted_growth", title="DMS predicted growth", format=".1f"),
                alt.Tooltip("number sequences", format=".2g"),
                alt.Tooltip("date", title="designation date"),
                alt.Tooltip("muts_from_DMS_clade", title=f"muts from {dms_clade}"),
                alt.Tooltip("muts_from_DMS_clade_missing_data", title="muts missing DMS data"),
                *[alt.Tooltip(v, format=".2f") for v in exog_vars],  
            ],
        )
        .properties(width=plot_size, height=plot_size)
        .add_params(clade_selection, n_sequences_slider)
    )

    growth_charts = []
    simple_corr_charts = []
    for i, (dms_pheno, pheno) in enumerate(zip(
        exog_vars,
        itertools.cycle(phenotypes)
    )):
        assert dms_pheno.startswith(pheno)
        base_pheno_chart = (
            base_growth_chart
            .encode(
                y=alt.Y(
                    "R",
                    title="actual clade growth rate",
                    scale=alt.Scale(nice=False, padding=5, zero=False),
                    axis=None if i % len(phenotypes) else alt.Axis(),
                ),
            )
        )

        growth_charts.append(
            base_pheno_chart
            .encode(
                x=alt.X(
                    "DMS_predicted_growth",
                    title="DMS predicted clade growth",
                    scale=alt.Scale(nice=False, padding=5, zero=False),
                ),
                color=alt.Color(
                    dms_pheno,
                    title=None,
                    legend=alt.Legend(
                        orient="top",
                        titleFontSize=12,
                        gradientLength=plot_size,
                        gradientThickness=10,
                        offset=5,
                        tickCount=3,
                    ),
                    scale=alt.Scale(
                        range=polyclonal.plot.color_gradient_hex("lightgray", phenotype_colors[pheno], 40),
                        nice=False,
                    ),
                ),
            )
            .mark_circle(stroke="black", fillOpacity=0.6)
            .properties(
                title=alt.TitleParams(
                    text=dms_pheno,
                    subtitle=[
                        f"unique variance explained: {unique_var[dms_pheno] * 100:.1f}%",
                        f"coefficient: {res_ols.params[dms_pheno]:.1f} "
                        # https://stackoverflow.com/a/53966201
                        + f"\u00B1 {res_ols.bse[dms_pheno]:.1f}"
                    ],
                    subtitleFontSize=11,
                ),
            )
        )

        # get real and randomized P-value for simple correlations
        pheno_r, _ = scipy.stats.pearsonr(fitted_df["R"], fitted_df[dms_pheno])
        rand_rs = [
            scipy.stats.pearsonr(pivot_for_ols_vars(rand_df)["R"], pivot_for_ols_vars(rand_df)[dms_pheno])[0]
            for _, rand_df in pango_dms_growth_df_rand.groupby("randomize")
        ]
        rand_p = sum(pheno_r <= r for r in rand_rs) / len(rand_rs)
            
        simple_corr_charts.append(
            base_pheno_chart
            .transform_calculate(color_phenotype=f"'{pheno}'")
            .encode(
                x=alt.X(
                    dms_pheno,
                    scale=alt.Scale(nice=False, padding=5, zero=False),
                ),
                color=alt.Color(
                    "color_phenotype:N",
                    scale=alt.Scale(
                        range=list(phenotype_colors.values()),
                        domain=list(phenotype_colors.keys()),
                    ),
                    legend=None,
                ),
            )
            .mark_circle(stroke="black", fillOpacity=0.3, color=phenotype_colors[pheno])
            .properties(
                title=alt.TitleParams(
                    text=dms_pheno,
                    subtitle=f"Pearson r = {pheno_r:.2f} (P = {rand_p:.1g})",
                    subtitleFontSize=11,
                ),
            )
        )
            
    actual_r = math.sqrt(res_ols.rsquared)
    assert len(growth_charts) % len(phenotypes) == 0
    growth_chart = (
        alt.vconcat(
            alt.vconcat(
                *[
                    alt.hconcat(
                        *growth_charts[i * len(phenotypes): (i + 1) * len(phenotypes)], spacing=13
                    ).resolve_scale(color="independent")
                    for i in range(len(growth_charts) // len(phenotypes))
                ],
                spacing=13,
            ),
            date_slider,
        )
        .properties(
            title=alt.TitleParams(
                f"Weighted linear regression of DMS phenotypes vs clade growth (Pearson r = {actual_r:.2f})",
                subtitle=f"Showing all Pango clades descended from {', '.join(starting_clades)}",
                anchor="middle",
                fontSize=13,
                dy=-5,
            ),
        )
        .configure_axis(grid=False)
    )

    simple_corr_chart = (
        alt.vconcat(
            alt.vconcat(
                *[
                    alt.hconcat(
                        *simple_corr_charts[i * len(phenotypes): (i + 1) * len(phenotypes)], spacing=13
                    )
                    for i in range(len(simple_corr_charts) // len(phenotypes))
                ],
                spacing=13,
            ),
            date_slider,
        )
        .properties(
            title=alt.TitleParams(
                "Correlations of DMS phenotypes vs clade growth",
                subtitle=f"Showing all Pango clades descended from {', '.join(starting_clades)}",
                anchor="middle",
                fontSize=13,
                dy=-5,
            ),
        )
        .configure_axis(grid=False)
    )
    
    display(growth_chart)
    print(f"Saving to {regression_chartfile}")
    growth_chart.save(regression_chartfile)

    display(simple_corr_chart)
    print(f"Saving to {corr_chartfile}")
    simple_corr_chart.save(corr_chartfile)

    # fit randomized models and compute P-value for WLS based on R values
    print("Computing P-value for WLS randomizations")
    rand_r = []
    for randomseed, rand_df in pango_dms_growth_df_rand.groupby("randomize"):
        rand_ols_vars = pivot_for_ols_vars(rand_df)
        rand_ols_model = statsmodels.api.WLS(
            endog=rand_ols_vars[["R"]],
            exog=statsmodels.api.add_constant(rand_ols_vars[exog_vars]),
            # weight by log n sequences, so pass log**2
            weights=numpy.log(rand_ols_vars["number sequences"])**2,
        )
        rand_res_ols = rand_ols_model.fit()
        rand_r.append(math.sqrt(rand_res_ols.rsquared))
    n_rand_ge = sum(r >= actual_r for r in rand_r)
    pval = f"= {n_rand_ge / len(rand_r)}" if n_rand_ge else f"< {1 / len(rand_r)}"
    
    rand_r_hist = (
        alt.Chart(pd.DataFrame({"r": rand_r}))
        .encode(
            x=alt.X(
                "r",
                title="Pearson r",
                bin=alt.BinParams(step=0.02, extent=(0, 1)),
                scale=alt.Scale(domain=(0, 1)),
                axis=alt.Axis(values=[0, 0.2, 0.4, 0.6, 0.8, 1]),
            ),
            y=alt.Y("count()", title="number of randomizations"),
        )
        .mark_bar(color="black", opacity=0.65, align="right")
        .properties(width=250, height=130)
    )
    
    actual_r_line = (
        alt.Chart(pd.DataFrame({"r": [actual_r]}))
        .encode(x="r")
        .mark_rule(size=2, color="red", strokeDash=[4, 2])
    )
    
    pval_chart = (
        (rand_r_hist + actual_r_line)
        .configure_axis(grid=False)
        .properties(
            title=alt.TitleParams(
                f"P {pval}",
                subtitle=f"{n_rand_ge} of {len(rand_r)} randomizations 	\u2265 actual r of {actual_r:.2f}",
            ),
        )
    )
    
    display(pval_chart)

## Analyze all clades that differ from another clade by just one amino-acid mutation in spike
Get all clades that differ from their parent by just one spike mutation, also indicating if the spike mutation is only difference or there are also other non-spike differences:

In [None]:
def consolidate_reversions(r):
    """If there are reversions, combine with new mutations to get actual changes."""
    reverted = r["all_new_muts_reverted_from_ref"]
    new = r["all_new_muts_from_ref"]
    if not reverted:
        return new
    # get as dicts with key: val of (gene, site): (wt, mutant)
    reverted_dict = {
        (m.split(":")[0], m.split(":")[1][1: -1]): (m.split(":")[1][0], m.split(":")[1][-1])
        for m in reverted
    }
    new_dict = {
        (m.split(":")[0], m.split(":")[1][1: -1]): (m.split(":")[1][0], m.split(":")[1][-1])
        for m in new
    }
    muts = []
    for (gene, site), (rev_wt, rev_mutant) in reverted_dict.items():
        if (gene, site) in new_dict:
            new_wt, new_mutant = new_dict[(gene, site)]
            muts.append(f"{gene}:{rev_mutant}{site}{new_mutant}")
            del new_dict[(gene, site)]
        else:
            muts.append(f"{gene}:{rev_mutant}{site}{rev_wt}")
    for (gene, site), (new_wt, new_mutant) in new_dict.items():
        muts.append(f"{gene}:{new_wt}{site}{new_mutant}")
    return muts

pango_w_one_mut_from_parent = (
    pango_df
    .drop(columns=["n_child_clades", "muts_from_dms_clade"])
    .assign(
        all_new_muts=lambda x: x.apply(consolidate_reversions, axis=1),
        spike_new_muts=lambda x: x["all_new_muts"].map(lambda ms: [m.split(":")[1] for m in ms if m[0] == "S"]),
        n_new_spike_muts=lambda x: x["spike_new_muts"].map(len),
        n_new_all_muts=lambda x: x["all_new_muts"].map(len),
    )
    .query("n_new_spike_muts == 1")
    .assign(
        new_spike_mut=lambda x: x["spike_new_muts"].map(lambda ms: ms[0]),
        all_new_muts=lambda x: x["all_new_muts"].map(lambda ms: "; ".join(ms)),
        only_new_mut_is_spike=lambda x: x["n_new_spike_muts"] == x["n_new_all_muts"],
    )
    [["clade", "parent", "date", "new_spike_mut", "all_new_muts", "only_new_mut_is_spike"]]
    .reset_index(drop=True)
)

Add differences between clade and parent in DMS phenotype and growth rates:

In [None]:
pango_dms_growth_df

In [None]:
pango_dms_w_one_mut_from_parent = (
    pango_w_one_mut_from_parent
    .merge(
        (
            pango_dms_growth_df
            .rename(
                columns={f"muts_from_{dms_clade}_missing_data": "muts_missing_data"}
            )
            [["clade", "DMS_phenotype", "phenotype", "R", "number sequences", "muts_missing_data"]]
        ),
        on="clade",
        validate="one_to_many",
    )
    .merge(
        (
            pango_dms_growth_df
            .rename(
                columns={
                    "clade": "parent",
                    "phenotype": "parent_phenotype",
                    "R": "parent_R",
                    "number sequences": "parent number sequences",
                    f"muts_from_{dms_clade}_missing_data": "parent_muts_missing_data",
                }
            )
            [["parent", "DMS_phenotype", "parent_phenotype", "parent_R", "parent number sequences", "parent_muts_missing_data"]]
        ),
        on=["parent", "DMS_phenotype"],
        validate="many_to_one",
    )
    .assign(
        delta_phenotype=lambda x: x["phenotype"] - x["parent_phenotype"],
        delta_R=lambda x: x["R"] - x["parent_R"],
        site_missing_data=lambda x: x.apply(
            lambda r: set(
                m[1: -1]
                for m in r["muts_missing_data"].split(";") + r["parent_muts_missing_data"].split(";")
                if m
            ),
            axis=1,
        ),
        missing_data=lambda x: x.apply(
            lambda r: r["new_spike_mut"][1 : -1] in r["site_missing_data"],
            axis=1,
        ),
    )
    .query("not missing_data")
    .reset_index(drop=True)
    .drop(columns=["phenotype", "parent_phenotype", "R", "parent_R", "muts_missing_data", "parent_muts_missing_data", "site_missing_data", "missing_data"])
)

pango_dms_w_one_mut_from_parent.head()

In [None]:
pivoted_one_mut_from_parent = (
    pango_dms_w_one_mut_from_parent
    .query("new_spike_mut != 'L455F'")
#    .query("only_new_mut_is_spike")
    .pivot_table(
        index=["clade", "delta_R"],
        values="delta_phenotype",
        columns="DMS_phenotype",
    )
    .reset_index()
    .query(" and ".join(f"`{p}`.notnull()" for p in phenotypes))
)

ols_model = statsmodels.api.OLS(
    endog=pivoted_one_mut_from_parent[["delta_R"]],
    exog=statsmodels.api.add_constant(pivoted_one_mut_from_parent[phenotypes]),
)

res_ols = ols_model.fit()
display(res_ols.summary())

In [None]:
only_spike_muts = alt.param(
    value=False,
    bind=alt.binding_radio(
        options=[True, False],
        name="only show changes for clades with no non-spike differences",
    ),
)

delta_growth_dms_corr_base = (
    alt.Chart(pango_dms_w_one_mut_from_parent)
    .transform_calculate(keep_all=alt.expr.if_(only_spike_muts, False, True))
    .transform_filter(
        {
            "or": [
                alt.datum["only_new_mut_is_spike"],
                alt.datum["keep_all"],
            ]
        }
    )
    .transform_filter(
        (alt.expr.log(alt.datum["number sequences"]) / math.log(10) >= n_sequences_slider)
        and (alt.expr.log(alt.datum["parent number sequences"]) / math.log(10) >= n_sequences_slider)
    )
    .transform_filter(select_date)
    .add_params(clade_selection, n_sequences_slider, only_spike_muts)
    .properties(width=plot_size, height=plot_size)
)

delta_growth_dms_corr_scatter = (
    delta_growth_dms_corr_base
    .encode(
        alt.X(
            "delta_phenotype",
            title="change in DMS phenotype",
            scale=alt.Scale(nice=False, padding=3),
        ),
        alt.Y(
            "delta_R",
            title="change in growth rate",
            scale=alt.Scale(nice=False, padding=3),
        ),
        color=alt.Color(
            "DMS_phenotype",
            sort=phenotypes,
            scale=alt.Scale(range=[phenotype_colors[p] for p in phenotypes]),
            legend=None,
        ),
        strokeWidth=alt.condition(clade_selection, alt.value(2), alt.value(0)),
        opacity=alt.condition(clade_selection, alt.value(1), alt.value(0.3)),
        size=alt.condition(clade_selection, alt.value(80), alt.value(50)),
        tooltip=[
            "clade",
            "parent",
            alt.Tooltip("delta_R", title="change in growth", format=".1f"),
            "DMS_phenotype",
            alt.Tooltip("delta_phenotype", title="change in DMS phenotype", format=".2f"),
            "new_spike_mut",
            "all_new_muts",
            "date",
            "number sequences",
            "parent number sequences",
        ],
    )
    .mark_circle(size=40, stroke="black", strokeOpacity=1)
)

delta_growth_dms_corr_r = (
    delta_growth_dms_corr_base
    .transform_regression("delta_phenotype", "delta_R", params=True)
    .transform_calculate(
        r=alt.expr.if_(
            alt.datum["coef"][1] >= 0,
            alt.expr.sqrt(alt.datum["rSquared"]),
            -alt.expr.sqrt(alt.datum["rSquared"]),
        ),
        label='"r = " + format(datum.r, ".2f")',
    )
    .mark_text(align="left", color="black", fontWeight=500, fontSize=12)
    .encode(
        x=alt.value(3),
        y=alt.value(8),
        text=alt.Text("label:N"),
    )
)

delta_growth_dms_corr = (
    alt.vconcat(
        (
            (delta_growth_dms_corr_scatter + delta_growth_dms_corr_r)
            .facet(
                alt.Column(
                    "DMS_phenotype",
                    header=alt.Header(
                        title=None,
                        labelFontSize=13,
                        labelFontStyle="bold",
                        labelPadding=3,
                    ),
                    sort=phenotypes,
                ),
                spacing=4,
            )
            .resolve_scale(x="independent")
        ),
        date_slider,
    )
    .properties(
        title=alt.TitleParams(
            "Change in DMS phenotype vs change in growth for clades differing by one spike mutation",
            subtitle=f"Showing all Pango clades descended from {', '.join(starting_clades)}",
            anchor="middle",
            fontSize=13,
            dy=-5,
        ),
    )
    .configure_axis(grid=False)
)

delta_growth_dms_corr

## Distributions of DMS mutation effects in clades with growth estimates versus all mutations
These are mutations relative to the DMS strain.

In [None]:
muts_in_clades = collections.Counter(
    pango_dms_growth_df
    [f"muts_from_{dms_clade}"]
    .pipe(lambda s: s[s != ""])
    .str.split("; ")
    .explode()
)
print(f"There are {len(muts_in_clades)} mutations found in clades with growth estimates")

all_muts_dms = (
    dms_summary
    .query("wildtype != mutant")
    .assign(mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"])
    .assign(region=lambda x: x["region"].where(x["region"] == "RBD", "non RBD"))
    .melt(
        id_vars=["mutation", "region"],
        value_vars=phenotypes,
        var_name="DMS_phenotype",
        value_name="phenotype",
    )
    .query("phenotype.notnull()") 
)

all_muts_dms = pd.concat(
    [
        all_muts_dms.assign(mutation_type="any", count=1),
        all_muts_dms.query("mutation in @muts_in_clades").assign(
            mutation_type="in Pango clade",
            count=lambda x: x["mutation"].map(muts_in_clades),
        ),
    ]
)

for pheno in phenotypes:
    
    base_hist = (
        alt.Chart(
            all_muts_dms
            .query("DMS_phenotype == @pheno")
            .drop(columns=["DMS_phenotype", "mutation"])
        )
        .encode(
            x=alt.X("phenotype", bin=alt.BinParams(maxbins=50)),
            y=alt.Y("sum(count)", title="mutations"),
            color=alt.value(phenotype_colors[pheno]),
            row=alt.Row("mutation_type", title=None, spacing=5),
        )
        .properties(width=200, height=75, title=pheno)
        .mark_bar()
        .resolve_scale(y="independent")
    )
    display(base_hist)