In [None]:
%load_ext autoreload
%autoreload 2
%aimport utils_1_1

import pandas as pd
import numpy as np
import altair as alt
from altair_saver import save
import datetime
import dateutil.parser
from os.path import join

from constants_1_1 import SITE_FILE_TYPES
from utils_1_1 import (
    get_site_file_paths,
    get_site_file_info,
    get_site_ids,
    get_visualization_subtitle,
    get_country_color_map,
)
from theme import apply_theme
from web import for_website

alt.data_transformers.disable_max_rows(); # Allow using rows more than 5000

In [None]:
WAVE_COLOR = [
    '#D45E00', # '#BA4338', # early
    '#0072B2', # late
    'black'
]
STROKE = None
AXIS_SHOW = alt.Axis(grid=True, labels=True, ticks=True, domain=True, tickMinStep=1)
AXIS_HIDE_TITLE = alt.Axis(grid=True, labels=True, ticks=True, domain=True)
AXIS_HIDE = alt.Axis(grid=True, labels=False, ticks=False, domain=True)

# Define Function to Viaulize

In [None]:
def FUNC_DEMOGRAPHICS_BY_WAVE(
    _data, 
    country='', 
    race=True, 
    patient_group='all' # either 'all' or 'severe'
):
    d = _data.copy()
    
    """
    RENAME COLUMNS AND VALUES
    """
    d = d.rename(columns={
        'p.all': 'p_all', 
        'n.all': 'n_all',
        'p.severe': 'p_severe',
        'n.severe': 'n_severe',
        
        'p.all.lwr.ci': 'p_all_ci_l',
        'p.all.upr.ci': 'p_all_ci_u',
        'p.ever.lwr.ci': 'p_severe_ci_l',
        'p.ever.upr.ci': 'p_severe_ci_u',
        'p.all.sig': 'p_all_sig',
        'p.ever.sig': 'p_severe_sig',
        
        'p.lwr.ci': 'p_ci_l',
        'p.upr.ci': 'p_ci_u',
        'p.sig': 'p_sig'
    })
