In [1]:
import os
import sys
import requests
from pathlib import Path
import numpy as np
import plotly.express as px
from tqdm import tqdm
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass
from fancy_einsum import einsum
import math 

device = t.device("cuda" if t.cuda.is_available() else "cpu")

Create the toy dataset

There are two options for input sequences: noise and signal. The signal contains a 0 preceded by 2 numbers less in [1,4], and has a ground truth answer given by an arbitrary map. 

In [2]:
ans_map = {(1, 1):0, (1, 2):1, (1, 3):2, (1, 4):3, (2, 2):4, (2, 3):5, (2, 4):6, (3,3):7, (3, 4):8, (4, 4):9}
class SkipTrigramDataset(Dataset):

    def __init__(self, size: int, prob:float, prompt_length: int, seed:int = 42):
        self.size = size
        self.prob = prob
        self.prompt_length = prompt_length
        self.data = [self.generate_sample() for _ in range(size)]
        t.manual_seed(seed=seed)

    def generate_sample(self):
        random_tensor = t.randint(low = 5, high = 11, size = (self.prompt_length,))
        random_int = t.randint(low = 0, high = self.prob, size = (1,))
        if random_int.item() == self.prob - 1:
            rand_num_1 = t.randint(1, 5, size = (1,)).item()
            rand_num_2 = t.randint(1, 5, size = (1,)).item()
            random_tensor[0] = rand_num_1
            random_tensor[-2] = rand_num_2
            sample = random_tensor
            sample[-1] = 0
            label = ans_map[tuple(sorted((rand_num_1, rand_num_2)))]
            label = t.tensor([label])
        else:
            sample = random_tensor
            label = t.randint(low = 0, high = 11, size = (1,))

        return sample, label


    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.size

In [3]:
ans_map = {(1, 1):0, (1, 2):1, (1, 3):2, (1, 4):3, (2, 2):4, (2, 3):5, (2, 4):6, (3,3):7, (3, 4):8, (4, 4):9}

example_dataset = SkipTrigramDataset(size = 10_000, prob = 2, prompt_length = 4)
for i in range(10):
    print(f'Sample: {example_dataset[i][0]}, label: {example_dataset[i][1][-1]}')



Sample: tensor([10, 10, 10,  6]), label: 1
Sample: tensor([1, 7, 1, 0]), label: 0
Sample: tensor([ 6, 10,  7,  8]), label: 4
Sample: tensor([9, 5, 9, 9]), label: 10
Sample: tensor([6, 6, 7, 7]), label: 6
Sample: tensor([9, 5, 9, 8]), label: 3
Sample: tensor([10,  6,  8, 10]), label: 4
Sample: tensor([2, 5, 3, 0]), label: 5
Sample: tensor([7, 9, 6, 5]), label: 8
Sample: tensor([3, 9, 1, 0]), label: 2


In [4]:
signal = [example[1].item() for example in example_dataset if example[0][-1] == 0]
noise = [example[1].item() for example in example_dataset if example[0][-1] != 0]

print(len(signal), len(noise))

px.histogram(signal)
#px.histogram(noise)

## repeated ints < 5 are half as likely in the signal dataset. This shouldn't matter, but could check to see if those outputs (0, 4, 7, 9) are treated systematically differently by the model


4930 5070


Model Config

In [5]:
@dataclass
class model_cfg:
    d_model: int = 12
    d_vocab: int = 12
    init_range: float = 0.02
    n_ctx: int = 5
    d_head: int = 3
    n_heads: int = 3
    n_layers: int = 1

model_config = model_cfg()

print(model_config)

model_cfg(d_model=12, d_vocab=12, init_range=0.02, n_ctx=5, d_head=3, n_heads=3, n_layers=1)


One-Hot Embed and Unembed (overkill, since just the trivial embedding)

