# Functional effects of mutations averaged across replicates
This notebook aggregates all the global epistasis fits for individual replicates of the effects of mutations on the functional effects of mutations on viral entry.
It analyzes both the latent and observed phenotypes from the global epistasis models.

First, import Python modules:

In [None]:
import os

import altair as alt

import dms_variants.utils

import pandas as pd

import polyclonal
import polyclonal.alphabets
import polyclonal.plot

import yaml

In [None]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

Get configuration information:

In [None]:
# If you are running notebook interactively rather than in pipeline that handles
# working directories, you may have to first `os.chdir` to appropriate directory.

with open("config.yaml") as f:
    config = yaml.safe_load(f)

Read the sequential-to-reference site numbering map:

In [None]:
sitenumbering_map = pd.read_csv(config["site_numbering_map"])

## Read the mutation effects
The functional selections data frame:

In [None]:
func_selections = pd.read_csv(config["functional_selections"])

The mutation effects:

In [None]:
phenotypes = ["observed", "latent"]

muteffects = pd.concat(
    [
        pd.read_csv(
            os.path.join(
                config["globalepistasis_dir"],
                f"{selection_name}_muteffects_{phenotype}.csv",
            )
        ).assign(
            selection_name=selection_name,
            phenotype=phenotype,
            times_seen=lambda x: x["times_seen"].astype("Int64"),
            mutation=lambda x: x["wildtype"]
            + x["sequential_site"].astype(str)
            + x["mutant"],
        )
        for selection_name in func_selections["selection_name"]
        for phenotype in phenotypes
    ],
    ignore_index=True,
).merge(
    func_selections,
    on="selection_name",
    how="left",
    validate="many_to_one",
)

assert len(muteffects) == len(muteffects.drop_duplicates())
assert muteffects.drop(columns="times_seen").notnull().all().all()

## Correlations among mutation effects
Correlations among replicates:

In [None]:
corrs = (
    dms_variants.utils.tidy_to_corr(
        df=muteffects,
        sample_col="selection_name",
        label_col="mutation",
        value_col="effect",
        group_cols="phenotype",
    )
    .assign(r2=lambda x: x["correlation"] ** 2)
    .drop(columns="correlation")
)

for phenotype, phenotype_corr in corrs.groupby("phenotype"):
    corr_chart = (
        alt.Chart(phenotype_corr)
        .encode(
            alt.X("selection_name_1", title=None),
            alt.Y("selection_name_2", title=None),
            color=alt.Color("r2", scale=alt.Scale(zero=True)),
            tooltip=[
                alt.Tooltip(c, format=".3g") if c == "r2" else c
                for c in ["phenotype", "selection_name_1", "selection_name_2", "r2"]
            ],
        )
        .mark_rect(stroke="black")
        .properties(width=alt.Step(15), height=alt.Step(15), title=phenotype)
        .configure_axis(labelLimit=500)
    )

    display(corr_chart)

## Compute average mutation effects
Compute averages for each library individually and across all replicates of both libraries.
Note that the cross-library averages are **not** weighted equally by library, but are rather weighted by the number of total replicates for each library:

In [None]:
muteffects_avg_method = config["muteffects_avg_method"]
print(f"Defining the average as the {muteffects_avg_method} across replicates")
assert muteffects_avg_method in {"median", "mean"}

n_selections = muteffects["selection_name"].nunique()
assert n_selections == len(func_selections)

groupcols = ["sequential_site", "reference_site", "wildtype", "mutant", "phenotype"]
muteffects_avg = (
    muteffects.groupby(groupcols, as_index=False)
    .aggregate(
        effect=pd.NamedAgg("effect", muteffects_avg_method),
        effect_std=pd.NamedAgg("effect", "std"),
        times_seen=pd.NamedAgg("times_seen", lambda s: s.sum() / n_selections),
        n_libraries=pd.NamedAgg("library", "nunique"),
    )
    .assign(
        times_seen=lambda x: x["times_seen"].where(x["wildtype"] != x["mutant"], pd.NA)
    )
    # add per-library effects
    .merge(
        muteffects.groupby(["library", *groupcols], as_index=False)
        .aggregate(
            effect=pd.NamedAgg("effect", muteffects_avg_method),
        )
        .assign(library=lambda x: x["library"].astype(str) + " effect")
        .pivot_table(index=groupcols, columns="library", values="effect"),
        on=groupcols,
        validate="one_to_one",
        how="left",
    )
)

Write average mutation effects to CSVs:

In [None]:
for phenotype, df in muteffects_avg.groupby("phenotype"):
    outfile = config[f"muteffects_{phenotype}"]
    os.makedirs(os.path.dirname(outfile), exist_ok=True)
    print(f"Writing {phenotype}-phenotype mutation effects to {outfile}")
    df.to_csv(outfile, index=False, float_format="%.4f")

## Plot average mutational effects
These are interactive plots.
The `times_seen` is averaged across all replicates, and you can select how many libraries must have data for the mutation.
The tooltips show library-specific values as well.
Plot using the **reference** site numbering:

In [None]:
plot_kwargs = config["muteffects_plot_kwargs"]

df_to_plot = muteffects_avg.rename(columns={"reference_site": "site"})

if "addtl_slider_stats" not in plot_kwargs:
    plot_kwargs["addtl_slider_stats"] = {}

if "times_seen" not in plot_kwargs["addtl_slider_stats"]:
    plot_kwargs["addtl_slider_stats"]["times_seen"] = 1

if "n_libraries" not in plot_kwargs["addtl_slider_stats"]:
    plot_kwargs["addtl_slider_stats"]["n_libraries"] = 1

if "region" in sitenumbering_map.columns:
    df_to_plot = df_to_plot.merge(
        sitenumbering_map.rename(columns={"reference_site": "site"})[["site", "region"]]
    )
    plot_kwargs["site_zoom_bar_color_col"] = "region"

if "addtl_tooltip_stats" not in plot_kwargs:
    plot_kwargs["addtl_tooltip_stats"] = []

plot_kwargs["addtl_tooltip_stats"].append("effect_std")

if any(df_to_plot["site"] != df_to_plot["sequential_site"]):
    if "sequential_site" not in plot_kwargs["addtl_tooltip_stats"]:
        plot_kwargs["addtl_tooltip_stats"].append("sequential_site")

# make strings for proper plotting of null values for library effects
libraries = sorted(muteffects["library"].unique())
for lib in libraries:
    plot_kwargs["addtl_tooltip_stats"].append(f"{lib} effect")
    df_to_plot[f"{lib} effect"] = df_to_plot[f"{lib} effect"].map(
        lambda v: "na" if pd.isnull(v) else f"{v:.2f}"
    )

for phenotype, df in df_to_plot.groupby("phenotype"):

    print(f"\n{phenotype} phenotype\n")

    plot_kwargs["plot_title"] = f"functional effects ({phenotype} phenotype)"

    chart = polyclonal.plot.lineplot_and_heatmap(
        data_df=df,
        stat_col="effect",
        category_col="phenotype",
        alphabet=polyclonal.alphabets.biochem_order_aas(
            polyclonal.AAS_WITHSTOP_WITHGAP
        ),
        sites=sitenumbering_map.sort_values("sequential_site")[
            "reference_site"
        ].tolist(),
        **plot_kwargs,
    )

    heatmapfile = (
        os.path.splitext(config[f"muteffects_{phenotype}"])[0] + "_heatmap.html"
    )
    print(f"Saving to {heatmapfile}")
    chart.save(heatmapfile)

    display(chart)