# Bootstrapping the simulated RBD dataset

## Imports and setup

In [1]:
import pandas as pd
from polyclonal import Polyclonal, PolyclonalCollection
import polyclonal.polyclonal_collection as boot
import numpy as np

rbd_data = (
    pd.read_csv("../notebooks/RBD_variants_escape_noisy.csv", na_filter=None)
    .query('library == "avg2muts"')
    .query("concentration in [0.25, 1, 4]")
    .reset_index(drop=True)
)

rbd_data.head()

Unnamed: 0,library,aa_substitutions,concentration,prob_escape,IC90
0,avg2muts,,0.25,0.05044,0.1128
1,avg2muts,,0.25,0.1431,0.1128
2,avg2muts,,0.25,0.05452,0.1128
3,avg2muts,,0.25,0.08473,0.1128
4,avg2muts,,0.25,0.04174,0.1128


## Create and fit a `root_polyclonal` model to RBD

In [2]:
# Create a root polyclonal object
rbd_poly = Polyclonal(
    data_to_fit=rbd_data,
    activity_wt_df=pd.DataFrame.from_records(
        [
            ("1", 1.0),
            ("2", 3.0),
            ("3", 2.0),
        ],
        columns=["epitope", "activity"],
    ),
    site_escape_df=pd.DataFrame.from_records(
        [
            ("1", 417, 10.0),
            ("2", 484, 10.0),
            ("3", 444, 10.0),
        ],
        columns=["epitope", "site", "escape"],
    ),
    data_mut_escape_overlap="fill_to_data",
)

In [3]:
# Fit model
_ = rbd_poly.fit(logfreq=100)

# First fitting site-level model.
# Starting optimization of 522 parameters at Thu Mar 10 16:09:29 2022.
       step   time_sec       loss   fit_loss reg_escape  regspread
          0   0.059647     9144.4     9144.2    0.29701          0
        100     6.8575     1336.8     1333.2     3.5443          0
        200     13.214     1313.2     1308.8     4.3872          0
        300     19.418       1305     1299.8     5.1347          0
        400     25.149     1301.8     1296.2     5.6246          0
        500     31.024     1298.4     1292.5     5.8941          0
        600     36.845     1297.8     1291.7     6.0372          0
        700     43.025     1296.9     1290.4     6.5236          0
        800     48.887     1296.4     1289.8     6.6654          0
        900     54.739     1296.2     1289.4      6.759          0
       1000     60.581     1295.7     1288.8     6.8687          0
       1100     66.655     1295.4     1288.6      6.848          0
       1200     72.483  

## Create and fit different `PolyclonalCollection` objects

In [4]:
# Now create two different polyclonal collection objects
n_samps = 3
n_threads = 32
rbd_pc_a = PolyclonalCollection(
    root_polyclonal=rbd_poly,
    n_bootstrap_samples=n_samps,
    n_threads=n_threads,
    seed=0,
)
rbd_pc_b = PolyclonalCollection(
    root_polyclonal=rbd_poly,
    n_bootstrap_samples=n_samps,
    n_threads=n_threads,
    seed=10,
)
rbd_pc_a_copy = PolyclonalCollection(
    root_polyclonal=rbd_poly,
    n_bootstrap_samples=n_samps,
    n_threads=n_threads,
    seed=0,
)
rbd_pc_b_copy = PolyclonalCollection(
    root_polyclonal=rbd_poly,
    n_bootstrap_samples=n_samps,
    n_threads=n_threads * 2,
    seed=10,
)

In [5]:
rbd_pc_a.fit_models(fit_site_level_first=False)
rbd_pc_b.fit_models(fit_site_level_first=False)

In [6]:
rbd_pc_a_copy.fit_models(fit_site_level_first=False)
rbd_pc_b_copy.fit_models(fit_site_level_first=False)

## Tests to see if seed is respected (different results)

