# Compare DMS to natural sequence evolution

In [1]:
# this cell is tagged parameters for papermill parameterization
dms_summary_csv = None
growth_rates_csv = None
pango_consensus_seqs_json = None
starting_clades = None
dms_clade = None
n_random = None
exclude_clades = None
pango_dms_phenotypes_csv = None
pango_by_date_html = None
pango_affinity_vs_escape_html = None
pango_dms_vs_growth_html = None
pango_dms_vs_growth_by_domain_html = None

In [2]:
# Parameters
starting_clades = ["BA.2", "BA.5", "XBB"]
dms_clade = "XBB.1.5"
dms_summary_csv = "results/summaries/summary.csv"
growth_rates_csv = "data/2023-09-18_Murrell_growth_estimates.csv"
pango_consensus_seqs_json = (
    "results/compare_natural/pango-consensus-sequences_summary.json"
)
pango_dms_phenotypes_csv = "results/compare_natural/pango_dms_phenotypes.csv"
pango_by_date_html = "results/compare_natural/pango_dms_phenotypes_by_date.html"
pango_affinity_vs_escape_html = "results/compare_natural/pango_affinity_vs_escape.html"
pango_dms_vs_growth_html = "results/compare_natural/pango_dms_vs_growth.html"
pango_dms_vs_growth_by_domain_html = (
    "results/compare_natural/pango_dms_vs_growth_by_domain.html"
)
n_random = 1
exclude_clades = []

import os
os.chdir("../")

In [39]:
import collections
import itertools
import json
import math
import re

import altair as alt

import numpy

import pandas as pd

import polyclonal.plot

import statsmodels.api

_ = alt.data_transformers.disable_max_rows()

## Read Pango clades and mutations

In [4]:
with open(pango_consensus_seqs_json) as f:
    pango_clades = json.load(f)

def n_child_clades(c):
    """Get number of children clades of a Pango clade."""
    direct_children = pango_clades[c]["children"]
    return len(direct_children) + sum([n_child_clades(c_child) for c_child in direct_children])

def build_records(c, recs):
    """Build records of Pango clade information."""
    if c in recs["clade"]:
        return
    recs["clade"].append(c)
    recs["n_child_clades"].append(n_child_clades(c))
    recs["date"].append(pango_clades[c]["designationDate"])
    recs["muts_from_ref"].append(
        [
            mut.split(":")[1]
            for field in ["aaSubstitutions", "aaDeletions"]
            for mut in pango_clades[c][field]
            if mut.startswith("S:")
        ]
    )
    for c_child in pango_clades[c]["children"]:
        build_records(c_child, recs)
        
records = collections.defaultdict(list)
for starting_clade in starting_clades:
    build_records(starting_clade, records)

pango_df = pd.DataFrame(records).query("clade not in @exclude_clades")
dms_clade_mutations_from_ref = pango_df.set_index("clade").at[
    dms_clade, "muts_from_ref"
]

def mutations_from(muts, from_muts):
    """Get mutations from another sequence."""
    new_muts = set(muts).symmetric_difference(from_muts)
    assert all(re.fullmatch("[A-Z\-]\d+[A-Z\-]", m) for m in new_muts)
    new_muts_d = collections.defaultdict(list)
    for m in new_muts:
        new_muts_d[int(m[1: -1])].append(m)
    new_muts_list = []
    for _, ms in sorted(new_muts_d.items()):
        if len(ms) == 1:
            m = ms[0]
            if m in muts:
                new_muts_list.append(m)
            else:
                assert m in from_muts
                new_muts_list.append(m[-1] + m[1: -1] + m[0])
        else:
            m, from_m = ms
            if m not in muts:
                from_m, m = m, from_m
            assert m in muts and from_m in from_muts
            new_muts_list.append(from_m[-1] + m[1: ])
    return new_muts_list

