In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
from typing import Iterable

# HookedTransformer

* [TransformerLens - Tutorial - Trains HookedTransformer from Scratch](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/No_Position_Experiment.ipynb)

```python
import transformers

# note: it's probably easier to just operate on tokens outside of the model,
#       that'll also make it clearer where tokenizer is used
#
# okay wrapping a pretrained tokenizer *can* be done:
# - https://huggingface.co/learn/nlp-course/chapter6/8#building-a-bpe-tokenizer-from-scratch
# - but none of the models support just naive encoding
#   - https://huggingface.co/docs/tokenizers/api/models#tokenizers.models.BPE
class HookedTransformer:
    cfg: HookedTransformerConfig

    # note: actually does an `isinstance` check in the constructor
    tokenizer: transformers.PreTrainedTokenizerBase | None
```

In [3]:
import transformer_lens

from jaxtyping import Int64, Float32

import numpy as np
import plotly.express as px
import plotly.io as pio

import torch
import torch.utils.data

In [4]:
# plotting code copied over from transformer_lens tutorial notebook

def line(tensor: torch.Tensor, line_labels=None, yaxis="", xaxis="", **kwargs):
    tensor = transformer_lens.utils.to_numpy(tensor)
    labels = {"y": yaxis, "x": xaxis}
    fig = px.line(tensor, labels=labels, **kwargs)
    if line_labels:
        for c, label in enumerate(line_labels):
            fig.data[c].name = label
    fig.show()


def imshow(tensor: torch.Tensor, yaxis="", xaxis="", **kwargs):
    tensor = transformer_lens.utils.to_numpy(tensor)
    plot_kwargs = {
        "color_continuous_scale": "RdBu",
        "color_continuous_midpoint": 0.0,
        "labels": {"x": xaxis, "y": yaxis},
    }
    plot_kwargs.update(kwargs)
    px.imshow(tensor, **plot_kwargs).show()

In [5]:
device = transformer_lens.utils.get_device()

print(f'Using device: {device}')

Using device: mps


### Setup Sample Generator

In [6]:
import string
import itertools
import more_itertools

class SpecialToken:
    # note: as assume a BOS token because transformerlens expects it
    BOS = '<'
    # we use a EOS token for convenience
    EOS = '>'

# TODO(bschoen): Allow this to generalize in the future
def generate_sample() -> Iterable[str]:
    """Generate palindrom samples like `<abc|cba>`."""

    # Generate all combinations of lowercase letters
    characters = string.ascii_lowercase

    # note: chosen arbitrarily
    length = 3
    
    for combination in itertools.product(characters, repeat=length):

        combination_str = ''.join(combination)
        reversed_str = ''.join(reversed(combination_str))

        yield SpecialToken.BOS + combination_str + '|' + reversed_str + SpecialToken.EOS

# show a few examples
[x for x in more_itertools.take(10, generate_sample())]

['<aaa|aaa>',
 '<aab|baa>',
 '<aac|caa>',
 '<aad|daa>',
 '<aae|eaa>',
 '<aaf|faa>',
 '<aag|gaa>',
 '<aah|haa>',
 '<aai|iaa>',
 '<aaj|jaa>']

### Setup Tokenizer

In [7]:
from gpt_from_scratch.naive_tokenizer import NaiveTokenizer

vocab = string.ascii_lowercase + '|' + SpecialToken.BOS + SpecialToken.EOS

tokenizer = NaiveTokenizer.from_text(vocab)

In [8]:
from gpt_from_scratch import tokenizer_utils

# test tokenizer
input_text = '<abc|cba><bdd|ddb>'
tokenizer_utils.show_token_mapping(tokenizer, input_text)

