# Setup

In [1]:
%env HF_HOME=/workspace/cache/

env: HF_HOME=/workspace/cache/


## Dependencies

In [2]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install transformer_lens
    %pip install torchtyping
    # Install my janky personal plotting utils
    %pip install git+https://github.com/neelnanda-io/neel-plotly.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
    # Needed for PySvelte to work, v3 came out and broke things...
    %pip install typeguard==2.13.3
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

Running as a Jupyter notebook - intended for development only!


In [3]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
import plotly.express as px

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [4]:
import transformer_lens
from transformer_lens import HookedTransformer, utils
import torch
import einops
import numpy as np
import gradio as gr
import pprint
import json
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from huggingface_hub import HfApi
from IPython.display import HTML
from functools import partial
import tqdm.notebook as tqdm
import pandas as pd

from circuitsvis.attention import attention_heads
from IPython.display import HTML, IFrame

## Defining the Autoencoder

In [5]:
cfg = {
    "seed": 49,
    "batch_size": 4096,
    "buffer_mult": 384,
    "lr": 1e-4,
    "num_tokens": int(2e9),
    "l1_coeff": 3e-4,
    "beta1": 0.9,
    "beta2": 0.99,
    "dict_mult": 8,
    "seq_len": 128,
    "d_mlp": 2048,
    "enc_dtype":"fp32",
    "remove_rare_dir": False,
}
cfg["model_batch_size"] = 64
cfg["buffer_size"] = cfg["batch_size"] * cfg["buffer_mult"]
cfg["buffer_batches"] = cfg["buffer_size"] // cfg["seq_len"]

In [6]:
DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
class AutoEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        d_hidden = cfg["d_mlp"] * cfg["dict_mult"]
        d_mlp = cfg["d_mlp"]
        l1_coeff = cfg["l1_coeff"]
        dtype = DTYPES[cfg["enc_dtype"]]
        torch.manual_seed(cfg["seed"])
        self.W_enc = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_mlp, d_hidden, dtype=dtype)))
        self.W_dec = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, d_mlp, dtype=dtype)))
        self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))
        self.b_dec = nn.Parameter(torch.zeros(d_mlp, dtype=dtype))

        self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)

        self.d_hidden = d_hidden
        self.l1_coeff = l1_coeff

        self.to("cuda")

    def forward(self, x):
        x_cent = x - self.b_dec
        acts = F.relu(x_cent @ self.W_enc + self.b_enc)
        x_reconstruct = acts @ self.W_dec + self.b_dec
        l2_loss = (x_reconstruct.float() - x.float()).pow(2).sum(-1).mean(0)
        l1_loss = self.l1_coeff * (acts.float().abs().sum())
        loss = l2_loss + l1_loss
        return loss, x_reconstruct, acts, l2_loss, l1_loss

    @torch.no_grad()
    def remove_parallel_component_of_grads(self):
        W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
        W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed
        self.W_dec.grad -= W_dec_grad_proj

    # def get_version(self):
    #     return 1+max([int(file.name.split(".")[0]) for file in list(SAVE_DIR.iterdir()) if "pt" in str(file)])

    # def save(self):
    #     version = self.get_version()
    #     torch.save(self.state_dict(), SAVE_DIR/(str(version)+".pt"))
    #     with open(SAVE_DIR/(str(version)+"_cfg.json"), "w") as f:
    #         json.dump(cfg, f)
    #     print("Saved as version", version)

    # def load(cls, version):
    #     cfg = (json.load(open(SAVE_DIR/(str(version)+"_cfg.json"), "r")))
    #     pprint.pprint(cfg)
    #     self = cls(cfg=cfg)
    #     self.load_state_dict(torch.load(SAVE_DIR/(str(version)+".pt")))
    #     return self

    @classmethod
    def load_from_hf(cls, version):
        """
        Loads the saved autoencoder from HuggingFace.

        Version is expected to be an int, or "run1" or "run2"

        version 25 is the final checkpoint of the first autoencoder run,
        version 47 is the final checkpoint of the second autoencoder run.
        """
        if version=="run1":
            version = 25
        elif version=="run2":
            version = 47

        cfg = utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}_cfg.json")
        pprint.pprint(cfg)
        self = cls(cfg=cfg)
        self.load_state_dict(utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}.pt", force_is_torch=True))
        return self


