In [None]:
from src.silence import silence
silence()

from pathlib import Path
from typing import Dict, List, Sequence, Union, Tuple

import pandas as pd
from darts import concatenate
from darts.models import LightGBMModel, TSMixerModel
from tqdm import tqdm

from config import (
    ALLOWED_DATA_MODES,
    ALLOWED_MODELS,
    ALLOWED_MODES,
    DATA_MODE_CONFIG,
    ENCODERS,
    FORECAST_DATES,
    HORIZON,
    NUM_SAMPLES,
    QUANTILES,
    RANDOM_SEEDS,
    ROOT,
    SHARED_ARGS,
    DataMode,
    Mode,
    ModelName,
)
from src.load_data import encode_static_covariates, reshape_forecast
from src.realtime_utils import (
    load_nowcast,
    load_realtime_training_data,
    make_target_paths,
)
from src.tuning import exclude_covid_weights, get_best_parameters

In [None]:
MODEL_REGISTRY = {
    "lightgbm": LightGBMModel,
    "tsmixer": TSMixerModel,
}


# def load_model(path: Path, disable_progress_bar=True):
#     model_name = path.name[11:].rsplit("-", 1)[0]
#     model_family = model_name.split("-", 1)[0]
#     model = MODEL_REGISTRY[model_family].load(str(path))

#     if disable_progress_bar:
#         kwargs = getattr(model, "trainer_params", {}) or {}
#         kwargs["enable_progress_bar"] = False
#         model.trainer_params = kwargs

#     return model


# core (pure, no I/O)
def compute_forecast(
    model,
    *,
    # preloaded inputs
    targets=None,  # as-of targets (for "naive", "coupling", "discard")
    covariates=None,  # as-of covariates
    ts_nowcast=None,  # preloaded nowcast (for "coupling", "discard")
    complete_targets=None,  # fully corrected truth (only for "oracle")
    # meta
    forecast_date=None,
    mode: Mode = "naive",
    # defaults
    horizon: int = HORIZON,
    num_samples: int = NUM_SAMPLES,
) -> pd.DataFrame:
    """
    Pure forecasting wrapper (no I/O).
      - 'naive': use uncorrected as-of targets/covariates as provided.
      - 'coupling': build sample-path targets from as-of targets + nowcast.
      - 'discard': like 'coupling' but drop the last data point
      - 'oracle': truncate fully corrected targets at forecast_date (no nowcast).
    """
    if mode not in ("naive", "coupling", "discard", "oracle"):
        raise ValueError("mode must be one of {'naive','coupling','discard','oracle'}.")

    if mode == "naive":
        series_for_model = targets
        covs_for_model = covariates

    elif mode in {"coupling", "discard"}:
        if targets is None or ts_nowcast is None:
            raise ValueError("coupling/discard require `targets` (as-of) and `ts_nowcast`.")
        target_list = make_target_paths(targets, ts_nowcast)
        target_list = [encode_static_covariates(t, ordinal=False) for t in target_list]
        if mode == "discard":
            target_list = [t[:-1] for t in target_list]  # discard last data point
        series_for_model = target_list
        covs_for_model = [covariates] * len(target_list) if covariates is not None else None

    else:  # "oracle"
        if complete_targets is None:
            raise ValueError("oracle requires `complete_targets` (fully corrected).")
        ts_cut = complete_targets[: pd.Timestamp(forecast_date)]
        ts_cut = encode_static_covariates(ts_cut, ordinal=False)
        series_for_model = ts_cut
        covs_for_model = covariates

    fct = model.predict(
        n=horizon,
        series=series_for_model,
        past_covariates=covs_for_model,
        num_samples=num_samples,
    )

    ts_forecast = concatenate(fct, axis="sample") if isinstance(fct, list) else fct
    df = reshape_forecast(ts_forecast)

    df["forecast_date"] = pd.Timestamp(forecast_date)
    if mode == "discard":
        df["horizon"] = df["horizon"] - 1
    return df


# helpers
def aggregate_runs(dfs: List[pd.DataFrame]) -> pd.DataFrame:
    return (
        pd.concat(dfs, ignore_index=True)
        .groupby(
            ["location", "age_group", "forecast_date", "target_end_date", "horizon", "type", "quantile"], as_index=False
        )["value"]
        .mean()
        .sort_values(["location", "age_group", "horizon", "quantile"])
    )


def save_csv(df: pd.DataFrame, out_dir: Path, filename: str) -> None:
    out_dir.mkdir(parents=True, exist_ok=True)
    df.to_csv(out_dir / filename, index=False)



