# Average mutational effects for an antibody/serum, or receptor affinity from soluble receptor neutralization
This notebook averages selections that measure escape from neutralization by an antibody or serum, or receptor-affinity from neutralization by soluble receptor.
In

Import Python modules.
We use `polyclonal` for the averaging and plotting:

In [None]:
import pickle

import altair as alt

import pandas as pd

import polyclonal

_ = alt.data_transformers.disable_max_rows()

This notebook is parameterized by `papermill`.
The next cell is tagged as `parameters` to get the passed parameters.

In [None]:
# this cell is tagged parameters for `papermill` parameterization
assay = None
site_numbering_map_csv = None
prob_escape_mean_csvs = None
pickles = None
avg_pickle_file = None
effect_csv = None
icXX_csv = None
effect_html = None
icXX_html = None
params = None

In [None]:
print(f"Analyzing results for {assay=}")

Read the input data and parameters:

In [None]:
site_numbering_map = pd.read_csv(site_numbering_map_csv).rename(
    columns={"reference_site": "site"}
)

assert len(params["selections"]) == len(set(params["selections"]))

# read Polyclonal models into a data frame that can be passed to PolyclonalAverage
models_df = pd.DataFrame(
    [(s, pickle.load(open(f, "rb"))) for s, f in zip(params["selections"], pickles)],
    columns=["selection", "model"],
)

# read prob_escape means all into one data frame
prob_escape_means = pd.concat(
    [
        pd.read_csv(f).assign(selection=s)
        for s, f in zip(params["selections"], prob_escape_mean_csvs)
    ],
    ignore_index=True,
)

# get the plot kwargs
escape_plot_kwargs = params["escape_plot_kwargs"]

## Neutralization at concentrations used for each selection
For each selection going into the average, plot the average fraction neutralization (probability of escape) of variants with different numbers of mutations, both for the censored values used to fit the models and the uncensored values.
Note the concentrations **not** used in the model fits are shown fainter and in a different shape:

In [None]:
mean_prob_escape_chart = (
    alt.Chart(prob_escape_means)
    .encode(
        x=alt.X("concentration", scale=alt.Scale(type="log")),
        y=alt.Y(
            "probability escape",
            scale=(
                alt.Scale(type="symlog", constant=0.04)
                if assay == "antibody_escape"
                else alt.Scale()
            ),
        ),
        column=alt.Column(
            "censored",
            title=None,
            header=alt.Header(labelFontWeight="bold", labelFontSize=10),
        ),
        row=alt.Row(
            "selection",
            title=None,
            header=alt.Header(labelFontWeight="bold", labelFontSize=10),
        ),
        color=alt.Color("n_substitutions"),
        tooltip=[
            alt.Tooltip(c, format=".3g") if c == "probability escape" else c
            for c in prob_escape_means.columns
        ],
        shape=alt.Shape("use_in_fit", scale=alt.Scale(domain=[True, False])),
        opacity=alt.Opacity(
            "use_in_fit", scale=alt.Scale(domain=[True, False], range=[0.9, 0.3])
        ),
    )
    .mark_line(point=True, size=0.75, opacity=0.8)
    .properties(width=230, height=145)
    .configure_axis(grid=False)
    .configure_point(size=50)
)

mean_prob_escape_chart

## Average mutation effects
First build a `PolyclonalAverage`:

In [None]:
avg_model = polyclonal.PolyclonalAverage(models_df)

print(f"Saving the average model to {avg_pickle_file}")
with open(avg_pickle_file, "wb") as f:
    pickle.dump(avg_model, f)

Get the `times_seen` (how many variants a mutation must be seen in) cutoff for the plots:

In [None]:
try:
    times_seen = escape_plot_kwargs["addtl_slider_stats"]["times_seen"]
except KeyError:
    times_seen = 1

print(f"Making plots for {times_seen=}")

Correlation of escape across different selections:

In [None]:
avg_model.mut_escape_corr_heatmap(min_times_seen=times_seen)

Neutralization curves against unmutated protein (which reflect the wildtype activities, Hill coefficients, and non-neutralizable fractions):

In [None]:
avg_model.curves_plot()

Site line plots for the site values for each individual selection (model) in the average.
This makes it easier to tell if one selection is an outlier before we plot the full averages below, and how correlated the selections are.
Note the plot is interactive: you can mouseover points and change the site metric shown.

In [None]:
per_selection_site_escape = (
    avg_model.mut_escape_site_summary_df_replicates(min_times_seen=times_seen)
    .melt(
        id_vars=["selection", "site", "wildtype", "epitope"],
        value_vars=["mean", "total positive", "total negative"],
        var_name="site statistic",
        value_name="site escape",
    )
    .merge(
        site_numbering_map[["site", "sequential_site"]].assign(
            site=lambda x: x["site"].astype(type(avg_model.sites[0]))
        ),
        validate="many_to_one",
    )
)

if assay == "receptor_affinity":
    # invert because negative escape means better affinity
    per_selection_site_escape = per_selection_site_escape.assign(
        **{
            "site escape": lambda x: -x["site escape"],
            "site statistic": lambda x: x["site statistic"].map(
                lambda s: {
                    "mean": "mean",
                    "total positive": "total negative",
                    "total negative": "total positive",
                }[s]
            ),
        }
    )

site_statistic_selection = alt.selection_point(
    fields=["site statistic"],
    bind=alt.binding_select(
        name="site statistic",
        options=per_selection_site_escape["site statistic"].unique(),
    ),
    value="mean",
)

