# Test `PolyclonalAverage`

First we create some models to average.
They should all be similar, but we add random noise and flip the epitope labels for some of them:

In [1]:
# NBVAL_IGNORE_OUTPUT

import numpy

import pandas as pd

import polyclonal


activity_wt_df = pd.DataFrame({"epitope": [1, 2], "activity": [2.0, 1.0]})

mut_escape_df = pd.DataFrame(
    {
        "mutation": [
            "M1C",
            "M1C",
            "G2A",
            "G2A",
            "A4K",
            "A4K",
            "A4L",
            "A4L",
            "A4Q",
            "A4Q",
        ],
        "epitope": [1, 2, 1, 2, 1, 2, 1, 2, 1, 2],
        "escape": [2.0, 0.0, 3.0, 0.0, 0.0, 2.5, 0.0, 1.5, 0.0, 3.5],
    }
)


models = []
flip_epitopes = {1: 2, 2: 1}
keep_epitopes = {1: 1, 2: 2}
n_muts = mut_escape_df["mutation"].nunique()
for i in range(5):
    numpy.random.seed(i)
    a_df = activity_wt_df.assign(
        activity=lambda x: x["activity"] + numpy.random.random(len(x)),
        epitope=lambda x: x["epitope"].map(flip_epitopes if i % 2 else keep_epitopes),
    )
    muts_to_keep = numpy.random.choice(
        mut_escape_df["mutation"].unique(),
        size=n_muts - 1,
        replace=False,
    ).tolist()
    e_df = mut_escape_df.assign(
        escape=lambda x: x["escape"] + numpy.random.random(len(x)),
        epitope=lambda x: x["epitope"].map(flip_epitopes if i % 2 else keep_epitopes),
    ).query("mutation in @muts_to_keep")
    models.append(polyclonal.Polyclonal(mut_escape_df=e_df, activity_wt_df=a_df))

models_df = (
    pd.Series(models)
    .rename_axis("replicate")
    .rename("model")
    .reset_index()
    .assign(
        library=lambda x: x["replicate"].map(lambda r: "A" if r < 3 else "B"),
        replicate=lambda x: x["replicate"].map(lambda r: r % 3),
    )[["library", "replicate", "model"]]
)

models_df

Unnamed: 0,library,replicate,model
0,A,0,<polyclonal.polyclonal.Polyclonal object at 0x...
1,A,1,<polyclonal.polyclonal.Polyclonal object at 0x...
2,A,2,<polyclonal.polyclonal.Polyclonal object at 0x...
3,B,0,<polyclonal.polyclonal.Polyclonal object at 0x...
4,B,1,<polyclonal.polyclonal.Polyclonal object at 0x...


Now make the average model:

In [2]:
avg_model = polyclonal.PolyclonalAverage(models_df)

Get the correlations between library / replicates:

In [3]:
corr = avg_model.mut_escape_corr().assign(r2=lambda x: x["correlation"] ** 2)

corr.round(3)

Unnamed: 0,epitope,correlation,library_1,replicate_1,library_2,replicate_2,r2
0,1,1.0,A,0,A,0,1.0
1,2,1.0,A,0,A,0,1.0
2,1,0.877,A,1,A,0,0.77
3,2,1.0,A,1,A,0,0.999
4,1,0.997,A,2,A,0,0.994
5,2,0.992,A,2,A,0,0.984
6,1,0.999,B,0,A,0,0.998
7,2,0.992,B,0,A,0,0.984
8,1,0.893,B,1,A,0,0.798
9,2,0.988,B,1,A,0,0.976


Plot correlations among  models:

In [4]:
# NBVAL_IGNORE_OUTPUT

avg_model.mut_escape_corr_heatmap()

Activities:

In [5]:
avg_model.activity_wt_df_replicates

