# Compare DMS to natural sequence evolution

In [None]:
# this cell is tagged parameters for papermill parameterization
dms_summary_csv = None
growth_rates_csv = None
pango_consensus_seqs_json = None
starting_clades = None
dms_clade = None
n_random = None
exclude_clades = None
pango_dms_phenotypes_csv = None
pango_by_date_html = None
pango_affinity_vs_escape_html = None
pango_dms_vs_growth_regression_html = None
pango_dms_vs_growth_regression_by_domain_html = None
pango_dms_vs_growth_corr_html = None
pango_dms_vs_growth_corr_by_domain_html = None
exclude_clades_with_muts = None

In [None]:
# Parameters
starting_clades = ["BA.2", "BA.5", "XBB"]
dms_clade = "XBB.1.5"
dms_summary_csv = "results/summaries/summary.csv"
growth_rates_csv = "MultinomialLogisticGrowth/model_fits/rates.csv"
pango_consensus_seqs_json = (
    "results/compare_natural/pango-consensus-sequences_summary.json"
)
pango_dms_phenotypes_csv = "results/compare_natural/pango_dms_phenotypes.csv"
pango_by_date_html = "results/compare_natural/pango_dms_phenotypes_by_date.html"
pango_affinity_vs_escape_html = "results/compare_natural/pango_affinity_vs_escape.html"
pango_dms_vs_growth_regression_html = (
    "results/compare_natural/pango_dms_vs_growth_regression.html"
)
pango_dms_vs_growth_regression_by_domain_html = (
    "results/compare_natural/pango_dms_vs_growth_regression_by_domain.html"
)
pango_dms_vs_growth_corr_html = "results/compare_natural/pango_dms_vs_growth_corr.html"
pango_dms_vs_growth_corr_by_domain_html = (
    "results/compare_natural/pango_dms_vs_growth_corr_by_domain.html"
)
n_random = 10
exclude_clades = []
exclude_clades_with_muts = []

import os
os.chdir("../")


In [None]:
import collections
import itertools
import json
import math
import re

import altair as alt

import numpy

import pandas as pd

import polyclonal.plot

import scipy.stats

import statsmodels.api

_ = alt.data_transformers.disable_max_rows()

## Read Pango clades and mutations

In [None]:
with open(pango_consensus_seqs_json) as f:
    pango_clades = json.load(f)

def n_child_clades(c):
    """Get number of children clades of a Pango clade."""
    direct_children = pango_clades[c]["children"]
    return len(direct_children) + sum([n_child_clades(c_child) for c_child in direct_children])

def build_records(c, recs):
    """Build records of Pango clade information."""
    if c in recs["clade"]:
        return
    recs["clade"].append(c)
    recs["n_child_clades"].append(n_child_clades(c))
    recs["date"].append(pango_clades[c]["designationDate"])
    recs["muts_from_ref"].append(
        [
            mut.split(":")[1]
            for field in ["aaSubstitutions", "aaDeletions"]
            for mut in pango_clades[c][field]
            if mut.startswith("S:")
        ]
    )
    for c_child in pango_clades[c]["children"]:
        build_records(c_child, recs)
        
records = collections.defaultdict(list)
for starting_clade in starting_clades:
    build_records(starting_clade, records)

pango_df = pd.DataFrame(records).query("clade not in @exclude_clades")
dms_clade_mutations_from_ref = pango_df.set_index("clade").at[
    dms_clade, "muts_from_ref"
]

def mutations_from(muts, from_muts):
    """Get mutations from another sequence."""
    new_muts = set(muts).symmetric_difference(from_muts)
    assert all(re.fullmatch("[A-Z\-]\d+[A-Z\-]", m) for m in new_muts)
    new_muts_d = collections.defaultdict(list)
    for m in new_muts:
        new_muts_d[int(m[1: -1])].append(m)
    new_muts_list = []
    for _, ms in sorted(new_muts_d.items()):
        if len(ms) == 1:
            m = ms[0]
            if m in muts:
                new_muts_list.append(m)
            else:
                assert m in from_muts
                new_muts_list.append(m[-1] + m[1: -1] + m[0])
        else:
            m, from_m = ms
            if m not in muts:
                from_m, m = m, from_m
            assert m in muts and from_m in from_muts
            new_muts_list.append(from_m[-1] + m[1: ])
    return new_muts_list

