# Choosing library size

In [1]:
import pandas as pd
import polyclonal
import pickle
import altair as alt
import numpy as np
import time
import os

In [2]:
noisy_data = (
    pd.read_csv('RBD_variants_escape_noisy.csv', na_filter=None)
    .query("library == 'avg3muts'")
    .query('concentration in [0.25, 1, 4]')
    .reset_index(drop=True)
    )

noisy_data

Unnamed: 0,library,aa_substitutions,concentration,prob_escape,IC90
0,avg3muts,,0.25,0.00000,0.1128
1,avg3muts,,0.25,0.01090,0.1128
2,avg3muts,,0.25,0.01458,0.1128
3,avg3muts,,0.25,0.09465,0.1128
4,avg3muts,,0.25,0.03299,0.1128
...,...,...,...,...,...
89995,avg3muts,Y449I L518Y C525R L461I,4.00,0.02197,2.3100
89996,avg3muts,Y449V K529R N394R,4.00,0.04925,0.9473
89997,avg3muts,Y451L N481T F490V,4.00,0.02315,0.9301
89998,avg3muts,Y453R V483G L492V N501P I332P,4.00,0.00000,5.0120


In [4]:
library_sizes = [500, 1000, 2500, 5000, 10000, 20000]
os.makedirs('scipy_results', exist_ok=True)
       
def fit_polyclonal(size):
    poly_abs = polyclonal.Polyclonal(data_to_fit=noisy_data.groupby('concentration')
                                                           .apply(lambda x: x.sample(n=size, random_state=123))
                                                           .reset_index(drop = True),
                                     activity_wt_df=pd.DataFrame.from_records(
                                         [('1', 1.0),
                                          ('2', 3.0),
                                          ('3', 2.0),
                                          ],
                                         columns=['epitope', 'activity'],
                                         ),
                                     site_escape_df=pd.DataFrame.from_records(
                                         [('1', 417, 10.0),
                                          ('2', 484, 10.0),
                                          ('3', 444, 10.0),
                                          ],
                                         columns=['epitope', 'site', 'escape'],
                                         ),
                                     data_mut_escape_overlap='fill_to_data',
                                 )
    start = time.time()
    poly_abs.fit()
    return poly_abs, time.time() - start

fit_models = {}
for size in library_sizes:
    model_string = f'noisy_[0.25, 1, 4]conc_3muts_{size}vars'
    if os.path.exists(f'scipy_results/{model_string}.pkl') is True:
        model = pickle.load(open(f'scipy_results/{model_string}.pkl', 'rb'))
        fit_models.update({model_string : model})
        print(f"Model with {size} variants was already fit.")
    else:
        model, time_elapsed = fit_polyclonal(size)
        fit_models.update({model_string : model})
        pickle.dump(model, open(f'scipy_results/{model_string}.pkl', 'wb'))
        print(f"Model with {size} variants fit in {time_elapsed:.1f} seconds.")

Model with 500 variants was already fit.
Model with 1000 variants fit in 19.7 seconds.
Model with 2500 variants fit in 28.9 seconds.
Model with 5000 variants fit in 56.5 seconds.
Model with 10000 variants fit in 70.8 seconds.
Model with 20000 variants fit in 230.5 seconds.


In [5]:
full_model_string = f'noisy_[0.25, 1, 4]conc_3muts'
full_model = pickle.load(open(f'scipy_results/{full_model_string}.pkl', 'rb'))
fit_models.update({f'{full_model_string}_30000vars' : full_model})

In [6]:
library_sizes = [500, 1000, 2500, 5000, 10000, 20000, 30000]

all_corrs = pd.DataFrame({'epitope' : [], 
                          'correlation' : [], 
                          'num_variants' : []})

for size in library_sizes:
    model = fit_models[f'noisy_[0.25, 1, 4]conc_3muts_{size}vars']

    mut_escape_pred = (
        pd.read_csv('RBD_mut_escape_df.csv')
        .merge((model.mut_escape_df
                .assign(epitope=lambda x: 'class ' + x['epitope'].astype(str))
                .rename(columns={'escape': 'predicted escape'})
                ),
               on=['mutation', 'epitope'],
               validate='one_to_one',
               )
        )

    corr = (mut_escape_pred
            .groupby('epitope')
            .apply(lambda x: x['escape'].corr(x['predicted escape']))
            .rename('correlation')
            .reset_index()
            )
    
    all_corrs = pd.concat([all_corrs, 
                    corr.assign(num_variants = [str(size)]* len(corr.index))
                        ])

In [8]:
# NBVAL_IGNORE_OUTPUT
base = alt.Chart(all_corrs).mark_point().encode(
    alt.X('num_variants:Q'),
    alt.Y('correlation:Q'),
    alt.Color('epitope:N'),
    tooltip=['num_variants', 'correlation', 'epitope']
)
base + base.transform_loess('num_variants', 'correlation', groupby=['epitope']
                           ).mark_line(size=2.5
                                      ).properties(title='predicted vs. true beta coefficients')

In [9]:
exact_data = (
    pd.read_csv('RBD_variants_escape_exact.csv', na_filter=None)
    .query('library == "avg4muts"')
    .query('concentration in [1]')
    .reset_index(drop=True)
    )

