# Design SARS-CoV-2 spike RBD mutants
This notebook chooses/designs spikes with mutant RBDs.

## Setup
Import Python modules:

In [1]:
import collections
import json
import os
import subprocess
import tempfile
import urllib

import altair as alt

import Bio.Entrez
import Bio.SeqIO

import numpy

import pandas as pd

import ruamel.yaml as yaml

Read configuration:

In [2]:
with open("config.yaml") as f:
    config = yaml.YAML().load(f)

## Get the reference RBD

In [3]:
Bio.Entrez.email = "example@example.com"

print(f"Getting reference spike from accession {config['ref_spike']}")
with Bio.Entrez.efetch(id=config["ref_spike"], rettype="gb", retmode="text", db="protein") as f:
    ref_spike = Bio.SeqIO.read(f, "gb")
print(f"Got spike of length {len(ref_spike)}")

rbd_coords = config["rbd_coords"]

ref_rbd = {r: ref_spike[r - 1] for r in range(rbd_coords[0], rbd_coords[1] + 1)}

Getting reference spike from accession YP_009724390
Got spike of length 1273


## Get RBD mutations in each Pango lineage

In [4]:
pango_json = config["pango_json"]
print(f"Reading Pango clade definitions from {pango_json}")
with urllib.request.urlopen(pango_json) as url:
    pango_lineages = json.load(url)
print(f"Read definitions for {len(pango_lineages)} lineages")


def parse_spike_rbd_muts(lineage_d):
    """Parse spike RBD mutations from dict for a lineage."""
    rbd_muts = []
    for mut in lineage_d["aaSubstitutions"] + lineage_d["aaDeletions"]:
        if mut.startswith("S:"):
            mut = mut.split(":")[1]
            wt = mut[0]
            r = int(mut[1: -1])
            m = mut[-1]
            if rbd_coords[0] <= r <= rbd_coords[1]:
                assert wt == ref_rbd[r]
                rbd_muts.append((wt, r, m))
    return rbd_muts
            
pango_rbd_muts = {pango: parse_spike_rbd_muts(d) for (pango, d) in pango_lineages.items()}

Reading Pango clade definitions from https://raw.githubusercontent.com/corneliusroemer/pango-sequences/main/data/pango-consensus-sequences_summary.json
Read definitions for 3916 lineages


## Get and import the escape calculator

In [5]:
# get and import the module
_ = urllib.request.urlretrieve(
    config["escape_calculator_module_url"],
    "escapecalculator.py",
)

import escapecalculator

## Now make designs for each parent

In [20]:
repetition_downweight = config["repetition_downweight"]
allow_reversions_to_ref = config["allow_reversions_to_ref"]
categories = config["categories"]

for parent, parent_d in config["parent_specs"].items():
    nmutants = parent_d["nmutants"]
    print(f"\nMaking {nmutants=} mutant designs for {parent=} in each category")
    with open(parent_d["specs"]) as f:
        parent_config = yaml.YAML().load(f)

    # make a data frame that includes parent and ref amino acid at each site
    parent_aas = (
        # first get reference and parent amino acid for each site
        pd.Series(ref_rbd).rename_axis("site").rename("ref_amino_acid").reset_index()
        .merge(
            pd.DataFrame(
                pango_rbd_muts[parent],
                columns=["ref_amino_acid", "site", "parent_amino_acid"],
            ),
            how="outer",
            on=["ref_amino_acid", "site"],
        )
        .assign(
            parent_amino_acid=lambda x: x["parent_amino_acid"].where(
                x["parent_amino_acid"].notnull(), x["ref_amino_acid"])
        )
    )
    assert len(parent_aas) == parent_aas["site"].nunique() == len(ref_rbd)

    # set up escape calculator