In [6]:
class OneHotEmbed(nn.Module):
    def __init__(self, model_cfg):
        super().__init__()
        self.model_cfg = model_cfg
        self.embedding = nn.Embedding(self.model_cfg.d_vocab, self.model_cfg.d_vocab)
        self.embedding.weight = nn.Parameter(t.eye(self.model_cfg.d_vocab))
        self.embedding.weight.requires_grad = False

    def forward(self, input):
        # tokens: [batch, position]
        # output: [batch, position, d_model]
        input = input.to(device)
        embed_out = self.embedding(input)
        return embed_out
    
class OneHotUnembed(nn.Module):
    def __init__(self, model_cfg):
        super().__init__()
        self.model_cfg = model_cfg
        self.unembed = t.eye(self.model_cfg.d_vocab)

    def forward(self, input):
        # tokens: [batch, position, d_model]
        # output: [batch, position]
        input = input.to(device)
        return t.einsum("vm,bpm->bpv", self.unembed.to(device), input)  # [batch, pos, d_vocab]


In [7]:
class Attention(nn.Module):
    def __init__(self, model_cfg):
        super().__init__()
        self.model_cfg = model_cfg

        # set weights, biases as parameters 
        self.W_Q = nn.Parameter(t.empty((model_cfg.n_heads, model_cfg.d_model, model_cfg.d_head)))
        self.W_K = nn.Parameter(t.empty((model_cfg.n_heads, model_cfg.d_model, model_cfg.d_head)))
        self.W_V = nn.Parameter(t.empty((model_cfg.n_heads, model_cfg.d_model, model_cfg.d_head)))
        self.W_O = nn.Parameter(t.empty((model_cfg.n_heads, model_cfg.d_head, model_cfg.d_model)))
        self.b_Q = nn.Parameter(t.zeros((model_cfg.n_heads, model_cfg.d_head)))
        self.b_K = nn.Parameter(t.zeros((model_cfg.n_heads, model_cfg.d_head)))
        self.b_V = nn.Parameter(t.zeros((model_cfg.n_heads, model_cfg.d_head)))
        self.b_O = nn.Parameter(t.zeros((model_cfg.d_model)))

        #initialize randomly
        nn.init.normal_(self.W_Q, std=self.model_cfg.init_range)
        nn.init.normal_(self.W_K, std=self.model_cfg.init_range)
        nn.init.normal_(self.W_V, std=self.model_cfg.init_range)
        nn.init.normal_(self.W_O, std=self.model_cfg.init_range)
        self.register_buffer("IGNORE", t.tensor(-1e5, dtype=t.float32, device=device))

        self.attn_activation_cache = {}

    def forward(self, normalized_resid_pre: t.Tensor):
        # normalized_resid_pre: [batch, position, d_model] (we're not using LN, so just input)
        # output: [batch, position, d_model]
        normalized_resid_pre = normalized_resid_pre.to(device)

        queries = einsum('n_heads d_model d_head, batch pos_q d_model -> batch pos_q n_heads d_head', self.W_Q, normalized_resid_pre) + self.b_Q
        keys = einsum('n_heads d_model d_head, batch pos_k d_model -> batch pos_k n_heads d_head', self.W_K, normalized_resid_pre) + self.b_K
        values = einsum('n_heads d_model d_head, batch pos_k d_model -> batch pos_k n_heads d_head', self.W_V, normalized_resid_pre) + self.b_V

        attn_scores = einsum('batch pos_q n_heads d_head, batch pos_k n_heads d_head -> batch n_heads pos_q pos_k', queries, keys)

        scaled_attn = attn_scores / math.sqrt(self.model_cfg.d_head)
        masked_attn = self.apply_causal_mask(scaled_attn)
        # apply softmax over keys to get attention pattern
        softmax_attn = masked_attn.softmax(dim = -1)

        # get weighted attention values
        z = einsum('batch pos_k n_heads d_head, batch n_heads pos_q pos_k -> batch pos_q n_heads d_head', values, softmax_attn)

        # layer output
        output = einsum('n_heads d_head d_model, batch pos_q n_heads d_head -> batch pos_q d_model', self.W_O, z) + self.b_O

        ## save to the cache 
        self.attn_activation_cache['hook_q'] = queries
        self.attn_activation_cache['hook_k'] = keys
        self.attn_activation_cache['hook_attn_scores'] = attn_scores
        self.attn_activation_cache['hook_pattern'] = softmax_attn
        self.attn_activation_cache['hook_v'] = values
        self.attn_activation_cache['hook_z'] = z
        self.attn_activation_cache['hook_attn_out'] = output

        return output

    def apply_causal_mask(self, attn_scores: t.Tensor):
        # attn_scores: [batch, n_heads, query_pos, key_pos]
        # output: [batch, n_heads, query_pos, key_pos]
        mask = t.triu(attn_scores, diagonal=1).bool()
        return attn_scores.masked_fill_(mask, self.IGNORE)

