# Electricity Demand Forecasting — Classical Time Series Models (Darts)
**Data**: Monthly electricity consumption (KWh) | Jan 2019 – Jan 2026  
**Goal**: Forecast next 3 months using classical time series models tracked via MLflow  
**Library**: [Darts](https://unit8co.github.io/darts/) by Unit8  
**Model Progression**: Baseline → SES → Holt's Linear → Holt-Winters (4 variants) → ARIMA (auto) → SARIMA (auto + manual candidates)


## 1. Imports & Configuration

In [None]:
import warnings
import pickle
import os
import tempfile

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import seaborn as sns

# Darts core
from darts import TimeSeries
from darts.models import ExponentialSmoothing, ARIMA, AutoARIMA
from darts.utils.utils import ModelMode, SeasonalityMode
from darts.metrics import mae, rmse, mape

# Stats helpers (for EDA/checks — not for model building)
from statsmodels.tsa.seasonal import seasonal_decompose
from statsmodels.tsa.stattools import adfuller
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
from statsmodels.stats.diagnostic import acorr_ljungbox

# MLflow
import mlflow

warnings.filterwarnings("ignore")
plt.rcParams.update({"figure.dpi": 120, "figure.figsize": (12, 4)})

# ── Global constants ──────────────────────────────────────────────────────────
FORECAST_HORIZON  = 3    # months ahead to forecast
SEASONAL_PERIOD   = 12   # monthly data → annual seasonality
TRAIN_END_DATE    = "2025-09-01"
MLFLOW_TRACKING_URI = "mlruns"

mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
print("MLflow tracking URI:", mlflow.get_tracking_uri())


## 2. Data Loading & Preprocessing

In [None]:
RAW_DATA = {
    "month_year": [
        "2019-01-01","2019-02-01","2019-03-01","2019-04-01","2019-05-01","2019-06-01",
        "2019-07-01","2019-08-01","2019-09-01","2019-10-01","2019-11-01","2019-12-01",
        "2020-01-01","2020-02-01","2020-03-01","2020-04-01","2020-05-01","2020-06-01",
        "2020-07-01","2020-08-01","2020-09-01","2020-10-01","2020-11-01","2020-12-01",
        "2021-01-01","2021-02-01","2021-03-01","2021-04-01","2021-05-01","2021-06-01",
        "2021-07-01","2021-08-01","2021-09-01","2021-10-01","2021-11-01","2021-12-01",
        "2022-01-01","2022-02-01","2022-03-01","2022-04-01","2022-05-01","2022-06-01",
        "2022-07-01","2022-08-01","2022-09-01","2022-10-01","2022-11-01","2022-12-01",
        "2023-01-01","2023-02-01","2023-03-01","2023-04-01","2023-05-01","2023-06-01",
        "2023-07-01","2023-08-01","2023-09-01","2023-10-01","2023-11-01","2023-12-01",
        "2024-01-01","2024-02-01","2024-03-01","2024-04-01","2024-05-01","2024-06-01",
        "2024-07-01","2024-08-01","2024-09-01","2024-10-01","2024-11-01","2024-12-01",
        "2025-01-01","2025-02-01","2025-03-01","2025-04-01","2025-05-01","2025-06-01",
        "2025-07-01","2025-08-01","2025-09-01","2025-10-01","2025-11-01","2025-12-01",
        "2026-01-01",
    ],
    "total_units_kwh": [
        199722489.0,196282917.0,229630601.0,307610021.0,378333246.0,386806926.0,
        318453165.0,303440777.0,299577754.0,302758063.0,287854703.0,249393096.0,
        226911059.0,233987135.0,249055256.0,311844901.0,378972626.0,371104608.0,
        362331096.0,376818006.5,349757651.6,342309639.0,317955772.0,269855118.0,
        243155839.6,253523091.0,263214583.0,339192666.0,405477414.0,2158077.0,
        368050250.0,384068332.0,361304647.0,350300159.0,326220132.17,284013708.0,
        243976433.0,236413333.11,260031223.0,371437583.28,424833878.1,443537503.0,
        376807680.32,322593689.64,351653428.59,332648965.59,294138231.0,249154006.57,
        259018866.16,255718830.0,262851784.0,341098867.0,373179111.0,468056532.0,
        414786578.0,340837857.0,380861584.0,348924442.0,346219332.0,295555199.0,
        240950774.0,288656864.0,296623052.0,389765964.0,467155674.0,276370652.0,
        429853045.0,385815658.0,417634123.0,388451915.0,402365044.0,301822337.0,
        292729563.0,303704651.0,328929748.0,427534816.0,485839957.0,498140707.0,
        442414161.0,432117294.0,416787169.0,414938105.0,390481404.0,290133269.0,
        273816782.0,
    ],
}

df = pd.DataFrame(RAW_DATA)
df["month_year"] = pd.to_datetime(df["month_year"])
df = df.set_index("month_year")
df.index.freq = "MS"

print(f"Shape     : {df.shape}")
print(f"Date range: {df.index.min().date()} → {df.index.max().date()}")
print(f"Null count: {df.isnull().sum().values[0]}")
df.head()


## 3. Outlier Fix — 2021-06

In [None]:
OUTLIER_DATE = "2021-06-01"
print(f"Original 2021-06 value : {df.loc[OUTLIER_DATE, 'total_units_kwh']:,.0f} KWh  ← data entry error")

df.loc[OUTLIER_DATE, "total_units_kwh"] = np.nan
df["total_units_kwh"] = df["total_units_kwh"].interpolate(method="time")

print(f"Imputed  2021-06 value : {df.loc[OUTLIER_DATE, 'total_units_kwh']:,.0f} KWh  ← time-interpolated")

fig, ax = plt.subplots()
ax.plot(df.index, df["total_units_kwh"], linewidth=1.5, color="steelblue")
ax.axvline(pd.Timestamp(OUTLIER_DATE), color="red", linestyle="--", label="Outlier fixed (2021-06)")
ax.set_title("Monthly Electricity Consumption — Cleaned Series")
ax.set_ylabel("KWh")
ax.yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f"{x/1e6:.0f}M"))
ax.legend()
plt.tight_layout()
plt.savefig("01_cleaned_series.png", bbox_inches="tight")
plt.show()


