In [1]:
import ast
import csv
import io
import math
import os
import re
from typing import Literal

import matplotlib.pyplot as plt
import pandas as pd
from global_graph_settings import set_global_graph_settings
from measure_labels_dict import labels
from responses_double_comparison_graph import make_subgroup_comparison_graph
from school_responses_all_pupils_graphs import make_all_pupils_responses_graph
from topic_name_dict import topic_name_dict


def clean_master_df(df: pd.DataFrame) -> pd.DataFrame:
    df_copy = df.copy()

    df_copy.loc[
        df_copy["measure_lab"].isin([
            "Talked about feeling down with... an adult at school",
            "Did you feel listened to when you spoke with... an adult at school",
            "How would you feel about speaking with... an adult at school",
            "Did you receive advice that you found helpful from... an adult at school"
        ]), 
        "group"
    ] = "staff_talk"
    
    df_copy.loc[
        df_copy["measure_lab"].isin([
            "Talked about feeling down with... another person your age",
            "Did you feel listened to when you spoke with... another person your age",
            "How would you feel about speaking with... another person your age",
            "Did you receive advice that you found helpful from... another person your age"
        ]), 
        "group"
    ] = "peer_talk"
    
    df_copy.loc[
        df_copy["measure_lab"].isin([
            "Talked about feeling down with... one of your parents/carers",
            "Did you feel listened to when you spoke with... one of your parents/carers",
            "Did you receive advice that you found helpful from... one of your parents/carers",
            "How would you feel about speaking with... one of your parents/carers"
        ]), 
        "group"
    ] = "home_talk"

    columns_to_parse = ["cat", "cat_lab", "count", "percentage"]

    for column in columns_to_parse:
        df_copy[column] = df_copy[column].apply(lambda x: safe_parse(x, column))

    # Remove /n strings from measure_lab column (escaped)
    df_copy["measure_lab"] = df_copy["measure_lab"].str.replace("\\n", "")
    return df_copy.dropna(subset=["count", "percentage"])


SubGroups = Literal["all_pupils", "year_group", "gender", "fsm", "sen"]


def parse_list_string(s: str) -> list[str]:
    """Parse a string representation of a list, handling quotes and commas correctly.

    Args:
        s (str): The string to parse, e.g., "['Very safe', 'Fairly safe', 'Don't know']"

    Returns:
        list: A list of parsed strings.

    """
    # Remove the outer brackets if present
    s = s.strip()
    if s.startswith("[") and s.endswith("]"):
        s = s[1:-1]

    # Use csv reader to handle quoted strings with commas and apostrophes
    reader = csv.reader(io.StringIO(s), skipinitialspace=True)
    try:
        parsed_list = next(reader)
    except StopIteration:
        parsed_list = []
    return parsed_list


def safe_parse(val, column):
    """Safely parses string representations of lists in the DataFrame columns.

    Args:
        val: The value to parse.
        column (str): The name of the column being parsed.

    Returns:
        list: The parsed list with appropriate data types.

    """
    # Handle NaN values (only for actual NaN, not string 'nan')
    if isinstance(val, float) and math.isnan(val):
        return []

    # Handle empty strings and None
    if val == "" or val is None:
        return []

    try:
        val = str(val).strip()

        # Ensure the string starts with '[' and ends with ']'
        if not (val.startswith("[") and val.endswith("]")):
            val = "[" + val + "]"

        if column == "cat_lab":
            # Use custom parser for 'cat_lab' to handle quotes and commas
            return parse_list_string(val)
        # For 'count' and 'percentage' columns
        # Replace '\N' and 'nan' with '0' to handle missing or invalid values
        val = val.replace("\\N", "0").replace("nan", "0")

        # Safely evaluate the string to a Python list
        parsed_val = ast.literal_eval(val)

        # If the parsed value is a list, convert its items to floats
        if column in ["count", "percentage", "cat"] and isinstance(
            parsed_val,
            list,
        ):
            numeric_list = []
            for item in parsed_val:
                if item is None:
                    # Replace None with 0.0
                    numeric_list.append(0.0)
                else:
                    try:
                        numeric_list.append(float(item))
                    except (ValueError, TypeError):
                        # If conversion fails, append 0.0
                        numeric_list.append(0.0)
            return numeric_list
        return parsed_val
    except (ValueError, SyntaxError) as e:
        print(f"Error parsing value in column '{column}': {val}. Error: {e}")
        return []  # Return empty list if parsing fails


def filter_rows_by_school(df: pd.DataFrame, school_name: str):
    df_copy = df.copy()
    return df_copy[df_copy["school_lab"] == school_name]


def get_school_dfs(list_of_schools: list[str], df: pd.DataFrame):
    school_dfs = {}
    for school in list_of_schools:
        school_dfs[school] = filter_rows_by_school(df, school)
    return school_dfs


