# What library size should we use?

We will test library sizes varying from 1000, 5000, 10000, 200000, 30000 variants to see how these improve model training. We will use a library with 3 mutations on average, and a concentration set = [0.25, 1, 4], which were previously determined to be optimal here and here. 

In [1]:
import pandas as pd
import polyclonal
import pickle
import random
import altair as alt
import numpy

In [3]:
noisy_data = (
    pd.read_csv('RBD_variants_escape_noisy.csv', na_filter=None)
    .query("library == 'avg3muts'")
    .query('concentration in [0.25, 1, 4]')
    .reset_index(drop=True)
    )

noisy_data

Unnamed: 0,library,aa_substitutions,concentration,prob_escape,IC90
0,avg3muts,,0.25,0.00000,0.1128
1,avg3muts,,0.25,0.01090,0.1128
2,avg3muts,,0.25,0.01458,0.1128
3,avg3muts,,0.25,0.09465,0.1128
4,avg3muts,,0.25,0.03299,0.1128
...,...,...,...,...,...
89995,avg3muts,Y449I L518Y C525R L461I,4.00,0.02197,2.3100
89996,avg3muts,Y449V K529R N394R,4.00,0.04925,0.9473
89997,avg3muts,Y451L N481T F490V,4.00,0.02315,0.9301
89998,avg3muts,Y453R V483G L492V N501P I332P,4.00,0.00000,5.0120


In [3]:
library_sizes = [1000, 5000, 10000, 20000, 30000]

for size in library_sizes:
    poly_abs = polyclonal.Polyclonal(data_to_fit=(noisy_data.groupby('concentration')
                                                            .apply(lambda x: x.sample(n=size, random_state=123))
                                                            .reset_index(drop = True)),
                                     activity_wt_df=pd.DataFrame.from_records(
                                         [('1', 1.0),
                                          ('2', 3.0),
                                          ('3', 2.0),
                                          ],
                                         columns=['epitope', 'activity'],
                                         ),
                                     site_escape_df=pd.DataFrame.from_records(
                                         [('1', 417, 10.0),
                                          ('2', 484, 10.0),
                                          ('3', 444, 10.0),
                                          ],
                                         columns=['epitope', 'site', 'escape'],
                                         ),
                                     data_mut_escape_overlap='fill_to_data',
                                 )
    
    opt_res = poly_abs.fit(logfreq=500)
    pickle.dump(poly_abs, open(f'scipy_results/libsize{size}_noisy_3conc_3muts.pkl', 'wb'))
    print(f"Model fit on library with {size} variants to scipy_results/libsize{size}_noisy_3conc_3muts.pkl")

# First fitting site-level model.
# Starting optimization of 519 parameters at Fri Nov 26 11:41:39 2021.
       step   time_sec       loss   fit_loss reg_escape  regspread
          0   0.011101     609.84     609.54    0.29701          0
        500     5.3137     56.556     51.293     5.2629          0
       1000     10.359     55.669     50.234     5.4356          0
       1239     12.641     55.516      50.03     5.4861          0
# Successfully finished at Fri Nov 26 11:41:52 2021.
# Starting optimization of 5448 parameters at Fri Nov 26 11:41:52 2021.
       step   time_sec       loss   fit_loss reg_escape  regspread
          0   0.010517     112.21     51.685     60.526 1.4313e-29
        500     6.2786     36.836     13.726     16.696      6.414
        776     9.6045     36.248     14.414     15.558     6.2759
# Successfully finished at Fri Nov 26 11:42:02 2021.
Model fit on library with 1000 variants to scipy_results/libsize1000_noisy_3conc_3muts.pkl
# First fitting site-le

## Get correlation between predicted and true beta coefficients for each trained model

In [4]:
library_sizes = [1000, 5000, 10000, 20000, 30000]
all_corrs = pd.DataFrame({'epitope' : [], 
                          'correlation' : [], 
                          'num_variants' : []}
                        )

for size in library_sizes:
    model = pickle.load(open(f'scipy_results/libsize{size}_noisy_3conc_3muts.pkl', 'rb'))

    mut_escape_pred = (
        pd.read_csv('RBD_mut_escape_df.csv')
        .merge((model.mut_escape_df
                .assign(epitope=lambda x: 'class ' + x['epitope'].astype(str))
                .rename(columns={'escape': 'predicted escape'})
                ),
               on=['mutation', 'epitope'],
               validate='one_to_one',
               )
        )

    corr = (mut_escape_pred
            .groupby('epitope')
            .apply(lambda x: x['escape'].corr(x['predicted escape']))
            .rename('correlation')
            .reset_index()
            )
    all_corrs = pd.concat([all_corrs, 
                           corr.assign(num_variants = [size] * len(corr.index))]
                         )
