In [64]:
import pickle

import altair as alt

import pandas as pd

import polyclonal

import warnings
warnings.filterwarnings('ignore')

from IPython.utils import io

In [2]:
import os
os.chdir('../../')

In [75]:
a_sites = [122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 
           136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146]

b_sites = [155, 156, 157, 158, 159, 160, 186, 187, 188, 189, 190, 191, 192, 193, 
           194, 195, 196, 197, 198]

In [113]:
site_list = a_sites + b_sites

In [124]:
site_list = [50, 62, 82, 94, 103, 121, 122, 124, 131, 135, 137, 138, 145, 156, 157, 
              159, 160, 188, 189, 193, 220, 224, 276]

ped_sera_list = [2367, 3944, 2462, 2389, 2323, 2388, 2463, 3973, 4299, 4584]
teen_sera_list = [2343, 2350, 2365, 2382, 3866, 2380, 3856, 3857, 3862, 3895]
adult_sera_list = ['33C', '34C', '197C', '199C', '215C', 
                   '210C', '74C', '68C', '150C', '18C']

In [125]:
def get_summed_escapes(sera_list, age_group, site_list):
    summed_escape_list = []
    
    for serum in sera_list:
        prob_escape = pd.read_csv(
            f'results/antibody_escape/{serum}_avg.csv'
        ).query(
            "`times_seen` >= 3"
        )
        
        prob_escape_sum = prob_escape.groupby('site', as_index=False).aggregate({'escape_mean': 'sum'})
        
        prob_escape_filt = prob_escape_sum[prob_escape_sum['site'].isin(site_list)]
        
        prob_escape_filt['serum'] = serum
        prob_escape_filt['age_group'] = age_group
        prob_escape_filt['site'] = pd.Categorical(prob_escape_filt['site'], ordered=True)
        prob_escape_filt['site'] = prob_escape_filt['site'].astype(str)
        
        summed_escape_list.append(prob_escape_filt)
        
    summed_escape = pd.concat(summed_escape_list)
    return summed_escape

In [126]:
ped = get_summed_escapes(ped_sera_list, '0-5', site_list)
teen = get_summed_escapes(teen_sera_list, '15-18', site_list)
adult = get_summed_escapes(adult_sera_list, '40-45', site_list)

In [127]:
full_escape = pd.concat([ped, teen, adult])

site_dict = {'50': '050', 
             '62': '062', 
             '82': '082', 
             '94': '094'}

full_escape['site'] = full_escape['site'].apply(lambda x: site_dict[x] if x in site_dict else x)

In [153]:
summed_escape_lineplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape_mean",
            title="summed_escape",
        ),
        # color=alt.Color('serum:N', 
        #                 legend=alt.Legend(orient="right", title='sera')
        #                ),
        color=alt.Color('age_group:N', 
                        legend=alt.Legend(orient="right", title='age group')
                       ).scale(scheme='set2'),
        detail='serum',
    )
    .mark_circle(size=30, opacity=0.7)
    .properties(width=400, height=200)
    # .mark_circle(opacity=0.7, size=20))
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')

alt.layer(
    summed_escape_lineplot, x_axis, data=full_escape
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='age cohort',
    ),
    columns=2
).configure_axis(
    grid=False
).resolve_axis(
    x='independent'
)

# chart.save('230501_immune_escape_by_age.pdf')


## Overlay colored by age cohorts

In [155]:
summed_escape_lineplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape_mean:Q",
            title="summed_escape",
        ),
        xOffset='jitter:Q',
        # column='age_group:N',
        color=alt.Color('age_group:N', 
                        legend=alt.Legend(orient="right", title='age group')
                       ).scale(scheme='set2'),
        detail='serum'
    ).transform_calculate(jitter="sqrt(-2*log(random()))*cos(2*PI*random())")
    .mark_circle(size=30, opacity=0.7)
    .properties(width=700, height=200)
    # .mark_circle(opacity=0.7, size=20))
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')

alt.layer(
    summed_escape_lineplot, x_axis, data=full_escape
).configure_axis(
    grid=False
).resolve_axis(
    x='independent'
)
# chart.save('230508_immune_escape_scatter.pdf')


