# Fit polyclonal model
Here we fit [polyclonal](https://jbloomlab.github.io/polyclonal) models to the data.

First, import Python modules:

In [1]:
import pickle

import altair as alt

import pandas as pd

import polyclonal

import yaml

In [2]:
import os
os.chdir('../../')

In [3]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

## Read input data

Get parameterized variable from [papermill](https://papermill.readthedocs.io/)

In [4]:
# papermill parameters cell (tagged as `parameters`)
prob_escape_csv = None
n_threads = None
pickle_file = None
antibody = None

In [5]:
# Parameters
prob_escape_csv = "results/prob_escape/libB_221108_1_1C04-5G04_1_prob_escape.csv"
pickle_file = "results/polyclonal_fits/libB_221108_1_1C04-5G04_1.pickle"
n_threads = 2


Read the probabilities of escape, and filter for those with sufficient no-antibody counts:

In [49]:
print(f"\nReading probabilities of escape from {prob_escape_csv}")

prob_escape = pd.read_csv(
    prob_escape_csv, keep_default_na=False, na_values="nan"
).query("`no-antibody_count` >= no_antibody_count_threshold")
assert prob_escape.notnull().all().all()


Reading probabilities of escape from results/prob_escape/libB_221108_1_1C04-5G04_1_prob_escape.csv


Read the rest of the configuration and input data:

In [50]:
# get information from config
with open("config.yaml") as f:
    config = yaml.safe_load(f)

antibody = prob_escape["antibody"].unique()
assert len(antibody) == 1, antibody
antibody = antibody[0]

# get site numbering map and the reference sites in order
site_numbering_map = pd.read_csv(config["site_numbering_map"])
reference_sites = site_numbering_map.sort_values("sequential_site")[
    "reference_site"
].tolist()

# get the polyclonal configuration for this antibody
with open(config["polyclonal_config"]) as f:
    polyclonal_config = yaml.safe_load(f)
if antibody not in polyclonal_config:
    raise ValueError(f"`polyclonal_config` lacks configuration for {antibody=}")
antibody_config = polyclonal_config[antibody]

# print names of variables and settings
print(f"{antibody=}")
print(f"{n_threads=}")
print(f"{pickle_file=}")
print(f"{antibody_config=}")

antibody='1C04-5G04'
n_threads=2
pickle_file='results/polyclonal_fits/libB_221108_1_1C04-5G04_1.pickle'
antibody_config={'min_epitope_activity_to_include': 0.2, 'plot_kwargs': {'addtl_slider_stats': {'times_seen': 3, 'functional effect': -1.38}, 'slider_binding_range_kwargs': {'n_models': {'step': 1}, 'times_seen': {'step': 1, 'min': 1, 'max': 25}}, 'heatmap_max_at_least': 2, 'heatmap_min_at_least': -2}, 'max_epitopes': 2, 'fit_kwargs': {'reg_escape_weight': 0.1, 'reg_spread_weight': 0.25, 'reg_activity_weight': 1.0}}


## Some summary statistics
Note that these statistics are only for the variants that passed upstream filtering in the pipeline.

Number of variants per concentration:

In [51]:
display(
    prob_escape.groupby("antibody_concentration").aggregate(
        n_variants=pd.NamedAgg("barcode", "nunique")
    )
)

Unnamed: 0_level_0,n_variants
antibody_concentration,Unnamed: 1_level_1
1.37,31045
2.05,31045
3.08,31045
4.62,31045
6.93,31045
10.4,31045


Plot mean probability of escape across all variants with the indicated number of mutations.
Note that this plot weights each variant the same in the means regardless of how many barcode counts it has.
We plot means for both censored (set to between 0 and 1) and uncensored probabilities of escape.
Also, note it uses a symlog scale for the y-axis.
Mouseover points for values:

In [43]:
max_aa_subs = 4  # group if >= this many substitutions

mean_prob_escape = (
    prob_escape.assign(
        n_subs=lambda x: (
            x["aa_substitutions_reference"]
            .str.split()
            .map(len)
            .clip(upper=max_aa_subs)
            .map(lambda n: str(n) if n < max_aa_subs else f">{max_aa_subs - 1}")
        )
    )
    .groupby(["antibody_concentration", "n_subs"], as_index=False)
    .aggregate({"prob_escape": "mean", "prob_escape_uncensored": "mean"})
    .rename(
        columns={
            "prob_escape": "censored to [0, 1]",
            "prob_escape_uncensored": "not censored",
        }
    )
    .melt(
        id_vars=["antibody_concentration", "n_subs"],
        var_name="censored",
        value_name="probability escape",
    )
)

mean_prob_escape_chart = (
    alt.Chart(mean_prob_escape)
    .encode(
        x=alt.X("antibody_concentration"),
        y=alt.Y(
            "probability escape",
            scale=alt.Scale(type="symlog", constant=0.05),
        ),
        column=alt.Column("censored", title=None),
        color=alt.Color("n_subs", title="n substitutions"),
        tooltip=[
            alt.Tooltip(c, format=".3g") if mean_prob_escape[c].dtype == float else c
            for c in mean_prob_escape.columns
        ],
    )
    .mark_line(point=True, size=0.5)
    .properties(width=200, height=125)
    .configure_axis(grid=False)
)

mean_prob_escape_chart

  for col_name, dtype in df.dtypes.iteritems():


## Fit `polyclonal` model
First, get the fitting related keyword arguments from the configuration passed by `snakemake`:

In [44]:
max_epitopes = antibody_config["max_epitopes"]
print(f"{max_epitopes=}")

fit_kwargs = antibody_config["fit_kwargs"]
print(f"{fit_kwargs=}")

min_epitope_activity_to_include = antibody_config["min_epitope_activity_to_include"]
print(f"{min_epitope_activity_to_include=}")

max_epitopes=2
fit_kwargs={'reg_escape_weight': 0.1, 'reg_spread_weight': 0.25, 'reg_activity_weight': 1.0}
min_epitope_activity_to_include=0.2


Fit a model to all the data, and keep adding epitopes until we either reach the maximum specified or the new epitope has negative activity.
Note that that we fit using the **reference** based-site-numbering scheme, so results are shown with those numbers:Z

In [45]:
models = []

for n_epitopes in range(1, max_epitopes + 1):
    print(f"\nFitting model with {n_epitopes=}")

    # create model
    model = polyclonal.Polyclonal(
        n_epitopes=n_epitopes,
        data_to_fit=prob_escape.rename(
            columns={
                "antibody_concentration": "concentration",
                "aa_substitutions_reference": "aa_substitutions",
            }
        ),
        alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
        sites=reference_sites,
    )

    # fit model
    opt_res = model.fit(logfreq=200, **fit_kwargs)

    # display activities
    print("Activities of epitopes:")
    display(model.activity_wt_df.round(1))
    print("Max and mean absolute-value escape at each epitope:")
    display(
        model.mut_escape_df.groupby("epitope")
        .aggregate(
            max_escape=pd.NamedAgg("escape", "max"),
            mean_abs_escape=pd.NamedAgg("escape", lambda s: s.abs().mean()),
        )
        .round(1)
    )

    # stop if activity below threshold for any epitope and fit at least one epitope
    if len(models) and any(
        model.activity_wt_df["activity"] <= min_epitope_activity_to_include
    ):
        print(f"Stop fitting, epitope has activity <={min_epitope_activity_to_include}")
        models.append(model)
        model = models[-2]  # get previous model
        break
    else:
        models.append(model)

print(f"\nThe selected model has {len(model.epitopes)} epitopes")


Fitting model with n_epitopes=1
# First fitting site-level model.
# Starting optimization of 521 parameters at Wed Nov  9 17:49:00 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.026862       2645.7       2644.8            0            0            0      0.90499
           25      0.62029       35.824        31.22      0.12377            0            0       4.4807
# Successfully finished at Wed Nov  9 17:49:01 2022.
# Starting optimization of 3363 parameters at Wed Nov  9 17:49:01 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.028375       50.561       44.415        1.666   5.1638e-34            0       4.4807
           18       0.6669       49.026       44.431       0.0075   0.00019871            0       4.5876
# Successfully finished at Wed Nov  9 17:49:02 2022.
Activities of epitopes:


Unnamed: 0,epitope,activity
0,1,4.7


Max and mean absolute-value escape at each epitope:


Unnamed: 0_level_0,max_escape,mean_abs_escape
epitope,Unnamed: 1_level_1,Unnamed: 2_level_1
1,0.1,0.0



Fitting model with n_epitopes=2
# First fitting site-level model.
# Starting optimization of 1042 parameters at Wed Nov  9 17:49:05 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.040495       141.45       140.55            0            0            0      0.90499
           31       1.4483       34.221       31.201      0.16374            0            0       2.8566
# Successfully finished at Wed Nov  9 17:49:06 2022.
# Starting optimization of 6726 parameters at Wed Nov  9 17:49:06 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.051787       49.583       44.489       2.2376    5.036e-33            0       2.8566
           22       1.3642       47.527       44.512     0.014202   0.00038534            0       3.0006
# Successfully finished at Wed Nov  9 17:49:08 2022.
Activities of epitopes:


Unnamed: 0,epitope,activity
0,1,3.0
1,2,0.1


Max and mean absolute-value escape at each epitope:


Unnamed: 0_level_0,max_escape,mean_abs_escape
epitope,Unnamed: 1_level_1,Unnamed: 2_level_1
1,0.1,0.0
2,0.0,0.0


Stop fitting, epitope has activity <=0.2

The selected model has 1 epitopes


Epitope activities:

In [46]:
model.activity_wt_barplot()

  for col_name, dtype in df.dtypes.iteritems():


Plot of escape values:

In [47]:
df_to_merge = site_numbering_map.rename(columns={"reference_site": "site"})

plot_kwargs = antibody_config["plot_kwargs"]
if "plot_title" not in plot_kwargs:
    plot_kwargs["plot_title"] = str(antibody)
if "region" in site_numbering_map:
    plot_kwargs["site_zoom_bar_color_col"] = "region"
if "addtl_slider_stats" not in plot_kwargs:
    plot_kwargs["addtl_slider_stats"] = {"times_seen": 1}
elif "times_seen" not in plot_kwargs["addtl_slider_stats"]:
    plot_kwargs["addtl_slider_stats"]["times_seen"] = 1
if "functional effect" in plot_kwargs["addtl_slider_stats"]:
    del plot_kwargs["addtl_slider_stats"]["functional effect"]  # only antibody averages
if any(site_numbering_map["sequential_site"] != site_numbering_map["reference_site"]):
    if "addtl_tooltip_stats" not in plot_kwargs:
        plot_kwargs["addtl_tooltip_stats"] = ["sequential_site"]
    else:
        plot_kwargs["addtl_tooltip_stats"].append("sequential_site")

model.mut_escape_plot(df_to_merge=df_to_merge, **plot_kwargs)

  for col_name, dtype in df.dtypes.iteritems():


In [25]:
model = polyclonal.Polyclonal(
    n_epitopes=2,
    data_to_fit=prob_escape.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_sequential": "aa_substitutions",
        }
    ).query("concentration > 1.0"),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
)

