In [1]:
import ast
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from global_graph_settings import set_global_graph_settings  # Add graph settings

# Load the DataFrame from the CSV file
demographics_df = pd.read_csv(
    "/Users/ellengoddard/Desktop/development-folder/beewell-graphs/kailo-beewell-graphs/data/real/symbol_nd_aggregate_demographic.csv"
)


# Helper function to convert values to lists
def convert_to_list(value):
    if isinstance(value, str) and value.startswith("[") and value.endswith("]"):
        try:
            return ast.literal_eval(value)
        except (ValueError, SyntaxError):
            return value
    return value


# Apply conversion for categorical labels
demographics_df["cat_lab"] = demographics_df["cat_lab"].apply(convert_to_list)


# Helper function to convert percentages to list and handle NaNs
def convert_percentage_to_list(value):
    if isinstance(value, str):
        try:
            value = value.strip("[]").replace("nan", "0")  # Replace 'nan' with '0'
            values = list(map(float, value.split(",")))  # Convert comma-separated string to list of floats
            return [round(x) if not pd.isna(x) else 0 for x in values]
        except Exception as e:
            print(f"Error converting value '{value}': {e}")
            return []
    elif isinstance(value, list):
        return [round(x) if not pd.isna(x) else 0 for x in value]
    return []


# Apply conversion for percentages and counts
demographics_df["percentage"] = demographics_df["percentage"].apply(convert_percentage_to_list)
demographics_df["count"] = demographics_df["count"].apply(convert_percentage_to_list)


# Normalize percentages to 100 if necessary
def normalize_to_100(percentages):
    total = sum(percentages)
    if total == 0:
        return [0] * len(percentages)  # If all are zero, return a list of zeros
    return [round((p / total) * 100, 2) for p in percentages]


# Define colors and styles for plotting
LEFT_BAR_COLOUR = "#ea7555"  # Color for 'Your School'
RIGHT_BAR_COLOUR = "#f1b79f"  # Color for 'Other Schools'
BACKGROUND_COLOR = "#ffffff"  # White background


