# Bootstrapping model fits
The previous section describes fitting a single model.
But we may also want to have confidence estimates for the fit.
We can do that via bootstrapping the data set.

The overall recommended workflow is to first fit models to all the data to determine the number of epitopes, etc.
Then once the desired fitting parameters are determined, you can bootstrap to get confidence on predictions.

## Get model fit to the data
The first step is just to fit a `Polyclonal` model to all the data we are using.
We do similar to the previous notebook for our RBD example, but first shrink the size of the data set to just 7500 variants to provide more "error" to better illustrate the bootstrapping.

We will call this model fit to all the data we are using the "root" model as it's used as the starting point (root) for the subsequent bootstrapping.
Note that data (which we will bootstrap) are attached to this pre-fit model:

In [1]:
# NBVAL_IGNORE_OUTPUT

import pandas as pd

import polyclonal

# read the data, and just make "barcode" the numerical rank of the variants
noisy_data = (
    pd.read_csv("RBD_variants_escape_noisy.csv", na_filter=None)
    .query('library == "avg3muts"')
    .query("concentration in [0.25, 1, 4]")
    .sort_values(["concentration", "aa_substitutions"])
    .reset_index(drop=True)
    .assign(barcode=lambda x: x.groupby("concentration").cumcount())
)

# just keep some variants to make fitting "noisier"
n_keep = 7500
barcodes_to_keep = (
    noisy_data["barcode"].drop_duplicates().sample(n_keep, random_state=1).tolist()
)
noisy_data = noisy_data.query("barcode in @barcodes_to_keep")

# make and fit the root Polyclonal object with all the data we are using
root_poly = polyclonal.Polyclonal(
    data_to_fit=noisy_data,
    activity_wt_df=pd.DataFrame.from_records(
        [
            ("class 1", 1.0),
            ("class 2", 3.0),
            ("class 3", 2.0),
        ],
        columns=["epitope", "activity"],
    ),
    site_escape_df=pd.DataFrame.from_records(
        [
            ("class 1", 417, 10.0),
            ("class 2", 484, 10.0),
            ("class 3", 444, 10.0),
        ],
        columns=["epitope", "site", "escape"],
    ),
    data_mut_escape_overlap="fill_to_data",
)

opt_res = root_poly.fit(logfreq=100)

# First fitting site-level model.
# Starting optimization of 522 parameters at Sat Mar 19 06:36:28 2022.
       step   time_sec       loss   fit_loss reg_escape  regspread
          0   0.023001       4506     4505.7    0.29701          0
        100     2.6045     550.09     546.34     3.7432          0
        200     4.9812     541.57     537.02     4.5554          0
        300     7.3625        539     533.84     5.1659          0
        400     9.6883     538.27      532.9     5.3674          0
        500     12.051     537.67     532.23     5.4371          0
        600      14.34     537.13     531.63     5.5078          0
        700     16.685     536.64     530.85     5.7896          0
        800      19.25     536.31      530.4     5.9137          0
        900     21.762     536.06     530.11      5.956          0
       1000     24.082     535.51      529.4     6.1068          0
       1100     26.449     535.05      528.9     6.1524          0
       1200     28.804  

## Create and fit bootstrapped models
To create the bootstrapped models, we initialize a `PolyclonalCollection`, here just using 10 samples for speed (for better error estimates you may want more on the order of 20 to 100).
Note it is important that the root model you are using has already been fit to the data!

In [2]:
n_bootstrap_samples = 10

bootstrap_poly = polyclonal.PolyclonalCollection(
    root_polyclonal=root_poly,
    n_bootstrap_samples=n_bootstrap_samples,
)

Now fit the bootstrapped models:

In [3]:
# NBVAL_IGNORE_OUTPUT

import time

start = time.time()
print(f"Starting fitting bootstrap models at {time.asctime()}")
n_fit, n_failed = bootstrap_poly.fit_models()
print(f"Fitting took {time.time() - start:.3g} seconds, finished at {time.asctime()}")
assert n_failed == 0 and n_fit == n_bootstrap_samples

Starting fitting bootstrap models at Sat Mar 19 06:38:45 2022
Fitting took 85.2 seconds, finished at Sat Mar 19 06:40:10 2022


## Look at summarized results
We can get the resulting measurements for the epitope activities and mutation effects both per-replicate and summarized across replicates (mean, median, standard deviation).

### Epitope activities
Epitope activities for each replicate:

In [4]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.activity_wt_df_replicates.round(1)

