# Setup

## Dependencies

In [285]:
import transformer_lens
from transformer_lens import HookedTransformer, utils
import torch
import numpy as np
import gradio as gr
import pprint
import json
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from huggingface_hub import HfApi
from IPython.display import HTML
from functools import partial
import tqdm.notebook as tqdm
import plotly.express as px
import pandas as pd
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
# os.environ["CUDA_VISIBLE_DEVICES"]="3"

In [286]:
MODEL = 'crate-3l'
MULT = 16
LAYER = 2

## Defining the Autoencoder

In [287]:
import sae
if '6l' in MODEL:
    d_base = 768
else:
    d_base = 128
cfg = {
    "seed": 1, 
    "batch_size": 10,  # Number of samples we pass through THE LM 
    "seq_len": 1024,  # Length of each input sequence for the model
    "d_in": d_base * 4,   # Input dimension for the encoder model
    "d_sae": d_base * 4 * MULT,  # Dimensionality for the sparse autoencoder (SAE)
    "l1_lambda": 1.6e-4,
    "dataset": "-",  # Name of the dataset to use
    "dataset_args": [],  # Any additional arguments for the dataset
    "dataset_kwargs": {"split": "train", "streaming": True}, 
    "dtype": torch.float32, 
    "device": "cuda:0"
}

# import os
encoder = sae.SAE(cfg)
dir_path = f"/home/ubuntu/nanogpt4crate/arthursae/rebuttal/{MODEL.lower()}-{MULT}x/{LAYER}"
# get file path (file is the only file in the directory)
file_name = os.listdir(dir_path)[0]
path = os.path.join(dir_path, file_name)
encoder.load_from_local(path=path)


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



## Utils

### Get Reconstruction Loss

In [288]:
def replacement_hook(mlp_post, hook, encoder):
    mlp_post[:] = encoder(mlp_post, return_mode="sae_out")
    # print("mlp_post_reconstr equal to mlp_post?", (mlp_post_reconstr == mlp_post).all())
    # print("mlp_post_reconstr:", mlp_post_reconstr)
    # print("mlp_post:", mlp_post)
    return mlp_post

def mean_ablate_hook(mlp_post, hook):
    mlp_post[:] = mlp_post.mean([0, 1])
    return mlp_post

def zero_ablate_hook(mlp_post, hook):
    mlp_post[:] = 0.
    return mlp_post

@torch.no_grad()
def get_recons_loss(num_batches=20, local_encoder=None):
    if local_encoder is None:
        local_encoder = encoder
    loss_list = []
    for i in range(num_batches):
        tokens = all_tokens[torch.randperm(len(all_tokens))[:cfg["batch_size"]]]
        loss = model(tokens, return_type="loss")
        recons_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", LAYER), partial(replacement_hook, encoder=local_encoder))])
        mean_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", LAYER), mean_ablate_hook)])
        zero_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", LAYER), zero_ablate_hook)])
        loss_list.append((loss, recons_loss, zero_abl_loss))
    losses = torch.tensor(loss_list)
    loss, recons_loss, zero_abl_loss = losses.mean(0).tolist()
    
    print(f"loss: {loss:.4f}, recons_loss: {recons_loss:.4f}, zero_abl_loss: {zero_abl_loss:.4f}, mean_abl_loss: {mean_abl_loss:.4f}")
    score = ((zero_abl_loss - recons_loss)/(zero_abl_loss - loss))
    print(f"Reconstruction Score: {score:.2%}")
    # print(f"{((zero_abl_loss - mean_abl_loss)/(zero_abl_loss - loss)).item():.2%}")
    return score, loss, recons_loss, zero_abl_loss, mean_abl_loss

### Get Frequencies

