<a href="https://colab.research.google.com/github/mahadikprasad15/ARENA/blob/main/GPT_2_small_induction_heads%2C_previous_token_heads.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install transformer_lens

In [None]:
import torch
from transformer_lens import HookedTransformer

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = HookedTransformer.from_pretrained('gpt2-small', device = device)

In [None]:
model.eval()

In [None]:
def induction_prompt_generator(batch = 0, seq_length = 10):
  vocab_size = model.cfg.d_vocab
  tokens = torch.randint(1000, vocab_size, (batch,seq_length))
  bos = torch.tensor(model.to_single_token('<|endoftext|>')).repeat(batch).unsqueeze(dim = 1)
  combined_seq = torch.cat((bos, tokens), dim=1)
  combined_repeat = torch.cat((combined_seq, tokens), dim = 1)

  return combined_repeat

In [None]:
prompts = induction_prompt_generator(4,10)
logits, cache = model.run_with_cache(prompts)

In [None]:
import plotly.express as px

# Limit the number of layers and heads to plot for faster execution
max_layers_to_plot = 2  # Plot only the first 2 layers
max_heads_to_plot = 4   # Plot only the first 4 heads per layer


for layer in range(min(model.cfg.n_layers, max_layers_to_plot)):
  for head in range(min(model.cfg.n_heads, max_heads_to_plot)):
    attention_pattern = cache[f'blocks.{layer}.attn.hook_pattern'][0, head]
    fig = px.imshow(attention_pattern.detach().cpu().numpy(),
                    title=f"Attention Pattern: Layer {layer}, Head {head}",
                    color_continuous_scale='dense') # Added title for clarity
    fig.show()

In [None]:
attention_pattern = cache[f'blocks.3.attn.hook_pattern'][0, 0]
fig = px.imshow(attention_pattern.detach().cpu().numpy(),
                    title=f"Attention Pattern: Layer 3, Head 0",
                    color_continuous_scale='dense') # Added title for clarity
fig.show()


attention_pattern = cache[f'blocks.0.attn.hook_pattern'][0, 5]
fig = px.imshow(attention_pattern.detach().cpu().numpy(),
                    title=f"Attention Pattern: Layer 0, Head 5",
                    color_continuous_scale='dense') # Added title for clarity
fig.show()

attention_pattern = cache[f'blocks.1.attn.hook_pattern'][0, 11]
fig = px.imshow(attention_pattern.detach().cpu().numpy(),
                    title=f"Attention Pattern: Layer 1, Head 11",
                    color_continuous_scale='dense') # Added title for clarity
fig.show()


attention_pattern = cache[f'blocks.0.attn.hook_pattern'][0, 1]
fig = px.imshow(attention_pattern.detach().cpu().numpy(),
                    title=f"Attention Pattern: Layer 0, Head 1",
                    color_continuous_scale='dense') # Added title for clarity
fig.show()

In [None]:
attention_pattern = cache[f'blocks.0.attn.hook_pattern'][0,1]

def induction_score(cache):
  attention_score_tensor = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
  attention_scores_dict = {}
  for layer in range(model.cfg.n_layers):
    for head in range(model.cfg.n_heads):
      attention_score = cache[f'blocks.{layer}.attn.hook_pattern'][:,head].diagonal(offset = -10, dim1 = 1, dim2 = 2).sum()
      attention_scores_dict[layer, head] = attention_score
      attention_score_tensor[layer, head] = attention_score

  return attention_scores_dict, attention_score_tensor


scores_dict, scores_tensor = induction_score(cache)
sorted_scores = sorted(scores_dict.items(), key = lambda x: x[1], reverse=True )[:10]

print('Top 10 induction scores:')
print('\n')
for x,y in sorted_scores:
  print(f'Layer {x[0]}, Head {x[1]} --- Score: {y.item():.2f}')

In [None]:
# Induction Head Scores

fig = px.imshow(
    scores_tensor,
    labels={"x": "Head", "y": "Layer"},
    title="Induction Scores by Head",
    text_auto=".2f",
    color_continuous_scale='dense',
    width=1900,
    height=700,
)

