In [None]:
import plotly.express as px
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt 

from urllib.request import urlopen
import json

import geopandas
import ast

import sys
sys.path.append('../src/')

from utils.ckm_plotting import plot_rt, gen_dropdown
from utils.state_abbreviations import state_abbr_map, state_abbr_map_r
from generate_rt import create_case_pop_df

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

from plotly.offline import iplot

In [None]:
RT_COUNTY_DATA = '../../DATA/rt_county/rt_county.csv'
RT_STATE_DATA = '../../DATA/rt_state/rt_state.csv'

In [None]:
rt_county_df = pd.read_csv(RT_COUNTY_DATA)
rt_county_df['countyFIPS'] = rt_county_df['countyFIPS'].apply(lambda x: f"{x:05d}")
rt_county_df.tail()

In [None]:
rt_county_df['stateFIPS'] = rt_county_df['countyFIPS'].apply(lambda x: x[:2])
state_fips_map = rt_county_df[['stateFIPS','state']].drop_duplicates().set_index('stateFIPS').to_dict()['state']
state_fips_map_r = {v:k for k,v in state_fips_map.items()}

In [None]:
rt_state_df = pd.read_csv(RT_STATE_DATA)
rt_state_df['state'] = rt_state_df.region.map(state_abbr_map)
rt_state_df['stateFIPS'] = rt_state_df.state.map(state_fips_map_r)
rt_state_df.tail()

In [None]:
county_geojson_df = geopandas.read_file('https://raw.githubusercontent.com/plotly/datasets/master/geojson-counties-fips.json')
county_geojson_df['STATE_NAME'] = county_geojson_df['STATE'].map(state_fips_map)
# county_geojson_df = county_geojson_df.rename(columns={'id':'countyFIPS'})
county_geojson_df.tail()

In [None]:
# state_geojson_df = geopandas.read_file('https://eric.clst.org/assets/wiki/uploads/Stuff/gz_2010_us_outline_500k.json')
state_geojson_df = geopandas.read_file('../../DATA/geojsons/state_geojsons/cb_2018_us_state_20m.shp')
state_geojson_df.head()

# RT Animation (Country Wide)

In [None]:
SAMPLE_FREQUENCY = 7# in days
DATE_SUBSET = [date for i, date in enumerate(np.sort(rt_county_df.date.unique().tolist())) if i%SAMPLE_FREQUENCY==0]

In [None]:
disp_col = 'median'
fig = px.choropleth(
    data_frame=rt_state_df[rt_state_df.date.isin(DATE_SUBSET)], 
    #geojson=counties, 
    geojson=ast.literal_eval(state_geojson_df.to_json()), 
    locations='stateFIPS', 
    color=disp_col,
    color_continuous_scale=[[0., 'rgb(0,255,0)'], [1.0, 'rgb(50,50,50)']],
    animation_frame='date',
    range_color=(rt_state_df[disp_col].quantile(q=0.05), rt_state_df[disp_col].quantile(q=0.95)),
    hover_name='region',
    labels={disp_col: f'R_t ({disp_col})'},
    featureidkey="properties.GEOID",
    scope='usa'
)

fig.show()

In [None]:
disp_col = 'mean'

fig = px.choropleth(
    data_frame=rt_county_df[rt_county_df.date.isin(DATE_SUBSET)], 
    geojson=ast.literal_eval(county_geojson_df.dropna(subset=['STATE_NAME']).to_json()), 
    locations='countyFIPS', 
    color=disp_col,
    color_continuous_scale=[[0., 'rgb(0,255,0)'], [1.0, 'rgb(50,50,50)']],
    animation_frame='date',
    range_color=(rt_state_df[disp_col].quantile(q=0.05), rt_state_df[disp_col].quantile(q=0.95)),
    hover_name='region',
    labels={disp_col: f'R_t ({disp_col})'},
    featureidkey="properties.id",
    scope='usa'
)

# fig.update_geos(fitbounds="locations", visible=False)
fig.update_layout(margin={"r":0,"t":0,"l":0,"b":0})

fig.show()

In [None]:
for state in county_geojson_df['STATE_NAME'].unique().tolist():
    try:
        print (f'{state} : SUCCESS')
        ast.literal_eval(county_geojson_df[county_geojson_df['STATE_NAME']==state].to_json())
    except:
        print (f'{state} : FAIL')


In [None]:
rt_county_df[rt_county_df.date.isin(DATE_SUBSET)]

# RT Animation (by State)

