# Fit polyclonal model
Here we fit [polyclonal](https://jbloomlab.github.io/polyclonal) models to the data.

First, import Python modules:

In [1]:
import pickle

import altair as alt

import pandas as pd

import polyclonal

import yaml

In [2]:
import os
os.chdir('../../')

In [3]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

## Read input data

Get parameterized variable from [papermill](https://papermill.readthedocs.io/)

In [14]:
# papermill parameters cell (tagged as `parameters`)
prob_escape_csv = None
n_threads = None
pickle_file = None
antibody = None

In [44]:
# Parameters
prob_escape_csv = "results/prob_escape/libA_221021_1_1C04-5G04_1_prob_escape.csv"
pickle_file = "results/polyclonal_fits/libA_221021_1_1C04-5G04_1.pickle"
n_threads = 2


Read the probabilities of escape, and filter for those with sufficient no-antibody counts:

In [45]:
print(f"\nReading probabilities of escape from {prob_escape_csv}")

prob_escape = pd.read_csv(
    prob_escape_csv, keep_default_na=False, na_values="nan"
).query("`no-antibody_count` >= no_antibody_count_threshold")
assert prob_escape.notnull().all().all()


Reading probabilities of escape from results/prob_escape/libA_221021_1_1C04-5G04_1_prob_escape.csv


Read the rest of the configuration and input data:

In [17]:
# get information from config
with open("config.yaml") as f:
    config = yaml.safe_load(f)

antibody = prob_escape["antibody"].unique()
assert len(antibody) == 1, antibody
antibody = antibody[0]

# get site numbering map and the reference sites in order
site_numbering_map = pd.read_csv(config["site_numbering_map"])
reference_sites = site_numbering_map.sort_values("sequential_site")[
    "reference_site"
].tolist()

# get the polyclonal configuration for this antibody
with open(config["polyclonal_config"]) as f:
    polyclonal_config = yaml.safe_load(f)
if antibody not in polyclonal_config:
    raise ValueError(f"`polyclonal_config` lacks configuration for {antibody=}")
antibody_config = polyclonal_config[antibody]

# print names of variables and settings
print(f"{antibody=}")
print(f"{n_threads=}")
print(f"{pickle_file=}")
print(f"{antibody_config=}")

antibody='1C04-5G04'
n_threads=2
pickle_file='results/polyclonal_fits/libA_221021_1_1C04-5G04_1.pickle'
antibody_config={'min_epitope_activity_to_include': 0.2, 'plot_kwargs': {'addtl_slider_stats': {'times_seen': 3, 'functional effect': -1.38}, 'slider_binding_range_kwargs': {'n_models': {'step': 1}, 'times_seen': {'step': 1, 'min': 1, 'max': 25}}, 'heatmap_max_at_least': 2, 'heatmap_min_at_least': -2}, 'max_epitopes': 2, 'fit_kwargs': {'reg_escape_weight': 0.1, 'reg_spread_weight': 0.25, 'reg_activity_weight': 1.0}}


## Some summary statistics
Note that these statistics are only for the variants that passed upstream filtering in the pipeline.

Number of variants per concentration:

In [85]:
display(
    prob_escape.groupby("antibody_concentration").aggregate(
        n_variants=pd.NamedAgg("barcode", "nunique")
    )
)

Unnamed: 0_level_0,n_variants
antibody_concentration,Unnamed: 1_level_1
1.55,35041
3.1,35041
6.2,35041


In [84]:
prob_escape = prob_escape.loc[(prob_escape['antibody_concentration'] != 24.80) &
                              (prob_escape['antibody_concentration'] != 12.40)
                             ]

Plot mean probability of escape across all variants with the indicated number of mutations.
Note that this plot weights each variant the same in the means regardless of how many barcode counts it has.
We plot means for both censored (set to between 0 and 1) and uncensored probabilities of escape.
Also, note it uses a symlog scale for the y-axis.
Mouseover points for values:

In [86]:
max_aa_subs = 4  # group if >= this many substitutions

mean_prob_escape = (
    prob_escape.assign(
        n_subs=lambda x: (
            x["aa_substitutions_reference"]
            .str.split()
            .map(len)
            .clip(upper=max_aa_subs)
            .map(lambda n: str(n) if n < max_aa_subs else f">{max_aa_subs - 1}")
        )
    )
    .groupby(["antibody_concentration", "n_subs"], as_index=False)
    .aggregate({"prob_escape": "mean", "prob_escape_uncensored": "mean"})
    .rename(
        columns={
            "prob_escape": "censored to [0, 1]",
            "prob_escape_uncensored": "not censored",
        }
    )
    .melt(
        id_vars=["antibody_concentration", "n_subs"],
        var_name="censored",
        value_name="probability escape",
    )
)

mean_prob_escape_chart = (
    alt.Chart(mean_prob_escape)
    .encode(
        x=alt.X("antibody_concentration"),
        y=alt.Y(
            "probability escape",
            scale=alt.Scale(type="symlog", constant=0.05),
        ),
        column=alt.Column("censored", title=None),
        color=alt.Color("n_subs", title="n substitutions"),
        tooltip=[
            alt.Tooltip(c, format=".3g") if mean_prob_escape[c].dtype == float else c
            for c in mean_prob_escape.columns
        ],
    )
    .mark_line(point=True, size=0.5)
    .properties(width=200, height=125)
    .configure_axis(grid=False)
)

mean_prob_escape_chart

  for col_name, dtype in df.dtypes.iteritems():


## Fit `polyclonal` model
First, get the fitting related keyword arguments from the configuration passed by `snakemake`:

In [37]:
max_epitopes = antibody_config["max_epitopes"]
print(f"{max_epitopes=}")

# fit_kwargs = antibody_config["fit_kwargs"]
fit_kwargs = {'reg_escape_weight': 0.1, 'reg_spread_weight': 0.25, 'reg_activity_weight': 1.0,
              'reg_similarity_weight': 0.12
             }
print(f"{fit_kwargs=}")

min_epitope_activity_to_include = antibody_config["min_epitope_activity_to_include"]
print(f"{min_epitope_activity_to_include=}")

max_epitopes=2
fit_kwargs={'reg_escape_weight': 0.1, 'reg_spread_weight': 0.25, 'reg_activity_weight': 1.0, 'reg_similarity_weight': 0.12}
min_epitope_activity_to_include=0.2


Fit a model to all the data, and keep adding epitopes until we either reach the maximum specified or the new epitope has negative activity.
Note that that we fit using the **reference** based-site-numbering scheme, so results are shown with those numbers:Z

In [20]:
models = []

for n_epitopes in range(2, 3):
    print(f"\nFitting model with {n_epitopes=}")

    # create model
    model = polyclonal.Polyclonal(
        n_epitopes=n_epitopes,
        data_to_fit=prob_escape.rename(
            columns={
                "antibody_concentration": "concentration",
                "aa_substitutions_reference": "aa_substitutions",
            }
        ),
        alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
        sites=reference_sites,
    )

    # fit model
#     opt_res = model.fit(logfreq=200, **fit_kwargs)
    opt_res = model.fit(
        logfreq=200,
        reg_escape_weight=0.1,
        reg_similarity_weight=1.0,  # regularize epitope similarity
    )

    # display activities
    print("Activities of epitopes:")
    display(model.activity_wt_df.round(1))
    print("Max and mean absolute-value escape at each epitope:")
    display(
        model.mut_escape_df.groupby("epitope")
        .aggregate(
            max_escape=pd.NamedAgg("escape", "max"),
            mean_abs_escape=pd.NamedAgg("escape", lambda s: s.abs().mean()),
        )
        .round(1)
    )

    # stop if activity below threshold for any epitope and fit at least one epitope
    if len(models) and any(
        model.activity_wt_df["activity"] <= min_epitope_activity_to_include
    ):
        print(f"Stop fitting, epitope has activity <={min_epitope_activity_to_include}")
        models.append(model)
        model = models[-2]  # get previous model
        break
    else:
        models.append(model)

print(f"\nThe selected model has {len(model.epitopes)} epitopes")


Fitting model with n_epitopes=2
# First fitting site-level model.
# Starting optimization of 1054 parameters at Mon Oct 31 21:36:09 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.069164       1124.7       1123.8            0            0            0      0.90499
           58       4.2668       479.42       471.45       5.8215            0       0.1108       2.0335
# Successfully finished at Mon Oct 31 21:36:13 2022.
# Starting optimization of 6922 parameters at Mon Oct 31 21:36:13 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.078964       695.71        611.3       65.016   2.0744e-31       17.363       2.0335
           53       4.6264       612.75       606.86       2.6958      0.23609     0.046231       2.9155
# Successfully finished at Mon Oct 31 21:36:18 2022.
Activities of epitopes:


Unnamed: 0,epitope,activity
0,1,2.9
1,2,0.2


Max and mean absolute-value escape at each epitope:


Unnamed: 0_level_0,max_escape,mean_abs_escape
epitope,Unnamed: 1_level_1,Unnamed: 2_level_1
1,2.2,0.0
2,0.2,0.0



The selected model has 2 epitopes


Epitope activities:

In [21]:
model.activity_wt_barplot()

  for col_name, dtype in df.dtypes.iteritems():


Plot of escape values:

In [22]:
model.mut_escape_plot()

  for col_name, dtype in df.dtypes.iteritems():


In [87]:
# NBVAL_IGNORE_OUTPUT
model = polyclonal.Polyclonal(
    n_epitopes=2,
    data_to_fit=prob_escape.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_sequential": "aa_substitutions",
        }
    ).query("concentration > 1.0"),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
)