define a one layer attn only class      

In [8]:
class OneLayerAttnOnly(nn.Module):
    def __init__(self, model_cfg):
        super().__init__()
        self.model_cfg = model_cfg
        self.embed = OneHotEmbed(model_cfg)
        self.attn = Attention(model_cfg)
        self.unembed = OneHotUnembed(model_cfg)

        self.activation_cache = {}


    def forward(self, tokens):
        # tokens [batch, pos]
        # output [batch, pos, d_vocab]

        input = tokens.to(device)
        resid_pre = self.embed(input)
        resid_post = resid_pre + self.attn(resid_pre)
        out = self.unembed(resid_post)

        # add activations to the same cache as before
        self.activation_cache = self.attn.attn_activation_cache
        self.activation_cache['hook_resid_pre'] = resid_pre
        self.activation_cache['hook_resid_post'] = resid_post
        self.activation_cache['unembed'] = out

        return out
    

attn_only_1l = OneLayerAttnOnly(model_cfg).to(device)
attn_only_1l

        

OneLayerAttnOnly(
  (embed): OneHotEmbed(
    (embedding): Embedding(12, 12)
  )
  (attn): Attention()
  (unembed): OneHotUnembed()
)

In [9]:
def plot_metrics(train_loss,train_acc, test_loss, test_acc):
    fig = make_subplots(rows=2,cols=2, subplot_titles = ['Train Loss', 'Train Accuracy', 'Test Loss', 'Test Accuracy'])
    fig.add_trace(go.Scatter(y=train_loss), row=1, col=1)
    fig.add_trace(go.Scatter(y=train_acc), row=1, col=2)
    fig.add_trace(go.Scatter(y=test_loss), row=2, col=1)
    fig.add_trace(go.Scatter(y=test_acc), row=2, col=2)
    # get rid of legend
    fig.update_layout(showlegend=False)
    ## add x and y axis labels and change plot style 
    fig.update_xaxes(
        title = 'Epoch',
        zeroline=True,
        zerolinewidth=1,
        zerolinecolor='black',
        showgrid=False
    )
    fig.update_yaxes(
        title = 'Loss',
        showgrid=False,
        col=1
    )
    fig.update_yaxes(
        title = 'Accuracy',
        showgrid=False,
        col=2
    )
    fig.update_layout(
        font = dict(size =12),
        width=800,
        height =500,
        margin=dict(l=10, 
                    r = 20,
                    b=10,
                    t=20,
                    pad=0)
    )
    fig.show()

In [10]:
@dataclass
class train_cfg:
    ctx_length: int = 4
    trainset_size: int = 100000
    valset_size: int = 1000
    prob: int = 2
    epochs: int = 15
    batch_size: int = 512
    weight_decay: float = 0.0
    seed: int = 42
    lr: float = 1e-3

class Trainer:

    def __init__(self, train_cfg, model_cfg):
        self.train_cfg = train_cfg
        self.model = OneLayerAttnOnly(model_cfg).to(device)


    def train_dataloader(self, seed: int):
        trainset = SkipTrigramDataset(size=self.train_cfg.trainset_size,
                                      prob=self.train_cfg.prob,
                                      prompt_length=self.train_cfg.ctx_length)

        return DataLoader(trainset, batch_size=self.train_cfg.batch_size, shuffle=True)

    def val_dataloader(self, seed: int):
        valset = SkipTrigramDataset(size=self.train_cfg.valset_size,
                                    prob = self.train_cfg.prob,
                                    prompt_length=self.train_cfg.ctx_length)
        return DataLoader(valset, batch_size=self.train_cfg.batch_size, shuffle=False)

    def configure_optimizers(self):
        optimizer = t.optim.Adam(self.model.parameters(), lr=self.train_cfg.lr, weight_decay=self.train_cfg.weight_decay)
        return optimizer