def filter_rows_by_topic(df: pd.DataFrame, topic_name: str):
    df_copy = df.copy()
    return df_copy[df_copy["group"] == topic_name]


def filter_by_all_pupils(df: pd.DataFrame):
    df_copy = df.copy()
    return df_copy[
        (df_copy["year_group_lab"] == "All")
        & (df_copy["gender_lab"] == "All")
        & (df_copy["fsm_lab"] == "All")
        & (df_copy["sen_lab"] == "All")
    ]


def filter_by_year_group(df: pd.DataFrame):
    df_copy = df.copy()
    return df_copy[(df_copy["year_group_lab"] == "Year 8") | (df_copy["year_group_lab"] == "Year 10")]


def filter_by_gender(df: pd.DataFrame):
    df_copy = df.copy()
    return df_copy[(df_copy["gender_lab"] == "Girl") | (df_copy["gender_lab"] == "Boy")]


def filter_by_fsm(df: pd.DataFrame):
    df_copy = df.copy()
    return df_copy[(df_copy["fsm_lab"] == "FSM") | (df_copy["fsm_lab"] == "Non-FSM")]


def filter_by_sen(df: pd.DataFrame):
    df_copy = df.copy()
    return df_copy[(df_copy["sen_lab"] == "SEN") | (df_copy["sen_lab"] == "Non-SEN")]


def filter_rows_by_subgroup(df: pd.DataFrame, subgroup_name: SubGroups):
    if subgroup_name == "all_pupils":
        return filter_by_all_pupils(df)
    if subgroup_name == "year_group":
        return filter_by_year_group(df)
    if subgroup_name == "gender":
        return filter_by_gender(df)
    if subgroup_name == "fsm":
        return filter_by_fsm(df)
    if subgroup_name == "sen":
        return filter_by_sen(df)
    return None


def get_measures(df: pd.DataFrame):
    return df["measure"].unique()


def clean_filename(s):
    # Remove any illegal characters
    s = re.sub(r'[<>:"/\\|?*]', "", s)
    # Replace spaces with underscores
    s = s.replace(" ", "_")
    # Remove leading and trailing whitespace
    s = s.strip()
    # Truncate filename if it's too long
    return s[:255]  # Filesystems usually limit filenames to 255 characters


def save_graph(file_path):
    directory = os.path.dirname(file_path)
    if not os.path.exists(directory):
        os.makedirs(directory)
    plt.savefig(file_path, dpi=100)
    plt.close()


def single_comparison_graph(
    df: pd.DataFrame,
    measure_name: str,
    subgroup_name: str,
    topic_name: str,
    school: str,
):
    for _, data in df.iterrows():
        labels_list = data["cat_lab"]
        counts = data["count"]
        percentages = [float(x) for x in data["percentage"]]
        # Get the full measure label from the dictionary
        measure_label = labels.get(measure_name, data["measure_lab"])
        topic_dirty = data["group"]
        topic = topic_name_dict.get(topic_dirty, topic_dirty)
        subgroup = subgroup_name

        if not labels_list or not percentages:
            print(f"Skipping measure '{measure_label}' due to empty data.")
            continue

        if len(labels_list) != len(percentages):
            print(f"Skipping measure '{measure_label}' due to mismatched lengths.")
            continue

        try:
            make_all_pupils_responses_graph(
                category_label=labels_list,
                percentages=percentages,
                counts_list=counts,
                topic=topic,
                measure_label=measure_label,
                legend_title="Pupils",
                legend_label="All pupils",
            )
            # Clean the measure label for use in the filename
            filename = clean_filename(data["measure_lab"]) + "_percentages.png"
            save_graph(
                f"outputs/{school}/responses/{subgroup}/{topic_dirty}/{filename}",
            )
        except Exception as e:
            print(f"Error creating graph for {measure_label}: {e}")


def make_graphs_for_single_comparison(
    df: pd.DataFrame,
    subgroup_name: str,
    topic_name: str,
    school: str,
):
    if df.empty:
        print(f"No data for {subgroup_name} in topic {topic_name} at school {school}")
        return
    measures = get_measures(df)
    for measure in measures:
        measure_df = df[df["measure"] == measure]
        single_comparison_graph(measure_df, measure, subgroup_name, topic_name, school)