In [7]:
def fit_model(model, targets, covariates, params, use_covariates, use_encoders, weights, seed):
    if model == "lightgbm":
        mdl = LightGBMModel(
            **params,
            output_chunk_length=HORIZON,
            add_encoders=ENCODERS if use_encoders else None,
            likelihood="quantile",
            quantiles=QUANTILES,
            verbose=-1,
            random_state=seed,
        )
        mdl.fit(
            targets,
            past_covariates=covariates if use_covariates else None,
            sample_weight=weights,
        )

    elif model == "tsmixer":
        mdl = TSMixerModel(
            **params,
            add_encoders=ENCODERS if use_encoders else None,
            **SHARED_ARGS,
            random_state=seed,
        )
        mdl.fit(
            targets,
            past_covariates=covariates if use_covariates else None,
            sample_weight=weights,
            dataloader_kwargs={"pin_memory": False},
        )
    return mdl


# one date → all modes (load once, reuse across modes & seeds)
def generate_forecasts_for_date(
    model: ModelName,
    forecast_date: str,
    *,
    modes: Union[Mode, Sequence[Mode]] = ("naive", "coupling", "discard", "oracle"),
    data_mode: DataMode = "all",
    seeds=RANDOM_SEEDS,
    save_model: bool = False,
):
    """
    Loads inputs once for this date, trains each model once per seed, runs all modes,
    aggregates across seeds per mode, and exports one csv per mode.
    Returns: {mode: ensembled_df}
    """

    if model not in ALLOWED_MODELS:
        raise ValueError(f"Invalid model: {model!r}. Allowed values: {sorted(ALLOWED_MODELS)}")
    modes = [modes] if isinstance(modes, str) else modes
    if any(m not in ALLOWED_MODES for m in modes):
        raise ValueError(f"Invalid mode(s): {modes!r}. Allowed values: {sorted(ALLOWED_MODES)}")
    if data_mode not in ALLOWED_DATA_MODES:
        raise ValueError(f"Invalid data_mode: {data_mode!r}. Allowed values: {sorted(ALLOWED_DATA_MODES)}")

    model_name = model if data_mode == "all" else f"{model}-{data_mode}"

    use_covariates, sample_weight = DATA_MODE_CONFIG[data_mode]

    # pick best hyperparams for this family (optionally filtered)
    params, wis = get_best_parameters(
        model, use_covariates=use_covariates, sample_weight=sample_weight, clean=True, return_score=True
    )
    use_encoders = params.pop("use_encoders")

    print(
        f"\n=== Training config ===\n"
        f"  model          : {model_name}\n"
        f"  use_covariates : {use_covariates}\n"
        f"  sample_weight  : {sample_weight}\n"
        f"  modes          : {modes}\n"
        f"  forecast_date  : {forecast_date}\n"
        f"  seeds          : {min(seeds)} → {max(seeds)} (n={len(seeds)})\n"
        f"=======================\n"
        f"  Parameters:"
    )
    for key, value in params.items():
        print(f"    {key}: {value}")
    print(f"\n  Validation score : {wis:.3f}\n=======================\n")

    # Training data (only complete data points)
    targets_train, covars_train = load_realtime_training_data(as_of=forecast_date, drop_incomplete=True)
    weights = exclude_covid_weights(targets_train) if sample_weight == "no-covid" else sample_weight

    # As-of data for prediction (with incomplete data points)
    targets_asof, covars_asof = load_realtime_training_data(as_of=forecast_date, drop_incomplete=False)

    # Only load final target data if 'oracle' is requested
    if "oracle" in modes:
        complete_targets, _ = load_realtime_training_data()

    # Nowcast once per date (used by 'coupling' and 'discard')
    if any(m in ("coupling", "discard") for m in modes):
        ts_now = load_nowcast(forecast_date=forecast_date)

    per_mode_runs: Dict[Mode, List[pd.DataFrame]] = {m: [] for m in modes}

    for seed in tqdm(seeds, desc=f"{forecast_date}", leave=False):
        # Train models
        model_path = Path("../models") / forecast_date / f"{forecast_date}-{model_name}-{seed}.pt"
        mdl = fit_model(model, targets_train, covars_train, params, use_covariates, use_encoders, weights, seed)

        if save_model:
            path = ROOT / "models" / forecast_date  # TODO move outside
            path.mkdir(parents=True, exist_ok=True)
            mdl.save(str(model_path), clean=True)  # Darts .save only accepts str

        # If the model was trained without past covariates, disable them by setting to None
        if not mdl.uses_past_covariates:
            covars_asof = None

        if "naive" in modes:
            per_mode_runs["naive"].append(
                compute_forecast(
                    mdl,
                    targets=targets_asof,
                    covariates=covars_asof,
                    forecast_date=forecast_date,
                    mode="naive",
                )
            )
        if "coupling" in modes:
            per_mode_runs["coupling"].append(
                compute_forecast(
                    mdl,
                    targets=targets_asof,
                    covariates=covars_asof,
                    ts_nowcast=ts_now,
                    forecast_date=forecast_date,
                    mode="coupling",
                )
            )
        if "discard" in modes:
            per_mode_runs["discard"].append(
                compute_forecast(
                    mdl,
                    targets=targets_asof,
                    covariates=covars_asof,
                    ts_nowcast=ts_now,
                    forecast_date=forecast_date,
                    mode="discard",
                )
            )
        if "oracle" in modes:
            per_mode_runs["oracle"].append(
                compute_forecast(
                    mdl,
                    complete_targets=complete_targets,
                    covariates=covars_asof,
                    forecast_date=forecast_date,
                    mode="oracle",
                )
            )

    # Aggregate across seeds per mode
    ensembled: Dict[Mode, pd.DataFrame] = {m: aggregate_runs(per_mode_runs[m]) for m in modes}

    for m, df in ensembled.items():
        out_dir = ROOT / "forecasts" / f"{model_name}-{m}"
        fname = f"{forecast_date}-icosari-sari-{model_name}-{m}.csv"
        save_csv(df, out_dir, fname)