The test for mutation frequency in the mutation frequency dictionary is not suitable here.
With so many multi-variants, we may not come across situations where a mutation isn't sampled by at least one model.
For our seed tests, I will just make sure we get different summary stats for each seed.

In [7]:
(
    rbd_pc_a_escape_dict,
    rbd_pc_a_activity_wt_dict,
) = rbd_pc_a.summarize_bootstrapped_params()
(
    rbd_pc_b_escape_dict,
    rbd_pc_b_activity_wt_dict,
) = rbd_pc_b.summarize_bootstrapped_params()
(
    rbd_pc_a_copy_escape_dict,
    rbd_pc_a_copy_activity_wt_dict,
) = rbd_pc_a_copy.summarize_bootstrapped_params()
(
    rbd_pc_b_copy_escape_dict,
    rbd_pc_b_copy_activity_wt_dict,
) = rbd_pc_b_copy.summarize_bootstrapped_params()

In [8]:
# Test to see if inferreed params are the same with the same seed.
assert rbd_pc_a_escape_dict["mean"].equals(rbd_pc_a_copy_escape_dict["mean"])
assert rbd_pc_a_escape_dict["median"].equals(rbd_pc_a_copy_escape_dict["median"])
assert rbd_pc_a_escape_dict["std"].equals(rbd_pc_a_copy_escape_dict["std"])

In [9]:
assert rbd_pc_a_activity_wt_dict["mean"].equals(rbd_pc_a_copy_activity_wt_dict["mean"])
assert rbd_pc_a_activity_wt_dict["median"].equals(
    rbd_pc_a_copy_activity_wt_dict["median"]
)
assert rbd_pc_a_activity_wt_dict["std"].equals(rbd_pc_a_copy_activity_wt_dict["std"])

In [10]:
# Test to see if inferred params are the same with same seed and different thread count
assert rbd_pc_b_escape_dict["mean"].equals(rbd_pc_b_copy_escape_dict["mean"])
assert rbd_pc_b_escape_dict["median"].equals(rbd_pc_b_copy_escape_dict["median"])
assert rbd_pc_b_escape_dict["std"].equals(rbd_pc_b_copy_escape_dict["std"])

assert rbd_pc_b_activity_wt_dict["mean"].equals(rbd_pc_b_copy_activity_wt_dict["mean"])
assert rbd_pc_b_activity_wt_dict["median"].equals(
    rbd_pc_b_copy_activity_wt_dict["median"]
)
assert rbd_pc_b_activity_wt_dict["std"].equals(rbd_pc_b_copy_activity_wt_dict["std"])

In [11]:
# Make sure inferred params are different
assert not rbd_pc_a_escape_dict["mean"].equals(rbd_pc_b_escape_dict["mean"])
assert not rbd_pc_a_escape_dict["median"].equals(rbd_pc_b_escape_dict["median"])
assert not rbd_pc_a_escape_dict["std"].equals(rbd_pc_b_escape_dict["std"])

In [12]:
assert not rbd_pc_a_activity_wt_dict["mean"].equals(rbd_pc_b_activity_wt_dict["mean"])
assert not rbd_pc_a_activity_wt_dict["median"].equals(
    rbd_pc_b_activity_wt_dict["median"]
)
assert not rbd_pc_a_activity_wt_dict["std"].equals(rbd_pc_b_activity_wt_dict["std"])

In [13]:
test_df = rbd_data.sample(n=200, random_state=0)
pc_a_preds = pd.concat(rbd_pc_a.make_predictions(test_df))
pc_b_preds = pd.concat(rbd_pc_b.make_predictions(test_df))
pc_a_copy_preds = pd.concat(rbd_pc_a_copy.make_predictions(test_df))
pc_b_copy_preds = pd.concat(rbd_pc_b_copy.make_predictions(test_df))

In [14]:
assert pc_a_preds.equals(pc_a_copy_preds)

In [15]:
assert not pc_a_preds.equals(pc_b_preds)

In [16]:
# Test threads for reproducability
assert pc_b_preds.equals(pc_b_copy_preds)