In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pickle

from typing import Union

In [None]:
submodule_trainers = {
    "resid_post_layer_3": {"trainer_ids": [10]},
}

model_name_lookup = {"pythia70m": "EleutherAI/pythia-70m-deduped"}
dictionaries_path = "../dictionary_learning/dictionaries"

model_location = "pythia70m"
sweep_name = "_sweep_topk_ctx128_0730"

In [None]:
submodule_name = "resid_post_layer_3"
trainer_id = submodule_trainers[submodule_name]["trainer_ids"][0]

trainer_path = os.path.join(dictionaries_path, model_location + sweep_name, submodule_name, f"trainer_{trainer_id}")

with open(os.path.join(trainer_path, "class_accuracies.pkl"), 'rb') as f:
    class_accuracies = pickle.load(f)

with open(os.path.join(trainer_path, "node_effects.pkl"), 'rb') as f:
    node_effects = pickle.load(f)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from typing import Union

def plot_accuracy_comparison(
    class_accuracies: dict[int, Union[int, dict[float, float]]],
):
    clean_accuracies = class_accuracies[-1]
    abl_accuracies = {k: v for k, v in class_accuracies.items() if k != -1}
    
    T_effects = list(next(iter(abl_accuracies.values())).keys())  # same for all classes
    ablated_class_idxs = list(abl_accuracies.keys())
    eval_class_idxs = list(clean_accuracies.keys())

    for T_effect in T_effects:
        # Set up the plot
        fig, ax = plt.subplots(figsize=(12, 6))

        # Set the width of each bar and the positions of the bars
        bar_width = 0.8 / (len(eval_class_idxs) + 1)
        x = np.arange(len(ablated_class_idxs))

        # Create background bars for clean accuracies
        for i, eval_class_idx in enumerate(eval_class_idxs):
            clean_accuracy = clean_accuracies[eval_class_idx]
            # Light grey background
            ax.bar(x + (i - (len(eval_class_idxs) - 1) / 2) * bar_width, 
                   [clean_accuracy] * len(ablated_class_idxs), 
                   bar_width, color='lightgray', alpha=0.5, zorder=-1)
            # Thin black top bar (now thinner)
            black_bar_height = 0.002  # Reduced from 0.005 to 0.001
            ax.bar(x + (i - (len(eval_class_idxs) - 1) / 2) * bar_width, 
                   [black_bar_height] * len(ablated_class_idxs), 
                   bar_width, bottom=[clean_accuracy - black_bar_height] * len(ablated_class_idxs),
                   color='black', alpha=1, zorder=10)

        # Create bars for each evaluated class
        for i, eval_class_idx in enumerate(eval_class_idxs):
            values = [abl_accuracies[abl_class_idx][T_effect].get(eval_class_idx, 0)['acc'] for abl_class_idx in ablated_class_idxs]
            colors = ["orange" if eval_class_idx != abl_class_idx else "red" for abl_class_idx in ablated_class_idxs]
            ax.bar(x + (i - (len(eval_class_idxs) - 1) / 2) * bar_width, values, bar_width, color=colors)

        # Customize the plot
        ax.set_xlabel("Ablated Class Index")
        ax.set_ylabel("Test Accuracy")
        ax.set_title(f"Probe accuracies for ablated models\nT_effect = {T_effect}")
        ax.set_xticks(x)
        ax.set_xticklabels(ablated_class_idxs)
        
        # Create a custom legend
        from matplotlib.patches import Patch
        legend_elements = [Patch(facecolor='lightgray', edgecolor='black', label='Clean Accuracy'),
                           Patch(facecolor='orange', label='Ablated Accuracy'),
                           Patch(facecolor='red', label='Ablated = Evaluated Class')]
        ax.legend(handles=legend_elements, loc="lower right")

        # Add some padding to the x-axis
        plt.xlim(-0.5, len(ablated_class_idxs) - 0.5)
        plt.ylim(0.4, 1.0)

        # Show the plot
        plt.tight_layout()
        plt.show()

plot_accuracy_comparison(class_accuracies)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from typing import Union

profession_dict = {
    "accountant": 0, "architect": 1, "attorney": 2, "chiropractor": 3,
    "comedian": 4, "composer": 5, "dentist": 6, "dietitian": 7,
    "dj": 8, "filmmaker": 9, "interior_designer": 10, "journalist": 11,
    "model": 12, "nurse": 13, "painter": 14, "paralegal": 15,
    "pastor": 16, "personal_trainer": 17, "photographer": 18, "physician": 19,
    "poet": 20, "professor": 21, "psychologist": 22, "rapper": 23,
    "software_engineer": 24, "surgeon": 25, "teacher": 26, "yoga_teacher": 27,
    "profession": -4, "gender": -2
}

fontsize = 16

def plot_accuracy_percentage_change(
    class_accuracies: dict[int, Union[int, dict[float, float]]],
):
    clean_accuracies = class_accuracies[-1]
    abl_accuracies = {k: v for k, v in class_accuracies.items() if k != -1}
    del abl_accuracies[-2]
    del abl_accuracies[-4]
    
    T_effects = list(next(iter(abl_accuracies.values())).keys())  # same for all classes
    ablated_class_idxs = list(abl_accuracies.keys())
    eval_class_idxs = list(clean_accuracies.keys())

    # Create a mapping from class index to profession name
    idx_to_profession = {v: k for k, v in profession_dict.items()}

    for T_effect in T_effects:
        fig, ax = plt.subplots(figsize=(16, 8))  # Increased figure size for readability
        bar_width = 0.8 / (len(eval_class_idxs) + 1)
        x = np.arange(len(ablated_class_idxs))

        # Calculate and plot percentage changes
        for i, eval_class_idx in enumerate(eval_class_idxs):
            clean_accuracy = clean_accuracies[eval_class_idx]
            changes = [(abl_accuracies[abl_class_idx][T_effect].get(eval_class_idx, 0)['acc'] - clean_accuracy) / clean_accuracy * 100 
                       for abl_class_idx in ablated_class_idxs]
            
            colors = ["orange" if eval_class_idx != abl_class_idx else "red" for abl_class_idx in ablated_class_idxs]
            ax.bar(x + (i - (len(eval_class_idxs) - 1) / 2) * bar_width, changes, bar_width, color=colors)

        # Add a horizontal line at 0%
        ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)

        # Customize the plot
        ax.set_xlabel("Ablated Profession", fontsize=fontsize)
        ax.set_ylabel("Percentage Change in Accuracy", fontsize=fontsize)
        ax.set_title(f"Percentage change in Class Accuracies\nNumber of Ablated Features = {T_effect}")
        ax.set_xticks(x)
        ax.set_xticklabels([idx_to_profession[idx] for idx in ablated_class_idxs], rotation=45, ha='right')
        
        # Create a custom legend
        from matplotlib.patches import Patch
        legend_elements = [Patch(facecolor='orange', label='Unintended Class Accuracy Change'),
                           Patch(facecolor='red', label='Intended Class Accuracy Change')]
        ax.legend(handles=legend_elements, loc="lower left")

        # Add some padding to the x-axis
        plt.xlim(-0.5, len(ablated_class_idxs) - 0.5)
        
        # Set y-axis limits symmetrically based on the max absolute change
        max_change = np.max(changes)
        min_change = np.min(changes)
        print(max_change, min_change)
        # plt.ylim(min_change * 1.1, max_change * 1.1)

        # Show the plot
        plt.tight_layout()
        plt.show()

plot_accuracy_percentage_change(class_accuracies)