In [289]:
# Frequency
@torch.no_grad()
def get_freqs(num_batches=25, local_encoder=None):
    if local_encoder is None:
        local_encoder = encoder
    # each dimension gets a freq score
    act_freq_scores = torch.zeros(local_encoder.d_sae, dtype=torch.float32).cuda()
    total = 0
    for i in tqdm.trange(num_batches):
        tokens = all_tokens[torch.randperm(len(all_tokens))[:cfg["batch_size"]]] # subsample
        # print(tokens)
        # print("number of tokens:", len(tokens))

        _, cache = model.run_with_cache(tokens, stop_at_layer=LAYER+1, names_filter=utils.get_act_name("post", LAYER))
        mlp_acts = cache[utils.get_act_name("post", LAYER)]
        # print("shape of mlp_acts (after model):", mlp_acts.shape)
        mlp_acts = mlp_acts.reshape(-1, d_mlp)
        # overcomplete autoencoding
        # loss, x_reconstruct, **acts**, l2_loss, l1_loss
        hidden = local_encoder(mlp_acts, return_mode="hidden_post")
        # print("shape of hidden (after local_encoder):", hidden.shape)

        act_freq_scores += (hidden > 0).sum(0)

        total+=hidden.shape[0]

    print("total:", total)
    act_freq_scores /= total
    num_dead = (act_freq_scores==0).float().mean()
    print("Num dead", num_dead)
    return act_freq_scores

## Visualise Feature Utils

In [290]:
from html import escape
import colorsys

from IPython.display import display

SPACE = "·"
NEWLINE="↩"
TAB = "→"

def create_html(strings, values, max_value=None, saturation=0.5, allow_different_length=False, return_string=False):
    # escape strings to deal with tabs, newlines, etc.
    escaped_strings = [escape(s, quote=True) for s in strings]
    processed_strings = [
        s.replace("\n", f"{NEWLINE}<br/>").replace("\t", f"{TAB}&emsp;").replace(" ", "&nbsp;")
        for s in escaped_strings
    ]

    if isinstance(values, torch.Tensor) and len(values.shape)>1:
        values = values.flatten().tolist()

    if not allow_different_length:
        assert len(processed_strings) == len(values)

    # scale values
    if max_value is None:
        max_value = max(max(values), -min(values))+1e-3
    scaled_values = [v / max_value * saturation for v in values]

    # create html
    html = ""
    for i, s in enumerate(processed_strings):
        if i<len(scaled_values):
            v = scaled_values[i]
        else:
            v = 0
        if v < 0:
            hue = 0  # hue for red in HSV
        else:
            hue = 0.66  # hue for blue in HSV
        rgb_color = colorsys.hsv_to_rgb(
            hue, v, 1
        )  # hsv color with hue 0.66 (blue), saturation as v, value 1
        hex_color = "#%02x%02x%02x" % (
            int(rgb_color[0] * 255),
            int(rgb_color[1] * 255),
            int(rgb_color[2] * 255),
        )
        html += f'<span style="background-color: {hex_color}; border: 1px solid lightgray; font-size: 16px; border-radius: 3px;">{s}</span>'
    if return_string:
        return html
    else:
        display(HTML(html))

def basic_feature_vis(text, feature_index, max_val=0):
    feature_in = encoder.W_enc[:, feature_index]
    feature_bias = encoder.b_enc[feature_index]
    _, cache = model.run_with_cache(text, stop_at_layer=LAYER+1, names_filter=utils.get_act_name("post", 0))
    mlp_acts = cache[utils.get_act_name("post", 0)][0]
    feature_acts = F.relu((mlp_acts - encoder.b_dec) @ feature_in + feature_bias)
    if max_val==0:
        max_val = max(1e-7, feature_acts.max().item())
        # print(max_val)
    # if min_val==0:
    #     min_val = min(-1e-7, feature_acts.min().item())
    return basic_token_vis_make_str(text, feature_acts, max_val)