#     print(d)
    d.group = d.group.apply(
        lambda x: {
            '00to25': '0-25',
            '26to49': '26-49',
            '50to69': '50-69',
            '70to79': '70-79',
            '80plus': '80+',
            'female': 'Female',
            'male': 'Male',
            'white': 'White',
            'black': 'Black',
            'other': 'Other',
            'other_age': 'Other',
            'other_sex': 'Other',
            'other_race': 'Other'
        }[x]
    )
    d.wave = d.wave.apply(
        lambda x: {
            'early': 'Early',
            'late': 'Late'
        }[x]
    )
    
    """
    CATEGORIES WE USE
    """
    AGE_GROUPS = ['0-25', '26-49', '50-69', '70-79', '80+']
    SEX_GROUPS = ['Female', 'Male']
    RACE_GROUPS = ['White', 'Black']
    
    """
    /////////////////////
    SUB-CHARTS FOR GROUPS
    /////////////////////
    """
    
    """
    COMMON VISUAL PARAMETERS
    """
    width = 140
    titleX = -60
    
    """
    AGE GROUPS
    """
    ad = d[d.group.isin(AGE_GROUPS)]
    
    ############## Bar Chart for % of Participants ##############
    age_p_bar_base = alt.Chart(
        ad
    ).encode(
        x=alt.X(
            'group:N', 
            title=None,
            axis=AXIS_HIDE_TITLE
        ),
        y=alt.Y(f'p_{patient_group}:Q', axis=alt.Axis(format='.0%', titleX=titleX), title="Percentage of Patients"),
        color=alt.Color("wave:N", scale=alt.Scale(domain=['Early', 'Late', 'Late - Early'], range=WAVE_COLOR), title="Wave")
    ).properties(
        width=320, height=200
    )

    age_p_bar_early = age_p_bar_base.transform_filter(
        {'field': 'wave', 'oneOf': ['Early']}
    ).mark_bar(
        xOffset=-10,
        size=20, stroke=STROKE
    )
    age_p_bar_late = age_p_bar_base.transform_filter(
        {'field': 'wave', 'oneOf': ['Late']}
    ).mark_bar(
        xOffset=10,
        size=20, stroke=STROKE
    )
    
    
    """
    ================================================================================================================================================
    ################################################################## SIG NOTATATION ##############################################################
    ================================================================================================================================================
    """
    age_p_error_bar_early = age_p_bar_early.mark_bar(
        size=2, color='black', opacity=1, xOffset=-10, 
    ).encode(
        x=alt.X('group:N', title=None, axis=AXIS_HIDE_TITLE),
        y=alt.Y(f'p_{patient_group}_ci_u:Q'),
        y2=alt.Y2(f'p_{patient_group}_ci_l:Q'),
        color=alt.value('black')
    )
    
    age_p_error_bar_late = age_p_bar_late.mark_bar(
        size=2, color='black', opacity=1, xOffset=10, 
    ).encode(
        x=alt.X('group:N', title=None, axis=AXIS_HIDE_TITLE),
        y=alt.Y(f'p_{patient_group}_ci_u:Q'),
        y2=alt.Y2(f'p_{patient_group}_ci_l:Q'),
        color=alt.value('black')
    )
    
    age_p_sig_tick = age_p_bar_base.mark_tick(
        size=45, yOffset=-10, strokeWidth=8
    ).encode(
        x=alt.X(
            'group:N',
            title=None,
            axis=AXIS_HIDE_TITLE
        ),
        y=alt.Y(f'max(p_{patient_group}):Q', axis=alt.Axis(format='.0%', titleX=titleX), title="Percentage of Patients"),
        color=alt.value('black')
    ).transform_filter(
        {'field': f"p_{patient_group}_sig", 'oneOf': [True]}
    )
    
    age_p_sig_star = age_p_sig_tick.mark_text(
        yOffset=-15, fontSize=20
    ).encode(
        x=alt.X(
            'group:N',
            title=None,
            axis=AXIS_HIDE_TITLE
        ),
        y=alt.Y(f'max(p_{patient_group}):Q', axis=alt.Axis(format='.0%', titleX=titleX), title="Percentage of Patients"),
        color=alt.value('black'),
        text=alt.value('*')
    )
    
    age_p_bar = alt.layer(age_p_bar_early, age_p_bar_late, age_p_error_bar_early, age_p_error_bar_late, age_p_sig_tick, age_p_sig_star)

    ############## Bar Chart for # of Participants ##############
    age_n_bar_base = alt.Chart(
        ad
    ).encode(
        x=alt.X(
            'group:N', 
            title='Age', 
            axis=AXIS_SHOW
        ),
        y=alt.Y(f'n_{patient_group}:Q', title="# Patients", axis=alt.Axis(titleX=titleX)),
        color=alt.Color("wave:N", scale=alt.Scale(domain=['Early', 'Late'], range=WAVE_COLOR), title="Wave")
    ).properties(
        width=320, height=70
    )

    age_n_bar_early = age_n_bar_base.transform_filter(
        {'field': 'wave', 'oneOf': ['Early']}
    ).mark_bar(
        xOffset=-10,
        size=20, stroke=STROKE
    )
    age_n_bar_late = age_n_bar_base.transform_filter(
        {'field': 'wave', 'oneOf': ['Late']}
    ).mark_bar(
        xOffset=10,
        size=20, stroke=STROKE
    )
    
    age_n_bar = alt.layer(age_n_bar_early, age_n_bar_late)
    
    ############## Bar Chart for % of Ever Severe ##############
    age_s_bar_base = alt.Chart(
        ad
    ).encode(
        x=alt.X(
            'group:N', 
            title=None, 
            axis=AXIS_HIDE
        ),
        y=alt.Y(f'p:Q', title="% Severe", axis=alt.Axis(format=".0%", titleX=titleX)),
        color=alt.Color("wave:N", scale=alt.Scale(domain=['Early', 'Late'], range=WAVE_COLOR), title="Wave")
    ).properties(
        width=320, height=70
    )

    age_s_bar_early = age_s_bar_base.transform_filter(
        {'field': 'wave', 'oneOf': ['Early']}
    ).mark_bar(
        xOffset=-10,
        size=20, stroke=STROKE
    )
    age_s_bar_late = age_s_bar_base.transform_filter(
        {'field': 'wave', 'oneOf': ['Late']}
    ).mark_bar(
        xOffset=10,
        size=20, stroke=STROKE
    )
    
    """
    ================================================================================================================================================
    ################################################################## SIG NOTATATION ##############################################################
    ================================================================================================================================================
    """
    age_s_error_bar_early = age_s_bar_early.mark_bar(
        size=2, color='black', opacity=1, xOffset=-10, 
    ).encode(
        x=alt.X('group:N', title=None, axis=AXIS_HIDE_TITLE),
        y=alt.Y(f'p_ci_u:Q'),
        y2=alt.Y2(f'p_ci_l:Q'),
        color=alt.value('black')
    )
    
    age_s_error_bar_late = age_s_bar_late.mark_bar(
        size=2, color='black', opacity=1, xOffset=10, 
    ).encode(
        x=alt.X('group:N', title=None, axis=AXIS_HIDE_TITLE),
        y=alt.Y(f'p_ci_u:Q'),
        y2=alt.Y2(f'p_ci_l:Q'),
        color=alt.value('black')
    )
    
    age_s_sig_tick = age_s_bar_base.mark_tick(
        size=45, yOffset=-10, strokeWidth=8
    ).encode(
        x=alt.X(
            'group:N',
            title=None,
            axis=AXIS_HIDE_TITLE
        ),
        y=alt.Y(f'max(p_ci_u):Q', axis=alt.Axis(format='.0%')),
        color=alt.value('black')
    ).transform_filter(
        {'field': f"p_sig", 'oneOf': [True]}
    )
    
    age_s_sig_star = age_s_sig_tick.mark_text(
        yOffset=-15, fontSize=20
    ).encode(
        x=alt.X(
            'group:N',
            title=None,
            axis=AXIS_HIDE_TITLE
        ),
        y=alt.Y(f'max(p_ci_u):Q'),
        color=alt.value('black'),
        text=alt.value('*')
    )
    
    age_s_bar = alt.layer(age_s_bar_early, age_s_bar_late, age_s_error_bar_early, age_s_error_bar_late, age_s_sig_tick, age_s_sig_star)
    
    """
    SEX GROUPS
    """    
    sd = d[d.group.isin(SEX_GROUPS)]

    ############## Bar Chart for % of Participants ##############
    sex_p_bar_base = alt.Chart(
        sd
    ).encode(
        x=alt.X(
            'group:N', 
            title=None, 
            axis=AXIS_HIDE_TITLE
        ),
        y=alt.Y(f'p_{patient_group}:Q', axis=alt.Axis(format='.0%', grid=True, labels=False, ticks=False, domain=False), title=None),
        color=alt.Color("wave:N", scale=alt.Scale(domain=['Early', 'Late'], range=WAVE_COLOR), title="Wave")
    ).properties(
        width=width, height=200
    )

    sex_p_bar_early = sex_p_bar_base.transform_filter(
        {'field': 'wave', 'oneOf': ['Early']}
    ).mark_bar(
        xOffset=-10,
        size=20, stroke=STROKE
    )
    sex_p_bar_late = sex_p_bar_base.transform_filter(
        {'field': 'wave', 'oneOf': ['Late']}
    ).mark_bar(
        xOffset=10,
        size=20, stroke=STROKE
    )

    """
    !!! ERROR BARS !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    """
    sex_p_error_bar_early = sex_p_bar_early.mark_bar(
        size=2, color='black', opacity=1, xOffset=-10, 
    ).encode(
        x=alt.X('group:N', title=None, axis=AXIS_HIDE_TITLE),
        y=alt.Y(f'p_{patient_group}_ci_u:Q'),
        y2=alt.Y2(f'p_{patient_group}_ci_l:Q'),
        color=alt.value('black')
    )
    
    sex_p_error_bar_late = sex_p_bar_late.mark_bar(
        size=2, color='black', opacity=1, xOffset=10, 
    ).encode(
        x=alt.X('group:N', title=None, axis=AXIS_HIDE_TITLE),
        y=alt.Y(f'p_{patient_group}_ci_u:Q'),
        y2=alt.Y2(f'p_{patient_group}_ci_l:Q'),
        color=alt.value('black')
    )
    
    """
    ================================================================================================================================================
    ################################################################## SIG NOTATATION ##############################################################
    ================================================================================================================================================
    """
    sex_p_sig_tick = sex_p_bar_base.mark_tick(
        size=45, yOffset=-10, strokeWidth=8
    ).encode(
        x=alt.X(
            'group:N',
            title=None,
            axis=AXIS_HIDE_TITLE
        ),
        y=alt.Y(f'max(p_{patient_group}):Q', axis=alt.Axis(format='.0%', titleX=titleX), title="Percentage of Patients"),
        color=alt.value('black')
    ).transform_filter(
        {'field': f"p_{patient_group}_sig", 'oneOf': [True]}
    )
    
    sex_p_sig_star = sex_p_sig_tick.mark_text(
        yOffset=-15, fontSize=20
    ).encode(
        x=alt.X(
            'group:N',
            title=None,
            axis=AXIS_HIDE_TITLE
        ),
        y=alt.Y(f'max(p_{patient_group}):Q', axis=alt.Axis(format='.0%', titleX=titleX), title="Percentage of Patients"),
        color=alt.value('black'),
        text=alt.value('*')
    )
    
    sex_p_bar = alt.layer(sex_p_bar_early, sex_p_bar_late, sex_p_error_bar_early, sex_p_error_bar_late, sex_p_sig_tick, sex_p_sig_star)

    ############## Bar Chart for # of Participants ##############
    sex_n_bar_base = alt.Chart(
        sd
    ).encode(
        x=alt.X(
            'group:N', 
            title='Sex', 
            axis=AXIS_SHOW
        ),
        y=alt.Y(f'n_{patient_group}:Q', axis=alt.Axis(title=None, grid=True, labels=False, ticks=False, domain=False), title=None),
        color=alt.Color("wave:N", scale=alt.Scale(domain=['Early', 'Late'], range=WAVE_COLOR), title="Wave")
    ).properties(
        width=width, height=70
    )

    sex_n_bar_early = sex_n_bar_base.transform_filter(
        {'field': 'wave', 'oneOf': ['Early']}
    ).mark_bar(
        xOffset=-10,
        size=20, stroke=STROKE
    )
    sex_n_bar_late = sex_n_bar_base.transform_filter(
        {'field': 'wave', 'oneOf': ['Late']}
    ).mark_bar(
        xOffset=10,
        size=20, stroke=STROKE
    )
    sex_n_bar = alt.layer(sex_n_bar_early, sex_n_bar_late)
    
    ############## Bar Chart for % of Ever Severe ##############
    sex_s_bar_base = alt.Chart(
        sd
    ).encode(
        x=alt.X(
            'group:N', 
            title=None, 
            axis=AXIS_HIDE
        ),
        y=alt.Y(f'p:Q', axis=alt.Axis(title=None, grid=True, labels=False, ticks=False, domain=False), title=None),
        color=alt.Color("wave:N", scale=alt.Scale(domain=['Early', 'Late'], range=WAVE_COLOR), title="Wave")
    ).properties(
        width=width, height=70
    )

    sex_s_bar_early = sex_s_bar_base.transform_filter(
        {'field': 'wave', 'oneOf': ['Early']}
    ).mark_bar(
        xOffset=-10,
        size=20, stroke=STROKE
    )
    sex_s_bar_late = sex_s_bar_base.transform_filter(
        {'field': 'wave', 'oneOf': ['Late']}
    ).mark_bar(
        xOffset=10,
        size=20, stroke=STROKE
    )

    """
    ================================================================================================================================================
    ################################################################## SIG NOTATATION ##############################################################
    ================================================================================================================================================
    """
    sex_s_error_bar_early = sex_s_bar_early.mark_bar(
        size=2, color='black', opacity=1, xOffset=-10, 
    ).encode(
        x=alt.X('group:N', title=None, axis=AXIS_HIDE_TITLE),
        y=alt.Y(f'p_ci_u:Q'),
        y2=alt.Y2(f'p_ci_l:Q'),
        color=alt.value('black')
    )
    
    sex_s_error_bar_late = sex_s_bar_late.mark_bar(
        size=2, color='black', opacity=1, xOffset=10, 
    ).encode(
        x=alt.X('group:N', title=None, axis=AXIS_HIDE_TITLE),
        y=alt.Y(f'p_ci_u:Q'),
        y2=alt.Y2(f'p_ci_l:Q'),
        color=alt.value('black')
    )
    
    sex_s_sig_tick = sex_s_bar_base.mark_tick(
        size=45, yOffset=-10, strokeWidth=8
    ).encode(
        x=alt.X(
            'group:N',
            title=None,
            axis=AXIS_HIDE_TITLE
        ),
        y=alt.Y(f'max(p_ci_u):Q', axis=alt.Axis(format='.0%')),
        color=alt.value('black')
    ).transform_filter(
        {'field': f"p_sig", 'oneOf': [True]}
    )
    
    sex_s_sig_star = sex_s_sig_tick.mark_text(
        yOffset=-15, fontSize=20
    ).encode(
        x=alt.X(
            'group:N',
            title=None,
            axis=AXIS_HIDE_TITLE
        ),
        y=alt.Y(f'max(p_ci_u):Q'),
        color=alt.value('black'),
        text=alt.value('*')
    )
    
    sex_s_bar = alt.layer(sex_s_bar_early, sex_s_bar_late, sex_s_error_bar_early, sex_s_error_bar_late, sex_s_sig_tick, sex_s_sig_star)
    
    """
    RACE GROUPS
    """    
    rd = d[d.group.isin(RACE_GROUPS)]

    ############## Bar Chart for % of Participants ##############
    race_p_bar_base = alt.Chart(
        rd
    ).encode(
        x=alt.X(
            'group:N', 
            title=None, 
            axis=AXIS_HIDE_TITLE
        ),
        y=alt.Y(f'p_{patient_group}:Q', axis=alt.Axis(format='.0%', grid=True, labels=False, ticks=False, domain=False), title=None),
        color=alt.Color("wave:N", scale=alt.Scale(domain=['Early', 'Late'], range=WAVE_COLOR), title="Wave", legend=alt.Legend(title="Wave"))
    ).properties(
        width=width, height=200
    )

    race_p_bar_early = race_p_bar_base.transform_filter(
        {'field': 'wave', 'oneOf': ['Early']}
    ).mark_bar(
        xOffset=-10,
        size=20, stroke=STROKE
    )
    race_p_bar_late = race_p_bar_base.transform_filter(
        {'field': 'wave', 'oneOf': ['Late']}
    ).mark_bar(
        xOffset=10,
        size=20, stroke=STROKE
    )
    
    """
    !!! ERROR BARS !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    """
    race_p_error_bar_early = race_p_bar_early.mark_bar(
        size=2, color='black', opacity=1, xOffset=-10, 
    ).encode(
        x=alt.X('group:N', title=None, axis=AXIS_HIDE_TITLE),
        y=alt.Y(f'p_{patient_group}_ci_u:Q', title=None),
        y2=alt.Y2(f'p_{patient_group}_ci_l:Q'),
        color=alt.value('black')
    )
    
    race_p_error_bar_late = race_p_bar_late.mark_bar(
        size=2, color='black', opacity=1, xOffset=10, 
    ).encode(
        x=alt.X('group:N', title=None, axis=AXIS_HIDE_TITLE),
        y=alt.Y(f'p_{patient_group}_ci_u:Q', title=None),
        y2=alt.Y2(f'p_{patient_group}_ci_l:Q'),
        color=alt.value('black')
    )

    
    """
    ================================================================================================================================================
    ################################################################## SIG NOTATATION ##############################################################
    ================================================================================================================================================
    """
    race_p_sig_tick = race_p_bar_base.mark_tick(
        size=45, yOffset=-10, strokeWidth=8
    ).encode(
        x=alt.X(
            'group:N',
            title=None,
            axis=AXIS_HIDE_TITLE
        ),
        y=alt.Y(f'max(p_{patient_group}):Q', axis=alt.Axis(format='.0%', titleX=titleX), title="Percentage of Patients"),
        color=alt.value('black')
    ).transform_filter(
        {'field': f"p_{patient_group}_sig", 'oneOf': [True]}
    )
    
    race_p_sig_star = race_p_sig_tick.mark_text(
        yOffset=-15, fontSize=20
    ).encode(
        x=alt.X(
            'group:N',
            title=None,
            axis=AXIS_HIDE_TITLE
        ),
        y=alt.Y(f'max(p_{patient_group}):Q', axis=alt.Axis(format='.0%', titleX=titleX), title="Percentage of Patients"),
        color=alt.value('black'),
        text=alt.value('*')
    )
    
    race_p_bar = alt.layer(race_p_bar_early, race_p_bar_late, race_p_error_bar_early, race_p_sig_tick, race_p_sig_star, race_p_error_bar_late)

    ############## Bar Chart for # of Participants ##############
    race_n_bar_base = alt.Chart(
        rd
    ).encode(
        x=alt.X(
            'group:N', 
            title='Race', 
            axis=AXIS_SHOW
        ),
        y=alt.Y(f'n_{patient_group}:Q', axis=alt.Axis(title=None, grid=True, labels=False, ticks=False, domain=False), title=None),
        color=alt.Color("wave:N", scale=alt.Scale(domain=['Early', 'Late'], range=WAVE_COLOR), title="Wave")
    ).properties(
        width=width, height=70
    )

    race_n_bar_early = race_n_bar_base.transform_filter(
        {'field': 'wave', 'oneOf': ['Early']}
    ).mark_bar(
        xOffset=-10,
        size=20, stroke=STROKE
    )
    race_n_bar_late = race_n_bar_base.transform_filter(
        {'field': 'wave', 'oneOf': ['Late']}
    ).mark_bar(
        xOffset=10,
        size=20, stroke=STROKE
    )
    race_n_bar = alt.layer(race_n_bar_early, race_n_bar_late)

    ############## Bar Chart for % of Ever Severe ##############
    race_s_bar_base = alt.Chart(
        rd
    ).encode(
        x=alt.X(
            'group:N', 
            title=None, 
            axis=AXIS_HIDE
        ),
        y=alt.Y(f'p:Q', axis=alt.Axis(title=None, grid=True, labels=False, ticks=False, domain=False), title=None),
        color=alt.Color("wave:N", scale=alt.Scale(domain=['Early', 'Late'], range=WAVE_COLOR), title="Wave")
    ).properties(
        width=width, height=70
    )

    race_s_bar_early = race_s_bar_base.transform_filter(
        {'field': 'wave', 'oneOf': ['Early']}
    ).mark_bar(
        xOffset=-10,
        size=20, stroke=STROKE
    )
    race_s_bar_late = race_s_bar_base.transform_filter(
        {'field': 'wave', 'oneOf': ['Late']}
    ).mark_bar(
        xOffset=10,
        size=20, stroke=STROKE
    )
    
    """
    ================================================================================================================================================
    ################################################################## SIG NOTATATION ##############################################################
    ================================================================================================================================================
    """
    race_s_error_bar_early = race_s_bar_early.mark_bar(
        size=2, color='black', opacity=1, xOffset=-10, 
    ).encode(
        x=alt.X('group:N', title=None, axis=AXIS_HIDE_TITLE),
        y=alt.Y(f'p_ci_u:Q'),
        y2=alt.Y2(f'p_ci_l:Q'),
        color=alt.value('black')
    )
    
    race_s_error_bar_late = race_s_bar_late.mark_bar(
        size=2, color='black', opacity=1, xOffset=10, 
    ).encode(
        x=alt.X('group:N', title=None, axis=AXIS_HIDE_TITLE),
        y=alt.Y(f'p_ci_u:Q'),
        y2=alt.Y2(f'p_ci_l:Q'),
        color=alt.value('black')
    )
    
    race_s_sig_tick = race_s_bar_base.mark_tick(
        size=45, yOffset=-10, strokeWidth=8
    ).encode(
        x=alt.X(
            'group:N',
            title=None,
            axis=AXIS_HIDE_TITLE
        ),
        y=alt.Y(f'max(p_ci_u):Q', axis=alt.Axis(format='.0%')),
        color=alt.value('black')
    ).transform_filter(
        {'field': f"p_sig", 'oneOf': [True]}
    )
    
    race_s_sig_star = race_s_sig_tick.mark_text(
        yOffset=-15, fontSize=20
    ).encode(
        x=alt.X(
            'group:N',
            title=None,
            axis=AXIS_HIDE_TITLE
        ),
        y=alt.Y(f'max(p_ci_u):Q'),
        color=alt.value('black'),
        text=alt.value('*')
    )
    
    race_s_bar = alt.layer(race_s_bar_early, race_s_bar_late, race_s_error_bar_early, race_s_error_bar_late, race_s_sig_tick, race_s_sig_star)
    
    """
    /////////////////////
    LAYER-CHARTS FOR DIFF
    /////////////////////
    """
    diff = d.copy()
    diff.loc[diff.wave == 'Early', 'p_all'] = diff[diff.wave == 'Early'].p_all.apply(lambda x: -x)
    diff = diff.groupby(['group']).sum()
    diff = diff.reset_index()

    # VISUAL PARAMETERS
    xOffset = -7

    """
    AGE DIFF
    """
    age_diff = diff[diff.group.isin(AGE_GROUPS)]

    age_diff_chart = alt.Chart(
        age_diff
    ).mark_tick(
        color="black", size=40, stroke="white", strokeWidth=2, thickness=5, xOffset=xOffset
    ).encode(
        x=alt.X("group:N", title=None, axis=None),
        y=alt.Y("p_all:Q")
    )

    """
    SEX DIFF
    """
    sex_diff = diff[diff.group.isin(SEX_GROUPS)]

    sex_diff_chart = alt.Chart(
        sex_diff
    ).mark_tick(
        color="black", size=40, stroke="white", strokeWidth=2, thickness=5, xOffset=xOffset
    ).encode(
        x=alt.X("group:N", title=None, axis=None),
        y=alt.Y("p_all:Q")
    )

    """
    RACE DIFF
    """
    race_diff = diff[diff.group.isin(RACE_GROUPS)]

    race_diff_chart = alt.Chart(
        race_diff
    ).mark_tick(
        color="black", size=40, stroke="white", strokeWidth=2, thickness=5, xOffset=xOffset
    ).encode(
        x=alt.X("group:N", title=None, axis=None),
        y=alt.Y("p_all:Q")
    )
    
    """
    ////////////////////////////////////
    ASSEMBLE
    ////////////////////////////////////
    """
    if race:
        final_chart = alt.hconcat(
            (age_p_bar + age_diff_chart), 
            (sex_p_bar + sex_diff_chart), 
            (race_p_bar + race_diff_chart)
        ).resolve_scale(y='shared', color='shared')
    
        nfinal_chart = alt.hconcat(age_n_bar, sex_n_bar, race_n_bar).resolve_scale(y='shared', color='shared')
        sfinal_chart = alt.hconcat(age_s_bar, sex_s_bar, race_s_bar).resolve_scale(y='shared', color='shared')
    else:
        # Because we do not show race info for the countries other than USA
        final_chart = alt.hconcat(
            (age_p_bar + age_diff_chart), 
            (sex_p_bar + sex_diff_chart),
        ).resolve_scale(y='shared', color='shared')
    
        nfinal_chart = alt.hconcat(age_n_bar, sex_n_bar).resolve_scale(y='shared', color='shared')
        sfinal_chart = alt.hconcat(age_s_bar, sex_s_bar).resolve_scale(y='shared', color='shared')

    final_chart = (
        (final_chart & sfinal_chart & nfinal_chart).properties( # 
            title={
                "text": f"{country} Demographics of {'All' if patient_group == 'all' else 'Ever Severe'} Patients by Wave",
                "dx": 80,
                "subtitle": get_visualization_subtitle(data_release='2021-02-15', with_num_sites=False), 
                "subtitleColor": "gray",
            }
        )
    )