# fit model
opt_res = model.fit(
    logfreq=200,
    reg_escape_weight=0.1,
    reg_spread_weight=0.25,
    reg_activity_weight=1.0,
    reg_similarity_weight=0.2,  # regularize epitope similarity
)
model.mut_escape_plot(category_colors={"1": "#0072B2", "2": "#CC79A7"})

# First fitting site-level model.
# Starting optimization of 1042 parameters at Wed Nov  9 17:46:22 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.048524       508.33       507.43            0            0            0      0.90499
           53       2.7185       174.56       170.54       1.5508            0     0.016722       2.4501
# Successfully finished at Wed Nov  9 17:46:24 2022.
# Starting optimization of 6726 parameters at Wed Nov  9 17:46:25 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.059301       275.29       248.45       20.307   4.5985e-32        4.081       2.4501
           88       5.9217       250.26       244.64       2.2733      0.46527     0.051356        2.833
# Successfully finished at Wed Nov  9 17:46:30 2022.


  for col_name, dtype in df.dtypes.iteritems():


## Combine with libA data - 

In [52]:
# Parameters
prob_escape_csv = "results/prob_escape/libA_221108_1_1C04-5G04_1_prob_escape.csv"
pickle_file = "results/polyclonal_fits/libA_221108_1_1C04-5G04_1.pickle"
n_threads = 2