# fit model
opt_res = model.fit(
    logfreq=200,
    reg_escape_weight=0.1,
    reg_spread_weight=0.25,
    reg_activity_weight=1,
    reg_similarity_weight=0.1
#     reg_similarity_weight=1.501e-2,  # regularize epitope similarity
)

model.mut_escape_plot()

# First fitting site-level model.
# Starting optimization of 1054 parameters at Mon Oct 31 22:25:15 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.043535       1118.9         1118            0            0            0      0.90499
           87       4.1844       472.99       465.02       5.8545            0      0.04603       2.0701
# Successfully finished at Mon Oct 31 22:25:19 2022.
# Starting optimization of 6922 parameters at Mon Oct 31 22:25:19 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.051265       678.69        603.9       65.128   3.1662e-31       7.5878       2.0701
           49       3.0629       605.31       599.26       2.8529       0.2353     0.022621       2.9398
# Successfully finished at Mon Oct 31 22:25:22 2022.


  for col_name, dtype in df.dtypes.iteritems():


In [40]:
df_to_merge = site_numbering_map.rename(columns={"reference_site": "site"})

plot_kwargs = antibody_config["plot_kwargs"]
if "plot_title" not in plot_kwargs:
    plot_kwargs["plot_title"] = str(antibody)
if "region" in site_numbering_map:
    plot_kwargs["site_zoom_bar_color_col"] = "region"