Input:		<abc|cba><bdd|ddb>
Tokenized:	[44m[97m<[0m[45m[97ma[0m[41m[97mb[0m[46m[97mc[0m[43m[97m|[0m[42m[97mc[0m[44m[97mb[0m[45m[97ma[0m[41m[97m>[0m[46m[97m<[0m[43m[97mb[0m[42m[97md[0m[44m[97md[0m[45m[97m|[0m[41m[97md[0m[46m[97md[0m[43m[97mb[0m[42m[97m>[0m
Token ID | Token Bytes | Token String
---------+-------------+--------------
       0 | [38;5;2m3C[0m | '<'
          [48;5;1m[38;5;15m<[0mabc|cba><bdd|ddb>
          U+003C LESS-THAN SIGN (1 bytes: [38;5;2m3C[0m)
       2 | [38;5;2m61[0m | 'a'
          <[48;5;1m[38;5;15ma[0mbc|cba><bdd|ddb>
          U+0061 LATIN SMALL LETTER A (1 bytes: [38;5;2m61[0m)
       3 | [38;5;2m62[0m | 'b'
          <a[48;5;1m[38;5;15mb[0mc|cba><bdd|ddb>
          U+0062 LATIN SMALL LETTER B (1 bytes: [38;5;2m62[0m)
       4 | [38;5;2m63[0m | 'c'
          <ab[48;5;1m[38;5;15mc[0m|cba><bdd|ddb>
          U+0063 LATIN SMALL LETTER C (1 bytes: [38;5;2m63[0m)
      28 | [38;5;2m

### Setup Model

In [23]:
# now we know our vocab size from our sample generation

cfg = transformer_lens.HookedTransformerConfig(
    n_layers=1,
    d_model=16,
    d_head=4,
    # The number of attention heads.
    # If not specified, will be set to d_in // d_head.
    # (This is represented by a default value of -1)
    n_heads=4,
    # The dimensionality of the feedforward mlp network.
    # Defaults to 4 * d_in, and in an attn-only model is None.
    # d_mlp=16,
    # note: transformerlens does the same thing if this is not set
    d_vocab=len(tokenizer.byte_to_token_dict),
    # length of the first sample is our context length
    n_ctx=len(more_itertools.nth(generate_sample(), 1)),
    act_fn="relu",
    normalization_type="LN",
    # note: must be set, otherwise tries to default to cuda / cpu (not mps)
    device=device.type,
)

print(f'Num params: {cfg.n_params}')

cfg

Num params: 3072


HookedTransformerConfig:
{'act_fn': 'relu',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 2.0,
 'attn_scores_soft_cap': -1.0,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 4,
 'd_mlp': 64,
 'd_model': 16,
 'd_vocab': 29,
 'd_vocab_out': 29,
 'decoder_start_token_id': None,
 'default_prepend_bos': True,
 'device': 'mps',
 'dtype': torch.float32,
 'eps': 1e-05,
 'experts_per_token': None,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': True,
 'initializer_range': 0.2,
 'load_in_4bit': False,
 'model_name': 'custom',
 'n_ctx': 9,
 'n_devices': 1,
 'n_heads': 4,
 'n_key_value_heads': None,
 'n_layers': 1,
 'n_params': 3072,
 'normalization_type': 'LN',
 'num_experts': None,
 'original_architecture': None,
 'output_logits_soft_cap': -1.0,
 'parallel_attn_mlp': False,
 'positional_embedding_type': 'standard',
 'post_embedding_ln': False,
 'rela

### Setup Loss Function

In [12]:
def loss_fn(logits, target):
    # standard cross entropy loss
    return torch.nn.functional.cross_entropy(
        logits.view(-1, logits.size(-1)),
        target.view(-1),
    )

### Evaluate On Test

In [13]:
def evaluate_loss_on_test_batches(
    model: transformer_lens.HookedTransformer,
    data_loader: torch.utils.data.DataLoader,
) -> float:

    # Set the model to evaluation mode
    model.eval()  

    losses = []

    with torch.no_grad():  # Disable gradient computation
        
        for batch in data_loader:
            
            x, y = batch

            x, y = x.to(device), y.to(device)

            logits = model(x)

            loss = loss_fn(logits, y)

            losses.append(loss.item())

    # Set the model back to training mode
    model.train() 

    return sum(losses) / len(losses)

### Setup Data Loaders

In [14]:
class AutoregressiveDataset(torch.utils.data.Dataset):
    def __init__(self, samples: list[str], tokenizer: NaiveTokenizer) -> None:
        self.samples = samples
        self.tokenizer = tokenizer  # Assuming tokenizer is defined in the global scope

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        tokens = self.tokenizer.encode(sample)
        
        # Convert to tensor and add batch dimension
        x = torch.tensor(tokens[:-1], dtype=torch.long)
        y = torch.tensor(tokens[1:], dtype=torch.long)
        
        return x, y

def make_batch_dataloader(
    samples: list[str],
    tokenizer: NaiveTokenizer,
    batch_size: int,
) -> tuple[torch.utils.data.Dataset, torch.utils.data.DataLoader]:

    dataset = AutoregressiveDataset(samples=samples, tokenizer=tokenizer)
    
    # Create DataLoader
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        # drop the last batch if it's incomplete
        drop_last=True,
    )
    
    return dataset, dataloader

# Example usage:
# batch_generator = make_batch_generator(tokenizer, batch_size=4)
# for x, y in batch_generator:
#     # x is input, y is target (x shifted by 1)
#     pass


In [16]:
# split into test and train
all_samples = list(generate_sample())

# note: 4394 batches = (26 * 26 * 26) / 4
print(f'{len(all_samples)} samples')

# max_samples = 10
# print(f'Capping at {max_samples} batches first to make sure we can overfit')
# all_samples = all_samples[:max_samples]

test_train_ratio = 0.1

test_size = int(test_train_ratio * len(all_samples))
    
# put remaining ones into train
train_size = len(all_samples) - test_size

train_samples = all_samples[:train_size]
test_samples = all_samples[train_size:]

print(f'{len(train_samples)=}')
print(f'{len(test_samples)=}')

# now we can finally construct dataloaders
batch_size = 4

train_dataset, train_loader = make_batch_dataloader(
    samples=train_samples,
    tokenizer=tokenizer,
    batch_size=batch_size,
)
test_dataset, test_loader = make_batch_dataloader(
    samples=test_samples,
    tokenizer=tokenizer,
    batch_size=batch_size,
)


17576 samples
len(train_samples)=15819
len(test_samples)=1757


### Training

In [24]:
import tqdm

import torch.optim

import wandb

# TODO(bschoen): Try out optuna wrapper for this


# create new model instance
model = transformer_lens.HookedTransformer(cfg)

# setup optimizers
lr = 1e-4
betas = (0.9, 0.95)
max_grad_norm = 1.0
wd = 0.1

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=betas, weight_decay=wd)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda i: min(i / 100, 1.0))

num_epochs = 20000

# setup wandb
config = cfg.to_dict()
config.update({
    'num_epochs': num_epochs,
    'batch_size': batch_size,
    'lr': 1e-4,
    'max_grad_norm': 1.0,
    'wd': 0.1,
})
wandb.init(
    project="toy-problem-hooked-transformer",
    config=config,
)

losses = []
test_losses = []

for epoch, batch in tqdm.tqdm(
    zip(
        range(num_epochs),
        itertools.cycle(train_loader),
    )
):

    tokens, target = batch

    tokens, target = tokens.to(device), target.to(device)

    # ex: torch.Size([4, 9, 29])
    logits: Float32[torch.Tensor, "batch_size n_ctx d_vocab"] = model(tokens)

    # print(f"Logits:\n{logits.shape}")
    loss = loss_fn(logits, target)

    loss.backward()

    if max_grad_norm is not None:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

    optimizer.step()

    optimizer.zero_grad()

    scheduler.step()

    losses.append(loss.item())

    # TODO(bschoen): Shouldn't you actually divide loss by batch size?
    if epoch % 500 == 0:
        print('Evaluating test loss')

        test_loss = evaluate_loss_on_test_batches(model, test_loader)

        test_losses.append(test_loss)

        print(f"Epoch {epoch}: Train loss: {loss.item():.6f}, Test loss: {test_loss:.6f}")

        wandb.log({
            'epoch': epoch,
            'train_loss': loss.item(),
            'test_loss': test_loss,
        })

wandb.finish()

# log locally to sanity check
px.line(losses, labels={"x": "Epoch", "y": "Train Loss"})


0it [00:00, ?it/s]

Evaluating test loss


10it [00:01,  9.03it/s]

Epoch 0: Train loss: 3.478746, Test loss: 3.807300


496it [00:06, 91.22it/s]

Evaluating test loss


516it [00:08, 26.58it/s]

Epoch 500: Train loss: 2.992052, Test loss: 3.120214


1000it [00:13, 83.40it/s]

Evaluating test loss


1019it [00:15, 24.90it/s]

Epoch 1000: Train loss: 2.737082, Test loss: 2.881124


1498it [00:20, 91.33it/s]

Evaluating test loss


1517it [00:22, 25.28it/s]

Epoch 1500: Train loss: 2.621378, Test loss: 2.728826


1998it [00:27, 88.97it/s]

Evaluating test loss


2016it [00:29, 25.17it/s]

Epoch 2000: Train loss: 2.517176, Test loss: 2.604682


2493it [00:34, 87.20it/s]

Evaluating test loss


2511it [00:36, 23.07it/s]

Epoch 2500: Train loss: 2.332778, Test loss: 2.555676


2994it [00:41, 87.11it/s]

Evaluating test loss


3012it [00:43, 23.26it/s]

Epoch 3000: Train loss: 2.316447, Test loss: 2.488751


3494it [00:49, 74.91it/s]

Evaluating test loss


3511it [00:50, 22.44it/s]

Epoch 3500: Train loss: 2.192639, Test loss: 2.403550


3997it [00:56, 89.21it/s]

Evaluating test loss


4015it [00:57, 24.30it/s]

Epoch 4000: Train loss: 2.163207, Test loss: 2.343389


4497it [01:03, 90.54it/s]

Evaluating test loss


4516it [01:04, 25.70it/s]

Epoch 4500: Train loss: 2.136429, Test loss: 2.267579


4996it [01:10, 91.23it/s]

Evaluating test loss


5015it [01:11, 24.27it/s]

Epoch 5000: Train loss: 1.937449, Test loss: 2.175023


5495it [01:17, 88.68it/s]

Evaluating test loss


5513it [01:18, 25.03it/s]

Epoch 5500: Train loss: 1.888417, Test loss: 2.091810


5996it [01:24, 90.77it/s]

Evaluating test loss


6015it [01:25, 25.85it/s]

Epoch 6000: Train loss: 1.811999, Test loss: 2.014765


6495it [01:31, 84.28it/s]

Evaluating test loss


6514it [01:32, 25.22it/s]

Epoch 6500: Train loss: 1.724781, Test loss: 1.953917


6997it [01:38, 90.79it/s]

Evaluating test loss


7016it [01:39, 25.19it/s]

Epoch 7000: Train loss: 1.638892, Test loss: 1.873010


7498it [01:45, 89.85it/s]

Evaluating test loss


7516it [01:46, 25.38it/s]

Epoch 7500: Train loss: 1.513364, Test loss: 1.802080


8000it [01:52, 89.67it/s]

Evaluating test loss


8009it [01:53, 18.87it/s]

Epoch 8000: Train loss: 1.492503, Test loss: 1.737086


8492it [01:59, 88.51it/s]

Evaluating test loss


8511it [02:00, 24.72it/s]

Epoch 8500: Train loss: 1.422508, Test loss: 1.682528


9000it [02:06, 85.77it/s]

Evaluating test loss


9009it [02:07, 18.53it/s]

Epoch 9000: Train loss: 1.369791, Test loss: 1.620792


9494it [02:13, 90.04it/s]

Evaluating test loss


9513it [02:14, 25.75it/s]

Epoch 9500: Train loss: 1.340623, Test loss: 1.600983


9997it [02:20, 83.57it/s]

Evaluating test loss


10016it [02:21, 25.56it/s]

Epoch 10000: Train loss: 1.336905, Test loss: 1.568584


10496it [02:27, 89.12it/s]

Evaluating test loss


10512it [02:28, 22.64it/s]

Epoch 10500: Train loss: 1.340092, Test loss: 1.574866


11000it [02:34, 90.32it/s]

Evaluating test loss


11010it [02:35, 19.80it/s]

Epoch 11000: Train loss: 1.278600, Test loss: 1.525385


11493it [02:41, 87.92it/s]

Evaluating test loss


11511it [02:42, 24.79it/s]

Epoch 11500: Train loss: 1.250157, Test loss: 1.514312


11999it [02:48, 87.44it/s]

Evaluating test loss


12018it [02:49, 24.99it/s]

Epoch 12000: Train loss: 1.266643, Test loss: 1.517130


12495it [02:54, 93.88it/s]

Evaluating test loss


12514it [02:56, 26.33it/s]

Epoch 12500: Train loss: 1.310005, Test loss: 1.494483


13000it [03:01, 88.90it/s]

Evaluating test loss


13018it [03:03, 24.44it/s]

Epoch 13000: Train loss: 1.243109, Test loss: 1.511725


13496it [03:08, 88.26it/s]

Evaluating test loss


13512it [03:10, 22.55it/s]

Epoch 13500: Train loss: 1.241693, Test loss: 1.498240


13991it [03:15, 91.67it/s]

Evaluating test loss


14010it [03:17, 25.22it/s]

Epoch 14000: Train loss: 1.255840, Test loss: 1.507804


14492it [03:22, 90.23it/s]

Evaluating test loss


14512it [03:24, 26.03it/s]

Epoch 14500: Train loss: 1.240428, Test loss: 1.489433


14997it [03:29, 87.11it/s]

Evaluating test loss


15015it [03:31, 23.94it/s]

Epoch 15000: Train loss: 1.232544, Test loss: 1.473856


15500it [03:36, 89.29it/s]

Evaluating test loss


15518it [03:38, 25.22it/s]

Epoch 15500: Train loss: 1.208100, Test loss: 1.468849


15995it [03:43, 86.93it/s]

Evaluating test loss


16013it [03:45, 24.57it/s]

Epoch 16000: Train loss: 1.231310, Test loss: 1.498882


16495it [03:50, 89.60it/s]

Evaluating test loss


16512it [03:52, 22.97it/s]

Epoch 16500: Train loss: 1.266035, Test loss: 1.472799


16991it [03:57, 90.16it/s]

Evaluating test loss


17011it [03:59, 26.22it/s]

Epoch 17000: Train loss: 1.241320, Test loss: 1.499916


17498it [04:04, 89.40it/s]

Evaluating test loss


17516it [04:06, 25.28it/s]

Epoch 17500: Train loss: 1.231346, Test loss: 1.476767


17997it [04:12, 92.61it/s]

Evaluating test loss


18017it [04:13, 26.57it/s]

Epoch 18000: Train loss: 1.218808, Test loss: 1.486324


18493it [04:18, 92.53it/s]

Evaluating test loss


18512it [04:20, 25.86it/s]

Epoch 18500: Train loss: 1.220689, Test loss: 1.460379


18995it [04:25, 88.33it/s]

Evaluating test loss


19013it [04:27, 24.32it/s]

Epoch 19000: Train loss: 1.242467, Test loss: 1.457452


19496it [04:32, 79.27it/s]

Evaluating test loss


19514it [04:34, 23.74it/s]

Epoch 19500: Train loss: 1.266019, Test loss: 1.464777


20000it [04:40, 71.39it/s]


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
test_loss,█▆▅▅▄▄▄▄▄▃▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▆▆▅▅▄▄▄▄▄▃▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,19500.0
test_loss,1.46478
train_loss,1.26602


In [82]:
# Look at some example output
import circuitsvis as cv

# create a custom to_string function since using our own tokenizer
def token_to_string(token: int) -> str:
    return tokenizer.decode([token])

# grab something from the test batch
example_batch = next(iter(test_loader))

x, y = example_batch

example_sample = x[0]

# grab the first part of it, ex: `<abc|`
example_prompt = example_sample # [:8]

example_prompt = example_prompt.to(device)

print(f'Using {example_prompt} from {example_sample} (from test set)')

# note: already encoded
input_tokens = example_prompt

logits_batch, cache = model.run_with_cache(input_tokens)

logits = logits_batch[0]

log_probs = logits.log_softmax(dim=-1)

cv.logits.token_log_probs(
    token_indices=input_tokens,
    log_probs=log_probs,
    to_string=token_to_string,
)
 

Using tensor([ 0, 25, 16, 20, 28, 20, 16, 25], device='mps:0') from tensor([ 0, 25, 16, 20, 28, 20, 16, 25]) (from test set)


### Looking at it with CircuitsViz

In [83]:
# before even going to SAE, let's look at circuitsviz here
import circuitsvis as cv

import circuitsvis.activations
import circuitsvis.attention
import circuitsvis.logits
import circuitsvis.tokens
import circuitsvis.topk_samples
import circuitsvis.topk_tokens

In [84]:
# first let's see what we have
import tabulate

print(f'{len(input_tokens)=}')

# show the first few elements of the `HookedTransformerConfig`, since that has things like `d_model`, num heads, etc
print(tabulate.tabulate([(k, v) for k, v in cfg.__dict__.items()][:10]))

print(tabulate.tabulate([(k, v.shape) for k, v in cache.items()]))

len(input_tokens)=8
----------  ------
n_layers    1
d_model     16
n_ctx       9
d_head      4
model_name  custom
n_heads     4
d_mlp       64
act_fn      relu
d_vocab     29
eps         1e-05
----------  ------
------------------------------  ------------------------
hook_embed                      torch.Size([1, 8, 16])
hook_pos_embed                  torch.Size([1, 8, 16])
blocks.0.hook_resid_pre         torch.Size([1, 8, 16])
blocks.0.ln1.hook_scale         torch.Size([1, 8, 1])
blocks.0.ln1.hook_normalized    torch.Size([1, 8, 16])
blocks.0.attn.hook_q            torch.Size([1, 8, 4, 4])
blocks.0.attn.hook_k            torch.Size([1, 8, 4, 4])
blocks.0.attn.hook_v            torch.Size([1, 8, 4, 4])
blocks.0.attn.hook_attn_scores  torch.Size([1, 4, 8, 8])
blocks.0.attn.hook_pattern      torch.Size([1, 4, 8, 8])
blocks.0.attn.hook_z            torch.Size([1, 8, 4, 4])
blocks.0.hook_attn_out          torch.Size([1, 8, 16])
blocks.0.hook_resid_mid         torch.Size([1, 8, 16])
bloc

#### circuitsvis.activations

In [85]:
# tokens := List of tokens if single sample (e.g. `["A", "person"]`) or list of lists of tokens (e.g. `[[["A", "person"], ["is", "walking"]]]`)
# activations := Activations of the shape [tokens x layers x neurons] if single sample or list of [tokens x layers x neurons] if multiple samples

# take first batch for now
activations = cache['blocks.0.hook_mlp_out'][0]
print(f'{activations.shape=}')

# reshape [tokens x neurons] -> [tokens x 1 x neurons]
#  - `-1` means to automatically infer the size of the last dimension
activations_view = activations.view(len(input_tokens), cfg.n_layers, -1)

print(f'{activations_view.shape=}')

# convert to strings (which this function expects)
input_tokens_as_strings = [token_to_string(x.item()) for x in input_tokens]

# TODO(bschoen): Is there a way to essentially stack these? Claude can probably give the React for that

# so here we can visualize activations for a `torch.Size([1, 8, 16])`, which is most
# of them since this is the size of the embedding dimension
circuitsvis.activations.text_neuron_activations(
    tokens=[token_to_string(x.item()) for x in input_tokens],
    activations=activations_view,
)

activations.shape=torch.Size([8, 16])
activations_view.shape=torch.Size([8, 1, 16])


#### circuitsvis.attention

In [44]:
# note `attention_pattern` and `attention_patterns` are deprecated in favor of `attention_heads`
circuitsvis.attention.attention_heads?

[0;31mSignature:[0m
[0mcircuitsvis[0m[0;34m.[0m[0mattention[0m[0;34m.[0m[0mattention_heads[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mattention[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mlist[0m[0;34m,[0m [0mnumpy[0m[0;34m.[0m[0mndarray[0m[0;34m,[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtokens[0m[0;34m:[0m [0mList[0m[0;34m[[0m[0mstr[0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mattention_head_names[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mList[0m[0;34m[[0m[0mstr[0m[0;34m][0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmax_value[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mfloat[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmin_value[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mfloat[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mnegative_color[0m

In [64]:
# tokens: List of tokens (e.g. `["A", "person"]`). Must be the same length as the list of values.
# attention: Attention head activations of the shape [dest_tokens x src_tokens]
# max_value: Maximum value. Used to determine how dark the token color is when positive (i.e. based on how close it is to the maximum value).
# min_value: Minimum value. Used to determine how dark the token color is when negative (i.e. based on how close it is to the minimum value).
# negative_color: Color for negative values
# positive_color: Color for positive values.
#show_axis_labels: Whether to show axis labels.
# mask_upper_tri: Whether or not to mask the upper triangular portion of the attention patterns. Should be true for causal attention, false for bidirectional attention.



# take first batch
# ex: torch.Size([4, 8, 8]) -> [n_heads, n_ctx, n_ctx]
# note: `blocks.0.attn.hook_attn_scores` is too early (not normalized?)
attention = cache['blocks.0.attn.hook_pattern'][0]

print(f'{attention.shape=}')

circuitsvis.attention.attention_heads(
    tokens=input_tokens_as_strings,
    attention=attention,
    max_value=1,
    min_value=-1,
    negative_color='blue',
    positive_color='red',
    mask_upper_tri=True,
)

attention.shape=torch.Size([4, 8, 8])


#### circuitsvis.logits

In [66]:
# this is the normal one we usually show, i.e.
# cv.logits.token_log_probs(
#     token_indices=input_tokens,
#     log_probs=log_probs,
#     to_string=token_to_string,
# )

#### circuitsvis.tokens

In [94]:
# for example, we'll look at each 

# take first batch, ex: torch.Size([8, 16])
pos_embed = cache['hook_pos_embed'][0]

# low level function for coloring tokens according to single value
for i in range(cfg.d_model):
    display(circuitsvis.tokens.colored_tokens(
        tokens=input_tokens_as_strings,
        values=pos_embed[:, i],
        negative_color='blue',
        positive_color='red',
    ))

    # only display a few for example
    # if i >= 2:
    #    break


In [95]:
# take first batch
# ex: torch.size([8, 16]) = [n_ctx, d_model]
attention_out = cache['blocks.0.hook_attn_out'][0]

circuitsvis.tokens.colored_tokens_multi(
    tokens=input_tokens_as_strings,
    values=attention_out,
    labels=[str(x) for x in range(cfg.d_model)],
)

In [97]:
circuitsvis.tokens.visualize_model_performance(
    tokens=input_tokens,
    str_tokens=input_tokens_as_strings,
    logits=logits,
)

#### circuitsvis.topk_samples

In [98]:
circuitsvis.topk_samples.topk_samples??

[0;31mSignature:[0m
[0mcircuitsvis[0m[0;34m.[0m[0mtopk_samples[0m[0;34m.[0m[0mtopk_samples[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mtokens[0m[0;34m:[0m [0mList[0m[0;34m[[0m[0mList[0m[0;34m[[0m[0mList[0m[0;34m[[0m[0mList[0m[0;34m[[0m[0mstr[0m[0;34m][0m[0;34m][0m[0;34m][0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mactivations[0m[0;34m:[0m [0mList[0m[0;34m[[0m[0mList[0m[0;34m[[0m[0mList[0m[0;34m[[0m[0mList[0m[0;34m[[0m[0mfloat[0m[0;34m][0m[0;34m][0m[0;34m][0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mzeroth_dimension_name[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mstr[0m[0;34m][0m [0;34m=[0m [0;34m'Layer'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mfirst_dimension_name[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mstr[0m[0;34m][0m [0;34m=[0m [0;34m'Neuron'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mzeroth_dimension_labels[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mLis

#### circuitsvis.topk_tokens

In [99]:
circuitsvis.topk_tokens.topk_tokens??

[0;31mSignature:[0m
[0mcircuitsvis[0m[0;34m.[0m[0mtopk_tokens[0m[0;34m.[0m[0mtopk_tokens[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mtokens[0m[0;34m:[0m [0mList[0m[0;34m[[0m[0mList[0m[0;34m[[0m[0mstr[0m[0;34m][0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mactivations[0m[0;34m:[0m [0mList[0m[0;34m[[0m[0mnumpy[0m[0;34m.[0m[0mndarray[0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmax_k[0m[0;34m:[0m [0mint[0m [0;34m=[0m [0;36m10[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mfirst_dimension_name[0m[0;34m:[0m [0mstr[0m [0;34m=[0m [0;34m'Layer'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mthird_dimension_name[0m[0;34m:[0m [0mstr[0m [0;34m=[0m [0;34m'Neuron'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0msample_labels[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mList[0m[0;34m[[0m[0mstr[0m[0;34m][0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mfirst_dimension_lab

## SAE

In [None]:
for layer_index in range(cfg.n_layers):
    imshow(
        transformer_lens.utils.to_numpy(cache["attn", layer_index].mean([0, 1])),
        title=f"Layer {layer_index} Attention Pattern",
        height=400,
        width=400,
    )

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

Loss = Float32[torch.Tensor, ""]
MSELoss = Float32[torch.Tensor, ""]
WeightedSparsityLoss = Float32[torch.Tensor, ""]

Logits = Float32[torch.Tensor, "n_ctx d_vocab"]
BatchedLogits = Float32[torch.Tensor, "batch n_ctx d_vocab"]

ModelActivations = Float32[torch.Tensor, "n_ctx d_model"]
BatchedModelActivations = Float32[torch.Tensor, "batch n_ctx d_model"]

FlattenedModelActivations = Float32[torch.Tensor, "d_sae_in"]

BatchedFlattenedModelActivations = Float32[torch.Tensor, "batch d_sae_in"]
BatchedSAEActivations = Float32[torch.Tensor, "batch d_sae_model"]

@dataclasses.dataclass
class SAEOutput:
    sae_activations: BatchedSAEActivations
    reconstructed_model_activations: BatchedFlattenedModelActivations

def sparse_loss_kl_divergence(
    flattened_model_activations: BatchedModelActivations,
    sae_output: SAEOutput,
    sparsity_target: float,
    sparsity_weight: float,
    epsilon: float = 1e-10,
) -> tuple[Loss, MSELoss, WeightedSparsityLoss]:

    # same as dense loss (this is constant?)
    mse_loss = F.mse_loss(
        sae_output.reconstructed_model_activations,
        flattened_model_activations,
    )
    
    # KL divergence for sparsity
    avg_activation = torch.mean(sae_output.sae_activations, dim=0)

    # print(f'[pre-clamping] {avg_activation=}')

    # Add epsilon for numerical stability
    avg_activation = torch.clamp(avg_activation, epsilon, 1 - epsilon)

    kl_div = sparsity_target * torch.log(sparsity_target / avg_activation) + \
             (1 - sparsity_target) * torch.log((1 - sparsity_target) / (1 - avg_activation))
    kl_div = torch.sum(kl_div)

    # `sparsity_weight` decides how much we weight `KL-Divergence`
    sparsity_penalty = sparsity_weight * kl_div

    # print(f'[sparse_loss] {mse_loss=}, {avg_activation=}, {kl_div=}, {sparsity_penalty=}')
    
    return mse_loss + sparsity_penalty, mse_loss, sparsity_penalty

In [None]:
import dataclasses

import lightning.pytorch



@dataclasses.dataclass
class SparseAutoencoderConfig:
    d_in: int
    d_model: int
    sparsity_target: float = 0.05

# TODO(bschoen): Start using the config pattern, it stays typesafe and allows
#                easy logging to things like wandb
class SparseAutoencoder(nn.Module):
    def __init__(
        self,
        cfg: SparseAutoencoderConfig,
    ) -> None:

        print(f'Creating SparseAutoencoder with {cfg}')

        super(SparseAutoencoder, self).__init__()

        self.d_in = cfg.d_in
        self.d_model = cfg.d_model

        self.encoder = nn.Linear(cfg.d_in, cfg.d_model)
        self.decoder = nn.Linear(cfg.d_model, cfg.d_in)

        # Target average activation of hidden neurons
        # Motivation: Encourages each neuron to be active for ex: ~5% of inputs, promoting specialization
        self.sparsity_target = cfg.sparsity_target
    
    def forward(
        self,
        x: BatchedFlattenedModelActivations,
    ) -> SAEOutput:
        
        # TODO(bschoen): Which activation function should we use?
        
        encoded = F.gelu(self.encoder(x))
        
        decoded = self.decoder(encoded)
        
        return SAEOutput(
            sae_activations=encoded,
            reconstructed_model_activations=decoded,
        )
    
@dataclasses.dataclass
class LightningSparseAutoencoderConfig:
    
    model_config: transformer_lens.HookedTransformerConfig
    sae_config: SparseAutoencoderConfig
    learning_rate: float
    sparsity_weight: float
    
# note: this kind of lightning adapter is a common pattern: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#starter-example
class LightningSparseAutoencoder(lightning.pytorch.LightningModule):
    def __init__(
        self,
        cfg: LightningSparseAutoencoderConfig,
    ) -> None:

        super(LightningSparseAutoencoder, self).__init__()
        
        self.model = transformer_lens.HookedTransformer(cfg=cfg.model_config)
        self.sae = SparseAutoencoder(cfg=cfg.sae_config)
        self.cfg = cfg

    def forward(self, inputs, target):
        return self.model(inputs, target)

    def training_step(self, batch, batch_idx: int) -> LossValue:
        inputs, target = batch

        self.model
        output = self(inputs, target)
        loss = torch.nn.functional.cr(output, target.view(-1))
        return loss

    def configure_optimizers(self):
        return torch.optim.SGD(self.model.parameters(), lr=0.1)

In [None]:
hook_id = 'blocks.0.hook_mlp_out'

cache[hook_id].shape

In [None]:
# Training loop
sae_num_epochs = 10000
sae_expansion_factor = 4

learning_rate = 1e-4

# both arbitrary for now
# - Start small: A common approach is to begin with a relatively small sparsity weight, 
#                typically in the range of 1e-5 to 1e-3. This allows the model to
#                learn meaningful representations before enforcing strong sparsity
#                constraints.
sparsity_weight: float = 1e-3  # Weight of the sparsity loss in the total loss
sparsity_target: float = 0.05  # Target average activation of hidden neurons

print(f'Training SAE for {hook_id}...')
sae_d_in = (cfg.n_ctx - 1) * cfg.d_model  # -1 since not predicting first token
sae_d_model = sae_d_in * sae_expansion_factor

sae_model = SparseAutoencoder(
    d_in=sae_d_in,
    d_model=sae_d_model,
    sparsity_target=sparsity_target,
)
sae_model.to(device)

sae_optimizer = optim.Adam(sae_model.parameters(), lr=learning_rate)
sae_criterion = nn.MSELoss()

sae_config = {
    'sae_num_epochs': sae_num_epochs,
    'sae_expansion_factor': sae_expansion_factor,
    'learning_rate': learning_rate,
    'sparsity_weight': sparsity_weight,
    'sparsity_target': sparsity_target,
    'sae_d_in': sae_d_in,
    'sae_d_model': sae_d_model,
    'hook_id': hook_id,
}
wandb.init(
    project="toy-problem-hooked-transformer-sae",
    config=sae_config,
)

# put model itself into eval mode so doesn't change
model.eval()

# go through the training data again, this time training the sae on the activations
for epoch, batch in tqdm.tqdm(
    zip(
        range(sae_num_epochs),
        itertools.cycle(train_loader),
    )
):

    tokens, target = batch

    tokens, target = tokens.to(device), target.to(device)

    # run through the model (with cache) to get the activations
    logits, cache = model.run_with_cache(tokens)

    # ex: torch.Size([4, 8, 16])
    activations = cache[hook_id]

    # ex: torch.Size([4, 128])
    flattened_activations = activations.reshape(activations.size(0), -1)

    sae_optimizer.zero_grad()

    # now the SAE model is given the *activations*
    encoded, decoded = sae_model(flattened_activations)

    # print(f'Encoded: {encoded.shape}')
    # print(f'Decoded: {decoded.shape}')

    # compute loss
    reconstruction_loss = sae_criterion(decoded, flattened_activations)
    sparsity_loss = sae_model.get_sparsity_loss(encoded)

    total_loss = reconstruction_loss + (sparsity_weight * sparsity_loss)

    total_loss.backward()

    sae_optimizer.step()

    if epoch % 500 == 0:
        print(
            f"Step {epoch}, "
            f"Total Loss: {total_loss.item():.6f}, "
            f"Reconstruction Loss: {reconstruction_loss.item():.6f}, "
            f"Sparsity Loss: {sparsity_weight *sparsity_loss.item():.6f}",
        )

        wandb.log({
            'epoch': epoch,
            'total_loss': total_loss.item(),
            'reconstruction_loss': reconstruction_loss.item(),
            'sparsity_loss': sparsity_loss.item(),
            'weighted_sparsity_loss': sparsity_weight * sparsity_loss.item(),
        })

wandb.finish()

#### Dictionary Learning Implementation

See [simple_dictionary_learning.ipynb](simple_dictionary_learning.ipynb) for a details

#### Extracting the learned dictionary

In [None]:
# Creating SparseAutoencoder with d_in=128, d_model=512, sparsity_target=0.05
dictionary: Float32[torch.Tensor, "sae_hidden sae_in"] = sae_model.encoder.weight.detach()

# ex: Dictionary shape: torch.Size([512, 128])
print(f'Dictionary shape: {dictionary.shape}')

In [None]:
# Reshape dictionary elements to match original activation shape
# (essentially `unflatting`)
reshaped_dictionary = dictionary.reshape(sae_d_model, (cfg.n_ctx - 1), cfg.d_model)

# Motivation: Extract the learned features (dictionary elements) from the encoder weights
# ex: Dictionary shape: torch.Size([512, 8, 16])
print(f"Dictionary shape: {reshaped_dictionary.shape}")


In [None]:
# It's always worth checking this sort of thing when you do this by hand
# to check that you haven't got the wrong site, or are missing a
# scaling factor or something like this. 
#
# This is like the overfitting thing

In [None]:
# let's look at an example batch from `test`

# set both to eval mode
model.eval()
sae_model.eval()

# grab something from the test batch
example_batch = next(iter(test_loader))

x, y = example_batch

_, cache = model.run_with_cache(x)

activations = cache[hook_id]

print(f'Activations shape: {activations.shape}')

# flatten it
flattened_activations = activations.reshape(activations.size(0), -1)

print(f'{flattened_activations.shape=}')

encoded, decoded = sae_model(flattened_activations)

# renamed
sae_activations = encoded
reconstructed_activations = decoded

print(f'{sae_activations.shape=}')
print(f'{reconstructed_activations.shape=}')

# now we can get the dictionary
# dictionary = sae_model.encoder.weight.detach()

# print(f'Dictionary shape: {dictionary.shape}')

# now we can get the sparse coefficients
# alpha = dictionary.T @ flattened_activations

In [None]:
# good, 0.995, basically all the variance is explained by our SAE
numerator = torch.mean((reconstructed_activations[:, 1:] - flattened_activations[:, 1:]) ** 2)
denominator = (flattened_activations[:, 1:].to(torch.float32).var())

explained_variance = 1 - (numerator / denominator)

print(f'{explained_variance.item()=:.4f}')

In [None]:
(sae_activations.mean(dim=0) > 0).sum()

In [None]:
# collect max activations


with torch.no_grad():

    # go through the training data again, but don't cycle, no reason to go through more than once
    for batch in tqdm.tqdm(train_loader):

        tokens, target = batch

        tokens, target = tokens.to(device), target.to(device)

        # run through the model (with cache) to get the activations
        logits, cache = model.run_with_cache(tokens)

        # ex: torch.Size([4, 8, 16])
        activations = cache[hook_id]

        # ex: torch.Size([4, 128])
        flattened_activations = activations.reshape(activations.size(0), -1)

        # now the SAE model is given the *activations*
        encoded, decoded = sae_model(flattened_activations)

        sae_activations = encoded

        # sae_activations.reshape(sae_d_model, (cfg.n_ctx - 1), cfg.d_model)

        # max_activations = torch.max(encoded, dim=1)


        break

In [None]:
sae_activations.shape

In [None]:
sae_activations[0].shape

In [None]:
8 * 16