In [None]:
import matplotlib.pyplot as plt
import json
import torch
import pickle
from typing import Optional
from matplotlib.colors import Normalize
import numpy as np
import os

import experiments.utils as utils

DICTIONARIES_PATH = "../dictionary_learning/dictionaries"

In [None]:
def get_sparsity_penalty(config: dict, trainer_class: str) -> float:
    if trainer_class == "TrainerTopK":
        return config["trainer"]["k"]
    elif trainer_class == "PAnnealTrainer":
        return config["trainer"]["sparsity_penalty"]
    else:
        return config["trainer"]["l1_penalty"]


def get_l0_frac_recovered(ae_paths: list[str]) -> dict[str, dict[str, float]]:
    results = {}
    for ae_path in ae_paths:
        eval_results_file = f"{ae_path}/eval_results.json"
        if not os.path.exists(eval_results_file):
            print(f"Warning: {eval_results_file} does not exist.")
            continue

        with open(eval_results_file, "r") as f:
            eval_results = json.load(f)

        l0 = eval_results["l0"]
        frac_recovered = eval_results["frac_recovered"]

        results[ae_path] = {
            "l0": l0,
            "frac_recovered": frac_recovered,
        }

    return results


def add_ae_config_results(
    ae_paths: list[str], results: dict[str, dict[str, float]]
) -> dict[str, dict[str, float]]:
    for ae_path in ae_paths:
        config_file = f"{ae_path}/config.json"

        with open(config_file, "r") as f:
            config = json.load(f)

        trainer_class = config["trainer"]["trainer_class"]
        results[ae_path]["trainer_class"] = trainer_class
        results[ae_path]["l1_penalty"] = get_sparsity_penalty(config, trainer_class)

        results[ae_path]["lr"] = config["trainer"]["lr"]
        results[ae_path]["dict_size"] = config["trainer"]["dict_size"]
        results[ae_path]["steps"] = config["trainer"]["steps"]

    return results


In [None]:
def get_probe_names(filename_filter: str, attrib_dir: int) -> tuple[str, str, str]:
    eval_probe_class_id = "male_professor / female_nurse"
    if "_bias_shift" in filename_filter:
        ablated_probe_class_id = eval_probe_class_id

        if filename_filter == "_bias_shift_dir1":
            eval_data_class_id = "professor / nurse"
        elif filename_filter == "_bias_shift_dir2":
            eval_data_class_id = "male / female"
        else:
            raise ValueError("Invalid filename filter.")
    else:
        if attrib_dir == 1:
            ablated_probe_class_id = "male / female"
            eval_data_class_id = "professor / nurse"
        elif attrib_dir == 2:
            ablated_probe_class_id = "professor / nurse"
            eval_data_class_id = "male / female"
        else:
            raise ValueError("Invalid attrib_dir.")
        
    return ablated_probe_class_id, eval_probe_class_id, eval_data_class_id

