## importações e configurações

In [None]:
import os, sys
from pathlib import Path
import importlib
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

In [None]:
# Caminho raiz do projeto
PROJ = Path("/content/drive/MyDrive/tcc-modelo/tcc-demand-forecasting")

# monta o drive
from google.colab import drive
drive.mount('/content/drive')

# Garante que o PROJECT_DIR está no sys.path
if str(PROJ) not in sys.path:
    sys.path.append(str(PROJ))

print("Repositório ativo em:", PROJ)

In [None]:
from src.evaluations.models_metrics import calculate_metrics, compare_models

interim_dir = PROJ / "data" / "interim"
output_name_imputed = "olist_weekly_agg_withlags_imputed_2.parquet"
df_path = interim_dir / output_name_imputed

In [None]:
# Colunas do seu dataset
cutoff_col  = "order_week"   # coluna temporal (datetime)
date_col    = "order_week"   # mesma coluna para carimbar previsões
target_col  = "sales_qty"    # alvo
id_col      = "id"   # opcional

# Períodos (exemplo)
first_train_end = pd.Timestamp("2018-03-18")
test_start      = pd.Timestamp("2018-03-19")
test_end        = pd.Timestamp("2018-08-27")

# Janela de rolling (ex.: blocos de 4 semanas)
horizon_stride = (pd.Timedelta(days=4), pd.Timedelta(days=7))  # (gap após cutoff, janela)

## definição do df

In [None]:
df = pd.read_parquet(df_path)

In [None]:
df['id'] = df.index

In [None]:
deciles = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
df['order_week'].describe(percentiles=deciles)

## funcoes uteis

In [None]:
from typing import Dict, List, Tuple, Optional
import pandas as pd
import numpy as np

# ---------- 1) Esquemas de treinamento ----------
def split_static(df, cutoff_col, first_train_end, test_start, test_end) -> List[Tuple[pd.DataFrame, pd.DataFrame]]:
    """Treina uma vez até first_train_end e prevê em janelas subsequentes sem re-treinar."""
    train = df[df[cutoff_col] <= first_train_end].copy()
    test  = df[(df[cutoff_col] >= test_start) & (df[cutoff_col] <= test_end)].copy()
    return [(train, test)]  # único par

def split_rolling(df, cutoff_col, first_train_end, horizon_stride, step_k=1) -> List[Tuple[pd.DataFrame, pd.DataFrame]]:
    """
    Expanding window: a cada iteração, amplia o treino e prevê o próximo bloco (stride).
    step_k controla de quantas em quantas iterações re-treinamos (k=1 re-treina sempre).
    """
    pairs = []
    current_end = first_train_end
    i = 0
    while True:
        val_start = current_end + horizon_stride[0]
        val_end   = current_end + horizon_stride[1]
        train = df[df[cutoff_col] <= current_end].copy()
        valid = df[(df[cutoff_col] > current_end) & (df[cutoff_col] <= val_end)].copy()
        if valid.empty:
            break
        pairs.append((train, valid))
        current_end = val_end
        i += 1
    return pairs

# ---------- 2) Conjuntos de features ----------
def feature_set_basic(cols_all: List[str]) -> List[str]:
    """Somente 'sales_qty_...' e calendário."""
    return [c for c in cols_all if c.startswith("sales_qty_") or c.startswith("cal_")]

def feature_set_all(cols_all: List[str]) -> List[str]:
    """Todas as features criadas (exceto alvo/IDs)."""
    return [c for c in cols_all if c not in ("sales_qty", "id", "order_week", "have_nulls", "product_category_name")]

def feature_set_selected(selected_list_path: str) -> List[str]:
    """Carrega a lista congelada de relevância univariada."""
    with open(selected_list_path) as f:
        feats = [ln.strip() for ln in f if ln.strip()]
    return feats

