In [1]:
import torch as t
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataset import TensorDataset
from einops import rearrange
from torch.nn import functional as F
import tqdm
import random
from collections import OrderedDict
from dataclasses import dataclass
from typing import Optional
from einops import rearrange
import wandb
from fancy_einsum import einsum

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import HookedRootModule, HookPoint  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

import requests
import os

In [2]:
# create dataset with bracket strings
# and labels for each bracket

# tokenizer for brackets
class Tokenizer:
    def __init__(self, vocab_size_without_special, seq_len):
        self.vocab_size = vocab_size_without_special
        self.seq_len = seq_len

        self.pad_token_id = vocab_size_without_special
        self.cls_token_id = vocab_size_without_special + 1
        self.sep_token_id = vocab_size_without_special + 2

    def __call__(self, text):
        return self.encode(text)

    def encode(self, nums):
        assert len(nums) <= self.seq_len
        nums = [self.cls_token_id] + nums + [self.sep_token_id] # add cls and sep tokens
        nums = nums + [self.pad_token_id] * (self.seq_len - len(nums) + 2) # pad to max len
        assert len(nums) == self.seq_len + 2
        return nums

    def decode(self, ids):
        if type(ids) == t.Tensor:
            ids = ids.tolist()
        ids = [id for id in ids if id != self.pad_token_id]
        # remove cls and sep tokens if present
        if ids[0] == self.cls_token_id:
            ids = ids[1:]
        if ids[-1] == self.sep_token_id:
            ids = ids[:-1]
        return ids

tokenizer = Tokenizer(10, 16)

eg = [1,2,3]
assert eg == tokenizer.decode(tokenizer(eg))


In [3]:
class Dataset(TensorDataset):
    def __init__(self, size, vocab_size, seq_len):
        # self.tokenizer = tokenizer
        self.size = size
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.rng = random.Random(42)
        self.train = self._make_dataset()
        super().__init__(*self.train)
    
    def _make_dataset(self):
        x = t.randint(0, self.vocab_size, (self.size, self.seq_len))
        # add cls and sep tokens
        # x = t.cat([t.ones(self.size, 1).long() * self.tokenizer.cls_token_id, x, t.ones(self.size, 1).long() * self.tokenizer.sep_token_id], dim=1)
        # pad to max len
        # x = t.cat([x, t.ones(self.size, self.tokenizer.seq_len + 2 - x.shape[1]).long() * self.tokenizer.pad_token_id], dim=1)
        # assert x.shape == (self.size, self.tokenizer.seq_len + 2)
        y = x.cummax(-1).values
        return x, y

# test dataset 
for x, y in DataLoader(Dataset(10, 15, 10), batch_size=2):
    print(x.shape, y.shape)
    print(x)
    print(y)
    break

torch.Size([2, 10]) torch.Size([2, 10])
tensor([[ 5,  6,  5, 10, 14,  9,  3, 14, 12, 13],
        [ 0,  4,  7, 12, 12,  8, 12,  3, 10,  3]])
tensor([[ 5,  6,  6, 10, 14, 14, 14, 14, 14, 14],
        [ 0,  4,  7, 12, 12, 12, 12, 12, 12, 12]])


In [4]:
device = t.device('cuda' if t.cuda.is_available() else 'cpu')

@dataclass
class Config:
    lr: float = 1e-3
    batch_size: int = 64
    epochs: int = 10
    seed: int = 42
    use_wandb: bool = False
    vocab_size: int = 1000
    seq_len: int = 16

config = Config()
print(config.__dict__)

{'lr': 0.001, 'batch_size': 64, 'epochs': 10, 'seed': 42, 'use_wandb': False, 'vocab_size': 1000, 'seq_len': 16}


In [5]:
ht_config = HookedTransformerConfig(
    n_layers=1,
    d_model=64,
    n_heads=2,
    n_ctx=16,
    d_head=32,
    attn_only=True,
    d_vocab=config.vocab_size
)
model = HookedTransformer(ht_config, tokenizer)
# print(model)

