# Test epitope harmonizing

## Setup

In [1]:
import numpy

import random

import pandas as pd

from polyclonal import Polyclonal

import polyclonal.utils as utils

import unittest

activity_wt_df = pd.DataFrame({"epitope": ["1", "2"], "activity": [2.0, 1.0]})

mut_escape_df = pd.DataFrame(
    {
        "mutation": [
            "M1C",
            "M1C",
            "G2A",
            "G2A",
            "A4K",
            "A4K",
            "A4L",
            "A4L",
            "A4Q",
            "A4Q",
        ],
        "epitope": ["1", "2", "1", "2", "1", "2", "1", "2", "1", "2"],
        "escape": [2.0, 0.0, 3.0, 0.0, 0.0, 2.5, 0.0, 1.5, 0.0, 3.5],
    }
)

polyclonal_sim = Polyclonal(activity_wt_df=activity_wt_df, mut_escape_df=mut_escape_df)

variants_df = pd.DataFrame.from_records(
    [
        ("AA", ""),
        ("AC", "M1C"),
        ("AG", "G2A"),
        ("AT", "A4K"),
        ("TA", "A4L"),
        ("GA", "A4Q"),
        ("CA", "M1C G2A"),
        ("CG", "M1C A4K"),
        ("TT", "M1C A4L"),
        ("GT", "M1C A4Q"),
        ("CC", "G2A A4K"),
        ("TC", "G2A A4L"),
        ("GG", "G2A A4Q"),
        ("CT", "M1C G2A A4K"),
        ("TG", "M1C G2A A4L"),
        ("GA", "M1C G2A A4Q"),
    ],
    columns=["barcode", "aa_substitutions"],
)

escape_probs = polyclonal_sim.prob_escape(
    variants_df=variants_df, concentrations=[1.0, 2.0, 4.0]
)

data_to_fit = escape_probs.rename(columns={"predicted_prob_escape": "prob_escape"})

In [2]:
n_eps = 2
poly_one = Polyclonal(
    data_to_fit=data_to_fit, n_epitopes=n_eps, activity_wt_df=None, site_escape_df=None
)
poly_two = Polyclonal(
    data_to_fit=data_to_fit, n_epitopes=n_eps, activity_wt_df=None, site_escape_df=None
)

In [3]:
random.seed(1)
_ = poly_one.fit(fit_site_level_first=False)
_ = poly_two.fit(fit_site_level_first=False)

## Tests

### Epitope correlation
The following tests assess if the helper methods for epitope harmonizing all work when we have two identical models (1s in correlation matrix should be on diagonal):

In [4]:
corr_df = poly_two.mut_escape_corr(poly_one)
assert len(corr_df) == n_eps**2
assert corr_df.correlation.between(-1, 1).all()
corr_df

Unnamed: 0,ref_epitope,self_epitope,correlation
0,1,1,1.0
1,1,2,-0.872404
2,2,1,-0.872404
3,2,2,1.0


#### Scenario: Flipped `mut_escape_df`