def get_spurious_correlation_plotting_dict(
    ae_paths: list[str],
    acc_key: str = "acc",
    filename_filters: tuple[str] = ("_attrib", "_auto_interp", "_bias_shift_dir1", "_bias_shift_dir2"),
) -> tuple[dict[str, dict[str, float]], float]:
    results = {}
    orig_acc = None

    for ae_path in ae_paths:
        eval_results_file = f"{ae_path}/eval_results.json"

        if not os.path.exists(eval_results_file):
            print(f"Warning: {eval_results_file} does not exist.")
            continue

        with open(eval_results_file, "r") as f:
            eval_results = json.load(f)

        l0 = eval_results["l0"]
        frac_recovered = eval_results["frac_recovered"]

        results[ae_path] = {
            "l0": l0,
            "frac_recovered": frac_recovered,
        }

        config_file = f"{ae_path}/config.json"

        with open(config_file, "r") as f:
            config = json.load(f)

        trainer_class = config["trainer"]["trainer_class"]
        results[ae_path]["trainer_class"] = trainer_class
        results[ae_path]["l1_penalty"] = get_sparsity_penalty(config, trainer_class)

        results[ae_path]["lr"] = config["trainer"]["lr"]
        results[ae_path]["dict_size"] = config["trainer"]["dict_size"]
        if "steps" in config["trainer"]:
            results[ae_path]["steps"] = config["trainer"]["steps"]
        else:
            results[ae_path]["steps"] = -1

        for filename_filter in filename_filters:

            class_accuracies_file = f"{ae_path}/class_accuracies{filename_filter}.pkl"

            if not os.path.exists(class_accuracies_file):
                print(
                    f"Warning: {class_accuracies_file} does not exist. Removing this path from results."
                )
                del results[ae_path]
                continue

            with open(class_accuracies_file, "rb") as f:
                class_accuracies = pickle.load(f)

            # for class_id in class_accuracies['clean_acc']:
            #     print(class_id, class_accuracies['clean_acc'][class_id])

            dirs = [1,2]
            if "bias_shift_dir" in filename_filter:
                dirs = [0]

            for dir in dirs:

                ablated_probe_class_id, eval_probe_class_id, eval_data_class_id = get_probe_names(filename_filter, dir)


                for threshold in class_accuracies[ablated_probe_class_id]:

                    clean_acc = class_accuracies["clean_acc"][eval_data_class_id]["acc"]

                    combined_class_name = f"{eval_probe_class_id} probe on {eval_data_class_id} data"

                    original_acc = class_accuracies["clean_acc"][combined_class_name]["acc"]
                    if orig_acc is None:
                        orig_acc = original_acc
                        print(f"Original acc: {original_acc}")

                    changed_acc = class_accuracies[ablated_probe_class_id][threshold][combined_class_name][
                        acc_key
                    ]

                    changed_acc = (changed_acc - original_acc) / (clean_acc - original_acc)

                    if dir == 0:
                        metric_key = f"scr{filename_filter}_threshold_{threshold}"
                    else:
                        metric_key = f"scr{filename_filter}_dir{dir}_threshold_{threshold}"

                    results[ae_path][metric_key] = changed_acc

    if orig_acc is None:
        raise ValueError(f"original_acc not found for {ae_paths}")
    return results, orig_acc


def get_classes(first_path: str) -> list[int]:
    class_accuracies_file = f"{first_path}/class_accuracies.pkl"
    with open(class_accuracies_file, "rb") as f:
        class_accuracies = pickle.load(f)
    return list(class_accuracies["clean_acc"].keys())

