# Bootstrapping the simulated RBD dataset

## Imports and setup

In [1]:
import pandas as pd
from polyclonal import Polyclonal, PolyclonalCollection
import polyclonal.polyclonal_collection as boot
import numpy as np

rbd_data = (
    pd.read_csv("../notebooks/RBD_variants_escape_noisy.csv", na_filter=None)
    .query('library == "avg2muts"')
    .query("concentration in [0.25, 1, 4]")
    .reset_index(drop=True)
)

rbd_data.head()

Unnamed: 0,library,aa_substitutions,concentration,prob_escape,IC90
0,avg2muts,,0.25,0.05044,0.1128
1,avg2muts,,0.25,0.1431,0.1128
2,avg2muts,,0.25,0.05452,0.1128
3,avg2muts,,0.25,0.08473,0.1128
4,avg2muts,,0.25,0.04174,0.1128


## Create and fit a `root_polyclonal` model to RBD

In [2]:
# Create a root polyclonal object
rbd_poly = Polyclonal(
    data_to_fit=rbd_data,
    activity_wt_df=pd.DataFrame.from_records(
        [
            ("1", 1.0),
            ("2", 3.0),
            ("3", 2.0),
        ],
        columns=["epitope", "activity"],
    ),
    site_escape_df=pd.DataFrame.from_records(
        [
            ("1", 417, 10.0),
            ("2", 484, 10.0),
            ("3", 444, 10.0),
        ],
        columns=["epitope", "site", "escape"],
    ),
    data_mut_escape_overlap="fill_to_data",
)

In [3]:
# Fit model
_ = rbd_poly.fit(logfreq=100)

# First fitting site-level model.
# Starting optimization of 522 parameters at Thu Feb 24 11:00:01 2022.
       step   time_sec       loss   fit_loss reg_escape  regspread
          0   0.052244     9144.4     9144.2    0.29701          0
        100     6.8986     1336.8     1333.2     3.5443          0
        200     13.906     1313.2     1308.8     4.3872          0
        300      21.41       1305     1299.8     5.1347          0
        400     28.214     1301.8     1296.2     5.6246          0
        500     34.967     1298.4     1292.5     5.8941          0
        600      41.68     1297.8     1291.7     6.0372          0
        700     48.761     1296.9     1290.4     6.5236          0
        800     55.565     1296.4     1289.8     6.6654          0
        900     62.327     1296.2     1289.4      6.759          0
       1000     69.127     1295.7     1288.8     6.8687          0
       1100     76.115     1295.4     1288.6      6.848          0
       1200     82.921  

## Create and fit two different `PolyclonalCollection` objects

In [4]:
# Now create two different polyclonal collection objects
n_samps = 5
n_threads = 32
rbd_pc_a = PolyclonalCollection(
    root_polyclonal=rbd_poly,
    n_bootstrap_samples=n_samps,
    n_threads=n_threads,
    seed=0,
)
rbd_pc_b = PolyclonalCollection(
    root_polyclonal=rbd_poly,
    n_bootstrap_samples=n_samps,
    n_threads=n_threads,
    seed=10,
)

In [5]:
rbd_pc_a.fit_models()
rbd_pc_b.fit_models()

## Tests to see if seed is respected (different results)

In [6]:
# The test for the mutation frequency dictionary isn't suitable here
# With so many multi-variants, we may not get situations where a mutation isn't sampled by all models
# For our seed tests, I will just make sure we get different summary stats for each seed.
(
    rbd_pc_a_escape_dict,
    rbd_pc_a_activity_wt_dict,
) = rbd_pc_a._summarize_bootstrapped_params()
(
    rbd_pc_b_escape_dict,
    rbd_pc_b_activity_wt_dict,
) = rbd_pc_b._summarize_bootstrapped_params()

In [7]:
# Make sure inferred params are different
assert not rbd_pc_a_escape_dict["mean"].equals(rbd_pc_b_escape_dict["mean"])
assert not rbd_pc_a_escape_dict["median"].equals(rbd_pc_b_escape_dict["median"])
assert not rbd_pc_a_escape_dict["std"].equals(rbd_pc_b_escape_dict["std"])

In [8]:
assert not rbd_pc_a_activity_wt_dict["mean"].equals(rbd_pc_b_activity_wt_dict["mean"])
assert not rbd_pc_a_activity_wt_dict["median"].equals(
    rbd_pc_b_activity_wt_dict["median"]
)
assert not rbd_pc_a_activity_wt_dict["std"].equals(rbd_pc_b_activity_wt_dict["std"])

In [9]:
test_df = rbd_data.sample(n=200, random_state=0)
pc_a_preds = pd.concat(rbd_pc_a.make_predictions(test_df))
pc_b_preds = pd.concat(rbd_pc_b.make_predictions(test_df))

In [10]:
assert not pc_a_preds.equals(pc_b_preds)