pango_df = (
    pango_df
    .assign(
        muts_from_dms_clade=lambda x: x["muts_from_ref"].apply(
            mutations_from, args=(dms_clade_mutations_from_ref,),
        ),
        date=lambda x: pd.to_datetime(x["date"]),
    )
    .drop(columns="muts_from_ref")
    .sort_values("date")
    .reset_index(drop=True)
)

pango_df

Unnamed: 0,clade,n_child_clades,date,muts_from_dms_clade
0,BA.2,377,2021-12-07,"[A83V, -144Y, Q146H, E183Q, E213G, V252G, H339..."
1,BA.2.1,0,2022-02-25,"[A83V, -144Y, Q146H, E183Q, E213G, V252G, H339..."
2,BA.2.2,1,2022-02-25,"[A83V, -144Y, Q146H, E183Q, E213G, V252G, H339..."
3,BA.2.3,52,2022-02-25,"[A83V, -144Y, Q146H, E183Q, E213G, V252G, H339..."
4,BA.2.4,0,2022-03-25,"[A83V, -144Y, Q146H, E183Q, E213G, V252G, H339..."
...,...,...,...,...
1534,XBB.1.5.106,0,2023-09-17,[A623V]
1535,JG.3,0,2023-09-17,"[Q52H, L455F, F456L, S704L]"
1536,EG.5.1.9,0,2023-09-17,"[Q52H, L176F, F456L]"
1537,GK.4,0,2023-09-17,"[L455F, F456L, A475V]"


## Assign DMS phenotypes to Pango clades

First define function that assigns DMS phenotypes to mutations:

In [5]:
# read the DMS data
dms_summary = pd.read_csv(dms_summary_csv).rename(
    columns={
        "spike mediated entry": "cell entry",
        "human sera escape": "sera escape",
    }
)

# specify DMS phenotypes of interest
phenotypes = [
    "sera escape",
    "ACE2 affinity",
    "cell entry",
]
assert set(phenotypes).issubset(dms_summary.columns)

# dict that maps site to wildtype in DMS
dms_wt = dms_summary.set_index("site")["wildtype"].to_dict()

# dict that maps site to region in DMS
site_to_region = dms_summary.set_index("site")["region"].to_dict()

def mut_dms(m, dms_data):
    """Get DMS phenotypes for a mutation."""
    null_d = {k: pd.NA for k in phenotypes}
    if pd.isnull(m) or int(m[1: -1]) not in dms_wt:
        d = null_d
        d["is_RBD"] = pd.NA
    else:
        parent = m[0]
        site = int(m[1: -1])
        mut = m[-1]
        wt = dms_wt[site]
        if parent == wt:
            try:
                d = dms_data[(site, parent, mut)]
            except KeyError:
                d = null_d
        elif mut == wt:
            try:
                d = {k: -v for (k, v) in dms_data[(site, mut, parent)].items()}
            except KeyError:
                d = null_d
        else:
            try:
                parent_d = dms_data[(site, wt, parent)]
                mut_d = dms_data[(site, wt, mut)]
                d = {p: mut_d[p] - parent_d[p] for p in phenotypes}
            except KeyError:
                d = null_d
        d["is_RBD"] = (site_to_region[site] == "RBD")
    assert list(d) == phenotypes + ["is_RBD"]
    return d

Now assign phenotypes to pango clades.
We do this both using the actual DMS data and randomizing the DMS data among measured mutations:

In [6]:
def get_pango_dms_df(dms_data_dict):
    """Given dict mapping mutations to DMS data, get data frame of values for Pango clades."""
    pango_dms_df = (
        pango_df
        # put one mutation in each column
        .explode("muts_from_dms_clade")
        .rename(columns={"muts_from_dms_clade": "mutation"})
        # to add multiple columns: https://stackoverflow.com/a/46814360
        .apply(
            lambda cols: pd.concat([cols, pd.Series(mut_dms(cols["mutation"], dms_data_dict))]),
            axis=1,
        )
        .melt(
            id_vars=["clade", "date", "n_child_clades", "mutation", "is_RBD"],
            value_vars=phenotypes,
            var_name="DMS_phenotype",
            value_name="mutation_effect",
        )
        .assign(
            muts_from_dms_clade=lambda x: x.groupby(["clade", "DMS_phenotype"])["mutation"].transform(
                lambda ms: "; ".join([m for m in ms if not pd.isnull(m)])
            ),
            mutation_missing=lambda x: x["mutation"].where(
                x["mutation_effect"].isnull() & x["mutation"].notnull(),
                pd.NA,
            ),
            muts_from_dms_clade_missing_data=lambda x: (
                x.groupby(["clade", "DMS_phenotype"])["mutation_missing"]
                .transform(lambda ms: "; ".join([m for m in ms if not pd.isnull(m)]))
            ),
            mutation_effect=lambda x: x["mutation_effect"].fillna(0),
            is_RBD=lambda x: x["is_RBD"].fillna(False),
            mutation_effect_RBD=lambda x: x["mutation_effect"] * x["is_RBD"].astype(int),
            mutation_effect_nonRBD=lambda x: x["mutation_effect"] * (~x["is_RBD"]).astype(int),
        )
        .groupby(
            [
                "clade",
                "date",
                "n_child_clades",
                "muts_from_dms_clade",
                "muts_from_dms_clade_missing_data",
                "DMS_phenotype",
            ],
            as_index=False,
        )
        .aggregate(
            phenotype=pd.NamedAgg("mutation_effect", "sum"),
            phenotype_RBD_only=pd.NamedAgg("mutation_effect_RBD", "sum"),
            phenotype_nonRBD_only=pd.NamedAgg("mutation_effect_nonRBD", "sum"),
        )
        .rename(
            columns={
                "muts_from_dms_clade": f"muts_from_{dms_clade}",
                "muts_from_dms_clade_missing_data": f"muts_from_{dms_clade}_missing_data",
            },
        )
        .sort_values(["date", "DMS_phenotype"])
        .reset_index(drop=True)
    )
    
    assert set(pango_df["clade"]) == set(pango_dms_df["clade"])
    assert numpy.allclose(
        pango_dms_df["phenotype"],
        pango_dms_df["phenotype_RBD_only"] + pango_dms_df["phenotype_nonRBD_only"]
    )

    return pango_dms_df

# First, get the actual DMS data mapped to phenotype
dms_data_dict_actual = (
    dms_summary
    .set_index(["site", "wildtype", "mutant"])
    [phenotypes]
    .to_dict(orient="index")
)
pango_dms_df = get_pango_dms_df(dms_data_dict_actual)
print(f"Saving Pango DMS phenotypes to {pango_dms_phenotypes_csv}")
pango_dms_df.to_csv(pango_dms_phenotypes_csv, float_format="%.4f", index=False)

# Now get the randomized DMS data mapped to phenotype
pango_dms_dfs_rand = []
numpy.random.seed(0)
for irandom in range(1, n_random + 1):
    # randomize the non-null DMS data for each phenotype
    dms_summary_rand = dms_summary.copy()
    for phenotype in phenotypes:
        dms_summary_rand = dms_summary_rand.assign(
            **{phenotype: lambda x: numpy.random.permutation(x[phenotype].values)}
        )
    dms_data_dict_rand = (
        dms_summary_rand
        .set_index(["site", "wildtype", "mutant"])
        [phenotypes]
        .to_dict(orient="index")
    )
    pango_dms_dfs_rand.append(get_pango_dms_df(dms_data_dict_rand).assign(randomize=irandom))
# all randomizations concatenated
pango_dms_df_rand = pd.concat(pango_dms_dfs_rand)

Saving Pango DMS phenotypes to results/compare_natural/pango_dms_phenotypes.csv


## Plot phenotypes of Pango clades
Plot phenotypes of Pango clades versus their designation dates:

In [7]:
region_cols = {
    "phenotype": "full spike",
    "phenotype_RBD_only": "RBD only",
    "phenotype_nonRBD_only": "non-RBD only",
}

pango_chart_df = (
    pango_dms_df
    .melt(
        id_vars=[c for c in pango_dms_df if c not in region_cols],
        value_vars=region_cols,
        var_name="spike_region",
        value_name="phenotype value",
    )
    .assign(
        spike_region=lambda x: x["spike_region"].map(region_cols),
    )
    .rename(columns={f"muts_from_{dms_clade}_missing_data": "muts_missing_data"})
)

# columns cannot have "." in them for Altair
col_renames = {c: c.replace(".", "_") for c in pango_chart_df.columns if "." in c}
col_renames_rev = {v: k for (k, v) in col_renames.items()}
pango_chart_df = pango_chart_df.rename(columns=col_renames)

clade_selection = alt.selection_point(fields=["clade"], on="mouseover", empty=False)

base_pango_chart = (
    alt.Chart(pango_chart_df)
    .encode(
        tooltip=[
            alt.Tooltip(c, title=col_renames_rev[c] if c in col_renames_rev else c)
            for c in pango_chart_df.columns
        ],
        opacity=alt.condition(clade_selection, alt.value(1), alt.value(0.35)),
        size=alt.condition(clade_selection, alt.value(60), alt.value(40)),
        strokeWidth=alt.condition(clade_selection, alt.value(2), alt.value(0)),
        color=alt.Color("DMS_phenotype", legend=None),
    )
    .mark_circle(stroke="black")
    .properties(width=300, height=125)
)

phenotype_pango_charts = []
for phenotype in phenotypes:
    first_row = (phenotype == phenotypes[0])
    last_row = (phenotype == phenotypes[-1])
    phenotype_pango_charts.append(
        base_pango_chart
        .transform_filter(alt.datum["DMS_phenotype"] == phenotype)
        .encode(
            x=alt.X(
                "date",
                title="designation date of clade" if last_row else None,
                axis=(
                    alt.Axis(titleFontSize=12, labelOverlap=True, format="%b-%Y", labelAngle=-90)
                    if last_row
                    else None
                ),
                scale=alt.Scale(nice=False, padding=3),
            ),
            y=alt.Y(
                "phenotype value",
                title=phenotype,
                axis=alt.Axis(titleFontSize=12),
                scale=alt.Scale(nice=False, padding=3),
            ),
            column=alt.Column(
                "spike_region",
                sort=list(region_cols),
                title=None,
                header=(
                    alt.Header(labelFontSize=12, labelFontStyle="bold", labelPadding=4)
                    if first_row
                    else None
                ),
                spacing=4,
            ),
        )
    )

pango_chart = (
    alt.vconcat(*phenotype_pango_charts, spacing=4)
    .configure_axis(grid=False)
    .add_params(clade_selection)
    .properties(        
        title=alt.TitleParams(
            f"DMS predicted phenotypes of Pango clades descended from {', '.join(starting_clades)}",
            anchor="middle",
            fontSize=16,
            dy=-5,
        ),
    )
)

print(f"Saving chart to {pango_by_date_html}")
pango_chart.save(pango_by_date_html)

pango_chart

Saving chart to results/compare_natural/pango_dms_phenotypes_by_date.html


## Pango clade affinity versus escape scatter plot

In [8]:
pango_scatter_df = (
    pango_dms_df
    .pivot_table(
        index=[
            c
            for c in pango_dms_df
            if c not in {"DMS_phenotype", "phenotype", "phenotype_RBD_only", "phenotype_nonRBD_only"}
        ],
        values="phenotype",
        columns="DMS_phenotype",
    )
    .reset_index()
    .rename(columns={f"muts_from_{dms_clade}_missing_data": "muts_missing_data"})
    .rename(columns=col_renames)
)

