## Compare serum neutralizing titers obtained from individuals to pooled sera

I want to compare the geometric mean and median titers of the different cohorts of individuals I've run from Seattle Children's Hospital and Univeristy of Penn Vaccine Cohort to the serum pools that I made from these same individuals. 

I'll want to be sure that I include the same groupings that I made pools of, namely:
* `SCH cohort only`
* `Penn pre-vax only`
* `Penn post-vax only`
* `SCH + Penn pre-vax`
* `SCH + Penn post-vax`
* `Penn pre-vax + Penn post-vax`
* `SCH + Penn pre-vax + Penn post-vax` (also known as the "super pool")
 

In [1]:
# Import packages
import os
import altair as alt
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression
from sklearn.metrics import root_mean_squared_error 

# Ignore error message from Altair about large dataframes
_ = alt.data_transformers.disable_max_rows()

# Color scheme
palette = [
    '#345995', #blue
    '#03cea4', #teal
    '#ca1551', #red
    '#eac435', #yellow
               ]

In [2]:
# Define inputs
datadir = '../data'
resultsdir = '../results'
os.makedirs(datadir, exist_ok = True)
os.makedirs(resultsdir, exist_ok = True)

# Define SCH titers
SCH_titers = (pd.read_csv('../../../results/aggregated_titers/titers_SCH.csv')
             .assign(
                 barcode = lambda x: x['serum'].str.split('_').str[2],)
             )


# Define Penn titers
titers_PennVaccineCohort = (pd.read_csv('../../../results/aggregated_titers/titers_PennVaccineCohort.csv')
                            .assign(
                                barcode = lambda x: x['serum'].str.split('_').str[2],
                                day = lambda x: 'd' + x['serum'].str.split('d').str[1])
                           )

# Define pooled titers
pooled_titers = pd.read_csv('../../../results/aggregated_titers/titers_PooledSera.csv')

In [3]:
# Get list of pools to get geometric mean and median titers from
pools = pooled_titers.serum.unique().tolist()

# Sort this list based on increasing string length 
pools = (sorted(pools, key=len))
pools

['SCH_pool',
 'PennPreVax_pool',
 'PennPostVax_pool',
 'SCHPennPrePost_pool',
 'SCH_PennPreVax_pool',
 'SCH_PennPostVax_pool',
 'PennPreVax_PennPostVax_pool']

In [4]:
# Define virus order
viral_plot_order = pd.read_csv('../../../data/H3N2library_2023-2024_strain_order.csv')
viruses = [v for v in viral_plot_order.strain]

# Define vaccine strains
vaccine_strains = []
with open('../data/vaccine_strains.csv') as f:
    for line in f:
        line = line.strip('\n')
        if 'strain' not in line:
            vaccine_strains.append(line)

# Define separate list where Massachusetts/18/2022 is reclassified as a 2023-circulating strains
vaccine_strains_no_Massachusetts = [item for item in vaccine_strains if item != 'A/Massachusetts/18/2022']

In [5]:
all_titers = pd.concat([SCH_titers, titers_PennVaccineCohort])

# Add a detailed group column
group_fine_dict = {np.nan: 'Children',
                   'PennVaccineCohort_d0': 'PennPreVax',
                   'PennVaccineCohort_d28': 'PennPostVax'}

all_titers = (all_titers
                 .assign(group_detailed = lambda x: x['group'] + '_' + x['day'])
                 .replace({'group_detailed': group_fine_dict})
             )

In [6]:
# Initialize dict of saved aggregate values
pooled_mean_median_dict = {}

