# Compare phenotypes from DMS (or some other source) to natural sequence evolution


In [1]:
# This cell is tagged parameters for `papermill` parameterization
murrell_growth_rates_csv = None
clade_phenotypes_csv = None
clade_phenotypes_randomized_csv = None
pair_growth_dms_csv = None
clade_growth_dms_csv = None
pair_corr_html = None
clade_corr_html = None
pair_ols_html = None

In [2]:
import os
os.chdir("../")

murrell_growth_rates_csv = "MultinomialLogisticGrowth/model_fits/rates.csv"
clade_phenotypes_csv = "results/SARS2-spike-predictor-phenos/clade_phenotypes.csv"
clade_phenotypes_randomized_csv = "results/SARS2-spike-predictor-phenos/clade_phenotypes_randomized.csv"
pair_growth_dms_csv="results/compare_natural/clade_pair_growth.csv"
clade_growth_dms_csv="results/compare_natural/clade_growth.csv"
pair_corr_html="results/compare_natural/clade_pair_growth.html"
clade_corr_html="results/compare_natural/clade_growth.html"
pair_ols_html="results/compare_natural/ols_clade_pair_growth.html"

In [3]:
import collections
import copy
import datetime
import functools
import itertools
import json
import math
import operator
import os
import re
import tempfile

import altair as alt

import baltic

import Bio.Phylo

import dmslogo.colorschemes

import matplotlib
import matplotlib.pyplot as plt

import pandas as pd

import polyclonal.plot

import scipy.stats

import statsmodels.api

_ = alt.data_transformers.disable_max_rows()

plt.rcParams["svg.fonttype"] = "none"

## Read the Pango clade data
Read the clades with the real and randomized data, and add the Ben Murrell growth data to the Bedford lab growth data already in the CSVs:

In [4]:
murrell_growth_rates = pd.read_csv(murrell_growth_rates_csv).rename(
    columns={"pango": "clade", "R": "clade growth (Murrell)"}
).query("seq_volume >= 400")

clade_phenotypes = (
    pd.read_csv(clade_phenotypes_csv)
    .rename(
        columns={
            "clade growth": "clade growth (Bedford)",
            "clade growth HDI 95": "clade growth HDI 95 (Bedford)",
        }
    )
    .merge(murrell_growth_rates, validate="one_to_one", how="left")
)

# get all columns with various properties
growth_cols = ["clade growth (Bedford)", "clade growth (Murrell)"]
other_cols = ["clade", "parent", "date", "number spike muts from Wuhan-Hu-1", "seq_volume", "clade growth HDI 95 (Bedford)"] + [
    c for c in clade_phenotypes.columns if c.startswith("spike muts")
]
descendant_of_cols = [c for c in clade_phenotypes.columns if c.startswith("descendant of")]
pheno_cols = [c for c in clade_phenotypes.columns if c not in growth_cols + other_cols + descendant_of_cols]
print("Analyzing the following phenotypes:\n " + "\n ".join(pheno_cols))

clade_phenotypes_randomized = pd.read_csv(clade_phenotypes_randomized_csv)
clade_phenotypes = clade_phenotypes.merge(
    clade_phenotypes_randomized.drop(
        columns=[
            c for c in clade_phenotypes_randomized.columns
            if not (c == "clade" or c.startswith("random_"))
        ]
    ),
    on="clade",
    validate="one_to_one",
)

random_pheno_cols = [c for c in clade_phenotypes_randomized.columns if c.startswith("random_")]
random_seeds = list(dict.fromkeys([c.split()[0] for c in random_pheno_cols]))
print(f"\nThere are {len(random_seeds)} different randomizations of each phenotype.")
assert all(f"{seed} {pheno}" in random_pheno_cols for seed in random_seeds for pheno in pheno_cols)

Analyzing the following phenotypes:
 spike pseudovirus DMS human sera escape relative to XBB.1.5
 spike pseudovirus DMS ACE2 binding relative to XBB.1.5
 spike pseudovirus DMS spike mediated entry relative to XBB.1.5
 RBD yeast-display DMS ACE2 affinity relative to XBB.1.5
 RBD yeast-display DMS RBD expression relative to XBB.1.5
 RBD yeast-display DMS escape relative to XBB.1.5
 EVEscape relative to XBB.1.5
 Hamming distance relative to XBB.1.5

There are 100 different randomizations of each phenotype.


Look at correlation of Bedford and Murrell growth rates:

In [5]:
display(
    clade_phenotypes
    .assign(
        has_Bedford_growth=lambda x: x["clade growth (Bedford)"].notnull(),
        has_Murrell_growth=lambda x: x["clade growth (Murrell)"].notnull(),
    )
    .groupby(["has_Bedford_growth", "has_Murrell_growth"])
    .aggregate(n_clades=pd.NamedAgg("clade", "count"))
)

