In [None]:
import matplotlib.pyplot as plt
import json
import torch
import pickle

import experiments.utils as utils

submodule_trainers = {
    # "resid_post_layer_3": {"trainer_ids": list(range(0, 12, 2))},
    "resid_post_layer_4": {"trainer_ids": list(range(10, 12, 2))},
}

model_name_lookup = {"pythia70m": "EleutherAI/pythia-70m-deduped"}
dictionaries_path = "../dictionary_learning/dictionaries"

model_location = "pythia70m"
sweep_name = "_sweep0709"

ae_group_paths = utils.get_ae_group_paths(
    dictionaries_path, model_location, sweep_name, submodule_trainers
)
ae_paths = utils.get_ae_paths(ae_group_paths)

print(ae_paths)

In [None]:
results = {}
threshold = 0.001
def get_classes(first_path: str) -> list[int]:
    class_accuracies_file = f"{ae_path}/class_accuracies.pkl"
    with open(class_accuracies_file, "rb") as f:
        class_accuracies = pickle.load(f)
    return list(class_accuracies[-1].keys())

classes = get_classes(ae_paths[0])

for ae_path in ae_paths:
    class_accuracies_file = f"{ae_path}/class_accuracies.pkl"
    with open(class_accuracies_file, "rb") as f:
        class_accuracies = pickle.load(f)
    eval_results_file = f"{ae_path}/eval_results.json"
    with open(eval_results_file, "r") as f:
        eval_results = json.load(f)

    l0 = eval_results["l0"]
    frac_recovered = eval_results["frac_recovered"]

    diffs = []

    for class_id in classes:
        clean = class_accuracies[-1][class_id][0]
        patched = class_accuracies[class_id][threshold][class_id][0]

        diff = clean - patched
        diffs.append(diff)

    average_diff = sum(diffs) / len(diffs)

    results[ae_path] = {
        "l0": l0,
        "frac_recovered": frac_recovered,
        "average_diff": average_diff,
    }
    

In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import numpy as np

def plot_3var_graph(
    results: dict,
    xlims: tuple[float, float] = (0.0, 400.0),
    y_lims: tuple[float, float] = (0.985, 1.001),
    colorbar_label: str = "Average Diff",
    output_filename: str = None,
):
    # Extract data from results
    l0_values = [data['l0'] for data in results.values()]
    frac_recovered_values = [data['frac_recovered'] for data in results.values()]
    average_diff_values = [data['average_diff'] for data in results.values()]

    # Create the scatter plot
    fig, ax = plt.subplots(figsize=(10, 6))

    # Create a normalize object for color scaling
    norm = Normalize(vmin=min(average_diff_values), vmax=max(average_diff_values))

    # Plot data points
    scatter = ax.scatter(
        l0_values,
        frac_recovered_values,
        c=average_diff_values,
        cmap="viridis",
        s=100,
        norm=norm,
        edgecolor="black"
    )

    # Add colorbar
    cbar = fig.colorbar(scatter, ax=ax, label=colorbar_label)

    # Set labels and title
    ax.set_xlabel("L0")
    ax.set_ylabel("Fraction Recovered")
    ax.set_title("L0 vs Fraction Recovered (color: Average Diff)")

    # Set axis limits
    ax.set_xlim(*xlims)
    ax.set_ylim(*y_lims)

    plt.tight_layout()

    # Save and show the plot
    if output_filename:
        plt.savefig(output_filename, bbox_inches="tight")
    plt.show()

# Example usage:
plot_3var_graph(results, (0, 600), (0, 1))