Unnamed: 0,epitope,activity,library,replicate
0,1,2.548814,A,0
1,2,1.715189,A,0
2,1,2.417022,A,1
3,2,1.720324,A,1
4,1,2.435995,A,2
5,2,1.025926,A,2
6,1,2.550798,B,0
7,2,1.708148,B,0
8,1,2.96703,B,1
9,2,1.547232,B,1


In [6]:
avg_model.activity_wt_df

Unnamed: 0,epitope,activity_mean,activity_median,activity_std
0,1,2.583932,2.548814,0.222957
1,2,1.543364,1.708148,0.298224


In [7]:
# NBVAL_IGNORE_OUTPUT

avg_model.activity_wt_barplot()

You can also plot mean:

In [8]:
# NBVAL_IGNORE_OUTPUT

avg_model.activity_wt_barplot(avg_type="mean")

Mutation escapes:

In [9]:
avg_model.mut_escape_df_replicates

Unnamed: 0,epitope,site,wildtype,mutant,mutation,escape,library,replicate
0,1,1,M,C,M1C,2.645894,A,0
1,1,2,G,A,G2A,3.891773,A,0
2,1,4,A,K,A4K,0.383442,A,0
3,1,4,A,Q,A4Q,0.925597,A,0
4,2,1,M,C,M1C,0.437587,A,0
5,2,2,G,A,G2A,0.963663,A,0
6,2,4,A,K,A4K,3.291725,A,0
7,2,4,A,Q,A4Q,3.571036,A,0
8,1,1,M,C,M1C,2.236089,A,1
9,1,4,A,K,A4K,0.935539,A,1


In [10]:
avg_model.mut_escape_df

Unnamed: 0,epitope,site,wildtype,mutant,mutation,escape_mean,escape_median,escape_std,n_models,frac_models
0,1,1,M,C,M1C,2.452469,2.463946,0.200167,4,0.8
1,1,2,G,A,G2A,3.545543,3.619271,0.388378,3,0.6
2,1,4,A,K,A4K,0.508091,0.383442,0.394036,5,1.0
3,1,4,A,L,A4L,0.116103,0.029876,0.171201,3,0.6
4,1,4,A,Q,A4Q,0.578654,0.513578,0.2208,5,1.0
5,2,1,M,C,M1C,0.29507,0.300615,0.144206,4,0.8
6,2,2,G,A,G2A,0.490187,0.299655,0.412637,3,0.6
7,2,4,A,K,A4K,3.166589,3.132963,0.159791,5,1.0
8,2,4,A,L,A4L,2.017913,2.024548,0.058047,3,0.6
9,2,4,A,Q,A4Q,3.773534,3.729577,0.200122,5,1.0


In [11]:
# NBVAL_IGNORE_OUTPUT

avg_model.mut_escape_heatmap()

Site escape:

In [12]:
avg_model.mut_escape_site_summary_df_replicates()

Unnamed: 0,epitope,site,wildtype,mean,total positive,max,min,total negative,n mutations,library,replicate
0,1,1,M,2.645894,2.645894,2.645894,2.645894,0.0,1,A,0
1,1,2,G,3.891773,3.891773,3.891773,3.891773,0.0,1,A,0
2,1,4,A,0.654519,1.309038,0.925597,0.383442,0.0,2,A,0
3,2,1,M,0.437587,0.437587,0.437587,0.437587,0.0,1,A,0
4,2,2,G,0.963663,0.963663,0.963663,0.963663,0.0,1,A,0
5,2,4,A,3.431381,6.862761,3.571036,3.291725,0.0,2,A,0
6,1,1,M,2.236089,2.236089,2.236089,2.236089,0.0,1,A,1
7,1,4,A,0.564088,1.692265,0.935539,0.313274,0.0,3,A,1
8,2,1,M,0.396581,0.396581,0.396581,0.396581,0.0,1,A,1
9,2,4,A,3.033479,9.100436,3.729577,2.024548,0.0,3,A,1


In [13]:
avg_model.mut_escape_site_summary_df()