(
    alt.Chart(
        clade_phenotypes[["clade growth (Bedford)", "clade growth (Murrell)", "clade", "date", "seq_volume"]]
    )
    .encode(
        x=alt.X("clade growth (Bedford)", scale=alt.Scale(zero=False)),
        y=alt.Y("clade growth (Murrell)", scale=alt.Scale(zero=False)),
        tooltip=["clade", "date", "seq_volume"],
    )
    .mark_circle()
)

Unnamed: 0_level_0,Unnamed: 1_level_0,n_clades
has_Bedford_growth,has_Murrell_growth,Unnamed: 2_level_1
False,False,3021
False,True,693
True,False,63
True,True,34


## Get changes in growth rates between parent-descendant clade pairs

In [6]:
def relative_mutations(muts, reference_muts):
    """Get mutation in `muts` relative `reference_muts`."""
    if pd.isnull(muts):
        muts = []
    else:
        muts = [(m[0], int(m[1: -1]), m[-1]) for m in muts.split()]
    if pd.isnull(reference_muts):
        reference_muts = []
    else:
        reference_muts = [(m[0], int(m[1: -1]), m[-1]) for m in reference_muts.split()]        
    shared_muts = set(muts).intersection(reference_muts)
    sites = {
        r: (wt, m) for (wt, r, m) in [tup for tup in muts if tup not in shared_muts]
    }
    reference_sites = {
        r: (wt, m) for (wt, r, m) in [tup for tup in reference_muts if tup not in shared_muts]
    }
    muts = []
    for r, (wt, m) in sites.items():
        if r in reference_sites:
            assert wt == reference_sites[r][0]
            muts.append((r, reference_sites[r][1], m))
        else:
            muts.append((r, wt, m))
    for r, (wt, m) in reference_sites.items():
        if r in sites:
            assert wt == sites[r][0]
            pass  # already counted
        else:
            muts.append((r, m, wt))
    return " ".join([f"{wt}{r}{m}" for (r, wt, m) in sorted(muts)])

pair_phenotypes = (
    clade_phenotypes
    .merge(
        clade_phenotypes
        [["clade", "spike muts from Wuhan-Hu-1"] + growth_cols + pheno_cols + random_pheno_cols]
        .rename(
            columns={
                "clade": "parent",
                "spike muts from Wuhan-Hu-1": "parent spike muts from Wuhan-Hu-1",
                **{
                    col: f"parent {col}"
                    for col in growth_cols + pheno_cols + random_pheno_cols
                },
            }
        ),
        validate="many_to_one",
        on="parent",
    )
    .assign(
        spike_muts_from_parent=lambda x: x.apply(
            lambda row: relative_mutations(
                row["spike muts from Wuhan-Hu-1"],
                row["parent spike muts from Wuhan-Hu-1"],
            ),
            axis=1,
        ),
    )
    .drop(columns=["spike muts from Wuhan-Hu-1", "parent spike muts from Wuhan-Hu-1"])
)
print(f"\n{len(pair_phenotypes)} of {len(clade_phenotypes)} clades have parent pairs")
for col in growth_cols + pheno_cols + random_pheno_cols:
    pair_phenotypes[col] = pair_phenotypes[col] - pair_phenotypes[f"parent {col}"]
pair_phenotypes = pair_phenotypes.drop(
    columns=[c for c in pair_phenotypes.columns if c.startswith("parent ")]
)

display(
    pair_phenotypes
    .assign(
        has_Bedford_growth=lambda x: x["clade growth (Bedford)"].notnull(),
        has_Murrell_growth=lambda x: x["clade growth (Murrell)"].notnull(),
    )
    .groupby(["has_Bedford_growth", "has_Murrell_growth"])
    .aggregate(n_clade_pairs=pd.NamedAgg("clade", "count"))
)


3699 of 3811 clades have parent pairs


Unnamed: 0_level_0,Unnamed: 1_level_0,n_clade_pairs
has_Bedford_growth,has_Murrell_growth,Unnamed: 2_level_1
False,False,3002
False,True,642
True,False,36
True,True,19


## Univariate correlations of each phenotype with growth rate

In [7]:
min_points_for_corr = 5

