In [None]:
import os
import pandas as pd
from datenspende.surveydataIO import vaccinations, pcr_tests
import functools
import numpy as np
from rocsDB import postgre
from sciencetools.plottools.colors import flatuicolors
from sciencetools.plottools import styling
from matplotlib import pyplot as plt
from scipy.stats import ttest_ind
from statsmodels.stats.proportion import proportions_ztest
from datetime import timedelta

# Load data

In [None]:
TESTDATA = 'data/p14_pcr_test.feather'
VACCINATIONS = 'data/p14_vaccinations.feather'
VITALS = 'data/p14_vitals.feather'
USERS = 'data/p14_user_info.feather'

In [None]:
def load(filename, force_reload, function):
    
    if os.path.exists(filename) and not force_reload:
        df = pd.read_feather(filename)
    else:
        df = function()
        df.to_feather(filename)
    
    return df


def get_vitals(user_ids):
    
    # Make sure that the IN-condition for the SQL query either takes the form '(userid)' in the case
    # of a single requested user id or '(userid1, userid2, ..., useridN)' in the case of multiple
    # requested user ids
    if isinstance(user_ids, int) or isinstance(user_ids, np.int64):
        formatter = f'({user_ids})'
    elif len(user_ids) == 1:
        formatter = f'({user_ids[0]})'
    else:
        formatter = tuple(user_ids)
    
    db = postgre()
    
    query = f"""
    SELECT 
        user_id AS userid, date, type AS vitalid, value, source AS deviceid
    FROM 
        datenspende.vitaldata
    WHERE 
        vitaldata.user_id IN {formatter}
    AND
        vitaldata.type IN (9, 65, 43)
    """

    vitals = pd.read_sql_query(query, db.conn)
    vitals.date = pd.to_datetime(vitals.date)    

    db.close()
    
    return vitals


def get_user_data(user_ids):
    query = """SELECT * FROM datenspende.users WHERE user_id IN {0}""".format(tuple(user_ids))
    
    db = postgre()
    
    users = pd.read_sql_query(query, db.conn)
    users.salutation = users.salutation.fillna(30.0)

    db.close()
    
    return users


def load_data(force_reload=False):
    
    tests = load(TESTDATA, force_reload, pcr_tests)
    vaccs = load(VACCINATIONS, force_reload, vaccinations)
    
    data = pd.merge(vaccs, tests, on='user_id')

    print('Number of users with vaccination data:', len(vaccs))
    print('Number of users with test data:', len(tests))
    print('Overlap between both groups:', len(data))
    
    get_vitals_partial = functools.partial(get_vitals, data.user_id.unique())
    vitals = load(VITALS, force_reload, get_vitals_partial)
    
    get_userdata_partial = functools.partial(get_user_data, data.user_id.unique())
    user_info = load(USERS, force_reload, get_userdata_partial)
        
    return data, vitals, user_info

# Preprocess

In [None]:
def preprocess(vitals):

    # Apple users after 2021 October update that messes with sleep data
    invalid = (vitals.deviceid == 6) & (vitals.vitalid == 43) & (vitals.date >= '2021-10-20')
    vitals = vitals[~invalid]
    print("Number of users after removing invalid apple users:", len(vitals.userid.unique()))

    return normalize(vitals, by=['vitalid', 'date', 'deviceid'])


def normalize(df, by=['vitalid', 'date', 'deviceid']):
    
    norm = df.groupby(by)[['value']].mean().rename(columns={'value': 'daily_mean'})
    norm.reset_index(inplace=True)
    df = pd.merge(df, norm, on=by)
    df['value'] = df['value'] - df['daily_mean']
    
    return df.drop(columns='daily_mean') 


def create_cohorts(metadata):
    
    invalid = metadata.jansen_received == True 
    #metadata = metadata[~invalid]
    print('Number of users that received jansen:', invalid.sum())

    #invalid = metadata.previously_infected == True
    #metadata = metadata[~invalid]
    #print('Number of users that had a previous infection:', invalid.sum())
    
    positive = metadata[
        (metadata.test_result == 'positive')
    ].user_id.values
    
    unvacc = metadata[
        (metadata.first_dose > metadata.test_date) & 
        (metadata.test_result == 'positive')
    ].user_id.values

    vacc = metadata[
        metadata.status.isin(['full', 'booster']) & 
        (metadata.second_dose < metadata.test_date) & 
        (metadata.test_result == 'positive') &
        (metadata.jansen_received == False)
    ].user_id.values
    
    deltavacc = metadata[
        metadata.status.isin(['full', 'booster']) & 
        (metadata.second_dose < metadata.test_date) & 
        (metadata.test_result == 'positive') & 
        (metadata.test_date < '2021-12-15') &
        (metadata.jansen_received == False)
    ].user_id.values
    
    omicronvacc = metadata[
        metadata.status.isin(['full', 'booster']) & 
        (metadata.second_dose < metadata.test_date) & 
        (metadata.test_result == 'positive') & 
        (metadata.test_date >= '2021-12-15') &
        (metadata.jansen_received == False)
    ].user_id.values
    
    negative = metadata[metadata.test_result == 'negative'].user_id.values
    
    total = np.concatenate([vacc, unvacc, negative])
    
    results = {
        'unvaccinated': unvacc, 
        'vaccinated': vacc, 
        'negative': negative, 
        'total': total,
        'vaccinated_delta': deltavacc,
        'vaccinated_omicron': omicronvacc,
        'positive': positive,
    }

    return results