all_corrs.head()

Unnamed: 0,epitope,correlation,num_variants
0,class 1,0.024152,1000.0
1,class 2,0.619466,1000.0
2,class 3,0.142082,1000.0
0,class 1,0.231682,5000.0
1,class 2,0.809985,5000.0


In [6]:
base = alt.Chart(all_corrs).mark_point().encode(
    alt.X('num_variants:Q'),
    alt.Y('correlation:Q'),
    alt.Color('epitope:N'),
    tooltip=['num_variants', 'correlation', 'epitope']
)

chart = base + base.transform_loess('num_variants', 'correlation', groupby=['epitope']).mark_line(size=2.5)
chart.save('scipy_results/figures/library_size.pdf')
chart

## Get correlation between predicted and true IC90's for each trained model

In [5]:
exact_data = (
    pd.read_csv('RBD_variants_escape_exact.csv', na_filter=None)
    .query('library == "avg4muts"')
    .query('concentration in [1]')
    .reset_index(drop=True)
    )

In [9]:
model = pickle.load(open(f'scipy_results/libsize1000_noisy_3conc_3muts.pkl', 'rb'))

max_ic90 = 50
ic90s = (exact_data[['aa_substitutions', 'IC90']]
         .assign(IC90=lambda x: x['IC90'].clip(upper=max_ic90))
         .drop_duplicates()
         )

ic90s

Unnamed: 0,aa_substitutions,IC90
0,,0.1128
569,A344E A348G N388L,1.1700
570,A344E F392V D428G Y451Q G482L,4.8070
571,A344E I468L Q493R,8.7400
572,A344E K378I K386S T415K G476A Q493H T531H,22.2700
...,...,...
29995,Y508V C525R,0.4691
29996,Y508V K529E,0.4704
29997,Y508W,0.2285
29998,Y508W C525F,0.4073


In [8]:
#ic90s['aa_substitutions'].tolist()

In [16]:
def remove_unseen_mutations(mut_df, model):
    """Remove mutations in a mutation escape dataframe that were 
    not seen during model fitting. Useful before icXX prediction,
    to ensure only mutations with fit beta coefficients are used.
    
    Paramters
    ---------
    mut_df : pandas.DataFrame
        Must include a column named 'aa_substitutions'.
    model : a fit `Polyclonal` object
    
    Returns
    -------
    mut_df : pandas.DataFrame
        Copy of input dataframe, with rows of variants
        that have unseen mutations removed.
    
    Example
    -------
    """
    all_muts = []
    for row in mut_df.index:
        for mut in mut_df['aa_substitutions'][row].split():
            all_muts.append(mut) 
    all_muts = list(set(all_muts))
 
    seen_muts = []
    for row in model.data_to_fit.index:
        for mut in model.data_to_fit['aa_substitutions'][row].split():
            seen_muts.append(mut)
    seen_muts = list(set(seen_muts))

    unseen_muts = list(set(all_muts) - set(seen_muts))
    
    for row in mut_df.index:
        if not set(mut_df['aa_substitutions'][row].split()).isdisjoint(unseen_muts):
            mut_df.drop(row, inplace=True)
    
    return mut_df

In [15]:
ic90s = remove_unseen_mutations(ic90s, model)
ic90s

Unnamed: 0,aa_substitutions,IC90
0,,0.1128
589,A344K,0.2010
591,A344K D389G S443T E484R G504E Y508H L518P P521R,50.0000
592,A344K D389T G485A T500H K529W,4.0050
594,A344K F392A V483I F490L L517R,3.4750
...,...,...
29995,Y508V C525R,0.4691
29996,Y508V K529E,0.4704
29997,Y508W,0.2285
29998,Y508W C525F,0.4073


In [47]:
all_muts = unique_muts(ic90s)

In [48]:
len(all_muts)

1932

In [58]:
seen_muts = unique_muts(model.data_to_fit)

In [60]:
len(seen_muts)

1815

In [62]:
unseen_muts = list(set(all_muts) - set(seen_muts))

In [78]:
for row in ic90s.index:
    if not set(ic90s['aa_substitutions'][row].split()).isdisjoint(unseen_muts):
        ic90s.drop(row, inplace=True)

ic90s

