# Fit global epistasis models to functional scores
Here we fit [global epistasis models](https://jbloomlab.github.io/dms_variants/dms_variants.globalepistasis.html) to the functional scores of the variants to estimate how mutations affect viral entry.

First import Python modules:

In [None]:
import itertools
import pickle

import altair as alt

import binarymap.binarymap

import dms_variants.globalepistasis

import pandas as pd

import plotnine as p9

import polyclonal.alphabets
import polyclonal.plot

In [None]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

## Read input data
Get parameterized variables from [papermill](https://papermill.readthedocs.io/):

In [None]:
# papermill parameters cell (tagged as `parameters`)
func_scores_csv = None
sitenumbering_map_csv = None
pickle_file = None
muteffects_latent_csv = None
muteffects_observed_csv = None
min_times_seen = None
likelihood = None
ftol = None
plot_kwargs = None

Read the functional scores, only keeping those with sufficient pre-selection counts:

In [None]:
func_scores = pd.read_csv(func_scores_csv, na_filter=False).query(
    "pre_count >= pre_count_threshold"
)

assert len(func_scores.groupby(["library", "pre_sample", "post_sample"])) == 1

## Fit the global epistasis model
Create a [BinaryMap](https://jbloomlab.github.io/binarymap/binarymap.binarymap.html#binarymap.binarymap.BinaryMap), using the sequentially numbered amino-acid substitutions:

In [None]:
bmap = binarymap.binarymap.BinaryMap(
    func_scores,
    substitutions_col="aa_substitutions_sequential",
    alphabet=binarymap.binarymap.AAS_WITHSTOP_WITHGAP,
    cols_optional=False,
)

Now fit a [GlobalEpistasis](https://jbloomlab.github.io/dms_variants/dms_variants.globalepistasis.html) model and a [NoEpistasis](https://jbloomlab.github.io/dms_variants/dms_variants.globalepistasis.html#dms_variants.globalepistasis.NoEpistasis) (linear) model:

In [None]:
fit_df = dms_variants.globalepistasis.fit_models(
    binarymap=bmap,
    likelihood=likelihood,
    ftol=ftol,
)

Results of the fitting.
The global epistasis model should fit better than the no-epistasis model if fitting is working correctly, as the latter is nested within the former:

In [None]:
fit_df.drop(columns=["n_latent_phenotypes", "model"]).round(1)

Get the global epistasis model:

In [None]:
model = fit_df.set_index("description").at["global epistasis", "model"]

## Examine global epistasis fit
Plot the relationships among the latent phenotypes from the model, the observed phenotypes from the model, and the measured functional scores for all variants used to fit the model:

In [None]:
for x, y in itertools.combinations(
    ["latent_phenotype", "observed_phenotype", "func_score"], 2
):
    p = (
        p9.ggplot(model.phenotypes_df, p9.aes(x, y))
        + p9.geom_point(alpha=0.05, size=0.5)
        + p9.theme(figure_size=(2.25, 2.25))
    )
    _ = p.draw()

## Mutation effects from global epistasis model
The global epistasis model deconvolves the effects of individual mutations on both the observed and latent phenotype scales.
Get those effects, and make a version where the site numbers are both sequential and reference based, and annotate by number of times each mutation is seen (number of variants containing it):

In [None]:
# to renumber to reference
sitenumbering_map = pd.read_csv(sitenumbering_map_csv)

# number of times each mutation seen
times_seen = (
    func_scores["aa_substitutions_sequential"]
    .str.split()
    .explode()
    .dropna()
    .value_counts()
)
assert times_seen.notnull().all() and all(times_seen.astype(int) == times_seen)

# get mutation effects on both phenotypes, renumber to reference
muteffects = (
    pd.concat(
        [
            model.single_mut_effects(phenotype).assign(phenotype=phenotype)
            for phenotype in ["observed", "latent"]
        ],
        ignore_index=True,
    )
    .assign(
        times_seen=lambda x: x["mutation"].map(times_seen.to_dict()).astype("Int64")
    )
    .drop(columns="mutation")
    .rename(columns={"site": "sequential_site"})
    .merge(
        sitenumbering_map[["sequential_site", "reference_site"]],
        how="left",
        on="sequential_site",
        validate="many_to_one",
    )[
        [
            "sequential_site",
            "reference_site",
            "wildtype",
            "mutant",
            "effect",
            "times_seen",
            "phenotype",
        ]
    ]
)

## Plot mutation effects
We use `polyclonal.plot.lineplot_and_heatmap` function to make an interactive heatmap showing the effects of mutations on both the observed and latent phenotypes.
These are plotted using the **reference** sites:

In [None]:
df_to_plot = muteffects.rename(columns={"reference_site": "site"})

if "addtl_slider_stats" not in plot_kwargs:
    plot_kwargs["addtl_slider_stats"] = {"times_seen": 1}
elif "times_seen" not in plot_kwargs["addtl_slider_stats"]:
    plot_kwargs["addtl_slider_stats"]["times_seen"] = 1
    
if "region" in sitenumbering_map.columns:
    df_to_plot = df_to_plot.merge(sitenumbering_map.rename(columns={"reference_site": "site"})[["site", "region"]])
    plot_kwargs["site_zoom_bar_color_col"] = "region"
    
if any(df_to_plot["site"] != df_to_plot["sequential_site"]):
    if "addtl_tooltip_stats" not in plot_kwargs:
        plot_kwargs["addtl_tooltip_stats"] = ["sequential_site"]
    elif "sequential_site" not in plot_kwargs["addtl_tooltip_stats"]:
        plot_kwargs["addtl_tooltip_stats"].append("sequential_site")

for phenotype, df in df_to_plot.groupby("phenotype"):
    
    print(f"\n{phenotype} phenotype\n")
    
    if "plot_title" not in plot_kwargs:
        plot_kwargs["plot_title"] = f"functional effects ({phenotype} phenotype)"

    chart = polyclonal.plot.lineplot_and_heatmap(
        data_df=df,
        stat_col="effect",
        category_col="phenotype",
        alphabet=polyclonal.alphabets.biochem_order_aas(polyclonal.AAS_WITHSTOP_WITHGAP),
        sites=sitenumbering_map.sort_values("sequential_site")["reference_site"].tolist(),
        **plot_kwargs,
    )
    
    display(chart)

## Write output files
Write the output files:

In [None]:
print(f"Pickling model to {pickle_file}")
with open(pickle_file, "wb") as f:
    pickle.dump(model, f)

for phenotype, outfile in [
    ("latent", muteffects_latent_csv),
    ("observed", muteffects_observed_csv),
]:
    print(f"Writing {phenotype}-scale mutation effects to {outfile}")
    (
        muteffects.query("phenotype == @phenotype")
        .drop(columns="phenotype")
        .to_csv(outfile, index=False, float_format="%.4f")
    )