## Plot fraction of individuals with low neutralization titers by strain

In [2]:
# Import packages
import os
import altair as alt
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import gmean


# Ignore error message from Altair about large dataframes
_ = alt.data_transformers.disable_max_rows()

# Basic color palette
color_palette = [
    '#345995', #blue
    '#03cea4', #teal
    '#ca1551', #red
    '#eac435', #yellow
    # '#EDEDF4', #white
    '#A499B3', #rose quartz
    '#515A47', #ebony
               ]

In [3]:
# Define inputs
datadir = '../data'
resultsdir = '../results'
os.makedirs(datadir, exist_ok = True)
os.makedirs(resultsdir, exist_ok = True)

# Define SCH titers
SCH_titers = (pd.read_csv('../../../results/aggregated_titers/titers_SCH.csv')
             .assign(
                 barcode = lambda x: x['serum'].str.split('_').str[2],)
             )

# Define Penn titers
titers_PennVaccineCohort = (pd.read_csv('../../../results/aggregated_titers/titers_PennVaccineCohort.csv')
                            .assign(
                                barcode = lambda x: x['serum'].str.split('_').str[2],
                                timepoint = lambda x: 'd' + x['serum'].str.split('d').str[1])
                           )

# Define Australian MA22 vaccine cohort titers
Australia_MA22_titers = (pd.read_csv('../../../results/aggregated_titers/titers_AusVaccineCohort.csv')
                         .assign(
                             barcode = lambda x: x['serum'].str.split('_').str[1],
                             timepoint = lambda x: x['serum'].str.split('_').str[2])
                        )

# Define sera metadata
SCH_metadata = pd.read_csv('../../../data/sera_metadata/metadata_SCH.csv')
AusVaccineCohort_metadata = (pd.read_csv('../../../data/sera_metadata/metadata_AusVaccineCohort.csv')
                             .rename(columns = {'Bloom_lab_serum_ID': 'serum'}))
# Reformatting the Penn metadata
day0 = (pd.read_csv('../../../data/sera_metadata/metadata_PennVaccineCohort.csv')
        .drop(['Bloom_lab_ID_d28'], axis=1)
        .rename(columns = {'Bloom_lab_ID_d0': 'serum'}))

day28 = (pd.read_csv('../../../data/sera_metadata/metadata_PennVaccineCohort.csv')
         .drop(['Bloom_lab_ID_d0'], axis=1)
         .rename(columns = {'Bloom_lab_ID_d28': 'serum'}))
PennVaccineCohort_metadata = pd.concat([day0,day28])
# Concatenate all metadata
all_metadata = pd.concat([SCH_metadata,PennVaccineCohort_metadata, AusVaccineCohort_metadata])


# Concatenate all titers into one dataframe
all_titers = pd.concat([
    SCH_titers,
    titers_PennVaccineCohort,
    Australia_MA22_titers
])

# Merge all titers and metadata on 'serum' key
all_titers = all_titers.merge(all_metadata, on = 'serum')

# Add a slightly more informative 'group_detail' column that incorporates pre- or post-vax info
all_titers['group_detail'] = np.where(all_titers['group'] == 'SCH', 
                                         all_titers['group'],  # If the value matches the string
                                         all_titers['group'].astype(str) + '_' + all_titers['timepoint'])  # Otherwise, combination of col2 and col3


In [4]:
# Define virus order
viral_plot_order = pd.read_csv('../../../data/H3N2library_2023-2024_strain_order.csv')
virus_order = [v for v in viral_plot_order.strain]

# Define vaccine strains
vaccine_strains = []
with open('../data/vaccine_strains.csv') as f:
    for line in f:
        line = line.strip('\n')
        if 'strain' not in line:
            vaccine_strains.append(line)

# Define egg-passaged vaccine strains
egg_passaged_vaccine_strains = []
with open('../data/egg-passaged_vaccine_strains.csv') as f:
    for line in f:
        line = line.strip('\n')
        if 'strain' not in line:
            egg_passaged_vaccine_strains.append(line)

# Define separate list where Massachusetts/18/2022 is reclassified as a 2023-circulating strains
vaccine_strains_no_Massachusetts = [item for item in vaccine_strains if item != 'A/Massachusetts/18/2022']

