In [None]:
# Databricks notebook source
# üîç Notebook de Debug: Valida√ß√£o de M√©tricas e Identifica√ß√£o de Lojas
# Objetivo: Identificar por que o 'codigo_loja' n√£o est√° sendo salvo corretamente na tabela de m√©tricas.
# Este notebook in-lineia e instrumenta as fun√ß√µes cr√≠ticas de `src/validation/data.py` e `src/validation/trainer.py`.

# --- IMPORTS ---
%load_ext autoreload
%autoreload 2

import sys
import pickle
import os
import hashlib
import traceback
import pandas as pd
import numpy as np
import mlflow
from typing import List, Tuple, Optional, Any, Dict
from pyspark.sql import SparkSession, DataFrame
import pyspark.sql.functions as F
from mlflow.models import ModelSignature
from mlflow.types.schema import Schema, ColSpec

# Adiciona diret√≥rio raiz ao path para imports relativos funcionarem se necess√°rio
sys.path.append(os.getcwd())

from src.validation.config import Config
from src.validation.data import DataIngestion 
from src.validation.pipeline import ProjectPipeline
from src.validation.trainer import ModelTrainer
from src.validation.models import DartsWrapper

from darts import TimeSeries
from darts.metrics import mape, mse, rmse, r2_score, smape
from darts.models import LinearRegressionModel

# --- CONFIGURA√á√ÉO ---
if 'spark' not in locals():
    spark = SparkSession.builder.getOrCreate()

config = Config(spark)
print(f"üîß Config Loaded: Version {config.VERSION}")
print(f"üìÖ Data Start: {config.DATA_START}, Train End: {config.TRAIN_END_DATE}, Ingestion End: {config.INGESTION_END}")

# COMMAND ----------

# --- 1. DATA INGESTION DEBUGGING (Exploded Function) ---
# Esta fun√ß√£o substitui DataIngestion.build_darts_objects com prints de debug focados no indice de estaticas

def build_darts_objects_debug(
    df_spark_wide: DataFrame, 
    df_global_support: pd.DataFrame
) -> Tuple[List[TimeSeries], List[TimeSeries]]:
    
    print("\nüêõ [DEBUG] Iniciando build_darts_objects_debug...")
    
    # 1. Spark para Pandas
    print("‚öôÔ∏è Materializando dados do Spark para Pandas (Driver)...")
    df_wide = df_spark_wide.toPandas()
    
    # Dedup cleaning
    df_wide = df_wide.loc[:, ~df_wide.columns.duplicated()]
    
    # Fix codigo_loja conversion issues
    if "codigo_loja" in df_wide.columns:
         col_obj = df_wide["codigo_loja"]
         if isinstance(col_obj, pd.DataFrame):
              print("   ‚ö†Ô∏è CRITICAL: 'codigo_loja' is still a DataFrame (duplicate columns)!")
              df_wide = df_wide.loc[:, ~df_wide.columns.duplicated(keep='first')]

    df_wide['data'] = pd.to_datetime(df_wide['data'])
    
    # Define Static Cols
    possible_static = ["cluster_loja", "sigla_uf", "tipo_loja", "modelo_loja"]
    static_cols = [c for c in possible_static if c in df_wide.columns]
    
    print(f"   ‚ÑπÔ∏è Colunas Est√°ticas identificadas: {static_cols}")

    # 2. Criar Target Series
    print("   Build: Criando Target Series (Vetorizado)...")
    try:
        target_series_list = TimeSeries.from_group_dataframe(
            df_wide,
            group_cols="codigo_loja",
            time_col="data",
            value_cols="target_vendas",
            static_cols=static_cols,
            freq='D',
            fill_missing_dates=True,
            fillna_value=0.0
        )
    except Exception as e:
        print(f"‚ùå Erro cr√≠tico no from_group_dataframe (Target): {e}")
        raise e

    # --- DEBUG CR√çTICO: Verificar IDs nas Covari√°veis Est√°ticas ---
    print("\nüîé [DEBUG] Verificando IDs gerados em 'target_series_list' (RAW):")
    target_dict = {}
    for i, ts in enumerate(target_series_list[:5]): # Mostra apenas os 5 primeiros
        if ts.static_covariates is not None:
            # Tenta pegar o ID do index
            idx_name = ts.static_covariates.index.name
            idx_val = ts.static_covariates.index[0]
            print(f"   üëâ Series[{i}] - Index Name: '{idx_name}' | Value: '{idx_val}' (Type: {type(idx_val)})")
        else:
            print(f"   ‚ùå Series[{i}] - Static Covariates is None!")

    # Processamento Normal
    for ts in target_series_list:
        if ts.static_covariates is not None and not ts.static_covariates.empty:
            if ts.static_covariates.index.name == "target_vendas":
                 ts.with_static_covariates(ts.static_covariates.rename_axis("codigo_loja"))
            
            key_val = str(ts.static_covariates.index[0]).replace(".0", "")
            target_dict[key_val] = ts
    
    valid_stores = list(target_dict.keys())
    print(f"   ‚úÖ Total Lojas V√°lidas (target_dict keys): {len(valid_stores)}")
    if len(valid_stores) > 0:
        print(f"   Exemplo de keys: {valid_stores[:5]}")

    # 3. Criar Covari√°veis Locais (Feriados)
    # (Simplificado para o debug, focamos no target que carrega o ID principal)
    # ... (Reusing original logic implicitly via filtered dataframe if needed, 
    # but for identifying the saving error, the target list is usually the source of truth for ordering)

    # Retorna usando a logica original simplificada para focar no erro de ID
    # Recriando lista ordenada
    final_target_list = list(target_dict.values())
    
    # Mock de covari√°veis globais para n√£o quebrar pipeline
    # Criando dummy covariates apenas para passar no pipeline
    print("   Build: Gerando dummy covariates para teste...")
    full_covariates_list = []
    for ts in final_target_list:
        # Cria covari√°vel dummy zerada
        cov = TimeSeries.from_times_and_values(
            ts.time_index, 
            np.zeros((len(ts), 1)), 
            freq='D', 
            columns=['dummy_cov']
        )
        full_covariates_list.append(cov)
        
    return final_target_list, full_covariates_list