def basic_token_vis_make_str(strings, values, max_val=None):
    if not isinstance(strings, list):
        strings = model.to_str_tokens(strings)
    values = utils.to_numpy(values)
    if max_val is None:
        max_val = values.max()
    # if min_val is None:
    #     min_val = values.min()
    header_string = f"<h4>Max Range <b>{values.max():.4f}</b> Min Range: <b>{values.min():.4f}</b></h4>"
    header_string += f"<h4>Set Max Range <b>{max_val:.4f}</b></h4>"
    # values[values>0] = values[values>0]/ma|x_val
    # values[values<0] = values[values<0]/abs(min_val)
    body_string = create_html(strings, values, max_value=max_val, return_string=True)
    return header_string + body_string
# display(HTML(basic_token_vis_make_str(tokens[0, :10], mlp_acts[0, :10, 7], 0.1)))
# # %%
# The `with gr.Blocks() as demo:` syntax just creates a variable called demo containing all these components
import gradio as gr
try:
    demos[0].close()
except:
    pass
demos = [None]
def make_feature_vis_gradio(feature_id, starting_text=None, batch=None, pos=None):
    if starting_text is None:
        starting_text = model.to_string(all_tokens[batch, 1:pos+1])
    try:
        demos[0].close()
    except:
        pass
    with gr.Blocks() as demo:
        gr.HTML(value=f"Hacky Interactive Neuroscope for gelu-1l")
        # The input elements
        with gr.Row():
            with gr.Column():
                text = gr.Textbox(label="Text", value=starting_text)
                # Precision=0 makes it an int, otherwise it's a float
                # Value sets the initial default value
                feature_index = gr.Number(
                    label="Feature Index", value=feature_id, precision=0
                )
                # # If empty, these two map to None
                max_val = gr.Number(label="Max Value", value=None)
                # min_val = gr.Number(label="Min Value", value=None)
                inputs = [text, feature_index, max_val]
        with gr.Row():
            with gr.Column():
                # The output element
                out = gr.HTML(label="Neuron Acts", value=basic_feature_vis(starting_text, feature_id))
        for inp in inputs:
            inp.change(basic_feature_vis, inputs, out)
    demo.launch(share=True)
    demos[0] = demo

### Inspecting Top Logits

In [291]:
SPACE = "·"
NEWLINE="↩"
TAB = "→"
def process_token(s):
    if isinstance(s, torch.Tensor):
        s = s.item()
    if isinstance(s, np.int64):
        s = s.item()
    if isinstance(s, int):
        s = model.to_string(s)
    s = s.replace(" ", SPACE)
    s = s.replace("\n", NEWLINE+"\n")
    s = s.replace("\t", TAB)
    return s

def process_tokens(l):
    if isinstance(l, str):
        l = model.to_str_tokens(l)
    elif isinstance(l, torch.Tensor) and len(l.shape)>1:
        l = l.squeeze(0)
    return [process_token(s) for s in l]

def process_tokens_index(l):
    if isinstance(l, str):
        l = model.to_str_tokens(l)
    elif isinstance(l, torch.Tensor) and len(l.shape)>1:
        l = l.squeeze(0)
    return [f"{process_token(s)}/{i}" for i,s in enumerate(l)]

def create_vocab_df(logit_vec, make_probs=False, full_vocab=None):
    if full_vocab is None:
        full_vocab = process_tokens(model.to_str_tokens(torch.arange(model.cfg.d_vocab)))
    vocab_df = pd.DataFrame({"token": full_vocab, "logit": utils.to_numpy(logit_vec)})
    if make_probs:
        vocab_df["log_prob"] = utils.to_numpy(logit_vec.log_softmax(dim=-1))
        vocab_df["prob"] = utils.to_numpy(logit_vec.softmax(dim=-1))
    return vocab_df.sort_values("logit", ascending=False)

### Make Token DataFrame

In [292]:
def list_flatten(nested_list):
    return [x for y in nested_list for x in y]