# ---------- 3) Modelos (LGBM com HPO opcional) ----------
def _get_lgbm_search_objects(cfg_hpo: Dict):
    """Retorna (estimator, searcher) conforme config de HPO."""
    from lightgbm import LGBMRegressor
    from sklearn.model_selection import GridSearchCV, RandomizedSearchCV

    base = LGBMRegressor(
        random_state=42,
        n_jobs=-1
    )

    default_grid = {
        "num_leaves": [31, 63, 127],
        "max_depth": [-1, 8, 12],
        "learning_rate": [0.03, 0.05, 0.1],
        "n_estimators": [100, 400, 800],
        "subsample": [0.8, 1.0],
        "colsample_bytree": [0.8, 1.0],
        "min_child_samples": [10, 20, 30],
        "reg_alpha": [0.0, 0.1],
        "reg_lambda": [0.0, 0.1]
    }

    search_kind = cfg_hpo.get("search", "random")  # "random" | "grid"
    scoring     = cfg_hpo.get("scoring", "neg_root_mean_squared_error")
    cv          = cfg_hpo.get("cv", None)
    n_jobs      = cfg_hpo.get("n_jobs", -1)
    verbose     = cfg_hpo.get("verbose", 0)
    param_grid  = cfg_hpo.get("param_grid", default_grid)

    if search_kind == "grid":
        searcher = GridSearchCV(
            estimator=base,
            param_grid=param_grid,
            scoring=scoring,
            cv=cv,
            n_jobs=n_jobs,
            verbose=verbose
        )
    else:
        n_iter = cfg_hpo.get("n_iter", 30)
        searcher = RandomizedSearchCV(
            estimator=base,
            param_distributions=param_grid,
            n_iter=n_iter,
            scoring=scoring,
            cv=cv,
            n_jobs=n_jobs,
            verbose=verbose,
            random_state=42
        )

    return base, searcher

def _make_timeseries_cv(train_df: pd.DataFrame, date_col: str, n_splits: int = 3):
    """Cria um TimeSeriesSplit consistente (ordenando por data)."""
    from sklearn.model_selection import TimeSeriesSplit
    return TimeSeriesSplit(n_splits=n_splits)

def fit_predict_lgbm(
    train: pd.DataFrame,
    valid: pd.DataFrame,
    features: List[str],
    id_col: Optional[str],
    date_col: str,
    target_col: str,
    cfg: Optional[Dict] = None
):
    """
    Treina e prediz com LGBM.
    - Se cfg["lgbm_hpo"]["enable"] for True, faz HPO temporal em train e depois ajusta best_estimator_ no train completo.
    - Caso contrário, usa um conjunto fixo de hiperparâmetros.
    Retorna (preds_valid, model).
    """
    import lightgbm as lgb

    cfg = cfg or {}
    hpo_cfg = cfg.get("lgbm_hpo", {"enable": False})

    train_sorted = train.sort_values(date_col)
    valid_sorted = valid.sort_values(date_col)

    X_tr = train_sorted[features]
    y_tr = train_sorted[target_col]
    X_va = valid_sorted[features]

    if hpo_cfg.get("enable", False):
        cv = _make_timeseries_cv(train_sorted, date_col, n_splits=hpo_cfg.get("cv_splits", 3))
        _, searcher = _get_lgbm_search_objects({
            **hpo_cfg,
            "cv": cv
        })

        searcher.fit(X_tr, y_tr)
        best_params = searcher.best_params_

        mdl = lgb.LGBMRegressor(random_state=42, **best_params)
        mdl.fit(
            X_tr, y_tr,
            eval_set=[(X_va, valid_sorted[target_col])],
            eval_metric="rmse",
            callbacks=[lgb.early_stopping(stopping_rounds=hpo_cfg.get("early_stopping_rounds", 100), verbose=False)]
        )
    else:
        mdl = lgb.LGBMRegressor(
            random_state=42,
            n_jobs=-1
        )
        mdl.fit(
            X_tr, y_tr,
            eval_set=[(X_va, valid_sorted[target_col])],
            eval_metric="rmse",
            callbacks=[lgb.early_stopping(stopping_rounds=100, verbose=False)]
        )

    preds = mdl.predict(X_va)
    return preds, mdl

# ---------- Prophet ----------
def fit_predict_prophet(train, valid, features_exog, date_col, target_col):
    """
    Assume que train/valid são de UMA série (ex.: 1 categoria).
    Usa features_exog como regressores adicionais.
    """
    from prophet import Prophet
    df_p = train[[date_col, target_col] + features_exog].rename(columns={date_col: "ds", target_col: "y"})
    m = Prophet()
    for c in features_exog:
        m.add_regressor(c)
    m.fit(df_p)
    df_future = valid[[date_col] + features_exog].rename(columns={date_col: "ds"})
    yhat = m.predict(df_future)["yhat"].values
    return yhat, m

