# Estimate amino-acid fitness values aggregated across clades

## Explanation of what this notebook is doing
For each clade have estimated the change in fitness $\Delta f_{xy}$ caused by mutating a site from amino-acid $x$ to $y$, where $x$ is the amino acid in the clade founder sequence.
For each such mutation, we also have $n_{xy}$ which is the number of **expected** mutations from the clade founder amino acid $x$ to $y$.
These $n_{xy}$ values are important because they give some estimate of our "confidence" in the $\Delta f_{xy}$ values: if a mutation has high expected counts (large $n_{xy}$) then we can estimate the change in fitness caused by the mutation more accurately, and if $n_{xy}$ is small then the estimate will be much noisier.

However, we would like to aggregate the data across multiple clades to estimate amino-acid fitness values at a site under the assumption that these are constant across clades.
Now things get more complicated.
For instance, let's say at our site of interest, the clade founder amino acid is $x$ in one clade and $z$ in another clade.
For each clade we then have a set of $\Delta f_{xy}$ and $n_{xy}$ values for the first clade (where $y$ ranges over the 19 amino acids that aren't $x$), and another set of up to 19 $\Delta f_{zy}$ and $n_{zy}$ values for the second clade (where $y$ ranges over the 19 amino acids that aren't $z$).

From these sets of mutation fitness changes, we'd like to estimate the fitness $f_x$ of each amino acid $x$, where the $f_x$ values satisfy $\Delta f_{xy} = f_y - f_x$ (in other words, a higher $f_x$ means higher fitness of that amino acid).
When there are multiple clades with different founder amino acids at the site, there is no guarantee that we can find $f_x$ values that precisely satisfy the above equation since there are more $\Delta f_{xy}$ values than $f_x$ values and the $\Delta f_{xy}$ values may have noise (and is some cases even real shifts among clades due to epistasis).
Nonetheless, we can try to find the $f_x$ values that come closest to satisfying the above equation.

First, we choose one amino acid to have a fitness value of zero, since the scale of the $f_x$ values is arbitrary and there are really only 19 unique parameters among the 20 $f_x$ values (since we only measure differences among them, not absolute values).
Typically if there was just one clade, we would set the wildtype value of $f_x = 0$ and then for mutations to all other amino acids $y$ we would simply have $f_y = \Delta f_{xy}$.
However, when there are multple clades with different founder amino acids, there is no longer a well defined "wildtype".
So we choose the most common parental amino-acid for the observed mutations and set that to zero.
In other words, we find $x$ that maximizes $\sum_y n_{xy}$ and set that $f_x$ value to zero.

Next, we choose the $f_x$ values that most closely match the measured mutation effects, weighting more strongly mutation effects with higher expected counts (since these should be more accurate).
Specifically, we define a loss function as
$$
L = \sum_x \sum_{y \ne x} n_{xy} \left(\Delta f_{xy} - \left[f_y - f_x\right]\right)^2
$$
where we ignore effects of synonymous mutations (the $x \ne y$ term in second summand) because we are only examining protein-level effects.
We then use numerical optimization to find the $f_x$ values that minimize that loss $L$.

Finally, we would still like to report an equivalent of the $n_{xy}$ values for the $\Delta f_{xy}$ values that give us some sense of how accurately we have estimated the fitness $f_x$ of each amino acid.
To do that, we tabulate $N_x = \sum_y \left(n_{xy} + n_{yx} \right)$ as the total number of mutations either from or to amino-acid $x$ as the "count" for the amino acid.
Amino acids with larger values of $N_x$ should have more accurate estimates of $f_x$.

## Implementation of calculation

Get variables from `snakemake`:

In [None]:
# from snakemake 
aamut_fitness_csv = snakemake.input.aamut_fitness
aa_fitness_csv = snakemake.output.aa_fitness

# manually defined for debugging
# aamut_fitness_csv = "../results/aa_fitness/aamut_fitness_all.csv"

