In [None]:
import pandas as pd
from scipy.stats import ttest_ind
from long_covid.colors import flatuicolors
from long_covid import styling
from matplotlib import pyplot as plt
import datetime
import numpy as np

In [None]:
# Global variables for styling
COLORS = [flatuicolors.pomegranate, flatuicolors.amethyst, flatuicolors.belizehole]
ALPHA_VALUES = [1., 1., 0.4]
MARKERS = ['o', 'd', 'p']
VITALS_IDS = [65, 9, 43]
SIGNIFICANCE_MARKER_POSITIONS = [(2.5, 0.3), (880, 440), (70, 8)]
ITEM_ORDER_IN_LEGEND = [2,3,0,1]

X_LABEL = 'Weeks since PCR-test'
Y_LABELS = ['Change in RHR [bpm]', 'Change in daily steps', 'Change in daily sleep\nduration [minutes]']
BAR_LABEL = 'Week of PCR-test'
    
# Globally accessible input data
COHORTS = pd.read_feather('../data/03_derived/user_cohorts.feather')
WEEKLY_DEVIATIONS = pd.read_feather('../data/03_derived/weekly_vital_deviations_per_user.feather')

DATA = pd.merge(WEEKLY_DEVIATIONS, COHORTS, left_on='userid', right_on='user_id').drop(columns='user_id') 

In [None]:
def significance_test(combinations, level, two_sided=False):
    
    ind = pd.MultiIndex.from_tuples([], names=('vitalid', 'weeks_since_test'))
    sig = pd.DataFrame(index=ind)

    for vitalid in (65, 9, 43):

        if two_sided:
            side = 'two-sided'
        elif vitalid == 9:
            side = 'less'
        else:
            side = 'greater'

        for i in range (-3, 13):

            for cohort1, cohort2 in combinations:

                df = DATA[(DATA.vitalid == vitalid) & (DATA.weeks_since_test == i)]
                a = df[df[cohort1]].vital_change.values
                b = df[df[cohort2]].vital_change.values

                sig.loc[(vitalid, i), cohort1 + '_' + cohort2] = ttest_ind(a, b, equal_var=False, alternative=side)[1]

    return sig < level


def aggregate(df, cohort_key, limit):
   
    agg = df.groupby(['vitalid', 'weeks_since_test', cohort_key])[['vital_change']].agg(['mean', 'std', 'count'])
    agg.rename(columns={'vital_change': cohort_key}, inplace=True)
    agg[cohort_key, 'err'] = 1. * agg[cohort_key, 'std'] / np.sqrt(agg[cohort_key, 'count'])
    
    return agg.loc[:, -3:limit, True].droplevel(cohort_key)
    

def compute_aggregates(df, cohort_keys, limits):
    
    agg = aggregate(df, cohort_keys[0], limits[0])
    for cohort_key, limit in zip(cohort_keys[1:], limits[1:]):
        agg = pd.merge(agg, aggregate(df, cohort_key, limit), on=['vitalid', 'weeks_since_test'], how='outer')
    return agg
    

def plot_average_trajectories(aggregate_data, significance, cohorts, significance_info, labels, outfile, show_errors=[True, True, False]):

    f, axarr = plt.subplots(3, 1, sharex=True, figsize=(4, 7))
    
    # Plot average trajectories
    plot_info = list(zip(cohorts, labels, COLORS, ALPHA_VALUES, MARKERS, show_errors))
    for cohort, label, color, alpha, marker, show_errors in plot_info:
        for ax, vital in zip(axarr, VITALS_IDS):    
            x, y, yerr = aggregate_data.loc[vital][cohort][['mean', 'err']].reset_index().values.T
            if show_errors:
                ax.errorbar(x, y, yerr, lw=1.5, ls='-', label=label, marker=marker, c=color, alpha=alpha)
            else:
                ax.plot(x, y, label=label, marker=marker, ls='-', c=color, alpha=alpha)
    
    # Set markers for significant differences
    for ax, vital, marker_positions in zip(axarr, VITALS_IDS, SIGNIFICANCE_MARKER_POSITIONS):
        for key, facecolor, fillcolor, row in significance_info: 
            significant_weeks = significance.loc[vital][significance.loc[vital][key]].index.values
            y0, delta = marker_positions
            y = [y0 - delta * row] * len(significant_weeks)
            ax.scatter(significant_weeks, y, marker="*", edgecolor=facecolor, facecolor=fillcolor, s=90)
    
    # Add grey bar to indicate week of the test
    for ax in axarr:
        y0, y1 = ax.get_ylim()        
        ax.fill_between(x=[-0.5, 0.5], y1=y0, y2=y1, label=BAR_LABEL, color='k', alpha=0.115, zorder=-10)
        ax.fill_between(x=[-1, 1], y1=y0, y2=y1, color='k', alpha=0.075, zorder=-10)
        ax.set_ylim(y0, y1)   
        
    # Axis styling
    for ax, y_label in zip(axarr, Y_LABELS):
        styling.hide_and_move_axis(ax)
        ax.set_ylabel(y_label, size=12)
     
    # Add labels for subfigures
    for ax, label in zip(axarr, ('A', 'B', 'C')):
        x0, x1 = ax.get_xlim()
        y0, y1 = ax.get_ylim()
        if label != 'B':
            ax.text(x0 + 0.88 * (x1 - x0), y0 + 0.8 * (y1 - y0), label, size=24)
        else:
            ax.text(x0 + 0.03 * (x1 - x0), y0 + 0.1 * (y1 - y0), label, size=24)
        
    # Style x-axis
    axarr[-1].set_xticks(range(-2, 16, 2))
    axarr[-1].set_xlabel(X_LABEL, size=12)
    
    # Add legend
    handles, labels = plt.gca().get_legend_handles_labels()
    handles = [handles[idx] for idx in ITEM_ORDER_IN_LEGEND]
    labels = [labels[idx] for idx in ITEM_ORDER_IN_LEGEND]
    axarr[1].legend(handles, labels, fontsize=9, frameon=False)
    
    # Finalize
    plt.tight_layout()
    plt.savefig(outfile, dpi=400)