## mutation beta correlation between age cohorts

In [None]:
# read model files from pickle in results/antibody_escape/{antibody}.pickle
# then make df as shown below
# then use polyclonal.PolyclonalAverage to get average

In [5]:
age_sera_dict = {
    '0-5': [2367, 3944, 2462, 2389, 2323, 2388, 2463, 3973, 4299, 4584],
    '15-18': [2343, 2350, 2365, 2382, 3866, 2380, 3856, 3857, 3862, 3895],
    '40-45': ['33C', '34C', '197C', '199C', '215C', '210C', '74C', '68C', '150C', '18C']
}

In [19]:
# def get_model_df(age_sera_dict):
    
#     df_list = []
    
#     for age in age_sera_dict:
#         model_list = []
#         serum_list = []
        
#         for serum in age_sera_dict[age]:
#             model = pickle.load(open(f'results/antibody_escape/{serum}.pickle', 'rb'))
#             model_list.append(model)
#             serum_list.append(serum)
        
#         df = pd.DataFrame({
#             'model': model_list,
#             'serum': serum_list,
#             'age': age
#         })
        
#         df_list.append(df)
        
#     models_df = pd.concat(df_list, ignore_index=True)
    
#     return models_df

In [56]:
def get_model_df(sera, age, date, replicate):
    model_list = []
    serum_list = []
    
    for serum in sera:
        model = pickle.load(open(f'results/polyclonal_fits/libA_{date}_1_{serum}_{replicate}.pickle', 'rb'))
        
        model_list.append(model[0])
        serum_list.append(serum)
    
    df = pd.DataFrame({
        'age': age,
        'model': model_list,
        'serum': serum_list,
    })
    
    return df

In [48]:
adult_sera_1 = ['33C', '34C', '197C', '199C', '215C']
adult_sera_2 = ['210C', '74C', '68C', '150C', '18C']
# adult_sera_names = adult_sera_1 + adult_sera_2

ped_sera_1 = [2367, 3944, 2462, 2389, 2323]
ped_sera_2 = [2388, 2463, 3973, 4299, 4584]
# ped_sera_names = ped_sera_1 + ped_sera_2

teen_sera_1 = [2343, 2350, 2365, 2382, 3866]
teen_sera_2 = [2380, 3856, 3857, 3862, 3895]
# teen_sera_names = teen_sera_1 + teen_sera_2

In [57]:
ped1 = get_model_df(ped_sera_1, '0-5', '230221', '1')
ped2 = get_model_df(ped_sera_2, '0-5', '230323', '1')
peds = pd.concat([ped1, ped2], ignore_index=True)

In [50]:
avg_model = polyclonal.PolyclonalAverage(peds)
avg_model.mut_escape_corr_heatmap()

In [58]:
teen1 = get_model_df(teen_sera_1, '15-18', '230317', '1')
teen2 = get_model_df(teen_sera_2, '15-18', '230403', '1')
teens = pd.concat([teen1, teen2], ignore_index=True)

In [52]:
avg_model = polyclonal.PolyclonalAverage(teens)
avg_model.mut_escape_corr_heatmap()

In [59]:
adult1 = get_model_df(adult_sera_1, '40-45', '230419', '1')
adult2 = get_model_df(adult_sera_2, '40-45', '230425', '1')
adults = pd.concat([adult1, adult2], ignore_index=True)

# avg_model = polyclonal.PolyclonalAverage(adults)
# avg_model.mut_escape_corr_heatmap()

In [60]:
full_df = pd.concat([peds, teens, adults], ignore_index=True)

avg_model = polyclonal.PolyclonalAverage(full_df)
avg_model.mut_escape_corr_heatmap()

In [7]:
adult_1 = get_prob_escape(adult_sera_1, 'libA', '230419', '1')
adult_2 = get_prob_escape(adult_sera_2, 'libA', '230425', '1')
adult_prob_escape = adult_1 + adult_2

ped_1 = get_prob_escape(ped_sera_1, 'libA', '230221', '1')
ped_2 = get_prob_escape(ped_sera_2, 'libA', '230323', '1')
ped_prob_escape = ped_1 + ped_2