def create_tpp_plotting_dict(
    ae_paths: list[str],
    intended_filter_class_ids: list[int],
    unintended_filter_class_ids: list[int],
    filename_filters: tuple[str] = ("_attrib", "_auto_interp"),
    acc_key: str = "acc",
    save_results: bool = True,
) -> dict:
    
    results = {}

    for ae_path in ae_paths:

        results[ae_path] = {}

        for filename_filter in filename_filters:

            intended_diffs = {}
            unintended_diffs = {}


            class_accuracies_file = f"{ae_path}/class_accuracies{filename_filter}.pkl"

            if not os.path.exists(class_accuracies_file):
                print(
                    f"Warning: {class_accuracies_file} does not exist. Skipping this path."
                )
                continue

            with open(class_accuracies_file, "rb") as f:
                class_accuracies = pickle.load(f)

            classes = list(class_accuracies["clean_acc"].keys())


            for class_id in classes:
                if isinstance(class_id, str) and " probe on " in class_id:
                    raise ValueError("This is deprecated, shouldn't be here.")

                if intended_filter_class_ids and class_id not in intended_filter_class_ids:
                    continue

                clean = class_accuracies["clean_acc"][class_id]["acc"]

                for threshold in class_accuracies[class_id]:
                    patched = class_accuracies[class_id][threshold][class_id][acc_key]

                    diff = clean - patched

                    if threshold not in intended_diffs:
                        intended_diffs[threshold] = []

                    intended_diffs[threshold].append(diff)

            for intended_class_id in classes:
                if isinstance(intended_class_id, str) and " probe on " in intended_class_id:
                    raise ValueError("This is deprecated, shouldn't be here.")

                # If we have a filter, skip the unintended classes
                if intended_filter_class_ids and intended_class_id not in intended_filter_class_ids:
                    continue

                for unintended_class_id in classes:
                    if intended_class_id == unintended_class_id:
                        continue

                    # If we have a filter, skip the unintended classes
                    if (
                        unintended_filter_class_ids
                        and unintended_class_id not in unintended_filter_class_ids
                    ):
                        continue

                    if isinstance(unintended_class_id, str) and " probe on " in unintended_class_id:
                        raise ValueError("This is deprecated, shouldn't be here.")

                    clean = class_accuracies["clean_acc"][unintended_class_id]["acc"]

                    for threshold in class_accuracies[intended_class_id]:
                        patched = class_accuracies[intended_class_id][threshold][unintended_class_id][
                            acc_key
                        ]
                        diff = clean - patched

                        if threshold not in unintended_diffs:
                            unintended_diffs[threshold] = []

                        unintended_diffs[threshold].append(diff)

            for threshold in intended_diffs:
                assert threshold in unintended_diffs

                average_intended_diff = sum(intended_diffs[threshold]) / len(intended_diffs[threshold])
                average_unintended_diff = sum(unintended_diffs[threshold]) / len(unintended_diffs[threshold])
                average_diff = average_intended_diff - average_unintended_diff

                results[ae_path][f"tpp{filename_filter}_threshold_{threshold}_total_metric"] = average_diff
                results[ae_path][f"tpp{filename_filter}_threshold_{threshold}_intended_diff_only"] = average_intended_diff
                results[ae_path][f"tpp{filename_filter}_threshold_{threshold}_unintended_diff_only"] = average_unintended_diff

        if save_results:

            single_ae_results_dict = results[ae_path]
            single_ae_results_dict['hyperparameters'] = {}
            single_ae_results_dict['hyperparameters']['classes'] = classes

            print(f"Saving results for {ae_path}")
            with open(f"{ae_path}/tpp_results.json", "w") as f:
                json.dump(single_ae_results_dict, f)

    return results


In [None]:
def get_probe_clean_accuracies(ae_paths: list[str], filename_filter: str, acc_key: str) -> dict:
    for ae_path in ae_paths:
        class_accuracies_file = f"{ae_path}/class_accuracies{filename_filter}.pkl"
        if not os.path.exists(class_accuracies_file):
            print(f"Warning: {class_accuracies_file} does not exist.")
            continue

        with open(class_accuracies_file, "rb") as f:
            class_accuracies = pickle.load(f)

        return class_accuracies["clean_acc"]

In [None]:
# Another way to generate graphs, where you manually populate sweep_name and submodule_trainers
sweep_name = "pythia70m_test_sae"
submodule_trainers = {"resid_post_layer_3": {"trainer_ids": [0]}}

# Current recommended way to generate graphs. You can copy paste ae_sweep_paths directly from bib_intervention.py
ae_sweep_paths = {
    "pythia70m_sweep_standard_ctx128_0712": {"resid_post_layer_3": {"trainer_ids": [1, 7, 11, 18]}}
}

# trainer_ids = [2, 6, 10, 14, 18]
trainer_ids = None

ae_sweep_paths = {
    "pythia70m_sweep_standard_ctx128_0712": {
        # "resid_post_layer_0": {"trainer_ids": None},
        # "resid_post_layer_1": {"trainer_ids": None},
        # "resid_post_layer_2": {"trainer_ids": None},
        "resid_post_layer_3": {"trainer_ids": trainer_ids},
        # "resid_post_layer_4": {"trainer_ids": None},
    },
    "pythia70m_sweep_topk_ctx128_0730": {
        # "resid_post_layer_0": {"trainer_ids": None},
        # "resid_post_layer_1": {"trainer_ids": None},
        # "resid_post_layer_2": {"trainer_ids": None},
        # "resid_post_layer_3": {"trainer_ids": None},
        # "resid_post_layer_4": {"trainer_ids": None},
        "resid_post_layer_3": {"trainer_ids": trainer_ids},
    },
    # "pythia70m_sweep_gated_ctx128_0730": {
    #     # "resid_post_layer_0": {"trainer_ids": None},
    #     # "resid_post_layer_1": {"trainer_ids": None},
    #     # "resid_post_layer_2": {"trainer_ids": None},
    #     # "resid_post_layer_3": {"trainer_ids": None},
    #     # "resid_post_layer_4": {"trainer_ids": None},
    #     "resid_post_layer_3": {"trainer_ids": trainer_ids},
    # },
}

# ae_sweep_paths = {
#     "pythia70m_sweep_topk_ctx128_0730": {
#         # "resid_post_layer_0": {"trainer_ids": None},
#         # "resid_post_layer_1": {"trainer_ids": None},
#         # "resid_post_layer_2": {"trainer_ids": None},
#         "resid_post_layer_3": {"trainer_ids": None},
#         # "resid_post_layer_4": {"trainer_ids": None},
#     }
# }

trainer_ids = None
# trainer_ids = [1]

ae_sweep_paths = {
    "gemma-2-2b_sweep_standard_ctx128_ef8_0824": {
        # "resid_post_layer_12": {"trainer_ids": trainer_ids},
        "resid_post_layer_15": {"trainer_ids": trainer_ids},
        # "resid_post_layer_19": {"trainer_ids": trainer_ids},
        # "resid_post_layer_20": {"trainer_ids": trainer_ids},
    },
    "gemma-2-2b_sweep_topk_ctx128_ef8_0824": {
        # "resid_post_layer_12": {"trainer_ids": trainer_ids},
        # "resid_post_layer_11": {"trainer_ids": trainer_ids},
        "resid_post_layer_15": {"trainer_ids": trainer_ids},
        # "resid_post_layer_19": {"trainer_ids": trainer_ids},
        # "resid_post_layer_20": {"trainer_ids": trainer_ids},
    },
    "gemma-2-2b_sweep_jumprelu_0902": {
        "resid_post_layer_15": {"trainer_ids": trainer_ids},
        # "resid_post_layer_19": {"trainer_ids": trainer_ids},
    },
}

# trainer_ids = None
# ae_sweep_paths = {
#     "gemma-2-2b_sweep_jumprelu_0902_probe_layer24_results": {
#         "resid_post_layer_11": {"trainer_ids": trainer_ids},
#     },
#     "gemma-2-2b_sweep_standard_ctx128_ef8_0824_probe_layer24_results": {
#         "resid_post_layer_11": {"trainer_ids": trainer_ids},
#     },
#     "gemma-2-2b_sweep_topk_ctx128_ef8_0824_probe_layer24_results": {
#         "resid_post_layer_11": {"trainer_ids": trainer_ids},
#     },
#     "gemma-2-2b_sweep_standard_ctx128_ef2_0824_probe_layer_24_results": {
#         "resid_post_layer_11": {"trainer_ids": trainer_ids},
#     },
#     "gemma-2-2b_sweep_topk_ctx128_ef2_0824_probe_layer_24_results": {
#         "resid_post_layer_11": {"trainer_ids": trainer_ids},
#     },
# }

# ae_sweep_paths = {
#     "gemma-2-2b_sweep_standard_ctx128_ef8_0824_probe_layer20_results": {
#         "resid_post_layer_11": {"trainer_ids": trainer_ids},
#     },
#     "gemma-2-2b_sweep_topk_ctx128_ef8_0824_probe_layer20_results": {
#         "resid_post_layer_11": {"trainer_ids": trainer_ids},
#     },
# }

