In [None]:
# Model Inference - Walmart Store Sales Forecasting
# =================================================

import pandas as pd
import numpy as np
import joblib
import wandb
from datetime import datetime, timedelta
import warnings
warnings.filterwarnings('ignore')

# ============================================================
# Load Best Model from Model Registry
# ============================================================

def load_best_model_from_registry():
    """Load the best model from wandb model registry"""
    
    # Initialize wandb
    run = wandb.init(
        project="Walmart Recruiting - Store Sales Forecasting",
        job_type="inference"
    )
    
    # Download model artifact
    artifact = run.use_artifact('nbeats_final_model:latest', type='model')
    artifact_dir = artifact.download()
    
    # Load the model
    model_path = f"{artifact_dir}/nbeats_final_model.pkl"
    model = joblib.load(model_path)
    
    print(f"Model loaded from: {model_path}")
    return model

def load_local_model(model_path='nbeats_final_model.pkl'):
    """Load model from local file (fallback option)"""
    model = joblib.load(model_path)
    print(f"Model loaded from: {model_path}")
    return model

# ============================================================
# Data Preprocessing for Test Set
# ============================================================

class TestDataProcessor:
    """Process test data for inference"""
    
    def __init__(self):
        pass
    
    def load_test_data(self, stores_path, features_path, test_path):
        """Load test data files"""
        self.stores = pd.read_csv(stores_path)
        self.features = pd.read_csv(features_path)
        self.test = pd.read_csv(test_path)
        return self
    
    def preprocess_test_data(self, merge_features=True, merge_stores=True):
        """Preprocess test data to match training format"""
        df = self.test.copy()
        
        # Convert Date to datetime
        df['Date'] = pd.to_datetime(df['Date'])
        
        # Merge with stores data
        if merge_stores and hasattr(self, 'stores'):
            df = df.merge(self.stores, on='Store', how='left')
        
        # Merge with features data
        if merge_features and hasattr(self, 'features'):
            df = df.merge(self.features, on=['Store', 'Date'], how='left')
            
        # Fill missing values
        df = df.fillna(0)
        
        # Sort by Store, Dept, Date
        df = df.sort_values(['Store', 'Dept', 'Date'])
        
        # Create unique_id for N-BEATS (combination of Store and Dept)
        df['unique_id'] = df['Store'].astype(str) + '_' + df['Dept'].astype(str)
        
        # Rename columns for N-BEATS format
        df = df.rename(columns={'Date': 'ds'})
        
        # Select relevant columns (no 'y' column in test data)
        columns_to_keep = ['unique_id', 'ds', 'Id']  # Keep Id for submission
        if 'IsHoliday' in df.columns:
            columns_to_keep.append('IsHoliday')
        
        # Add any other features that were used during training
        available_cols = df.columns.tolist()
        for col in available_cols:
            if col not in columns_to_keep and col not in ['Store', 'Dept']:
                columns_to_keep.append(col)
        
        df = df[columns_to_keep]
        
        self.processed_test_data = df
        return df

# ============================================================
# Prediction Pipeline
# ============================================================

def create_submission_predictions(model, test_data):
    """Create predictions for submission"""
    
    print("Making predictions on test data...")
    
    # For N-BEATS, we need historical data to make future predictions
    # Since test data contains future dates, we'll need to handle this carefully
    
    # Get unique store-department combinations
    unique_ids = test_data['unique_id'].unique()
    
    all_predictions = []
    
    for unique_id in unique_ids:
        # Get data for this store-department combination
        subset = test_data[test_data['unique_id'] == unique_id].copy()
        
        # Sort by date
        subset = subset.sort_values('ds')
        
        try:
            # Make predictions using the model
            # Note: This assumes the model can handle the prediction format
            predictions = model.predict(subset)
            
            # Create prediction dataframe
            pred_df = subset[['Id']].copy()
            pred_df['Weekly_Sales'] = predictions
            
            all_predictions.append(pred_df)
            
        except Exception as e:
            print(f"Error predicting for {unique_id}: {e}")
            # Use fallback prediction (e.g., historical average)
            pred_df = subset[['Id']].copy()
            pred_df['Weekly_Sales'] = 0  # or some fallback value
            all_predictions.append(pred_df)
    
    # Combine all predictions
    final_predictions = pd.concat(all_predictions, ignore_index=True)
    
    return final_predictions

