# Analyze fitness effects of mutations fixed in each clade

This notebook analyzes the fitness effects of amino-acid mutations that have fixed in each clade, looking at the mutations in both the forward polarity in clades where it hasn't yet occurred, and the reverse polarity as reversions from clades where it has fixed, which is only possible for mutations with reversions that can be accessed by a different nucleotide change since we exclude direct nucleotide reversions when calculating the fitnesses. 

Import Python modules:

In [None]:
import os

import altair as alt

import Bio.Seq

import numpy

import pandas as pd

import yaml

_ = alt.data_transformers.disable_max_rows()

Now get variables from `snakemake`:

In [None]:
if "snakemake" not in globals() and "snakemake" not in locals():
    # variables set manually for interactive debugging
    aa_fitness_csv = "../results/aa_fitness/aa_fitness.csv"
    aamut_by_clade_csv = "../results/aa_fitness/aamut_fitness_by_clade.csv"
    clade_founder_nts_csv = "../results/clade_founder_nts/clade_founder_nts.csv"
    fixed_muts_chartfile = "../results/clade_fixed_muts/clade_fixed_muts.html"
    fixed_muts_histfile = "../results/clade_fixed_muts/clade_fixed_muts_hist.html"
    
    with open("../config.yaml") as f:
        config = yaml.safe_load(f)
    min_expected_count = config["min_expected_count"]
    ref = config["clade_fixed_muts_ref"]
    orf1ab_to_nsps = config["orf1ab_to_nsps"]
    
else:
    # get variables from `snakemake` when running pipeline
    aa_fitness_csv = snakemake.input.aafitness
    aamut_by_clade_csv = snakemake.input.aamut_by_clade
    clade_founder_nts_csv = snakemake.input.clade_founder_nts_csv
    fixed_muts_chartfile = snakemake.output.fixed_muts_chart
    fixed_muts_histfile = snakemake.output.fixed_muts_hist
    min_expected_count = snakemake.params.min_expected_count
    ref = snakemake.params.ref
    orf1ab_to_nsps = snakemake.params.orf1ab_to_nsps

Read the input data:

In [None]:
aa_fitness = pd.read_csv(aa_fitness_csv).rename(columns={"aa_site": "site"})
aamut_by_clade = pd.read_csv(aamut_by_clade_csv).rename(columns={"aa_site": "site"})
clade_founder_nts = pd.read_csv(clade_founder_nts_csv)

We only consider clades that had sufficient counts to estimate the effects of mutations:

In [None]:
clades = aamut_by_clade["clade"].unique().tolist()

print(f"Analyzing {clades=}")

Conversion from ORF1ab to nsps:

In [None]:
# now convert ORF1ab numbers to nsp numbers
orf1ab_to_nsps_df = pd.concat(
    [
        pd.DataFrame(
            [(i, i - start + 1) for i in range(start, end + 1)],
            columns=["ORF1ab_site", "nsp_site"],
        ).assign(nsp=nsp).drop_duplicates()
        for nsp, (start, end) in orf1ab_to_nsps.items()
    ],
    ignore_index=True,
)

First we get the amino acids at each position in each gene in each clade founder and the reference clade.
We ignore ORF1a and just look at ORF1ab:

In [None]:
codon_to_aa = {
    f"{nt1}{nt2}{nt3}": str(Bio.Seq.Seq(f"{nt1}{nt2}{nt3}").translate())
    for nt1 in "ACGT" for nt2 in "ACGT" for nt3 in "ACGT"
}

assert ref in clade_founder_nts["clade"].unique()

clade_founder_aas = (
    clade_founder_nts
    .query("gene != 'noncoding'")
    [["clade", "gene", "codon", "codon_site"]]
    .drop_duplicates()
    .assign(
        gene=lambda x: x["gene"].str.split(";"),
        codon=lambda x: x["codon"].str.split(";"),
        site=lambda x: x["codon_site"].str.split(";"),
    )
    .explode(["gene", "codon", "site"])
    .query("gene != 'ORF1a'")
    .assign(
        amino_acid=lambda x: x["codon"].map(codon_to_aa),
        site=lambda x: x["site"].astype(int),
    )
    [["clade", "gene", "site", "codon", "amino_acid"]]
    .drop_duplicates()
)

ref_aas = (
    clade_founder_aas
    .query("clade == @ref")
    [["gene", "site", "codon", "amino_acid"]]
    .rename(columns={"codon": "ref_codon", "amino_acid": "ref_amino_acid"})
) 

clade_founder_aas = clade_founder_aas.query("clade in @clades").merge(
    ref_aas,
    on=["gene", "site"],
    validate="many_to_one",
)

Now get just the amino acid mutations in each clade founder, and summarize the number of such mutations:

In [None]:
clade_founder_muts = clade_founder_aas.query("amino_acid != ref_amino_acid")

(
    clade_founder_muts
    .groupby("clade")
    .aggregate(n_amino_acid_muts=pd.NamedAgg("site", "count"))
)

Get the estimated fitness effect of each mutation from the overall fitness estimates.
Also convert from ORF1ab to nsp numbering:

In [None]:
# overall clade estimates of fitness effects of mutations
fitness_effects_tidy = (
    clade_founder_muts
    .merge(
        aa_fitness[["gene", "site", "aa", "fitness", "expected_count"]].rename(
            columns={"aa": "amino_acid", "fitness": "amino_acid_fitness"},
        ),
        on=["site", "gene", "amino_acid"],
        how="left",
        validate="many_to_one",
    )
    .merge(
        aa_fitness[["gene", "site", "aa", "fitness", "expected_count"]].rename(
            columns={"aa": "ref_amino_acid", "fitness": "ref_amino_acid_fitness"},
        ),
        on=["site", "gene", "ref_amino_acid"],
        how="left",
        validate="many_to_one",
    )
    .assign(
        mutation=lambda x: x["ref_amino_acid"] + x["site"].astype(str) + x["amino_acid"],
        fitness_effect=lambda x: x["amino_acid_fitness"] - x["ref_amino_acid_fitness"],
        expected_count=lambda x: numpy.minimum(x["expected_count_x"], x["expected_count_y"]),
    )
    .drop(columns=["expected_count_x", "expected_count_y"])
)

# replace ORF1ab naming with nsp naming
fitness_effects_tidy = pd.concat(
    [
        fitness_effects_tidy.query("gene != 'ORF1ab'"),
        (
            fitness_effects_tidy
            .query("gene == 'ORF1ab'")
            .merge(
                orf1ab_to_nsps_df.rename(columns={"ORF1ab_site": "site"}),
                how="left",
                validate="many_to_one",
            )
            .drop(columns=["site", "gene"])
            .rename(columns={"nsp_site": "site", "nsp": "gene"})
        )
    ],
    ignore_index=True,
)

fitness_effects_tidy

Now get this information in a wide-form data frame that just lists the clades with the mutation, and also those clades alongside their clade founder codons:

In [None]:
fitness_effects = (
    fitness_effects_tidy
    .assign(clade_codon=lambda x: x["clade"] + " (" + x["codon"] + ")")
    .groupby(
        ["gene", "site", "mutation", "ref_codon", "ref_amino_acid", "fitness_effect", "expected_count"],
        as_index=False,
        dropna=False,
    )
    .aggregate(
        clades_with_mutation=pd.NamedAgg("clade", lambda s: "; ".join(s)),
        clade_codons_with_mutation=pd.NamedAgg("clade_codon", lambda s: "; ".join(s)),
    )
)

assert len(fitness_effects) == len(fitness_effects.groupby(["gene", "mutation"]))

fitness_effects

Get amino-acid mutation fitness estimates by clade, polarized so they are coming from the reference amino-acid identity at the site, and keeping track whether each estimate is for a forward or reversion mutation in that clade.

First in tidy form:

In [None]:
fitness_effects_by_clade_tidy = (
    aamut_by_clade
    .query("mutant_aa != clade_founder_aa")
    .merge(ref_aas, on=["gene", "site"], validate="many_to_one")
    .query("(clade_founder_aa == ref_amino_acid) or (mutant_aa == ref_amino_acid)")
    .assign(
        mutation_polarity=lambda x: (x["clade_founder_aa"] == x["ref_amino_acid"]).map(
            {True: "forward", False: "reverse"}
        ),
        mutation=lambda x: numpy.where(
            x["mutation_polarity"] == "forward",
            x["clade_founder_aa"] + x["site"].astype(str) + x["mutant_aa"],
            x["mutant_aa"] + x["site"].astype(str) + x["clade_founder_aa"],
        ),
        fitness_effect=lambda x: numpy.where(
            x["mutation_polarity"] == "forward",
            x["delta_fitness"],
            -x["delta_fitness"],
        ),
    )
    [[
        "clade",
        "gene",
        "site",
        "mutation",
        "fitness_effect",
        "expected_count",
        "mutation_polarity",
    ]]
)