#     final_chart = apply_theme(
#         final_chart,
#         axis_y_title_font_size=16,
#         title_anchor='start',
#         legend_orient='right'
#     )
    return final_chart

# Demographics

In [None]:
df2 = pd.read_csv(join("..", "data", "1.1.resurgence", "demographics", "demographic_stats_withCI.csv"))

for i, patient_group in enumerate(['all', 'severe']):
    if i == 0:
        plot = FUNC_DEMOGRAPHICS_BY_WAVE(df2, country="", patient_group=patient_group)
    else:
        plot = alt.hconcat(plot, FUNC_DEMOGRAPHICS_BY_WAVE(df2, country="", patient_group=patient_group), spacing=30)#.resolve_scale(color='independent')

plot = apply_theme(
    plot,
    axis_y_title_font_size=16,
    title_anchor='start',
    legend_orient='right'
)
plot

# Demographics by Country

In [None]:
df = pd.read_csv(join("..", "data", "1.1.resurgence", "demographics", "demographic_stats_bycountry_withCI.csv"))

# List of values to get a better sense on the data
print(
    # Singapore data shouldn't be included here
    # demo_df.country.unique().tolist(),
    df.group.unique().tolist(),
    # demo_df.country_wave.unique().tolist()
) 

df

In [None]:
df.country.unique().tolist()