Unnamed: 0,aa_substitutions,IC90
0,,0.1128
589,A344K,0.2010
591,A344K D389G S443T E484R G504E Y508H L518P P521R,50.0000
592,A344K D389T G485A T500H K529W,4.0050
594,A344K F392A V483I F490L L517R,3.4750
...,...,...
29995,Y508V C525R,0.4691
29996,Y508V K529E,0.4704
29997,Y508W,0.2285
29998,Y508W C525F,0.4073


In [79]:
ic90s = model.icXX(ic90s, x=0.9, col='predicted_IC90', max_c=max_ic90)
ic90s

Unnamed: 0,aa_substitutions,IC90,predicted_IC90
0,,0.1128,0.128422
1,A344K,0.2010,0.220149
2,A344K D389G S443T E484R G504E Y508H L518P P521R,50.0000,14.930351
3,A344K D389T G485A T500H K529W,4.0050,4.390572
4,A344K F392A V483I F490L L517R,3.4750,8.423201
...,...,...,...
24458,Y508V C525R,0.4691,0.674645
24459,Y508V K529E,0.4704,0.434172
24460,Y508W,0.2285,0.196201
24461,Y508W C525F,0.4073,0.228384


In [6]:
library_sizes = [1000, 5000, 10000, 20000, 30000]
ic90_data = pd.DataFrame({'aa_substitutions' : [], 
                          'log_IC90' : [],
                          'predicted_log_IC90' : [],
                          'num_variants' : []}
                          )

max_ic90 = 50
for size in library_sizes:
    model = pickle.load(open(f'scipy_results/libsize{size}_noisy_3conc_3muts.pkl', 'rb'))
    
    # we only need the variants, not the concentration for the IC90 comparison
    ic90s = (exact_data[['aa_substitutions', 'IC90']]
         .assign(IC90=lambda x: x['IC90'].clip(upper=max_ic90))
         .drop_duplicates()
         )

    ic90s = model.icXX(ic90s, x=0.9, col='predicted_IC90', max_c=max_ic90)

    ic90s = (
        ic90s
        .assign(log_IC90=lambda x: numpy.log10(x['IC90']),
                predicted_log_IC90=lambda x: numpy.log10(x['predicted_IC90']),
        )
    )

    corr = ic90s['log_IC90'].corr(ic90s['predicted_log_IC90'])
    print(f"Correlation is {corr:.2f} for library with {size} variants.")
    
    ic90_data = pd.concat([ic90_data,
                          (ic90s.drop(columns=['IC90', 'predicted_IC90'])
                                .assign(num_variants = [size] * len(ic90s.index)))])

ic90_data

ValueError: substitutions not in `allowed_subs`: ['A344E', 'A344N', 'A348H', 'A352E', 'A363I', 'A363M', 'A363W', 'A363Y', 'A372D', 'A372K', 'A372M', 'A475F', 'A475K', 'A520C', 'A522M', 'C391N', 'C391Y', 'D420K', 'D427M', 'E340M', 'E471F', 'E471Q', 'E484K', 'F338I', 'F338W', 'F374Y', 'F377M', 'F377W', 'F429W', 'F464W', 'F486C', 'F490W', 'G339F', 'G339K', 'G413F', 'G413Q', 'G446C', 'G446F', 'G504F', 'I332H', 'I332Q', 'I332S', 'I358W', 'I358Y', 'I468M', 'I472C', 'I472M', 'I472Q', 'K356Q', 'K386E', 'K386Q', 'K417W', 'K458W', 'K462E', 'K462H', 'K528W', 'L335I', 'L390N', 'L425M', 'L517Y', 'N331Q', 'N334C', 'N334D', 'N334G', 'N334Y', 'N394C', 'N394M', 'N450C', 'N481K', 'N481M', 'P337M', 'P384W', 'P527F', 'P527M', 'Q474K', 'Q474M', 'Q493K', 'Q493N', 'Q498C', 'Q498E', 'Q498F', 'Q498H', 'R408F', 'R408W', 'S359I', 'S359W', 'S371I', 'S371W', 'S373M', 'S383M', 'S383Q', 'S459Y', 'S477C', 'S514H', 'S514W', 'S514Y', 'S530E', 'S530K', 'T333W', 'T376C', 'T376M', 'T393D', 'T470M', 'T500W', 'T531Y', 'V367C', 'V367M', 'V395M', 'V483K', 'V503C', 'Y365V', 'Y369G', 'Y369H', 'Y396M', 'Y451M', 'Y473N', 'Y489W']