## 4. Exploratory Data Analysis

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(12, 11))

# Full series
axes[0].plot(df.index, df["total_units_kwh"], color="steelblue", linewidth=1.5)
axes[0].set_title("Full Time Series (Cleaned)")
axes[0].set_ylabel("KWh")
axes[0].yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f"{x/1e6:.0f}M"))

# Monthly seasonality boxplot
df_eda = df.copy()
month_labels = ["Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec"]
df_eda["month_name"] = df_eda.index.month.map(lambda m: month_labels[m - 1])
sns.boxplot(
    data=df_eda, x="month_name", y="total_units_kwh",
    order=month_labels, ax=axes[1], palette="coolwarm"
)
axes[1].set_title("Monthly Seasonality Boxplot (per Calendar Month)")
axes[1].set_xlabel("")
axes[1].set_ylabel("KWh")
axes[1].yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f"{x/1e6:.0f}M"))

# Year-on-year overlay
df_eda["year"]  = df_eda.index.year
df_eda["month"] = df_eda.index.month
pivot = df_eda[df_eda["year"] <= 2025].pivot_table(
    index="month", columns="year", values="total_units_kwh"
)
for yr in pivot.columns:
    axes[2].plot(month_labels, pivot[yr], marker="o", markersize=3, label=str(yr), linewidth=1.2)
axes[2].set_title("Year-on-Year Monthly Overlay (2019–2025)")
axes[2].set_ylabel("KWh")
axes[2].yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f"{x/1e6:.0f}M"))
axes[2].legend(ncol=4, fontsize=8)

plt.tight_layout()
plt.savefig("02_eda.png", bbox_inches="tight")
plt.show()


## 5. Seasonal Decomposition

In [None]:
decomp = seasonal_decompose(df["total_units_kwh"], model="multiplicative", period=SEASONAL_PERIOD)

fig, axes = plt.subplots(4, 1, figsize=(12, 10), sharex=True)
for ax, (comp, title, color) in zip(axes, [
    (decomp.observed, "Observed",  "steelblue"),
    (decomp.trend,    "Trend",     "darkorange"),
    (decomp.seasonal, "Seasonal",  "green"),
    (decomp.resid,    "Residual",  "red"),
]):
    ax.plot(comp.index, comp, color=color, linewidth=1.2)
    ax.set_ylabel(title, fontsize=9)
    ax.grid(alpha=0.3)

axes[0].yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f"{x/1e6:.0f}M"))
plt.suptitle("Multiplicative Seasonal Decomposition (period=12)", y=1.01, fontsize=12)
plt.tight_layout()
plt.savefig("03_decomposition.png", bbox_inches="tight")
plt.show()


## 6. Stationarity Check (ADF Test)

In [None]:
def run_adf(series: pd.Series, label: str) -> None:
    result = adfuller(series.dropna(), autolag="AIC")
    stationary = result[1] < 0.05
    print(f"[{label:<30}]  ADF={result[0]:>8.4f}  p={result[1]:.4f}  Stationary={stationary}")

run_adf(df["total_units_kwh"],                        "Raw Series")
run_adf(df["total_units_kwh"].diff(1),                "1st Difference (d=1)")
run_adf(df["total_units_kwh"].diff(12),               "Seasonal Diff (D=1, lag=12)")
run_adf(df["total_units_kwh"].diff(1).diff(12),       "1st + Seasonal Diff (d=1, D=1)")