In [49]:



def train(model, train_data, criterion, optimizer):

    avg_loss = 0
    acc = 0

    for _, batch in list(enumerate(train_data)):

        X = batch[0].to(device)
        y = batch[1].squeeze().to(device)
        assert X.shape[0] == y.shape[0]

        output = model(X)[:,-1]

        with t.no_grad():
            corr = (output.argmax(dim=-1) == y).count_nonzero()
        loss = criterion(output,y)
        loss.backward()

        optimizer.step()
        avg_loss += loss.item() * len(batch)
        acc += corr.item()
        optimizer.zero_grad()

    cache = model.activation_cache

    return avg_loss/train_cfg.trainset_size, acc/train_cfg.trainset_size, cache
    
def test(model, test_data, criterion):

    avg_loss = 0
    acc = 0

    for _, batch in list(enumerate(test_data)):
        X = batch[0].to(device)
        y = batch[1].squeeze().to(device)
        assert X.shape[0] == y.shape[0]

        # don't log the gradients for testing
        with t.no_grad():
            output = model(X)[:,-1]
        loss = criterion(output,y)
        corr = (output.argmax(dim=-1) == y).count_nonzero()
        avg_loss += loss.item() * len(batch)
        acc += corr.item()


    return avg_loss/train_cfg.valset_size, acc/train_cfg.valset_size

In [12]:
model_cfg.d_model, model_cfg.n_heads
11//3

3

In [52]:
from tqdm import tqdm 

def train_and_eval(train_cfg, model_cfg, save_caches = []):

    trainer = Trainer(train_cfg, model_cfg)
    model = trainer.model
    train_data = trainer.train_dataloader(seed=train_cfg.seed)

    test_data = trainer.val_dataloader(seed=train_cfg.seed)
    criterion = nn.CrossEntropyLoss()
    optimizer = trainer.configure_optimizers()
    train_loss = []
    train_acc = []
    test_loss = []
    test_acc = []

    caches = []

    for n in tqdm(range(train_cfg.epochs)):
        loss, acc, cache = train(model, train_data, criterion, optimizer)

        train_loss.append(loss)
        train_acc.append(acc)
        t_loss, t_acc = test(model, test_data, criterion)
        test_loss.append(t_loss)
        test_acc.append(t_acc)
        print(f"train loss = {loss}, train acc = {acc}, test loss = {t_loss}, test acc = {t_acc}")

        if n in save_caches:
            caches.append(cache)

    return train_loss, train_acc, test_loss, test_acc, caches, trainer.model



In [184]:
train_config = train_cfg(
    ctx_length=4,
    weight_decay=0,
    trainset_size = 100000,
    prob = 2,
    epochs = 50,
    batch_size = 512
)

model_config = model_cfg(
    d_model = 11,
    d_vocab = 11,
    d_head = 1,
    init_range = 0.02,
    n_ctx = 4,
    n_heads = 2,
    n_layers = 1)

train_loss, train_acc, test_loss, test_acc, caches, pre_trained = train_and_eval(train_config, model_config)

  2%|▏         | 1/50 [00:00<00:21,  2.29it/s]

train loss = 0.009545068707466126, train acc = 0.07677, test loss = 0.009633139133453369, test acc = 0.084


  4%|▍         | 2/50 [00:00<00:20,  2.36it/s]

train loss = 0.009345167083740235, train acc = 0.10393, test loss = 0.009463642120361328, test acc = 0.114


  6%|▌         | 3/50 [00:01<00:20,  2.27it/s]

train loss = 0.009077833290100097, train acc = 0.1735, test loss = 0.008978469848632813, test acc = 0.26


  8%|▊         | 4/50 [00:02<00:24,  1.85it/s]