def make_token_df(tokens, len_prefix=10, len_suffix=10):
    str_tokens = [process_tokens(model.to_str_tokens(t)) for t in tokens]
    unique_token = [[f"{s}/{i}" for i, s in enumerate(str_tok)] for str_tok in str_tokens]

    context = []
    batch = []
    pos = []
    label = []
    for b in range(tokens.shape[0]):
        # context.append([])
        # batch.append([])
        # pos.append([])
        # label.append([])
        for p in range(tokens.shape[1]):
            prefix = "".join(str_tokens[b][max(0, p-len_prefix):p])
            if p==tokens.shape[1]-1:
                suffix = ""
            else:
                suffix = "".join(str_tokens[b][p+1:min(tokens.shape[1]-1, p+1+len_suffix)])
            current = str_tokens[b][p]
            context.append(f"{prefix}|{current}|{suffix}")
            batch.append(b)
            pos.append(p)
            label.append(f"{b}/{p}")
    # print(len(batch), len(pos), len(context), len(label))
    return pd.DataFrame(dict(
        str_tokens=list_flatten(str_tokens),
        unique_token=list_flatten(unique_token),
        context=context,
        batch=batch,
        pos=pos,
        label=label,
    ))

## Loading the Model

In [293]:
if 'crate' in MODEL:
    model = HookedTransformer.from_pretrained(MODEL, fold_ln=False).to(cfg["dtype"])
else:
    model = HookedTransformer.from_pretrained(MODEL).to(cfg["dtype"])
n_layers = model.cfg.n_layers
d_model = model.cfg.d_model
n_heads = model.cfg.n_heads
d_head = model.cfg.d_head
d_mlp = model.cfg.d_mlp
d_vocab = model.cfg.d_vocab