# ---------- SARIMA ----------
def fit_predict_sarima(
    train: pd.DataFrame,
    valid: pd.DataFrame,
    date_col: str,
    target_col: str,
    cfg: Optional[Dict] = None
):
    """
    SARIMA univariado (assume que train/valid contêm uma única série,
    ex.: já filtrada por categoria).
    """
    from statsmodels.tsa.statespace.sarimax import SARIMAX

    cfg = cfg or {}
    sar_cfg = cfg.get("sarima", {})

    order = sar_cfg.get("order", (1, 1, 1))
    seasonal_order = sar_cfg.get("seasonal_order", (1, 1, 1, 52))  # semanal por padrão
    enforce_stationarity = sar_cfg.get("enforce_stationarity", True)
    enforce_invertibility = sar_cfg.get("enforce_invertibility", True)

    train_sorted = train.sort_values(date_col)
    valid_sorted = valid.sort_values(date_col)

    y_tr = train_sorted[target_col].values

    model = SARIMAX(
        y_tr,
        order=order,
        seasonal_order=seasonal_order,
        enforce_stationarity=enforce_stationarity,
        enforce_invertibility=enforce_invertibility
    )
    res = model.fit(disp=False)

    n_forecast = len(valid_sorted)
    yhat = res.forecast(steps=n_forecast)
    return np.asarray(yhat), res

# ---------- LSTM ----------
def _make_lstm_sequences(series: np.ndarray, lookback: int):
    """
    Transforma série 1D em (X, y) para LSTM:
    X: (n_samples, lookback, 1), y: (n_samples,)
    """
    X, y = [], []
    for i in range(lookback, len(series)):
        X.append(series[i - lookback:i])
        y.append(series[i])
    X = np.array(X)
    y = np.array(y)
    # adiciona dimensão de feature = 1
    X = X[..., None]
    return X, y

def fit_predict_lstm(
    train: pd.DataFrame,
    valid: pd.DataFrame,
    date_col: str,
    target_col: str,
    cfg: Optional[Dict] = None
):
    """
    LSTM univariado simples.
    Assumimos uma única série (ex.: 1 categoria) em train/valid.
    """
    from tensorflow.keras.models import Sequential
    from tensorflow.keras.layers import LSTM, Dense
    from tensorflow.keras.optimizers import Adam

    cfg = cfg or {}
    lstm_cfg = cfg.get("lstm", {})

    lookback   = lstm_cfg.get("lookback", 8)   # janelas de 8 semanas por padrão
    epochs     = lstm_cfg.get("epochs", 50)
    batch_size = lstm_cfg.get("batch_size", 16)
    units      = lstm_cfg.get("units", 32)
    lr         = lstm_cfg.get("lr", 1e-3)

    train_sorted = train.sort_values(date_col)
    valid_sorted = valid.sort_values(date_col)

    y_tr = train_sorted[target_col].values.astype("float32")

    # Garante comprimento mínimo
    if len(y_tr) <= lookback:
        raise ValueError(f"Série de treino muito curta para LSTM (len={len(y_tr)}, lookback={lookback}).")

    X_tr, y_tr_supervised = _make_lstm_sequences(y_tr, lookback)

    model = Sequential([
        LSTM(units, input_shape=(lookback, 1)),
        Dense(1)
    ])
    model.compile(loss="mse", optimizer=Adam(learning_rate=lr))

    model.fit(
        X_tr, y_tr_supervised,
        epochs=epochs,
        batch_size=batch_size,
        verbose=0
    )

    # Forecast recursivo para o horizonte de validação
    n_forecast = len(valid_sorted)
    history = y_tr.copy()
    preds = []

    for _ in range(n_forecast):
        if len(history) < lookback:
            raise ValueError("Histórico insuficiente durante a fase de previsão LSTM.")
        x_input = history[-lookback:]
        x_input = x_input.reshape(1, lookback, 1)
        yhat = model.predict(x_input, verbose=0)[0, 0]
        preds.append(float(yhat))
        history = np.append(history, yhat)

    return np.array(preds, dtype="float32"), model

# ---------- 4) Ensemble ----------
def ensemble_mean(preds_dict: Dict[str, np.ndarray]) -> np.ndarray:
    """Média simples entre modelos."""
    stacked = np.vstack([v for v in preds_dict.values()])
    return stacked.mean(axis=0)

def ensemble_weighted(preds_dict: Dict[str, np.ndarray], weights: Dict[str, float]) -> np.ndarray:
    keys = list(preds_dict.keys())
    W = np.array([weights.get(k, 1.0) for k in keys])
    W = W / W.sum()
    stacked = np.vstack([preds_dict[k] for k in keys])
    return (stacked * W[:, None]).sum(axis=0)