def plot_comparison(school_name, df, output_dir):
    # Filter the "Your School" data for the current school
    school_data = df[df["School_lab"] == school_name]
    your_school_data = school_data[school_data["school_group_lab"] == "Your school"]
    # Filter the "Other Schools" data for the same school
    other_schools_data = school_data[school_data["school_group_lab"] == "Other schools"]
    # Ensure global graph settings are applied
    set_global_graph_settings()  # Apply global graph settings

    uni_measures = list(school_data["measure"].unique())
    for measure in uni_measures:
        your_school_measure = your_school_data[your_school_data["measure"] == measure]
        your_school_measure.reset_index(inplace=True, drop=True)
        other_school_measure = other_schools_data[other_schools_data["measure"] == measure]
        other_school_measure.reset_index(inplace=True, drop=True)

        measure_lab = your_school_measure.loc[0, "measure_lab"]

        cat_labels = your_school_measure.loc[0, "cat_lab"]
        percentages = your_school_measure.loc[0, "percentage"]
        n_responses = your_school_measure.loc[0, "count"]

        other_cat_labels = other_school_measure.loc[0, "cat_lab"]
        other_percentages = other_school_measure.loc[0, "percentage"]
        other_n_responses = other_school_measure.loc[0, "count"]

        if measure == "NYC (Census)":
            measure = "Year group"

        valid_indices = [i for i, p in enumerate(percentages) if p > 0]
        valid_cat_labels = [cat_labels[i] for i in valid_indices]
        valid_percentages = [percentages[i] for i in valid_indices]
        valid_n_responses = [n_responses[i] for i in valid_indices if not np.isnan(n_responses[i])]

        # Get the counts for "Your School"
        actual_counts = valid_n_responses

        # Extract the pre-calculated "Other Schools" data for the same categories
        avg_counts_other_schools = []
        avg_percentages_other_schools = []

        for cat in valid_cat_labels:
            if cat in other_cat_labels:
                avg_counts_other_schools.append(other_n_responses[other_cat_labels.index(cat)])
                avg_percentages_other_schools.append(other_percentages[other_cat_labels.index(cat)])
            else:
                avg_counts_other_schools.append(0)
                avg_percentages_other_schools.append(0)

        print(valid_cat_labels)
        print(valid_percentages)
        print(avg_percentages_other_schools)

        # # Iterate through rows in the "Your School" data
        # for _index, row in school_data.iterrows():
        #     measure = row["measure_lab"]
        #     cat_labels = row["cat_lab"]
        #     percentages = row["percentage"]
        #     n_responses = row["count"]  # The count for 'Your School'
        #     # Manually change 'NYC Census' to 'Year group' in the title if applicable
        #     if measure == "NYC (Census)":
        #         measure = "Year group"

        #     # Ensure that percentages and counts match the categories (ignore NaN values)
        #     valid_indices = [i for i, p in enumerate(percentages) if p > 0]
        #     valid_cat_labels = [cat_labels[i] for i in valid_indices]
        #     valid_percentages = [percentages[i] for i in valid_indices]
        #     valid_n_responses = [n_responses[i] for i in valid_indices if not np.isnan(n_responses[i])]

        #     # Get the counts for "Your School"
        #     actual_counts = valid_n_responses

        #     # Extract the pre-calculated "Other Schools" data for the same categories
        #     avg_counts_other_schools = []
        #     avg_percentages_other_schools = []
        #     # print(demographics_df.columns)
        #     # Loop through each category and match with the "Other Schools" data
        #     for cat in valid_cat_labels:
        #         print(cat)
        #         # Get the data for the current category in "Other Schools"
        #         other_category_data = other_schools_data[other_schools_data["cat_lab"].apply(lambda x: cat in x)]
        #         # Extract the pre-calculated counts and percentages for "Other Schools"
        #         other_counts = other_category_data["count"].dropna()
        #         other_percentages = other_category_data["percentage"].dropna()
        #         print(other_percentages)
        #         # If there are valid counts for "Other Schools", extract the first one (assumes pre-calculated data)
        #         if len(other_counts) > 0:
        #             avg_count_other_schools = other_counts.iloc[0]  # Use the first available count
        #             avg_percentage_other_schools = other_percentages.iloc[0]  # Use the first available percentage
        #         else:
        #             avg_count_other_schools = 0
        #             avg_percentage_other_schools = 0

        #         # Append the results for comparison
        #         avg_counts_other_schools.append(avg_count_other_schools)
        #         avg_percentages_other_schools.append(avg_percentage_other_schools)

        # If there are valid categories to plot
        if valid_cat_labels:
            num_categories = len(valid_cat_labels)
            width = 0.25  # Bar width (fixed size)

            # Set a consistent figure size for all plots
            fig, ax = plt.subplots(figsize=(10, 6))  # Fixed figure size for all plots
            fig.subplots_adjust(top=0.85)  # Adjust this for more space at the top

            # Plot bars for 'Your School' and 'Other Schools'
            bars1 = ax.bar(
                np.arange(num_categories) - width / 2,
                valid_percentages,
                width,
                label="Your School",
                color=LEFT_BAR_COLOUR,
                align="center",
            )
            bars2 = ax.bar(
                np.arange(num_categories) + width / 2,
                avg_percentages_other_schools,
                width,
                label="Other Schools",
                color=RIGHT_BAR_COLOUR,
                align="center",
            )

            # Add percentage values for both sets of bars
            for i, (bar1, bar2) in enumerate(zip(bars1, bars2)):
                height1 = bar1.get_height()
                height2 = bar2.get_height()

                # Display percentages for 'Your School' bars
                ax.text(
                    bar1.get_x() + bar1.get_width() / 2,
                    height1 + 6,  # Space above the top of the bar
                    f"{height1:.0f}%",
                    ha="center",
                    va="bottom",
                    fontsize=8,
                )

                # Display percentages for 'Other Schools' bars
                ax.text(
                    bar2.get_x() + bar2.get_width() / 2,
                    height2 + 6,  # Space above the top of the bar
                    f"{height2:.0f}%",
                    ha="center",
                    va="bottom",
                    fontsize=8,
                )

                # Adjust n_responses label placement for "Your School"
                n_responses_label = f"n={int(actual_counts[i])}"  # Convert to integer for cleaner output
                if height1 < 6:
                    ax.text(
                        bar1.get_x() + bar1.get_width() / 2,
                        height1 + 2,  # Above the bar, leaving space for percentage
                        n_responses_label,
                        ha="center",
                        va="bottom",
                        fontsize=8,
                    )
                else:
                    ax.text(
                        bar1.get_x() + bar1.get_width() / 2,
                        height1 - 2,  # Below the percentage
                        n_responses_label,
                        ha="center",
                        va="top",
                        fontsize=8,
                    )

                # For the other schools, apply the same logic
                n_responses_label_2 = f"n={int(round(avg_counts_other_schools[i]))}"  # Round and convert
                if height2 < 6:
                    ax.text(
                        bar2.get_x() + bar2.get_width() / 2,
                        height2 + 2,
                        n_responses_label_2,
                        ha="center",
                        va="bottom",
                        fontsize=8,
                    )
                else:
                    ax.text(
                        bar2.get_x() + bar2.get_width() / 2,
                        height2 - 2,
                        n_responses_label_2,
                        ha="center",
                        va="top",
                        fontsize=8,
                    )

            # Labeling and formatting
            ax.set_xlabel("Categories")
            ax.set_ylabel("Percentage of Pupils")
            ax.set_title(measure_lab, pad=20)  # Title includes school name
            ax.set_ylim(0, 100)  # Set Y-axis limit to 100
            ax.set_xticks(np.arange(num_categories))
            ax.set_xticklabels(valid_cat_labels, rotation=-45, ha="left")

            # Add gridlines behind bars
            ax.grid(visible=True, which="both", axis="y", color="#CCCCCC", zorder=0)

            # Remove spines (frame)
            for spine in ax.spines.values():
                spine.set_visible(False)

            # Add legend and adjust position
            ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")

            # Adjust layout and save the figure
            plt.tight_layout(pad=2)  # Add padding to prevent overlap
            output_filename = f"{school_name}_{measure}.png"
            output_path = os.path.join(output_dir, output_filename)
            plt.savefig(output_path)
            plt.close()  # Close the figure to free memory


def save_all_plots(df, output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Iterate through all unique school names and plot groups to save the plots
    for school_name in df["School_lab"].unique():
        plot_comparison(school_name, df, output_dir)


# Example usage
save_all_plots(
    demographics_df,
    "/Users/ellengoddard/Desktop/development-folder/beewell-graphs/kailo-beewell-graphs/outputs"
)


KeyError: 'School_lab'