In [1]:
import sqlite3
import pandas
import click
import os
from datetime import timedelta, datetime
from epiweeks import Week
import re
import contextlib
import json
from subprocess import check_output, check_call
from glob import glob, iglob
import multiprocessing as mp
import plotly.express as px
import sys
sys.path.insert(0, "/private/home/mattle/covid19_spread")
import metrics
from functools import partial
from submit_forecast import submit_reichlab
import numpy as np

In [2]:
ctx = click.Context(submit_reichlab)
sweeps = {
    'FAIR-paper-backfill': '/checkpoint/mattle/covid19/forecasts/us/2021_02_02_12_44_23'
}

# def f(team, sweep):
#     f = open(os.devnull, 'w')
#     with contextlib.redirect_stdout(f):
#         if not os.path.exists(os.path.join(sweep, "model_selection.json")):
#             return
#         ms = json.load(open(os.path.join(sweep, "model_selection.json")))
#         ms = {x["name"]: x["pth"] for x in ms}
#         if not os.path.exists(os.path.join(ms["best_mae"], "final_model_validation.csv")):
#             return
#         ctx.invoke(submit_reichlab, pth=sweep, no_push=True, team=team, nweeks=4, no_pull=True)
        
# team = "FAIR-NRAR-smooth-7d-inc-cases"
# def format(team):
#     os.makedirs(f"/checkpoint/mattle/covid19/data/lematt1991_covid19-forecast-hub/data-processed/{team}", exist_ok=True)
#     directories = glob(f"{sweeps[team]}/sweep*")
#     with mp.Pool() as p:
#         p.map(partial(f, team), directories)

# for team, sweep in sweeps.items():
#     format(team)

In [3]:
data_pth = '/checkpoint/mattle/covid19/data/lematt1991_covid19-forecast-hub'
gt = pandas.read_csv(
    os.path.join(data_pth, 'data-truth/truth-Incident Cases.csv'), 
    dtype={'location': str}, 
    parse_dates=['date']
)
gt = gt[gt['location'].str.match('\d{5}')]

In [4]:
def mk_gt(gt, forecast_date):
    forecast_date = pandas.to_datetime(forecast_date)
    next_week = Week.fromdate(forecast_date).daydate(5) if forecast_date.weekday() in {0, 6} else  Week.fromdate(forecast_date).daydate(5) + timedelta(days=7)
    submission = []
    next_week = pandas.to_datetime(next_week)
    prev_date = forecast_date
    for i in range(1, 5):
        if next_week > gt['date'].max():
            break
        submission.append(
            gt[(gt['date'] >= prev_date) & (gt['date'] <= next_week)].groupby('location')['value'].sum().reset_index()
        )
        submission[-1]["target"] = f"{i} wk ahead inc case"
        submission[-1]["target_end_date"] = next_week
        next_week += timedelta(days=7)
    return pandas.concat(submission) if len(submission) > 0 else None

In [18]:
case_forecasts = []
data_pth = '/checkpoint/mattle/covid19/data/lematt1991_covid19-forecast-hub'
for file in iglob(os.path.join(data_pth, 'data-processed/*/*.csv')):
    x = check_output(f'cat {file} | grep "inc case" | head -n 1', shell=True).decode('utf-8')
    if x.strip() != '':
        date = pandas.to_datetime(re.search('(\d+-\d+-\d+)', os.path.basename(file)).group(0))
        case_forecasts.append({'file': file, 'date': date})
case_forecasts = pandas.DataFrame(case_forecasts)

In [22]:
# Code to compute dates for paper backfill
temp = case_forecasts.copy()
temp["model"] = temp["file"].apply(lambda x: x.split('/')[-2])
temp["model"] = temp["model"].apply(lambda x: "CU" if x.startswith("CU") else x)
models = [
    "Google_Harvard-CPF", 
    "CU", 
    "LANL-GrowthRate", 
    "Microsoft-DeepSTIA",
    "UVA-Ensemble",
    "LNQ-ens1",
    "CEID-Walk"
]
subset = temp.set_index("model").loc[models]

