# Design SARS-CoV-2 spike RBD mutants
This notebook chooses/designs spikes with mutant RBDs.

## Setup
Import Python modules:

In [1]:
import collections
import itertools
import json
import os
import subprocess
import tempfile
import urllib

import altair as alt

import Bio.Entrez
import Bio.SeqIO

import numpy

import pandas as pd

import ruamel.yaml as yaml

Read configuration:

In [2]:
with open("config.yaml") as f:
    config = yaml.YAML().load(f)

## Get the reference RBD

In [3]:
Bio.Entrez.email = "example@example.com"

print(f"Getting reference spike from accession {config['ref_spike']}")
with Bio.Entrez.efetch(id=config["ref_spike"], rettype="gb", retmode="text", db="protein") as f:
    ref_spike = Bio.SeqIO.read(f, "gb")
print(f"Got spike of length {len(ref_spike)}")

rbd_coords = config["rbd_coords"]

ref_rbd = {r: ref_spike[r - 1] for r in range(rbd_coords[0], rbd_coords[1] + 1)}

Getting reference spike from accession YP_009724390
Got spike of length 1273


## Get RBD mutations in each Pango lineage

In [4]:
pango_json = config["pango_json"]
print(f"Reading Pango clade definitions from {pango_json}")
with urllib.request.urlopen(pango_json) as url:
    pango_lineages = json.load(url)
print(f"Read definitions for {len(pango_lineages)} lineages")


def parse_spike_rbd_muts(lineage_d):
    """Parse spike RBD mutations from dict for a lineage."""
    rbd_muts = []
    for mut in lineage_d["aaSubstitutions"] + lineage_d["aaDeletions"]:
        if mut.startswith("S:"):
            mut = mut.split(":")[1]
            wt = mut[0]
            r = int(mut[1: -1])
            m = mut[-1]
            if rbd_coords[0] <= r <= rbd_coords[1]:
                assert wt == ref_rbd[r]
                rbd_muts.append((wt, r, m))
    return rbd_muts
            
pango_rbd_muts = {pango: parse_spike_rbd_muts(d) for (pango, d) in pango_lineages.items()}

Reading Pango clade definitions from https://raw.githubusercontent.com/corneliusroemer/pango-sequences/main/data/pango-consensus-sequences_summary.json
Read definitions for 3916 lineages


## Get and import the escape calculator

In [5]:
# get and import the module
_ = urllib.request.urlretrieve(
    config["escape_calculator_module_url"],
    "escapecalculator.py",
)

import escapecalculator

## Now make designs for each parent

In [None]:
repetition_downweight = config["repetition_downweight"]
allow_reversions_to_ref = config["allow_reversions_to_ref"]
categories = config["categories"]

