<a href="https://colab.research.google.com/github/hollymandel/Mistral7B_Induction_Heads/blob/main/ih_babi.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Impact of ablating induction heads on Mistral 7B performance in a question-answering task**

In this notebook I measure the performance of Mistral7B on a single-word question-answering task. After some prompt engineering, Mistral 7B gets approximately 75% of questions right ("unablated"). Then I ablate the N most induction-head-like attention heads (see Olsson et al. 2022, https://transformer-circuits.pub), as determined by the scores computed in `ih_sweep.ipynb` and rerun the assessment ("high score ablations"). Finally, I instead ablate the N attention heads that are the least induction-head-like, subject to the constraint that the distribution over model layers matches the previously ablated heads ("low score ablations").

Here are the results for N = 16. I used a largish N to have a more observable effect. However I did not find that ablating induction heads had more of an effect than ablating low-induction-score controls. It would be interesting to dissect the difference between performance on bAbI and performance on the repeated sequence task that was used to measure induction heads. With a higher-compute setup, it would also be interesting to measure the impact of ablating each individual head versus the induction head scores.

Unablated: 1481/2000 \\
High score ablations: 1476/2000 \\
Low score ablations: 1481/2000

In [1]:
""" evaluating the model using the bAbI question-answering task from Facebook """

!pip --q install datasets
from datasets import load_dataset
dataset = load_dataset("facebook/babi_qa","en-10k-qa1",revision="main")

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.0/542.0 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m17.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[?25h

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading data:   0%|          | 0.00/51.1k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/9.41k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/200 [00:00<?, ? examples/s]

In [2]:
!pip --q install transformers
!pip --q install bitsandbytes accelerate xformers einops # necessary for quantization

import torch
import transformers
import matplotlib.pyplot as plt
import random
import numpy as np
import pickle

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.8/119.8 MB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.6/302.6 kB[0m [31m25.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m222.7/222.7 MB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.3/21.3 MB[0m [31m45.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
""" using a quantized and sharded version of Mistral 7B. Thanks to Hugo Fernandez for
engineering the model to run on a single T4. Note that this block takes 5-10 minutes to run."""

bnb_config = transformers.BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16)

model_id = "Hugofernandez/Mistral-7B-v0.1-colab-sharded"
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    quantization_config = bnb_config,
    device_map = "auto",
    attn_implementation="eager")
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
start_token = tokenizer.bos_token
end_token = tokenizer.eos_token

config.json:   0%|          | 0.00/613 [00:00<?, ?B/s]

pytorch_model.bin.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/6 [00:00<?, ?it/s]

pytorch_model-00001-of-00006.bin:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

pytorch_model-00002-of-00006.bin:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

pytorch_model-00003-of-00006.bin:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

pytorch_model-00004-of-00006.bin:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

pytorch_model-00005-of-00006.bin:   0%|          | 0.00/4.83G [00:00<?, ?B/s]

pytorch_model-00006-of-00006.bin:   0%|          | 0.00/4.25G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/918 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/438 [00:00<?, ?B/s]

In [5]:
N_HEADS_ABLATE = 16

with open("ih_scores_dict.pkl", "rb") as f:
  ih_scores_dict = pickle.load(f)

def top_values(scores_dict, n):
  """ extract the top n scorers from scores_dict. Scores_dict is formatted as
  key: vector of scores. Output formatted as (score, key, index in vector). """
  flattened = []
  for key, vector in scores_dict.items():
      for idx, value in enumerate(vector):
          flattened.append((value, key, idx))

  sorted_flattened = sorted(flattened, key=lambda x: x[0], reverse=True)

  return sorted_flattened[:n]

top_ih = top_values(ih_scores_dict, n = N_HEADS_ABLATE) # List[(score, layer, head index)]

# count the layers in top_ih
top_ih_layers = [ layer for _, layer, _ in top_ih ]
top_ih_layers = { key: sum(x == key for x in top_ih_layers) for key in top_ih_layers }

# pick lowest-scoring heads subject to matching layer multiplicity. Also formatted as
# List[(score, layer, head index)]
matched_list = []
for layer, count in top_ih_layers.items():
  vector = ih_scores_dict[layer].copy()
  labelled = [ (value, layer, idx) for idx, value in enumerate(vector) ]
  sorted_labelled = sorted(labelled, key=lambda x: x[0], reverse=False)
  matched_list.extend(sorted_labelled[:count])

In [6]:
import re
import numpy as np

PROMPT_HEADER = "Read the following passage and answer the questions with a single word.\n\n"

def create_prompt(story, header = PROMPT_HEADER, start = 0, end = np.infty):
  """ Mistral 7B performs poorly on the "raw task", so this function appends an instructional
  header and then gives several examples of successful answering before asking the question.
  But note that context from earlier questions is relevant to later questions."""
  BLOCK_LENGTH = 3

  prompt = header
  for i, (text, answer) in enumerate(zip(story["text"], story["answer"])):
    top = min(len(story["text"]),end)
    if i < start or i >= end:
      continue
    # task is a list formatted as [ passage, question, answer, ... ]
    if (i+1) % BLOCK_LENGTH != 0:
      if (i+1) % BLOCK_LENGTH == 1:
        prompt += "Passage: "
      prompt += f"{text}\n"
    else:
      prompt += "\nQuestion: " + f"{text}\n"
      prompt += "Answer: "
      if i != top-1:
        prompt += f"{answer}\n"

  return prompt

""" Example prompt """
create_prompt(dataset["train"]["story"][0])

'Read the following passage and answer the questions with a single word.\n\nPassage: Mary moved to the bathroom.\nJohn went to the hallway.\n\nQuestion: Where is Mary?\nAnswer: bathroom\nPassage: Daniel went back to the hallway.\nSandra moved to the garden.\n\nQuestion: Where is Daniel?\nAnswer: hallway\nPassage: John moved to the office.\nSandra journeyed to the bathroom.\n\nQuestion: Where is Daniel?\nAnswer: hallway\nPassage: Mary moved to the hallway.\nDaniel travelled to the office.\n\nQuestion: Where is Daniel?\nAnswer: office\nPassage: John went back to the garden.\nJohn moved to the bedroom.\n\nQuestion: Where is Sandra?\nAnswer: '

In [7]:
MODEL_SIZE = 128 # internal representation dimension of Mistral 7B

def score_model(model, story, margin = 3, verbose = 0, start = 0, end = np.infty):
  """ Format the prompt using the prompt header and several examples, get `margin` tokens of output,
  and then mark as correct if the solution is contained in the output -- a bit liberal, to allow for
  some formatting tokens, etc. """

  if end == np.infty:
    end = len(story["id"]) # the end parameter is not robust to off-block choices

  prompt = create_prompt(story, start = start, end = end)
  prompt_encode = tokenizer.encode(prompt, return_tensors = "pt", padding=True, truncation=True)

  prompt_len = len(prompt_encode[0])
  output = model.generate(prompt_encode, max_length=prompt_len + margin, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)
  decode = tokenizer.decode(output[0][prompt_len:], padding=True, return_tensors="pt", truncation=True)

  parsed = re.findall(r'\w+', decode)

  if verbose >= 2:
    print(prompt)
  if verbose >= 1:
    print(parsed)
    print(f"desired: {story['answer'][end-1]}")

  if story["answer"][end-1] in parsed:
    return True
  return False

def insert_ablation_hook(head_index):
  def hook(module, input, output):
      output[:,:,head_index*MODEL_SIZE:(head_index+1)*MODEL_SIZE].data.zero_()
  return hook

def ablation_experiment(model, ablation_list, print_pd = 5):
  """ Ablate the heads in ablation_list, compute performance on score_model for the entire training
  dataset, make sure to remove hooks no matter what happenes. Ablation list is assumed formatted as
  (score, head, layer), though score is not used here. """

  try:
    ablation_hooks = {}
    for i, (_, layer, head) in enumerate(ablation_list):
      ablation_hooks[i] = model.model.layers[layer].post_attention_layernorm.register_forward_hook(insert_ablation_hook(head))

    win = 0
    lose = 0
    lose_list = []
    for i, story in enumerate(dataset["train"]["story"]):
      if score_model(model, story,verbose=0,end=np.infty):
        win += 1
      else:
        lose += 1
        lose_list.append(i)

      if i % print_pd == 0:
        print(f"score: {win}/{win+lose}")

  finally:
    try:
      for hook in ablation_hooks.values():
        hook.remove()
      print("successfully removed all hooks")
    except Exception as e:
      raise Exception("Critical Error: failure to remove ablation hook") from e

  return win, lose, lose_list

In [None]:
""" run ablation experiment """

win_hs_abl, lose_hs_abl, lose_hs_list = ablation_experiment(model, top_ih) # 16: 1489/2000
# win_ls_abl, lose_ls_abl, lose_ls_list = ablation_experiment(model, matched_list) # 16: 1476/2000
# win_unabl, lose_unabl, lose_unable_list = ablation_experiment(model, []) # 16: 1481/2000