Import Python modules:

In [None]:
import numpy

import pandas as pd

import scipy.optimize

We read the amino-acid mutation fitnesses, **ignoring** synonymous mutations:

In [None]:
aamut_fitness = pd.read_csv(aamut_fitness_csv).query("clade_founder_aa != mutant_aa")

aamut_fitness

In [None]:
def get_aa_fitness(site_df):
    """Estimates fitness of amino acids at site in a gene."""
    assert site_df["gene"].nunique() == 1
    gene = site_df["gene"].unique()[0]
    
    assert site_df["aa_site"].nunique() == 1
    site = site_df["aa_site"].unique()[0]
    
    assert len(site_df) == len(site_df.groupby(["clade_founder_aa", "mutant_aa"]))
    
    # if there is just one clade founder x, then f_y = Delta f_xy,
    # and n_y = n_xy for y != x and n_x = sum_y n_xy.
    if site_df["clade_founder_aa"].nunique() == 1:
        return pd.concat(
            [
                (
                    site_df
                    .rename(
                        columns={
                            "mutant_aa": "aa",
                            "expected_count": "count",
                            "delta_fitness": "fitness",
                        }
                    )
                    [["gene", "aa_site", "aa", "fitness", "count"]]
                ),
                pd.DataFrame(
                    {
                        "gene": [gene],
                        "aa_site": [site],
                        "aa": site_df["clade_founder_aa"].unique(),
                        "fitness": [0.0],
                        "count": [site_df["expected_count"].sum()]
                    }
                )
            ],
        ).assign(aa_differs_among_clade_founders=False)
    
    # If we get here, there are multiple clade founders and we need to solve for f_x.
    # The code below is highly inefficient in terms of speed, but is fast enough
    # for current purposes.
    
    # first get counts of each amino-acid and the highest count one for which
    # we set f_x to zero
    count_df = (
        site_df
        .rename(columns={"clade_founder_aa": "aa"})
        .groupby("aa", as_index=False)
        .aggregate(count_1=pd.NamedAgg("expected_count", "sum"))
        .merge(
            site_df
            .rename(columns={"mutant_aa": "aa"})
            .groupby("aa", as_index=False)
            .aggregate(count_2=pd.NamedAgg("expected_count", "sum")),
            how="outer",
            on="aa",
        )
        .fillna(0)
        .assign(count=lambda x: x["count_1"] + x["count_2"])
        .sort_values("count")
    )
    counts = count_df.set_index("aa")["count"].to_dict()
    highest_count_aa = count_df["aa"].tolist()[-1]
    aas = count_df["aa"].tolist()[: -1]  # all but highest count
    
    parent_aas = site_df["clade_founder_aa"].unique()
    mutant_aas = site_df["mutant_aa"].unique()
    # keyed by (parent_aa, mutant_aa)
    site_dict = (
        site_df
        .set_index(["clade_founder_aa", "mutant_aa"])
        [["expected_count", "delta_fitness"]]
        .to_dict(orient="index")
    )
    
    def loss(f_vec):
        f_aa = dict(zip(aas, f_vec))
        f_aa[highest_count_aa] = 0
        loss_val = 0.0
        for parent_aa in parent_aas:
            f_parent = f_aa[parent_aa]
            for mutant_aa in mutant_aas:
                try:
                    delta_f = site_dict[(parent_aa, mutant_aa)]["delta_fitness"]
                    n = site_dict[(parent_aa, mutant_aa)]["expected_count"]
                except KeyError:
                    continue
                f_mutant = f_aa[mutant_aa]
                loss_val += n * (delta_f - (f_mutant - f_parent))**2
        return loss_val
    
    opt_res = scipy.optimize.minimize(loss, numpy.zeros(len(aas)), method="Powell")
    assert opt_res.success, f"{opt_res}\n\n{site_df}"
    
    fs = dict(zip(aas, opt_res.x))
    fs[highest_count_aa] = 0
    assert list(fs.keys()) == list(counts.keys())

    return pd.DataFrame(
        {
            "gene": gene,
            "aa_site": site,
            "aa": fs.keys(),
            "fitness": fs.values(),
            "count": counts.values(),
            "aa_differs_among_clade_founders": True,
        }
    )


