In [None]:
from src.silence import silence
silence()

from pathlib import Path
from typing import Dict, List, Literal

import pandas as pd
from darts import concatenate
from darts.models import LightGBMModel, TSMixerModel

from config import FORECAST_DATES, HORIZON, NUM_SAMPLES, RANDOM_SEEDS, ROOT
from src.load_data import encode_static_covariates, reshape_forecast
from src.realtime_utils import (
    load_nowcast,
    load_realtime_training_data,
    make_target_paths,
)

In [None]:
Mode = Literal["naive", "coupling", "discard", "oracle"]

MODEL_REGISTRY = {
    "lightgbm": LightGBMModel,
    "tsmixer": TSMixerModel,
}


def load_model(path: Path, disable_progress_bar=True):
    model_name = path.name[11:].rsplit("-", 1)[0]
    model_family = model_name.split("-", 1)[0]
    model = MODEL_REGISTRY[model_family].load(str(path))

    if disable_progress_bar:
        kwargs = getattr(model, "trainer_params", {}) or {}
        kwargs["enable_progress_bar"] = False
        model.trainer_params = kwargs

    return model


# core (pure, no I/O)
def compute_forecast(
    model,
    *,
    # preloaded inputs
    targets=None,  # as-of targets (for "naive", "coupling", "discard")
    covariates=None,  # as-of covariates
    ts_nowcast=None,  # preloaded nowcast (for "coupling", "discard")
    complete_targets=None,  # fully corrected truth (only for "oracle")
    # meta
    forecast_date=None,
    mode: Mode = "naive",
    # defaults
    horizon: int = HORIZON,
    num_samples: int = NUM_SAMPLES,
) -> pd.DataFrame:
    """
    Pure forecasting wrapper (no I/O).
      - 'naive': use uncorrected as-of targets/covariates as provided.
      - 'coupling': build sample-path targets from as-of targets + nowcast.
      - 'discard': like 'coupling' but drop the last data point
      - 'oracle': truncate fully corrected targets at forecast_date (no nowcast).
    """
    if mode not in ("naive", "coupling", "discard", "oracle"):
        raise ValueError("mode must be one of {'naive','coupling','discard','oracle'}.")

    if mode == "naive":
        series_for_model = targets
        covs_for_model = covariates

    elif mode in {"coupling", "discard"}:
        if targets is None or ts_nowcast is None:
            raise ValueError("coupling/discard require `targets` (as-of) and `ts_nowcast`.")
        target_list = make_target_paths(targets, ts_nowcast)
        target_list = [encode_static_covariates(t, ordinal=False) for t in target_list]
        if mode == "discard":
            target_list = [t[:-1] for t in target_list]  # discard last data point
        series_for_model = target_list
        covs_for_model = [covariates] * len(target_list) if covariates is not None else None

    else:  # "oracle"
        if complete_targets is None:
            raise ValueError("oracle requires `complete_targets` (fully corrected).")
        ts_cut = complete_targets[: pd.Timestamp(forecast_date)]
        ts_cut = encode_static_covariates(ts_cut, ordinal=False)
        series_for_model = ts_cut
        covs_for_model = covariates

    fct = model.predict(
        n=horizon,
        series=series_for_model,
        past_covariates=covs_for_model,
        num_samples=num_samples,
    )

    ts_forecast = concatenate(fct, axis="sample") if isinstance(fct, list) else fct
    df = reshape_forecast(ts_forecast)

    df["forecast_date"] = pd.Timestamp(forecast_date)
    if mode == "discard":
        df["horizon"] = df["horizon"] - 1
    return df


# helpers
def aggregate_runs(dfs: List[pd.DataFrame]) -> pd.DataFrame:
    return (
        pd.concat(dfs, ignore_index=True)
        .groupby(
            ["location", "age_group", "forecast_date", "target_end_date", "horizon", "type", "quantile"], as_index=False
        )["value"]
        .mean()
        .sort_values(["location", "age_group", "horizon", "quantile"])
    )


def save_csv(df: pd.DataFrame, out_dir: Path, filename: str) -> None:
    out_dir.mkdir(parents=True, exist_ok=True)
    df.to_csv(out_dir / filename, index=False)


