In [1]:
%load_ext autoreload
%autoreload 2

from task_evaluation import TaskEvaluation
from circuit_discovery import CircuitDiscovery

import torch
import time
import plotly.express as px
import matplotlib.pyplot as plt

from task_evaluation import TaskEvaluation
from data.ioi_dataset import gen_templated_prompts
from data.greater_than_dataset import generate_greater_than_dataset
from circuit_discovery import CircuitDiscovery, only_feature
from circuit_lens import CircuitComponent
from plotly_utils import *
from data.ioi_dataset import IOI_GROUND_TRUTH_HEADS
from data.greater_than_dataset import GT_GROUND_TRUTH_HEADS
from memory import get_gpu_memory
from sklearn import metrics
from tqdm import trange

from utils import get_attn_head_roc



Loaded pretrained model gpt2-small into HookedTransformer


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [3]:
torch.set_grad_enabled(False)

dataset_prompts = gen_templated_prompts(template_idex=1, N=500)

def component_filter(component: str):
    return component in [
        CircuitComponent.Z_FEATURE,
        CircuitComponent.MLP_FEATURE,
        CircuitComponent.ATTN_HEAD,
        CircuitComponent.UNEMBED,
        # CircuitComponent.UNEMBED_AT_TOKEN,
        CircuitComponent.EMBED,
        CircuitComponent.POS_EMBED,
        # CircuitComponent.BIAS_O,
        CircuitComponent.Z_SAE_ERROR,
        # CircuitComponent.Z_SAE_BIAS,
        # CircuitComponent.TRANSCODER_ERROR,
        # CircuitComponent.TRANSCODER_BIAS,
    ]


pass_based = True

passes = 5
node_contributors = 1
first_pass_minimal = True

sub_passes = 3
do_sub_pass = False
layer_thres = 9
minimal = True


num_greedy_passes = 20
k = 1
N = 30

thres = 4

def strategy(cd: CircuitDiscovery):
    if pass_based:
        for _ in range(passes):
            cd.add_greedy_pass(contributors_per_node=node_contributors, minimal=first_pass_minimal)

            if do_sub_pass:
                for _ in range(sub_passes):
                    cd.add_greedy_pass_against_all_existing_nodes(contributors_per_node=node_contributors, skip_z_features=True, layer_threshold=layer_thres, minimal=minimal)
    else:
        for _ in range(num_greedy_passes):
            cd.greedily_add_top_contributors(k=k, reciever_threshold=thres)



task_eval = TaskEvaluation(prompts=dataset_prompts, circuit_discovery_strategy=strategy, allowed_components_filter=component_filter)

cd = task_eval.get_circuit_discovery_for_prompt(20)
# f = task_eval.get_features_at_heads_over_dataset(N=30)
N = 20

features_for_heads = task_eval.get_features_at_heads_over_dataset(N=N, use_set=False)
features_for_mlps = task_eval.get_features_at_mlps_over_dataset(N=N, use_set=False)
mlp_freqs = task_eval.get_mlp_freqs_over_dataset(N=N, return_freqs=True, visualize=False)
attn_freqs = task_eval.get_attn_head_freqs_over_dataset(N=N, subtract_counter_factuals=False, return_freqs=True, visualize=False)

Loaded pretrained model gpt2-small into HookedTransformer

Loading SAEs...


100%|██████████| 12/12 [00:08<00:00,  1.43it/s]



Loading Transcoders...


 50%|█████     | 6/12 [00:02<00:02,  2.16it/s]


KeyboardInterrupt: 

In [None]:
features_for_heads = task_eval.get_features_at_heads_over_dataset(N=N, use_set=False)
features_for_mlps = task_eval.get_features_at_mlps_over_dataset(N=N, use_set=False)
mlp_freqs = task_eval.get_mlp_freqs_over_dataset(N=N, return_freqs=True, visualize=False)
attn_freqs = task_eval.get_attn_head_freqs_over_dataset(N=N, subtract_counter_factuals=False, return_freqs=True, visualize=False)


In [None]:
mlp_freqs = task_eval.get_mlp_freqs_over_dataset(N=N, return_freqs=True)

In [None]:
ground_truth = IOI_GROUND_TRUTH_HEADS

attn_freqs = task_eval.get_attn_head_freqs_over_dataset(N=N, subtract_counter_factuals=False, return_freqs=True)
score, _, _, _ = get_attn_head_roc(ground_truth, attn_freqs.flatten().softmax(dim=-1), "GT", visualize=True, additional_title="(No Counterfactuals)")

In [None]:
import numpy as np