teen_1 = get_prob_escape(teen_sera_1, 'libA', '230317', '1')
teen_2 = get_prob_escape(teen_sera_2, 'libA', '230403', '1')
teen_prob_escape = teen_1 + teen_2

In [20]:
models_df = get_model_df(age_sera_dict)
models_df

Unnamed: 0,model,serum,age
0,<polyclonal.polyclonal_collection.PolyclonalAv...,2367,0-5
1,<polyclonal.polyclonal_collection.PolyclonalAv...,3944,0-5
2,<polyclonal.polyclonal_collection.PolyclonalAv...,2462,0-5
3,<polyclonal.polyclonal_collection.PolyclonalAv...,2389,0-5
4,<polyclonal.polyclonal_collection.PolyclonalAv...,2323,0-5
5,<polyclonal.polyclonal_collection.PolyclonalAv...,2388,0-5
6,<polyclonal.polyclonal_collection.PolyclonalAv...,2463,0-5
7,<polyclonal.polyclonal_collection.PolyclonalAv...,3973,0-5
8,<polyclonal.polyclonal_collection.PolyclonalAv...,4299,0-5
9,<polyclonal.polyclonal_collection.PolyclonalAv...,4584,0-5


In [62]:
models_df['model'][0].mut_escape_df

Unnamed: 0,epitope,site,wildtype,mutant,mutation,escape_mean,escape_median,escape_min_magnitude,escape_std,n_models,times_seen,frac_models
0,1,-18,K,N,K-18N,0.016676,0.016676,0.016676,,1,1.0,0.5
1,1,-14,L,M,L-14M,-0.041901,-0.041901,-0.041901,,1,1.0,0.5
2,1,-13,V,F,V-13F,0.025171,0.025171,0.025171,,1,1.0,0.5
3,1,-9,A,S,A-9S,0.006079,0.006079,0.001243,0.006840,2,1.0,1.0
4,1,-6,A,S,A-6S,-0.016964,-0.016964,-0.016964,,1,1.0,0.5
...,...,...,...,...,...,...,...,...,...,...,...,...
3509,1,540,Q,K,Q540K,-0.214534,-0.214534,0.002010,0.306239,2,10.5,1.0
3510,1,540,Q,R,Q540R,-0.721596,-0.721596,-0.083235,0.902778,2,9.5,1.0
3511,1,540,Q,*,Q540*,-0.047928,-0.047928,0.025789,0.104252,2,2.0,1.0
3512,1,541,K,E,K541E,0.029264,0.029264,0.029264,,1,1.0,0.5


In [13]:
avg_model = polyclonal.PolyclonalAverage(models_df)
avg_model.mut_escape_corr_heatmap()

AttributeError: 'PolyclonalAverage' object has no attribute 'epitope_harmonized_model'

In [114]:
libs = ['libA', 'libB']
replicates = ['1', '1']
models = [libA_3944_model, libB_3944_model]
models_df = pd.DataFrame({
    'library': libs,
    'replicate': replicates,
    'model': models
})

models_df

Unnamed: 0,library,replicate,model
0,libA,1,<polyclonal.polyclonal.Polyclonal object at 0x...
1,libB,1,<polyclonal.polyclonal.Polyclonal object at 0x...


In [115]:
avg_model = polyclonal.PolyclonalAverage(models_df)
avg_model.mut_escape_corr_heatmap()

## analyzing antigenic sites a and b

In [None]:
a_sites = [122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 
           136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146]

b_sites = [155, 156, 157, 158, 159, 160, 186, 187, 188, 189, 190, 191, 192, 193, 
           194, 195, 196, 197, 198]