# Get individual cohort aggregate values
for grp in list(group_fine_dict.values()):
    
    df = all_titers.query(f'group_detailed == "{grp}"')

    # Initialize empty list for virus, geometric mean and median
    virus_list = []
    for virus in df.virus.unique():
        virus_df = df.query(f'virus == "{virus}"')
        # Get arithmetic mean, geometric mean and median
        amean = virus_df['titer'].mean()
        gmean = np.exp(np.log(virus_df['titer']).mean())
        median = virus_df['titer'].median()
        # Save to virus list
        virus_list.append([virus, amean, gmean, median])
        
    # Save all viruses to dict
    pooled_mean_median_dict[grp] = virus_list


    # Now, for combined cohorts
    # Calculate pooled cohort aggregate values
    for second_grp in list(group_fine_dict.values()):
       
        # Avoid pairing up exact match groups
        if second_grp == grp:
            pass 

        # Get all other pairings
        elif second_grp != grp:
            
            # Only save pairing if it hasn't already been seen before 
            if tuple(sorted([second_grp, grp])) not in pooled_mean_median_dict.keys():
                
                df = all_titers[all_titers['group_detailed'].isin([second_grp, grp])]
                
                # Initialize empty list for virus, geometric mean and median
                virus_list = []
                for virus in df.virus.unique():
                    virus_df = df.query(f'virus == "{virus}"')
                    # Get arithmetic mean, geometric mean and median
                    amean = virus_df['titer'].mean()
                    gmean = np.exp(np.log(virus_df['titer']).mean())
                    median = virus_df['titer'].median()
                    # Save to virus list
                    virus_list.append([virus, amean, gmean, median])
                    
                # Save all viruses to dict
                pooled_mean_median_dict[(tuple(sorted([second_grp, grp])))] = virus_list


# Get superpool (all groups, with Penn represented twice) aggregate values
virus_list = []
for virus in all_titers.virus.unique():
    virus_df = all_titers.query(f'virus == "{virus}"')
    # Get arithmetic mean, geometric mean and median
    amean = virus_df['titer'].mean()
    gmean = np.exp(np.log(virus_df['titer']).mean())
    median = virus_df['titer'].median()
    # Save to virus list
    virus_list.append([virus, amean, gmean, median])
    
# Save all viruses to dict
pooled_mean_median_dict[tuple(sorted(group_fine_dict.values()))] = virus_list

## Plot actual titers across strains for matching pools and correlations

In [7]:
# Define dictionary to match the serum names to the dictionary pool names
pooledSerum_to_aggreagate_dict = {
    'PennPostVax_pool': 'PennPostVax',
    'PennPreVax_PennPostVax_pool': ('PennPostVax', 'PennPreVax'),
    'PennPreVax_pool': 'PennPreVax',
    'SCHPennPrePost_pool': ('Children', 'PennPostVax', 'PennPreVax'),
    'SCH_PennPostVax_pool': ('Children', 'PennPostVax'),
    'SCH_PennPreVax_pool': ('Children', 'PennPreVax'),
    'SCH_pool': 'Children'}

In [8]:
def get_titer_plot(data, title):

    # Configure plot
    color_scheme = alt.Color('aggregate_type', sort = ['Children']).scale(range=palette)
    titer_range = [30, 20000]
    width = 700
    height = 150
    titleFontSize = 14
    labelFontSize = 11

    # Draw line plot of titers
    line = (alt.Chart(data, width=width, height=height)
            .mark_line(size = 3, point = False)
            .encode(
                alt.X('virus', 
                          axis = alt.Axis(grid=False, 
                                          titleFontSize=titleFontSize, 
                                          labelFontSize=labelFontSize,
                                          title=None,
                                          titleY = 330,
                                          labelLimit = 1000, 
                                          labelAlign = 'right'), 
                          sort = viruses),
                alt.Y('titer', 
                          scale =alt.Scale(type='log',domain=titer_range, nice=False), 
                          axis=alt.Axis(grid=False, 
                                        titleFontSize=titleFontSize, 
                                        labelFontSize=labelFontSize, title="NT50")),
                color = color_scheme,))
    
    # Draw points
    points = (alt.Chart(data)
              .mark_point(size = 60, strokeWidth = 2)
              .encode(
                  alt.X('virus', 
                        sort = viruses),
                  alt.Y('titer'),
                  color = color_scheme))
        
    # Make layer chart
    layered = (alt.layer(line, points)
               .properties(title = alt.TitleParams(text = title, 
                                                   fontSize = titleFontSize, 
                                                   anchor = 'middle',)))

    return(layered)