if "addtl_slider_stats" not in plot_kwargs:
    plot_kwargs["addtl_slider_stats"] = {"times_seen": 1}
elif "times_seen" not in plot_kwargs["addtl_slider_stats"]:
    plot_kwargs["addtl_slider_stats"]["times_seen"] = 1
if "functional effect" in plot_kwargs["addtl_slider_stats"]:
    del plot_kwargs["addtl_slider_stats"]["functional effect"]  # only antibody averages
if any(site_numbering_map["sequential_site"] != site_numbering_map["reference_site"]):
    if "addtl_tooltip_stats" not in plot_kwargs:
        plot_kwargs["addtl_tooltip_stats"] = ["sequential_site"]
    else:
        plot_kwargs["addtl_tooltip_stats"].append("sequential_site")

model.mut_escape_plot(df_to_merge=df_to_merge, **plot_kwargs)

  for col_name, dtype in df.dtypes.iteritems():


In [29]:
# NBVAL_IGNORE_OUTPUT
model = polyclonal.Polyclonal(
    n_epitopes=2,
    data_to_fit=prob_escape.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_sequential": "aa_substitutions",
        }
    ).query("concentration > 1.0"),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
)

# fit model
opt_res = model.fit(
    logfreq=200,
    reg_escape_weight=0.1,
    reg_similarity_weight=1.0,  # regularize epitope similarity
)
model.mut_escape_plot(category_colors={"1": "#0072B2", "2": "#CC79A7"})