In [None]:
i = 0
for country in df.country.unique().tolist():
    if country == 'GERMANY':
        continue
    df_country = df[df.country == country]

    j = 0
    for patient_group in ['all', 'severe']:
        if j == 0:
            plot = FUNC_DEMOGRAPHICS_BY_WAVE(df_country, race=True if country == 'USA' else False, country=country if country == 'USA' else country.capitalize(), patient_group=patient_group)
        else:
            plot = alt.hconcat(
                plot,
                FUNC_DEMOGRAPHICS_BY_WAVE(df_country, race=True if country == 'USA' else False, country=country if country == 'USA' else country.capitalize(), patient_group=patient_group),
                spacing=30
            )
        j+=1
    if i == 0:
        view = plot.copy()
    else:
        view = alt.vconcat(
            view, 
            plot,
            spacing=30
        ).resolve_scale(color='independent')
    i+=1
        
view = apply_theme(
    view,
    axis_y_title_font_size=16,
    title_anchor='start',
    legend_orient='right'
)

view

In [None]:
def FUNC_DEMOGRAPHICS_BY_WAVE_WITH_LINES(
    _data,
    patient_group='all' # either 'all' or 'severe'
):
    d = _data.copy()
    
    """
    RENAME COLUMNS AND VALUES
    """
    d = d.rename(columns={
        'p.all': 'p_all', 
        'n.all': 'n_all',
        'p.severe': 'p_severe',
        'n.severe': 'n_severe'
    })
    d.group = d.group.apply(
        lambda x: {
            '00to25': '0-25',
            '26to49': '26-49',
            '50to69': '50-69',
            '70to79': '70-79',
            '80plus': '80+',
            'female': 'Female',
            'male': 'Male',
            'white': 'White',
            'black': 'Black',
            'other': 'Other',
            'other_age': 'Other',
            'other_sex': 'Other',
            'other_race': 'Other'
        }[x]
    )
    d.wave = d.wave.apply(
        lambda x: {
            'early': 'Early',
            'late': 'Late'
        }[x]
    )
    
    """
    CATEGORIES WE USE
    """
    AGE_GROUPS = ['0-25', '26-49', '50-69', '70-79', '80+']
    SEX_GROUPS = ['Female', 'Male']
    RACE_GROUPS = ['White', 'Black']
    COUNTRY_COLORS = ['#0072B2', '#E79F00', '#029F73', '#D45E00', '#CB7AA7']
    
    """
    /////////////////////
    SUB-CHARTS FOR GROUPS
    /////////////////////
    """
    
    """
    COMMON VISUAL PARAMETERS
    """
    width = 120
    titleX = -60
    padding = 0.3
    
    """
    AGE GROUPS
    """
    ad = d[d.group.isin(AGE_GROUPS)]
    
    ############## Bar Chart for % of Participants ##############
    age_p_line = alt.Chart(
        ad
    ).mark_line(
        point=True,
        size=3,
    ).encode(
        x=alt.X(
            'wave:N', 
            title=None,
#             axis=AXIS_HIDE_TITLE
            axis=None,
            scale=alt.Scale(padding=padding)
        ),
        y=alt.Y(f'p_{patient_group}:Q', axis=alt.Axis(format='.0%', titleX=titleX), title="Percentage of Patients"),
        color=alt.Color("country:N", scale=alt.Scale(range=COUNTRY_COLORS), title=None)
    ).properties(
        width=width, height=200
    ).facet(
        spacing=2,
        column=alt.Column(
            "group:N",
            header=alt.Header(labelOrient="bottom", title=None, titleOrient="bottom", labels=False)
        )
    )

    ############## Bar Chart for % of Ever Severe ##############
    age_s_line = alt.Chart(
        ad
    ).mark_line(
        point=True,
        size=3,
    ).encode(
        x=alt.X(
            'wave:N', 
            title=None, 
            axis=None,
            scale=alt.Scale(padding=padding)
        ),
        y=alt.Y(f'p:Q', title="% Severe", axis=alt.Axis(format=".0%", titleX=titleX)),
        color=alt.Color("country:N", scale=alt.Scale(range=COUNTRY_COLORS))
    ).properties(
        width=width, height=70
    ).facet(
        spacing=2,
        column=alt.Column(
            "group:N",
            header=alt.Header(labelOrient="bottom", title=None, titleOrient="bottom", labels=False)
        )
    )
    
    ############## Bar Chart for # of Participants ##############
    age_n_line = alt.Chart(
        ad
    ).mark_line(
        point=True,
        size=3,
    ).encode(
        x=alt.X(
            'wave:N', 
            title=None, 
            axis=alt.Axis(grid=False, labels=True, ticks=False, domain=False, tickMinStep=1),
            scale=alt.Scale(padding=padding)
        ),
        y=alt.Y(f'n_{patient_group}:Q', title="# Patients", axis=alt.Axis(titleX=titleX)),
        color=alt.Color("country:N", scale=alt.Scale(range=COUNTRY_COLORS))
    ).properties(
        width=width, height=70
    ).facet(
        spacing=2,
        column=alt.Column(
            "group:N",
            header=alt.Header(labelOrient="bottom", title="Age Group", titleOrient="bottom")
        )
    )

    """
    SEX GROUPS
    """    
    sd = d[d.group.isin(SEX_GROUPS)]
    
    ############## Bar Chart for % of Participants ##############
    sex_p_line = alt.Chart(
        sd
    ).mark_line(
        point=True,
        size=3,
    ).encode(
        x=alt.X(
            'wave:N', 
            title=None,
#             axis=AXIS_HIDE_TITLE
            axis=None,
            scale=alt.Scale(padding=padding)
        ),
        y=alt.Y(f'p_{patient_group}:Q', axis=alt.Axis(format='.0%', titleX=titleX, orient="left"), title="Percentage of Patients"),
        color=alt.Color("country:N", scale=alt.Scale(range=COUNTRY_COLORS), title='Country')
    ).properties(
        width=width, height=200
    ).facet(
        spacing=2,
        column=alt.Column(
            "group:N",
            header=alt.Header(labelOrient="bottom", title=None, titleOrient="bottom", labels=False)
        )
    )

    ############## Bar Chart for % of Ever Severe ##############
    sex_s_line = alt.Chart(
        sd
    ).mark_line(
        point=True,
        size=3,
    ).encode(
        x=alt.X(
            'wave:N', 
            title=None, 
            axis=None,
            scale=alt.Scale(padding=padding)
        ),
        y=alt.Y(f'p:Q', title="% Severe", axis=alt.Axis(format=".0%", titleX=titleX, orient="left")),
        color=alt.Color("country:N", scale=alt.Scale(range=COUNTRY_COLORS))
    ).properties(
        width=width, height=70
    ).facet(
        spacing=2,
        column=alt.Column(
            "group:N",
            header=alt.Header(labelOrient="bottom", title=None, titleOrient="bottom", labels=False)
        )
    )
    
    ############## Bar Chart for # of Participants ##############
    sex_n_line = alt.Chart(
        sd
    ).mark_line(
        point=True,
        size=3,
    ).encode(
        x=alt.X(
            'wave:N', 
            title=None, 
            axis=alt.Axis(grid=False, labels=True, ticks=False, domain=False, tickMinStep=1),
            scale=alt.Scale(padding=padding)
        ),
        y=alt.Y(f'n_{patient_group}:Q', title="# Patients", axis=alt.Axis(titleX=titleX, orient="left")),
        color=alt.Color("country:N", scale=alt.Scale(range=COUNTRY_COLORS))
    ).properties(
        width=width, height=70
    ).facet(
        spacing=2,
        column=alt.Column(
            "group:N",
            header=alt.Header(labelOrient="bottom", title="Sex", titleOrient="bottom")
        )
    )    
    
    """
    ////////////////////////////////////
    ASSEMBLE
    ////////////////////////////////////
    """
    final_chart = (
        alt.hconcat((age_p_line & age_s_line & age_n_line), (sex_p_line & sex_s_line & sex_n_line), spacing=40).properties(
            title={
                "text": f"Country-level Demographics of {'All' if patient_group == 'all' else 'Ever Severe'} Patients by Wave",
                "dx": 80,
                "subtitle": get_visualization_subtitle(data_release='2021-02-15', with_num_sites=False), 
                "subtitleColor": "gray",
            }
        )
    )