fig.show()



In [None]:
from  functools import partial

In [None]:
def zero_ablation(attention, hook, head):
  attention[:, :, head, :] = 0
  return attention


def mean_ablation(attention, hook, head):
  attention[:, :, head, :] = attention[:, :, head, :].mean()
  return attention

# Baseline logprobs

prompts = induction_prompt_generator(5,20)
baseline_logits, baseline_cache = model.run_with_cache(prompts)

def log_probs_calculate(logits):
  probs = torch.nn.functional.log_softmax(logits, dim = -1)
  baseline_scores = []
  for i in range(5):
    for j in range(21, 41):
      baseline_scores.append(probs[i,j,prompts[i,j].item()].item())

  average_score = torch.tensor(baseline_scores).mean()
  return average_score.item()


baseline_logprobs = log_probs_calculate(baseline_logits)

print(f'The baseline logprobs on correct tokens: {baseline_logprobs}')

In [None]:
ablated_logprobs = {}

for (layer, head), _ in sorted_scores:
  hook_function = partial(zero_ablation, head = head)
  zero_ablated_logits = model.run_with_hooks(prompts, fwd_hooks = [(f'blocks.{layer}.attn.hook_z', hook_function)])
  ablated_logprobs[layer, head] = zero_ablated_logits


In [None]:
difference = {}
for heads, logits in ablated_logprobs.items():
  diff = logits - baseline_logprobs
  difference[heads] = diff



In [None]:
def prev_token_attention_score(cache):
  attention_scores = {}
  attention_score_tensor = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
  for layer in range(model.cfg.n_layers):
    for head in range(model.cfg.n_heads):
      attention_pattern = cache[f'blocks.{layer}.attn.hook_pattern'][: ,head, :, :]
      score = attention_pattern.diagonal(dim1 = -2, dim2 = -1, offset = -1).mean()
      attention_scores[layer, head] = score
      attention_score_tensor[layer, head] = score
  return attention_scores, attention_score_tensor


In [None]:
# Previous token heads
score_dict, score_tensor = prev_token_attention_score(cache)

prev_token_heads = sorted(score_dict.items(), key = lambda x: x[1], reverse = True)[:5]
prev_token_heads

In [None]:
# Previous Token Head Scores

fig = px.imshow(
    score_tensor,
    labels={"x": "Head", "y": "Layer"},
    title="Previous Token Scores by Head",
    text_auto=".2f",
    color_continuous_scale='dense',
    width=1900,
    height=700,
)

fig.show()


In [None]:
from transformer_lens import FactoredMatrix

head_index = 0
layer = 3

W_O = model.W_O[layer, head_index]
W_V = model.W_V[layer, head_index]
W_E = model.W_E
W_U = model.W_U

OV_circuit = FactoredMatrix(W_V, W_O)
full_OV_circuit = W_E @ OV_circuit @ W_U

In [None]:
indices = torch.randint(0, model.cfg.d_vocab, (200,))
full_OV_circuit_sample = full_OV_circuit[indices, indices].AB

px.imshow(
    full_OV_circuit_sample.detach().cpu().numpy(),
    labels={"x": "Logits on output token", "y": "Input token"},
    title="Full OV circuit for copying head",
    width=700,
    height=600,
    color_continuous_scale= 'dense'
)

In [None]:
def top_1_acc(full_OV_circuit: FactoredMatrix, batch_size: int = 1000) -> float:
    """
    Return the fraction of the time that the maximum value is on the circuit diagonal.
    """
    total = 0

    for indices in torch.split(torch.arange(full_OV_circuit.shape[0], device=device), batch_size):
        AB_slice = full_OV_circuit[indices].AB
        total += (torch.argmax(AB_slice, dim=1) == indices).float().sum().item()

    return total / full_OV_circuit.shape[0]


print(f"Fraction of time that the best logit is on diagonal: {top_1_acc(full_OV_circuit):.4f}")