In [9]:
def get_corr_plot(data, title, titer_range = [30, 20000]):

    # Configure plot
    color_scheme = alt.Color('serum', legend = None).scale(scheme='plasma')
    titer_range = titer_range
    width = 150
    titleFontSize = 16
    labelFontSize = 14

    # Make first scatter
    scatter_amean = (
        alt.Chart(data, width=width, height=width)
        .mark_circle(size=60, filled=False)
        .encode(
            alt.Y('pooled_serum:Q', 
                  title = 'pooled serum NT50',
                  scale = alt.Scale(nice=False, padding=6, type="log"),
                  axis = alt.Axis(grid=False, titleFontSize=titleFontSize, labelFontSize=labelFontSize)),
            alt.X('amean:Q',
                  title = 'arithmetic mean NT50',
                  scale = alt.Scale(nice=False, padding=6, type="log"),
                  axis = alt.Axis(grid=False, titleFontSize=titleFontSize, labelFontSize=labelFontSize)),
            color = color_scheme))

    # Make second scatter
    scatter_gmean = (
        alt.Chart(data, width=width, height=width)
        .mark_circle(size=60, filled=False)
        .encode(
            alt.Y('pooled_serum:Q', 
                  title = 'pooled serum NT50',
                  scale = alt.Scale(nice=False, padding=6, type="log"),
                  axis = alt.Axis(grid=False, titleFontSize=titleFontSize, labelFontSize=labelFontSize)),
            alt.X('gmean:Q',
                  title = 'geometric mean NT50',
                  scale = alt.Scale(nice=False, padding=6, type="log"),
                  axis = alt.Axis(grid=False, titleFontSize=titleFontSize, labelFontSize=labelFontSize)),
            color = color_scheme))

    # Make third scatter
    scatter_median = (
        alt.Chart(data, width=width, height=width)
        .mark_circle(size=60, filled=False)
        .encode(
            alt.Y('pooled_serum:Q', 
                  title = 'pooled serum NT50',
                  scale = alt.Scale(nice=False, padding=6, type="log"),
                  axis = alt.Axis(grid=False, titleFontSize=titleFontSize, labelFontSize=labelFontSize)),
            alt.X('median:Q',
                  title = 'median NT50',
                  scale = alt.Scale(nice=False, padding=6, type="log"),
                  axis = alt.Axis(grid=False, titleFontSize=titleFontSize, labelFontSize=labelFontSize)),
            color = color_scheme,))
    
    # Make dummy line plot
    line = pd.DataFrame({'Goals Conceded': titer_range, 'Goals': titer_range,})
    line_plot = alt.Chart(line).mark_line(color= 'black', strokeDash = [8,8]).encode(
        x= 'Goals Conceded',y= 'Goals')

    # Make dictionary of correlation values
    corr_text = {}
    for agg in ['amean', 'gmean', 'median']:
        # Initiate linear regression model
        model = LinearRegression()    
        # Define predictor and response variables
        titer1, titer2 = data[['pooled_serum']], data[[f'{agg}']]   
        # Fit regression model
        model.fit(titer1, titer2)  
        # Calculate R-squared of regression model
        r_squared = model.score(titer1, titer2)
        # Save to dict
        corr_text[agg] = str(r_squared)[0:4]

    # Layer scatter and line plots, add corr text
    scatter_amean = (scatter_amean + line_plot + (alt.Chart(pd.DataFrame({'r2' : [corr_text['amean']]}))
                                                  .encode(text=alt.Text('r2:Q', format=".2f")) 
                                                  .mark_text(baseline="bottom",align="right", dx=70, dy=70,
                                                             color="black",fontSize=labelFontSize)
                                                  .transform_calculate(
                                                      r2_string='"R² = " + format(datum.r2, ".2f")'  # Create R² string
                                                  ).encode(
                                                      text=alt.Text('r2_string:N'),)))
    
    scatter_gmean = (scatter_gmean + line_plot + (alt.Chart(pd.DataFrame({'r2' : [corr_text['gmean']]}))
                                                  .encode(text=alt.Text('r2:Q', format=".2f")) 
                                                  .mark_text(baseline="bottom",align="right", dx=70, dy=70,
                                                             color="black",fontSize=labelFontSize)
                                                  .transform_calculate(
                                                      r2_string='"R² = " + format(datum.r2, ".2f")'  # Create R² string
                                                  ).encode(
                                                      text=alt.Text('r2_string:N'),)))
    
    scatter_median = (scatter_median + line_plot + (alt.Chart(pd.DataFrame({'r2' : [corr_text['median']]}))
                                                  .encode(text=alt.Text('r2:Q', format=".2f")) 
                                                  .mark_text(baseline="bottom",align="right", dx=70, dy=70,
                                                             color="black",fontSize=labelFontSize)
                                                  .transform_calculate(
                                                      r2_string='"R² = " + format(datum.r2, ".2f")'  # Create R² string
                                                  ).encode(
                                                      text=alt.Text('r2_string:N'),)))

    # Concatenate and return plots
    concat = (alt.concat(scatter_amean, 
                         scatter_gmean, # Don't plot geometric mean for now
                         scatter_median, columns = 3, spacing = 10,)
             .properties(title = alt.TitleParams(text = title, 
                                            fontSize = titleFontSize, 
                                            anchor = 'middle',)))
    return(concat)

