# Test how various models score the phenotypes of BA.2.86 and its subsequent descendants

In [55]:
# This cell is tagged as `parameters` for `papermill` parameterization
clade_phenotypes_csv = "SARS2-spike-predictor-phenos/results/clade_phenotypes.csv"
mutation_phenotypes_csv = "SARS2-spike-predictor-phenos/results/mutation_phenotypes.csv"
gisaid_mutation_counts_csv = "data/GISAID_alignment_counts_2024-03-01.csv"
gisaid_min_counts = 10

Import Python modules:

In [2]:
import altair as alt

import pandas as pd

Read the clade phenotypes and get the mutations of each clade relative to its parent.
We only analyze BA.2.86 and descendant clades with at least one spike mutation relative to their parent:

In [37]:
def relative_mutations(muts, reference_muts):
    """Get mutation in `muts` relative `reference_muts`."""
    if pd.isnull(muts):
        muts = []
    else:
        muts = [(m[0], int(m[1: -1]), m[-1]) for m in muts.split()]
    if pd.isnull(reference_muts):
        reference_muts = []
    else:
        reference_muts = [(m[0], int(m[1: -1]), m[-1]) for m in reference_muts.split()]        
    shared_muts = set(muts).intersection(reference_muts)
    sites = {
        r: (wt, m) for (wt, r, m) in [tup for tup in muts if tup not in shared_muts]
    }
    reference_sites = {
        r: (wt, m) for (wt, r, m) in [tup for tup in reference_muts if tup not in shared_muts]
    }
    muts = []
    for r, (wt, m) in sites.items():
        if r in reference_sites:
            assert wt == reference_sites[r][0]
            muts.append((r, reference_sites[r][1], m))
        else:
            muts.append((r, wt, m))
    for r, (wt, m) in reference_sites.items():
        if r in sites:
            assert wt == sites[r][0]
            pass  # already counted
        else:
            muts.append((r, m, wt))
    return [(wt, r, m) for (r, wt, m) in sorted(muts)]

clade_phenotypes = (
    pd.read_csv(clade_phenotypes_csv)
    [
        ["clade", "parent", "date", "spike muts from Wuhan-Hu-1", "descendant of BA.2.86"]
    ]
    .query("(clade == 'BA.2.86') or `descendant of BA.2.86`")
)

ba_2_86_spike_muts_from_wuhan_hu_1 = clade_phenotypes.set_index("clade").at[
    "BA.2.86", "spike muts from Wuhan-Hu-1"
]

clade_phenotypes = (
    clade_phenotypes
    .merge(
        clade_phenotypes
        [["clade", "spike muts from Wuhan-Hu-1"]]
        .rename(
            columns={
                "clade": "parent",
                "spike muts from Wuhan-Hu-1": "parent spike muts from Wuhan-Hu-1",
            }
        ),
        on="parent",
        validate="many_to_one",
        how="left",
    )
    .assign(
        spike_muts_from_parent=lambda x: x.apply(
            lambda row: relative_mutations(
                row["spike muts from Wuhan-Hu-1"],
                row["parent spike muts from Wuhan-Hu-1"],
            ),
            axis=1,
        ),
        spike_muts_from_ba_2_86=lambda x: x.apply(
            lambda row: relative_mutations(
                row["spike muts from Wuhan-Hu-1"],
                ba_2_86_spike_muts_from_wuhan_hu_1,
            ),
            axis=1,
        ),            
        has_spike_muts_from_parent=lambda x: x["spike_muts_from_parent"].map(lambda ms: len(ms) > 0),
    )
    .query("has_spike_muts_from_parent")
    .drop(columns=["spike muts from Wuhan-Hu-1", "parent spike muts from Wuhan-Hu-1", "has_spike_muts_from_parent"])
)

Get all amino acids at each site observed at least a threshold number of times in GISAID sequences as well as the "wildtype" amino acid at each site.
We will randomize from these amino acids:

In [65]:
gisaid_mutation_counts = (
    pd.read_csv(gisaid_mutation_counts_csv)
    .assign(meets_threshold=lambda x: x["count"] >= gisaid_min_counts)
    .query("meets_threshold")
)

gisaid_muts = (
    list(gisaid_mutation_counts[["site", "mutant"]].itertuples(index=False, name=None))
    + list(set(gisaid_mutation_counts[["site", "wildtype"]].itertuples(index=False, name=None)))
)
assert len(gisaid_muts) == len(set(gisaid_muts))

print(f"Retained {len(gisaid_muts)} natural mutations to randomize among")

Retained 5839 natural mutations to randomize among


Now get phenotype changes of each clade:

In [50]:
mutation_phenotypes = pd.read_csv(mutation_phenotypes_csv)
assert (mutation_phenotypes["ref_clade"] == "XBB.1.5").all()