pango_df = (
    pango_df
    .assign(
        muts_from_dms_clade=lambda x: x["muts_from_ref"].apply(
            mutations_from, args=(dms_clade_mutations_from_ref,),
        ),
        date=lambda x: pd.to_datetime(x["date"]),
    )
    .drop(columns="muts_from_ref")
    .sort_values("date")
    .reset_index(drop=True)
)

for mut in exclude_clades_with_muts:
    pango_df = pango_df[pango_df["muts_from_dms_clade"].map(lambda ms: mut not in ms)]

pango_df

## Assign DMS phenotypes to Pango clades

First define function that assigns DMS phenotypes to mutations:

In [None]:
# read the DMS data
dms_summary = pd.read_csv(dms_summary_csv).rename(
    columns={
        "spike mediated entry": "cell entry",
        "human sera escape": "sera escape",
    }
)

# specify DMS phenotypes of interest
phenotypes = [
    "sera escape",
    "ACE2 affinity",
    "cell entry",
]
assert set(phenotypes).issubset(dms_summary.columns)

phenotype_colors = {
    "sera escape": "red",
    "ACE2 affinity": "blue",
    "cell entry": "purple",
}
assert set(phenotypes) == set(phenotype_colors)


# dict that maps site to wildtype in DMS
dms_wt = dms_summary.set_index("site")["wildtype"].to_dict()

# dict that maps site to region in DMS
site_to_region = dms_summary.set_index("site")["region"].to_dict()

def mut_dms(m, dms_data):
    """Get DMS phenotypes for a mutation."""
    null_d = {k: pd.NA for k in phenotypes}
    if pd.isnull(m) or int(m[1: -1]) not in dms_wt:
        d = null_d
        d["is_RBD"] = pd.NA
    else:
        parent = m[0]
        site = int(m[1: -1])
        mut = m[-1]
        wt = dms_wt[site]
        if parent == wt:
            try:
                d = dms_data[(site, parent, mut)]
            except KeyError:
                d = null_d
        elif mut == wt:
            try:
                d = {k: -v for (k, v) in dms_data[(site, mut, parent)].items()}
            except KeyError:
                d = null_d
        else:
            try:
                parent_d = dms_data[(site, wt, parent)]
                mut_d = dms_data[(site, wt, mut)]
                d = {p: mut_d[p] - parent_d[p] for p in phenotypes}
            except KeyError:
                d = null_d
        d["is_RBD"] = (site_to_region[site] == "RBD")
    assert list(d) == phenotypes + ["is_RBD"]
    return d

Now assign phenotypes to pango clades.
We do this both using the actual DMS data and randomizing the DMS data among measured mutations:

In [None]:
def get_pango_dms_df(dms_data_dict):
    """Given dict mapping mutations to DMS data, get data frame of values for Pango clades."""
    pango_dms_df = (
        pango_df
        # put one mutation in each column
        .explode("muts_from_dms_clade")
        .rename(columns={"muts_from_dms_clade": "mutation"})
        # to add multiple columns: https://stackoverflow.com/a/46814360
        .apply(
            lambda cols: pd.concat([cols, pd.Series(mut_dms(cols["mutation"], dms_data_dict))]),
            axis=1,
        )
        .melt(
            id_vars=["clade", "date", "n_child_clades", "mutation", "is_RBD"],
            value_vars=phenotypes,
            var_name="DMS_phenotype",
            value_name="mutation_effect",
        )
        .assign(
            muts_from_dms_clade=lambda x: x.groupby(["clade", "DMS_phenotype"])["mutation"].transform(
                lambda ms: "; ".join([m for m in ms if not pd.isnull(m)])
            ),
            mutation_missing=lambda x: x["mutation"].where(
                x["mutation_effect"].isnull() & x["mutation"].notnull(),
                pd.NA,
            ),
            muts_from_dms_clade_missing_data=lambda x: (
                x.groupby(["clade", "DMS_phenotype"])["mutation_missing"]
                .transform(lambda ms: "; ".join([m for m in ms if not pd.isnull(m)]))
            ),
            mutation_effect=lambda x: x["mutation_effect"].fillna(0),
            is_RBD=lambda x: x["is_RBD"].fillna(False),
            mutation_effect_RBD=lambda x: x["mutation_effect"] * x["is_RBD"].astype(int),
            mutation_effect_nonRBD=lambda x: x["mutation_effect"] * (~x["is_RBD"]).astype(int),
        )
        .groupby(
            [
                "clade",
                "date",
                "n_child_clades",
                "muts_from_dms_clade",
                "muts_from_dms_clade_missing_data",
                "DMS_phenotype",
            ],
            as_index=False,
        )
        .aggregate(
            phenotype=pd.NamedAgg("mutation_effect", "sum"),
            phenotype_RBD_only=pd.NamedAgg("mutation_effect_RBD", "sum"),
            phenotype_nonRBD_only=pd.NamedAgg("mutation_effect_nonRBD", "sum"),
        )
        .rename(
            columns={
                "muts_from_dms_clade": f"muts_from_{dms_clade}",
                "muts_from_dms_clade_missing_data": f"muts_from_{dms_clade}_missing_data",
            },
        )
        .sort_values(["date", "DMS_phenotype"])
        .reset_index(drop=True)
    )
    
    assert set(pango_df["clade"]) == set(pango_dms_df["clade"])
    assert numpy.allclose(
        pango_dms_df["phenotype"],
        pango_dms_df["phenotype_RBD_only"] + pango_dms_df["phenotype_nonRBD_only"]
    )

    return pango_dms_df

# First, get the actual DMS data mapped to phenotype
dms_data_dict_actual = (
    dms_summary
    .set_index(["site", "wildtype", "mutant"])
    [phenotypes]
    .to_dict(orient="index")
)
pango_dms_df = get_pango_dms_df(dms_data_dict_actual)
print(f"Saving Pango DMS phenotypes to {pango_dms_phenotypes_csv}")
pango_dms_df.to_csv(pango_dms_phenotypes_csv, float_format="%.4f", index=False)

# Now get the randomized DMS data mapped to phenotype
pango_dms_dfs_rand = []
numpy.random.seed(0)
for irandom in range(1, n_random + 1):
    # randomize the non-null DMS data for each phenotype
    dms_summary_rand = dms_summary.copy()
    for phenotype in phenotypes:
        dms_summary_rand = dms_summary_rand.assign(
            **{phenotype: lambda x: numpy.random.permutation(x[phenotype].values)}
        )
    dms_data_dict_rand = (
        dms_summary_rand
        .set_index(["site", "wildtype", "mutant"])
        [phenotypes]
        .to_dict(orient="index")
    )
    pango_dms_dfs_rand.append(get_pango_dms_df(dms_data_dict_rand).assign(randomize=irandom))
# all randomizations concatenated
pango_dms_df_rand = pd.concat(pango_dms_dfs_rand)

## Plot phenotypes of Pango clades
Plot phenotypes of Pango clades versus their designation dates:

In [None]:
region_cols = {
    "phenotype": "full spike",
    "phenotype_RBD_only": "RBD only",
    "phenotype_nonRBD_only": "non-RBD only",
}

pango_chart_df = (
    pango_dms_df
    .melt(
        id_vars=[c for c in pango_dms_df if c not in region_cols],
        value_vars=region_cols,
        var_name="spike_region",
        value_name="phenotype value",
    )
    .assign(
        spike_region=lambda x: x["spike_region"].map(region_cols),
    )
    .rename(columns={f"muts_from_{dms_clade}_missing_data": "muts_missing_data"})
)

# columns cannot have "." in them for Altair
col_renames = {c: c.replace(".", "_") for c in pango_chart_df.columns if "." in c}
col_renames_rev = {v: k for (k, v) in col_renames.items()}
pango_chart_df = pango_chart_df.rename(columns=col_renames)

clade_selection = alt.selection_point(fields=["clade"], on="mouseover", empty=False)

base_pango_chart = (
    alt.Chart(pango_chart_df)
    .encode(
        tooltip=[
            alt.Tooltip(c, title=col_renames_rev[c] if c in col_renames_rev else c)
            for c in pango_chart_df.columns
        ],
        opacity=alt.condition(clade_selection, alt.value(1), alt.value(0.35)),
        size=alt.condition(clade_selection, alt.value(60), alt.value(40)),
        strokeWidth=alt.condition(clade_selection, alt.value(2), alt.value(0)),
        color=alt.Color(
            "DMS_phenotype",
            legend=None,
            scale=alt.Scale(
                range=list(phenotype_colors.values()),
                domain=list(phenotype_colors.keys()),
            ),
        ),
    )
    .mark_circle(stroke="black")
    .properties(width=300, height=125)
)

phenotype_pango_charts = []
for phenotype in phenotypes:
    first_row = (phenotype == phenotypes[0])
    last_row = (phenotype == phenotypes[-1])
    phenotype_pango_charts.append(
        base_pango_chart
        .transform_filter(alt.datum["DMS_phenotype"] == phenotype)
        .encode(
            x=alt.X(
                "date",
                title="designation date of clade" if last_row else None,
                axis=(
                    alt.Axis(titleFontSize=12, labelOverlap=True, format="%b-%Y", labelAngle=-90)
                    if last_row
                    else None
                ),
                scale=alt.Scale(nice=False, padding=3),
            ),
            y=alt.Y(
                "phenotype value",
                title=phenotype,
                axis=alt.Axis(titleFontSize=12),
                scale=alt.Scale(nice=False, padding=3),
            ),
            column=alt.Column(
                "spike_region",
                sort=list(region_cols),
                title=None,
                header=(
                    alt.Header(labelFontSize=12, labelFontStyle="bold", labelPadding=4)
                    if first_row
                    else None
                ),
                spacing=4,
            ),
        )
    )

pango_chart = (
    alt.vconcat(*phenotype_pango_charts, spacing=4)
    .configure_axis(grid=False)
    .add_params(clade_selection)
    .properties(        
        title=alt.TitleParams(
            f"DMS predicted phenotypes of Pango clades descended from {', '.join(starting_clades)}",
            anchor="middle",
            fontSize=16,
            dy=-5,
        ),
    )
)

print(f"Saving chart to {pango_by_date_html}")
pango_chart.save(pango_by_date_html)

pango_chart

## Pango clade affinity versus escape scatter plot

In [None]:
pango_scatter_df = (
    pango_dms_df
    .pivot_table(
        index=[
            c
            for c in pango_dms_df
            if c not in {"DMS_phenotype", "phenotype", "phenotype_RBD_only", "phenotype_nonRBD_only"}
        ],
        values="phenotype",
        columns="DMS_phenotype",
    )
    .reset_index()
    .rename(columns={f"muts_from_{dms_clade}_missing_data": "muts_missing_data"})
    .rename(columns=col_renames)
)

pango_scatter_df

pango_scatter_chart = (
    alt.Chart(pango_scatter_df)
    .encode(
        x=alt.X(
            "ACE2 affinity",
            axis=alt.Axis(titleFontSize=12),
            scale=alt.Scale(nice=False, padding=5),
        ),
        y=alt.Y(
            "sera escape",
            axis=alt.Axis(titleFontSize=12),
            scale=alt.Scale(nice=False, padding=5),
        ),
        tooltip=[
            alt.Tooltip(c, title=col_renames_rev[c] if c in col_renames_rev else c)
            for c in pango_scatter_df.columns
        ],
        opacity=alt.condition(clade_selection, alt.value(1), alt.value(0.35)),
        size=alt.condition(clade_selection, alt.value(100), alt.value(55)),
        strokeWidth=alt.condition(clade_selection, alt.value(2), alt.value(0)),
    )
    .mark_circle(stroke="red", color="black")
    .add_params(clade_selection)
    .configure_axis(grid=False)
    .properties(        
        title=alt.TitleParams(
            [
                "DMS predicted ACE2 affinity vs serum escape",
                f"for Pango clades descended from {starting_clade}"
            ],
            anchor="middle",
            fontSize=14,
            dy=-5,
        ),
    )
    .properties(width=300, height=300)
)

