In [10]:
!pip install transformers torch einops

Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl.metadata (13 kB)
Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0
[0m

In [11]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from einops import rearrange
import math

In [4]:
token = input("Enter hf token: ")

In [5]:
token

'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'

In [8]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", use_auth_token=token)



tokenizer_config.json:   0%|          | 0.00/776 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

In [12]:
def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

In [13]:
class AttnWrapper(torch.nn.Module):
    def __init__(self, attn):
        super().__init__()
        self.attn = attn
        self.q_states, self.k_states, self.v_states = None, None, None
        
        self.attn.q_proj.register_forward_hook(self.save_q_states)
        self.attn.k_proj.register_forward_hook(self.save_k_states)
        self.attn.v_proj.register_forward_hook(self.save_v_states)

    def forward(self, *args, **kwargs):
        output = self.attn(*args, **kwargs)
        return output

    def rearrange_states(self, states, head_dim):
        return rearrange(states, 'b q (h d) -> b h q d', d=head_dim)

    def save_q_states(self, module, input, output):
        self.q_states = self.rearrange_states(output, self.attn.head_dim)

    def save_k_states(self, module, input, output):
        self.k_states = self.rearrange_states(output, self.attn.head_dim)

    def save_v_states(self, module, input, output):
        self.v_states = self.rearrange_states(output, self.attn.head_dim)

    def get_head_output_at_position(self, head_no, token_position):
        """
        Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
        """
        if self.q_states is None or self.k_states is None or self.v_states is None:
            raise ValueError("Q, K, V states have not been initialized or forward has not been called yet.")        
        cos, sin = self.attn.rotary_emb(self.v_states, seq_len=self.v_states.shape[-2])
        query_states, key_states = apply_rotary_pos_emb(self.q_states, self.k_states, cos, sin)
        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.attn.head_dim)
        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_output = torch.matmul(attn_weights, self.v_states)
        return attn_output[:, head_no, token_position, :]

    def reset(self):
        self.q_states, self.k_states, self.v_states = None, None, None

In [14]:
class MLPWrapper(torch.nn.Module):
    def __init__(self, mlp):
        super().__init__()
        self.mlp = mlp
        self.saved_activations = None

    def forward(self, *args, **kwargs):
        output = self.mlp(*args, **kwargs)
        self.saved_activations = output.clone()
        return output

    def reset(self):
        self.saved_activations = None

In [15]:
class BlockOutputWrapper(torch.nn.Module):
    def __init__(self, block):
        super().__init__()
        self.block = block
        self.block.self_attn = AttnWrapper(self.block.self_attn)
        self.block.mlp = MLPWrapper(self.block.mlp)
        self.resid_acts = None

    def forward(self, *args, **kwargs):
        output = self.block(*args, **kwargs)
        self.resid_acts = output[0]
        return output

    def get_mlp_acts(self):
        return self.block.mlp.saved_activations

    def get_attn(self, head, tok_pos):
        return self.block.self_attn.get_head_output_at_position(head, tok_pos)

    def reset(self):
        self.block.mlp.reset()
        self.block.self_attn.reset()
        self.resid_acts = None

In [16]:
class Llama7BHelper:
    def __init__(self, token):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained(
            "meta-llama/Llama-2-7b-hf", use_auth_token=token
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            "meta-llama/Llama-2-7b-hf", use_auth_token=token
        ).to(self.device)
        for i, layer in enumerate(self.model.model.layers):
            self.model.model.layers[i] = BlockOutputWrapper(
                layer
            )

    def reset(self):
        for layer in self.model.model.layers:
            layer.reset()

    def get_mlp_acts(self, layer):
        return self.model.model.layers[layer].get_mlp_acts()

    def get_resid_acts(self, layer):
        return self.model.model.layers[layer].resid_acts

    def get_attn(self, layer, head, tok_pos):
        return self.model.model.layers[layer].get_attn(head, tok_pos)

    def get_logits(self, tokens):
        with torch.no_grad():
            logits = self.model(tokens).logits
            return logits

In [17]:
model = Llama7BHelper(token=token)



config.json:   0%|          | 0.00/609 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

In [22]:
# Test model
test_input = "The capital of France is a"
test_tokens = model.tokenizer(test_input, return_tensors="pt").input_ids.to(model.device)
test_logits = model.get_logits(test_tokens)
test_logits.shape

torch.Size([1, 7, 32000])

In [23]:
max_token_id = test_logits[0, -1, :].argmax()
decoded_token = model.tokenizer.decode(max_token_id)

In [24]:
decoded_token

'city'