site_selection = alt.selection_point(fields=["site"], on="mouseover", empty=False)

per_selection_site_escape_chart_base = (
    alt.Chart(per_selection_site_escape)
    .encode(
        x=alt.X(
            "site",
            sort=alt.SortField("sequential_site"),
            axis=alt.Axis(labelOverlap=True),
            scale=alt.Scale(nice=False, zero=False),
        ),
        y=alt.Y("site escape", title=assay),
        color="epitope",
        tooltip=[
            "site",
            alt.Tooltip("site escape", format=".2f"),
        ],
    )
    .properties(width=800, height=85)
    .add_params(site_statistic_selection, site_selection)
    .transform_filter(site_statistic_selection)
)

per_selection_site_escape_chart_lines = per_selection_site_escape_chart_base.mark_line(
    size=0.75
)

per_selection_site_escape_chart_points = per_selection_site_escape_chart_base.encode(
    size=alt.condition(site_selection, alt.value(75), alt.value(30)),
    strokeWidth=alt.condition(site_selection, alt.value(2), alt.value(0)),
).mark_circle(filled=True, stroke="orange")

per_selection_escape_chart = (
    (per_selection_site_escape_chart_lines + per_selection_site_escape_chart_points)
    .facet(
        facet=alt.Facet(
            "selection",
            title=None,
            header=alt.Header(labelPadding=0),
        ),
        columns=1,
        spacing=5,
    )
    .configure_axis(grid=False)
)

per_selection_escape_chart

Plot and save the mutation effects for the average model:

In [None]:
# first build up arguments used to format plot
escape_plot_kwargs = params["escape_plot_kwargs"]
plot_hide_stats = params["plot_hide_stats"]

if "addtl_slider_stats" not in escape_plot_kwargs:
    escape_plot_kwargs["addtl_slider_stats"] = {}
if "addtl_slider_stats_hide_not_filter" not in escape_plot_kwargs:
    escape_plot_kwargs["addtl_slider_stats_hide_not_filter"] = []

escape_plot_kwargs["df_to_merge"] = []

for stat, stat_d in plot_hide_stats.items():
    escape_plot_kwargs["addtl_slider_stats"][stat] = stat_d["init"]
    escape_plot_kwargs["addtl_slider_stats_hide_not_filter"].append(stat)
    merge_df = pd.read_csv(stat_d["csv"]).rename(columns={stat_d["csv_col"]: stat})
    if "min_filters" in stat_d:
        for col, col_min in stat_d["min_filters"].items():
            if col not in merge_df.columns:
                raise ValueError(f"{stat=} CSV lacks {col=}\n{merge_df.columns=}")
            merge_df = merge_df[merge_df[col] >= col_min]
    escape_plot_kwargs["df_to_merge"].append(merge_df[["site", "mutant", stat]])

escape_plot_kwargs["df_to_merge"].append(
    site_numbering_map[["site", "sequential_site", "region"]]
)

if "init_n_models" not in escape_plot_kwargs:
    escape_plot_kwargs["init_n_models"] = len(avg_model.models) // 2 + 1

if assay == "receptor_affinity":
    escape_plot_kwargs["scale_stat_col"] = -1
    escape_plot_kwargs["rename_stat_col"] = "receptor affinity"

print(f"Writing mutation values to {effect_csv}")
effects_for_csv = avg_model.mut_escape_df_w_model_values.drop(
    columns=["escape_std", "escape_min_magnitude"]
)
for c in ["escape_mean", "escape_median"] + params["selections"]:
    if assay == "receptor_affinity":
        # invert as negative means more affinity, and rename
        effects_for_csv[c] = -effects_for_csv[c]
        if c in {"escape_mean", "escape_median"}:
            effects_for_csv = effects_for_csv.rename(
                columns={c: c.replace("escape_", "affinity_")}
            )
effects_for_csv.to_csv(effect_csv, index=False, float_format="%.4g")

escape_chart = avg_model.mut_escape_plot(**escape_plot_kwargs)
print(f"Writing chart to {effect_html}")
escape_chart.save(effect_html)

display(escape_chart)

Plot and save the predicted change in ICXX induced by each mutation:

In [None]:
icXX = params["icXX"]
print(f"Getting predicted changes in IC{icXX}")

icXX_col = f"IC{icXX}"
log_fold_change_icXX_col = f"log2 fold change {icXX_col}"

print(f"Writing changes in ICXX to {icXX_csv}")
icXX_for_csv = avg_model.mut_icXX_df_w_model_values(
    x=icXX / 100.0,
    icXX_col=icXX_col,
    log_fold_change_icXX_col=log_fold_change_icXX_col,
).drop(columns=[log_fold_change_icXX_col + s for s in [" min_magnitude", " std"]])
for c in [log_fold_change_icXX_col + s for s in [" mean", " median"]] + params[
    "selections"
]:
    if assay == "receptor_affinity":
        # invert as negative means more affinity
        icXX_for_csv[c] = -icXX_for_csv[c]
icXX_for_csv.to_csv(icXX_csv, index=False, float_format="%.4g")

if assay == "receptor_affinity" and "rename_stat_col" in escape_plot_kwargs:
    del escape_plot_kwargs["rename_stat_col"]

icXX_chart = avg_model.mut_icXX_plot(
    x=icXX / 100.0,
    icXX_col=icXX_col,
    log_fold_change_icXX_col=log_fold_change_icXX_col,
    **escape_plot_kwargs,
)
print(f"Writing ICXX chart to {icXX_html}")
icXX_chart.save(icXX_html)

display(icXX_chart)