subset = subset.reset_index().drop_duplicates(["date", "model"])
subset["dummy"] = 1
pivot = subset.pivot(index="date", columns="model", values="dummy")
pivot = pivot[pivot.sum(axis=1) > 1]

K = 10
dates = []
for model in pivot.sum(0).index:
    x = pivot[~pivot[model].isnull()]
    # Already have this many dates included, fetch K-current more
    current = (~x.reindex(dates)[model].isnull()).sum()
    not_taken = x[~x.index.isin(dates)]
    dates.extend(not_taken.iloc[np.linspace(0, len(not_taken) - 1, K - current)].index) 


display(pivot.loc[dates].sort_index())

cmd = "python cv.py backfill cv/us_prod.yml bar -remote -array-parallelism 250"
for date in dates:
    cmd += f" -dates {date.date()}"
print(cmd)


model,CU,Google_Harvard-CPF,LANL-GrowthRate,Microsoft-DeepSTIA
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
2020-08-06,1.0,,1.0,
2020-08-13,1.0,,1.0,
2020-09-03,1.0,,1.0,
2020-09-27,1.0,1.0,,
2020-10-01,1.0,,1.0,
2020-10-11,1.0,1.0,,
2020-10-15,1.0,,1.0,
2020-10-25,1.0,1.0,,
2020-11-08,1.0,1.0,,
2020-11-15,1.0,1.0,,


python cv.py backfill cv/us_prod.yml bar -remote -array-parallelism 250 -dates 2020-08-06 -dates 2020-09-03 -dates 2020-10-01 -dates 2020-10-15 -dates 2020-10-25 -dates 2020-11-08 -dates 2020-11-19 -dates 2020-12-06 -dates 2020-12-17 -dates 2021-01-03 -dates 2020-09-27 -dates 2020-10-11 -dates 2020-11-15 -dates 2020-12-13 -dates 2021-01-11 -dates 2021-02-01 -dates 2020-08-13 -dates 2021-01-25 -dates 2020-12-07 -dates 2020-12-07 -dates 2020-12-14 -dates 2020-12-21 -dates 2020-12-21 -dates 2021-01-04 -dates 2021-01-18


In [7]:
def compute_mae(current_gt, pth):
    forecast = pandas.read_csv(pth, parse_dates=['forecast_date', 'target_end_date'], dtype={'location': str})
    forecast = forecast[forecast['location'].str.match('\d{4,5}') & (forecast['type'] == 'point')]
    if len(forecast) == 0:
        return None  # no county level forecasts
    diff = (
        forecast.set_index(['location', 'target_end_date'])['value'] -
        current_gt.set_index(['location', 'target_end_date'])['value'] 
    )
    mae = diff.abs().groupby(level=1).mean().sort_index().reset_index()
    mae = mae.merge(diff.groupby(level=1).mean().sort_index().reset_index(), on='target_end_date', suffixes=('_abs', '_raw'))
    mae['nweeks'] = range(1, len(mae) + 1)
    mae['model'] = re.search('\d{4}-\d{2}-\d{2}-(.*).csv', os.path.basename(pth)).group(1)
    mae['basedate'] = basedate
    return mae

# for pth in case_forecasts:
maes = []
for basedate, group in case_forecasts.groupby('date'):
    if not group['file'].str.contains('FAIR').any():
        continue
    current_gt = mk_gt(gt, basedate)
    if current_gt is None:
        continue
    for pth in group['file']:
        maes.append(compute_mae(current_gt, pth))
maes = pandas.concat(maes)

In [8]:
our_model = "FAIR-paper-backfill"
fair = maes[maes["model"] == our_model]

fair = fair[~fair["value_abs"].isnull()]
x = maes.set_index(["basedate", "target_end_date"]).loc[fair.set_index(["basedate", "target_end_date"]).index]
# To remove other FAIR models
x = x[(~x["model"].str.startswith("FAIR")) | (x["model"] == our_model)]
x = x.reset_index()
all_maes = x.copy()
x = x.loc[x.groupby(["model", "basedate"])["target_end_date"].idxmax()]
x["rank"] = x.groupby(["basedate", "target_end_date"])["value_abs"].rank(pct=True)