#    escape_calc = escapecalculator.EscapeCalculator(virus=parent_config["escapecalculator"]["virus"])
    # get sites that differ between parent and virus used to initialize escape calculator
    parent_escape_calc_virus_diff_sites = list(set([
        tup[1]
        for tup in set(pango_rbd_muts[parent]).symmetric_difference(
            pango_rbd_muts[parent_config["escapecalculator"]["virus"]]
        )
    ]))

    # get RBD deep mutational scanning data
    rbd_dms = (
        pd.read_csv(parent_config["rbd_dms"]["data"])
        .query("target == @parent_config['rbd_dms']['target']")
        .rename(columns={"position": "site", "mutant": "amino_acid"})
        [["site", "amino_acid", "delta_bind", "delta_expr"]]
        .assign(
            delta_bind=lambda x: x["delta_bind"].clip(upper=parent_config["rbd_dms"]["clip"]),
            delta_expr=lambda x: x["delta_expr"].clip(upper=parent_config["rbd_dms"]["clip"]),
        )
    )
    # normalize values to parent amino acids
    rbd_dms = (
        rbd_dms
        .merge(parent_aas[["site", "parent_amino_acid"]], on="site", validate="many_to_one")
        .query("parent_amino_acid == amino_acid")
        .rename(columns={"delta_expr": "parent_expr", "delta_bind": "parent_bind"})
        .drop(columns=["amino_acid", "parent_amino_acid"])
        .merge(rbd_dms, on="site", validate="one_to_many")
        .assign(
            delta_bind=lambda x: x["delta_bind"] - x["parent_bind"],
            delta_expr=lambda x: x["delta_expr"] - x["parent_expr"],
        )
        .drop(columns=["parent_expr", "parent_bind"])
        .rename(columns={"delta_bind": "rbd_delta_bind", "delta_expr": "rbd_delta_expr"})
    )
    assert len(rbd_dms)

    # get all-clade fitness estimates 
    fitness = (
        pd.read_csv(parent_config["fitness_estimates"]["fitness"])
        .query("gene == 'S'")
        .query("expected_count >= @parent_config['fitness_estimates']['fitness_min_count']")
        .rename(columns={"aa_site": "site", "aa": "amino_acid"})
        .assign(fitness=lambda x: x["fitness"].clip(upper=parent_config["fitness_estimates"]["clip"]))
        [["site", "amino_acid", "fitness"]]
    )
    assert len(fitness)

    # get clade fitness effects
    by_clade_fitness = (
        pd.read_csv(parent_config["fitness_estimates"]["by_clade"])
        .query("gene == 'S'")
        .query("clade == @parent_config['fitness_estimates']['clade']")
        .query(
            "(expected_count >= @parent_config['fitness_estimates']['clade_min_count'])"
            "or (actual_count >= @parent_config['fitness_estimates']['clade_min_count'])"
        )
        .rename(columns={"aa_site": "site", "mutant_aa": "amino_acid", "delta_fitness": "clade_fitness"})
        .assign(fitness=lambda x: x["clade_fitness"].clip(upper=parent_config["fitness_estimates"]["clip"]))
        [["site", "amino_acid", "clade_fitness"]]
    )
    assert len(by_clade_fitness)

    # get spike DMS
    spike_dms = (
        pd.read_csv(parent_config["spike_dms"]["csv"])
        .rename(columns={"mutant": "amino_acid"})
        [["site", "amino_acid", "human sera escape", "spike mediated entry", "ACE2 binding"]]
    )
    # normalize values to parent amino acids
    spike_dms = (
        spike_dms
        .merge(parent_aas[["site", "parent_amino_acid"]], on="site", validate="many_to_one")
        .query("parent_amino_acid == amino_acid")
        .rename(
            columns={
                "human sera escape": "parent_human sera escape",
                "spike mediated entry": "parent_spike mediated entry",
                "ACE2 binding": "parent_ACE2 binding",
            }
        )
        .drop(columns=["amino_acid", "parent_amino_acid"])
        .merge(spike_dms, on="site", validate="one_to_many")
        .assign(
            spike_escape=lambda x: x["human sera escape"] - x["parent_human sera escape"],
            spike_entry=lambda x: x["spike mediated entry"] - x["parent_spike mediated entry"],
            spike_ACE2_binding=lambda x: x["ACE2 binding"] - x["parent_ACE2 binding"],
        )
        [["site", "amino_acid", "spike_escape", "spike_entry", "spike_ACE2_binding"]]
    )
    assert len(spike_dms)

    # now add each phenotype and its weight
    parent_phenotypes = (
        parent_aas
        # add fitness estimates, only keeping mutations with estimates
        .merge(fitness, on="site", validate="one_to_many", how="inner")
        .assign(fitness_weight=parent_config["fitness_estimates"]["weights"]["fitness"])
        # add clade fitness estimates, clipping lower at zero and also setting missing to zero
        .merge(by_clade_fitness, on=["site", "amino_acid"], validate="one_to_one", how="left")
        .assign(
            clade_fitness=lambda x: x["clade_fitness"].where(x["clade_fitness"] > 0, 0),
            clade_fitness_weight=parent_config["fitness_estimates"]["weights"]["by_clade_effect"],
        )
        # add RBD DMS, only keeping mutations with measurements
        .merge(rbd_dms, on=["site", "amino_acid"], validate="one_to_one", how="inner")
        .assign(
            rbd_delta_bind_weight=parent_config["rbd_dms"]["weights"]["delta_bind"],
            rbd_deta_expr_weight=parent_config["rbd_dms"]["weights"]["delta_expr"],
        )
        # add spike DMS, only keeping mutations with measurements
        .merge(spike_dms, on=["site", "amino_acid"], validate="one_to_one", how="inner")
        .assign(
            spike_escape_weight=parent_config["spike_dms"]["weights"]["sera escape"],
            spike_entry_weight=parent_config["spike_dms"]["weights"]["spike mediated entry"],
            spike_ACE2_binding_weight=parent_config["spike_dms"]["weights"]["ACE2 binding"],
        )
    )
    print(f"Have phenotypes for {len(parent_phenotypes)} mutations at {parent_phenotypes['site'].nunique()} sites")




Making nmutants=5 mutant designs for parent='JN.1' in each category
Have phenotypes for 977 mutations at 192 sites


In [16]:
parent_config

