# 001 - Early Exiting

We want to use SAEs features to help perform early exiting. The idea is to merge the knowledge on the task (e.g. python programming) with interpretable features.

### Setup

In [1]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install circuitsvis sae-lens transformer-lens einsum
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import torch
from tqdm import tqdm
import plotly.express as px

# Imports for displaying vis in Colab / notebook
import webbrowser
import http.server
import socketserver
import threading
PORT = 8000

torch.set_grad_enabled(False);

In [2]:
import torch 
# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: cuda


In [3]:
def display_vis_inline(filename: str, height: int = 850):
    '''
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    '''
    if not(COLAB):
        webbrowser.open(filename);

    else:
        global PORT

        def serve(directory):
            os.chdir(directory)

            # Create a handler for serving files
            handler = http.server.SimpleHTTPRequestHandler

            # Create a socket server with the handler
            with socketserver.TCPServer(("", PORT), handler) as httpd:
                print(f"Serving files from {directory} on port {PORT}")
                httpd.serve_forever()

        thread = threading.Thread(target=serve, args=("/content",))
        thread.start()

        output.serve_kernel_port_as_iframe(PORT, path=f"/{filename}", height=height, cache_in_notebook=True)

        PORT += 1

### Loading SAEs

In [4]:
from huggingface_hub import snapshot_download

REPO_ID = "jbloom/GPT2-Small-SAEs-Reformatted"
path = snapshot_download(repo_id=REPO_ID)

Fetching 41 files:   0%|          | 0/41 [00:00<?, ?it/s]

In [5]:
from sae_lens import LMSparseAutoencoderSessionloader
from huggingface_hub import snapshot_download
import os

layer = 8
SUBFOLDER = f"blocks.{layer}.hook_resid_pre"

model, sae_group, activation_store = LMSparseAutoencoderSessionloader.load_pretrained_sae(
    path = os.path.join(path, SUBFOLDER), device=device
)
sae_group.eval()
sae = sae_group[f'blocks.{layer}.hook_resid_pre']

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


### Testing GPT2 on Python and HTML

In [6]:
from datasets import load_dataset
from transformer_lens import utils

# Load and prepare the dataset
dataset = load_dataset('NeelNanda/c4-code-20k')['train']
tokens = utils.tokenize_and_concatenate(dataset, model.tokenizer, max_length=512)['tokens']

### Reconstruction test

We first test how the SAE handles reconstruction of python code.

In [17]:
from transformer_lens import utils
from functools import partial

# next we want to do a reconstruction test.
def reconstr_hook(activation, hook, sae_out):
    return sae_out

def zero_abl_hook(activation, hook):
    return torch.zeros_like(activation)