# ---------- 5) Métricas ----------
def wape(y_true, y_pred):
    den = np.abs(y_true).sum()
    return np.inf if den == 0 else np.abs(y_true - y_pred).sum() / den

def rmse(y_true, y_pred):
    return float(np.sqrt(np.mean((y_true - y_pred) ** 2)))

# ---------- 6) Runner de cenário ----------
def run_scenario(pairs, feature_mode, model_mode, cols_all, cfg) -> Dict:
    """
    pairs: lista de (train, valid)
    feature_mode: "basic" | "all" | "selected"
    model_mode:
        "single:lgbm"
        "single:prophet"
        "single:sarima"
        "single:lstm"
        "ensemble:all" (lgbm + prophet + sarima + lstm, quando disponíveis)
    """
    # 6.1 escolhe features
    if feature_mode == "basic":
        feats = feature_set_basic(cols_all)
    elif feature_mode == "selected":
        feats = feature_set_selected(cfg["selected_list_path"])
    else:
        feats = feature_set_all(cols_all)

    results = []
    for i, (tr, va) in enumerate(pairs):
        y_true = va[cfg["target_col"]].values
        preds_pack = {}

        # LGBM (usa features tabulares)
        if model_mode in ("single:lgbm", "ensemble:all"):
            p_lgbm, _ = fit_predict_lgbm(
                tr, va, feats,
                cfg.get("id_col"), cfg["date_col"], cfg["target_col"],
                cfg=cfg
            )
            preds_pack["lgbm"] = p_lgbm

        # Prophet (usa exógenas, aqui tudo que não começa com sales_qty_)
        if model_mode in ("single:prophet", "ensemble:all"):
            exog = [c for c in feats if not c.startswith("sales_qty_")]
            p_prophet, _ = fit_predict_prophet(tr, va, exog, cfg["date_col"], cfg["target_col"])
            preds_pack["prophet"] = p_prophet

        # SARIMA (univariado)
        if model_mode in ("single:sarima", "ensemble:all"):
            p_sarima, _ = fit_predict_sarima(
                tr, va,
                date_col=cfg["date_col"],
                target_col=cfg["target_col"],
                cfg=cfg
            )
            preds_pack["sarima"] = p_sarima

        # LSTM (univariado)
        if model_mode in ("single:lstm", "ensemble:all"):
            p_lstm, _ = fit_predict_lstm(
                tr, va,
                date_col=cfg["date_col"],
                target_col=cfg["target_col"],
                cfg=cfg
            )
            preds_pack["lstm"] = p_lstm

        if model_mode.startswith("single:"):
            key = model_mode.split(":")[1]
            y_pred = preds_pack[key]
        else:
            y_pred = ensemble_mean(preds_pack)

        res = {
            "fold": i,
            "feature_mode": feature_mode,
            "model_mode": model_mode,
            "WAPE": wape(y_true, y_pred),
            "RMSE": rmse(y_true, y_pred),
        }
        results.append(res)

    return {
        "rows": results,
        "avg_WAPE": np.mean([r["WAPE"] for r in results]),
        "avg_RMSE": np.mean([r["RMSE"] for r in results])
    }


## lista de festures

In [None]:
desconsiderar = [ 'approval_time_hours_mean_co',
 'approval_time_hours_mean_ne',
 'approval_time_hours_mean_n',
 'approval_time_hours_mean_se',
 'approval_time_hours_mean_s',
 'delivery_diff_estimated_mean_co',
 'delivery_diff_estimated_mean_ne',
 'delivery_diff_estimated_mean_n',
 'delivery_diff_estimated_mean_se',
 'delivery_diff_estimated_mean_s',
 'est_delivery_lead_days_mean_co',
 'est_delivery_lead_days_mean_ne',
 'est_delivery_lead_days_mean_n',
 'est_delivery_lead_days_mean_se',
 'est_delivery_lead_days_mean_s',
 'delivery_diff_estimated_weighted',
 'est_delivery_lead_days_weighted',
 'approval_time_hours_weighted',
 'customer_regions']

In [None]:
# Garante tipo datetime
df[cutoff_col] = pd.to_datetime(df[cutoff_col])

# Seleciona features "completas" (a função já exclui y/id/date por padrão)
features = feature_set_all(df.columns.tolist())
features = [c for c in features if c not in desconsiderar]

In [None]:
selected_features = ['sales_qty_roll8_mean',
 'sales_qty_lag1',
 'sales_qty_roll4_mean',
 'sales_qty_lag2',
 'sales_qty_lag4',
 'sales_qty_lag8',
 'sales_qty_roll8_std',
 'sales_qty_roll4_std',
 'approval_time_hours_weighted_roll8_std',
 'price_var_m4_vs_prev4_mean_roll8_std',
 'est_delivery_lead_days_weighted_roll8_std',
 'approval_time_hours_weighted_roll4_std',
 'est_delivery_lead_days_weighted',
 'delivery_diff_estimated_weighted',
 'est_delivery_lead_days_weighted_roll4_std',
 'price_var_m4_vs_prev4_mean_roll4_std',
 'price_var_w1_point_mean_roll8_std',
 'price_var_w1_point_mean_roll4_std',
 'price_var_w1_smooth_mean_roll8_std',
 'approval_time_hours_weighted']

### teste com retreino

In [None]:
# Gera janelas rolling
pairs_rolling = split_rolling(
    df=df,
    cutoff_col=cutoff_col,
    first_train_end=first_train_end,
    horizon_stride=horizon_stride,
    step_k=1,  # re-treina a cada janela
)

preds_roll = []
for tr_i, va_i in pairs_rolling:
    # opcional: limitar a janela de validação ao período de teste
    va_i = va_i[(va_i[cutoff_col] >= test_start) & (va_i[cutoff_col] <= test_end)]
    if va_i.empty:
        continue

    yhat_i, _ = fit_predict_lgbm(
        tr_i, va_i,
        features=features,
        id_col=id_col, date_col=date_col, target_col=target_col
    )

    tmp = va_i[[date_col, target_col] + ([id_col] if id_col in va_i.columns else [])].copy()
    tmp["y_pred"] = yhat_i
    tmp.rename(columns={target_col: "y_true"}, inplace=True)
    preds_roll.append(tmp)

# Concatena previsões de todas as janelas (dentro do período de teste)
df_pred_roll = pd.concat(preds_roll, ignore_index=True) if preds_roll else \
               pd.DataFrame(columns=[date_col, "y_true", "y_pred"])

# Métricas
m_roll = calculate_metrics(df_pred_roll, y_true="y_true", y_pred="y_pred")
m_roll.to_dict()


## teste 2

In [None]:
# Listas para acumular previsões e métricas por categoria
all_preds_static = []
all_preds_roll   = []
metrics_static_list = []
metrics_roll_list   = []

# Itera por categoria de produto
for cat, df_cat in df.groupby("product_category_name"):
    print(f"Treinando categoria: {cat}")

    # ---------- TREINO ESTÁTICO (por categoria) ----------
    pairs_static = split_static(
        df=df_cat,
        cutoff_col=cutoff_col,
        first_train_end=first_train_end,
        test_start=test_start,
        test_end=test_end,
    )

    (train_static, test_static) = pairs_static[0]

    # Se não houver dados de teste nessa categoria, pula
    if test_static.empty:
        print(f"  >> Sem dados de teste para categoria {cat}, pulando treino estático.")
    else:
        # Treina 1x no treino estático e prediz o teste
        yhat_static, mdl_static = fit_predict_lgbm(
            train_static, test_static,
            features=features,
            id_col=id_col, date_col=date_col, target_col=target_col
        )

        # Monta DF de avaliação padronizado
        base_cols = [date_col, target_col, "product_category_name"]
        if id_col in test_static.columns:
            base_cols.append(id_col)

        df_pred_static = test_static[base_cols].copy()
        df_pred_static["y_pred"] = yhat_static
        df_pred_static.rename(columns={target_col: "y_true"}, inplace=True)

        # Métricas por categoria
        m_static = calculate_metrics(df_pred_static, y_true="y_true", y_pred="y_pred")
        # Guarda as métricas em forma de dict, adicionando a categoria
        metrics_static_list.append({
            "product_category_name": cat,
            **m_static.to_dict()
        })

        # Guarda predições para concatenar depois
        all_preds_static.append(df_pred_static)

    # ---------- TREINO COM JANELAS (por categoria) ----------
    pairs_rolling = split_rolling(
        df=df_cat,
        cutoff_col=cutoff_col,
        first_train_end=first_train_end,
        horizon_stride=horizon_stride,
        step_k=1,  # re-treina a cada janela
    )

    preds_roll_cat = []
    for tr_i, va_i in pairs_rolling:
        # Limita a janela de validação ao período de teste global
        va_i = va_i[
            (va_i[cutoff_col] >= test_start) &
            (va_i[cutoff_col] <= test_end)
        ]
        if va_i.empty:
            continue

        yhat_i, _ = fit_predict_lgbm(
            tr_i, va_i,
            features=features,
            id_col=id_col, date_col=date_col, target_col=target_col
        )

        base_cols = [date_col, target_col, "product_category_name"]
        if id_col in va_i.columns:
            base_cols.append(id_col)

        tmp = va_i[base_cols].copy()
        tmp["y_pred"] = yhat_i
        tmp.rename(columns={target_col: "y_true"}, inplace=True)
        preds_roll_cat.append(tmp)

    # Se houve ao menos uma janela válida, calcula métricas para a categoria
    if preds_roll_cat:
        df_pred_roll_cat = pd.concat(preds_roll_cat, ignore_index=True)
        m_roll = calculate_metrics(df_pred_roll_cat, y_true="y_true", y_pred="y_pred")
        metrics_roll_list.append({
            "product_category_name": cat,
            **m_roll.to_dict()
        })

        all_preds_roll.append(df_pred_roll_cat)
    else:
        print(f"  >> Sem janelas válidas no período de teste para categoria {cat}, pulando treino rolling.")

# ---------- Consolida resultados de TODAS as categorias ----------

# DataFrames de predições consolidadas
df_pred_static_all = (
    pd.concat(all_preds_static, ignore_index=True)
    if all_preds_static else
    pd.DataFrame(columns=[date_col, "y_true", "y_pred", "product_category_name"] + ([id_col] if id_col in df.columns else []))
)

df_pred_roll_all = (
    pd.concat(all_preds_roll, ignore_index=True)
    if all_preds_roll else
    pd.DataFrame(columns=[date_col, "y_true", "y_pred", "product_category_name"] + ([id_col] if id_col in df.columns else []))
)

# Métricas por categoria (1 linha por categoria)
metrics_static_df = pd.DataFrame(metrics_static_list)  # treino estático
metrics_roll_df   = pd.DataFrame(metrics_roll_list)    # treino com janelas

# Se quiser um dict {categoria: métricas}
metrics_static_by_cat = {
    row["product_category_name"]: row.drop("product_category_name").to_dict()
    for _, row in metrics_static_df.iterrows()
}

metrics_roll_by_cat = {
    row["product_category_name"]: row.drop("product_category_name").to_dict()
    for _, row in metrics_roll_df.iterrows()
}


In [None]:
resultado = pd.DataFrame(metrics_roll_by_cat)
resultado

In [None]:
# calculando a média pelo index mape
resultado.mean(axis=1)


## teste 3

In [None]:
# Lista de modelos que vamos rodar
model_names = ["lgbm", "prophet", "sarima", "lstm"]

# Listas para acumular previsões e métricas por categoria
all_preds_static = []  # cada df terá colunas: y_true, y_pred_lgbm, ..., y_pred_ensemble
all_preds_roll   = []

metrics_static_list = []  # uma linha por (categoria, modelo, esquema)
metrics_roll_list   = []

# Itera por categoria de produto
for cat, df_cat in df.groupby("product_category_name"):
    print(f"Treinando categoria: {cat}")

    # ---------- TREINO ESTÁTICO (por categoria) ----------
    pairs_static = split_static(
        df=df_cat,
        cutoff_col=cutoff_col,
        first_train_end=first_train_end,
        test_start=test_start,
        test_end=test_end,
    )

    (train_static, test_static) = pairs_static[0]

    # Se não houver dados de teste nessa categoria, pula
    if test_static.empty:
        print(f"  >> Sem dados de teste para categoria {cat}, pulando treino estático.")
    else:
        # Base comum (sem previsões ainda)
        base_cols = [date_col, target_col, "product_category_name"]
        if id_col in test_static.columns:
            base_cols.append(id_col)

        df_pred_static = test_static[base_cols].copy()
        df_pred_static.rename(columns={target_col: "y_true"}, inplace=True)

        # Dicionário para guardar arrays de previsões por modelo
        preds_static_dict = {}

        # --- LGBM ---
        try:
            yhat_lgbm, mdl_static_lgbm = fit_predict_lgbm(
                train_static, test_static,
                features=features,
                id_col=id_col, date_col=date_col, target_col=target_col
            )
            preds_static_dict["lgbm"] = np.asarray(yhat_lgbm)
        except Exception as e:
            print(f"  >> Erro ao treinar LGBM (static) categoria {cat}: {e}")

        # --- Prophet ---
        try:
            # mesmas exógenas que você usou no run_scenario: tudo que não começa com 'sales_qty_'
            exog = [c for c in features if not c.startswith("sales_qty_")]
            yhat_prophet, mdl_static_prophet = fit_predict_prophet(
                train_static, test_static,
                features_exog=exog,
                date_col=date_col,
                target_col=target_col
            )
            preds_static_dict["prophet"] = np.asarray(yhat_prophet)
        except Exception as e:
            print(f"  >> Erro ao treinar Prophet (static) categoria {cat}: {e}")

        # --- SARIMA ---
        try:
            yhat_sarima, mdl_static_sarima = fit_predict_sarima(
                train_static, test_static,
                date_col=date_col,
                target_col=target_col,
                cfg=None  # ou passe um dict se tiver cfg["sarima"]
            )
            preds_static_dict["sarima"] = np.asarray(yhat_sarima)
        except Exception as e:
            print(f"  >> Erro ao treinar SARIMA (static) categoria {cat}: {e}")

        # --- LSTM ---
        try:
            yhat_lstm, mdl_static_lstm = fit_predict_lstm(
                train_static, test_static,
                date_col=date_col,
                target_col=target_col,
                cfg=None  # ou cfg com parâmetros de LSTM
            )
            preds_static_dict["lstm"] = np.asarray(yhat_lstm)
        except Exception as e:
            print(f"  >> Erro ao treinar LSTM (static) categoria {cat}: {e}")

        # Garante que todos os arrays têm o mesmo tamanho da base
        n_test = len(df_pred_static)
        preds_static_dict = {
            k: v for k, v in preds_static_dict.items()
            if len(v) == n_test
        }

        if not preds_static_dict:
            print(f"  >> Nenhum modelo válido (static) para categoria {cat}.")
        else:
            # Adiciona colunas de previsão por modelo
            for mname, yhat in preds_static_dict.items():
                df_pred_static[f"y_pred_{mname}"] = yhat

            # Ensemble (média simples dos modelos disponíveis)
            df_pred_static["y_pred_ensemble"] = ensemble_mean(preds_static_dict)

            # ---- Métricas por modelo (inclui ensemble) ----
            modelos_para_metricas = list(preds_static_dict.keys()) + ["ensemble"]

            for mname in modelos_para_metricas:
                col_pred = f"y_pred_{mname}"
                df_tmp = df_pred_static[[date_col, "y_true", col_pred, "product_category_name"]].copy()
                df_tmp.rename(columns={col_pred: "y_pred"}, inplace=True)

                m_static = calculate_metrics(df_tmp, y_true="y_true", y_pred="y_pred")
                metrics_static_list.append({
                    "product_category_name": cat,
                    "train_scheme": "static",
                    "model": mname,
                    **m_static.to_dict()
                })

            # Guarda predições para concatenar depois
            all_preds_static.append(df_pred_static)

    # ---------- TREINO COM JANELAS (por categoria) ----------
    pairs_rolling = split_rolling(
        df=df_cat,
        cutoff_col=cutoff_col,
        first_train_end=first_train_end,
        horizon_stride=horizon_stride,
        step_k=1,  # re-treina a cada janela
    )

    preds_roll_cat = []  # lista de DFs com y_true + preds de cada modelo por janela

    for tr_i, va_i in pairs_rolling:
        # Limita a janela de validação ao período de teste global
        va_i = va_i[
            (va_i[cutoff_col] >= test_start) &
            (va_i[cutoff_col] <= test_end)
        ]
        if va_i.empty:
            continue

        base_cols = [date_col, target_col, "product_category_name"]
        if id_col in va_i.columns:
            base_cols.append(id_col)

        tmp = va_i[base_cols].copy()
        tmp.rename(columns={target_col: "y_true"}, inplace=True)

        preds_roll_dict = {}

        # --- LGBM (rolling) ---
        try:
            yhat_i_lgbm, _ = fit_predict_lgbm(
                tr_i, va_i,
                features=features,
                id_col=id_col, date_col=date_col, target_col=target_col
            )
            preds_roll_dict["lgbm"] = np.asarray(yhat_i_lgbm)
        except Exception as e:
            print(f"  >> Erro LGBM (rolling) categoria {cat}: {e}")

        # --- Prophet (rolling) ---
        try:
            exog = [c for c in features if not c.startswith("sales_qty_")]
            yhat_i_prophet, _ = fit_predict_prophet(
                tr_i, va_i,
                features_exog=exog,
                date_col=date_col,
                target_col=target_col
            )
            preds_roll_dict["prophet"] = np.asarray(yhat_i_prophet)
        except Exception as e:
            print(f"  >> Erro Prophet (rolling) categoria {cat}: {e}")

        # --- SARIMA (rolling) ---
        try:
            yhat_i_sarima, _ = fit_predict_sarima(
                tr_i, va_i,
                date_col=date_col,
                target_col=target_col,
                cfg=None
            )
            preds_roll_dict["sarima"] = np.asarray(yhat_i_sarima)
        except Exception as e:
            print(f"  >> Erro SARIMA (rolling) categoria {cat}: {e}")

        # --- LSTM (rolling) ---
        try:
            yhat_i_lstm, _ = fit_predict_lstm(
                tr_i, va_i,
                date_col=date_col,
                target_col=target_col,
                cfg=None
            )
            preds_roll_dict["lstm"] = np.asarray(yhat_i_lstm)
        except Exception as e:
            print(f"  >> Erro LSTM (rolling) categoria {cat}: {e}")

        n_va = len(tmp)
        preds_roll_dict = {
            k: v for k, v in preds_roll_dict.items()
            if len(v) == n_va
        }

        if not preds_roll_dict:
            continue

        for mname, yhat in preds_roll_dict.items():
            tmp[f"y_pred_{mname}"] = yhat

        tmp["y_pred_ensemble"] = ensemble_mean(preds_roll_dict)

        preds_roll_cat.append(tmp)

    # Se houve ao menos uma janela válida, calcula métricas para a categoria
    if preds_roll_cat:
        df_pred_roll_cat = pd.concat(preds_roll_cat, ignore_index=True)

        modelos_para_metricas = [
            c.replace("y_pred_", "") for c in df_pred_roll_cat.columns
            if c.startswith("y_pred_")
        ]

        for mname in modelos_para_metricas:
            col_pred = f"y_pred_{mname}"
            df_tmp = df_pred_roll_cat[[date_col, "y_true", col_pred, "product_category_name"]].copy()
            df_tmp.rename(columns={col_pred: "y_pred"}, inplace=True)

            m_roll = calculate_metrics(df_tmp, y_true="y_true", y_pred="y_pred")
            metrics_roll_list.append({
                "product_category_name": cat,
                "train_scheme": "rolling",
                "model": mname,
                **m_roll.to_dict()
            })

        all_preds_roll.append(df_pred_roll_cat)
    else:
        print(f"  >> Sem janelas válidas no período de teste para categoria {cat}, pulando treino rolling.")