def create_kaggle_submission(predictions, output_file='submission.csv'):
    """Create Kaggle submission file"""
    
    # Ensure we have the required columns
    submission = predictions[['Id', 'Weekly_Sales']].copy()
    
    # Sort by Id
    submission = submission.sort_values('Id')
    
    # Save to CSV
    submission.to_csv(output_file, index=False)
    
    print(f"Submission file saved as: {output_file}")
    print(f"Submission shape: {submission.shape}")
    print(f"Sample predictions:")
    print(submission.head())
    
    return submission

# ============================================================
# Alternative Prediction Method for N-BEATS
# ============================================================

def predict_with_nbeats_direct(test_data, train_data_path='data/train.csv'):
    """
    Alternative method: Use N-BEATS with historical data to predict future
    """
    
    # Load training data to get historical context
    train_data = pd.read_csv(train_data_path)
    
    # Preprocess training data
    train_data['Date'] = pd.to_datetime(train_data['Date'])
    train_data['unique_id'] = train_data['Store'].astype(str) + '_' + train_data['Dept'].astype(str)
    train_data = train_data.rename(columns={'Date': 'ds', 'Weekly_Sales': 'y'})
    
    # Get the date range for predictions
    test_start_date = test_data['ds'].min()
    test_end_date = test_data['ds'].max()
    
    print(f"Predicting from {test_start_date} to {test_end_date}")
    
    # For each unique store-department combination
    unique_ids = test_data['unique_id'].unique()
    all_predictions = []
    
    from neuralforecast import NeuralForecast
    from neuralforecast.models import NBEATS
    
    # Create N-BEATS model with best parameters
    model = NBEATS(
        max_steps=100,  # Reduced for inference
        h=53,  # Prediction horizon
        input_size=52,
        batch_size=256,
        learning_rate=1e-3,
        random_seed=42,
        enable_progress_bar=False
    )
    
    nf = NeuralForecast(models=[model], freq='W')
    
    for i, unique_id in enumerate(unique_ids):
        print(f"Processing {unique_id} ({i+1}/{len(unique_ids)})")
        
        try:
            # Get historical data for this store-department
            hist_data = train_data[train_data['unique_id'] == unique_id].copy()
            hist_data = hist_data[['unique_id', 'ds', 'y']].sort_values('ds')
            
            if len(hist_data) < 52:  # Need minimum history
                print(f"Insufficient history for {unique_id}, using fallback")
                test_subset = test_data[test_data['unique_id'] == unique_id]
                pred_df = test_subset[['Id']].copy()
                pred_df['Weekly_Sales'] = hist_data['y'].mean() if len(hist_data) > 0 else 0
                all_predictions.append(pred_df)
                continue
            
            # Fit model on historical data
            nf.fit(hist_data)
            
            # Make predictions
            forecasts = nf.predict()
            
            # Map predictions to test data
            test_subset = test_data[test_data['unique_id'] == unique_id].copy()
            test_subset = test_subset.sort_values('ds')
            
            # Take first len(test_subset) predictions
            if len(forecasts) >= len(test_subset):
                predictions = forecasts['NBEATS'].iloc[:len(test_subset)].values
            else:
                # Pad with last prediction if needed
                predictions = forecasts['NBEATS'].values
                predictions = np.pad(predictions, (0, len(test_subset) - len(predictions)), 
                                   mode='constant', constant_values=predictions[-1])
            
            pred_df = test_subset[['Id']].copy()
            pred_df['Weekly_Sales'] = predictions
            all_predictions.append(pred_df)
            
        except Exception as e:
            print(f"Error with {unique_id}: {e}")
            # Fallback prediction
            test_subset = test_data[test_data['unique_id'] == unique_id]
            pred_df = test_subset[['Id']].copy()
            pred_df['Weekly_Sales'] = 0
            all_predictions.append(pred_df)
    
    # Combine all predictions
    final_predictions = pd.concat(all_predictions, ignore_index=True)
    return final_predictions

# ============================================================
# Main Inference Pipeline
# ============================================================