# replace ORF1ab naming with nsp naming
fitness_effects_by_clade_tidy = pd.concat(
    [
        fitness_effects_by_clade_tidy.query("gene != 'ORF1ab'"),
        (
            fitness_effects_by_clade_tidy
            .query("gene == 'ORF1ab'")
            .merge(
                orf1ab_to_nsps_df.rename(columns={"ORF1ab_site": "site"}),
                how="left",
                validate="many_to_one",
            )
            .drop(columns=["site", "gene"])
            .rename(columns={"nsp_site": "site", "nsp": "gene"})
        )
    ],
    ignore_index=True,
)

assert fitness_effects_by_clade_tidy.notnull().all().all()

fitness_effects_by_clade_tidy

Now make a wide version with just the mutations that occurred in the clade founders, grouping together the forward and reverse clade estimates.
We calculate the **weighted average** (weighted by clade expected count) for each mutation polarity:

In [None]:
fitness_effects_by_clade = (
    fitness_effects_by_clade_tidy
    # use this merge as a way to retain just mutations that occurred
    .merge(fitness_effects[["gene", "site", "mutation"]])
    # group forward and reverse polarities, calculate weighted average fitness effects
    .assign(
        clade_effect=(
            lambda x: (
                x["clade"]
                + " (" + x["fitness_effect"].map(lambda f: "{:.2f}".format(f)) + ")"
            )
        ),
        weighted_effect=lambda x: x["fitness_effect"] * x["expected_count"],
    )
    .groupby(["gene", "mutation", "mutation_polarity"], as_index=False)
    .aggregate(
        expected_count=pd.NamedAgg("expected_count", "sum"),
        weighted_effect=pd.NamedAgg("weighted_effect", "sum"),
        clade_effect=pd.NamedAgg("clade_effect", lambda s: "; ".join(s)),
    )
    .assign(fitness_effect=lambda x: x["weighted_effect"] / x["expected_count"])
    .drop(columns="weighted_effect")
)

fitness_effects_by_clade

Finally, aggregate information about all substitutions in the clades of interest (relative to the reference) with both the forward and reverse polarity information as well as the overall estimated effect:

In [None]:
mut_effects = (
    pd.concat(
        [
            (
                fitness_effects
                .drop(columns=["site", "ref_amino_acid"])
                .assign(mutation_polarity="all")
            ),
            fitness_effects_by_clade.merge(
                fitness_effects[
                    [
                        "gene",
                        "mutation",
                        "ref_codon",
                        "clades_with_mutation",
                        "clade_codons_with_mutation",
                    ]
                ],
                on=["gene", "mutation"],
                how="left",
                validate="many_to_one",
            ),
        ]
    )
    .sort_values(["gene", "mutation", "mutation_polarity"])
    .reset_index(drop=True)
    .assign(gene_mutation=lambda x: x["mutation"] + " in " + x["gene"])
)

mut_effects

Now plot the mutation effects, showing the estimated effects overall and from forward and reverse mutations.
Note that there is a slider for the minimum expected counts:

In [None]:
# order x-axis by mutation favorability
x_order = (
    mut_effects
    .sort_values("fitness_effect", ascending=False)
    .query("mutation_polarity == 'all'")
    ["gene_mutation"]
    .unique()
)

polarity_selection = alt.selection_multi(
    fields=["mutation_polarity"],
    bind="legend",
    init=[{"mutation_polarity": polarity} for polarity in mut_effects["mutation_polarity"].unique()],
)

gene_selection = alt.selection_multi(
    fields=["gene"],
    bind="legend",
)

expected_count_selection = alt.selection_single(
    bind=alt.binding_range(
        min=0,
        max=min(4 * min_expected_count, mut_effects["expected_count"].quantile(0.8)),
        step=1,
        name="minimum expected count",
    ),
    fields=["cutoff"],
    init={"cutoff": min_expected_count},
)

base = (
    alt.Chart(
        mut_effects.assign(
            is_spike=lambda x: (x["gene"] == "S").map(
                {True: "spike mutations", False: "non-spike mutations"}
            )
        )
    )
    .transform_filter(
        alt.datum["expected_count"] >= expected_count_selection["cutoff"] - 1e-6
    )
    .transform_filter(polarity_selection)
    .transform_joinaggregate(
        sort="max(fitness_effect)",
        groupby=["gene_mutation"],
    )
    .encode(
        x=alt.X(
            "gene_mutation",
            title=None,
            sort=alt.SortField("sort", order="descending"),
        ),
    )
)