site_dfs = []
for i, ((gene, site), site_df) in enumerate(aamut_fitness.groupby(["gene", "aa_site"])):
    site_dfs.append(get_aa_fitness(site_df))
    if i % 500 == 0:
        print(f"Completed optimization {i + 1}")
print(f"Completed all {i + 1} optimizations.")

fitness_df = (
    pd.concat(site_dfs)
    .merge(aamut_fitness[["gene", "subset_of_ORF1ab"]].drop_duplicates())
    .sort_values(["gene", "aa_site", "aa"])
    .reset_index(drop=True)
)

assert len(fitness_df) == len(fitness_df.groupby(["gene", "aa_site", "aa"]))

fitness_df

Look at how many sites have changed in clade founders:

In [None]:
(
    fitness_df
    .query("not subset_of_ORF1ab")
    .groupby(["gene", "aa_differs_among_clade_founders"], as_index=False)
    .aggregate(n_sites=pd.NamedAgg("aa_site", "nunique"))
    .pivot_table(
        index="gene",
        columns="aa_differs_among_clade_founders",
        values="n_sites",
    )
    .assign(percent_that_differ=lambda x: 100 * x[True] / (x[False] + x[True]))
    .round(1)
)

Now we compare the amino-acid fitness estimates to the mutation delta fitness values.
First do this for all sites where the clade founders share a wildtype.
This correlation should be exactly one:

In [None]:
one_founder_corrs = (
    fitness_df
    .query("not aa_differs_among_clade_founders")
    .query("not subset_of_ORF1ab")
    [["gene", "aa_site", "aa", "fitness"]]
    .merge(
        aamut_fitness
        [["gene", "mutant_aa", "aa_site", "delta_fitness"]]
        .rename(columns={"mutant_aa": "aa"}),
    )
    [["fitness", "delta_fitness"]]
    .corr()
)

assert (one_founder_corrs.values == 1).all()

one_founder_corrs

Now get the correlations for sites with multiple clade founders.
To do this, we adjust the $\Delta f_{xy}$ values by the clade founder fitnesses.
Now we expect the correlations to be good, but not necessarily quite one.
They should be better for the case where the clade founder is the most abundant one as those weigh higher in the amino-acid fitness estimates:

In [None]:
multi_founder_corrs = (
    fitness_df
    .query("aa_differs_among_clade_founders")
    .query("not subset_of_ORF1ab")
    .drop(columns=["aa_differs_among_clade_founders", "subset_of_ORF1ab"])
    .merge(
        aamut_fitness
        [["gene", "clade_founder_aa", "mutant_aa", "aa_site", "delta_fitness"]]
        .rename(columns={"mutant_aa": "aa"})
    )
    .merge(
        fitness_df[["gene", "aa_site", "aa", "fitness"]]
        .rename(columns={"aa": "clade_founder_aa", "fitness": "clade_founder_fitness"})
    )
    .assign(
        most_abundant_clade_founder=lambda x: x["clade_founder_fitness"] == 0,
        adjusted_delta_fitness=lambda x: x["delta_fitness"] + x["clade_founder_fitness"],
    )
    .groupby("most_abundant_clade_founder")
    [["fitness", "adjusted_delta_fitness"]]
    .corr()
)

assert (one_founder_corrs.values >= 0.85).all()

multi_founder_corrs.round(3)

Write the values to a file:

In [None]:
print(f"Writing to {aa_fitness_csv}")

fitness_df.to_csv(aa_fitness_csv, index=False, float_format="%.5g")