corr_dfs = []
for growth_col, (desc, df), descendant_of in itertools.product(
    growth_cols,
    [
        ("absolute clade property", clade_phenotypes),
        ("difference from clade parent", pair_phenotypes),
    ],
    descendant_of_cols,
):
    print(f"Analyzing {growth_col}, {desc}, {descendant_of}")
    df = df[df[growth_col].notnull() & df[descendant_of]]
    if len(df) < min_points_for_corr:
        continue
    records = [
        (
            pheno_col,
            df[pheno_col].notnull().sum(),
            df[growth_col].corr(df[pheno_col])**2,
            [df[growth_col].corr(df[f"{random_seed} {pheno_col}"])**2 for random_seed in random_seeds],
        )
        for pheno_col in pheno_cols
    ]
    corr_dfs.append(
        pd.DataFrame(records, columns=["phenotype", "n", "correlation (R2)", "randomized correlations"])
        .assign(
            growth_type=desc,
            growth_metric=growth_col,
            descendant_of=descendant_of,
            p=lambda x: x.apply(
                lambda row: sum(r >= row["correlation (R2)"] for r in row["randomized correlations"]) / len(random_seeds),
                axis=1,
            ),
        )
    )
corr_df = pd.concat(corr_dfs, ignore_index=True, sort=False)
display(corr_df.head())

Analyzing clade growth (Bedford), absolute clade property, descendant of BA.2.86
Analyzing clade growth (Bedford), absolute clade property, descendant of XBB
Analyzing clade growth (Bedford), absolute clade property, descendant of BA.2
Analyzing clade growth (Bedford), difference from clade parent, descendant of BA.2.86


  c /= stddev[:, None]
  c /= stddev[None, :]


Analyzing clade growth (Bedford), difference from clade parent, descendant of XBB
Analyzing clade growth (Bedford), difference from clade parent, descendant of BA.2


  c /= stddev[:, None]
  c /= stddev[None, :]


Analyzing clade growth (Murrell), absolute clade property, descendant of BA.2.86
Analyzing clade growth (Murrell), absolute clade property, descendant of XBB
Analyzing clade growth (Murrell), absolute clade property, descendant of BA.2
Analyzing clade growth (Murrell), difference from clade parent, descendant of BA.2.86
Analyzing clade growth (Murrell), difference from clade parent, descendant of XBB
Analyzing clade growth (Murrell), difference from clade parent, descendant of BA.2


Unnamed: 0,phenotype,n,correlation (R2),randomized correlations,growth_type,growth_metric,descendant_of,p
0,spike pseudovirus DMS human sera escape relati...,19,0.067132,"[0.35326370839859367, 0.00047311643080649035, ...",absolute clade property,clade growth (Bedford),descendant of BA.2.86,0.6
1,spike pseudovirus DMS ACE2 binding relative to...,19,0.162751,"[0.4195756382432586, 0.6421752156960275, 0.388...",absolute clade property,clade growth (Bedford),descendant of BA.2.86,0.42
2,spike pseudovirus DMS spike mediated entry rel...,19,0.661357,"[0.05352981140796372, 0.015153065425885978, 0....",absolute clade property,clade growth (Bedford),descendant of BA.2.86,0.04
3,RBD yeast-display DMS ACE2 affinity relative t...,19,0.630444,"[0.5275860746611065, 0.0005641699627895184, 0....",absolute clade property,clade growth (Bedford),descendant of BA.2.86,0.14
4,RBD yeast-display DMS RBD expression relative ...,19,0.019715,"[0.6154264070994924, 0.000895960505335881, 0.5...",absolute clade property,clade growth (Bedford),descendant of BA.2.86,0.92


Plot these P-values:

In [8]:
selectors = [
    alt.selection_point(
        fields=[col],
        value=corr_df[col][0],
        bind=alt.binding_select(
            name=col,
            options=corr_df[col].unique(),
        ),
    )
    for col in ["growth_type", "growth_metric", "descendant_of"]
]

base_chart = (
    alt.Chart(corr_df)
    .add_params(*selectors)
    .transform_filter(functools.reduce(operator.and_, selectors))
    .encode(
        alt.X("correlation (R2)", scale=alt.Scale(domain=[0, 1])),
        alt.Y(
            "phenotype",
            sort=pheno_cols,
            title=None,
            axis=alt.Axis(labelLimit=500, labelFontSize=11),
        ),
        tooltip=[
            alt.Tooltip("correlation (R2)", format=".2f"), alt.Tooltip("p", format=".2g"), "n",
        ],
    )
    .properties(width=350, height=alt.Step(20))
)

point_chart = (
    base_chart
    .mark_circle(size=70, opacity=1)
)

boxplot_chart = (
    base_chart
    .transform_flatten(["randomized correlations"])
    .transform_calculate(**{"correlation (R2)": alt.datum["randomized correlations"]})
    .mark_boxplot(color="gray", opacity=0.3, extent="min-max")
)

point_chart + boxplot_chart