In [None]:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp

from logging import getLogger

from relex.data_processing import prepare_data_for_modeling
from relex.model.definition import define_model
from relex.model.training import train_model
from relex.model.logging import log_inferred_parameters
from relex.model.forcasting import forecast_model
from relex.model.visualization import (
    plot_forecast,
    plot_components_with_forecast,
)
from relex.data_io import load_data_in_wide_format
from relex.model.decomposition import (
    decompose_historical_time_series,
    decompose_forecast_time_series,
)
from relex.model.evaluation import evaluate

# Configure logger
logger = getLogger(__name__)

# Set random seed for reproducibility
np.random.seed(42)
tf.random.set_seed(42)
tf.compat.v1.enable_eager_execution()
sns.set_context("notebook", font_scale=1.0)
sns.set_style("whitegrid")

# Load Data
def analyze_item(item_num, merged_df):
    print(f"\n--- Analyzing Item {item_num} ---")

    # Prepare data
    data = prepare_data_for_modeling(item_num, df=merged_df)

    item_model = define_model(data["train_sales"], data["train_prices"])
    variational_posteriors = train_model(item_model=item_model, data=data, item_num=item_num)

    # Draw samples from the variational posterior
    parameter_samples = variational_posteriors.sample(50)
    log_inferred_parameters(item_model=item_model, parameter_samples=parameter_samples)

    forecast_dist, forecast_mean, forecast_scale, forecast_samples = forecast_model(
        item_model=item_model, data=data, parameter_samples=parameter_samples)

    # Get releveant date indeicies
    dates_pd = pd.DatetimeIndex(data["dates"])
    forecast_dates = pd.DatetimeIndex(data["test_dates"])

    plot_forecast(
        dates_pd,
        data["full_sales"],
        forecast_mean,
        forecast_scale,
        forecast_samples,
        title=f"Sales Forecast for Item {item_num}",
    )

    # Decompose time series
    component_means, component_stddevs = decompose_historical_time_series(
        item_model, data, parameter_samples
    )
    forecast_component_means, forecast_component_stddevs = (
        decompose_forecast_time_series(item_model, forecast_dist, parameter_samples)
    )

    # Plot components with forecasts
    fig, _ = plot_components_with_forecast(
        dates=dates_pd,  # All dates
        train_dates=dates_pd[: len(data["train_sales"])],  # Historical dates
        forecast_dates=forecast_dates,  # Forecast dates
        component_means_dict=component_means,
        component_stddevs_dict=component_stddevs,
        forecast_component_means=forecast_component_means,  # Use consistent variable name
        forecast_component_stddevs=forecast_component_stddevs,  # Use consistent variable name
    )

    rmse, mape = evaluate(data, forecast_mean, component_means, forecast_component_means)

    return {
        "model": item_model,
        "forecast_mean": forecast_mean,
        "forecast_scale": forecast_scale,
        "parameters": parameter_samples,
        "component_means": component_means,
        "component_stddevs": component_stddevs,
        "forecast_component_means": forecast_component_means,
        "forecast_component_stddevs": forecast_component_stddevs,
        "metrics": {"rmse": rmse, "mape": mape},
    }


if __name__ == "__main__":

    merged_df = load_data_in_wide_format(
        price_data_path="../data/average_price.csv",
        sales_data_path="../data/category_sales.csv",
    )

    results = {
        f"item_{item_num}": analyze_item(item_num, merged_df)
        for item_num in range(1, 5)
    }