In [52]:
import pandas
import os
from datetime import timedelta, datetime
import re
import contextlib
import json
from subprocess import check_output, check_call
from glob import glob, iglob
import multiprocessing as mp
import plotly.express as px
import sys
sys.path.insert(0, "/private/home/mattle/covid19_spread")
import metrics
from functools import partial
import numpy as np
from xvfbwrapper import Xvfb
import os

In [53]:
sweeps = {
    'baseline': '/checkpoint/mattle/covid19/forecasts/us/2021_02_02_12_44_23',
    'no_granger': '/checkpoint/mattle/covid19/forecasts/us/2021_02_04_12_10_29',
    'no_cross_correlation':  '/checkpoint/mattle/covid19/forecasts/us/2021_02_04_12_12_23',
    'poisson': '/checkpoint/mattle/covid19/forecasts/us/2021_02_04_12_37_32',
}

In [54]:
gt = pandas.read_csv("/private/home/mattle/prod_covid19_spread/data/usa/data_cases.csv", index_col="region")
gt = gt.transpose()
gt.index = pandas.to_datetime(gt.index)

In [55]:
def load_backfill(name):
    pth = sweeps[name]
    results = []
    for file in iglob(os.path.join(pth, "**/model_selection.json")):
        ms = json.load(open(file))
        ms = {x['name']: x['pth'] for x in ms}
        job = ms["best_mae"]
        fcst_file = os.path.join(job, "final_model_validation.csv")
        if not os.path.exists(fcst_file):
            continue
        fcst = pandas.read_csv(fcst_file, index_col="date", parse_dates=["date"])
        mets = metrics.compute_metrics(gt, fcst)
        results.append({
            "name": name,
            "mets": mets,
            "basedate": fcst.index.min() - timedelta(days=1),
        })
    return results
results = {key: load_backfill(key) for key in sweeps.keys()}

In [56]:
{k: len(v) for k, v in results.items()}


{'baseline': 23, 'no_granger': 23, 'no_cross_correlation': 23, 'poisson': 23}

In [57]:
def concat(res):
    temp = []
    for x in res:
        temp.append(x['mets'].reset_index().melt(id_vars=["Measure"], var_name="target_date"))
        temp[-1]["basedate"] = x["basedate"]
    return pandas.concat(temp)


baseline = concat(results["baseline"])
baseline["name"] = "baseline"
others = {k: v for k, v in results.items() if k != 'baseline'}
for name, res in others.items():
    res = concat(res)
    res["name"] = name
    common_dates = list(set(baseline["basedate"]).intersection(res["basedate"]))
    df = pandas.concat([
        res.set_index("basedate").loc[common_dates].reset_index(), 
        baseline.set_index("basedate").loc[common_dates].reset_index()
    ])
    df["ndays"] = (df["target_date"] - df["basedate"]).dt.days
    grouped = df.groupby(["name", "ndays", "Measure"])["value"].mean().reset_index()
    title = f"{name} Ablation (MAE vs time horizon)"
    kwargs = {}
    if name == "no_cross_correlation":
        kwargs["log_y"] = True
    
    px.line(grouped[grouped["Measure"] == "MAE"], x="ndays", y="value", color="name", title=title, **kwargs).show()
        
        

In [59]:
{k: len(v) for k, v in results.items()}

{'baseline': 23, 'no_granger': 23, 'no_cross_correlation': 23, 'poisson': 23}