pango_scatter_df

pango_scatter_chart = (
    alt.Chart(pango_scatter_df)
    .encode(
        x=alt.X(
            "ACE2 affinity",
            axis=alt.Axis(titleFontSize=12),
            scale=alt.Scale(nice=False, padding=5),
        ),
        y=alt.Y(
            "sera escape",
            axis=alt.Axis(titleFontSize=12),
            scale=alt.Scale(nice=False, padding=5),
        ),
        tooltip=[
            alt.Tooltip(c, title=col_renames_rev[c] if c in col_renames_rev else c)
            for c in pango_scatter_df.columns
        ],
        opacity=alt.condition(clade_selection, alt.value(1), alt.value(0.35)),
        size=alt.condition(clade_selection, alt.value(100), alt.value(55)),
        strokeWidth=alt.condition(clade_selection, alt.value(2), alt.value(0)),
    )
    .mark_circle(stroke="red", color="black")
    .add_params(clade_selection)
    .configure_axis(grid=False)
    .properties(        
        title=alt.TitleParams(
            [
                "DMS predicted ACE2 affinity vs serum escape",
                f"for Pango clades descended from {starting_clade}"
            ],
            anchor="middle",
            fontSize=14,
            dy=-5,
        ),
    )
    .properties(width=300, height=300)
)

print(f"Saving chart to {pango_affinity_vs_escape_html}")
pango_scatter_chart.save(pango_affinity_vs_escape_html)

pango_scatter_chart

Saving chart to results/compare_natural/pango_affinity_vs_escape.html


## Correlate with clade growth

In [9]:
growth_rates = pd.read_csv(growth_rates_csv).rename(
    columns={"pango": "clade", "seq_volume": "number sequences"}
)

if (invalid_clades := set(growth_rates["clade"]) - set(pango_clades)):
    raise ValueError(f"Growth rates specified for {invalid_clades}")

pango_dms_growth_df = pango_dms_df.merge(growth_rates, on="clade", validate="many_to_one")

pango_dms_growth_df_rand = pango_dms_df_rand.merge(growth_rates, on="clade", validate="many_to_one")

print(
    f"{growth_rates['clade'].nunique()} clades have growth rates estimates.\n"
    f"{pango_dms_df['clade'].nunique()} clades have DMS estimates.\n"
    f"{pango_dms_growth_df['clade'].nunique()} clades have growth and DMS estimates"
)

print("Simple correlations:")
display(
    pango_dms_growth_df
    .groupby("DMS_phenotype")
    [["R", "phenotype", "phenotype_RBD_only", "phenotype_nonRBD_only"]]
    .corr()
    [["R"]]
)

966 clades have growth rates estimates.
1539 clades have DMS estimates.
900 clades have growth and DMS estimates
Simple correlations:


Unnamed: 0_level_0,Unnamed: 1_level_0,R
DMS_phenotype,Unnamed: 1_level_1,Unnamed: 2_level_1
ACE2 affinity,R,1.0
ACE2 affinity,phenotype,-0.480301
ACE2 affinity,phenotype_RBD_only,-0.305674
ACE2 affinity,phenotype_nonRBD_only,-0.309355
cell entry,R,1.0
cell entry,phenotype,0.797371
cell entry,phenotype_RBD_only,0.820969
cell entry,phenotype_nonRBD_only,0.448803
sera escape,R,1.0
sera escape,phenotype,0.934447


Plot number of sequences versus date, with sizes proportional to log of number of sequences in clade:

In [10]:
(
    alt.Chart(pango_dms_growth_df)
    .encode(
        x="date",
        y="R",
        size=alt.Size("number sequences", scale=alt.Scale(type="log")),
        tooltip=pango_dms_growth_df.columns.tolist(),
    )
    .mark_circle(opacity=0.25)
)

Now perform OLS, weighting clades by log number of sequences:

In [48]:
# pivot DMS data to get phenotypes
def pivot_for_ols_vars(df):
    ols_vars = (
        df
        .rename(
            columns={
                "phenotype": "full spike",
                "phenotype_RBD_only": "RBD",
                "phenotype_nonRBD_only": "non RBD",
            }
        )
        .assign(
            # group muts missing data from all phenotypes
            muts_from_DMS_clade_missing_data=lambda x: (
                x.groupby("clade")
                [f"muts_from_{dms_clade}_missing_data"]
                .transform(
                    lambda s: "; ".join(dict.fromkeys([m for ms in s.str.split("; ") for m in ms if m]))
                )
            ),
        )
        .rename(columns={f"muts_from_{dms_clade}": "muts_from_DMS_clade"})
        .pivot_table(
            index=[
                "clade",
                "R",
                "date",
                "muts_from_DMS_clade",
                "muts_from_DMS_clade_missing_data",
                "number sequences",
            ],
            columns="DMS_phenotype",
            values=["full spike", "RBD", "non RBD"],
        )
    )
    # flatten column names
    assert all(len(c) == 2 for c in ols_vars.columns.values)
    ols_vars.columns = [f"{pheno} ({domain})" for domain, pheno in ols_vars.columns.values]
    return ols_vars.reset_index()

ols_vars = pivot_for_ols_vars(pango_dms_growth_df)