def multiple_comparison_graph(
    df: pd.DataFrame,
    measure_name: str,
    labels_list: list[str],
    subgroup_name: str,
    topic_name: str,
    school: str,
):
    percentages_list = []

    for label in labels_list:
        # Adjust the filtering to match the appropriate subgroup label columns
        if subgroup_name == "gender":
            filtered_df = df[df["gender_lab"] == label]
        elif subgroup_name == "fsm":
            filtered_df = df[df["fsm_lab"] == label]
        elif subgroup_name == "sen":
            filtered_df = df[df["sen_lab"] == label]
        elif subgroup_name == "year_group":
            filtered_df = df[df["year_group_lab"] == label]
        else:
            continue  # Skip if subgroup_name doesn't match

    if filtered_df.empty:
        # Dont make a graph
        raise AttributeError("DF is empty.")

    # Get the full measure label from the dictionary
    measure_label = labels.get(measure_name, measure_name)
    topic = topic_name
    subgroup = subgroup_name

    if not labels_list or not percentages_list:
        print(f"Skipping measure '{measure_label}' due to empty data.")
        return

    if len(labels_list) != len(percentages_list):
        print(f"Skipping measure '{measure_label}' due to mismatched lengths.")
        return

    try:
        plt.figure(figsize=(10, 6))
        bars = plt.bar(
            labels_list,
            percentages_list,
            color="#ff7f0e",
            edgecolor="black",
        )
        plt.title(
            f"{measure_label} - Percentage Comparison for {subgroup}",
            fontsize=16,
            fontweight="bold",
            wrap=True,
        )
        plt.xlabel(subgroup.capitalize(), fontsize=14)
        plt.ylabel("Percentage", fontsize=14)
        plt.xticks(rotation=45, ha="right", fontsize=12)
        plt.yticks(fontsize=12)
        plt.grid(
            visible=True,
            which="both",
            linestyle="--",
            linewidth=0.5,
            color="gray",
            axis="y",
        )

        # Add data labels on top of the bars
        for bar, percentage in zip(bars, percentages_list, strict=False):
            height = bar.get_height()
            plt.text(
                bar.get_x() + bar.get_width() / 2,
                height + max(percentages_list) * 0.01,
                f"{percentage:.1f}%",
                ha="center",
                va="bottom",
                fontsize=12,
                wrap=True,
            )

        plt.tight_layout()
        # Clean the measure label for use in the filename
        filename = clean_filename(measure_label) + "_percentages_comparison.png"
        save_graph(f"outputs/{school}/responses/{subgroup}/{topic}/{filename}")
    except Exception as e:
        print(f"Error creating comparison graph for {measure_label}: {e}")


def get_label_list_rows(*, df, labels_list, subgroup_name):
    rows = []

    if subgroup_name == "gender":
        for label in labels_list:
            rows.append(df[df["gender_lab"] == label])
    elif subgroup_name == "fsm":
        for label in labels_list:
            rows.append(df[df["fsm_lab"] == label])
    elif subgroup_name == "sen":
        for label in labels_list:
            rows.append(df[df["sen_lab"] == label])
    elif subgroup_name == "year_group":
        for label in labels_list:
            rows.append(df[df["year_group_lab"] == label])

    return rows


def make_graph_for_subgroup_row(
    df: pd.DataFrame,
    subgroup_name: SubGroups,
    topic_name: str,
    school_name: str,
):
    """e.g., topic_name = Autonomy
    e.g., school_name = School A
    e.g., subgroup_name = year_group
    """
    if df.empty:
        print(
            f"No data for {subgroup_name} in topic {topic_name} at school {school_name}",
        )
        return

    # Define the correct labels based on the subgroup
    if subgroup_name == "gender":
        labels_list = ["Girl", "Boy"]
    elif subgroup_name == "fsm":
        labels_list = ["FSM", "Non-FSM"]
    elif subgroup_name == "sen":
        labels_list = ["SEN", "Non-SEN"]
    elif subgroup_name == "year_group":
        labels_list = ["Year 8", "Year 10"]
    else:
        print(f"Unknown subgroup: {subgroup_name}. Skipping.")
        return

    # Create multiple comparison graphs for the filtered data
    measures = get_measures(df)
    for measure in measures:
        measure_df: pd.DataFrame = df[df["measure"] == measure]

        rows = get_label_list_rows(
            df=measure_df,
            labels_list=labels_list,
            subgroup_name=subgroup_name,
        )

        # Display topic names
        display_topic_name = topic_name_dict.get(topic_name, topic_name)

        # **Corrected Zipping Process**
        try:
            # Extract the first element (the list) from each Series
            percentages_group1 = rows[0]["percentage"].iloc[0]
            percentages_group2 = rows[1]["percentage"].iloc[0]
            counts_group1 = rows[0]["count"].iloc[0]
            counts_group2 = rows[1]["count"].iloc[0]
        except (IndexError, AttributeError) as e:
            print(f"Error accessing data for measure '{measure}': {e}")
            continue

        # Now, zip the individual elements
        percentages_tuple_list = list(zip(percentages_group1, percentages_group2, strict=False))
        counts_tuple_list = list(zip(counts_group1, counts_group2, strict=False))

        # assert len(percentages_tuple_list) == len(
        #     counts_tuple_list
        # ), "Length of percents and counts not the same... this is an error."

        _fig = make_subgroup_comparison_graph(
            category_label=measure_df.iloc[0]["cat_lab"],
            percentages=percentages_tuple_list,
            counts_list=counts_tuple_list,
            topic=display_topic_name,
            measure_label=labels.get(measure, measure),
            comparison_groups=labels_list,
            legend_title="Pupils",
        )

        filename = clean_filename(measure) + "_percentages_comparison.png"
        filepath = f"outputs/{school_name}/responses/{subgroup_name}/{topic_name}/{filename}"
        save_graph(filepath)


