In [1]:
from transformer_lens import HookedTransformer
import torch
from torch import Tensor
from transformer_lens import HookedTransformer
from einops import einsum
import torch
import pandas as pd
import circuitsvis as cv

In [2]:
gpt2_small: HookedTransformer = HookedTransformer.from_pretrained("gpt2-small")

Loaded pretrained model gpt2-small into HookedTransformer


In [3]:
gpt2_small.cfg

HookedTransformerConfig:
{'NTK_by_parts_factor': 8.0,
 'NTK_by_parts_high_freq_factor': 4.0,
 'NTK_by_parts_low_freq_factor': 1.0,
 'NTK_original_ctx_len': 8192,
 'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': np.float64(8.0),
 'attn_scores_soft_cap': -1.0,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'decoder_start_token_id': None,
 'default_prepend_bos': True,
 'device': device(type='cuda'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'experts_per_token': None,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': np.float64(0.02886751345948129),
 'load_in_4bit': False,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_key_value_heads': None,
 'n_layers': 12,
 'n_params': 84934656,
 'normali

In [4]:
# Running model with loss and logits

In [5]:
text = ["Hello world, this is somebody who has never touched a keyboard. What do you think?", "should i keep touching grass, or should i start coding? "]
# text --> batch or a string
logits = gpt2_small(text, return_type="logits")   
loss = gpt2_small(text, return_type="loss")   # prediction loss for the entire batch
# logits, loss = gpt2_small(text, return_type="both")   # return_type=None does not calculate the logits(faster way when we need only the intermediate activation) 

print(loss, logits.size())   # loss: floating, logits: [batch_size, no_tokens, d_vocab]

tensor(5.2622, device='cuda:0', grad_fn=<DivBackward0>) torch.Size([2, 19, 50257])


In [6]:
# Indexing weights

In [7]:
# weight lables: "https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/full-merm.svg"
input_ids = [2029, 220]

print(len(gpt2_small.blocks))   # no of blocks
print(gpt2_small.embed.forward(input_ids).size())   # returns token embeddings for the id [batch_size, d_model]  
print(gpt2_small.blocks[0].attn.W_Q.size())   # getting attention block weights [n_heads, d_model, d_head]
print(gpt2_small.blocks[11].mlp.W_in.size())  # MLP in transformation matrix [d_model, d_mlp]

# alternately gpt2_small.W_K, gpt2_small.W_out, etc work aswell

12
torch.Size([2, 768])
torch.Size([12, 768, 64])
torch.Size([768, 3072])


In [8]:
# Tokenization

In [9]:
# print(gpt2_small.tokenizer)   # get tokenizer details
print(gpt2_small.to_str_tokens(text))   # returns tokens
print(gpt2_small.to_tokens(text, prepend_bos=False))   # converts to token ids
print(gpt2_small.to_string([50256, 15496,   995,   428,   318]))   # converts to string

# use prepend_bos=False in methods like to_tokens, model.forward, etc to disable adding the "endoftext"
# token in the beginning.

[['<|endoftext|>', 'Hello', ' world', ',', ' this', ' is', ' somebody', ' who', ' has', ' never', ' touched', ' a', ' keyboard', '.', ' What', ' do', ' you', ' think', '?'], ['<|endoftext|>', 'should', ' i', ' keep', ' touching', ' grass', ',', ' or', ' should', ' i', ' start', ' coding', '?', ' ']]
tensor([[15496,   995,    11,   428,   318,  8276,   508,   468,  1239, 12615,
           257, 10586,    13,  1867,   466,   345,   892,    30],
        [21754,  1312,  1394, 15241,  8701,    11,   393,   815,  1312,   923,
         19617,    30,   220, 50256, 50256, 50256, 50256, 50256]],
       device='cuda:0')
<|endoftext|>Hello world this is


In [10]:
text = "Hello world, this is somebody who has never touched a keyboard. What do you think about it?"

In [11]:
logits: Tensor = gpt2_small(text, return_type="logits")
prediction = logits.argmax(dim=-1).squeeze()

tokenized_txt = gpt2_small.to_tokens(text).squeeze()

In [12]:
print(tokenized_txt, "\n", prediction)

tensor([50256, 15496,   995,    11,   428,   318,  8276,   508,   468,  1239,
        12615,   257, 10586,    13,  1867,   466,   345,   892,   546,   340,
           30], device='cuda:0') 
 tensor([198,  11,  11, 198, 318, 616, 508, 468, 587, 587, 257, 983, 878, 314,
        318, 345, 892,  30, 428,  30, 198], device='cuda:0')


In [13]:
# finding the no of correctly predicted tokens
correct = 0
for idx, pred in enumerate(prediction[:-1]):   # loops runs till the second last element
    if pred == tokenized_txt[idx+1]:           # idx+1 to skip first token, cuz its bos-token
        print(pred.item(), " --> ", gpt2_small.to_string(pred.item()))
        correct += 1
print("correct predictions: ", correct)

11  -->  ,
318  -->   is
508  -->   who
468  -->   has
257  -->   a
345  -->   you
892  -->   think
30  -->  ?
correct predictions:  8


In [14]:
logits, cache = gpt2_small.run_with_cache("Why did the chicken cross the", remove_batch_dim=True)

In [15]:
attn_pattrn_l0 = cache["pattern", 0]
print(attn_pattrn_l0.size())    # [n_heads, seq_len, seq_len] 

torch.Size([12, 7, 7])


In [16]:
# MANUALLY CALCULATING BLOCK-0 ATTENTION PATTERN

In [17]:
gpt2_small.pos_embed

PosEmbed()

In [18]:
input_ids = gpt2_small.to_tokens("Why did the chicken cross the")
l0_wQ = gpt2_small.blocks[0].attn.W_Q
l0_wK = gpt2_small.blocks[0].attn.W_K
l0_bQ = gpt2_small.blocks[0].attn.b_Q
l0_bK = gpt2_small.blocks[0].attn.b_K

embeddings = gpt2_small.embed(input_ids)[0] + gpt2_small.pos_embed(input_ids)[0]
Q_mat = einsum(embeddings, l0_wQ, "seq_len d_model, n_heads d_model d_head -> n_heads seq_len d_head")
Q_mat += l0_bQ.unsqueeze(1)
K_mat = einsum(embeddings, l0_wK, "seq_len d_model, n_heads d_model d_head -> n_heads seq_len d_head")
K_mat += l0_bK.unsqueeze(1)

QK_dot = einsum(Q_mat, K_mat, "n_heads seq_q d_head, n_heads seq_k d_head -> n_heads seq_q seq_k")

mask = torch.triu(torch.ones(input_ids.size()[-1], input_ids.size()[-1], device="cuda"), diagonal=1).bool()
masked_dot = QK_dot.masked_fill(mask, float('-inf'))/(Q_mat.size()[-1] ** 0.5)

attn_probs = torch.softmax(masked_dot, dim=-1)

In [19]:
torch.allclose(attn_pattrn_l0, attn_probs)   # konni yerri pu jeevithaalu anthey bro, enni moggalu veesina dengadhu

False

In [20]:
q, k = cache["q", 0], cache["k", 0]
seq, nhead, headsize = q.shape
layer0_attn_scores = einsum(q, k, "seqQ n h, seqK n h -> n seqQ seqK")
mask = torch.triu(torch.ones((seq, seq), dtype=bool), diagonal=1).to("cuda")
layer0_attn_scores.masked_fill_(mask, -1e9)
layer0_pattern_from_q_and_k = (layer0_attn_scores / headsize**0.5).softmax(-1)

In [21]:
torch.allclose(attn_pattrn_l0, layer0_pattern_from_q_and_k)    # donga puku, q & k tensors direct ga cache nunchi dengeysadu bro.

True

In [71]:
text = "Jack and Jill went up the hill. Jack returned back without"
# text = "Humans are huge until they stand next to an elephant"
logits, cache = gpt2_small.run_with_cache(text, remove_batch_dim=True)
attention_pattern = cache["pattern", 8]
gpt2_str_tokens = gpt2_small.to_str_tokens(text)

display(cv.attention.attention_patterns(
    tokens=gpt2_str_tokens, 
    attention=attention_pattern,
))

In [45]:
attention_pattern[0].round(decimals=3)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9770, 0.0230, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9850, 0.0020, 0.0130, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9640, 0.0010, 0.0000, 0.0350, 0.0000, 0.0000, 0.0000],
        [0.8330, 0.0020, 0.0000, 0.0060, 0.1590, 0.0000, 0.0000],
        [0.9330, 0.0000, 0.0000, 0.0040, 0.0010, 0.0620, 0.0000],
        [0.9510, 0.0010, 0.0000, 0.0280, 0.0010, 0.0000, 0.0190]],
       device='cuda:0')