In [1]:
import pickle

import altair as alt

import pandas as pd

import polyclonal

import warnings
warnings.filterwarnings('ignore')

from IPython.utils import io

In [2]:
import os
os.chdir('../../../')

## Plotting immune escape across full protein and selected sites

In [3]:
def get_summed_escapes(sera_list, age_group, site_list=None):
    summed_escape_list = []
    
    for serum in sera_list:
        prob_escape = pd.read_csv(
            f'results/antibody_escape/{serum}_avg.csv'
        ).query(
            "`times_seen` >= 5"
        )
        
        prob_escape_sum = prob_escape.groupby(['site', 'wildtype'], as_index=False).aggregate({'escape_mean': 'sum'})

        if site_list:
            prob_escape_final = prob_escape_sum[prob_escape_sum['site'].isin(site_list)]
            prob_escape_final['site'] = pd.Categorical(prob_escape_final['site'], ordered=True)
            prob_escape_final['site'] = prob_escape_final['site'].astype(str)

        else:
            prob_escape_final = prob_escape_sum.copy()
            
        prob_escape_final['serum'] = serum
        prob_escape_final['age_group'] = age_group
        
        summed_escape_list.append(prob_escape_final)
        
    summed_escape = pd.concat(summed_escape_list)
    return summed_escape

In [25]:
# initialize list of key sites, plus samples in each age cohort
site_list = [50, 82, 103, 121, 122, 124, 131, 135, 137, 138, 145, 156, 157, 
              159, 160, 186, 188, 189, 193, 220, 224, 244, 276]

# site_list = [53, 83, 94, 128, 131, 135, 137, 138, 156, 159, 160, 164, 186, 190, 195, 276]

peds = [2367, 3944, 2389, 2323, 2388, 3973, 4299, 4584]
teens = [2350, 2365, 2382, 3866, 2380, 3856, 3857, 3862]
adults = ['33C', '34C', '197C', '199C', '215C', '210C', '74C', '68C', '150C', '18C']
infant = [2462]
ferrets = ['ferret_1', 'ferret_2', 'ferret_3']

sample_lists = [peds, teens, adults, infant, ferrets]
cohorts = ['02-05 years', '15-20 years', '40-45 years', 'infant', 'ferrets']
summed_escapes_filtered = []

i=0 # for looping through age cohort definitions

# start by getting full escape df filtered to key sites
for list in sample_lists:
    summed_escape_selected_sites = get_summed_escapes(list, cohorts[i], site_list)
    summed_escapes_filtered.append(summed_escape_selected_sites)

    i+=1

escape_df_filtered = pd.concat(summed_escapes_filtered)

site_dict = {'50': '050', 
             '53': '053',
             '62': '062', 
             '82': '082', 
             '83': '083',
             '94': '094'}

escape_df_filtered['site'] = escape_df_filtered['site'].apply(
    lambda x: site_dict[x] if x in site_dict else x
)

escape_df_filtered['serum'] = escape_df_filtered['serum'].astype(str)

# add 'mean_escape_mean' column of mean escape values per site within an age group
escape_df_filtered['mean_escape_mean'] = (
    escape_df_filtered.groupby(['site', 'age_group'])['escape_mean']
    .transform('mean')
)

In [26]:
# also generate escape df with all sites included
summed_escapes = []
i=0 # for looping through age cohort definitions

for list in sample_lists:
    summed_escape = get_summed_escapes(list, cohorts[i])
    summed_escapes.append(summed_escape)

    i+=1

escape_df_full = pd.concat(summed_escapes)

escape_df_full['serum'] = escape_df_full['serum'].astype(str)

## Set up different chart options

In [27]:
# filtered sites, scatterplot
summed_escape_scatterplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape_mean",
            title="summed_escape",
        ),
        color=alt.Color('age_group:N', 
                        legend=None
                       ).scale(scheme='set2'),
        detail='serum',
        tooltip=['serum', 'site', 'escape_mean']
    )
    .mark_circle(size=30, opacity=0.7)
    .properties(width=500, height=200)
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')

faceted_scatter_filtered = alt.layer(
    summed_escape_scatterplot, x_axis, data=escape_df_filtered
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='summed escape at significant sites',
        header=alt.Header(
            titleFontSize=21,
            titleFontWeight='normal',
            titlePadding=20,
            labelFontSize=17,
            labelOrient='right',
            # labelAngle=0,
            labelFontStyle='italic'
        )
    ),
    columns=1
).configure_axis(
    grid=False,
    labelFontSize=13,
    titleFontSize=15
)