def add_user_age(users):
    
    users['age'] = np.floor((2022 + 4 / 12) - users['birth_date'] + 2.5)

    return users

#vacc = data.status.isin(['full', 'booster']) & (data.second_dose < data.test_date) & (data.test_date < '2021-12-15')
#vacc = data[vacc & (data.test_result == 'positive')].user_id.values
#negative = data[data.test_result == 'negative'].user_id.values
#total = np.concatenate([vacc, unvacc, negative])

# Transform

In [None]:
def remove_unplausible_values(df, key='vital_change'):
    
    # Remove data points where users differ more than 1E6 steps from their baseline. 
    # Happens at one single instance
    df = df[df[key] < 1E6]
    
    return df


def weekly_deviations(vitaldata, testdata, min_points_per_week=6, min_weeks_for_baseline=3):

    # Merge vitals and test
    df = pd.merge(vitaldata, testdata, left_on='userid', right_on='user_id')
    
    # Compute weeks since test
    df['weeks_since_test'] = (df.date - df.test_date).dt.days // 7
    
    # Restrict to interesting time interval
    df.drop(columns=['user_id', 'date', 'test_date'], inplace=True)
    df = df[df.weeks_since_test.between(-8, 20)]
    print('Number of users that donate at least one data point between -8 and 20 weeks around test:', len(df.userid.unique()))
    
    # Remove implausible value
    df = remove_unplausible_values(df, key='value')
    print('Number of users after removal of unplausible values:', len(df.userid.unique()))
    
    # Aggregate weekly
    aggregations =  {'test_result': 'first', 'value': ['mean', 'count', 'std', 'max', 'min']}
    df = df.groupby(['userid', 'vitalid', 'weeks_since_test']).agg(aggregations)
    df = df[df['value']['count'] >= min_points_per_week].copy()
    print(f'Number of users with at least one week of {min_points_per_week} data points:', len(df.reset_index().userid.unique()))

    # Reformat dataframe
    df.columns = df.columns.get_level_values(1)
    df.drop(columns=['count'], inplace=True)
    df.reset_index(inplace=True)
    df.rename(columns={'first': 'test_result', 'mean': 'value'}, inplace=True)
    
    # Compute per-user baseline
    baseline = df[df.weeks_since_test < -1].groupby(['userid', 'vitalid'])[['value']].agg(['mean', 'count'])
    baseline = baseline[baseline['value']['count'] >= min_weeks_for_baseline]
    baseline.columns = baseline.columns.get_level_values(1)
    baseline = baseline.rename(columns={'mean': 'baseline'}).drop(columns='count')
    print(f'Number of users with at least {min_weeks_for_baseline} weeks of baseline data:', len(baseline.reset_index().userid.unique()))

    # Compute vital changes
    df = pd.merge(df, baseline, on=['userid', 'vitalid'])
    df['vital_change'] = df['value'] - df.baseline
    df['vital_change_max'] = df['max'] - df.baseline
    df['vital_change_min'] = df['min'] - df.baseline
    
    return df

# Statistics

In [None]:
def significance_test(data, group1, group2, level, two_sided=False):
    
    significance = {65: {}, 9: {}, 43: {}}
    
    for vitalid in significance.keys():
        df = data[data.vitalid == vitalid]
                    
        for i in range (-3, 13):
            
            a = df[(df.userid.isin(group1)) & (df.weeks_since_test == i)].vital_change.values
            b = df[(df.userid.isin(group2)) & (df.weeks_since_test == i)].vital_change.values
                
            if two_sided:
                side = 'two-sided'
            else:
                if vitalid == 9:
                    side = 'less'
                else:
                    side = 'greater'
                    
            significance[vitalid][i] = ttest_ind(a, b, equal_var=False, alternative=side)[1] < level
                
    return significance

# Plots

