# Design cocktail SARS-CoV-2 vaccines
This notebook chooses/designs spikes for a SARS-CoV-2 cocktail vaccine.
The design is done by Jesse Bloom, for a project led by Drew Weissman.

We design the following vaccines:

 - **parent-vax**: unmutated spike(s) of leading SARS-CoV-2 variants at time of vaccine design.
 
 - **cocktail-vax**: spike(s) in *parent-vax* plus additional designed spikes with mutations predicted to be likely to occur in future human SARS-CoV-2.
 
 - **conservative-cocktail-vax**: like *cocktail-vax*, but fewer mutations per designed spike.
 
 - **aggressive-cocktail-vax**: like *cocktail-vax*, but more mutations per designed spike.

## Setup
Import Python modules:

In [1]:
import os
import subprocess
import tempfile
import urllib

import altair as alt

import Bio.Entrez
import Bio.SeqIO

import numpy

import pandas as pd

import yaml

Read configuration:

In [2]:
with open("config.yaml") as f:
    config = yaml.safe_load(f)
    
print("Here is the configuration:\n")
print(yaml.dump(config))

Here is the configuration:

components_per_cocktail: 5
escapecalculator:
  kwargs:
    antibody_binding: https://raw.githubusercontent.com/jbloomlab/SARS2-RBD-escape-calc/683abdd4c3277a3cbc20cddd7ab98ff844f9ef80/results/antibody_binding.csv
    antibody_ic50s: https://raw.githubusercontent.com/jbloomlab/SARS2-RBD-escape-calc/683abdd4c3277a3cbc20cddd7ab98ff844f9ef80/results/antibody_IC50s.csv
    antibody_reweighting: https://raw.githubusercontent.com/jbloomlab/SARS2-RBD-escape-calc/683abdd4c3277a3cbc20cddd7ab98ff844f9ef80/results/antibody_reweighting.csv
    antibody_sources: https://raw.githubusercontent.com/jbloomlab/SARS2-RBD-escape-calc/683abdd4c3277a3cbc20cddd7ab98ff844f9ef80/results/antibody_sources.csv
    config: https://raw.githubusercontent.com/jbloomlab/SARS2-RBD-escape-calc/683abdd4c3277a3cbc20cddd7ab98ff844f9ef80/config.yaml
  module_url: https://raw.githubusercontent.com/jbloomlab/SARS2-RBD-escape-calc/683abdd4c3277a3cbc20cddd7ab98ff844f9ef80/escapecalculator.py
  parent_

## Design *parent-vax* 
Get the sequences:

In [3]:
Bio.Entrez.email = "example@example.com"

print(f"Getting reference spike from accession {config['ref_spike']}")
with Bio.Entrez.efetch(id=config["ref_spike"], rettype="gb", retmode="text", db="protein") as f:
    ref_spike = Bio.SeqIO.read(f, "gb")
print(f"Got spike of length {len(ref_spike)}")

parents = {}
for name, acc in config["parent_spikes"].items():
    print(f"\nGetting parent spike {name} from {acc}")
    with Bio.Entrez.efetch(id=acc, rettype="gb", retmode="text", db="protein") as f:
        parents[name] = Bio.SeqIO.read(f, "gb")
    print(f"Got spike of length {len(parents[name])}")

Getting reference spike from accession YP_009724390
Got spike of length 1273

Getting parent spike XBB.1.16 from WCM02109
Got spike of length 1269


Get the lists of mutations in each parental spike relative to the Wuhan-Hu-1 reference, and also a mapping of site number to wildtypes and site number in reference to site in each parent:

In [4]:
parent_wts = {}

ref_to_parent_site = {}

for name, seq in parents.items():
    with tempfile.TemporaryDirectory() as tmpdir:
        f_in = os.path.join(tmpdir, "in.fa")
        f_out = os.path.join(tmpdir, "out.fa")
        Bio.SeqIO.write([ref_spike, seq], f_in, "fasta")
        res = subprocess.run(
            ["muscle", "-align", f_in, "-output", f_out],
            check=True,
            capture_output=True,
        )
        assert res.returncode == 0
        aligned_ref, aligned_seq = list(Bio.SeqIO.parse(f_out, "fasta"))
    muts = []
    r_ref = r_seq = 1
    parent_wts[name] = {}
    ref_to_parent_site[name] = {}
    for a_ref, a_seq in zip(aligned_ref.seq, aligned_seq.seq):
        parent_wts[name][r_ref] = a_seq
        ref_to_parent_site[name][r_ref] = r_seq
        if a_ref != a_seq:
            muts.append(f"{a_ref}{r_ref}{a_seq}")
        if a_ref != "-":
            r_ref += 1
        if a_seq != "-":
            r_seq += 1
    print(f"\n{name} has the following {len(muts)} mutations:\n  " + "\n  ".join(muts))