# First fitting site-level model.
# Starting optimization of 1054 parameters at Mon Oct 31 19:22:14 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.075698       1124.7       1123.8            0            0            0      0.90499
           58       4.1571       479.42       471.45       5.8215            0       0.1108       2.0335
# Successfully finished at Mon Oct 31 19:22:18 2022.
# Starting optimization of 6922 parameters at Mon Oct 31 19:22:18 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0       0.0781       695.71        611.3       65.016   1.9809e-31       17.363       2.0335
           53        4.537       612.75       606.86       2.6958      0.23609     0.046231       2.9155
# Successfully finished at Mon Oct 31 19:22:23 2022.


  for col_name, dtype in df.dtypes.iteritems():


In [60]:
# NBVAL_IGNORE_OUTPUT
model = polyclonal.Polyclonal(
    n_epitopes=2,
    data_to_fit=prob_escape.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_sequential": "aa_substitutions",
        }
    ).query("concentration > 1.0"),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
)

# fit model
opt_res = model.fit(
    logfreq=200,
    reg_escape_weight=0.1,
    reg_similarity_weight=0.8,  # regularize epitope similarity
)
model.mut_escape_plot(category_colors={"1": "#0072B2", "2": "#CC79A7"})

# First fitting site-level model.
# Starting optimization of 1054 parameters at Mon Oct 31 19:40:09 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.071874       1124.7       1123.8            0            0            0      0.90499
           69       4.8528        479.4       471.44       5.8144            0        0.105       2.0403
# Successfully finished at Mon Oct 31 19:40:14 2022.
# Starting optimization of 6922 parameters at Mon Oct 31 19:40:14 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.076147       694.75       611.25       64.945   2.6095e-31       16.518       2.0403
           54       4.4088       612.74       606.82        2.722      0.23712     0.043919       2.9147
# Successfully finished at Mon Oct 31 19:40:18 2022.


  for col_name, dtype in df.dtypes.iteritems():


In [72]:
reg_weights_to_test = [
    0.1,
    0.2,
    0.3,
    0.4,
    0.5,
    0.6,
    0.7,
    0.8,
    0.9
]

# models = []

for weight in reg_weights_to_test:
    model = polyclonal.Polyclonal(
        n_epitopes=2,
        data_to_fit=prob_escape.rename(
            columns={
                "antibody_concentration": "concentration",
                "aa_substitutions_sequential": "aa_substitutions",
            }
        ).query("concentration > 1.0"),
        alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
    )

    # fit model
    opt_res = model.fit(
        logfreq=200,
        reg_escape_weight=0.1,
        reg_similarity_weight=weight,  # regularize epitope similarity
    )
    
#     models.append(model)
    model.mut_escape_plot()

# First fitting site-level model.
# Starting optimization of 1054 parameters at Mon Oct 31 19:53:55 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.072465       1124.7       1123.8            0            0            0      0.90499
           77       6.7845       479.25       471.29       5.8474            0     0.044585       2.0675
# Successfully finished at Mon Oct 31 19:54:02 2022.
# Starting optimization of 6922 parameters at Mon Oct 31 19:54:02 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0      0.69621       685.57        611.3       64.986   1.8522e-31       7.2158       2.0675
           47       5.3702       612.67       606.64         2.85      0.23588     0.021092       2.9235
# Successfully finished at Mon Oct 31 19:54:07 2022.


  for col_name, dtype in df.dtypes.iteritems():


# First fitting site-level model.
# Starting optimization of 1054 parameters at Mon Oct 31 19:54:13 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.064265       1124.7       1123.8            0            0            0      0.90499
           78       5.5411       479.28       471.35       5.8157            0     0.061049       2.0583
# Successfully finished at Mon Oct 31 19:54:19 2022.
# Starting optimization of 6922 parameters at Mon Oct 31 19:54:19 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.076901       688.01        611.3       64.737   3.0156e-31       9.9128       2.0583
           43       3.6095       612.69       606.72       2.7886      0.23255     0.028145       2.9197
# Successfully finished at Mon Oct 31 19:54:22 2022.


  for col_name, dtype in df.dtypes.iteritems():


# First fitting site-level model.
# Starting optimization of 1054 parameters at Mon Oct 31 19:54:28 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.069191       1124.7       1123.8            0            0            0      0.90499
           83       5.9666       479.31       471.34       5.8411            0     0.074254       2.0516