# ---------- Consolida resultados de TODAS as categorias ----------

# DataFrames de predições consolidadas
if all_preds_static:
    df_pred_static_all = pd.concat(all_preds_static, ignore_index=True)
else:
    base_cols = [date_col, "y_true", "product_category_name"] + (
        [id_col] if id_col in df.columns else []
    )
    df_pred_static_all = pd.DataFrame(columns=base_cols)

if all_preds_roll:
    df_pred_roll_all = pd.concat(all_preds_roll, ignore_index=True)
else:
    base_cols = [date_col, "y_true", "product_category_name"] + (
        [id_col] if id_col in df.columns else []
    )
    df_pred_roll_all = pd.DataFrame(columns=base_cols)

# Métricas por categoria / modelo / esquema
metrics_static_df = pd.DataFrame(metrics_static_list)
metrics_roll_df   = pd.DataFrame(metrics_roll_list)

# Se quiser um dict { (categoria, modelo, esquema): métricas }
metrics_static_by_key = {
    (row["product_category_name"], row["model"], row["train_scheme"]): row.drop(
        ["product_category_name", "model", "train_scheme"]
    ).to_dict()
    for _, row in metrics_static_df.iterrows()
}

metrics_roll_by_key = {
    (row["product_category_name"], row["model"], row["train_scheme"]): row.drop(
        ["product_category_name", "model", "train_scheme"]
    ).to_dict()
    for _, row in metrics_roll_df.iterrows()
}


In [None]:
metrics_roll_by_key