def reconstruction_test(batch_tokens, component, layer, sae_out):

    original = model(batch_tokens, return_type="loss").item()
    reconstruction = model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[
            (
                utils.get_act_name(component, layer),
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item()
    zero = model.run_with_hooks(
        batch_tokens,
        return_type="loss",
        fwd_hooks=[(utils.get_act_name(component, layer), zero_abl_hook)],
    ).item()

    print("Original:", original)
    print("Reconstruction:", reconstruction)
    print("Zero Ablation:", zero)
    print("Ratio:", (reconstruction - zero) / (original - zero))

In [18]:
batch_tokens = tokens[:16]

with torch.no_grad():
    logits, cache = model.run_with_cache(batch_tokens, prepend_bos=True)
    sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sae(
        cache[sae_group.cfg.hook_point]
    )

reconstruction_test(batch_tokens, 'resid_pre', 8, sae_out)

Original: 2.0879335403442383
Reconstruction: 2.8737123012542725
Zero Ablation: 11.773224830627441
Ratio: 0.9188688561490794


In [155]:
from transformer_lens import utils

example_prompt = "import pandas as pd\nimport numpy as"
example_answer = " np"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'import', ' pand', 'as', ' as', ' p', 'd', '\n', 'import', ' n', 'umpy', ' as']
Tokenized answer: [' np']


Top 0th token. Logit: 29.40 Prob: 99.00% Token: | np|
Top 1th token. Logit: 24.71 Prob:  0.91% Token: | p|
Top 2th token. Logit: 21.96 Prob:  0.06% Token: | n|
Top 3th token. Logit: 19.13 Prob:  0.00% Token: | u|
Top 4th token. Logit: 18.94 Prob:  0.00% Token: | m|
Top 5th token. Logit: 18.85 Prob:  0.00% Token: | pl|
Top 6th token. Logit: 18.81 Prob:  0.00% Token: | pi|
Top 7th token. Logit: 18.60 Prob:  0.00% Token: | r|
Top 8th token. Logit: 17.95 Prob:  0.00% Token: | pc|
Top 9th token. Logit: 17.69 Prob:  0.00% Token: | mm|


In [None]:
for l in range(12):
    SUBFOLDER = f"blocks.{l}.hook_resid_pre"
    print(f"\n\nLAYER {l}")

    model, sae_group, activation_store = LMSparseAutoencoderSessionloader.load_pretrained_sae(
        path = os.path.join(path, SUBFOLDER), device=device
    )
    sae_group.eval()
    sae = sae_group[f'blocks.{l}.hook_resid_pre']

    hook_point = sae_group.cfg.hook_point

    logits, cache = model.run_with_cache(example_prompt, prepend_bos=True)
    sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sae(
        cache[sae_group.cfg.hook_point]
    )

    with model.hooks(
        fwd_hooks=[
            (
                hook_point,
                partial(reconstr_hook, sae_out=sae_out),
            )
        ]
    ): utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

**Results**

start: 99.0%
* L0: 98.8% *
* L1: 97.1% *
* L2: 52.8% *
* L3: 47.3% *
* L4: 77.0% *
* L5: 24.5%
* L6: 9.0%
* L7: 13.2%
* L8: 2.3%
* L9: 2.9%
* L10: 0.5%
* L11: 3.5%

The first objective is to find the features related to Python. We can use lienar probing for that.

In [None]:
reconstruction_test(code_tokens[:1], 'resid', 8, sae_out)

In [None]:
import circuitsvis as cv  # optional dep, install with pip install circuitsvis

# Let's make a longer prompt and see the log probabilities of the tokens
logits, cache = model.run_with_cache(example_prompt)
cv.logits.token_log_probs(
    model.to_tokens(example_prompt),
    model(example_prompt)[0].log_softmax(dim=-1),
    model.to_string,
)
# hover on the output to see the result.

In [None]:
generation = model.generate(
    "import pandas as pd\n",
    stop_at_eos=True,
    temperature=0,
    verbose=True,
    max_new_tokens=32,
)
logits, cache = model.run_with_cache(generation)
cv.logits.token_log_probs(
    model.to_tokens(generation),
    model(generation)[0].log_softmax(dim=-1),
    model.to_string,
)

  0%|          | 0/32 [00:00<?, ?it/s]

### Probing

In [None]:
code_features = []
wiki_features = []
batch_size = 4

for b in tqdm(range(0, 100, batch_size)):
    # Code
    with torch.no_grad():
        _, cache = model.run_with_cache(code_tokens[b:b+batch_size])

        sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sae(
            cache[sae.cfg.hook_point]
        )

    code_features.append(feature_acts.cpu())
    del cache

    # Wiki
    with torch.no_grad():
        _, cache = model.run_with_cache(wiki_tokens[b:b+batch_size])

        sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sae(
            cache[sae.cfg.hook_point]
        )

    wiki_features.append(feature_acts.cpu())
    del cache

code_features = torch.cat(code_features, dim=0)
wiki_features = torch.cat(wiki_features, dim=0)

100%|██████████| 25/25 [00:31<00:00,  1.24s/it]


In [None]:
import numpy as np
idx = torch.tensor(np.random.choice(range(1024), 128, replace=False))

In [None]:
n_features = code_features.shape[-1]

X = torch.cat([
    code_features[:, idx], wiki_features[:, idx]
], dim=0).reshape(-1, n_features)

y = torch.cat([
    torch.ones(X.shape[0] // 2), torch.zeros(X.shape[0] // 2)
])

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from sklearn.linear_model import LogisticRegression

# Split into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.5, random_state=42, stratify=y)

f1_scores = {}

# Iterate over each variable
for i in tqdm(range(X.shape[1])):
    # Reset model parameters
    lr = LogisticRegression(penalty=None)

    lr.fit(X_train[:, i, None], y_train)

    # Evaluate the lr on the validation set
    y_pred = lr.predict(X_val[:, i, None])
    f1 = f1_score(y_val, y_pred)

    # Store the F1 score for the current variable
    f1_scores[i] = f1

100%|██████████| 24576/24576 [06:27<00:00, 63.38it/s]


In [None]:
# Sort the variables based on F1 score in descending order
sorted_vars = sorted(f1_scores.items(), key=lambda x: x[1], reverse=True)

# Get the top variable index and its F1 score
top_var_idx, top_var_f1 = sorted_vars[0]

print(f"The top variable is {top_var_idx} with an F1 score of {top_var_f1:.4f}")

The top variable is 21130 with an F1 score of 0.7168


In [None]:
from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list

get_neuronpedia_quick_list([21130], layer=8)

'https://neuronpedia.org/quick-list/?name=temporary_list&features=%5B%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%228-res-jb%22%2C%20%22index%22%3A%20%2221130%22%7D%5D'

## Fine-tuning

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SparseAutoencoder(nn.Module):
    def __init__(self, input_size, hidden_size, lambda_=1e-5):
        super(SparseAutoencoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.lambda_ = lambda_

        # Initialize weights and biases as nn.Parameters
        self.W_enc = nn.Parameter(torch.randn(input_size, hidden_size) * 0.01)
        self.b_enc = nn.Parameter(torch.zeros(hidden_size))
        self.W_dec = nn.Parameter(torch.randn(hidden_size, input_size) * 0.01)
        self.b_dec = nn.Parameter(torch.zeros(input_size))

    def forward(self, x):
        encoded = F.linear(x, self.W_enc.T, self.b_enc)
        encoded = F.relu(encoded)
        decoded = F.linear(encoded, self.W_dec.T, self.b_dec)
        return encoded, decoded

    def loss_function(self, x):
        encoded, decoded = self.forward(x)
        l2_loss = F.mse_loss(decoded, x, reduction='sum')
        l1_loss = self.lambda_ * torch.sum(torch.abs(encoded))
        total_loss = l2_loss + l1_loss
        return total_loss, l2_loss, l1_loss

    def from_hooked(self, sae):
        assert sae.W_enc.shape == self.W_enc.shape
        assert sae.b_enc.shape == self.b_enc.shape
        assert sae.W_dec.shape == self.W_dec.shape
        assert sae.b_dec.shape == self.b_dec.shape
        
        self.W_enc.data = sae.W_enc.data.clone()
        self.b_enc.data = sae.b_enc.data.clone()
        self.W_dec.data = sae.W_dec.data.clone()
        self.b_dec.data = sae.b_dec.data.clone()

### Training

In [12]:
torch_sae = SparseAutoencoder(
    input_size=sae.d_in,
    hidden_size=sae.d_sae,
    lambda_=sae.cfg.l1_coefficient
)

torch_sae.from_hooked(sae)

In [13]:
import torch
from transformers import AutoModel
from peft import LoraConfig, PeftModel
import wandb
from tqdm import tqdm

num_epochs = 100

# Set up the LoRA configuration
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["encoder", "decoder"],
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM"
)

# Create the LoRA model
lora_model = PeftModel(torch_sae, lora_config)

In [14]:
# Initialize wandb
wandb.init(project="sae-fine-tuning")

num_epochs = 1

# Set up the training loop
optimizer = torch.optim.AdamW(lora_model.parameters(), lr=3e-4)
for epoch in range(num_epochs):
    for batch in tqdm(tokens):
        inputs = batch['tokens']

        with torch.no_grad():
            _, cache = model.run_with_cache(inputs)

        act = cache['blocks.8.hook_resid_pre'].reshape(-1, sae.d_in) # bs * pos, d_in
        del cache
        
        loss, l2_loss, l1_loss = lora_model.loss_function(act)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # Log losses to wandb
        wandb.log({"loss": loss.item(), "mse_loss": l2_loss.item(), "l1_loss": l1_loss.item()})

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdavide-ghilardi0[0m. Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 80770/80770 [37:20<00:00, 36.04it/s]


### Evaluation

In [10]:
lora_sae = SparseAutoencoder(sae.d_in, sae.d_sae)
lora_sae.from_hooked(sae)
state_dict = torch.load('../models/lora_model.pth')

# Adjust LoRA weights
lora_encoder = torch.matmul(
    state_dict["base_model.model.encoder.lora_B.default.weight"],
    state_dict["base_model.model.encoder.lora_A.default.weight"]
)

lora_decoder = torch.matmul(
    state_dict["base_model.model.decoder.lora_B.default.weight"],
    state_dict["base_model.model.decoder.lora_A.default.weight"]
)

lora_sae.W_enc += lora_encoder.T
lora_sae.W_dec += lora_decoder.T

In [19]:
batch_tokens = tokens[:16]

with torch.no_grad():
    logits, cache = model.run_with_cache(batch_tokens, prepend_bos=True)
    encoded, decoded = lora_sae(
        cache[sae_group.cfg.hook_point]
    )

reconstruction_test(batch_tokens, 'resid_pre', 8, decoded)

Original: 2.0879335403442383
Reconstruction: 2.3907034397125244
Zero Ablation: 11.773224830627441
Ratio: 0.968739205637311


In [148]:
from transformer_lens import utils

example_prompt = """
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>My Webpage</title>
    <style>
        .container {
            max-width: 800px;
            margin: 20px auto;
            padding: 0 20px;
        }
    </style>
</head>
<body>
    <header>
        <h1>Welcome to My Webpage</h1>
    </header>
    <div class="container">
        <section>
            <h2>About Us</h2>
            <p>This is a sample webpage.</p>
        </section>
        <section>
            <h2>Contact Us</h2>
            <p>You can reach us at example@email.com</p>
        </section>
    </"""
example_answer = "num"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '\n', '<', '!', 'DO', 'CT', 'Y', 'PE', ' html', '>', '\n', '<', 'html', ' lang', '="', 'en', '">', '\n', '<', 'head', '>', '\n', ' ', ' ', ' ', ' <', 'meta', ' chars', 'et', '="', 'UTF', '-', '8', '">', '\n', ' ', ' ', ' ', ' <', 'meta', ' name', '="', 'view', 'port', '"', ' content', '="', 'width', '=', 'device', '-', 'width', ',', ' initial', '-', 'scale', '=', '1', '.', '0', '">', '\n', ' ', ' ', ' ', ' <', 'title', '>', 'My', ' Web', 'page', '</', 'title', '>', '\n', ' ', ' ', ' ', ' <', 'style', '>', '\n', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' .', 'container', ' {', '\n', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' max', '-', 'width', ':', ' 800', 'px', ';', '\n', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' margin', ':', ' 20', 'px', ' auto', ';', '\n', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' padding', ':', ' 0', ' 20', 'px', ';', '\n', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' }', '\n', ' ', ' ', ' ', ' </',

Top 0th token. Logit: 19.97 Prob: 54.19% Token: |body|
Top 1th token. Logit: 19.50 Prob: 34.18% Token: |div|
Top 2th token. Logit: 17.37 Prob:  4.06% Token: |section|
Top 3th token. Logit: 15.92 Prob:  0.95% Token: |html|
Top 4th token. Logit: 15.74 Prob:  0.79% Token: |foot|
Top 5th token. Logit: 15.70 Prob:  0.76% Token: |p|
Top 6th token. Logit: 15.67 Prob:  0.74% Token: |header|
Top 7th token. Logit: 15.30 Prob:  0.51% Token: |block|
Top 8th token. Logit: 14.18 Prob:  0.17% Token: |head|
Top 9th token. Logit: 14.07 Prob:  0.15% Token: | body|


In [149]:
SUBFOLDER = f"blocks.8.hook_resid_pre"

hook_point = sae_group.cfg.hook_point

logits, cache = model.run_with_cache(example_prompt, prepend_bos=True)
sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sae(
    cache[sae_group.cfg.hook_point]
)

with model.hooks(
    fwd_hooks=[
        (
            hook_point,
            partial(reconstr_hook, sae_out=sae_out),
        )
    ]
): utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '\n', '<', '!', 'DO', 'CT', 'Y', 'PE', ' html', '>', '\n', '<', 'html', ' lang', '="', 'en', '">', '\n', '<', 'head', '>', '\n', ' ', ' ', ' ', ' <', 'meta', ' chars', 'et', '="', 'UTF', '-', '8', '">', '\n', ' ', ' ', ' ', ' <', 'meta', ' name', '="', 'view', 'port', '"', ' content', '="', 'width', '=', 'device', '-', 'width', ',', ' initial', '-', 'scale', '=', '1', '.', '0', '">', '\n', ' ', ' ', ' ', ' <', 'title', '>', 'My', ' Web', 'page', '</', 'title', '>', '\n', ' ', ' ', ' ', ' <', 'style', '>', '\n', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' .', 'container', ' {', '\n', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' max', '-', 'width', ':', ' 800', 'px', ';', '\n', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' margin', ':', ' 20', 'px', ' auto', ';', '\n', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' padding', ':', ' 0', ' 20', 'px', ';', '\n', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' }', '\n', ' ', ' ', ' ', ' </',

Top 0th token. Logit: 23.53 Prob: 43.35% Token: |body|
Top 1th token. Logit: 22.65 Prob: 17.97% Token: |section|
Top 2th token. Logit: 22.44 Prob: 14.60% Token: |h|
Top 3th token. Logit: 22.29 Prob: 12.61% Token: |p|
Top 4th token. Logit: 21.31 Prob:  4.73% Token: |div|
Top 5th token. Logit: 20.34 Prob:  1.79% Token: |span|
Top 6th token. Logit: 19.29 Prob:  0.63% Token: |li|
Top 7th token. Logit: 18.96 Prob:  0.45% Token: |style|
Top 8th token. Logit: 18.80 Prob:  0.38% Token: |head|
Top 9th token. Logit: 18.71 Prob:  0.35% Token: |header|


In [150]:
SUBFOLDER = f"blocks.8.hook_resid_pre"

hook_point = sae_group.cfg.hook_point

logits, cache = model.run_with_cache(example_prompt, prepend_bos=True)
encoded, decoded = lora_sae(
    cache[sae_group.cfg.hook_point]
)

with model.hooks(
    fwd_hooks=[
        (
            hook_point,
            partial(reconstr_hook, sae_out=decoded),
        )
    ]
): utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '\n', '<', '!', 'DO', 'CT', 'Y', 'PE', ' html', '>', '\n', '<', 'html', ' lang', '="', 'en', '">', '\n', '<', 'head', '>', '\n', ' ', ' ', ' ', ' <', 'meta', ' chars', 'et', '="', 'UTF', '-', '8', '">', '\n', ' ', ' ', ' ', ' <', 'meta', ' name', '="', 'view', 'port', '"', ' content', '="', 'width', '=', 'device', '-', 'width', ',', ' initial', '-', 'scale', '=', '1', '.', '0', '">', '\n', ' ', ' ', ' ', ' <', 'title', '>', 'My', ' Web', 'page', '</', 'title', '>', '\n', ' ', ' ', ' ', ' <', 'style', '>', '\n', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' .', 'container', ' {', '\n', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' max', '-', 'width', ':', ' 800', 'px', ';', '\n', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' margin', ':', ' 20', 'px', ' auto', ';', '\n', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' padding', ':', ' 0', ' 20', 'px', ';', '\n', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' }', '\n', ' ', ' ', ' ', ' </',

Top 0th token. Logit: 22.64 Prob: 71.55% Token: |body|
Top 1th token. Logit: 20.92 Prob: 12.81% Token: |div|
Top 2th token. Logit: 20.84 Prob: 11.90% Token: |section|
Top 3th token. Logit: 18.51 Prob:  1.15% Token: |p|
Top 4th token. Logit: 17.62 Prob:  0.47% Token: |header|
Top 5th token. Logit: 17.53 Prob:  0.43% Token: |html|
Top 6th token. Logit: 17.09 Prob:  0.28% Token: |span|
Top 7th token. Logit: 16.45 Prob:  0.15% Token: |block|
Top 8th token. Logit: 16.28 Prob:  0.12% Token: |h|
Top 9th token. Logit: 15.94 Prob:  0.09% Token: | body|


In [87]:
from fancy_einsum import einsum

token = model.to_tokens(' np', prepend_bos=False)[0]

base_sim = einsum(
    'i j, j k -> i k',
    sae.W_dec, model.W_U[:, token]
)

lora_sim = einsum(
    'i j, j k -> i k',
    lora_sae.W_dec, model.W_U[:, token]
)

In [88]:
top_base_features = base_sim.argsort(dim=0, descending=True)[:10]
bot_base_features = base_sim.argsort(dim=0, descending=False)[:10]

top_lora_features = lora_sim.argsort(dim=0, descending=True)[:10]
bot_lora_features = lora_sim.argsort(dim=0, descending=False)[:10]

In [133]:
top_lora_features

tensor([[21963],
        [ 6804],
        [ 8535],
        [13321],
        [10244],
        [ 1233],
        [ 6965],
        [   67],
        [14759],
        [13384]], device='cuda:0')

In [134]:
top_base_features

tensor([[21963],
        [ 6804],
        [ 8535],
        [13321],
        [14615],
        [14759],
        [  722],
        [ 1233],
        [13384],
        [10244]], device='cuda:0')

In [139]:
tokens['tokens'].shape, all_tokens.shape

(torch.Size([80770, 512]), torch.Size([24576, 128]))

In [29]:
max_delta = (lora_sae.W_dec - sae.W_dec).norm(dim=-1).argsort(descending=True)[:10]
max_delta

tensor([ 4853, 19120, 10236, 11945, 21030, 22418,  3030, 17487, 18573, 11640],
       device='cuda:0')

In [31]:
from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

hook_point = sae_group.cfg.hook_point
test_feature_idx_gpt = max_delta.cpu().tolist()

feature_vis_config_gpt = SaeVisConfig(
    hook_point=hook_point,
    features=test_feature_idx_gpt,
    batch_size=2048,
    minibatch_size_tokens=128,
    verbose=True,
)

sae_vis_data_gpt = SaeVisData.create(
    encoder=sae,
    model=model,
    tokens=tokens[:1000],  # type: ignore
    cfg=feature_vis_config_gpt,
)

Forward passes to cache data for vis:   0%|          | 0/8 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/10 [00:00<?, ?it/s]

In [32]:
filename = f"lora_sae_features.html"
sae_vis_data_gpt.save_feature_centric_vis(filename)

Saving feature-centric vis:   0%|          | 0/10 [00:00<?, ?it/s]

## Eearly exiting 

In [6]:
from tqdm import tqdm
from sae_lens.training.activations_store import ActivationsStore

def get_tokens(
    activation_store: ActivationsStore,
    n_batches_to_sample_from: int = 2**10,
    n_prompts_to_select: int = 4096 * 6,
):
    all_tokens_list = []
    pbar = tqdm(range(n_batches_to_sample_from))
    for _ in pbar:
        batch_tokens = activation_store.get_batch_tokens()
        batch_tokens = batch_tokens[torch.randperm(batch_tokens.shape[0])][
            : batch_tokens.shape[0]
        ]
        all_tokens_list.append(batch_tokens)

    all_tokens = torch.cat(all_tokens_list, dim=0)
    all_tokens = all_tokens[torch.randperm(all_tokens.shape[0])]
    return all_tokens[:n_prompts_to_select]

all_tokens = get_tokens(activation_store)

100%|██████████| 1024/1024 [00:18<00:00, 54.77it/s]


In [10]:
sub_tokens = all_tokens[:1024]
batch_size = 16
activations = []

for b in tqdm(range(0, len(sub_tokens), batch_size)):
    with torch.no_grad():
        _, cache = model.run_with_cache(sub_tokens[b:b+batch_size, :-1])
    
    activations.append(cache[sae_group.cfg.hook_point].cpu())
    del cache

activations = torch.cat(activations).reshape(-1, 768)
labels = sub_tokens[:, 1:].reshape(-1, 1)

100%|██████████| 64/64 [00:02<00:00, 29.10it/s]


In [20]:
token_pred = []
sae_bs = 4096

for b in range(0, len(activations), sae_bs):
    sae_out, features_act, *_ = sae(activations[b:b+sae_bs].to('cuda'))
    token_pred.append(model.unembed(sae_out[:, None]).argmax(-1))

token_pred = torch.cat(token_pred)

In [64]:
train_labels = (token_pred == labels).type(torch.int64)[:, 0]

In [59]:
import torch
import torch.nn as nn

# Define the logistic regression model
class LogisticRegression(nn.Module):
    def __init__(self, input_size, num_classes):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(input_size, num_classes)
        
    def forward(self, x):
        out = self.linear(x)
        return out

In [60]:
input_size = sae.d_sae
num_classes = 2

# Create an instance of the logistic regression model
lr = LogisticRegression(input_size, num_classes).to('cuda')

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(lr.parameters(), lr=0.001)

In [62]:
# Training loop
num_epochs = 10
sae_batch_size = 4096

for epoch in range(num_epochs):
    for b in range(0, len(activations), sae_batch_size):
        # SAE run
        with torch.no_grad():
            sae_out, features_act, *_ = sae(activations[b:b+sae_bs].to('cuda'))
        
        # Forward pass
        outputs = lr(features_act).softmax(-1)
        loss = criterion(outputs, train_labels[b:b+sae_bs])
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Print the loss every 10 epochs
        if (epoch+1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn