# Fit polyclonal model
Here we fit [polyclonal](https://jbloomlab.github.io/polyclonal) models to the data.

First, import Python modules:

In [None]:
import pickle

import altair as alt

import pandas as pd

import polyclonal

import yaml

In [None]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

## Read input data

Get parameterized variable from [papermill](https://papermill.readthedocs.io/)

In [None]:
# papermill parameters cell (tagged as `parameters`)
prob_escape_csv = None
n_threads = None
pickle_file = None
antibody = None

Read the probabilities of escape, and filter for those with sufficient no-antibody counts:

In [None]:
print(f"\nReading probabilities of escape from {prob_escape_csv}")

prob_escape = pd.read_csv(
    prob_escape_csv, keep_default_na=False, na_values="nan"
).query("`no-antibody_count` >= no_antibody_count_threshold")
assert prob_escape.notnull().all().all()

Read the rest of the configuration and input data:

In [None]:
# get information from config
with open("config.yaml") as f:
    config = yaml.safe_load(f)

antibody = prob_escape["antibody"].unique()
assert len(antibody) == 1, antibody
antibody = antibody[0]

# get the polyclonal configuration for this antibody
with open(config["polyclonal_config"]) as f:
    polyclonal_config = yaml.safe_load(f)
if antibody not in polyclonal_config:
    raise ValueError(f"`polyclonal_config` lacks configuration for {antibody=}")
antibody_config = polyclonal_config[antibody]

# print names of variables and settings
print(f"{antibody=}")
print(f"{n_threads=}")
print(f"{pickle_file=}")
print(f"{antibody_config=}")

## Some summary statistics
Note that these statistics are only for the variants that passed upstream filtering in the pipeline.

Number of variants per concentration:

In [None]:
display(
    prob_escape.groupby("antibody_concentration").aggregate(
        n_variants=pd.NamedAgg("barcode", "nunique")
    )
)

Plot mean probability of escape across all variants with the indicated number of mutations.
Note that this plot weights each variant the same in the means regardless of how many barcode counts it has.
We plot means for both censored (set to between 0 and 1) and uncensored probabilities of escape.
Also, note it uses a symlog scale for the y-axis.
Mouseover points for values:

In [None]:
max_aa_subs = 4  # group if >= this many substitutions

mean_prob_escape = (
    prob_escape.assign(
        n_subs=lambda x: (
            x["aa_substitutions_sequential"]
            .str.split()
            .map(len)
            .clip(upper=max_aa_subs)
            .map(lambda n: str(n) if n < max_aa_subs else f">{max_aa_subs - 1}")
        )
    )
    .groupby(["antibody_concentration", "n_subs"], as_index=False)
    .aggregate({"prob_escape": "mean", "prob_escape_uncensored": "mean"})
    .rename(
        columns={
            "prob_escape": "censored to [0, 1]",
            "prob_escape_uncensored": "not censored",
        }
    )
    .melt(
        id_vars=["antibody_concentration", "n_subs"],
        var_name="censored",
        value_name="probability escape",
    )
)

mean_prob_escape_chart = (
    alt.Chart(mean_prob_escape)
    .encode(
        x=alt.X("antibody_concentration"),
        y=alt.Y(
            "probability escape",
            scale=alt.Scale(type="symlog", constant=0.05),
        ),
        column=alt.Column("censored", title=None),
        color=alt.Color("n_subs", title="n substitutions"),
        tooltip=[
            alt.Tooltip(c, format=".3g") if mean_prob_escape[c].dtype == float else c
            for c in mean_prob_escape.columns
        ],
    )
    .mark_line(point=True, size=0.5)
    .properties(width=200, height=125)
    .configure_axis(grid=False)
)

mean_prob_escape_chart

## Fit `polyclonal` models
First, get the fitting related keyword arguments from the configuration passed by `snakemake`:

In [None]:
n_bootstrap_samples = antibody_config["n_bootstrap_samples"]
print(f"{n_bootstrap_samples=}")

times_seen = antibody_config["times_seen"]
print(f"{times_seen=}")

max_epitopes = antibody_config["max_epitopes"]
print(f"{max_epitopes=}")

fit_kwargs = {
    "reg_escape_weight": antibody_config["reg_escape_weight"],
    "reg_spread_weight": antibody_config["reg_spread_weight"],
    "reg_activity_weight": antibody_config["reg_activity_weight"],
}
print(f"{fit_kwargs=}")

min_epitope_activity_to_include = antibody_config["min_epitope_activity_to_include"]
print(f"{min_epitope_activity_to_include=}")

Fit a model to all the data, and keep adding epitopes until we either reach the maximum specified or the new epitope has negative activity.
This will be the "root" model for the bootstrapping.
Note that for now the amino-acid substitutions are in **sequential** (not reference) site numbering:

In [None]:
models = []

for n_epitopes in range(1, max_epitopes + 1):
    print(f"\nFitting model with {n_epitopes=}")

    # create model
    model = polyclonal.Polyclonal(
        n_epitopes=n_epitopes,
        data_to_fit=prob_escape.rename(
            columns={
                "antibody_concentration": "concentration",
                "aa_substitutions_sequential": "aa_substitutions",
            }
        ),
        alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
    )

    # fit model
    opt_res = model.fit(logfreq=200, **fit_kwargs)

    # display activities
    print("Activities of epitopes:")
    display(model.activity_wt_df.round(1))
    print("Max and mean absolute-value escape at each epitope:")
    display(
        model.mut_escape_df.groupby("epitope")
        .aggregate(
            max_escape=pd.NamedAgg("escape", "max"),
            mean_abs_escape=pd.NamedAgg("escape", lambda s: s.abs().mean()),
        )
        .round(1)
    )

    # stop if activity below threshold for any epitope and fit at least one epitope
    if len(models) and any(
        model.activity_wt_df["activity"] <= min_epitope_activity_to_include
    ):
        print(f"Stop fitting, epitope has activity <={min_epitope_activity_to_include}")
        root_model = models[-1]
        models.append(model)
        break
    else:
        models.append(model)
        root_model = model

print(f"\nThe selected model has {len(root_model.epitopes)} epitopes")

Now perform bootstrapping:

In [None]:
bootstrap_model = polyclonal.PolyclonalCollection(
    root_polyclonal=root_model,
    n_bootstrap_samples=n_bootstrap_samples,
    n_threads=n_threads,
)

n_fit, n_failed = bootstrap_model.fit_models(
    **fit_kwargs,
)

print(f"Successfully fit {n_fit=} models with {n_failed}")

assert n_fit == n_bootstrap_samples and n_failed == 0

Epitope activities:

In [None]:
bootstrap_model.activity_wt_barplot()

Line plot of escape at each site:

In [None]:
bootstrap_model.mut_escape_lineplot(
    mut_escape_site_summary_df_kwargs={"min_times_seen": times_seen},
)

Escape for each mutation:

In [None]:
bootstrap_model.mut_escape_heatmap(init_min_times_seen=times_seen)

Pickle and save bootstrapped models:

In [None]:
print(f"Saving bootstrapped models to {pickle_file=}")
with open(pickle_file, "wb") as f:
    pickle.dump(bootstrap_model, f)