In [None]:
def plot_example_timeseries(vitals, metadata, unvacc_user, vacc_user):
    
    f, axarr = plt.subplots(3, 2, sharex='col', figsize=(8, 4))

    BASELINE_LENGTH = 54
    
    for i, userid in enumerate((unvacc_user, vacc_user)):
        
        testdate = metadata[metadata.user_id == userid].test_date.values[0]

        for ax, vitalid in zip(axarr[:, i], (65, 9, 43)):

            if i == 0:
                color = flatuicolors.pomegranate
            else:
                color = flatuicolors.wisteria
                
            x, y = vitals[
                (vitals.userid == userid) & 
                (vitals.vitalid == vitalid) &
                (vitals.date.between(testdate - np.timedelta64(BASELINE_LENGTH,'D'), testdate + np.timedelta64(90,'D')))
            ][['date', 'value']].values.T

            if vitalid == 43:
                mask = y > 250
                x = x[mask]
                y = y[mask]
                y = y / 60
                ax.set_ylim(5, 15)
                ax.set_ylabel('Sleep duration\n[hrs/day]')
                
            if vitalid == 9:
                mask = y < 20000
                x = x[mask]
                y = y[mask]
                ax.set_ylim(0, 15000)
                ax.set_ylabel('Activity\n[steps/day]')

            if vitalid == 65:
                ax.set_ylabel('Resting heart rate\n[bpm]')

            
            
            if i == 0:
                if vitalid == 43:
                    styling.hide_and_move_axis(ax)
                else:
                    styling.hide_and_move_axis(ax, hide=['right', 'top', 'bottom'])
                    ax.xaxis.set_visible(False) 
                
                if vitalid == 65:
                    ax.set_ylim(60, 80)
            if i == 1:
                if vitalid == 43:
                    styling.hide_and_move_axis(ax, hide=['top', 'left'])
                else:
                    styling.hide_and_move_axis(ax, hide=['top', 'left', 'bottom'])
                    ax.xaxis.set_visible(False) 
                                 
                if vitalid == 65:
                    ax.set_ylim(45, 60)

            ax.plot(x, y, c=color, lw=1.2, alpha=0.9, zorder=10)
            ax.axhline(y[:BASELINE_LENGTH].mean(), color=flatuicolors.midnightblue, ls='--')
            #ax.fill_between(x=[testdate, testdate + np.timedelta64(7,'D')], y1=y.min()-1, y2=y.max()+1, color='r', alpha=0.5)

        rect = plt.Rectangle((testdate, -0.04), width=np.timedelta64(7,'D'), height=3+2*f.subplotpars.hspace+0.15,
                     transform=ax.get_xaxis_transform(), clip_on=False,
                     edgecolor='none', facecolor="k", alpha=0.25, linewidth=3)
        ax.add_patch(rect)
        
    plt.subplots_adjust(wspace=0.03, hspace=0.24, bottom=0.075, top=0.97, left=0.14, right=0.86)
    plt.savefig('output/example_timeseries.jpg', dpi=400)

In [None]:
def plot_weekly_deviations(data, plot_info, significance_info, outfile, lang='en', order=[2,3,0,1], pos = {65: (2.5, 0.3), 9: (880, 440), 43: (70, 8)}):
    
    f, axarr = plt.subplots(3, 1, sharex=True, figsize=(4, 7))
    
    if lang == 'en':
        labels = {
            65: 'Change in RHR [bpm]',
            9: 'Change in daily steps',
            43: 'Change in daily sleep\nduration [minutes]'
        }
        xlabel = 'Weeks since PCR-test'
        barlabel = 'Week of PCR-test'
        
    else:
        labels = {
            65: 'Ruhepulsänderung [bpm]',
            9: 'Änderung der\ntgl. Schrittzahl',
            43: 'Änderung der\nSchlafdauer [min]'
        }
        xlabel = 'Wochen nach PCR-Test'
        barlabel = 'Testwoche'

    for ax, vitalid in zip(axarr, (65, 9, 43)):

        for users, limit, label, color, alpha, marker, show_errors in plot_info:

            agg = data[(data.vitalid == vitalid) & data.userid.isin(users)]

            agg = agg.groupby('weeks_since_test')[['vital_change']].agg(['mean', 'std', 'count'])
            agg.columns = agg.columns.get_level_values(1)
            agg['err'] = 1. * agg['std'] / np.sqrt(agg['count'])
            agg= agg.loc[-3:limit]

            #print(agg['mean'].max(), agg['mean'])
            if show_errors:
                ax.errorbar(
                    agg.index, agg['mean'], agg['err'], lw=1.5, ls='-', label=label, marker=marker, 
                    c=color, alpha=alpha
                )
            else:
                ax.plot(
                    agg.index, agg['mean'], label=label, marker=marker, ls='-', 
                    c=color, alpha=alpha
                )

            agg = agg.drop(columns=['count', 'std'])
            agg = agg.rename(columns={'mean': 'mittelwert', 'err': 'standardfehler'})
            
        # Significance test!
        for sig_data, facecolor, fillcolor, row in significance_info: 
            for week, is_significant in sig_data[vitalid].items():
        
                if not is_significant:
                    continue
                
                x0, delta = pos[vitalid]
                ax.scatter(week, x0 - delta * row, marker="*", edgecolor=facecolor, facecolor=fillcolor, s=90)
                      
        y0, y1 = ax.get_ylim()
        styling.hide_and_move_axis(ax)
        
        ax.fill_between(x=[-0.5, 0.5], y1=y0, y2=y1, label=barlabel, color='k', alpha=0.115, zorder=-10)
        ax.fill_between(x=[-1, 1], y1=y0, y2=y1, color='k', alpha=0.075, zorder=-10)
        
        ax.set_ylim(y0, y1)
        ax.set_ylabel(labels[vitalid], size=12)

    ax.set_xticks(range(-2, 16, 2))
    
    handles, labels = plt.gca().get_legend_handles_labels()
    print(handles)
    axarr[-2].legend([handles[idx] for idx in order], [labels[idx] for idx in order], fontsize=9, frameon=False)
    
    for ax, label in zip(axarr, ('A', 'B', 'C')):
        x0, x1 = ax.get_xlim()
        y0, y1 = ax.get_ylim()
        if label != 'B':
            ax.text(x0 + 0.88 * (x1 - x0), y0 + 0.8 * (y1 - y0), label, size=24)
        else:
            ax.text(x0 + 0.03 * (x1 - x0), y0 + 0.1 * (y1 - y0), label, size=24)
            
    #axarr[0].text(13.5, 2.1, 'A', size=24)
    #axarr[1].text(-3.5, -3500, 'B', size=24)
    #axarr[2].text(13.5, 55, 'C', size=24)

    ax.set_xlabel(xlabel, size=12)
    plt.tight_layout()
    plt.savefig(outfile, dpi=400)
    
    
