# Estimate nucleotide fitness values aggregated across clades

## Explanation of what this notebook is doing
For each clade have estimated the change in fitness $\Delta f_{xy}$ caused by mutating a site from nucleotide $x$ to $y$, where $x$ is the nucleotide in the clade founder sequence.
For each such mutation, we also have $n_{xy}$ which is the number of **expected** mutations from the clade founder nucleotide $x$ to $y$.
These $n_{xy}$ values are important because they give some estimate of our "confidence" in the $\Delta f_{xy}$ values: if a mutation has high expected counts (large $n_{xy}$) then we can estimate the change in fitness caused by the mutation more accurately, and if $n_{xy}$ is small then the estimate will be much noisier.

However, we would like to aggregate the data across multiple clades to estimate nucleotide fitness values at a site under the assumption that these are constant across clades.
Now things get more complicated.
For instance, let's say at our site of interest, the clade founder nucleotide is $x$ in one clade and $z$ in another clade.
For each clade we then have a set of $\Delta f_{xy}$ and $n_{xy}$ values for the first clade (where $y$ ranges over the 20 nucleotide, including stop codon, that aren't $x$), and another set of up to 20 $\Delta f_{zy}$ and $n_{zy}$ values for the second clade (where $y$ ranges over the 20 nucleotide that aren't $z$).

From these sets of mutation fitness changes, we'd like to estimate the fitness $f_x$ of each nucleotide $x$, where the $f_x$ values satisfy $\Delta f_{xy} = f_y - f_x$ (in other words, a higher $f_x$ means higher fitness of that nucleotide).
When there are multiple clades with different founder nucleotides at the site, there is no guarantee that we can find $f_x$ values that precisely satisfy the above equation since there are more $\Delta f_{xy}$ values than $f_x$ values and the $\Delta f_{xy}$ values may have noise (and is some cases even real shifts among clades due to epistasis).
Nonetheless, we can try to find the $f_x$ values that come closest to satisfying the above equation.

First, we choose one nucleotide to have a fitness value of zero, since the scale of the $f_x$ values is arbitrary and there are really only 20 unique parameters among the 21 $f_x$ values (there are 21 nucleotides since we consider stops, but we only measure differences among them, not absolute values).
Typically if there was just one clade, we would set the wildtype value of $f_x = 0$ and then for mutations to all other nucleotides $y$ we would simply have $f_y = \Delta f_{xy}$.
However, when there are multple clades with different founder nucleotides, there is no longer a well defined "wildtype".
So we choose the most common **non-stop** parental nucleotide for the observed mutations and set that to zero.
In other words, we find $x$ that maximizes $\sum_y n_{xy}$ and set that $f_x$ value to zero.

Next, we choose the $f_x$ values that most closely match the measured mutation effects, weighting more strongly mutation effects with higher expected counts (since these should be more accurate).
Specifically, we define a loss function as
$$
L = \sum_x \sum_{y \ne x} n_{xy} \left(\Delta f_{xy} - \left[f_y - f_x\right]\right)^2
$$
where we ignore effects of synonymous mutations (the $x \ne y$ term in second summand) because we are only examining protein-level effects.
We then use numerical optimization to find the $f_x$ values that minimize that loss $L$.

Finally, we would still like to report an equivalent of the $n_{xy}$ values for the $\Delta f_{xy}$ values that give us some sense of how accurately we have estimated the fitness $f_x$ of each nucleotide.
To do that, we tabulate $N_x = \sum_y \left(n_{xy} + n_{yx} \right)$ as the total number of mutations either from or to nucleotide $x$ as the "count" for the nucleotide.
Nucleotides with larger values of $N_x$ should have more accurate estimates of $f_x$.

## Implementation of calculation

Get variables from `snakemake`:

In [None]:
if "snakemake" in locals() or "snakemake" in globals():
    # from snakemake 
    ntmut_fitness_csv = snakemake.input.ntmut_fitness
    nt_fitness_csv = snakemake.output.nt_fitness
else:
    # manually defined for debugging outside snakemake pipeline
    ntmut_fitness_csv = "../results/nt_fitness/ntmut_fitness_all.csv"
    nt_fitness_csv = "../results/nt_fitness/nt_fitness.csv"

Import Python modules:

In [None]:
import numpy

import pandas as pd

import scipy.optimize

We read the nucleotide mutation fitnesses:

In [None]:
ntmut_fitness = pd.read_csv(ntmut_fitness_csv).assign(
    parent_nt=lambda x: x["nt_mutation"].str[0],
    mutant_nt=lambda x: x["nt_mutation"].str[-1],
)

ntmut_fitness