print(f"Saving chart to {pango_affinity_vs_escape_html}")
pango_scatter_chart.save(pango_affinity_vs_escape_html)

pango_scatter_chart

## Correlate with clade growth

In [None]:
growth_rates = pd.read_csv(growth_rates_csv).rename(
    columns={"pango": "clade", "seq_volume": "number sequences"}
)

if (invalid_clades := set(growth_rates["clade"]) - set(pango_clades)):
    raise ValueError(f"Growth rates specified for {invalid_clades}")

pango_dms_growth_df = pango_dms_df.merge(growth_rates, on="clade", validate="many_to_one")

pango_dms_growth_df_rand = pango_dms_df_rand.merge(growth_rates, on="clade", validate="many_to_one")

print(
    f"{growth_rates['clade'].nunique()} clades have growth rates estimates.\n"
    f"{pango_dms_df['clade'].nunique()} clades have DMS estimates.\n"
    f"{pango_dms_growth_df['clade'].nunique()} clades have growth and DMS estimates"
)

print("Simple correlations:")
display(
    pango_dms_growth_df
    .groupby("DMS_phenotype")
    [["R", "phenotype", "phenotype_RBD_only", "phenotype_nonRBD_only"]]
    .corr()
    [["R"]]
)

Plot number of sequences versus date, with sizes proportional to log of number of sequences in clade:

In [None]:
(
    alt.Chart(pango_dms_growth_df)
    .encode(
        x="date",
        y="R",
        size=alt.Size("number sequences", scale=alt.Scale(type="log")),
        tooltip=pango_dms_growth_df.columns.tolist(),
    )
    .mark_circle(opacity=0.25, color="black")
)

Now perform OLS, weighting clades by log number of sequences:

In [None]:
# pivot DMS data to get phenotypes
def pivot_for_ols_vars(df):
    ols_vars = (
        df
        .rename(
            columns={
                "phenotype": "full spike",
                "phenotype_RBD_only": "RBD",
                "phenotype_nonRBD_only": "non RBD",
            }
        )
        .assign(
            # group muts missing data from all phenotypes
            muts_from_DMS_clade_missing_data=lambda x: (
                x.groupby("clade")
                [f"muts_from_{dms_clade}_missing_data"]
                .transform(
                    lambda s: "; ".join(dict.fromkeys([m for ms in s.str.split("; ") for m in ms if m]))
                )
            ),
        )
        .rename(columns={f"muts_from_{dms_clade}": "muts_from_DMS_clade"})
        .pivot_table(
            index=[
                "clade",
                "R",
                "date",
                "muts_from_DMS_clade",
                "muts_from_DMS_clade_missing_data",
                "number sequences",
            ],
            columns="DMS_phenotype",
            values=["full spike", "RBD", "non RBD"],
        )
    )
    # flatten column names
    assert all(len(c) == 2 for c in ols_vars.columns.values)
    ols_vars.columns = [f"{pheno} ({domain})" for domain, pheno in ols_vars.columns.values]
    return ols_vars.reset_index()

ols_vars = pivot_for_ols_vars(pango_dms_growth_df)