train loss = 0.008420146040916444, train acc = 0.26673, test loss = 0.008232321739196777, test acc = 0.26


 10%|█         | 5/50 [00:02<00:22,  2.02it/s]

train loss = 0.007796144764423371, train acc = 0.26669, test loss = 0.007728257417678833, test acc = 0.261


 12%|█▏        | 6/50 [00:02<00:20,  2.12it/s]

train loss = 0.007369661312103272, train acc = 0.26757, test loss = 0.007366776466369629, test acc = 0.281


 14%|█▍        | 7/50 [00:03<00:19,  2.20it/s]

train loss = 0.007053357949256897, train acc = 0.36838, test loss = 0.007090522527694702, test acc = 0.382


 16%|█▌        | 8/50 [00:03<00:18,  2.25it/s]

train loss = 0.006804799084663391, train acc = 0.38996, test loss = 0.006862983465194702, test acc = 0.382


 18%|█▊        | 9/50 [00:04<00:17,  2.29it/s]

train loss = 0.006596111154556274, train acc = 0.39011, test loss = 0.00666984224319458, test acc = 0.383


 20%|██        | 10/50 [00:04<00:17,  2.32it/s]

train loss = 0.0064112615251541135, train acc = 0.38999, test loss = 0.006483502626419067, test acc = 0.381


 22%|██▏       | 11/50 [00:04<00:16,  2.34it/s]

train loss = 0.00622092627286911, train acc = 0.44583, test loss = 0.006290257453918457, test acc = 0.439


 24%|██▍       | 12/50 [00:05<00:16,  2.35it/s]

train loss = 0.006046051731109619, train acc = 0.45217, test loss = 0.00613145923614502, test acc = 0.441


 26%|██▌       | 13/50 [00:05<00:15,  2.37it/s]

train loss = 0.005900594561100006, train acc = 0.46056, test loss = 0.00599289870262146, test acc = 0.475


 28%|██▊       | 14/50 [00:06<00:15,  2.37it/s]

train loss = 0.005782149085998535, train acc = 0.48256, test loss = 0.005879133939743042, test acc = 0.482


 30%|███       | 15/50 [00:06<00:14,  2.36it/s]

train loss = 0.005682502586841583, train acc = 0.48236, test loss = 0.005775232076644898, test acc = 0.483


 32%|███▏      | 16/50 [00:07<00:14,  2.36it/s]

train loss = 0.005590680098533631, train acc = 0.48797, test loss = 0.005681214570999146, test acc = 0.501


 34%|███▍      | 17/50 [00:07<00:13,  2.37it/s]

train loss = 0.005507345016002655, train acc = 0.50784, test loss = 0.005596549272537232, test acc = 0.52


 36%|███▌      | 18/50 [00:07<00:13,  2.38it/s]

train loss = 0.005436308195590973, train acc = 0.51348, test loss = 0.005525225639343261, test acc = 0.52


 38%|███▊      | 19/50 [00:08<00:12,  2.39it/s]

train loss = 0.005375001130104065, train acc = 0.51348, test loss = 0.0054634280204772945, test acc = 0.52


 40%|████      | 20/50 [00:08<00:12,  2.38it/s]

train loss = 0.0053241491866111755, train acc = 0.51348, test loss = 0.005411920070648194, test acc = 0.52


 42%|████▏     | 21/50 [00:09<00:12,  2.37it/s]

train loss = 0.005276197123527527, train acc = 0.51353, test loss = 0.005365716457366943, test acc = 0.52


 44%|████▍     | 22/50 [00:09<00:11,  2.37it/s]

train loss = 0.0052374373960495, train acc = 0.52213, test loss = 0.005325216054916382, test acc = 0.547


 46%|████▌     | 23/50 [00:09<00:11,  2.36it/s]

train loss = 0.0052013473773002625, train acc = 0.53335, test loss = 0.005288618326187134, test acc = 0.547


 48%|████▊     | 24/50 [00:10<00:13,  1.87it/s]