In [None]:
def animate_state(data_df=rt_county_df,
                  geojson_df=county_geojson_df,
                  date_list=DATE_SUBSET,
                  state='NY',
                  data_filter_col='state',
                  geo_filter_col='STATE_NAME',
                  disp_col='mean',
                  date_col='date',
                  featureidkey='properties.id',
                  locations='countyFIPS'
                 ):
   
    state_df_subset = data_df[data_df[date_col].isin(DATE_SUBSET)]
    #geojson_subset = geojson_df[geojson_df[geo_filter_col]==state]
    state_df = state_df_subset[state_df_subset[data_filter_col]==state]

    fig = px.choropleth(
        data_frame=state_df, 
        geojson=ast.literal_eval(geojson_df.to_json()), 
#         geojson=ast.literal_eval(geojson_subset.to_json()), 
        locations=locations, 
        color=disp_col,
        color_continuous_scale=[[0., 'rgb(0,255,0)'], [1.0, 'rgb(0,0,0)']],
        animation_frame=date_col,
        range_color=(state_df_subset[disp_col].quantile(0.05), 
                     state_df_subset[disp_col].quantile(0.95)),
        hover_name='region',
        labels={disp_col:f'R_t ({disp_col})'},
        featureidkey=featureidkey,
    )

    fig.update_geos(fitbounds="locations", visible=False)
    fig.update_layout(margin={"r":0,"t":0,"l":0,"b":0})

    return fig

In [None]:
fig = animate_state(state='TX', 
                    data_df=rt_county_df,
                    geojson_df=county_geojson_df.dropna(subset=['STATE_NAME']),
                    date_list=DATE_SUBSET,
                    disp_col='mean'
                  )
fig.show()

In [None]:
fig = animate_state(state='TX', 
                    data_df=rt_state_df,
                    geojson_df=state_geojson_df,#.dropna(subset=['STATE_NAME']),
                    date_list=DATE_SUBSET,
                    disp_col='mean',
                    featureidkey='properties.GEOID',
                    locations='stateFIPS'
                  )
fig.show()

# RT Dropdowns

In [None]:
## SNAPSHOT AT A DATE
rt_county_df_subset = rt_county_df[rt_county_df.date==rt_county_df.date.max()]

STATE_LIST = np.sort(county_geojson_df.STATE_NAME.dropna().unique().tolist())
# STATE_LIST = ["NY","NJ","PA","FL", "AL"]

data = []
buttons_list = []

fig_dict = {}

for i, state in enumerate(STATE_LIST):
    
    geojson_subset = county_geojson_df[county_geojson_df.STATE_NAME==state]
    state_df = rt_county_df_subset[rt_county_df_subset['state']==state]
    
    fig = px.choropleth(
        data_frame=state_df, 
        #geojson=counties, 
        geojson=ast.literal_eval(geojson_subset.to_json()), 
        locations='countyFIPS', 
        color='mean',
#         color_continuous_scale="Bluered",
        color_continuous_scale=[[0., 'rgb(0,255,0)'], [1.0, 'rgb(0,0,0)']],
        range_color=(0.2, 2),
        hover_name='region',
        labels={'mean':'R_t (mean)'},
        featureidkey="properties.id",
    )
    
    
    visible_list = [False]*len(STATE_LIST)
    visible_list[i] = True
    buttons_list.append(
        dict(label = state,
             method = 'update',
             args = [{'visible': visible_list},
                     {'title': f'{state} R(t) values'}
                    ]
            )        
    )

    fig_dict[state] = fig

    
fig_dropdown = gen_dropdown(
    figure_list=[sub_fig.to_plotly_json() for k,sub_fig in fig_dict.items()],
    button_labels=[k for k,v in fig_dict.items()],
)


data = fig_dropdown.to_plotly_json()['data']



updatemenus = list([
    dict(active=0,
         buttons=buttons_list,       
    )
])

layout = dict(
        title='Tracking R(t) values', 
        showlegend=False,
        updatemenus=updatemenus,
        geo = dict(
            fitbounds='locations',
            visible=False
              ),
        coloraxis = dict(
            cmin=0.2,
            cmax=2,
            colorscale=[[0.0, 'rgb(0,0,255)'], [1.0, 'rgb(255,0,0)']],
    )
)

fig_dropdown = dict( data=data, layout=layout )
iplot(fig_dropdown)

# RT Trends

In [None]:
subset_df = rt_county_df[rt_county_df.state=='AL']

ncols = np.min([4, subset_df.region.nunique()])
nrows = int(np.ceil(subset_df.region.nunique() / ncols))

fig, axes = plt.subplots(
    nrows=nrows,
    ncols=ncols,
    figsize=(14, nrows*3),
    sharey='row')

for ax, (county_state, result) in zip(axes.flat, subset_df.groupby('region')):
    plot_rt(county_state, result, ax)

fig.tight_layout()
fig.set_facecolor('w')

# RT Snapshot