In [53]:
print(f"\nReading probabilities of escape from {prob_escape_csv}")

prob_escape_libA = pd.read_csv(
    prob_escape_csv, keep_default_na=False, na_values="nan"
).query("`no-antibody_count` >= no_antibody_count_threshold")
assert prob_escape.notnull().all().all()


Reading probabilities of escape from results/prob_escape/libA_221108_1_1C04-5G04_1_prob_escape.csv


In [54]:
# get information from config
with open("config.yaml") as f:
    config = yaml.safe_load(f)

antibody = prob_escape["antibody"].unique()
assert len(antibody) == 1, antibody
antibody = antibody[0]

# get site numbering map and the reference sites in order
site_numbering_map = pd.read_csv(config["site_numbering_map"])
reference_sites = site_numbering_map.sort_values("sequential_site")[
    "reference_site"
].tolist()

# get the polyclonal configuration for this antibody
with open(config["polyclonal_config"]) as f:
    polyclonal_config = yaml.safe_load(f)
if antibody not in polyclonal_config:
    raise ValueError(f"`polyclonal_config` lacks configuration for {antibody=}")
antibody_config = polyclonal_config[antibody]

# print names of variables and settings
print(f"{antibody=}")
print(f"{n_threads=}")
print(f"{pickle_file=}")
print(f"{antibody_config=}")