Now, we will create an example where we train two models that learn the same parameters but flip the epitopes (i.e., I'm just going to flip the values in `mut_escape_df` from one model, and re-create polyclonal objects:

In [5]:
# Create a test example where two models "flipped" the epitopes
one_df = poly_one.mut_escape_df
two_df = poly_two.mut_escape_df
one_wt_df = poly_one.activity_wt_df
two_wt_df = pd.DataFrame(
    {"epitope": one_wt_df.epitope.values, "activity": one_wt_df.activity[::-1]}
)

# Create a "flipped" version of the `activity_wt_df`
two_df["escape"] = one_df.escape[5:10].tolist() + one_df.escape[0:5].tolist()

# Create polyclonal objects (can't seem to edit `mut_escape_df` bc it's a property)
original_poly = Polyclonal(
    mut_escape_df=one_df, activity_wt_df=one_wt_df, data_to_fit=None
)
flipped_poly = Polyclonal(
    mut_escape_df=two_df, activity_wt_df=two_wt_df, data_to_fit=None
)

##### Sanity checks
Here, we just want to make sure the `mut_escape_df` properties for each `Polyclonal` object are flipped. 
This should result in the escape values for a given mutation (i.e. M1C), in epitopes 1 and of the `orignal_poly` object to be flipped in the `flipped_poly` object, and so on.

This should manifest in the `_params` data field as binary swaps across the arrays (i.e. [1, 2, 3, 4] --> [2,1,4,3]) after the first two epitope params (because params are on a mutation-epitope ordering)

In [6]:
original_poly.mut_escape_df

Unnamed: 0,epitope,site,wildtype,mutant,mutation,escape
0,1,1,M,C,M1C,0.152164
1,1,2,G,A,G2A,0.042005
2,1,4,A,K,A4K,2.358184
3,1,4,A,L,A4L,1.513825
4,1,4,A,Q,A4Q,3.036481
5,2,1,M,C,M1C,2.067248
6,2,2,G,A,G2A,3.227758
7,2,4,A,K,A4K,0.012367
8,2,4,A,L,A4L,-0.060366
9,2,4,A,Q,A4Q,0.077058


In [7]:
flipped_poly.mut_escape_df

Unnamed: 0,epitope,site,wildtype,mutant,mutation,escape
0,1,1,M,C,M1C,2.067248
1,1,2,G,A,G2A,3.227758
2,1,4,A,K,A4K,0.012367
3,1,4,A,L,A4L,-0.060366
4,1,4,A,Q,A4Q,0.077058
5,2,1,M,C,M1C,0.152164
6,2,2,G,A,G2A,0.042005
7,2,4,A,K,A4K,2.358184
8,2,4,A,L,A4L,1.513825
9,2,4,A,Q,A4Q,3.036481


#### Scenario: Helper method input violations
Another set of tests on the helper methods, this time, 1s should be on the "off-diagonal"

In [8]:
corr_df2 = flipped_poly.mut_escape_corr(original_poly)
assert len(corr_df2) == n_eps**2
assert corr_df2.correlation.between(-1, 1).all()
corr_df2

Unnamed: 0,ref_epitope,self_epitope,correlation
0,1,1,-0.872404
1,1,2,1.0
2,2,1,1.0
3,2,2,-0.872404


### Epitope harmonization

Now we harmonize the flipped object with the original one -- since these are the exact same dataframes but with flipped epitopes, after harmonization, `flipped_poly.mut_escape_df` should be equal to `original_poly.mut_escape_df`.

We should also have equal `activity_wt_df` and `_params` propertoes after harmonization as well.

In [9]:
flipped_poly.mut_escape_df

Unnamed: 0,epitope,site,wildtype,mutant,mutation,escape
0,1,1,M,C,M1C,2.067248
1,1,2,G,A,G2A,3.227758
2,1,4,A,K,A4K,0.012367
3,1,4,A,L,A4L,-0.060366
4,1,4,A,Q,A4Q,0.077058
5,2,1,M,C,M1C,0.152164
6,2,2,G,A,G2A,0.042005
7,2,4,A,K,A4K,2.358184
8,2,4,A,L,A4L,1.513825
9,2,4,A,Q,A4Q,3.036481


In [10]:
assert not any(
    flipped_poly.mut_escape_df[["epitope", "mutation", "escape"]]
    .merge(original_poly.mut_escape_df, how="outer", on=["epitope", "mutation"])
    .assign(equal=lambda x: x["escape_x"] == x["escape_y"])["equal"]
)
assert not any(
    flipped_poly.activity_wt_df[["epitope", "activity"]]
    .merge(original_poly.activity_wt_df, how="outer", on=["epitope"])
    .assign(equal=lambda x: x["activity_x"] == x["activity_y"])["equal"]
)

map_df = flipped_poly.harmonize_epitopes_with(original_poly)

assert all(
    flipped_poly.mut_escape_df[["epitope", "mutation", "escape"]]
    .merge(original_poly.mut_escape_df, how="outer", on=["epitope", "mutation"])
    .assign(equal=lambda x: x["escape_x"] == x["escape_y"])["equal"]
)
assert all(
    flipped_poly.activity_wt_df[["epitope", "activity"]]
    .merge(original_poly.activity_wt_df, how="outer", on=["epitope"])
    .assign(equal=lambda x: x["activity_x"] == x["activity_y"])["equal"]
)

map_df

Unnamed: 0,self_initial_epitope,self_harmonized_epitope,ref_epitope,correlation
0,1,2,2,1.0
1,2,1,1,1.0
