In [None]:
import ast

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from global_graph_settings import set_global_graph_settings  # Add graph settings

# Load the DataFrame from the CSV file
demographics_df = pd.read_csv("../data/real/standard_school_aggregate_demographic.csv")


# Helper function to convert values to lists
def convert_to_list(value):
    if isinstance(value, str) and value.startswith("[") and value.endswith("]"):
        try:
            return ast.literal_eval(value)
        except (ValueError, SyntaxError):
            return value
    return value


demographics_df["cat_lab"] = demographics_df["cat_lab"].apply(convert_to_list)


# Helper function to convert to list and handle multi-word phrases
def convert_to_list_cat_lab(value):
    if isinstance(value, list):
        return value
    if isinstance(value, str):
        try:
            value = ast.literal_eval(value)
            if isinstance(value, list):
                return value
        except (ValueError, SyntaxError):
            pass

        multi_word_phrases = [
            "Prefer not to say",
            "Currently unsure",
            "I describe myself another way",
            "Not sure",
            "Gay/lesbian",
            "Heterosexual/straight",
            "I dont know",
            "Non-binary",
        ]

        for phrase in multi_word_phrases:
            value = value.replace(phrase, f"{{{{{phrase}}}}}")

        value_list = [item.strip() for item in value.split(",")]
        value_list = [item.replace("{{", "").replace("}}", "") for item in value_list]

        return value_list
    return []


demographics_df["cat_lab"] = demographics_df["cat_lab"].apply(convert_to_list_cat_lab)


# Function to convert percentages to list and handle NaNs properly
def convert_percentage_to_list(value):
    if isinstance(value, str):
        try:
            value = value.strip("[]").replace("nan", "0")  # Replace 'nan' with '0'
            values = list(map(float, value.split(",")))  # Convert comma-separated string to list of floats
            return [round(x) if not pd.isna(x) else 0 for x in values]
        except Exception as e:  # noqa: BLE001
            print(f"Error converting value '{value}': {e}")
            return []
    elif isinstance(value, list):
        return [round(x) if not pd.isna(x) else 0 for x in value]
    return []


demographics_df["percentage"] = demographics_df["percentage"].apply(convert_percentage_to_list)
demographics_df["count"] = demographics_df["count"].apply(
    convert_percentage_to_list
)  # Converting count to list as well

# Display the final processed DataFrame
print(demographics_df)

# Design Colors and Background
LEFT_BAR_COLOUR = "#ea7555"  # Warm coral for 'Your School'
RIGHT_BAR_COLOUR = "#f1b79f"  # Light peach for 'Other Schools'
BACKGROUND_COLOR = "#ffffff"  # White background


def plot_comparison(school_name, df) -> None:
    # Filter data for the chosen school and all other schools
    school_data = df[df["School_lab"] == school_name]
    other_schools_data = df[df["School_lab"] != school_name]
    set_global_graph_settings()  # Apply global graph settings

    for _index, row in school_data.iterrows():
        measure = row["measure_lab"]
        cat_labels = row["cat_lab"]
        percentages = row["percentage"]
        n_responses = row["count"]  # This is the 'count' column for the selected school

        # Ensure that n_responses matches the categories (ignore NaN values)
        valid_indices = [i for i, p in enumerate(percentages) if p > 0]
        valid_cat_labels = [cat_labels[i] for i in valid_indices]
        valid_percentages = [percentages[i] for i in valid_indices]
        valid_n_responses = [n_responses[i] for i in valid_indices if not np.isnan(n_responses[i])]

        # Calculate actual counts for "Your School"
        actual_counts = valid_n_responses

        # Calculate the average counts for "Other Schools" (average count for each category)
        avg_counts_other_schools = []
        for i, cat in enumerate(valid_cat_labels):
            category_data = other_schools_data[other_schools_data["cat_lab"].apply(lambda x: cat in x)]
            avg_count = np.mean(category_data["count"].apply(lambda x: x[i] if len(x) > i else 0))
            avg_counts_other_schools.append(avg_count)

        # If there are valid categories to plot
        if valid_cat_labels:
            x = np.arange(len(valid_cat_labels))  # X positions for categories
            width = 0.35  # Bar width (fixed size)

            fig, ax = plt.subplots(figsize=(10, 6))
            fig.subplots_adjust(top=1)

            # Plot bars for 'Your School' and 'Other Schools'
            bars1 = ax.bar(
                x - width / 2, valid_percentages, width, label="Your school", color=LEFT_BAR_COLOUR
            )  # Changed school name to 'Your School'
            bars2 = ax.bar(x + width / 2, valid_percentages, width, label="Other schools", color=RIGHT_BAR_COLOUR)

            # Fixed logic of this loop to include all bars, and add %s for both left and right bars
            # Add percentage values for both sets of bars
            for i, (bar1, bar2) in enumerate(zip(bars1, bars2)):
                # Display percentages for 'Your School' bars
                height1 = bar1.get_height()
                ax.text(
                    bar1.get_x() + bar1.get_width() / 2,
                    height1 + 2,  # Move % slightly higher to not overlap with n
                    f"{height1}%",
                    ha="center",
                    va="bottom",
                    fontsize=8,
                )

                # Display n_responses at the bottom of 'Your School' bars
                ax.text(
                    bar1.get_x() + bar1.get_width() / 2,
                    1,
                    f"n={actual_counts[i]}",
                    ha="center",
                    va="bottom",
                    fontsize=8,
                    color="black",
                )

                # Display percentages for 'Other Schools' bars
                height2 = bar2.get_height()
                ax.text(
                    bar2.get_x() + bar2.get_width() / 2,
                    height2 + 2, # Move % slightly higher to not overlap with n
                    f"{height2}%",
                    ha="center",
                    va="bottom",
                    fontsize=8,
                )

                # Display n_responses for 'Other Schools' at the bottom
                ax.text(
                    bar2.get_x() + bar2.get_width() / 2,
                    1,
                    f"n={round(avg_counts_other_schools[i])}",
                    ha="center",
                    va="bottom",
                    fontsize=8,
                    color="black",
                )

            # Labeling and formatting
            ax.set_xlabel("Responses")
            ax.set_ylabel("Percentage of pupils")
            ax.set_title(measure, pad=0.8)  # Add padding to the title
            ax.set_ylim(0, 100)  # Set Y-axis limit to 100
            ax.set_title(f"{measure}")  # Removed "Comparison" from title
            fig.suptitle(school_name, fontsize=10, x=0.5, y=1.05)  # Added temporary label for school
            ax.set_xticks(x)
            ax.set_xticklabels(valid_cat_labels, rotation=-45, ha="left")

            # Add gridlines behind bars
            ax.grid(visible=True, which="both", axis="y", color="#CCCCCC", zorder=0)  # Make lines not dashed

            # Remove spines (frame)
            for spine in ax.spines.values():
                spine.set_visible(False)

            # Add legend and adjust position
            ax.legend(bbox_to_anchor=(1, 1), loc="upper left")

            # Ensure tight layout
            plt.tight_layout()
            plt.show()


# Loop through all schools and generate comparisons
for school_name in demographics_df["School_lab"].unique():
    plot_comparison(school_name, demographics_df)