class CircuitPrediction:

    def __init__(self, attn_freqs, mlp_freqs, features_for_heads, features_for_mlps):
        self.attn_freqs = attn_freqs
        self.mlp_freqs = mlp_freqs
        self.features_for_heads = features_for_heads
        self.features_for_mlps = features_for_mlps

        self.component_labels = self.get_component_labels()

        self.circuit_hypergraph = self.create_circuit_hypergraph()

    def create_circuit_hypergraph(self):
        """ 
        Creates the dictionary for each component with frequency of occurrence and list of features.
        """
        # Create circuit hypergraph with keys being the component labels
        circuit_hypergraph = {label: {"freq": 0, "features": []} for label in self.component_labels}

        # Add attention heads
        for layer, freq in enumerate(self.attn_freqs):
            for head, freq_head in enumerate(freq):
                label = f"L{layer}_H{head}"
                circuit_hypergraph[label]["freq"] += freq_head.item()
                circuit_hypergraph[label]["features"].extend(self.features_for_heads[layer][head])

        # Add MLPs
        for i, freq in enumerate(self.mlp_freqs):
            label = f"MLP{i}"
            circuit_hypergraph[label]["freq"] += freq.item()
            circuit_hypergraph[label]["features"].extend(self.features_for_mlps[i])

        return circuit_hypergraph
            
    def get_component_labels(self):
        # Head labels
        head_labels = [f"L{layer}_H{head}" for layer in range(12) for head in range(12)]

        # MLP labels
        mlp_labels = [f"MLP{i}" for i in range(12)]

        # After every 12 head labels, insert the next MLP label
        labels = []
        for i in range(12):
            labels.extend(head_labels[i*12:(i+1)*12])
            labels.append(mlp_labels[i])

        return labels
    
    def get_all_features_from_attn_layer(self, layer: int):
        """
        Returns all features in the circuit hypergraph from a given attention layer.
        """ 
        features = []
        for head in range(12):
            features.extend(self.features_for_heads[layer][head])

        return features

    def get_circuit_at_threshold(self, threshold: float, visualize: bool = False):
        """
        Selects attention heads and MLPs whose frequency is above the threshold.
        """
        circuit = np.zeros(len(self.component_labels))

        for i, label in enumerate(self.component_labels):
            if self.circuit_hypergraph[label]["freq"] > threshold:
                circuit[i] = 1

        if visualize:
            self.visualize_circuit(circuit)

        return circuit
    
    def visualize_circuit(self, circuit: torch.Tensor, additional_title=""):
        """
        Visualizes the circuit.
        """
        # Circuit array
        circuit_array = np.zeros((12, 13))

        # Labels are A1, ..., A12, MLP
        labels = [f"A{i}" for i in range(1, 13)] + ["MLP"]

        # Fill in the circuit array
        for i, pred in enumerate(circuit):
            layer = i // 13
            head = i % 13

            circuit_array[layer, head] = pred

        # Create the figure with plotly imshow
        fig = px.imshow(circuit_array, labels=dict(x="Attention Head", y="Layer"), width=500,
                        title=additional_title,
                        x=labels, y=[x for x in range(12)], color_continuous_scale="blues")
        fig.show()

    def component_frequency_array(self, visualize: bool = False):
        """
        Returns the frequency of each component in the circuit.
        """
        frequency_array = np.zeros((12, 13))

        # Labels are A1, ..., A12, MLP
        labels = [f"A{i}" for i in range(1, 13)] + ["MLP"]

        # Fill in the frequency by looking at the circuit hypergraph
        for i, label in enumerate(self.component_labels):
            layer = i // 13
            head = i % 13

            frequency_array[layer, head] = self.circuit_hypergraph[label]["freq"]

        if visualize:
            fig = px.imshow(frequency_array, labels=dict(x="Attention Head", y="Layer"), width=450,
                        title="Frequency of components",
                        x=labels, y=[x for x in range(12)], color_continuous_scale="blues")
            fig.show()

        return frequency_array


    def unique_feature_array(self, visualize: bool = False):
        """
        Returns the unique features for each component in the circuit.
        """ 
        unique_features_array = np.zeros((12, 13))

        # Labels are A1, ..., A12, MLP
        labels = [f"A{i}" for i in range(1, 13)] + ["MLP"]

        # Fill in the num unique features by looking at the circuit hypergraph
        for i, label in enumerate(self.component_labels):
            layer = i // 13
            head = i % 13

            unique_features_array[layer, head] = len(set(self.circuit_hypergraph[label]["features"]))

        if visualize:
            fig = px.imshow(unique_features_array, labels=dict(x="Attention Head", y="Layer"), width=450,
                        title="Unique features",
                        x=labels, y=[x for x in range(12)], color_continuous_scale="blues")
            fig.show()

        return unique_features_array

In [None]:
cp = CircuitPrediction(attn_freqs, mlp_freqs, features_for_heads, features_for_mlps)
_ = cp.unique_feature_array(visualize=True)

In [None]:
unique_feature_array = cp.unique_feature_array(visualize=False)[:, :-1]
ground_truth = IOI_GROUND_TRUTH_HEADS.numpy()
unique_feature_array.shape, ground_truth.shape

In [None]:
from sklearn.metrics import roc_auc_score

# Flatten both
unique_feature_array = unique_feature_array.flatten()
ground_truth = ground_truth.flatten()

roc_auc_score(ground_truth, unique_feature_array)

In [None]:
# Actual frequency array
freq_array = cp.component_frequency_array(visualize=False)[:, :-1]
freq_array = torch.tensor(freq_array).flatten().softmax(dim=-1).numpy()
print(freq_array.shape)

roc_auc_score(ground_truth, freq_array)

In [None]:
prop_feature_array = unique_feature_array * freq_array

