In [1]:
import ast
import os

import matplotlib.pyplot as plt
import pandas as pd


def clean_master_df(df: pd.DataFrame) -> pd.DataFrame:
    """Clean and preprocess the input DataFrame."""
    required_columns = ["cat", "cat_lab", "count", "percentage", "measure_lab", "school_lab", "group"]
    for col in required_columns:
        if col not in df.columns:
            print(f"Missing column: {col}")
            return pd.DataFrame()

    df_copy = df.copy()
    columns_to_parse = ["cat", "cat_lab", "count", "percentage"]

    for column in columns_to_parse:
        df_copy[column] = df_copy[column].apply(lambda x: safe_parse(x, column))

    return df_copy.dropna(subset=["count", "percentage"])


def safe_parse(val, column):
    """Safely parse list-like strings in DataFrame columns."""
    try:
        if pd.isna(val) or not isinstance(val, str):
            return []

        val = val.strip()
        if not (val.startswith("[") and val.endswith("]")):
            val = "[" + val + "]"

        parsed_val = ast.literal_eval(val)

        if column in ["count", "percentage", "cat"] and isinstance(parsed_val, list):
            return [float(x) if x not in [None, "nan", "\\N"] else 0.0 for x in parsed_val]
        if column == "cat_lab":
            return [str(x) for x in parsed_val]
        return parsed_val
    except Exception as e:
        print(f"Error parsing value: {val}. Error: {e}")
        return []


def filter_by_all_pupils(df: pd.DataFrame) -> pd.DataFrame:
    """Filter the DataFrame for 'all pupils'."""
    return df[(df["year_group_lab"] == "All") & (df["gender_lab"] == "All") & (df["fsm_lab"] == "All")]


def make_all_pupils_responses_graph(category_label, percentages, counts_list, topic, measure_label):
    """Generate a single bar graph for all pupils."""
    if len(category_label) != len(percentages) or len(percentages) != len(counts_list):
        print(f"Mismatched data lengths for {measure_label}. Skipping graph.")
        return

    fig, ax = plt.subplots(figsize=(10, 6))
    x = range(len(category_label))
    bars = ax.bar(x, percentages, color="#ff7f0e", edgecolor="black")

    for bar, percentage, count in zip(bars, percentages, counts_list, strict=False):
        ax.text(
            bar.get_x() + bar.get_width() / 2, bar.get_height() + 1, f"{percentage:.1f}%", ha="center", va="bottom",
        )
        ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() - 5, f"n={int(count)}", ha="center", va="top")

    ax.set_title(f"{topic}: {measure_label}", fontsize=16)
    ax.set_ylabel("Percentage", fontsize=14)
    ax.set_xticks(x)
    ax.set_xticklabels(category_label, rotation=45, ha="right", fontsize=12)
    plt.tight_layout()


def save_graph(file_path: str):
    """Save the graph to a file."""
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    plt.savefig(file_path, dpi=100)
    plt.close()


def make_graphs_for_all_pupils(df: pd.DataFrame, topic_name: str, school_name: str):
    """Generate graphs for all pupils by topic and school."""
    topic_df = df[df["group"] == topic_name]

    for _, row in topic_df.iterrows():
        category_label = row["cat_lab"]
        percentages = row["percentage"]
        counts_list = row["count"]

        measure_label = row["measure_lab"]

        if not category_label or not percentages:
            print(f"Skipping graph for measure '{measure_label}' (empty data).")
            continue

        make_all_pupils_responses_graph(
            category_label=category_label,
            percentages=percentages,
            counts_list=counts_list,
            topic=topic_name,
            measure_label=measure_label,
        )
        save_graph(f"outputs/{school_name}/all_pupils/{topic_name}/{measure_label}_percentages.png")


def make_all_graphs(master_df: pd.DataFrame):
    """Generate graphs for all pupils."""
    cleaned_df = clean_master_df(master_df)
    if cleaned_df.empty:
        print("No data available after cleaning.")
        return

    all_pupils_df = filter_by_all_pupils(cleaned_df)
    if all_pupils_df.empty:
        print("No data for 'all pupils'.")
        return

    schools = all_pupils_df["school_lab"].unique()
    topics = all_pupils_df["group"].unique()

    for school_name in schools:
        school_df = all_pupils_df[all_pupils_df["school_lab"] == school_name]
        for topic_name in topics:
            make_graphs_for_all_pupils(school_df, topic_name, school_name)


# Main execution
if __name__ == "__main__":
    input_file = "/Users/ellengoddard/Desktop/development-folder/beewell-graphs/kailo-beewell-graphs/data/real/symbol_nd_aggregate_responses.csv"
    master_df = pd.read_csv(input_file, encoding="utf-8", on_bad_lines="skip")
    make_all_graphs(master_df)
    print("✅ All graphs generated successfully.")


Missing column: group
No data available after cleaning.
✅ All graphs generated successfully.
