# Test variant filtering

In [1]:
import pandas as pd
import polyclonal
import random
import numpy
import itertools
import collections

Lets say we train a model on this dataset.

In [2]:
train_df = pd.DataFrame.from_records([
            ('var1', '', 0.5, 0.1),
            ('var2', 'M1A', 0.5, 0.2),
            ('var3', 'M1A G2A', 0.5, 0.3),
            ('var4', 'M1A G2C', 0.5, 0.4),
            ('var5', 'G2A', 1, 0.5),
            ('var6', 'M1A', 1, 0.6),
            ],
            columns=['barcode', 'aa_substitutions', 'concentration', 'prob_escape'])
train_df

Unnamed: 0,barcode,aa_substitutions,concentration,prob_escape
0,var1,,0.5,0.1
1,var2,M1A,0.5,0.2
2,var3,M1A G2A,0.5,0.3
3,var4,M1A G2C,0.5,0.4
4,var5,G2A,1.0,0.5
5,var6,M1A,1.0,0.6


So, our model has "seen" the following mutations.

In [3]:
seen_mutations = ('M1A', 'G2A', 'G2C')
seen_mutations

('M1A', 'G2A', 'G2C')

Now, we want to predict the IC50's in a new dataset.

In [4]:
predict_df = pd.DataFrame.from_records([
            ('var1', '', 0.5),
            ('var2', 'M1C', 0.5),
            ('var3', 'G2A', 0.5),
            ('var4', 'M1C G2C', 0.5),
            ('var5', 'G2A', 1),
            ('var6', 'M1A G2A', 1),
            ('var7', 'G2C', 1),
            ],
            columns=['barcode', 'aa_substitutions', 'concentration'])
predict_df

Unnamed: 0,barcode,aa_substitutions,concentration
0,var1,,0.5
1,var2,M1C,0.5
2,var3,G2A,0.5
3,var4,M1C G2C,0.5
4,var5,G2A,1.0
5,var6,M1A G2A,1.0
6,var7,G2C,1.0


Note that there is a mutation (M1C) that our model did not observe.

In [5]:
predict_mutations = ('M1A', 'G2A', 'G2C', 'M1C')
predict_mutations

('M1A', 'G2A', 'G2C', 'M1C')

We define a function that takes `predict_df` as input and removes variants that contain mutations that were not observed by our model. 

Lets filter out variants that contain mutations that are not seen by the model

In [6]:
poly_abs = polyclonal.Polyclonal(data_to_fit=train_df, n_epitopes=2)
poly_abs.filter_variants_by_seen_muts(predict_df)

Unnamed: 0,barcode,aa_substitutions,concentration
0,var1,,0.5
1,var3,G2A,0.5
2,var5,G2A,1.0
3,var6,M1A G2A,1.0
4,var7,G2C,1.0


Note how only variants containing M1C were removed from the dataframe.

## Realistic example

In [7]:
noisy_data = (
    pd.read_csv('../notebooks/RBD_variants_escape_noisy.csv', na_filter=None)
    .query("library == 'avg3muts'")
    .query('concentration in [0.25, 1, 4]')
    .reset_index(drop=True)
    )

noisy_data

Unnamed: 0,library,aa_substitutions,concentration,prob_escape,IC90
0,avg3muts,,0.25,0.00000,0.1128
1,avg3muts,,0.25,0.01090,0.1128
2,avg3muts,,0.25,0.01458,0.1128
3,avg3muts,,0.25,0.09465,0.1128
4,avg3muts,,0.25,0.03299,0.1128
...,...,...,...,...,...
89995,avg3muts,Y449I L518Y C525R L461I,4.00,0.02197,2.3100
89996,avg3muts,Y449V K529R N394R,4.00,0.04925,0.9473
89997,avg3muts,Y451L N481T F490V,4.00,0.02315,0.9301
89998,avg3muts,Y453R V483G L492V N501P I332P,4.00,0.00000,5.0120


I randomly sample 500 variants from the dataset to train on. This train dataset will not have seen all mutations.