roc_auc_score(ground_truth, prop_feature_array)

In [None]:
score, _, _, _ = get_attn_head_roc(ground_truth, prop_feature_array, "IOI", visualize=True, additional_title="(No Counterfactuals)")

In [None]:
score, _, _, _ = get_attn_head_roc(ground_truth, unique_feature_array, "IOI", visualize=True, additional_title="(No Counterfactuals)")

In [None]:
score, _, _, _ = get_attn_head_roc(ground_truth, freq_array, "IOI", visualize=True, additional_title="(No Counterfactuals)")

## Autointerpretability

Given a component (i.e. L5H5 or MLP7) and a list of features for that component, we want to use `CircuitLens` to find the max-activating examples for each feature. We will then feed these features to a language model and go through the autointerpretability pipeline.

In [None]:
dataset_prompts[0]['text']

In [None]:
from typing import List, Dict
from circuit_lens import CircuitLens
from circuit_lens import ComponentLens
from torch import Tensor

prompt = dataset_prompts[0]['text'] + dataset_prompts[0]['correct']
circuit_lens = CircuitLens(prompt=prompt)

In [None]:
def get_mlp_feature_activation(circuit_lens, seq_index: int, layer: int, feature: int):
    # retrieve the active features for the given sequence index
    active_features = circuit_lens.get_active_features(seq_index)
    
    # get the starting index for the MLP features in the specified layer
    start_index = active_features.get_mlp_start_index(layer)
    
    # get the number of MLP features in this layer
    num_mlp_features = active_features.keys[layer]['mlp']
    
    # extract the MLP features and their corresponding values
    mlp_features = active_features.features[start_index:start_index + num_mlp_features]
    mlp_values = active_features.values[start_index:start_index + num_mlp_features]
    
    # check if the specified feature is active and return its value
    indices = (mlp_features == feature).nonzero(as_tuple=True)[0]
    if len(indices) > 0: return mlp_values[indices[0]].item()
    
    # if the feature is not active, return 0
    return 0.0

# Example usage
layer = 7
seq_index = -2
feature = 1414
activation = get_mlp_feature_activation(circuit_lens, seq_index, layer, feature)
print(f"Activation of feature {feature} in MLP layer {layer}: {activation}")
# for feature in range(24576):
#     #feature = 100
#     activation = get_mlp_feature_activation(circuit_lens, seq_index, layer, feature)
#     if activation > 0:
#         print(f"Activation of feature {feature} in MLP layer {layer}: {activation}")

In [None]:
def get_attention_head_activation(circuit_lens, seq_index: int, layer: int, feature: int):
    # Retrieve the active features for the given sequence index
    active_features = circuit_lens.get_active_features(seq_index)

    # Delete active features vectors
    #del active_features.vectors
    print(active_features)
    
    # Get the starting index for the attention features in the specified layer
    start_index = active_features.get_attn_start_index(layer)
    
    # Get the number of attention features in this layer
    num_attn_features = active_features.keys[layer]['attn']
    
    # Extract the attention features and their corresponding values
    attn_features = active_features.features[start_index:start_index + num_attn_features]
    attn_values = active_features.values[start_index:start_index + num_attn_features]
    
    # Check if the specified feature is active and return its value
    indices = (attn_features == feature).nonzero(as_tuple=True)[0]
    if len(indices) > 0: return attn_values[indices[0]].item()

    # If the feature is not active, return 0
    return 0.0

# Example usage
layer = 5
seq_index = -1
feature = 44256
activation = get_attention_head_activation(circuit_lens, seq_index, layer, feature)
print(f"Activation of feature {feature} in attention layer {layer}: {activation}")
# for feature in range(24576):
#     activation = get_attention_head_activation(circuit_lens, seq_index, layer, feature)
#     if activation > 0:
#         print(f"Activation of feature {feature} in attention layer {layer}: {activation}")

In [None]:
cp = CircuitPrediction(attn_freqs, mlp_freqs, features_for_heads, features_for_mlps)
layer_5_features = [x for x in list(set(cp.get_all_features_from_attn_layer(5))) if x != -1]
layer_5_features

In [None]:
prompt = 'Then, Jose and Eric had a lot of fun at the store. Eric gave a computer to Jose'
circuit_lens = CircuitLens(prompt=prompt)

layer = 5
feats = [x for x in list(set(cp.get_all_features_from_attn_layer(5))) if x != -1]

feat_act_dict = {}
for feat in feats:
    act = get_attention_head_activation(circuit_lens, -1, layer, feat)
    feat_act_dict[feat] = act

In [None]:
feat_act_dict

Okay, so how we are going to do this? What do we need to be able to do?
* Specify a list of components, and a list of features for each component.
* Over some large corpus of tokens (probably about 1 mill), cache the active features (this is good - Danny has already set up the infrastructure for this)!
* Since we have saved these active features in a principled way, they don't take up much space and also they allow us to just lookup the max-activating examples for a specific feature in a specific layer.
* We need to save the active features

In [None]:
from circuit_lens import ActiveFeatures