pre_2020_strains = [
    'A/Massachusetts/18/2022', 'A/Thailand/8/2022',
    'A/Darwin/6/2021', 'A/Darwin/9/2021',
]

vaccine_strains_pre_2020 = [item for item in vaccine_strains if item not in pre_2020_strains]

## Plots

In [5]:
def plot_titers_vaccination_cohorts(data, sort_order, _range = [30, 16000], title=None):
    # Make plot with all individuals and median dots
    color_scheme = alt.Color('timepoint', sort=['prevax']).scale(range=color_palette[4:])
    titer_range = _range
    titleFontSize=18
    labelFontSize=18
    lineOpacity = 0.2
    lineSize = 2.8
    markerOpacity = 0.8
    markerSize = 160
    width = 1100
    height = 200

    # Add vaccine strain weights
    vacc_weights = {
    'condition': [
        {'test' : 'datum.label == "A/Massachusetts/18/2022"', 'value': 'bold'},
        {'test' : 'datum.label == "A/Thailand/8/2022"', 'value': 'bold'},
        {'test' : 'datum.label == "A/Darwin/6/2021"', 'value': 'bold'},
        {'test' : 'datum.label == "A/Darwin/9/2021"', 'value': 'bold'},
    ],
     'value': 'normal'} # The default value if no condition is met

    band = (alt.Chart(data, width=width, height=height, )
            .mark_errorband(extent='iqr', opacity=0.4)
            .encode(alt.X('virus', axis = alt.Axis(grid=False, titleFontSize=titleFontSize, labelFontSize=labelFontSize,
                                          title = None,labelLimit = 1000, labelAlign = 'right',
                                                   labelFontWeight = vacc_weights,
                                         ),             
                          sort = virus_order),
                    alt.Y('titer', 
                          scale =alt.Scale(type='log',domain=_range, nice=False), 
                          axis=alt.Axis(grid=False, titleFontSize=titleFontSize, labelFontSize=labelFontSize, title="neutralization titer")),
                color = color_scheme,)
           ) 
    
    points = (alt.Chart(data, width=width)
              .mark_point(size = markerSize, stroke = 'black', strokeWidth = 2.2, filled=True,  opacity=markerOpacity)
              .encode(alt.X('virus', sort = virus_order),
                      alt.Y('median(titer)'),
                      color = color_scheme,))
        
    layered = (alt.layer(band, points)
               .facet(row = alt.Row('group:N',title=None, sort=sort_order),
                      config = alt.Config(legend = alt.LegendConfig(titleFontSize=titleFontSize, labelFontSize = labelFontSize,
                                                    strokeColor='gray',padding=10,cornerRadius=10,
                                                    labelLimit = 1000 # Let legend labels be as long as they want
                                                     )))
               .properties(title=title)
               .configure_header(labels=False, # Removing labels for pretty versions of figure, comment out to see labels
                                  labelFontSize=labelFontSize,labelFontWeight='bold',
                                  labelOrient='right', 
                                 # labelAngle=270,
                                )
               .configure_title(align='center', anchor='middle', fontSize=titleFontSize, fontWeight='bold')
           .configure_legend(symbolSize=markerSize, symbolOpacity=markerOpacity, symbolStrokeWidth=2.2, symbolStrokeColor='black', 
                             titleFontSize=titleFontSize, labelFontSize = labelFontSize,
                            strokeColor='gray',padding=10,cornerRadius=10,
                            labelLimit = 1000 # Let legend labels be as long as they want
                            )
    )

    return layered

In [7]:
data = all_titers[all_titers['group'].isin(
    # group_sort_list
    ['PennVaccineCohort', 'AusVaccineCohort']
)]
data = (data[~data['virus'].isin(vaccine_strains_no_Massachusetts)]
       .replace({'timepoint': {'d0': 'prevax',
                               'd28': 'postvax'}
                }))

plot = plot_titers_vaccination_cohorts(data, sort_order = ['PennVaccineCohort', 'AusVaccineCohort'], 
                                       _range=[30, 4000], title = '2023-circulating strains')
# Save final plot
outfile = os.path.join(resultsdir, 'post_vaccination_2023_titers.pdf')
plot.save(outfile, dpi = 600)
plot