train loss = 0.005169225826263428, train acc = 0.53206, test loss = 0.005255936861038208, test acc = 0.52


 50%|█████     | 25/50 [00:11<00:12,  2.00it/s]

train loss = 0.0051414821600914, train acc = 0.52179, test loss = 0.005227025270462036, test acc = 0.52


 52%|█████▏    | 26/50 [00:11<00:11,  2.08it/s]

train loss = 0.005112387092113495, train acc = 0.5149, test loss = 0.005200297355651856, test acc = 0.52


 54%|█████▍    | 27/50 [00:12<00:10,  2.17it/s]

train loss = 0.005089211678504944, train acc = 0.52044, test loss = 0.0051757657527923585, test acc = 0.547


 56%|█████▌    | 28/50 [00:12<00:09,  2.24it/s]

train loss = 0.005068411169052124, train acc = 0.54331, test loss = 0.005154607295989991, test acc = 0.547


 58%|█████▊    | 29/50 [00:12<00:09,  2.28it/s]

train loss = 0.0050457572078704835, train acc = 0.54452, test loss = 0.005134851455688477, test acc = 0.547


 60%|██████    | 30/50 [00:13<00:08,  2.32it/s]

train loss = 0.005026319859027862, train acc = 0.54452, test loss = 0.005116821527481079, test acc = 0.547


 62%|██████▏   | 31/50 [00:13<00:08,  2.34it/s]

train loss = 0.005009707839488983, train acc = 0.54454, test loss = 0.005097129821777344, test acc = 0.547


 64%|██████▍   | 32/50 [00:14<00:08,  2.17it/s]

train loss = 0.004989639592170715, train acc = 0.54457, test loss = 0.005077495813369751, test acc = 0.547


 66%|██████▌   | 33/50 [00:14<00:07,  2.23it/s]

train loss = 0.00496987776517868, train acc = 0.54459, test loss = 0.0050588083267211915, test acc = 0.547


 68%|██████▊   | 34/50 [00:15<00:07,  2.26it/s]

train loss = 0.004949586892127991, train acc = 0.5446, test loss = 0.005038021087646484, test acc = 0.547


 70%|███████   | 35/50 [00:15<00:06,  2.29it/s]

train loss = 0.004929448466300964, train acc = 0.5446, test loss = 0.005017090559005738, test acc = 0.547


 72%|███████▏  | 36/50 [00:15<00:06,  2.31it/s]

train loss = 0.0049099418115615846, train acc = 0.5446, test loss = 0.004999287366867065, test acc = 0.547


 74%|███████▍  | 37/50 [00:16<00:05,  2.33it/s]

train loss = 0.004894062378406524, train acc = 0.5446, test loss = 0.004982048511505127, test acc = 0.547


 76%|███████▌  | 38/50 [00:16<00:05,  2.35it/s]

train loss = 0.00487897686958313, train acc = 0.5446, test loss = 0.00496993899345398, test acc = 0.547


 78%|███████▊  | 39/50 [00:17<00:04,  2.36it/s]

train loss = 0.0048679304242134095, train acc = 0.5446, test loss = 0.0049573581218719485, test acc = 0.547


 80%|████████  | 40/50 [00:17<00:04,  2.36it/s]

train loss = 0.004856181149482727, train acc = 0.5446, test loss = 0.0049439573287963865, test acc = 0.547


 82%|████████▏ | 41/50 [00:18<00:03,  2.37it/s]

train loss = 0.0048461181139945984, train acc = 0.5446, test loss = 0.0049365942478179934, test acc = 0.547


 84%|████████▍ | 42/50 [00:18<00:03,  2.37it/s]

train loss = 0.004840160622596741, train acc = 0.5446, test loss = 0.004926082849502564, test acc = 0.547


 86%|████████▌ | 43/50 [00:18<00:03,  2.27it/s]

train loss = 0.004827325692176819, train acc = 0.5446, test loss = 0.004913423776626587, test acc = 0.547


 88%|████████▊ | 44/50 [00:19<00:02,  2.17it/s]

train loss = 0.004818317217826843, train acc = 0.5446, test loss = 0.004904419183731079, test acc = 0.547


 90%|█████████ | 45/50 [00:19<00:02,  2.18it/s]