for (parent, parent_d), category in itertools.product(config["parent_specs"].items(), categories):
    nmutants = parent_d["nmutants"]
    nmutations = categories[category]
    with open(parent_d["specs"]) as f:
        parent_config = yaml.YAML().load(f)

    # make a data frame that includes parent and ref amino acid at each site
    parent_aas = (
        # first get reference and parent amino acid for each site
        pd.Series(ref_rbd).rename_axis("site").rename("ref_amino_acid").reset_index()
        .merge(
            pd.DataFrame(
                pango_rbd_muts[parent],
                columns=["ref_amino_acid", "site", "parent_amino_acid"],
            ),
            how="outer",
            on=["ref_amino_acid", "site"],
        )
        .assign(
            parent_amino_acid=lambda x: x["parent_amino_acid"].where(
                x["parent_amino_acid"].notnull(), x["ref_amino_acid"])
        )
    )
    assert len(parent_aas) == parent_aas["site"].nunique() == len(ref_rbd)

    # set up escape calculator
    escape_calc = escapecalculator.EscapeCalculator(virus=parent_config["escapecalculator"]["virus"])
    # get sites that differ between parent and virus used to initialize escape calculator
    parent_escape_calc_virus_diff_sites = list(set([
        tup[1]
        for tup in set(pango_rbd_muts[parent]).symmetric_difference(
            pango_rbd_muts[parent_config["escapecalculator"]["virus"]]
        )
    ]))

    # get RBD deep mutational scanning data
    rbd_dms = (
        pd.read_csv(parent_config["rbd_dms"]["data"])
        .query("target == @parent_config['rbd_dms']['target']")
        .rename(columns={"position": "site", "mutant": "amino_acid"})
        [["site", "amino_acid", "delta_bind", "delta_expr"]]
        .assign(
            delta_bind=lambda x: x["delta_bind"].clip(upper=parent_config["rbd_dms"]["clip"]),
            delta_expr=lambda x: x["delta_expr"].clip(upper=parent_config["rbd_dms"]["clip"]),
        )
    )
    # normalize values to parent amino acids
    rbd_dms = (
        rbd_dms
        .merge(parent_aas[["site", "parent_amino_acid"]], on="site", validate="many_to_one")
        .query("parent_amino_acid == amino_acid")
        .rename(columns={"delta_expr": "parent_expr", "delta_bind": "parent_bind"})
        .drop(columns=["amino_acid", "parent_amino_acid"])
        .merge(rbd_dms, on="site", validate="one_to_many")
        .assign(
            delta_bind=lambda x: x["delta_bind"] - x["parent_bind"],
            delta_expr=lambda x: x["delta_expr"] - x["parent_expr"],
        )
        .drop(columns=["parent_expr", "parent_bind"])
        .rename(columns={"delta_bind": "rbd_delta_bind", "delta_expr": "rbd_delta_expr"})
    )
    assert len(rbd_dms)

    # get all-clade fitness estimates 
    fitness = (
        pd.read_csv(parent_config["fitness_estimates"]["fitness"])
        .query("gene == 'S'")
        .query("expected_count >= @parent_config['fitness_estimates']['fitness_min_count']")
        .rename(columns={"aa_site": "site", "aa": "amino_acid"})
        .assign(fitness=lambda x: x["fitness"].clip(upper=parent_config["fitness_estimates"]["clip"]))
        [["site", "amino_acid", "fitness"]]
    )
    assert len(fitness)

    # get clade fitness effects
    by_clade_fitness = (
        pd.read_csv(parent_config["fitness_estimates"]["by_clade"])
        .query("gene == 'S'")
        .query("clade == @parent_config['fitness_estimates']['clade']")
        .query(
            "(expected_count >= @parent_config['fitness_estimates']['clade_min_count'])"
            "or (actual_count >= @parent_config['fitness_estimates']['clade_min_count'])"
        )
        .rename(columns={"aa_site": "site", "mutant_aa": "amino_acid", "delta_fitness": "clade_fitness"})
        .assign(fitness=lambda x: x["clade_fitness"].clip(upper=parent_config["fitness_estimates"]["clip"]))
        [["site", "amino_acid", "clade_fitness"]]
    )
    assert len(by_clade_fitness)

    # get spike DMS
    spike_dms = (
        pd.read_csv(parent_config["spike_dms"]["csv"])
        .rename(columns={"mutant": "amino_acid"})
        [["site", "amino_acid", "human sera escape", "spike mediated entry", "ACE2 binding"]]
    )
    # normalize values to parent amino acids
    spike_dms = (
        spike_dms
        .merge(parent_aas[["site", "parent_amino_acid"]], on="site", validate="many_to_one")
        .query("parent_amino_acid == amino_acid")
        .rename(
            columns={
                "human sera escape": "parent_human sera escape",
                "spike mediated entry": "parent_spike mediated entry",
                "ACE2 binding": "parent_ACE2 binding",
            }
        )
        .drop(columns=["amino_acid", "parent_amino_acid"])
        .merge(spike_dms, on="site", validate="one_to_many")
        .assign(
            spike_escape=lambda x: x["human sera escape"] - x["parent_human sera escape"],
            spike_entry=lambda x: x["spike mediated entry"] - x["parent_spike mediated entry"],
            spike_ACE2_binding=lambda x: x["ACE2 binding"] - x["parent_ACE2 binding"],
        )
        [["site", "amino_acid", "spike_escape", "spike_entry", "spike_ACE2_binding"]]
    )
    assert len(spike_dms)

    # now add each phenotype and its weight
    parent_phenotypes = (
        parent_aas
        # add fitness estimates, only keeping mutations with estimates
        .merge(fitness, on="site", validate="one_to_many", how="inner")
        .assign(fitness_weight=parent_config["fitness_estimates"]["weights"]["fitness"])
        # add clade fitness estimates, clipping lower at zero and also setting missing to zero
        .merge(by_clade_fitness, on=["site", "amino_acid"], validate="one_to_one", how="left")
        .assign(
            clade_fitness=lambda x: x["clade_fitness"].where(x["clade_fitness"] > 0, 0),
            clade_fitness_weight=parent_config["fitness_estimates"]["weights"]["by_clade_effect"],
        )
        # add RBD DMS, only keeping mutations with measurements
        .merge(rbd_dms, on=["site", "amino_acid"], validate="one_to_one", how="inner")
        .assign(
            rbd_delta_bind_weight=parent_config["rbd_dms"]["weights"]["delta_bind"],
            rbd_delta_expr_weight=parent_config["rbd_dms"]["weights"]["delta_expr"],
        )
        # add spike DMS, only keeping mutations with measurements
        .merge(spike_dms, on=["site", "amino_acid"], validate="one_to_one", how="inner")
        .assign(
            spike_escape_weight=parent_config["spike_dms"]["weights"]["sera escape"],
            spike_entry_weight=parent_config["spike_dms"]["weights"]["spike mediated entry"],
            spike_ACE2_binding_weight=parent_config["spike_dms"]["weights"]["ACE2 binding"],
        )
    )

    # Design each mutant
    mut_counts = collections.defaultdict(int)  # count how many times each mutation already in a mutant
    for imutant in range(nmutants):
        mutant_name = f"{parent}-{category}-mutant-{imutant + 1}"
        print(f"\nDesigning {mutant_name}")
        
        mutated_sites = []  # sites mutated in this variant
        mutated_df = []
        mutations = []
        for imutation in range(nmutations):
            # we re-do escape calculator for each new mutated sites
            mutant_phenotypes = parent_phenotypes.merge(
                escape_calc.escape_per_site(parent_escape_calc_virus_diff_sites + mutated_sites)
                .rename(columns={"retained_escape": "escape_calc"})
                [["site", "escape_calc"]]
                .assign(escape_calc_weight=parent_config["escapecalculator"]["weight"]),
                on="site",
                validate="many_to_one",
            )

            # ignore parent to parent mutations
            mutant_phenotypes = mutant_phenotypes.query("parent_amino_acid != amino_acid")

            # potentially ignore mutations to reference
            if not config["allow_reversions_to_ref"]:
                mutant_phenotypes = mutant_phenotypes.query("ref_amino_acid != amino_acid")

            # get the phenotypes
            phenos = [
                c for c in mutant_phenotypes.columns
                if not c.endswith("_weight") and c not in {"site", "amino_acid", "parent_amino_acid", "ref_amino_acid"}
            ]
            assert all(f"{p}_weight" in set(mutant_phenotypes.columns) for p in phenos)

            # assign score to each mutation
            mutant_phenotypes["score"] = mutant_phenotypes["escape_calc"]**mutant_phenotypes["escape_calc_weight"]
            for pheno in phenos:
                if pheno != "escape_calc":
                    mutant_phenotypes["score"] = mutant_phenotypes["score"] * numpy.exp(
                        mutant_phenotypes[pheno] * mutant_phenotypes[f"{pheno}_weight"]
                    )

            # re-weight mutations already added to variants
            mutant_phenotypes = (
                mutant_phenotypes
                .assign(
                    mutation=lambda x: x["parent_amino_acid"] + x["site"].astype(str) + x["amino_acid"],
                    reweight=lambda x: config["repetition_downweight"]**x["mutation"].map(mut_counts),
                    score=lambda x: x["score"] * x["reweight"],
                )
            )

            # get top scoring mutation not in already mutated sites
            mutation_row = (
                mutant_phenotypes
                .query("site not in @mutated_sites")
                .query("score.notnull()")
                .sort_values("score")
                [["site", "mutation", "ref_amino_acid", "parent_amino_acid", "amino_acid", *phenos]]
                .tail(1)
            )
            site = mutation_row["site"].iloc[0]
            parent_amino_acid = mutation_row["parent_amino_acid"].iloc[0]
            amino_acid = mutation_row["amino_acid"].iloc[0]
            mutation = mutation_row["mutation"].iloc[0]
            mutated_sites.append(site)
            mutated_df.append(mutation_row)
            mutations.append(mutation)
            mut_counts[mutation] += 1

        mutated_df = pd.concat(mutated_df, ignore_index=True)
        print(f"Contains the following mutations: {', '.join(mutations)}")
        display(mutated_df)