In [10]:
ic90_corrs = pd.DataFrame({'correlation' : [], 
                           'num_variants' : []})

max_ic90 = 50
for size in library_sizes:
    model = fit_models[f'noisy_[0.25, 1, 4]conc_3muts_{size}vars']
    
    ic90s = (exact_data[['aa_substitutions', 'IC90']]
         .assign(IC90=lambda x: x['IC90'].clip(upper=max_ic90))
         .drop_duplicates()
         )
    ic90s = model.filter_variants_by_seen_muts(ic90s)
    ic90s = model.icXX(ic90s, x=0.9, col='predicted_IC90', max_c=max_ic90)

    ic90s = (
        ic90s
        .assign(log_IC90=lambda x: np.log10(x['IC90']),
            predicted_log_IC90=lambda x: np.log10(x['predicted_IC90']),
            )
    )

    corr = ic90s['log_IC90'].corr(ic90s['predicted_log_IC90'])
    
    ic90_corrs = pd.concat([ic90_corrs,
                    pd.DataFrame({'correlation' : corr,
                                  'num_variants' : [str(size)]})])

In [13]:
# NBVAL_IGNORE_OUTPUT
base = alt.Chart(ic90_corrs).mark_point().encode(
    alt.X('num_variants:Q'),
    alt.Y('correlation:Q'),
    tooltip=['num_variants', 'correlation']
)
base + base.transform_loess('num_variants', 'correlation'
                           ).mark_line(size=2.5
                                      ).properties(title='predicted vs. true IC90')

In [14]:
model = fit_models[f'noisy_[0.25, 1, 4]conc_3muts_500vars']
    
ic90s = (exact_data[['aa_substitutions', 'IC90']]
         .assign(IC90=lambda x: x['IC90'].clip(upper=max_ic90))
         .drop_duplicates()
         )
ic90s = model.filter_variants_by_seen_muts(ic90s)
ic90s = model.icXX(ic90s, x=0.9, col='predicted_IC90', max_c=max_ic90)

ic90s = (
        ic90s
        .assign(log_IC90=lambda x: np.log10(x['IC90']),
            predicted_log_IC90=lambda x: np.log10(x['predicted_IC90']),
            )
    )
ic90s

Unnamed: 0,aa_substitutions,IC90,predicted_IC90,log_IC90,predicted_log_IC90
0,,0.1128,0.194264,-0.947691,-0.711608
1,A344K,0.2010,0.253763,-0.696804,-0.595571
2,A344K F392A V483I F490L L517R,3.4750,8.518010,0.540955,0.930338
3,A344K G446L N501S,2.7660,4.732901,0.441852,0.675127
4,A344K N370R D428R,0.7868,1.424515,-0.104136,0.153667
...,...,...,...,...,...
15918,Y508V C525R,0.4691,0.624992,-0.328735,-0.204126
15919,Y508V K529E,0.4704,0.400234,-0.327533,-0.397686
15920,Y508W,0.2285,0.468951,-0.641114,-0.328872
15921,Y508W C525F,0.4073,0.615208,-0.390086,-0.210978


In [22]:
alt.Chart(ic90s).mark_circle(size=30, opacity=0.3).encode(
    x='predicted_log_IC90',
    y='log_IC90',
).properties(title='predicted vs. true IC90 with 500 variant library')

In [27]:
model = fit_models[f'noisy_[0.25, 1, 4]conc_3muts_10000vars']
    
ic90s = (exact_data[['aa_substitutions', 'IC90']]
         .assign(IC90=lambda x: x['IC90'].clip(upper=max_ic90))
         .drop_duplicates()
         )
ic90s = model.filter_variants_by_seen_muts(ic90s)
ic90s = model.icXX(ic90s, x=0.9, col='predicted_IC90', max_c=max_ic90)

ic90s = (
        ic90s
        .assign(log_IC90=lambda x: np.log10(x['IC90']),
            predicted_log_IC90=lambda x: np.log10(x['predicted_IC90']),
            )
    )
ic90s

Unnamed: 0,aa_substitutions,IC90,predicted_IC90,log_IC90,predicted_log_IC90
0,,0.1128,0.120042,-0.947691,-0.920667
1,A344E A348G N388L,1.1700,1.163510,0.068186,0.065770
2,A344E F392V D428G Y451Q G482L,4.8070,5.724790,0.681874,0.757760
3,A344E I468L Q493R,8.7400,9.403300,0.941511,0.973280
4,A344E K378I K386S T415K G476A Q493H T531H,22.2700,12.728816,1.347720,1.104788
...,...,...,...,...,...
28377,Y508V C525R,0.4691,0.440172,-0.328735,-0.356378
28378,Y508V K529E,0.4704,0.487389,-0.327533,-0.312125
28379,Y508W,0.2285,0.218778,-0.641114,-0.659996
28380,Y508W C525F,0.4073,0.391173,-0.390086,-0.407632


In [28]:
alt.Chart(ic90s).mark_circle(size=30, opacity=0.3).encode(
    x='predicted_log_IC90',
    y='log_IC90',
).properties(title='predicted vs. true IC90 with 10000 variant library')