trainer_ids = None

ae_sweep_paths = {
    "gemma-2-2b_sweep_topk_ctx128_ef8_0824_probe_layer_11_tpp": {
        "resid_post_layer_11_checkpoints": {"trainer_ids": trainer_ids},
    },
    "gemma-2-2b_sweep_topk_ctx128_ef8_0824_probe_layer_15_tpp": {
        "resid_post_layer_15_checkpoints": {"trainer_ids": trainer_ids},
    },
    "gemma-2-2b_sweep_topk_ctx128_ef8_0824_probe_layer_19_tpp": {
        "resid_post_layer_19_checkpoints": {"trainer_ids": trainer_ids},
    },
}

ae_sweep_paths = {
    "gemma-2-2b_sweep_standard_ctx128_ef8_0824_probe_layer_19_tpp": None,
    "gemma-2-2b_sweep_topk_ctx128_ef8_0824_probe_layer_19_tpp": None,
    "gemma-2-2b_sweep_jumprelu_0902_probe_layer_19_tpp": None,
}

# ae_sweep_paths = {
#     "gemma-2-2b_sweep_standard_ctx128_ef8_0824_probe_layer_19_tpp": None,
#     "gemma-2-2b_sweep_topk_ctx128_ef8_0824_probe_layer_19_tpp": None,
# }

DICTIONARIES_PATH = (
    "../dictionary_learning/dictionaries/09_20_gemma_tpp_autointerp_topk_standard_jumprelu"
)

plot_spurious = False
plot_tpp = True
plot_checkpoints = False
plot_averaged_results = False

l0_threshold = 500

model = "Gemma-2-2B"

l0_threshold = None

no_title = True



ae_paths = []

for sweep_name, submodule_trainers in ae_sweep_paths.items():
    ae_group_paths = utils.get_ae_group_paths(DICTIONARIES_PATH, sweep_name, submodule_trainers)
    ae_paths.extend(utils.get_ae_paths(ae_group_paths))



In [None]:
ae_sweep_paths = {}

DICTIONARIES_PATH = (
    "../dictionary_learning/dictionaries/09_20_gemma_spurious_autointerp_topk_standard_jumprelu"
)

plot_spurious = True
plot_tpp = False
plot_checkpoints = False
plot_averaged_results = True

l0_threshold = 500

model = "Gemma-2-2B"

l0_threshold = None

In [None]:


###


# print(ae_paths)
if plot_tpp:
    # If not empty, this will filter to only include the specified class ids
    intended_filter_class_ids = []
    unintended_filter_class_ids = []

if plot_averaged_results:
    ae_group_paths = [
        "gemma-2-2b_sweep_topk_ctx128_ef8_0824",
        "gemma-2-2b_sweep_standard_ctx128_ef8_0824",
        "gemma-2-2b_sweep_jumprelu_0902",
    ]
    probe_layers = [19]
    column1_vals_list = [
        ("professor", "nurse"),
        ("architect", "journalist"),
        # ("painter", "teacher"),
        # ("photographer", "physician"),
    ]
    dataset = "bias_in_bios"

    # dataset = "amazon_reviews_1and5"
    # column1_vals_list = [
    #     ("Books", "CDs_and_Vinyl"),
    #     # ("Software", "Electronics"),
    #     ("Pet_Supplies", "Office_Products"),
    # ]

    

    ignore_sae_filters = ["trainer_4"]
    ignore_sae_filters = []

In [None]:
if plot_tpp:
    tpp_results = create_tpp_plotting_dict(
        ae_paths,
        intended_filter_class_ids,
        unintended_filter_class_ids,
    )

In [None]:
# Set this to true if you have multiple spurious correlation runs to average over
# like: gemma-2-2b_sweep_standard_ctx128_ef8_0824_probe_layer_15_spurious_bias_in_bios_filmmaker_dentist
# and: gemma-2-2b_sweep_standard_ctx128_ef8_0824_probe_layer_15_spurious_bias_in_bios_painter_teacher


