In [None]:
%load_ext autoreload
%autoreload 2
%aimport utils_1_0

import pandas as pd
import numpy as np
import altair as alt
from altair_saver import save
from os.path import join
from web import for_website
import requests
import io

from constants_1_0 import COLUMNS
from utils_1_0 import (
    get_visualization_subtitle,
    apply_theme
)

# Data Preprocessing

## Demographics Data From Figshare (WIP)
Use the latest data from https://doi.org/10.6084/m9.figshare.12152973.v1

In [None]:
# Demographics-CombinedByCountry.csv
df_dm_combined = pd.read_csv("https://ndownloader.figshare.com/files/22346619")

df_dm_combined.head()

In [None]:
# Demographics-Combined.csv
df_dm = pd.read_csv("https://ndownloader.figshare.com/files/22346622")

df_dm.head()

In [None]:
ALL_COUNTRY = "All countries"
ALL_COUNTRY_COLOR = "#444444"
COUNTRIES = [ "France", "Germany", "Italy", "Singapore", "USA" ]
COUNTRY_COLOR = [ "#0072B2", "#E69F00", "#009E73", "#CC79A7", "#D55E00" ]
COLOR_BY_COUNTRY = { COUNTRIES[i]: COUNTRY_COLOR[i] for i in range(len(COUNTRIES)) }

## Merge Data

