# Design SARS-CoV-2 spike RBD mutants
This notebook chooses/designs spikes with mutant RBDs.

## Setup
Import Python modules:

In [1]:
import collections
import json
import os
import subprocess
import tempfile
import urllib

import altair as alt

import Bio.Entrez
import Bio.SeqIO

import numpy

import pandas as pd

import ruamel.yaml as yaml

Read configuration:

In [13]:
with open("config.yaml") as f:
    config = yaml.YAML().load(f)

## Get the reference RBD

In [3]:
Bio.Entrez.email = "example@example.com"

print(f"Getting reference spike from accession {config['ref_spike']}")
with Bio.Entrez.efetch(id=config["ref_spike"], rettype="gb", retmode="text", db="protein") as f:
    ref_spike = Bio.SeqIO.read(f, "gb")
print(f"Got spike of length {len(ref_spike)}")

rbd_coords = config["rbd_coords"]

ref_rbd = {r: ref_spike[r - 1] for r in range(rbd_coords[0], rbd_coords[1] + 1)}

Getting reference spike from accession YP_009724390
Got spike of length 1273


## Get RBD mutations in each Pango lineage

In [4]:
pango_json = config["pango_json"]
print(f"Reading Pango clade definitions from {pango_json}")
with urllib.request.urlopen(pango_json) as url:
    pango_lineages = json.load(url)
print(f"Read definitions for {len(pango_lineages)} lineages")


def parse_spike_rbd_muts(lineage_d):
    """Parse spike RBD mutations from dict for a lineage."""
    rbd_muts = []
    for mut in lineage_d["aaSubstitutions"] + lineage_d["aaDeletions"]:
        if mut.startswith("S:"):
            mut = mut.split(":")[1]
            wt = mut[0]
            r = int(mut[1: -1])
            m = mut[-1]
            if rbd_coords[0] <= r <= rbd_coords[1]:
                assert wt == ref_rbd[r]
                rbd_muts.append((wt, r, m))
    return rbd_muts
            
pango_rbd_muts = {pango: parse_spike_rbd_muts(d) for (pango, d) in pango_lineages.items()}

Reading Pango clade definitions from https://raw.githubusercontent.com/corneliusroemer/pango-sequences/main/data/pango-consensus-sequences_summary.json
Read definitions for 3916 lineages


## Get and import the escape calculator

In [14]:
# get and import the module
_ = urllib.request.urlretrieve(
    config["escape_calculator_module_url"],
    "escapecalculator.py",
)

import escapecalculator

## Now make designs for each parent

In [10]:
repetition_downweight = config["repetition_downweight"]
allow_reversions_to_ref = config["allow_reversions_to_ref"]
categories = config["categories"]

for parent, parent_d in config["parent_specs"].items():
    nmutants = parent_d["nmutants"]
    print(f"Making {nmutants=} mutant designs for {parent=} in each category")
    with open(parent_d["specs"]) as f:
        parent_config = yaml.YAML().load(f)

    print(parent_config)

