In [1]:
import transformer_lens
from transformer_lens import HookedTransformer, HookedTransformerConfig
import torch
import gdown
from einops import einsum
import circuitsvis as cv

In [2]:
"""
$ 2-layer model
$ attention only
$ no layer-normalization and biases
$ positional embeddings are added to each query and key vectors in the attention layer(not token embeddings).
  so no position data in the value matrix hence also not in the residual-stream[line: shortformer]
"""

cfg = HookedTransformerConfig(
    d_model=768,
    d_head=64,
    n_heads=12,
    n_layers=2,
    n_ctx=2048,
    d_vocab=50278,
    attention_dir="causal",
    attn_only=True, # defaults to False
    tokenizer_name="EleutherAI/gpt-neox-20b", 
    seed=398,
    use_attn_result=True,
    normalization_type=None, # defaults to "LN", i.e. layernorm with weights & biases
    positional_embedding_type="shortformer"
)

In [6]:
weights_dir = "/home/happyuser/main/3m/jupyter/transformerlens-practice/learn-mech-interp/essentials"
# url = "https://drive.google.com/uc?id=1vcZLJnJoYKQs-2KOjkd6LvHZrkSdoxhu"
# output = str(weights_dir)
# gdown.download(url, output)

In [7]:
model = HookedTransformer(cfg)
pretrained_weights = torch.load(weights_dir+"/essentialsewk26ptl.part", map_location="cuda")
model.load_state_dict(pretrained_weights)

<All keys matched successfully>

In [8]:
text = "We think that powerful, significantly superhuman machine intelligence is more likely than not to be created this century. If current machine learning techniques were scaled up to this level, we think they would by default produce systems that are deceptive or manipulative, and that no solid plans are known for how to avoid this."

logits, cache = model.run_with_cache(text, remove_batch_dim=True)

In [9]:
str_tokens = model.to_str_tokens(text)
for layer in range(model.cfg.n_layers):
    attention_pattern = cache["pattern", layer]
    display(cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern))

In [10]:
attention_pattern = cache["pattern", 1]
print(torch.diag(attention_pattern[0], diagonal=-1).round(decimals=3))

tensor([0.7370, 0.1320, 0.0130, 0.0790, 0.0840, 0.0510, 0.3310, 0.0740, 0.0220,
        0.0080, 0.0670, 0.1480, 0.0310, 0.0660, 0.0470, 0.1030, 0.0330, 0.0380,
        0.0010, 0.0720, 0.0120, 0.0400, 0.0600, 0.0610, 0.0110, 0.0720, 0.1220,
        0.0320, 0.0010, 0.0010, 0.0190, 0.0330, 0.0040, 0.0380, 0.0280, 0.0050,
        0.0340, 0.0090, 0.0140, 0.0730, 0.0170, 0.0280, 0.0050, 0.0290, 0.1070,
        0.1040, 0.0160, 0.0140, 0.0160, 0.1560, 0.0340, 0.0140, 0.0030, 0.0700,
        0.0630, 0.0020, 0.2190, 0.0620, 0.0590, 0.0050, 0.0200],
       device='cuda:0')


In [11]:
def current_attn_detector(cache, thresh=0.5):
    res_list = []
    for layer in range(model.cfg.n_layers):
        layer_patts = cache["pattern", layer]
        for head in range(model.cfg.n_heads):
            head_patt = layer_patts[head]
            mean = torch.sum(torch.diag(head_patt, diagonal=0))/head_patt.size()[-1]
            # print(f"head:{head} => {mean}")
            if mean > thresh:
                res_list.append(f"{layer}.{head}")
    return res_list
    
def prev_attn_detector(cache, thresh=0.5):
    res_list = []
    for layer in range(model.cfg.n_layers):
        layer_patts = cache["pattern", layer]
        for head in range(model.cfg.n_heads):
            head_patt = layer_patts[head]
            mean = (torch.sum(torch.diag(head_patt, diagonal=-1)[1:]) + 1)/head_patt.size()[-1]   # extra 1 for head_patt[0][0]
            # print(f"head:{head} => {mean}")
            if mean > thresh:
                res_list.append(f"{layer}.{head}")
    return res_list
    

def first_attn_detector(cache, thresh=0.5):
    res_list = []
    for layer in range(model.cfg.n_layers):
        layer_patts = cache["pattern", layer]
        for head in range(model.cfg.n_heads):
            head_patt = layer_patts[head]
            mean = torch.sum(head_patt[:, 0])/head_patt.size()[-1]
            # print(f"head:{head} => {mean}")
            if mean > thresh:
                res_list.append(f"{layer}.{head}")
    return res_list


In [12]:
print("Heads attending to current token  = ", ", ".join(current_attn_detector(cache, thresh=0.4)))
print("Heads attending to previous token = ", ", ".join(prev_attn_detector(cache, thresh=0.4)))
print("Heads attending to first token    = ", ", ".join(first_attn_detector(cache, thresh=0.4)))

Heads attending to current token  =  0.9
Heads attending to previous token =  0.7
Heads attending to first token    =  0.3, 1.4, 1.10