circuit_lens = CircuitLens(prompt=prompt)
active_features = circuit_lens.get_active_features(-2)
values = active_features.values
features = active_features.features
keys = active_features.keys
print(values.shape, features.shape)
print(keys)

test_af = ActiveFeatures(vectors=[], values=values, features=features, keys=keys)
test_af.get_attn_start_index(5)

In [None]:
# Get the starting index for the attention features in the specified layer
start_index = active_features.get_attn_start_index(layer)

# Get the number of attention features in this layer
layer = 5
num_attn_features = active_features.keys[layer]['attn']

# Extract the attention features and their corresponding values
attn_features = active_features.features[start_index:start_index + num_attn_features]
attn_values = active_features.values[start_index:start_index + num_attn_features]

attn_features.shape

In [None]:
head_lens = circuit_lens.get_head_seq_lens_for_z_feature(layer=layer, seq_index=-2, feature=attn_features[0], visualize=False, k=50)
head_lens

In [None]:
head_lens[0][0].run_data['head']

In [None]:
# Need to convert the head_lens into a list of tuples
# Each tuple is (head, contribution) to that specific feature
head_lens_list = [(head[0].run_data['head'], head[1]) for head in head_lens]
# Sum the contributions of each head
head_contributions = {}
for head, contribution in head_lens_list:
    if head in head_contributions:
        head_contributions[head] += contribution
    else:
        head_contributions[head] = contribution
head_contributions

In [None]:
print(keys)

In [None]:
prompt = 'Then, Jose and Eric had a lot of fun at the store. Eric gave a computer to Jose'
circuit_lens = CircuitLens(prompt=prompt)

In [None]:
from circuit_lens import ActiveFeatures


def prompt_to_active_features(circuit_lens: CircuitLens, seq_pos: int):

    active_features = circuit_lens.get_active_features(seq_pos)
    values = active_features.values
    features = active_features.features
    keys = active_features.keys

    head_feature_contributions = {}

    for layer in range(12):
        # Get the starting index for the attention features in the specified layer
        start_index = active_features.get_attn_start_index(layer)

        # Get the number of attention features in this layer
        num_attn_features = active_features.keys[layer]['attn']

        # Extract the attention features and their corresponding values
        attn_features = active_features.features[start_index:start_index + num_attn_features]
        
        # Get the unique features for the attention heads in this layer

        layer_head_feature_contributions = {}

        for attn_feature in set(attn_features):
            if attn_feature == -1: continue
            head_lens = circuit_lens.get_head_seq_lens_for_z_feature(layer=layer, seq_index=seq_pos, feature=attn_feature, visualize=False, k=50)
            head_lens_list = [(head[0].run_data['head'], head[1]) for head in head_lens]
            head_contributions = {}
            for head, contribution in head_lens_list:
                if head in head_contributions:
                    head_contributions[head] += contribution
                else:
                    head_contributions[head] = contribution
            layer_head_feature_contributions[attn_feature.item()] = head_contributions

        head_feature_contributions[layer] = layer_head_feature_contributions

    del active_features

    # Final dictionary to save
    prompt_dict = {
        'prompt': prompt,
        'values': values,
        'features': features,
        'keys': keys,
        'head_feature_contributions': head_feature_contributions
    }

    return prompt_dict

prompt_dict = prompt_to_active_features(circuit_lens, -2)
prompt_dict

In [None]:
from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

device = 'cpu'
batch_size = 1
# Load the transformer model and activation store
hook_point = "blocks.8.hook_resid_pre" # this doesn't matter
saes, _ = get_gpt2_res_jb_saes(hook_point)
sparse_autoencoder = saes[hook_point]
sparse_autoencoder.to(device)
sparse_autoencoder.cfg.device = device
sparse_autoencoder.cfg.hook_point = f"blocks.{layer}.attn.hook_z"
sparse_autoencoder.cfg.store_batch_size = batch_size

loader = LMSparseAutoencoderSessionloader(sparse_autoencoder.cfg)

print(f"Loader cfg batch size = {sparse_autoencoder.cfg.store_batch_size} (batch size = {batch_size})")

# don't overwrite the sparse autoencoder with the loader's sae (newly initialized)
tl_model, _, activation_store = loader.load_sae_training_group_session()

In [None]:
tokens = activation_store.get_batch_tokens()

In [None]:
tl_model.to_string(tokens)

In [None]:
total_tokens = 1_000
seq_len = 128

all_active_features = []

