In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
from typing import Iterable

# HookedTransformer

* [TransformerLens - Tutorial - Trains HookedTransformer from Scratch](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/No_Position_Experiment.ipynb)

```python
import transformers

# note: it's probably easier to just operate on tokens outside of the model,
#       that'll also make it clearer where tokenizer is used
#
# okay wrapping a pretrained tokenizer *can* be done:
# - https://huggingface.co/learn/nlp-course/chapter6/8#building-a-bpe-tokenizer-from-scratch
# - but none of the models support just naive encoding
#   - https://huggingface.co/docs/tokenizers/api/models#tokenizers.models.BPE
class HookedTransformer:
    cfg: HookedTransformerConfig

    # note: actually does an `isinstance` check in the constructor
    tokenizer: transformers.PreTrainedTokenizerBase | None
```

In [3]:
import transformer_lens

from jaxtyping import Int64, Float32

import numpy as np
import plotly.express as px
import plotly.io as pio

import torch
import torch.utils.data

In [4]:
device = transformer_lens.utils.get_device()

print(f'Using device: {device}')

Using device: mps


### Setup Sample Generator

In [5]:
import string
import itertools
import more_itertools

class SpecialToken:
    # note: as assume a BOS token because transformerlens expects it
    BOS = '<'
    # we use a EOS token for convenience
    EOS = '>'

# TODO(bschoen): Allow this to generalize in the future
def generate_sample() -> Iterable[str]:
    """Generate palindrom samples like `<abc|cba>`."""

    # Generate all combinations of lowercase letters
    characters = string.ascii_lowercase

    # note: chosen arbitrarily
    length = 3
    
    for combination in itertools.product(characters, repeat=length):

        combination_str = ''.join(combination)
        reversed_str = ''.join(reversed(combination_str))

        yield SpecialToken.BOS + combination_str + '|' + reversed_str + SpecialToken.EOS

# show a few examples
[x for x in more_itertools.take(10, generate_sample())]

['<aaa|aaa>',
 '<aab|baa>',
 '<aac|caa>',
 '<aad|daa>',
 '<aae|eaa>',
 '<aaf|faa>',
 '<aag|gaa>',
 '<aah|haa>',
 '<aai|iaa>',
 '<aaj|jaa>']

### Setup Tokenizer

In [6]:
from gpt_from_scratch.naive_tokenizer import NaiveTokenizer

vocab = string.ascii_lowercase + '|' + SpecialToken.BOS + SpecialToken.EOS

tokenizer = NaiveTokenizer.from_text(vocab)

In [7]:
from gpt_from_scratch import tokenizer_utils

# test tokenizer
input_text = '<abc|cba><bdd|ddb>'
tokenizer_utils.show_token_mapping(tokenizer, input_text)