# COMMAND ----------

# --- 2. MODEL TRAINER DEBUGGING (Exploded Function) ---

def extract_id_debug(ts: TimeSeries, stage: str = "UNKNOWN") -> str:
    """Extrai ID com prints de debug"""
    try:
        if ts.static_covariates is not None:
            if not ts.static_covariates.empty:
                val = str(ts.static_covariates.index[0])
                idx_name = ts.static_covariates.index.name
                # print(f"      [DEBUG ID {stage}] Found Val: '{val}', Index: '{idx_name}'")
                if val.endswith(".0"): val = val[:-2]
                return val
            else:
                print(f"      [DEBUG ID {stage}] Static Covariates is Empty!")
        else:
            print(f"      [DEBUG ID {stage}] Static Covariates is None!")
    except Exception as e:
        print(f"      [DEBUG ID {stage}] Exception: {e}")
    return "UNKNOWN"

def train_evaluate_walkforward_debug(
    config: Any,
    train_series_static: List[TimeSeries], 
    full_series_scaled: List[TimeSeries], 
    val_series_original: List[TimeSeries], 
    target_pipeline: Any
) -> None:
    
    print("\nüêõ [DEBUG] Iniciando train_evaluate_walkforward_debug...")
    
    # --- CHECK 1: Verificando IDs na entrada da fun√ß√£o (Scaled Series) ---
    print("\nüîé [DEBUG] CHECK 1: Inspecionando 'full_series_scaled' (Onde a l√≥gica original busca os IDs)...")
    debug_ids = []
    for i, ts in enumerate(full_series_scaled[:5]):
        extracted = extract_id_debug(ts, stage="SCALED_INPUT")
        debug_ids.append(extracted)
        print(f"   Series[{i}] (Scaled) -> Extracted ID: '{extracted}'")
        if ts.static_covariates is None:
             print("   ‚ö†Ô∏è AVISO: static_covariates desapareceu ap√≥s scaling!")
    
    ordered_store_ids = [extract_id_debug(ts, stage="SCALED_FULL") for ts in full_series_scaled]
    
    # Valida quantos UNKNOWN temos
    unknown_count = ordered_store_ids.count("UNKNOWN")
    print(f"\nüìä Total IDs extra√≠dos: {len(ordered_store_ids)}")
    print(f"‚ö†Ô∏è Total 'UNKNOWN': {unknown_count}")
    print(f"üìù Primeiros 10 IDs: {ordered_store_ids[:10]}")
    
    if unknown_count == len(ordered_store_ids):
        print("\nüö®üö® ERRO CR√çTICO IDENTIFICADO: Todos os IDs s√£o UNKNOWN ap√≥s o scaling.")
        print("   CAUSA PROV√ÅVEL: O 'ProjectPipeline' ou seus Transformers est√£o removendo/resetando o √≠ndice das StaticCovariates.")
        print("   Recomenda√ß√£o: Verificar 'src/validation/pipeline.py' e 'StaticCovariatesTransformer'.")
        return # Para aqui pois n√£o adianta continuar

    # Se tivermos IDs, simulamos a valida√ß√£o de 1 m√™s
    print("\nüîÑ Simulando loop de valida√ß√£o para verificar salvamento...")
    
    # Mock predictions (copia do real para teste)
    preds = full_series_scaled 
    
    # Inverso transform
    print("   Invertendo transforma√ß√£o (Inverse Transform)...")
    preds_inverse = target_pipeline.inverse_transform(preds, partial=True)
    
    # --- CHECK 2: Verificando IDs ap√≥s Inverse Transform ---
    print("\nüîé [DEBUG] CHECK 2: Inspecionando 'preds_inverse' (ap√≥s inverse transform)...")
    for i, ts in enumerate(preds_inverse[:3]):
        extracted = extract_id_debug(ts, stage="INVERSE_PRED")
        print(f"   PredsInverse[{i}] -> Extracted ID (pode ser perdido aqui, mas n√£o afeta ordena√ß√£o): '{extracted}'")

    # Simula _calc_metrics_and_format zip
    print("\nü§ù [DEBUG] Simulando ZIP para montar DataFrame de m√©tricas...")
    
    res_dfs = []
    # Zipar usando a lista ordered_store_ids original
    for i, (ts_pred, ts_real_full, store_id) in enumerate(zip(preds_inverse, val_series_original, ordered_store_ids)):
        if i >= 5: break # Apenas 5
        
        print(f"   Itera√ß√£o {i}: Store ID from List = '{store_id}'")
        
        try:
            # Simula slice
            ts_real_sliced = ts_real_full.slice_intersect(ts_pred)
            
            # Monta DF
            df_row = pd.DataFrame({
                'data': ts_pred.time_index,
                'previsao': ts_pred.values().flatten(),
                'real': ts_real_sliced.values().flatten(),
                'codigo_loja': store_id, # << PONTO DE FALHA SE STORE_ID FOR UNKNOWN
                'modelo': 'DEBUG_MODEL',
                'metrica_mes': '2025-01'
            })
            res_dfs.append(df_row)
            print(f"     ‚úÖ DataFrame criado para '{store_id}'. Shape: {df_row.shape}")
            if store_id == "UNKNOWN":
                print("     ‚ö†Ô∏è ALERTA: Salvando registro com codigo_loja='UNKNOWN'")
                
        except Exception as e:
            print(f"     ‚ùå Erro na itera√ß√£o {i}: {e}")

    if res_dfs:
        final_df = pd.concat(res_dfs)
        print("\nüìä Amostra do DataFrame Final:")
        print(final_df[['data', 'codigo_loja', 'previsao']].head())
    else:
        print("\n‚ùå Nenhuma previs√£o gerada.")