Unnamed: 0,epitope,site,wildtype,metric,escape_mean,escape_median,escape_std,n_models,frac_models,n mutations
0,1,1,M,max,2.452469,2.463946,0.200167,4,0.8,1.0
1,1,1,M,mean,2.452469,2.463946,0.200167,4,0.8,1.0
2,1,1,M,min,2.452469,2.463946,0.200167,4,0.8,1.0
3,1,1,M,total negative,0.0,0.0,0.0,4,0.8,1.0
4,1,1,M,total positive,2.452469,2.463946,0.200167,4,0.8,1.0
5,1,2,G,max,3.545543,3.619271,0.388378,3,0.6,1.0
6,1,2,G,mean,3.545543,3.619271,0.388378,3,0.6,1.0
7,1,2,G,min,3.545543,3.619271,0.388378,3,0.6,1.0
8,1,2,G,total negative,0.0,0.0,0.0,3,0.6,1.0
9,1,2,G,total positive,3.545543,3.619271,0.388378,3,0.6,1.0


In [14]:
# NBVAL_IGNORE_OUTPUT

avg_model.mut_escape_lineplot()

Now make sure things work when averaging just **one** model (an edge case):

In [15]:
avg_one_model = polyclonal.PolyclonalAverage(models_df.head(n=1).copy())

In [16]:
avg_one_model.mut_escape_corr()

Unnamed: 0,epitope,correlation,library_1,replicate_1,library_2,replicate_2
0,1,1.0,A,0,A,0
1,2,1.0,A,0,A,0


In [17]:
# NBVAL_IGNORE_OUTPUT

avg_one_model.mut_escape_corr_heatmap()

In [18]:
avg_one_model.activity_wt_df

Unnamed: 0,epitope,activity_mean,activity_median,activity_std
0,1,2.548814,2.548814,
1,2,1.715189,1.715189,


In [19]:
# NBVAL_IGNORE_OUTPUT

avg_one_model.activity_wt_barplot(avg_type="mean")

In [20]:
avg_one_model.mut_escape_df

Unnamed: 0,epitope,site,wildtype,mutant,mutation,escape_mean,escape_median,escape_std,n_models,frac_models
0,1,1,M,C,M1C,2.645894,2.645894,,1,1.0
1,1,2,G,A,G2A,3.891773,3.891773,,1,1.0
2,1,4,A,K,A4K,0.383442,0.383442,,1,1.0
3,1,4,A,Q,A4Q,0.925597,0.925597,,1,1.0
4,2,1,M,C,M1C,0.437587,0.437587,,1,1.0
5,2,2,G,A,G2A,0.963663,0.963663,,1,1.0
6,2,4,A,K,A4K,3.291725,3.291725,,1,1.0
7,2,4,A,Q,A4Q,3.571036,3.571036,,1,1.0


In [21]:
# NBVAL_IGNORE_OUTPUT

avg_one_model.mut_escape_heatmap()

In [22]:
avg_one_model.mut_escape_site_summary_df()

Unnamed: 0,epitope,site,wildtype,metric,escape_mean,escape_median,escape_std,n_models,frac_models,n mutations
0,1,1,M,max,2.645894,2.645894,,1,1.0,1.0
1,1,1,M,mean,2.645894,2.645894,,1,1.0,1.0
2,1,1,M,min,2.645894,2.645894,,1,1.0,1.0
3,1,1,M,total negative,0.0,0.0,,1,1.0,1.0
4,1,1,M,total positive,2.645894,2.645894,,1,1.0,1.0
5,1,2,G,max,3.891773,3.891773,,1,1.0,1.0
6,1,2,G,mean,3.891773,3.891773,,1,1.0,1.0
7,1,2,G,min,3.891773,3.891773,,1,1.0,1.0
8,1,2,G,total negative,0.0,0.0,,1,1.0,1.0
9,1,2,G,total positive,3.891773,3.891773,,1,1.0,1.0


In [23]:
# NBVAL_IGNORE_OUTPUT

avg_one_model.mut_escape_lineplot()