In [6]:
trainset = Dataset(10000, config.vocab_size, config.seq_len)
valset = Dataset(1000, config.vocab_size, config.seq_len)
testset = Dataset(1000, config.vocab_size, config.seq_len)

trainloader = DataLoader(trainset, batch_size=config.batch_size, shuffle=True)
valloader = DataLoader(valset, batch_size=config.batch_size, shuffle=False)
testloader = DataLoader(testset, batch_size=config.batch_size, shuffle=False)

In [7]:
print(config.__dict__ | ht_config.__dict__)

{'lr': 0.001, 'batch_size': 64, 'epochs': 10, 'seed': None, 'use_wandb': False, 'vocab_size': 1000, 'seq_len': 16, 'n_layers': 1, 'd_model': 64, 'n_ctx': 16, 'd_head': 32, 'model_name': 'custom', 'n_heads': 2, 'd_mlp': None, 'act_fn': None, 'd_vocab': 1000, 'eps': 1e-05, 'use_attn_result': False, 'use_attn_scale': True, 'use_local_attn': False, 'original_architecture': None, 'from_checkpoint': False, 'checkpoint_index': None, 'checkpoint_label_type': None, 'checkpoint_value': None, 'tokenizer_name': None, 'window_size': None, 'attn_types': None, 'init_mode': 'gpt2', 'normalization_type': 'LN', 'device': 'cpu', 'attention_dir': 'causal', 'attn_only': True, 'initializer_range': 0.1, 'init_weights': True, 'scale_attn_by_inverse_layer_idx': False, 'positional_embedding_type': 'standard', 'final_rms': False, 'd_vocab_out': 1000, 'parallel_attn_mlp': False, 'rotary_dim': None, 'n_params': 16384}


In [8]:
def train(model, config: Config, ht_config: HookedTransformerConfig, train_loader: DataLoader) -> None:
    '''Train using masked language modelling.'''
    model.train()
    model.to(device)
    opt = t.optim.Adam(model.parameters(), lr=config.lr)
    if config.use_wandb:
        wandb.init(project="gpt-max", config=config.__dict__ | ht_config.__dict__)
        wandb.watch(model)
    run_name = wandb.run.name if config.use_wandb else 'gpt-max'
    # tqdm progress bar of train loader annotated with epoch number
    # make dir if doesn't exist 
    if not os.path.exists(f"models/{run_name}"):
        os.makedirs(f"models/{run_name}")
    t.save(model.state_dict(), f"./models/{run_name}/{run_name}-e-1.pt")
    
    for epoch in range(config.epochs):
        progress_bar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")
        for n_batch, (batch, target) in enumerate(progress_bar):
            batch = batch.to(device)
            opt.zero_grad()

            mask = (batch != tokenizer.pad_token_id).float()
            mask.requires_grad = False
            logits = model(batch.to(device))
            loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), target.view(-1))
            loss.backward()
            opt.step()

            # measure test loss 
            if n_batch % 10 == 0:
                test_loss = 0
                with t.no_grad():
                    for batch, target in testloader:
                        batch = batch.to(device)
                        logits = model(batch)
                        test_loss += F.cross_entropy(logits.view(-1, logits.shape[-1]), target.view(-1)).item()
                test_loss /= len(testloader)
                progress_bar.set_postfix({"Train Loss": loss.item(), "Test Loss": test_loss})
                if config.use_wandb:
                    wandb.log({"Train Loss": loss.item(), "Test Loss": test_loss})
        t.save(model.state_dict(), f"./models/{run_name}/{run_name}-e-{epoch+1}.pt")
    if config.use_wandb:
        wandb.finish()
    requests.post("https://ntfy.sh/arena-brackets", data=f"Done training {run_name} 🎉".encode(encoding='utf-8'))

config.use_wandb = True
train(model, config, ht_config, trainloader)

Moving model to device:  cpu