for i in range(total_tokens // seq_len):
    tokens = activation_store.get_batch_tokens()
    prompt = tl_model.to_string(tokens)

    for seq_pos in trange(seq_len-1):

        circuit_lens = CircuitLens(prompt=prompt)

        prompt_dict = prompt_to_active_features(circuit_lens, seq_pos)

        all_active_features.append(prompt_dict)

        del circuit_lens
        del prompt_dict

## Different approach - storing ZSAE and transcoder activations directly

In [None]:
import einops
def tokenize_and_concatenate(
    dataset,
    tokenizer,
    streaming = False,
    max_length = 1024,
    column_name = "text",
    add_bos_token = True,
):
    """Helper function to tokenizer and concatenate a dataset of text. This converts the text to tokens, concatenates them (separated by EOS tokens) and then reshapes them into a 2D array of shape (____, sequence_length), dropping the last batch. Tokenizers are much faster if parallelised, so we chop the string into 20, feed it into the tokenizer, in parallel with padding, then remove padding at the end.

    This tokenization is useful for training language models, as it allows us to efficiently train on a large corpus of text of varying lengths (without, eg, a lot of truncation or padding). Further, for models with absolute positional encodings, this avoids privileging early tokens (eg, news articles often begin with CNN, and models may learn to use early positional encodings to predict these)

    Args:
        dataset (Dataset): The dataset to tokenize, assumed to be a HuggingFace text dataset.
        tokenizer (AutoTokenizer): The tokenizer. Assumed to have a bos_token_id and an eos_token_id.
        streaming (bool, optional): Whether the dataset is being streamed. If True, avoids using parallelism. Defaults to False.
        max_length (int, optional): The length of the context window of the sequence. Defaults to 1024.
        column_name (str, optional): The name of the text column in the dataset. Defaults to 'text'.
        add_bos_token (bool, optional): . Defaults to True.

    Returns:
        Dataset: Returns the tokenized dataset, as a dataset of tensors, with a single column called "tokens"

    Note: There is a bug when inputting very small datasets (eg, <1 batch per process) where it just outputs nothing. I'm not super sure why
    """
    for key in dataset.features:
        if key != column_name:
            dataset = dataset.remove_columns(key)

    if tokenizer.pad_token is None:
        # We add a padding token, purely to implement the tokenizer. This will be removed before inputting tokens to the model, so we do not need to increment d_vocab in the model.
        tokenizer.add_special_tokens({"pad_token": "<PAD>"})
    # Define the length to chop things up into - leaving space for a bos_token if required
    if add_bos_token:
        seq_len = max_length - 1
    else:
        seq_len = max_length

    def tokenize_function(examples):
        text = examples[column_name]
        # Concatenate it all into an enormous string, separated by eos_tokens
        full_text = tokenizer.eos_token.join(text)
        # Divide into 20 chunks of ~ equal length
        num_chunks = 20
        chunk_length = (len(full_text) - 1) // num_chunks + 1
        chunks = [
            full_text[i * chunk_length : (i + 1) * chunk_length]
            for i in range(num_chunks)
        ]
        # Tokenize the chunks in parallel. Uses NumPy because HuggingFace map doesn't want tensors returned
        tokens = tokenizer(chunks, return_tensors="np", padding=True)[
            "input_ids"
        ].flatten()
        # Drop padding tokens
        tokens = tokens[tokens != tokenizer.pad_token_id]
        num_tokens = len(tokens)
        num_batches = num_tokens // (seq_len)
        # Drop the final tokens if not enough to make a full sequence
        tokens = tokens[: seq_len * num_batches]
        tokens = einops.rearrange(
            tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len
        )
        if add_bos_token:
            prefix = np.full((num_batches, 1), tokenizer.bos_token_id)
            tokens = np.concatenate([prefix, tokens], axis=1)
        return {"tokens": tokens}

    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=[column_name],
    )
    #tokenized_dataset.set_format(type="torch", columns=["tokens"])
    return tokenized_dataset

In [None]:
from transformer_lens import HookedTransformer, utils
model = HookedTransformer.from_pretrained('gpt2-small')

In [None]:
from datasets import load_dataset
from huggingface_hub import HfApi

dataset = load_dataset('Skylion007/openwebtext', split='train', streaming=True)
dataset = dataset.shuffle(seed=42, buffer_size=10_000)
tokenized_owt = tokenize_and_concatenate(dataset, model.tokenizer, max_length=128, streaming=True)
tokenized_owt = tokenized_owt.shuffle(42)
tokenized_owt = tokenized_owt.take(12800*2)
owt_tokens = np.stack([x['tokens'] for x in tokenized_owt])
owt_tokens_torch = torch.tensor(owt_tokens)

In [None]:
owt_tokens_torch.shape

In [None]:
from circuit_lens import get_model_encoders

device = 'cpu'
tl_model, z_saes, transcoders = get_model_encoders(device=device)

In [None]:
import tqdm
from transformer_lens.utils import to_numpy

# def get_feature_scores_transcoder(model, encoder, tokens_arr, feature_idx, batch_size=64, act_name='resid_pre', 
# 								  use_raw_scores=False, use_decoder=False, feature_post=None, ignore_endoftext=False):
# 	act_name = encoder.cfg.hook_point
# 	layer = encoder.cfg.hook_point_layer
		
