In [None]:
%load_ext autoreload
%autoreload 2
%aimport utils_1_1

import pandas as pd
import numpy as np
import altair as alt
from altair_saver import save
import itertools

from constants_1_1 import SITE_FILE_TYPES
from utils_1_1 import (
    read_loinc_df,
    get_site_file_paths,
    get_site_file_info,
    get_site_ids,
    get_siteid_country_map,
    get_siteid_color_maps,
    get_country_color_map,
    read_full_demographics_df,
    get_visualization_subtitle,
    apply_theme,
)
from web import for_website

alt.data_transformers.disable_max_rows(); # Allow using rows more than 5000

In [None]:
DATA_RELEASE = "2020-09-30"

In [None]:
GROUPS = [
    '26to49_vs_00to25',
    '50to69_vs_00to25',
    '70to79_vs_00to25',
    '80plus_vs_00to25',
    'black_vs_white',
    'female_vs_male',
    'other_vs_white',
]

In [None]:
df = pd.DataFrame()
for group in GROUPS:
    group_df = pd.read_csv(f"../data/Figure_1_dem_plot_meta_metafor_{group}.csv")
    group_df["resname"] = group
    df = df.append(group_df, ignore_index=True)

In [None]:
# Filter out the non-pool rows (don't care about individual site results, only pooled countries)
df["is_pool"] = df["siteId"].apply(lambda site_id: site_id.startswith("pool_"))

NUM_SITES = len(df.loc[df["is_pool"] == False]["siteid"].unique().tolist())

df = df.loc[df["is_pool"]].reset_index(drop=True)
# Process the country name
df["country"] = df["siteId"].apply(lambda site_id: site_id[5:].lower().capitalize())
df["country"] = df["country"].apply(lambda site_id: site_id.upper() if site_id.lower() == "usa" else site_id)
df = df.drop(columns=['siteId', 'siteid', 'is_pool'])
df.head()

In [None]:
GROUP_NAME_MAP = {
    '00to25': '0 to 25',
    '26to49': '26 to 49',
    '50to69': '50 to 69',
    '70to79': '70 to 79',
    '80plus': '80 plus',
    'black': 'Black',
    'white': 'White',
    'female': 'Female',
    'male': 'Male',
    'other': 'Other'
}
GROUP_TYPE_MAP = {
    '26to49_vs_00to25': 'Age Group',
    '50to69_vs_00to25': 'Age Group',
    '70to79_vs_00to25': 'Age Group',
    '80plus_vs_00to25': 'Age Group',
    'black_vs_white': 'Race',
    'female_vs_male': 'Sex',
    'other_vs_white': 'Race',
}
df["group_type"] = df["resname"].apply(lambda x: GROUP_TYPE_MAP[x])

df["group1"] = df["resname"].apply(lambda x: x.split("_vs_")[0])
df["group2"] = df["resname"].apply(lambda x: x.split("_vs_")[1])

df["group1"] = df["group1"].apply(lambda x: GROUP_NAME_MAP[x])
df["group2"] = df["group2"].apply(lambda x: GROUP_NAME_MAP[x])
df["resname"] = df.apply(lambda row: f"{row['group2']} vs. {row['group1']}", axis='columns')
df.head()

In [None]:
df["ci_95L"] = df["ci_95L"].clip(lower=-1.0)
df["ci_95U"] = df["ci_95U"].clip(upper=1.0)

In [None]:
country_color_map = get_country_color_map()
country_color_map["All"] = "gray"

In [None]:
new_country_color_map = country_color_map.copy()
for country_name in country_color_map:
    if country_name not in df["country"].unique().tolist():
        del new_country_color_map[country_name]
    
country_color_map = new_country_color_map
country_color_map