In [None]:
def rt_live_error_plot(rt_df=rt_county_df, 
                       filter_column='state', 
                       filter_field='NY',
                       height=800,
                       width=800
                      ):

    subset_df = rt_df[rt_df[filter_column]==filter_field]

    subset_df.loc[:, 'error_y_plus'] = subset_df['upper_90'] - subset_df['mean']
    subset_df.loc[:,'error_y_minus'] = subset_df['mean'] - subset_df['lower_90']
    subset_df.loc[:,'color'] = (subset_df['mean'] >= 1).map({True: 'High Risk', False: 'Low/Moderate Risk'})
    fig = px.scatter(
        data_frame=subset_df.loc[subset_df.date == subset_df.date.max(), :].sort_values('mean'),
        y='region',
        x='mean',
        error_x='error_y_plus',
        error_x_minus='error_y_minus',
        color='color',
        hover_name='region',
        width=width,
        height=height
    )
    
    return fig

In [None]:
rt_live_error_plot(filter_field='NY')

# RT Dashboard

In [None]:
test_df = rt_county_df[rt_county_df['countyFIPS']=='36047']
test_df

In [None]:
#test_df = rt_county_df[rt_county_df['countyFIPS']=='36047']
test_df = test_df2[test_df2['region']=='Albany County NEW YORK']

test_df = test_df.sort_values('date')

fig = make_subplots(
    rows=3,
    cols=3,
    specs=[[{'type':'indicator'}, {'type':'indicator'}, {'type':'indicator'}],
           [{'colspan':3, 'rowspan':2,'type':'xy'}, None,None],
            [None, None,None]],
    subplot_titles=(
        r"$\text{R}_t$",
        r"$\text{Change}$",
        r"$\text{Risk}$", 
        (test_df.region.tolist()[0])
    )
)

fig.add_trace(go.Scatter(
    x=test_df['date'].tolist(), 
    y=test_df['lower_90'].tolist(),
    fill=None,
    mode='lines',
    fillcolor='rgba(0,255,0,0.1)',
    name='lower_90',
    line_color='rgba(0,255,0,0.1)',
    ),
    row=2,
    col=1
)
fig.add_trace(go.Scatter(
    x=test_df['date'].tolist(), 
    y=test_df['upper_90'].tolist(),
    fill='tonexty', # fill area between trace0 and trace1
    mode='lines', 
    fillcolor='rgba(0,255,0,0.1)',
    name='upper_90',
    line_color='rgba(0,255,0,0.1)',
    ),
    row=2,
    col=1
)

fig.add_trace(go.Scatter(
    x=test_df['date'].tolist(), 
    y=test_df['mean'].tolist(),
#     fill='tonexty', # fill area between trace0 and trace1
    mode='markers+lines', 
    fillcolor='rgba(0,255,0,0.1)',
    name='mean',
    line_color='gray'
    ),
    row=2,
    col=1
)


fig.add_trace(go.Indicator(
    mode = "number",
    value = test_df['mean'].tolist()[-1],
    number={
        'valueformat':'0.3f',
        'font':{
            'color': '#3D9970' if (test_df['mean'].tolist()[-1] <1) else '#FF4136',
            'size' : 22
        }
    }
    ),
    row=1,
    col=1
)

fig.add_trace(go.Indicator(
    mode = "delta",
    delta= {
        'reference': test_df['mean'].tolist()[0], 
        'relative':True,
        'increasing': {
            'color':'#FF4136'
        },
        'decreasing': {
            'color': '#3D9970'
        },
        'font':{
            'size' : 22
        }
    },
    value = test_df['mean'].tolist()[-1],
    ),
    row=1,
    col=2
)

if test_df['mean'].tolist()[-1]>1:
    val = 'CRITICAL'
    col = 'red'
elif (test_df['mean'].tolist()[-1])>0.6:
    val = 'MEDIUM'
    col = 'orange'
else:
    val = 'LOW'
    col = 'green'

fig.add_trace(go.Indicator(
    mode = "number",
    delta= {
        'reference': test_df['mean'].tolist()[0], 
        'relative':True,
        'increasing': {
            'color':'#FF4136'
        },
        'decreasing': {
            'color': '#3D9970'
        }
    },
    value = test_df['mean'].tolist()[-1],
    number={
        'valueformat':'0.3f',
        'prefix' : r'$\textbf{'+val+'}$',
        'font':{
            'color': col,
            'size':30
        }
    }
    #title = {'text': "Speed"},
    #domain = {'x': [0, 1], 'y': [0, 1]}
    ),
    row=1,
    col=3
)
# fig.update_traces(layout=dict(paper_bgcolor="RoyalBlue"),
#                   col=1,row=1)

fig.update_layout(showlegend=False, title_text="", width=900, height=600)
fig.show()