# Get amino-acid mutations in clade founders

In [None]:
# input files
clade_founder_nts_csv = "results/clade_founder_nts/clade_founder_nts.csv"
rates_by_clade_csv = "results/synonymous_mut_rates/rates_by_clade.csv"

# output file
output_csv = "results/clade_founder_aa_muts/clade_founder_aa_muts.csv"

In [None]:
import itertools
import os

import altair as alt

import Bio.Seq
import Bio.SeqIO

import pandas as pd

import yaml

In [None]:
with open("config.yaml") as f:
    config = yaml.safe_load(f)
    
orf1ab_to_nsps = config["orf1ab_to_nsps"]

clade_synonyms = config["clade_synonyms"]

Get protein sequences for all genes:

In [None]:
clades_of_interest = pd.read_csv(rates_by_clade_csv)["clade"].unique().tolist()

clade_founder_seqs = (
    pd.read_csv(clade_founder_nts_csv)
    .query("clade in @clades_of_interest")
    .assign(gene=lambda x: x["gene"].str.split(";"))
    .explode("gene")
    .sort_values(["clade", "gene", "site"])
    .groupby(["clade", "gene"], as_index=False)
    .aggregate(gene_seq=pd.NamedAgg("nt", lambda s: "".join(s)))
    .assign(
        prot_seq=lambda x: x["gene_seq"].map(
            lambda s: str(Bio.Seq.Seq(s).translate())
        )
    )
)

assert all(
    clade_founder_seqs["gene_seq"].map(len)
    == 3 * clade_founder_seqs["prot_seq"].map(len)
)

clade_founder_seqs

Get amino-acid mutations between each pair of clades:

In [None]:
def get_muts(row):
    assert len(row["prot_seq_1"]) == len(row["prot_seq_2"])
    return [
        f"{x1}{r + 1}{x2}"
        for r, (x1, x2) in enumerate(zip(row["prot_seq_1"], row["prot_seq_2"]))
        if x1 != x2
    ]

aa_muts_df = []

for clade_1, clade_2 in itertools.product(
    clade_founder_seqs["clade"].unique(), clade_founder_seqs["clade"].unique(),
):
    aa_muts_df.append(
        clade_founder_seqs
        .query("clade == @clade_1")
        .merge(
            clade_founder_seqs.query("clade == @clade_2"),
            on="gene",
            suffixes=["_1", "_2"],
        )
        .assign(mutations=lambda x: x.apply(get_muts, axis=1))
        [["clade_1", "clade_2", "gene", "mutations"]]
    )

aa_muts_df = pd.concat(aa_muts_df, ignore_index=True)

aa_muts_df

Assign ORF1ab to Nsp mutations:

In [None]:
aa_muts_nsp_df = (
    aa_muts_df
    .explode("mutations")
    .query("mutations.notnull()")
)

# only keep ORF1ab mutations as ORF1a is subset of those
assert set(aa_muts_nsp_df.query("gene == 'ORF1a'")["mutations"]).issubset(
    aa_muts_nsp_df.query("gene == 'ORF1ab'")["mutations"])

def gene_mutation(row):
    gene = row["gene"]
    mutation = row["mutations"]
    if gene == "ORF1ab":
        r = int(mutation[1: -1])
        for nsp, nsp_start in orf1ab_to_nsps.items():
            if r < nsp_start:
                mutation = f"{mutation[0]}{r - last_nsp_start + 1}{mutation[-1]}"
                gene = last_nsp
                break
            last_nsp = nsp
            last_nsp_start = nsp_start
        else:
            mutation = f"{mutation[0]}{r - last_nsp_start + 1}{mutation[-1]}"
            gene = last_nsp
    return f"{gene} {mutation}"

aa_muts_nsp_df = (
    aa_muts_nsp_df
    .query("gene != 'ORF1a'")
    .assign(
        gene_mutation=lambda x: x.apply(gene_mutation, axis=1),
        gene=lambda x: x["gene_mutation"].str.split().str[0],
        mutation=lambda x: x["gene_mutation"].str.split().str[1],
    )
    .drop(columns="mutations")
)

aa_muts_nsp_df

Get non-spike mutations relative to first clade:

In [None]:
ref_clade = clades_of_interest[0]
print(f"Getting mutations relative to {ref_clade}")

non_spike_muts = (
    aa_muts_nsp_df
    .query("gene != 'S'")
    .query("clade_1 == @ref_clade")
    .rename(columns={"clade_2": "clade"})
    .groupby(["clade", "gene"], as_index=False)
    .aggregate(mutations=pd.NamedAgg("mutation", lambda s: ", ".join(s)))
    .assign(gene_mutations=lambda x: x["gene"] + ": " + x["mutations"])
    .groupby("clade", as_index=False)
    .aggregate(gene_mutations=pd.NamedAgg("gene_mutations", lambda s: "; ".join(s)))
    .assign(clade=lambda x: x["clade"].map(lambda c: f"{c} ({clade_synonyms[c]})"))
)

os.makedirs(os.path.dirname(output_csv), exist_ok=True)
non_spike_muts.to_csv(output_csv, index=False, sep="\t")

pd.options.display.max_colwidth = 300

non_spike_muts