def make_single_comparison_graph(
    df: pd.DataFrame,
    subgroup_name: SubGroups,
    topic_name: str,
    school_name: str,
):
    make_graphs_for_single_comparison(df, subgroup_name, topic_name, school_name)


def make_graphs_for_all_pupils(df: pd.DataFrame, topic_name, school_name):
    # Generate graphs for all pupils
    make_single_comparison_graph(
        df,
        subgroup_name="all_pupils",
        topic_name=topic_name,
        school_name=school_name,
    )


def make_graphs_for_subgroup(
    df: pd.DataFrame,
    subgroup_name: SubGroups,
    topic_name: str,
    school_name: str,
):
    if df.empty:
        print(  # noqa: T201
            f"No data for {subgroup_name} in topic {topic_name} at school {school_name}",
        )
        return
    if subgroup_name == "all_pupils":
        make_graphs_for_all_pupils(df, topic_name=topic_name, school_name=school_name)
    else:
        # For other subgroups, directly make graphs using the filtered data
        make_graph_for_subgroup_row(
            df,
            subgroup_name=subgroup_name,
            topic_name=topic_name,
            school_name=school_name,
        )


def make_all_graphs(master_df: pd.DataFrame) -> None:
    """Make all graphs in the standard school responses notebook."""
    set_global_graph_settings()

    cleaned_master_df = clean_master_df(master_df)

    # Get the list of schools and topics from the DataFrame
    list_of_schools = cleaned_master_df["school_lab"].unique()
    topics = cleaned_master_df["group"].unique()

    # Define the subgroups
    subgroups = ["all_pupils", "year_group", "gender", "fsm", "sen"]

    # Get the DataFrames for each school
    school_dfs = get_school_dfs(list_of_schools, cleaned_master_df)

    # Loop through each school
    for school_name, school_df in school_dfs.items():
        print(f"🚌 Processing school: {school_name}")  # noqa: T201
        # Loop through each subgroup
        for subgroup_name in subgroups:
            print(f"🏘️ Processing subgroup: {subgroup_name}")  # noqa: T201
            # Filter the DataFrame for the subgroup
            subgroup_df = filter_rows_by_subgroup(school_df, subgroup_name)
            # Loop through each topic
            for topic_name in topics:
                # Filter the DataFrame for the topic
                topic_df = filter_rows_by_topic(subgroup_df, topic_name)
                # Make graphs for the subgroup and topic
                print(f"📊 Making graph for {subgroup_name}, {topic_name} for School {school_name}")  # noqa: T201
                make_graphs_for_subgroup(
                    topic_df,
                    subgroup_name=subgroup_name,
                    topic_name=topic_name,
                    school_name=school_name,
                )


# Read in data
master_df = pd.read_csv(
    "../data/real/standard_school_aggregate_responses.csv",
    encoding="utf-8",
    encoding_errors="ignore",
    on_bad_lines="skip",
)

cleaned_df = clean_master_df(master_df)

make_all_graphs(cleaned_df)
print("✅ Finished making all graphs.")  # noqa: T201


🚌 Processing school: BRAUNTON SCHOOL AND C.C.
🏘️ Processing subgroup: all_pupils
📊 Making graph for all_pupils, autonomy for School BRAUNTON SCHOOL AND C.C.
📊 Making graph for all_pupils, life_satisfaction for School BRAUNTON SCHOOL AND C.C.
📊 Making graph for all_pupils, optimism for School BRAUNTON SCHOOL AND C.C.
📊 Making graph for all_pupils, wellbeing for School BRAUNTON SCHOOL AND C.C.
📊 Making graph for all_pupils, esteem for School BRAUNTON SCHOOL AND C.C.
📊 Making graph for all_pupils, stress for School BRAUNTON SCHOOL AND C.C.
📊 Making graph for all_pupils, appearance for School BRAUNTON SCHOOL AND C.C.
📊 Making graph for all_pupils, negative for School BRAUNTON SCHOOL AND C.C.
📊 Making graph for all_pupils, lonely for School BRAUNTON SCHOOL AND C.C.
📊 Making graph for all_pupils, support for School BRAUNTON SCHOOL AND C.C.
📊 Making graph for all_pupils, sleep for School BRAUNTON SCHOOL AND C.C.
📊 Making graph for all_pupils, physical for School BRAUNTON SCHOOL AND C.C.
📊 Mak