#     final_chart = apply_theme(
#         final_chart,
#         axis_y_title_font_size=16,
#         title_anchor='start',
#         legend_orient='bottom'
#     )
    return final_chart

In [None]:
plot = alt.vconcat(
    FUNC_DEMOGRAPHICS_BY_WAVE_WITH_LINES(df, patient_group='all'),
    FUNC_DEMOGRAPHICS_BY_WAVE_WITH_LINES(df, patient_group='severe'),
    spacing=30
).resolve_scale(color='independent')

plot = apply_theme(
    plot,
    axis_y_title_font_size=16,
    title_anchor='start',
    legend_orient='right'
)

plot

In [None]:
def FUNC_DEMOGRAPHICS_BY_WAVE_WITH_STACKED_BARS(
    _data,
    patient_group='all' # either 'all' or 'severe'
):
    d = _data.copy()
    
    """
    RENAME COLUMNS AND VALUES
    """
    d = d.rename(columns={
        'p.all': 'p_all', 
        'n.all': 'n_all',
        'p.severe': 'p_severe',
        'n.severe': 'n_severe'
    })
    d.group = d.group.apply(
        lambda x: {
            '00to25': '0-25',
            '26to49': '26-49',
            '50to69': '50-69',
            '70to79': '70-79',
            '80plus': '80+',
            'female': 'Female',
            'male': 'Male',
            'white': 'White',
            'black': 'Black',
            'other': 'Other',
            'other_age': 'Other',
            'other_sex': 'Other',
            'other_race': 'Other'
        }[x]
    )
    d.wave = d.wave.apply(
        lambda x: {
            'early': 'Early',
            'late': 'Late'
        }[x]
    )
    
    """
    CATEGORIES WE USE
    """
    AGE_GROUPS = ['0-25', '26-49', '50-69', '70-79', '80+']
    SEX_GROUPS = ['Female', 'Male']
    RACE_GROUPS = ['White', 'Black']
    COUNTRY_COLORS = ['#0072B2', '#E79F00', '#029F73', '#D45E00', '#CB7AA7']
    
    """
    /////////////////////
    SUB-CHARTS FOR GROUPS
    /////////////////////
    """
    
    """
    COMMON VISUAL PARAMETERS
    """
    width = 200
    height = 200
    titleX = -60
    padding = 0.3
    size = 60
    
    """
    AGE GROUPS
    """
    ad = d[d.group.isin(AGE_GROUPS)]
    
    ############## Bar Chart for # of Participants ##############
    age_stacked = alt.Chart(
        ad
    ).mark_bar(
        size=size,
        stroke="black"
    ).encode(
        x=alt.X(
            'wave:N', 
            title="Wave", 
#             axis=alt.Axis(grid=False, labels=True, ticks=False, domain=False, tickMinStep=1),
            scale=alt.Scale(padding=padding)
        ),
        y=alt.Y(f'p_{patient_group}:Q', title="Percentage of Patients", axis=alt.Axis(titleX=titleX, format=".0%")),
        color=alt.Color("group:N", scale=alt.Scale(range=['#E9F1FA', '#BBD7EB', '#6DAED5', '#2D7CBA', '#083672']), title="Age Group")
    ).properties(
        width=width, height=height
    )
    
    text = alt.Chart(
        ad
    ).mark_text(size=16, dx=0, dy=5, color='white', baseline='top', fontWeight=500).encode(
        y=alt.Y(f'sum(p_{patient_group}):Q', stack='zero'),
        x=alt.X('wave:N'),
        detail='group:N',
        text=alt.Text(f'sum(p_{patient_group}):Q', format='.0%')
    ).transform_filter(
        (f'datum.p_{patient_group} > 0.10')
    )
    age_stacked = (age_stacked + text)

    """
    SEX GROUPS
    """    
    sd = d[d.group.isin(SEX_GROUPS)]
    
    ############## Bar Chart for % of Participants ##############
    sex_stacked = alt.Chart(
        sd
    ).mark_bar(
        size=size,
        stroke="black"
    ).encode(
        x=alt.X(
            'wave:N', 
            title="Wave",
            axis=AXIS_SHOW,
            scale=alt.Scale(padding=padding)
        ),
        y=alt.Y(f'p_{patient_group}:Q', axis=alt.Axis(titleX=titleX, orient="left", format=".0%"), title="Percentage of Patients"),
        color=alt.Color("group:N", scale=alt.Scale(range=['#DC3912', '#3366CC']), title="Sex")
    ).properties(
        width=width, height=height
    ) 
    
    text = alt.Chart(
        sd
    ).mark_text(size=16, dx=0, dy=5, color='white', baseline='top', fontWeight=500).encode(
        y=alt.Y(f'sum(p_{patient_group}):Q', stack='zero'),
        x=alt.X('wave:N'),
        detail='group:N',
        text=alt.Text(f'sum(p_{patient_group}):Q', format='.0%')
    )
    sex_stacked = (sex_stacked + text)
    
    """
    RACE GROUPS
    """    
    rd = d[d.group.isin(RACE_GROUPS)]
    
    ############## Bar Chart for % of Participants ##############
    race_stacked = alt.Chart(
        rd
    ).mark_bar(
        size=size
    ).encode(
        x=alt.X(
            'wave:N', 
#             title=None,
#             axis=None,
            scale=alt.Scale(padding=padding)
        ),
        y=alt.Y(f'p_{patient_group}:Q', axis=alt.Axis(titleX=titleX, orient="left", format=".0%"), title="Percentage of Patients"),
        color=alt.Color("group:N", scale=alt.Scale(range=['#DC3912', '#3366CC']), title="Race")
    ).properties(
        width=width, height=height
    )
    
    """
    ////////////////////////////////////
    ASSEMBLE
    ////////////////////////////////////
    """
    final_chart = (
        alt.hconcat(age_stacked, sex_stacked, spacing=40).resolve_scale(color='independent').properties(
            title={
                "text": f"Demographics of {'All' if patient_group == 'all' else 'Ever Severe'} Patients by Wave",
                "dx": 80,
                "subtitle": get_visualization_subtitle(data_release='2021-01-25', with_num_sites=False), 
                "subtitleColor": "gray",
            }
        )
    )

    final_chart = apply_theme(
        final_chart,
        axis_y_title_font_size=16,
        title_anchor='start',
        legend_orient='right'
    )
    return final_chart    

In [None]:
df2 = pd.read_csv(join("..", "data", "1.1.resurgence", "demographics", "demographic_stats.csv"))

for patient_group in ['all', 'severe']:
    FUNC_DEMOGRAPHICS_BY_WAVE_WITH_STACKED_BARS(df2, patient_group=patient_group).display()

In [None]:
df = pd.read_csv(join("..", "data", "1.1.resurgence", "demographics", "demographic_stats_bycountry_withCI.csv"))