In [None]:
def preprocess_demo_df(df_dm):
    
    # Drop unused columns before preprocessing for the simplicity
    df_dm = df_dm.drop(columns=[
        COLUMNS.UNMASKED_SITES_TOTAL_PATIENTS,
        COLUMNS.UNMASKED_SITES_AGE_0TO2,
        COLUMNS.UNMASKED_SITES_AGE_3TO5,
        COLUMNS.UNMASKED_SITES_AGE_6TO11,
        COLUMNS.UNMASKED_SITES_AGE_12TO17,
        COLUMNS.UNMASKED_SITES_AGE_18TO25,
        COLUMNS.UNMASKED_SITES_AGE_26TO49,
        COLUMNS.UNMASKED_SITES_AGE_50TO69,
        COLUMNS.UNMASKED_SITES_AGE_70TO79,
        COLUMNS.UNMASKED_SITES_AGE_80PLUS,
        COLUMNS.MASKED_SITES_TOTAL_PATIENTS,
        COLUMNS.MASKED_SITES_AGE_0TO2,
        COLUMNS.MASKED_SITES_AGE_3TO5,
        COLUMNS.MASKED_SITES_AGE_6TO11,
        COLUMNS.MASKED_SITES_AGE_12TO17,
        COLUMNS.MASKED_SITES_AGE_18TO25,
        COLUMNS.MASKED_SITES_AGE_26TO49,
        COLUMNS.MASKED_SITES_AGE_50TO69,
        COLUMNS.MASKED_SITES_AGE_70TO79,
        COLUMNS.MASKED_SITES_AGE_80PLUS,
        COLUMNS.MASKED_UPPER_BOUND_TOTAL_PATIENTS,
        COLUMNS.TOTAL_PATIENTS,
    ])

    # Wide to long
    df_dm = pd.melt(df_dm, id_vars=[
        COLUMNS.SITE_ID,
        COLUMNS.SEX,
        COLUMNS.MASKED_UPPER_BOUND_AGE_0TO2,
        COLUMNS.MASKED_UPPER_BOUND_AGE_3TO5,
        COLUMNS.MASKED_UPPER_BOUND_AGE_6TO11,
        COLUMNS.MASKED_UPPER_BOUND_AGE_12TO17,
        COLUMNS.MASKED_UPPER_BOUND_AGE_18TO25,
        COLUMNS.MASKED_UPPER_BOUND_AGE_26TO49,
        COLUMNS.MASKED_UPPER_BOUND_AGE_50TO69,
        COLUMNS.MASKED_UPPER_BOUND_AGE_70TO79,
        COLUMNS.MASKED_UPPER_BOUND_AGE_80PLUS,
    ])
    df_dm = df_dm.rename(columns={"variable": COLUMNS.AGE_GROUP, "value": COLUMNS.NUM_PATIENTS})

    df_dm[COLUMNS.SEX] = df_dm[COLUMNS.SEX].apply(lambda x: x.capitalize())

    # Drop unused columns
    df_dm = df_dm.drop(columns=[
        COLUMNS.MASKED_UPPER_BOUND_AGE_0TO2,
        COLUMNS.MASKED_UPPER_BOUND_AGE_3TO5,
        COLUMNS.MASKED_UPPER_BOUND_AGE_6TO11,
        COLUMNS.MASKED_UPPER_BOUND_AGE_12TO17,
        COLUMNS.MASKED_UPPER_BOUND_AGE_18TO25,
        COLUMNS.MASKED_UPPER_BOUND_AGE_26TO49,
        COLUMNS.MASKED_UPPER_BOUND_AGE_50TO69,
        COLUMNS.MASKED_UPPER_BOUND_AGE_70TO79,
        COLUMNS.MASKED_UPPER_BOUND_AGE_80PLUS,
    ])

    # Add a percentage column
    unique_site_ids = df_dm[COLUMNS.SITE_ID].unique()
    for site in unique_site_ids:
        unique_sex = df_dm[df_dm[COLUMNS.SITE_ID] == site][COLUMNS.SEX].unique()
        for sex in unique_sex:
            df_filter = (df_dm[COLUMNS.SITE_ID] == site) & (df_dm[COLUMNS.SEX] == sex)
            total = df_dm.loc[
                df_filter, 
                COLUMNS.NUM_PATIENTS
            ].sum()
            
            df_dm.loc[
                df_filter, 
                "per_patients"
            ] = df_dm.loc[df_filter, COLUMNS.NUM_PATIENTS] / total * 100
    
    # Use readable names
    df_dm.loc[df_dm[COLUMNS.SITE_ID] == "Combined", COLUMNS.SITE_ID] = ALL_COUNTRY
    readable_age_group = {
        COLUMNS.AGE_0TO2: "0 - 2",
        COLUMNS.AGE_3TO5: "3 - 5",
        COLUMNS.AGE_6TO11: "6 - 11",
        COLUMNS.AGE_12TO17: "12 - 17",
        COLUMNS.AGE_18TO25: "18 - 25",
        COLUMNS.AGE_26TO49: "26 - 49",
        COLUMNS.AGE_50TO69: "50 - 69",
        COLUMNS.AGE_70TO79: "70 - 79",
        COLUMNS.AGE_80PLUS: "80+"
    }
    df_dm[COLUMNS.AGE_GROUP] = df_dm[COLUMNS.AGE_GROUP].apply(lambda x: readable_age_group[x])
    
    # Compute standard error and confidence interval
    df_dm["P"] = df_dm["per_patients"] / 100
    df_dm["N"] = df_dm["num_patients"]
    df_dm["standard_error"] = df_dm.apply(lambda obs: np.nan if obs["N"] == 0 else (obs["P"]*(1-obs["P"]))/obs["N"], axis='columns')
    df_dm["95_CI_below"] = df_dm.apply(lambda obs: obs["P"] - 1.96*np.sqrt(obs["standard_error"]), axis='columns')
    df_dm["95_CI_above"] = df_dm.apply(lambda obs: obs["P"] + 1.96*np.sqrt(obs["standard_error"]), axis='columns')
    df_dm["95_CI_below"] = df_dm["95_CI_below"].clip(lower=0)
    df_dm["95_CI_below_x100"] = df_dm["95_CI_below"] * 100
    df_dm["95_CI_above_x100"] = df_dm["95_CI_above"] * 100
    
    

    return df_dm

# df_dm = read_combined_by_country_demographics_df() # For loading local data
df_dm = preprocess_demo_df(df_dm)

# df_dm_combined = read_combined_demographics_df() # For loading local data
df_dm_combined = preprocess_demo_df(df_dm_combined)

# Merge
df_dm = pd.concat([df_dm, df_dm_combined])

df_dm

For the CI:
- For percentages in demographics
    - $P = \texttt{percentage of patients in group in country}$ (for example age group 50-69 in France, or 50-69 AND Male in France)
    - $N = \texttt{total number of cases in country}$
    - then the standard error should be $P*(1-P)/N$
    - so 95% CI would be $P \pm 1.96*\sqrt{P*(1-P)/N}$


# Visualizations