train loss = 0.004810175030231476, train acc = 0.5446, test loss = 0.004899523496627808, test acc = 0.547


 92%|█████████▏| 46/50 [00:20<00:01,  2.24it/s]

train loss = 0.004807646968364715, train acc = 0.5446, test loss = 0.004896140098571778, test acc = 0.547


 94%|█████████▍| 47/50 [00:20<00:01,  2.28it/s]

train loss = 0.004804169480800629, train acc = 0.5446, test loss = 0.004892194032669067, test acc = 0.547


 96%|█████████▌| 48/50 [00:21<00:00,  2.31it/s]

train loss = 0.004800879147052765, train acc = 0.5446, test loss = 0.004887486934661865, test acc = 0.547


 98%|█████████▊| 49/50 [00:21<00:00,  2.32it/s]

train loss = 0.004799224174022674, train acc = 0.5446, test loss = 0.004888589143753052, test acc = 0.547


100%|██████████| 50/50 [00:22<00:00,  2.27it/s]

train loss = 0.004796865525245666, train acc = 0.5446, test loss = 0.004883102416992188, test acc = 0.547





In [185]:

plot_metrics(train_loss,train_acc, test_loss, test_acc) 

In [105]:
model_config

model_cfg(d_model=11, d_vocab=11, init_range=0.02, n_ctx=4, d_head=3, n_heads=2, n_layers=1)

Examining head outputs

In [186]:
from transformer_lens import utils, HookedTransformer, ActivationCache

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [213]:
def get_head_to_resid(model, cache, head, seq):
  model(seq)
  z = cache["hook_z"][:, :, head, :]
  wo = model.attn.W_O[head, :, :]
  bo = model.attn.b_O
  out = einsum("batch pos d_head, d_head d_model -> batch pos d_model", z, wo) + bo
  return out

def get_head_to_logits(model, head_to_resid):
  return model.unembed(head_to_resid)

In [188]:
list(ans_map.keys())

[(1, 1),
 (1, 2),
 (1, 3),
 (1, 4),
 (2, 2),
 (2, 3),
 (2, 4),
 (3, 3),
 (3, 4),
 (4, 4)]

In [236]:
def show_pattern(cache, seq):
    X = t.Tensor(seq).unsqueeze(0).int()
    pre_trained(X)
    ax_label = [f"{i}: {z.item()}" for i, z in enumerate(X[0])]
    imshow(cache['hook_pattern'][0], facet_col = 0, x = ax_label, y = ax_label)



The following attention patterns show that if you change the noise token, sometimes the destination noise tokens will attend to themselves randomly, but the last destination token (0) will attend to the other two tokens less than 5, favoring one over the other. The hierarchy of prefered tokens is switched between the heads, so the ordering is accounted for. 


** If the signal token is repeated, it will attend equally to both. 


In [246]:
show_pattern(pre_trained.activation_cache, [1,5,2,0])
show_pattern(pre_trained.activation_cache, [1,8,2,0])
show_pattern(pre_trained.activation_cache, [3,5,3,0])
show_pattern(pre_trained.activation_cache, [3,8,3,0])

In [253]:
#seqs = [t.Tensor([i,5,j,0]).unsqueeze(0).int() for i,j in list(ans_map.keys())]
seqs = [[i,5,j,0] for i,j in list(ans_map.keys())]
vals = ans_map.values()
assert [pre_trained(t.Tensor(seq).unsqueeze(0).int())[0,-1].argmax().item() for seq in seqs] == list(ans_map.values())
seqs, vals