In [10]:
# Initialize empty lists for titer plots and correlation plots
titer_plots = []
corr_plots = []

subset = []
subset_corr_plots = []

# Iterate through pools
for pool in pools:

    # Get reduced dataframe
    pool_df = (pooled_titers.query(f'serum == "{pool}"'))
    # Make dataframe of relevant aggregated individuals
    individual_aggregate_titers_df = (pd.DataFrame(pooled_mean_median_dict
                                                   [pooledSerum_to_aggreagate_dict
                                                   [f'{pool}']], 
                                                   columns = ['virus', 'amean', 'gmean', 'median']))
    # Merge dataframes
    data = (pool_df
            .merge(individual_aggregate_titers_df, on = 'virus')
            .rename(columns = {'titer': 'pooled_serum'}) # Make pooled sera titer name more informative
           )

    # Melt merged datframe for easier plotting
    melt_data = (pd.melt(data, id_vars = ['group', 'serum', 'virus'],
                         value_vars = ['pooled_serum', 'amean', 'gmean', 'median'],
                         var_name = 'aggregate_type',
                         value_name = 'titer'))

    # Get oddly complicated title string
    pool_str = pooledSerum_to_aggreagate_dict[pool]
    title = ' + '.join(pool_str) if len(pool_str) < 4 else ''.join(pool_str)

    # Get titer plot and save to list
    titer_plots.append(get_titer_plot(melt_data, title))

    # Get correlation plot and save to list
    corr_plots.append(get_corr_plot(data, title))
  

In [11]:
# Get titer plots
(alt.concat(*titer_plots, title = '', columns = 2).resolve_scale(
    y='shared',
    x='shared')
 .configure_legend(titleFontSize=16, 
                   labelFontSize = 16,
                   strokeColor='gray',
                   padding=5,
                   cornerRadius=10,
                   symbolStrokeWidth = 3))


Surprisingly, `gmean` and `median` tend to track very closely with each other. Pooled serum tends to overestimate NT50 values (especially for children), but overall follows the same pattern as `gmean` and `median`. 

In [12]:
# Get corr plots
(alt.concat(*corr_plots, title = '', columns = 1, padding = 20).resolve_scale(
    y='shared',
    x='shared'))

Children are furthest off the x=y line, which is expected given this is the cohort containing the individuals with the greatest relative variance. 

## Make correlation plots without vaccine strains
Massachusetts/18 can stay in the analysis because it's also a currently circulating strain. 

In [13]:
pooled_titers_no_vax = pooled_titers[~pooled_titers['virus'].isin(vaccine_strains_no_Massachusetts)]

In [14]:
# Initialize empty lists for titer plots and correlation plots
titer_plots = []
corr_plots = []

subset = []
subset_corr_plots = []

# Iterate through pools
for pool in pools:

    # Get reduced dataframe
    pool_df = (pooled_titers_no_vax.query(f'serum == "{pool}"'))
    # Make dataframe of relevant aggregated individuals
    individual_aggregate_titers_df = (pd.DataFrame(pooled_mean_median_dict
                                                   [pooledSerum_to_aggreagate_dict
                                                   [f'{pool}']], 
                                                   columns = ['virus', 'amean', 'gmean', 'median']))
    # Merge dataframes
    data = (pool_df
            .merge(individual_aggregate_titers_df, on = 'virus')
            .rename(columns = {'titer': 'pooled_serum'}) # Make pooled sera titer name more informative
           )

    # Melt merged datframe for easier plotting
    melt_data = (pd.melt(data, id_vars = ['group', 'serum', 'virus'],
                         value_vars = ['pooled_serum', 'amean', 'gmean', 'median'],
                         var_name = 'aggregate_type',
                         value_name = 'titer'))

    # Get oddly complicated title string
    pool_str = pooledSerum_to_aggreagate_dict[pool]
    title = ' + '.join(pool_str) if len(pool_str) < 4 else ''.join(pool_str)

    # Get titer plot and save to list
    titer_plots.append(get_titer_plot(melt_data, title))

    # Get correlation plot and save to list
    corr_plots.append(get_corr_plot(data, title, titer_range = [30, 3000])) # Custom titer range

