In [19]:
!pip install transformers torch einops datasets

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting datasets
  Downloading datasets-2.16.1-py3-none-any.whl.metadata (20 kB)
Collecting pyarrow>=8.0.0 (from datasets)
  Downloading pyarrow-15.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.0 kB)
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl.metadata (3.6 kB)
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl.metadata (9.9 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec (from torch)
  Downloading fsspec-2023.10.0-py3-none-any.whl.metadata (6.8 kB)
Collecting aiohttp (from datasets)
  Downloading aiohttp-3.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.4 kB)
Collecting multidict<7.0,>=4.5 (from aiohttp->datasets)
  Downloading multidict-6.0.4-c

In [60]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from einops import rearrange
import math
from datasets import load_dataset
import random
from tqdm import tqdm

In [3]:
token = input("Enter hf token: ")

In [4]:
token

'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'

In [5]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", use_auth_token=token)



In [6]:
def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

In [7]:
class AttnWrapper(torch.nn.Module):
    def __init__(self, attn):
        super().__init__()
        self.attn = attn
        self.q_states, self.k_states, self.v_states = None, None, None
        
        self.attn.q_proj.register_forward_hook(self.save_q_states)
        self.attn.k_proj.register_forward_hook(self.save_k_states)
        self.attn.v_proj.register_forward_hook(self.save_v_states)

    def forward(self, *args, **kwargs):
        output = self.attn(*args, **kwargs)
        return output

    def rearrange_states(self, states, head_dim):
        return rearrange(states, 'b q (h d) -> b h q d', d=head_dim)

    def save_q_states(self, module, input, output):
        self.q_states = self.rearrange_states(output, self.attn.head_dim)

    def save_k_states(self, module, input, output):
        self.k_states = self.rearrange_states(output, self.attn.head_dim)

    def save_v_states(self, module, input, output):
        self.v_states = self.rearrange_states(output, self.attn.head_dim)

    def get_q_activations_at_position(self, head_no, token_position):
        return self.q_states[:, head_no, token_position, :]
    
    def get_k_activations_at_position(self, head_no, token_position):
        return self.k_states[:, head_no, token_position, :]
    
    def get_head_output_at_position(self, head_no, token_position):
        """
        Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
        """
        if self.q_states is None or self.k_states is None or self.v_states is None:
            raise ValueError("Q, K, V states have not been initialized or forward has not been called yet.")        
        cos, sin = self.attn.rotary_emb(self.v_states, seq_len=self.v_states.shape[-2])
        query_states, key_states = apply_rotary_pos_emb(self.q_states, self.k_states, cos, sin)
        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.attn.head_dim)
        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_output = torch.matmul(attn_weights, self.v_states)
        return attn_output[:, head_no, token_position, :]

    def reset(self):
        self.q_states, self.k_states, self.v_states = None, None, None

In [8]:
class MLPWrapper(torch.nn.Module):
    def __init__(self, mlp):
        super().__init__()
        self.mlp = mlp
        self.saved_activations = None

    def forward(self, *args, **kwargs):
        output = self.mlp(*args, **kwargs)
        self.saved_activations = output.clone()
        return output

    def reset(self):
        self.saved_activations = None

In [9]:
class BlockOutputWrapper(torch.nn.Module):
    def __init__(self, block):
        super().__init__()
        self.block = block
        self.block.self_attn = AttnWrapper(self.block.self_attn)
        self.block.mlp = MLPWrapper(self.block.mlp)
        self.resid_acts = None

    def forward(self, *args, **kwargs):
        output = self.block(*args, **kwargs)
        self.resid_acts = output[0]
        return output

    def get_mlp_acts(self):
        return self.block.mlp.saved_activations

    def get_attn(self, head, tok_pos):
        return self.block.self_attn.get_head_output_at_position(head, tok_pos)
    
    def get_attn_q(self, head, tok_pos):
        return self.block.self_attn.get_q_activations_at_position(head, tok_pos)
    
    def get_attn_k(self, head, tok_pos):
        return self.block.self_attn.get_k_activations_at_position(head, tok_pos)

    def reset(self):
        self.block.mlp.reset()
        self.block.self_attn.reset()
        self.resid_acts = None