# faceted_scatter_filtered.save('scratch_notebooks/figure_drafts/sitewise_escape_plots/230714_summed_escape_scatterplot.html')

faceted_scatter_filtered

In [28]:
# filtered sites, lineplot
summed_escape_lineplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape_mean",
            title="summed escape",
        ),
        color=alt.Color('age_group:N', 
                        legend=None
                       ).scale(scheme='set2'),
        detail='serum',
        tooltip=['serum', 'site', 'escape_mean']
    )
    .mark_line(size=0.6, opacity=.6)
    .properties(width=550, height=175)
)

mean_line = (
    alt.Chart()
    .encode(
        x=alt.X("site", title="site"),
        y=alt.Y("mean_escape_mean", title="summed escape"),
        color=alt.Color('age_group:N', legend=None).scale(scheme='set2'),
        tooltip=['site', 'escape_mean']
    )
    .mark_line(size=3, opacity=0.75)
    .properties(width=550, height=175)
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')

faceted_lineplot_filtered = alt.layer(
    summed_escape_lineplot, mean_line, x_axis, data=escape_df_filtered
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='summed escape at significant sites',
        header=alt.Header(
            titleFontSize=21,
            titleFontWeight='normal',
            titlePadding=20,
            labelFontSize=17,
            labelOrient='right',
            # labelAngle=0,
            labelFontStyle='italic'
        )
    ),
    columns=1
).configure_axis(
    grid=False,
    labelFontSize=13,
    titleFontSize=15
)

# faceted_lineplot_filtered.save('scratch_notebooks/figure_drafts/sitewise_escape_plots/230715_summed_escape_lineplot.html')

faceted_lineplot_filtered

In [29]:
# filtered sites, line and scatter overlay
summed_escape_lineplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape_mean",
            # scale=alt.Scale(domain=[-14, 16]),
            title="escape score",
        ),
        color=alt.Color('age_group:N', 
                        legend=None
                       ).scale(scheme='set2'),
        detail='serum',
        tooltip=['serum', 'site', 'escape_mean']
    )
    .mark_line(size=1.5, opacity=0.4)
    .properties(width=525, height=120)
)

mean_points = (
    alt.Chart()
    .encode(
        x=alt.X("site", title="site"),
        y=alt.Y("mean_escape_mean"),
        color=alt.Color('age_group:N', legend=None).scale(scheme='set2'),
        tooltip=['site', 'escape_mean']
    )
    .mark_circle(size=45, opacity=0.75)
    # .properties(width=400, height=150)
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')

faceted_line_scatter_overlay = alt.layer(
    summed_escape_lineplot, mean_points, x_axis, data=escape_df_filtered
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='summed escape',
        header=alt.Header(
            titleFontSize=21,
            titleFontWeight='normal',
            titlePadding=5,
            labelFontSize=17,
            labelOrient='right',
            # labelAngle=0,
            labelFontStyle='italic'
        )
    ),
    spacing=5,
    columns=1
).configure_axis(
    grid=False,
    labelFontSize=14,
    titleFontSize=15
)

faceted_line_scatter_overlay.save('scratch_notebooks/figure_drafts/sitewise_escape/230724_summed_escape_filt_sites.html')

faceted_line_scatter_overlay

In [31]:
# all sites, lineplot
summed_escape_lineplot_full = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape_mean",
            title="summed escape",
        ),
        color=alt.Color('age_group:N', 
                        legend=None,
                       ).scale(scheme='set2'),
        detail='serum',
        tooltip=['serum', 'site', 'escape_mean']
    )
    .mark_line(size=0.7, opacity=0.7)
    .properties(width=800, height=120)
)

faceted_lineplot_full = alt.layer(
    summed_escape_lineplot_full, x_axis, data=escape_df_full
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='escape at all residues',
        header=alt.Header(
            titleFontSize=21,
            titleFontWeight='normal',
            titlePadding=5,
            labelFontSize=17,
            labelOrient='right',
            # labelAngle=0,
            labelFontStyle='italic'
        )
    ),
    spacing=5,
    columns=1
).configure_axis(
    grid=False,
    labelFontSize=13,
    titleFontSize=15
)