# Get corr plots
(alt.concat(*corr_plots, title = '', columns = 1, padding = 20).resolve_scale(
    y='shared',
    x='shared'))

# Make tidy plots, just showing medians for a subset of groups
These will be for the paper, and are a subset of what has been plotted above. 

In [15]:
def get_corr_plot_tidy(data, title, titer_range = [30, 20000], palette=palette):

    # Configure plot
    color_scheme = alt.Color('serum', legend = None, sort = ['Children']).scale(range=palette)
    titer_range = titer_range
    width = 260
    RtextLocation = 120
    titleFontSize = 19
    labelFontSize = 19
    
    circleSize = 140
    circleOpacity = 0.6
    circleStroke = 2.4

    # Make median scatter
    scatter_median = (
        alt.Chart(data, width=width, height=width)
        .mark_circle(size=circleSize, stroke = 'black', strokeWidth=circleStroke, filled=True, opacity=circleOpacity)
        .encode(
            alt.Y('pooled_serum:Q', 
                  title = 'pool neutralization titer',
                  scale = alt.Scale(nice=False, padding=6, type="log"),
                  axis = alt.Axis(grid=False, titleFontSize=titleFontSize, labelFontSize=labelFontSize)),
            alt.X('median:Q',
                  title = 'median neutralization titer',
                  scale = alt.Scale(nice=False, padding=6, type="log"),
                  axis = alt.Axis(grid=False, titleFontSize=titleFontSize, labelFontSize=labelFontSize,
                                 titleY=35)), # Make X-axis label less bunched up on axis
            color = color_scheme,))
    
    # Make dummy line plot
    line = pd.DataFrame({'Goals Conceded': titer_range, 'Goals': titer_range,})
    line_plot = alt.Chart(line).mark_line(color= 'black', strokeDash = [8,8]).encode(
        x= 'Goals Conceded',y= 'Goals')

    # Make dictionary of correlation values
    corr_text = {}
    for agg in ['amean', 'gmean', 'median']:
        # Initiate linear regression model
        model = LinearRegression()    
        # Define predictor and response variables
        titer1, titer2 = data[['pooled_serum']], data[[f'{agg}']]   
        # Fit regression model
        model.fit(titer1, titer2)  
        # Calculate R-squared of regression model
        r_squared = model.score(titer1, titer2)
        # Save to dict
        corr_text[agg] = str(r_squared)[0:4]

    # Layer scatter and line plots, add corr text

    scatter_median = (scatter_median + line_plot + (alt.Chart(pd.DataFrame({'r2' : [corr_text['median']]}))
                                                  .encode(text=alt.Text('r2:Q', format=".2f")) 
                                                  .mark_text(baseline="bottom",align="right", dx=RtextLocation, dy=RtextLocation,
                                                             color="black",fontSize=labelFontSize)
                                                  .transform_calculate(
                                                      r2_string='"R² = " + format(datum.r2, ".2f")'  # Create R² string
                                                  ).encode(
                                                      text=alt.Text('r2_string:N'),)))

    # Concatenate and return plots
    concat = (alt.concat(
                         scatter_median, columns = 3, spacing = 10,)
             .properties(title = alt.TitleParams(text = title, 
                                            fontSize = titleFontSize, 
                                            align = 'center',anchor='middle',
                                                dx=20)))
    return(concat)