# Successfully finished at Mon Oct 31 19:54:34 2022.
# Starting optimization of 6922 parameters at Mon Oct 31 19:54:34 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.077148       690.45        611.3       65.009   3.2245e-31       12.091       2.0516
           36       3.5982        612.7       606.75       2.7664      0.23412      0.03275       2.9187
# Successfully finished at Mon Oct 31 19:54:38 2022.


  for col_name, dtype in df.dtypes.iteritems():


# First fitting site-level model.
# Starting optimization of 1054 parameters at Mon Oct 31 19:54:44 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.071886       1124.7       1123.8            0            0            0      0.90499
           81       5.9828       479.33       471.37       5.8314            0     0.084072       2.0484
# Successfully finished at Mon Oct 31 19:54:50 2022.
# Starting optimization of 6922 parameters at Mon Oct 31 19:54:50 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.077523       691.94        611.3       64.933   2.8165e-31       13.662       2.0484
           40       3.4808       612.71        606.8       2.7207      0.23164     0.035873       2.9181
# Successfully finished at Mon Oct 31 19:54:54 2022.


  for col_name, dtype in df.dtypes.iteritems():


# First fitting site-level model.
# Starting optimization of 1054 parameters at Mon Oct 31 19:54:59 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.065823       1124.7       1123.8            0            0            0      0.90499
           69       5.0062       479.35        471.4       5.8183            0     0.091371        2.045
# Successfully finished at Mon Oct 31 19:55:04 2022.
# Starting optimization of 6922 parameters at Mon Oct 31 19:55:05 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.075452       693.02       611.29        64.94   3.2759e-31       14.748        2.045
           39       3.4059       612.72       606.83        2.706      0.23127     0.038789       2.9176
# Successfully finished at Mon Oct 31 19:55:08 2022.


  for col_name, dtype in df.dtypes.iteritems():


# First fitting site-level model.
# Starting optimization of 1054 parameters at Mon Oct 31 19:55:14 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.070792       1124.7       1123.8            0            0            0      0.90499
           71       4.9201       479.37       471.39       5.8349            0     0.097266       2.0416
# Successfully finished at Mon Oct 31 19:55:19 2022.
# Starting optimization of 6922 parameters at Mon Oct 31 19:55:19 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.082014       693.94       611.28       65.038   2.7927e-31       15.579       2.0416
           55       4.6294       612.73       606.82       2.7109      0.23389     0.040398       2.9167
# Successfully finished at Mon Oct 31 19:55:23 2022.


  for col_name, dtype in df.dtypes.iteritems():


# First fitting site-level model.
# Starting optimization of 1054 parameters at Mon Oct 31 19:55:29 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.067537       1124.7       1123.8            0            0            0      0.90499
           68       4.9231       479.38       471.42       5.8196            0      0.10098       2.0395
# Successfully finished at Mon Oct 31 19:55:34 2022.
# Starting optimization of 6922 parameters at Mon Oct 31 19:55:34 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.076076        694.3        611.3       64.939   2.3954e-31       16.023       2.0395
           52       4.4977       612.73       606.84       2.7003      0.23371     0.042142       2.9168
# Successfully finished at Mon Oct 31 19:55:39 2022.


  for col_name, dtype in df.dtypes.iteritems():


# First fitting site-level model.
# Starting optimization of 1054 parameters at Mon Oct 31 19:55:45 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.064726       1124.7       1123.8            0            0            0      0.90499
           69        4.862        479.4       471.44       5.8144            0        0.105       2.0403
# Successfully finished at Mon Oct 31 19:55:49 2022.
# Starting optimization of 6922 parameters at Mon Oct 31 19:55:50 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.074867       694.75       611.25       64.945   2.6095e-31       16.518       2.0403
           54       4.4129       612.74       606.82        2.722      0.23712     0.043919       2.9147
# Successfully finished at Mon Oct 31 19:55:54 2022.


  for col_name, dtype in df.dtypes.iteritems():


# First fitting site-level model.
# Starting optimization of 1054 parameters at Mon Oct 31 19:56:00 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.072354       1124.7       1123.8            0            0            0      0.90499
           70       5.6137       479.41       471.45       5.8192            0      0.10724       2.0358