In [None]:
def get_summed_escapes(sera_list, age_group, site_list):
    summed_escape_list = []
    
    for serum in sera_list:
        prob_escape = pd.read_csv(
            f'results/antibody_escape/{serum}_avg.csv'
        ).query(
            "`times_seen` >= 3"
        )
        
        prob_escape_sum = prob_escape.groupby('site', as_index=False).aggregate({'escape_mean': 'sum'})
        
        prob_escape_filt = prob_escape_sum[prob_escape_sum['site'].isin(site_list)]
        
        prob_escape_filt['serum'] = serum
        prob_escape_filt['age_group'] = age_group
        # prob_escape_filt['site'] = pd.Categorical(prob_escape_filt['site'], ordered=True)
        # prob_escape_filt['site'] = prob_escape_filt['site'].astype(str)
        
        summed_escape_list.append(prob_escape_filt)
        
    summed_escape = pd.concat(summed_escape_list)
    return summed_escape

In [None]:
ped = get_summed_escapes(ped_sera_list, '0-5', a_sites + b_sites)

In [35]:
teen = get_summed_escapes(teen_sera_list, '15-18', a_sites + b_sites)

In [36]:
adult = get_summed_escapes(adult_sera_list, '40-45', a_sites + b_sites)

In [37]:
full_escape = pd.concat([ped, teen, adult])
full_escape['site'] = full_escape['site'].astype(str)

# site_dict = {'50': '050', 
#              '62': '062', 
#              '82': '082', 
#              '94': '094'}

# full_escape['site'] = full_escape['site'].apply(lambda x: site_dict[x] if x in site_dict else x)
# full_escape

In [37]:
full_escape = pd.concat([ped, teen, adult])
full_escape['site'] = full_escape['site'].astype(str)

# site_dict = {'50': '050', 
#              '62': '062', 
#              '82': '082', 
#              '94': '094'}

# full_escape['site'] = full_escape['site'].apply(lambda x: site_dict[x] if x in site_dict else x)
# full_escape

In [None]:
summed_escape_lineplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape_mean",
            title="summed_escape",
        ),
        # column='age_group:N',
        color=alt.Color('serum:N', 
                        legend=alt.Legend(orient="right", title='sera'))
    )
    .mark_line(size=1, opacity=0.7)
    .properties(width=400, height=200)
    # .mark_circle(opacity=0.7, size=20))
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')

alt.layer(
    summed_escape_lineplot, x_axis, data=full_escape
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='age cohort',
    ),
    columns=2
).configure_axis(
    grid=False
).resolve_axis(
    x='independent'
)

# chart.save('230501_immune_escape_by_age.pdf')


In [9]:
sites_list = [103, 121, 122, 124, 135, 137, 138, 145, 159, 160, 186, 189, 193, 220, 224]

summed_escape_list = []

i=0
for prob_escape in prob_escape_list:
    summed_escape = get_summed_escape(prob_escape, sites_list)
    summed_escape['sera'] = adult_sera_list[i]
    
    summed_escape_list.append(summed_escape)
    
    i+=1

NameError: name 'prob_escape_list' is not defined