def main_inference():
    """Main inference pipeline"""
    
    print("Starting inference pipeline...")
    
    # Load test data
    print("Loading test data...")
    processor = TestDataProcessor()
    processor.load_test_data(
        stores_path='data/stores.csv',
        features_path='data/features.csv',
        test_path='data/test.csv'
    )
    
    # Preprocess test data
    test_data = processor.preprocess_test_data()
    print(f"Test data shape: {test_data.shape}")
    
    try:
        # Method 1: Try to load model from registry
        print("Attempting to load model from wandb registry...")
        model = load_best_model_from_registry()
        
        # Make predictions
        predictions = create_submission_predictions(model, test_data)
        
    except Exception as e:
        print(f"Error loading from registry: {e}")
        print("Using alternative prediction method...")
        
        # Method 2: Direct N-BEATS prediction
        predictions = predict_with_nbeats_direct(test_data)
    
    # Create submission file
    submission = create_kaggle_submission(predictions)
    
    # Log submission info
    wandb.init(
        project="Walmart Recruiting - Store Sales Forecasting",
        job_type="inference"
    )
    
    wandb.log({
        'submission_size': len(submission),
        'prediction_mean': submission['Weekly_Sales'].mean(),
        'prediction_std': submission['Weekly_Sales'].std(),
        'prediction_min': submission['Weekly_Sales'].min(),
        'prediction_max': submission['Weekly_Sales'].max()
    })
    
    # Save submission as wandb artifact
    artifact = wandb.Artifact(
        name="kaggle_submission",
        type="prediction"
    )
    artifact.add_file("submission.csv")
    wandb.log_artifact(artifact)
    
    wandb.finish()
    
    print("Inference completed successfully!")
    return submission

# ============================================================
# Model Performance Analysis
# ============================================================