# COMMAND ----------

# --- 3. EXECU√á√ÉO DO FLUXO DE DEBUG ---

print("üöÄ INICIANDO PIPELINE DE DEBUG üöÄ")

# 1. Ingest√£o RAW (Usando original para pegar dados base)
ingestion = DataIngestion(spark, config)
# Pegamos apenas uma amostra ou dados reais filtrados
print("üõí Carregando dados (pode demorar um pouco)...")
df_spark_wide = ingestion.create_training_set() # Spark DF

# Filtra no Spark para ser r√°pido (apenas algumas lojas se poss√≠vel, ou usa tudo se n√£o for gigante)
# df_spark_wide = df_spark_wide.limit(10000) # Opcional: limitar para teste r√°pido

# 2. Build Objects (Logic Modified)
df_support_global = ingestion.get_global_support()
raw_series, raw_covs = build_darts_objects_debug(df_spark_wide, df_support_global)

if not raw_series:
    print("‚ùå Abortando: Nenhuma s√©rie retornada.")
else:
    # 3. Pipeline Transform
    print("\nüõ†Ô∏è Executando ProjectPipeline (Fit/Transform)...")
    project_pipeline = ProjectPipeline()
    
    # Split simples para fit
    train_cutoff_date = pd.Timestamp(config.TRAIN_END_DATE) - pd.Timedelta(days=1)
    train_for_fit = [s.drop_after(train_cutoff_date) for s in raw_series]
    cov_for_fit = [s for s in raw_covs] # Dummy covs
    
    project_pipeline.fit(train_for_fit, cov_for_fit)
    
    print("üîÑ Transformando s√©ries...")
    # AQUI OCORRE O ERRO POTENCIAL DE PERDA DE STATIC COVARIATES
    series_scaled_full, cov_scaled_full = project_pipeline.transform(raw_series, raw_covs)
    
    # 4. Trainer Debug
    train_evaluate_walkforward_debug(
        config,
        train_series_static=train_for_fit, # nao usado intensamente no debug
        full_series_scaled=series_scaled_full, # << IMPORTANTE: Validar se index static persisitiu aqui
        val_series_original=raw_series,
        target_pipeline=project_pipeline
    )
