In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
from dateutil import parser
from covidcaremap.data import (get_ihme_forecast, read_us_counties_gdf, read_us_states_gdf, 
                               external_data_path, read_state_case_info)
import numpy as np
import requests
import json
import plotly.express as px
import us
from functools import reduce
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from covidcaremap.chime import get_regional_predictions

In [None]:
def date(str):
    return parser.parse(str).date()

In [None]:
s_str = '20200330'
e_str = '20200412'
start_date = parser.parse(s_str)
end_date = parser.parse(e_str)

actual

In [None]:
gt = read_state_case_info(all_days=True)
gt_t = gt[(gt['date'] >= start_date) & (gt['date'] <= end_date)]
gt_t = gt_t[['date', 'state', 'hospitalizedIncrease']].sort_values(['state', 'date'])
gt_t.columns = ['date', 'state', 'actual']

ihme

In [None]:
ihme = get_ihme_forecast()
ihme = ihme[ihme['location_name'].isin([x.name  for x in us.states.STATES])].copy()
ihme['state'] = ihme['location_name'].apply(lambda x: us.states.lookup(x).abbr)
ihme_t = ihme[(ihme['date'] >= start_date) & (ihme['date'] <= end_date)]
ihme_t = ihme_t[['state', 'date', 'admis_mean']].sort_values(['state', 'date'])
ihme_t.columns = ['state', 'date', 'ihme']

chime

In [None]:
# testing whether passing an arbitrary function to `calculate_infected` works
def example_calc_infected(p):
    return 20

In [None]:
states = read_us_states_gdf()
cases = read_state_case_info(s_str)
state_cases = pd.merge(states, cases, 'inner', left_on='State', right_on='state')
chime = get_regional_predictions(state_cases, 
                                 region_id_column='State', 
                                 population_column='Population', 
                                 cases_column='positive',
                                 tested_column='tested',
                                 recovered_column='recovered',
                                 current_date = start_date.date())

In [None]:
chime_t = chime[(chime['date_total'] >= start_date) & (chime['date_total'] <= end_date)]
chime_t = chime_t[['date_total', 'State', 'admits_hospitalized']].sort_values(['State', 'date_total'])
chime_t.columns = ['date', 'state', 'chime']

merge

In [None]:
df = reduce(lambda x, y: pd.merge(x, y, how='inner', on=['state', 'date']), [gt_t, ihme_t, chime_t])
nulls = [r['state'] for _, r in df.groupby('state').sum().reset_index().iterrows() if r['actual'] == 0]
df = df[~df['state'].isin(nulls)].copy()

In [None]:
x = pd.melt(df, id_vars=['date', 'state'], value_vars=['actual', 'ihme', 'chime'], var_name='source', value_name='hospitalizations')

In [None]:
fig = px.line(x, x='date', y='hospitalizations', color='source',
              facet_col='state', 
              facet_col_wrap=6)
fig.update_yaxes(matches='x', title='')
fig.update_xaxes(matches='y', title='')
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fig.layout.yaxis.update({'title': '# of new COVID-19 hospitalizations'})
fig.update_layout(
    autosize=False,
    width=1700,
    height=1000,
    margin=dict(
        l=50,
        r=50,
        b=100,
        t=100,
        pad=4
    ),
    title=go.layout.Title(
        text="IHME and CHIME hospitalization models vs actual (3/30/2020 - 4/12/2020)",
        font=go.layout.title.Font(size=20),
    )
)

In [None]:
with open('CHIME_IHME_actual.html', 'w') as f:
    f.write(fig.to_html())

In [None]:
row = 1
col = 1
fig = make_subplots(rows=8, cols=4, 
                    start_cell="top-left", 
                    shared_xaxes=True, 
                    subplot_titles=df['state'].unique(),
                    x_title='Date', 
                    y_title='New hospitalizations')
for st in df['state'].unique():
    t = df[df['state'] == st]
    for color, source in [('blue', 'ihme'), ('green', 'chime'), ('purple', 'actual')]:
        fig.add_trace(
            go.Scatter(
                x=t['date'], 
                y=t[source], 
                line={'color': color}, 
                showlegend=False,
                mode='lines'),
                row=row, col=col
        )
    if row == 8:
        col += 1
        row = 1
    else:
        row += 1

In [None]:
fig.update_layout(
    autosize=False,
    width=1700,
    height=1000,
    margin=dict(
        l=50,
        r=50,
        b=100,
        t=100,
        pad=4
    ),
    title=go.layout.Title(
        text="IHME and CHIME hospitalization models vs actual (3/30/2020 - 4/12/2020)",
        font=go.layout.title.Font(size=20))
)