# Fit polyclonal model
Here we fit [polyclonal](https://jbloomlab.github.io/polyclonal) models to the data.

First, import Python modules:

In [1]:
import pickle

import altair as alt

import pandas as pd

import polyclonal

import yaml

In [2]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

In [3]:
import os
os.chdir('../../')

## Read input data

Get parameterized variable from [papermill](https://papermill.readthedocs.io/)

In [4]:
# papermill parameters cell (tagged as `parameters`)
prob_escape_csv = None
n_threads = None
pickle_file = None
antibody = None

In [5]:
# Parameters
prob_escape_csv = "results/prob_escape/libA_221021_1_1C04-5G04_1_prob_escape.csv"
pickle_file = "results/polyclonal_fits/libA_221021_1_1C04-5G04_1.pickle"
n_threads = 2


Read the probabilities of escape, and filter for those with sufficient no-antibody counts:

In [6]:
print(f"\nReading probabilities of escape from {prob_escape_csv}")

prob_escape = pd.read_csv(
    prob_escape_csv, keep_default_na=False, na_values="nan"
).query("`no-antibody_count` >= no_antibody_count_threshold")
assert prob_escape.notnull().all().all()


Reading probabilities of escape from results/prob_escape/libA_221021_1_1C04-5G04_1_prob_escape.csv


Read the rest of the configuration and input data:

In [7]:
# get information from config
with open("config.yaml") as f:
    config = yaml.safe_load(f)

antibody = prob_escape["antibody"].unique()
assert len(antibody) == 1, antibody
antibody = antibody[0]

# get the reference sites in order
reference_sites = (
    pd.read_csv(config["site_numbering_map"])
    .sort_values("sequential_site")["reference_site"]
    .tolist()
)

# get the polyclonal configuration for this antibody
with open(config["polyclonal_config"]) as f:
    polyclonal_config = yaml.safe_load(f)
if antibody not in polyclonal_config:
    raise ValueError(f"`polyclonal_config` lacks configuration for {antibody=}")
antibody_config = polyclonal_config[antibody]

# print names of variables and settings
print(f"{antibody=}")
print(f"{n_threads=}")
print(f"{pickle_file=}")
print(f"{antibody_config=}")

antibody='1C04-5G04'
n_threads=2
pickle_file='results/polyclonal_fits/libA_221021_1_1C04-5G04_1.pickle'
antibody_config={'max_epitopes': 2, 'n_bootstrap_samples': 50, 'reg_escape_weight': 0.1, 'reg_spread_weight': 0.25, 'reg_activity_weight': 1.0, 'times_seen': 3, 'min_epitope_activity_to_include': 0.2}


## Some summary statistics
Note that these statistics are only for the variants that passed upstream filtering in the pipeline.

Number of variants per concentration:

In [8]:
display(
    prob_escape.groupby("antibody_concentration").aggregate(
        n_variants=pd.NamedAgg("barcode", "nunique")
    )
)

Unnamed: 0_level_0,n_variants
antibody_concentration,Unnamed: 1_level_1
1.55,35041
3.1,35041
6.2,35041
12.4,35041
24.8,35041


Plot mean probability of escape across all variants with the indicated number of mutations.
Note that this plot weights each variant the same in the means regardless of how many barcode counts it has.
We plot means for both censored (set to between 0 and 1) and uncensored probabilities of escape.
Also, note it uses a symlog scale for the y-axis.
Mouseover points for values:

In [44]:
max_aa_subs = 4  # group if >= this many substitutions

mean_prob_escape = (
    prob_escape.assign(
        n_subs=lambda x: (
            x["aa_substitutions_reference"]
            .str.split()
            .map(len)
            .clip(upper=max_aa_subs)
            .map(lambda n: str(n) if n < max_aa_subs else f">{max_aa_subs - 1}")
        )
    )
    .groupby(["antibody_concentration", "n_subs"], as_index=False)
    .aggregate({"prob_escape": "mean", "prob_escape_uncensored": "mean"})
    .rename(
        columns={
            "prob_escape": "censored to [0, 1]",
            "prob_escape_uncensored": "not censored",
        }
    )
    .melt(
        id_vars=["antibody_concentration", "n_subs"],
        var_name="censored",
        value_name="probability escape",
    )
)

mean_prob_escape_chart = (
    alt.Chart(mean_prob_escape)
    .encode(
        x=alt.X("antibody_concentration"),
        y=alt.Y(
            "probability escape",
            scale=alt.Scale(type="symlog", constant=0.05),
        ),
        column=alt.Column("censored", title=None),
        color=alt.Color("n_subs", title="n substitutions"),
        tooltip=[
            alt.Tooltip(c, format=".3g") if mean_prob_escape[c].dtype == float else c
            for c in mean_prob_escape.columns
        ],
    )
    .mark_line(point=True, size=0.5)
    .properties(width=200, height=125)
    .configure_axis(grid=False)
)

mean_prob_escape_chart

  for col_name, dtype in df.dtypes.iteritems():


## Fit `polyclonal` model
First, get the fitting related keyword arguments from the configuration passed by `snakemake`:

In [10]:
times_seen = antibody_config["times_seen"]
print(f"{times_seen=}")

max_epitopes = antibody_config["max_epitopes"]
print(f"{max_epitopes=}")

fit_kwargs = {
    "reg_escape_weight": antibody_config["reg_escape_weight"],
    "reg_spread_weight": antibody_config["reg_spread_weight"],
    "reg_activity_weight": antibody_config["reg_activity_weight"],
}
print(f"{fit_kwargs=}")

min_epitope_activity_to_include = antibody_config["min_epitope_activity_to_include"]
print(f"{min_epitope_activity_to_include=}")

times_seen=3
max_epitopes=2
fit_kwargs={'reg_escape_weight': 0.1, 'reg_spread_weight': 0.25, 'reg_activity_weight': 1.0}
min_epitope_activity_to_include=0.2


Fit a model to all the data, and keep adding epitopes until we either reach the maximum specified or the new epitope has negative activity.
Note that that we fit using the **reference** based-site-numbering scheme, so results are shown with those numbers:Z

In [40]:
models = []

for n_epitopes in range(1, max_epitopes + 1):
    print(f"\nFitting model with {n_epitopes=}")

    # create model
    model = polyclonal.Polyclonal(
        n_epitopes=n_epitopes,
        data_to_fit=prob_escape.rename(
            columns={
                "antibody_concentration": "concentration",
                "aa_substitutions_reference": "aa_substitutions",
            }
        ),
        alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
        sites=reference_sites,
    )

    # fit model
    opt_res = model.fit(logfreq=200, **fit_kwargs)

    # display activities
    print("Activities of epitopes:")
    display(model.activity_wt_df.round(1))
    print("Max and mean absolute-value escape at each epitope:")
    display(
        model.mut_escape_df.groupby("epitope")
        .aggregate(
            max_escape=pd.NamedAgg("escape", "max"),
            mean_abs_escape=pd.NamedAgg("escape", lambda s: s.abs().mean()),
        )
        .round(1)
    )

    # stop if activity below threshold for any epitope and fit at least one epitope
    if len(models) and any(
        model.activity_wt_df["activity"] <= min_epitope_activity_to_include
    ):
        print(f"Stop fitting, epitope has activity <={min_epitope_activity_to_include}")
        models.append(model)
        model = models[-2]  # get previous model
        break
    else:
        models.append(model)

print(f"\nThe selected model has {len(model.epitopes)} epitopes")


Fitting model with n_epitopes=1
# First fitting site-level model.
# Starting optimization of 527 parameters at Mon Oct 24 12:33:22 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.035417       5983.9         5983            0            0      0.90499
           45       1.6973       485.83       477.69       4.9632            0       3.1712
# Successfully finished at Mon Oct 24 12:33:23 2022.
# Starting optimization of 3461 parameters at Mon Oct 24 12:33:24 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.039234       671.84       611.94       56.734   8.0935e-32       3.1712
           35       1.5967       614.86       608.74       1.9215      0.17444       4.0264
# Successfully finished at Mon Oct 24 12:33:25 2022.
Activities of epitopes:


Unnamed: 0,epitope,activity
0,1,4.1


Max and mean absolute-value escape at each epitope:


Unnamed: 0_level_0,max_escape,mean_abs_escape
epitope,Unnamed: 1_level_1,Unnamed: 2_level_1
1,1.9,0.0



Fitting model with n_epitopes=2
# First fitting site-level model.
# Starting optimization of 1054 parameters at Mon Oct 24 12:33:30 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.071037       1124.7       1123.8            0            0      0.90499
           40       2.8082       479.21       471.25       5.8777            0       2.0813
# Successfully finished at Mon Oct 24 12:33:33 2022.
# Starting optimization of 6922 parameters at Mon Oct 24 12:33:33 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.080193        678.5       611.29       65.128   1.7057e-31       2.0813
           41       4.1886       612.63       606.14       3.1866      0.19319        3.106
# Successfully finished at Mon Oct 24 12:33:37 2022.
Activities of epitopes:


Unnamed: 0,epitope,activity
0,1,1.8
1,2,1.5


Max and mean absolute-value escape at each epitope:


Unnamed: 0_level_0,max_escape,mean_abs_escape
epitope,Unnamed: 1_level_1,Unnamed: 2_level_1
1,1.3,0.0
2,1.1,0.0



The selected model has 2 epitopes


Epitope activities:

In [41]:
model.activity_wt_barplot()

  for col_name, dtype in df.dtypes.iteritems():


Line plot of escape at each site:

In [42]:
model.mut_escape_plot()

  for col_name, dtype in df.dtypes.iteritems():


## Test modeling with different selections included

In [15]:
display(
    prob_escape.groupby("antibody_concentration").aggregate(
        n_variants=pd.NamedAgg("barcode", "nunique")
    )
)

Unnamed: 0_level_0,n_variants
antibody_concentration,Unnamed: 1_level_1
1.55,35041
3.1,35041
6.2,35041
12.4,35041
24.8,35041


### Start by dropping the two lowest selection concentrations

In [74]:
prob_escape_high = prob_escape.loc[(prob_escape['antibody_concentration'] != 1.55) 
                             ]

In [75]:
max_aa_subs = 4  # group if >= this many substitutions

mean_prob_escape_high = (
    prob_escape_high.assign(
        n_subs=lambda x: (
            x["aa_substitutions_reference"]
            .str.split()
            .map(len)
            .clip(upper=max_aa_subs)
            .map(lambda n: str(n) if n < max_aa_subs else f">{max_aa_subs - 1}")
        )
    )
    .groupby(["antibody_concentration", "n_subs"], as_index=False)
    .aggregate({"prob_escape": "mean", "prob_escape_uncensored": "mean"})
    .rename(
        columns={
            "prob_escape": "censored to [0, 1]",
            "prob_escape_uncensored": "not censored",
        }
    )
    .melt(
        id_vars=["antibody_concentration", "n_subs"],
        var_name="censored",
        value_name="probability escape",
    )
)

mean_prob_escape_chart_high = (
    alt.Chart(mean_prob_escape_high)
    .encode(
        x=alt.X("antibody_concentration"),
        y=alt.Y(
            "probability escape",
            scale=alt.Scale(type="symlog", constant=0.05),
        ),
        column=alt.Column("censored", title=None),
        color=alt.Color("n_subs", title="n substitutions"),
        tooltip=[
            alt.Tooltip(c, format=".3g") if mean_prob_escape[c].dtype == float else c
            for c in mean_prob_escape.columns
        ],
    )
    .mark_line(point=True, size=0.5)
    .properties(width=200, height=125)
    .configure_axis(grid=False)
)

mean_prob_escape_chart_high

  for col_name, dtype in df.dtypes.iteritems():


In [76]:
times_seen = antibody_config["times_seen"]
print(f"{times_seen=}")

max_epitopes = antibody_config["max_epitopes"]
print(f"{max_epitopes=}")

fit_kwargs = {
    "reg_escape_weight": antibody_config["reg_escape_weight"],
    "reg_spread_weight": antibody_config["reg_spread_weight"],
    "reg_activity_weight": antibody_config["reg_activity_weight"],
}
print(f"{fit_kwargs=}")

min_epitope_activity_to_include = antibody_config["min_epitope_activity_to_include"]
print(f"{min_epitope_activity_to_include=}")

times_seen=3
max_epitopes=2
fit_kwargs={'reg_escape_weight': 0.1, 'reg_spread_weight': 0.25, 'reg_activity_weight': 1.0}
min_epitope_activity_to_include=0.2


In [77]:
models = []

for n_epitopes in range(1, max_epitopes + 1):
    print(f"\nFitting model with {n_epitopes=}")

    # create model
    model_high = polyclonal.Polyclonal(
        n_epitopes=n_epitopes,
        data_to_fit=prob_escape_high.rename(
            columns={
                "antibody_concentration": "concentration",
                "aa_substitutions_reference": "aa_substitutions",
            }
        ),
        alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
        sites=reference_sites,
    )

    # fit model
    opt_res = model_high.fit(logfreq=200, **fit_kwargs)

    # display activities
    print("Activities of epitopes:")
    display(model_high.activity_wt_df.round(1))
    print("Max and mean absolute-value escape at each epitope:")
    display(
        model_high.mut_escape_df.groupby("epitope")
        .aggregate(
            max_escape=pd.NamedAgg("escape", "max"),
            mean_abs_escape=pd.NamedAgg("escape", lambda s: s.abs().mean()),
        )
        .round(1)
    )

    # stop if activity below threshold for any epitope and fit at least one epitope
    if len(models) and any(
        model_high.activity_wt_df["activity"] <= min_epitope_activity_to_include
    ):
        print(f"Stop fitting, epitope has activity <={min_epitope_activity_to_include}")
        models.append(model_high)
        model_high = models[-2]  # get previous model
        break
    else:
        models.append(model_high)

print(f"\nThe selected model has {len(model_high.epitopes)} epitopes")


Fitting model with n_epitopes=1
# First fitting site-level model.
# Starting optimization of 527 parameters at Mon Oct 24 12:51:29 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0      0.03236       2190.9         2190            0            0      0.90499
           20      0.59707       35.155       30.836     0.032326            0       4.2868
# Successfully finished at Mon Oct 24 12:51:30 2022.
# Starting optimization of 3461 parameters at Mon Oct 24 12:51:30 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.033941        42.48       37.688      0.50513   6.0072e-34       4.2868
            7      0.27144       42.024       37.665    0.0086333   0.00026899         4.35
# Successfully finished at Mon Oct 24 12:51:30 2022.
Activities of epitopes:


Unnamed: 0,epitope,activity
0,1,4.4


Max and mean absolute-value escape at each epitope:


Unnamed: 0_level_0,max_escape,mean_abs_escape
epitope,Unnamed: 1_level_1,Unnamed: 2_level_1
1,0.1,0.0



Fitting model with n_epitopes=2
# First fitting site-level model.
# Starting optimization of 1054 parameters at Mon Oct 24 12:51:34 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.052794          130       129.09            0            0      0.90499
           26       1.5866       34.319       31.471     0.044036            0       2.8039
# Successfully finished at Mon Oct 24 12:51:36 2022.
# Starting optimization of 6922 parameters at Mon Oct 24 12:51:36 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.066822        41.81        38.32      0.68656   8.7225e-34       2.8039
           10       0.6864       41.193       38.289     0.011819   0.00035269       2.8913
# Successfully finished at Mon Oct 24 12:51:37 2022.
Activities of epitopes:


Unnamed: 0,epitope,activity
0,1,2.9
1,2,0.1


Max and mean absolute-value escape at each epitope:


Unnamed: 0_level_0,max_escape,mean_abs_escape
epitope,Unnamed: 1_level_1,Unnamed: 2_level_1
1,0.0,0.0
2,0.0,0.0


Stop fitting, epitope has activity <=0.2

The selected model has 1 epitopes


In [78]:
model_high.activity_wt_barplot()

  for col_name, dtype in df.dtypes.iteritems():


In [79]:
model_high.mut_escape_plot()

  for col_name, dtype in df.dtypes.iteritems():


In [27]:
prob_escape_low = prob_escape.loc[(prob_escape['antibody_concentration'] != 24.80)
                             ]

In [46]:
max_aa_subs = 4  # group if >= this many substitutions

mean_prob_escape_low = (
    prob_escape_low.assign(
        n_subs=lambda x: (
            x["aa_substitutions_reference"]
            .str.split()
            .map(len)
            .clip(upper=max_aa_subs)
            .map(lambda n: str(n) if n < max_aa_subs else f">{max_aa_subs - 1}")
        )
    )
    .groupby(["antibody_concentration", "n_subs"], as_index=False)
    .aggregate({"prob_escape": "mean", "prob_escape_uncensored": "mean"})
    .rename(
        columns={
            "prob_escape": "censored to [0, 1]",
            "prob_escape_uncensored": "not censored",
        }
    )
    .melt(
        id_vars=["antibody_concentration", "n_subs"],
        var_name="censored",
        value_name="probability escape",
    )
)

mean_prob_escape_chart_low = (
    alt.Chart(mean_prob_escape_low)
    .encode(
        x=alt.X("antibody_concentration"),
        y=alt.Y(
            "probability escape",
            scale=alt.Scale(type="symlog", constant=0.05),
        ),
        column=alt.Column("censored", title=None),
        color=alt.Color("n_subs", title="n substitutions"),
        tooltip=[
            alt.Tooltip(c, format=".3g") if mean_prob_escape[c].dtype == float else c
            for c in mean_prob_escape.columns
        ],
    )
    .mark_line(point=True, size=0.5)
    .properties(width=200, height=125)
    .configure_axis(grid=False)
)

mean_prob_escape_chart_low

  for col_name, dtype in df.dtypes.iteritems():


In [29]:
times_seen = antibody_config["times_seen"]
print(f"{times_seen=}")

max_epitopes = antibody_config["max_epitopes"]
print(f"{max_epitopes=}")

fit_kwargs = {
    "reg_escape_weight": antibody_config["reg_escape_weight"],
    "reg_spread_weight": antibody_config["reg_spread_weight"],
    "reg_activity_weight": antibody_config["reg_activity_weight"],
}
print(f"{fit_kwargs=}")

min_epitope_activity_to_include = antibody_config["min_epitope_activity_to_include"]
print(f"{min_epitope_activity_to_include=}")

times_seen=3
max_epitopes=2
fit_kwargs={'reg_escape_weight': 0.1, 'reg_spread_weight': 0.25, 'reg_activity_weight': 1.0}
min_epitope_activity_to_include=0.2


In [30]:
models = []

for n_epitopes in range(1, max_epitopes + 1):
    print(f"\nFitting model with {n_epitopes=}")

    # create model
    model_low = polyclonal.Polyclonal(
        n_epitopes=n_epitopes,
        data_to_fit=prob_escape_low.rename(
            columns={
                "antibody_concentration": "concentration",
                "aa_substitutions_reference": "aa_substitutions",
            }
        ),
        alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
        sites=reference_sites,
    )

    # fit model
    opt_res = model_low.fit(logfreq=200, **fit_kwargs)

    # display activities
    print("Activities of epitopes:")
    display(model_low.activity_wt_df.round(1))
    print("Max and mean absolute-value escape at each epitope:")
    display(
        model_low.mut_escape_df.groupby("epitope")
        .aggregate(
            max_escape=pd.NamedAgg("escape", "max"),
            mean_abs_escape=pd.NamedAgg("escape", lambda s: s.abs().mean()),
        )
        .round(1)
    )

    # stop if activity below threshold for any epitope and fit at least one epitope
    if len(models) and any(
        model_low.activity_wt_df["activity"] <= min_epitope_activity_to_include
    ):
        print(f"Stop fitting, epitope has activity <={min_epitope_activity_to_include}")
        models.append(model_low)
        model_low = models[-2]  # get previous model
        break
    else:
        models.append(model_low)

print(f"\nThe selected model has {len(model_low.epitopes)} epitopes")


Fitting model with n_epitopes=1
# First fitting site-level model.
# Starting optimization of 527 parameters at Mon Oct 24 12:30:21 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0      0.03997       5950.7       5949.8            0            0      0.90499
           44       1.4647       482.77       474.59       5.0121            0       3.1688
# Successfully finished at Mon Oct 24 12:30:22 2022.
# Starting optimization of 3461 parameters at Mon Oct 24 12:30:22 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0      0.03512        668.8       608.34       57.294   1.0314e-31       3.1688
           33        1.289       611.25       605.11       1.9345      0.17811       4.0312
# Successfully finished at Mon Oct 24 12:30:23 2022.
Activities of epitopes:


Unnamed: 0,epitope,activity
0,1,4.1


Max and mean absolute-value escape at each epitope:


Unnamed: 0_level_0,max_escape,mean_abs_escape
epitope,Unnamed: 1_level_1,Unnamed: 2_level_1
1,1.9,0.0



Fitting model with n_epitopes=2
# First fitting site-level model.
# Starting optimization of 1054 parameters at Mon Oct 24 12:30:27 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.060841       1121.7       1120.7            0            0      0.90499
           41        2.397       475.93       467.97       5.8816            0       2.0812
# Successfully finished at Mon Oct 24 12:30:30 2022.
# Starting optimization of 6922 parameters at Mon Oct 24 12:30:30 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.064695       674.75       607.46       65.211   1.7561e-31       2.0812
           47        3.224        608.8       602.31       3.1863      0.19379       3.1074
# Successfully finished at Mon Oct 24 12:30:33 2022.
Activities of epitopes:


Unnamed: 0,epitope,activity
0,1,1.8
1,2,1.5


Max and mean absolute-value escape at each epitope:


Unnamed: 0_level_0,max_escape,mean_abs_escape
epitope,Unnamed: 1_level_1,Unnamed: 2_level_1
1,1.3,0.0
2,1.2,0.0



The selected model has 2 epitopes


In [31]:
model_low.activity_wt_barplot()

  for col_name, dtype in df.dtypes.iteritems():


In [33]:
model_low.mut_escape_plot()

  for col_name, dtype in df.dtypes.iteritems():


In [34]:
prob_escape_lower = prob_escape.loc[(prob_escape['antibody_concentration'] != 24.80) & 
                              (prob_escape['antibody_concentration'] != 12.40)
                             ]

In [48]:
max_aa_subs = 4  # group if >= this many substitutions

mean_prob_escape_lower = (
    prob_escape_lower.assign(
        n_subs=lambda x: (
            x["aa_substitutions_reference"]
            .str.split()
            .map(len)
            .clip(upper=max_aa_subs)
            .map(lambda n: str(n) if n < max_aa_subs else f">{max_aa_subs - 1}")
        )
    )
    .groupby(["antibody_concentration", "n_subs"], as_index=False)
    .aggregate({"prob_escape": "mean", "prob_escape_uncensored": "mean"})
    .rename(
        columns={
            "prob_escape": "censored to [0, 1]",
            "prob_escape_uncensored": "not censored",
        }
    )
    .melt(
        id_vars=["antibody_concentration", "n_subs"],
        var_name="censored",
        value_name="probability escape",
    )
)

mean_prob_escape_chart_lower = (
    alt.Chart(mean_prob_escape_lower)
    .encode(
        x=alt.X("antibody_concentration"),
        y=alt.Y(
            "probability escape",
            scale=alt.Scale(type="symlog", constant=0.05),
        ),
        column=alt.Column("censored", title=None),
        color=alt.Color("n_subs", title="n substitutions"),
        tooltip=[
            alt.Tooltip(c, format=".3g") if mean_prob_escape[c].dtype == float else c
            for c in mean_prob_escape.columns
        ],
    )
    .mark_line(point=True, size=0.5)
    .properties(width=200, height=125)
    .configure_axis(grid=False)
)

mean_prob_escape_chart_lower

  for col_name, dtype in df.dtypes.iteritems():


In [36]:
times_seen = antibody_config["times_seen"]
print(f"{times_seen=}")

max_epitopes = antibody_config["max_epitopes"]
print(f"{max_epitopes=}")

fit_kwargs = {
    "reg_escape_weight": antibody_config["reg_escape_weight"],
    "reg_spread_weight": antibody_config["reg_spread_weight"],
    "reg_activity_weight": antibody_config["reg_activity_weight"],
}
print(f"{fit_kwargs=}")

min_epitope_activity_to_include = antibody_config["min_epitope_activity_to_include"]
print(f"{min_epitope_activity_to_include=}")

times_seen=3
max_epitopes=2
fit_kwargs={'reg_escape_weight': 0.1, 'reg_spread_weight': 0.25, 'reg_activity_weight': 1.0}
min_epitope_activity_to_include=0.2


In [37]:
models = []

for n_epitopes in range(1, max_epitopes + 1):
    print(f"\nFitting model with {n_epitopes=}")

    # create model
    model_lower = polyclonal.Polyclonal(
        n_epitopes=n_epitopes,
        data_to_fit=prob_escape_lower.rename(
            columns={
                "antibody_concentration": "concentration",
                "aa_substitutions_reference": "aa_substitutions",
            }
        ),
        alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
        sites=reference_sites,
    )

    # fit model
    opt_res = model_lower.fit(logfreq=200, **fit_kwargs)

    # display activities
    print("Activities of epitopes:")
    display(model_lower.activity_wt_df.round(1))
    print("Max and mean absolute-value escape at each epitope:")
    display(
        model_lower.mut_escape_df.groupby("epitope")
        .aggregate(
            max_escape=pd.NamedAgg("escape", "max"),
            mean_abs_escape=pd.NamedAgg("escape", lambda s: s.abs().mean()),
        )
        .round(1)
    )

    # stop if activity below threshold for any epitope and fit at least one epitope
    if len(models) and any(
        model_lower.activity_wt_df["activity"] <= min_epitope_activity_to_include
    ):
        print(f"Stop fitting, epitope has activity <={min_epitope_activity_to_include}")
        models.append(model_lower)
        model_lower = models[-2]  # get previous model
        break
    else:
        models.append(model_lower)

print(f"\nThe selected model has {len(model_lower.epitopes)} epitopes")


Fitting model with n_epitopes=1
# First fitting site-level model.
# Starting optimization of 527 parameters at Mon Oct 24 12:32:33 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0      0.02639       5820.2       5819.3            0            0      0.90499
           44       1.2039       479.88       471.59       5.1324            0       3.1589
# Successfully finished at Mon Oct 24 12:32:34 2022.
# Starting optimization of 3461 parameters at Mon Oct 24 12:32:34 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.028828       666.77       605.04       58.567   8.0813e-32       3.1589
           37       1.3178       607.95       601.74       1.9853      0.18605        4.033
# Successfully finished at Mon Oct 24 12:32:36 2022.
Activities of epitopes:


Unnamed: 0,epitope,activity
0,1,4.1


Max and mean absolute-value escape at each epitope:


Unnamed: 0_level_0,max_escape,mean_abs_escape
epitope,Unnamed: 1_level_1,Unnamed: 2_level_1
1,2.0,0.0



Fitting model with n_epitopes=2
# First fitting site-level model.
# Starting optimization of 1054 parameters at Mon Oct 24 12:32:39 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.047302       1118.9         1118            0            0      0.90499
           40       1.8801       472.95       464.98       5.8892            0        2.082
# Successfully finished at Mon Oct 24 12:32:41 2022.
# Starting optimization of 6922 parameters at Mon Oct 24 12:32:41 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.052794       671.32        603.9        65.34   2.4947e-31        2.082
           40       2.3122       605.24       598.76       3.1797      0.19336       3.1111
# Successfully finished at Mon Oct 24 12:32:43 2022.
Activities of epitopes:


Unnamed: 0,epitope,activity
0,1,1.8
1,2,1.5


Max and mean absolute-value escape at each epitope:


Unnamed: 0_level_0,max_escape,mean_abs_escape
epitope,Unnamed: 1_level_1,Unnamed: 2_level_1
1,1.3,0.0
2,1.2,0.0



The selected model has 2 epitopes


In [38]:
model_lower.activity_wt_barplot()

  for col_name, dtype in df.dtypes.iteritems():


In [39]:
model_lower.mut_escape_plot()

  for col_name, dtype in df.dtypes.iteritems():


## Modeling at different selections summary

### All 5 concentrations

In [49]:
mean_prob_escape_chart

In [43]:
model.mut_escape_plot()

  for col_name, dtype in df.dtypes.iteritems():


## Highest 3 concentrations

In [50]:
mean_prob_escape_chart_high

In [51]:
model_high.mut_escape_plot()

  for col_name, dtype in df.dtypes.iteritems():


## Lowest 4 concentraitons

In [52]:
mean_prob_escape_chart_low

In [53]:
model_low.mut_escape_plot()

  for col_name, dtype in df.dtypes.iteritems():


In [54]:
mean_prob_escape_chart_lower

In [55]:
model_lower.mut_escape_plot()

  for col_name, dtype in df.dtypes.iteritems():