In [None]:
generate_forecasts_for_date("lightgbm", "2023-11-16")

In [None]:
generate_forecasts_for_date("lightgbm", "2023-11-16", seeds=[1], save_model=True)

In [10]:
generate_forecasts_for_date("lightgbm", "2023-11-16", data_mode='no_covariates', seeds=[1], save_model=True)


=== Training config ===
  model          : lightgbm-no_covariates
  use_covariates : False
  sample_weight  : linear
  modes          : ('naive', 'coupling', 'discard', 'oracle')
  forecast_date  : 2023-11-16
  seeds          : 1 → 1 (n=1)
  Parameters:
    colsample_bytree: 0.8
    lags: 8
    lags_future_covariates: (0, 1)
    learning_rate: 0.01
    max_bin: 1024
    max_depth: -1
    min_child_samples: 20
    min_split_gain: 0.0
    n_estimators: 1000
    num_leaves: 20
    reg_alpha: 1.0
    reg_lambda: 0.5
    subsample: 0.8
    subsample_freq: 1
    use_static_covariates: False

  Validation score : 450.189



                                                          

In [8]:
def generate_forecasts_for_all_dates(
    model: ModelName,
    forecast_dates: Sequence[str] = FORECAST_DATES,
    *,
    modes: Union[Mode, Sequence[Mode]] = ("naive", "coupling", "discard", "oracle"),
    data_mode: DataMode = "all",
    seeds=RANDOM_SEEDS,
    save_model: bool = False,
) -> None:
    """
    Run `generate_forecasts_for_date` for each date in `forecast_dates`.
    Exports one CSV per date × mode combination. Prints summary of errors if any.
    """
    failed: list[Tuple[str, str]] = []

    for fd in forecast_dates:
        print(f"→ {fd}")
        try:
            generate_forecasts_for_date(
                model=model,
                forecast_date=fd,
                modes=modes,
                data_mode=data_mode,
                seeds=seeds,
                save_model=save_model,
            )
        except Exception as e:
            failed.append((fd, f"{type(e).__name__}: {e}"))
            print(f"[{fd}] ABORTED — {type(e).__name__}: {e}")

    if failed:
        print("\nCompleted with errors — the following dates failed:")
        for d, reason in failed:
            print(f"  {d}: {reason}")
    else:
        print("\nAll dates completed successfully.")

In [None]:

# multiple dates and modes
def run_all_dates_all_modes(
    forecast_dates: List[str],
    model_name: str = "lightgbm",
    *,
    modes: List[Mode] = ("naive", "coupling", "discard", "oracle"),
    seeds=RANDOM_SEEDS,
    export: bool = True,
) -> Dict[tuple, pd.DataFrame]:
    results: Dict[tuple, pd.DataFrame] = {}
    for fd in forecast_dates:
        print(f"→ {fd}")
        out = compute_ensembles_for_date(fd, model_name, modes=modes, seeds=seeds, export=export)
        for m, df in out.items():
            results[(fd, m)] = df
    return results


In [None]:
run_all_dates_all_modes(FORECAST_DATES, "tsmixer")

In [None]:
run_all_dates_all_modes(FORECAST_DATES, "tsmixer-no_covariates", modes=["coupling"])

In [None]:
run_all_dates_all_modes(FORECAST_DATES, "tsmixer-no_covariates", modes=["naive", "discard", "oracle"])

In [None]:
run_all_dates_all_modes(FORECAST_DATES, "tsmixer-no_covid", modes=["coupling"])

In [None]:
run_all_dates_all_modes(FORECAST_DATES, "lightgbm-no_covariates", modes=["coupling"])

In [None]:
run_all_dates_all_modes(FORECAST_DATES, "lightgbm-no_covid", modes=["coupling"])