XBB.1.16 has the following 43 mutations:
  T19I
  L24-
  P25-
  P26-
  A27S
  V83A
  G142D
  Y145-
  H146Q
  E180V
  Q183E
  V213E
  G252V
  G339H
  R346T
  L368I
  S371F
  S373P
  S375F
  T376A
  D405N
  R408S
  K417N
  N440K
  V445P
  G446S
  N460K
  S477N
  T478R
  E484A
  F486P
  F490S
  Q498R
  N501Y
  Y505H
  D614G
  H655Y
  N679K
  P681H
  N764K
  D796Y
  Q954H
  N969K


Write the *parent-vax* to a file:

In [5]:
os.makedirs("vax_designs", exist_ok=True)
parent_vax_file = "vax_designs/parent-vax.fa"
print(f"Writing parent-vax of {len(parents)} to {parent_vax_file}")
with open(parent_vax_file, "w") as f:
    for name, seq in parents.items():
        f.write(f">{name}\n{str(seq.seq)}\n")

Writing parent-vax of 1 to vax_designs/parent-vax.fa


## Now design mutated cocktails
First, set up escape calculator:

In [6]:
# get and import the module
_ = urllib.request.urlretrieve(
    config["escapecalculator"]["module_url"],
    "escapecalculator.py",
)
import escapecalculator

escape_calc = escapecalculator.EscapeCalculator(**config["escapecalculator"]["kwargs"])

print(f"Using the following virus: {escape_calc.virus=}")

Using the following virus: escape_calc.virus='XBB'


Now get the RBD deep mutational scanning data:

In [7]:
rbd_dms_target = config["rbd_dms"]["target"]

rbd_dms = (
    pd.read_csv(config["rbd_dms"]["data"])
    .query("target == @rbd_dms_target")
    .rename(columns={"position": "site", "mutant": "amino_acid"})
    [["site", "amino_acid", "delta_bind", "delta_expr"]]
)
    
rbd_dms

Unnamed: 0,site,amino_acid,delta_bind,delta_expr
4020,331,A,-0.08339,-0.62526
4021,331,C,-0.61624,-1.18984
4022,331,D,-0.14670,-0.53294
4023,331,E,-0.14146,-0.37718
4024,331,F,-0.53604,-1.12351
...,...,...,...,...
8035,531,S,0.04801,0.01163
8036,531,T,0.00000,0.00000
8037,531,V,0.04638,0.03490
8038,531,W,-0.01288,-0.00798


Now get the fitness estimates.
Note that these implicitly include the nucleotide mutation spectrum as there won't be estimates for inaccessible mutations:

In [8]:
# first get the fitness estimates themselves
fitness = (
    pd.read_csv(config["fitness_estimates"]["fitness"])
    .query("gene == 'S'")
    .query("expected_count >= @config['fitness_estimates']['drop_min_expected_count']")
    .rename(columns={"aa_site": "site", "aa": "amino_acid"})
    [["site", "amino_acid", "fitness"]]
)

# now get clade specific estimates for favorable mutations w sufficient counts
fitness_clade = (
    pd.read_csv(config["fitness_estimates"]["by_clade"])
    .query("clade == @config['fitness_estimates']['clade']")
    .query("gene == 'S'")
    .rename(columns={"aa_site": "site", "mutant_aa": "amino_acid"})
    .query("delta_fitness > 0")
    .assign(max_count=lambda x: numpy.maximum(x["expected_count"], x["actual_count"]))
    .query("max_count >= @config['fitness_estimates']['clade_min_count']")
    [["site", "amino_acid", "delta_fitness"]]
)

# average in favorable clade-specific estimates
fitness = (
    fitness
    .merge(fitness_clade, how="left")
    .assign(
        avg=lambda x: numpy.where(
            x["delta_fitness"].isnull(),
            x["fitness"],
            (x["fitness"] + x["delta_fitness"]) / 2,
        ),
        fitness_score=lambda x: numpy.clip(
            numpy.exp(x["avg"]), a_max=config["fitness_estimates"]["clip"], a_min=None,
        )
    )
    [["site", "amino_acid", "fitness_score"]]
)