In [None]:
def create_plot_risk_comparison_for_group_type_by_country(group_type, width=300, height=400, tick_size=20):
    group_type_df = df.loc[df["group_type"] == group_type]

    group1_values = group_type_df["group1"].unique().tolist()
    group2_values = group_type_df["group2"].unique().tolist()

    group1_dropdown = alt.binding_select(options=group1_values)
    group1_selection = alt.selection_single(fields=["group1"], bind=group1_dropdown, name="Comparison Group", init={"group1": group1_values[0]})


    group2_dropdown = alt.binding_select(options=group2_values)
    group2_selection = alt.selection_single(fields=["group2"], bind=group2_dropdown, name="Reference Group", init={"group2": group2_values[0]})

    country_color_scale = alt.Scale(domain=list(country_color_map.keys()), range=list(country_color_map.values()))
    
    tooltips = [
        alt.Tooltip("group_type", title="Comparison Type"),
        alt.Tooltip("group2", title="Reference Group"),
        alt.Tooltip("group1", title="Comparison Group"),
        alt.Tooltip("country", title="Country"),
        alt.Tooltip("ci_95L", title="95% CI lower bound"),
        alt.Tooltip("ci_95U", title="95% CI upper bound"),
        alt.Tooltip("mean", title="Mean"),
        alt.Tooltip("p", title="p value"),
    ]
    
    chart = alt.Chart(group_type_df).transform_filter(
        group1_selection
    ).transform_filter(
        group2_selection
    )

    ci_bars = chart.mark_bar(size=tick_size).encode(
        y=alt.Y("country:N", axis=alt.Axis(title='Country')),
        color=alt.Color("country:N", legend=alt.Legend(title="Country", orient="right"), scale=country_color_scale),
        x=alt.X("ci_95L:Q", axis=alt.Axis(title='Pooled mean (CI)'), scale=alt.Scale(domain=[-1.0, 1.0])),
        x2=alt.X2("ci_95U:Q"),
        tooltip=tooltips,
    ).properties(width=width, height=height)

    mean_ticks = chart.mark_tick(size=tick_size, thickness=2).encode(
        y=alt.Y("country:N"),
        opacity=alt.value(1),
        color=alt.value('white'),
        x=alt.X('mean:Q', axis=alt.Axis(title='Pooled mean (CI)'))
    )
    
    p_val_texts = chart.mark_text(size=14, thickness=3).encode(
        y=alt.Y("country:N", axis=None, title=None),
        color=alt.value('black'),
        text=alt.Text('p:Q', format='.2f'),
    ).properties(
        width=60,
        title={
            'text': 'p value',
            'orient': 'top',
            'fontSize': 12
        }
    )

    chart = alt.hconcat(alt.layer(ci_bars, mean_ticks), p_val_texts).resolve_scale(y='shared').properties(title={
        "text": [f"{group_type} Risk Comparison by Country"], 
        "dx": 50,
        "subtitle": get_visualization_subtitle(data_release=DATA_RELEASE, num_sites=NUM_SITES),
        "subtitleColor": "gray",
        "anchor": "middle",
    })

    chart = apply_theme(chart).add_selection(
        group1_selection
    ).add_selection(
        group2_selection
    )
    
    for_website(chart, "Demographics", f"plot_risk_comparison_for_{group_type}_by_country", df=group_type_df)

    return chart

In [None]:
create_plot_risk_comparison_for_group_type_by_country("Age Group", width=300, height=200)

In [None]:
create_plot_risk_comparison_for_group_type_by_country("Race", width=300, height=200)

In [None]:
create_plot_risk_comparison_for_group_type_by_country("Sex", width=300, height=200)

In [None]:
def create_plot_risk_comparison_for_group_type_with_country_dropdown(group_type, width=300, height=400, tick_size=20):
    group_type_df = df.loc[df["group_type"] == group_type]
    
    country_values = group_type_df["country"].unique().tolist()


    country_dropdown = alt.binding_select(options=country_values)
    country_selection = alt.selection_single(fields=["country"], bind=country_dropdown, name="Country", init={"country": "All"})

    country_color_scale = alt.Scale(domain=list(country_color_map.keys()), range=list(country_color_map.values()))
    
    tooltips = [
        alt.Tooltip("group_type", title="Comparison Type"),
        alt.Tooltip("resname", title="Comparison"),
        alt.Tooltip("country", title="Country"),
        alt.Tooltip("ci_95L", title="95% CI lower bound"),
        alt.Tooltip("ci_95U", title="95% CI upper bound"),
        alt.Tooltip("mean", title="Mean"),
        alt.Tooltip("p", title="p value"),
    ]
    
    chart = alt.Chart(group_type_df).transform_filter(
        country_selection
    )

    ci_bars = chart.mark_bar(size=tick_size).encode(
        y=alt.Y("resname:N", axis=alt.Axis(title='Comparison')),
        color=alt.Color("country:N", legend=alt.Legend(title="Country", orient="right"), scale=country_color_scale),
        x=alt.X("ci_95L:Q", axis=alt.Axis(title='Pooled mean (CI)'), scale=alt.Scale(domain=[-1.0, 1.0])),
        x2=alt.X2("ci_95U:Q"),
        tooltip=tooltips,
    ).properties(width=width, height=height)

    mean_ticks = chart.mark_tick(size=tick_size, thickness=2).encode(
        y=alt.Y("resname:N"),
        opacity=alt.value(1),
        color=alt.value('white'),
        x=alt.X('mean:Q', axis=alt.Axis(title='Pooled mean (CI)'))
    )
    
    p_val_texts = chart.mark_text(size=14, thickness=3).encode(
        y=alt.Y("resname:N", axis=None, title=None),
        color=alt.value('black'),
        text=alt.Text('p:Q', format='.2f'),
    ).properties(
        width=60,
        title={
            'text': 'p value',
            'orient': 'top',
            'fontSize': 12
        }
    )

    chart = alt.hconcat(alt.layer(ci_bars, mean_ticks), p_val_texts).resolve_scale(y='shared').properties(title={
        "text": [f"{group_type} Risk Comparison by Country"], 
        "dx": 50,
        "subtitle": get_visualization_subtitle(data_release=DATA_RELEASE, num_sites=NUM_SITES),
        "subtitleColor": "gray",
        "anchor": "middle",
    })

    chart = apply_theme(chart).add_selection(
        country_selection
    )
    
    for_website(chart, "Demographics", f"plot_risk_comparison_for_{group_type}_with_country_dropdown", df=group_type_df)

    return chart