def figure2(cohortkey, label, outfile, max_weeks, siglevel, lang='en', with_sig=True, altkey='unvaccinated', altlabel='Unvaccinated'):

    if lang == 'en':
        neg_label = 'Negative'
    else:
        neg_label = 'negativer Test'
    
    # Plot
    sig_unvacc_vs_vacc = significance_test(DF, COHORTS[altkey], COHORTS[cohortkey], level=siglevel)
    sig_unvacc_vs_negative = significance_test(DF, COHORTS[altkey], COHORTS['negative'], level=siglevel)
    sig_vacc_vs_negative = significance_test(DF, COHORTS[cohortkey], COHORTS['negative'], level=siglevel)

    # Each row contains:
    # users, limit, label, color, alpha, marker, show_errors
    PLOT_INFO = [
            (COHORTS[altkey], 15, altlabel, flatuicolors.pomegranate, 1, 'o', True),
            (COHORTS[cohortkey], max_weeks, label, flatuicolors.amethyst, 1, 'd', True),
            (COHORTS['negative'], 15, neg_label, flatuicolors.belizehole , 0.4, 'p', False),
    ]

    # Each row contains:
    # sig_data, facecolor, fillcolor, row

    if with_sig:
        SIG_INFO = [
            (sig_unvacc_vs_negative, flatuicolors.pomegranate, 'w', 0),
            (sig_unvacc_vs_vacc, flatuicolors.pomegranate, flatuicolors.pomegranate, 0),
            (sig_vacc_vs_negative, flatuicolors.amethyst, flatuicolors.amethyst, 1),
        ]
    else:
        SIG_INFO = [
            (sig_unvacc_vs_negative, [(0, 0, 0, 0)], [(0, 0, 0, 0)], 0),
            (sig_unvacc_vs_vacc, [(0, 0, 0, 0)], [(0, 0, 0, 0)], 0),
            (sig_vacc_vs_negative, [(0, 0, 0, 0)], [(0, 0, 0, 0)], 1),
        ]
        
    plot_weekly_deviations(data=DF, plot_info=PLOT_INFO, significance_info=SIG_INFO, outfile=outfile, lang=lang)
    
    