faceted_lineplot_full.save('scratch_notebooks/figure_drafts/sitewise_escape/230724_summed_escape_full.html')

faceted_lineplot_full

# Chart comparison

In [13]:
faceted_scatter_filtered

In [14]:
faceted_lineplot_filtered

In [15]:
faceted_line_scatter_overlay

In [16]:
faceted_lineplot_full

## scratch code

In [75]:
def get_escapes(sera_list, age_group, agg_type, site_list=None):
    escape_df_list = []
    
    for serum in sera_list:
        prob_escape = pd.read_csv(
            f'results/antibody_escape/{serum}_avg.csv'
        ).query(
            "`times_seen` >= 5"
        )

        # prob_escape['summed'] = prob_escape.groupby(['site', 'wildtype'])['escape_mean'].transform('sum')
        # prob_escape['mean'] = prob_escape.groupby(['site', 'wildtype'])['escape_mean'].transform('mean')
        # prob_escape = prob_escape[['site', 'wildtype', 'summed', 'mean']].drop_duplicates()

        # prob_escape = prob_escape.melt(
        #     id_vars=['site', 'wildtype'], 
        #     value_vars=['summed', 'mean'],
        #     var_name='escape_type',
        #     value_name='escape'
        # )

        prob_escape['escape'] = prob_escape.groupby(['site', 'wildtype'])['escape_mean'].transform(agg_type)
        prob_escape = prob_escape[['site', 'wildtype', 'escape']].drop_duplicates()

        if site_list:
            prob_escape = prob_escape[prob_escape['site'].isin(site_list)]
            prob_escape['site'] = pd.Categorical(prob_escape['site'], ordered=True)
            prob_escape['site'] = prob_escape['site'].astype(str)
            
        prob_escape['serum'] = serum
        prob_escape['age_group'] = age_group
        
        escape_df_list.append(prob_escape)
        
    full_escape_df = pd.concat(escape_df_list)
    return full_escape_df

In [76]:
# initialize list of key sites, plus samples in each age cohort
site_list = [50, 82, 103, 121, 122, 124, 131, 135, 137, 138, 145, 156, 157, 
              159, 160, 186, 188, 189, 193, 220, 224, 244, 276]

peds = [2367, 3944, 2462, 2389, 2323, 2388, 2463, 3973, 4299, 4584]
teens = [2343, 2350, 2365, 2382, 3866, 2380, 3856, 3857, 3862, 3895]
adults = ['33C', '34C', '197C', '199C', '215C', '210C', '74C', '68C', '150C', '18C']
ferrets = ['ferret_1', 'ferret_2', 'ferret_3']

sample_lists = [peds, teens, adults, ferrets]
cohorts = ['0-5', '15-18', '40-45', 'ferrets']
summed_escapes_filtered = []
mean_escapes_filtered = []

i=0 # for looping through age cohort definitions

# start by getting full escape df filtered to key sites
for list in sample_lists:
    escape_sum = get_escapes(list, cohorts[i], 'sum', site_list)
    escape_mean = get_escapes(list, cohorts[i], 'mean', site_list)
    
    summed_escapes_filtered.append(escape_sum)
    mean_escapes_filtered.append(escape_mean)

    i+=1

filtered_df_sum = pd.concat(summed_escapes_filtered)
filtered_df_mean = pd.concat(mean_escapes_filtered)

site_dict = {'50': '050', 
             '62': '062', 
             '82': '082', 
             '94': '094'}

filtered_df_sum['site'] = filtered_df_sum['site'].apply(
    lambda x: site_dict[x] if x in site_dict else x
)
filtered_df_mean['site'] = filtered_df_mean['site'].apply(
    lambda x: site_dict[x] if x in site_dict else x
)

filtered_df_sum['serum'] = filtered_df_sum['serum'].astype(str)
filtered_df_mean['serum'] = filtered_df_mean['serum'].astype(str)

# add 'escape_mean' columns of mean escape values per site within an age group
filtered_df_sum['escape_mean'] = (
    filtered_df_sum.groupby(['site', 'age_group'])['escape']
    .transform('mean')
)

filtered_df_mean['escape_mean'] = (
    filtered_df_sum.groupby(['site', 'age_group'])['escape']
    .transform('mean')
)

# full_df_filtered['summed_escape_mean'] = (
#     full_df_filtered.groupby(['site', 'age_group'])['summed_escape']
#     .transform('mean')
# )

