# Test variant filtering

In [1]:
import pandas as pd
import polyclonal

Lets say we train a model on this dataset.

In [2]:
train_df = pd.DataFrame.from_records(
    [
        ("var1", "", 0.5, 0.1),
        ("var2", "M1A", 0.5, 0.2),
        ("var3", "M1A G2A", 0.5, 0.3),
        ("var4", "M1A G2C", 0.5, 0.4),
        ("var5", "G2A", 1, 0.5),
        ("var6", "M1A", 1, 0.6),
    ],
    columns=["barcode", "aa_substitutions", "concentration", "prob_escape"],
)
train_df

Unnamed: 0,barcode,aa_substitutions,concentration,prob_escape
0,var1,,0.5,0.1
1,var2,M1A,0.5,0.2
2,var3,M1A G2A,0.5,0.3
3,var4,M1A G2C,0.5,0.4
4,var5,G2A,1.0,0.5
5,var6,M1A,1.0,0.6


So, our model has "seen" the following mutations.

In [3]:
seen_mutations = ("M1A", "G2A", "G2C")
seen_mutations

('M1A', 'G2A', 'G2C')

Now, we want to predict the IC50's in a new dataset.

In [4]:
predict_df = pd.DataFrame.from_records(
    [
        ("var1", "", 0.5),
        ("var2", "M1C", 0.5),
        ("var3", "G2A", 0.5),
        ("var4", "M1C G2C", 0.5),
        ("var5", "G2A", 1),
        ("var6", "M1A G2A", 1),
        ("var7", "G2C", 1),
    ],
    columns=["barcode", "aa_substitutions", "concentration"],
)
predict_df

Unnamed: 0,barcode,aa_substitutions,concentration
0,var1,,0.5
1,var2,M1C,0.5
2,var3,G2A,0.5
3,var4,M1C G2C,0.5
4,var5,G2A,1.0
5,var6,M1A G2A,1.0
6,var7,G2C,1.0


Note that there is a mutation (M1C) that our model did not observe.

In [5]:
predict_mutations = ("M1A", "G2A", "G2C", "M1C")
predict_mutations

('M1A', 'G2A', 'G2C', 'M1C')

We define a function that takes `predict_df` as input and removes variants that contain mutations that were not observed by our model. 

Lets filter out variants that contain mutations that are not seen by the model.

In [6]:
poly_abs = polyclonal.Polyclonal(data_to_fit=train_df, n_epitopes=2)
filtered_df = poly_abs.filter_variants_by_seen_muts(predict_df)
filtered_df

Unnamed: 0,barcode,aa_substitutions,concentration
0,var1,,0.5
1,var3,G2A,0.5
2,var5,G2A,1.0
3,var6,M1A G2A,1.0
4,var7,G2C,1.0


In [7]:
assert filtered_df["aa_substitutions"].str.contains("M1C").any() == False

Note how only variants containing M1C were removed from the dataframe.