In [None]:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp

from logging import getLogger

from relex.data_processing import prepare_data_for_modeling
from relex.model.definition import define_model
from relex.model.training import train_model
from relex.model.logging import log_inferred_parameters
from relex.model.forcasting import forecast_model
from relex.model.visualization import (
    plot_forecast,
    plot_components_with_forecast,
)
from relex.data_io import load_data_in_wide_format
from relex.model.decomposition import (
    decompose_historical_time_series,
    decompose_forecast_time_series,
)
from relex.model.evaluation import evaluate

# Configure logger
logger = getLogger(__name__)
NUM_STEPS_FORECAST = 23

# Load Data
def analyze_item(item_num, merged_df):

    logger.info(f"\n--- Analyzing Item {item_num} ---")

    # Prepare data: train_df and test_df are produced by your helper function.
    historical_data, forecast_data = prepare_data_for_modeling(
        df=merged_df,
        num_steps_forecast=NUM_STEPS_FORECAST
    )

    # Pass both training and forecast data to the model.
    model = define_model(historical_data=historical_data, forecast_data=forecast_data, item_num=item_num)
    variational_posteriors = train_model(
        model=model,
        historical_data=historical_data,
        item_num=item_num,
        sales_col=f"sales_item_{item_num}",
    )

    # Draw samples from the variational posterior
    parameter_samples = variational_posteriors.sample(50)
    log_inferred_parameters(model=model, parameter_samples=parameter_samples)

    forecast_dist, forecast_mean, forecast_scale, forecast_samples = forecast_model(
        item_model=model,
        historical_data=historical_data,
        parameter_samples=parameter_samples,
        num_steps_forecast=NUM_STEPS_FORECAST,
        sales_col=f"sales_item_{item_num}",
    )

    plot_forecast(
        historical_data=historical_data,
        forecast_data=forecast_data,
        forecast_mean=forecast_mean,
        forecast_scale=forecast_scale,
        forecast_samples=forecast_samples,
        sales_col=f"sales_item_{item_num}",
        title=f"Sales Forecast for Item {item_num}",
    )

    # Decompose time series
    historical_component_means, historical_component_stddevs = decompose_historical_time_series(
        model=model,
        historical_data=historical_data,
        parameter_samples=parameter_samples,
        sales_col=f"sales_item_{item_num}",
    )
    forecast_component_means, forecast_component_stddevs = (
        decompose_forecast_time_series(
            model=model,
            forecast_dist=forecast_dist,
            parameter_samples=parameter_samples
        )
    )

    # Plot components with forecasts
    plot_components_with_forecast(
        historical_dates=historical_data.index,
        forecast_dates=forecast_data.index,
        historical_component_means=historical_component_means,
        historical_component_stddevs=historical_component_stddevs,
        forecast_component_means=forecast_component_means,
        forecast_component_stddevs=forecast_component_stddevs,
        item_num=item_num
    )

    # rmse, mape = evaluate(data, forecast_mean, historical_component_means, forecast_component_means)

    return {
        "model": model,
        "forecast_mean": forecast_mean,
        "forecast_scale": forecast_scale,
        "parameters": parameter_samples,
        "component_means": historical_component_means,
        "component_stddevs": historical_component_stddevs,
        "forecast_component_means": forecast_component_means,
        "forecast_component_stddevs": forecast_component_stddevs,
        "metrics": {"rmse": rmse, "mape": mape},
    }


if __name__ == "__main__":

    # Set random seed for reproducibility
    np.random.seed(42)
    tf.random.set_seed(42)
    tf.compat.v1.enable_eager_execution()
    sns.set_context("notebook", font_scale=1.0)
    sns.set_style("whitegrid")

    merged_df = load_data_in_wide_format(
        price_data_path="../data/average_price.csv",
        sales_data_path="../data/category_sales.csv",
    )

    results = {
        f"item_{item_num}": analyze_item(item_num, merged_df)
        for item_num in range(1, 5)
    }