In [37]:
def get_titer_plot_tidy(data, title, pool_position, final_pool):

    custom_palette = ['black', palette[2], palette[0], palette[1],]

    # Configure plot
    color_scheme = alt.Color('aggregate_type', sort = ['Children']).scale(range = custom_palette)
    titer_range = [40, 1000]
    width = 1100
    height = 200
    titleFontSize = 19
    labelFontSize = 19
    
    lineOpacity = 0.7
    lineSize = 3.4
    markerOpacity = 0.7
    markerSize = 120

    # X axes params, determined by plot position
    if pool_position == final_pool:
        x_axes_params = alt.Axis(grid=False, 
                                          titleFontSize=titleFontSize, 
                                          labelFontSize=labelFontSize,
                                          title=None,
                                          titleY = 330,
                                          labelLimit = 1000, 
                                          labelAlign = 'right')
    else:
        x_axes_params = alt.Axis(grid=False, 
                                          titleFontSize=0, 
                                          labelFontSize=0,
                                          title=None,
                                          titleY = 330,
                                          labelLimit = 1000, 
                                          labelAlign = 'right')

    data = data[data['aggregate_type'].isin(['pooled_serum', 'median', 'gmean', 'amean'])]

    # Draw line plot of titers
    line = (alt.Chart(data, width=width, height=height)
            .mark_line(size = lineSize, point = False, opacity = lineOpacity)
            .encode(
                alt.X('virus', 
                          axis = x_axes_params, 
                          sort = viruses),
                alt.Y('titer', 
                          scale =alt.Scale(type='log',domain=titer_range, nice=False), 
                          axis=alt.Axis(grid=False, 
                                        titleFontSize=titleFontSize, 
                                        labelFontSize=labelFontSize, title="Neutralization titer")),
                color = color_scheme,))
    
    # Draw points
    points = (alt.Chart(data)
              .mark_point(size = markerSize,
                      stroke = 'black',
                      strokeWidth = 2,
                      filled=True, 
                      opacity=markerOpacity)
              .encode(
                  alt.X('virus', 
                        sort = viruses),
                  alt.Y('titer'),
                  color = color_scheme))
        
    # Make layer chart
    layered = (alt.layer(line, points)
               .properties(title = alt.TitleParams(text = title, 
                                                   fontSize = 24, 
                                                   anchor = 'middle',)))

    return(layered)

In [38]:
# Get only pooled titers for children, adults pre, and children+adults pre
subset_pools = ['Children', 'PennPreVax', 'Children + PennPreVax']

In [39]:
# Initialize empty lists for titer plots and correlation plots
titer_plots = []
corr_plots = []

subset = []
subset_corr_plots = []

# Initialize pool position counter
i=1
final_pool = len(subset_pools)

# Iterate through pools
for pool in pools:

    # Get reduced dataframe
    pool_df = (pooled_titers_no_vax.query(f'serum == "{pool}"'))
    # Make dataframe of relevant aggregated individuals
    individual_aggregate_titers_df = (pd.DataFrame(pooled_mean_median_dict
                                                   [pooledSerum_to_aggreagate_dict
                                                   [f'{pool}']], 
                                                   columns = ['virus', 'amean', 'gmean', 'median']))
    # Merge dataframes
    data = (pool_df
            .merge(individual_aggregate_titers_df, on = 'virus')
            .rename(columns = {'titer': 'pooled_serum'}) # Make pooled sera titer name more informative
           )

    # Melt merged datframe for easier plotting
    melt_data = (pd.melt(data, id_vars = ['group', 'serum', 'virus'],
                         value_vars = ['pooled_serum', 'amean', 'gmean', 'median'],
                         var_name = 'aggregate_type',
                         value_name = 'titer'))

    # Get oddly complicated title string
    pool_str = pooledSerum_to_aggreagate_dict[pool]
    title = ' + '.join(pool_str) if len(pool_str) < 4 else ''.join(pool_str)

    # Define pretty titles
    if title == 'Children':
        pretty_title = 'Children (n=56)'
    elif title == 'PennPreVax':
        pretty_title = 'Adults pre-vaccination (n=39)'
    elif title == 'Children + PennPreVax':
        pretty_title = 'Children and adults (n=95)'

    # Only concatenate the subset pools
    if title in subset_pools:
    
        # Get correlation plot and save to list
        corr_plots.append(get_corr_plot_tidy(data, pretty_title, titer_range = [60, 1000])) # Custom titer range

    # Only concatenate the subset pools
    if title in subset_pools:
        
        # Get titer plot and save to list
        titer_plots.append(get_titer_plot_tidy(melt_data, pretty_title, pool_position = i, final_pool = final_pool))

        # Add to pool position counter
        i+=1


In [40]:
# Get corr plots
pretty_corr_plot = (alt.concat(*corr_plots, title = '', columns = 3, spacing=140).resolve_scale(
    y='shared',
    x='shared'))

# Save
outfile = os.path.join(resultsdir, 'corr_plots.pdf')
pretty_corr_plot.save(outfile, dpi = 300)
pretty_corr_plot

In [20]:
# Get titer plots
pretty_titer_plot = (alt.concat(*titer_plots, title = '', columns = 1).resolve_scale(
    y='shared',
    x='shared')
 .configure_legend(titleFontSize=16, 
                   labelFontSize = 16,
                   strokeColor='gray',
                   padding=5,
                   cornerRadius=10,
                   symbolStrokeWidth = 3)
                    )

