In [None]:
import datetime
from matplotlib import pyplot as plt
from long_covid.colors import flatuicolors
import numpy as np
from statsmodels.stats.proportion import proportions_ztest
from long_covid import styling
import pandas as pd

In [None]:
def significance_test(data, group1, group2, significance_level):
    
    significance = {65: [], 9: [], 43: []}

    for vitalid in significance.keys():
        
        df = data[data.vitalid == vitalid]
                    
        for i in range (-1, 6):
            
            a = df[(df.userid.isin(group1)) & (df.weeks_since_test == i)].vital_change.values
            b = df[(df.userid.isin(group2)) & (df.weeks_since_test == i)].vital_change.values
                
            if vitalid == 9:
                a = a < VITAL_THRESHOLDS[vitalid]
                b = b < VITAL_THRESHOLDS[vitalid]
            else:
                a = a > VITAL_THRESHOLDS[vitalid]
                b = b > VITAL_THRESHOLDS[vitalid]

            if proportions_ztest([a.sum(), b.sum()], [len(a), len(b)], alternative='larger')[1] < significance_level:
                significance[vitalid].append(i) 
    
    return significance

In [None]:
def extreme_vitals(df, significance_level=0.01):

    f, axarr = plt.subplots(3, 1, sharex=True, figsize=(4,7))

    cohorts = [COHORTS.user_id[COHORTS[key]].values for key in COHORT_KEYS] 
    cases = list(zip(cohorts, LABELS, COLORS, ALPHA_VALUES))

    sig_unvacc_vs_vacc = significance_test(df, cohorts[0], cohorts[1], significance_level)
    sig_unvacc_vs_negative = significance_test(df, cohorts[0], cohorts[2], significance_level)
    sig_vacc_vs_negative = significance_test(df, cohorts[1], cohorts[2], significance_level)

    significance_info = [
        (sig_unvacc_vs_negative, flatuicolors.pomegranate, 'w', 0),
        (sig_unvacc_vs_vacc, flatuicolors.pomegranate, flatuicolors.pomegranate, 0),
        (sig_vacc_vs_negative, flatuicolors.amethyst, flatuicolors.amethyst, 1),
    ]
    
    for ax, vitalid in zip(axarr, VITAL_IDS):

        for i, (users, label, color, alpha) in enumerate(cases):

            agg = df[(df.vitalid == vitalid) & df.userid.isin(users)].copy()

            if vitalid in (65, 43):
                agg['is_extreme'] = agg.vital_change > VITAL_THRESHOLDS[vitalid]
            else:
                agg['is_extreme'] = agg.vital_change < VITAL_THRESHOLDS[vitalid]

            agg = agg.groupby('weeks_since_test').is_extreme.agg(['mean', 'count'])
            agg['err'] = 1. * np.sqrt(agg['mean'] / agg['count'] * (1 - agg['mean']))
            
            agg = agg.loc[-2:5]
            agg[agg['mean'] == 0] = 0.001

            ax.bar(
                agg.index + i / 4 - 0.25, 
                agg['mean'], 
                width=0.18, 
                color=color, 
                yerr=agg['err'], 
                label=label,
                error_kw=dict(ecolor='k', alpha=0.85, lw=1.5, capsize=2, capthick=1.5)
            )

    for vital_index, ax in enumerate(axarr):
        for sig_data, facecolor, fillcolor, row in significance_info: 
            
            significant_weeks = sig_data[VITAL_IDS[vital_index]]
            y0, delta = SIGNIFICANCE_MARKER_POSITIONS[vital_index]
            y = [y0 - delta * row] * len(significant_weeks)
            ax.scatter(np.array(significant_weeks) - 0.17, y, marker='*', edgecolor=facecolor, facecolor=fillcolor, s=90)
                   
    for ax, y, label, y_label in zip(axarr, (.21, .31, .40), ('A', 'B', 'C'), Y_LABELS):
        ax.text(-2.5, y, label, size=24)
        styling.hide_and_move_axis(ax)
        ax.set_ylabel(y_label, size=12)
    
    ax.set_xticks(range(-2, 6, 1))
    ax.set_xlabel(X_LABEL, size=12)
    
    axbox = axarr[1].get_position()
    axarr[1].legend(loc=(axbox.x0 + .4, axbox.y0), frameon=False)
    
    plt.tight_layout()
    plt.savefig('../output/figure4_acute_phase_extreme_vitals.pdf'.format(datetime.datetime.now().timestamp()), dpi=400)

In [None]:
WEEKLY_DEVIATIONS = pd.read_feather('../data/03_derived/weekly_vital_deviations_per_user.feather')
COHORTS = pd.read_feather('../data/03_derived/user_cohorts.feather')

COHORT_KEYS = ['unvaccinated', 'vaccinated', 'negative'] 

Y_LABELS = [
    'Prevalence\nof more than 5 bpm/day\nincrease in RHR', 
    'Prevalence\nof more than 5000 steps/day\nreduced activity', 
    'Prevalence of more\nthan 1 hr/day\nadditional sleep'
]

X_LABEL = 'Weeks since PCR-test'
LABELS = [key.capitalize() for key in COHORT_KEYS]

COLORS = [flatuicolors.pomegranate, flatuicolors.amethyst, flatuicolors.belizehole]
ALPHA_VALUES = [1., 1., 0.5]

VITAL_IDS = [65, 9, 43]
VITAL_THRESHOLDS = {65: 5, 9: -5000, 43: 60}

SIGNIFICANCE_MARKER_POSITIONS = [(.24, .022), (.35, .033), (.46, .04)]

extreme_vitals(WEEKLY_DEVIATIONS, significance_level=0.01)