In [None]:
!python3 -V

# Setup

## Dependencies

In [None]:

#!pip install transformer_lens
#!pip install gradio

In [None]:
import transformer_lens
from transformer_lens import HookedTransformer, utils
import torch
import numpy as np
import gradio as gr
import pprint
import json
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from huggingface_hub import HfApi
from IPython.display import HTML
from functools import partial
import tqdm.notebook as tqdm
import plotly.express as px
import pandas as pd

## Defining the Autoencoder

In [None]:
cfg = {
    "seed": 49,
    "batch_size": 4096,
    "buffer_mult": 384,
    "lr": 1e-4,
    "num_tokens": int(2e9),
    "l1_coeff": 3e-4,
    "beta1": 0.9,
    "beta2": 0.99,
    "dict_mult": 8,
    "seq_len": 128,
    "d_mlp": 2048,
    "enc_dtype":"fp32",
    "remove_rare_dir": False,
}
cfg["model_batch_size"] = 64
cfg["buffer_size"] = cfg["batch_size"] * cfg["buffer_mult"]
cfg["buffer_batches"] = cfg["buffer_size"] // cfg["seq_len"]

In [None]:
DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
class AutoEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        d_hidden = cfg["d_mlp"] * cfg["dict_mult"]
        d_mlp = cfg["d_mlp"]
        l1_coeff = cfg["l1_coeff"]
        dtype = DTYPES[cfg["enc_dtype"]]
        torch.manual_seed(cfg["seed"])
        self.W_enc = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_mlp, d_hidden, dtype=dtype)))
        self.W_dec = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, d_mlp, dtype=dtype)))
        self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))
        self.b_dec = nn.Parameter(torch.zeros(d_mlp, dtype=dtype))

        self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)

        self.d_hidden = d_hidden
        self.l1_coeff = l1_coeff

        self.to("cuda")

    def forward(self, x):
        x_cent = x - self.b_dec
        acts = F.relu(x_cent @ self.W_enc + self.b_enc)
        x_reconstruct = acts @ self.W_dec + self.b_dec
        l2_loss = (x_reconstruct.float() - x.float()).pow(2).sum(-1).mean(0)
        l1_loss = self.l1_coeff * (acts.float().abs().sum())
        loss = l2_loss + l1_loss
        return loss, x_reconstruct, acts, l2_loss, l1_loss

    @torch.no_grad()
    def remove_parallel_component_of_grads(self):
        W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
        W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed
        self.W_dec.grad -= W_dec_grad_proj

    # def get_version(self):
    #     return 1+max([int(file.name.split(".")[0]) for file in list(SAVE_DIR.iterdir()) if "pt" in str(file)])

    # def save(self):
    #     version = self.get_version()
    #     torch.save(self.state_dict(), SAVE_DIR/(str(version)+".pt"))
    #     with open(SAVE_DIR/(str(version)+"_cfg.json"), "w") as f:
    #         json.dump(cfg, f)
    #     print("Saved as version", version)

    # def load(cls, version):
    #     cfg = (json.load(open(SAVE_DIR/(str(version)+"_cfg.json"), "r")))
    #     pprint.pprint(cfg)
    #     self = cls(cfg=cfg)
    #     self.load_state_dict(torch.load(SAVE_DIR/(str(version)+".pt")))
    #     return self

    @classmethod
    def load_from_hf(cls, version):
        """
        Loads the saved autoencoder from HuggingFace.

        Version is expected to be an int, or "run1" or "run2"

        version 25 is the final checkpoint of the first autoencoder run,
        version 47 is the final checkpoint of the second autoencoder run.
        """
        if version=="run1":
            version = 25
        elif version=="run2":
            version = 47

        cfg = utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}_cfg.json")
        pprint.pprint(cfg)
        self = cls(cfg=cfg)
        self.load_state_dict(utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}.pt", force_is_torch=True))
        return self


## Utils

### Get Reconstruction Loss

In [None]:
def replacement_hook(mlp_post, hook, encoder):
    mlp_post_reconstr = encoder(mlp_post)[1]
    return mlp_post_reconstr

def mean_ablate_hook(mlp_post, hook):
    mlp_post[:] = mlp_post.mean([0, 1])
    return mlp_post

def zero_ablate_hook(mlp_post, hook):
    mlp_post[:] = 0.
    return mlp_post

@torch.no_grad()
def get_recons_loss(num_batches=5, local_encoder=None):
    if local_encoder is None:
        local_encoder = encoder
    loss_list = []
    for i in range(num_batches):
        tokens = all_tokens[torch.randperm(len(all_tokens))[:cfg["model_batch_size"]]]
        loss = model(tokens, return_type="loss")
        recons_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), partial(replacement_hook, encoder=local_encoder))])
        # mean_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), mean_ablate_hook)])
        zero_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), zero_ablate_hook)])
        loss_list.append((loss, recons_loss, zero_abl_loss))
    losses = torch.tensor(loss_list)
    loss, recons_loss, zero_abl_loss = losses.mean(0).tolist()

    print(f"loss: {loss:.4f}, recons_loss: {recons_loss:.4f}, zero_abl_loss: {zero_abl_loss:.4f}")
    score = ((zero_abl_loss - recons_loss)/(zero_abl_loss - loss))
    print(f"Reconstruction Score: {score:.2%}")
    # print(f"{((zero_abl_loss - mean_abl_loss)/(zero_abl_loss - loss)).item():.2%}")
    return score, loss, recons_loss, zero_abl_loss

### Get Frequencies

In [None]:
# Frequency
@torch.no_grad()
def get_freqs(num_batches=25, local_encoder=None):
    if local_encoder is None:
        local_encoder = encoder
    act_freq_scores = torch.zeros(local_encoder.d_hidden, dtype=torch.float32).cuda()
    total = 0
    for i in tqdm.trange(num_batches):
        tokens = all_tokens[torch.randperm(len(all_tokens))[:cfg["model_batch_size"]]]

        _, cache = model.run_with_cache(tokens, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
        mlp_acts = cache[utils.get_act_name("post", 0)]
        mlp_acts = mlp_acts.reshape(-1, d_mlp)

        hidden = local_encoder(mlp_acts)[2]

        act_freq_scores += (hidden > 0).sum(0)
        total+=hidden.shape[0]
    act_freq_scores /= total
    num_dead = (act_freq_scores==0).float().mean()
    print("Num dead", num_dead)
    return act_freq_scores

## Visualise Feature Utils

In [None]:
from html import escape
import colorsys

from IPython.display import display

SPACE = "·"
NEWLINE="↩"
TAB = "→"

def create_html(strings, values, max_value=None, saturation=0.5, allow_different_length=False, return_string=False):
    # escape strings to deal with tabs, newlines, etc.
    escaped_strings = [escape(s, quote=True) for s in strings]
    processed_strings = [
        s.replace("\n", f"{NEWLINE}<br/>").replace("\t", f"{TAB}&emsp;").replace(" ", "&nbsp;")
        for s in escaped_strings
    ]

    if isinstance(values, torch.Tensor) and len(values.shape)>1:
        values = values.flatten().tolist()

    if not allow_different_length:
        assert len(processed_strings) == len(values)

    # scale values
    if max_value is None:
        max_value = max(max(values), -min(values))+1e-3
    scaled_values = [v / max_value * saturation for v in values]

    # create html
    html = ""
    for i, s in enumerate(processed_strings):
        if i<len(scaled_values):
            v = scaled_values[i]
        else:
            v = 0
        if v < 0:
            hue = 0  # hue for red in HSV
        else:
            hue = 0.66  # hue for blue in HSV
        rgb_color = colorsys.hsv_to_rgb(
            hue, v, 1
        )  # hsv color with hue 0.66 (blue), saturation as v, value 1
        hex_color = "#%02x%02x%02x" % (
            int(rgb_color[0] * 255),
            int(rgb_color[1] * 255),
            int(rgb_color[2] * 255),
        )
        html += f'<span style="background-color: {hex_color}; border: 1px solid lightgray; font-size: 16px; border-radius: 3px;">{s}</span>'
    if return_string:
        return html
    else:
        display(HTML(html))

def basic_feature_vis(text, feature_index, max_val=0):
    feature_in = encoder.W_enc[:, feature_index]
    feature_bias = encoder.b_enc[feature_index]
    _, cache = model.run_with_cache(text, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
    mlp_acts = cache[utils.get_act_name("post", 0)][0]
    feature_acts = F.relu((mlp_acts - encoder.b_dec) @ feature_in + feature_bias)
    if max_val==0:
        max_val = max(1e-7, feature_acts.max().item())
        # print(max_val)
    # if min_val==0:
    #     min_val = min(-1e-7, feature_acts.min().item())
    return basic_token_vis_make_str(text, feature_acts, max_val)
def basic_token_vis_make_str(strings, values, max_val=None):
    if not isinstance(strings, list):
        strings = model.to_str_tokens(strings)
    values = utils.to_numpy(values)
    if max_val is None:
        max_val = values.max()
    # if min_val is None:
    #     min_val = values.min()
    header_string = f"<h4>Max Range <b>{values.max():.4f}</b> Min Range: <b>{values.min():.4f}</b></h4>"
    header_string += f"<h4>Set Max Range <b>{max_val:.4f}</b></h4>"
    # values[values>0] = values[values>0]/ma|x_val
    # values[values<0] = values[values<0]/abs(min_val)
    body_string = create_html(strings, values, max_value=max_val, return_string=True)
    return header_string + body_string
# display(HTML(basic_token_vis_make_str(tokens[0, :10], mlp_acts[0, :10, 7], 0.1)))
# # %%
# The `with gr.Blocks() as demo:` syntax just creates a variable called demo containing all these components
import gradio as gr
try:
    demos[0].close()
except:
    pass
demos = [None]
def make_feature_vis_gradio(feature_id, starting_text=None, batch=None, pos=None):
    if starting_text is None:
        starting_text = model.to_string(all_tokens[batch, 1:pos+1])
    try:
        demos[0].close()
    except:
        pass
    with gr.Blocks() as demo:
        gr.HTML(value=f"Hacky Interactive Neuroscope for gelu-1l")
        # The input elements
        with gr.Row():
            with gr.Column():
                text = gr.Textbox(label="Text", value=starting_text)
                # Precision=0 makes it an int, otherwise it's a float
                # Value sets the initial default value
                feature_index = gr.Number(
                    label="Feature Index", value=feature_id, precision=0
                )
                # # If empty, these two map to None
                max_val = gr.Number(label="Max Value", value=None)
                # min_val = gr.Number(label="Min Value", value=None)
                inputs = [text, feature_index, max_val]
        with gr.Row():
            with gr.Column():
                # The output element
                out = gr.HTML(label="Neuron Acts", value=basic_feature_vis(starting_text, feature_id))
        for inp in inputs:
            inp.change(basic_feature_vis, inputs, out)
    demo.launch(share=True)
    demos[0] = demo

### Inspecting Top Logits

In [None]:
SPACE = "·"
NEWLINE="↩"
TAB = "→"
def process_token(s):
    if isinstance(s, torch.Tensor):
        s = s.item()
    if isinstance(s, np.int64):
        s = s.item()
    if isinstance(s, int):
        s = model.to_string(s)
    s = s.replace(" ", SPACE)
    s = s.replace("\n", NEWLINE+"\n")
    s = s.replace("\t", TAB)
    return s

def process_tokens(l):
    if isinstance(l, str):
        l = model.to_str_tokens(l)
    elif isinstance(l, torch.Tensor) and len(l.shape)>1:
        l = l.squeeze(0)
    return [process_token(s) for s in l]

def process_tokens_index(l):
    if isinstance(l, str):
        l = model.to_str_tokens(l)
    elif isinstance(l, torch.Tensor) and len(l.shape)>1:
        l = l.squeeze(0)
    return [f"{process_token(s)}/{i}" for i,s in enumerate(l)]

def create_vocab_df(logit_vec, make_probs=False, full_vocab=None):
    if full_vocab is None:
        full_vocab = process_tokens(model.to_str_tokens(torch.arange(model.cfg.d_vocab)))
    vocab_df = pd.DataFrame({"token": full_vocab, "logit": utils.to_numpy(logit_vec)})
    if make_probs:
        vocab_df["log_prob"] = utils.to_numpy(logit_vec.log_softmax(dim=-1))
        vocab_df["prob"] = utils.to_numpy(logit_vec.softmax(dim=-1))
    return vocab_df.sort_values("logit", ascending=False)

### Make Token DataFrame

In [None]:
def list_flatten(nested_list):
    return [x for y in nested_list for x in y]
def make_token_df(tokens, len_prefix=5, len_suffix=1):
    str_tokens = [process_tokens(model.to_str_tokens(t)) for t in tokens]
    unique_token = [[f"{s}/{i}" for i, s in enumerate(str_tok)] for str_tok in str_tokens]

    context = []
    batch = []
    pos = []
    label = []
    for b in range(tokens.shape[0]):
        # context.append([])
        # batch.append([])
        # pos.append([])
        # label.append([])
        for p in range(tokens.shape[1]):
            prefix = "".join(str_tokens[b][max(0, p-len_prefix):p])
            if p==tokens.shape[1]-1:
                suffix = ""
            else:
                suffix = "".join(str_tokens[b][p+1:min(tokens.shape[1]-1, p+1+len_suffix)])
            current = str_tokens[b][p]
            context.append(f"{prefix}|{current}|{suffix}")
            batch.append(b)
            pos.append(p)
            label.append(f"{b}/{p}")
    # print(len(batch), len(pos), len(context), len(label))
    return pd.DataFrame(dict(
        str_tokens=list_flatten(str_tokens),
        unique_token=list_flatten(unique_token),
        context=context,
        batch=batch,
        pos=pos,
        label=label,
    ))

## Loading the Model

In [None]:
model = HookedTransformer.from_pretrained("gelu-1l").to(DTYPES[cfg["enc_dtype"]])
n_layers = model.cfg.n_layers
d_model = model.cfg.d_model
n_heads = model.cfg.n_heads
d_head = model.cfg.d_head
d_mlp = model.cfg.d_mlp
d_vocab = model.cfg.d_vocab

## Loading Data

In [None]:
data = load_dataset("NeelNanda/c4-code-20k", split="train")
print(type(data))
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(42)
all_tokens = tokenized_data["tokens"]
print('all_tokens.shape', all_tokens.shape)

# Analysis

## Loading the Autoencoder

There are two runs on separate random seeds, along with a bunch of intermediate checkpoints

In [None]:
auto_encoder_run = "run1" # @param ["run1", "run2"]
encoder = AutoEncoder.load_from_hf(auto_encoder_run)

## Using the Autoencoder

We run the model and replace the MLP activations with those reconstructed from the autoencoder, and get 91% loss recovered

In [None]:
_ = get_recons_loss(num_batches=5, local_encoder=encoder)

## Rare Features Are All The Same

For each feature we can get the frequency at which it's non-zero (per token, averaged across a bunch of batches), and plot a histogram

In [None]:
freqs = get_freqs(num_batches = 50, local_encoder = encoder)

In [None]:
# Add 1e-6.5 so that dead features show up as log_freq -6.5
log_freq = (freqs + 10**-6.5).log10()
px.histogram(utils.to_numpy(log_freq), title="Log Frequency of Features", histnorm='percent')

We see that it's clearly bimodal! Let's define rare features as those with freq < 1e-4, and look at the cosine sim of each feature with the average rare feature - we see that almost all rare features correspond to this feature!

In [None]:
is_rare = freqs < 1e-4
rare_enc = encoder.W_enc[:, is_rare]
rare_mean = rare_enc.mean(-1)
px.histogram(utils.to_numpy(rare_mean @ encoder.W_enc / rare_mean.norm() / encoder.W_enc.norm(dim=0)), title="Cosine Sim with Ave Rare Feature", color=utils.to_numpy(is_rare), labels={"color": "is_rare", "count": "percent", "value": "cosine_sim"}, marginal="box", histnorm="percent", barmode='overlay')

## Interpreting A Feature

Let's go and investigate a non rare feature, feature 7

In [None]:
feature_id = 7 # @param {type:"number"}
batch_size = 128 # @param {type:"number"}

print(f"Feature freq: {freqs[7].item():.4f}")

Let's run the model on some text and then use the autoencoder to process the MLP activations

In [None]:
tokens = all_tokens[:batch_size]
_, cache = model.run_with_cache(tokens, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
mlp_acts = cache[utils.get_act_name("post", 0)]
print(type(mlp_acts))
print(cfg['d_mlp'])
mlp_acts_flattened = mlp_acts.reshape(-1, cfg["d_mlp"])
loss, x_reconstruct, hidden_acts, l2_loss, l1_loss = encoder(mlp_acts_flattened)
# This is equivalent to:
# hidden_acts = F.relu((mlp_acts_flattened - encoder.b_dec) @ encoder.W_enc + encoder.b_enc)
print("hidden_acts.shape", hidden_acts.shape)

We can now sort and display the top tokens, and we see that this feature activates on text like " and I" (ditto for other connectives and pronouns)! It seems interpretable!

**Aside:** Note on how to read the context column:

A line like "·himself·as·democratic·socialist·and|·he|·favors" means that the preceding 5 tokens are " himself as democratic socialist and", the current token is " he" and the next token is " favors".  · are spaces, ↩ is a newline.

This gets a bit confusing for this feature, since the pipe separators look a lot like a capital I


In [None]:
token_df = make_token_df(tokens)
print(hidden_acts[:, feature_id].shape)
token_df["feature"] = utils.to_numpy(hidden_acts[:, feature_id])
token_df.sort_values("feature", ascending=False).head(20).style.background_gradient("coolwarm")

It's easy to misread evidence like the above, so it's useful to take some text and edit it and see how this changes the model's activations. Here's a hacky interactive tool to play around with some text.

In [None]:
model.cfg

In [None]:
s = "The 1899 Kentucky gubernatorial election was held on November 7, 1899. The Republican incumbent, William Bradley, was term-limited. The Democrats chose William Goebel. Republicans nominated William Taylor. Taylor won by a vote of 193,714 to 191,331. The vote was challenged on grounds of voter fraud, but the Board of Elections, though stocked with pro-Goebel members, certified the result. Democratic legislators began investigations, but before their committee could report, Goebel was shot by an unknown assassin (event pictured) on January 30, 1900. Democrats voided enough votes to swing the election to Goebel, Taylor was deposed, and Goebel was sworn into office on January 31. He died on February 3. The lieutenant governor of Kentucky, J. C. W. Beckham, became governor, and battled Taylor in court. Beckham won on appeal, and Taylor fled to Indiana, fearing arrest as an accomplice. The only persons convicted in connection with the killing were later pardoned; the assassin's identity remains a mystery"
t = model.to_tokens(s)
print(t)

In [None]:

starting_text = "Hero and I will head to Samantha and Mark's, then he and she will. Then I or you" # @param {type:"string"}
make_feature_vis_gradio(feature_id, starting_text)

A final piece of evidence: This is a one layer model, so the neurons can only matter by directly impacting the final logits! We can directly look at how the decoder weights for this feature affect the logits, and see that it boosts `'ll`! This checks out, I and he'll etc is a common construction.

In [None]:
logit_effect = encoder.W_dec[feature_id] @ model.W_out[0] @ model.W_U
create_vocab_df(logit_effect).head(20).style.background_gradient("coolwarm")

In [None]:
# Let's try to pass through inputs with the same label from a sentiment analysis dataset to see what features are activated the most
# We'll use a Twitter sentiment analysis dataset

import pyarrow as pa
import pyarrow.dataset as ds
import pandas as pd
from datasets import Dataset

# Example for turning pandas DF to huggingface Dataset
df = pd.DataFrame({'a': [0,1,2], 'b': [3,4,5]})
dataset = ds.dataset(pa.Table.from_pandas(df).to_batches())
### convert to Huggingface dataset
hg_dataset = Dataset(pa.Table.from_pandas(df))

# Load the dataset as an HG dataset
dataset = load_dataset('csv', data_files="twitter_training.csv", split="train")
dataset = dataset.rename_column('Tweet content', 'text')
dataset = dataset.filter(lambda x: x['text'] != None)
print(dataset)

# Take out just the Tweet content
#tweet_content = dataset['Tweet content']
# Rename the column to 'text' to match the model
#tweet_content = tweet_content.rename_column('Tweet content', 'text')

tokenized_data = utils.tokenize_and_concatenate(dataset, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(seed=42)
all_tokens = tokenized_data["tokens"]
print('all_tokens.shape', all_tokens.shape)

batch_size = 128
example_tokens = all_tokens[:batch_size] 

_, cache = model.run_with_cache(example_tokens, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
mlp_acts = cache[utils.get_act_name("post", 0)]
mlp_acts_flattened = mlp_acts.reshape(-1, cfg["d_mlp"])

loss, x_reconstruct, hidden_acts, l2_loss, l1_loss = encoder(mlp_acts_flattened)
print("hidden_acts.shape", hidden_acts.shape)

In [None]:
token_df = make_token_df(example_tokens)
token_df["feature"] = utils.to_numpy(hidden_acts[:, feature_id])
token_df.sort_values("feature", ascending=False).head(20).style.background_gradient("coolwarm")

In [None]:
# For each set of tokens, check which features are activated the most

def argsort(seq, reverse=False):
    return sorted(range(len(seq)), key=seq.__getitem__, reverse=reverse)

batch_df = token_df
# Go through every feature and keep track of the most activated feature
activations = [None for _ in range(encoder.d_hidden)]
mean_activations = [None for _ in range(encoder.d_hidden)]
for feature_id in range(encoder.d_hidden):

    feature_act = utils.to_numpy(hidden_acts[:, feature_id])
    
    activations[feature_id] = feature_act
    # Measure max activation based on average activation (for now)
    mean_activations[feature_id] = feature_act.mean()

# Sort the features by the mean of the max activations
sorted_feature_idxs = argsort(mean_activations, reverse=True)

In [None]:
# Visualize the MOST activated feature (based on whatever metric we used)
highest_act_idx = sorted_feature_idxs[0]
print(highest_act_idx)
batch_df['feature'] = utils.to_numpy(hidden_acts[:, highest_act_idx])
batch_df.sort_values("feature", ascending=False).head(20).style.background_gradient("coolwarm")

In [None]:
logit_effect = encoder.W_dec[highest_act_idx] @ model.W_out[0] @ model.W_U
create_vocab_df(logit_effect).head(20).style.background_gradient("coolwarm")


In the cell above we can see that for our randomly sampled batch of data this is the feature that has the highest mean activation. Unfortunately, it doesn't seem to mean much.

In [None]:
# Now that we know how to do this let's separate our data into posiitive, neutral, and negative sentiments
# For now just try negative
negative_segment = dataset.filter(lambda x: x['sentiment'] == 'Negative')

tokenized_negative = utils.tokenize_and_concatenate(negative_segment, model.tokenizer, max_length=128)
tokenized_negative = tokenized_negative.shuffle(seed=42)
all_tokens_negative = tokenized_negative["tokens"]

batch_size = 256
batch_negative = all_tokens_negative[:batch_size]

_, cache = model.run_with_cache(batch_negative, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
mlp_acts = cache[utils.get_act_name("post", 0)]
mlp_acts_flattened = mlp_acts.reshape(-1, cfg["d_mlp"])

loss, x_reconstruct, hidden_acts, l2_loss, l1_loss = encoder(mlp_acts_flattened)

token_df_negative = make_token_df(batch_negative)
batch_df = token_df_negative
# Go through every feature and keep track of the most activated feature
activations = [None for _ in range(encoder.d_hidden)]
mean_activations = [None for _ in range(encoder.d_hidden)]
for feature_id in range(encoder.d_hidden):

    feature_act = utils.to_numpy(hidden_acts[:, feature_id])
    
    activations[feature_id] = feature_act
    # Measure max activation based on average activation (for now)
    mean_activations[feature_id] = feature_act.mean()

# Sort the features by the mean of the max activations
sorted_feature_idxs = argsort(mean_activations, reverse=True)

In [None]:
# Visualize the MOST activated feature (based on whatever metric we used)
highest_act_idx = sorted_feature_idxs[0]
print(highest_act_idx)
batch_df['feature'] = utils.to_numpy(hidden_acts[:, highest_act_idx])
batch_df.sort_values("feature", ascending=False).head(20).style.background_gradient("coolwarm")

In [None]:
logit_effect = encoder.W_dec[highest_act_idx] @ model.W_out[0] @ model.W_U
create_vocab_df(logit_effect).head(20).style.background_gradient("coolwarm")

Ok so this doesn't seem to mean much either... but that is to be expected since the training data for the SAE is probably not sufficient, and the context of all the negative-labelled data is probably not very cohesive.


In [115]:
# Refactor the feature search process into a function

def argsort(seq, reverse=False):
        return sorted(range(len(seq)), key=seq.__getitem__, reverse=reverse)

def feature_search(dataset, model, autoencoder, batch_size=128, token_length=128, activation_rank=0):
    '''
    Search through features of a pre-trained SAE.
    'dataset' must first be loaded as an HuggingFace Dataset:
    from 
    dataset = load_dataset('file_type', data_files="local_file.file_type", split="whatever_split")
    '''
    tokenized_data = utils.tokenize_and_concatenate(dataset, model.tokenizer, max_length=token_length)
    tokenized_data = tokenized_data.shuffle(seed=42)
    all_tokens = tokenized_data["tokens"]

    example_tokens = all_tokens[:batch_size] 

    _, cache = model.run_with_cache(example_tokens, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
    mlp_acts = cache[utils.get_act_name("post", 0)]
    mlp_acts_flattened = mlp_acts.reshape(-1, cfg["d_mlp"])

    loss, x_reconstruct, hidden_acts, l2_loss, l1_loss = autoencoder(mlp_acts_flattened)
    print("hidden_acts.shape", hidden_acts.shape)

    batch_df = make_token_df(example_tokens)

    # Go through every feature and keep track of the most activated feature
    activations = [None for _ in range(autoencoder.d_hidden)]
    mean_activations = [None for _ in range(autoencoder.d_hidden)]
    for feature_id in range(autoencoder.d_hidden):

        feature_act = utils.to_numpy(hidden_acts[:, feature_id])

        activations[feature_id] = feature_act

        # Measure most activation based on average of top 3 activations
        # NOTE: fix here, this is not taking the largest activations
        # sorted_feature_idxs = np.argsort(feature_act)
        # sorted_feature_act = feature_act[sorted_feature_idxs]
        # mean_activations[feature_id] = sorted_feature_act[-3:].mean()

        # Measure most activation based on average activation (for now)
        mean_activations[feature_id] = feature_act.mean()

    # Sort the features by the mean of the max activations
    sorted_feature_idxs = argsort(mean_activations, reverse=True)
    #print(mean_activations[sorted_feature_idxs[0]])

    # Visualize the MOST activated feature (based on whatever metric we used)
    highest_act_idx = sorted_feature_idxs[activation_rank]
    print(sorted_feature_idxs)
    batch_df['feature'] = utils.to_numpy(hidden_acts[:, highest_act_idx])
    

    logit_effect = autoencoder.W_dec[highest_act_idx] @ model.W_out[0] @ model.W_U
    

    return highest_act_idx, batch_df, logit_effect


# Just to make sure the function works
feature_id, batch_df, logit_effect = feature_search(negative_segment, model, encoder, batch_size=128,activation_rank=0)
print(feature_id)

hidden_acts.shape torch.Size([16384, 16384])
[12344, 15631, 9780, 16359, 11960, 7052, 16350, 15192, 13810, 2802, 933, 16139, 15325, 4022, 11873, 10068, 10943, 5305, 2970, 9157, 8719, 6001, 16025, 1839, 13196, 16272, 15203, 1374, 2645, 8992, 9005, 4147, 14402, 8063, 15348, 11777, 9214, 2740, 12715, 7016, 16328, 8564, 1528, 8093, 14035, 12530, 8174, 1348, 15481, 3358, 10303, 2132, 187, 472, 9245, 15465, 12928, 9346, 11073, 14410, 3341, 4700, 7649, 5766, 326, 11042, 5815, 1771, 6116, 9981, 1564, 1444, 744, 4628, 15564, 8515, 5194, 11022, 9410, 7120, 3755, 1487, 13256, 15413, 13843, 13110, 10009, 1244, 12345, 1959, 2192, 8716, 14483, 2531, 3872, 1238, 15294, 6424, 11645, 12903, 1399, 8442, 1607, 14921, 6442, 4780, 14365, 5067, 16218, 12977, 3394, 6198, 1908, 9177, 4170, 5844, 6231, 3709, 9984, 6843, 11584, 11675, 1563, 14370, 6865, 7586, 13999, 475, 3520, 5905, 4494, 930, 6324, 12015, 9697, 8830, 7432, 13044, 15588, 6190, 15781, 9257, 13451, 161, 8405, 13764, 7873, 12328, 14588, 1527, 4805

In [103]:
#print('Feature that activates the most on this batch')
batch_df.sort_values("feature", ascending=False).head(20).style.background_gradient("coolwarm")

Unnamed: 0,str_tokens,unique_token,context,batch,pos,label,feature
11258,·Cel,·Cel/122,·his·right·leg<|EOS|>Lo|·Cel|so,87,122,87/122,1.012925
373,Game,Game/117,<|EOS|>@Rainbow6|Game|·Who,2,117,2/117,1.011477
12668,Game,Game/124,<|EOS|>@·bennite|Game|·fix,98,124,98/124,1.000981
16092,·PS,·PS/92,·cinderblock·console....|·PS|5,125,92,125/92,0.982856
4095,itches,itches/127,·fix·your·vc·gl|itches|,31,127,31/127,0.976438
5597,·PS,·PS/93,·the·white·of·the·new|·PS|5,43,93,43/93,0.946677
11232,·Cel,·Cel/96,·right·leg··<|EOS|>Lo|·Cel|so,87,96,87/96,0.943884
11102,·ur,·ur/94,bow6Game·can·fix|·ur|·fucking,86,94,86/94,0.943391
4056,itches,itches/88,·the·biggest·vc·gl|itches|·for,31,88,31/88,0.921062
347,Game,Game/91,?<|EOS|>@·Rainbow6|Game|·Who,2,91,2/91,0.91759


In [104]:
#print('\n\n')
#print('Top 20 tokens that activate the feature the most')
create_vocab_df(logit_effect).head(20).style.background_gradient("coolwarm")

Unnamed: 0,token,logit
38663,syntax,1.243342
11192,rosis,1.188068
46002,EMPT,1.171864
37027,IOException,1.136177
46975,ftime,1.109961
5173,Exception,1.090886
12483,·fucking,1.07424
36721,Constructor,1.03999
43384,·-*-,1.032145
37900,Exists,1.024349


In [None]:
positive_segment = dataset.filter(lambda x: x['sentiment'] == 'Positive')

feature_id_positive, batch_df_positive, logit_effect_positive = feature_search(positive_segment, model, encoder, token_length=256)
print(feature_id_positive)

In [None]:
batch_df_positive.sort_values("feature", ascending=False).head(20).style.background_gradient("coolwarm")

In [None]:
create_vocab_df(logit_effect_positive).head(20).style.background_gradient("coolwarm")

In [112]:
tweet_emotion_data = load_dataset('csv', data_files='tweet_emotions.csv', split='train')
tweet_emotion_data = tweet_emotion_data.rename_column('content', 'text')
print(tweet_emotion_data)
love_segment = tweet_emotion_data.filter(lambda x: x['sentiment'] == 'love')

Dataset({
    features: ['tweet_id', 'sentiment', 'text'],
    num_rows: 40000
})


In [121]:
feature_id_love, batch_df_love, logit_effect_love = feature_search(love_segment, model, encoder, token_length=128, activation_rank=1)
print(feature_id_love)

hidden_acts.shape torch.Size([16384, 16384])


[12344, 9780, 16359, 15631, 16025, 16139, 13810, 2802, 11960, 7890, 4022, 8719, 11873, 16350, 7052, 13489, 10068, 8564, 1348, 1528, 2970, 15346, 8093, 8174, 5194, 933, 8005, 15325, 1839, 6865, 15465, 7370, 14921, 7016, 5656, 11073, 472, 1399, 2645, 6142, 8063, 14970, 11767, 4922, 14994, 13969, 6231, 8757, 14637, 12806, 14588, 11703, 4329, 1487, 12530, 1462, 12802, 7586, 4114, 14751, 2132, 14369, 10303, 12855, 14602, 12708, 2779, 2346, 4147, 7677, 14402, 13851, 16100, 10718, 2531, 15203, 13200, 3198, 15716, 6832, 161, 4151, 4700, 5804, 9950, 11533, 16328, 15978, 9278, 1527, 10887, 7548, 15564, 9981, 14370, 8900, 12345, 4750, 5845, 12928, 14696, 549, 2192, 8992, 4382, 6248, 14354, 1313, 5681, 13229, 8887, 4494, 16374, 13907, 15253, 8405, 2662, 12133, 15781, 11584, 13256, 13244, 4939, 5815, 10875, 4627, 1024, 16345, 10221, 4170, 2900, 3028, 15665, 153, 3755, 13999, 14039, 9772, 14035, 14437, 10397, 9410, 3621, 7049, 843, 11357, 7197, 15315, 5305, 4805, 6185, 9596, 11675, 12559, 4946, 5844

In [122]:
batch_df_love.sort_values("feature", ascending=False).head(20).style.background_gradient("coolwarm")

Unnamed: 0,str_tokens,unique_token,context,batch,pos,label,feature
3321,ined,ined/121,"stock,·NY.·Fr|ined|·had",25,121,25/121,1.126806
4057,·of,·of/89,Vampire·tarot|·of|·the,31,89,31/89,1.118899
9455,.,./111,·finally·home·from·the·city|.|·,73,111,73/111,1.117948
9683,oux,oux/83,nie·is·the·one·Si|oux|S,75,83,75/83,1.101768
3455,",",",/127","ancin,·travelin|,|",26,127,26/127,1.069804
4058,·the,·the/90,ampire·tarot·of|·the|·eternal,31,90,31/90,1.054958
10303,.,./63,·bien·small·world·small·world|.|·yo,80,63,80/63,1.054359
1005,.,./109,·That·would·be·the·one|.|·,7,109,7/109,1.053843
3315,·Wood,·Wood/115,Home·from·gallery·opening·in|·Wood|stock,25,115,25/115,1.048529
7257,·Mr,·Mr/89,'m·sure·that·would·make|·Mr|.,56,89,56/89,1.048109


In [123]:
create_vocab_df(logit_effect_love).head(20).style.background_gradient("coolwarm")

Unnamed: 0,token,logit
39427,·wastewater,1.386535
27001,·excav,1.352543
35814,·maritime,1.301021
22094,China,1.291364
44690,·Azerba,1.29108
43403,·pioneering,1.265554
44971,ionization,1.26279
23181,·inaug,1.247584
44492,Construction,1.243941
18691,STATE,1.236436


In [125]:
# Let's try filtering for text in tweet_emotions tha that contain the word 'hate'
hate_segment = tweet_emotion_data.filter(lambda x: 'hate' in x['text'])

Filter:   0%|          | 0/40000 [00:00<?, ? examples/s]

In [141]:
feature_id_hate, batch_df_hate, logit_effect_hate = feature_search(hate_segment, model, encoder, token_length=128, 
                                                                   activation_rank=2)

hidden_acts.shape torch.Size([12672, 16384])
[12344, 15631, 7052, 9780, 16359, 16139, 13810, 11960, 2802, 2970, 16350, 11873, 2645, 16025, 13489, 4022, 4922, 11073, 15465, 14970, 3028, 7677, 13200, 15325, 8719, 8063, 8093, 7890, 5656, 1839, 1528, 933, 8564, 10068, 8757, 472, 6424, 16218, 7016, 6865, 8174, 1348, 6142, 2740, 5194, 16100, 4114, 6231, 14994, 14751, 15346, 8005, 6248, 10303, 12657, 2132, 14035, 12997, 14402, 14696, 10009, 8233, 11130, 4329, 1399, 10312, 12133, 16069, 7548, 5815, 9950, 15203, 8935, 12790, 13765, 15781, 3198, 1487, 2779, 15716, 12802, 2346, 1254, 12928, 9157, 3621, 4147, 8887, 8182, 5388, 12855, 15564, 2192, 15207, 4613, 4700, 11570, 4170, 4627, 12559, 13451, 15315, 12708, 7586, 12530, 4577, 5766, 11767, 12505, 15978, 13644, 13256, 5681, 4494, 5305, 10875, 6843, 11777, 13911, 8716, 16328, 4441, 14921, 8992, 549, 5806, 14588, 10364, 6001, 12773, 3077, 10887, 10526, 1462, 14854, 15828, 10827, 15143, 13907, 7439, 3326, 13851, 1800, 10027, 14884, 8156, 14365, 116

In [142]:
batch_df_hate.sort_values("feature", ascending=False).head(20).style.background_gradient("coolwarm")

Unnamed: 0,str_tokens,unique_token,context,batch,pos,label,feature
2777,hh,hh/89,·her·house..·ahhhh|hh|·i,21,89,21/89,0.856751
8380,yyyy,yyyy/60,..i·hateee·today|yyyy|y,65,60,65/60,0.851003
1158,·hate,·hate/6,·fuck...ahaha·i|·hate|·,9,6,9/6,0.843697
1683,·hates,·hates/19,·i·dont·know·why·she|·hates|·me,13,19,13/19,0.838199
7627,akes,akes/75,·i·know·i·hate·f|akes|<|EOS|>,59,75,59/75,0.827615
1157,·i,·i/5,<|BOS|>·fuck...ahaha|·i|·hate,9,5,9/5,0.80235
7511,·hates,·hates/87,·my·pic·again··it|·hates|·me,58,87,58/87,0.794212
469,ions,ions/85,·i·hate·making·decs|ions|!,3,85,3/85,0.780964
1163,.,./11,·hate·my·life·sometimes|.|·why,9,11,9/11,0.779551
9903,·english,·english/47,.·it·should·just·be|·english|·all,77,47,77/47,0.777446


In [143]:
create_vocab_df(logit_effect_hate).head(20).style.background_gradient("coolwarm")

Unnamed: 0,token,logit
1835,$$,3.065956
44101,!!!!!!!!,3.016877
24724,????,2.972375
22654,·lol,2.889661
14257,·♪,2.880017
13692,·ya,2.809452
33397,damn,2.791361
22805,·crap,2.779339
22709,·huh,2.745423
32453,llll,2.722019


In [None]:
# Let's try a sentiment analysis dataset from Huuggingface
# Using the news_sentiment_newsmtsc dataset

import pyarrow as pa
import pyarrow.dataset as ds
import pandas as pd
from datasets import Dataset

# Load the dataset
dataset = load_dataset("news_sentiment_newsmtsc", split="train")

_, cache = model.run_with_cache(example_tweet_tokens, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
mlp_acts = cache[utils.get_act_name("post", 0)][0]
mlp_acts_flattened = mlp_acts.reshape(-1, cfg["d_mlp"])

loss, x_reconstruct, hidden_acts, l2_loss, l1_loss = encoder(mlp_acts_flattened)
print("hidden_acts.shape", hidden_acts.shape)