Input:		<abc|cba><bdd|ddb>
Tokenized:	[44m[97m<[0m[45m[97ma[0m[46m[97mb[0m[42m[97mc[0m[43m[97m|[0m[41m[97mc[0m[44m[97mb[0m[45m[97ma[0m[46m[97m>[0m[42m[97m<[0m[43m[97mb[0m[41m[97md[0m[44m[97md[0m[45m[97m|[0m[46m[97md[0m[42m[97md[0m[43m[97mb[0m[41m[97m>[0m
Token ID | Token Bytes | Token String
---------+-------------+--------------
       0 | [38;5;2m3C[0m | '<'
          [48;5;1m[38;5;15m<[0mabc|cba><bdd|ddb>
          U+003C LESS-THAN SIGN (1 bytes: [38;5;2m3C[0m)
       2 | [38;5;2m61[0m | 'a'
          <[48;5;1m[38;5;15ma[0mbc|cba><bdd|ddb>
          U+0061 LATIN SMALL LETTER A (1 bytes: [38;5;2m61[0m)
       3 | [38;5;2m62[0m | 'b'
          <a[48;5;1m[38;5;15mb[0mc|cba><bdd|ddb>
          U+0062 LATIN SMALL LETTER B (1 bytes: [38;5;2m62[0m)
       4 | [38;5;2m63[0m | 'c'
          <ab[48;5;1m[38;5;15mc[0m|cba><bdd|ddb>
          U+0063 LATIN SMALL LETTER C (1 bytes: [38;5;2m63[0m)
      28 | [38;5;2m

### Setup Model

In [8]:
# now we know our vocab size from our sample generation

cfg = transformer_lens.HookedTransformerConfig(
    n_layers=2,
    d_model=64,
    d_head=16,
    # The number of attention heads.
    # If not specified, will be set to d_model // d_head.
    # (This is represented by a default value of -1)
    n_heads=4,
    # The dimensionality of the feedforward mlp network.
    # Defaults to 4 * d_model, and in an attn-only model is None.
    # d_mlp=16,
    # note: transformerlens does the same thing if this is not set
    d_vocab=len(tokenizer.byte_to_token_dict),
    # length of the first sample is our context length
    n_ctx=len(more_itertools.nth(generate_sample(), 1)),
    act_fn="relu",
    normalization_type="LN",
    # note: must be set, otherwise tries to default to cuda / cpu (not mps)
    device=device.type,
)

print(f'Num params: {cfg.n_params}')

cfg

Num params: 98304


HookedTransformerConfig:
{'act_fn': 'relu',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 4.0,
 'attn_scores_soft_cap': -1.0,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 16,
 'd_mlp': 256,
 'd_model': 64,
 'd_vocab': 29,
 'd_vocab_out': 29,
 'decoder_start_token_id': None,
 'default_prepend_bos': True,
 'device': 'mps',
 'dtype': torch.float32,
 'eps': 1e-05,
 'experts_per_token': None,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': True,
 'initializer_range': 0.1,
 'load_in_4bit': False,
 'model_name': 'custom',
 'n_ctx': 9,
 'n_devices': 1,
 'n_heads': 4,
 'n_key_value_heads': None,
 'n_layers': 2,
 'n_params': 98304,
 'normalization_type': 'LN',
 'num_experts': None,
 'original_architecture': None,
 'output_logits_soft_cap': -1.0,
 'parallel_attn_mlp': False,
 'positional_embedding_type': 'standard',
 'post_embedding_ln': False,
 'r

### Setup Loss Function

In [9]:
def loss_fn(logits, target):
    # standard cross entropy loss
    return torch.nn.functional.cross_entropy(
        logits.view(-1, logits.size(-1)),
        target.view(-1),
    )

### Evaluate On Test

In [10]:
def evaluate_loss_on_test_batches(
    model: transformer_lens.HookedTransformer,
    data_loader: torch.utils.data.DataLoader,
) -> float:

    # Set the model to evaluation mode
    model.eval()  

    losses = []

    with torch.no_grad():  # Disable gradient computation
        
        for batch in data_loader:
            
            x, y = batch

            x, y = x.to(device), y.to(device)

            logits = model(x)

            loss = loss_fn(logits, y)

            losses.append(loss.item())

    # Set the model back to training mode
    model.train() 

    return sum(losses) / len(losses)

### Setup Data Loaders

In [14]:
class AutoregressiveDataset(torch.utils.data.Dataset):
    def __init__(self, samples: list[str], tokenizer: NaiveTokenizer) -> None:
        self.samples = samples
        self.tokenizer = tokenizer  # Assuming tokenizer is defined in the global scope

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        tokens = self.tokenizer.encode(sample)
        
        # Convert to tensor and add batch dimension
        x = torch.tensor(tokens[:-1], dtype=torch.long)
        y = torch.tensor(tokens[1:], dtype=torch.long)
        
        return x, y

def make_batch_dataloader(
    samples: list[str],
    tokenizer: NaiveTokenizer,
    batch_size: int,
) -> torch.utils.data.DataLoader:

    dataset = AutoregressiveDataset(samples=samples, tokenizer=tokenizer)
    
    # Create DataLoader
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        # drop the last batch if it's incomplete
        drop_last=True,
    )
    
    return dataloader

# Example usage:
# batch_generator = make_batch_generator(tokenizer, batch_size=4)
# for x, y in batch_generator:
#     # x is input, y is target (x shifted by 1)
#     pass


In [15]:
# split into test and train
all_samples = list(generate_sample())

# note: 4394 batches = (26 * 26 * 26) / 4
print(f'{len(all_samples)} samples')

# max_samples = 10
# print(f'Capping at {max_samples} batches first to make sure we can overfit')
# all_samples = all_samples[:max_samples]

test_train_ratio = 0.1

test_size = int(test_train_ratio * len(all_samples))
    
# put remaining ones into train
train_size = len(all_samples) - test_size

train_samples = all_samples[:train_size]
test_samples = all_samples[train_size:]

print(f'{len(train_samples)=}')
print(f'{len(test_samples)=}')

# now we can finally construct dataloaders
batch_size = 4

train_loader = make_batch_dataloader(
    samples=train_samples,
    tokenizer=tokenizer,
    batch_size=batch_size,
)
test_loader = make_batch_dataloader(
    samples=test_samples,
    tokenizer=tokenizer,
    batch_size=batch_size,
)


17576 samples
len(train_samples)=15819
len(test_samples)=1757


### Training

In [16]:
import tqdm

import torch.optim

import wandb


# create new model instance
model = transformer_lens.HookedTransformer(cfg)

# setup optimizers
lr = 1e-4
betas = (0.9, 0.95)
max_grad_norm = 1.0
wd = 0.1

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=betas, weight_decay=wd)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda i: min(i / 100, 1.0))

num_epochs = 4000

# setup wandb
config = cfg.to_dict()
config.update({
    'num_epochs': num_epochs,
    'batch_size': batch_size,
    'lr': 1e-4,
    'max_grad_norm': 1.0,
    'wd': 0.1,
})
wandb.init(
    project="toy-problem-hooked-transformer",
    config=config,
)

losses = []
test_losses = []

for epoch, batch in tqdm.tqdm(
    zip(
        range(num_epochs),
        itertools.cycle(train_loader),
    )
):

    tokens, target = batch

    tokens, target = tokens.to(device), target.to(device)

    # ex: torch.Size([4, 9, 29])
    logits: Float32[torch.Tensor, "batch_size n_ctx d_vocab"] = model(tokens)

    # print(f"Logits:\n{logits.shape}")
    loss = loss_fn(logits, target)

    loss.backward()

    if max_grad_norm is not None:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

    optimizer.step()

    optimizer.zero_grad()

    scheduler.step()

    losses.append(loss.item())

    # TODO(bschoen): Shouldn't you actually divide loss by batch size?
    if epoch % 1000 == 0:
        print('Evaluating test loss')

        test_loss = evaluate_loss_on_test_batches(model, test_loader)

        test_losses.append(test_loss)

        print(f"Epoch {epoch}: Train loss: {loss.item():.6f}, Test loss: {test_loss:.6f}")

        wandb.log({
            'epoch': epoch,
            'train_loss': loss.item(),
            'test_loss': test_loss,
        })

wandb.finish()

# log locally to sanity check
px.line(losses, labels={"x": "Epoch", "y": "Train Loss"})


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011138140745202287, max=1.0…

0it [00:00, ?it/s]

Evaluating test loss


5it [00:02,  2.86it/s]

Epoch 0: Train loss: 3.551395, Test loss: 3.716435


995it [00:21, 53.38it/s]

Evaluating test loss


1007it [00:23, 12.06it/s]

Epoch 1000: Train loss: 1.291783, Test loss: 1.434423


1998it [00:42, 52.78it/s]

Evaluating test loss


2010it [00:44, 11.98it/s]

Epoch 2000: Train loss: 1.229445, Test loss: 1.367708


3000it [01:03, 53.31it/s]

Evaluating test loss


3006it [01:05,  9.02it/s]

Epoch 3000: Train loss: 1.188691, Test loss: 1.370201


4000it [01:24, 47.50it/s]


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▃▆█
test_loss,█▁▁▁
train_loss,█▁▁▁

0,1
epoch,3000.0
test_loss,1.3702
train_loss,1.18869


In [18]:
# Look at some example output
import circuitsvis as cv

# create a custom to_string function since using our own tokenizer
def token_to_string(token: int) -> str:
    return tokenizer.decode([token])

# grab something from the test batch
example_batch = next(iter(test_loader))

x, y = example_batch

example_sample = x[0]

# grab the first part of it, ex: `<abc|`
example_prompt = example_sample[:8]

example_prompt = example_prompt.to(device)

print(f'Using {example_prompt} from {example_sample} (from test set)')

# note: already encoded
input_tokens = example_prompt

logits_batch, cache = model.run_with_cache(input_tokens)

logits = logits_batch[0]

log_probs = logits.log_softmax(dim=-1)

cv.logits.token_log_probs(
    token_indices=input_tokens,
    log_probs=log_probs,
    to_string=token_to_string,
)
 

Using tensor([ 0, 27,  9,  9, 28,  9,  9, 27], device='mps:0') from tensor([ 0, 27,  9,  9, 28,  9,  9, 27]) (from test set)


In [None]:
transformer_lens.utils.test_prompt(
    prompt='Jill threw the ball to Jack. Jack threw the ball to Will. Will threw the ball back to',
    answer=' Jill',
    model=model,
    prepend_space_to_answer=True, # default
    print_details=True, # default
    prepend_bos=None, # default
    top_k=10, # default
)

In [None]:
# make sure we can overfit