architecture crate
entering crate
loading model crate-3l
loading model /home/ubuntu/nanogpt4crate/out/ckpt-crate-3l-overparam.pt
entering nanogpt
dict_keys(['transformer.wte.weight', 'transformer.wpe.weight', 'transformer.h.0.prenorm_1.weight', 'transformer.h.0.prenorm_1.bias', 'transformer.h.0.attn.bias', 'transformer.h.0.attn.c_attn.weight', 'transformer.h.0.attn.c_proj.weight', 'transformer.h.0.attn.c_proj.bias', 'transformer.h.0.prenorm_2.weight', 'transformer.h.0.prenorm_2.bias', 'transformer.h.0.ista.weight', 'transformer.h.1.prenorm_1.weight', 'transformer.h.1.prenorm_1.bias', 'transformer.h.1.attn.bias', 'transformer.h.1.attn.c_attn.weight', 'transformer.h.1.attn.c_proj.weight', 'transformer.h.1.attn.c_proj.bias', 'transformer.h.1.prenorm_2.weight', 'transformer.h.1.prenorm_2.bias', 'transformer.h.1.ista.weight', 'transformer.h.2.prenorm_1.weight', 'transformer.h.2.prenorm_1.bias', 'transformer.h.2.attn.bias', 'transformer.h.2.attn.c_attn.weight', 'transformer.h.2.attn.c_proj.w


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



Loaded pretrained model crate-3l into HookedTransformer
Changing model dtype to torch.float32


In [294]:
tokens = "Beijing is the capital of"
input_ids = model.tokenizer.encode(tokens, return_tensors="pt").cuda()
create_vocab_df(model(input_ids)[0][-1], make_probs=True).head(20).style.background_gradient("coolwarm")

Unnamed: 0,token,logit,log_prob,prob
262,·the,13.38105,-1.301797,0.272043
2807,·China,11.085363,-3.597484,0.027393
257,·a,10.666135,-4.016712,0.018012
3794,·India,10.477721,-4.205126,0.014919
2253,·America,10.265337,-4.41751,0.012064
674,·our,10.24741,-4.435437,0.01185
428,·this,10.17963,-4.503217,0.011073
663,·its,10.093251,-4.589596,0.010157
7595,·Brazil,9.771369,-4.911478,0.007362
198,↩,9.603951,-5.078897,0.006227


In [295]:
tokens = "When Mary and John went to the store, John gave a drink to"
input_ids = model.tokenizer.encode(tokens, return_tensors="pt").cuda()
create_vocab_df(model(input_ids)[0][-1], make_probs=True).head(20).style.background_gradient("coolwarm")

Unnamed: 0,token,logit,log_prob,prob
262,·the,13.737627,-1.493005,0.224696
257,·a,12.071457,-3.159175,0.042461
465,·his,11.751668,-3.478964,0.030839
683,·him,11.496007,-3.734625,0.023882
607,·her,11.386432,-3.8442,0.021404
307,·be,11.196205,-4.034427,0.017696
502,·me,11.037497,-4.193135,0.015099
616,·my,10.655001,-4.575631,0.0103
651,·get,10.532378,-4.698254,0.009111
467,·go,10.144191,-5.086441,0.00618


In [296]:
print(model.generate("I like to go to school because"))

  0%|          | 0/10 [00:00<?, ?it/s]

I like to go to school because my parents were volunteering as a really significant pair of


## Loading Data

In [297]:
data = np.memmap('/home/ubuntu/nanogpt4crate/data/pile/val.bin', dtype=np.uint16, mode='r')
end_index = int(len(data) // 1024 * 1024)
print(end_index)
data = torch.from_numpy((data).astype(np.int64))[:end_index]
all_tokens = data.reshape(-1, cfg["seq_len"]).cuda()
print("Number of all tokens:", len(all_tokens))

294434816
Number of all tokens: 287534


# Analysis

## Using the Autoencoder

We run the model and replace the MLP activations with those reconstructed from the autoencoder, and get 91% loss recovered

In [298]:
# while True:
score, loss, recons_loss, zero_abl_loss, mean_abl_loss = get_recons_loss(num_batches=30, local_encoder=encoder)
    # if round(score, 2) == 3.45:
    #     break

loss: 3.5049, recons_loss: 3.7706, zero_abl_loss: 9.2351, mean_abl_loss: 11.1273
Reconstruction Score: 95.36%


## Rare Features Are All The Same

For each feature we can get the frequency at which it's non-zero (per token, averaged across a bunch of batches), and plot a histogram

In [299]:
freqs = get_freqs(num_batches = 20, local_encoder = encoder)

  0%|          | 0/20 [00:00<?, ?it/s]

total: 204800
Num dead tensor(0.0010, device='cuda:0')


In [300]:
# Add 1e-6.5 so that dead features show up as log_freq -6.5
log_freq = (freqs + 10**-6.5).log10()
px.histogram(utils.to_numpy(log_freq), title="Log Frequency of Features", histnorm='percent')

We see that it's clearly bimodal! Let's define rare features as those with freq < 1e-4, and look at the cosine sim of each feature with the average rare feature - we see that almost all rare features correspond to this feature!

这个(2048,)的稀疏特征条很有价值

In [301]:
is_rare = freqs < 1e-5
# W_enc is the dictionary
rare_enc = encoder.W_enc[:, is_rare]
rare_mean = rare_enc.mean(-1)
cos_sim = utils.to_numpy(rare_mean @ encoder.W_enc / rare_mean.norm() / encoder.W_enc.norm(dim=0))
px.histogram(cos_sim, title="Cosine Sim with Ave Rare Feature", color=utils.to_numpy(is_rare), labels={"color": "is_rare", "count": "percent", "value": "cosine_sim"}, marginal="box", histnorm="percent", barmode='overlay')

## Interpreting A Feature

Let's go and investigate a non rare feature, feature 7

In [302]:
very_active = []

for feature_id in range(512):
    if freqs[feature_id].item() > 1e-2:
        very_active.append(feature_id)
        if len(very_active) >= 10:
            break
print(very_active)

feature_id = 9
batch_size = 50 # @param {type:"number"}

print(f"Feature freq: {freqs[feature_id].item():.4f}")
if freqs[feature_id].item() <= 1e-3:
    print("This is a rare feature!")

tokens = all_tokens[torch.randperm(len(all_tokens))[:batch_size]]
_, cache = model.run_with_cache(tokens, stop_at_layer=LAYER+1, names_filter=utils.get_act_name("post", 0))
mlp_acts = cache[utils.get_act_name("post", 0)]
print("mlp_acts shape:", mlp_acts.shape)
mlp_acts_flattened = mlp_acts.reshape(-1, cfg["d_in"])
_, hidden_acts = encoder(mlp_acts_flattened, return_mode="both")
# This is equivalent to:
# hidden_acts = F.relu((mlp_acts_flattened - encoder.b_dec) @ encoder.W_enc + encoder.b_enc)
print("hidden_acts.shape", hidden_acts.shape)

token_df = make_token_df(tokens)
token_df["feature"] = utils.to_numpy(hidden_acts[:, feature_id])
token_df.sort_values("feature", ascending=False).head(20).style.background_gradient("coolwarm")

[5, 13, 14, 20, 21, 24, 25, 33, 37, 38]
Feature freq: 0.0033
mlp_acts shape: torch.Size([50, 1024, 512])
hidden_acts.shape torch.Size([51200, 8192])


Unnamed: 0,str_tokens,unique_token,context,batch,pos,label,feature
51160,.,./984,"·at·475,·485,·116·S.Ct|.|·2240.·Indeed,·courts·""start·with·the",49,984,49/984,0.0
51161,·2,·2/985,"·475,·485,·116·S.Ct.|·2|240.·Indeed,·courts·""start·with·the·assumption",49,985,49/985,0.0
51162,240,240/986,",·485,·116·S.Ct.·2|240|.·Indeed,·courts·""start·with·the·assumption·that",49,986,49/986,0.0
51163,.,./987,"·485,·116·S.Ct.·2240|.|·Indeed,·courts·""start·with·the·assumption·that·the",49,987,49/987,0.0
51164,·Indeed,·Indeed/988,"5,·116·S.Ct.·2240.|·Indeed|,·courts·""start·with·the·assumption·that·the·historic",49,988,49/988,0.0
51165,",",",/989",",·116·S.Ct.·2240.·Indeed|,|·courts·""start·with·the·assumption·that·the·historic·police",49,989,49/989,0.0
51166,·courts,·courts/990,"·116·S.Ct.·2240.·Indeed,|·courts|·""start·with·the·assumption·that·the·historic·police·powers",49,990,49/990,0.0
51167,"·""","·""/991","·S.Ct.·2240.·Indeed,·courts|·""|start·with·the·assumption·that·the·historic·police·powers·of",49,991,49/991,0.0
51152,",",",/976","·and·safety·of·their·citizens.·Id.·at·475|,|·485,·116·S.Ct.·2240",49,976,49/976,0.0
51153,·48,·48/977,"·safety·of·their·citizens.·Id.·at·475,|·48|5,·116·S.Ct.·2240.",49,977,49/977,0.0


It's easy to misread evidence like the above, so it's useful to take some text and edit it and see how this changes the model's activations. Here's a hacky interactive tool to play around with some text.

In [303]:
# starting_text = "Most programming languages are easy to use, like OpenCL, OpenGL, Qt, C, etc." # @param {type:"string"}
# make_feature_vis_gradio(feature_id, starting_text)

A final piece of evidence: This is a one layer model, so the neurons can only matter by directly impacting the final logits! We can directly look at how the decoder weights for this feature affect the logits, and see that it boosts `'ll`! This checks out, I and he'll etc is a common construction.

In [304]:
# logit_effect = encoder.W_dec[feature_id] @ model.blocks[0].mlp.weight @ model.W_U
# create_vocab_df(logit_effect).head(20).style.background_gradient("coolwarm")

## Manually set activation to a high value

In [305]:
input_text = "1,2,3,4,5"
# run the model with cache
tokens = model.to_tokens(input_text)

def activating_sae_hook(mlp_post, hook, encoder):
    mlp_post[:] = encoder(mlp_post, return_mode="sae_out", activate_feature=feature_id)
    return mlp_post

# _, cache = model.run_with_cache(tokens, stop_at_layer=LAYER+1, names_filter=utils.get_act_name("post", 0))
# mlp_acts = cache[utils.get_act_name("post", 0)]
# print("mlp_acts shape:", mlp_acts.shape)
# mlp_acts_flattened = mlp_acts.reshape(-1, cfg["d_in"])
# _, hidden_acts = encoder(mlp_acts_flattened, return_mode="both")
# manually set the feature to a high value
res_sae_activated = model.run_with_hooks(tokens, return_type="logits", fwd_hooks=[(utils.get_act_name("post", LAYER), partial(activating_sae_hook, encoder=encoder))])

# print("res_activated shape:", res_activated[-1].shape)
create_vocab_df(res_sae_activated[0][-1]).head(20).style.background_gradient("coolwarm")

Unnamed: 0,token,logit
8379,·sequence,15.138222
16311,·sequences,14.341653
26789,·similarity,13.963812
43366,·divergence,12.588655
50250,·amplification,12.483156
19114,·alignment,12.312242
32702,·motif,12.234316
19617,·coding,12.124865
23005,·mutations,12.079557
17670,·variants,12.026893


## Solely CRATE

In [306]:
very_active = []

for feature_id in range(512):
    if freqs[feature_id].item() > 1e-2:
        very_active.append(feature_id)
        if len(very_active) >= 10:
            break
print(very_active)

feature_id = 21
batch_size = 200 # @param {type:"number"}

print(f"Feature freq: {freqs[feature_id].item():.4f}")
if freqs[feature_id].item() <= 1e-3:
    print("This is a rare feature!")

tokens = all_tokens[torch.randperm(len(all_tokens))[:batch_size]]
_, cache = model.run_with_cache(tokens, stop_at_layer=LAYER+1, names_filter=utils.get_act_name("post", 0))
mlp_acts = cache[utils.get_act_name("post", 0)]
print("mlp_acts shape:", mlp_acts.shape)
mlp_acts_flattened = mlp_acts.reshape(-1, cfg["d_in"])
# _, hidden_acts = encoder(mlp_acts_flattened, return_mode="both")
# print("hidden_acts.shape", hidden_acts.shape)

token_df = make_token_df(tokens)
token_df["feature"] = utils.to_numpy(mlp_acts_flattened[:, feature_id])
token_df.sort_values("feature", ascending=False).head(20).style.background_gradient("coolwarm")

[5, 13, 14, 20, 21, 24, 25, 33, 37, 38]
Feature freq: 0.0233


OutOfMemoryError: CUDA out of memory. Tried to allocate 3.12 GiB. GPU 0 has a total capacity of 39.39 GiB of which 2.41 GiB is free. Including non-PyTorch memory, this process has 36.98 GiB memory in use. Of the allocated memory 35.01 GiB is allocated by PyTorch, and 1.48 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
input_text = "1,2,3,4,5"
# run the model with cache
tokens = model.to_tokens(input_text)

def activating_hook(mlp_post, hook, encoder):
    print("mlp_post shape:", mlp_post.shape)
    mlp_post[:, :, feature_id] = 100
    return mlp_post

# manually set the feature to a high value
res_activated = model.run_with_hooks(tokens, return_type="logits", fwd_hooks=[(utils.get_act_name("post", LAYER), partial(activating_hook, encoder=encoder))])

# print("res_activated shape:", res_activated[-1].shape)
create_vocab_df(res_activated[0][-1]).head(20).style.background_gradient("coolwarm")

mlp_post shape: torch.Size([1, 10, 512])


Unnamed: 0,token,logit
30109,[[,14.78965
198,↩,14.781439
58,[,13.484408
220,·,12.952336
2,#,12.80312
23428,VAL,11.601645
9792,FT,10.868013
5420,ref,10.781559
7753,file,10.532041
12,-,10.437276
