In [1]:
import transformer_lens
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


## Setup

In [2]:
device_name = "cpu"
if torch.cuda.is_available():
    device_name = "cuda" # CUDA for NVIDIA GPU

device = torch.device(device_name)
print(f"Device: {device_name}")

Device: cuda


In [3]:
model_name = "gpt2-small"
model = transformer_lens.HookedTransformer.from_pretrained(model_name)

Loaded pretrained model gpt2-small into HookedTransformer


## Generate dataset

### Load text

In [5]:
N_WORDS = 10000
with open('datasets/common_words.txt', 'r') as file:
    words = np.random.choice([l.rstrip("\n") for l in file.readlines()], N_WORDS)
    
word_len_dict = {w: len(model.to_tokens(f" {w}", prepend_bos=False).squeeze(0)) for w in words}
word_len = np.vectorize(lambda x: word_len_dict[x])

### generating

In [6]:
PREFIX = "Hello and welcome to my blog, where I love to list words.\nWhat "
BATCH_SIZE = 256
N_SAMPLE = 10

prefix_len = len(model.to_tokens(PREFIX, prepend_bos=True).squeeze(0)) - 1

def generate_batch():
    batch_words = []
    for i in range(BATCH_SIZE):
        sampled = np.random.choice(words, N_SAMPLE)

        batch_words.append(sampled)

    tokens = model.to_tokens([PREFIX + " ".join(s) for s in batch_words], prepend_bos=True)
    mapped_len = word_len(batch_words)

    word_idxs = np.ones((BATCH_SIZE, N_SAMPLE * 3)) * -1
    for i, r in enumerate(mapped_len):
        row = np.repeat(np.arange(N_SAMPLE), r)
        word_idxs[i, :len(row)] = row

    return tokens, word_idxs

In [7]:
DATA_BATCHES = 50

torch.set_grad_enabled(False)

all_resids = []
all_word_idxs = []

for i_batch in tqdm(range(DATA_BATCHES)):
    tokens, word_idxs = generate_batch()
    _, cache = model.run_with_cache(tokens, names_filter=lambda x: x.endswith("resid_post"))
    residuals = cache.stack_activation("resid_post")

    residuals = residuals[:, :, prefix_len:, :]
    word_idxs = word_idxs[:, :residuals.size(dim=2)]

    mask = word_idxs != -1

    all_resids.append(residuals.cpu().numpy()[:, mask].reshape(12, -1, model.cfg.d_model))
    all_word_idxs.append(word_idxs[mask])


  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [00:13<00:00,  3.79it/s]


In [8]:
x_all_layers = np.concatenate(all_resids, axis=1)
y = np.concatenate(all_word_idxs)

print(x_all_layers.shape)
print(y.shape)

(12, 158383, 768)
(158383,)


## Probes

In [9]:
class ProbingDataset(Dataset):
    def __init__(self, act, y):
        assert len(act) == len(y)
        print(f"dataset: {len(act)} pairs loaded...")
        self.act = act
        self.y = y
        print("y:", np.unique(y, return_counts=True))
        
    def __len__(self, ):
        return len(self.y)
    
    def __getitem__(self, idx):
        return torch.tensor(self.act[idx]), torch.tensor(self.y[idx]).long()

In [10]:
LAYER = 3
x = x_all_layers[LAYER, :, :]

probing_dataset = ProbingDataset(x, y)
train_size = int(0.8 * len(probing_dataset))
test_size = len(probing_dataset) - train_size
probe_train_dataset, probe_test_dataset = torch.utils.data.random_split(probing_dataset, [train_size, test_size])
print(f"split into [test/train], [{test_size}/{train_size}]")

dataset: 158383 pairs loaded...
y: (array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([15878, 15793, 15871, 15914, 15859, 15881, 15847, 15800, 15780,
       15760]))
split into [test/train], [31677/126706]


In [11]:
class LinearProbe(nn.Module):
    def __init__(self, num_input_features, num_classes):
        super(LinearProbe, self).__init__()
        self.linear = nn.Linear(num_input_features, num_classes)
    
    def forward(self, x):
        return self.linear(x)

In [13]:
probe = LinearProbe(768, 10).to(device)

config = {
    'learning_rate': 0.001,
    'weight_decay': 1e-3,
    'batch_size': 1024,
    'num_epochs': 50,
}

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(probe.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

dataloader = DataLoader(probe_train_dataset, batch_size=config['batch_size'], shuffle=True)

torch.set_grad_enabled(True)

# simple training loop
bar = tqdm(range(config['num_epochs']))
for epoch in bar:
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()

        outputs = probe(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # train accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    bar.set_description(f'Epoch {epoch+1}, Loss: {running_loss/len(dataloader):.6f}, Acc: {correct/total:.6f}')


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 50, Loss: 0.621112, Acc: 0.748749: 100%|██████████| 50/50 [01:37<00:00,  1.95s/it]


In [14]:
from torch.utils.data import DataLoader

test_dataloader = DataLoader(probe_test_dataset, batch_size=config['batch_size'], shuffle=False)

total = 0
correct = 0

y_pred = []

probe.eval()
with torch.no_grad():
    for inputs, labels in tqdm(test_dataloader):
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = probe(inputs)
        _, predicted = torch.max(outputs.data, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        y_pred.append(predicted.cpu().numpy())

print(f'Test Accuracy: {correct/total:.5f}')

y_pred = np.concatenate(y_pred)

100%|██████████| 31/31 [00:00<00:00, 73.12it/s]

Test Accuracy: 0.73621





In [16]:
from sklearn.metrics import classification_report

print(classification_report(y[probe_test_dataset.indices], y_pred, digits=3))

              precision    recall  f1-score   support

         0.0      0.982     0.995     0.988      3226
         1.0      0.974     0.949     0.961      3103
         2.0      0.886     0.917     0.902      3277
         3.0      0.798     0.791     0.795      3120
         4.0      0.721     0.677     0.698      3172
         5.0      0.608     0.653     0.630      3156
         6.0      0.587     0.539     0.562      3181
         7.0      0.534     0.589     0.560      3121
         8.0      0.535     0.469     0.500      3127
         9.0      0.718     0.772     0.744      3194

    accuracy                          0.736     31677
   macro avg      0.734     0.735     0.734     31677
weighted avg      0.735     0.736     0.735     31677