[34m[1mwandb[0m: Currently logged in as: [33mjnhl[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 0: 100%|██████████| 157/157 [00:02<00:00, 60.94it/s, Train Loss=3.28, Test Loss=3.15]
Epoch 1: 100%|██████████| 157/157 [00:02<00:00, 65.53it/s, Train Loss=0.81, Test Loss=0.854]
Epoch 2: 100%|██████████| 157/157 [00:02<00:00, 64.91it/s, Train Loss=0.345, Test Loss=0.454]
Epoch 3: 100%|██████████| 157/157 [00:02<00:00, 68.55it/s, Train Loss=0.216, Test Loss=0.28] 
Epoch 4: 100%|██████████| 157/157 [00:02<00:00, 69.04it/s, Train Loss=0.225, Test Loss=0.272]
Epoch 5: 100%|██████████| 157/157 [00:02<00:00, 68.43it/s, Train Loss=0.178, Test Loss=0.21]  
Epoch 6: 100%|██████████| 157/157 [00:02<00:00, 62.39it/s, Train Loss=0.123, Test Loss=0.164] 
Epoch 7: 100%|██████████| 157/157 [00:03<00:00, 47.15it/s, Train Loss=0.0575, Test Loss=0.144]
Epoch 8: 100%|██████████| 157/157 [00:03<00:00, 51.74it/s, Train Loss=0.0213, Test Loss=0.182]
Epoch 9: 100%|██████████| 157/157 [00:02<00:00, 70.53it/s, Train Loss=0.0562, Test Loss=0.146]


VBox(children=(Label(value='0.001 MB of 0.036 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.027022…

0,1
Test Loss,█▇▆▅▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Train Loss,█▇▇▆▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Test Loss,0.14618
Train Loss,0.05622


In [9]:
model

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (ln_final): LayerNorm(
    (hook_scale): HookPoint()
    (hook_normalized): HookPoint()
  )
  (unembed): Unembed()
)

In [10]:
with t.inference_mode():
    tests = t.randint(0, config.vocab_size, (20, 13))
    for test in tests:
        print(test)
        print(model(test.unsqueeze(0).to(device)).argmax(-1))
        print()

tensor([467, 655, 605,  67, 369,  94, 474, 218, 798, 169, 374, 816, 107])
tensor([[467, 655, 655, 655, 655, 655, 655, 655, 798, 798, 798, 816, 816]])

tensor([332, 784, 158, 359, 176, 349, 689, 996, 728, 988, 829,  31, 715])
tensor([[332, 784, 784, 784, 784, 784, 784, 996, 996, 996, 996, 996, 996]])

tensor([700, 961, 618, 702, 268, 569,  66, 712, 105, 176,   0,  32, 570])
tensor([[700, 961, 961, 961, 961, 961, 961, 961, 961, 961, 961, 961, 961]])

tensor([318, 473, 159, 728, 713, 184, 789, 310, 495, 219, 892, 532, 739])
tensor([[318, 473, 473, 728, 728, 728, 789, 789, 789, 789, 892, 892, 892]])

tensor([123,  32, 663, 810, 364, 330, 806, 955, 513, 791,  91, 803, 848])
tensor([[123, 123, 663, 810, 810, 810, 806, 955, 955, 955, 955, 955, 955]])

tensor([343, 655, 159,  13, 355, 465, 366, 959, 117, 156, 583, 172, 302])
tensor([[343, 655, 655, 655, 655, 655, 655, 959, 959, 959, 959, 959, 959]])

tensor([174, 968, 203,  48, 108, 679, 411, 586, 747, 780, 209,  60, 231])
tensor([[174, 968, 9

In [11]:
with t.inference_mode():
    tests = t.randint(0, config.vocab_size, (20, 13))
    for test in tests:
        print(model(test.unsqueeze(0).to(device)).shape)
        break


torch.Size([1, 13, 1000])


In [12]:
model

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (ln_final): LayerNorm(
    (hook_scale): HookPoint()
    (hook_normalized): HookPoint()
  )
  (unembed): Unembed()
)

In [13]:
def view_post_final_ln_logit_hook(ts, hook: HookPoint):
    '''View the logits after the final layer norm.'''
    # clean = clean_cache[hook.name]
    # need to use the hook that gets the thing before the final layer norm
    return ts


In [14]:
trainset.tensors[0].shape

torch.Size([10000, 16])

In [15]:
logits, cache = model.run_with_cache(trainset.tensors[0])
logits = logits[0,0,:]
print(logits.shape)

torch.Size([1000])


In [16]:
vocablen = logits.shape[-1]
# diffdir = logits - ((logits.sum(-1, keepdim=True) - logits) / (vocablen - 1))

# diffdir shape is (batch_size, seq_len)

diffdir = logits - ((logits.sum(-1, keepdim=True) - logits) / (vocablen - 1))
diffdir.shape

torch.Size([1000])

In [17]:
def view_post_final_ln_logit_hook(ts, hook: HookPoint):
    '''View the logits after the final layer norm.'''
    # clean = clean_cache[hook.name]
    # need to use the hook that gets the thing before the final layer norm
    print(ts.shape)
    return ts

ln_final = model.ln_final
inputs = t.randint(0, config.vocab_size, (1, 13)).to(device)

model.run_with_hooks(inputs, fwd_hooks=[("ln_final.hook_normalized", view_post_final_ln_logit_hook)])


torch.Size([1, 13, 64])


tensor([[[-0.7695, -0.2882,  0.6947,  ..., -1.3604, -2.7450,  0.4886],
         [-0.6885, -0.5564,  0.1146,  ..., -1.3411, -2.5370,  0.8246],
         [-0.7988, -0.1549,  0.5057,  ..., -1.4457, -2.5238,  1.1742],
         ...,
         [-2.6439, -5.3273, -6.2290,  ...,  3.4536,  3.6812,  0.7984],
         [-2.9570, -5.3918, -6.3698,  ...,  3.5195,  3.4453,  1.1000],
         [-2.5041, -5.4830, -6.3181,  ...,  4.2196,  3.7131,  0.9899]]],
       grad_fn=<AddBackward0>)

In [18]:
def get_post_final_ln_dir(model) -> t.Tensor: 
    """
    Returns the "direction" that the post final layer norm layer points to to get the logits.

    We get this by doing a calculation on the logits
    """
    return model.decoder.weight[0, :] - model.decoder.weight[1, :]

In [22]:
data = trainset.tensors[0].to(device)
logits, cache = model.run_with_cache(data)

In [23]:
from typing import Union, Tuple
from sklearn.linear_model import LinearRegression

def get_ln_fit(
    model, data, seq_pos: Union[None, int]
) -> Tuple[LinearRegression, t.Tensor]:
    '''
    if seq_pos is None, find best fit for all sequence positions. Otherwise, fit only for given seq_pos.

    Returns: A tuple of a (fitted) sklearn LinearRegression object and a dimensionless tensor containing the r^2 of the fit (hint: wrap a value in torch.tensor() to make a dimensionless tensor)
    '''
    inp = cache['blocks.0.hook_resid_post']
    out = cache['ln_final.hook_normalized']
    if seq_pos is None:
        inp = rearrange(inp, 'b s e -> (b s) e')
        out = rearrange(out, 'b s e -> (b s) e')
    else:
        inp = inp[:, seq_pos, :]
        out = out[:, seq_pos, :]
    print(f'inp shape: {inp.shape}')
    print(f'out shape: {out.shape}')

    linear_regression = LinearRegression().fit(inp, out)
    r2 = linear_regression.score(inp, out)
    return linear_regression, t.tensor(r2)
    


if True:
    (final_ln_fit, r2) = get_ln_fit(model, data, seq_pos=0)
    print("r^2: ", r2)
    print(r2.shape)

inp shape: torch.Size([10000, 64])
out shape: torch.Size([10000, 64])
r^2:  tensor(0.9762, dtype=torch.float64)
torch.Size([])


In [24]:
def get_L(model, data):
    (final_ln_fit, r2) = get_ln_fit(model, data, model.norm, seq_pos=0)
    L = t.from_numpy(final_ln_fit.coef_)
    return L

def get_pre_final_ln_dir(model, data) -> t.Tensor:
    post_final_ln_dir = get_post_final_ln_dir(model)
    L = get_L(model, data)
    return t.einsum("i,ij->j", post_final_ln_dir, L)

In [30]:
def get_out_by_head(model, data, layer: int) -> t.Tensor:
    '''
    Get the output of the heads in a particular layer when the model is run on the data.
    Returns a tensor of shape (batch, num_heads, seq, embed_width)
    '''
    module = model.layers[layer].self_attn.W_O
    r = cache[f'blocks.{layer}.hook_z']
    # print(f'r shape: {r.shape}')
    r = rearrange(r, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=model.nhead)
    # print(f'r shape: {r.shape}')
    W_O = module.weight
    # print(f'weight shape: {W_O.shape}')
    W_O = rearrange(W_O, 'emb (nheads headsize) -> nheads emb headsize', nheads=model.nhead)
    # print(f'weight shape: {W_O.shape}')
    out = einsum('batch nheads seq headsize, nheads emb headsize -> batch nheads seq emb', r, W_O)
    # print(f'out shape: {out.shape}')
    return out

ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_post', 'ln_final.hook_scale', 'ln_final.hook_normalized']

In [None]:
def get_out_by_components(model, data) -> t.Tensor:
    '''
    Computes a tensor of shape [10, dataset_size, seq_pos, emb] representing the output of the model's components when run on the data.
    The first dimension is  [embeddings, head 0.0, head 0.1, mlp 0, head 1.0, head 1.1, mlp 1, head 2.0, head 2.1, mlp 2]
    '''
    out = []
    out.append(cache['hook_pos_embed'])
    for i in range(model.nlayers):
        out.append(get_out_by_head(model, data, i)[:, 0, :, :])
        out.append(get_out_by_head(model, data, i)[:, 1, :, :])
    return t.stack(out, dim=0)

if MAIN:
    w5d5_tests.test_get_out_by_component(get_out_by_components, model, data)

In [27]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

def hists_per_comp(magnitudes, data, n_layers=3, xaxis_range=(-1, 1)):
    num_comps = magnitudes.shape[0]
    titles = {
        (1, 1): "embeddings",
        (2, 1): "head 0.0",
        (2, 2): "head 0.1",
        (2, 3): "mlp 0",
        (3, 1): "head 1.0",
        (3, 2): "head 1.1",
        (3, 3): "mlp 1",
        (4, 1): "head 2.0",
        (4, 2): "head 2.1",
        (4, 3): "mlp 2"
    }
    assert num_comps == len(titles)

    fig = make_subplots(rows=n_layers+1, cols=3)
    for ((row, col), title), mag in zip(titles.items(), magnitudes):
        if row == n_layers+2: break
        fig.add_trace(go.Histogram(x=mag[data.isbal].numpy(), name="Balanced", marker_color="blue", opacity=0.5, legendgroup = '1', showlegend=title=="embeddings"), row=row, col=col)
        fig.add_trace(go.Histogram(x=mag[~data.isbal].numpy(), name="Unbalanced", marker_color="red", opacity=0.5, legendgroup = '2', showlegend=title=="embeddings"), row=row, col=col)
        fig.update_xaxes(title_text=title, row=row, col=col, range=xaxis_range)
    fig.update_layout(width=1200, height=250*(n_layers+1), barmode="overlay", legend=dict(yanchor="top", y=0.92, xanchor="left", x=0.4), title="Histograms of component significance")
    fig.show()

if True:
    out_by_components = get_out_by_components(model, data)[:, :, 0, :].detach()
    unbalanced_dir = get_pre_final_ln_dir(model, data).detach()
    magnitudes = einsum("component sample emb, emb -> component sample", out_by_components, unbalanced_dir)
    magnitudes = magnitudes - magnitudes[:, data.isbal].mean(-1, keepdim=True)

    assert "magnitudes" in locals(), "You need to define `magnitudes`"
    hists_per_comp(magnitudes, data, xaxis_range=[-10, 20])

NameError: name 'get_out_by_components' is not defined