<a href="https://colab.research.google.com/github/mahadikprasad15/ARENA/blob/main/Pythia-160M%20Induction%20Circuits.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install transformer_lens

In [None]:
import torch
import transformer_lens
import plotly.express as px
from transformer_lens import utils
import tqdm
from functools import partial

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
model = transformer_lens.HookedTransformer.from_pretrained('pythia-160m')
model.eval()

In [None]:
model.to_str_tokens('This is a test example, to see the tokenization')

In [None]:
print(f'Number of layers: {model.cfg.n_layers}')
print(f'Number of heads: {model.cfg.n_heads}')
print(f'Model residual stream dimension: {model.cfg.d_model}')
print(f'Model vocab-size: {model.cfg.d_vocab}')
print(f'Dimension of the heads: {model.cfg.d_head}')

# Generating induction prompts

First, I need to create a function to generate the induction prompts - this will have batch, seq_length - that's it, and then it will also have BOS added in the beginning.
This should generate the induction prompts of any size, and batch, to test on.


In [None]:
def generate_induction_prompts(batch = 1, seq_length = 20):
  tokens = torch.randint(1, model.cfg.d_vocab, (batch, seq_length), dtype = torch.long)
  BOS = torch.zeros(batch,).unsqueeze(-1).to(torch.long)
  prompt_tokens = torch.cat([BOS, tokens], dim = -1)
  prompt_tokens = torch.cat([prompt_tokens, tokens], dim = -1)
  return prompt_tokens

Now that we have the function, I need to have a first look at the attentino patterns of all heads, and see which one shows distinctive induction pattern.
I also have to look for previous token heads.

In [None]:
layers = model.cfg.n_layers
heads = model.cfg.n_heads
prompt_tokens = generate_induction_prompts(5,20)
logits, cache = model.run_with_cache(prompt_tokens)


for layer in range(layers):
    attention_pattern = cache[utils.get_act_name('pattern', layer)].cpu().numpy()
    fig = px.imshow(attention_pattern.mean(axis=0),
                    facet_col=0,
                    title=f'Attention Patterns: Layer {layer}',
                    labels={'x': 'Key Token', 'y': 'Query Token'},
                    color_continuous_scale='viridis',
                    width=3000,
                    height=2500
                   )
    fig.show()

To get clearly strong heads, function that calculates these metrics across layers and brings the largest ones.



In [None]:
def induction_score(attention_pattern, prompt_tokens):
  seq_len = prompt_tokens.size(1) // 2
  offset = -(seq_len-1)

  return attention_pattern.diagonal(offset = offset).sum()



scores = torch.zeros(layers, heads, device = device)


for layer in range(layers):
  attention_pattern_layer = cache[utils.get_act_name('pattern', layer)].mean(dim=0)
  for head in range(heads):

    attention_pattern_head = attention_pattern_layer[head, :, :]
    score = induction_score(attention_pattern_head, prompt_tokens=prompt_tokens)
    scores[layer, head] = score

In [None]:
fig = px.imshow(scores.cpu().numpy(),
          title = 'Induction scores for all Layers and Heads',
          labels= {'x':'Head', 'y': 'Layer'},
          color_continuous_scale = 'dense',
          text_auto = '.2f',
          aspect = 'auto'
            )

fig.show()

## Getting top previous token heads

In [None]:
def previous_token_score(attention_pattern, prompt_tokens):
  seq_len = prompt_tokens.size(1) // 2
  offset = -1

  return attention_pattern.diagonal(offset = offset).sum()


scores = torch.zeros(layers, heads, device = device)

for layer in range(layers):
  attention_pattern_layer = cache[utils.get_act_name('pattern', layer)].mean(dim=0)
  for head in range(heads):

    attention_pattern_head = attention_pattern_layer[head, :, :]
    score = previous_token_score(attention_pattern_head, prompt_tokens=prompt_tokens)
    scores[layer, head] = score

