In [None]:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp

from logging import getLogger

from relex.data_processing import prepare_data_for_modeling
from relex.model.definition import define_model
from relex.model.training import train_model
from relex.model.logging import log_inferred_parameters
from relex.model.forcasting import forecast_model
from relex.model.visualization import (
    plot_forecast,
    plot_components_with_forecast,
)
from relex.data_io import load_data_in_wide_format

# Configure logger
logger = getLogger(__name__)

# Set random seed for reproducibility
np.random.seed(42)
tf.random.set_seed(42)
tf.compat.v1.enable_eager_execution()
sns.set_context("notebook", font_scale=1.0)
sns.set_style("whitegrid")

# Load Data
def analyze_item(item_num, merged_df):
    print(f"\n--- Analyzing Item {item_num} ---")

    # Prepare data
    data = prepare_data_for_modeling(item_num, df=merged_df)

    # Build & train the model
    item_model = define_model(data["train_sales"], data["train_prices"])
    variational_posteriors = train_model(item_model=item_model, data=data, item_num=item_num)

    # Draw samples from the variational posterior
    parameter_samples = variational_posteriors.sample(50)
    log_inferred_parameters(item_model=item_model, parameter_samples=parameter_samples)

    forecast_dist, forecast_mean, forecast_scale, forecast_samples = forecast_model(
        item_model=item_model, data=data, parameter_samples=parameter_samples)

    # Convert dates to pd.DatetimeIndex for easier handling
    dates_pd = pd.DatetimeIndex(data["dates"])

    # Create forecast dates correctly - weekly frequency
    last_train_date = dates_pd[len(data["train_sales"]) - 1]
    forecast_dates = pd.date_range(
        start=last_train_date + pd.Timedelta(days=7),
        periods=data["num_forecast_steps"],
        freq="W",
    )

    # Plot the forecast
    fig, ax = plot_forecast(
        dates_pd,
        data["full_sales"],
        forecast_mean,
        forecast_scale,
        forecast_samples,
        title=f"Sales Forecast for Item {item_num}",
    )

    # Decompose the time series into components - only for training data
    print("Decomposing historical time series into components...")
    component_dists = tfp.sts.decompose_by_component(
        item_model,
        observed_time_series=data["train_sales"],
        parameter_samples=parameter_samples,
    )

    # Extract component means and standard deviations
    component_means = {k.name: c.mean().numpy() for k, c in component_dists.items()}
    component_stddevs = {k.name: c.stddev().numpy() for k, c in component_dists.items()}

    # Now decompose the forecast into components
    print("Decomposing forecast into components...")
    forecast_component_dists = tfp.sts.decompose_forecast_by_component(
        model=item_model,
        forecast_dist=forecast_dist,
        parameter_samples=parameter_samples,
    )

    # Extract forecast component means and standard deviations
    component_forecast_means = {}  # Make sure this variable name is used consistently
    component_forecast_stddevs = {}

    print("\nComponent dimensions:")
    for k, c in forecast_component_dists.items():
        mean_val = c.mean().numpy()
        std_val = c.stddev().numpy()
        print(
            f"Component {k.name}: mean shape {mean_val.shape}, stddev shape {std_val.shape}"
        )

        # The forecast components might have an extra dimension at the end
        # Extract the appropriate slice based on shape
        if mean_val.ndim > 1:
            mean_val = mean_val[..., 0]  # Take the first slice of the last dimension
        if std_val.ndim > 1:
            std_val = std_val[..., 0]

        component_forecast_means[k.name] = mean_val  # Use consistent variable name
        component_forecast_stddevs[k.name] = std_val  # Use consistent variable name

    # Plot components with forecasts
    fig, _ = plot_components_with_forecast(
        dates=dates_pd,  # All dates
        train_dates=dates_pd[: len(data["train_sales"])],  # Historical dates
        forecast_dates=forecast_dates,  # Forecast dates
        component_means_dict=component_means,
        component_stddevs_dict=component_stddevs,
        component_forecast_means_dict=component_forecast_means,  # Use consistent variable name
        component_forecast_stddevs_dict=component_forecast_stddevs,  # Use consistent variable name
    )
    plt.suptitle(f"Component Decomposition for Item {item_num}")
    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    plt.show()

    # Calculate forecast accuracy metrics
    forecast_error = data["test_sales"] - forecast_mean
    rmse = np.sqrt(np.mean(forecast_error**2))
    mape = np.mean(np.abs(forecast_error / data["test_sales"])) * 100

    print(f"\nForecast Accuracy Metrics:")
    print(f"RMSE: {rmse:.2f}")
    print(f"MAPE: {mape:.2f}%")

    # Calculate component contributions (variance explained)
    total_variance = np.var(data["train_sales"])
    component_variances = {name: np.var(mean) for name, mean in component_means.items()}
    component_percentages = {
        name: var / total_variance * 100 for name, var in component_variances.items()
    }

    print("\nComponent Contribution (% variance explained):")
    for name, percentage in component_percentages.items():
        print(f"{name}: {percentage:.2f}%")

    # Calculate component contributions to forecast
    forecast_variance = np.var(forecast_mean)
    forecast_component_variances = {
        name: np.var(mean)
        for name, mean in component_forecast_means.items()  # Use consistent variable name
    }
    forecast_component_percentages = {
        name: var / forecast_variance * 100
        for name, var in forecast_component_variances.items()
    }

    print("\nComponent Contribution to Forecast (% variance explained):")
    for name, percentage in forecast_component_percentages.items():
        print(f"{name}: {percentage:.2f}%")

    return {
        "model": item_model,
        "forecast_mean": forecast_mean,
        "forecast_scale": forecast_scale,
        "parameters": parameter_samples,
        "component_means": component_means,
        "component_stddevs": component_stddevs,
        "component_forecast_means": component_forecast_means,  # Use consistent variable name
        "component_forecast_stddevs": component_forecast_stddevs,  # Use consistent variable name
        "metrics": {"rmse": rmse, "mape": mape},
    }


if __name__ == "__main__":

    merged_df = load_data_in_wide_format(
        price_data_path="../data/average_price.csv",
        sales_data_path="../data/category_sales.csv",
    )

    results = {
        f"item_{item_num}": analyze_item(item_num, merged_df)
        for item_num in range(1, 5)
    }