Unnamed: 0,epitope,activity,bootstrap_replicate
0,class 1,2.0,1
1,class 2,2.6,1
2,class 3,2.1,1
3,class 1,1.9,2
4,class 2,2.7,2
5,class 3,1.9,2
6,class 1,2.1,3
7,class 2,2.5,3
8,class 3,2.0,3
9,class 1,1.9,4


Epitope activities summarized across replicates.
The `std` column gives the standard deviation:

In [5]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.activity_wt_df.round(1)

Unnamed: 0,epitope,mean,median,std
0,class 1,2.0,2.0,0.1
1,class 2,2.6,2.6,0.1
2,class 3,2.0,2.0,0.1


We can plot the epitope activities summarized across replicates.
The dropdown allows you to choose the summary stat (mean, median), and the black lines indicate the standard deviation.
Mouse over for values:

In [6]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.activity_wt_barplot()

### Mutation escape values
Mutation escape values for each replicate:

In [7]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.mut_escape_df_replicates.round(1).head()

Unnamed: 0,epitope,site,wildtype,mutant,mutation,escape,bootstrap_replicate
0,class 1,331,N,A,N331A,0.4,1
1,class 1,331,N,D,N331D,-0.4,1
2,class 1,331,N,E,N331E,0.3,1
3,class 1,331,N,F,N331F,0.1,1
4,class 1,331,N,G,N331G,0.2,1


Mutation escape values summarizes across replicates.
Note the `frac_bootstrap_replicates` column has the fraction of bootstrap replicates with a value for this mutation:

In [8]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.mut_escape_df.round(1).head(n=3)

Unnamed: 0,epitope,site,wildtype,mutant,mutation,mean,median,std,n_bootstrap_replicates,frac_bootstrap_replicates
0,class 1,331,N,A,N331A,0.2,0.3,0.3,10,1.0
1,class 1,331,N,D,N331D,0.0,-0.0,0.3,10,1.0
2,class 1,331,N,E,N331E,0.0,0.0,0.4,10,1.0


We can plot the mutation escape values across replicates.
The dropdown selects the statistic shown in the heatmap (mean or median), and mouseovers give details on points.
Here we set `min_frac_bootstrap_replicates=0.9` to only report escape values observed in at least 90% of bootstrap replicates (this gets rid of rare mutations):

In [9]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.mut_escape_heatmap(min_frac_bootstrap_replicates=0.9)

### Site summaries of mutation escape
Site summaries of mutation escape values for replicates:

In [10]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.mut_escape_site_summary_df_replicates.round(1).head()

Unnamed: 0,epitope,site,wildtype,mean,total positive,max,min,total negative,bootstrap_replicate
0,class 1,331,N,0.5,8.9,1.8,-0.7,-1.3,1
1,class 1,332,I,0.6,10.6,1.5,0.0,0.0,1
2,class 1,333,T,0.5,9.4,1.3,-0.7,-0.9,1
3,class 1,334,N,0.8,13.9,1.9,-0.2,-0.3,1
4,class 1,335,L,0.5,9.5,1.5,-0.5,-0.8,1


Site summaries of mutation escape values summarized (e.g., averaged) across replicates.
Note that the `metric` column now indicates a different row for each site-summary metric type, which is then summarized by its mean, median, and standard deviation:

In [11]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.mut_escape_site_summary_df.round(1).head()

Unnamed: 0,epitope,site,wildtype,metric,mean,median,std,n_bootstrap_replicates,frac_bootstrap_replicates
0,class 1,331,N,max,1.7,1.7,0.3,10,1.0
1,class 1,331,N,mean,0.5,0.5,0.2,10,1.0
2,class 1,331,N,min,-0.3,-0.3,0.2,10,1.0
3,class 1,331,N,total negative,-0.7,-0.5,0.6,10,1.0
4,class 1,331,N,total positive,9.2,9.1,2.6,10,1.0


We can plot site summaries of the mutation escape.
Note that there is an option to toggle on/off the error bars (standard deviations) and show what metric is shown (e.g., mean effect of mutation, total positive escape at site, etc) as well as how that metric is summarize (mean, median):

In [12]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.mut_escape_lineplot(min_frac_bootstrap_replicates=0.9)

## Some tests
Below are just tests for approximate consistency of results with what is expected:

In [20]:
for attr in ["activity_wt_df", "mut_escape_site_summary_df", "mut_escape_df"]:
    f = f"RBD_bootstrap_expected_{attr}.csv"
    expected = pd.read_csv(f)
    print(f"Testing {attr}")
    pd.testing.assert_frame_equal(
        getattr(bootstrap_poly, attr),
        expected,
        atol=0.5,
        obj=f"{attr} DataFrame",
    )

Testing activity_wt_df
Testing mut_escape_site_summary_df
Testing mut_escape_df
