<a href="https://colab.research.google.com/github/mahadikprasad15/ARENA/blob/main/Pythia_70M_Induction_Circuits.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install transformer_lens

In [None]:
import torch
import transformer_lens
import plotly.express as px
from transformer_lens import utils


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
model = transformer_lens.HookedTransformer.from_pretrained('pythia-70m')
model.eval()

In [None]:
model.to_str_tokens('This is a test example, to see the tokenization')

In [None]:
print(f'Number of layers: {model.cfg.n_layers}')
print(f'Number of heads: {model.cfg.n_heads}')
print(f'Model residual stream dimension: {model.cfg.d_model}')
print(f'Model vocab-size: {model.cfg.d_vocab}')
print(f'Dimension of the heads: {model.cfg.d_head}')

# Generating induction prompts

First, I need to create a function to generate the induction prompts - this will have batch, seq_length - that's it, and then it will also have BOS added in the beginning.
This should generate the induction prompts of any size, and batch, to test on.


In [None]:
def generate_induction_prompts(batch = 1, seq_length = 20):
  tokens = torch.randint(1, model.cfg.d_vocab, (batch, seq_length), dtype = torch.long)
  BOS = torch.zeros(batch,).unsqueeze(-1).to(torch.long)
  prompt_tokens = torch.cat([BOS, tokens], dim = -1)
  prompt_tokens = torch.cat([prompt_tokens, tokens], dim = -1)
  return prompt_tokens

Now that we have the function, I need to have a first look at the attentino patterns of all heads, and see which one shows distinctive induction pattern.
I also have to look for previous token heads.

In [None]:
layers = model.cfg.n_layers
heads = model.cfg.n_heads
prompt_tokens = generate_induction_prompts(5,20)
logits, cache = model.run_with_cache(prompt_tokens)


for layer in range(layers):
    attention_pattern = cache[utils.get_act_name('pattern', layer)].cpu().numpy()
    fig = px.imshow(attention_pattern.mean(axis=0),
                    facet_col=0,
                    title=f'Attention Patterns: Layer {layer}',
                    labels={'x': 'Key Token', 'y': 'Query Token'},
                    color_continuous_scale='viridis',
                    width=3000,
                    height=2500
                   )
    fig.show()

From observations, **some layers are showing high induction patterns, and one of them looks like a previous token head**

#### **Induction heads**
* Layer 3 heads 1,2,6,7
* Layer 2 head 6
* Layer 0 head 3,5
#### **Previous Token heads**
* Layer 2 head 1 (Previous token head)
* Layer 1.1 and 1.2 also look like previous token heads
* Layer 0, head 3 looks very interesting, and different.



To get clearly strong heads, function that calculates these metrics across layers and brings the largest ones.



In [None]:
def induction_score(attention_pattern, prompt_tokens):
  seq_len = prompt_tokens.size(1) // 2
  offset = -(seq_len-1)

  return attention_pattern.diagonal(offset = offset).sum()



scores = torch.zeros(layers, heads, device = device)


for layer in range(layers):
  attention_pattern_layer = cache[utils.get_act_name('pattern', layer)].mean(dim=0)
  for head in range(heads):

    attention_pattern_head = attention_pattern_layer[head, :, :]
    score = induction_score(attention_pattern_head, prompt_tokens=prompt_tokens)
    scores[layer, head] = score

In [None]:
fig = px.imshow(scores.cpu().numpy(),
          title = 'Induction scores for all Layers and Heads',
          labels= {'x':'Head', 'y': 'Layer'},
          color_continuous_scale = 'dense',
          text_auto = '.2f'
            )

fig.show()

## Getting top previous token heads

In [None]:
def previous_token_score(attention_pattern, prompt_tokens):
  seq_len = prompt_tokens.size(1) // 2
  offset = -1

  return attention_pattern.diagonal(offset = offset).sum()


scores = torch.zeros(layers, heads, device = device)

for layer in range(layers):
  attention_pattern_layer = cache[utils.get_act_name('pattern', layer)].mean(dim=0)
  for head in range(heads):

    attention_pattern_head = attention_pattern_layer[head, :, :]
    score = previous_token_score(attention_pattern_head, prompt_tokens=prompt_tokens)
    scores[layer, head] = score

In [None]:
fig = px.imshow(scores.cpu().numpy(),
          title = 'Previous token scores for all Layers and Heads',
          labels= {'x':'Head', 'y': 'Layer'},
          color_continuous_scale = 'dense',
          text_auto = '.2f'
            )

fig.show()

In [None]:
accumulated_resid = cache.accumulated_resid(layer = -1, apply_ln = True)


Now I have to use accumulated_resid and other functions to get the resids for layers, heads and so on - and then apply layer norm on it, and then W_U on it - to get the logits for each layer, head etc.

But this will be

In [None]:
accumulated_resid_second_half = (accumulated_resid @ model.W_U)[:, :, prompt_tokens.size(1)//2:-1]
target_token_indices = prompt_tokens[:, (prompt_tokens.size(1)//2)+1:]



target_token_indices_reshaped = target_token_indices.unsqueeze(0).unsqueeze(-1)

target_token_indices_reshaped = target_token_indices_reshaped.repeat(accumulated_resid_second_half.size(0), 1, 1, 1)



logits_for_target_tokens = accumulated_resid_second_half.gather(dim = -1, index = target_token_indices_reshaped)


logits_for_target_tokens = logits_for_target_tokens.squeeze(-1)

logits_per_layer = logits_for_target_tokens.mean(dim = (-1, -2))

In [None]:
import pandas as pd

df = pd.DataFrame({
    'Layer': range(len(logits_per_layer)),
    'Average Logit': logits_per_layer.detach().cpu().numpy()
})

# Create the line plot
fig = px.line(df,
              x='Layer',
              y='Average Logit',
              title='Overall Average Logit for Target Tokens Across Layers')

fig.show()

In [None]:
per_head_resid = cache.stack_head_results( layer = -1, apply_ln=True)

In [None]:
per_head_resid_second_half = (per_head_resid @ model.W_U)[:, :, prompt_tokens.size(1)//2:-1]
target_token_indices = prompt_tokens[:, (prompt_tokens.size(1)//2)+1:]



target_token_indices_reshaped = target_token_indices.unsqueeze(0).unsqueeze(-1)

target_token_indices_reshaped = target_token_indices_reshaped.repeat(per_head_resid_second_half.size(0), 1, 1, 1)



logits_for_target_tokens = per_head_resid_second_half.gather(dim = -1, index = target_token_indices_reshaped)


logits_for_target_tokens = logits_for_target_tokens.squeeze(-1)



In [None]:
logits_per_head = logits_for_target_tokens.mean(dim = (-1,-2)).reshape(-1, heads)

In [None]:
fig = px.imshow(logits_per_head.detach().numpy(),
          title = 'Overall Average Logit for Target Tokens Across Layers and Heads',
          labels= {'x':'Head', 'y': 'Layer'},
          color_continuous_scale = 'dense',
          text_auto = '.2f'
            )

fig.show()

In [None]:
logits_per_head.shape

In [None]:
help(cache)