# Average mutation functional effect shifts for a set of comparisons

Import Python modules.
We use `polyclonal` for the plotting:

In [None]:
from natsort import natsort_keygen

import pickle

import pandas as pd

import multidms

This notebook is parameterized by `papermill`.
The next cell is tagged as `parameters` to get the passed parameters.

In [None]:
# this cell is tagged parameters for `papermill` parameterization
site_numbering_map_csv = None
shifts_csv = None
shifts_html = None
params = None

load the pickle binary model collection dataframes and combine them for initializing the `multidms.ModelCollection` Object.

In [None]:
comparisons = params["comparisons"]

combined_fit_collection = pd.concat(
    [
        pickle.load(
            open(
                f"results/func_effect_shifts/by_comparison/{c}_fit_collection.pkl", "rb"
            )
        )
        for c in comparisons
    ]
).reset_index(drop=True)
combined_fit_collection

In [None]:
mc = multidms.ModelCollection(combined_fit_collection)

Plot correlation of shifts for each lasso shift, restricting to a minimum threshold `times_seen`, and not plotting shifts for wildtype residues.
In general, you might hope to find a lasso shift that has relatively few non-zero shifts, and those are correlated among comparisons.

In [None]:
try:
    times_seen = params["plot_kwargs"]["addtl_slider_stats"]["times_seen"]
except KeyError:
    times_seen = 3


sparsity_chart = mc.shift_sparsity(times_seen_threshold=times_seen, width_scalar=200)
correlation_chart = mc.mut_param_dataset_correlation(
    times_seen_threshold=times_seen, width_scalar=200
)

(sparsity_chart & correlation_chart).resolve_scale(color="independent")

Now make an interactive plots of the shifts.

In [None]:
chart = mc.mut_param_heatmap(
    times_seen_threshold=times_seen,
    query=f"scale_coeff_lasso_shift == {float(params['lasso_shift'])}",
    mut_param="shift",
)
print(f"Saving chart to {shifts_html}")
chart.save(shifts_html)
chart

Now write out the averaged shifts to csv. 

In [None]:
muts_df = mc.split_apply_combine_muts(
    groupby="scale_coeff_lasso_shift",  # we're averaging the results across all datasets in the following query,
    query=f"scale_coeff_lasso_shift == {float(params['lasso_shift'])}",  # query on fit collection
    aggregate_func=params["avg_method"],  # how to combine fit collection muts,
    inner_merge_dataset_muts=True,  # only keep muts that are in all datasets being combined
    times_seen_threshold=times_seen,
)
parse_mut_fn = mc.fit_models.iloc[0].model.data.parse_mut
muts_df["wildtype"], muts_df["site"], muts_df["mutant"] = zip(
    *muts_df.reset_index()["mutation"].map(parse_mut_fn)
)
(
    muts_df.reset_index()
    .rename(columns={"beta": "latent_phenotype_effect"})
    .sort_values("site", key=natsort_keygen())
    .to_csv(shifts_csv, index=False, float_format="%.4g")
)