## 7. ACF & PACF Plots

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 7))
plot_acf( df["total_units_kwh"],                          lags=30, ax=axes[0][0], title="ACF — Raw Series")
plot_pacf(df["total_units_kwh"],                          lags=30, ax=axes[0][1], title="PACF — Raw Series")
plot_acf( df["total_units_kwh"].diff(1).diff(12).dropna(),lags=30, ax=axes[1][0], title="ACF — After d=1 + D=1 Differencing")
plot_pacf(df["total_units_kwh"].diff(1).diff(12).dropna(),lags=30, ax=axes[1][1], title="PACF — After d=1 + D=1 Differencing")
plt.tight_layout()
plt.savefig("04_acf_pacf.png", bbox_inches="tight")
plt.show()


## 8. Build Darts TimeSeries Objects
Darts requires data in its `TimeSeries` format. We create the full series, train, and test splits here.


In [None]:
# Full series as Darts TimeSeries
series_full = TimeSeries.from_series(df["total_units_kwh"])

# Train / Test split — hold out last 4 months (Oct 2025 – Jan 2026)
train_ts, test_ts = series_full.split_before(pd.Timestamp("2025-10-01"))

print(f"Full series : {series_full.start_time().date()} → {series_full.end_time().date()}  ({len(series_full)} obs)")
print(f"Train series: {train_ts.start_time().date()} → {train_ts.end_time().date()}  ({len(train_ts)} obs)")
print(f"Test  series: {test_ts.start_time().date()} → {test_ts.end_time().date()}   ({len(test_ts)} obs)")


## 9. Shared Utilities

In [None]:
def compute_metrics(actual_ts: TimeSeries, forecast_ts: TimeSeries) -> dict:
    """Compute MAE, RMSE, MAPE between a Darts actual and forecast TimeSeries."""
    return {
        "MAE" : round(float(mae(actual_ts,  forecast_ts)), 2),
        "RMSE": round(float(rmse(actual_ts, forecast_ts)), 2),
        "MAPE": round(float(mape(actual_ts, forecast_ts)), 4),
    }


def plot_forecast(
    train_ts: TimeSeries,
    test_ts: TimeSeries,
    forecast_ts: TimeSeries,
    model_name: str,
    save_path: str,
    conf_low: TimeSeries = None,
    conf_high: TimeSeries = None,
) -> None:
    """Plot train history, test actuals, forecast, and optional 95% CI band."""
    fig, ax = plt.subplots(figsize=(13, 4))

    train_pd  = train_ts.pd_series()
    test_pd   = test_ts.pd_series()
    fc_pd     = forecast_ts.pd_series()

    ax.plot(train_pd.index, train_pd,  label="Train",        color="steelblue",  linewidth=1.2)
    ax.plot(test_pd.index,  test_pd,   label="Test Actual",  color="black",      linewidth=1.5, linestyle="--")
    ax.plot(fc_pd.index,    fc_pd,     label=f"{model_name} Forecast",
            color="tomato", linewidth=1.5, marker="o", markersize=5)

    if conf_low is not None and conf_high is not None:
        ax.fill_between(
            fc_pd.index,
            conf_low.pd_series().values,
            conf_high.pd_series().values,
            alpha=0.2, color="tomato", label="95% CI",
        )

    ax.set_title(f"{model_name} — Forecast vs Actual")
    ax.set_ylabel("KWh")
    ax.yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f"{x/1e6:.0f}M"))
    ax.legend(fontsize=8)
    ax.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight")
    plt.show()


def plot_residuals(residuals: pd.Series, model_name: str, save_path: str) -> None:
    """Residual time plot + ACF of residuals."""
    fig, axes = plt.subplots(1, 2, figsize=(13, 4))
    axes[0].plot(residuals.index, residuals, color="purple", linewidth=1)
    axes[0].axhline(0, linestyle="--", color="black", linewidth=0.8)
    axes[0].set_title(f"{model_name} — Residuals")
    axes[0].grid(alpha=0.3)
    plot_acf(residuals.dropna(), lags=20, ax=axes[1], title=f"{model_name} — Residual ACF")
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight")
    plt.show()


def ljung_box_test(residuals: pd.Series, lags: int = 10) -> dict:
    """Ljung-Box white-noise test on residuals."""
    lb = acorr_ljungbox(residuals.dropna(), lags=[lags], return_df=True)
    return {
        "lb_stat":   round(float(lb["lb_stat"].values[0]),   4),
        "lb_pvalue": round(float(lb["lb_pvalue"].values[0]), 4),
    }


def save_model_artifact(model, filename: str) -> str:
    """Pickle a Darts model and return the file path."""
    path = os.path.join(tempfile.gettempdir(), filename)
    with open(path, "wb") as f:
        pickle.dump(model, f)
    return path


def log_run(
    experiment_name: str,
    run_name: str,
    model_family: str,
    model_variant: str,
    params: dict,
    metrics: dict,
    lb: dict,
    forecast_png: str,
    residual_png: str,
    model_obj,
    extra_metrics: dict = None,
) -> None:
    """Consolidated MLflow logging helper used by every model stage."""
    mlflow.set_experiment(experiment_name)
    with mlflow.start_run(run_name=run_name):
        mlflow.set_tag("model_family",  model_family)
        mlflow.set_tag("model_variant", model_variant)
        mlflow.set_tag("outlier_fixed", "True")
        mlflow.log_params(params)
        mlflow.log_metrics(metrics)
        mlflow.log_metrics(lb)
        if extra_metrics:
            mlflow.log_metrics(extra_metrics)
        mlflow.log_artifact(forecast_png)
        mlflow.log_artifact(residual_png)
        model_path = save_model_artifact(model_obj, f"{run_name}.pkl")
        mlflow.log_artifact(model_path, artifact_path="model")
    print(f"[{run_name}] {metrics} | Ljung-Box p={lb['lb_pvalue']}")


# ── Create MLflow experiments ─────────────────────────────────────────────────
EXP_SMOOTHING = "exponential_smoothing"
EXP_ARIMA     = "arima_sarima"

for exp in [EXP_SMOOTHING, EXP_ARIMA]:
    if not mlflow.get_experiment_by_name(exp):
        mlflow.create_experiment(exp)
        print(f"Created experiment: {exp}")
    else:
        print(f"Experiment exists : {exp}")


## 10. Stage 1 — Baseline Models
Benchmarks that every subsequent model must beat.
- **Mean Forecast**: predicts training mean for all test periods
- **Drift Forecast**: extrapolates the line between first and last training point


In [None]:
train_vals = train_ts.pd_series()

# ── Mean Baseline ─────────────────────────────────────────────────────────────
mean_vals   = np.full(len(test_ts), train_vals.mean())
fc_mean_ts  = TimeSeries.from_times_and_values(test_ts.time_index, mean_vals)
metrics_mean = compute_metrics(test_ts, fc_mean_ts)

mlflow.set_experiment(EXP_SMOOTHING)
with mlflow.start_run(run_name="baseline_mean"):
    mlflow.set_tag("model_family", "baseline")
    mlflow.set_tag("model_variant", "mean")
    mlflow.set_tag("outlier_fixed", "True")
    mlflow.log_params({"model_type": "mean_forecast", "forecast_horizon": FORECAST_HORIZON})
    mlflow.log_metrics(metrics_mean)
    plot_forecast(train_ts, test_ts, fc_mean_ts, "Mean Baseline", "05a_baseline_mean.png")
    mlflow.log_artifact("05a_baseline_mean.png")
print("Mean Baseline →", metrics_mean)

# ── Drift Baseline ────────────────────────────────────────────────────────────
n = len(train_vals)
drift_slope = (train_vals.iloc[-1] - train_vals.iloc[0]) / (n - 1)
drift_vals  = np.array([train_vals.iloc[-1] + drift_slope * h for h in range(1, len(test_ts) + 1)])
fc_drift_ts = TimeSeries.from_times_and_values(test_ts.time_index, drift_vals)
metrics_drift = compute_metrics(test_ts, fc_drift_ts)

with mlflow.start_run(run_name="baseline_drift"):
    mlflow.set_tag("model_family", "baseline")
    mlflow.set_tag("model_variant", "drift")
    mlflow.set_tag("outlier_fixed", "True")
    mlflow.log_params({"model_type": "drift_forecast", "forecast_horizon": FORECAST_HORIZON})
    mlflow.log_metrics(metrics_drift)
    plot_forecast(train_ts, test_ts, fc_drift_ts, "Drift Baseline", "05b_baseline_drift.png")
    mlflow.log_artifact("05b_baseline_drift.png")
print("Drift Baseline →", metrics_drift)


## 11. Stage 2 — Simple Exponential Smoothing (SES)
In Darts: `ExponentialSmoothing(trend=ModelMode.NONE, seasonal=SeasonalityMode.NONE)`.  
Models level only — expected to underfit given trend + seasonality in data.


In [None]:
model_ses = ExponentialSmoothing(
    trend=ModelMode.NONE,
    seasonal=SeasonalityMode.NONE,
    seasonal_periods=SEASONAL_PERIOD,
)
model_ses.fit(train_ts)
fc_ses = model_ses.predict(len(test_ts))
metrics_ses = compute_metrics(test_ts, fc_ses)

# Residuals from fitted values
fitted_vals_ses = model_ses.model.fittedvalues  # underlying statsmodels object
residuals_ses   = pd.Series(
    train_ts.pd_series().values - fitted_vals_ses.values,
    index=train_ts.time_index,
)
lb_ses = ljung_box_test(residuals_ses)

plot_forecast(train_ts, test_ts, fc_ses, "SES", "06_ses_forecast.png")
plot_residuals(residuals_ses, "SES", "06_ses_residuals.png")