antibody='1C04-5G04'
n_threads=2
pickle_file='results/polyclonal_fits/libA_221108_1_1C04-5G04_1.pickle'
antibody_config={'min_epitope_activity_to_include': 0.2, 'plot_kwargs': {'addtl_slider_stats': {'times_seen': 3, 'functional effect': -1.38}, 'slider_binding_range_kwargs': {'n_models': {'step': 1}, 'times_seen': {'step': 1, 'min': 1, 'max': 25}}, 'heatmap_max_at_least': 2, 'heatmap_min_at_least': -2}, 'max_epitopes': 2, 'fit_kwargs': {'reg_escape_weight': 0.1, 'reg_spread_weight': 0.25, 'reg_activity_weight': 1.0}}


In [55]:
prob_escape_full = pd.concat([prob_escape, prob_escape_libA])
prob_escape_full

Unnamed: 0,library,antibody_sample,no-antibody_sample,aa_substitutions_sequential,n_aa_substitutions,barcode,prob_escape,prob_escape_uncensored,antibody_count,no-antibody_count,antibody_neut_standard_count,no-antibody_neut_standard_count,total_no_antibody_count,no_antibody_count_threshold,aa_substitutions_reference,antibody,antibody_concentration
0,libB,221108_1_antibody_1C04-5G04_10.4_1,221108_1_no-antibody_control_1,Q76M L263M,2,TAACGAGGTCTAATCA,0.0006,0.0006,15667,85467,984817,3489,9397483,15,Q57M L244M,1C04-5G04,10.40
1,libB,221108_1_antibody_1C04-5G04_10.4_1,221108_1_no-antibody_control_1,Q76L K101A P246S E429M,4,GTTATATAAGCGGTAA,0.2806,0.2806,12912,163,984817,3489,9397483,15,Q57L K82A P227S E410M,1C04-5G04,10.40
2,libB,221108_1_antibody_1C04-5G04_10.4_1,221108_1_no-antibody_control_1,P122D R220A R239A L515I,4,AATAATAGGTGAGTCT,0.0007,0.0007,9564,49975,984817,3489,9397483,15,P103D R201A R220A L496I,1C04-5G04,10.40
3,libB,221108_1_antibody_1C04-5G04_10.4_1,221108_1_no-antibody_control_1,V509T,1,ACCGCGGCCTAATACG,0.0004,0.0004,8054,81504,984817,3489,9397483,15,V490T,1C04-5G04,10.40
4,libB,221108_1_antibody_1C04-5G04_10.4_1,221108_1_no-antibody_control_1,R220S,1,ACATGATTTATTGTCT,0.0005,0.0005,7183,50325,984817,3489,9397483,15,R201S,1C04-5G04,10.40
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
399473,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,R111I G161N K410E,3,TTTTTTCCCCTATGCA,0.0000,0.0000,0,32,68369,6147,10886757,15,R92I G142N K391E,1C04-5G04,1.37
399475,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,A182V K283A S289T T337L,4,TTTTTTGAACACCAAG,0.0000,0.0000,0,153,68369,6147,10886757,15,A163V K264A S270T T318L,1C04-5G04,1.37
399478,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,Q330A K487T,2,TTTTTTGCATAGAGAG,0.0000,0.0000,0,46,68369,6147,10886757,15,Q311A K468T,1C04-5G04,1.37
399480,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,T179R G205C P343R V366L E378D L458I,6,TTTTTTTCACCAACTG,0.0000,0.0000,0,84,68369,6147,10886757,15,T160R G186C P324R V347L E359D L439I,1C04-5G04,1.37


In [56]:
display(
    prob_escape_full.groupby("antibody_concentration").aggregate(
        n_variants=pd.NamedAgg("barcode", "nunique")
    )
)