In [9]:
x["rank"] = x.groupby(["basedate", "target_end_date"])["value_abs"].apply(lambda x: (x.rank()-1) / len(x))
x["groupsize"] = x.groupby(["basedate", "target_end_date"])["value_abs"].transform('size')
grouped = x[x["groupsize"] > 1].groupby("model").agg({"rank": "mean"}).sort_values(by="rank")
grouped["nsubmissions"] = x[x["groupsize"] > 1].groupby("model")["rank"].apply(len)
grouped.sort_values(by="rank")

Unnamed: 0_level_0,rank,nsubmissions
model,Unnamed: 1_level_1,Unnamed: 2_level_1
FAIR-paper-backfill,0.051446,22
CU-scenario_mid,0.220663,14
UVA-Ensemble,0.254545,7
CU-select,0.257259,14
CU-scenario_high,0.295852,14
LNQ-ens1,0.302333,8
Microsoft-DeepSTIA,0.311869,6
CEID-Walk,0.352746,10
Google_Harvard-CPF,0.358126,11
COVIDhub-baseline,0.368831,7


In [26]:
def highlight_max(data, color='yellow'):
    attr = 'background-color: {}'.format(color)
    is_max = data == data.min().min()    
    return [attr if v else '' for v in is_max]

    
models = [
    our_model, 
    "LANL-GrowthRate", 
    "Microsoft-DeepSTIA", 
#     'CU-nochange', 
#     'CU-scenario_high',
#     'CU-scenario_low',
#     'CU-scenario_mid',
    'CU-select',
    'Google_Harvard-CPF',
    "UVA-Ensemble",
    "LNQ-ens1",
    "CEID-Walk",
]

# models = [
#     "Google_Harvard-CPF", 
#     "CU", 
#     "LANL-GrowthRate", 
#     "Microsoft-DeepSTIA",
#     "UVA-Ensemble",
#     "LNQ-ens1",
#     "CEID-Walk"
# ]

temp = all_maes.set_index("model").loc[models].reset_index()
pivot = temp.pivot(index=["nweeks", "basedate"], columns="model", values="value_abs")
# pivot.style.apply(highlight_max, axis=1)


def fmt(row):
    result = row.copy()
    minval = row.min()
    for key in result.index:
        val = "-" if np.isnan(row[key]) else f"{row[key]:.3f}"
        if row[key] == minval:
            val = r'{\cellcolor{blue!25} ' + val + '}'
#             val = f"\\cellcolor{{{val}}}"

        result[key] = val
    return result

        
pivot = pivot.apply(fmt, axis=1).rename(columns={"FAIR-paper-backfill": r"$\bAR$"})

pivot.columns = [c.split('-')[0].replace('_', r'\_') for c in pivot.columns]

for nweeks, group in pivot.groupby(level=0):

    print(f"""
    
\\begin{{table*}}[t]
\\small
\\caption{{COVID-19 Forecast Hub MAE ({nweeks} week horizon).\label{{tab:reichalb_eval_{nweeks}}} }}

\\centering
    """)
    
    group = group.droplevel(0)
    print(group.sort_index().to_latex(escape=False, column_format="c" * (len(group.columns)+1)))
    print("\\end{table*}")
    


    
\begin{table*}[t]
\small
\caption{COVID-19 Forecast Hub MAE (1 week horizon).\label{tab:reichalb_eval_1} }

\centering
    
\begin{tabular}{ccccccccc}
\toprule
{} &                           CEID &       CU &                         $\bAR$ & Google\_Harvard &                           LANL &                            LNQ &                     Microsoft &      UVA \\
basedate   &                                &          &                                &                 &                                &                                &                               &          \\
\midrule
2020-08-06 &                              - &   70.681 &   {\cellcolor{blue!25} 58.610} &               - &                         90.897 &                              - &                             - &        - \\
2020-08-13 &                              - &   61.309 &   {\cellcolor{blue!25} 53.553} &               - &                         74.575 &                              - &      