In [None]:
fig = px.imshow(scores.cpu().numpy(),
          title = 'Previous token scores for all Layers and Heads',
          labels= {'x':'Head', 'y': 'Layer'},
          color_continuous_scale = 'dense',
          text_auto = '.2f',
          aspect = 'auto'
            )

fig.show()

In [None]:
accumulated_resid = cache.accumulated_resid(layer = -1, apply_ln = True)

Now I have to use accumulated_resid and other functions to get the resids for layers, heads and so on - and then apply layer norm on it, and then W_U on it - to get the logits for each layer, head etc.

But this will be

In [None]:
accumulated_resid_second_half = (accumulated_resid @ model.W_U)[:, :, prompt_tokens.size(1)//2:-1].to(device)
target_token_indices = prompt_tokens[:, (prompt_tokens.size(1)//2)+1:].to(device)



target_token_indices_reshaped = target_token_indices.unsqueeze(0).unsqueeze(-1)

target_token_indices_reshaped = target_token_indices_reshaped.repeat(accumulated_resid_second_half.size(0), 1, 1, 1)



logits_for_target_tokens = accumulated_resid_second_half.gather(dim = -1, index = target_token_indices_reshaped)


logits_for_target_tokens = logits_for_target_tokens.squeeze(-1)

logits_per_layer = logits_for_target_tokens.mean(dim = (-1, -2))

In [None]:
import pandas as pd

df = pd.DataFrame({
    'Layer': range(len(logits_per_layer)),
    'Average Logit': logits_per_layer.detach().cpu().numpy()
})

# Create the line plot
fig = px.line(df,
              x='Layer',
              y='Average Logit',
              title='Overall Average Logit for Target Tokens Across Layers')

fig.show()

In [None]:
per_head_resid = cache.stack_head_results( layer = -1, apply_ln=True)

In [None]:
per_head_resid_second_half = (per_head_resid @ model.W_U)[:, :, prompt_tokens.size(1)//2:-1].to(device)
target_token_indices = prompt_tokens[:, (prompt_tokens.size(1)//2)+1:].to(device)



target_token_indices_reshaped = target_token_indices.unsqueeze(0).unsqueeze(-1)

target_token_indices_reshaped = target_token_indices_reshaped.repeat(per_head_resid_second_half.size(0), 1, 1, 1)



logits_for_target_tokens = per_head_resid_second_half.gather(dim = -1, index = target_token_indices_reshaped)


logits_for_target_tokens = logits_for_target_tokens.squeeze(-1)



In [None]:
logits_per_head = logits_for_target_tokens.mean(dim = (-1,-2)).reshape(-1, heads)

In [None]:
fig = px.imshow(logits_per_head.cpu().detach().numpy(),
          title = 'Overall Average Logit for Target Tokens Across Layers and Heads',
          labels= {'x':'Head', 'y': 'Layer'},
          color_continuous_scale = 'dense',
          text_auto = '.2f',
            aspect = 'auto')

fig.show()

In [None]:
from transformer_lens import patching

In [None]:
def corrupt_induction_prompt(clean_tokens):
  corrupt = clean_tokens.clone()
  corrupt[:, 1:(corrupt.size(1)//2) + 1] = torch.randint(1, model.cfg.d_vocab, (corrupt.size(0), corrupt.size(1)//2))

  return corrupt


clean_tokens = prompt_tokens.to(device)
corrupt_tokens = corrupt_induction_prompt(prompt_tokens).to(device)

In [None]:
clean_logits = model(clean_tokens)
corrupt_logits = model(corrupt_tokens)

In [None]:
def calculate_correct_logits(logits, tokens):
  return torch.gather(logits[: , (logits.size(1)//2 ):-1 , :], dim = -1, index = tokens[: , 1:(tokens.size(1)//2)+1].unsqueeze(-1)).mean().item()


clean_score = calculate_correct_logits(clean_logits, clean_tokens)
corrupt_score = calculate_correct_logits(corrupt_logits, clean_tokens)

In [None]:
def patching_metric(ablated_logits):
  ablated_score = calculate_correct_logits(ablated_logits, clean_tokens)
  # Ensure the result is a tensor so .item() can be called on it
  return torch.tensor((ablated_score - corrupt_score) / (clean_score - corrupt_score))

In [None]:
def zero_ablation(tensor, hook, head):
  target = tensor.clone()
  target[: , :, head, :] = 0
  return target

def mean_ablation(tensor, hook, head):
  target = tensor.clone()
  target[: , :, head, :] = target.mean()
  return target

In [None]:
# Zero Ablation on heads

results = torch.zeros(layers, heads, device = device)

for layer in tqdm.tqdm(range(layers)):
  for head in (range(heads)):

    hook_function = partial(zero_ablation, head = head)

    ablated_logits = model.run_with_hooks(clean_tokens, fwd_hooks = [(utils.get_act_name('z', layer), hook_function)])

    result = 1 - patching_metric(ablated_logits)

    results[layer, head] = result



In [None]:
fig = px.imshow(results.cpu().detach().numpy(),
                title = 'Results of ablating z for each head',
                labels = {'x': 'Heads', 'y': 'Layers'}, # Corrected labels based on tensor dimensions
                color_continuous_scale = 'dense',
                aspect = 'auto',
                text_auto = '.2f'
          )

fig.show()

In [None]:
# Zero Ablation on heads

results = torch.zeros(layers, heads, device = device)

for layer in tqdm.tqdm(range(layers)):
  for head in (range(heads)):

    hook_function = partial(mean_ablation, head = head)

    ablated_logits = model.run_with_hooks(clean_tokens, fwd_hooks = [(utils.get_act_name('z', layer), hook_function)])

    result = 1- patching_metric(ablated_logits)

    results[layer, head] = result



In [None]:
fig = px.imshow(results.cpu().detach().numpy(),
                title = 'Results of mean ablating z for each head',
                labels = {'x': 'Heads', 'y': 'Layers'}, # Corrected labels based on tensor dimensions
                color_continuous_scale = 'dense',
                aspect = 'auto',
                text_auto = '.2f'
          )

fig.show()

In [None]:
clean_logits, clean_cache = model.run_with_cache(clean_tokens)

In [None]:
resid_pre_patching = patching.get_act_patch_resid_pre(model = model, corrupted_tokens= corrupt_tokens, clean_cache = clean_cache, patching_metric = patching_metric)

In [None]:
fig = px.imshow(resid_pre_patching.cpu().detach().numpy(),
                title = 'Patching Residual Stream Before Block (resid_pre)',
                labels = {'x':'Token Position' , 'y': 'Layer' },
                color_continuous_scale = 'RdBu',
                color_continuous_midpoint=0,
                aspect="auto"
               )

fig.update_layout(coloraxis_colorbar_title="Patching Metric")

fig.show()

In [None]:
heads_patching = patching.get_act_patch_attn_head_all_pos_every(model = model, corrupted_tokens= corrupt_tokens, clean_cache = clean_cache, metric = patching_metric) # Corrected keyword argument name to 'metric'

In [None]:

facet_labels = ['Output', 'Query', 'Key', 'Value', 'Pattern']

fig = px.imshow(heads_patching.cpu().detach().numpy(),
                title = 'Patching Heads',
                labels = {'x':'Heads' , 'y': 'Layer' },
                color_continuous_scale = 'RdBu',
                color_continuous_midpoint=0,
                aspect="auto",
                facet_col= 0,
                facet_col_wrap=3,
               )


for i, label in enumerate(facet_labels):

    annotation_name = f'annotations[{i}]'
    fig.layout[annotation_name]['text'] = label


fig.update_layout(coloraxis_colorbar_title="Patching Metric")

fig.show()

In [None]:
help(patching)