def figure_breakthrough_timing(metadata, weekly_deviations, vaccinateds):

    f, ax = plt.subplots(figsize=(4,3))
    
    valid_users = weekly_deviations[weekly_deviations.userid.isin(vaccinateds)].userid.unique()
    temp = metadata[metadata.user_id.isin(valid_users)]
    temp['dt'] = temp.test_date - temp.third_dose
    temp['dt'][temp.dt.isna()] = temp[temp.dt.isna()].test_date - temp[temp.dt.isna()].second_dose
    
    x = np.array([t.days // 30 for t  in temp.dt])

    print('Share of breakthrough cases in month after vaccination:', (x == 0).sum() / len(x))
    
    ax.hist(x, bins=np.arange(-0.4, 11.5, 1), width=0.8, color=flatuicolors.midnightblue)
    styling.hide_and_move_axis(ax)
    ax.set_xlabel('Months since vaccination')
    ax.set_ylabel('Number of recorded\nbreakthrough infections')
    
    plt.tight_layout()
    plt.savefig('output/time_diff_vaccination_infection.jpg', dpi=400)
    
    
def user_table(valid_user, user_info, cohorts):
    
    table = """\\begin{tabular}{l | r r r | r}\n"""
    table += """& Vaccinated & Unvaccinated & Negative & Total\\\\ [0.5ex] \\hline\\hline\n"""
    
    vacc = cohorts['vaccinated']
    unvacc = cohorts['unvaccinated']
    negative = cohorts['negative']
    total = cohorts['total']
    
    users = user_info[user_info.user_id.isin(valid_user)]
    
    for row in ('Female', 'Male', 'Other', 'Age (mean)', 'Age (std)'):

        row_str = row

        for group in (vacc, unvacc, negative, total):

            if row == 'Female':
                n = len(users[users.user_id.isin(group) & (users.salutation == 10)])
            elif row == 'Male':
                n = len(users[users.user_id.isin(group) & (users.salutation == 20)])
            elif row == 'Other':
                n = len(users[users.user_id.isin(group) & (users.salutation == 30)])
            elif row == 'Age (mean)':
                n = users[users.user_id.isin(group)].age.mean()
            elif row == 'Age (std)':
                n = users[users.user_id.isin(group)].age.std()

            if row in ('Female', 'Male', 'Other'):
                p = n / len(users[users.user_id.isin(group)])
                row_str += ' & ' + str(n) + ' ({:.2f}\%)'.format(p * 100) 
            else:
                row_str += ' & {:.2f}yr '.format(n)

        row_str += ' \\\\ \n'
        table += row_str

    table += """\end{tabular}"""
    print(table)

    with open('output/user_data_vaccinations.txt', 'w') as outfile:
        outfile.write(table)
        
        
def figure_si_age_groups(users):
    
    f, ax = plt.subplots(figsize=(5,3.8))

    count, bins = np.histogram(users['age'], bins=[0, 21, 25, 40, 60, 65, 100])

    german_pop = pd.read_excel(
        'data/statistic_id1365_bevoelkerung-deutschlands-nach-relevanten-altersgruppen-2020.xlsx', 
        sheet_name='Daten',
        header=4,
        usecols=[1, 2],
        nrows=10,
    )

    german_vals = np.append([german_pop['2020'][:5].sum()], german_pop['2020'][5:].values)
    german_vals /= german_vals.sum()

    ax.bar(np.arange(len(count)) - 0.22, count / count.sum(), width=0.37, label='Study cohort', color=flatuicolors.midnightblue)
    ax.bar(np.arange(len(german_vals)) + 0.22, german_vals, width=0.37, label='German Population', color=flatuicolors.concrete)

    ax.set_xticks(np.arange(len(count)))
    ax.set_xticklabels(['0-20', '20-24', '25-39', '40-59', '60-64', '65+'])

    styling.hide_and_move_axis(ax)
    ax.legend(loc='upper left')

    ax.set_xlabel('Age group')
    ax.set_ylabel('Relative frequency')

    ax.set_ylim(0, 0.6)
    plt.tight_layout()
    plt.savefig('output/age_groups.jpg', dpi=400)

In [None]:
def figure_delta_omicron(outfile, max_weeks, siglevel):
    
    # Plot
    sig = significance_test(DF, COHORTS['vaccinated_delta'], COHORTS['vaccinated_omicron'], level=siglevel, two_sided=True)
    
    # Each row contains:
    # users, limit, label, color, alpha, marker, show_errors
    PLOT_INFO = [
            (COHORTS['vaccinated_delta'], max_weeks, 'Delta-Infection', flatuicolors.pomegranate, 1, 'o', True),
            (COHORTS['vaccinated_omicron'], max_weeks, 'Omicron-Infection', flatuicolors.amethyst, 1, 'd', True),
            (COHORTS['negative'], 15, 'Negative', flatuicolors.belizehole , 0.4, 'p', False),
            ]

    # Each row contains:
    # sig_data, facecolor, fillcolor, row

    SIG_INFO = [
        (sig, flatuicolors.pomegranate, flatuicolors.pomegranate, 0),
    ]

    plot_weekly_deviations(
        data=DF, plot_info=PLOT_INFO, significance_info=SIG_INFO, outfile=outfile, 
        order=[ 2, 3, 0, 1]
    )


In [None]:
def monthly_reports():

    unvacc = DF[DF.userid.isin(COHORTS['unvaccinated'])].userid.unique()
    unvacc = METADATA[METADATA.user_id.isin(unvacc)]
    unvacc.set_index('test_date', inplace=True)
    unvacc = unvacc.resample('1M').count()

    vacc = DF[DF.userid.isin(COHORTS['vaccinated'])].userid.unique()
    vacc = METADATA[METADATA.user_id.isin(vacc)]
    vacc.set_index('test_date', inplace=True)
    vacc = vacc.resample('1M').count()
    #print(vacc.index.min())

    plot_data = pd.merge(vacc, unvacc, how='outer', on='test_date')
    plot_data.sort_index(inplace=True)
    plot_data = plot_data[2:]
    
    f, ax = plt.subplots(figsize=(4.5,3))

    ax.bar(plot_data.index - timedelta(days=30), plot_data.user_id_x, width=20, color=flatuicolors.amethyst, label='Vaccinated')
    ax.bar(plot_data.index - timedelta(days=30), plot_data.user_id_y.fillna(0), width=20, bottom=plot_data.user_id_x.fillna(0), color=flatuicolors.pomegranate, label='Unvaccinated')
    ax.legend(loc='upper left')

    ax.semilogy()
    ax.set_ylabel('Number of self-reported\npositive PCR-tests')
    styling.hide_and_move_axis(ax)
    plt.tight_layout()
    
    ax.set_ylim(0.5, 2E3)
    plt.savefig('output/pcr_counts.jpg', dpi=400)

In [None]:
def vital_change_distributions(weekly_deviations, unvacc, vacc, negative):

    f, axarr = plt.subplots(1, 3, figsize=(10, 3), sharey=True)

    for ax, vital, label in zip(axarr, (65, 9, 43), ['RHR increase [bpm/day]', 'Activity reduction [steps/day]', 'Sleep increase [min/day]']):

        a = weekly_deviations[(weekly_deviations.vitalid == vital) & weekly_deviations.weeks_since_test.between(0, 4)]

        xmin = np.min(a.vital_change)
        xmax = np.max(a.vital_change)

        #print(xmin, xmax)
        if vital == 43:
            bins = np.arange(-15, 180, 30)
        elif vital == 65:
            bins = np.arange(-1, 12, 2)
        else:
            bins = np.arange(-11000, 2000, 2000)

        a.vital_change[a.vital_change < bins[0]] = bins[0]
        a.vital_change[a.vital_change > bins[-1]] = bins[-1]

        ncount, x = np.histogram(a[a.userid.isin(negative)].vital_change, bins=bins)
        vcount, x = np.histogram(a[a.userid.isin(vacc)].vital_change, bins=bins)
        ucount, x = np.histogram(a[a.userid.isin(unvacc)].vital_change, bins=bins)

        ncount = ncount / len(a[a.userid.isin(negative)])
        vcount = vcount / len(a[a.userid.isin(vacc)])
        ucount = ucount / len(a[a.userid.isin(unvacc)])

        print(ncount.sum(), vcount.sum(), ucount.sum())

        x = x[:-1] + np.diff(x)[0] / 2
        width = np.diff(x)[0] / 3 * .7

        if vital == 9:
            x *= -1

        ax.bar(x-1.15*width, ncount, width=width, color=flatuicolors.belizehole, label='Negative')
        ax.bar(x, vcount, width=width, color=flatuicolors.amethyst, label='Vaccinated') 
        ax.bar(x+1.15*width, ucount, width=width, color=flatuicolors.pomegranate, label='Unvaccinated')

        ax.set_xticks(x)
        styling.hide_and_move_axis(ax)
        ax.set_xlabel(label, size=12)

    axarr[0].text(10, .61, 'A', size=24)
    axarr[1].text(10000, .61, 'B', size=24)
    axarr[2].text(150, .61, 'C', size=24)

    axarr[0].set_ylabel('Relative frequency', size=12)
    axarr[0].legend(loc='center right', frameon=False)
    plt.tight_layout()
    plt.savefig('output/change_frequencies.jpg', dpi=400)
    

def extreme_vitals(df, unvacc, vacc, negative, lang='en', with_sig=True, sig_level=0.01):
    f, axarr = plt.subplots(3, 1, sharex=True, figsize=(4,7))

    if lang == 'en':
        labels = [
            'Prevalence\nof more than 5 bpm/day\nincrease in RHR',
            'Prevalence\nof more than 5000 steps/day\nreduced activity',
            'Prevalence of more\nthan 1 hr/day\nadditional sleep'
        ]

        cases = [
            ('positive', unvacc, 15, 'Unvaccinated', flatuicolors.pomegranate, 1),
            ('positive', vacc, 9, 'Vaccinated', flatuicolors.amethyst, 1),
            ('negative', negative, 15, 'Negative', flatuicolors.peterriver, 0.5),
        ]
        xlabel = 'Weeks since PCR-test'
    
    else:
        labels = [
            'Prävalenz von mehr\nals 5 bpm/Tag\nRuhepulserhöhung',
            'Prävalenz von mehr\nals 5000 Schritten/Tag\nreduzierter Aktivität',
            'Prävalenz von mehr\nals 1 h/Tag\nverlängerter Schlafdauer'
        ]

        cases = [
            ('positive', unvacc, 15, 'ohne Impfschutz', flatuicolors.pomegranate, 1),
            ('positive', vacc, 9, 'mit Impfschutz', flatuicolors.amethyst, 1),
            ('negative', negative, 15, 'negativer Test', flatuicolors.peterriver, 0.5),
        ]
        xlabel = 'Wochen nach PCR-Test'

    for ax, vitalid, ylabel, threshold in zip(axarr, ( 65, 9, 43, ), labels, (5, -5000, 60)):

        for i, (test_result, users, limit, label, color, alpha) in enumerate(cases):

            agg = df[(df.vitalid == vitalid) & df.userid.isin(users)].copy()

            if vitalid == 65:
                agg['is_extreme'] = agg.vital_change > threshold
            elif vitalid == 43:
                agg['is_extreme'] = agg.vital_change > threshold
            else:
                agg['is_extreme'] = agg.vital_change < threshold

            agg = agg.groupby('weeks_since_test').is_extreme.agg(['mean', 'count'])
            agg['err'] = 1. * np.sqrt(agg['mean'] / agg['count'] * (1 - agg['mean']))
            agg = agg.loc[-2:5]
            agg[agg['mean'] == 0] = 0.001

            ax.bar(agg.index + i / 4 - 0.25, agg['mean'], width=0.18, color=color, yerr=agg['err'], label=label,
                  error_kw=dict(ecolor='k', alpha=0.85, lw=1.5, capsize=2, capthick=1.5)
                  )

        # Significance test!
        if with_sig:
            for i in range (-1, 6):
                for j, (users, facecolor, fillcolor) in enumerate(zip(
                    [(unvacc, negative), (unvacc, vacc), (vacc, negative)], 
                    [flatuicolors.pomegranate, flatuicolors.pomegranate,  flatuicolors.amethyst],
                    ['w', flatuicolors.pomegranate,  flatuicolors.amethyst],
                )):

                    a = df[(df.vitalid == vitalid) & (df.userid.isin(users[0])) & (df.weeks_since_test == i)].vital_change.values
                    b = df[(df.vitalid == vitalid) & (df.userid.isin(users[1])) & (df.weeks_since_test == i)].vital_change.values

                    if vitalid == 9:
                        a = a < threshold
                        b = b < threshold
                    else:
                        a = a > threshold
                        b = b > threshold

                    if proportions_ztest([a.sum(), b.sum()], [len(a), len(b)], alternative='larger')[1] < sig_level:
                        if vitalid == 65:
                            ax.scatter(i-0.17, 0.24 - 0.022 * (j >= 2), marker='*', edgecolor=facecolor, facecolor=fillcolor, s=90)
                        if vitalid == 9:
                            ax.scatter(i-0.17, 0.35 - 0.033 * (j >= 2), marker='*', edgecolor=facecolor, facecolor=fillcolor, s=90)
                        if vitalid == 43:
                            ax.scatter(i-0.17, 0.44 - 0.04 * (j >= 2), marker='*', edgecolor=facecolor, facecolor=fillcolor, s=90)

        styling.hide_and_move_axis(ax)
        ax.set_ylabel(ylabel, size=12)

    axarr[0].text(-2.5, .21, 'A', size=24)
    axarr[1].text(-2.5, .31, 'B', size=24)
    axarr[2].text(-2.5, .39, 'C', size=24)

    ax.set_xticks(range(-2, 6, 1))
    
    axbox = axarr[1].get_position()
    axarr[1].legend(loc=(axbox.x0 + .4, axbox.y0), frameon=False)
    
    
    ax.set_xlabel(xlabel, size=12)
    plt.tight_layout()
    plt.savefig(f'output/acute_phase_extreme_vitals_{lang}.jpg', dpi=400)

In [None]:
FORCE_RELOAD = False
MIN_POINTS_PER_WEEK = 6
MIN_WEEKS_FOR_BASELINE = 3

# Load data
METADATA, RAW_VITALS, USERS = load_data(force_reload=FORCE_RELOAD)

# Preprocess
VITALS = preprocess(RAW_VITALS)
COHORTS = create_cohorts(metadata=METADATA)
USERS = add_user_age(USERS)

# Transform
DF = weekly_deviations(
    vitaldata=VITALS, testdata=METADATA, min_points_per_week=MIN_POINTS_PER_WEEK, 
    min_weeks_for_baseline=MIN_WEEKS_FOR_BASELINE
)

In [None]:
plot_example_timeseries(RAW_VITALS, METADATA, 276043, 460177)

In [None]:
figure2('vaccinated', 'Vaccinated', 'output/figure2_weekly_changes.jpg', 12, 0.01)

In [None]:
vital_change_distributions(DF, COHORTS['unvaccinated'], COHORTS['vaccinated'], COHORTS['negative'])

In [None]:
extreme_vitals(DF, COHORTS['unvaccinated'], COHORTS['vaccinated'], COHORTS['negative'], sig_level=0.01)

In [None]:
figure2('vaccinated_delta', 'Vaccinated/Delta', 'output/figure_SI_weekly_changes_delta_signifance005.jpg', 12, 0.05)

In [None]:
figure_delta_omicron('output/delta_omicron.jpg', 12, 0.01)

In [None]:
figure_breakthrough_timing(METADATA, DF, COHORTS['vaccinated'])

In [None]:
user_table(DF.userid.unique(), USERS, COHORTS)

In [None]:
figure_si_age_groups(USERS)

In [None]:
monthly_reports()

# EQ5D

In [None]:
tests = load('data/recent_tests.feather', True, pcr_tests)
vaccs = load('data/recent_vaccs.feather', True, vaccinations)
    
metadata = pd.merge(vaccs, tests, on='user_id')

In [None]:
def load_eq5d(elements=(144, 145, 146, 147, 148)):
    
    if len(elements) == 1:
        formatter = '('+str(elements[0]) + ')'
    else:
        formatter = str(elements)
    
    query = """
    SELECT 
        answers.user_id, answers.questionnaire_session, answers.created_at, choice.choice_id 
    FROM 
        datenspende.answers, datenspende.choice 
    WHERE
        answers.element = choice.element and 
        answers.question in {0}
    """.format(formatter)

    db = postgre()
    
    df = pd.read_sql(query, db.conn)
    
    df = df.groupby(['user_id','questionnaire_session']).agg({'created_at': 'mean', 'choice_id': ['sum', 'count']})
    df.columns = df.columns.get_level_values(1)
    df['mean'] = pd.to_datetime(df['mean'], unit='ms').dt.date
    df.rename(columns={'mean': 'date', 'sum': 'EQ5D'}, inplace=True)
    df = df[df['count'] == len(elements)]
    df.drop(columns='count', inplace=True)
    
    return df.reset_index()

In [None]:
LABEL = [
    'Overall health problems',
    'Problems in walking about',
    'Problems washing or dressing',
    'Problems in usual activities',
    'Pain or discomfort',
    'Anxiety and depression'
]   
    

f, axarr = plt.subplots(2, 3, figsize=(10, 6))

ELEMENTS = [(144, 145, 146, 147, 148), (144, ), (145, ), (146, ), (147, ), (148, )]

for element, ax, label in zip(ELEMENTS, axarr.flatten(), LABEL):

    eq5d = load_eq5d(elements=element)
    
    df = pd.merge(eq5d, metadata, on='user_id')
    df['EQ5D_trunc'] = round(df['EQ5D'] / len(element))
    
    #df = df[df.EQ5D_trunc > 1]
    #print(df.EQ5D_trunc.value_counts())

    ucount, x = np.histogram(df[df.user_id.isin(COHORTS['unvaccinated'])].EQ5D_trunc, bins=np.arange(0.5, 5.5))
    vcount, x = np.histogram(df[df.user_id.isin(COHORTS['vaccinated'])].EQ5D_trunc, bins=np.arange(0.5, 5.5))
    ncount, x = np.histogram(df[df.user_id.isin(COHORTS['negative'])].EQ5D_trunc, bins=np.arange(0.5, 5.5))
    
    print(ucount.sum(), vcount.sum(), ncount.sum())
    ucount = ucount / len(df[df.user_id.isin(COHORTS['unvaccinated'])])
    vcount = vcount / len(df[df.user_id.isin(COHORTS['vaccinated'])])
    ncount = ncount / len(df[df.user_id.isin(COHORTS['negative'])])
    
    x = x[:-1] + np.diff(x)[0] / 2
    width = np.diff(x)[0] / 3 * .7

    ax.bar(x-1.15 * width, ucount, width=width, label='Unvaccinated', color=flatuicolors.pomegranate)
    ax.bar(x, vcount, width=width, label='Vaccinated', color=flatuicolors.amethyst)
    ax.bar(x+1.15 * width, ncount, width=width, label='Negative', color=flatuicolors.belizehole)


    styling.hide_and_move_axis(ax)
    ax.set_xticks([1, 2, 3, 4, 5])
    ax.set_xticklabels(['None', 'Slight', 'Intermediate', 'Large', 'Extreme'], rotation=45, ha="right")
    
    ax.set_title(label)
    
for ax in axarr[:, 0]:
    ax.set_ylabel('Relative frequency')
    
ax.legend()
plt.tight_layout()
plt.savefig('output/eq5d.jpg', dpi=400)