Unnamed: 0_level_0,n_variants
antibody_concentration,Unnamed: 1_level_1
1.37,61762
2.05,61762
3.08,61762
4.62,61762
6.93,61762
10.4,61762


In [57]:
max_aa_subs = 4  # group if >= this many substitutions

mean_prob_escape = (
    prob_escape_full.assign(
        n_subs=lambda x: (
            x["aa_substitutions_reference"]
            .str.split()
            .map(len)
            .clip(upper=max_aa_subs)
            .map(lambda n: str(n) if n < max_aa_subs else f">{max_aa_subs - 1}")
        )
    )
    .groupby(["antibody_concentration", "n_subs"], as_index=False)
    .aggregate({"prob_escape": "mean", "prob_escape_uncensored": "mean"})
    .rename(
        columns={
            "prob_escape": "censored to [0, 1]",
            "prob_escape_uncensored": "not censored",
        }
    )
    .melt(
        id_vars=["antibody_concentration", "n_subs"],
        var_name="censored",
        value_name="probability escape",
    )
)

mean_prob_escape_chart = (
    alt.Chart(mean_prob_escape)
    .encode(
        x=alt.X("antibody_concentration"),
        y=alt.Y(
            "probability escape",
            scale=alt.Scale(type="symlog", constant=0.05),
        ),
        column=alt.Column("censored", title=None),
        color=alt.Color("n_subs", title="n substitutions"),
        tooltip=[
            alt.Tooltip(c, format=".3g") if mean_prob_escape[c].dtype == float else c
            for c in mean_prob_escape.columns
        ],
    )
    .mark_line(point=True, size=0.5)
    .properties(width=200, height=125)
    .configure_axis(grid=False)
)

mean_prob_escape_chart

  for col_name, dtype in df.dtypes.iteritems():


In [59]:
model = polyclonal.Polyclonal(
    n_epitopes=2,
    data_to_fit=prob_escape_full.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_sequential": "aa_substitutions",
        }
    ).query("concentration > 1.0"),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
)

# fit model
opt_res = model.fit(
    logfreq=200,
    reg_escape_weight=0.1,
    reg_spread_weight=0.25,
    reg_activity_weight=1.0,
    reg_similarity_weight=0.3,  # regularize epitope similarity
)
model.mut_escape_plot(category_colors={"1": "#0072B2", "2": "#CC79A7"})

# First fitting site-level model.
# Starting optimization of 1064 parameters at Wed Nov  9 17:55:16 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0      0.11419       2744.9         2744            0            0            0      0.90499
          200       24.029         1215       1203.1       10.131            0      0.18464       1.5739
          221       26.796       1214.9         1203        10.16            0      0.18145       1.5731
# Successfully finished at Wed Nov  9 17:55:42 2022.
# Starting optimization of 7156 parameters at Wed Nov  9 17:55:43 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0      0.15753       1904.3       1798.4       87.157   3.3149e-31       17.211       1.5731
          117       20.632       1776.2       1746.7       24.787       1.5047        0.931       2.3451
# Successfully finished at Wed N

  for col_name, dtype in df.dtypes.iteritems():


In [60]:
prob_escape_full = prob_escape_full.loc[prob_escape_full['antibody_concentration'] != 1.37]

In [61]:
model = polyclonal.Polyclonal(
    n_epitopes=2,
    data_to_fit=prob_escape_full.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_sequential": "aa_substitutions",
        }
    ).query("concentration > 1.0"),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
)

# fit model
opt_res = model.fit(
    logfreq=200,
    reg_escape_weight=0.1,
    reg_spread_weight=0.25,
    reg_activity_weight=1.0,
    reg_similarity_weight=0.3,  # regularize epitope similarity
)
model.mut_escape_plot(category_colors={"1": "#0072B2", "2": "#CC79A7"})

# First fitting site-level model.
# Starting optimization of 1064 parameters at Wed Nov  9 17:56:48 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0     0.088732        987.3       986.39            0            0            0      0.90499
           57       5.6948       245.41       240.95       1.7366            0     0.021766       2.7016
# Successfully finished at Wed Nov  9 17:56:54 2022.
# Starting optimization of 7156 parameters at Wed Nov  9 17:56:54 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spreadreg_similarity reg_activity
            0      0.13295       386.99       356.22       22.915   6.1949e-32       5.1517       2.7016
           87       12.723       358.57       352.83       2.1812      0.37521     0.041357       3.1422
# Successfully finished at Wed Nov  9 17:57:06 2022.


  for col_name, dtype in df.dtypes.iteritems():