In [36]:
summed_escape_full = pd.concat(summed_escape_list)
summed_escape_full['site'] = summed_escape_full['site'].astype(str)
summed_escape_full.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 75 entries, 87 to 185
Data columns (total 3 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   site    75 non-null     object 
 1   escape  75 non-null     float64
 2   sera    75 non-null     object 
dtypes: float64(1), object(2)
memory usage: 2.3+ KB


In [56]:
summed_escape_base = (
    alt.Chart(summed_escape_full)
    .encode(
        x=alt.X("site", 
                title="site",
                # scale=alt.Scale(type="log"),
               ),
        y=alt.Y(
            "escape",
            title="summed_escape",
            # scale=alt.Scale(type="log", constant=0.02, domainMax=1),
        ),
        # column='antibody:N',
        color=alt.Color('sera:N', 
                        legend=alt.Legend(orient="right", title='sera'))
    )
)

site_lineplot = (
    (
        (
            summed_escape_base.mark_line(size=1, opacity=0.7)
            # .transform_calculate(_stat_show_line="true")
            # .transform_filter(line_selection)
        )
        + summed_escape_base.mark_circle(opacity=0.7, size=20)
        + alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')
    )
)

# line = alt.Chart(pd.DataFrame({'y': [1]})).mark_rule().encode(y='y')

site_lineplot.configure_axis(grid=False)

### scratch code - 

In [19]:
def get_prob_escape(sera_list, lib, date, replicate):
    prob_escape_list = []
    for serum in sera_list:        
        prob_escape = pd.read_csv(
            f'results/prob_escape/{lib}_{date}_1_{serum}_{replicate}_prob_escape.csv', 
            keep_default_na=False,
            na_values="nan"
        ).query(
            "`no-antibody_count` >= no_antibody_count_threshold"
        )
        
        prob_escape_list.append(prob_escape)
        
    return prob_escape_list

In [20]:
adult_sera_list = ['33C', '34C', '197C', '199C', '215C']

prob_escape_list = get_prob_escape(adult_sera_list, 'libA', '230419', '2')

In [6]:
def generate_model(
    prob_escape_df,
    n_epitopes=1
):
    
    model = polyclonal.Polyclonal(
        n_epitopes=n_epitopes,
        data_to_fit=prob_escape_df.rename(
            columns={
                "antibody_concentration": "concentration",
                "aa_substitutions_reference": "aa_substitutions",
            }
        ),
        alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
    )

    # fit model, suppressing output text to avoid clutter in notebook
    with io.capture_output() as captured:
        opt_res = model.fit(
            logfreq=200,
            reg_escape_weight=0.1,
        )

    mut_escape_plot = model.mut_escape_plot(addtl_slider_stats={"times_seen": 3}, init_floor_at_zero=False)

    return model

In [23]:
def get_summed_escape(prob_escape, sites_list):
    model = generate_model(prob_escape)
    df = model.mut_escape_df
    df = df.loc[df['times_seen'] >= 3]
    
    summed_escapes = df.groupby('site', as_index=False).aggregate({'escape': 'sum'})
    
    filtered_escape = summed_escapes[summed_escapes['site'].isin(sites_list)]
    
    return filtered_escape

In [24]:
sites_list = [103, 121, 122, 124, 135, 137, 138, 145, 159, 160, 186, 189, 193, 220, 224]

summed_escape_list = []

i=0
for prob_escape in prob_escape_list:
    summed_escape = get_summed_escape(prob_escape, sites_list)
    summed_escape['sera'] = adult_sera_list[i]
    
    summed_escape_list.append(summed_escape)
    
    i+=1

In [36]:
summed_escape_full = pd.concat(summed_escape_list)
summed_escape_full['site'] = summed_escape_full['site'].astype(str)
summed_escape_full.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 75 entries, 87 to 185
Data columns (total 3 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   site    75 non-null     object 
 1   escape  75 non-null     float64
 2   sera    75 non-null     object 
dtypes: float64(1), object(2)
memory usage: 2.3+ KB


In [56]:
summed_escape_base = (
    alt.Chart(summed_escape_full)
    .encode(
        x=alt.X("site", 
                title="site",
                # scale=alt.Scale(type="log"),
               ),
        y=alt.Y(
            "escape",
            title="summed_escape",
            # scale=alt.Scale(type="log", constant=0.02, domainMax=1),
        ),
        # column='antibody:N',
        color=alt.Color('sera:N', 
                        legend=alt.Legend(orient="right", title='sera'))
    )
)

site_lineplot = (
    (
        (
            summed_escape_base.mark_line(size=1, opacity=0.7)
            # .transform_calculate(_stat_show_line="true")
            # .transform_filter(line_selection)
        )
        + summed_escape_base.mark_circle(opacity=0.7, size=20)
        + alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')
    )
)

# line = alt.Chart(pd.DataFrame({'y': [1]})).mark_rule().encode(y='y')

site_lineplot.configure_axis(grid=False)

In [18]:
summed_escape_lineplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape_mean",
            title="summed_escape",
        ),
        # column='age_group:N',
        color=alt.Color('serum:N', 
                        legend=alt.Legend(orient="right", title='sera'))
    )
    .mark_line(size=1, opacity=0.7)
    .properties(width=800, height=200)
    # .mark_circle(opacity=0.7, size=20))
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')

alt.layer(
    summed_escape_lineplot, x_axis, data=full_escape
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='age cohort',
    ),
    columns=1
).configure_axis(grid=False)

# chart.save('immune_escape_by_age.pdf')