In [16]:
import pandas as pd
import plotly.graph_objects as go

# Load the dataset
file_path = r"C:\\Users\\c.hakker\\OneDrive - VISTA college\\Senior Stuff\\Opleiding Data science\\Data\\baseline-mean-errors-80072ned.xlsx"
data = pd.read_excel(file_path)

# Clean the dataset by identifying models
data['Model'] = data['sbi_title'].apply(lambda x: 'Baseline' if 'Baseline' in x else 'SARIMAX')
data['Cleaned_Branch'] = data['sbi_title'].str.replace("Baseline ", "").str.replace("SARIMAX ", "")

# Convert 'MAE' column to numeric
data['MAE'] = pd.to_numeric(data['MAE'], errors='coerce')

# Define custom colors for each branch
sarimax_colors = {'C Manufacturing': '#3f9abf', 'G Trade': '#006789', 'Q Healthcare': '#5fb5db'}
baseline_colors = {'C Manufacturing': '#deaa00', 'G Trade': '#fe9001', 'Q Healthcare': '#febf01'}

# Loop through years (2022 and 2023)
for year in [2022, 2023]:
    # Filter data for the year
    year_data = data[data['Year'] == year]

    # Aggregate data by averaging duplicate entries for the same quarter, branch, and model
    cleaned_data = year_data.groupby(['Cleaned_Branch', 'quarter', 'Model']).agg({'MAE': 'mean'}).reset_index()

    # Pivot the data for visualization
    pivot_data = cleaned_data.pivot(index=['Cleaned_Branch', 'quarter'], columns='Model', values='MAE').reset_index()

    # Create a combined y-axis label and sort by branch and quarter (ascending for Q1 to Q4)
    pivot_data['y_label'] = pivot_data['Cleaned_Branch'] + " Q" + pivot_data['quarter'].astype(str)
    pivot_data = pivot_data.sort_values(by=['Cleaned_Branch', 'quarter'], ascending=[False, False])

    # Create the Plotly figure
    fig = go.Figure()

    # Track unique branch-model combinations to avoid duplicate legend entries
    added_legends = set()

    # Add bars for each branch and quarter
    for _, row in pivot_data.iterrows():
        y_label = row['y_label']
        baseline_mae = row.get('Baseline', 0)  # Default to 0 if Baseline is missing
        sarimax_mae = row.get('SARIMAX', 0)  # Default to 0 if SARIMAX is missing
        branch = row['Cleaned_Branch']

        # Add Baseline bar
        if (branch, 'Baseline') not in added_legends:
            fig.add_trace(go.Bar(
                y=[y_label],
                x=[baseline_mae],
                name=f"{branch} Baseline ({year})",
                marker_color=baseline_colors[branch],
                orientation='h',
                text=[f"{baseline_mae:.2f}"],
                textposition='inside',
            ))
            added_legends.add((branch, 'Baseline'))
        else:
            fig.add_trace(go.Bar(
                y=[y_label],
                x=[baseline_mae],
                marker_color=baseline_colors[branch],
                orientation='h',
                showlegend=False,  # Prevent duplicate legend entry
                text=[f"{baseline_mae:.2f}"],
                textposition='inside',
            ))

        # Add SARIMAX bar
        if (branch, 'SARIMAX') not in added_legends:
            fig.add_trace(go.Bar(
                y=[y_label],
                x=[sarimax_mae],
                name=f"{branch} SARIMAX ({year})",
                marker_color=sarimax_colors[branch],
                orientation='h',
                text=[f"{sarimax_mae:.2f}"],
                textposition='inside',
            ))
            added_legends.add((branch, 'SARIMAX'))
        else:
            fig.add_trace(go.Bar(
                y=[y_label],
                x=[sarimax_mae],
                marker_color=sarimax_colors[branch],
                orientation='h',
                showlegend=False,  # Prevent duplicate legend entry
                text=[f"{sarimax_mae:.2f}"],
                textposition='inside',
            ))

    # Update the layout
    fig.update_layout(
        title=f"MAE Stacked Insights: Baseline vs SARIMAX Performance Across Quarters ({year})",
        xaxis_title=None,  # Remove x-axis text
        xaxis_showticklabels=False,  # Remove x-axis tick labels
        yaxis_title="Branch and Quarter",
        yaxis=dict(categoryorder='array', categoryarray=pivot_data['y_label']),  # Correct order for y-axis
        barmode='stack',
        legend_title="Model",
        template="plotly_white"
    )

    # Save the plot as an HTML file
    output_file = f"Stacked_MAE_Comparison_{year}.html"
    fig.write_html(output_file)

    # Show the plot for the year
    fig.show()
