# Analyze correlation of amino-acid fitnesses with dN/dS values
This notebook looks at how dN/dS values from FEL correlation with site summaries of the amino-acid fitness values and DMS effects, ignoring stop codons but retaining wildtype amino acids when calculating the site entropy summary statistic.

In [None]:
import altair as alt

import numpy

import pandas as pd

import scipy.stats.mstats

In [None]:
# get variables from `snakemake`
dnds_csv = snakemake.input.dnds
aa_fitness_csv = snakemake.input.aa_fitness
corr_html = snakemake.output.corr_html
min_expected_count = snakemake.params.min_expected_count

# get DMS datasets with added entry for data
dms_datasets = snakemake.params.dms_datasets
for dataset in dms_datasets:
    dms_datasets[dataset]["data"] = pd.read_csv(getattr(snakemake.input, dataset))

Get dN/dS values averaging the dN (`beta`) and dS (`alpha`) values over timeframes first, and clipping the dN/dS values at 0.05 and 20:

In [None]:
gene_map = {
    "3C": "nsp5 (Mpro)",
    "RdRp": "nsp12 (RdRp)",
    "endornase": "nsp15",
    "exonuclease": "nsp14",
    "helicase": "nsp13",
    "leader": "nsp1",
    "methyltransferase": "nsp16",
}

dnds = (
    pd.read_csv(dnds_csv)
    .query("(alpha != 0) or (beta != 0)")
    .groupby(["gene", "site"], as_index=False)
    .aggregate({"alpha": "mean", "beta": "mean"})
    .assign(
        dnds=lambda x: numpy.clip(x["beta"] / x["alpha"], a_min=0.05, a_max=20),
        log_dnds=lambda x: numpy.log(x["dnds"]),
        gene=lambda x: x["gene"].map(lambda g: gene_map[g] if g in gene_map else g),
    )
    .rename(columns={"beta": "dN", "log_dnds": "log dN/dS"})
    .drop(columns=["alpha", "dnds"])
)

dnds

Read amino-acid fitnesses, and for each site:
 - the mean fitness of mutations
 - entropy when mutations are assigned a probability weight of $e^{fitness}$

In [None]:
aa_fitness = (
    pd.read_csv(aa_fitness_csv)
    .query("aa != '*'")
    .query("gene not in ['ORF1a', 'ORF1ab']")
    .query("expected_count >= @min_expected_count")
    .rename(columns={"aa_site": "site"})
    .assign(
        p_aa=lambda x: numpy.exp(x["fitness"]),
        p=lambda x: x["p_aa"] / x.groupby(["gene", "site"])["p_aa"].transform("sum"),
    )
    .groupby(["gene", "site"], as_index=False)
    .aggregate(
        mean_fitness=pd.NamedAgg("fitness", "mean"),
        fitness_entropy=pd.NamedAgg("p", lambda p: -(p * numpy.log(p)).sum()),
    )
    .assign(fitness_n_effective=lambda x: numpy.exp(x["fitness_entropy"]))
    .rename(columns={"mean_fitness": "mean fitness", "fitness_entropy": "fitness entropy"})
)

assert not (set(dnds["gene"].unique()) - set(aa_fitness["gene"].unique()))

aa_fitness

Merge amino-acid fitness estimates aggregated at site level with dN/dS values:

In [None]:
fitness_dnds_df = (
    aa_fitness
    .merge(dnds, on=["gene", "site"], validate="one_to_one")
    [["gene", "site", "mean fitness", "log dN/dS"]]
)

Get fitness versus dN/dS correlations for each gene, then plot:

In [None]:
fitness_dnds_corrs = (
    fitness_dnds_df
    .groupby("gene")
    .corr()
    .reset_index()
    .query("level_1 == 'mean fitness'")
    .rename(columns={"log dN/dS": "Pearson correlation"})
    [["gene", "Pearson correlation"]]
)

fitness_dnds_corrs

Now plot the fitness versus dN/dS per gene correlations:

In [None]:
# sort by correlation
gene_order = (
    fitness_dnds_corrs
    .sort_values("Pearson correlation", ascending=False)
    ["gene"]
    .tolist()
)

fitness_dnds_corr_chart = (
    alt.Chart(fitness_dnds_corrs)
    .encode(
        x=alt.X("gene", scale=alt.Scale(domain=gene_order), title=None),
        y=alt.Y("Pearson correlation"),
        tooltip=[
            alt.Tooltip(c, format=".2f") if fitness_dnds_corrs[c].dtype == float else c
            for c in fitness_dnds_corrs.columns
        ],
    )
    .mark_point(filled=True, size=50, color="black")
    .properties(
        width=alt.Step(20),
        height=150,
        title="Correlation of site-mean estimated fitness and log dN/dS for each gene",
    )
)