In [None]:
def get_nt_fitness(site_df):
    """Estimates fitness of nucleotides at site in a gene."""
    assert site_df["gene"].nunique() == 1
    gene = site_df["gene"].unique()[0]
    
    assert site_df["nt_site"].nunique() == 1
    site = site_df["nt_site"].unique()[0]
    
    assert len(site_df) == len(site_df.groupby(["parent_nt", "mutant_nt"])), site_df
    
    # if there is just one clade founder x, then f_y = Delta f_xy,
    # and n_y = n_xy for y != x and n_x = sum_y n_xy.
    if site_df["parent_nt"].nunique() == 1:
        return pd.concat(
            [
                (
                    site_df
                    .rename(
                        columns={
                            "mutant_nt": "nt",
                            "delta_fitness": "fitness",
                        }
                    )
                    [["nt_site", "nt", "fitness", "expected_count"]]
                ),
                pd.DataFrame(
                    {
                        "nt_site": [site],
                        "nt": site_df["parent_nt"].unique(),
                        "fitness": [0.0],
                        "expected_count": [site_df["expected_count"].sum()]
                    }
                )
            ],
        ).assign(nt_differs_among_clade_founders=False)
    
    # If we get here, there are multiple clade founders and we need to solve for f_x.
    # The code below is highly inefficient in terms of speed, but is fast enough
    # for current purposes.
    
    # first get counts of each nucleotide and the highest count one for which
    # we set f_x to zero
    count_df = (
        site_df
        .rename(columns={"parent_nt": "nt"})
        .groupby("nt", as_index=False)
        .aggregate(count_1=pd.NamedAgg("expected_count", "sum"))
        .merge(
            site_df
            .rename(columns={"mutant_nt": "nt"})
            .groupby("nt", as_index=False)
            .aggregate(count_2=pd.NamedAgg("expected_count", "sum")),
            how="outer",
            on="nt",
        )
        .fillna(0)
        .assign(count=lambda x: x["count_1"] + x["count_2"])
        .sort_values("count")
    )
    counts = count_df.set_index("nt")["count"].to_dict()
    highest_count_nt = count_df["nt"].tolist()[-1]
    nts = [nt for nt in count_df["nt"] if nt != highest_count_nt]  # all but highest count
    
    parent_nts = site_df["parent_nt"].unique()
    mutant_nts = site_df["mutant_nt"].unique()
    # keyed by (parent_nt, mutant_nt)
    site_dict = (
        site_df
        .set_index(["parent_nt", "mutant_nt"])
        [["expected_count", "delta_fitness"]]
        .to_dict(orient="index")
    )
    
    def loss(f_vec):
        f_nt = dict(zip(nts, f_vec))
        f_nt[highest_count_nt] = 0
        loss_val = 0.0
        for parent_nt in parent_nts:
            f_parent = f_nt[parent_nt]
            for mutant_nt in mutant_nts:
                try:
                    delta_f = site_dict[(parent_nt, mutant_nt)]["delta_fitness"]
                    n = site_dict[(parent_nt, mutant_nt)]["expected_count"]
                except KeyError:
                    continue
                f_mutant = f_nt[mutant_nt]
                loss_val += n * (delta_f - (f_mutant - f_parent))**2
        return loss_val
    
    opt_res = scipy.optimize.minimize(loss, numpy.zeros(len(nts)), method="Powell")
    assert opt_res.success, f"{opt_res}\n\n{site_df}"
    
    fs = dict(zip(nts, opt_res.x))
    fs[highest_count_nt] = 0

    return pd.DataFrame(
        {
            "nt_site": site,
            "nt": fs.keys(),
            "fitness": fs.values(),
            "expected_count": [counts[nt] for nt in fs],
            "nt_differs_among_clade_founders": True,
        }
    )


site_dfs = []
for i, (site, site_df) in enumerate(ntmut_fitness.groupby("nt_site")):
    site_dfs.append(get_nt_fitness(site_df))
    if i % 500 == 0:
        print(f"Completed optimization {i + 1}")
print(f"Completed all {i + 1} optimizations.")

fitness_df = (
    pd.concat(site_dfs)
    .sort_values(["nt_site", "nt"])
    .reset_index(drop=True)
)

assert len(fitness_df) == len(fitness_df.groupby(["nt_site", "nt"]))

fitness_df

Look at how many sites have changed in clade founders:

In [None]:
(
    fitness_df
    .groupby("nt_differs_among_clade_founders", as_index=False)
    .aggregate(n_sites=pd.NamedAgg("nt_site", "nunique"))
)

Now we compare the fitness estimates to the mutation delta fitness values.
First do this for all sites where the clade founders share a wildtype.
This correlation should be exactly one:

In [None]:
one_founder_corrs = (
    fitness_df
    .query("not nt_differs_among_clade_founders")
    [["nt_site", "nt", "fitness"]]
    .merge(
        ntmut_fitness
        [["mutant_nt", "nt_site", "delta_fitness"]]
        .rename(columns={"mutant_nt": "nt"}),
    )
    [["fitness", "delta_fitness"]]
    .corr()
)

assert (one_founder_corrs.values == 1).all()

one_founder_corrs

Now get the correlations for sites with multiple clade founders.
To do this, we adjust the $\Delta f_{xy}$ values by the clade founder fitnesses.
Now we expect the correlations to be good, but not necessarily quite one.
They should be better for the case where the clade founder is the most abundant one as those weigh higher in the fitness estimates:

In [None]:
multi_founder_corrs = (
    fitness_df
    .query("nt_differs_among_clade_founders")
    .drop(columns="nt_differs_among_clade_founders")
    .merge(
        ntmut_fitness
        [["parent_nt", "mutant_nt", "nt_site", "delta_fitness"]]
        .rename(columns={"mutant_nt": "nt"})
    )
    .merge(
        fitness_df[["nt_site", "nt", "fitness"]]
        .rename(columns={"nt": "parent_nt", "fitness": "clade_founder_fitness"})
    )
    .assign(
        most_abundant_clade_founder=lambda x: x["clade_founder_fitness"] == 0,
        adjusted_delta_fitness=lambda x: x["delta_fitness"] + x["clade_founder_fitness"],
    )
    .groupby("most_abundant_clade_founder")
    [["fitness", "adjusted_delta_fitness"]]
    .corr()
)

assert (one_founder_corrs.values >= 0.85).all()

multi_founder_corrs.round(3)

Write the values to a file:

In [None]:
print(f"Writing to {nt_fitness_csv}")

fitness_df.to_csv(nt_fitness_csv, index=False, float_format="%.5g")