In [None]:
import numpy as np
import torch
import plotly_express as px

from transformer_lens import HookedTransformer

# Model Loading
from sae_lens import SAE
from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list

# Virtual Weight / Feature Statistics Functions
from sae_lens.analysis.feature_statistics import (
    get_all_stats_dfs,
    get_W_U_W_dec_stats_df,
)

# Enrichment Analysis Functions
from sae_lens.analysis.tsea import (
    get_enrichment_df,
    manhattan_plot_enrichment_scores,
    plot_top_k_feature_projections_by_token_and_category,
)
from sae_lens.analysis.tsea import (
    get_baby_name_sets,
    get_letter_gene_sets,
    generate_pos_sets,
    get_test_gene_sets,
    get_gene_set_from_regex,
)

In [None]:
model = HookedTransformer.from_pretrained("gpt2-small")
# this is an outdated way to load the SAE. We need to have feature spartisity loadable through the new interface to remove it.
gpt2_saes = {}
gpt2_sparsities = {}

for layer in range(12):
    print(f"Downloading from layer {layer}")
    sae, original_cfg_dict, sparsity = SAE.from_pretrained(
        release="gpt2-small-res-jb",
        sae_id=f"blocks.{layer}.hook_resid_pre",
        device="cpu",
    )
    gpt2_saes[f"blocks.{layer}.hook_resid_pre"] = sae
    gpt2_sparsities[f"blocks.{layer}.hook_resid_pre"] = sparsity



In [None]:
layer = 8

# get the corresponding SAE and feature sparsities.
sparse_autoencoder = gpt2_saes[f"blocks.{layer}.hook_resid_pre"]
log_feature_sparsity = gpt2_sparsities[f"blocks.{layer}.hook_resid_pre"].cpu()

W_dec = sparse_autoencoder.W_dec.detach().cpu()

# calculate the statistics of the logit weight distributions
W_U_stats_df_dec, dec_projection_onto_W_U = get_W_U_W_dec_stats_df(
    W_dec, model, cosine_sim=False
)
W_U_stats_df_dec["sparsity"] = (
    log_feature_sparsity  # add feature sparsity since it is often interesting.
)
display(W_U_stats_df_dec)


In [None]:
from datasets import load_dataset 

# Load the OpenWebText dataset
# proportion = 0.000001
# dataset = load_dataset("openwebtext", trust_remote_code=True)#, split=f"train[:{int(proportion * 100)}%]")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")

# Function to extract text data from the dataset
def extract_texts(dataset, num_samples=None):
    texts = []
    for i, sample in enumerate(dataset["train"]):
        texts.append(sample["text"])
        if num_samples and i + 1 >= num_samples:
            break
    return texts

# Extract a large number of texts from the dataset
text_data = extract_texts(dataset)  # Adjust num_samples as needed
text_data = [i for i in text_data if i != '']


In [None]:

# Sample text data
# text_data = [
#     "The quick brown fox jumps over the lazy dog.",
#     "GPT-2 is a transformer-based model developed by OpenAI.",
#     "Sparse autoencoders can be used for feature extraction."
# ]

# Tokenize and encode the text data
tokenizer = model.tokenizer
inputs = tokenizer(text_data, return_tensors="pt", padding=True)
print('tokenisation complete')


In [None]:
from tqdm import tqdm  # For progress tracking
import h5py

# num_samples = 100  # Adjust this number as needed
# text_data = extract_texts(dataset, num_samples=num_samples)
text_data = extract_texts(dataset)
text_data = [i for i in text_data if i != '']

# Tokenize and encode the text data with truncation
tokenizer = model.tokenizer

# Define max length for truncation
# max_length = 512

# Hook function to collect hidden states
hidden_states = []

def hook_fn(module, input, output):
    hidden_states.append(output)

# Attach hooks to each layer
hooks = []
for layer in range(12):
    hook = model.blocks[layer].hook_resid_pre.register_forward_hook(hook_fn)
    hooks.append(hook)

# Open an HDF5 file to store the latent variables
with h5py.File("latent_variables.h5", "w") as hdf5_file:

    # Initialize datasets for each layer
    # Adjust based on the actual dimensionality of the latent variables
    latent_size = 24576  # Example size from the output shape
    for layer in range(12):
        hdf5_file.create_dataset(f"layer_{layer}", (0, latent_size), maxshape=(None, latent_size), chunks=True)

    # Process the corpus in smaller batches
    batch_size = 8  # Adjust batch size based on your GPU/CPU memory capacity
    for i in tqdm(range(0, len(text_data), batch_size)):
        batch_texts = text_data[i:i + batch_size]
        print("beginning tokenising")
        inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True)
        print("finished tokenising")

        # Run the input through GPT-2
        with torch.no_grad():
            _ = model(inputs['input_ids'])

        # Process hidden states
        for layer in tqdm(range(12)):
            sae = gpt2_saes[f"blocks.{layer}.hook_resid_pre"]
            hidden_state = hidden_states[layer].detach().cpu()
            # print(f"Hidden state shape for layer {layer}: {hidden_state.shape}")  # Print hidden state shape

            # Extract the final token's representation
            final_token_representation = hidden_state[:, -1, :]
            # print(f"Final token representation shape for layer {layer}: {final_token_representation.shape}")

            # Encode the final token's representation using the autoencoder
            latent_vars = np.round(sae.encode(final_token_representation).detach().cpu().numpy(), 2)
            # print(f"Latent variable shape for layer {layer}: {latent_vars.shape}")

            # Append latent variables to the HDF5 file
            layer_dataset = hdf5_file[f"layer_{layer}"]
            current_size = layer_dataset.shape[0]
            new_size = current_size + latent_vars.shape[0]
            layer_dataset.resize(new_size, axis=0)
            layer_dataset[current_size:new_size, :] = latent_vars

        # Clear hidden states
        hidden_states.clear()
        # print(i)
        # Flush data to disk periodically
        # if i % 10 == 0:  # Adjust the frequency of flushing as needed
            # hdf5_file.flush()

# Remove hooks
for hook in hooks:
    hook.remove()

print("Latent variables saved successfully.")

In [None]:
# Hook function to collect hidden states
hidden_states = []

def hook_fn(module, input, output):
    hidden_states.append(output)

# Attach hooks to each layer
hooks = []
for layer in range(12):
    hook = model.blocks[layer].hook_resid_pre.register_forward_hook(hook_fn)
    hooks.append(hook)

# Run the input through GPT-2
with torch.no_grad():
    _ = model(inputs['input_ids'])

# Remove hooks
for hook in hooks:
    hook.remove()


In [None]:
print(len(text_data))
print(inputs['input_ids'].shape)
print(inputs)

In [None]:
# Initialize a dictionary to store the latent variables
latent_variables = {f"layer_{layer}": [] for layer in range(12)}

# Run hidden states through the corresponding autoencoders and save latent variables
for layer in range(12):
    sae = gpt2_saes[f"blocks.{layer}.hook_resid_pre"]
    hidden_state = hidden_states[layer].detach().cpu()
    latent_vars = sae.encode(hidden_state).detach().cpu().numpy()
    latent_variables[f"layer_{layer}"].append(latent_vars)

# Convert latent variables to numpy arrays
for layer in range(12):
    latent_variables[f"layer_{layer}"] = np.concatenate(latent_variables[f"layer_{layer}"], axis=0)

# Save the latent variables to files (e.g., using np.save)
for layer in range(12):
    np.save(f"latent_variables_layer_{layer}.npy", latent_variables[f"layer_{layer}"])

print("Latent variables saved successfully.")