In [326]:
# standard imports
import h5py
import numpy as np
import os
import pandas as pd
import pickle
import yaml

from datetime import timedelta
from pathlib import Path

# bokeh
from bokeh.io import output_notebook, export_png, export_svgs
from bokeh.layouts import gridplot
from bokeh.models import HoverTool, Title, FactorRange, LinearAxis, Legend
from bokeh.palettes import Blues
from bokeh.plotting import figure, show, output_file, ColumnDataSource
from bokeh.transform import factor_cmap
from bokeh.sampledata.us_counties import data as counties
output_notebook()

# lib
import sys
sys.path.append('../')
from metrics import compute_metrics, _compute_metrics

In [339]:
def aggregate_states(df):
    df = df.transpose()
    df["state"] = [r.split(', ')[-1] for r in df.index]
    df.reset_index(drop=True, inplace=True)
    df = df.groupby('state').sum()
    df = df.transpose()
    df.index.set_names(['date'], inplace=True)
    return df

def load_backfill(job, basedir='/checkpoint/maxn/covid19/forecasts'):
    """collect all forcasts from job dir"""
    jobdir = os.path.join(basedir, job)
    forecasts = {}
    configs = []
    for path in Path(jobdir).rglob('*_forecast.csv'):
        date = str(path).split('/')[7]
        assert date.startswith('sweep_'), date
        date = date[6:]
        forecasts[date] = path
        cfg = '/'.join(str(path).split('/')[:-1] + ['ar.yml'])
        cfg = yaml.load(open(cfg), Loader=yaml.FullLoader)['train']
        cfg['date'] = date
        configs.append(cfg)
    configs = pd.DataFrame(configs)
    configs.set_index('date', inplace=True)
    return forecasts, configs
        
def load_predictions(path):
    df = pd.read_csv(path, index_col='date', parse_dates=['date'])
    return df

def plot_comparison(mets, other, date, metric):
    source = ColumnDataSource(mets)
    p = figure(
        x_axis_type='datetime', 
        plot_height=300, 
        plot_width=350, 
        title=f"Forecast Quality US {date}", 
        tools="save,hover",
        x_axis_label='Day', 
        y_axis_label=metric,
     )
    l_ar = p.line(x='day', y='AR', source=source, line_width=3, color='#009ed7', legend_label='FAIR-AR')
    l_na = p.line(x='day', y='Naive', source=source, line_width=3, color='LightGray', legend_label='Naive')
    l_ot = p.line(x='day', y=other, source=source, line_width=3, color='#009ed7', line_dash='dotted', legend_label=other)
    p.legend.location = 'top_left'
    p.output_backend = 'svg'
    return p, legend

def plot_metric_for_dates(dates, other, ftemplate, metric='MAE'):
    ps = []
    for date in dates:
        df_other = load_predictions(ftemplate.format(date))
        df_ar = aggregate_states(pd.read_csv(fs[date], index_col='date', parse_dates=['date']))

        met_other = _compute_metrics(df_states, df_other)
        met_ar = _compute_metrics(df_states, df_ar)
        source = pd.DataFrame({
            'Naive': met_other.loc[f'{metric}_NAIVE'],
            other: met_other.loc[metric], 
            'AR': met_ar.loc[metric]
        })
        source.index.set_names('day', inplace=True)
        p, legend = plot_comparison(source, other, date, metric)
        ps.append(p)
    p = gridplot(ps, ncols=2, plot_width=350)
    return p

### Progression in the US

In [369]:
# Load ground truth data
df_states = aggregate_states(pd.read_csv('../data/usa/data_deaths.csv', index_col='region').transpose())
df_states.index = pd.to_datetime(df_states.index)

# plot deaths over time 
source = ColumnDataSource(df_states.iloc[60:])
p = figure(
    x_axis_type='datetime', 
    plot_height=350, 
    plot_width=500, 
    title=f"Morbidity per US State", 
    tools="save",
    x_axis_label='Day', 
    y_axis_label='Deaths',
)
for state in df_states.columns:
    p.line(x='date', y=state, source=source, line_width=1, color='#009ed7')
p.output_backend = 'svg'
show(p)

### Load Backfill and Configs

In [391]:
#job = "us/2020_05_08_18_10"
#job = "us/2020_05_09_06_25"
# job = "us/2020_05_09_07_50"
job = "us/2020_05_09_08_06"
job = "us/2020_05_09_12_58"
job = "us/2020_05_09_13_08"
fs, cfgs = load_backfill(job) 
cfgs.drop(columns=['fdat', 'fpop'])

Unnamed: 0_level_0,decay,loss,lr,momentum,niters,t0,test_on,weight_decay,window
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2020-04-05,latent16_1,nb,0.001,0.9,20000,10,21,0.1,20
2020-04-08,latent8_1,nb,0.001,0.9,20000,10,21,0.1,20
2020-04-12,latent4_1,nb,0.001,0.9,20000,10,21,0.1,25
2020-04-15,latent8_1,nb,0.001,0.9,20000,10,21,0.1,15
2020-04-19,latent16_1,nb,0.001,0.9,20000,10,21,0.1,10
2020-04-22,latent4_1,nb,0.001,0.9,20000,10,21,0.1,20
2020-04-26,latent16_1,nb,0.001,0.9,20000,10,21,0.1,20
2020-04-29,latent8_1,nb,0.001,0.9,20000,10,21,0.1,10
2020-05-03,latent4_1,nb,0.001,0.9,20000,10,21,0.1,10


### Los Alamos

Compare our forecasts to published data by [Los Alamos National Laboratory](https://covid-19.bsvgateway.org/)

In [392]:
dates_los_alamos = [
    #'2020-04-05', 
    #'2020-04-08', 
    #'2020-04-12',
    #'2020-04-15',
    '2020-04-19',
    '2020-04-22',
    '2020-04-26', 
    '2020-04-29',
    '2020-05-03',
    #'2020-05-06',
]
p = plot_metric_for_dates(dates_los_alamos, 'Los Alamos', '../data/losalamos/predictions_{}.csv', 'MAE')
show(p)
_ = export_png(p, filename='/tmp/losalamos.png')

### IHME

In [250]:
dates_ihme = [
    '2020-04-01'
]
plot_dates(dates_ihme, 'IHME', '../data/ihme/prediction-deaths-{}.csv')

FileNotFoundError: [Errno 2] File ../data/ihme/prediction-deaths-2020-04-01.csv does not exist: '../data/ihme/prediction-deaths-2020-04-01.csv'