fitness_dnds_corr_chart

Now for each DMS dataset add in the DMS measured values aggregated at the site level:

In [None]:
dms_site_df = []

for dms_dataset, dms_d in dms_datasets.items():
   
    dms_df = (
        dms_d["data"]
        .query("(wildtype != '*') and (mutant != '*')")
    )
    
    if dms_df["site"].dtype != int:
        dms_df = (
            dms_df
            .query("site.str.isnumeric()")
            .assign(site=lambda x: x["site"].astype(int))
        )
    
    if not len(dms_df.query("wildtype == mutant")):
        # add wildtypes with effects of zero
        dms_df = pd.concat(
            [
                dms_df,
                (
                    dms_df
                    [["site", "wildtype"]]
                    .drop_duplicates()
                    .assign(mutant=lambda x: x["wildtype"], effect=0)
                ),
            ]
        )
    
    if "filter_cols" in dms_d:
        for c, x in dms_d["filter_cols"].items():
            dms_df = dms_df[dms_df[c] >= x]
            
    # calculate site summary DMS and merge with dN/dS and fitness
    dms_site_df.append(
        dms_df
        .groupby("site", as_index=False)
        .aggregate(mean_effect=pd.NamedAgg("effect", "mean"))
        .assign(gene=dms_d["gene"], study=dms_d["description"])
        .merge(fitness_dnds_df, on=["gene", "site"], validate="one_to_one")
    )

dms_site_df = (
    pd.concat(dms_site_df)
    .rename(
        columns={
            "mean fitness": "fitness (current study)",
            "mean_effect": "DMS measurement",
        }
    )
    .melt(
        id_vars=["study", "site", "DMS measurement"],
        value_vars=["fitness (current study)", "log dN/dS"],
        var_name="estimation method",
        value_name="estimated value",
    )
)

dms_site_df

Now plot site correlations of DMS with fitness estimates and dN/dS:

In [None]:
highlight = alt.selection_single(
    on="mouseover",
    fields=["study", "site"],
    empty="none",
)

base = alt.Chart(dms_site_df).encode(
    x=alt.X("DMS measurement", scale=alt.Scale(nice=False), axis=alt.Axis(grid=False)),
    y=alt.Y("estimated value", axis=alt.Axis(title=None, grid=False), scale=alt.Scale(nice=False)),
)

dms_site_chart = (
    base
    .encode(
        color=alt.Color("study", legend=None),
        shape=alt.Shape("estimation method", legend=None),
        strokeWidth=alt.condition(highlight, alt.value(2), alt.value(0)),
        size=alt.condition(highlight, alt.value(50), alt.value(25)),
        opacity=alt.condition(highlight, alt.value(1), alt.value(0.5)),
        tooltip=[
            alt.Tooltip(c, format=".2f") if dms_site_df[c].dtype == float else c
            for c in dms_site_df.columns if c != "study"
        ],
    )
    .mark_point(filled=True, stroke="black")
)

# regression line and correlation coefficient: https://stackoverflow.com/a/60239699
line = (
    base
    .transform_regression("DMS measurement", "estimated value")
    .mark_line(color="gray", size=4, opacity=0.5)
)
    
params_r = (
    base
    .transform_regression("DMS measurement", "estimated value", params=True)
    .transform_calculate(
        r=alt.expr.sqrt(alt.datum["rSquared"]),
        label='"r = " + format(datum.r, ".2f")',
    )
    .mark_text(align="left", color="gray", fontWeight="bold", fontSize=13)
    .encode(x=alt.value(5), y=alt.value(8), text=alt.Text("label:N"))
)

dms_corr_chart = (
    (dms_site_chart + line + params_r)
    .properties(width=155, height=155)
    .facet(
        column=alt.Column(
            "study",
            title="Correlation of site-average DMS measurements with site-mean estimated fitness or log dN/dS",
            header=alt.Header(labelFontSize=12, labelPadding=1, labelFontStyle="bold", titleFontSize=13),
        ),
        row=alt.Row(
            "estimation method",
            title=None,
            header=alt.Header(labelFontStyle="bold", labelFontSize=11, labelPadding=1),
        ),
        spacing=6,
    )
    .resolve_scale(x="independent", y="independent")
    .add_selection(highlight)
)

dms_corr_chart

In [None]:
chart = alt.vconcat(fitness_dnds_corr_chart, dms_corr_chart, spacing=35, center=True)

chart.save(corr_html)

chart