# 	scores = []
# 	endoftext_token = model.tokenizer.eos_token 
# 	for i in tqdm.tqdm(range(0, tokens_arr.shape[0], batch_size)):
# 		with torch.no_grad():
# 			_, cache = model.run_with_cache(tokens_arr[i:i+batch_size], stop_at_layer=layer+1, names_filter=[
# 				act_name
# 			])
# 			mlp_acts = cache[act_name]
# 			mlp_acts_flattened = mlp_acts.reshape(-1, encoder.W_enc.shape[0])
# 			if feature_post is None:
# 				feature_post = encoder.W_enc[:, feature_idx] if not use_decoder else encoder.W_dec[feature_idx]
# 			bias = -(encoder.b_dec @ feature_post) if use_decoder else encoder.b_enc[feature_idx] - (encoder.b_dec @ feature_post)
# 			if use_raw_scores:
# 				cur_scores = (mlp_acts_flattened @ feature_post) + bias
# 			else:
# 				hidden_acts = encoder.encode(mlp_acts_flattened)
# 				cur_scores = hidden_acts[:, feature_idx]
# 				del hidden_acts
# 			if ignore_endoftext:
# 					cur_scores[tokens_arr[i:i+batch_size].reshape(-1) == endoftext_token] = -torch.inf
# 		scores.append(to_numpy(cur_scores.reshape(-1, tokens_arr.shape[1])).astype(np.float16))
# 	return np.concatenate(scores)


# def get_feature_scores_zsae(model, encoder, tokens_arr, feature_idx, batch_size=64, act_name='attn.hook_z',
# 							use_raw_scores=False, feature_post=None, ignore_endoftext=False):
# 	print(encoder.cfg)
# 	layer = encoder.cfg['layer']

# 	scores = []
# 	endoftext_token = model.tokenizer.eos_token

# 	name_filter = f'blocks.{layer}.attn.hook_z'

# 	for i in tqdm.tqdm(range(0, tokens_arr.shape[0], batch_size)):
# 		with torch.no_grad():
# 			_, cache = model.run_with_cache(tokens_arr[i:i+batch_size], stop_at_layer=layer+1, names_filter=[
# 				name_filter
# 			])
# 			mlp_acts = cache[name_filter]
# 			mlp_acts_flattened = mlp_acts.reshape(-1, encoder.W_enc.shape[0])
# 			if feature_post is None:
# 				feature_post = encoder.W_enc[:, feature_idx]
# 			bias = encoder.b_enc[feature_idx] - (encoder.b_dec @ feature_post)
# 			if use_raw_scores:
# 				cur_scores = (mlp_acts_flattened @ feature_post) + bias
# 			else:
# 				hidden_acts = encoder.encode(mlp_acts_flattened)
# 				cur_scores = hidden_acts[:, feature_idx]
# 				del hidden_acts
# 			if ignore_endoftext:
# 				cur_scores[tokens_arr[i:i+batch_size].reshape(-1) == endoftext_token] = -torch.inf
# 		scores.append(to_numpy(cur_scores.reshape(-1, tokens_arr.shape[1])).astype(np.float16))

# 	return np.concatenate(scores)

def get_feature_scores_transcoder(model, encoder, tokens_arr, feature_indices, batch_size=64, act_name='resid_pre', 
                                  use_raw_scores=False, use_decoder=False, feature_post=None, ignore_endoftext=False):
    act_name = encoder.cfg.hook_point
    layer = encoder.cfg.hook_point_layer
    
    scores = []
    endoftext_token = model.tokenizer.eos_token 
    for i in tqdm.tqdm(range(0, tokens_arr.shape[0], batch_size)):
        with torch.no_grad():
            _, cache = model.run_with_cache(tokens_arr[i:i+batch_size], stop_at_layer=layer+1, names_filter=[act_name])
            mlp_acts = cache[act_name]
            mlp_acts_flattened = mlp_acts.reshape(-1, encoder.W_enc.shape[0])
            if feature_post is None:
                feature_post = encoder.W_enc[:, feature_indices] if not use_decoder else encoder.W_dec[:, feature_indices]
            bias = -(encoder.b_dec @ feature_post) if use_decoder else encoder.b_enc[feature_indices] - (encoder.b_dec @ feature_post)
            if use_raw_scores:
                cur_scores = (mlp_acts_flattened @ feature_post) + bias
            else:
                hidden_acts = encoder.encode(mlp_acts_flattened)
                cur_scores = hidden_acts[:, feature_indices]
                del hidden_acts
            if ignore_endoftext:
                cur_scores[tokens_arr[i:i+batch_size].reshape(-1) == endoftext_token] = -torch.inf
        scores.append(to_numpy(cur_scores.reshape(-1, len(feature_indices), tokens_arr.shape[1])).astype(np.float16))
    return np.concatenate(scores, axis=0)


def get_feature_scores_zsae(model, encoder, tokens_arr, feature_indices, batch_size=64, act_name='attn.hook_z',
                            use_raw_scores=False, feature_post=None, ignore_endoftext=False):
    layer = encoder.cfg['layer']

    scores = []
    endoftext_token = model.tokenizer.eos_token

    name_filter = f'blocks.{layer}.attn.hook_z'

    for i in tqdm.tqdm(range(0, tokens_arr.shape[0], batch_size)):
        with torch.no_grad():
            _, cache = model.run_with_cache(tokens_arr[i:i+batch_size], stop_at_layer=layer+1, names_filter=[name_filter])
            mlp_acts = cache[name_filter]
            mlp_acts_flattened = mlp_acts.reshape(-1, encoder.W_enc.shape[0])
            if feature_post is None:
                feature_post = encoder.W_enc[:, feature_indices]
            bias = encoder.b_enc[feature_indices] - (encoder.b_dec @ feature_post)
            if use_raw_scores:
                cur_scores = (mlp_acts_flattened @ feature_post) + bias
            else:
                hidden_acts = encoder.encode(mlp_acts_flattened)
                cur_scores = hidden_acts[:, feature_indices]
                del hidden_acts
            if ignore_endoftext:
                cur_scores[tokens_arr[i:i+batch_size].reshape(-1) == endoftext_token] = -torch.inf
        scores.append(to_numpy(cur_scores.reshape(-1, len(feature_indices), tokens_arr.shape[1])).astype(np.float16))

    return np.concatenate(scores, axis=0)

