In [13]:
import sys

sys.path.append("..")
import os
import pickle

import torch
from dotenv import load_dotenv
from huggingface_hub import login
from nnsight import LanguageModel
from tqdm import tqdm

from utils_data import get_dataset_sizes, get_xy_traintest
from utils_training import find_best_reg

load_dotenv()
from os import getenv

HF_TOKEN = getenv("HF_TOKEN")

# This file tests the baseline classifier on 110_human_aimade
We find that while the SAE probes identify spurious punctuation features, our baseline 
classifier also primarily activates on punctuation.

In [11]:
from datasets import load_dataset

dataset = load_dataset("NeelNanda/pile-10k", split="train")
df = dataset.to_pandas()

Using the latest cached version of the dataset since NeelNanda/pile-10k couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /home/subhashk/.cache/huggingface/datasets/NeelNanda___pile-10k/default/0.0.0/127bfedcd5047750df5ccf3a12979a47bfa0bafa (last modified on Sun Feb 23 15:51:29 2025).


In [None]:
def find_classifier(layer=20, model_name="gemma-2-9b"):
    dataset = "110_aimade_humangpt3"
    dataset_sizes = get_dataset_sizes()
    size = dataset_sizes[dataset]
    num_train = min(size - 100, 1024)
    X_train, y_train, X_test, y_test = get_xy_traintest(
        num_train, dataset, layer, model_name=model_name
    )
    _, classifier = find_best_reg(
        X_train, y_train, X_test, y_test, return_classifier=True
    )
    # Create directory if it doesn't exist
    os.makedirs("results/investigate", exist_ok=True)
    # Save classifier
    with open(f"results/investigate/{dataset}/{dataset}_classifier.pkl", "wb") as f:
        pickle.dump(classifier, f)
    return classifier


def load_classifier_weights(dataset="110_aimade_humangpt3"):
    """Load saved classifier from disk and return its weights"""
    with open(f"results/investigate/{dataset}/{dataset}_classifier.pkl", "rb") as f:
        classifier = pickle.load(f)
    return torch.tensor(classifier.coef_[0])


classifier = load_classifier_weights()
classifier

In [None]:
# Load model# Set the token as an environment variable
os.environ["HUGGINGFACE_TOKEN"] = HF_TOKEN
login(token=os.environ["HUGGINGFACE_TOKEN"])
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
full_model_name = "google/gemma-2-9b"  #'EleutherAI/gpt-j-6B'#'meta-llama/Llama-3.1-8B'#'EleutherAI/gpt-j-6B' #'EleutherAI/pythia-6.9b'## # ##'##
MODEL_NAME = full_model_name.split("/")[-1]
model = LanguageModel(
    full_model_name, device_map=device, torch_dtype=torch.bfloat16, dispatch=True
)
remote = False
NLAYERS = model.config.num_hidden_layers

In [None]:
def get_texts():
    texts = list(df["text"])
    return texts


def get_tokens():
    texts = get_texts()
    toks = []
    tok_strings = []
    max_seq_len = 1024
    skipped = 0
    for text in tqdm(texts):
        q_toks = model.tokenizer(text)["input_ids"]
        if len(q_toks) > max_seq_len:
            # print(f"Warning: Sequence length {len(q_toks)} exceeds max length {max_seq_len}")
            skipped += 1
            continue
        toks.append(q_toks)
        # Get token strings
        tok_str = model.tokenizer.convert_ids_to_tokens(q_toks)
        tok_strings.append(tok_str)
    # Create directory if it doesn't exist
    os.makedirs("results/investigate", exist_ok=True)
    # Save tokens and token strings
    torch.save(toks, "results/investigate/110_aimade_humangpt3_tokens.pt")
    torch.save(tok_strings, "results/investigate/110_aimade_humangpt3_token_strings.pt")
    print(skipped)
    return toks, tok_strings


get_tokens()

In [None]:
def load_tokens():
    """Load saved tokens from disk"""
    return torch.load(
        "results/investigate/110_aimade_humangpt3_tokens.pt", weights_only=False
    )


def load_token_strings():
    """Load saved token strings from disk"""
    return torch.load(
        "results/investigate/110_aimade_humangpt3_token_strings.pt", weights_only=False
    )

In [None]:
def get_dot_products():
    tokens = load_tokens()
    classifier = load_classifier_weights()
    token_strs = load_token_strings()
    token_val = {}
    with torch.no_grad():
        bar = tqdm(tokens)
        i = 0
        for token in bar:
            with model.trace(validate=False, remote=remote) as tracer:
                with tracer.invoke(token, scan=False):
                    hs = model.model.layers[20].output[0][0].save()
            token_str = token_strs[i]
            # Calculate dot product between classifier and hidden states
            dot_products = torch.matmul(hs.float(), classifier.to("cuda:1").float())
            # Add values to dictionary
            for j, tok in enumerate(token_str):
                if tok not in token_val:
                    token_val[tok] = []
                token_val[tok].append(dot_products[j].cpu().item())
            i += 1
            # Find token with highest mean dot product
            max_tok = max(
                token_val.items(),
                key=lambda x: sum(x[1]) / len(x[1]) if len(x[1]) > 0 else float("-inf"),
            )[0]
            bar.set_postfix({"max_tok": max_tok})
            if i % 100 == 0:
                os.makedirs("results/investigate/110_aimade_humangpt3", exist_ok=True)
                torch.save(
                    token_val,
                    "results/investigate/110_aimade_humangpt3/token_values.pt",
                )

    # Save token values
    os.makedirs("results/investigate/110_aimade_humangpt3", exist_ok=True)
    torch.save(token_val, "results/investigate/110_aimade_humangpt3/token_values.pt")
    return token_val


get_dot_products()

# Here we print data for Table 8

In [14]:
t = torch.load(
    "../results/investigate/110_aimade_humangpt3/token_values.pt", weights_only=False
)

# Calculate total number of unique tokens
total_tokens = len(t.keys())
summ = 0
for key in t.keys():
    summ += len(t[key])
print("Total tokens processed", summ)
# Calculate mean for each key and sort
means = {k: torch.mean(torch.tensor(v)).item() for k, v in t.items() if len(t[k]) >= 10}
sorted_keys = sorted(means.items(), key=lambda x: x[1], reverse=True)

# Print total token count
print(f"Total unique tokens: {total_tokens}\n")

# Print table header
print(f"{'Token':<20} {'Mean Activation':>15} {'Occurrences':>12}")
print("-" * 47)

# Print top 10 tokens
for key, mean in sorted_keys[:10]:
    print(f"{key:<20} {mean:>15.4f} {len(t[key]):>12}")

Total tokens processed 2833414
Total unique tokens: 93100

Token                Mean Activation  Occurrences
-----------------------------------------------
<bos>                         6.8863         7625
!).                           6.2529           10
Q                             6.2271         1436
”.                            6.0338          144
.”                            5.9111          975
.).                           5.7334           24
﻿                             5.5035           17
."                            5.4455         1057
".                            5.4132          319
}$.                           5.3990           24
