## About WIS
Good ressources:
-  Supplement of Cramer et al.
- code cramer et al. here https://github.com/reichlab/covid19-forecast-evals
- obviously https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1008618#sec015
- [git clone https://github.com/adrian-lison/interval-scoring.git](https://github.com/adrian-lison/interval-scoring/tree/master) Adrian Lison's code for WIS
- https://epiforecasts.io/scoringutils/ Scoring utils package -- perhaps best to use ?

In [1]:
from interval_scoring import scoring
import pandas as pd
import numpy as np
from tqdm import tqdm
import datetime
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()

In [2]:
# A modification of Lison's code that splits the calibration in underprediction and overprediction
def weighted_interval_score_fast(
    observations, alphas, q_dict, weights=None, percent=False, check_consistency=True
):
    """
    Compute weighted interval scores for an array of observations and a number of different predicted intervals.
    
    This function implements the WIS-score (2). A dictionary with the respective (alpha/2)
    and (1-(alpha/2)) quantiles for all alpha levels given in `alphas` needs to be specified.
    
    This is a more efficient implementation using array operations instead of repeated calls of `interval_score`.
    
    Parameters
    ----------
    observations : array_like
        Ground truth observations.
    alphas : iterable
        Alpha levels for (1-alpha) intervals.
    q_dict : dict
        Dictionary with predicted quantiles for all instances in `observations`.
    weights : iterable, optional
        Corresponding weights for each interval. If `None`, `weights` is set to `alphas`, yielding the WIS^alpha-score.
    percent: bool, optional
        If `True`, score is scaled by absolute value of observations to yield a percentage error. Default is `False`.
    check_consistency: bool, optional
        If `True`, quantiles in `q_dict` are checked for consistency. Default is `True`.
        
    Returns
    -------
    total : array_like
        Total weighted interval scores.
    sharpness : array_like
        Sharpness component of weighted interval scores.
    calibration : array_like
        Calibration component of weighted interval scores.
        
    (2) Bracher, J., Ray, E. L., Gneiting, T., & Reich, N. G. (2020). Evaluating epidemic forecasts in an interval format. arXiv preprint arXiv:2005.12881.
    """
    if weights is None:
        weights = np.array(alphas)/2

    if not all(alphas[i] <= alphas[i + 1] for i in range(len(alphas) - 1)):
        raise ValueError("Alpha values must be sorted in ascending order.")

    reversed_weights = list(reversed(weights))

    lower_quantiles = [q_dict.get(alpha / 2) for alpha in alphas]
    upper_quantiles = [q_dict.get(1 - (alpha / 2)) for alpha in reversed(alphas)]
    if any(q is None for q in lower_quantiles) or any(
        q is None for q in upper_quantiles
    ):
        raise ValueError(
            f"Quantile dictionary does not include all necessary quantiles."
        )

    lower_quantiles = np.vstack(lower_quantiles)
    upper_quantiles = np.vstack(upper_quantiles)

    # Check for consistency
    if check_consistency and np.any(
        np.diff(np.vstack((lower_quantiles, upper_quantiles)), axis=0) < 0
    ):
        raise ValueError("Quantiles are not consistent.")

    lower_q_alphas = (2 / np.array(alphas)).reshape((-1, 1))
    upper_q_alphas = (2 / np.array(list(reversed(alphas)))).reshape((-1, 1))

    # compute score components for all intervals
    sharpnesses = np.flip(upper_quantiles, axis=0) - lower_quantiles

    lower_calibrations = (
        np.clip(lower_quantiles - observations, a_min=0, a_max=None) * lower_q_alphas
    )
    upper_calibrations = (
        np.clip(observations - upper_quantiles, a_min=0, a_max=None) * upper_q_alphas
    )
    calibrations = lower_calibrations + np.flip(upper_calibrations, axis=0)
    upper_calibrations = np.flip(upper_calibrations, axis=0)
    lower_calibrations = lower_calibrations

    # scale to percentage absolute error
    if percent:
        sharpnesses = sharpnesses / np.abs(observations)
        calibrations = calibrations / np.abs(observations)
        raise ValueError("Not Supported with the calibration split")

    totals = sharpnesses + calibrations

    # weigh scores
    weights = np.array(weights).reshape((-1, 1))

    sharpnesses_weighted = sharpnesses * weights
    calibrations_weighted = calibrations * weights
    upper_calibrations_weighted = upper_calibrations * weights
    lower_calibrations_weighted = lower_calibrations * weights
    totals_weighted = totals * weights

    # normalize and aggregate all interval scores
    weights_sum = np.sum(weights)

    sharpnesses_final = np.sum(sharpnesses_weighted, axis=0) / weights_sum
    calibrations_final = np.sum(calibrations_weighted, axis=0) / weights_sum
    upper_calibrations_final = np.sum(upper_calibrations_weighted, axis=0) / weights_sum
    lower_calibrations_final = np.sum(lower_calibrations_weighted, axis=0) / weights_sum
    totals_final = np.sum(totals_weighted, axis=0) / weights_sum

    return totals_final, sharpnesses_final, calibrations_final, lower_calibrations_final, upper_calibrations_final

In [3]:
def score_Nwk_forecasts(gt, forecasts, n=4) -> pd.DataFrame: 
    if isinstance(gt, str):
        gt = pd.read_csv(gt)
    if isinstance(forecasts, str):
        forecast = pd.read_csv(forecast)

    # take only the locations and dates that are forecasted
    gt = gt[gt["location"].isin(forecasts["location"])]
    gt = gt[gt["date"].isin(forecasts.target_end_date)]

    #first_forecast_date = datetime.datetime.strptime(forecasts["target_end_date"].sort_values()[0], "%Y-%m-%d").date()
    #target_dates = pd.date_range(first_forecast_date, first_forecast_date + datetime.timedelta(days=n*7), freq="W-SAT").date

    gt_piv = gt.pivot(index="date", columns="location", values="value").sort_index()


    target_dict = dict(zip(gt_piv.index, [f"{n} wk ahead" for n in range(1,n+1)]))
    
    # Alpha for WIS
    alphas=np.array(sorted(forecasts["quantile"].unique()))[:11]*2
    
    # gt_piv.index should be similar to target_dict.keys() apart from format

    all_targets = []
    
    for target in target_dict.keys():
        f = forecasts[forecasts["target_end_date"] == target]
        q_dict = {}
        for q in f["quantile"].unique():
            q_dict[float(q)] = f[f["quantile"]==q].pivot(index=["target_end_date"], columns="location", values="value").sort_index().to_numpy().ravel()
        wis_total, wis_sharpness, wis_calibration, underprediction, overprediction =   weighted_interval_score_fast(observations=gt_piv.loc[target].to_numpy(), 
                                                                                        alphas=alphas, 
                                                                                        q_dict=q_dict, 
                                                                                        weights=alphas/2)
        df = pd.DataFrame([wis_total, wis_sharpness, wis_calibration, underprediction, overprediction], index = ["wis_total", "wis_sharpness", "wis_calibration", "wis_underprediction", "wis_overprediction"], columns=gt_piv.columns)
        df["target"] = target_dict[target]
        df["target_end_date"] = target    
        all_targets.append(df)

    
    return pd.concat(all_targets).reset_index(names="wis_type").set_index(["target", "target_end_date"])

In [4]:

# "CADPH-FluCAT_Ensemble",
flusight_model_list = [ # %ls Flusight/Flusight-forecast-data/data-forecasts
 "LUcompUncertLab-humanjudgment",
"CEID-Walk", "LUcompUncertLab-stacked_ili",
"CEPH-Rtrend_fluH", #"LosAlamos_NAU-CModel_Flu",
"CMU-TimeSeries", #"METADATA.m",
"CU-ensemble", "MIGHTE-Nsemble",
"Flusight-baseline", "MOBS-GLEAM_FLUH",
"Flusight-ensemble", "NIH-Flu_ARIMA",
"GH-Flusight", "PSI-DICE",
"GT-FluFNP", #"README.m",
"IEM_Health-FluProject", "SGroup-RandomForest",
"ISU_NiemiLab-Flu", "SGroup-SIkJalpha",
"JHUAPL-Gecko", #"SigSci-CREG",
"JHU_IDD-CovidSP", #"SigSci-TSENS",
"LUcompUncertLab-HWAR2", "UGA_flucast-OKeeffe",
"LUcompUncertLab-KalmanFilter", "UGuelph-FluPLUG",
"LUcompUncertLab-TEVA", "UMass-gbq",
"LUcompUncertLab-VAR2", "UMass-trends_ensemble",
"LUcompUncertLab-VAR2K", "UNC_IDD-InfluPaint",
"LUcompUncertLab-VAR2K_plusCOVID", "UT_FluCast-Voltaire",
"LUcompUncertLab-VAR2_plusCOVID", "UVAFluX-Ensemble",
"LUcompUncertLab-ensemble_rclp", "Umass-ARIMA",
#"LUcompUncertLab-experthuman", "VTSanghani-ExogModel",
"LUcompUncertLab-hier_mech_model", "VTSanghani-Transformer"]

model_list = ["UNC_IDD-InfluPaint"]
model_list = flusight_model_list

In [26]:
len(fdates)

9

In [16]:
6*32/4

48.0

In [20]:
fdates = pd.date_range("2022-11-14", "2023-05-15", freq="3W-MON")

gt = pd.read_csv("Flusight/Flusight-forecast-data/data-truth/truth-Incident Hospitalizations.csv")


scores = {}
for model in model_list:
    
    skipped = []
    scores[model] = {}
    for date in fdates:
        date = date.date()
        try:
            forecasts = pd.read_csv(f"Flusight/Flusight-forecast-data/data-forecasts/{model}/{str(date)}-{model}.csv")
            #forecasts = pd.read_csv(f"Flusight/Flusight-forecast-data/data-forecasts/JHU_IDD-covidSP/{str(date)}-JHU_IDD-covidSP.csv")
            #forecasts = pd.read_csv(f"Flusight/Flusight-forecast-data/data-forecasts/MOBS-GLEAM_FLUH/{str(date)}-MOBS-GLEAM_FLUH.csv")
            forecasts = forecasts[forecasts["type"]=="quantile"]
            this_date=True
        except FileNotFoundError:
            skipped.append(date)
            this_date=False
        if this_date:
            wis_all = score_Nwk_forecasts(gt, forecasts)
            scores[model][date] = wis_all
        
    
    if len(skipped) < 5:
        print(f"Adding {model}")
        if len(skipped): print(f">> skipped {','.join([str(i) for i in skipped])}")
        scores[model] = pd.concat(scores[model], names=["forecast_date", "target", "target_end_date"])
    else:
        scores.pop(model)
        #print(f">> Too many skipped, removing")

Adding LUcompUncertLab-humanjudgment
>> skipped 2022-11-14,2023-03-20,2023-04-10,2023-05-01
Adding CEPH-Rtrend_fluH
Adding CMU-TimeSeries
Adding CU-ensemble
>> skipped 2022-12-26
Adding MIGHTE-Nsemble
Adding Flusight-baseline
Adding MOBS-GLEAM_FLUH
Adding Flusight-ensemble
Adding NIH-Flu_ARIMA
>> skipped 2023-02-27,2023-05-01
Adding PSI-DICE
Adding GT-FluFNP
Adding SGroup-RandomForest
>> skipped 2022-12-26
Adding ISU_NiemiLab-Flu
>> skipped 2023-04-10,2023-05-01
Adding JHU_IDD-CovidSP
>> skipped 2022-12-26
Adding UGA_flucast-OKeeffe
>> skipped 2022-12-26
Adding UMass-trends_ensemble
Adding UNC_IDD-InfluPaint
Adding UVAFluX-Ensemble
>> skipped 2023-05-01


In [21]:
all_scores = pd.concat(scores, names=["model", "forecast_date", "target", "target_end_date"])
wis_total = all_scores[all_scores["wis_type"] == "wis_total"].drop("wis_type", axis=1)
wis_total  = pd.melt(wis_total , var_name="location", value_name="wis_total",ignore_index=False).reset_index()

wis_underprediction = all_scores[all_scores["wis_type"] == "wis_underprediction"].drop("wis_type", axis=1)
wis_underprediction = pd.melt(wis_underprediction , var_name="location", value_name="wis_underprediction",ignore_index=False).reset_index()

wis_overprediction = all_scores[all_scores["wis_type"] == "wis_overprediction"].drop("wis_type", axis=1)
wis_overprediction = pd.melt(wis_overprediction , var_name="location", value_name="wis_overprediction",ignore_index=False).reset_index()

wis_sharpness = all_scores[all_scores["wis_type"] == "wis_sharpness"].drop("wis_type", axis=1)
wis_sharpness = pd.melt(wis_sharpness , var_name="location", value_name="wis_sharpness",ignore_index=False).reset_index()

wis_total 

Unnamed: 0,model,forecast_date,target,target_end_date,location,wis_total
0,LUcompUncertLab-humanjudgment,2022-12-05,1 wk ahead,2022-12-10,01,424.453978
1,LUcompUncertLab-humanjudgment,2022-12-05,2 wk ahead,2022-12-17,01,871.445129
2,LUcompUncertLab-humanjudgment,2022-12-05,3 wk ahead,2022-12-24,01,927.082232
3,LUcompUncertLab-humanjudgment,2022-12-05,4 wk ahead,2022-12-31,01,1069.695353
4,LUcompUncertLab-humanjudgment,2022-12-26,1 wk ahead,2022-12-31,01,673.532752
...,...,...,...,...,...,...
32179,UVAFluX-Ensemble,2023-03-20,4 wk ahead,2023-04-15,US,399.641256
32180,UVAFluX-Ensemble,2023-04-10,1 wk ahead,2023-04-15,US,410.594046
32181,UVAFluX-Ensemble,2023-04-10,2 wk ahead,2023-04-22,US,487.871228
32182,UVAFluX-Ensemble,2023-04-10,3 wk ahead,2023-04-29,US,609.518295


In [8]:
a = wis_total[["wis_total","model"]].groupby("model").sum().sort_values(by="wis_total")
a

Unnamed: 0_level_0,wis_total
model,Unnamed: 1_level_1
MOBS-GLEAM_FLUH,1541851.0
CMU-TimeSeries,1802241.0
PSI-DICE,1967445.0
NIH-Flu_ARIMA,1994083.0
Flusight-ensemble,2083896.0
MIGHTE-Nsemble,2150634.0
SGroup-RandomForest,2158199.0
GT-FluFNP,2273337.0
UMass-trends_ensemble,2352290.0
CEPH-Rtrend_fluH,2464946.0


In [23]:
a = wis_total[["wis_total","model"]].groupby("model").sum().sort_values(by="wis_total")
a

Unnamed: 0_level_0,wis_total
model,Unnamed: 1_level_1
NIH-Flu_ARIMA,534734.9
CU-ensemble,585450.0
MOBS-GLEAM_FLUH,657359.2
CMU-TimeSeries,662484.5
SGroup-RandomForest,685044.0
GT-FluFNP,783157.3
PSI-DICE,784792.9
UGA_flucast-OKeeffe,794882.0
LUcompUncertLab-humanjudgment,845172.4
Flusight-ensemble,855282.9


In [None]:
all_scores

In [None]:
tp = wis_total[wis_total["location"]=="US"].pivot(values="wis_total", index="target", columns="forecast_date")

f, ax = plt.subplots(figsize=(9, 6))
sns.heatmap(np.log(tp), annot=False, fmt="", linewidths=1, ax=ax)

In [None]:
tp = wis_total[wis_total["target"]=="1 wk ahead"].pivot(values="wis_total", index="location", columns="forecast_date")
tp = np.log(tp)
f, ax = plt.subplots(figsize=(9, 12))
sns.heatmap(np.log(tp), annot=False, fmt="", linewidths=1, ax=ax)

In [None]:
# tp = wis_total.pivot(values="wis_total", index="location", columns=["forecast_date","target"])
# tp = np.log(tp)
# f, ax = plt.subplots(figsize=(12, 12), dpi=300)
# sns.heatmap(np.log(tp), annot=False, fmt="", linewidths=1, ax=ax)

In [None]:
tp1 = wis_underprediction[wis_underprediction["location"]=="US"].pivot(values="wis_underprediction", index="target", columns="forecast_date")
tp2 = wis_overprediction[wis_overprediction["location"]=="US"].pivot(values="wis_overprediction", index="target", columns="forecast_date")

tp1 = wis_sharpness[wis_sharpness["location"]=="US"].pivot(values="wis_sharpness", index="target", columns="forecast_date")
tp2 = wis_total[wis_total["location"]=="US"].pivot(values="wis_total", index="target", columns="forecast_date")

f, ax = plt.subplots(figsize=(9, 6))
sns.heatmap(tp1/tp2, annot=False, fmt="", linewidths=1, ax=ax)
print((tp1/tp2).mean().mean())

In [None]:
tp1 = wis_underprediction[wis_underprediction["location"]=="US"].pivot(values="wis_underprediction", index="target", columns="forecast_date")
#tp2 = wis_overprediction[wis_overprediction["location"]=="US"].pivot(values="wis_overprediction", index="target", columns="forecast_date")

tp2 = wis_total[wis_total["location"]=="US"].pivot(values="wis_total", index="target", columns="forecast_date")

f, ax = plt.subplots(figsize=(9, 6))
sns.heatmap(tp1/tp2, annot=False, fmt="", linewidths=1, ax=ax)
print((tp1/tp2).mean().mean())

In [None]:
tp1 = wis_overprediction[wis_overprediction["location"]=="US"].pivot(values="wis_overprediction", index="target", columns="forecast_date")

tp2 = wis_total[wis_total["location"]=="US"].pivot(values="wis_total", index="target", columns="forecast_date")

f, ax = plt.subplots(figsize=(9, 6))
sns.heatmap(tp1/tp2, annot=False, fmt="", linewidths=1, ax=ax)
print((tp1/tp2).mean().mean())