# Getting started with `polyclonal` bootstrapping

## Imports and setup

In [1]:
import pandas as pd
from polyclonal import Polyclonal, PolyclonalCollection
import polyclonal.polyclonal_collection as boot
import polyclonal.plot
import numpy as np

Below contains some simulated data from Jesse
* One simulation where we have the same variants for each conecntration
* One simulation where we do not have the same variants for each concentration
* These two situations hare handeled differently in `polyclonal` objects

In [2]:
activity_wt_df = pd.DataFrame({"epitope": ["1", "2"], "activity": [2.0, 1.0]})

mut_escape_df = pd.DataFrame(
    {
        "mutation": ["M1C", "M1C", "G2A", "G2A", "A4K", "A4K", "A4L", "A4L"],
        "epitope": ["1", "2", "1", "2", "1", "2", "1", "2"],
        "escape": [2.0, 0.0, 3.0, 0.0, 0.0, 2.5, 0.0, 1.5],
    }
)

polyclonal_sim = Polyclonal(activity_wt_df=activity_wt_df, mut_escape_df=mut_escape_df)

variants_df = pd.DataFrame.from_records(
    [
        ("AA", ""),
        ("AC", "M1C"),
        ("AG", "G2A"),
        ("AT", "A4K"),
        ("TA", "A4L"),
        ("CA", "M1C G2A"),
        ("CG", "M1C A4K"),
        ("CC", "G2A A4K"),
        ("TC", "G2A A4L"),
        ("CT", "M1C G2A A4K"),
        ("TG", "M1C G2A A4L"),
        ("GA", "M1C"),
    ],
    columns=["barcode", "aa_substitutions"],
)

escape_probs = polyclonal_sim.prob_escape(
    variants_df=variants_df, concentrations=[1.0, 2.0, 4.0]
)

data_to_fit = escape_probs.rename(columns={"predicted_prob_escape": "prob_escape"})

polyclonal_data = Polyclonal(
    data_to_fit=data_to_fit,
    activity_wt_df=activity_wt_df,
    site_escape_df=pd.DataFrame.from_records(
        [("1", 1, 1.0), ("1", 4, 0.0), ("2", 1, 0.0), ("2", 4, 2.0)],
        columns=["epitope", "site", "escape"],
    ),
    data_mut_escape_overlap="fill_to_data",
)

# Now only use the first 30 elements to get rid of some mutations
polyclonal_data2 = Polyclonal(
    data_to_fit=data_to_fit.head(20),
    activity_wt_df=activity_wt_df,
    site_escape_df=pd.DataFrame.from_records(
        [("1", 1, 1.0), ("1", 4, 0.0), ("2", 1, 0.0), ("2", 4, 2.0)],
        columns=["epitope", "site", "escape"],
    ),
    data_mut_escape_overlap="fill_to_data",
)

## Test basic functionality of helper methods outside of class

In [3]:
boot_df = boot.create_bootstrap_sample(data_to_fit)

# Make sure we got an appropriate number of samples
assert len(boot_df) == len(data_to_fit)
# Make sure we did sample with replacement
assert not len(boot_df.drop_duplicates()) == len(data_to_fit)
# Make sure we sampled the same number of variants at each concentration as the orignal dataset has
assert boot_df.concentration.value_counts().equals(
    data_to_fit.concentration.value_counts()
)

## Test initialization of `PolyclonalCollection` objects

In [4]:
n_samps = 5
n_threads = 4
pc = PolyclonalCollection(
    root_polyclonal=polyclonal_data,
    n_bootstrap_samples=n_samps,
    n_threads=n_threads,
    seed=0,
)

# Make sure the desired number of models were created and stored
assert len(pc.models) == n_samps
# Make sure attributes are stored properly
assert pc.n_threads == n_threads
# Make sure we aren't copying the actual data from the original polyclonal model
for i in range(n_samps):
    assert not pc.root_polyclonal.data_to_fit.equals(pc.models[i].data_to_fit)

## Test random seeding

In [5]:
# Do two different seeds generate different objects?
pc2 = PolyclonalCollection(
    root_polyclonal=polyclonal_data,
    n_bootstrap_samples=n_samps,
    n_threads=n_threads,
    seed=10,
)

In [6]:
# What if we use the same seed with multiple threads?
pc_copy = PolyclonalCollection(
    root_polyclonal=polyclonal_data,
    n_bootstrap_samples=n_samps,
    n_threads=n_threads,
    seed=0,
)

## Test `PolyclonalCollection` bootstrapping results

In [7]:
# Test to ensure that parameters change during fitting.
for model in pc.models:
    old_params = model._params
    boot._fit_polyclonal_model_static(model, fit_site_level_first=False)
    # Did the params change?
    assert not np.array_equal(model._params, old_params)

In [8]:
# Test `fit_models()` with kwargs -- shouldn't throw any errors
pc_copy.fit_models(fit_site_level_first=False)

(5, 0)

### Test predictions across all models

In [9]:
# Create a list of prediction dataframes from bootstrapped models
test_list = pc.make_predictions(variants_df=boot_df)

# Make sure we got the same number of predictions for every model
for result in test_list:
    assert len(result) == len(boot_df)
    assert not (result.shape[1]) == (boot_df.shape[1])

In [10]:
# Aggregate the results
agg_results = pc.summarize_bootstraped_predictions(test_list)
assert len(test_list[0]) != len(agg_results)

## Test aggregation results

### Test summary validity

In [11]:
# Ensure center summary stats are still probabilities
assert agg_results.mean_predicted_prob_escape.between(0, 1).all()
assert agg_results.median_predicted_prob_escape.between(0, 1).all()

In [12]:
test_list_copy = pc_copy.make_predictions(variants_df=boot_df)
agg_results_copy = pc_copy.summarize_bootstraped_predictions(test_list_copy)
assert agg_results.equals(agg_results_copy)

In [13]:
# Test that different seeds result in differnt results
pc2.fit_models(fit_site_level_first=False)

(5, 0)

In [14]:
test_list_2 = pc2.make_predictions(variants_df=boot_df)

## Plotting

### Heatmaps for summary statistics and parameters

In [18]:
# Normal plotting not impacted
polyclonal_data.mut_escape_heatmap()

In [19]:
# Plot of mean
polyclonal.plot.mut_escape_heatmap(
    mut_escape_df=summary_mut_escape_dict["mean"],
    alphabet=polyclonal_data.alphabet,
    epitope_colors=polyclonal_data.epitope_colors,
    stat="mean",
)

In [20]:
polyclonal_data.activity_wt_barplot()

In [23]:
polyclonal.plot.activity_wt_barplot(
    activity_wt_df=summary_activity_wt_dict["std"],
    epitope_colors=polyclonal_data.epitope_colors,
    stat="std",
)