In [None]:
def figure2():

    # General settings for cohort
    cohorts = ['unvaccinated', 'vaccinated', 'negative']
    labels = [key.capitalize() for key in cohorts]
    max_weeks = [15, 12, 15]
    
    # Compute corresponding aggregations
    aggregations = compute_aggregates(DATA, cohorts, max_weeks)
    
    # Signififance tests
    significance_level = 0.01
    two_sided_test = False
    
    combinations = (('unvaccinated', 'negative'), ('unvaccinated', 'vaccinated'), ('vaccinated', 'negative'))
    significance = significance_test(combinations, significance_level, two_sided=two_sided_test)
    
    significance_info = [
        ('unvaccinated_negative', flatuicolors.pomegranate, 'w', 0),
        ('unvaccinated_vaccinated', flatuicolors.pomegranate, flatuicolors.pomegranate, 0),
        ('vaccinated_negative', flatuicolors.amethyst, flatuicolors.amethyst, 1),
    ]
    
    # Path to output file
    outfile = '../output/figure2_weekly_changes.jpg'.format(datetime.datetime.now().timestamp())
    
    plot_average_trajectories(aggregations, significance, cohorts, significance_info, labels, outfile)
    
    
def si_figure2():

    # General settings for cohort
    cohorts = ['unvaccinated', 'vaccinated_delta', 'negative']
    labels = ['Unvaccinated', 'Vaccinated/Delta', 'Negative']
    max_weeks = [15, 12, 15]
    
    # Compute corresponding aggregations
    aggregations = compute_aggregates(DATA, cohorts, max_weeks)
    
    # Signififance tests
    significance_level = 0.05
    two_sided_test = False
    
    combinations = (('unvaccinated', 'negative'), ('unvaccinated', 'vaccinated_delta'), ('vaccinated_delta', 'negative'))
    significance = significance_test(combinations, significance_level, two_sided=two_sided_test)
    
    significance_info = [
        ('unvaccinated_negative', flatuicolors.pomegranate, 'w', 0),
        ('unvaccinated_vaccinated_delta', flatuicolors.pomegranate, flatuicolors.pomegranate, 0),
        ('vaccinated_delta_negative', flatuicolors.amethyst, flatuicolors.amethyst, 1),
    ]
    
    # Path to output file
    outfile = '../output/si_figure2_only_delta_breakthrough_infections_with_delta.jpg'.format(datetime.datetime.now().timestamp())
    
    plot_average_trajectories(aggregations, significance, cohorts, significance_info, labels, outfile)
    
    
def si_figure3():

    # General settings for cohort
    cohorts = ['vaccinated_delta', 'vaccinated_omicron', 'negative']
    labels = ['Delta-Infection', 'Omicron-Infection', 'Negative']
    max_weeks = [12, 12, 15]
    
    # Compute corresponding aggregations
    aggregations = compute_aggregates(DATA, cohorts, max_weeks)
    
    # Signififance tests
    significance_level = 0.01
    two_sided_test = True
    
    combinations = (('vaccinated_delta', 'vaccinated_omicron'),)
    significance = significance_test(combinations, significance_level, two_sided=two_sided_test)
    
    significance_info = [
        ('vaccinated_delta_vaccinated_omicron', flatuicolors.pomegranate, flatuicolors.pomegranate, 0),
    ]
    
    # Path to output file
    outfile = '../output/si_figure3_delta_vs_omicron.jpg'.format(datetime.datetime.now().timestamp())
    
    plot_average_trajectories(aggregations, significance, cohorts, significance_info, labels, outfile)

In [None]:
figure2()
si_figure2()
si_figure3()