In [None]:
def demographics(is_percent=False, by_country=False):

    # Selection components
    country_dropdown = alt.binding_select(options=[ALL_COUNTRY] + COUNTRIES)
    country_selection = alt.selection_single(fields=[COLUMNS.SITE_ID], bind=country_dropdown, name="Country", init={COLUMNS.SITE_ID: ALL_COUNTRY})
    sex_dropdown = alt.binding_select(options=["All", "Male", "Female"])
    sex_selection = alt.selection_single(fields=[COLUMNS.SEX], bind=sex_dropdown, name="Sex", init={COLUMNS.SEX: "All"})
    color_field = COLUMNS.SITE_ID if by_country else COLUMNS.SEX
    legend_selection = alt.selection_multi(fields=[color_field], bind="legend")

    # Filter
    filtered_chart = alt.Chart(df_dm).transform_filter(
        alt.datum[COLUMNS.SEX] != "Other"
    ).transform_filter(
        legend_selection
    )

    if by_country:
        filtered_chart = filtered_chart.transform_filter(
            sex_selection
        ).transform_filter(
            alt.datum[COLUMNS.SITE_ID] != ALL_COUNTRY
        )
    else:
        filtered_chart = filtered_chart.transform_filter(
            country_selection
        ).transform_filter(
            alt.datum[COLUMNS.SEX] != "All"
        )

    DEMO_TOOLTIP = [
        alt.Tooltip(COLUMNS.SITE_ID, title="Country"),
        alt.Tooltip(COLUMNS.SEX, title="Sex"),
        alt.Tooltip(COLUMNS.AGE_GROUP, title="Age group"),
        alt.Tooltip(COLUMNS.NUM_PATIENTS, title="Number of patients"),
    ]

    y_field = "per_patients" if is_percent else COLUMNS.NUM_PATIENTS
    if is_percent:
        DEMO_TOOLTIP += [
            alt.Tooltip("95_CI_below_x100", title="95% CI lower bound"),
            alt.Tooltip("per_patients", title="Percentage of patients (%)", format=".1f"),
            alt.Tooltip("95_CI_above_x100", title="95% CI upper bound"),
        ]

    # Render
    color_scale = alt.Scale(domain=COUNTRIES, range=COUNTRY_COLOR)  if by_country else alt.Scale(domain=["Male", "Female"], range=["#3366cc", "#dc3912"]) 
    y_title = "Percentage of patients (%)" if is_percent else "Number of patients"
    bar = filtered_chart.mark_bar().encode(
        x=alt.X(f"{color_field}:N", title=None, axis=None),
        y=alt.Y(f"{y_field}:Q", title=y_title, axis=alt.Axis(tickCount=5)),
        color=alt.Color(f"{color_field}:N", title=None, scale=color_scale),
        tooltip=DEMO_TOOLTIP
    ).properties(width=67,height=400)
    
    errorbar = filtered_chart.mark_errorbar(color="black").encode(
        x=alt.X(f"{color_field}:N", title=None, axis=None),
        y=alt.Y(f"95_CI_above_x100:Q", title=""), 
        y2=alt.Y2(f"95_CI_below_x100:Q", title="")
    )
    
    result_vis = (
        alt.layer(bar, errorbar, data=df_dm)
            .facet(
                column=alt.Column(
                    "age_group:O",
                    sort=["age_0to2","age_3to5","age_6to11","age_12to17","age_18to25","age_26to49","age_50to69","age_70to79", "age_80plus"],
                    header=alt.Header(labelOrient="bottom", title="Age group", titleOrient="bottom")
                )
            )
            .add_selection(legend_selection)
    )

    if by_country:
        result_vis = result_vis.add_selection(
            sex_selection
        )
    else:
        result_vis = result_vis.add_selection(
            country_selection
        )

    return result_vis

## Demographics by sex

In [None]:
demo = apply_theme(demographics(is_percent=True), legend_stroke_color="lightgray", axis_title_font_size=18).properties(title={
    "text": "Demographics by Sex",
    "subtitle": get_visualization_subtitle(alt_num_sites=21),
    "subtitleColor": "gray",
    "anchor": "start",
    "dx": 60
})
demo.display()

for_website(demo, "Demographics", "Demographics by sex with confidence intervals")
# save(demo, join(SAVE_DIR, f"demographics.png".lower()), scalefactor=2.0) # Uncomment this to save *.png files

## Demographics by country

In [None]:
demo = apply_theme(demographics(is_percent=True, by_country=True), legend_stroke_color="lightgray").properties(title={
    "text": "Demographics by Country",
    "subtitle": get_visualization_subtitle(alt_num_sites=21),
    "subtitleColor": "gray",
    "anchor": "start",
    "dx": 60
})
demo.display()

for_website(demo, "Demographics", "Demographics by country with confidence intervals")
# save(demo, join(SAVE_DIR, f"demographics.png".lower()), scalefactor=2.0) # Uncomment this to save *.png files