alpha_ses = round(model_ses.model.params["smoothing_level"], 4)
log_run(
    EXP_SMOOTHING, "SES", "exponential_smoothing", "SES",
    params={"model_type": "SES", "alpha": alpha_ses, "forecast_horizon": FORECAST_HORIZON},
    metrics=metrics_ses, lb=lb_ses,
    forecast_png="06_ses_forecast.png", residual_png="06_ses_residuals.png",
    model_obj=model_ses,
)


## 12. Stage 3 — Holt's Linear Exponential Smoothing (Double ES)
In Darts: `ExponentialSmoothing(trend=ModelMode.ADDITIVE, seasonal=SeasonalityMode.NONE)`.  
Adds trend component (α + β). Two variants: standard and damped.


In [None]:
holt_variants = [
    {"damped": False, "run_name": "Holts_Linear",        "label": "Holt Linear"},
    {"damped": True,  "run_name": "Holts_Linear_Damped", "label": "Holt Linear Damped"},
]

for v in holt_variants:
    model_holt = ExponentialSmoothing(
        trend=ModelMode.ADDITIVE,
        damped=v["damped"],
        seasonal=SeasonalityMode.NONE,
        seasonal_periods=SEASONAL_PERIOD,
    )
    model_holt.fit(train_ts)
    fc_holt = model_holt.predict(len(test_ts))
    metrics_holt = compute_metrics(test_ts, fc_holt)

    fitted_holt  = model_holt.model.fittedvalues
    residuals_holt = pd.Series(
        train_ts.pd_series().values - fitted_holt.values,
        index=train_ts.time_index,
    )
    lb_holt = ljung_box_test(residuals_holt)

    alpha = round(model_holt.model.params["smoothing_level"], 4)
    beta  = round(model_holt.model.params["smoothing_trend"], 4)
    phi   = round(model_holt.model.params.get("damping_trend", 1.0), 4)

    fc_png  = f"07_{v['run_name']}_forecast.png"
    res_png = f"07_{v['run_name']}_residuals.png"
    plot_forecast(train_ts, test_ts, fc_holt, v["label"], fc_png)
    plot_residuals(residuals_holt, v["label"], res_png)

    log_run(
        EXP_SMOOTHING, v["run_name"], "exponential_smoothing", v["run_name"],
        params={
            "model_type": "Holts_ExponentialSmoothing",
            "trend": "additive", "damped": v["damped"],
            "alpha": alpha, "beta": beta, "phi": phi,
            "forecast_horizon": FORECAST_HORIZON,
        },
        metrics=metrics_holt, lb=lb_holt,
        forecast_png=fc_png, residual_png=res_png,
        model_obj=model_holt,
        extra_metrics={"AIC": round(model_holt.model.aic, 4)},
    )


## 13. Stage 4 — Holt-Winters Exponential Smoothing
In Darts: `ExponentialSmoothing` with both `trend` and `seasonal` set.  
Four variants: additive/multiplicative seasonal × standard/damped trend.


In [None]:
hw_variants = [
    {"seasonal": SeasonalityMode.ADDITIVE,       "damped": False,
     "run_name": "HW_Additive",             "label": "HW Additive"},
    {"seasonal": SeasonalityMode.ADDITIVE,       "damped": True,
     "run_name": "HW_Additive_Damped",      "label": "HW Additive Damped"},
    {"seasonal": SeasonalityMode.MULTIPLICATIVE, "damped": False,
     "run_name": "HW_Multiplicative",       "label": "HW Multiplicative"},
    {"seasonal": SeasonalityMode.MULTIPLICATIVE, "damped": True,
     "run_name": "HW_Multiplicative_Damped","label": "HW Multiplicative Damped"},
]

for v in hw_variants:
    model_hw = ExponentialSmoothing(
        trend=ModelMode.ADDITIVE,
        damped=v["damped"],
        seasonal=v["seasonal"],
        seasonal_periods=SEASONAL_PERIOD,
    )
    model_hw.fit(train_ts)
    fc_hw = model_hw.predict(len(test_ts))
    metrics_hw = compute_metrics(test_ts, fc_hw)

    fitted_hw    = model_hw.model.fittedvalues
    residuals_hw = pd.Series(
        train_ts.pd_series().values - fitted_hw.values,
        index=train_ts.time_index,
    )
    lb_hw = ljung_box_test(residuals_hw)

    alpha = round(model_hw.model.params["smoothing_level"],    4)
    beta  = round(model_hw.model.params["smoothing_trend"],    4)
    gamma = round(model_hw.model.params["smoothing_seasonal"], 4)
    phi   = round(model_hw.model.params.get("damping_trend", 1.0), 4)

    fc_png  = f"08_{v['run_name']}_forecast.png"
    res_png = f"08_{v['run_name']}_residuals.png"
    plot_forecast(train_ts, test_ts, fc_hw, v["label"], fc_png)
    plot_residuals(residuals_hw, v["label"], res_png)

    log_run(
        EXP_SMOOTHING, v["run_name"], "holt_winters", v["run_name"],
        params={
            "model_type": "HoltWinters",
            "trend": "additive",
            "seasonal": str(v["seasonal"]).split(".")[-1].lower(),
            "seasonal_periods": SEASONAL_PERIOD,
            "damped": v["damped"],
            "alpha": alpha, "beta": beta, "gamma": gamma, "phi": phi,
            "forecast_horizon": FORECAST_HORIZON,
        },
        metrics=metrics_hw, lb=lb_hw,
        forecast_png=fc_png, residual_png=res_png,
        model_obj=model_hw,
        extra_metrics={
            "AIC": round(model_hw.model.aic, 4),
            "BIC": round(model_hw.model.bic, 4),
        },
    )


