In [None]:
%load_ext autoreload
%autoreload 2
%aimport utils_1_1

import pandas as pd
import numpy as np
import altair as alt
from altair_saver import save
import datetime
import dateutil.parser
from os.path import join

from constants_1_1 import SITE_FILE_TYPES
from utils_1_1 import (
    get_site_file_paths,
    get_site_file_info,
    get_site_ids,
    get_visualization_subtitle,
    get_country_color_map,
)
from theme import apply_theme
from web import for_website

alt.data_transformers.disable_max_rows(); # Allow using rows more than 5000

In [None]:
# df = pd.read_csv(join("..", "data", "res.score.toShare.csv"))
df = pd.read_csv(join("..", "data", "res.score.train.early.valid.all.toShare.csv"))

df

In [None]:
def risk(_d, metric='pos'):
    d = _d.copy()
    
    """
    DATA PREPROCESSING...
    """
    d.loc[d.site == 'combine', 'site'] = 'All Sites'
    d.cat = d.cat.apply(lambda x: {'L':'Low Risk', 'M': 'Medium Risk', 'H': 'High Risk'}[x])
    
    """
    PLOT!
    """
    y_title = '% of Patients in Each Category' if metric == 'pos' else '% of Event in Each Category'
    colors = ['#7BADD1', '#427BB5', '#14366E'] if metric == 'pos' else ['#A8DED1', '#3FA86F', '#005A24']
    width = 300
    size = 50
    y_scale = alt.Scale(domain=[0, 1]) if metric == 'pos' else alt.Scale()
    
    bar = alt.Chart(
        d
    ).transform_calculate(
        order="{'Low Risk':0, 'Medium Risk': 1, 'High Risk': 2}[datum.variable]"  
    ).transform_filter(
        {'field': 'metric', 'oneOf': [metric]}
    ).encode(
        x=alt.X("month:N", title='Month', scale=alt.Scale(domain=['Mar-Apr', 'May-Jun', 'Jul-Aug', 'Sep-Oct', 'Nov-Dec'])),
        y=alt.Y("value:Q", title=y_title, axis=alt.Axis(format='.0%'), scale=y_scale),
        color=alt.Color("cat:N", title='Category', scale=alt.Scale(domain=['Low Risk', 'Medium Risk', 'High Risk'], range=colors)),
        order="order:O"
    ).properties(
        width=width
    )
    
    if metric == 'pos':
        bar = bar.mark_bar(
            size=size, stroke='black'
        )
    else:
        bar = bar.mark_line(
            size=3, point=True, opacity=0.8
        )
    
    d['visibility'] = d['value'] > 0.08
    text = alt.Chart(
        d
    ).transform_filter(
        {'field': 'metric', 'oneOf': [metric]}
    ).mark_text(size=16, dx=0, dy=5, color='white', baseline='top', fontWeight=500).encode(
        x=alt.X('month:N'),
        y=alt.Y('value:Q', stack='zero'),
        detail='cat:N',
        text=alt.Text('value:Q', format='.0%'),
        order="order:O",
        opacity=alt.Opacity('visibility:N', scale=alt.Scale(domain=[True, False], range=[1, 0]))
    )
#     .transform_filter(
#         (f'datum.value > 0.10')
#     )
    
    if metric == 'pos':
        bar = (bar + text)
    
    bar = bar.facet(
        column=alt.Column('site:N', header=alt.Header(title=None)),
    )
    
    """
    COMBINE
    """
    res = bar.properties(
        title={
            "text": [
                f"Distribution of Risk Scores" if metric == 'pos' else f"Event Rate of Risk Scores"
            ],
            "dx": 80,
            "subtitle": [
                # lab, #.title(),
                get_visualization_subtitle(data_release='2021-01-31', with_num_sites=False)
            ], 
            "subtitleColor": "gray",
        }
    )

    
    return res

In [None]:
pos = risk(df, metric='pos')
ppv = risk(df, metric='ppv')

res = alt.vconcat(pos, ppv, spacing=30).resolve_scale(color='independent')

res = apply_theme(
    res,
    axis_y_title_font_size=16,
    title_anchor='start',
    legend_orient='right',
    header_label_font_size=16
)

res