## Utils

### Get Reconstruction Loss

In [7]:
def replacement_hook(mlp_post, hook, encoder):
    mlp_post_reconstr = encoder(mlp_post)[1]
    return mlp_post_reconstr

def mean_ablate_hook(mlp_post, hook):
    mlp_post[:] = mlp_post.mean([0, 1])
    return mlp_post

def zero_ablate_hook(mlp_post, hook):
    mlp_post[:] = 0.
    return mlp_post

@torch.no_grad()
def get_recons_loss(num_batches=5, local_encoder=None):
    if local_encoder is None:
        local_encoder = encoder
    loss_list = []
    for i in range(num_batches):
        tokens = all_tokens[torch.randperm(len(all_tokens))[:cfg["model_batch_size"]]]
        loss = model(tokens, return_type="loss")
        recons_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), partial(replacement_hook, encoder=local_encoder))])
        # mean_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), mean_ablate_hook)])
        zero_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), zero_ablate_hook)])
        loss_list.append((loss, recons_loss, zero_abl_loss))
    losses = torch.tensor(loss_list)
    loss, recons_loss, zero_abl_loss = losses.mean(0).tolist()

    print(f"loss: {loss:.4f}, recons_loss: {recons_loss:.4f}, zero_abl_loss: {zero_abl_loss:.4f}")
    score = ((zero_abl_loss - recons_loss)/(zero_abl_loss - loss))
    print(f"Reconstruction Score: {score:.2%}")
    # print(f"{((zero_abl_loss - mean_abl_loss)/(zero_abl_loss - loss)).item():.2%}")
    return score, loss, recons_loss, zero_abl_loss

### Get Frequencies

