In [1]:
import os
import json

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm

import transformer_lens

from probe_model import LinearProbe, Trainer, TrainerConfig

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
import importlib
import probe_model
importlib.reload(probe_model)
from probe_model import LinearProbe, Trainer, TrainerConfig
import spacy

In [125]:
from torch.nn.functional import pad

# Setup

In [3]:
device_name = "cpu"
if torch.cuda.is_available():
    device_name = "cuda" # CUDA for NVIDIA GPU

device = torch.device(device_name)
print(f"Device: {device_name}")

Device: cuda


In [4]:
model_name = "gpt2-small"
model = transformer_lens.HookedTransformer.from_pretrained(model_name)

Loaded pretrained model gpt2-small into HookedTransformer


In [10]:
# Load the spaCy model
spacy_model = spacy.load("en_core_web_sm")

# Prewar Dataset

In [177]:
BATCH_SIZE = 128
def generate_word_idxs(text):
    
    gpt_tokens = model.to_tokens(text).squeeze(0)
    gpt_tokens_str = [model.to_single_str_token(int(t)) for t in gpt_tokens]
    
    doc = spacy_model(text)
    words = [t.text for t in doc if t.is_alpha]
    word_idxs = []
    
    i = 0
    cur = 0 # current word index
    sub_idx = 0 # sub index of current word
    #print(gpt_tokens_str)
    
    if not len(words):
        return [], []

    while i < len(gpt_tokens_str):
        t = gpt_tokens_str[i].strip()
        # skip if token is just a newline or other whitespace
        if not len(t):
            word_idxs.append(-1)
            i += 1
            continue
#         print(cur)
#         print(words)
        cur_word = words[cur]
        # if token is part of current word, update sub_idx, continue to next token
        if cur_word.find(t, sub_idx) != -1:
            word_idxs.append(cur)
            sub_idx += len(t)
            i += 1
        else:
            # if token not in cur_word, check next word
            if cur+1 < len(words) and t in words[cur+1]:
                cur += 1
                sub_idx = 0
            # if not in cur_word or next word, give up and continue
            else:
                word_idxs.append(-1)
                i += 1 

#     print(f"text {text}")
#     print(f"gpt_tokens {gpt_tokens}")
#     print(f"gpt_tokens_str {gpt_tokens_str}")
#     print(f"words {words}")

    word_idxs.extend([-1]*(BATCH_SIZE - len(word_idxs)))
    
    gpt_tokens_padded = pad(gpt_tokens, (0, BATCH_SIZE - gpt_tokens.shape[0]))

    return gpt_tokens_padded, np.array(word_idxs)

In [178]:
generate_word_idxs("hey this is an example sentence, !")

(tensor([50256, 20342,   428,   318,   281,  1672,  6827,    11,  5145,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,  

In [191]:
torch.set_grad_enabled(False)

all_resids = []
all_tokens = []
all_word_idxs = []

N_LINES = 500
with open('datasets/prewar.txt', 'r') as f:
    prewar = [line.rstrip("\n") for line in f.readlines()]

for sentence in tqdm(prewar[:N_LINES]):
    tokens, word_idxs = generate_word_idxs(sentence)

    all_tokens += tokens
    all_word_idxs.append(word_idxs)


all_tokens = torch.stack(all_tokens)
#print(all_tokens)

all_word_idx = np.array(all_word_idx)

#_, cache = model.run_with_cache(all_tokens, names_filter=lambda x: x.endswith("resid_post"))
_, cache = model.run_with_cache(all_tokens)
residuals = cache.stack_activation("resid_post")

print(f"model.cfg.d_model {model.cfg.d_model}")

# residuals contains num_layers, patch size, seq length, hidden dimension
residuals_np = residuals.cpu().numpy()
all_resids = residuals_np[:].reshape(12, -1, model.cfg.d_model)
#all_resids = residuals.cpu().numpy()[:].reshape(12, -1, model.cfg.d_model)

x_all_layers = all_resids
print(all_resids.shape)
y = np.concatenate(all_word_idxs)
print(y.shape)

  0%|          | 0/500 [00:00<?, ?it/s]


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [186]:
class ProbingDataset(Dataset):
    def __init__(self, act, y):
        assert len(act) == len(y)
        print(f"dataset: {len(act)} pairs loaded...")
        self.act = act
        self.y = y
        print("y:", np.unique(y, return_counts=True))
        
    def __len__(self, ):
        return len(self.y)
    
    def __getitem__(self, idx):
        return torch.tensor(self.act[idx]), torch.tensor(self.y[idx]).long()

In [None]:
LAYER = 3
x = x_all_layers[LAYER, :, :]

probing_dataset = ProbingDataset(x, y)
train_size = int(0.8 * len(probing_dataset))
test_size = len(probing_dataset) - train_size
probe_train_dataset, probe_test_dataset = torch.utils.data.random_split(probing_dataset, [train_size, test_size])
print(f"split into [test/train], [{test_size}/{train_size}]")

In [188]:
probe = LinearProbe(device, 768, 10)

folder = f"ckpts/{model_name}/randwords_159k/layer{LAYER}"
config = TrainerConfig(num_epochs=40, ckpt_path=folder)
trainer = Trainer(device, probe, probe_train_dataset, probe_test_dataset, config)

In [None]:
trainer.train()

In [None]:
print(trainer.generate_report())