# https://www.einblick.ai/python-code-examples/ordinary-least-squares-regression-statsmodels/
for name, exog_vars, regression_chartfile, corr_chartfile in [
    (
        "full spike",
        [f"{c} (full spike)" for c in phenotypes],
        pango_dms_vs_growth_regression_html,
        pango_dms_vs_growth_corr_html
    ),
    (
        "separate RBD and non-RBD",
        [f"{c} ({d})" for d in ["RBD", "non RBD"] for c in phenotypes],
        pango_dms_vs_growth_regression_by_domain_html,
        pango_dms_vs_growth_corr_by_domain_html,
    ),
]:
    print(f"\n\nFitting for {name}:")
    ols_model = statsmodels.api.WLS(
        endog=ols_vars[["R"]],
        exog=statsmodels.api.add_constant(ols_vars[exog_vars]),
        # weight by log n sequences, so pass log**2
        weights=numpy.log(ols_vars["number sequences"])**2,
    )
    res_ols = ols_model.fit()
    display(res_ols.summary())

    fitted_df = ols_vars.assign(DMS_predicted_growth=res_ols.predict())

    plot_size=180
    
    clade_selection = alt.selection_point(fields=["clade"], on="mouseover", empty=False)

    n_sequences_init = int(10 * math.log10(fitted_df["number sequences"].min())) / 10
    n_sequences_slider = alt.param(
        value=n_sequences_init,
        bind=alt.binding_range(
            name="minimum log10 number sequences in clade",
            min=n_sequences_init,
            max=math.log10(fitted_df["number sequences"].max() / 10),
        ),
    )

    # date slider: https://stackoverflow.com/a/67941109
    select_date = alt.selection_interval(encodings=["x"])
    date_slider = (
        alt.Chart(fitted_df[["clade", "date"]].drop_duplicates())
        .mark_bar(color="black")
        .encode(
            x=alt.X(
                "date",
                title="zoom bar to select clades by designation date",
                axis=alt.Axis(format="%b-%Y"),
            ),
            y=alt.Y("count()", title=["number", "clades"]),
        )
        .properties(width=1.5 * plot_size, height=45)
        .add_params(select_date)
    )
    
    base_growth_chart = (
        alt.Chart(fitted_df)
        .transform_filter(
            alt.expr.log(alt.datum["number sequences"]) / math.log(10) >= n_sequences_slider
        )
        .transform_filter(select_date)
        .encode(
            size=alt.Size(
                "number sequences",
                scale=alt.Scale(
                    type="log",
                    nice=False,
                    range=[15, 250],
                ),
                legend=alt.Legend(symbolStrokeWidth=0, symbolFillColor="gray"),
            ),
            strokeWidth=alt.condition(clade_selection, alt.value(2), alt.value(0.5)),
            strokeOpacity=alt.condition(clade_selection, alt.value(1), alt.value(0.5)),
            tooltip=[
                "clade",
                alt.Tooltip("R", title="growth rate (R)", format=".1f"),
                alt.Tooltip("DMS_predicted_growth", title="DMS predicted growth", format=".1f"),
                alt.Tooltip("number sequences", format=".2g"),
                alt.Tooltip("date", title="designation date"),
                alt.Tooltip("muts_from_DMS_clade", title=f"muts from {dms_clade}"),
                alt.Tooltip("muts_from_DMS_clade_missing_data", title="muts missing DMS data"),
                *[alt.Tooltip(v, format=".2f") for v in exog_vars],  
            ],
        )
        .properties(width=plot_size, height=plot_size)
        .add_params(clade_selection, n_sequences_slider)
    )

    growth_charts = []
    simple_corr_charts = []
    for i, (dms_pheno, pheno) in enumerate(zip(
        exog_vars,
        itertools.cycle(phenotypes)
    )):
        assert dms_pheno.startswith(pheno)
        base_pheno_chart = (
            base_growth_chart
            .encode(
                y=alt.Y(
                    "R",
                    title="actual clade growth rate (R)",
                    scale=alt.Scale(nice=False, padding=5, zero=False),
                    axis=None if i % len(phenotypes) else alt.Axis(),
                ),
            )
        )

        growth_charts.append(
            base_pheno_chart
            .encode(
                x=alt.X(
                    "DMS_predicted_growth",
                    title="DMS predicted clade growth",
                    scale=alt.Scale(nice=False, padding=5, zero=False),
                ),
                color=alt.Color(
                    dms_pheno,
                    title=None,
                    legend=alt.Legend(
                        orient="top",
                        titleFontSize=12,
                        gradientLength=plot_size,
                        gradientThickness=10,
                        offset=5,
                        tickCount=3,
                    ),
                    scale=alt.Scale(
                        range=polyclonal.plot.color_gradient_hex("lightgray", phenotype_colors[pheno], 40),
                        nice=False,
                    ),
                ),
            )
            .mark_circle(stroke="black", fillOpacity=0.6)
            .properties(
                title=alt.TitleParams(
                    text=dms_pheno,
                    subtitle=(
                        f"coefficient: {res_ols.params[dms_pheno]:.1f} "
                        # https://stackoverflow.com/a/53966201
                        + f"\u00B1 {res_ols.bse[dms_pheno]:.1f}, "
                        + f"P: {res_ols.pvalues[dms_pheno]:.1g}"
                    ),
                    subtitleFontSize=11,
                ),
            )
        )

        # get real and randomized P-value
        pheno_r, _ = scipy.stats.pearsonr(fitted_df["R"], fitted_df[dms_pheno])
        rand_rs = [
            scipy.stats.pearsonr(pivot_for_ols_vars(rand_df)["R"], pivot_for_ols_vars(rand_df)[dms_pheno])[0]
            for _, rand_df in pango_dms_growth_df_rand.groupby("randomize")
        ]
        rand_p = sum(pheno_r <= r for r in rand_rs) / len(rand_rs)
            
        simple_corr_charts.append(
            base_pheno_chart
            .transform_calculate(color_phenotype=f"'{pheno}'")
            .encode(
                x=alt.X(
                    dms_pheno,
                    scale=alt.Scale(nice=False, padding=5, zero=False),
                ),
                color=alt.Color(
                    "color_phenotype:N",
                    scale=alt.Scale(
                        range=list(phenotype_colors.values()),
                        domain=list(phenotype_colors.keys()),
                    ),
                    legend=None,
                ),
            )
            .mark_circle(stroke="black", fillOpacity=0.3, color=phenotype_colors[pheno])
            .properties(
                title=alt.TitleParams(
                    text=dms_pheno,
                    subtitle=f"Pearson r = {pheno_r:.2f} (P = {rand_p:.1g})",
                    subtitleFontSize=11,
                ),
            )
        )
            
    actual_r = math.sqrt(res_ols.rsquared)
    assert len(growth_charts) % len(phenotypes) == 0
    growth_chart = (
        alt.vconcat(
            alt.vconcat(
                *[
                    alt.hconcat(
                        *growth_charts[i * len(phenotypes): (i + 1) * len(phenotypes)], spacing=13
                    ).resolve_scale(color="independent")
                    for i in range(len(growth_charts) // len(phenotypes))
                ],
                spacing=13,
            ),
            date_slider,
        )
        .properties(
            title=alt.TitleParams(
                f"Weighted linear regression of DMS phenotypes vs clade growth (Pearson r = {actual_r:.2f})",
                anchor="middle",
                fontSize=14,
                dy=-5,
            ),
        )
        .configure_axis(grid=False)
    )

    simple_corr_chart = (
        alt.vconcat(
            alt.vconcat(
                *[
                    alt.hconcat(
                        *simple_corr_charts[i * len(phenotypes): (i + 1) * len(phenotypes)], spacing=13
                    )
                    for i in range(len(simple_corr_charts) // len(phenotypes))
                ],
                spacing=13,
            ),
            date_slider,
        )
        .properties(
            title=alt.TitleParams(
                "Simple correlations of DMS phenotypes vs clade growth",
                anchor="middle",
                fontSize=14,
                dy=-5,
            ),
        )
        .configure_axis(grid=False)
    )
    
    display(growth_chart)
    print(f"Saving to {regression_chartfile}")
    growth_chart.save(regression_chartfile)

    display(simple_corr_chart)
    print(f"Saving to {corr_chartfile}")
    simple_corr_chart.save(corr_chartfile)

    # fit randomized models and compute P-value based on R values
    print("Computing P-value from randomizations")
    rand_r = []
    for randomseed, rand_df in pango_dms_growth_df_rand.groupby("randomize"):
        rand_ols_vars = pivot_for_ols_vars(rand_df)
        rand_ols_model = statsmodels.api.WLS(
            endog=rand_ols_vars[["R"]],
            exog=statsmodels.api.add_constant(rand_ols_vars[exog_vars]),
            # weight by log n sequences, so pass log**2
            weights=numpy.log(rand_ols_vars["number sequences"])**2,
        )
        rand_res_ols = rand_ols_model.fit()
        rand_r.append(math.sqrt(rand_res_ols.rsquared))
    n_rand_ge = sum(r >= actual_r for r in rand_r)
    pval = f"= {n_rand_ge / len(rand_r)}" if n_rand_ge else f"< {1 / len(rand_r)}"
    
    rand_r_hist = (
        alt.Chart(pd.DataFrame({"r": rand_r}))
        .encode(
            x=alt.X(
                "r",
                title="Pearson r",
                bin=alt.BinParams(step=0.02, extent=(0, 1)),
                scale=alt.Scale(domain=(0, 1)),
                axis=alt.Axis(values=[0, 0.2, 0.4, 0.6, 0.8, 1]),
            ),
            y=alt.Y("count()", title="number of randomizations"),
        )
        .mark_bar(color="black", opacity=0.65, align="right")
        .properties(width=250, height=130)
    )
    
    actual_r_line = (
        alt.Chart(pd.DataFrame({"r": [actual_r]}))
        .encode(x="r")
        .mark_rule(size=2, color="red", strokeDash=[4, 2])
    )
    
    pval_chart = (
        (rand_r_hist + actual_r_line)
        .configure_axis(grid=False)
        .properties(
            title=alt.TitleParams(
                f"P {pval}",
                subtitle=f"{n_rand_ge} of {len(rand_r)} randomizations 	\u2265 actual r of {actual_r:.2f}",
            ),
        )
    )
    
    display(pval_chart)

## Distributions of DMS mutation effects in clades with growth estimates versus all mutations

In [None]:
muts_in_clades = collections.Counter(
    pango_dms_growth_df
    [f"muts_from_{dms_clade}"]
    .pipe(lambda s: s[s != ""])
    .str.split("; ")
    .explode()
)
print(f"There are {len(muts_in_clades)} mutations found in clades with growth estimates")

all_muts_dms = (
    dms_summary
    .query("wildtype != mutant")
    .assign(mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"])
    .assign(region=lambda x: x["region"].where(x["region"] == "RBD", "non RBD"))
    .melt(
        id_vars=["mutation", "region"],
        value_vars=phenotypes,
        var_name="DMS_phenotype",
        value_name="phenotype",
    )
    .query("phenotype.notnull()") 
)

all_muts_dms = pd.concat(
    [
        all_muts_dms.assign(mutation_type="any", count=1),
        all_muts_dms.query("mutation in @muts_in_clades").assign(
            mutation_type="in Pango clade",
            count=lambda x: x["mutation"].map(muts_in_clades),
        ),
    ]
)

for pheno in phenotypes:
    
    base_hist = (
        alt.Chart(
            all_muts_dms
            .query("DMS_phenotype == @pheno")
            .drop(columns=["DMS_phenotype", "mutation"])
        )
        .encode(
            x=alt.X("phenotype", bin=alt.BinParams(maxbins=50)),
            y=alt.Y("sum(count)", title="mutations"),
            color=alt.value(phenotype_colors[pheno]),
            row=alt.Row("mutation_type", title=None, spacing=5),
        )
        .properties(width=200, height=75, title=pheno)
        .mark_bar()
        .resolve_scale(y="independent")
    )
    display(base_hist)