In [8]:
poly_abs = polyclonal.Polyclonal(data_to_fit=(noisy_data.groupby('concentration')
                                                        .apply(lambda x: x.sample(n=500, random_state=123))
                                                        .reset_index(drop = True)),
                                     activity_wt_df=pd.DataFrame.from_records(
                                         [('1', 1.0),
                                          ('2', 3.0),
                                          ('3', 2.0),
                                          ],
                                         columns=['epitope', 'activity'],
                                         ),
                                     site_escape_df=pd.DataFrame.from_records(
                                         [('1', 417, 10.0),
                                          ('2', 484, 10.0),
                                          ('3', 444, 10.0),
                                          ],
                                         columns=['epitope', 'site', 'escape'],
                                         ),
                                     data_mut_escape_overlap='fill_to_data',
                                 )

In [9]:
opt_res = poly_abs.fit(logfreq=500)

# First fitting site-level model.
# Starting optimization of 504 parameters at Fri Dec  3 15:43:15 2021.
       step   time_sec       loss   fit_loss reg_escape  regspread
          0  0.0099931     308.37     308.07    0.29701          0
        500     4.6128     21.298     16.412     4.8862          0
       1000     9.4118      20.62      15.47     5.1501          0
       1038      9.749      20.62     15.472     5.1484          0
# Successfully finished at Fri Dec  3 15:43:25 2021.
# Starting optimization of 4659 parameters at Fri Dec  3 15:43:25 2021.
       step   time_sec       loss   fit_loss reg_escape  regspread
          0   0.009223     67.751     15.733     52.018 8.2874e-30
        427     4.2811      19.25     5.4251     9.9417     3.8831
# Successfully finished at Fri Dec  3 15:43:29 2021.


In [10]:
num_unique_muts = len(poly_abs.mutations)
assert num_unique_muts < 1932
print(f"{num_unique_muts}/1932 mutations were seen during model fitting.")

1552/1932 mutations were seen during model fitting.


Now lets load in a dataset that we want to predict IC90's. 

In [11]:
exact_data = (
    pd.read_csv('../notebooks/RBD_variants_escape_exact.csv', na_filter=None)
    .query('library == "avg3muts"')
    .query('concentration in [0.25, 0.5]')
    .reset_index(drop=True)
    )
exact_data

Unnamed: 0,library,aa_substitutions,concentration,prob_escape,IC90
0,avg3muts,,0.25,0.02512,0.1128
1,avg3muts,,0.25,0.02512,0.1128
2,avg3muts,,0.25,0.02512,0.1128
3,avg3muts,,0.25,0.02512,0.1128
4,avg3muts,,0.25,0.02512,0.1128
...,...,...,...,...,...
59995,avg3muts,Y508V,0.50,0.03305,0.2531
59996,avg3muts,Y508V H519D,0.25,0.39300,1.2960
59997,avg3muts,Y508V H519D,0.50,0.23920,1.2960
59998,avg3muts,Y508W,0.25,0.08727,0.2285


Lets filter out variants that have unseen mutations.

In [12]:
max_ic90 = 50

# we only need the variants, not the concentration for the IC90 comparison
ic90s = (exact_data[['aa_substitutions', 'IC90']]
         .assign(IC90=lambda x: x['IC90'].clip(upper=max_ic90))
         .drop_duplicates()
         )

filtered_ic90s = poly_abs.filter_variants_by_seen_muts(ic90s)
filtered_ic90s = poly_abs.icXX(filtered_ic90s, x=0.9, col='predicted_IC90', max_c=max_ic90)
filtered_ic90s

Unnamed: 0,aa_substitutions,IC90,predicted_IC90
0,,0.1128,0.194264
1,A344K G404K,0.4673,0.826282
2,A344K G446R,0.7924,1.336770
3,A344K K529T,0.5077,0.332364
4,A344K N354R D428A S494M A520P,4.7700,3.995793
...,...,...,...
16291,Y508T,0.2475,0.359703
16292,Y508T L517V,0.6513,0.505955
16293,Y508V,0.2531,0.359603
16294,Y508V H519D,1.2960,2.001409


In [13]:
filtered_muts = numpy.unique(list(
    itertools.chain.from_iterable(filtered_ic90s['aa_substitutions']
                                           .str
                                           .split()
                                           )))

seen_mut_counts = collections.Counter(
        itertools.chain.from_iterable(poly_abs.data_to_fit['aa_substitutions']
                                              .str
                                              .split()
                                              ))
for m in filtered_muts:
    assert seen_mut_counts[m] >= 1