In [None]:
create_plot_risk_comparison_for_group_type_with_country_dropdown("Age Group", width=300, height=200)

In [None]:
create_plot_risk_comparison_for_group_type_with_country_dropdown("Race", width=300, height=100)

In [None]:
create_plot_risk_comparison_for_group_type_with_country_dropdown("Sex", width=300, height=50)

In [None]:
def create_plot_risk_comparison_for_group_type_with_country_rowfacet(group_type, width=300, height=400, tick_size=20):
    group_type_df = df.loc[df["group_type"] == group_type]
    
    country_values = group_type_df["country"].unique().tolist()


    country_dropdown = alt.binding_select(options=country_values)
    country_selection = alt.selection_single(fields=["country"], bind=country_dropdown, name="Country", init={"country": "All"})

    country_color_scale = alt.Scale(domain=list(country_color_map.keys()), range=list(country_color_map.values()))
    
    tooltips = [
        alt.Tooltip("group_type", title="Comparison Type"),
        alt.Tooltip("resname", title="Comparison"),
        alt.Tooltip("country", title="Country"),
        alt.Tooltip("ci_95L", title="95% CI lower bound"),
        alt.Tooltip("ci_95U", title="95% CI upper bound"),
        alt.Tooltip("mean", title="Mean"),
        alt.Tooltip("p", title="p value"),
    ]
    
    chart = alt.Chart(group_type_df)

    ci_bars = chart.mark_bar(size=tick_size).encode(
        y=alt.Y("resname:N", axis=alt.Axis(title=None)),
        color=alt.Color("country:N", legend=alt.Legend(title="Country", orient="right"), scale=country_color_scale),
        x=alt.X("ci_95L:Q", axis=alt.Axis(title='Pooled mean (CI)'), scale=alt.Scale(domain=[-1.0, 1.0])),
        x2=alt.X2("ci_95U:Q"),
        tooltip=tooltips,
    ).properties(width=width, height=height)

    mean_ticks = chart.mark_tick(size=tick_size, thickness=2).encode(
        y=alt.Y("resname:N"),
        opacity=alt.value(1),
        color=alt.value('white'),
        x=alt.X('mean:Q', axis=alt.Axis(title='Pooled mean (CI)'))
    )
    
    p_val_texts = chart.mark_text(size=14, thickness=3).encode(
        y=alt.Y("resname:N", axis=None, title=None),
        color=alt.value('black'),
        text=alt.Text('p:Q', format='.2f'),
        tooltip=tooltips,
    ).properties(
        width=60,
        height=height,
    )

    chart_left = alt.layer(ci_bars, mean_ticks).facet(
        row=alt.Row(
            f"country:N",
            sort=country_values,
            header=alt.Header(title=None)
        ),
    )
    chart_right = p_val_texts.facet(
        row=alt.Row(
            f"country:N",
            sort=country_values,
            header=alt.Header(title='p value',labels=False, titlePadding=5, titleOrient='top')
        ),
    )
    
    
    chart = alt.hconcat(chart_left, chart_right).resolve_scale(y='shared').properties(title={
        "text": [f"{group_type} Risk Comparison by Country"], 
        "dx": 50,
        "subtitle": get_visualization_subtitle(data_release=DATA_RELEASE, num_sites=NUM_SITES),
        "subtitleColor": "gray",
        "anchor": "middle",
    })

    chart = apply_theme(chart)
    
    for_website(chart, "Demographics", f"plot_risk_comparison_for_{group_type}_with_country_rowfacet", df=group_type_df)

    return chart

In [None]:
create_plot_risk_comparison_for_group_type_with_country_rowfacet("Age Group", width=300, height=100, tick_size=15)

In [None]:
create_plot_risk_comparison_for_group_type_with_country_rowfacet("Race", width=300, height=60, tick_size=15)

In [None]:
create_plot_risk_comparison_for_group_type_with_country_rowfacet("Sex", width=300, height=30, tick_size=15)