## 14. Stage 5 — ARIMA (Auto Order Selection)
Darts `AutoARIMA` wraps `statsforecast.models.AutoARIMA` for automatic non-seasonal order selection.  
Acts as a stepping stone before adding the seasonal component.


In [None]:
mlflow.set_experiment(EXP_ARIMA)

print("Fitting AutoARIMA (non-seasonal)...")
model_arima = AutoARIMA(season_length=1)   # season_length=1 → non-seasonal search
model_arima.fit(train_ts)
fc_arima = model_arima.predict(len(test_ts))
metrics_arima = compute_metrics(test_ts, fc_arima)

# Extract fitted model internals via the underlying statsforecast model
arima_summary = model_arima.model  # statsforecast AutoARIMA result

plot_forecast(train_ts, test_ts, fc_arima, "AutoARIMA (non-seasonal)", "09_arima_forecast.png")

# Residuals: actual − fitted (use in-sample predict as proxy for fitted values)
fc_insample  = model_arima.predict(len(train_ts), series=train_ts)
# For residuals use statsmodels ARIMA wrapper via Darts ARIMA class with best order
residuals_arima = train_ts.pd_series() - fc_insample.pd_series().reindex(train_ts.time_index)
residuals_arima = residuals_arima.dropna()
lb_arima = ljung_box_test(residuals_arima)

plot_residuals(residuals_arima, "AutoARIMA", "09_arima_residuals.png")

with mlflow.start_run(run_name="ARIMA_auto"):
    mlflow.set_tag("model_family",  "arima")
    mlflow.set_tag("model_variant", "AutoARIMA_non_seasonal")
    mlflow.set_tag("outlier_fixed", "True")
    mlflow.log_params({
        "model_type": "AutoARIMA", "seasonal": False,
        "season_length": 1, "forecast_horizon": FORECAST_HORIZON,
    })
    mlflow.log_metrics(metrics_arima)
    mlflow.log_metrics(lb_arima)
    mlflow.log_artifact("09_arima_forecast.png")
    mlflow.log_artifact("09_arima_residuals.png")
    model_path = save_model_artifact(model_arima, "arima_auto.pkl")
    mlflow.log_artifact(model_path, artifact_path="model")

print(f"AutoARIMA → {metrics_arima} | Ljung-Box p={lb_arima['lb_pvalue']}")


## 15. Stage 6a — SARIMA (AutoARIMA with m=12)
Darts `AutoARIMA(season_length=12)` searches optimal (p,d,q)(P,D,Q)[12] automatically.


In [None]:
mlflow.set_experiment(EXP_ARIMA)

print("Fitting AutoARIMA (seasonal, m=12)... may take ~60 seconds.")
model_sarima_auto = AutoARIMA(season_length=SEASONAL_PERIOD)
model_sarima_auto.fit(train_ts)
fc_sarima_auto = model_sarima_auto.predict(len(test_ts))
metrics_sarima_auto = compute_metrics(test_ts, fc_sarima_auto)

plot_forecast(train_ts, test_ts, fc_sarima_auto,
              "AutoARIMA SARIMA (m=12)", "10_sarima_auto_forecast.png")

fc_insample_s    = model_sarima_auto.predict(len(train_ts), series=train_ts)
residuals_sarima = (train_ts.pd_series()
                    - fc_insample_s.pd_series().reindex(train_ts.time_index)).dropna()
lb_sarima_auto   = ljung_box_test(residuals_sarima)

plot_residuals(residuals_sarima, "SARIMA Auto", "10_sarima_auto_residuals.png")

with mlflow.start_run(run_name="SARIMA_auto"):
    mlflow.set_tag("model_family",  "sarima")
    mlflow.set_tag("model_variant", "AutoARIMA_seasonal_m12")
    mlflow.set_tag("outlier_fixed", "True")
    mlflow.log_params({
        "model_type": "AutoARIMA_seasonal",
        "season_length": SEASONAL_PERIOD,
        "forecast_horizon": FORECAST_HORIZON,
    })
    mlflow.log_metrics(metrics_sarima_auto)
    mlflow.log_metrics(lb_sarima_auto)
    mlflow.log_artifact("10_sarima_auto_forecast.png")
    mlflow.log_artifact("10_sarima_auto_residuals.png")
    model_path = save_model_artifact(model_sarima_auto, "sarima_auto.pkl")
    mlflow.log_artifact(model_path, artifact_path="model")