([[1, 5, 1, 0],
  [1, 5, 2, 0],
  [1, 5, 3, 0],
  [1, 5, 4, 0],
  [2, 5, 2, 0],
  [2, 5, 3, 0],
  [2, 5, 4, 0],
  [3, 5, 3, 0],
  [3, 5, 4, 0],
  [4, 5, 4, 0]],
 dict_values([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))

In [293]:
for seq,val in zip(seqs,vals):
    res = []
    X = t.Tensor(seq).unsqueeze(0).int()
    for i in range(model_cfg.n_heads - 1):
        temp = get_head_to_resid(pre_trained, pre_trained.activation_cache, i, X)
        temp = get_head_to_logits(pre_trained, temp)
        res.append(temp[0,-1])

    res = t.stack(res)
    imshow(res, y = list(range(model_config.n_heads)), title = f"{seq} -> {val}")

Get the attention ranking of each token per head

In [292]:
seq = seqs[2]
keys = seq[:-1]

X = t.Tensor(seq).unsqueeze(0).int()
logits = pre_trained(X)
scores = pre_trained.activation_cache['hook_attn_scores'][0]
rankings = {}

for head in range(model_config.n_heads):
    ## get scores of attention from last query pos to all other (key) positions
    rankings[head] = sorted(list(zip(range(len(keys)),keys, scores[head,-1,:-1].detach())), key = lambda x: x[-1], reverse = True)

rankings



{0: [(0, 1, tensor(12.5718)), (2, 3, tensor(3.8726)), (1, 5, tensor(-2.3139))],
 1: [(2, 3, tensor(11.3493)), (0, 1, tensor(2.3414)), (1, 5, tensor(-3.3475))]}

In [290]:
def get_head_to_resid(model, cache, head, seq):
  model(seq)
  z = cache["hook_z"][:, :, head, :]
  wo = model.attn.W_O[head, :, :]
  bo = model.attn.b_O
  out = einsum("batch pos d_head, d_head d_model -> batch pos d_model", z, wo) + bo
  return out

def get_head_to_logits(model, head_to_resid):
  return model.unembed(head_to_resid)

model_cfg(d_model=11, d_vocab=11, init_range=0.02, n_ctx=4, d_head=1, n_heads=2, n_layers=1)

Get logit outputs for each head, to each possible output token

In [318]:
head_outputs = t.zeros((model_config.n_heads, model_cfg.n_ctx - 1, model_cfg.d_vocab - 1))
for i in range(1,5):
    seq = [i,0]
    X = t.Tensor(seq).unsqueeze(0).int()
    for head in range(model_config.n_heads):
        temp = get_head_to_resid(pre_trained, pre_trained.activation_cache, head, X)
        temp = get_head_to_logits(pre_trained, temp)
        #print(head_outputs[head, i-1].shape, temp[0,-1].shape)
        head_outputs[head, i-1,:] = temp[0,-1,:]


head_outputs

tensor([[[ 16.9513,  -0.6493,  26.3321,  29.9996, -12.9128, -19.9840, -25.8277,
           11.3394,  -5.3936, -28.0967,  11.8657],
         [-16.5686,   0.8765, -25.8378, -29.5320,  13.0714,  19.7836,  25.5719,
          -11.3452,   5.3020,  27.7958, -11.9892],
         [  5.8210,  -0.1426,   9.0090,  10.2320,  -4.2847,  -6.7791,  -8.7604,
            3.8069,  -1.8421,  -9.5375,   3.9446],
         [-13.2610,   0.7260, -20.6899, -23.6578,  10.5074,  15.8595,  20.5001,
           -9.1068,   4.2466,  22.2806,  -9.6354]],

        [[ 16.3343,  27.6433, -11.7089,  -2.0150,  24.9799, -13.4253,  11.5354,
          -19.1631, -19.9949,  -3.1470,  15.4797],
         [ 21.2373,  35.9965, -15.3294,  -2.6890,  32.5319, -17.4741,  15.0672,
          -24.9737, -26.0497,  -4.0642,  20.1989],
         [-21.5916, -36.9704,  16.2964,   3.1982, -33.4363,  17.8936, -15.7836,
           25.7833,  26.8397,   3.9474, -21.0240],
         [ -3.0534,  -5.3872,   2.6074,   0.6500,  -4.8824,   2.5849,  -2.4301,
 

In [321]:
print(head_outputs.shape)
imshow(head_outputs, facet_col = 0, y = [1,2,3,4])

torch.Size([2, 4, 11])


In [299]:
model_config.d_vocab

11