# Fit model to data
We will fit a `Polyclonal` model to the RBD antibody mix we simulated.

First, we read in that simulated data.
Recall that we simulated both "exact" and "noisy" data, with several average per-library mutations rates, and at six different concentrations.
Here we analyze the noisy data for the library with an average of 2 mutations per gene, measured at three different concentrations, as this represents a fairly realistic representation of a real experiment:

In [6]:
import pandas as pd

import polyclonal

noisy_data = (
    pd.read_csv('RBD_variants_escape_noisy.csv', na_filter=None)
    .query('library == "avg2muts"')
    .query('concentration in [0.25, 1, 0.5]')
    .reset_index(drop=True)
    )

noisy_data

Unnamed: 0,library,aa_substitutions,concentration,prob_escape
0,avg2muts,,0.25,0.05044
1,avg2muts,,0.25,0.14310
2,avg2muts,,0.25,0.05452
3,avg2muts,,0.25,0.08473
4,avg2muts,,0.25,0.04174
...,...,...,...,...
89995,avg2muts,Y508V,1.00,0.02958
89996,avg2muts,Y508V A520L,1.00,0.05519
89997,avg2muts,Y508V H519N,1.00,0.07836
89998,avg2muts,Y508W,1.00,0.01435


Initialize a `Polyclonal` model with these data, including three epitopes.
We know from [prior work](https://www.nature.com/articles/s41467-021-24435-8) the three most important epitopes and a key mutation in each, so we use this prior knowledge to "seed" initial guesses that assign large escape values to a key site in each epitope:

 - site 417 for class 1 epitope, which is often the least important
 - site 484 for class 2 epitope, which is often the dominant one
 - site 444 for class 3 epitope, which is often the second most dominant one

In [2]:
poly_abs = polyclonal.Polyclonal(data_to_fit=variants_df,
                                 activity_wt_df=pd.DataFrame.from_records(
                                         [('1', 1.0),
                                          ('2', 3.0),
                                          ('3', 2.0),
                                          ],
                                         columns=['epitope', 'activity'],
                                         ),
                                 site_escape_df=pd.DataFrame.from_records(
                                         [('1', 417, 10.0),
                                          ('2', 484, 10.0),
                                          ('3', 444, 10.0),
                                          ],
                                         columns=['epitope', 'site', 'escape'],
                                         ),
                                 data_mut_escape_overlap='fill_to_data',
                                 )

In [3]:
opt_res = poly_abs.fit(logfreq=100)

# First fitting site-level model.
# Starting optimization of 522 parameters at Thu Nov 18 05:45:03 2021.
       step   time_sec       loss   fit_loss reg_escape  regspread
          0   0.050129      12297      12295     1.4269          0
        100     5.6538     1543.5     1531.8     11.674          0
        200     11.324     1522.9     1508.9     14.043          0
        300      16.91     1518.8     1503.6     15.205          0
        400     22.425     1517.2     1501.3      15.85          0
        500     27.853     1515.5       1498     17.482          0
        600     33.118     1514.1     1496.8     17.283          0
        700     38.417     1512.5     1495.6     16.997          0
        800     43.813     1511.9     1494.6     17.231          0
        900     49.357     1511.5     1493.7     17.711          0
       1000     54.808     1511.3     1493.2     18.076          0
       1094     59.816     1511.1       1493     18.087          0
# Successfully finished 

In [4]:
poly_abs.activity_wt_barplot()

In [5]:
poly_abs.mut_escape_lineplot()

In [22]:
exact_data = (
    pd.read_csv('RBD_variants_escape_exact.csv', na_filter=None)
    .query('library == "avg4muts"')
    .query('concentration in [0.25, 1, 0.5]')
    .reset_index(drop=True)
    )

exact_data

Unnamed: 0,library,aa_substitutions,concentration,prob_escape
0,avg4muts,,0.25,0.025120
1,avg4muts,,0.25,0.025120
2,avg4muts,,0.25,0.025120
3,avg4muts,,0.25,0.025120
4,avg4muts,,0.25,0.025120
...,...,...,...,...
89995,avg4muts,Y508V C525R,1.00,0.027300
89996,avg4muts,Y508V K529E,1.00,0.027430
89997,avg4muts,Y508W,1.00,0.005799
89998,avg4muts,Y508W C525F,1.00,0.019370


In [25]:
exact_vs_pred = (
    poly_abs.prob_escape(variants_df=exact_data)
    .assign(n_aa_substitutions=lambda x: x['aa_substitutions'].map(lambda s: sum(c != "" for c in s.split())))
    .query('n_aa_substitutions > 2')
    )

exact_vs_pred

Unnamed: 0,library,aa_substitutions,concentration,prob_escape,predicted_prob_escape,n_aa_substitutions
569,avg4muts,A344E A348G N388L,0.25,0.4535,0.455409,3
570,avg4muts,A344E F392V D428G Y451Q G482L,0.25,0.7929,0.797678,5
571,avg4muts,A344E I468L Q493R,0.25,0.7957,0.772309,3
572,avg4muts,A344E K378I K386S T415K G476A Q493H T531H,0.25,0.9452,0.937859,7
573,avg4muts,A344E K386A L492I S514R,0.25,0.6497,0.635111,4
...,...,...,...,...,...,...
89969,avg4muts,Y489L L518H P527L,1.00,0.1244,0.136090,3
89970,avg4muts,Y489M G496S G526L,1.00,0.1257,0.097216,3
89974,avg4muts,Y489Q G496R S514R,1.00,0.1224,0.115763,3
89984,avg4muts,Y505N K528T K529Q T531Q,1.00,0.1713,0.132516,4


In [26]:
(exact_vs_pred
 .groupby(['concentration'])
 .apply(lambda x: x['prob_escape'].corr(x['predicted_prob_escape']))
 )

concentration
0.25    0.985599
0.50    0.986294
1.00    0.985332
dtype: float64

In [46]:
exact_mut_escape = pd.read_csv('RBD_mut_escape_df.csv')

corrs = []
for _, exact_ep_escape in exact_mut_escape.groupby('epitope'):
    exact_vs_pred = (
        exact_ep_escape
        .rename(columns={'escape': 'actual_escape'})
        [['mutation', 'actual_escape']]
        .merge(poly_abs.mut_escape_df.rename(columns={'escape': 'pred_escape'}),
               on='mutation',
               validate='one_to_many',
              )
        )
    ep_corrs = (exact_vs_pred
                .groupby('epitope')
                .apply(lambda x: x['actual_escape'].corr(x['pred_escape']))
                )
    corrs.append(ep_corrs.max())

print(corrs)

[0.7144821835556919, 0.9303172902862691, 0.865877338208553]