In [10]:
class Llama7BHelper:
    def __init__(self, token):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained(
            "meta-llama/Llama-2-7b-hf", use_auth_token=token
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            "meta-llama/Llama-2-7b-hf", use_auth_token=token
        ).to(self.device)
        for i, layer in enumerate(self.model.model.layers):
            self.model.model.layers[i] = BlockOutputWrapper(
                layer
            )

    def reset(self):
        for layer in self.model.model.layers:
            layer.reset()

    def get_mlp_acts(self, layer):
        return self.model.model.layers[layer].get_mlp_acts()

    def get_resid_acts(self, layer):
        return self.model.model.layers[layer].resid_acts

    def get_attn(self, layer, head, tok_pos):
        return self.model.model.layers[layer].get_attn(head, tok_pos)
    
    def get_attn_q(self, layer, head, tok_pos):
        return self.model.model.layers[layer].get_attn_q(head, tok_pos)
    
    def get_attn_k(self, layer, head, tok_pos):
        return self.model.model.layers[layer].get_attn_k(head, tok_pos)

    def get_logits(self, tokens):
        with torch.no_grad():
            logits = self.model(tokens).logits
            return logits

In [11]:
model = Llama7BHelper(token=token)



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [24]:
num_rows = 1000  # replace with the number of rows you want
# Load the dataset with only the first `num_rows`
dataset = load_dataset("wikipedia", "20220301.en", split=f'train', streaming=True)

subset = []
for i, row in enumerate(dataset):
    subset.append(row)
    if i + 1 >= num_rows:
        break

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [64]:
def make_tokenized_dataset(dataset, token_length=100):
    tokenized_dataset = []
    for row in dataset:
        text = row["text"]
        tokens = tokenizer.encode(text, return_tensors="pt")
        # split into windows of length `token_length`
        for i in range(0, tokens.shape[1]-token_length, token_length):
            tokenized_dataset.append(tokens[0, i:i+token_length].numpy().tolist())
    return tokenized_dataset

In [65]:
tokenized_dataset = make_tokenized_dataset(subset)

In [66]:
len(tokenized_dataset)

59717

In [72]:
from glob import glob

In [73]:
len(list(glob("acts/*.pt")))

20002

In [75]:
def experiment(n: int, layer: int, head: int, pos_k: int, pos_q: int) -> None:
    all_data = random.sample(tokenized_dataset, n)
    keys = []
    queries = []
    for i, elem in tqdm(enumerate(all_data)):
        model.get_logits(torch.tensor([elem]).to(model.device))
        k_act = model.get_attn_k(layer, head, pos_k)[0].cpu()
        q_act = model.get_attn_q(layer, head, pos_q)[0].cpu()
        keys.append(k_act)
        queries.append(q_act)
    keys_all = torch.stack(keys)
    queries_all = torch.stack(queries)
    torch.save(keys_all, f'acts/k_act_l{layer}_h{head}_{pos_k}.pt')
    torch.save(queries_all, f'acts/q_act_l{layer}_h{head}_{pos_q}.pt')


In [78]:
experiment(10_000, 14, 0, 70, 75)

10000it [15:25, 10.81it/s]


In [12]:
# Test model
test_input = "The capital of France is a"
test_tokens = model.tokenizer(test_input, return_tensors="pt").input_ids.to(model.device)
test_logits = model.get_logits(test_tokens)
test_logits.shape

torch.Size([1, 7, 32000])

In [13]:
max_token_id = test_logits[0, -1, :].argmax()
decoded_token = model.tokenizer.decode(max_token_id)

In [14]:
decoded_token

'city'

In [15]:
out = model.get_attn(10, 0, -1)
out.shape

torch.Size([1, 128])

In [16]:
model.get_attn_k(10, 0, -1).shape

torch.Size([1, 128])

In [17]:
model.get_attn_q(10, 0, -1).shape

torch.Size([1, 128])