Now iterate through the cocktails and design the mutants according to the following criteria:

 - We include all the parents
 - For the remaining components, we:
   1. Pick a parent
   2. Choose the specified number of mutations, at each step choosing the mutation that maximizes the product of:
     + the most escape according to the [RBD escape calculator](https://jbloomlab.github.io/SARS2-RBD-escape-calc/), but according to the condition no mutation is repeated in the cocktail.
     + exponential of the change in ACE2 affinity as measured in [RBD deep mutational scanning](https://journals.plos.org/plospathogens/article?id=10.1371/journal.ppat.1010951) with a ceiling applied
     + exponential of the change in expression as measured in [RBD deep mutational scanning](https://journals.plos.org/plospathogens/article?id=10.1371/journal.ppat.1010951) with a ceiling appled
     + a fitness score calculated as exponential the fitness estimate from [Bloom and Neher](https://www.biorxiv.org/content/10.1101/2023.01.30.526314v1), with clipping
   3. Pick the next parent (repeating a previously used one if needed), and pick a new set of mutations.
   

In [9]:
components_per_cocktail = config["components_per_cocktail"]
mutations_per_design = config["mutations_per_design"]
parent_existing_mut_sites = config["escapecalculator"]["parent_existing_mut_sites"]
rbd_dms_clip = config["rbd_dms"]["clip"]

if len(parents) >= components_per_cocktail:
    raise ValueError("nothing to design if as many parents and components")
print(f"All cocktails have {components_per_cocktail} components")

for cocktail, n_muts in mutations_per_design.items():
    print(f"\nDesigning {cocktail} with {n_muts} mutations per design")
    
    cocktail_seqs = [(name, str(seq.seq)) for name, seq in parents.items()]
    cocktail_mutations = set()  # set of all mutations in cocktail so far, do not repeat
    
    i = 0
    while len(cocktail_seqs) < components_per_cocktail:
        design_parent_name, design_parent_seq = list(parents.items())[i % len(parents)]
        design_parent_seq = str(design_parent_seq.seq)
        i += 1
        
        mut_sites = list(parent_existing_mut_sites[design_parent_name])
        
        # normalize RBD DMS so parental value is 0
        parent_normalized_rbd_dms = (
            rbd_dms
            .assign(parent_aa=lambda x: x["site"].map(parent_wts[design_parent_name]))
            .query("amino_acid == parent_aa")
            .rename(columns={"delta_expr": "parent_expr", "delta_bind": "parent_bind"})
            .drop(columns="amino_acid")
            .merge(rbd_dms, on="site", validate="one_to_many")
            .assign(
                delta_expr=lambda x: x["delta_expr"] - x["parent_expr"],
                delta_bind=lambda x: x["delta_bind"] - x["parent_bind"],
            )
            [["site", "amino_acid", "delta_expr", "delta_bind"]]
        )
        
        design_mutations = []
        for _ in range(n_muts):
            # get top scoring mutation
            mut_scores = (
                escape_calc.escape_per_site(mut_sites)
                .merge(parent_normalized_rbd_dms, on="site")
                .merge(fitness, on=["site", "amino_acid"])
                .assign(
                    parent_aa=lambda x: x["site"].map(parent_wts[design_parent_name]),
                    mutation=lambda x: x["parent_aa"] + x["site"].astype(str) + x["amino_acid"],
                    score=lambda x: (
                        x["retained_escape"]
                        * numpy.clip(numpy.exp(x["delta_expr"]), a_min=None, a_max=rbd_dms_clip)
                        * numpy.clip(numpy.exp(x["delta_bind"]), a_min=None, a_max=rbd_dms_clip)
                        * x["fitness_score"]
                    ),
                )
                .sort_values("score", ascending=False)
                .query("mutation not in @cocktail_mutations")
                .query("parent_aa != amino_acid")
            )           
            mutation = mut_scores["mutation"].iloc[0]
            site = mut_scores["site"].iloc[0]
            design_mutations.append(mutation)
            cocktail_mutations.add(mutation)
            mut_sites.append(site)
        
        design_seq = list(design_parent_seq)
        for mut in design_mutations:
            r_parent = ref_to_parent_site[design_parent_name][int(mut[1: -1])]
            assert design_seq[r_parent - 1] == mut[0]
            design_seq[r_parent - 1] = mut[-1]
        cocktail_seqs.append(
            (
                design_parent_name + "_" + "_".join(design_mutations),
                "".join(design_seq),
            )
        )
        
    print("Components of cocktail are:\n " + "\n ".join(n for n, _ in cocktail_seqs))
    
    cocktail_vax_file = f"vax_designs/{cocktail}-vax.fa"
    print(f"Writing {cocktail} of {len(cocktail_seqs)} to {cocktail_vax_file}")
    with open(cocktail_vax_file, "w") as f:
        for name, seq in cocktail_seqs:
            f.write(f">{name}\n{seq}\n")

All cocktails have 5 components

Designing cocktail with 4 mutations per design
Components of cocktail are:
 XBB.1.16
 XBB.1.16_L455S_K440R_L452M_H505Y
 XBB.1.16_P445A_F456L_T346K_A484S
 XBB.1.16_K444R_A484Q_Y453F_L452R
 XBB.1.16_K444T_R403K_A484K_K356R
Writing cocktail of 5 to vax_designs/cocktail-vax.fa

Designing aggressive-cocktail with 6 mutations per design
Components of cocktail are:
 XBB.1.16
 XBB.1.16_L455S_K440R_L452M_H505Y_K444R_A484S
 XBB.1.16_P445A_F456L_T346K_A484Q_K356R_N405D
 XBB.1.16_K444T_Y453F_L452R_A484K_R403K_K440N
 XBB.1.16_K444M_A484D_T346S_L455F_I468V_K440Y
Writing aggressive-cocktail of 5 to vax_designs/aggressive-cocktail-vax.fa

Designing conservative-cocktail with 2 mutations per design
Components of cocktail are:
 XBB.1.16
 XBB.1.16_L455S_K440R
 XBB.1.16_P445A_H505Y
 XBB.1.16_L452M_K444R
 XBB.1.16_F456L_K444T
Writing conservative-cocktail of 5 to vax_designs/conservative-cocktail-vax.fa