In [None]:
sae = z_saes[8]
feature_idx = [20100]
feature_scores = get_feature_scores_zsae(model, sae, owt_tokens_torch[:1024], feature_idx, batch_size=4, act_name='z')

In [None]:
feature_scores = feature_scores[:, 0, :]

# Print (batch, seq) where feature scores is not zero
for i in range(feature_scores.shape[0]):
    for j in range(feature_scores.shape[1]):
        if feature_scores[i, j] != 0:
            print(f"Batch {i}, Seq {j}: {feature_scores[i, j]}")

In [None]:
transcoder = transcoders[8]
feature_idx = 20104

feature_scores = get_feature_scores_transcoder(model, transcoder, owt_tokens_torch[:1024], feature_idx, batch_size=4, act_name='resid_pre')

In [None]:
# print number of non-zero elements
np.count_nonzero(feature_scores)

In [None]:
# task_eval = TaskEvaluation(prompts=dataset_prompts, circuit_discovery_strategy=strategy, allowed_components_filter=component_filter)

# cd = task_eval.get_circuit_discovery_for_prompt(20)
# # f = task_eval.get_features_at_heads_over_dataset(N=30)
# N = 50

# features_for_heads = task_eval.get_features_at_heads_over_dataset(N=N, use_set=False)
# features_for_mlps = task_eval.get_features_at_mlps_over_dataset(N=N, use_set=False)
# mlp_freqs = task_eval.get_mlp_freqs_over_dataset(N=N, return_freqs=True, visualize=False)
# attn_freqs = task_eval.get_attn_head_freqs_over_dataset(N=N, subtract_counter_factuals=False, return_freqs=True, visualize=False)

cp = CircuitPrediction(attn_freqs, mlp_freqs, features_for_heads, features_for_mlps)

_ = cp.component_frequency_array(visualize=True)

In [None]:
features = list(set([x for x in cp.circuit_hypergraph['MLP3']['features'] if x != -1]))
features

In [None]:
owt_tokens_torch.shape

In [None]:
transcoder = transcoders[3]
feature_scores = get_feature_scores_transcoder(model, transcoder, owt_tokens_torch, features, batch_size=64, act_name='resid_pre')

In [None]:
features = list(set([x for x in cp.circuit_hypergraph['L5_H5']['features'] if x != -1]))
features

In [None]:
features = list(set([x for x in cp.circuit_hypergraph['L0_H1']['features'] if x != -1]))
print(features)
sae = z_saes[0]
feature_scores = get_feature_scores_zsae(model, sae, owt_tokens_torch, features, batch_size=64, act_name='z')

In [None]:
feature_scores.shape

In [None]:
import numpy as np
import torch
from IPython.display import HTML, display
import html

def get_top_k_activating_examples(feature_scores, tokens, model, k=5):
    # Flatten the feature_scores to get the top k scores and their indices
    flat_scores = feature_scores.flatten()
    top_k_indices = flat_scores.argsort()[-k:][::-1]
    top_k_scores = flat_scores[top_k_indices]
    
    # Convert the flat indices back to the original (batch, seq_len) indices
    top_k_batch_indices, top_k_seq_indices = np.unravel_index(top_k_indices, feature_scores.shape)

    # Extract the corresponding token sequences and scores
    top_k_tokens = [tokens[batch_idx].tolist() for batch_idx in top_k_batch_indices]
    top_k_tokens_str = [[model.to_string(x) for x in token_seq] for token_seq in top_k_tokens]
    top_k_scores_per_seq = [feature_scores[batch_idx].tolist() for batch_idx in top_k_batch_indices]

    return top_k_tokens_str, top_k_scores_per_seq, top_k_seq_indices