# Save
outfile = os.path.join(resultsdir, 'median_adults_and_children_vs_pool.png')
pretty_titer_plot.save(outfile, dpi = 600)
pretty_titer_plot

Get correlation plot for just post-vaccination adults for supplement

In [21]:
pooled_titers_no_vax

Unnamed: 0,group,serum,virus,titer,titer_bound,titer_sem,n_replicates,titer_as
0,PooledSera,PennPostVax_pool,A/AbuDhabi/6753/2023,559.1,interpolated,52.98,3,midpoint
1,PooledSera,PennPostVax_pool,A/Bangkok/P3599/2023,599.0,interpolated,56.92,3,midpoint
2,PooledSera,PennPostVax_pool,A/Bangkok/P3755/2023,500.3,interpolated,20.39,3,midpoint
3,PooledSera,PennPostVax_pool,A/Bhutan/0006/2023,599.3,interpolated,53.84,3,midpoint
4,PooledSera,PennPostVax_pool,A/Bhutan/0845/2023,644.3,interpolated,123.10,3,midpoint
...,...,...,...,...,...,...,...,...
538,PooledSera,SCH_pool,A/TECPAN/017FLU/2023,531.4,interpolated,41.76,3,midpoint
542,PooledSera,SCH_pool,A/Townsville/68/2023,608.0,interpolated,13.34,3,midpoint
543,PooledSera,SCH_pool,A/Victoria/1033/2023,484.9,interpolated,45.42,3,midpoint
544,PooledSera,SCH_pool,A/Wisconsin/27/2023,542.9,interpolated,48.85,3,midpoint


In [45]:
# Get only pooled titers for children, adults pre, and children+adults pre
pooled_titers_no_vax = pooled_titers[~pooled_titers['virus'].isin(vaccine_strains_no_Massachusetts)]
subset_pools = ['PennPostVax']

# Initialize empty lists for correlation plots
corr_plots = []

subset = []
subset_corr_plots = []

# Initialize pool position counter
i=1

# Iterate through pools
for pool in subset_pools:

    # Get reduced dataframe
    pool_df = (pooled_titers_no_vax.query(f'serum == "{pool}_pool"'))
    # Make dataframe of relevant aggregated individuals
    individual_aggregate_titers_df = (pd.DataFrame(pooled_mean_median_dict
                                                   [pooledSerum_to_aggreagate_dict
                                                   [f'{pool}_pool']], 
                                                   columns = ['virus', 'amean', 'gmean', 'median']))
    # Merge dataframes
    data = (pool_df
            .merge(individual_aggregate_titers_df, on = 'virus')
            .rename(columns = {'titer': 'pooled_serum'}) # Make pooled sera titer name more informative
           )

    # Melt merged datframe for easier plotting
    melt_data = (pd.melt(data, id_vars = ['group', 'serum', 'virus'],
                         value_vars = ['pooled_serum', 'amean', 'gmean', 'median'],
                         var_name = 'aggregate_type',
                         value_name = 'titer'))

    # Define pretty titles
    pretty_title = 'Adults post-vaccination (n=39)'

    # Only concatenate the subset pools
    if pool in subset_pools:
    
        # Get correlation plot and save to list
        corr_plots.append(get_corr_plot_tidy(data, pretty_title, titer_range = [60, 1000], palette=['#515A47','#03cea4',])) # Custom titer range

        # Add to pool position counter
        i+=1


# Get corr plots
pretty_corr_plot = (alt.concat(*corr_plots, title = '', columns = 3, spacing=140).resolve_scale(
    y='shared',
    x='shared'))

# Save
outfile = os.path.join(resultsdir, 'post_vax_corr_plot.pdf')
pretty_corr_plot.save(outfile, dpi = 300)
pretty_corr_plot

## Compare pooled sera to individually plotted titers

In [23]:
# Get no vaccine SCH titers
SCH_titers_no_vax = SCH_titers[~SCH_titers['virus'].isin(vaccine_strains_no_Massachusetts)]

# Get no vaccine Penn titers
titers_PennVaccineCohort_no_vax = (titers_PennVaccineCohort[~titers_PennVaccineCohort['virus'].isin(vaccine_strains_no_Massachusetts)]
                                   .assign(day = lambda x: x['serum'].str.split('d').str[1])
                                  )

print('Here are pools...')
pooled_titers_no_vax.serum.unique()

