# Choosing library mutation rate

In [1]:
import pandas as pd
import polyclonal
import pickle
import altair as alt
import numpy as np
import time
import os

In [2]:
noisy_data = (
    pd.read_csv('RBD_variants_escape_noisy.csv', na_filter=None)
    .query('concentration in [0.25, 1, 4]')
    .reset_index(drop=True)
    )

noisy_data

Unnamed: 0,library,aa_substitutions,concentration,prob_escape,IC90
0,avg1muts,,0.25,0.087480,0.1128
1,avg1muts,,0.25,0.034240,0.1128
2,avg1muts,,0.25,0.037880,0.1128
3,avg1muts,,0.25,0.035730,0.1128
4,avg1muts,,0.25,0.000000,0.1128
...,...,...,...,...,...
359995,avg2muts,Y473E L518F D427L,4.00,0.002918,1.1600
359996,avg1muts,Y473S G413Q,4.00,0.000000,0.5780
359997,avg1muts,Y473V P479R F392W,4.00,0.160200,1.4550
359998,avg3muts,Y489Q N501Y,4.00,0.000000,0.5881


In [3]:
avg_mut_rates = [1,2,3,4]
os.makedirs('scipy_results', exist_ok=True)
       
def fit_polyclonal(num):
    poly_abs = polyclonal.Polyclonal(data_to_fit=noisy_data.query(f"library == 'avg{num}muts'"),
                                     activity_wt_df=pd.DataFrame.from_records(
                                         [('1', 1.0),
                                          ('2', 3.0),
                                          ('3', 2.0),
                                          ],
                                         columns=['epitope', 'activity'],
                                         ),
                                     site_escape_df=pd.DataFrame.from_records(
                                         [('1', 417, 10.0),
                                          ('2', 484, 10.0),
                                          ('3', 444, 10.0),
                                          ],
                                         columns=['epitope', 'site', 'escape'],
                                         ),
                                     data_mut_escape_overlap='fill_to_data',
                                 )
    start = time.time()
    poly_abs.fit()
    return poly_abs, time.time() - start

fit_models = {}
for n in avg_mut_rates:
    model_string = f'noisy_[0.25, 1, 4]conc_{n}muts'
    if os.path.exists(f'scipy_results/{model_string}.pkl') is True:
        model = pickle.load(open(f'scipy_results/{model_string}.pkl', 'rb'))
        fit_models.update({model_string : model})
        print(f"Model with {n} average mutation rate was already fit.")
    else:
        model, time_elapsed = fit_polyclonal(n)
        fit_models.update({model_string : model})
        pickle.dump(model, open(f'scipy_results/noisy_[0.25, 1, 4]conc_{n}muts.pkl', 'wb'))
        print(f"Model with {n} average mutation rate fit in {time_elapsed:.1f} seconds.")  

Model with 1 average mutation rate was already fit.
Model with 2 average mutation rate was already fit.
Model with 3 average mutation rate was already fit.
Model with 4 average mutation rate was already fit.


In [4]:
all_corrs = pd.DataFrame({'epitope' : [], 
                          'correlation' : [], 
                          'mutation_rate' : []})

for n in avg_mut_rates:
    model = fit_models[f'noisy_[0.25, 1, 4]conc_{n}muts']

    mut_escape_pred = (
        pd.read_csv('RBD_mut_escape_df.csv')
        .merge((model.mut_escape_df
                .assign(epitope=lambda x: 'class ' + x['epitope'].astype(str))
                .rename(columns={'escape': 'predicted escape'})
                ),
               on=['mutation', 'epitope'],
               validate='one_to_one',
               )
        )

    corr = (mut_escape_pred
            .groupby('epitope')
            .apply(lambda x: x['escape'].corr(x['predicted escape']))
            .rename('correlation')
            .reset_index()
            )
    
    all_corrs = pd.concat([all_corrs, 
                    corr.assign(mutation_rate = [f'avg{n}muts'] * len(corr.index))
                        ])

In [5]:
# NBVAL_IGNORE_OUTPUT
alt.Chart(all_corrs).mark_circle(size=125).encode(
    x= alt.X('mutation_rate:O', 
             sort=alt.EncodingSortField('x', order='descending')),
    y='correlation:Q',
    column='epitope:N',
    tooltip = ['mutation_rate', 'correlation'],
    color=alt.Color('epitope', legend=None),
).properties(width=200, height=200, title='predicted vs. true beta coefficients')

In [6]:
exact_data = (
    pd.read_csv('RBD_variants_escape_exact.csv', na_filter=None)
    .query('library == "avg4muts"')
    .query('concentration in [0.5]')
    .reset_index(drop=True)
    )

In [8]:
ic90_corrs = pd.DataFrame({'correlation' : [], 
                           'mutation_rate' : []})

ic90_data = pd.DataFrame({'log_IC90' : [],
                          'predicted_log_IC90' : [],
                          'mutation_rate' : []})
max_ic90 = 50
for n in avg_mut_rates:
    model = fit_models[f'noisy_[0.25, 1, 4]conc_{n}muts']
    
    ic90s = (exact_data[['aa_substitutions', 'IC90']]
         .assign(IC90=lambda x: x['IC90'].clip(upper=max_ic90))
         .drop_duplicates()
         )
    ic90s = model.filter_variants_by_seen_muts(ic90s)
    ic90s = model.icXX(ic90s, x=0.9, col='predicted_IC90', max_c=max_ic90)

    ic90s = (
        ic90s
        .assign(log_IC90=lambda x: np.log10(x['IC90']),
            predicted_log_IC90=lambda x: np.log10(x['predicted_IC90']),
            )
    )

    corr = ic90s['log_IC90'].corr(ic90s['predicted_log_IC90'])
    
    ic90_corrs = pd.concat([ic90_corrs,
                    pd.DataFrame({'correlation' : corr,
                                  'mutation_rate' : [f'avg{n}muts']})])
    ic90_data = pd.concat([ic90_data,
                    pd.DataFrame({'log_IC90' : ic90s['log_IC90'],
                                  'predicted_log_IC90' : ic90s['predicted_log_IC90'],
                                  'mutation_rate' : [f'avg{n}muts'] * len(ic90s)})])
    

In [9]:
# NBVAL_IGNORE_OUTPUT
alt.Chart(ic90_corrs).mark_circle(size=125).encode(
    x='mutation_rate:O',
    y='correlation:Q',
    tooltip = ['mutation_rate', 'correlation']
).properties(width=200, height=200, title='predicted vs. true IC90')

In [19]:
alt.Chart(ic90_data.query("mutation_rate in ['avg1muts', 'avg3muts']")
         ).mark_circle(size=30, opacity=0.3).encode(
    x='predicted_log_IC90',
    y='log_IC90',
    color=alt.Color('mutation_rate',legend=None),
).properties(width=200, height=200, title='predicted vs. true IC90'
            ).facet(facet='mutation_rate',
                    columns=1
       )