Making nmutants=5 mutant designs for parent='XBB.1.5' in each category
{'escapecalculator': {'module_url': 'https://raw.githubusercontent.com/jbloomlab/SARS2-RBD-escape-calc/6d56b7d0b353ddf7b107eb63b50124826d476dee/escapecalculator.py', 'kwargs': {'antibody_ic50s': 'https://raw.githubusercontent.com/jbloomlab/SARS2-RBD-escape-calc/6d56b7d0b353ddf7b107eb63b50124826d476dee/results/antibody_IC50s.csv', 'antibody_binding': 'https://raw.githubusercontent.com/jbloomlab/SARS2-RBD-escape-calc/6d56b7d0b353ddf7b107eb63b50124826d476dee/results/antibody_binding.csv', 'antibody_sources': 'https://raw.githubusercontent.com/jbloomlab/SARS2-RBD-escape-calc/6d56b7d0b353ddf7b107eb63b50124826d476dee/results/antibody_sources.csv', 'antibody_reweighting': 'https://raw.githubusercontent.com/jbloomlab/SARS2-RBD-escape-calc/6d56b7d0b353ddf7b107eb63b50124826d476dee/results/antibody_reweighting.csv', 'config': 'https://raw.githubusercontent.com/jbloomlab/SARS2-RBD-escape-calc/6d56b7d0b353ddf7b107eb63b50124826d4

## Now design mutated cocktails
First, set up escape calculator:

In [None]:
# get and import the module
_ = urllib.request.urlretrieve(
    config["escapecalculator"]["module_url"],
    "escapecalculator.py",
)
import escapecalculator

escape_calc = escapecalculator.EscapeCalculator(**config["escapecalculator"]["kwargs"])

print(f"Using the following virus: {escape_calc.virus=}")

Now get the RBD deep mutational scanning data:

In [None]:
rbd_dms_target = config["rbd_dms"]["target"]

rbd_dms = (
    pd.read_csv(config["rbd_dms"]["data"])
    .query("target == @rbd_dms_target")
    .rename(columns={"position": "site", "mutant": "amino_acid"})
    [["site", "amino_acid", "delta_bind", "delta_expr"]]
)
    
rbd_dms

Now get the fitness estimates.
Note that these implicitly include the nucleotide mutation spectrum as there won't be estimates for inaccessible mutations.
We compute fitness as the average of the all-clade and recent-clade fitnesses where both available, and just the all-clade otherwise:

In [None]:
# first get the fitness estimates themselves
fitness = (
    pd.read_csv(config["fitness_estimates"]["fitness"])
    .query("gene == 'S'")
    .query("expected_count >= @config['fitness_estimates']['drop_min_expected_count']")
    .rename(columns={"aa_site": "site", "aa": "amino_acid"})
    [["site", "amino_acid", "fitness"]]
)

# now get clade specific estimates for favorable mutations w sufficient counts
fitness_clade = (
    pd.read_csv(config["fitness_estimates"]["by_clade"])
    .query("clade == @config['fitness_estimates']['clade']")
    .query("gene == 'S'")
    .rename(columns={"aa_site": "site", "mutant_aa": "amino_acid"})
    .query("delta_fitness > 0")
    .assign(max_count=lambda x: numpy.maximum(x["expected_count"], x["actual_count"]))
    .query("max_count >= @config['fitness_estimates']['clade_min_count']")
    [["site", "amino_acid", "delta_fitness"]]
)

# average in favorable clade-specific estimates
fitness = (
    fitness
    .merge(fitness_clade, how="left")
    .assign(
        avg=lambda x: numpy.where(
            x["delta_fitness"].isnull(),
            x["fitness"],
            (x["fitness"] + x["delta_fitness"]) / 2,
        ),
        fitness_score=lambda x: numpy.clip(
            numpy.exp(x["avg"]), a_max=config["fitness_estimates"]["clip"], a_min=None,
        )
    )
    [["site", "amino_acid", "fitness_score"]]
)

Now iterate through the cocktails and design the mutants according to the following criteria:

 - We include all the parents
 - For the remaining components, we:
   1. Pick a parent
   2. Choose the specified number of mutations, at each step choosing the mutation that maximizes the product of:
     + the most escape according to the [RBD escape calculator](https://jbloomlab.github.io/SARS2-RBD-escape-calc/), but according to the condition no mutation is repeated in the cocktail.
     + exponential of the change in ACE2 affinity as measured in [RBD deep mutational scanning](https://journals.plos.org/plospathogens/article?id=10.1371/journal.ppat.1010951) with a ceiling applied
     + exponential of the change in expression as measured in [RBD deep mutational scanning](https://journals.plos.org/plospathogens/article?id=10.1371/journal.ppat.1010951) with a ceiling appled
     + a fitness score calculated as exponential the fitness estimate from [Bloom and Neher](https://www.biorxiv.org/content/10.1101/2023.01.30.526314v1), with clipping
     + ensuring only one mutation is chosen at each site
   3. Pick the next parent (repeating a previously used one if needed), and pick a new set of mutations. Mutations are downweighted each time they are included in a cocktail component to encourage new compositions of subsequent variants.
   

In [None]:
components_per_cocktail = config["components_per_cocktail"]
mutations_per_design = config["mutations_per_design"]
parent_existing_mut_sites = config["escapecalculator"]["parent_existing_mut_sites"]
rbd_dms_clip = config["rbd_dms"]["clip"]
repetition_downweight = config["repetition_downweight"]

if len(parents) >= components_per_cocktail:
    raise ValueError("nothing to design if as many parents and components")
print(f"All cocktails have {components_per_cocktail} components")

for cocktail, n_muts in mutations_per_design.items():
    print(f"\nDesigning {cocktail} with {n_muts} mutations per design")
    
    cocktail_seqs = [(name, str(seq.seq)) for name, seq in parents.items()]
    cocktail_mutation_counts = collections.defaultdict(int)  # counts of mutations in cocktail so far
    
    i = 0
    while len(cocktail_seqs) < components_per_cocktail:
        design_parent_name, design_parent_seq = list(parents.items())[i % len(parents)]
        design_parent_seq = str(design_parent_seq.seq)
        i += 1
        
        mut_sites = list(parent_existing_mut_sites[design_parent_name])
        
        # normalize RBD DMS so parental value is 0
        parent_normalized_rbd_dms = (
            rbd_dms
            .assign(parent_aa=lambda x: x["site"].map(parent_wts[design_parent_name]))
            .query("amino_acid == parent_aa")
            .rename(columns={"delta_expr": "parent_expr", "delta_bind": "parent_bind"})
            .drop(columns="amino_acid")
            .merge(rbd_dms, on="site", validate="one_to_many")
            .assign(
                delta_expr=lambda x: x["delta_expr"] - x["parent_expr"],
                delta_bind=lambda x: x["delta_bind"] - x["parent_bind"],
            )
            [["site", "amino_acid", "delta_expr", "delta_bind"]]
        )
        
        design_mutations = []
        for _ in range(n_muts):
            # get top scoring mutation
            mut_scores = (
                escape_calc.escape_per_site(mut_sites)
                .merge(parent_normalized_rbd_dms, on="site")
                .merge(fitness, on=["site", "amino_acid"])
                .assign(
                    parent_aa=lambda x: x["site"].map(parent_wts[design_parent_name]),
                    mutation=lambda x: x["parent_aa"] + x["site"].astype(str) + x["amino_acid"],
                    repetition_weight=lambda x: repetition_downweight**x["mutation"].map(cocktail_mutation_counts),
                    score=lambda x: (
                        x["retained_escape"]
                        * numpy.clip(numpy.exp(x["delta_expr"]), a_min=None, a_max=rbd_dms_clip)
                        * numpy.clip(numpy.exp(x["delta_bind"]), a_min=None, a_max=rbd_dms_clip)
                        * x["fitness_score"]
                        * x["repetition_weight"]
                    ),
                )
                .sort_values("score", ascending=False)
                .query("parent_aa != amino_acid")
            )
            mutation = mut_scores["mutation"].iloc[0]
            site = mut_scores["site"].iloc[0]
            design_mutations.append(mutation)
            cocktail_mutation_counts[mutation] += 1
            mut_sites.append(site)
        
        design_seq = list(design_parent_seq)
        for mut in design_mutations:
            r_parent = ref_to_parent_site[design_parent_name][int(mut[1: -1])]
            assert design_seq[r_parent - 1] == mut[0]
            design_seq[r_parent - 1] = mut[-1]
        cocktail_seqs.append(
            (
                design_parent_name + "_" + "_".join(design_mutations),
                "".join(design_seq),
            )
        )
        
    print("Components of cocktail are:\n " + "\n ".join(n for n, _ in cocktail_seqs))
    
    cocktail_vax_file = f"vax_designs/{cocktail}-vax.fa"
    print(f"Writing {cocktail} of {len(cocktail_seqs)} to {cocktail_vax_file}")
    with open(cocktail_vax_file, "w") as f:
        for name, seq in cocktail_seqs:
            f.write(f">{name}\n{seq}\n")