def get_all_probe_clean_accuracies(
    ae_group_path: str,
    probe_layer: int,
    column1_vals_list: list[tuple[str, str]],
    filename_filter: str,
    acc_key: str = "acc",
    intervention_type: str = "spurious",
    dataset: str = "bias_in_bios",
):
    ae_base_path = f"{ae_group_path}_probe_layer_{probe_layer}_{intervention_type}_{dataset}"
    class_acc_dict = {}

    for column1_vals in column1_vals_list:
        ae_run_path = f"{ae_base_path}_{column1_vals[0]}_{column1_vals[1]}"
        sweep_name = ae_run_path
        submodule_trainers = None
        ae_group_paths = utils.get_ae_group_paths(DICTIONARIES_PATH, sweep_name, submodule_trainers)
        ae_paths = utils.get_ae_paths(ae_group_paths)

        clean_probe_accs = get_probe_clean_accuracies(ae_paths, filename_filter, acc_key)

        class_acc_dict[column1_vals] = clean_probe_accs

    return class_acc_dict


def average_multiple_scr_runs(
    ae_group_path: str,
    dictionaries_path: str,
    probe_layer: int,
    column1_vals_list: list[tuple[str, str]],
    ignore_sae_filters: list[str],
    acc_key: str = "acc",
    intervention_type: str = "spurious",
    dataset: str = "bias_in_bios",
) -> tuple[dict, float, dict]:
    ae_base_path = f"{ae_group_path}_probe_layer_{probe_layer}_{intervention_type}_{dataset}"
    ae_output_path = f"{ae_group_path}"

    all_results = {}
    original_accs = []
    class_acc_dict = {}

    for column1_vals in column1_vals_list:
        ae_run_path = f"{ae_base_path}_{column1_vals[0]}_{column1_vals[1]}"
        sweep_name = ae_run_path
        submodule_trainers = None
        ae_group_paths = utils.get_ae_group_paths(DICTIONARIES_PATH, sweep_name, submodule_trainers)
        ae_paths = utils.get_ae_paths(ae_group_paths)

        temp_results, orig_acc = get_spurious_correlation_plotting_dict(
            ae_paths,
            acc_key,
        )

        class_acc_dict[column1_vals] = orig_acc

        all_results.update(temp_results)
        original_accs.append(orig_acc)

        if orig_acc is None:
            raise ValueError(f"Original acc is None for {ae_run_path}")

    final_results = {}

    average_orig_acc = sum(original_accs) / len(original_accs)

    for ae_path in all_results:
        skip_path = False
        for filter in ignore_sae_filters:
            if filter in ae_path:
                skip_path = True
                break

        if skip_path:
            continue

        orig_ae_path = ae_path
        name_fixed = False
        for column1_vals in column1_vals_list:
            ae_run_path = f"{ae_base_path}_{column1_vals[0]}_{column1_vals[1]}"
            if ae_run_path in ae_path:
                ae_path = ae_path.replace(ae_run_path, ae_output_path)
                ae_path = ae_path.split(dictionaries_path)[1]
                name_fixed = True
        assert name_fixed

        if ae_path not in final_results:
            final_results[ae_path] = all_results[orig_ae_path]
        else:
            for key in all_results[orig_ae_path]:
                if isinstance(all_results[orig_ae_path][key], float):
                    final_results[ae_path][key] += all_results[orig_ae_path][key]

    for ae_path in final_results:
        for key in final_results[ae_path]:
            if isinstance(final_results[ae_path][key], float):
                final_results[ae_path][key] /= len(column1_vals_list)

    return final_results, average_orig_acc, class_acc_dict


