# Hybrid model TFT + Prophet is the OPTIMAL Combination

In [1]:
!pip install kaggle wandb onnx -Uq
from google.colab import drive
drive.mount('/content/drive')

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.6/17.6 MB[0m [31m120.7 MB/s[0m eta [36m0:00:00[0m
[?25hMounted at /content/drive


In [2]:
! mkdir ~/.kaggle
!cp /content/drive/MyDrive/ColabNotebooks/kaggle_API_credentials/kaggle.json ~/.kaggle/kaggle.json
! chmod 600 ~/.kaggle/kaggle.json
!kaggle competitions download -c walmart-recruiting-store-sales-forecasting
! unzip walmart-recruiting-store-sales-forecasting.zip
!unzip train.csv.zip
!unzip features.csv.zip

Downloading walmart-recruiting-store-sales-forecasting.zip to /content
  0% 0.00/2.70M [00:00<?, ?B/s]
100% 2.70M/2.70M [00:00<00:00, 860MB/s]
Archive:  walmart-recruiting-store-sales-forecasting.zip
  inflating: features.csv.zip        
  inflating: sampleSubmission.csv.zip  
  inflating: stores.csv              
  inflating: test.csv.zip            
  inflating: train.csv.zip           
Archive:  train.csv.zip
  inflating: train.csv               
Archive:  features.csv.zip
  inflating: features.csv            


In [3]:
!pip install torch mlflow dagshub scikit-learn pandas numpy matplotlib seaborn joblib -q wandb torch torchvision torchaudio -q prophet neuralforecast

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m88.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m67.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m51.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m792.3 kB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [4]:
import pandas as pd
import numpy as np
import logging
import time
import warnings
from datetime import datetime
import os
import sys
from io import StringIO

# Statistical libraries
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.preprocessing import LabelEncoder

# Prophet
try:
    from prophet import Prophet
    PROPHET_AVAILABLE = True
except ImportError:
    PROPHET_AVAILABLE = False
    print("❌ Prophet not available. Install with: pip install prophet")

# PyTorch and TFT components
try:
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.utils.data import Dataset, DataLoader
    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False
    print("❌ PyTorch not available. Install with: pip install torch")

# MLflow for logging
try:
    import mlflow
    import mlflow.pytorch
    MLFLOW_AVAILABLE = True
except ImportError:
    MLFLOW_AVAILABLE = False
    print("❌ MLflow not available. Install with: pip install mlflow")

warnings.filterwarnings('ignore')


In [16]:
#!/usr/bin/env python3
"""
Walmart Sales Forecasting - Hybrid Model (TFT + Prophet)
Adaptive weighting:
- Holiday periods: 60% Prophet + 40% TFT
- Regular days: 65% TFT + 35% Prophet
"""

import warnings
import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin, RegressorMixin
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder, FunctionTransformer
from sklearn.impute import SimpleImputer
from sklearn.metrics import mean_absolute_error, mean_squared_error
from neuralforecast import NeuralForecast
from neuralforecast.models import TFT
from prophet import Prophet
from statsmodels.tools.sm_exceptions import ValueWarning
import zipfile
import os
import logging
from datetime import datetime, timedelta

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore", category=ValueWarning)
warnings.filterwarnings("ignore")
logging.getLogger('prophet').setLevel(logging.WARNING)
logging.getLogger('cmdstanpy').setLevel(logging.WARNING)
logging.getLogger('cmdstanpy').setLevel(logging.ERROR)
pd.set_option("display.max_rows", 100)
pd.set_option("display.max_columns", None)

# Data path configuration
KAGGLE_DATA_PATH = "/kaggle/input/walmart-recruiting-store-sales-forecasting/"

def calculate_wmae(y_true, y_pred, is_holiday_flag, holiday_weight=5.0):
    """Calculate Weighted Mean Absolute Error (WMAE) as per competition rules."""
    abs_errors = np.abs(y_true - y_pred)
    weights = np.where(is_holiday_flag.astype(bool), holiday_weight, 1.0)
    wmae = np.sum(weights * abs_errors) / np.sum(weights)
    return wmae

# =============================================================================
# TFT Components (from model-exp-tft-fx.ipynb)
# =============================================================================

class DateFeatureCreator(BaseEstimator, TransformerMixin):
    def fit(self, X, y=None):
        return self

    def transform(self, X):
        X = X.copy()
        if "Date" not in X.columns:
            raise ValueError("DateFeatureCreator requires 'Date' column in input X.")

        # Ensure 'Date' is datetime type before operations
        if not pd.api.types.is_datetime64_any_dtype(X['Date']):
            X['Date'] = pd.to_datetime(X['Date'])

        # Using to_period('W') and then converting to integer week number
        # rank(method="dense") ensures consecutive integers for weeks
        X["week"] = (X["Date"].dt.to_period("W").rank(method="dense").astype(int) - 1)

        # Cyclical features for different periodicities
        X["sin_13"] = np.sin(2 * np.pi * X["week"] / 13) # Roughly quarterly seasonality
        X["cos_13"] = np.cos(2 * np.pi * X["week"] / 13)
        X["sin_23"] = np.sin(2 * np.pi * X["week"] / 23) # A different, less common periodicity
        X["cos_23"] = np.cos(2 * np.pi * X["week"] / 23)

        # Drop the original 'Date' column as its information is now in cyclical features
        X = X.drop(columns=["Date"], errors='ignore')
        return X

class ColumnDropper(BaseEstimator, TransformerMixin):
    def __init__(self, columns):
        self.columns = columns

    def fit(self, X, y=None):
        return self

    def transform(self, X):
        return X.drop(columns=self.columns, errors="ignore")

class ColumnTransformerWithNames(ColumnTransformer):
    """
    A wrapper around ColumnTransformer to retain column names and return a DataFrame.
    Handles OneHotEncoder output specifically.
    """
    def __init__(self, transformers, remainder='drop'):
        super().__init__(transformers=transformers, remainder=remainder)
        self.output_columns_ = None

    def fit(self, X, y=None):
        super().fit(X, y)
        self.output_columns_ = self._get_feature_names_out_internal(X)
        return self

    def _get_feature_names_out_internal(self, X):
        column_names = []
        for name, transformer, columns in self.transformers_:
            if transformer == 'drop':
                continue
            elif transformer == 'passthrough':
                # Ensure passthrough columns are correctly identified from original X
                if isinstance(columns, str):
                    column_names.append(columns)
                else:
                    column_names.extend(list(columns))
            else:
                if hasattr(transformer, 'get_feature_names_out'):
                    if isinstance(columns, str):
                        col_names = [columns]
                    else:
                        col_names = list(columns)
                    column_names.extend(list(transformer.get_feature_names_out(col_names)))
                else:
                    if isinstance(columns, str): # Fallback for transformers without get_feature_names_out
                        column_names.append(columns)
                    else:
                        column_names.extend(list(columns))
        return column_names

    def transform(self, X):
        transformed_array = super().transform(X)
        if self.output_columns_ is None:
             raise RuntimeError("ColumnTransformerWithNames must be fitted before transform.")

        # Convert to dense array if it's a sparse matrix (older sklearn default for OHE)
        if hasattr(transformed_array, 'toarray'):
            transformed_array = transformed_array.toarray()

        # Ensure that the index is preserved from the input X
        # This is CRITICAL for maintaining alignment with y
        return pd.DataFrame(transformed_array, index=X.index, columns=self.output_columns_)

    def fit_transform(self, X, y=None):
        transformed_array = super().fit_transform(X, y)
        self.output_columns_ = self._get_feature_names_out_internal(X)

        # Convert to dense array if it's a sparse matrix (older sklearn default for OHE)
        if hasattr(transformed_array, 'toarray'):
            transformed_array = transformed_array.toarray()

        # Ensure that the index is preserved from the input X
        # This is CRITICAL for maintaining alignment with y
        return pd.DataFrame(transformed_array, index=X.index, columns=self.output_columns_)

class MultiIndexKeeper(BaseEstimator, TransformerMixin):
    def __init__(self, index_cols=["Date", "Store", "Dept"]):
        self.index_cols = index_cols

    def fit(self, X, y=None):
        return self

    def transform(self, X):
        X = X.copy()

        if 'Date' in X.columns and not pd.api.types.is_datetime64_any_dtype(X['Date']):
            X['Date'] = pd.to_datetime(X['Date'])

        missing_cols = [col for col in self.index_cols if col not in X.columns]
        if missing_cols:
            raise ValueError(f"MultiIndexKeeper: Missing columns in input X: {missing_cols}")

        # IMPORTANT: When setting index, ensure 'Date', 'Store', 'Dept' columns
        # are not dropped, as they are needed later by NeuralForecast as covariates.
        # This is already handled by drop=False.
        X.set_index(self.index_cols, drop=False, inplace=True)
        return X

class TFTRegressor(BaseEstimator, RegressorMixin):
    def __init__(self, input_chunk_length=52, output_chunk_length=39, epochs=25, batch_size=32, random_seed=42):
        self.input_chunk_length = input_chunk_length
        self.output_chunk_length = output_chunk_length
        self.epochs = epochs
        self.batch_size = batch_size
        self.random_seed = random_seed
        self.nf_ = None
        self.model_ = None
        self.trained_df_ = None # Store the DataFrame used for training

    def fit(self, X, y):
        # Ensure y has a name for proper merging if it doesn't already
        if y.name is None:
            y.name = 'y'

        y_multiindexed = pd.Series(y.values, index=X.index, name='y')

        df = X.copy()
        df['y'] = y_multiindexed

        if not pd.api.types.is_datetime64_any_dtype(df.index.get_level_values('Date')):
            raise ValueError("MultiIndex 'Date' level is not datetime type. Ensure MultiIndexKeeper makes it datetime.")
        df['ds'] = df.index.get_level_values('Date')

        df["unique_id"] = df["Store"].astype(str) + "_" + df["Dept"].astype(str)

        # Store the prepared DataFrame for use in predict
        self.trained_df_ = df.copy() # Store the full df (including y and features)

        # --- DEBUGGING STEP: Check for NaNs immediately before NeuralForecast fit ---
        nan_check = df.isnull().sum()
        cols_with_nans = nan_check[nan_check > 0].index.tolist()
        if cols_with_nans:
            print(f"DEBUG: Found NaNs in the following columns before NeuralForecast fit: {cols_with_nans}")
            # Print head including index for better context
            print(df.loc[df[cols_with_nans[0]].isnull(), cols_with_nans].head())
            raise ValueError(f"Found missing values in {cols_with_nans}.")
        # --- END DEBUGGING STEP ---

        self.model_ = TFT(
            h=self.output_chunk_length,
            input_size=self.input_chunk_length,
            batch_size=self.batch_size,
            random_seed=self.random_seed,
        )

        self.nf_ = NeuralForecast(models=[self.model_], freq="W-FRI")

        self.nf_.fit(df=df)
        return self

    def predict(self, X):
        # X here is the transformed X_val from the pipeline, with MultiIndex ('Date', 'Store', 'Dept')

        # 1. Prepare future covariates from X (which is X_val after preprocessing)
        df_future_covariates_raw = X.copy() # This X contains all features for the validation period

        if not pd.api.types.is_datetime64_any_dtype(df_future_covariates_raw.index.get_level_values('Date')):
            raise ValueError("MultiIndex 'Date' level is not datetime type in predict. Ensure MultiIndexKeeper makes it datetime.")

        df_future_covariates_raw['ds'] = df_future_covariates_raw.index.get_level_values('Date')
        df_future_covariates_raw["unique_id"] = df_future_covariates_raw["Store"].astype(str) + "_" + df_future_covariates_raw["Dept"].astype(str)

        # Identify all covariate columns (all columns in self.trained_df_ except 'ds', 'unique_id', 'y')
        covariate_cols = [col for col in self.trained_df_.columns if col not in ['ds', 'unique_id', 'y']]

        # Select only the relevant future covariate columns and the required 'ds', 'unique_id'
        df_future_covariates_selected = df_future_covariates_raw[['ds', 'unique_id'] + covariate_cols].copy()

        # 2. Generate the full expected future dataframe for all series for the forecast horizon
        expected_future_df_template = self.nf_.make_future_dataframe(self.trained_df_)

        # 3. Merge the generated template with our actual future covariates (X_val)
        futr_df_complete = pd.merge(
            expected_future_df_template,
            df_future_covariates_selected,
            on=['unique_id', 'ds'],
            how='left'
        )

        nan_check_futr = futr_df_complete.isnull().sum()
        cols_with_nans_futr = nan_check_futr[nan_check_futr > 0].index.tolist()
        if cols_with_nans_futr:
            print(f"DEBUG: Found NaNs in futr_df_complete for columns: {cols_with_nans_futr}. Filling with 0.")
            futr_df_complete[cols_with_nans_futr] = futr_df_complete[cols_with_nans_futr].fillna(0)

        # 4. Perform the prediction
        # Ensure that `df` has all required unique_ids and `futr_df` aligns.
        # This is the point where the model generates predictions.
        forecast_df = self.nf_.predict(df=self.trained_df_, futr_df=futr_df_complete)

        forecast_df = forecast_df.rename(columns={'TFT': 'yhat'})

        if 'unique_id' in forecast_df.columns:
            forecast_df[['Store', 'Dept']] = forecast_df['unique_id'].str.split('_', expand=True)
            forecast_df['Store'] = forecast_df['Store'].astype(float).astype(int)
            forecast_df['Dept'] = forecast_df['Dept'].astype(float).astype(int)

        forecast_df['Date'] = pd.to_datetime(forecast_df['ds'])

        # Ensure the index columns are correct before setting the index
        # Also, make sure that 'Store' and 'Dept' are properly integer type before setting multi-index
        forecast_df_indexed = forecast_df.set_index(['Date', 'Store', 'Dept'])[['yhat']]

        # This is where NaNs can be introduced if forecast_df_indexed doesn't cover all X.index
        final_predictions = forecast_df_indexed.reindex(X.index)

        y_pred = final_predictions['yhat'].values.flatten()

        # FIX: Fill any NaNs in the final predictions array with 0 before evaluation
        if np.isnan(y_pred).any():
            print("DEBUG: Found NaNs in final y_pred after reindex. Filling with 0.")
            y_pred = np.nan_to_num(y_pred, nan=0.0)

        y_pred[y_pred < 0] = 0

        return y_pred

# =============================================================================
# Prophet Components (from model_exp_FX_Prophet.ipynb)
# =============================================================================

class WalmartProphetPreprocessingPipeline:
    """
    Preprocessing pipeline for Prophet models.
    Focuses on preparing data in the 'ds' (Date) and 'y' (Weekly_Sales) format,
    and handling holidays.
    """

    def __init__(self):
        self.fitted = False
        self.holidays_df = None

    def fit(self, train_data):
        """Fit the preprocessing pipeline (prepare holidays)."""
        print("🔧 Preparing Prophet specific data (holidays)...")

        # Prophet's holidays DataFrame: requires 'holiday', 'ds' columns
        # We define common US holidays that align with Walmart's IsHoliday flag
        self.holidays_df = pd.DataFrame([
            # Super Bowl: IsHoliday=True
            {'holiday': 'SuperBowl', 'ds': '2010-02-12'},
            {'holiday': 'SuperBowl', 'ds': '2011-02-11'},
            {'holiday': 'SuperBowl', 'ds': '2012-02-10'},
            # Labor Day: IsHoliday=True
            {'holiday': 'LaborDay', 'ds': '2010-09-10'},
            {'holiday': 'LaborDay', 'ds': '2011-09-09'},
            {'holiday': 'LaborDay', 'ds': '2012-09-07'},
            # Thanksgiving: IsHoliday=True
            {'holiday': 'Thanksgiving', 'ds': '2010-11-26'},
            {'holiday': 'Thanksgiving', 'ds': '2011-11-25'},
            {'holiday': 'Thanksgiving', 'ds': '2012-11-23'},
            # Christmas: IsHoliday=True (often last week of year in dataset)
            {'holiday': 'Christmas', 'ds': '2010-12-31'},
            {'holiday': 'Christmas', 'ds': '2011-12-30'},
            {'holiday': 'Christmas', 'ds': '2012-12-28'},
        ])
        self.holidays_df['ds'] = pd.to_datetime(self.holidays_df['ds'])

        # Ensure only holidays present in the training data date range are considered
        min_date = train_data['Date'].min()
        max_date = train_data['Date'].max()
        self.holidays_df = self.holidays_df[
            (self.holidays_df['ds'] >= min_date) &
            (self.holidays_df['ds'] <= max_date)
        ]

        print("✅ Pipeline fitted on training data with holiday-aware settings")
        self.fitted = True
        return self

    def transform(self, data, is_validation=False):
        """Transform data into Prophet's required format (ds, y)."""
        if not self.fitted:
            raise ValueError("Pipeline must be fitted before transform!")

        print(f"🔄 Transforming {'validation' if is_validation else 'training'} data...")

        df = data.copy()
        # Rename columns to Prophet's requirements
        df = df.rename(columns={'Date': 'ds', 'Weekly_Sales': 'y'})

        # Ensure 'y' (Weekly_Sales) is not negative, as sales cannot be negative
        df['y'] = df['y'].apply(lambda x: max(0, x))

        print(f"✅ Transform complete. Shape: {df.shape}")
        return df

def train_prophet_models(train_data_prophet, holidays_df, min_observations=50):
    """
    Trains Prophet models for each unique (Store, Dept) combination.
    """
    print(f"📈 Training Prophet models for each Store-Dept combination...")
    print(f"   ⏰ No time limit - training all combinations")

    unique_series_keys = train_data_prophet[['Store', 'Dept']].drop_duplicates().values
    total_combinations = len(unique_series_keys)
    print(f"   📊 Training models for {total_combinations} combinations")
    print(f"   🎯 Training Prophet for all combinations")

    models = {}
    successful_models = 0
    skipped_models_insufficient_data = 0
    failed_models_training_error = 0

    for i, (store_id, dept_id) in enumerate(unique_series_keys):
        # Progress update, similar to ARIMA
        # Print at start (index 0), every 200 models, and at the very end
        if i % 200 == 0 or i == total_combinations - 1:
            print(f"   ✅ Trained {i+1}/{total_combinations} models ({successful_models} successful, {skipped_models_insufficient_data + failed_models_training_error} failed)")

        series_data = train_data_prophet[
            (train_data_prophet['Store'] == store_id) &
            (train_data_prophet['Dept'] == dept_id)
        ].copy()

        # Check for minimum observations
        if len(series_data) < min_observations:
            skipped_models_insufficient_data += 1
            continue

        try:
            # Initialize Prophet model
            m = Prophet(
                yearly_seasonality=True,
                weekly_seasonality=True,
                holidays=holidays_df
            )

            # Fit the model
            m.fit(series_data)
            models[(store_id, dept_id)] = m
            successful_models += 1

        except Exception as e:
            failed_models_training_error += 1

    print(f"✅ Prophet training complete!")
    print(f"   🎯 Successful models: {successful_models}")
    print(f"   ❌ Failed models: {skipped_models_insufficient_data + failed_models_training_error}")
    print(f"   📊 Coverage: {successful_models}/{total_combinations} ({successful_models/total_combinations*100:.1f}%)")

    return models

def make_prophet_predictions(models, val_data_prophet, train_data=None):
    """
    Makes predictions using trained Prophet models for the validation period.
    """
    print("📈 Making Prophet predictions...")

    predictions = []
    actuals = []
    holidays_flags = []

    successful_predictions_count = 0
    skipped_predictions_no_model = 0
    failed_predictions_error = 0

    unique_val_series_keys = val_data_prophet[['Store', 'Dept']].drop_duplicates().values
    total_val_combinations = len(unique_val_series_keys)

    for i, (store_id, dept_id) in enumerate(unique_val_series_keys):
        # Get actual validation data for this series
        current_val_series_actuals = val_data_prophet[
            (val_data_prophet['Store'] == store_id) &
            (val_data_prophet['Dept'] == dept_id)
        ].copy()

        if current_val_series_actuals.empty:
            continue # No validation data for this series

        # Prepare future DataFrame for Prophet prediction
        # The future DataFrame should cover the exact dates in the validation set
        future_dates = pd.DataFrame({'ds': current_val_series_actuals['ds']})

        if (store_id, dept_id) in models:
            try:
                m = models[(store_id, dept_id)]
                forecast = m.predict(future_dates)
                yhat = forecast['yhat'].values

                # Ensure predictions are not negative
                yhat[yhat < 0] = 0

                predictions.extend(yhat)
                actuals.extend(current_val_series_actuals['y'].values)
                holidays_flags.extend(current_val_series_actuals['IsHoliday'].values)
                successful_predictions_count += len(yhat)

            except Exception as e:
                failed_predictions_error += len(current_val_series_actuals)
                predictions.extend(np.zeros(len(current_val_series_actuals)))
                actuals.extend(current_val_series_actuals['y'].values)
                holidays_flags.extend(current_val_series_actuals['IsHoliday'].values)
        else:
            skipped_predictions_no_model += len(current_val_series_actuals)
            predictions.extend(np.zeros(len(current_val_series_actuals)))
            actuals.extend(current_val_series_actuals['y'].values)
            holidays_flags.extend(current_val_series_actuals['IsHoliday'].values)

    print(f"✅ Predictions complete!")
    print(f"   🎯 Prophet predictions: {successful_predictions_count}")
    print(f"   ⏭️ Skipped (no model): {skipped_predictions_no_model}")

    return np.array(predictions), np.array(actuals), np.array(holidays_flags).astype(bool)

# =============================================================================
# Hybrid Model Implementation
# =============================================================================

class HybridTFTProphetRegressor(BaseEstimator, RegressorMixin):
    """
    Hybrid model combining TFT and Prophet with adaptive weighting:
    - Holiday periods: 60% Prophet + 40% TFT
    - Regular days: 65% TFT + 35% Prophet
    """

    def __init__(self, input_chunk_length=52, output_chunk_length=39, epochs=25, batch_size=32,
                 random_seed=42, min_observations=50):
        # TFT parameters
        self.input_chunk_length = input_chunk_length
        self.output_chunk_length = output_chunk_length
        self.epochs = epochs
        self.batch_size = batch_size
        self.random_seed = random_seed

        # Prophet parameters
        self.min_observations = min_observations

        # Model components
        self.tft_model_ = None
        self.prophet_models_ = None
        self.prophet_pipeline_ = None
        self.trained_data_ = None

        # Hybrid weights
        self.holiday_prophet_weight = 0.60
        self.holiday_tft_weight = 0.40
        self.regular_tft_weight = 0.65
        self.regular_prophet_weight = 0.35

    def fit(self, X, y):
        """Fit both TFT and Prophet models."""
        print("🤖 Training Hybrid Model (TFT + Prophet)...")

        # Store ORIGINAL training data for Prophet (before preprocessing that drops Date column)
        # We need to get the data before DateFeatureCreator drops the Date column
        original_data = X.copy()

        # The X here is already preprocessed (after MultiIndexKeeper, DateFeatureCreator, etc.)
        # We need to reconstruct the original format for Prophet
        # Since MultiIndexKeeper preserves the original columns with drop=False, we can use those
        if isinstance(X.index, pd.MultiIndex):
            # Get the original Date, Store, Dept from the MultiIndex
            original_data['Date'] = X.index.get_level_values('Date')
            original_data['Store'] = X.index.get_level_values('Store')
            original_data['Dept'] = X.index.get_level_values('Dept')

        # Reconstruct IsHoliday from one-hot encoded columns
        if 'IsHoliday_True' in original_data.columns and 'IsHoliday_False' in original_data.columns:
            original_data['IsHoliday'] = original_data['IsHoliday_True'].astype(bool)
        elif 'IsHoliday' not in original_data.columns:
            # If we can't reconstruct it, check if it's available elsewhere
            print("Warning: Could not reconstruct IsHoliday column")
            original_data['IsHoliday'] = False  # Default fallback

        original_data['Weekly_Sales'] = y

        # Reset MultiIndex for Prophet (it doesn't need MultiIndex)
        if isinstance(original_data.index, pd.MultiIndex):
            original_data = original_data.reset_index(drop=True)

        # Store for Prophet
        self.original_data_ = original_data

        # 1. Train Prophet model FIRST (using original data format)
        print("\n📊 Step 1: Training Prophet component...")

        # Use the original data for Prophet
        prophet_data = self.original_data_.copy()

        # Prophet expects specific columns: Date, Store, Dept, Weekly_Sales, IsHoliday
        required_prophet_cols = ['Date', 'Store', 'Dept', 'Weekly_Sales', 'IsHoliday']
        missing_cols = [col for col in required_prophet_cols if col not in prophet_data.columns]
        if missing_cols:
            print(f"Available columns: {list(prophet_data.columns)}")
            raise ValueError(f"Prophet component missing required columns: {missing_cols}")

        # Initialize Prophet pipeline
        self.prophet_pipeline_ = WalmartProphetPreprocessingPipeline()
        self.prophet_pipeline_.fit(prophet_data)

        # Transform data for Prophet
        prophet_train_data = self.prophet_pipeline_.transform(prophet_data, is_validation=False)

        # Train Prophet models
        self.prophet_models_ = train_prophet_models(
            prophet_train_data,
            self.prophet_pipeline_.holidays_df,
            min_observations=self.min_observations
        )

        print("✅ Prophet models training complete!")

        # 2. Train TFT model SECOND (using the preprocessed X)
        print("\n📊 Step 2: Training TFT component...")
        self.tft_model_ = TFTRegressor(
            input_chunk_length=self.input_chunk_length,
            output_chunk_length=self.output_chunk_length,
            epochs=self.epochs,
            batch_size=self.batch_size,
            random_seed=self.random_seed
        )
        self.tft_model_.fit(X, y)
        print("✅ TFT model training complete!")

        print("✅ Hybrid model training complete!")

        return self

    def predict(self, X):
        """Make hybrid predictions using adaptive weighting."""
        print("🔮 Making hybrid predictions...")

        # 1. Get Prophet predictions FIRST (need to reconstruct original format)
        print("   📊 Getting Prophet predictions...")

        # Reconstruct original data format for Prophet
        prophet_val_data = X.copy()

        # Get the original Date, Store, Dept from the MultiIndex
        if isinstance(X.index, pd.MultiIndex):
            prophet_val_data['Date'] = X.index.get_level_values('Date')
            prophet_val_data['Store'] = X.index.get_level_values('Store')
            prophet_val_data['Dept'] = X.index.get_level_values('Dept')

        # Reconstruct IsHoliday from one-hot encoded columns
        if 'IsHoliday_True' in prophet_val_data.columns and 'IsHoliday_False' in prophet_val_data.columns:
            prophet_val_data['IsHoliday'] = prophet_val_data['IsHoliday_True'].astype(bool)
        elif 'IsHoliday' not in prophet_val_data.columns:
            # If we can't reconstruct it, check if it's available elsewhere
            print("Warning: Could not reconstruct IsHoliday column in prediction")
            prophet_val_data['IsHoliday'] = False  # Default fallback

        prophet_val_data['Weekly_Sales'] = 0  # Dummy values for Prophet format

        # Reset MultiIndex for Prophet
        if isinstance(prophet_val_data.index, pd.MultiIndex):
            prophet_val_data = prophet_val_data.reset_index(drop=True)

        # Prophet expects specific columns: Date, Store, Dept, Weekly_Sales, IsHoliday
        required_prophet_cols = ['Date', 'Store', 'Dept', 'Weekly_Sales', 'IsHoliday']
        missing_cols = [col for col in required_prophet_cols if col not in prophet_val_data.columns]
        if missing_cols:
            print(f"Available columns: {list(prophet_val_data.columns)}")
            raise ValueError(f"Prophet component missing required columns: {missing_cols}")

        # Transform for Prophet
        prophet_val_transformed = self.prophet_pipeline_.transform(prophet_val_data, is_validation=True)

        # Get Prophet predictions
        prophet_predictions, _, _ = make_prophet_predictions(
            self.prophet_models_,
            prophet_val_transformed
        )

        # 2. Get TFT predictions SECOND (using preprocessed X)
        print("   📊 Getting TFT predictions...")
        tft_predictions = self.tft_model_.predict(X)

        # 3. Apply adaptive weighting
        print("   🎯 Applying adaptive weighting...")

        # Get holiday flags from reconstructed data (use the same logic as above)
        if 'IsHoliday_True' in X.columns and 'IsHoliday_False' in X.columns:
            is_holiday = X['IsHoliday_True'].values.astype(bool)
        else:
            print("Warning: Could not determine holiday status for weighting, using False")
            is_holiday = np.zeros(len(X), dtype=bool)

        # Initialize hybrid predictions
        hybrid_predictions = np.zeros_like(tft_predictions)

        # Apply different weights for holidays vs regular days
        holiday_mask = is_holiday
        regular_mask = ~is_holiday

        # Holiday periods: 60% Prophet + 40% TFT
        hybrid_predictions[holiday_mask] = (
            self.holiday_prophet_weight * prophet_predictions[holiday_mask] +
            self.holiday_tft_weight * tft_predictions[holiday_mask]
        )

        # Regular days: 65% TFT + 35% Prophet
        hybrid_predictions[regular_mask] = (
            self.regular_tft_weight * tft_predictions[regular_mask] +
            self.regular_prophet_weight * prophet_predictions[regular_mask]
        )

        # Ensure no negative predictions
        hybrid_predictions[hybrid_predictions < 0] = 0

        # Report weighting statistics
        n_holiday = np.sum(holiday_mask)
        n_regular = np.sum(regular_mask)
        total_points = len(hybrid_predictions)

        print(f"   📊 Weighting applied:")
        print(f"      Holiday points: {n_holiday:,} ({n_holiday/total_points*100:.1f}%) - 60% Prophet + 40% TFT")
        print(f"      Regular points: {n_regular:,} ({n_regular/total_points*100:.1f}%) - 65% TFT + 35% Prophet")

        print("✅ Hybrid predictions complete!")

        return hybrid_predictions

def main():
    """Main experiment execution for Hybrid model."""
    print("🚀 Starting Walmart Sales Forecasting with Hybrid Model (TFT + Prophet)")
    print("🎯 Adaptive weighting: Holiday (60% Prophet + 40% TFT), Regular (65% TFT + 35% Prophet)")
    print("=" * 80)

    try:
        # --- 1. Data Loading ---
        print("📊 Loading datasets...")

        # Check if running in Kaggle environment
        if os.path.exists(KAGGLE_DATA_PATH):
            # Kaggle environment
            train_zip_path = os.path.join(KAGGLE_DATA_PATH, 'train.csv.zip')
            features_zip_path = os.path.join(KAGGLE_DATA_PATH, 'features.csv.zip')
            stores_csv_path = os.path.join(KAGGLE_DATA_PATH, 'stores.csv')

            print("   📂 Unzipping necessary data files...")
            with zipfile.ZipFile(train_zip_path, 'r') as zip_ref:
                zip_ref.extractall('.')
            print(f"      - Extracted: {train_zip_path}")

            with zipfile.ZipFile(features_zip_path, 'r') as zip_ref:
                zip_ref.extractall('.')
            print(f"      - Extracted: {features_zip_path}")

            stores_df = pd.read_csv(stores_csv_path)
        else:
            # Local environment - assume files are in current directory
            print("   📂 Loading from local files...")
            if os.path.exists('train.csv.zip'):
                with zipfile.ZipFile('train.csv.zip', 'r') as zip_ref:
                    zip_ref.extractall('.')
            if os.path.exists('features.csv.zip'):
                with zipfile.ZipFile('features.csv.zip', 'r') as zip_ref:
                    zip_ref.extractall('.')
            stores_df = pd.read_csv('stores.csv')

        # Load the unzipped CSVs
        train_df = pd.read_csv('train.csv')
        features_df = pd.read_csv('features.csv')

        # Convert Date columns to datetime early for consistency
        train_df['Date'] = pd.to_datetime(train_df['Date'])
        features_df['Date'] = pd.to_datetime(features_df['Date'])

        print(f"   📈 Train data: {train_df.shape}")
        print(f"   📊 Features data: {features_df.shape}")
        print(f"   🏪 Stores data: {stores_df.shape}")

        # --- 2. Data Merging and Initial Cleaning ---
        print("\n🧹 Merging data and initial cleaning...")
        merged_df = pd.merge(train_df, features_df, on=['Store', 'Date', 'IsHoliday'], how='left')
        train_full = pd.merge(merged_df, stores_df, on=['Store'], how='left')

        # Fill NaN in MarkDown columns with 0, assuming no markdown if not specified
        markdown_cols = [f'MarkDown{i}' for i in range(1, 6)]
        for col in markdown_cols:
            if col in train_full.columns:
                train_full[col] = train_full[col].fillna(0)

        # Remove rows with negative Weekly_Sales
        initial_rows = len(train_full)
        train_full = train_full[train_full['Weekly_Sales'] > 0]
        print(f"   🗑️ Removed {initial_rows - len(train_full)} rows with negative Weekly_Sales.")

        # Define column lists
        numerical_cols = ['Temperature', 'Fuel_Price', 'CPI', 'Unemployment'] + [f'MarkDown{i}' for i in range(1, 6)]
        categorical_ohe_cols = ["Type", "IsHoliday"]
        passthrough_cols = ["Store", "Dept"]

        # Ensure continuous series and fill NaNs for NeuralForecast
        print("   Filling missing dates and sales for time series continuity...")

        # Create a full set of (Store, Dept, Date) unique combinations
        unique_store_dept_dates = train_full[['Store', 'Dept', 'Date']].drop_duplicates()

        # Generate all expected dates for each (Store, Dept)
        df_list = []
        for (store, dept), group in unique_store_dept_dates.groupby(['Store', 'Dept']):
            series_min_date = group['Date'].min()
            series_max_date = group['Date'].max()
            full_series_dates = pd.date_range(start=series_min_date, end=series_max_date, freq='W-FRI')

            temp_df = pd.DataFrame({
                'Store': store,
                'Dept': dept,
                'Date': full_series_dates
            })
            df_list.append(temp_df)

        complete_series_df = pd.concat(df_list, ignore_index=True)

        # Merge the complete series dates with the original train_full data
        train_full_cleaned = pd.merge(
            complete_series_df,
            train_full,
            on=['Store', 'Dept', 'Date'],
            how='left'
        )

        # Fill NaNs in 'Weekly_Sales' with 0
        nan_sales_before_fill = train_full_cleaned['Weekly_Sales'].isnull().sum()
        train_full_cleaned['Weekly_Sales'] = train_full_cleaned['Weekly_Sales'].fillna(0)
        print(f"   Filled {nan_sales_before_fill} NaN Weekly_Sales values with 0 for series continuity.")

        # Re-merge the original features_df and stores_df to fill in associated data
        train_full_cleaned = pd.merge(train_full_cleaned, features_df.drop(columns=['IsHoliday'], errors='ignore'), on=['Store', 'Date'], how='left', suffixes=('', '_feats'))
        train_full_cleaned = pd.merge(train_full_cleaned, stores_df, on=['Store'], how='left', suffixes=('', '_stores'))

        # Combine IsHoliday if it was duplicated
        if 'IsHoliday_feats' in train_full_cleaned.columns:
            train_full_cleaned['IsHoliday'] = train_full_cleaned['IsHoliday'].fillna(train_full_cleaned['IsHoliday_feats'])
            train_full_cleaned = train_full_cleaned.drop(columns=['IsHoliday_feats'])

        # Fill NaNs in features columns
        for col in numerical_cols:
            if col in train_full_cleaned.columns:
                train_full_cleaned[col] = train_full_cleaned.groupby(['Store', 'Dept'])[col].transform(lambda x: x.fillna(x.mean()))
                train_full_cleaned[col] = train_full_cleaned[col].fillna(train_full_cleaned[col].mean())

        # Ensure no negative sales after any filling process
        train_full_cleaned['Weekly_Sales'][train_full_cleaned['Weekly_Sales'] < 0] = 0

        # Sort by date, store, and department for time series consistency
        train_full = train_full_cleaned.sort_values(by=['Date', 'Store', 'Dept']).reset_index(drop=True)

        print(f"   ✅ Merged and cleaned data: {train_full.shape}")
        print(f"   📅 Date range: {train_full['Date'].min()} to {train_full['Date'].max()}")
        print(f"   Sanity check: NaNs in Weekly_Sales after cleaning: {train_full['Weekly_Sales'].isnull().sum()}")

        # --- 3. Data Splitting (80/20 Time-based) ---
        print("\n📅 Step 1: Creating temporal split (80/20)...")

        df_sorted = train_full.sort_values('Date').reset_index(drop=True)

        unique_dates = sorted(df_sorted['Date'].unique())
        total_weeks = len(unique_dates)
        train_ratio = 0.8
        train_weeks = int(total_weeks * train_ratio)

        if train_weeks < 1:
            train_weeks = 1
        if train_weeks >= total_weeks:
            train_weeks = total_weeks - 1

        split_date = unique_dates[train_weeks - 1]

        X_train = df_sorted[df_sorted['Date'] <= split_date].drop(columns=['Weekly_Sales']).copy()
        y_train = df_sorted[df_sorted['Date'] <= split_date]['Weekly_Sales'].copy()
        X_val = df_sorted[df_sorted['Date'] > split_date].drop(columns=['Weekly_Sales']).copy()
        y_val = df_sorted[df_sorted['Date'] > split_date]['Weekly_Sales'].copy()

        print(f"   📊 Split date: {split_date}")
        print(f"   📈 Train: {len(X_train):,} records ({X_train['Date'].min()} to {X_train['Date'].max()})")
        print(f"   📉 Val: {len(X_val):,} records ({X_val['Date'].min()} to {X_val['Date'].max()})")

        # --- 4. Pipeline Definition ---
        print("\n⚙️ Step 2: Defining preprocessing pipeline...")

        numerical_transformer = Pipeline(steps=[
            ('imputer', SimpleImputer(strategy='mean'))
        ])

        categorical_transformer = Pipeline(steps=[
            ('imputer', SimpleImputer(strategy='most_frequent')),
            ('onehot', OneHotEncoder(handle_unknown='ignore'))
        ])

        preprocessor = ColumnTransformerWithNames(transformers=[
            ('num', numerical_transformer, numerical_cols),
            ('cat', categorical_transformer, categorical_ohe_cols),
            ('pass', 'passthrough', passthrough_cols)
        ], remainder='drop')

        pipeline = Pipeline([
            ("multi_index_keeper", MultiIndexKeeper(index_cols=["Date", "Store", "Dept"])),
            ("date_feature_creator", DateFeatureCreator()),
            ("preprocessor", preprocessor),
            ("hybrid_regressor", HybridTFTProphetRegressor(
                input_chunk_length=52,
                output_chunk_length=39,
                epochs=25,
                batch_size=32,
                random_seed=42
            ))
        ])

        print("   ✅ Pipeline defined.")

        # --- 5. Model Training ---
        print("\n🧠 Step 3: Training Hybrid model...")
        pipeline.fit(X_train, y_train)
        print("   ✅ Hybrid Model Training Complete!")

        # --- 6. Model Evaluation ---
        print("\n📊 Step 4: Evaluating model on validation set...")
        y_pred_val = pipeline.predict(X_val)

        y_pred_val[y_pred_val < 0] = 0

        is_holiday_val = X_val['IsHoliday'].values.astype(bool)

        if len(y_val) > 0:
            val_mae = mean_absolute_error(y_val, y_pred_val)
            val_rmse = np.sqrt(mean_squared_error(y_val, y_pred_val))
            val_wmae = calculate_wmae(y_val, y_pred_val, is_holiday_val)
        else:
            val_mae, val_rmse, val_wmae = 0, 0, 0
            print("   ⚠️ Warning: No data points for evaluation. Metrics set to 0.")

        holiday_mask_val = is_holiday_val.astype(bool)
        holiday_mae_val = mean_absolute_error(y_val[holiday_mask_val], y_pred_val[holiday_mask_val]) if holiday_mask_val.any() else np.nan
        non_holiday_mae_val = mean_absolute_error(y_val[~holiday_mask_val], y_pred_val[~holiday_mask_val]) if (~holiday_mask_val).any() else np.nan

        print("\n" + "=" * 60)
        print("🎯 EXPERIMENT HYBRID RESULTS SUMMARY")
        print("=" * 60)

        print("\n📊 Validation Metrics:")
        print(f"   WMAE (Competition Metric): ${val_wmae:,.2f}")
        print(f"   MAE: ${val_mae:,.2f}")
        print(f"   RMSE: ${val_rmse:,.2f}")

        print("\n📊 Holiday Breakdown:")
        print(f"   Holiday MAE: ${holiday_mae_val:,.2f} ({int(holiday_mask_val.sum()):,} samples)")
        print(f"   Non-Holiday MAE: ${non_holiday_mae_val:,.2f} ({int((~holiday_mask_val).sum()):,} samples)")

        print("\n📊 Hybrid Model Configuration:")
        print(f"   Holiday periods: 60% Prophet + 40% TFT")
        print(f"   Regular periods: 65% TFT + 35% Prophet")

        print("\n🎉 Experiment Hybrid: Complete!")

    except Exception as e:
        print(f"❌ Experiment failed: {e}")
        raise

if __name__ == "__main__":
    main()


🚀 Starting Walmart Sales Forecasting with Hybrid Model (TFT + Prophet)
🎯 Adaptive weighting: Holiday (60% Prophet + 40% TFT), Regular (65% TFT + 35% Prophet)
📊 Loading datasets...
   📂 Loading from local files...
   📈 Train data: (421570, 5)
   📊 Features data: (8190, 12)
   🏪 Stores data: (45, 3)

🧹 Merging data and initial cleaning...
   🗑️ Removed 1358 rows with negative Weekly_Sales.
   Filling missing dates and sales for time series continuity...
   Filled 26507 NaN Weekly_Sales values with 0 for series continuity.
   ✅ Merged and cleaned data: (446719, 27)
   📅 Date range: 2010-02-05 00:00:00 to 2012-10-26 00:00:00
   Sanity check: NaNs in Weekly_Sales after cleaning: 0

📅 Step 1: Creating temporal split (80/20)...
   📊 Split date: 2012-04-06 00:00:00
   📈 Train: 357,562 records (2010-02-05 00:00:00 to 2012-04-06 00:00:00)
   📉 Val: 89,157 records (2012-04-13 00:00:00 to 2012-10-26 00:00:00)

⚙️ Step 2: Defining preprocessing pipeline...
   ✅ Pipeline defined.

🧠 Step 3: Training

INFO:lightning_fabric.utilities.seed:Seed set to 42
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name                    | Type                     | Params | Mode 
-----------------------------------------------------------------------------
0 | loss                    | MAE                      | 0      | train
1 | padder_train            | ConstantPad1d            | 0      | train
2 | scaler                  | TemporalNorm             | 0      | train
3 | embedding               | TFTEmbedding             | 512    | train
4 | temporal_encoder        | TemporalCovariateEncoder | 613 K  | train
5 | temporal_fusion_decoder | TemporalFusionDecoder    | 256 K

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_steps=1000` reached.


✅ TFT model training complete!
✅ Hybrid model training complete!
   ✅ Hybrid Model Training Complete!

📊 Step 4: Evaluating model on validation set...
🔮 Making hybrid predictions...
   📊 Getting Prophet predictions...
🔄 Transforming validation data...
✅ Transform complete. Shape: (89157, 19)
📈 Making Prophet predictions...
✅ Predictions complete!
   🎯 Prophet predictions: 88083
   ⏭️ Skipped (no model): 1074
   📊 Getting TFT predictions...
DEBUG: Found NaNs in futr_df_complete for columns: ['Temperature', 'Fuel_Price', 'CPI', 'Unemployment', 'MarkDown1', 'MarkDown2', 'MarkDown3', 'MarkDown4', 'MarkDown5', 'Type_A', 'Type_B', 'Type_C', 'IsHoliday_False', 'IsHoliday_True', 'Store', 'Dept']. Filling with 0.


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

DEBUG: Found NaNs in final y_pred after reindex. Filling with 0.
   🎯 Applying adaptive weighting...
   📊 Weighting applied:
      Holiday points: 2,957 (3.3%) - 60% Prophet + 40% TFT
      Regular points: 86,200 (96.7%) - 65% TFT + 35% Prophet
✅ Hybrid predictions complete!

🎯 EXPERIMENT HYBRID RESULTS SUMMARY

📊 Validation Metrics:
   WMAE (Competition Metric): $5,146.98
   MAE: $5,338.29
   RMSE: $9,717.69

📊 Holiday Breakdown:
   Holiday MAE: $4,502.58 (6,617 samples)
   Non-Holiday MAE: $5,405.28 (82,540 samples)

📊 Hybrid Model Configuration:
   Holiday periods: 60% Prophet + 40% TFT
   Regular periods: 65% TFT + 35% Prophet

🎉 Experiment Hybrid: Complete!
