In [None]:
import pandas as pd
from scipy.stats import ttest_ind
from long_covid.colors import flatuicolors
from long_covid import styling
from matplotlib import pyplot as plt
import datetime
import numpy as np

In [None]:
# Global variables for styling
COLORS = [flatuicolors.pomegranate, flatuicolors.amethyst, flatuicolors.belizehole]
ALPHA_VALUES = [1., 1., 0.4]
MARKERS = ['o', 'd', 'p']
VITALS_IDS = [65, 9, 43]
SIGNIFICANCE_MARKER_POSITIONS = [(2.5, 0.3), (880, 440), (70, 8)]
ITEM_ORDER_IN_LEGEND = [2,3,0,1]

X_LABEL = 'Weeks since PCR-test'
Y_LABELS = ['Change in RHR [bpm]', 'Change in daily steps', 'Change in daily sleep\nduration [minutes]']
BAR_LABEL = 'Week of PCR-test'
    
# Globally accessible input data
COHORTS = pd.read_feather('../data/03_derived/user_cohorts.feather')
WEEKLY_DEVIATIONS = pd.read_feather('../data/03_derived/weekly_vital_deviations_per_user.feather')

In [None]:
def significance_test(data, group1, group2, level, two_sided=False):
    
    significance = {65: [], 9: [], 43: []}
    
    for vitalid in significance.keys():
        df = data[data.vitalid == vitalid]
                    
        for i in range (-3, 13):
            
            a = df[(df.userid.isin(group1)) & (df.weeks_since_test == i)].vital_change.values
            b = df[(df.userid.isin(group2)) & (df.weeks_since_test == i)].vital_change.values
                
            if two_sided:
                side = 'two-sided'
            else:
                if vitalid == 9:
                    side = 'less'
                else:
                    side = 'greater'
                    
            if ttest_ind(a, b, equal_var=False, alternative=side)[1] < level:   
                significance[vitalid].append(i) 
            
    return significance


def compute_aggregates(vitalid, users, limit):
   
    agg = WEEKLY_DEVIATIONS[(WEEKLY_DEVIATIONS.vitalid == vitalid) & WEEKLY_DEVIATIONS.userid.isin(users)]

    agg = agg.groupby('weeks_since_test')[['vital_change']].agg(['mean', 'std', 'count'])
    agg.columns = agg.columns.get_level_values(1)
    
    agg['err'] = 1. * agg['std'] / np.sqrt(agg['count'])
    agg = agg.loc[-3:limit]  
    
    return agg.index, agg['mean'].values, agg['err'].values


def plot_average_trajectories(cohort_keys, labels, max_weeks, significance_level, outfile, show_errors=[True, True, False]):

    f, axarr = plt.subplots(3, 1, sharex=True, figsize=(4, 7))
    
    # Obtain list of user ids for each cohort
    cohorts = [COHORTS.user_id[COHORTS[key]].values for key in cohort_keys]     
    
    # Plot average trajectories
    plot_info = list(zip(cohorts, max_weeks, labels, COLORS, ALPHA_VALUES, MARKERS, show_errors))
    
    for vital_index, ax in enumerate(axarr):    
        for users, limit, label, color, alpha, marker, show_errors in plot_info:
            x, y, yerr = compute_aggregates(VITALS_IDS[vital_index], users, limit)
            if show_errors:
                ax.errorbar(x, y, yerr, lw=1.5, ls='-', label=label, marker=marker, c=color, alpha=alpha)
            else:
                ax.plot(x, y, label=label, marker=marker, ls='-', c=color, alpha=alpha)
    
    # Plot significance
    sig_unvacc_vs_vacc = significance_test(WEEKLY_DEVIATIONS, cohorts[0], cohorts[1], level=significance_level)
    sig_unvacc_vs_negative = significance_test(WEEKLY_DEVIATIONS, cohorts[0], cohorts[2], level=significance_level)
    sig_vacc_vs_negative = significance_test(WEEKLY_DEVIATIONS, cohorts[1], cohorts[2], level=significance_level)

    significance_info = [
        (sig_unvacc_vs_negative, flatuicolors.pomegranate, 'w', 0),
        (sig_unvacc_vs_vacc, flatuicolors.pomegranate, flatuicolors.pomegranate, 0),
        (sig_vacc_vs_negative, flatuicolors.amethyst, flatuicolors.amethyst, 1),
    ]
    
    for vital_index, ax in enumerate(axarr):
        for sig_data, facecolor, fillcolor, row in significance_info: 
            
            significant_weeks = sig_data[VITALS_IDS[vital_index]]
            y0, delta = SIGNIFICANCE_MARKER_POSITIONS[vital_index]
            y = [y0 - delta * row] * len(significant_weeks)
            ax.scatter(significant_weeks, y, marker="*", edgecolor=facecolor, facecolor=fillcolor, s=90)
    
    # Add grey bar to indicate week of the test
    for ax in axarr:
        y0, y1 = ax.get_ylim()        
        ax.fill_between(x=[-0.5, 0.5], y1=y0, y2=y1, label=BAR_LABEL, color='k', alpha=0.115, zorder=-10)
        ax.fill_between(x=[-1, 1], y1=y0, y2=y1, color='k', alpha=0.075, zorder=-10)
        ax.set_ylim(y0, y1)   
        
    # Axis styling
    for ax, y_label in zip(axarr, Y_LABELS):
        styling.hide_and_move_axis(ax)
        ax.set_ylabel(y_label, size=12)
     
    # Add labels for subfigures
    for ax, label in zip(axarr, ('A', 'B', 'C')):
        x0, x1 = ax.get_xlim()
        y0, y1 = ax.get_ylim()
        if label != 'B':
            ax.text(x0 + 0.88 * (x1 - x0), y0 + 0.8 * (y1 - y0), label, size=24)
        else:
            ax.text(x0 + 0.03 * (x1 - x0), y0 + 0.1 * (y1 - y0), label, size=24)
        
    # Style x-axis
    axarr[-1].set_xticks(range(-2, 16, 2))
    axarr[-1].set_xlabel(X_LABEL, size=12)
    
    # Add legend
    handles, labels = plt.gca().get_legend_handles_labels()
    handles = [handles[idx] for idx in ITEM_ORDER_IN_LEGEND]
    labels = [labels[idx] for idx in ITEM_ORDER_IN_LEGEND]
    axarr[1].legend(handles, labels, fontsize=9, frameon=False)
    
    # Finalize
    plt.tight_layout()
    plt.savefig(outfile, dpi=400)

In [None]:
def figure2():

    cohort_keys = ['unvaccinated', 'vaccinated', 'negative']
    labels = [key.capitalize() for key in cohort_keys]
    max_weeks = [15, 12, 15]
    significance_level = 0.01
    outfile = '../output/figure2_weekly_changes.jpg'
    
    plot_average_trajectories(cohort_keys, labels, max_weeks, significance_level, outfile)

In [None]:
figure2()