# Testing logit attribution function
Trying to compute the logit attributions of some heads and token combinations "by hand" and comparing those results to the full version

In [1]:
import torch

from transformers import AutoTokenizer

from gptomics import functional as func
from gptomics import transformersio, model

In [2]:
small_prompt = "The quick brown fox jumped over the lazy dog."

In [24]:
from importlib import reload

# Loading the model

In [3]:
m = model.model_by_name("EleutherAI/gpt-neo-125M")
t = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")

In [4]:
m.model.config.num_heads, m.model.config.num_layers

(12, 12)

# Running the batch function

In [5]:
%time attrs, tokens = func.logit_attribution("EleutherAI/gpt-neo-125M", small_prompt, False)

CPU times: user 1.83 s, sys: 331 ms, total: 2.16 s
Wall time: 2.64 s


# Picking a couple heads and token combinations

In [6]:
import random
random.seed(203834)
head_inds = random.sample(range(m.model.config.num_heads), 2)
layer_inds = random.sample(range(m.model.config.num_layers), 2)

In [7]:
list(zip(head_inds, layer_inds))

[(6, 11), (7, 0)]

# Extracting the hidden vectors and attention matrix at the right place

In [8]:
hidden_vectors = list()
attn_matrices = list()

def hook(attn_layer, hidden_states, output):
    hidden_vectors.append(hidden_states[0])
    attn_matrices.append(output[2])

In [9]:
def run_model(m, hook, block, prompt):
    handle = m.model.transformer.h[block].attn.attention.register_forward_hook(hook)
    
    input_ids = t(prompt, return_tensors="pt").input_ids
    
    with torch.no_grad():
        _ = m.model(input_ids, output_attentions=True)
        
    handle.remove()
    
    return input_ids[0]

# Block 6, head 11

In [10]:
token_ids = run_model(m, hook, 6, small_prompt)

In [11]:
import itertools
token_pairs = [(i, j) for (i, j) in itertools.product(range(10), range(10)) if i <= j]
random.seed(3894702937)
random.sample(token_pairs, 2)

[(9, 9), (7, 8)]

## Token 9 -> 9

In [103]:
b = 6  # block
h = 11  # head
src = 9  # src token
dst = 9  # dst token

In [104]:
# Extract the relevant hidden vectors and attention matrix
hvs = hidden_vectors[0]
att = attn_matrices[0]
att_layer = m.model.transformer.h[b].attn.attention

In [105]:
hvs.shape, att.shape

(torch.Size([1, 10, 768]), torch.Size([1, 12, 10, 10]))

In [106]:
# Form the value vector for token 9
v = att_layer._split_heads(
    att_layer.v_proj(hvs), att_layer.num_heads, att_layer.head_dim
)[0, h, src]
# Weight by attention weight from 9 to 9
v_ = v * att[0, h, dst, src]

In [107]:
# Fetching the relevant columns of the output matrix
o = m.model.transformer.h[b].attn.attention.out_proj.weight[:, h*64:(h+1)*64]
r = o @ v_

In [108]:
unembed = m.model.lm_head.weight[token_ids[dst]]
unembed @ r

tensor(0.1104)

In [109]:
# Batch result
attrs[b, h, src, dst]

tensor(0.1104)

## Token 7 -> 8

In [100]:
b = 6  # block
h = 11  # head
src = 7  # src token
dst = 8  # dst token

In [101]:
# Extract the relevant hidden vectors and attention matrix
hvs = hidden_vectors[0]
att = attn_matrices[0]

# Form the value vector for token 9
v = m.model.transformer.h[b].attn.attention.v_proj(hvs)[0, src, h*64:(h+1)*64]
# Weight by attention weight from 9 to 9
v_ = v * att[0, h, dst, src]

# Fetching the relevant columns of the output matrix
o = m.model.transformer.h[b].attn.attention.out_proj.weight[:, h*64:(h+1)*64]
r = o @ v_
unembed = m.model.lm_head.weight[token_ids[dst]]
unembed @ r

tensor(0.0070)

In [102]:
# Batch result
attrs[b, h, dst, src]

tensor(0.0070)

# Block 7, head 0

In [75]:
token_ids = run_model(m, hook, 7, small_prompt)

In [76]:
import itertools
token_pairs = [(i, j) for (i, j) in itertools.product(range(10), range(10)) if i <= j]
random.seed(2397846)
random.sample(token_pairs, 2)

[(1, 6), (2, 8)]

## Token 1 -> 6

In [97]:
b = 7  # block
h = 0  # head
src = 1  # src token
dst = 6  # dst token

In [98]:
# Extract the relevant hidden vectors and attention matrix
hvs = hidden_vectors[1]
att = attn_matrices[1]

# Form the value vector for token 9
v = m.model.transformer.h[b].attn.attention.v_proj(hvs)[0, src, h*64:(h+1)*64]
# Weight by attention weight from 9 to 9
v_ = v * att[0, h, dst, src]

# Fetching the relevant columns of the output matrix
o = m.model.transformer.h[b].attn.attention.out_proj.weight[:, h*64:(h+1)*64]
r = o @ v_
unembed = m.model.lm_head.weight[token_ids[dst]]
unembed @ r

tensor(0.1856)

In [99]:
# Batch result
attrs[b, h, dst, src]

tensor(0.1856)

## Token 2 -> 8

In [110]:
b = 7  # block
h = 0  # head
src = 2  # src token
dst = 8  # dst token

In [111]:
# Extract the relevant hidden vectors and attention matrix
hvs = hidden_vectors[1]
att = attn_matrices[1]

# Form the value vector for token 9
v = m.model.transformer.h[b].attn.attention.v_proj(hvs)[0, src, h*64:(h+1)*64]
# Weight by attention weight from 9 to 9
v_ = v * att[0, h, dst, src]

# Fetching the relevant columns of the output matrix
o = m.model.transformer.h[b].attn.attention.out_proj.weight[:, h*64:(h+1)*64]
r = o @ v_
unembed = m.model.lm_head.weight[token_ids[dst]]
unembed @ r

tensor(-0.2034)

In [112]:
# Batch result
attrs[b, h, dst, src]

tensor(-0.2034)