def highlight_scores_in_html(token_strs, scores, seq_idx, max_color='#ff8c00', zero_color='#ffffff', show_score=True):
    if len(token_strs) != len(scores):
        print("Length mismatch between tokens and scores")
        return "", ""

    scores_min = min(scores)
    scores_max = max(scores)
    scores_normalized = (np.array(scores) - scores_min) / (scores_max - scores_min)
    
    max_color_vec = np.array([int(max_color[1:3], 16), int(max_color[3:5], 16), int(max_color[5:7], 16)])
    zero_color_vec = np.array([int(zero_color[1:3], 16), int(zero_color[3:5], 16), int(zero_color[5:7], 16)])
    
    color_vecs = np.einsum('i, j -> ij', scores_normalized, max_color_vec) + np.einsum('i, j -> ij', 1 - scores_normalized, zero_color_vec)
    color_strs = [f"#{int(x[0]):02x}{int(x[1]):02x}{int(x[2]):02x}" for x in color_vecs]
    
    if show_score:
        tokens_html = "".join([
            f"""<span class='token' style='background-color: {color_strs[i]}'>{html.escape(token_str)}<span class='feature_val'> ({scores[i]:.2f})</span></span>"""
            for i, token_str in enumerate(token_strs)
        ])
        clean_text = " | ".join([
            f"{token_str} ({scores[i]:.2f})"
            for i, token_str in enumerate(token_strs)
        ])
    else:
        tokens_html = "".join([
            f"""<span class='token' style='background-color: {color_strs[i]}'>{html.escape(token_str)}</span>"""
            for i, token_str in enumerate(token_strs)
        ])
        clean_text = " | ".join(token_strs)

    head = """
    <style>
        span.token {
            font-family: monospace;
            border-style: solid;
            border-width: 1px;
            border-color: #dddddd;
        }
    </style>
    """
    return head + tokens_html, clean_text

def display_top_k_activating_examples(model, feature_scores, tokens, k=5, show_score=True):
    top_k_tokens_str, top_k_scores_per_seq, top_k_seq_indices = get_top_k_activating_examples(feature_scores, tokens, model, k=k)
    
    examples_html = []
    examples_clean_text = []
    
    for i in range(k):
        example_html, clean_text = highlight_scores_in_html(top_k_tokens_str[i], top_k_scores_per_seq[i], top_k_seq_indices[i], show_score=show_score)
        display(HTML(example_html))
        examples_html.append(example_html)
        examples_clean_text.append(clean_text)

    return examples_html, examples_clean_text

# Example usage:
# Assuming you have `model`, `feature_scores` (a tensor of shape (batch_size, seq_len)), and `tokens` (a tensor of shape (batch_size, seq_len))
example_html, examples_clean_text = display_top_k_activating_examples(model, feature_scores[:, 2, :], owt_tokens_torch, k=25, show_score=True)

In [None]:
from openai import AzureOpenAI
import yaml

config = yaml.safe_load(open("config.yaml"))

llm_client = AzureOpenAI(
        azure_endpoint=config["base_url"],
        api_key=config["azure_api_key"],
        api_version=config["api_version"],
    )

BASE_PROMPT = """ 
We're studying neurons in a neural network, trying to identify their roles. \
Look at the parts/tokens of the document this particular neuron activates \
highly for and summarize in a single sentence what the neuron is \
looking for. Don't list examples of words.We will show short text excerpts, \
followed by a comma separated list of tokens(part of word) that activate \
highly in those text excerpts. The format is word (score). Your task \
is to summarize what the highly activating tokens have in common, \
taking their context into account. 
"""

def get_response(llm_client, prompt):
    # Join prompt together separated by \n\n
    prompt = "\n\n".join(prompt)
    messages = [
        {"role": "user", "content": BASE_PROMPT + prompt}
    ]
    response = llm_client.chat.completions.create(
        model="gpt4_large",
        messages=messages,
    )
    return f"{response.choices[0].message.content}"


print(BASE_PROMPT)


In [None]:
get_response(llm_client, examples_clean_text)

## Pretty bird

In [None]:
from autointerpretability import *

config = yaml.safe_load(open("config.yaml"))
llm_client = AzureOpenAI(
    azure_endpoint=config["base_url"],
    api_key=config["azure_api_key"],
    api_version=config["api_version"],
)

model = HookedTransformer.from_pretrained('gpt2-small')

dataset = load_dataset('Skylion007/openwebtext', split='train', streaming=True)
dataset = dataset.shuffle(seed=42, buffer_size=10_000)
tokenized_owt = tokenize_and_concatenate(dataset, model.tokenizer, max_length=128, streaming=True)
tokenized_owt = tokenized_owt.shuffle(42)
tokenized_owt = tokenized_owt.take(12800 * 2)
owt_tokens = np.stack([x['tokens'] for x in tokenized_owt])
owt_tokens_torch = torch.tensor(owt_tokens)

device = 'cpu'
tl_model, z_saes, transcoders = get_model_encoders(device=device)

cp = get_circuit_prediction(task='ioi')

In [None]:
owt_tokens_torch.shape

In [None]:
features = [16513, 7861]
sae = z_saes[8]
feature_scores = get_feature_scores(model, sae, owt_tokens_torch, features, batch_size=64, act_name='z')

In [None]:
# component_name = 'L8_H1'
# features = list(set([x for x in cp.circuit_hypergraph[component_name]['features'] if x != -1]))
features = [16513, 7861]

if component_name[0] == 'L':
    sae = z_saes[0]
    feature_scores = get_feature_scores(model, sae, owt_tokens_torch, features, batch_size=64, act_name='z')
else:
    transcoder = transcoders[3]
    feature_scores = get_feature_scores(model, transcoder, owt_tokens_torch, features, batch_size=64, act_name='resid_pre')

example_html, examples_clean_text = display_top_k_activating_examples(model, feature_scores[:, 2, :], owt_tokens_torch, k=8, show_score=True)
feature_interpretation = get_response(llm_client, example_html)