points_chart = (
    base
    .encode(
        y=alt.Y(
            "fitness_effect",
            title="fitness effect",
            scale=alt.Scale(
                nice=False,
                domain=(mut_effects["fitness_effect"].min(), mut_effects["fitness_effect"].max()),
            ),
        ),
        color=alt.Color(
            "mutation_polarity",
            scale=alt.Scale(domain=mut_effects["mutation_polarity"].unique()),
        ),
        shape=alt.Shape(
            "mutation_polarity",
            scale=alt.Scale(domain=mut_effects["mutation_polarity"].unique()),
            legend=alt.Legend(orient="bottom"),
        ),
        stroke=alt.Stroke(
            "gene",
            scale=alt.Scale(domain=mut_effects["gene"].unique(), range=[1] * mut_effects["gene"].nunique()),
            legend=alt.Legend(orient="bottom", columns=7),
        ),
        opacity=alt.condition(gene_selection, alt.value(0.75), alt.value(0.25)),
        size=alt.condition(gene_selection, alt.value(50), alt.value(20)),
        tooltip=mut_effects.columns.tolist(),
    )
    .mark_point(filled=True, strokeWidth=0)
    .properties(width=alt.Step(12), height=150)
    .add_selection(expected_count_selection)
    .add_selection(polarity_selection)
    .add_selection(gene_selection)
)

line = (
    base
    .transform_calculate(zero="0")
    .encode(y="zero:Q")
    .mark_line(color="gray", strokeDash=[3, 3])
)

mut_effects_chart = (
    (points_chart + line)
    .facet(
        alt.Facet(
            "is_spike:N",
            title=None,
            header=alt.Header(labelFontStyle="bold", labelFontSize=13, labelPadding=0),
        ),
        columns=1,
        spacing=5,
    )
    .configure_axis(grid=False)
    .configure_legend(padding=12)
    .resolve_scale(x="independent")
    .configure_view(strokeWidth=0)
)

os.makedirs(os.path.dirname(fixed_muts_chartfile), exist_ok=True)
mut_effects_chart.save(fixed_muts_chartfile)

mut_effects_chart

Now plot histogram of fitness effects of fixed mutations and all mutations from the reference, using the overall amino-acid fitness estimates.

First, get the effects of all amino-acid mutations from the reference (for which there are fitness estimates):

In [None]:
# overall estimates of fitness effects of mutations from clade founder
all_ref_fitness_effects = (
    ref_aas
    .merge(pd.DataFrame({"amino_acid": aa_fitness["aa"].unique()}), how="cross")
    .merge(
        aa_fitness[["gene", "site", "aa", "fitness", "expected_count"]].rename(
            columns={"aa": "amino_acid", "fitness": "amino_acid_fitness"},
        ),
        on=["site", "gene", "amino_acid"],
        how="inner",
        validate="many_to_one",
    )
    .merge(
        aa_fitness[["gene", "site", "aa", "fitness", "expected_count"]].rename(
            columns={"aa": "ref_amino_acid", "fitness": "ref_amino_acid_fitness"},
        ),
        on=["site", "gene", "ref_amino_acid"],
        how="inner",
        validate="many_to_one",
    )
    .query("ref_amino_acid != amino_acid")
    .assign(
        mutation=lambda x: x["ref_amino_acid"] + x["site"].astype(str) + x["amino_acid"],
        fitness_effect=lambda x: x["amino_acid_fitness"] - x["ref_amino_acid_fitness"],
        expected_count=lambda x: numpy.minimum(x["expected_count_x"], x["expected_count_y"]),
    )
    .drop(columns=["expected_count_x", "expected_count_y"])
    [["fitness_effect", "expected_count"]]
)

all_ref_fitness_effects

In [None]:
hist_df = pd.concat(
    [
        all_ref_fitness_effects.assign(mut_type="all mutations"),
        mut_effects[["fitness_effect", "expected_count"]].assign(mut_type="mutations fixed in a clade")
    ]
)

hist_chart = (
    alt.Chart(hist_df)
    .encode(
        x=alt.X(
            "fitness_effect",
            bin=alt.Bin(step=0.5),
            axis=alt.Axis(),
            title="fitness effect of mutation",
        ),
        y=alt.Y("count()", title="number of mutations"),
        facet=alt.Facet(
            "mut_type",
            title=None,
            columns=2,
            header=alt.Header(labelFontStyle="bold", labelFontSize=12, labelPadding=0),
            spacing=5,
        ),
    )
    .mark_bar()
    .add_selection(expected_count_selection)
    .transform_filter(
        alt.datum["expected_count"] >= expected_count_selection["cutoff"] - 1e-6
    )
    .properties(width=250, height=115)
    .resolve_scale(y="independent")
    .configure_axis(grid=False)
)

hist_chart.save(fixed_muts_histfile)

hist_chart