Here are pools...


array(['PennPostVax_pool', 'PennPreVax_PennPostVax_pool',
       'PennPreVax_pool', 'SCHPennPrePost_pool', 'SCH_PennPostVax_pool',
       'SCH_PennPreVax_pool', 'SCH_pool'], dtype=object)

In [24]:
# Define function that takes a titers dataframe
# Plots all individuals and median
# Plots pool equivalent as dotted black line

def plot_individuals_and_pool(individual_titers_df, pooled_titers_df, title='', position = 1):

    # Configure plot
    titer_range = [30, 16000]
    titleFontSize=19
    labelFontSize=19
    color=palette[position]    
    lineOpacity = 0.15
    lineSize = 2.8
    markerOpacity = 0.9
    markerSize = 160
    width = 1300
    height = 260
    
    # Configure x axes labels
    if position == 0:
        xlabelFontSize=0
    else:
        xlabelFontSize=19   

    # Plot all individual lines
    line = (alt.Chart(individual_titers_df, width = width,height=height, title = title)
            .mark_line(size = lineSize, point = False, opacity = lineOpacity, color=color)
            .encode(
                alt.X('virus', 
                          axis = alt.Axis(grid=False, titleFontSize=titleFontSize, labelFontSize=xlabelFontSize,
                                          title = None,labelLimit = 1000, labelAlign = 'right',
                                         ),             
                          sort = viruses
                         ),
                alt.Y('titer', 
                          scale =alt.Scale(type='log',domain=titer_range, nice=False), 
                          axis=alt.Axis(grid=False, titleFontSize=titleFontSize, labelFontSize=labelFontSize, title="neutralization titer")
                     ),
                detail = 'serum',
            )
           )

    # Plot median of individuals as points
    points = (alt.Chart(individual_titers_df, width=width,)
              .mark_point(size = markerSize,
                          color=color,
                          stroke = 'black',
                          strokeWidth = 2.2,
                          filled=True, 
                          opacity=markerOpacity)
              .encode(
                  alt.X('virus', sort = viruses),
                  alt.Y('median(titer)'),
              )
             )

    # Plot pool as dotted black line
    pool_line = (alt.Chart(pooled_titers_df, width = width,height=height)
            .mark_line(size = 3, point = False, opacity = 0.8, color='black',
                       # strokeDash=[8,2]
                      )
            .encode(
                alt.X('virus', 
                          axis = alt.Axis(grid=False, titleFontSize=titleFontSize, labelFontSize=labelFontSize,
                                          title = None,labelLimit = 1000, labelAlign = 'right',
                                          # labelColor = all_vacc_label_colors,labelFontWeight = all_vacc_label_weights,
                                         ),             
                          sort = viruses
                         ),
                alt.Y('titer', 
                          scale =alt.Scale(type='log',domain=titer_range, nice=False), 
                          axis=alt.Axis(grid=False, titleFontSize=titleFontSize, labelFontSize=labelFontSize, title="")
                     )))
    
    # Plot pool points
    pool_points = (alt.Chart(pooled_titers_df, width=width,)
              .mark_point(size = markerSize,
                          color='black',
                          stroke = 'black',
                          strokeWidth = 2.2,
                          filled=True, 
                          opacity=markerOpacity)
              .encode(
                  alt.X('virus', sort = viruses),
                  alt.Y('median(titer)'),
              )
             )
    
    layered = (alt.layer(line, points, pool_line, pool_points))
    
    return layered
    

In [25]:
plots = [
    plot_individuals_and_pool(SCH_titers_no_vax, pooled_titers_no_vax.query('serum == "SCH_PennPreVax_pool"'), 'Children and adults (n=95)', 2),
    # plot_individuals_and_pool(titers_PennVaccineCohort_no_vax.query('day == "0"'), pooled_titers_no_vax.query('serum == "PennPreVax_pool"'), 'Adults (n=39)', 1),
]


# Configure plot
titleFontSize=19
labelFontSize=19


plot = alt.concat(*plots, columns = 1).configure_title(
    fontSize = titleFontSize
).configure_header(
    labelFontSize=labelFontSize,
    labelFontWeight='bold',
    labelOrient = 'top'
).resolve_scale(y='shared',x='shared')

# Save
outfile = os.path.join(resultsdir, 'cohort_individuals_and_medians_vs_pool.pdf')
plot.save(outfile, dpi = 600)
plot