# Successfully finished at Mon Oct 31 19:56:05 2022.
# Starting optimization of 6922 parameters at Mon Oct 31 19:56:06 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0      0.08636       695.16       611.29       64.934   3.6771e-31       16.898       2.0358
           55       5.5014       612.74       606.83       2.7138      0.23697     0.045006       2.9155
# Successfully finished at Mon Oct 31 19:56:11 2022.


  for col_name, dtype in df.dtypes.iteritems():


In [33]:
prob_escape = pd.read_csv(
    "https://raw.githubusercontent.com/jbloomlab/polyclonal/main/notebooks/libA_220810_1_1C04-5G04_1_prob_escape.csv",
    keep_default_na=False, na_values="nan"
).query(
    "`no-antibody_count` >= no_antibody_count_threshold"
)  # filter for those with sufficient no-antibody counts
assert prob_escape.notnull().all().all()
prob_escape.head()

Unnamed: 0,library,antibody_sample,no-antibody_sample,aa_substitutions_sequential,n_aa_substitutions,barcode,prob_escape,prob_escape_uncensored,antibody_count,no-antibody_count,antibody_neut_standard_count,no-antibody_neut_standard_count,total_no_antibody_count,no_antibody_count_threshold,aa_substitutions_reference,antibody,antibody_concentration
0,libA,220810_1_antibody_1C04-5G04_3.65_1,220810_1_no-antibody_control_1,K297I,1,ATAACACAAAAAAGTA,0.0017,0.0017,78972,246344,8599550,44895,10428350,15,K278I,1C04-5G04,3.65
1,libA,220810_1_antibody_1C04-5G04_3.65_1,220810_1_no-antibody_control_1,R111S V366M R402S,3,TATCTACCTAACGAAA,0.0047,0.0047,70662,78014,8599550,44895,10428350,15,R92S V347M R383S,1C04-5G04,3.65
2,libA,220810_1_antibody_1C04-5G04_3.65_1,220810_1_no-antibody_control_1,A125M P246H I393A G398Q F411Y,5,CTTTCAATTATGAGAC,0.037,0.037,57908,8163,8599550,44895,10428350,15,A106M P227H I374A G379Q F392Y,1C04-5G04,3.65
3,libA,220810_1_antibody_1C04-5G04_3.65_1,220810_1_no-antibody_control_1,Y113M S143N S164N I307M I393Y E468Q,6,TGTATTAGCATTTTGA,0.0074,0.0074,37740,26593,8599550,44895,10428350,15,Y94M S124N S145N I288M I374Y E449Q,1C04-5G04,3.65
4,libA,220810_1_antibody_1C04-5G04_3.65_1,220810_1_no-antibody_control_1,G237H P246H V366M,3,CCAAGGAGCACGAAAA,0.0218,0.0218,26699,6407,8599550,44895,10428350,15,G218H P227H V347M,1C04-5G04,3.65


In [34]:
# NBVAL_IGNORE_OUTPUT
model = polyclonal.Polyclonal(
    n_epitopes=2,
    data_to_fit=prob_escape.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_sequential": "aa_substitutions",
        }
    ).query("concentration > 1.0"),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
)

# fit model
opt_res = model.fit(
    logfreq=200,
    reg_escape_weight=0.1,
    reg_spread_weight=0.25,
    reg_activity_weight=1.0,
    reg_similarity_weight=1.0,  # regularize epitope similarity
)
model.mut_escape_plot(category_colors={"1": "#0072B2", "2": "#CC79A7"})

# First fitting site-level model.
# Starting optimization of 1026 parameters at Mon Oct 31 21:46:43 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.032203       264.24       263.33            0            0            0      0.90499
          173       6.2841       125.87       121.59       2.9759            0     0.072274       1.2303
# Successfully finished at Mon Oct 31 21:46:50 2022.
# Starting optimization of 6736 parameters at Mon Oct 31 21:46:50 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.036241       194.97       157.43       27.877   1.1857e-31       8.4309       1.2303
          110       4.7972       165.55        153.8       9.5379      0.38838      0.39313       1.4347
# Successfully finished at Mon Oct 31 21:46:54 2022.


  for col_name, dtype in df.dtypes.iteritems():