# https://www.einblick.ai/python-code-examples/ordinary-least-squares-regression-statsmodels/
for name, exog_vars, chartfile in [
    ("full spike", [f"{c} (full spike)" for c in phenotypes], pango_dms_vs_growth_html),
    (
        "separate RBD and non-RBD",
        [f"{c} ({d})" for d in ["RBD", "non RBD"] for c in phenotypes],
        pango_dms_vs_growth_by_domain_html,
    ),
]:
    print(f"\n\nFitting for {name}:")
    ols_model = statsmodels.api.WLS(
        endog=ols_vars[["R"]],
        exog=statsmodels.api.add_constant(ols_vars[exog_vars]),
        # weight by log n sequences, so pass log**2
        weights=numpy.log(ols_vars["number sequences"])**2,
    )
    res_ols = ols_model.fit()
    display(res_ols.summary())

    fitted_df = ols_vars.assign(DMS_predicted_growth=res_ols.predict())

    plot_size=180
    
    clade_selection = alt.selection_point(fields=["clade"], on="mouseover", empty=False)

    n_sequences_init = int(10 * math.log10(fitted_df["number sequences"].min())) / 10
    n_sequences_slider = alt.param(
        value=n_sequences_init,
        bind=alt.binding_range(
            name="minimum log10 number sequences in clade",
            min=n_sequences_init,
            max=math.log10(fitted_df["number sequences"].max() / 10),
        ),
    )

    # date slider: https://stackoverflow.com/a/67941109
    select_date = alt.selection_interval(encodings=["x"])
    date_slider = (
        alt.Chart(fitted_df[["clade", "date"]].drop_duplicates())
        .mark_bar(color="black")
        .encode(
            x=alt.X(
                "date",
                title="zoom bar to select clades by designation date",
                axis=alt.Axis(format="%b-%Y"),
            ),
            y=alt.Y("count()", title=["number", "clades"]),
        )
        .properties(width=1.5 * plot_size, height=45)
        .add_params(select_date)
    )
    
    base_growth_chart = (
        alt.Chart(fitted_df)
        .transform_filter(
            alt.expr.log(alt.datum["number sequences"]) / math.log(10) >= n_sequences_slider
        )
        .transform_filter(select_date)
        .encode(
            size=alt.Size(
                "number sequences",
                scale=alt.Scale(
                    type="log",
                    nice=False,
                    range=[15, 250],
                ),
                legend=alt.Legend(symbolStrokeWidth=0),
            ),
            strokeWidth=alt.condition(clade_selection, alt.value(2), alt.value(0.5)),
            strokeOpacity=alt.condition(clade_selection, alt.value(1), alt.value(0.5)),
            tooltip=[
                "clade",
                alt.Tooltip("R", title="growth rate (R)", format=".1f"),
                alt.Tooltip("DMS_predicted_growth", title="DMS predicted growth", format=".1f"),
                alt.Tooltip("number sequences", format=".2g"),
                alt.Tooltip("date", title="designation date"),
                alt.Tooltip("muts_from_DMS_clade", title=f"muts from {dms_clade}"),
                alt.Tooltip("muts_from_DMS_clade_missing_data", title="muts missing DMS data"),
                *[alt.Tooltip(v, format=".2f") for v in exog_vars],  
            ],
        )
        .mark_circle(stroke="black", fillOpacity=0.6)
        .properties(width=plot_size, height=plot_size)
        .add_params(clade_selection, n_sequences_slider)
    )

    growth_charts = [
        (
            base_growth_chart
            .encode(
                x=alt.X(
                    "DMS_predicted_growth",
                    title="DMS predicted clade growth",
                    scale=alt.Scale(nice=False, padding=5, zero=False),
                ),
                y=alt.Y(
                    "R",
                    title="actual clade growth rate (R)",
                    scale=alt.Scale(nice=False, padding=5, zero=False),
                    axis=None if i % len(phenotypes) else alt.Axis(),
                ),
                color=alt.Color(
                    dms_pheno,
                    title=None,
                    legend=alt.Legend(
                        orient="top",
                        titleFontSize=12,
                        gradientLength=plot_size,
                        gradientThickness=10,
                        offset=5,
                    ),
                    scale=alt.Scale(
                        scheme=color_scheme,
                        nice=False,
                    ),
                )
            )
            .properties(
                title=alt.TitleParams(
                    text=dms_pheno,
                    subtitle=(
                        f"coefficient: {res_ols.params[dms_pheno]:.1f} "
                        # https://stackoverflow.com/a/53966201
                        + f"\u00B1 {res_ols.bse[dms_pheno]:.1f}, "
                        + f"P: {res_ols.pvalues[dms_pheno]:.1g}"
                    ),
                    subtitleFontSize=11,
                ),
            )
        )
        for i, (dms_pheno, color_scheme, pheno) in enumerate(zip(
            exog_vars,
            itertools.cycle(["lightgreyred", "lightgreyteal", "purples"]),
            itertools.cycle(phenotypes)
        ))
    ]
            
    actual_r = math.sqrt(res_ols.rsquared)
    assert len(growth_charts) % len(phenotypes) == 0
    growth_chart = (
        alt.vconcat(
            alt.vconcat(
                *[
                    alt.hconcat(
                        *growth_charts[i * len(phenotypes): (i + 1) * len(phenotypes)], spacing=13
                    ).resolve_scale(color="independent")
                    for i in range(len(growth_charts) // len(phenotypes))
                ],
                spacing=13,
            ),
            date_slider,
        )
        .properties(
            title=alt.TitleParams(
                f"Weighted linear regression of DMS phenotypes vs clade growth (R = {actual_r:.2f})",
                anchor="middle",
                fontSize=14,
                dy=-5,
            ),
        )
        .configure_axis(grid=False)
    )
    
    display(growth_chart)
    print(f"Saving to {chartfile}")
    growth_chart.save(chartfile)

    # fit randomized models and compute P-value based on R values
    rand_r = []
    for randomseed, rand_df in pango_dms_growth_df_rand.groupby("randomize"):
        rand_ols_vars = pivot_for_ols_vars(rand_df)
        rand_ols_model = statsmodels.api.WLS(
            endog=rand_ols_vars[["R"]],
            exog=statsmodels.api.add_constant(rand_ols_vars[exog_vars]),
            # weight by log n sequences, so pass log**2
            weights=numpy.log(rand_ols_vars["number sequences"])**2,
        )
        rand_res_ols = rand_ols_model.fit()
        rand_r.append(math.sqrt(rand_res_ols.rsquared))
    pval = sum(r >= actual_r for r in rand_r) / len(rand_r)
    print(f"The P-value compared to randomized DMS data is {pval}")



Fitting for full spike:


0,1,2,3
Dep. Variable:,R,R-squared:,0.888
Model:,WLS,Adj. R-squared:,0.888
Method:,Least Squares,F-statistic:,2376.0
Date:,"Wed, 20 Sep 2023",Prob (F-statistic):,0.0
Time:,05:29:52,Log-Likelihood:,-3484.4
No. Observations:,900,AIC:,6977.0
Df Residuals:,896,BIC:,6996.0
Df Model:,3,,
Covariance Type:,nonrobust,,

0,1,2,3,4,5,6
,coef,std err,t,P>|t|,[0.025,0.975]
const,59.7536,0.730,81.888,0.000,58.321,61.186
sera escape (full spike),23.9263,0.554,43.149,0.000,22.838,25.015
ACE2 affinity (full spike),5.7993,1.283,4.520,0.000,3.281,8.317
cell entry (full spike),14.5368,2.211,6.574,0.000,10.197,18.877

0,1,2,3
Omnibus:,8.301,Durbin-Watson:,0.746
Prob(Omnibus):,0.016,Jarque-Bera (JB):,8.505
Skew:,0.194,Prob(JB):,0.0142
Kurtosis:,3.275,Cond. No.,13.1


Saving to results/compare_natural/pango_dms_vs_growth.html
The P-value compared to randomized DMS data is 0.0


Fitting for separate RBD and non-RBD:


0,1,2,3
Dep. Variable:,R,R-squared:,0.897
Model:,WLS,Adj. R-squared:,0.896
Method:,Least Squares,F-statistic:,1291.0
Date:,"Wed, 20 Sep 2023",Prob (F-statistic):,0.0
Time:,05:29:53,Log-Likelihood:,-3449.7
No. Observations:,900,AIC:,6913.0
Df Residuals:,893,BIC:,6947.0
Df Model:,6,,
Covariance Type:,nonrobust,,

0,1,2,3,4,5,6
,coef,std err,t,P>|t|,[0.025,0.975]
const,60.9830,0.727,83.852,0.000,59.556,62.410
sera escape (RBD),28.7096,0.818,35.078,0.000,27.103,30.316
ACE2 affinity (RBD),5.7018,1.388,4.107,0.000,2.977,8.427
cell entry (RBD),-16.8042,4.324,-3.886,0.000,-25.291,-8.317
sera escape (non RBD),45.8783,4.781,9.596,0.000,36.495,55.261
ACE2 affinity (non RBD),11.6336,2.009,5.790,0.000,7.690,15.577
cell entry (non RBD),24.7628,3.033,8.165,0.000,18.810,30.715

0,1,2,3
Omnibus:,20.534,Durbin-Watson:,0.835
Prob(Omnibus):,0.0,Jarque-Bera (JB):,33.125
Skew:,0.179,Prob(JB):,6.41e-08
Kurtosis:,3.869,Cond. No.,28.5


Saving to results/compare_natural/pango_dms_vs_growth_by_domain.html
The P-value compared to randomized DMS data is 0.0


In [23]:
max(1, 2, 3)

3

In [25]:
fitted_df.columns

Index(['clade', 'R', 'date', 'muts_from_DMS_clade',
       'muts_from_DMS_clade_missing_data', 'number sequences',
       'ACE2 affinity (RBD)', 'cell entry (RBD)', 'sera escape (RBD)',
       'ACE2 affinity (full spike)', 'cell entry (full spike)',
       'sera escape (full spike)', 'ACE2 affinity (non RBD)',
       'cell entry (non RBD)', 'sera escape (non RBD)',
       'DMS_predicted_growth'],
      dtype='object')