In [95]:
# searching induction head

In [16]:
"""
property alert: A striking thing about models with induction heads is that, given a repeated sequence of 
random-tokens, they can predict the repeated half of the sequence.

"""

model.generate("asdfghjkl asdfghjkl ")

  0%|          | 0/10 [00:00<?, ?it/s]

'asdfghjkl asdfghjkl ersk):asdfngduklwdf'

In [18]:
def generate_repeated_tokens(model, seq_len, batch=1):
    prefix = (torch.ones(batch, 1) * model.tokenizer.bos_token_id).long()
    rep_tokens_half = torch.randint(0, model.cfg.d_vocab, (batch, seq_len), dtype=torch.int64)
    rep_tokens = torch.cat([prefix, rep_tokens_half, rep_tokens_half], dim=-1).to("cuda")
    return rep_tokens

def run_and_cache_model_repeated_tokens(model, seq_len, batch=1):
    rep_tokens = generate_repeated_tokens(model, seq_len, batch)
    rep_logits, rep_cache = model.run_with_cache(rep_tokens)
    return rep_tokens, rep_logits, rep_cache


In [40]:
(rep_tokens, rep_logits, rep_cache) = run_and_cache_model_repeated_tokens(model, 11)

In [43]:
print(rep_tokens, model.to_string(rep_tokens))

tensor([[    0, 47092,  5480, 30791,  5351, 35279, 30281, 27427, 26889, 22688,
          2178, 30813, 47092,  5480, 30791,  5351, 35279, 30281, 27427, 26889,
         22688,  2178, 30813]], device='cuda:0') ['<|endoftext|>filed preparedumbentailsoste competenceitatingRM septoptpressedfiled preparedumbentailsoste competenceitatingRM septoptpressed']


In [106]:
def get_log_probs(logits, tokens):
    logits, tokens = logits[0][:-1], tokens[0][1:]
    log_probs = torch.log_softmax(logits, dim=-1)

    token_log_probs = log_probs[torch.arange(len(tokens)), tokens]
    return token_log_probs

In [108]:
log_probs = get_log_probs(rep_logits, rep_tokens)
log_probs[:11].mean(), log_probs[11:].mean()

(tensor(-14.6901, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(-3.8410, device='cuda:0', grad_fn=<MeanBackward0>))

In [111]:
"""
its the previously observed first token heads in the second layer, that are now producing this induction
behaviour for the 2 x random token sequences. 
"""

seq_len = 20
batch = 1
(rep_tokens, rep_logits, rep_cache) = run_and_cache_model_repeated_tokens(model, seq_len, batch)
rep_cache.remove_batch_dim()
rep_str = model.to_str_tokens(rep_tokens)
model.reset_hooks()
log_probs = get_log_probs(rep_logits, rep_tokens).squeeze()

for layer in range(model.cfg.n_layers):
    attention_pattern = rep_cache["pattern", layer]
    display(cv.attention.attention_patterns(tokens=rep_str, attention=attention_pattern))

In [171]:
def induction_detector(cache, thresh=0.5):
    res_list = []
    for layer in range(model.cfg.n_layers):
        layer_patts = cache["pattern", layer]
        for head in range(model.cfg.n_heads):
            head_patt = layer_patts[head]
            seq_len = (head_patt.shape[-1] - 1) // 2
            mean = torch.diag(head_patt, diagonal=(-seq_len+1)).mean()
            # print(f"head:{head} => {mean}")
            if mean > thresh:
                res_list.append(f"{layer}.{head}")
    return res_list

In [172]:
induction_detector(rep_cache, thresh=0.6)

['1.4', '1.10']

In [20]:
# detect induction heads in gpt2-small

In [21]:
gpt2_small = HookedTransformer.from_pretrained("gpt2-small")

Loaded pretrained model gpt2-small into HookedTransformer


In [27]:
gpt2rep_tokens, gpt2rep_logits, gpt2rep_cache = run_and_cache_model_repeated_tokens(gpt2_small, 11)

In [31]:
def show_and_detect_induction(cache, tokens, thresh=0.6, show_head=False):
    res_list = []
    for layer in range(gpt2_small.cfg.n_layers):
        layer_patts = cache["pattern", layer][0]
        for head in range(gpt2_small.cfg.n_heads):
            head_patt = layer_patts[head]
            seq_len = (head_patt.shape[-1] - 1) // 2
            mean = torch.diag(head_patt, diagonal=(-seq_len+1)).mean()
            # print(f"head:{head} => {mean}")
            if mean > thresh:
                res_list.append(f"{layer}.{head}")
                if show_head:
                    display(cv.attention.attention_pattern(tokens=tokens, attention=head_patt), metadata={"width": "100px", "height":"100px"})
    return res_list

In [33]:
gpt2rep_str = gpt2_small.to_str_tokens(gpt2rep_tokens)
show_and_detect_induction(gpt2rep_cache, tokens=gpt2rep_str, show_head=False)   # set show_heads to True

['5.1', '5.5', '6.9', '7.2', '7.10', '8.1', '9.6', '10.1', '10.7']