{'escapecalculator': {'virus': 'XBB', 'weight': 1}, 'rbd_dms': {'data': 'https://media.githubusercontent.com/media/tstarrlab/SARS-CoV-2-RBD_DMS_Omicron-XBB-BQ/main/results/final_variant_scores/final_variant_scores.csv', 'target': 'Omicron_XBB15', 'clip': 2, 'weights': {'delta_bind': 1, 'delta_expr': 1}}, 'fitness_estimates': {'fitness': 'https://raw.githubusercontent.com/jbloomlab/SARS2-mut-fitness/main/results_gisaid_2024-04-24/aa_fitness/aa_fitness.csv', 'by_clade': 'https://media.githubusercontent.com/media/jbloomlab/SARS2-mut-fitness/main/results_gisaid_2024-04-24/aa_fitness/aamut_fitness_by_clade.csv', 'clade': '23A', 'clade_min_count': 5, 'fitness_min_count': 10, 'clip': 4, 'weights': {'fitness': 1, 'by_clade_effect': 1}}, 'spike_dms': {'csv': 'https://raw.githubusercontent.com/dms-vep/SARS-CoV-2_XBB.1.5_spike_DMS/main/results/summaries/summary.csv', 'weights': {'sera escape': 1, 'spike mediated entry': 1, 'ACE2 binding': 1}}}

In [None]:
components_per_cocktail = config["components_per_cocktail"]
mutations_per_design = config["mutations_per_design"]
parent_existing_mut_sites = config["escapecalculator"]["parent_existing_mut_sites"]
rbd_dms_clip = config["rbd_dms"]["clip"]
repetition_downweight = config["repetition_downweight"]

if len(parents) >= components_per_cocktail:
    raise ValueError("nothing to design if as many parents and components")
print(f"All cocktails have {components_per_cocktail} components")

for cocktail, n_muts in mutations_per_design.items():
    print(f"\nDesigning {cocktail} with {n_muts} mutations per design")
    
    cocktail_seqs = [(name, str(seq.seq)) for name, seq in parents.items()]
    cocktail_mutation_counts = collections.defaultdict(int)  # counts of mutations in cocktail so far
    
    i = 0
    while len(cocktail_seqs) < components_per_cocktail:
        design_parent_name, design_parent_seq = list(parents.items())[i % len(parents)]
        design_parent_seq = str(design_parent_seq.seq)
        i += 1
        
        mut_sites = list(parent_existing_mut_sites[design_parent_name])
        
        # normalize RBD DMS so parental value is 0
        parent_normalized_rbd_dms = (
            rbd_dms
            .assign(parent_aa=lambda x: x["site"].map(parent_wts[design_parent_name]))
            .query("amino_acid == parent_aa")
            .rename(columns={"delta_expr": "parent_expr", "delta_bind": "parent_bind"})
            .drop(columns="amino_acid")
            .merge(rbd_dms, on="site", validate="one_to_many")
            .assign(
                delta_expr=lambda x: x["delta_expr"] - x["parent_expr"],
                delta_bind=lambda x: x["delta_bind"] - x["parent_bind"],
            )
            [["site", "amino_acid", "delta_expr", "delta_bind"]]
        )
        
        design_mutations = []
        for _ in range(n_muts):
            # get top scoring mutation
            mut_scores = (
                escape_calc.escape_per_site(mut_sites)
                .merge(parent_normalized_rbd_dms, on="site")
                .merge(fitness, on=["site", "amino_acid"])
                .assign(
                    parent_aa=lambda x: x["site"].map(parent_wts[design_parent_name]),
                    mutation=lambda x: x["parent_aa"] + x["site"].astype(str) + x["amino_acid"],
                    repetition_weight=lambda x: repetition_downweight**x["mutation"].map(cocktail_mutation_counts),
                    score=lambda x: (
                        x["retained_escape"]
                        * numpy.clip(numpy.exp(x["delta_expr"]), a_min=None, a_max=rbd_dms_clip)
                        * numpy.clip(numpy.exp(x["delta_bind"]), a_min=None, a_max=rbd_dms_clip)
                        * x["fitness_score"]
                        * x["repetition_weight"]
                    ),
                )
                .sort_values("score", ascending=False)
                .query("parent_aa != amino_acid")
            )
            mutation = mut_scores["mutation"].iloc[0]
            site = mut_scores["site"].iloc[0]
            design_mutations.append(mutation)
            cocktail_mutation_counts[mutation] += 1
            mut_sites.append(site)
        
        design_seq = list(design_parent_seq)
        for mut in design_mutations:
            r_parent = ref_to_parent_site[design_parent_name][int(mut[1: -1])]
            assert design_seq[r_parent - 1] == mut[0]
            design_seq[r_parent - 1] = mut[-1]
        cocktail_seqs.append(
            (
                design_parent_name + "_" + "_".join(design_mutations),
                "".join(design_seq),
            )
        )
        
    print("Components of cocktail are:\n " + "\n ".join(n for n, _ in cocktail_seqs))
    
    cocktail_vax_file = f"vax_designs/{cocktail}-vax.fa"
    print(f"Writing {cocktail} of {len(cocktail_seqs)} to {cocktail_vax_file}")
    with open(cocktail_vax_file, "w") as f:
        for name, seq in cocktail_seqs:
            f.write(f">{name}\n{seq}\n")