def analyze_predictions(submission):
    """Analyze the predictions for insights"""
    
    print("\n" + "="*50)
    print("PREDICTION ANALYSIS")
    print("="*50)
    
    # Basic statistics
    print(f"Total predictions: {len(submission):,}")
    print(f"Mean prediction: ${submission['Weekly_Sales'].mean():,.2f}")
    print(f"Median prediction: ${submission['Weekly_Sales'].median():,.2f}")
    print(f"Std deviation: ${submission['Weekly_Sales'].std():,.2f}")
    print(f"Min prediction: ${submission['Weekly_Sales'].min():,.2f}")
    print(f"Max prediction: ${submission['Weekly_Sales'].max():,.2f}")
    
    # Check for negative predictions
    negative_preds = submission[submission['Weekly_Sales'] < 0]
    if len(negative_preds) > 0:
        print(f"\nWarning: {len(negative_preds)} negative predictions found!")
        print("Consider post-processing to ensure non-negative sales")
    
    # Distribution analysis
    import matplotlib.pyplot as plt
    
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.hist(submission['Weekly_Sales'], bins=50, alpha=0.7)
    plt.title('Distribution of Predictions')
    plt.xlabel('Weekly Sales ($)')
    plt.ylabel('Frequency')
    
    plt.subplot(1, 2, 2)
    plt.boxplot(submission['Weekly_Sales'])
    plt.title('Prediction Box Plot')
    plt.ylabel('Weekly Sales ($)')
    
    plt.tight_layout()
    plt.savefig('prediction_analysis.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    return submission

def post_process_predictions(submission, min_sales=0, max_sales=None):
    """Post-process predictions to ensure realistic values"""
    
    print("Post-processing predictions...")
    
    original_predictions = submission.copy()
    
    # Clip negative values
    if min_sales is not None:
        clipped_negative = submission['Weekly_Sales'] < min_sales
        submission.loc[clipped_negative, 'Weekly_Sales'] = min_sales
        print(f"Clipped {clipped_negative.sum()} predictions to minimum value {min_sales}")
    
    # Clip extremely high values if specified
    if max_sales is not None:
        clipped_high = submission['Weekly_Sales'] > max_sales
        submission.loc[clipped_high, 'Weekly_Sales'] = max_sales
        print(f"Clipped {clipped_high.sum()} predictions to maximum value {max_sales}")
    
    # Log changes
    changes = (original_predictions['Weekly_Sales'] != submission['Weekly_Sales']).sum()
    print(f"Total predictions modified: {changes}")
    
    return submission

# ============================================================
# Ensemble Predictions (if multiple models available)
# ============================================================

def ensemble_predictions(model_predictions_list, weights=None):
    """Combine predictions from multiple models"""
    
    if weights is None:
        weights = [1.0] * len(model_predictions_list)
    
    if len(weights) != len(model_predictions_list):
        raise ValueError("Number of weights must match number of prediction sets")
    
    # Normalize weights
    weights = np.array(weights)
    weights = weights / weights.sum()
    
    # Combine predictions
    ensemble_pred = np.zeros_like(model_predictions_list[0]['Weekly_Sales'])
    
    for i, (predictions, weight) in enumerate(zip(model_predictions_list, weights)):
        ensemble_pred += weight * predictions['Weekly_Sales'].values
    
    # Create final submission
    final_submission = model_predictions_list[0][['Id']].copy()
    final_submission['Weekly_Sales'] = ensemble_pred
    
    print(f"Ensemble created from {len(model_predictions_list)} models")
    print(f"Weights: {weights}")
    
    return final_submission

# ============================================================
# Validation on Historical Data
# ============================================================

def validate_on_historical_data(model, validation_period_weeks=8):
    """Validate model on recent historical data"""
    
    print(f"Validating on last {validation_period_weeks} weeks of historical data...")
    
    # Load full training data
    train_data = pd.read_csv('data/train.csv')
    train_data['Date'] = pd.to_datetime(train_data['Date'])
    
    # Split into train and validation
    max_date = train_data['Date'].max()
    validation_start = max_date - pd.Timedelta(weeks=validation_period_weeks)
    
    val_data = train_data[train_data['Date'] >= validation_start].copy()
    hist_data = train_data[train_data['Date'] < validation_start].copy()
    
    print(f"Historical data: {len(hist_data):,} records")
    print(f"Validation data: {len(val_data):,} records")
    
    # Prepare validation data for prediction
    val_data['unique_id'] = val_data['Store'].astype(str) + '_' + val_data['Dept'].astype(str)
    val_data = val_data.rename(columns={'Date': 'ds'})
    
    # Make predictions (simplified version)
    try:
        predictions = model.predict(val_data.drop(columns=['Weekly_Sales']))
        
        # Calculate validation metrics
        from sklearn.metrics import mean_absolute_error, mean_squared_error
        
        mae = mean_absolute_error(val_data['Weekly_Sales'], predictions)
        rmse = np.sqrt(mean_squared_error(val_data['Weekly_Sales'], predictions))
        
        # Calculate WMAE if holiday info available
        if 'IsHoliday' in val_data.columns:
            weights = np.where(val_data['IsHoliday'], 5, 1)
            weighted_mae = np.sum(weights * np.abs(val_data['Weekly_Sales'] - predictions)) / np.sum(weights)
            print(f"Validation WMAE: {weighted_mae:.2f}")
        
        print(f"Validation MAE: {mae:.2f}")
        print(f"Validation RMSE: {rmse:.2f}")
        
        return {
            'mae': mae,
            'rmse': rmse,
            'wmae': weighted_mae if 'IsHoliday' in val_data.columns else None
        }
        
    except Exception as e:
        print(f"Validation failed: {e}")
        return None

# ============================================================
# Complete Inference Workflow
# ============================================================

def complete_inference_workflow():
    """Complete end-to-end inference workflow"""
    
    print("="*60)
    print("WALMART SALES FORECASTING - INFERENCE PIPELINE")
    print("="*60)
    
    # Step 1: Main inference
    submission = main_inference()
    
    # Step 2: Analyze predictions
    submission = analyze_predictions(submission)
    
    # Step 3: Post-process predictions
    submission = post_process_predictions(submission, min_sales=0)
    
    # Step 4: Save final submission
    final_submission = create_kaggle_submission(submission, 'final_submission.csv')
    
    # Step 5: Create summary report
    create_inference_report(final_submission)
    
    return final_submission

def create_inference_report(submission):
    """Create a summary report of the inference process"""
    
    report = f"""
# N-BEATS Model Inference Report
## Walmart Store Sales Forecasting

### Model Information
- **Model Type**: N-BEATS (Neural Basis Expansion Analysis for Time Series)
- **Framework**: PyTorch + NeuralForecast
- **Prediction Horizon**: 53 weeks
- **Input Window**: 52 weeks

### Prediction Summary
- **Total Predictions**: {len(submission):,}
- **Mean Weekly Sales**: ${submission['Weekly_Sales'].mean():,.2f}
- **Median Weekly Sales**: ${submission['Weekly_Sales'].median():,.2f}
- **Standard Deviation**: ${submission['Weekly_Sales'].std():,.2f}
- **Min Prediction**: ${submission['Weekly_Sales'].min():,.2f}
- **Max Prediction**: ${submission['Weekly_Sales'].max():,.2f}

### Data Quality Checks
- **Negative Predictions**: {(submission['Weekly_Sales'] < 0).sum()}
- **Zero Predictions**: {(submission['Weekly_Sales'] == 0).sum()}
- **Missing Values**: {submission['Weekly_Sales'].isnull().sum()}

### Files Generated
- `final_submission.csv`: Kaggle submission file
- `prediction_analysis.png`: Prediction distribution plots
- `inference_report.md`: This report

### Notes
- Predictions are based on historical patterns learned by the N-BEATS model
- Model was trained on historical sales data with weekly frequency
- Holiday effects are implicitly captured through the training data
- Post-processing applied to ensure non-negative sales values

Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
"""
    
    with open('inference_report.md', 'w') as f:
        f.write(report)
    
    print("Inference report saved as: inference_report.md")

# ============================================================
# Usage Examples
# ============================================================

if __name__ == "__main__":
    # Run complete inference workflow
    final_submission = complete_inference_workflow()
    
    print("\n" + "="*60)
    print("INFERENCE COMPLETED SUCCESSFULLY!")
    print("="*60)
    print(f"Final submission shape: {final_submission.shape}")
    print("Files generated:")
    print("- final_submission.csv")
    print("- prediction_analysis.png") 
    print("- inference_report.md")
    print("\nUpload 'final_submission.csv' to Kaggle for evaluation.")

# ============================================================
# Additional Utility Functions
# ============================================================

def compare_with_baseline(submission, baseline_file=None):
    """Compare predictions with a baseline model"""
    
    if baseline_file is None:
        # Create simple baseline (historical average)
        train_data = pd.read_csv('data/train.csv')
        baseline_mean = train_data['Weekly_Sales'].mean()
        
        baseline_submission = submission[['Id']].copy()
        baseline_submission['Weekly_Sales'] = baseline_mean
        
        print(f"Baseline (historical mean): ${baseline_mean:.2f}")
    else:
        baseline_submission = pd.read_csv(baseline_file)
    
    # Compare predictions
    diff = submission['Weekly_Sales'] - baseline_submission['Weekly_Sales']
    
    print(f"\nComparison with Baseline:")
    print(f"Mean difference: ${diff.mean():.2f}")
    print(f"Predictions higher than baseline: {(diff > 0).sum():,} ({(diff > 0).mean()*100:.1f}%)")
    print(f"Predictions lower than baseline: {(diff < 0).sum():,} ({(diff < 0).mean()*100:.1f}%)")
    
    return diff

def export_predictions_by_store(submission, test_data_path='data/test.csv'):
    """Export predictions grouped by store for analysis"""
    
    # Load test data to get store information
    test_data = pd.read_csv(test_data_path)
    
    # Merge with predictions
    detailed_predictions = test_data.merge(submission, on='Id', how='left')
    
    # Group by store
    store_summary = detailed_predictions.groupby('Store').agg({
        'Weekly_Sales': ['count', 'mean', 'std', 'sum'],
        'Dept': 'nunique'
    }).round(2)
    
    store_summary.columns = ['Predictions_Count', 'Mean_Sales', 'Std_Sales', 'Total_Sales', 'Dept_Count']
    store_summary = store_summary.reset_index()
    
    # Save store-level summary
    store_summary.to_csv('predictions_by_store.csv', index=False)
    print("Store-level predictions saved as: predictions_by_store.csv")
    
    return store_summary