In [10]:
import os
import warnings
from typing import Any

import numpy as np
import pandas as pd
import polars as pl

# Visualization
# import matplotlib.pyplot as plt

# NumPy settings
np.set_printoptions(precision=4)

# Pandas settings
pd.options.display.max_rows = 1_000
pd.options.display.max_columns = 1_000
pd.options.display.max_colwidth = 600

# Polars settings
pl.Config.set_fmt_str_lengths(1_000)
pl.Config.set_tbl_cols(n=1_000)
pl.Config.set_tbl_rows(n=200)

warnings.filterwarnings("ignore")

# Black code formatter (Optional)
%load_ext lab_black

# auto reload imports
%load_ext autoreload
%autoreload 2

In [2]:
from rich.console import Console
from rich.theme import Theme

custom_theme = Theme(
    {
        "white": "#FFFFFF",  # Bright white
        "info": "#00FF00",  # Bright green
        "warning": "#FFD700",  # Bright gold
        "error": "#FF1493",  # Deep pink
        "success": "#00FFFF",  # Cyan
        "highlight": "#FF4500",  # Orange-red
    }
)
console = Console(theme=custom_theme)


def go_up_from_current_directory(*, go_up: int = 1) -> None:
    """This is used to up a number of directories.

    Params:
    -------
    go_up: int, default=1
        This indicates the number of times to go back up from the current directory.

    Returns:
    --------
    None
    """
    import os
    import sys

    CONST: str = "../"
    NUM: str = CONST * go_up

    # Goto the previous directory
    prev_directory = os.path.join(os.path.dirname(__name__), NUM)
    # Get the 'absolute path' of the previous directory
    abs_path_prev_directory = os.path.abspath(prev_directory)

    # Add the path to the System paths
    sys.path.insert(0, abs_path_prev_directory)
    print(abs_path_prev_directory)

In [3]:
go_up_from_current_directory(go_up=1)

/Users/mac/Desktop/Projects/End-to-end-Sale-Forecasting


In [4]:
import httpx

url: str = "https://jsonplaceholder.typicode.com/posts"

response = httpx.get(url, timeout=10)
response.raise_for_status()  # Raise an error for bad responses
console.print(response.json()[:3], style="info")

In [5]:
from include.config import app_settings
from include.utilities.data_gen import RealisticSalesDataGenerator

gen_data = RealisticSalesDataGenerator(start_date="2025-01-31", end_date="2025-09-30", seed=123)
file_paths: dict[str, Any] = gen_data.generate_sales_data(output_dir="./data/sales_data")
file_paths

2025-09-05 15:27:15 - include.utilities.data_gen - [INFO] - Generating data for 2025-01-31
2025-09-05 15:27:15 - include.utilities.data_gen - [INFO] - Generating data for 2025-02-01
2025-09-05 15:27:15 - include.utilities.data_gen - [INFO] - Generating data for 2025-02-02
2025-09-05 15:27:15 - include.utilities.data_gen - [INFO] - Generating data for 2025-02-03
2025-09-05 15:27:15 - include.utilities.data_gen - [INFO] - Generating data for 2025-02-04
2025-09-05 15:27:15 - include.utilities.data_gen - [INFO] - Generating data for 2025-02-05
2025-09-05 15:27:15 - include.utilities.data_gen - [INFO] - Generating data for 2025-02-06
2025-09-05 15:27:15 - include.utilities.data_gen - [INFO] - Generating data for 2025-02-07
2025-09-05 15:27:15 - include.utilities.data_gen - [INFO] - Generating data for 2025-02-08
2025-09-05 15:27:15 - include.utilities.data_gen - [INFO] - Generating data for 2025-02-09
2025-09-05 15:27:15 - include.utilities.data_gen - [INFO] - Generating data for 2025-02-10