print(f"SARIMA Auto → {metrics_sarima_auto} | Ljung-Box p={lb_sarima_auto['lb_pvalue']}")


## 16. Stage 6b — SARIMA Manual Candidates
Using Darts `ARIMA` class with explicit `(p,d,q)` and `seasonal_order=(P,D,Q,m)`.  
Four candidates backed by ACF/PACF analysis (strong lag-1 and lag-12 autocorrelation).


In [None]:
from statsmodels.tsa.statespace.sarimax import SARIMAX as SM_SARIMAX

mlflow.set_experiment(EXP_ARIMA)

sarima_candidates = [
    {"p":1,"d":1,"q":1,"P":1,"D":1,"Q":1, "run_name":"SARIMA_1_1_1_1_1_1_12"},
    {"p":1,"d":0,"q":1,"P":1,"D":1,"Q":1, "run_name":"SARIMA_1_0_1_1_1_1_12"},
    {"p":2,"d":1,"q":1,"P":1,"D":1,"Q":1, "run_name":"SARIMA_2_1_1_1_1_1_12"},
    {"p":1,"d":1,"q":2,"P":1,"D":1,"Q":1, "run_name":"SARIMA_1_1_2_1_1_1_12"},
]

for c in sarima_candidates:
    label = f"SARIMA({c['p']},{c['d']},{c['q']})({c['P']},{c['D']},{c['Q']})[12]"
    try:
        # Darts ARIMA with seasonal_order
        model_s = ARIMA(
            p=c["p"], d=c["d"], q=c["q"],
            seasonal_order=(c["P"], c["D"], c["Q"], SEASONAL_PERIOD),
        )
        model_s.fit(train_ts)
        fc_s = model_s.predict(len(test_ts))
        metrics_s = compute_metrics(test_ts, fc_s)

        # Confidence intervals via underlying statsmodels object
        sm_result     = model_s.model  # fitted statsmodels SARIMAXResults
        fc_sm         = sm_result.get_forecast(steps=len(test_ts))
        conf_int      = fc_sm.conf_int()
        conf_low_ts   = TimeSeries.from_times_and_values(
            test_ts.time_index, conf_int.iloc[:, 0].values
        )
        conf_high_ts  = TimeSeries.from_times_and_values(
            test_ts.time_index, conf_int.iloc[:, 1].values
        )

        residuals_s = pd.Series(sm_result.resid, index=train_ts.time_index)
        lb_s        = ljung_box_test(residuals_s)

        fc_png  = f"11_{c['run_name']}_forecast.png"
        res_png = f"11_{c['run_name']}_residuals.png"
        plot_forecast(train_ts, test_ts, fc_s, label, fc_png, conf_low_ts, conf_high_ts)
        plot_residuals(residuals_s, label, res_png)

        log_run(
            EXP_ARIMA, c["run_name"], "sarima", label,
            params={
                "model_type": "SARIMA_manual",
                "p": c["p"], "d": c["d"], "q": c["q"],
                "P": c["P"], "D": c["D"], "Q": c["Q"], "m": SEASONAL_PERIOD,
                "forecast_horizon": FORECAST_HORIZON,
            },
            metrics=metrics_s, lb=lb_s,
            forecast_png=fc_png, residual_png=res_png,
            model_obj=model_s,
            extra_metrics={
                "AIC": round(sm_result.aic, 4),
                "BIC": round(sm_result.bic, 4),
            },
        )

    except Exception as e:
        print(f"Failed {c['run_name']}: {e}")


## 17. Model Comparison — MLflow Results Summary

In [None]:
def fetch_runs(experiment_name: str) -> pd.DataFrame:
    exp = mlflow.get_experiment_by_name(experiment_name)
    return mlflow.search_runs([exp.experiment_id]) if exp else pd.DataFrame()

cols = ["tags.mlflow.runName","metrics.MAPE","metrics.RMSE","metrics.MAE",
        "metrics.AIC","metrics.BIC","metrics.lb_pvalue","tags.model_family"]

all_runs = pd.concat([fetch_runs(EXP_SMOOTHING), fetch_runs(EXP_ARIMA)], ignore_index=True)
avail    = [c for c in cols if c in all_runs.columns]

summary = (
    all_runs[avail]
    .rename(columns={
        "tags.mlflow.runName":  "Model",
        "metrics.MAPE":        "MAPE(%)",
        "metrics.RMSE":        "RMSE",
        "metrics.MAE":         "MAE",
        "metrics.AIC":         "AIC",
        "metrics.BIC":         "BIC",
        "metrics.lb_pvalue":   "LjungBox_p",
        "tags.model_family":   "Family",
    })
    .sort_values("MAPE(%)")
    .reset_index(drop=True)
)

pd.set_option("display.float_format", lambda x: f"{x:,.4f}")
print("\n===== MODEL COMPARISON (sorted by MAPE) =====\n")
print(summary.to_string(index=False))
summary.to_csv("12_model_comparison.csv", index=False)
print("\nSaved: 12_model_comparison.csv")