In [8]:
# Frequency
@torch.no_grad()
def get_freqs(num_batches=25, local_encoder=None):
    if local_encoder is None:
        local_encoder = encoder
    act_freq_scores = torch.zeros(local_encoder.d_hidden, dtype=torch.float32).cuda()
    total = 0
    for i in tqdm.trange(num_batches):
        tokens = all_tokens[torch.randperm(len(all_tokens))[:cfg["model_batch_size"]]]

        _, cache = model.run_with_cache(tokens, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
        mlp_acts = cache[utils.get_act_name("post", 0)]
        mlp_acts = mlp_acts.reshape(-1, d_mlp)

        hidden = local_encoder(mlp_acts)[2]

        act_freq_scores += (hidden > 0).sum(0)
        total+=hidden.shape[0]
    act_freq_scores /= total
    num_dead = (act_freq_scores==0).float().mean()
    print("Num dead", num_dead)
    return act_freq_scores

## Visualise Feature Utils

In [9]:
from html import escape
import colorsys

from IPython.display import display

SPACE = "·"
NEWLINE="↩"
TAB = "→"

def create_html(strings, values, max_value=None, saturation=0.5, allow_different_length=False, return_string=False):
    # escape strings to deal with tabs, newlines, etc.
    escaped_strings = [escape(s, quote=True) for s in strings]
    processed_strings = [
        s.replace("\n", f"{NEWLINE}<br/>").replace("\t", f"{TAB}&emsp;").replace(" ", "&nbsp;")
        for s in escaped_strings
    ]

    if isinstance(values, torch.Tensor) and len(values.shape)>1:
        values = values.flatten().tolist()

    if not allow_different_length:
        assert len(processed_strings) == len(values)

    # scale values
    if max_value is None:
        max_value = max(max(values), -min(values))+1e-3
    scaled_values = [v / max_value * saturation for v in values]

    # create html
    html = ""
    for i, s in enumerate(processed_strings):
        if i<len(scaled_values):
            v = scaled_values[i]
        else:
            v = 0
        if v < 0:
            hue = 0  # hue for red in HSV
        else:
            hue = 0.66  # hue for blue in HSV
        rgb_color = colorsys.hsv_to_rgb(
            hue, v, 1
        )  # hsv color with hue 0.66 (blue), saturation as v, value 1
        hex_color = "#%02x%02x%02x" % (
            int(rgb_color[0] * 255),
            int(rgb_color[1] * 255),
            int(rgb_color[2] * 255),
        )
        html += f'<span style="background-color: {hex_color}; border: 1px solid lightgray; font-size: 16px; border-radius: 3px;">{s}</span>'
    if return_string:
        return html
    else:
        display(HTML(html))

def basic_feature_vis(text, feature_index, max_val=0):
    feature_in = encoder.W_enc[:, feature_index]
    feature_bias = encoder.b_enc[feature_index]
    _, cache = model.run_with_cache(text, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
    mlp_acts = cache[utils.get_act_name("post", 0)][0]
    feature_acts = F.relu((mlp_acts - encoder.b_dec) @ feature_in + feature_bias)
    if max_val==0:
        max_val = max(1e-7, feature_acts.max().item())
        # print(max_val)
    # if min_val==0:
    #     min_val = min(-1e-7, feature_acts.min().item())
    return basic_token_vis_make_str(text, feature_acts, max_val)
def basic_token_vis_make_str(strings, values, max_val=None):
    if not isinstance(strings, list):
        strings = model.to_str_tokens(strings)
    values = utils.to_numpy(values)
    if max_val is None:
        max_val = values.max()
    # if min_val is None:
    #     min_val = values.min()
    header_string = f"<h4>Max Range <b>{values.max():.4f}</b> Min Range: <b>{values.min():.4f}</b></h4>"
    header_string += f"<h4>Set Max Range <b>{max_val:.4f}</b></h4>"
    # values[values>0] = values[values>0]/ma|x_val
    # values[values<0] = values[values<0]/abs(min_val)
    body_string = create_html(strings, values, max_value=max_val, return_string=True)
    return header_string + body_string
# display(HTML(basic_token_vis_make_str(tokens[0, :10], mlp_acts[0, :10, 7], 0.1)))
# # %%
# The `with gr.Blocks() as demo:` syntax just creates a variable called demo containing all these components
import gradio as gr
try:
    demos[0].close()
except:
    pass
demos = [None]
def make_feature_vis_gradio(feature_id, starting_text=None, batch=None, pos=None):
    if starting_text is None:
        starting_text = model.to_string(all_tokens[batch, 1:pos+1])
    try:
        demos[0].close()
    except:
        pass
    with gr.Blocks() as demo:
        gr.HTML(value=f"Hacky Interactive Neuroscope for gelu-1l")
        # The input elements
        with gr.Row():
            with gr.Column():
                text = gr.Textbox(label="Text", value=starting_text)
                # Precision=0 makes it an int, otherwise it's a float
                # Value sets the initial default value
                feature_index = gr.Number(
                    label="Feature Index", value=feature_id, precision=0
                )
                # # If empty, these two map to None
                max_val = gr.Number(label="Max Value", value=None)
                # min_val = gr.Number(label="Min Value", value=None)
                inputs = [text, feature_index, max_val]
        with gr.Row():
            with gr.Column():
                # The output element
                out = gr.HTML(label="Neuron Acts", value=basic_feature_vis(starting_text, feature_id))
        for inp in inputs:
            inp.change(basic_feature_vis, inputs, out)
    demo.launch(share=True)
    demos[0] = demo

### Inspecting Top Logits

In [10]:
SPACE = "·"
NEWLINE="↩"
TAB = "→"
def process_token(s):
    if isinstance(s, torch.Tensor):
        s = s.item()
    if isinstance(s, np.int64):
        s = s.item()
    if isinstance(s, int):
        s = model.to_string(s)
    s = s.replace(" ", SPACE)
    s = s.replace("\n", NEWLINE+"\n")
    s = s.replace("\t", TAB)
    return s

def process_tokens(l):
    if isinstance(l, str):
        l = model.to_str_tokens(l)
    elif isinstance(l, torch.Tensor) and len(l.shape)>1:
        l = l.squeeze(0)
    return [process_token(s) for s in l]

def process_tokens_index(l):
    if isinstance(l, str):
        l = model.to_str_tokens(l)
    elif isinstance(l, torch.Tensor) and len(l.shape)>1:
        l = l.squeeze(0)
    return [f"{process_token(s)}/{i}" for i,s in enumerate(l)]

def create_vocab_df(logit_vec, make_probs=False, full_vocab=None):
    if full_vocab is None:
        full_vocab = process_tokens(model.to_str_tokens(torch.arange(model.cfg.d_vocab)))
    vocab_df = pd.DataFrame({"token": full_vocab, "logit": utils.to_numpy(logit_vec)})
    if make_probs:
        vocab_df["log_prob"] = utils.to_numpy(logit_vec.log_softmax(dim=-1))
        vocab_df["prob"] = utils.to_numpy(logit_vec.softmax(dim=-1))
    return vocab_df.sort_values("logit", ascending=False)

### Make Token DataFrame

In [11]:
def list_flatten(nested_list):
    return [x for y in nested_list for x in y]
def make_token_df(tokens, len_prefix=5, len_suffix=1):
    str_tokens = [process_tokens(model.to_str_tokens(t)) for t in tokens]
    unique_token = [[f"{s}/{i}" for i, s in enumerate(str_tok)] for str_tok in str_tokens]

    context = []
    batch = []
    pos = []
    label = []
    for b in range(tokens.shape[0]):
        # context.append([])
        # batch.append([])
        # pos.append([])
        # label.append([])
        for p in range(tokens.shape[1]):
            prefix = "".join(str_tokens[b][max(0, p-len_prefix):p])
            if p==tokens.shape[1]-1:
                suffix = ""
            else:
                suffix = "".join(str_tokens[b][p+1:min(tokens.shape[1]-1, p+1+len_suffix)])
            current = str_tokens[b][p]
            context.append(f"{prefix}|{current}|{suffix}")
            batch.append(b)
            pos.append(p)
            label.append(f"{b}/{p}")
    # print(len(batch), len(pos), len(context), len(label))
    return pd.DataFrame(dict(
        str_tokens=list_flatten(str_tokens),
        unique_token=list_flatten(unique_token),
        context=context,
        batch=batch,
        pos=pos,
        label=label,
    ))

## Loading the Model

In [12]:
model = HookedTransformer.from_pretrained("gelu-1l").to(DTYPES[cfg["enc_dtype"]])
n_layers = model.cfg.n_layers
d_model = model.cfg.d_model
n_heads = model.cfg.n_heads
d_head = model.cfg.d_head
d_mlp = model.cfg.d_mlp
d_vocab = model.cfg.d_vocab

Loaded pretrained model gelu-1l into HookedTransformer
Changing model dtype to torch.float32


## Loading Data

In [13]:
data = load_dataset("NeelNanda/c4-code-20k", split="train")
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(42)
all_tokens = tokenized_data["tokens"]

## Autoencoder: Finding Higher Frequency Tokens

In [14]:
auto_encoder_run = "run1" # @param ["run1", "run2"]
encoder = AutoEncoder.load_from_hf(auto_encoder_run)

{'batch_size': 4096,
 'beta1': 0.9,
 'beta2': 0.99,
 'buffer_batches': 12288,
 'buffer_mult': 384,
 'buffer_size': 1572864,
 'd_mlp': 2048,
 'dict_mult': 8,
 'enc_dtype': 'fp32',
 'l1_coeff': 0.0003,
 'lr': 0.0001,
 'model_batch_size': 512,
 'num_tokens': 2000000000,
 'seed': 52,
 'seq_len': 128}


In [15]:
_ = get_recons_loss(num_batches=20, local_encoder=encoder)

loss: 3.2575, recons_loss: 3.7475, zero_abl_loss: 8.7729
Reconstruction Score: 91.12%


In [16]:
freqs = get_freqs(num_batches = 50, local_encoder = encoder)
is_rare = freqs < 1e-4

  0%|          | 0/50 [00:00<?, ?it/s]

Num dead tensor(6.1035e-05, device='cuda:0')


### A  `('` feature in GELU-1L


- 5018 : Cool features that boosts the logits of the modes of communication, head 4 is interesting


- 2410 :- polysemantic neuron
- 2393 :- Error Token
- 1625 :- Blue token 
- 5090 :- tokens following `of`

In [17]:
feature_id = 5090 # @param {type:"number"}
batch_size = 128 # @param {type:"number"}

print(f"Feature freq: {freqs[feature_id].item():.5f}")
print(f"Is Rare: {is_rare[feature_id]}")

Feature freq: 0.01118
Is Rare: False


In [18]:
with torch.no_grad():
    tokens = all_tokens[batch_size*0:batch_size*10]
    _, cache = model.run_with_cache(tokens, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
    mlp_acts = cache[utils.get_act_name("post", 0)]
    mlp_acts_flattened = mlp_acts.reshape(-1, cfg["d_mlp"])
    loss, x_reconstruct, hidden_acts, l2_loss, l1_loss = encoder(mlp_acts_flattened)
    # This is equivalent to:
    # hidden_acts = F.relu((mlp_acts_flattened - encoder.b_dec) @ encoder.W_enc + encoder.b_enc)
    print("hidden_acts.shape", hidden_acts.shape)

hidden_acts.shape torch.Size([163840, 16384])


Maximum activating examples

In [19]:
token_df = make_token_df(tokens)
token_df["feature"] = utils.to_numpy(hidden_acts[:, feature_id])
token_df.sort_values("feature", ascending=False).head(20).style.background_gradient("coolwarm")

Unnamed: 0,str_tokens,unique_token,context,batch,pos,label,feature
117396,·technique,·technique/20,packet·sniffing·type·of|·technique|·has,917,20,917/20,1.536394
67267,·size,·size/67,·Returns·the·filter·kernel·of|·size|·(,525,67,525/67,1.530362
82757,·size,·size/69,·can·be·a·vector·of|·size|·one,646,69,646/69,1.464729
143900,·shape,·shape/28,:`ndarray`·of|·shape|·`,1124,28,1124/28,1.374878
8054,·culture,·culture/118,am·stressed·the·importance·of|·culture|·in,62,118,62/118,1.318117
6557,·culture,·culture/29,"’s·the·hub·of|·culture|,",51,29,51/29,1.271067
158835,·bike,·bike/115,·smells·of·the·rubber·of|·bike|·tires,1240,115,1240/115,1.250245
105191,·flavor,·flavor/103,·its’·massive·explosion·of|·flavor|!,821,103,821/103,1.24441
82891,·cooperative,·cooperative/75,·management·and·rollout·of|·cooperative|·marketing,647,75,647/75,1.208105
96031,·dimension,·dimension/31,nSparse·histogram·of|·dimension|·2,750,31,750/31,1.205151


Logit weights: To understand what logits the SAE features boosts the most

In [20]:
logit_effect = encoder.W_dec[feature_id] @ model.W_out[0] @ model.W_U
create_vocab_df(logit_effect).head(10).style.background_gradient("coolwarm")

Unnamed: 0,token,logit
3876,lation,1.497491
23030,lessness,1.497241
46767,·awaited,1.401943
2961,·needed,1.365549
45470,reth,1.307851
21571,heid,1.301885
31519,·borne,1.274774
44266,ickness,1.236776
5375,encing,1.221099
1295,ality,1.211532


direct path and de-embedding

##### Tried using backward hooks but need some time to debug to understand how to setup the intermediate activations for the backward pass. Will use quick hack for now.

In [21]:
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [None]:
s = model.to_string(tokens[1088, :56])
token_example = model.to_tokens(s)
logits, cache = model.run_with_cache(token_example, stop_at_layer=1)
print(s)

w_enc = encoder.W_enc[:, feature_id]

In [None]:
# tried to use backward hooks but need to spend time to debug the issue here. Will use the quick hack for now.

# model.set_use_attn_result(True)
# model.set_use_attn_in(True)
# model.set_use_hook_mlp_in(True)

# filter_not_qkv_input = lambda name: "_input" not in name
# model.reset_hooks()
# cache = {}
# def forward_cache_hook(act, hook):
#     cache[hook.name] = act.detach()
# model.add_hook(filter_not_qkv_input, forward_cache_hook, "fwd")

# grad_cache = {}
# def backward_cache_hook(act, hook):
#     grad_cache[hook.name] = act.detach()
# model.add_hook(filter_not_qkv_input, backward_cache_hook, "bwd")

# logits = model(token_example)

# cache['blocks.0.mlp.hook_post'].requires_grad_(True)
# x_mlpout = cache['blocks.0.mlp.hook_post'][0, -1]
# feature_act = einops.einsum(x_mlpout, w_enc, "hdim, hdim -> ")
# feature_act.backward()

# these are the cached activations of the model
# fwd_cache = ActivationCache(cache, model)

# do the backward pass the on SAE feature activations
# x_mlpout = cache['blocks.0.mlp.hook_post'][0, -1]
# w_enc = encoder.W_enc[:, feature_id]
# feature_act = einops.einsum(x_mlpout, w_enc, "hdim, hdim -> ")
# feature_act.backward()

# bwd_cache = ActivationCache(grad_cache, model)

# model.reset_hooks()

##### Quick hack using a seprate MLP model for now

In [21]:
# run this cell just once, don't need to reload the MLP for each feature
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc = nn.Linear(model.cfg.d_model ,model.cfg.d_mlp)
        self.act = torch.nn.GELU()

        #? Initialize the linear layer with the provided weights and biases
        self.fc.weight.data = model.W_in[0].T.detach()
        self.fc.bias.data = model.b_in[0].detach()
    
    def forward(self, x):
        return self.act(self.fc(x))

mlp_model = MLP()
mlp_model.to("cuda:0")

MLP(
  (fc): Linear(in_features=512, out_features=2048, bias=True)
  (act): GELU(approximate='none')
)

In [22]:
with torch.no_grad():
    s = model.to_string(all_tokens[917, 1:21])
    token_example = model.to_tokens(s)
    logits, cache = model.run_with_cache(token_example, stop_at_layer=1)
    print(s)

 in the POPvX. This 'data' file swapping/packet sniffing type of technique


In [23]:
token_example


tensor([[    1,   276,   254, 41261,    88,    58,    16,   826,   684,  2137,
             9,  1819,  1810,  5276,    17, 27581, 30367,   273,  1479,   274,
          5686]], device='cuda:0')

In [24]:
print(model.to_tokens(" aren't"))
print(model.to_str_tokens(" aren't"))

tensor([[   1, 6231,  625]], device='cuda:0')
['<|BOS|>', ' aren', "'t"]


In [25]:
# This is the input x to the MLP Layer 
x_mid = torch.autograd.Variable(cache['blocks.0.ln2.hook_normalized'][0, -1], requires_grad=True)
print("Input Shape: ", x_mid.shape)

# encoder feature direction 
w_enc = encoder.W_enc[:, feature_id]

x_mlpout = mlp_model(x_mid)
print("MLP Output Shape: ", x_mlpout.shape)

# print("The post MLP activations match well: ", torch.abs(cache['blocks.0.mlp.hook_post'][0, -1] - x_mlpout).max().item())
feature_act = einops.einsum(x_mlpout, w_enc, "hdim, hdim -> ")
actual_feature_act = einops.einsum(cache['blocks.0.mlp.hook_post'][0, -1], w_enc, "hdim, hdim -> ")
print("Feature Activation: ", feature_act)
print("Actual Feature Activation: ", actual_feature_act)

# Do the backward pass on the activation to get the linear approximation
feature_act.backward()

#! this is the feature vector in MLP input space
n_mid = x_mid.grad

Input Shape:  torch.Size([512])
MLP Output Shape:  torch.Size([2048])
Feature Activation:  tensor(1.9229, device='cuda:0', grad_fn=<ViewBackward0>)
Actual Feature Activation:  tensor(1.9229, device='cuda:0', grad_fn=<ViewBackward0>)


This is an approximation for how much each token in the model's vocabulary contributes to activating the original SAE Features

In [26]:
def feature_score_df(score_vec, make_probs=False, full_vocab=None, ascending=False):
    full_vocab = process_tokens(model.to_str_tokens(torch.arange(model.cfg.d_vocab)))
    vocab_df = pd.DataFrame({"token": full_vocab, "feature_scores": utils.to_numpy(score_vec)})
    return vocab_df.sort_values("feature_scores", ascending=ascending)

De-embedding token scores for the direct path to the linearized SAE feature

In [27]:
# direct path -> through de-embedding we can analyze how much each token contributes to the feature in the MLP input space
direct_path_scores = einops.einsum(model.W_E, n_mid, "n_vocab dim, dim -> n_vocab")
feature_score_df(direct_path_scores).head(10).style.background_gradient("coolwarm")

Unnamed: 0,token,feature_scores
44535,·pleural,0.330833
31259,·McN,0.330375
16029,�,0.326469
9784,·Administ,0.321786
1666,·associ,0.319722
34144,·Tort,0.318106
17445,·Iss,0.310321
4238,·Mich,0.309697
27281,·ż,0.306936
43179,·noct,0.303841


#### Attention: Analyzing the OV Circuit and the QK Circuit

In [28]:
print(s)

 in the POPvX. This 'data' file swapping/packet sniffing type of technique


In [29]:
str_tokens = model.to_str_tokens(s)

patterns = cache['blocks.0.attn.hook_pattern'][0] # attention pattern for the layer0 
print("attention pattern shape: ", patterns.shape)
labels = [str(i) for i in range(8)]

plot = attention_heads(attention=patterns, tokens=str_tokens, attention_head_names=labels).show_code()

# Display the title
title = "Attention Pattern Analysis"
title_html = f"<h2>{title}</h2><br/>"

# Return the visualisation as raw code
html = f"<div style='max-width: {str(700)}px;'>{title_html + plot}</div>"

attention pattern shape:  torch.Size([8, 21, 21])


In [30]:
print(str_tokens)

['<|BOS|>', ' in', ' the', ' POP', 'v', 'X', '.', ' This', " '", 'data', "'", ' file', ' sw', 'apping', '/', 'packet', ' sniff', 'ing', ' type', ' of', ' technique']


In [31]:
HTML(html)

In [32]:
# torch.set_printoptions(precision=4, sci_mode=False)
j = 20 #? this is the destination token

W_OV = einops.einsum(model.W_V[0], model.W_O[0], "head d1 hdim, head hdim d2 -> head d1 d2")
print("OV Shape: ", W_OV.shape)

x_pre = cache['blocks.0.ln1.hook_normalized'][0]
print("pre attention scores: ", x_pre.shape)

headwise_OV_scores = einops.einsum(x_pre, W_OV, "pos dim1, head dim1 dim2 -> head pos dim2")
print("Headwise OV Scores: ", headwise_OV_scores.shape)

direct_score_pre_attn = einops.einsum(n_mid, headwise_OV_scores,  "dim, head pos dim -> head pos")
direct_score_pst_attn = einops.einsum(direct_score_pre_attn, cache['blocks.0.attn.hook_pattern'][0, :, j], "head pos, head pos -> head pos")
print("Direct Attribution Scores Shape: ", direct_score_pst_attn.shape)

OV Shape:  torch.Size([8, 512, 512])
pre attention scores:  torch.Size([21, 512])
Headwise OV Scores:  torch.Size([8, 21, 512])
Direct Attribution Scores Shape:  torch.Size([8, 21])


In [33]:
fig = px.imshow(utils.to_numpy(direct_score_pst_attn.detach().cpu()), 
            color_continuous_midpoint=0.0, 
            color_continuous_scale="RdBu", 
            labels={"x": "Position", "y": "Head"}, 
            title="Direct Attribution Scores",
            x=[f"{tok} {i}" for i, tok in enumerate(str_tokens)] )

# writing to a file since show wasn't working in vscode properly
fig.write_image("direct_score_attn_heads.png")

In [172]:
print(str_tokens[16])

 virtual


Perform de-embedding for the OV Circuis

In [34]:
headwise_pre_attn_vector = einops.einsum(n_mid, W_OV, "dim2, head dim1 dim2 -> head dim1")
print("Headwise Pre Vector Shape: ", headwise_pre_attn_vector.shape)

# perform de-embedding of the token
direct_attn_scores = einops.einsum(model.W_E, headwise_pre_attn_vector, "n_vocab dim, head dim -> head n_vocab")

Headwise Pre Vector Shape:  torch.Size([8, 512])


In [35]:
feature_score_df(direct_attn_scores[0], ascending=False ).head(10).style.background_gradient("coolwarm")

Unnamed: 0,token,feature_scores
274,·of,0.589367
4539,·Of,0.329345
4387,Of,0.269217
1153,of,0.2581
40537,·στο,0.16445
1107,frac,0.15697
12970,·nas,0.155688
17184,·το,0.154302
7496,·san,0.145928
3368,·OF,0.145797