# one date → all modes (load once, reuse across modes & seeds)
def compute_ensembles_for_date(
    forecast_date: str,
    model_name: str,
    *,
    modes: List[Mode] = ("naive", "coupling", "discard", "oracle"),
    seeds=RANDOM_SEEDS,
    export: bool = True,
    out_root: Path = ROOT / "forecasts",
    probabilistic_nowcast: bool = True,
    local: bool = True,
    nowcast_model: str = "simple_nowcast",
) -> Dict[Mode, pd.DataFrame]:
    """
    Loads inputs once for this date, loads each model once per seed, runs all modes,
    aggregates across seeds per mode, and optionally exports.
    Returns: {mode: ensembled_df}
    """
    # As-of data once per date
    targets_asof, covars_asof = load_realtime_training_data(as_of=forecast_date, drop_incomplete=False)

    # Only load complete targets if 'oracle' is requested
    if "oracle" in modes:
        complete_targets, _ = load_realtime_training_data()

    # Nowcast once per date (used by 'coupling' and 'discard')
    if any(m in ("coupling", "discard") for m in modes):
        indicator = targets_asof.components[0].split("-")[1]
        ts_now = load_nowcast(
            forecast_date=forecast_date,
            probabilistic=probabilistic_nowcast,
            indicator=indicator,
            local=local,
            model=nowcast_model,
        )

    per_mode_runs: Dict[Mode, List[pd.DataFrame]] = {m: [] for m in modes}

    for seed in seeds:
        model_path = Path("../models") / forecast_date / f"{forecast_date}-{model_name}-{seed}.pt"
        model = load_model(model_path)

        # If the model was trained without past covariates, disable them by setting to None
        if not model.uses_past_covariates:
            covars_asof = None

        if "naive" in modes:
            per_mode_runs["naive"].append(
                compute_forecast(
                    model,
                    targets=targets_asof,
                    covariates=covars_asof,
                    forecast_date=forecast_date,
                    mode="naive",
                )
            )
        if "coupling" in modes:
            per_mode_runs["coupling"].append(
                compute_forecast(
                    model,
                    targets=targets_asof,
                    covariates=covars_asof,
                    ts_nowcast=ts_now,
                    forecast_date=forecast_date,
                    mode="coupling",
                )
            )
        if "discard" in modes:
            per_mode_runs["discard"].append(
                compute_forecast(
                    model,
                    targets=targets_asof,
                    covariates=covars_asof,
                    ts_nowcast=ts_now,
                    forecast_date=forecast_date,
                    mode="discard",
                )
            )
        if "oracle" in modes:
            per_mode_runs["oracle"].append(
                compute_forecast(
                    model,
                    complete_targets=complete_targets,
                    covariates=covars_asof,
                    forecast_date=forecast_date,
                    mode="oracle",
                )
            )

    ensembled: Dict[Mode, pd.DataFrame] = {m: aggregate_runs(per_mode_runs[m]) for m in modes}

    if export:
        for m, df in ensembled.items():
            out_dir = out_root / f"{model_name}-{m}"
            fname = f"{forecast_date}-icosari-sari-{model_name}-{m}.csv"
            save_csv(df, out_dir, fname)

    return ensembled


# multiple dates and modes
def run_all_dates_all_modes(
    forecast_dates: List[str],
    model_name: str = "lightgbm",
    *,
    modes: List[Mode] = ("naive", "coupling", "discard", "oracle"),
    seeds=RANDOM_SEEDS,
    export: bool = True,
) -> Dict[tuple, pd.DataFrame]:
    results: Dict[tuple, pd.DataFrame] = {}
    for fd in forecast_dates:
        print(f"→ {fd}")
        out = compute_ensembles_for_date(fd, model_name, modes=modes, seeds=seeds, export=export)
        for m, df in out.items():
            results[(fd, m)] = df
    return results


In [None]:
run_all_dates_all_modes(FORECAST_DATES, "tsmixer")

In [None]:
run_all_dates_all_modes(FORECAST_DATES, "tsmixer-no_covariates", modes=["coupling"])

In [None]:
run_all_dates_all_modes(FORECAST_DATES, "tsmixer-no_covariates", modes=["naive", "discard", "oracle"])

In [None]:
run_all_dates_all_modes(FORECAST_DATES, "tsmixer-no_covid", modes=["coupling"])

In [None]:
run_all_dates_all_modes(FORECAST_DATES, "lightgbm-no_covariates", modes=["coupling"])

In [None]:
run_all_dates_all_modes(FORECAST_DATES, "lightgbm-no_covid", modes=["coupling"])