## 18. Final Forecast — Best Model Retrained on Full Data
Retrain the best-performing model on the **full dataset** (all 85 months including Jan 2026)  
and forecast Feb–Apr 2026 (next 3 months).

> Update `BEST_MODEL_TYPE` below after reviewing the comparison table in Section 17.


In [None]:
# ── Select best model type after reviewing Section 17 ────────────────────────
# Options: "hw_multiplicative_damped"  |  "sarima_manual"  |  "sarima_auto"
BEST_MODEL_TYPE = "hw_multiplicative_damped"   # update if SARIMA wins

future_index = pd.date_range(start="2026-02-01", periods=FORECAST_HORIZON, freq="MS")

if BEST_MODEL_TYPE == "hw_multiplicative_damped":
    best_model = ExponentialSmoothing(
        trend=ModelMode.ADDITIVE,
        damped=True,
        seasonal=SeasonalityMode.MULTIPLICATIVE,
        seasonal_periods=SEASONAL_PERIOD,
    )
    best_model.fit(series_full)
    final_forecast_ts = best_model.predict(FORECAST_HORIZON)
    model_label = "HW Multiplicative Damped (Full Data)"

elif BEST_MODEL_TYPE == "sarima_auto":
    best_model = AutoARIMA(season_length=SEASONAL_PERIOD)
    best_model.fit(series_full)
    final_forecast_ts = best_model.predict(FORECAST_HORIZON)
    model_label = "AutoARIMA SARIMA m=12 (Full Data)"

elif BEST_MODEL_TYPE == "sarima_manual":
    # Update (p,d,q)(P,D,Q) based on best manual candidate from Section 16
    best_model = ARIMA(p=1, d=1, q=1, seasonal_order=(1, 1, 1, SEASONAL_PERIOD))
    best_model.fit(series_full)
    final_forecast_ts = best_model.predict(FORECAST_HORIZON)
    model_label = "SARIMA(1,1,1)(1,1,1)[12] (Full Data)"

# Reassign correct future timestamps
final_fc_vals = final_forecast_ts.pd_series().values
final_fc_ts   = TimeSeries.from_times_and_values(
    pd.DatetimeIndex(future_index), final_fc_vals
)

# ── Log final forecast to MLflow ──────────────────────────────────────────────
target_exp = EXP_SMOOTHING if "hw" in BEST_MODEL_TYPE else EXP_ARIMA
mlflow.set_experiment(target_exp)

with mlflow.start_run(run_name=f"FINAL_FORECAST_{BEST_MODEL_TYPE.upper()}"):
    mlflow.set_tag("model_family", BEST_MODEL_TYPE)
    mlflow.set_tag("stage",        "production_forecast")
    mlflow.set_tag("trained_on",   "full_data_2019_2026")
    mlflow.log_param("forecast_horizon", FORECAST_HORIZON)
    mlflow.log_param("forecast_start",   str(future_index[0].date()))
    mlflow.log_param("forecast_end",     str(future_index[-1].date()))
    for i, (dt, val) in enumerate(zip(future_index, final_fc_vals), 1):
        mlflow.log_metric(f"forecast_month_{i}", round(float(val), 2))

    # Plot: last 24 months of history + 3-month forecast
    hist_pd = series_full.pd_series().iloc[-24:]
    fig, ax = plt.subplots(figsize=(13, 5))
    ax.plot(hist_pd.index, hist_pd,
            label="Historical (last 24 months)", color="steelblue", linewidth=1.5)
    ax.plot(future_index, final_fc_vals,
            label="3-Month Forecast", color="tomato", linewidth=2,
            marker="o", markersize=7)
    ax.axvline(series_full.end_time(), linestyle="--", color="gray", linewidth=0.8)
    ax.set_title(f"3-Month Ahead Forecast — {model_label}")
    ax.set_ylabel("KWh")
    ax.yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f"{x/1e6:.0f}M"))
    ax.legend()
    ax.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig("13_final_forecast.png", bbox_inches="tight")
    plt.show()
    mlflow.log_artifact("13_final_forecast.png")

print("\n===== FINAL 3-MONTH FORECAST (Feb–Apr 2026) =====")
for dt, val in zip(future_index, final_fc_vals):
    print(f"  {dt.strftime('%Y-%m')}: {float(val):>20,.0f} KWh")


## 19. View Results in MLflow UI
Launch the MLflow tracking UI from your terminal:
```bash
mlflow ui --port 5000
```
Then open: **http://localhost:5000**

Navigate to:
- `exponential_smoothing` experiment → compare SES, Holt, Holt-Winters variants by MAPE
- `arima_sarima` experiment → compare ARIMA / SARIMA candidates by MAPE, AIC, BIC, Ljung-Box p

**Artifacts logged per run**: forecast plot, residual ACF plot, pickled model file.
