# Probes across all layers

We score SAEs by their ability to "recover" supervised concepts from the residual stream? To which degree are those concepts detectable by linear probes at all?

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Imports
import torch
import os
import pickle

from nnsight import LanguageModel

import experiments.utils as utils
from experiments.probe_training import train_probes
from experiments.pipeline_config import PipelineConfig
from experiments.dataset_info import *

In [None]:
cfg = PipelineConfig()

cfg.device = 'cuda'

# llm_model_name = "EleutherAI/pythia-70m-deduped"
llm_model_name = "google/gemma-2-2b"
cfg.model_dtype = torch.bfloat16

cfg.spurious_corr = False
cfg.probe_train_set_size = 1000
cfg.probe_test_set_size = 250
cfg.probe_context_length = 128
cfg.probe_batch_size = 250
cfg.probe_epochs = 10
cfg.probes_dir = 'probes'


# cfg.dataset_name = 'bias_in_bios'
# cfg.chosen_class_indices = [0, 1,]

cfg.dataset_name ='amazon_reviews_1and5'

chosen_classes = [
    "Beauty_and_Personal_Care",
    "Books",
    "Automotive",
    "Musical_Instruments",
    "Software",
    "Sports_and_Outdoors",
]
cfg.chosen_class_indices = [
    amazon_category_dict[c] for c in chosen_classes
]

In [None]:
# Load model
# TODO: I think there may be a scoping issue with model and get_acts(), but we currently aren't using get_acts()
model = LanguageModel(llm_model_name, device_map=cfg.device, dispatch=True, torch_dtype=cfg.model_dtype)
only_model_name = llm_model_name.split("/")[-1]
model_eval_config = utils.ModelEvalConfig.from_full_model_name(llm_model_name)

num_layers = len(model.model.layers) # TODO Make model agnostic
# num_layers = len(model.gpt_neox.layers) # TODO Make model agnostic

In [None]:
test_accs_all_layers_scr = []
# for layer in range(num_layers):
for layer in [19]:
    print(f"Training probes for layer {layer}")

    date = "0909"
    probe_path = f"probes/{only_model_name}/{cfg.dataset_name}_scr{cfg.spurious_corr}_probes_layer{layer}_date{date}.pkl"

    test_accs_all_layers_scr.append(
        # TODO adapt train_probes to reuse tokenized datasets
        train_probes(
            cfg.probe_train_set_size,
            cfg.probe_test_set_size,
            model,
            context_length=cfg.probe_context_length,
            probe_batch_size=cfg.probe_batch_size,
            llm_batch_size=model_eval_config.llm_batch_size,
            device=cfg.device,
            probe_output_filename=probe_path,
            dataset_name=cfg.dataset_name,
            probe_dir=cfg.probes_dir,
            llm_model_name=llm_model_name,
            epochs=cfg.probe_epochs,
            model_dtype=cfg.model_dtype,
            spurious_correlation_removal=cfg.spurious_corr,
            # column1_vals=cfg.column1_vals,
            # column2_vals=cfg.column2_vals,
            probe_layer=layer,
            chosen_class_indices=cfg.chosen_class_indices,
        )
    )

# Save test accuracies as json
test_accs_all_layers_path_scr = f"probes/{cfg.dataset_name}_scr{cfg.spurious_corr}_test_accs_date{date}.pkl"
with open(test_accs_all_layers_path_scr, "wb") as f:
    pickle.dump(test_accs_all_layers_scr, f)

In [None]:
all_classes = list(test_accs_all_layers_scr[0].keys())

test_accs_per_class = {}
for layer in range(len(test_accs_all_layers_scr)):
    assert all_classes == list(test_accs_all_layers_scr[layer].keys())
    for c, accs in test_accs_all_layers_scr[layer].items():
        if c not in test_accs_per_class:
            test_accs_per_class[c] = []
        test_accs_per_class[c].append(accs[0])
test_accs_per_class

In [None]:
import matplotlib.pyplot as plt

for c, accs in test_accs_per_class.items():
    plt.plot(accs, label=f'{c}. {full_amazon_int_to_str[c]}')
    plt.xlabel("Layer")
    plt.ylabel("Test Accuracy")
    plt.legend(title='Class')
plt.title(f"Test accuracy of class probes across residual stream\nfor {only_model_name}")

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def plot_class_accuracies(test_accs_per_class, full_amazon_int_to_str, only_model_name, threshold=None):
    # Create the figure
    fig = go.Figure()

    # Add traces for each class, optionally filtering by threshold
    for c, accs in test_accs_per_class.items():
        if threshold is None or max(accs) > threshold:
            fig.add_trace(go.Scatter(
                x=[11, 15, 19],  # Assuming layer numbers start from 0
                y=accs,
                mode='lines',
                name=f'{c}. {full_amazon_int_to_str[c]}'
            ))

    # Update layout
    fig.update_layout(
        title=f"Probe test acc, {only_model_name}," +
              (f" (max accuracy > {threshold})," if threshold is not None else "") +
              f" data: {cfg.dataset_name}",
        xaxis_title="Layer",
        yaxis_title="Test Accuracy",
        legend_title="Class",
        hovermode="x unified"
    )

    # Update x-axis to show integer values
    fig.update_xaxes(tick0=0, dtick=1)

    # Show the plot
    fig.show()


# only display class if max acc above thresh
display_thresh = 0.8
plot_class_accuracies(test_accs_per_class, full_amazon_int_to_str, only_model_name, threshold=display_thresh)