In [None]:
import pandas as pd
from scipy.stats import ttest_ind
from long_covid.colors import flatuicolors
from long_covid import styling
from matplotlib import pyplot as plt
import datetime
import numpy as np

In [None]:
def significance_test(data, group1, group2, level, two_sided=False):
    
    significance = {65: {}, 9: {}, 43: {}}
    
    for vitalid in significance.keys():
        df = data[data.vitalid == vitalid]
                    
        for i in range (-3, 13):
            
            a = df[(df.userid.isin(group1)) & (df.weeks_since_test == i)].vital_change.values
            b = df[(df.userid.isin(group2)) & (df.weeks_since_test == i)].vital_change.values
                
            if two_sided:
                side = 'two-sided'
            else:
                if vitalid == 9:
                    side = 'less'
                else:
                    side = 'greater'
                    
            significance[vitalid][i] = ttest_ind(a, b, equal_var=False, alternative=side)[1] < level
                
    return significance

In [None]:
def compute_aggregates(vitalid, users, limit):
   
    agg = weekly_deviations[(weekly_deviations.vitalid == vitalid) & weekly_deviations.userid.isin(users)]

    agg = agg.groupby('weeks_since_test')[['vital_change']].agg(['mean', 'std', 'count'])
    agg.columns = agg.columns.get_level_values(1)
    
    agg['err'] = 1. * agg['std'] / np.sqrt(agg['count'])
    agg = agg.loc[-3:limit]  
    
    return agg.index, agg['mean'].values, agg['err'].values


def plot_weekly_deviations(data, plot_info, significance_info, outfile, order=[2,3,0,1], pos = {65: (2.5, 0.3), 9: (880, 440), 43: (70, 8)}):
    
    f, axarr = plt.subplots(3, 1, sharex=True, figsize=(4, 7))
    
    labels = {65: 'Change in RHR [bpm]', 9: 'Change in daily steps', 43: 'Change in daily sleep\nduration [minutes]'}
    xlabel = 'Weeks since PCR-test'
    barlabel = 'Week of PCR-test'
        
    for ax, vitalid in zip(axarr, (65, 9, 43)):

        for users, limit, label, color, alpha, marker, show_errors in plot_info:

            x, y, yerr = compute_aggregates(vitalid, users, limit)
            if show_errors:
                ax.errorbar(x, y, yerr, lw=1.5, ls='-', label=label, marker=marker, c=color, alpha=alpha)
            else:
                ax.plot(x, y, label=label, marker=marker, ls='-', c=color, alpha=alpha)

        # Significance test!
        for sig_data, facecolor, fillcolor, row in significance_info: 
            for week, is_significant in sig_data[vitalid].items():
                if is_significant:
                    x0, delta = pos[vitalid]
                    ax.scatter(week, x0 - delta * row, marker="*", edgecolor=facecolor, facecolor=fillcolor, s=90)
                      
        y0, y1 = ax.get_ylim()
        styling.hide_and_move_axis(ax)
        
        ax.fill_between(x=[-0.5, 0.5], y1=y0, y2=y1, label=barlabel, color='k', alpha=0.115, zorder=-10)
        ax.fill_between(x=[-1, 1], y1=y0, y2=y1, color='k', alpha=0.075, zorder=-10)
        
        ax.set_ylim(y0, y1)
        ax.set_ylabel(labels[vitalid], size=12)

    ax.set_xticks(range(-2, 16, 2))
    
    handles, labels = plt.gca().get_legend_handles_labels()
    axarr[-2].legend([handles[idx] for idx in order], [labels[idx] for idx in order], fontsize=9, frameon=False)
    
    for ax, label in zip(axarr, ('A', 'B', 'C')):
        x0, x1 = ax.get_xlim()
        y0, y1 = ax.get_ylim()
        if label != 'B':
            ax.text(x0 + 0.88 * (x1 - x0), y0 + 0.8 * (y1 - y0), label, size=24)
        else:
            ax.text(x0 + 0.03 * (x1 - x0), y0 + 0.1 * (y1 - y0), label, size=24)
            
    ax.set_xlabel(xlabel, size=12)
    plt.tight_layout()
    plt.savefig(outfile, dpi=400)
    

def figure2(cohortkey, label, outfile, max_weeks, siglevel, altkey='unvaccinated', altlabel='Unvaccinated'):

    cohort1 = cohorts.user_id[cohorts[altkey]].values
    cohort2 = cohorts.user_id[cohorts[cohortkey]].values
    cohort3 = cohorts.user_id[cohorts["negative"]].values
    
    # Plot
    sig_unvacc_vs_vacc = significance_test(weekly_deviations, cohort1, cohort2, level=siglevel)
    sig_unvacc_vs_negative = significance_test(weekly_deviations, cohort1, cohort3, level=siglevel)
    sig_vacc_vs_negative = significance_test(weekly_deviations, cohort2, cohort3, level=siglevel)

    # Each row contains:
    # users, limit, label, color, alpha, marker, show_errors
    PLOT_INFO = [
            (cohort1, 15, altlabel, flatuicolors.pomegranate, 1, 'o', True),
            (cohort2, max_weeks, label, flatuicolors.amethyst, 1, 'd', True),
            (cohort3, 15, 'Negative', flatuicolors.belizehole , 0.4, 'p', False),
    ]

    # Each row contains:
    # sig_data, facecolor, fillcolor, row

    SIG_INFO = [
        (sig_unvacc_vs_negative, flatuicolors.pomegranate, 'w', 0),
        (sig_unvacc_vs_vacc, flatuicolors.pomegranate, flatuicolors.pomegranate, 0),
        (sig_vacc_vs_negative, flatuicolors.amethyst, flatuicolors.amethyst, 1),
    ]
        
    plot_weekly_deviations(data=weekly_deviations, plot_info=PLOT_INFO, significance_info=SIG_INFO, outfile=outfile)


In [None]:
weekly_deviations = pd.read_feather('../data/03_derived/weekly_vital_deviations_per_user.feather')
cohorts = pd.read_feather('../data/03_derived/user_cohorts.feather')
            
figure2('vaccinated', 'Vaccinated', '../output/figure2_weekly_changes_{0}.jpg'.format(str(datetime.datetime.now().timestamp())), 12, 0.01)