if plot_averaged_results:
    all_averaged_results = {}
    all_orig_accs = []

    for probe_layer in probe_layers:
        averaged_results = {}
        orig_accs = []

        for ae_group_path in ae_group_paths:
            single_averaged_results, single_average_orig_acc, class_act_dict = (
                average_multiple_scr_runs(
                    ae_group_path,
                    DICTIONARIES_PATH + "/",
                    probe_layer,
                    column1_vals_list,
                    ignore_sae_filters,
                    acc_key="acc",
                    dataset=dataset,
                )
            )

            averaged_results.update(single_averaged_results)
            orig_accs.append(single_average_orig_acc)

        average_orig_acc = sum(orig_accs) / len(orig_accs)

        print(average_orig_acc)

    all_averaged_results.update(averaged_results)
    all_orig_accs.append(average_orig_acc)

In [None]:
print(all_averaged_results.keys())
first_key = next(iter(all_averaged_results.keys()))
print(all_averaged_results[first_key].keys())

In [None]:
# dump json results
if plot_averaged_results:
    with open(f"all_scr_results.json", "w") as f:
        json.dump(all_averaged_results, f)


In [None]:
if plot_averaged_results and plot_checkpoints:
    y_label = "Probe Accuracy Increase"

    if no_title:
        title = None

    plot_steps_vs_average_diff(
        all_averaged_results,
        steps_key="steps",
        avg_diff_key=custom_metric1,
        title=title,
        y_label=y_label,
        output_filename=f"{image_filename_prefix}_checkpoints.png",
    )

In [None]:
def spurious_probe_acc_table(
    class_acc_dict: dict[tuple[str, str], float],
    evaled_probe_class_id: str,
    model_name: str,
    acc_key: str = "acc",
):
    print(f"# {model_name} Clean Probe Accuracies")
    print()
    # Print the header in Markdown format
    print("| Class 1 / Class 2      | Train Accuracy | Profession Accuracy | Gender Accuracy |")
    print(
        "|------------------------|---------------------|---------------------|-----------------|"
    )

    # Iterate over the items in dir1 (assuming keys are shared between dir1 and dir2)
    for column1_vals, clean_accs in class_acc_dict.items():
        class_name = f"{column1_vals[0]} / {column1_vals[1]}"

        train_acc = clean_accs[evaled_probe_class_id][acc_key]

        dir1_acc_key = f"{evaled_probe_class_id} probe on professor / nurse data"
        dir2_acc_key = f"{evaled_probe_class_id} probe on male / female data"

        dir1_acc = clean_accs[dir1_acc_key][acc_key]
        dir2_acc = clean_accs[dir2_acc_key][acc_key]

        print(
            f"| {class_name:<22} | {train_acc:.4f}              | {dir1_acc:.4f}              | {dir2_acc:.4f}           |"
        )


def single_class_probe_acc_table(
    class_acc_dict: dict[tuple[str, str], float],
    model_name: str,
    acc_key: str = "acc",
):
    print(f"# {model_name} Clean Probe Accuracies")
    print()
    # Print the header in Markdown format
    print("| Class 1 / Class 2      | Profession Accuracy | Gender Accuracy |")
    print("|------------------------|---------------------|-----------------|")

    # Iterate over the items in dir1 (assuming keys are shared between dir1 and dir2)
    for column1_vals, clean_accs in class_acc_dict.items():
        class_name = f"{column1_vals[0]} / {column1_vals[1]}"

        dir1_acc = clean_accs["professor / nurse"][acc_key]
        dir2_acc = clean_accs["male / female"][acc_key]

        print(f"| {class_name:<22} | | {dir1_acc:.4f}              | {dir2_acc:.4f}           |")


evaled_probe = "male_professor / female_nurse"

if plot_averaged_results:
    for probe_layer in probe_layers:
        for ae_group_path in ae_group_paths:
            class_act_dict = get_all_probe_clean_accuracies(
                ae_group_path, probe_layer, column1_vals_list, filename_filter, dataset=dataset
            )

            break
        assert class_act_dict is not None

        spurious_probe_acc_table(class_act_dict, evaled_probe, model)
        single_class_probe_acc_table(class_act_dict, model)
