In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.graph_objs as go
from plotly.subplots import make_subplots
from sklearn.metrics import mean_squared_error, mean_absolute_error
from statsmodels.tsa.statespace.sarimax import SARIMAX
from statsmodels.tsa.arima.model import ARIMA
from sklearn.metrics import mean_absolute_error, mean_squared_error

## Functions

In [4]:
import plotly.graph_objs as go
from plotly.subplots import make_subplots

def plot_sarimax_forecast(forecasts):
    """Plot SARIMAX forecast with observed values, predictions, and confidence intervals.

    Args:
        forecasts (pd.DataFrame): DataFrame containing the SARIMAX forecast results.
            Should include the following columns: 
            - 'test': Actual observed values (test data).
            - 'predicted_mean': Predicted mean values from the SARIMAX model.
            - 'lower_bound': Lower bound of the forecast confidence interval.
            - 'upper_bound': Upper bound of the forecast confidence interval.

    Returns:
        None: Displays the SARIMAX forecast plot.
    """
    fig = make_subplots()

    fig.add_trace(go.Scatter(
        x=forecasts.index, 
        y=forecasts.test, 
        name='Test',
        line=dict(dash='solid')
    ))

    fig.add_trace(go.Scatter(
        x=forecasts.index, 
        y=forecasts.predicted_mean, 
        mode='lines+markers',
        name='Mean-SARIMAX Forecast',
        marker=dict(symbol='circle', size=6),
        line=dict(dash='dot')
    ))

    fig.add_trace(go.Scatter(
        x=forecasts.index, 
        y=forecasts.lower_bound, 
        mode='lines+markers',
        name='P25-SARIMAX Forecast',
        marker=dict(symbol='triangle-up', size=8),
        line=dict(dash='dot')
    ))

    fig.add_trace(go.Scatter(
        x=forecasts.index, 
        y=forecasts.upper_bound, 
        mode='lines+markers',
        name='P75-SARIMAX Forecast',
        marker=dict(symbol='triangle-down', size=8),
        line=dict(dash='dot')
    ))

    fig.update_layout(
        title='SARIMAX Forecasting for Next 30 Days',
        xaxis_title='Time',
        yaxis_title='Value',
        legend_title='Series',
        width=800,
        height=500
    )

    fig.show()

def calculate_sarimax_metrics(test, forecasts):
    """Calculate MAE and RMSE metrics for SARIMAX predictions.

    Args:
        test (pd.Series): The actual observed values.
        forecasts (pd.DataFrame): The SARIMAX model predictions, must include a 'predicted_mean' column.

    Returns:
        tuple: Contains the MAE and RMSE values for the SARIMAX model.
    """
    sarimax_mae = mean_absolute_error(test, forecasts.predicted_mean)
    sarimax_mse = mean_squared_error(test, forecasts.predicted_mean)
    sarimax_rmse = np.sqrt(sarimax_mse)

    print(f"SARIMAX MAE: {sarimax_mae:.2f}, SARIMAX RMSE: {sarimax_rmse:.2f}")

    return sarimax_mae, sarimax_rmse