# full_df_filtered['mean_escape_mean'] = (
#     full_df_filtered.groupby(['site', 'age_group'])['mean_escape']
#     .transform('mean')
# )

filtered_df_mean

Unnamed: 0,site,wildtype,escape,serum,age_group,escape_mean
487,050,E,0.030716,2367,0-5,1.644210
645,082,K,-0.003329,2367,0-5,0.108860
768,103,P,-0.032294,2367,0-5,-0.376590
859,121,K,-0.034284,2367,0-5,-0.643280
879,122,N,0.000029,2367,0-5,-0.560240
...,...,...,...,...,...,...
1316,193,S,-0.038428,ferret_3,ferrets,-0.542600
1571,220,R,-0.333139,ferret_3,ferrets,-5.531733
1614,224,R,-0.357231,ferret_3,ferrets,-4.203567
1693,244,L,-0.235982,ferret_3,ferrets,-2.630167


In [84]:
# filtered sites, scatterplot
escape_scatterplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape",
            title="escape score",
        ),
        color=alt.Color('age_group:N', 
                        legend=None
                       ).scale(scheme='set2'),
        detail='serum',
        tooltip=['serum', 'site', 'escape_mean']
    )
    .mark_circle(size=30, opacity=0.7)
    .properties(width=500, height=200)
)



x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')

faceted_scatter_sum = alt.layer(
    escape_scatterplot, x_axis, data=filtered_df_sum
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='summed escape',
        header=alt.Header(
            titleFontSize=21,
            titleFontWeight='normal',
            titlePadding=20,
            labelFontSize=0
            # labelOrient='right',
            # labelAngle=0,
            # labelFontStyle='italic'
        )
    ),
    columns=1
)

faceted_scatter_mean = alt.layer(
    escape_scatterplot, x_axis, data=filtered_df_mean
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='mean escape',
        header=alt.Header(
            titleFontSize=21,
            titleFontWeight='normal',
            titlePadding=20,
            labelFontSize=17,
            labelOrient='right',
            labelAngle=0,
            labelFontStyle='italic'
        )
    ),
    columns=1
)

# faceted_scatter_filtered.save('scratch_notebooks/figure_drafts/sitewise_escape_plots/230714_summed_escape_scatterplot.html')

(faceted_scatter_sum | faceted_scatter_mean).configure_axis(
        grid=False,
        labelFontSize=13,
        titleFontSize=15
    )

In [89]:
escape_lineplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape",
            title="escape score",
        ),
        color=alt.Color('age_group:N', 
                        legend=None
                       ).scale(scheme='set2'),
        detail='serum',
        tooltip=['serum', 'site', 'escape']
    )
    .mark_line(size=1.5, opacity=0.4)
    .properties(width=500, height=150)
)

mean_points = (
    alt.Chart()
    .encode(
        x=alt.X("site", title="site"),
        y=alt.Y("mean", title="escape score"),
        color=alt.Color('age_group:N', legend=None).scale(scheme='set2'),
        tooltip=['site', 'escape']
    )
    .mark_circle(size=45, opacity=0.75)
    .properties(width=500, height=150)
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')

line_scatter_sum = alt.layer(
    escape_lineplot, mean_points, x_axis, data=filtered_df_sum
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='summed escape at significant sites',
        header=alt.Header(
            titleFontSize=21,
            titleFontWeight='normal',
            titlePadding=20,
            labelFontSize=17,
            labelOrient='right',
            # labelAngle=0,
            labelFontStyle='italic'
        )
    ),
    columns=1
)

line_scatter_mean = alt.layer(
    escape_lineplot, mean_points, x_axis, data=filtered_df_mean
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='mean escape',
        header=alt.Header(
            titleFontSize=21,
            titleFontWeight='normal',
            titlePadding=20,
            labelFontSize=17,
            labelOrient='right',
            labelAngle=0,
            labelFontStyle='italic'
        )
    ),
    columns=1
)

# faceted_scatter_filtered.save('scratch_notebooks/figure_drafts/sitewise_escape_plots/230714_summed_escape_scatterplot.html')

(line_scatter_sum | line_scatter_mean).configure_axis(
        grid=False,
        labelFontSize=13,
        titleFontSize=15
    )

ValueError: Unable to determine data type for the field "mean"; verify that the field name is not misspelled. If you are referencing a field from a transform, also confirm that the data type is specified correctly.

alt.HConcatChart(...)