class PhenotypeAssigner:
    """Assign phenotypes to sets of mutations.

    Parameters
    ----------
    mutation_phenotypes_df : pandas.DataFrame
        Should have columns `site`, `wildtype`, `mutant`, `mutation_effect`.
        
    """
    def __init__(self, mutation_phenotypes_df):
        assert len(mutation_phenotypes_df) == len(
            mutation_phenotypes_df[["site", "mutant"]].drop_duplicates()
        )
        self.sites = sorted(set(mutation_phenotypes_df["site"]))
        assert len(self.sites) == len(
            mutation_phenotypes_df[["site", "wildtype"]].drop_duplicates()
        )
        self.wts = mutation_phenotypes_df.set_index("site")["wildtype"].to_dict()
        self.effects = {
            site: site_df.set_index("mutant")["mutation_effect"].to_dict()
            for site, site_df in mutation_phenotypes_df.groupby("site")
        }
        for site, wt in self.wts.items():
            assert wt not in self.effects[site]
            self.effects[site][wt] = 0.0

    def phenotype(self, muts):
        """Returns phenotype for list of `muts` as `(wildtype, site, mutant)`."""
        pheno = 0.0
        for wt, site, m in muts:
            if (site in self.effects) and (wt in self.effects[site]) and (m in self.effects[site]):
                pheno += self.effects[site][m] - self.effects[site][wt]
        return pheno

pheno_changes_df = []
for clade_set, df, mut_col in [
    (
        "BA.2.86 relative to BA.2",
        clade_phenotypes.query("clade == 'BA.2.86'"),
        "spike_muts_from_parent",
    ),
    (
        "BA.2.86-descended clades with new spike mutations relative to BA.2",
        clade_phenotypes[clade_phenotypes["descendant of BA.2.86"]],
        "spike_muts_from_ba_2_86",
    ),
]:  
    for phenotype, mut_df in mutation_phenotypes.groupby("phenotype"):
        phenos = PhenotypeAssigner(mut_df)

        for clade, muts in df[["clade", mut_col]].itertuples(index=False):
            actual_pheno = phenos.phenotype(muts)
            nmuts = len(muts)

            pheno_changes_df.append((clade_set, phenotype, clade, nmuts, actual_pheno, []))

pheno_changes_df = pd.DataFrame(
    pheno_changes_df,
    columns=["clade_set", "phenotype", "clade", "n_mutations", "actual_phenotype", "randomized_phenotypes"],
)

pheno_changes_df.head(10)
        

Unnamed: 0,clade_set,phenotype,clade,n_mutations,actual_phenotype,randomized_phenotypes
0,BA.2.86 relative to BA.2,EVEscape,BA.2.86,59,29.34,[]
1,BA.2.86 relative to BA.2,RBD yeast-display DMS ACE2 affinity,BA.2.86,59,4.84,[]
2,BA.2.86 relative to BA.2,RBD yeast-display DMS RBD expression,BA.2.86,59,-0.09,[]
3,BA.2.86 relative to BA.2,RBD yeast-display DMS escape,BA.2.86,59,-0.656305,[]
4,BA.2.86 relative to BA.2,spike pseudovirus DMS ACE2 binding,BA.2.86,59,-1.131016,[]
5,BA.2.86 relative to BA.2,spike pseudovirus DMS human sera escape,BA.2.86,59,4.244814,[]
6,BA.2.86 relative to BA.2,spike pseudovirus DMS spike mediated entry,BA.2.86,59,6.47732,[]
7,BA.2.86-descended clades with new spike mutati...,EVEscape,JN.1,1,3.622,[]
8,BA.2.86-descended clades with new spike mutati...,EVEscape,JN.1.1.1,2,5.615,[]
9,BA.2.86-descended clades with new spike mutati...,EVEscape,JN.1.1.2,2,7.456,[]


In [41]:
mutation_phenotypes

Unnamed: 0,ref_clade,phenotype,site,wildtype,mutant,mutation_effect
0,XBB.1.5,spike pseudovirus DMS human sera escape,2,F,C,0.01114
1,XBB.1.5,spike pseudovirus DMS human sera escape,2,F,L,0.01876
2,XBB.1.5,spike pseudovirus DMS human sera escape,2,F,S,0.03169
3,XBB.1.5,spike pseudovirus DMS human sera escape,3,V,A,0.02402
4,XBB.1.5,spike pseudovirus DMS human sera escape,3,V,F,0.11270
...,...,...,...,...,...,...
54289,XBB.1.5,EVEscape,1273,T,R,3.17800
54290,XBB.1.5,EVEscape,1273,T,S,1.94600
54291,XBB.1.5,EVEscape,1273,T,V,2.65500
54292,XBB.1.5,EVEscape,1273,T,W,2.08300
