In [None]:
import ast

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Load the DataFrame from the CSV file
df = pd.read_csv("../data/real/standard_school_aggregate_demographic.csv")


# Helper function to convert values to lists
def convert_to_list(value):
    if isinstance(value, str):
        if value.startswith("[") and value.endswith("]"):
            try:
                return ast.literal_eval(value)
            except (ValueError, SyntaxError):
                return value
    return value


df["cat_lab"] = df["cat_lab"].apply(convert_to_list)


# Helper function to convert to list and handle multi-word phrases
def convert_to_list_cat_lab(value):
    if isinstance(value, list):
        return value
    if isinstance(value, str):
        try:
            value = ast.literal_eval(value)
            if isinstance(value, list):
                return value
        except (ValueError, SyntaxError):
            pass

        multi_word_phrases = [
            "Prefer not to say",
            "Currently unsure",
            "I describe myself another way",
            "Not sure",
            "Gay/lesbian",
            "Heterosexual/straight",
            "I dont know",
            "Non-binary",
        ]

        for phrase in multi_word_phrases:
            value = value.replace(phrase, f"{{{{{phrase}}}}}")

        value_list = [item.strip() for item in value.split(",")]
        value_list = [item.replace("{{", "").replace("}}", "") for item in value_list]

        return value_list
    return []


df["cat_lab"] = df["cat_lab"].apply(convert_to_list_cat_lab)


# Function to convert percentages to list and handle NaNs properly
def convert_percentage_to_list(value):
    if isinstance(value, str):
        try:
            value = value.strip("[]").replace("nan", "0")  # Replace 'nan' with '0'
            values = list(map(float, value.split(",")))  # Convert comma-separated string to list of floats
            return [round(x) if not pd.isna(x) else 0 for x in values]
        except Exception as e:
            print(f"Error converting value '{value}': {e}")
            return []
    elif isinstance(value, list):
        return [round(x) if not pd.isna(x) else 0 for x in value]
    return []


df["percentage"] = df["percentage"].apply(convert_percentage_to_list)
df["count"] = df["count"].apply(convert_percentage_to_list)  # Converting count to list as well

# Display the final processed DataFrame
print(df)

# Design Colors and Background
LEFT_BAR_COLOUR = "#ea7555"  # Warm coral for 'Your School'
RIGHT_BAR_COLOUR = "#f1b79f"  # Light peach for 'Other Schools'
BACKGROUND_COLOR = "#ffffff"  # White background


def plot_comparison(school_name, df):
    # Filter data for the chosen school and all other schools
    school_data = df[df["School_lab"] == school_name]
    other_schools_data = df[df["School_lab"] != school_name]

    for index, row in school_data.iterrows():
        measure = row["measure_lab"]
        cat_labels = row["cat_lab"]
        percentages = row["percentage"]
        n_responses = row["count"]  # This is the 'count' column for the selected school

        # Ensure that n_responses matches the categories (ignore NaN values)
        valid_indices = [i for i, p in enumerate(percentages) if p > 0]
        valid_cat_labels = [cat_labels[i] for i in valid_indices]
        valid_percentages = [percentages[i] for i in valid_indices]
        valid_n_responses = [n_responses[i] for i in valid_indices if not np.isnan(n_responses[i])]

        # Calculate actual counts for "Your School"
        actual_counts = valid_n_responses

        # Calculate the average counts for "Other Schools" (average count for each category)
        avg_counts_other_schools = []
        for i, cat in enumerate(valid_cat_labels):
            category_data = other_schools_data[other_schools_data["cat_lab"].apply(lambda x: cat in x)]
            avg_count = np.mean(category_data["count"].apply(lambda x: x[i] if len(x) > i else 0))
            avg_counts_other_schools.append(avg_count)

        # If there are valid categories to plot
        if valid_cat_labels:
            x = np.arange(len(valid_cat_labels))  # X positions for categories
            width = 0.35  # Bar width (fixed size)

            fig, ax = plt.subplots(figsize=(10, 6))

            # Plot bars for 'Your School' and 'Other Schools'
            bars1 = ax.bar(x - width / 2, valid_percentages, width, label=f"{school_name}", color=LEFT_BAR_COLOUR)
            bars2 = ax.bar(x + width / 2, valid_percentages, width, label="Other Schools", color=RIGHT_BAR_COLOUR)

            # Add percentage values inside the bars or just above them if they are small
            for i, bar in enumerate(bars1):
                height = bar.get_height()
                # Adjust the position for larger/smaller percentages
                if height > 6:
                    ax.text(
                        bar.get_x() + bar.get_width() / 2,
                        height + 1,
                        f"{height}%",
                        ha="center",
                        va="bottom",
                        fontsize=8,
                    )
                else:
                    ax.text(
                        bar.get_x() + bar.get_width() / 2,
                        height + 1,
                        f"{height}%",
                        ha="center",
                        va="bottom",
                        fontsize=8,
                    )

                # Add n_responses (count) at the bottom of the bars (inside the bar)
                ax.text(
                    bar.get_x() + bar.get_width() / 2,
                    0 + 1,
                    f"n={actual_counts[i]}",
                    ha="center",
                    va="bottom",
                    fontsize=8,
                    color="black",
                )

            # Add average n_responses for 'Other Schools' (right bars)
            for i, bar in enumerate(bars2):
                ax.text(
                    bar.get_x() + bar.get_width() / 2,
                    0 + 1,
                    f"n={round(avg_counts_other_schools[i])}",
                    ha="center",
                    va="bottom",
                    fontsize=8,
                    color="black",
                )

            # Labeling and formatting
            ax.set_xlabel("Responses")
            ax.set_ylabel("Percentage of pupils")
            ax.set_title(f"{measure} Comparison")  # Measure name as title
            ax.set_xticks(x)
            ax.set_xticklabels(valid_cat_labels, rotation=-45, ha="left")

            # Add gridlines behind bars
            ax.grid(True, which="both", axis="y", linestyle="--", color="#CCCCCC", zorder=0)

            # Remove spines (frame)
            for spine in ax.spines.values():
                spine.set_visible(False)

            # Add a legend
            ax.legend()

            # Ensure tight layout
            plt.tight_layout()
            plt.show()


# Loop through all schools and generate comparisons
for school_name in df["School_lab"].unique():
    plot_comparison(school_name, df)