{'sales': ['./data/sales_data/sales/year=2025/month=02/day=01/sales_2025-02-01.parquet',
  './data/sales_data/sales/year=2025/month=02/day=02/sales_2025-02-02.parquet',
  './data/sales_data/sales/year=2025/month=02/day=03/sales_2025-02-03.parquet',
  './data/sales_data/sales/year=2025/month=02/day=08/sales_2025-02-08.parquet',
  './data/sales_data/sales/year=2025/month=02/day=09/sales_2025-02-09.parquet',
  './data/sales_data/sales/year=2025/month=02/day=15/sales_2025-02-15.parquet',
  './data/sales_data/sales/year=2025/month=02/day=16/sales_2025-02-16.parquet',
  './data/sales_data/sales/year=2025/month=02/day=17/sales_2025-02-17.parquet',
  './data/sales_data/sales/year=2025/month=02/day=22/sales_2025-02-22.parquet',
  './data/sales_data/sales/year=2025/month=02/day=23/sales_2025-02-23.parquet',
  './data/sales_data/sales/year=2025/month=03/day=01/sales_2025-03-01.parquet',
  './data/sales_data/sales/year=2025/month=03/day=03/sales_2025-03-03.parquet',
  './data/sales_data/sales/year

In [6]:
total_files = sum(len(paths) for paths in file_paths.values())
total_files

440

In [None]:
# Convert to Polars
import mlflow
from polars.dataframe.frame import DataFrame

from include.ml.trainer import ModelTrainer

print("Loading sales data from multiple files...")
sales_dfs: list[pl.DataFrame] = []
max_files: int = 50
skipped_sales: int = 0

for i, sales_file in enumerate(file_paths["sales"][:max_files]):
    try:
        df = pl.read_parquet(sales_file)
        sales_dfs.append(df)
    except Exception as e:
        skipped_sales += 1
        print(f"  Skipping unreadable sales file {sales_file}: {e}")
        continue
    if (i + 1) % 10 == 0:
        print(f"  Loaded {i + 1} files...")
if not sales_dfs:
    raise ValueError("No readable sales parquet files were loaded; aborting training")

sales_df = pl.concat(sales_dfs)
print(f"Combined sales data shape: {sales_df.shape}")
daily_sales: DataFrame = (
    sales_df.group_by(["date", "store_id", "product_id", "category"])
    .agg(
        pl.col("quantity_sold").sum(),
        pl.col("revenue").sum().alias("sales"),
        pl.col("cost").sum(),
        pl.col("profit").sum(),
        pl.col("discount_percent").mean(),
        pl.col("unit_price").mean(),
    )
    .sort("date", "store_id")
)

if file_paths.get("promotions"):
    try:
        promo_df = pl.read_parquet(file_paths["promotions"][0])
        promo_summary = (
            promo_df.group_by(["date", "product_id"])
            .agg(pl.col("discount_percent").max())
            .with_columns(pl.lit(1).cast(pl.Int8).alias("has_promotion"))
        )
        daily_sales = daily_sales.join(
            promo_summary.select(["date", "product_id", "has_promotion"]),
            on=["date", "product_id"],
            how="left",
        ).with_columns(pl.col("has_promotion").fill_null(0))
    except Exception as e:
        print(f"Skipping promotions merge due to error: {e}")

if file_paths.get("customer_traffic"):
    traffic_dfs: list[pl.DataFrame] = []
    skipped_traffic: int = 0

    for traffic_file in file_paths["customer_traffic"][:10]:
        try:
            traffic_dfs.append(pl.read_parquet(traffic_file))
        except Exception as e:
            skipped_traffic += 1
            print(f"  Skipping unreadable traffic file {traffic_file}: {e}")

    if traffic_dfs:
        traffic_df = pl.concat(traffic_dfs)
        traffic_summary = traffic_df.group_by(["date", "store_id"]).agg(
            pl.col("customer_traffic").sum(), pl.col("is_holiday").max()
        )
        daily_sales = daily_sales.join(
            traffic_summary,
            on=["date", "store_id"],
            how="left",
        )
    else:
        print("No readable traffic files; skipping merge")
print(f"Final training data shape: {daily_sales.shape}")
print(f"Columns: {daily_sales.columns}")

trainer = ModelTrainer()
store_daily_sales: DataFrame = (
    daily_sales.group_by(["date", "store_id"])
    .agg(
        pl.col("sales").sum(),
        pl.col("quantity_sold").sum(),
        pl.col("profit").sum(),
        pl.col("has_promotion").mean(),
        pl.col("customer_traffic").first(),
        pl.col("is_holiday").first(),
    )
    .with_columns(pl.col("date").cast(pl.Date))
)
train_df, val_df, test_df = trainer.prepare_data(
    store_daily_sales,
    target_col="sales",
    group_cols=["store_id"],
    categorical_cols=["store_id"],
)
print(f"Train shape: {train_df.shape}, Val shape: {val_df.shape}, Test shape: {test_df.shape}")

Loading sales data from multiple files...
  Loaded 10 files...
  Loaded 20 files...
  Loaded 30 files...
  Loaded 40 files...
  Loaded 50 files...
Combined sales data shape: (174, 10)
Final training data shape: (174, 13)
Columns: ['date', 'store_id', 'product_id', 'category', 'quantity_sold', 'sales', 'cost', 'profit', 'discount_percent', 'unit_price', 'has_promotion', 'customer_traffic', 'is_holiday']
2025-09-05 15:56:26 - include.utilities.feature_engineering - [INFO] - Starting feature engineering pipeline
2025-09-05 15:56:26 - include.utilities.feature_engineering - [INFO] - Created 7 lag features
2025-09-05 15:56:26 - include.utilities.feature_engineering - [INFO] - Feature engineering pipeline completed. 41 total features.
2025-09-05 15:56:26 - include.ml.trainer - [INFO] - Data split - {"train_size": 60, "validation_size": 8, "test_size": 18}
Train shape: (60, 41), Val shape: (8, 41), Test shape: (18, 41)


In [18]:
mlflow.end_run()

results = trainer.train_all_models(train_df, val_df, test_df, target_col="sales")
for model_name, model_results in results.items():
    if "metrics" in model_results:
        print(f"\n{model_name} metrics:")
        for metric, value in model_results["metrics"].items():
            print(f"  {metric}: {value:.4f}")
print("\nVisualization charts have been generated and saved to MLflow/MinIO")
print("Charts include:")
print("  - Model metrics comparison")
print("  - Predictions vs actual values")
print("  - Residuals analysis")
print("  - Error distribution")
print("  - Feature importance comparison")

serializable_results: dict[str, dict[str, Any]] = {}
for model_name, model_results in results.items():
    serializable_results[model_name] = {"metrics": model_results.get("metrics", {})}

serializable_results: dict[str, dict[str, Any]] = {}
for model_name, model_results in results.items():
    serializable_results[model_name] = {"metrics": model_results.get("metrics", {})}


current_run = trainer.mlflow_manager.get_run_id()
final_results: dict[str, Any] = {
    "training_results": serializable_results,
    "mlflow_run_id": current_run,
}
console.print(final_results, style="info")

2025-09-05 15:56:26 - include.utilities.mlflow_utils - [INFO] - Started MLflow run: ba4fdb57e96742fa8000536ea248c5aa
2025-09-05 15:56:26 - include.ml.trainer - [INFO] - Training XGBoost model
[0]	validation_0-rmse:62.68329
[1]	validation_0-rmse:55.46352
[2]	validation_0-rmse:51.54056
[3]	validation_0-rmse:45.53690
[4]	validation_0-rmse:42.28120
[5]	validation_0-rmse:39.15963
[6]	validation_0-rmse:36.29725
[7]	validation_0-rmse:33.63027
[8]	validation_0-rmse:30.93414
[9]	validation_0-rmse:28.67686
[10]	validation_0-rmse:26.89091
[11]	validation_0-rmse:25.05359
[12]	validation_0-rmse:23.38517
[13]	validation_0-rmse:21.83391
[14]	validation_0-rmse:20.36703
[15]	validation_0-rmse:19.08786
[16]	validation_0-rmse:17.95905
[17]	validation_0-rmse:16.84690
[18]	validation_0-rmse:15.87859
[19]	validation_0-rmse:14.94017
[20]	validation_0-rmse:14.10110
[21]	validation_0-rmse:13.41335
[22]	validation_0-rmse:12.75128
[23]	validation_0-rmse:12.13010
[24]	validation_0-rmse:11.57988
[25]	validation_0-

In [None]:
df: pl.DataFrame = pl.DataFrame(
    data={
        "id": [1, 2, 3, 4],
        "name": ["Alice", "Bob", "Charlie", "Bob"],
        "role": ["Engineer", "Manager", "Engineer", "Manager"],
        "skill": ["Python", "Leadership", "Python", "Management"],
        "experience": [5, 2, 3, 3],
        "age": [30, 40, 35, 34],
        "target": [1, 0, 1, 1],
    }
)

df

In [None]:
counts = df["name"].value_counts()
mean_target = df.group_by("name").agg(pl.col("target").mean())
display(mean_target)
display(counts["name"])
for row in counts["name"]:
    print(counts.filter(pl.col("name").eq(row))["count"].item())

counts.filter(pl.col("name").eq("Alice"))["count"].item()

### Connect To MLFlow

- Set the `tracking URI` to the MLflow server.
    - Tracking URI requires the MLflow `server address`, `port`, `S3 endpoint URL`, and `S3 credentials`.
    - S3 credentials include `access key`, `secret key`, and `bucket name`.
    - `MinIO` is used as a local S3-compatible storage service.

- Verify the connection by listing experiments.

In [None]:
# Force localhost configuration and debug
RUNNING_IN_DOCKER = False
DEFAULT_MINIO_HOST = app_settings.AWS_S3_HOST if RUNNING_IN_DOCKER else "minio"
DEFAULT_MINIO_PORT = app_settings.AWS_S3_PORT
MINIO_ENDPOINT = app_settings.mlflow_s3_endpoint_url
# This connects to the MLflow server with PostgreSQL backend
MLFLOW_URI = app_settings.mlflow_tracking_uri
AWS_KEY = app_settings.AWS_ACCESS_KEY_ID
AWS_SECRET = app_settings.AWS_SECRET_ACCESS_KEY.get_secret_value()
AWS_REGION = app_settings.AWS_DEFAULT_REGION
BUCKET = app_settings.AWS_S3_BUCKET

# Set environment variables
os.environ["AWS_ACCESS_KEY_ID"] = app_settings.AWS_ACCESS_KEY_ID
os.environ["AWS_SECRET_ACCESS_KEY"] = AWS_SECRET
os.environ["AWS_DEFAULT_REGION"] = AWS_REGION
os.environ["MLFLOW_S3_ENDPOINT_URL"] = MINIO_ENDPOINT

print("=== CONFIGURATION DEBUG ===")
print(f"RUNNING_IN_DOCKER: {RUNNING_IN_DOCKER}")
print(f"DEFAULT_MINIO_HOST: {DEFAULT_MINIO_HOST}")
print(f"MINIO_ENDPOINT: {MINIO_ENDPOINT}")
print(f"MLFLOW_URI: {MLFLOW_URI}")
print(f"AWS_ACCESS_KEY_ID: {AWS_KEY}")
print(f"BUCKET: {BUCKET}")
print(f"Environment MLFLOW_S3_ENDPOINT_URL: {MINIO_ENDPOINT}")
print("=== END CONFIGURATION DEBUG ===\n")

In [None]:
# Test MLflow server connection and S3 storage
import tempfile
import traceback

import boto3
import mlflow
from botocore.exceptions import ClientError

# 1) Test S3/MinIO connection
print("Testing S3/MinIO connection...")
s3 = boto3.client(
    "s3",
    endpoint_url=MINIO_ENDPOINT,
    aws_access_key_id=AWS_KEY,
    aws_secret_access_key=AWS_SECRET,
    region_name=AWS_REGION,
)

try:
    s3.head_bucket(Bucket=BUCKET)
    print(f"✅ Bucket '{BUCKET}' is reachable")
except ClientError as e:
    print(f"❌ S3/MinIO connection failed: {e}")

# 2) Test MLflow server connection
print(f"\nTesting MLflow server connection to {MLFLOW_URI}...")
mlflow.set_tracking_uri(MLFLOW_URI)
print(f"✅ MLflow tracking URI set to: {mlflow.get_tracking_uri()}")

# 3) Test that MLflow uses PostgreSQL backend (not local files)
try:
    # This should connect to the MLflow server which uses PostgreSQL
    experiments = mlflow.search_experiments()
    print(f"✅ Connected to MLflow server. Found {len(experiments)} experiments.")
    print("✅ This confirms MLflow is using the PostgreSQL backend, not local files.")
except Exception as e:
    print(f"❌ Failed to connect to MLflow server: {e}")

print("\n" + "=" * 50)
print("IMPORTANT: If MLflow server is using PostgreSQL correctly,")
print("experiments and runs will be stored in the database,")
print("and artifacts will be stored in MinIO/S3.")
print("Local 'mlruns' folders should NOT be created.")
print("=" * 50)

In [None]:
import mlflow
import mlflow.sklearn
from botocore.exceptions import ClientError
from sklearn import datasets
from sklearn.linear_model import ElasticNet

try:
    mlflow.set_experiment("notebook_quick_test")
    X, y = datasets.load_diabetes(return_X_y=True)
    model = ElasticNet(alpha=0.1, l1_ratio=0.5, random_state=42)
    model.fit(X, y)

    with mlflow.start_run() as run:
        mlflow.log_param("alpha", 0.1)
        mlflow.log_param("l1_ratio", 0.5)
        mlflow.log_metric("dummy_score", model.score(X, y))

        # Create a small artifact file and upload
        with tempfile.NamedTemporaryFile("w", suffix=".txt", delete=False) as tmp:
            tmp.write("mlflow artifact test")
            tmp_path = tmp.name

        mlflow.log_artifact(tmp_path, artifact_path="test_artifacts")
        mlflow.sklearn.log_model(model, "model", input_example=X[:2].tolist())

        # Remove temp file after logging
        os.remove(tmp_path)

        print("✅ Logged run id:", run.info.run_id)
        print("✅ Experiment id:", run.info.experiment_id)

    print("✅ MLflow logging complete — check the UI and MinIO for artifact/model.")
    print("✅ Data stored in PostgreSQL database, artifacts in MinIO S3")

except ClientError as e:
    # boto3 ClientError can surface during artifact upload
    print("❌ Boto3 ClientError during MLflow operations:", e)
    print(traceback.format_exc())
    raise
except Exception:
    print("❌ Unexpected error during MLflow logging:")
    print(traceback.format_exc())
    raise

In [None]:
def create_cyclical_features(df: pl.DataFrame, date_col: str = "date") -> pl.DataFrame:
    df = df.clone()

    return df.with_columns(
        # month (convert 1-12 to 0-11 for proper cyclical encoding)
        pl.col(date_col).dt.month().map_elements(lambda x: np.sin(2 * np.pi * (x - 1) / 12)).alias("month_sin"),
        pl.col(date_col).dt.month().map_elements(lambda x: np.cos(2 * np.pi * (x - 1) / 12)).alias("month_cos"),
        # day (Retain original values; 1-31)
        pl.col(date_col).dt.day().map_elements(lambda x: np.sin(2 * np.pi * x / 31)).alias("day_sin"),
        pl.col(date_col).dt.day().map_elements(lambda x: np.cos(2 * np.pi * x / 31)).alias("day_cos"),
        # day of week (convert 1-7 to 0-6 for proper cyclical encoding)
        pl.col(date_col).dt.weekday().map_elements(lambda x: np.sin(2 * np.pi * (x - 1) / 7)).alias("day_of_week_sin"),
        pl.col(date_col).dt.weekday().map_elements(lambda x: np.cos(2 * np.pi * (x - 1) / 7)).alias("day_of_week_cos"),
    )


create_cyclical_features(temp_df, date_col="date")

In [None]:
import os
from typing import Dict, Any, Optional, Union

import numpy as np
import pandas as pd
import polars as pl
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from scipy import stats
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import plotly.io as pio

# Set default template
pio.templates.default = "plotly_white"


class PlotlyModelVisualizer:
    """Create comprehensive visualizations for model comparison and analysis using Plotly"""

    def __init__(self) -> None:
        """Initialize the visualizer with color scheme"""
        self.colors = {
            "xgboost": "#FF6B6B",
            "lightgbm": "#A7D7D4", 
            "prophet": "#99D145",
            "ensemble": "#198050",
            "actual": "#17222E",
        }

    def create_performance_dashboard(
        self, metrics_dict: Dict[str, Dict[str, float]], save_path: Optional[str] = None
    ) -> go.Figure:
        """Create a comprehensive performance dashboard combining metrics and rankings."""
        
        models = list(metrics_dict.keys())
        
        # Create subplot structure
        fig = make_subplots(
            rows=3, cols=3,
            subplot_titles=[
                "RMSE", "MAE", "MAPE (%)", "R² Score", 
                "Overall Model Ranking", "Normalized Performance"
            ],
            specs=[
                [{"type": "bar"}, {"type": "bar"}, {"type": "bar"}],
                [{"type": "bar"}, {"type": "bar"}, {"type": "bar"}], 
                [{"colspan": 3, "type": "bar"}, None, None]
            ],
            vertical_spacing=0.12,
            horizontal_spacing=0.08
        )

        # Individual metric plots
        metrics_info = [
            ("rmse", "RMSE", True, 1, 1),
            ("mae", "MAE", True, 1, 2), 
            ("mape", "MAPE (%)", True, 1, 3),
            ("r2", "R² Score", False, 2, 1),
        ]

        for metric, title, lower_better, row, col in metrics_info:
            values = [metrics_dict[model].get(metric, 0) for model in models]
            colors = [self.colors.get(model.lower(), "#95A5A6") for model in models]
            
            # Highlight best model
            best_idx = values.index(min(values) if lower_better else max(values))
            edge_colors = ['green' if i == best_idx else 'rgba(0,0,0,0)' for i in range(len(models))]
            edge_widths = [3 if i == best_idx else 0 for i in range(len(models))]
            
            fig.add_trace(
                go.Bar(
                    x=models,
                    y=values,
                    marker=dict(
                        color=colors,
                        line=dict(color=edge_colors, width=edge_widths)
                    ),
                    text=[f"{v:.3f}" for v in values],
                    textposition="outside",
                    showlegend=False,
                    name=title
                ),
                row=row, col=col
            )

        # Model ranking (top right)
        ranking_data = self._calculate_model_ranking(metrics_dict)
        models_sorted = [x[0] for x in ranking_data]
        scores_sorted = [x[1] for x in ranking_data]
        colors_sorted = [self.colors.get(model.lower(), "#95A5A6") for model in models_sorted]
        
        fig.add_trace(
            go.Bar(
                y=models_sorted,
                x=scores_sorted,
                orientation='h',
                marker=dict(color=colors_sorted),
                text=[f"{score:.2f}" for score in scores_sorted],
                textposition="outside",
                showlegend=False,
                name="Model Ranking"
            ),
            row=2, col=2
        )

        # Normalized performance comparison (bottom)
        norm_data = self._calculate_normalized_performance(metrics_dict)
        
        # Create grouped bar chart for normalized performance
        metrics = ["rmse", "mae", "mape", "r2"]
        colors_metrics = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728"]
        
        for i, metric in enumerate(metrics):
            fig.add_trace(
                go.Bar(
                    x=models,
                    y=norm_data[metric],
                    name=metric.upper(),
                    marker=dict(color=colors_metrics[i]),
                    text=[f"{v:.2f}" for v in norm_data[metric]],
                    textposition="outside",
                    offsetgroup=i,
                    showlegend=True
                ),
                row=3, col=1
            )

        # Update layout
        fig.update_layout(
            height=900,
            title_text="Model Performance Dashboard",
            title_x=0.5,
            title_font_size=18,
            showlegend=True,
            legend=dict(x=0.7, y=0.3)
        )
        
        # Update axes
        fig.update_xaxes(title_text="Models", row=3, col=1)
        fig.update_yaxes(title_text="Normalized Performance", row=3, col=1)
        fig.update_xaxes(title_text="Average Rank", row=2, col=2)
        fig.update_yaxes(title_text="Models", row=2, col=2)

        if save_path:
            fig.write_html(save_path)

        return fig

    def create_prediction_quality_analysis(
        self,
        predictions_dict: Dict[str, Union[pd.DataFrame, pl.DataFrame]],
        actual_data: Union[pd.DataFrame, pl.DataFrame],
        date_col: str = "date",
        target_col: str = "sales",
        save_path: Optional[str] = None,
    ) -> go.Figure:
        """Create prediction quality analysis combining scatter plots and time series."""
        
        # Convert to pandas if needed
        if isinstance(actual_data, pl.DataFrame):
            actual_data = actual_data.to_pandas()

        predictions_dict = {
            k: (v.to_pandas() if isinstance(v, pl.DataFrame) else v) 
            for k, v in predictions_dict.items()
        }

        n_models = len(predictions_dict)
        
        # Create subplots
        fig = make_subplots(
            rows=2, cols=n_models,
            subplot_titles=[f"{model} - Actual vs Predicted" for model in predictions_dict.keys()] +
                          [f"{model} - Time Series" for model in predictions_dict.keys()],
            vertical_spacing=0.12
        )

        for idx, (model_name, pred_df) in enumerate(predictions_dict.items()):
            col = idx + 1
            
            # Merge data
            merged = pd.merge(
                actual_data[[date_col, target_col]], 
                pred_df[[date_col, "prediction"]], 
                on=date_col, how="inner"
            )

            color = self.colors.get(model_name.lower(), "#95A5A6")

            # Top row: Actual vs Predicted scatter
            min_val = min(merged[target_col].min(), merged["prediction"].min())
            max_val = max(merged[target_col].max(), merged["prediction"].max())
            
            # Scatter plot
            fig.add_trace(
                go.Scatter(
                    x=merged[target_col],
                    y=merged["prediction"],
                    mode='markers',
                    marker=dict(color=color, opacity=0.6, size=4),
                    showlegend=False,
                    name=f"{model_name} Predictions"
                ),
                row=1, col=col
            )
            
            # Perfect prediction line
            fig.add_trace(
                go.Scatter(
                    x=[min_val, max_val],
                    y=[min_val, max_val],
                    mode='lines',
                    line=dict(color='red', dash='dash'),
                    showlegend=False,
                    name="Perfect Prediction"
                ),
                row=1, col=col
            )

            # Calculate R²
            r2 = r2_score(merged[target_col], merged["prediction"])
            
            # Add R² annotation
            fig.add_annotation(
                x=0.05, y=0.95,
                text=f"R² = {r2:.3f}",
                showarrow=False,
                xref=f"x{col if col > 1 else ''} domain",
                yref=f"y{col if col > 1 else ''} domain",
                bgcolor="white",
                bordercolor="black",
                row=1, col=col
            )

            # Bottom row: Time series comparison
            if len(merged) > 500:
                sample_idx = np.linspace(0, len(merged) - 1, 500, dtype=int)
                plot_data = merged.iloc[sample_idx]
            else:
                plot_data = merged

            # Actual values
            fig.add_trace(
                go.Scatter(
                    x=plot_data[date_col],
                    y=plot_data[target_col],
                    mode='lines',
                    line=dict(color=self.colors["actual"], width=2),
                    name="Actual" if col == 1 else None,
                    showlegend=(col == 1),
                    legendgroup="actual"
                ),
                row=2, col=col
            )
            
            # Predictions
            fig.add_trace(
                go.Scatter(
                    x=plot_data[date_col],
                    y=plot_data["prediction"],
                    mode='lines',
                    line=dict(color=color, width=2),
                    name=f"{model_name} Prediction" if col == 1 else None,
                    showlegend=(col == 1),
                    legendgroup=model_name
                ),
                row=2, col=col
            )

            # Add confidence intervals if available
            if "prediction_lower" in pred_df.columns and "prediction_upper" in pred_df.columns:
                plot_data_conf = pd.merge(plot_data, pred_df[[date_col, "prediction_lower", "prediction_upper"]], on=date_col, how="left")
                
                fig.add_trace(
                    go.Scatter(
                        x=plot_data_conf[date_col],
                        y=plot_data_conf["prediction_upper"],
                        mode='lines',
                        line=dict(width=0),
                        showlegend=False,
                        hoverinfo='skip'
                    ),
                    row=2, col=col
                )
                
                fig.add_trace(
                    go.Scatter(
                        x=plot_data_conf[date_col],
                        y=plot_data_conf["prediction_lower"],
                        mode='lines',
                        line=dict(width=0),
                        fill='tonexty',
                        fillcolor=f'rgba{tuple(list(px.colors.hex_to_rgb(color)) + [0.2])}',
                        showlegend=False,
                        hoverinfo='skip'
                    ),
                    row=2, col=col
                )

        # Update layout
        fig.update_layout(
            height=600,
            title_text="Prediction Quality Analysis",
            title_x=0.5,
            showlegend=True
        )
        
        # Update axes labels
        for col in range(1, n_models + 1):
            fig.update_xaxes(title_text=f"Actual {target_col.title()}", row=1, col=col)
            fig.update_yaxes(title_text="Predicted", row=1, col=col)
            fig.update_xaxes(title_text="Date", row=2, col=col)
            fig.update_yaxes(title_text=target_col.title(), row=2, col=col)

        if save_path:
            fig.write_html(save_path)

        return fig

    def create_residuals_diagnostic_panel(
        self,
        predictions_dict: Dict[str, Union[pd.DataFrame, pl.DataFrame]],
        actual_data: Union[pd.DataFrame, pl.DataFrame],
        target_col: str = "sales",
        save_path: Optional[str] = None,
    ) -> go.Figure:
        """Create comprehensive residuals diagnostic panel."""
        
        if isinstance(actual_data, pl.DataFrame):
            actual_data = actual_data.to_pandas()

        predictions_dict = {
            k: (v.to_pandas() if isinstance(v, pl.DataFrame) else v) 
            for k, v in predictions_dict.items()
        }

        # Calculate residuals
        residuals_data = {}
        for model_name, pred_df in predictions_dict.items():
            merged = pd.merge(actual_data[["date", target_col]], pred_df[["date", "prediction"]], on="date", how="inner")
            residuals = merged[target_col] - merged["prediction"]
            residuals_data[model_name] = {
                "residuals": residuals,
                "predictions": merged["prediction"],
                "actual": merged[target_col],
                "dates": merged["date"],
            }

        # Create subplots
        fig = make_subplots(
            rows=2, cols=3,
            subplot_titles=[
                "Residuals Distribution", "Residuals vs Fitted Values", "Scale-Location Plot",
                "Residuals Over Time", "Normal Q-Q Plot", "Residuals Distribution (Box)"
            ],
            specs=[
                [{"type": "histogram"}, {"type": "scatter"}, {"type": "scatter"}],
                [{"type": "scatter"}, {"type": "scatter"}, {"type": "box"}]
            ],
            vertical_spacing=0.12
        )

        # 1. Residuals distribution comparison
        for model_name, data in residuals_data.items():
            residuals = data["residuals"].dropna()
            if len(residuals) > 0:
                fig.add_trace(
                    go.Histogram(
                        x=residuals,
                        nbinsx=30,
                        opacity=0.6,
                        name=model_name,
                        marker=dict(color=self.colors.get(model_name.lower(), "#95A5A6")),
                        histnorm='probability density',
                        showlegend=True,
                        legendgroup=model_name
                    ),
                    row=1, col=1
                )

        # Add vertical line at x=0
        fig.add_vline(x=0, line_dash="dash", line_color="red", opacity=0.7, row=1, col=1)

        # 2. Residuals vs Fitted
        for model_name, data in residuals_data.items():
            fig.add_trace(
                go.Scatter(
                    x=data["predictions"],
                    y=data["residuals"],
                    mode='markers',
                    marker=dict(color=self.colors.get(model_name.lower(), "#95A5A6"), opacity=0.6, size=4),
                    name=model_name,
                    showlegend=False,
                    legendgroup=model_name
                ),
                row=1, col=2
            )

        # Add horizontal line at y=0
        fig.add_hline(y=0, line_dash="dash", line_color="red", opacity=0.7, row=1, col=2)

        # 3. Scale-Location plot
        for model_name, data in residuals_data.items():
            residuals = data["residuals"]
            predictions = data["predictions"]
            std_residuals = np.sqrt(np.abs(residuals / residuals.std()))
            
            fig.add_trace(
                go.Scatter(
                    x=predictions,
                    y=std_residuals,
                    mode='markers',
                    marker=dict(color=self.colors.get(model_name.lower(), "#95A5A6"), opacity=0.6, size=4),
                    name=model_name,
                    showlegend=False,
                    legendgroup=model_name
                ),
                row=1, col=3
            )

        # 4. Residuals over time
        for model_name, data in residuals_data.items():
            fig.add_trace(
                go.Scatter(
                    x=data["dates"],
                    y=data["residuals"],
                    mode='lines',
                    line=dict(color=self.colors.get(model_name.lower(), "#95A5A6"), width=1),
                    name=model_name,
                    showlegend=False,
                    legendgroup=model_name
                ),
                row=2, col=1
            )

        # Add horizontal line at y=0
        fig.add_hline(y=0, line_dash="dash", line_color="red", opacity=0.7, row=2, col=1)

        # 5. Q-Q Plot (Normal probability plot)
        for model_name, data in residuals_data.items():
            residuals = data["residuals"].dropna().values
            if len(residuals) > 0:
                # Calculate Q-Q plot points
                sorted_residuals = np.sort(residuals)
                n = len(sorted_residuals)
                theoretical_quantiles = stats.norm.ppf(np.arange(1, n + 1) / (n + 1))
                
                fig.add_trace(
                    go.Scatter(
                        x=theoretical_quantiles,
                        y=sorted_residuals,
                        mode='markers',
                        marker=dict(color=self.colors.get(model_name.lower(), "#95A5A6"), size=4),
                        name=model_name,
                        showlegend=False,
                        legendgroup=model_name
                    ),
                    row=2, col=2
                )
                
                # Add reference line
                if model_name == list(residuals_data.keys())[0]:  # Add only once
                    slope, intercept = np.polyfit(theoretical_quantiles, sorted_residuals, 1)
                    line_y = slope * theoretical_quantiles + intercept
                    fig.add_trace(
                        go.Scatter(
                            x=theoretical_quantiles,
                            y=line_y,
                            mode='lines',
                            line=dict(color='red', dash='dash'),
                            name="Reference Line",
                            showlegend=False
                        ),
                        row=2, col=2
                    )

        # 6. Residuals boxplot comparison
        for model_name, data in residuals_data.items():
            residuals = data["residuals"].dropna().values
            if len(residuals) > 0:
                fig.add_trace(
                    go.Box(
                        y=residuals,
                        name=model_name,
                        marker=dict(color=self.colors.get(model_name.lower(), "#95A5A6")),
                        showlegend=False
                    ),
                    row=2, col=3
                )

        # Add horizontal line at y=0
        fig.add_hline(y=0, line_dash="dash", line_color="red", opacity=0.7, row=2, col=3)

        # Update layout
        fig.update_layout(
            height=800,
            title_text="Residuals Diagnostic Panel",
            title_x=0.5,
            showlegend=True
        )
        
        # Update axes labels
        fig.update_xaxes(title_text="Residuals", row=1, col=1)
        fig.update_yaxes(title_text="Density", row=1, col=1)
        fig.update_xaxes(title_text="Fitted Values", row=1, col=2)
        fig.update_yaxes(title_text="Residuals", row=1, col=2)
        fig.update_xaxes(title_text="Fitted Values", row=1, col=3)
        fig.update_yaxes(title_text="√|Standardized Residuals|", row=1, col=3)
        fig.update_xaxes(title_text="Date", row=2, col=1)
        fig.update_yaxes(title_text="Residuals", row=2, col=1)
        fig.update_xaxes(title_text="Theoretical Quantiles", row=2, col=2)
        fig.update_yaxes(title_text="Sample Quantiles", row=2, col=2)
        fig.update_yaxes(title_text="Residuals", row=2, col=3)

        if save_path:
            fig.write_html(save_path)

        return fig

    def create_feature_importance_comparison(
        self,
        feature_importance_dict: Dict[str, Union[pd.DataFrame, pl.DataFrame]],
        top_n: int = 15,
        save_path: Optional[str] = None,
    ) -> go.Figure:
        """Create enhanced feature importance comparison with consistency analysis."""
        
        feature_importance_dict = {
            k: v.to_pandas() if isinstance(v, pl.DataFrame) else v 
            for k, v in feature_importance_dict.items()
        }

        n_models = len(feature_importance_dict)
        
        # Create subplots
        fig = make_subplots(
            rows=2, cols=n_models + 1,
            subplot_titles=list(feature_importance_dict.keys()) + ["Feature Consistency", "Feature Importance Heatmap"],
            specs=[[{"type": "bar"}] * n_models + [{"type": "bar"}]] + 
                  [[{"colspan": n_models + 1, "type": "heatmap"}] + [None] * n_models],
            vertical_spacing=0.15,
            horizontal_spacing=0.05
        )

        # Individual model importance plots (top row)
        for idx, (model_name, importance_df) in enumerate(feature_importance_dict.items()):
            col = idx + 1
            
            top_features = importance_df.nlargest(top_n, "importance")
            
            fig.add_trace(
                go.Bar(
                    y=top_features["feature"],
                    x=top_features["importance"],
                    orientation='h',
                    marker=dict(color=self.colors.get(model_name.lower(), "#95A5A6")),
                    text=[f"{v:.3f}" for v in top_features["importance"]],
                    textposition="outside",
                    showlegend=False,
                    name=f"{model_name} Features"
                ),
                row=1, col=col
            )

        # Feature consistency analysis (top right)
        consistency_data = self._calculate_feature_consistency(feature_importance_dict, top_n)
        if consistency_data:
            features, counts = zip(*consistency_data.items())
            fig.add_trace(
                go.Bar(
                    y=list(features),
                    x=list(counts),
                    orientation='h',
                    marker=dict(color="#4A90E2"),
                    showlegend=False,
                    name="Feature Consistency"
                ),
                row=1, col=n_models + 1
            )

        # Feature importance heatmap (bottom)
        heatmap_data = self._create_importance_heatmap_data(feature_importance_dict, top_n)
        if heatmap_data is not None:
            importance_matrix, features, models = heatmap_data
            
            fig.add_trace(
                go.Heatmap(
                    z=importance_matrix,
                    x=models,
                    y=features,
                    colorscale='RdYlBu_r',
                    showscale=True,
                    colorbar=dict(title="Normalized Importance")
                ),
                row=2, col=1
            )

        # Update layout
        fig.update_layout(
            height=900,
            title_text=f"Feature Importance Analysis - Top {top_n} Features",
            title_x=0.5,
            showlegend=False
        )

        if save_path:
            fig.write_html(save_path)

        return fig

    def create_model_stability_analysis(
        self,
        predictions_dict: Dict[str, Union[pd.DataFrame, pl.DataFrame]],
        actual_data: Union[pd.DataFrame, pl.DataFrame],
        target_col: str = "sales",
        window_size: int = 30,
        save_path: Optional[str] = None,
    ) -> go.Figure:
        """Create model stability analysis showing performance over time windows."""
        
        if isinstance(actual_data, pl.DataFrame):
            actual_data = actual_data.to_pandas()

        predictions_dict = {
            k: (v.to_pandas() if isinstance(v, pl.DataFrame) else v) 
            for k, v in predictions_dict.items()
        }

        # Calculate rolling metrics
        stability_data = {}
        for model_name, pred_df in predictions_dict.items():
            merged = pd.merge(actual_data[["date", target_col]], pred_df[["date", "prediction"]], on="date", how="inner")
            merged = merged.sort_values("date").reset_index(drop=True)

            rolling_mae = []
            rolling_rmse = []
            rolling_r2 = []
            dates = []

            for i in range(window_size, len(merged)):
                window_actual = merged[target_col].iloc[i - window_size : i]
                window_pred = merged["prediction"].iloc[i - window_size : i]

                mae = mean_absolute_error(window_actual, window_pred)
                rmse = np.sqrt(mean_squared_error(window_actual, window_pred))
                r2 = r2_score(window_actual, window_pred)

                rolling_mae.append(mae)
                rolling_rmse.append(rmse)
                rolling_r2.append(r2)
                dates.append(merged["date"].iloc[i])

            stability_data[model_name] = {
                "dates": dates, 
                "mae": rolling_mae, 
                "rmse": rolling_rmse, 
                "r2": rolling_r2
            }

        # Create subplots
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=["Rolling MAE", "Rolling RMSE", "Rolling R²", "Model Stability Ranking"],
            vertical_spacing=0.12
        )

        # Plot rolling metrics
        metrics_info = [("mae", "Rolling MAE", 1, 1), ("rmse", "Rolling RMSE", 1, 2), ("r2", "Rolling R²", 2, 1)]

        for metric, title, row, col in metrics_info:
            for model_name, data in stability_data.items():
                fig.add_trace(
                    go.Scatter(
                        x=data["dates"],
                        y=data[metric],
                        mode='lines',
                        line=dict(color=self.colors.get(model_name.lower(), "#95A5A6"), width=2),
                        name=model_name,
                        showlegend=(row == 1 and col == 1),
                        legendgroup=model_name
                    ),
                    row=row, col=col
                )

        # Model stability ranking (bottom right)
        stability_ranking = self._calculate_stability_ranking(stability_data)
        models_sorted = [x[0] for x in stability_ranking]
        scores_sorted = [x[1] for x in stability_ranking]
        colors_sorted = [self.colors.get(model.lower(), "#95A5A6") for model in models_sorted]

        fig.add_trace(
            go.Bar(
                y=models_sorted,
                x=scores_sorted,
                orientation='h',
                marker=dict(color=colors_sorted),
                text=[f"{score:.3f}" for score in scores_sorted],
                textposition="outside",
                showlegend=False,
                name="Stability Ranking"
            ),
            row=2, col=2
        )

        # Update layout
        fig.update_layout(
            height=700,
            title_text=f"Model Stability Analysis (Rolling Window: {window_size})",
            title_x=0.5,
            showlegend=True
        )
        
        # Update axes
        fig.update_xaxes(title_text="Date", row=1, col=1)
        fig.update_xaxes(title_text="Date", row=1, col=2)
        fig.update_xaxes(title_text="Date", row=2, col=1)
        fig.update_xaxes(title_text="Coefficient of Variation", row=2, col=2)
        fig.update_yaxes(title_text="MAE", row=1, col=1)
        fig.update_yaxes(title_text="RMSE", row=1, col=2)
        fig.update_yaxes(title_text="R²", row=2, col=1)
        fig.update_yaxes(title_text="Models", row=2, col=2)

        if save_path:
            fig.write_html(save_path)

        return fig

    def create_comprehensive_report(
        self,
        metrics_dict: Dict[str, Dict[str, float]],
        predictions_dict: Dict[str, Union[pd.DataFrame, pl.DataFrame]],
        actual_data: Union[pd.DataFrame, pl.DataFrame],
        feature_importance_dict: Optional[Dict[str, pd.DataFrame]] = None,
        save_dir: str = "/tmp/model_comparison_charts",
        window_size: int = 30,
    ) -> Dict[str, str]:
        """Generate all comparison charts and save them."""
        
        if isinstance(actual_data, pl.DataFrame):
            actual_data = actual_data.to_pandas()

        predictions_dict = {
            k: (v.to_pandas() if isinstance(v, pl.DataFrame) else v) 
            for k, v in predictions_dict.items()
        }

        os.makedirs(save_dir, exist_ok=True)
        saved_files = {}

        # 1. Performance Dashboard
        fig1 = self.create_performance_dashboard(metrics_dict)
        path1 = os.path.join(save_dir, "performance_dashboard.html")
        fig1.write_html(path1)
        saved_files["performance_dashboard"] = path1

        # 2. Prediction Quality Analysis
        fig2 = self.create_prediction_quality_analysis(predictions_dict, actual_data)
        path2 = os.path.join(save_dir, "prediction_quality_analysis.html")
        fig2.write_html(path2)
        saved_files["prediction_quality_analysis"] = path2

        # 3. Residuals Diagnostic Panel
        fig3 = self.create_residuals_diagnostic_panel(predictions_dict, actual_data)
        path3 = os.path.join(save_dir, "residuals_diagnostic_panel.html")
        fig3.write_html(path3)
        saved_files["residuals_diagnostic_panel"] = path3

        # 4. Model Stability Analysis (if enough data)
        if len(actual_data) > window_size * 2:
            fig4 = self.create_model_stability_analysis(predictions_dict, actual_data, window_size=window_size)
            path4 = os.path.join(save_dir, "model_stability_analysis.html")
            fig4.write_html(path4)
            saved_files["model_stability_analysis"] = path4

        # 5. Feature Importance (if available)
        if feature_importance_dict:
            fig5 = self.create_feature_importance_comparison(feature_importance_dict)
            path5 = os.path.join(save_dir, "feature_importance_comparison.html")
            fig5.write_html(path5)
            saved_files["feature_importance_comparison"] = path5

        return saved_files

    # Helper methods
    def _calculate_model_ranking(self, metrics_dict: Dict[str, Dict[str, float]]) -> list:
        """Calculate model ranking based on average rank across metrics."""
        models = list(metrics_dict.keys())
        ranking_scores = {}

        for model in models:
            total_rank = 0

            # For metrics where lower is better (RMSE, MAE, MAPE)
            for metric in ["rmse", "mae", "mape"]:
                values = [metrics_dict[m].get(metric, float("inf")) for m in models]
                sorted_values = sorted(values)
                model_value = metrics_dict[model].get(metric, float("inf"))
                rank = sorted_values.index(model_value) + 1
                total_rank += rank

            # For R² - higher is better
            r2_values = [metrics_dict[m].get("r2", -float("inf")) for m in models]
            r2_sorted_desc = sorted(r2_values, reverse=True)
            model_r2 = metrics_dict[model].get("r2", -float("inf"))
            r2_rank = r2_sorted_desc.index(model_r2) + 1
            total_rank += r2_rank

            # Average rank across all 4 metrics
            avg_rank = total_rank / 4
            ranking_scores[model] = avg_rank

        # Sort by ranking score (lowest average rank = best)
        return sorted(ranking_scores.items(), key=lambda x: x[1])

    def _calculate_normalized_performance(self, metrics_dict: Dict[str, Dict[str, float]]) -> Dict[str, list]:
        """Calculate normalized performance metrics."""
        models = list(metrics_dict.keys())
        metrics = ["rmse", "mae", "mape", "r2"]
        
        normalized_data = {}
        for metric in metrics:
            values = [metrics_dict[model].get(metric, 0) for model in models]
            if metric == "r2":  # Higher is better
                min_val, max_val = min(values), max(values)
                if max_val != min_val:
                    normalized_values = [0.1 + 0.9 * (v - min_val) / (max_val - min_val) for v in values]
                else:
                    normalized_values = [0.55] * len(values)
            else:  # Lower is better - invert
                min_val, max_val = min(values), max(values)
                if max_val != min_val:
                    normalized_values = [0.1 + 0.9 * (1 - (v - min_val) / (max_val - min_val)) for v in values]
                else:
                    normalized_values = [0.55] * len(values)
            
            normalized_data[metric] = normalized_values

        return normalized_data

    def _calculate_feature_consistency(self, feature_importance_dict: Dict[str, pd.DataFrame], top_n: int) -> Dict[str, int]:
        """Calculate feature consistency across models."""
        all_features = set()
        for df in feature_importance_dict.values():
            all_features.update(df["feature"].tolist())

        feature_counts = {}
        for feature in all_features:
            count = 0
            for df in feature_importance_dict.values():
                top_features = df.nlargest(top_n, "importance")["feature"].tolist()
                if feature in top_features:
                    count += 1
            if count > 1:  # Only show features that appear in multiple models
                feature_counts[feature] = count

        return feature_counts

    def _create_importance_heatmap_data(self, feature_importance_dict: Dict[str, pd.DataFrame], top_n: int):
        """Create data for feature importance heatmap."""
        # Get top features across all models
        all_features = set()
        for df in feature_importance_dict.values():
            top_features = df.nlargest(top_n, "importance")["feature"].tolist()
            all_features.update(top_features)

        if not all_features:
            return None

        # Create importance matrix
        models = list(feature_importance_dict.keys())
        features = list(all_features)
        importance_matrix = np.zeros((len(features), len(models)))

        for j, model in enumerate(models):
            df = feature_importance_dict[model]
            for i, feature in enumerate(features):
                importance_row = df[df["feature"] == feature]
                if not importance_row.empty:
                    importance_matrix[i, j] = importance_row["importance"].iloc[0]

        # Normalize by row (feature) for better visualization
        importance_matrix_norm = importance_matrix / (importance_matrix.max(axis=1, keepdims=True) + 1e-8)

        return importance_matrix_norm, features, models

    def _calculate_stability_ranking(self, stability_data: Dict[str, Dict]) -> list:
        """Calculate model stability ranking based on variance of rolling metrics."""
        stability_scores = {}

        for model_name, data in stability_data.items():
            # Calculate coefficient of variation (std/mean) for each metric
            mae_cv = np.std(data["mae"]) / (np.mean(data["mae"]) + 1e-8)
            rmse_cv = np.std(data["rmse"]) / (np.mean(data["rmse"]) + 1e-8)
            r2_cv = np.std(data["r2"]) / (np.mean(np.abs(data["r2"])) + 1e-8)

            # Average CV (lower is more stable)
            stability_scores[model_name] = (mae_cv + rmse_cv + r2_cv) / 3

        # Sort by stability (lower CV = more stable)
        return sorted(stability_scores.items(), key=lambda x: x[1])


# Example usage and compatibility function
def generate_model_comparison_report_plotly(
    run_id: str, 
    test_data: Union[pd.DataFrame, pl.DataFrame]
) -> Dict[str, str]:
    """Generate comparison report from MLflow run using Plotly.

    Parameters
    ----------
    run_id : str
        MLflow run ID.
    test_data : Union[pd.DataFrame, pl.DataFrame]
        Test data with ground truth.

    Returns
    -------
    Dict[str, str]
        Dictionary of saved file paths.
    """
    import mlflow
    
    if isinstance(test_data, pl.DataFrame):
        test_data = test_data.to_pandas()

    visualizer = PlotlyModelVisualizer()

    client = mlflow.tracking.MlflowClient()
    run = client.get_run(run_id)

    # Extract metrics
    metrics_dict = {}
    for model in ["xgboost", "lightgbm", "ensemble"]:
        model_metrics = {}
        for metric in ["rmse", "mae", "mape", "r2"]:
            metric_key = f"{model}_{metric}"
            if metric_key in run.data.metrics:
                model_metrics[metric] = run.data.metrics[metric_key]
        if model_metrics:
            metrics_dict[model] = model_metrics

    # Generate dummy predictions for visualization
    predictions_dict = {}
    rng = np.random.default_rng()
    for model in metrics_dict.keys():
        pred_df = test_data[["date"]].copy()
        noise = rng.normal(0, 5, len(test_data))
        pred_df["prediction"] = test_data["sales"] + noise
        predictions_dict[model] = pred_df

    # Generate visualizations
    saved_files = visualizer.create_comprehensive_report(metrics_dict, predictions_dict, test_data)

    # Log visualizations to MLflow
    for name, path in saved_files.items():
        mlflow.log_artifact(path, f"artifacts/visualizations/{name}")

    return saved_files


# Example of creating sample data for testing
def create_sample_data_for_testing():
    """Create sample data for testing the visualizer."""
    import datetime
    
    # Create sample actual data
    dates = pd.date_range(start='2023-01-01', end='2023-12-31', freq='D')
    np.random.seed(42)
    sales = 100 + 10 * np.sin(2 * np.pi * np.arange(len(dates)) / 365) + np.random.normal(0, 5, len(dates))
    
    actual_data = pd.DataFrame({
        'date': dates,
        'sales': sales
    })
    
    # Create sample metrics
    metrics_dict = {
        'xgboost': {'rmse': 5.2, 'mae': 4.1, 'mape': 8.5, 'r2': 0.85},
        'lightgbm': {'rmse': 5.8, 'mae': 4.5, 'mape': 9.2, 'r2': 0.82},
        'ensemble': {'rmse': 4.9, 'mae': 3.8, 'mape': 7.8, 'r2': 0.87}
    }
    
    # Create sample predictions
    predictions_dict = {}
    for model in metrics_dict.keys():
        pred_df = actual_data[['date']].copy()
        noise_std = metrics_dict[model]['rmse'] * 0.8
        pred_df['prediction'] = actual_data['sales'] + np.random.normal(0, noise_std, len(actual_data))
        predictions_dict[model] = pred_df
    
    # Create sample feature importance
    features = ['feature_1', 'feature_2', 'feature_3', 'feature_4', 'feature_5', 
               'feature_6', 'feature_7', 'feature_8', 'feature_9', 'feature_10']
    
    feature_importance_dict = {}
    for model in metrics_dict.keys():
        importance_values = np.random.exponential(scale=0.1, size=len(features))
        importance_values = importance_values / importance_values.sum()  # Normalize
        
        feature_importance_dict[model] = pd.DataFrame({
            'feature': features,
            'importance': importance_values
        })
    
    return metrics_dict, predictions_dict, actual_data, feature_importance_dict


# Test the visualizer
if __name__ == "__main__":
    # Create sample data
    metrics_dict, predictions_dict, actual_data, feature_importance_dict = create_sample_data_for_testing()
    
    # Initialize visualizer
    visualizer = PlotlyModelVisualizer()
    
    # Generate all reports
    saved_files = visualizer.create_comprehensive_report(
        metrics_dict=metrics_dict,
        predictions_dict=predictions_dict,
        actual_data=actual_data,
        feature_importance_dict=feature_importance_dict,
        save_dir="./plotly_charts"
    )
    
    print("Generated files:")
    for name, path in saved_files.items():
        print(f"  {name}: {path}")
        
    # You can also create individual charts
    # fig1 = visualizer.create_performance_dashboard(metrics_dict)
    # fig1.show()
    
    # fig2 = visualizer.create_prediction_quality_analysis(predictions_dict, actual_data)
    # fig2.show()

## Docker Container Import Testing

When working with the Airflow containers, imports work correctly when you run Python from the right directory.

### ✅ Correct way to import in Airflow containers:

```bash
# Start container shell from the correct directory
docker compose exec airflow-worker bash

# You'll be in /opt/airflow - this is the correct working directory
pwd  # Should show: /opt/airflow

# Now run Python and import
python
```

```python
# These imports will work correctly:
import pandas as pd
from include.config import app_settings
from include.utilities.data_gen import RealisticSalesDataGenerator

# Test the imports
print("All imports successful!")
print("MLFLOW_HOST:", app_settings.MLFLOW_HOST)
gen = RealisticSalesDataGenerator(start_date="2025-09-01", end_date="2025-09-02", seed=42)
print("Data generator created:", type(gen))
```

### ❌ Common mistake - don't do this:

```bash
# Don't cd into the include directory first
cd include  # This breaks imports!
python      # Imports will fail from here
```

### Why this happens:

1. Our `PYTHONPATH` is set to `/opt/airflow/include`
2. When you run `python` from `/opt/airflow/include`, Python adds `.` (current directory) to sys.path
3. This creates a conflict where Python tries to import `include` from within itself
4. The solution: always run Python from `/opt/airflow` directory

In [None]:
# Test imports in Docker container (run this to verify everything works)
import json
import subprocess


def test_docker_imports():
    """Test that imports work correctly in the Airflow container."""

    # Test command to run in the container
    test_script = """
import sys
import pandas as pd
from include.config import app_settings
from include.utilities.data_gen import RealisticSalesDataGenerator

# Test results
results = {
    "python_path_includes_include": "/opt/airflow/include" in sys.path,
    "current_working_directory": __import__("os").getcwd(),
    "pandas_version": pd.__version__,
    "mlflow_host": app_settings.MLFLOW_HOST,
    "data_generator_created": str(type(RealisticSalesDataGenerator(start_date="2025-09-01", end_date="2025-09-02", seed=42)))
}

import json
print(json.dumps(results, indent=2))
"""

    try:
        # Run the test in the container
        cmd = [
            "docker",
            "compose",
            "exec",
            "-T",
            "airflow-worker",
            "python",
            "-c",
            test_script,
        ]

        result = subprocess.run(cmd, capture_output=True, text=True, cwd="../")

        if result.returncode == 0:
            test_results = json.loads(result.stdout.strip())
            print("✅ Docker container import test PASSED!")
            print("\nTest Results:")
            for key, value in test_results.items():
                print(f"  {key}: {value}")
            return True
        print("❌ Docker container import test FAILED!")
        print("Error output:", result.stderr)
        return False

    except Exception as e:
        print(f"❌ Failed to run Docker test: {e}")
        return False


# Run the test
test_docker_imports()