# NDIF Main Demo Notebook

# Setup

In [2]:
# # Install necessary libraries (don't do this in cell, because you need ipykernel first anyway)
# %pip install -r requirements.txt

import os
import sys
from pathlib import Path
import gdown
import zipfile
from IPython.display import clear_output

root = Path("/root/function_vectors")
assert root.exists()

if not (root / "ndif-dev").exists():

    file_id = "1jS0ydba19uPXCC786_Sx1Bylqx-p0hEc"
    url = f"https://drive.google.com/uc?id={file_id}&export=download"
    output = "ndif-dev.zip"

    gdown.download(url, output, quiet=False)

    with zipfile.ZipFile(output, "r") as zip_ref:
        zip_ref.extractall(root)

    os.remove(output)

if 'ndif-dev' not in sys.path:
    sys.path.append('ndif-dev')
    # %pip install -e ndif-dev

import engine
from engine import LanguageModel
from engine.intervention import InterventionProxy

from IPython import get_ipython
ipython = get_ipython()
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

# clear_output()

In [3]:
import torch as t
import einops
import circuitsvis as cv
import plotly.express as px
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import transformers
from pathlib import Path
import numpy as np
from jaxtyping import Int, Float
from typing import List, Optional, Tuple, Union
from tqdm import tqdm
from IPython.display import display
import transformer_lens.utils as utils
import webbrowser
from rich import print as rprint
from rich.table import Table
import openai
import time
import gc
import string

from plotly_utils import imshow, line

import tests
# import function_vectors.tests as tests

device = t.device("cuda" if t.cuda.is_available() else "cpu")
assert str(device) == "cuda"

t.set_grad_enabled(False);

In [4]:
# gpt2 = LanguageModel('gpt2', device_map=device)
model = LanguageModel('EleutherAI/gpt-j-6b', device_map=device)
tokenizer = model.tokenizer

# 1️⃣ Introduction to `nnsight`

## Important syntax

Here, we'll discuss some important syntax for interacting with `nnsight` models. Since these models are extensions of HuggingFace models, some of this information (e.g. tokenization) applies to plain HuggingFace models as well as `nnsight` models, and some of it (e.g. forward passes) is specific to `nnsight`, i.e. it would work differently if you just had a standard HuggingFace model. Before each section, we'll indicate which is which.

### Model config

*This applies to HuggingFace and `nnsight` models.*

Each model comes with a `model.config`, which contains lots of useful information about the model (e.g. number of heads and layers, size of hidden layers, etc.). You can access this with `model.config`. Run the code below to see this in action, and to define some useful variables for later.

In [5]:
N_HEADS = model.config.n_head
N_LAYERS = model.config.n_layer
D_MODEL = model.config.n_embd
D_HEAD = D_MODEL // N_HEADS

print(f"Number of heads: {N_HEADS}")
print(f"Number of layers: {N_LAYERS}")
print(f"Model dimension: {D_MODEL}")
print(f"Head dimension: {D_HEAD}\n")

print(model.config)

Number of heads: 16
Number of layers: 28
Model dimension: 4096
Head dimension: 256

GPTJConfig {
  "_name_or_path": "EleutherAI/gpt-j-6b",
  "activation_function": "gelu_new",
  "architectures": [
    "GPTJForCausalLM"
  ],
  "attn_pdrop": 0.0,
  "bos_token_id": 50256,
  "embd_pdrop": 0.0,
  "eos_token_id": 50256,
  "gradient_checkpointing": false,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gptj",
  "n_embd": 4096,
  "n_head": 16,
  "n_inner": null,
  "n_layer": 28,
  "n_positions": 2048,
  "resid_pdrop": 0.0,
  "rotary": true,
  "rotary_dim": 64,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50,
      "temperature": 1.0
    }
  },
  "tie_word_embeddings": false,
  "tokenizer_class": "GPT2Tokenizer",
  "transformers_ver



### Tokenizers

*This applies to HuggingFace and `nnsight` models.*

A model comes with a tokenizer, accessable with `model.tokenizer` (just like TransformerLens). Unlike TransformerLens, we won't be using utility functions like `model.to_str_toks`, instead we'll be using the tokenizer directly. Some important functions for today's exercises are:

* `tokenizer` (i.e. just calling it on some input)
    * This takes in a string (or list of strings) and returns the tokenized version.
    * It will return a dictionary, always containing `input_ids` (i.e. the actual tokens) but also other things which are specific to the transformer model (e.g. `attention_mask` - see dropdown).
    * Other useful arguments for this function:
        * `return_tensors` - if this is `"pt"`, you'll get results returned as PyTorch tensors, rather than lists (which is the default).
        * `padding` - if True (default is False), the tokenizer can accept sequences of variable length. The shorter sequences get padded at the beginning (see dropdown below for more).
* `tokenizer.decode`
    * This takes in tokens, and returns the decoded string.
    * If the input is an integer, it returns the corresponding string. If the input is a list / 1D array of integers, it returns all those strings concatenated (which can sometimes not be what you want).
* `tokenizer.batch_decode`
    * Equivalent to `tokenizer.decode`, but it doesn't concatenate.
    * If the input is a list / 1D integer array, it returns a list of strings. If the input is 2D, it will concatenate within each list.
* `tokenizer.tokenize`
    * Takes in a string, and returns a list of strings.

Run the code below to see some examples of these functions in action.

In [6]:
# Calling tokenizer returns a dictionary, containing input ids & other data.
# If returned as a tensor, then by default it will have a batch dimension.
print(tokenizer("This must be Thursday", return_tensors="pt"))

# Decoding a list of integers, into a concatenated string.
print(tokenizer.decode([40, 1239, 714, 651, 262, 8181, 286, 48971, 12545, 13]))

# Using batch decode, on both 1D and 2D input.
print(tokenizer.batch_decode([4711, 2456, 481, 307, 6626, 510]))
print(tokenizer.batch_decode([[1212, 6827, 481, 307, 1978], [2396, 481, 428, 530]]))

# Split sentence into tokens (note we see the special Ġ character in place of prepended spaces).
print(tokenizer.tokenize("This sentence will be tokenized"))

{'input_ids': tensor([[1212, 1276,  307, 3635]]), 'attention_mask': tensor([[1, 1, 1, 1]])}
I never could get the hang of Thursdays.
['These', ' words', ' will', ' be', ' split', ' up']
['This sentence will be together', 'So will this one']
['This', 'Ġsentence', 'Ġwill', 'Ġbe', 'Ġtoken', 'ized']


<details>
<summary>Note on <code>attention_mask</code> (optional)</summary>

`attention_mask`, which is a series of 1s and 0s. We mask attention at all 0-positions (i.e. we don't allow these tokens to be attended to). This is useful when you have to do padding. For example:

```python
model.tokenizer(["Hello world", "Hello"], return_tensors="pt", padding=True)
```

will return:

```
{
    'attention_mask': tensor([[1, 1], [0, 1]]),
    'input_ids': tensor([[15496,   995], [50256, 15496]])
}
```

We can see how the shorter sequence has been padded at the beginning, and attention to this token will be masked.

</details>

### Model outputs

*This applies to HuggingFace and `nnsight` models.*

If you've worked with TransformerLens, then you'll be used to thinking of logits as the default output of a model, when you run a forward pass on that model.

HuggingFace models are a bit different. The standard way to get output from them is using the `model.generate` method. This method takes in a dictionary of inputs (which you can get from the tokenizer), and returns an object which contains a bunch of different things: the actual tokens generated by the model, plus maybe a few other things depending on what arguments you passed to `generate` (e.g. this might include logits, or hidden states).

The `nnsight` models we'll be using here are based on HuggingFace models, and we'll also be using `model.generate` which takes basically the same arguments, and produces an output object that contains the same kind of information. However, the exact way we use this method is quite different for `nnsight`...

### Running the model

*This only applies to `nnsight` models.*

Rather than just calling `model.generate`, we use a **context manager** to run the model. This is useful because we can access & do things with the internal state of the model, in the middle of the forward pass. Using this context manager is like setting up a set of detailed instructions for how the forward pass will work, and only when you exit the context manager are the instructions actually sent off & executed.

Below is the simplest example of code to run the model (and also access the internal states of the model). Run it and look at the output, then read the explanation below.

In [7]:
import warnings

# Running this code so we don't get this printout each time we run a fwd pass
warnings.filterwarnings("ignore", message="Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.")


In [8]:
prompt = 'The Eiffel Tower is in the city of'

with model.generate(max_new_tokens=2, pad_token_id=tokenizer.eos_token_id) as generator:
    with generator.invoke(prompt) as invoker:
        i = model.transformer.h[-1].input.save()
        hidden_states: InterventionProxy = model.transformer.h[-1].output[0].save()

# Get output, which is a tensor of token IDs, of shape (1, seq_len+1)
output = generator.output
print([model.tokenizer.decode(t) for t in output])

# Get hidden states, which are the value of the residual stream at last layer
print("Residual stream shape = ", hidden_states.value.shape)

['The Eiffel Tower is in the city of Paris,']
Residual stream shape =  torch.Size([1, 10, 4096])


Lets go over this piece by piece.

**First, we create a generation context block** by calling `.generate(...)` on the model object. This denotes that we wish to generate tokens given some prompts.

```python
with model.generate(max_new_tokens=1, pad_token_id=tokenizer.eos_token_id) as generator:
```

Calling `.generate(...)` does not actually initialize or run the model. Only after the `with ... as generator:` block is exited is the model actually loaded and run. All operations in the block are "proxies" which essentially creates a graph of operations we wish to carry out later.

The `max_new_tokens=1` argument just means we do a single forward pass, rather than autoregressively generate multiple tokens. The `pad_token_id` argument isn't strictly necessary (this is the default behaviour anyway), it just suppresses a warning message that would otherwise be printed.

**Within the generation context,** we create invocation contexts to specify the actual prompts we want to run.

```python
with generator.invoke(PROMPT) as invoker:
```

**Within an invoke context**, all operations/interventions will be applied to the processing of the prompt. Models can be run on a variety of input formats: strings, lists of tokens, tensors of tokens, etc.

This is all we actually need to run a forward pass on the model. We could replace the `hidden_states` line with just `pass`, and we'd still be able to access the model output in the same way. But the most interesting part of `nnsight` is the ability to access the model's internal states (like you've probably already done with TransformerLens). Let's see how this works!

```
hidden_states = model.transformer.h[-1].output[0].save()
```

On this line we're saying: access the last layer of the transformer `model.transformer.h[-1]`, access this layer's output `.output` (which is a tuple of tensors), index the first tensor in this tuple `.output[0]`, and save it `.save()`.

Let's break down this line in a bit more detail:

#### `model.transformer.h[-1]`

If you print out the model's architecture with `print(model)`, you'll see that it consists of `transformer` and `lm_head` (for "language modelling head"). The `transformer` module is made up of embeddings & dropout, a series of layers (called `.h`, for "hidden states"), and a final layernorm.

When you're working with different model architectures, it'll often be necessary to print out the model / visit the source code page, to see exactly how they work and what different modules are named. [Here](https://huggingface.co/transformers/v4.11.3/_modules/transformers/models/gptj/modeling_gptj.html) is the source code page for GPT-J.

#### `.output[0]`

When you access `.output` of a module within a context manager, you're returning a **proxy** for the output of this module during inference. Doing operations on it (like indexing it) also return proxies.

Note, modules often have output (and input) stored in tuples. Even if there is only one output and one input, these can be stored as length-1 tuples. In this particular case, `model.transformer.h[-1]` is a `GPTJBlock` module, which outputs a length-2 tensor. The 0th is the residual stream at the end of the block (i.e. the thing we want in this case).

<details>
<summary>Optional exercise - from the <a href="https://huggingface.co/transformers/v4.11.3/_modules/transformers/models/gptj/modeling_gptj.html">source code page</a>, can you figure out what the second output in this tuple is?</summary>

The second output is also a tuple of tensors, of length 2. In the GPT-J source code, they are called `present`. They represent the keys and values which were calculated in this forward pass (as opposed to those that were calculated in an earlier forward pass, and cached by the model). Since we're only generating one new token, these are just the full keys and values.

</details>

When debugging, you can call `.shape` on a proxy. This will even work if the proxy represents a tuple of tensors; you'll get a tuple of all the sizes of these tensors.

You can also use `.input` to access the inputs to a module - this works in the same way (often also stored as a tuple).

#### `.save()`

This informs the computation graph to clone the value of a proxy, allowing us to access the value of a proxy after generation.

During processing of the intervention computational graph we are building, when the value of a proxy is no longer ever needed, its value is dereferenced and destroyed. If you've saved, then you'll be able to access the value of the proxy after this happens (i.e. outside the context manager), using the `.value` attribute.

### Exercise - visualize attention heads

```c
Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪

You should spend up to 10-20 minutes on this exercise.
```

That was a lot, so lets put it into practice. Your first task is to extract the attention patterns from the zeroth layer of the transformer, and visualize them using circuitsvis. As a reminder, the syntax for circuitsvis is:

```python
cv.attention.attention_patterns(
    tokens=tokens,
    attention=attention,
)
```

where `tokens` is a list of strings, and `attention` is a tensor of shape `(num_heads, num_tokens, num_tokens)`.

If you're stuck, [here's a link](https://huggingface.co/transformers/v4.11.3/_modules/transformers/models/gptj/modeling_gptj.html) to the source code for GPT-J. Look for how the attention patterns are calculated, within the `GPTJAttention` block.

*Note - this model uses dropout on the attention probabilities, as you'll probably notice from looking at the source code in the link above. This won't affect the model's behaviour because dropout is disabled in inference mode (and using the `generate` method always puts a model in inference mode). But it is still a layer which exists in the model, so you can access its input or output just like any other module.*

<details>
<summary>Aside - inference mode</summary>

Dropout is one of the two main layers whose behaviour changes in inference mode (the other is BatchNorm).

If you want to run the model without inference mode, you can wrap your code in `with model.forward(inference=False):`. However, you don't need to worry about this for the purposes of these exercises.

</details>

If you're stuck indexing the model, see the following hint:

<details>
<summary>Hint - what module you should get attention from</summary>

You want to extract attention from `model.transformer.h[0].attn.attn_dropout.input`. If you used `.output`, it would give you the same values (although they might differ by a dummy batch dimension). Both of these will return a single tensor, because dropout layers take just one input and return just one output.
</details>

In [9]:
with model.generate(max_new_tokens=1, pad_token_id=tokenizer.eos_token_id) as generator:
    with generator.invoke(prompt) as invoker:
        attn_patterns = model.transformer.h[0].attn.attn_dropout.input.save()

In [10]:
# Get string tokens (replacing special character for spaces)
str_tokens = model.tokenizer.tokenize(prompt)
str_tokens = [s.replace('Ġ', ' ') for s in str_tokens]

# Attention patterns (squeeze out the batch dimension)
attn_patterns_value = attn_patterns.value[0].squeeze(0)

print("Layer 0 Head Attention Patterns:")
display(cv.attention.attention_patterns(
    tokens=str_tokens,
    attention=attn_patterns_value,
))

Layer 0 Head Attention Patterns:


<details>
<summary>Solution (and explanation)</summary>

```python
with model.generate(max_new_tokens=1, pad_token_id=tokenizer.eos_token_id) as generator:
    with generator.invoke(prompt) as invoker:
        attn_patterns = model.transformer.h[0].attn.attn_dropout.input.save()

str_tokens = model.tokenizer.tokenize(prompt)

# Attention patterns (squeeze out the batch dimension)
attn_patterns = attn_patterns.value[0].squeeze(0)

print("Layer 0 Head Attention Patterns:")
display(cv.attention.attention_patterns(
    tokens=str_tokens,
    attention=attn_patterns,
))
```

Explanation:

* Within the context managers:
    * We access the attention patterns by taking the input to the `attn_dropout`.
        * From the GPT-J source code, we can see that the attention weights are calculated by standard torch functions (and an unnamed `nn.Softmax` module) from the key and query vectors, and are then passed through the dropout layer before being used to calculate the attention layer output. So by accessing the input to the dropdown layer, we get the attention weights before dropout is applied.
        * Because of the previously discussed point about dropout not working in inference mode, we could also use the output of `attn_dropout`, and get the same values.
    * We use the `.save()` method to save the attention patterns (as an object).
* Outside of the context managers:
    * We use the `tokenize` method to tokenize the prompt.
    * We use the `.value` to access the actual value of the intervention proxy `attn_patterns`.
        * This returns a tuple of length-1, so we index into it to get the actual tensor, then squeeze to remove the batch dimension.
        
</details>


As an optional bonus exercise, you can verify for yourself that these are the correct attention patterns, by calculating them from scratch using the key and query vectors. Using `model.transformer.h[0].attn.q_proj.output` will give you the query vectors, and `k_proj` for the key vectors. However, one thing to be wary of is that GPT-J uses **rotary embeddings**, which makes the computation of attention patterns from keys and queries a bit harder than it would otherwise be. See [here](https://blog.eleuther.ai/rotary-embeddings/) for an in-depth discussion of rotary embeddings, and [here](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#q=rotary) for some rough intuitions.

# 2️⃣ Task-encoding hidden states

(Note - this section structurally follows section 2.1 of the function vectors paper).

We begin with the following question:

> When a transformer processes an ICL (in-context-learning) prompt with exemplars demonstrating task $T$, do any hidden states encode the task itself?

Throughout these exercises, we'll be focusing on the **antonyms task**. In other words, given a prompt which includes a bunch of antonym pairs, ending with a single word, what causes the model to complete this prompt with an antonym? Is there a residual stream state that encodes the "antonym task"?

### Exercise (optional) - generate your own antonym pairs

```c
Difficulty: 🔴🔴🔴🔴⚪
Importance: 🔵⚪⚪⚪⚪

You should spend up to 10-30 minutes on this exercise - depending on your familiarity with the OpenAI Python API.
```

We've provided you a list of word pairs, in the file `data/antonym_pairs.txt`. **Optionally, you can just skip this exercise, and run the code below to load these words in.**

Alternatively, if you want to run experiments like the ones in this paper, it can be good practice to learn how to generate prompts from GPT-4 or other models (this is how we generated the data for this exercise). To do this, you can fill in the `generate_dataset` function below, which should query GPT-4 and get a list of antonym pairs. See [here](https://platform.openai.com/docs/guides/gpt/chat-completions-api) for a guide to using the chat completions API, if you haven't already used it.

Use the two dropdowns below (in order) for some guidance.

<details>
<summary>Getting started</summary>

Here is a recommended template:

```python
response = openai.ChatCompletion.create(
    model="gpt-4",
    messages=[
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": antonym_task},
        {"role": "assistant", "content": start_of_response},
    ]
)
```

where `antonym_task` explains the antonym task, and `start_of_respose` gives the model a prompt to start from (e.g. "Sure, here are some antonyms: ..."), to guide its subsequent behaviour.

</details>

<details>
<summary>Getting started</summary>

Here is an template you might want to use for the actual request:

```python
example_antonyms = "old:young, top:bottom, awake:asleep, future:past, "

response = openai.ChatCompletion.create(
    model="gpt-4",
    messages=[
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": f"Give me {N} examples of antonym pairs. They should be obvious, i.e. each word should be associated with a single correct antonym."},
        {"role": "assistant", "content": f"Sure! Here are {N} pairs of antonyms: {example_antonyms}"},
    ]
)
```

where `N` is the function argument. Note that we've provided a few example antonyms, and appended them to the start of GPT4's completion. This is a classic trick to guide the rest of the output (in fact, it's commonly used in adversarial attacks).

</details>

Note - it's possible that not all the antonyms returned will be solvable by GPT-J. In this section, we won't worry too much about this. When it comes to testing out our zero-shot intervention, we'll make sure to only use cases where GPT-J can actually solve it.

In [24]:
openai.api_key = "<your-key-here>"

def generate_dataset(N: int):

    t0 = time.time()

    # Define a few examples (for our dataset, and for our prompt)
    example_antonyms = "old:young, top:bottom, awake:asleep, future:past, "

    # Use openai's api to generate examples. We prepend the example antonyms to the assistant's response, to both
    # make sure the query is successful, and so that the assistant returns words in the same syntax as the examples.
    response = openai.ChatCompletion.create(
        model="gpt-4",
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": f"Give me {N} examples of antonym pairs. They should be obvious, i.e. each word should be associated with a single correct antonym."},
            {"role": "assistant", "content": f"Sure! Here are {N} pairs of antonyms satiisfying this specification: {example_antonyms}"},
        ]
    )
    # Add our examples to the response
    response_text: str = example_antonyms + response["choices"][0]["message"]["content"]

    # Create word pairs, by splitting on commas and colons
    word_pairs = [word_pair.split(":") for word_pair in response_text.strip(".\n").split(", ")]

    print(f"Finished in {time.time()-t0:.2f} seconds.")

    return word_pairs


# WORD_PAIRS = generate_dataset(100)

# # save the word pairs in a text file
# with open(root / "data" / "antonym_pairs.txt", "w") as f:
#     for word_pair in WORD_PAIRS:
#         f.write(f"{word_pair[0]} {word_pair[1]}\n")

In [25]:
# load the word pairs from the text file
with open(root / "data" / "antonym_pairs.txt", "r") as f:
    WORD_PAIRS = [line.split() for line in f.readlines()]

In [None]:
WORD_PAIRS[:10]

[['old', 'young'],
 ['top', 'bottom'],
 ['awake', 'asleep'],
 ['future', 'past'],
 ['beginning', 'end'],
 ['volunteer', 'compel'],
 ['best', 'worst'],
 ['big', 'small'],
 ['boring', 'exciting'],
 ['brave', 'cowardly']]

## Antonym Dataset

To handle this list of word pairs, we've given you some helpful classes.

Firstly, there's the `AntonymSequence` class, which takes in a list of word pairs and contains methods for constructing a prompt (and completion) from these words. Run the code below to see how it works.

In [14]:
class AntonymSequence:
    '''
    Class to store a single antonym sequence.

    Uses the default template "Q: {x}\nA: {y}" (with separate pairs split by "\n\n").
    '''
    def __init__(self, word_pairs: List[List[str]]):
        self.word_pairs = word_pairs
        self.x, self.y = zip(*word_pairs)

    def __len__(self):
        return len(self.word_pairs)

    def __getitem__(self, idx: int):
        return self.word_pairs[idx]

    def prompt(self):
        '''Returns the prompt, which contains all but the second element in the last word pair.'''
        p = "\n\n".join([f"Q: {x}\nA: {y}" for x, y in self.word_pairs])
        return p[:-len(self.completion())]

    def completion(self):
        '''Returns the second element in the last word pair (with padded space).'''
        return " " + self.y[-1]

    def __str__(self):
        '''Prints a readable string representation of the prompt & completion (indep of template).'''
        return f"{', '.join([f'({x}, {y})' for x, y in self[:-1]])}, {self.x[-1]} ->".strip(", ")


word_list = [["hot", "cold"], ["yes", "no"], ["in", "out"], ["up", "down"]]
seq = AntonymSequence(word_list)

print("Tuple-representation of the sequence:")
print(seq)
print("\nActual prompt, which will be fed into the model:")
print(seq.prompt())

Tuple-representation of the sequence:
(hot, cold), (yes, no), (in, out), up ->

Actual prompt, which will be fed into the model:
Q: hot
A: cold

Q: yes
A: no

Q: in
A: out

Q: up
A:


Secondly, we have the `AntonymDataset` class. This is also fed a word pair list, and it has methods for generating batches of prompts and completions. It can generate both clean prompts (where each pair is actually an antonym pair) and corrupted prompts (where the answers for each pair are randomly chosen from the dataset).

In [69]:
class AntonymDataset:
    '''
    Dataset to create antonym pair prompts, in ICL task format. We use random seeds for consistency
    between the corrupted and clean datasets.

    Inputs:
        word_pairs:
            list of ICL task, e.g. [["old", "young"], ["top", "bottom"], ...] for the antonym task
        size:
            number of prompts to generate
        n_prepended:
            number of antonym pairs before the single-word ICL task
        bidirectional:
            if True, then we also consider the reversed antonym pairs
        corrupted:
            if True, then the second word in each pair is replaced with a random word
        seed:
            random seed, for consistency & reproducibility
    '''

    def __init__(
        self,
        word_pairs: List[List[str]],
        size: int,
        n_prepended: int,
        bidirectional: bool = True,
        corrupted: bool = False,
        seed: int = 0,
    ):
        assert n_prepended+1 <= len(word_pairs), "Not enough antonym pairs in dataset to create prompt."
        
        self.word_pairs = word_pairs
        self.word_list = [word for word_pair in word_pairs for word in word_pair]
        self.size = size
        self.n_prepended = n_prepended
        self.bidirectional = bidirectional
        self.corrupted = corrupted
        self.seed = seed

        self.seqs = []
        self.prompts = []
        self.completions = []

        # Generate the dataset (by choosing random antonym pairs, and constructing `AntonymSequence` objects)
        for n in range(size):
            np.random.seed(seed + n)
            random_pairs = np.random.choice(len(self.word_pairs), n_prepended+1, replace=False)
            random_orders = np.random.choice([1, -1], n_prepended+1)
            if not(bidirectional): random_orders[:] = 1
            word_pairs = [self.word_pairs[pair][::order] for pair, order in zip(random_pairs, random_orders)]
            if corrupted:
                for i in range(len(word_pairs) - 1):
                    word_pairs[i][1] = np.random.choice(self.word_list)
            seq = AntonymSequence(word_pairs)

            self.seqs.append(seq)
            self.prompts.append(seq.prompt())
            self.completions.append(seq.completion())

    def create_corrupted_dataset(self):
        '''Creates a corrupted version of the dataset (with same random seed).'''
        return AntonymDataset(self.word_pairs, self.size, self.n_prepended, True, self.seed)

    def __len__(self):
        return self.size

    def __getitem__(self, idx: int):
        return self.seqs[idx]

You can see how this dataset works below. **Note that the correct completions have a prepended space**, because this is how the antonym prompts are structured - the answers are tokenized as `"A: answer" -> ["A", ":", " answer"]`. Forgetting prepended spaces is a classic mistake when working with transformers!

In [16]:
dataset = AntonymDataset(WORD_PAIRS, size=10, n_prepended=2, corrupted=False)

table = Table("Prompt", "Correct completion")
for seq, completion in zip(dataset.seqs, dataset.completions):
    table.add_row(str(seq), repr(completion))

rprint(table)

Compare this output to what it looks like when `corrupted=True`. You'll see the second elements of each pair change to a random word, but the first elements (and the final pair) stay the same.

In [17]:
dataset = AntonymDataset(WORD_PAIRS, size=10, n_prepended=2, corrupted=True)

table = Table("Prompt", "Correct completion")
for seq, completions in zip(dataset.seqs, dataset.completions):
    table.add_row(str(seq), repr(completions))

rprint(table)

<details>
<summary>Aside - the <code>rich</code> library</summary>

The `rich` library is a helpful little library to display outputs more clearly in a Python notebook or terminal. It's not necessary for this workshop, but it's a nice little tool to have in your toolbox.

The most important function is `rich.print` (usually imported as `rprint`). This can print basic strings, but it also supports the following syntax for printing colors:

```python
rprint("[green]This is green text[/], this is default color")
```

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/rprint-1.png" width="350">

and for making text bold / underlined:

```python
rprint("[u dark_orange]This is underlined[/], and [b cyan]this is bold[/].")
```

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/rprint-2.png" width="350">

It can also print tables:

```python
from rich.table import Table

table = Table("Col1", "Col2", title="Title") # title is optional
table.add_row("A", "a")
table.add_row("B", "b")

rprint(table)
```

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/rprint-3.png" width="150">

The text formatting (bold, underlined, colors, etc) is also supported within table cells.

</details>

### Exercise - run forward pass on antonym dataset

```c
Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪

You should spend up to 10-15 minutes on this exercise.
```

You should fill in the `calculate_h` function below. It should:

* Generate `N` random prompts from the antonym dataset (using the `create_prompts` method),
* Run a forward pass on the model, using the `nnsight` syntax we've demonstrated previously,
* Return a tuple of the model's output (i.e. its tokens) and the residual stream value at the end of layer `layer` (e.g. if `layer = -1`, this means the final value of the residual stream before we convert into logits).

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/h-intervention-1.png" width="900">

You should only return the residual stream values for the very last sequence position in each prompt, i.e. the last `:` token (where the model makes the antonym prediction).

<details>
<summary>Help - I'm not sure how to run (and index into) a batch of inputs.</summary>

If we pass a list of strings to the `generator.invoke` function, this will be tokenized with padding automatically.

The type of padding which is applied is **left padding**, meaning if you index at sequence position `-1`, this will get the final token in the prompt for all prompts in the list, even if the prompts have different lengths.

</details>

In [18]:
def calculate_h(dataset: AntonymDataset, layer: int = -1) -> Tuple[List[str], Tensor]:
    '''
    Generates N random sequences of the form "old:young, vanish:appear, dark:", but with 3 randomly chosen pairs (and orders).

    Averages over the hidden states of te last layer of GPT-J for each token in each sequence.

    Returns:
        completions: list of model completion strings (i.e. the strings the model predicts to follow the last token)
        h: average hidden state tensor at final sequence position, of shape (d_model,)
    '''
    with model.generate(max_new_tokens=1, pad_token_id=tokenizer.eos_token_id) as generator:
        with generator.invoke(dataset.prompts) as invoker:
            hidden_states = model.transformer.h[layer].output[0][:, -1].save()

    completions = generator.output[:, -1]
    completions = model.tokenizer.batch_decode(completions)
    h = hidden_states.value.mean(dim=0)

    return completions, h


tests.test_calculate_h(calculate_h)

All tests in `test_calculate_h` passed.


We've provided you with a helper function, which displays the model's output on the antonym dataset (and highlights the examples where the model's prediction is correct).

If you've constructed your antonyms dataset well, you should find that the model's completion is correct most of the time, and most of its mistakes are understandable (e.g. predicting `weak` rather than `fragile` as the antonym of `strong`). If we were being rigorous, we'd want to filter this dataset to make sure it only contains examples where the model can correctly perform the task - but here, we won't worry about this.

In [19]:
def display_model_completions_on_antonyms(
    dataset: AntonymDataset,
    completions: List[str],
    num_to_display: int = 20,
) -> None:
    table = Table("Prompt", "Model's completion", "Correct completion", title="Model's antonym completions (green = first token is a match)")

    for i in range(min(len(completions), num_to_display)):

        # Get model's completion, and correct completion
        completion = completions[i]
        correct_completion = dataset.completions[i]
        correct_completion_first_token = model.tokenizer.tokenize(correct_completion)[0].replace('Ġ', ' ')
        seq = dataset.seqs[i]
        
        # Color code the completion based on whether it's correct
        is_correct = (completion == correct_completion_first_token)
        completion = f"[b green]{repr(completion)}[/]" if is_correct else repr(completion)

        table.add_row(str(seq), completion, repr(correct_completion))

    rprint(table)

In [27]:
# Get uncorrupted dataset
dataset = AntonymDataset(WORD_PAIRS, size=20, n_prepended=2)

# Getting it from layer 12, cause the graph suggested this was where there was high accuracy
model_completions, h = calculate_h(dataset, layer=12)

# Displaying the output
display_model_completions_on_antonyms(dataset, model_completions)

### Exercise - intervene with $h$

```c
Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵⚪

You should spend up to 10-15 minutes on this exercise.
```

You should fill in the function `intervene_with_h` below. This will involve:

* Using the `calculate_h` function you just wrote to get the h-vector (this code is already filled in below),
* Defining a zero-shot dataset, i.e. with no prepended antonym pairs,
* Run two forward passes (within the same context manager):
    * One with no intervention (i.e. `h` is unchanged),
    * One with an intervention on `h` (i.e. the residual stream value is set to `h`, at the layer which `h` was taken from).
* Return the zero-shot dataset, as well as the completions for no intervention and intervention cases respectively (see docstring).

The diagram below shows how all of this should work, when combined with the `calculate_h` function.

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/h-intervention-2.png" width="950">

Hint - you can use `tokenizer.batch_decode` to turn a list of tokens into a list of strings.

<details>
<summary>Help - I'm not sure how best to get both the no-intervention and intervention completions.</summary>

You can use `with generator.invoke...` more than once within the same context manager, in order to add to your batch. This will eventually give you output of shape (2*N, seq_len), which can then be indexed and reshaped to get the completions in the no intervention & intervention cases respectively.

</details>

<details>
<summary>Help - I'm not sure how to intervene on the hidden state.</summary>

First, you can define the tensor of hidden states (i.e. using `.output[0]`, like you've done before).

Then, you can add to this tensor directly (or add to some indexed version of it). You can use inplace operations (i.e. `tensor += h`) or redefining the tensor (i.e. `tensor = tensor + h`); either work.

You won't need to `.save()` anything here; we're just intervening rather than storing the value of the residual stream.

</details>

In [30]:
def intervene_with_h(
    dataset: AntonymDataset,
    layer: int,
    zero_shot_size: int,
) -> Tuple[AntonymDataset, List[str]]:
    '''
    Extracts the vector `h` using previously defined function, and intervenes by adding `h` to the
    residual stream of a set of generated zero-shot prompts.

    Inputs:
        word_list: the list of words used to create the prompts
        dataset: the dataset of clean prompts from which we'll extract the `h`-vector
        layer: the layer we'll be extracting the `h`-vector from
        zero_shot_size: the number of zero-shot prompts to generate, to test our intervention
    
    Returns:
        zero_shot_dataset: the dataset of zero-shot prompts, which you should generate in this fn
        completions: list of string completions for the zero-shot prompts, without intervention
        completions_intervention: list of string completions for the zero-shot prompts, with h-intervention
    '''
    # Run previous function to get h-vector
    h = calculate_h(dataset, layer=layer)[1]
    
    # Get zero-shot dataset
    zero_shot_dataset = AntonymDataset(WORD_PAIRS, size=zero_shot_size, n_prepended=0)

    with model.generate(max_new_tokens=1, pad_token_id=tokenizer.eos_token_id) as generator:

        # First, run a forward pass where we don't intervene
        with generator.invoke(zero_shot_dataset.prompts) as invoker:
            pass

        # Next, run a forward pass on the zero-shot prompts where we do intervene
        with generator.invoke(zero_shot_dataset.prompts) as invoker:
            # Access the tensor (which is the first element of the output tuple)
            hidden_states = model.transformer.h[layer].output[0]
            # Add the h-vector to the residual stream, at the last sequence position
            hidden_states[:, -1] += h

    # Get the output (token IDs), reshape it into 2 rows of (no intervention, intervention)
    token_completions = generator.output[:, -1].reshape(2, zero_shot_size)
    # Decode to get the string tokens
    completions = model.tokenizer.batch_decode(token_completions[0])
    completions_intervention = model.tokenizer.batch_decode(token_completions[1])

    return zero_shot_dataset, completions, completions_intervention

Next, you can run the code below to see the results of your intervention. In cases where your model is correct, its completion is highlighted in green.

(Note, we're using the `repr` function, because a lot of the completions are line breaks, and this helps us see them more clearly!)

If you've done this correctly, you should see at least a few correct completions (~25%).

In [31]:
def display_model_completions_on_h_intervention(
    dataset: AntonymDataset,
    completions: List[str],
    completions_intervention: List[str],
    num_to_display: int = 20,
) -> None:
    table = Table(
        "Prompt", "Model's completion\n(no intervention)", "Model's completion\n(intervention)", "Correct completion",
        title="Model's antonym completions\n(green = first token is a match)"
    )

    for i in range(min(len(completions), num_to_display)):

        completion_ni = completions[i]
        completion_i = completions_intervention[i]
        correct_completion = dataset.completions[i]
        correct_completion_first_token = tokenizer.tokenize(correct_completion)[0].replace('Ġ', ' ')
        seq = dataset.seqs[i]
        
        # Color code the completion based on whether it's correct
        is_correct = (completion_i == correct_completion_first_token)
        completion_i = f"[b green]{repr(completion_i)}[/]" if is_correct else repr(completion_i)

        table.add_row(str(seq), repr(completion_ni), completion_i, repr(correct_completion))

    rprint(table)

In [32]:
dataset = AntonymDataset(WORD_PAIRS, size=20, n_prepended=2)

zero_shot_dataset, model_completions, model_completions_intervention = intervene_with_h(dataset, layer=12, zero_shot_size=20)

display_model_completions_on_h_intervention(zero_shot_dataset, model_completions, model_completions_intervention)

### Exercise - combine the functions `calculate_h` and `intervene_with_h`

```c
Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵⚪⚪

You should spend up to 10-15 minutes on this exercise.
```

One great feature of the `nnsight` library is its ability to parallelize forward passes and perform complex interventions within a single context manager.

In the code above, we had one function to extract the hidden states from the model, and another function where we intervened with those hidden states. But we can actually do both at once: we can compute $h$ within our forward pass, and then intervene with it on a different forward pass (using our zero-shot prompts), all within the same `model.generate` context manager. In other words, **we'll be using `with generator.invoke...` three times** in this context manager.

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/h-intervention-3.png" width="1000">

You should fill in the `calculate_h_and_intervene` function below, to do this. Mostly, this should involve combining your `calculate_h` and `intervene_with_h` functions, and wrapping the forward passes in the same context manager (plus a bit of code rewriting).

Your output should be exactly the same as before (since the `AntonymDataset` class is deterministic).

<details>
<summary>Help - I'm not sure how to work with the <code>h</code> vector.</summary>

You extract `h` the same way as before, but you don't need to save it, or ever reference its `.value` attribute. It is kept as a proxy. You can still use it later in the context manager, just like it actually was a tensor.

You shouldn't have to `.save()` anything inside your context manager.

</details>

In [24]:
def calculate_h_and_intervene(
    dataset: AntonymDataset,
    layer: int,
    zero_shot_size: int,
) -> Tuple[AntonymDataset, List[str]]:
    '''
    Extracts the vector `h`, intervenes by adding `h` to the residual stream of a set of generated zero-shot prompts,
    all within the same forward pass.

    Inputs:
        word_list: the list of words used to create the prompts
        dataset: the dataset of clean prompts from which we'll extract the `h`-vector
        layer: the layer we'll be extracting the `h`-vector from
        zero_shot_size: the number of zero-shot prompts to generate, to test our intervention
    
    Returns:
        zero_shot_dataset: the dataset of zero-shot prompts, which you should generate in this fn
        model_completions: list of string completions for the zero-shot prompts, without intervention
        model_completions_intervention: list of string completions for the zero-shot prompts, with h-intervention
    '''

    # Get zero-shot dataset
    zero_shot_dataset = AntonymDataset(WORD_PAIRS, size=zero_shot_size, n_prepended=0)

    with model.generate(max_new_tokens=1) as generator:
        
        # Run on the clean prompts, to get the h-vector
        with generator.invoke(dataset.prompts) as invoker:
            # Define h (we don't need to save it, cause we don't need it outside `generator:`)
            hidden_states = model.transformer.h[layer].output[0]
            h = hidden_states[:, -1].mean(dim=0)

        # First, run a forward pass where we don't intervene
        with generator.invoke(zero_shot_dataset.prompts) as invoker:
            pass

        # Next, run a forward pass on the zero-shot prompts where we do intervene
        with generator.invoke(zero_shot_dataset.prompts) as invoker:
            # Access the tensor (which is the first element of the output tuple)
            hidden_states = model.transformer.h[layer].output[0]
            # Add the h-vector to the residual stream, at the last sequence position
            hidden_states[:, -1] += h

    # Get the output (token IDs), keep data from zero-shot dataset reshape into (no intervention, intervention)
    token_completions = generator.output[len(dataset):, -1].reshape(2, zero_shot_size)
    # Decode to get the string tokens
    model_completions = model.tokenizer.batch_decode(token_completions[0])
    model_completions_intervention = model.tokenizer.batch_decode(token_completions[1])

    return zero_shot_dataset, model_completions, model_completions_intervention

In [25]:
dataset = AntonymDataset(WORD_PAIRS, size=20, n_prepended=2)

zero_shot_dataset, model_completions, model_completions_intervention = calculate_h_and_intervene(dataset, layer=12, zero_shot_size=20)

display_model_completions_on_h_intervention(zero_shot_dataset, model_completions, model_completions_intervention)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


## Logit outputs

Currently, we've only seen what the model's highest-probability output is, because that's all we got from the `generator.output` object. But what if we want to look at the logits / probabilities instead, and see how those change when we intervene?

There are two ways you can get the model's logits:

1. **Add more arguments to the `generate` method.**

Just like standard HuggingFace models have extra arguments which can be supplied to the `generate` method, so do `nnsight` models. We can replace the line:

```python
with model.generate(...) as generator:
```

with:

```python
with model.generate(..., output_scores=True, return_dict_in_generate=True) as generator:
```

and then the `generator.output` object won't be a tensor containing the model's completion, instead it will be an object that contains both the completions ***and*** the logits. From this object, you can access:

* `generator.output.sequences` = model completions, i.e. a tensor of token IDs, of shape `(batch_size, seq_len)`
* `generator.output.scores` = logits, in the form of a tuple of tensors of shape `(batch_size, seq_len, vocab_size)` (the tuple has one element for each token generation, so in this case it will be length 1).

2. **For general models, you can access the logits just like you would any other hidden state.**

For example, in the case of GPT-J and other similar models, you can access the output of the final linear layer of the transformer - i.e. the one that maps from the hidden state to logits - with `model.lm_head`. You can then use the `.output` method to get a proxy for the output of this layer, and `.save()` to save it, just like you've done in previous exercises.

Either of these approaches are fine, which one you use is up to personal preference. The solutions will use the first approach.

### Exercise - compute change in accuracy

```c
Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪

You should spend up to 10-20 minutes on this exercise.
```

You should now rewrite the `calculate_h_and_intervene` function so that, rather than returning the string completions, it returns two lists of floats, containing the **logprobs assigned by the model to the correct antonym** in the no intervention / intervention cases respectively.

When you run the code below this function, it will display the log-probabilities (highlighting green when they increase from the zero-shot case). You should find that in every sequence, the logprobs on the correct token increase in the intervention. This helps make something clear - **even if the maximum-likelihood token doesn't change, this doesn't mean that the intervention isn't having a significant effect.**

<details>
<summary>Help - I don't know how to get the correct logprobs from the logits.</summary>

First, apply log softmax to the logits, to get logprobs.

Second, you can use `tokenizer(dataset.completion)["input_ids"]` to get the token IDs of the correct completions. (Gotcha - some words might be tokenized into multiple tokens, so make sure you're just picking the first token ID for each completion.)

</details>


In [33]:
def calculate_h_and_intervene_logprobs(
    dataset: AntonymDataset,
    layer: int,
    zero_shot_size: int,
) -> Tuple[AntonymDataset, List[float], List[float]]:
    '''
    Extracts the vector `h`, intervenes by adding `h` to the residual stream of a set of generated zero-shot prompts,
    all within the same forward pass.

    Inputs:
        word_list: the list of words used to create the prompts
        dataset: the dataset of clean prompts from which we'll extract the `h`-vector
        layer: the layer we'll be extracting the `h`-vector from
        zero_shot_size: the number of zero-shot prompts to generate, to test our intervention
    
    Returns:
        zero_shot_dataset: the dataset of zero-shot prompts, which you should generate in this fn
        correct_logprobs: list of correct-token logprobs for the zero-shot prompts, without intervention
        correct_logprobs_intervention: list of correct-token logprobs for the zero-shot prompts, with h-intervention
    '''

    # Get zero-shot dataset
    zero_shot_dataset = AntonymDataset(WORD_PAIRS, size=zero_shot_size, n_prepended=0)

    with model.generate(max_new_tokens=1, pad_token_id=tokenizer.eos_token_id, output_scores=True, return_dict_in_generate=True) as generator:
        
        # Clean prompts, to get the h-vector
        with generator.invoke(dataset.prompts) as invoker:
            # Define h (we don't need to save it, cause we don't need it outside `generator:`)
            hidden_states = model.transformer.h[layer].output[0]
            h = hidden_states[:, -1].mean(dim=0)

        # Zero-shot prompts, no intervention
        with generator.invoke(zero_shot_dataset.prompts) as invoker:
            pass

        # Zero-shot prompts, intervention with h
        with generator.invoke(zero_shot_dataset.prompts) as invoker:
            # Access the tensor (which is the first element of the output tuple)
            hidden_states = model.transformer.h[layer].output[0]
            # Add the h-vector to the residual stream, at the last sequence position
            hidden_states[:, -1] += h

    # Get logits, slice to remove the `dataset` outputs, and reshape into (2, zero_shot_size, d_vocab)
    logits: Tensor = generator.output.scores[0][len(dataset):].reshape(2, zero_shot_size, -1)
    logprobs = logits.log_softmax(dim=-1)

    # Get correct completions from `dataset`, and use these to index into the logprobs
    correct_completion_ids = [toks[0] for toks in tokenizer(zero_shot_dataset.completions)["input_ids"]]
    correct_logprobs, correct_logprobs_intervention = logprobs[:, range(zero_shot_size), correct_completion_ids].tolist()

    return zero_shot_dataset, correct_logprobs, correct_logprobs_intervention

In [43]:
def display_model_logprobs_on_h_intervention(
    dataset: AntonymDataset,
    correct_logprobs: List[float],
    correct_logprobs_intervention: List[float],
    num_to_display: int = 20,
) -> None:
    table = Table(
        "Zero-shot prompt", "Model's logprob\n(no intervention)", "Model's logprob\n(intervention)", "Change in logprob",
        title="Model's antonym logprobs, with zero-shot h-intervention\n(green = intervention improves accuracy)"
    )

    for i in range(min(len(correct_logprobs), num_to_display)):

        logprob_ni = correct_logprobs[i]
        logprob_i = correct_logprobs_intervention[i]
        delta_logprob = logprob_i - logprob_ni
        zero_shot_prompt = f"{dataset[i].x[0]:>8} -> {dataset[i].y[0]}"
        
        # Color code the logprob based on whether it's increased with this intervention
        is_improvement = (delta_logprob >= 0)
        delta_logprob = f"[b green]{delta_logprob:+.2f}[/]" if is_improvement else f"{delta_logprob:+.2f}"

        table.add_row(zero_shot_prompt, f"{logprob_ni:.2f}", f"{logprob_i:.2f}", delta_logprob)

    rprint(table)

In [44]:
dataset = AntonymDataset(WORD_PAIRS, size=20, n_prepended=2)

zero_shot_dataset, correct_logprobs, correct_logprobs_intervention = calculate_h_and_intervene_logprobs(dataset, layer=12, zero_shot_size=30)

display_model_logprobs_on_h_intervention(zero_shot_dataset, correct_logprobs, correct_logprobs_intervention)

# 3️⃣ Function Vectors

In this section, we'll move from thinking about residual stream states to thinking about the **output of specific attention heads.**

First, a bit of a technical complication. Most HuggingFace models don't have the nice attention head representations that TransformerLens models do (i.e. storing vectors & attention weights separately by heads). In the case of GPT-J, the input to `out_proj` (the final linear map in the attention layer) is a tensor of value vectors which has already been concatenated along attention heads, and applying `out_proj` is equivalent to summing over the attention heads (if you can't see how this is possible, see the section "Attention Heads are Independent and Additive" from Anthropic's [Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html)).

How can we deal with this? The easiest way is to just intervene on the input of `out_proj` instead (since this is causally the same as intervening on the output), and making sure we reshape this input tensor so that it has a head dimension (then we can intervene more easily on a per-head basis). In other words, you should intervene on the value which we've called `z` in the diagram below.

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/rearrange-output.png" width="950">

When you actually need to calculate `a` (the output for a particular attention head), the easiest thing to do is just apply the appropriate slice of the linear map to `z` (we'll get to this in the next exercise, so don't worry about it for now).

### Exercise - implement `calculate_fn_vectors_and_intervene`

```c
Difficulty: 🔴🔴🔴🔴⚪
Importance: 🔵🔵🔵🔵🔵

You should spend up to 30-45 minutes on this exercise.
```

This is probably the most important function in today's exercises. Implementing it will be pretty similar to the previous function `calculate_h_and_intervene`, but:

* Rather than extracting the value of the residual stream `h` at some particular layer, you'll be extracting the output of the attention heads: iterating over each layer and each head in the model.
    * You'll only need to run one clean forward pass to compute all these values, but you'll need to run a separate corrupted forward pass for each head.
* Rather than your 2 different datasets being (dataset, zero-shot dataset), your two datasets will be (dataset, corrupted version of that same dataset).
    * You can use the `create_corrupted_dataset` method of the `AntonymDataset` class for this.

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/cie-intervention.png" width="1200">

Before you actually start writing the code, it might be helpful to answer the following questions:

<details>
<summary>What will your total batch size be (if your dataset has size <code>N</code>) ?</summary>

Your batch size will be `N * ((N_LAYERS * N_HEADS) + 2)`, because you'll need to run:

* A fowrad pass on the clean dataset (batch size `N`),
* A forward pass on the corrupted dataset (batch size `N`) with no intevention,
* A forward pass on the corrupted dataset (batch size `N`) once for each head, where you're intervening on the heads.

</details>

<details>
<summary>Which proxy outputs (if any) will you need to use <code>.save()</code> on, in this function?</summary>

None. You're just doing causal interventions, and getting the logits. You don't need to save the $z$-tensors in order to causally intervene with them at later points within the context manager (see the solution for the `calculate_h_and_intervene` exercises).

</details>

A few other tips:

* When it comes to intervening, you can set the value of a reshaped tensor, i.e. `tensor.reshape(*new_shape)[index] = new_value` will change the values in `tensor` without actually reshaping it (for more on this, see the documentation for [`torch.Tensor.view`](https://pytorch.org/docs/stable/generated/torch.Tensor.view.html)).
* It's good practice to insert a lot of assert statements in your code, to check the shapes are what you expect.
* If you're confused about dimensions, use `einops.rearrange` rather than `.reshape` - it's like using code annotations within your actual code!

In [45]:
def calculate_fn_vectors_and_intervene(
    dataset: AntonymDataset,
    LAYERS = None,
) -> Float[Tensor, "layers heads"]:
    '''
    Returns a tensor of shape (layers, heads), containing the CIE for each head.

    TODO - remove the LAYERS argument (it's only necessary because of batch size constraints, which will
    hopefully be fixed when servers work).

    Inputs:
        dataset: AntonymDataset
            the dataset of clean prompts from which we'll extract the function vector (we'll also create a
            corrupted version of this dataset for interventions)
    '''

    LAYERS = LAYERS if (LAYERS is not None) else range(8, 15) # range(N_LAYERS)
    HEADS = range(N_HEADS)

    # Get corrupted dataset
    corrupted_dataset = dataset.create_corrupted_dataset()
    N = len(dataset)

    with model.generate(max_new_tokens=1, pad_token_id=tokenizer.eos_token_id, output_scores=True, return_dict_in_generate=True) as generator:

        # Run a forward pass on clean prompts, where we store attention head outputs
        a_dict = {}
        with generator.invoke(dataset.prompts) as invoker:
            for layer in LAYERS:
                # Get hidden states, reshape to get head dimension, store the mean tensor
                hidden_states = model.transformer.h[layer].attn.out_proj.input[0][:, -1]
                a = hidden_states.reshape(N, N_HEADS, D_HEAD).mean(dim=0)
                for head in HEADS:
                    a_dict[(layer, head)] = a[head]
        
        # Run a forward pass on corrupted prompts, where we don't intervene or store activations (just so we can
        # get the logits to compare with our intervention)
        with generator.invoke(corrupted_dataset.prompts) as invoker:
            pass

        # For each head, run a forward pass on corrupted prompts (here we need multiple different forward passes,
        # because we're doing different interventions each time)
        for layer in LAYERS:
            for head in HEADS:
                with generator.invoke(corrupted_dataset.prompts) as invoker:
                    # Get hidden states, reshape to get head dimension, then set it to the a-vector
                    hidden_states = model.transformer.h[layer].attn.out_proj.input[0][:, -1]
                    hidden_states.reshape(N, N_HEADS, D_HEAD)[:, head] = a_dict[(layer, head)]


    # Get output logits (which contains all `n_heads+2` sub-batches of size N, concatenated) and reshape into sub-batches
    output_logits = einops.rearrange(generator.output.scores[0], "(batch N) d_vocab -> batch N d_vocab", N=N)
    assert output_logits.shape[0] == (len(HEADS) * len(LAYERS)) + 2

    # Get the corrupted logits & the logits with intervention (i.e. red in the diagram). Reshape latter to get head dim
    logits_corrupted = output_logits[1]
    logits_intervention = einops.rearrange(output_logits[2:], "(layer head) N d_vocab -> layer head N d_vocab", head=N_HEADS)
    
    # Get logprobs, for correct tokens
    correct_completion_ids = [toks[0] for toks in tokenizer(dataset.completions)["input_ids"]]
    logprobs_corrupted = logits_corrupted.log_softmax(dim=-1)[range(N), correct_completion_ids]
    logprobs_intervention = logits_intervention.log_softmax(dim=-1)[:, :, range(N), correct_completion_ids]
    
    # Return mean effect of intervention, over the batch dimension
    return (logprobs_intervention - logprobs_corrupted).mean(dim=-1)

Need to run this in batches, for CUDA reasons (doing a forward pass with 5 prompts for each of the 28*16 heads in GPT-J is big!). I've attached the image, so I don't have to run this again.

In [46]:
# # 40 seconds for 2 layers, N=6, prepend=5
# # 85 seconds for 10 layers, N=6, prepend=5
# # (165 -> CUDA error) seconds for 28 layers, N=6, prepend=5

# results = []

# dataset = AntonymDataset(WORD_PAIRS, size=6, n_prepended=3)

# for LAYERS in tqdm([range(10), range(10, 20), range(20, 28)]):
    
#     gc.collect()
#     t.cuda.empty_cache()

#     _results = calculate_fn_vectors_and_intervene(
#         dataset = dataset,
#         LAYERS = LAYERS,
#     )
#     results.append(_results)

#     gc.collect()
#     t.cuda.empty_cache()
    

# results = t.concat(results)

# imshow(
#     results.T,
#     title = "Average indirect effect of function-vector intervention on antonym task",
#     width = 1000,
#     height = 600,
#     labels = {"x": "Layer", "y": "Head"},
#     aspect = "equal",
# )

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/main-result.png" width="800">

### Exercise - calculate the function vector

```c
Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵⚪⚪

You should spend up to 20-35 minutes on this exercise.
```

Your next task is to actually calculate and return the function vector, so we can do a few experiments with it.

You should pick the 10 highest-scoring attention heads (according to the diagram above). You can hardcode these.

Mostly, this will involve taking your previous code and removing parts of it (since you only need to return the function vectors, not intervene with them). However, there is one difficulty here - in the previous exercises we causally intervened on the `out_proj` input (because that was easier), but here we need the actual attention head outputs. Can you see how to do this?

<details>
<summary>Answer</summary>

Once we have the value vectors post-attention (let's call them `z`), we can just apply the linear map corresponding to a slice of the `out_proj` weight matrix.
 
To be more specific:

* The weight matrices of linear layers are stored as `(out_features, in_features)`, meaning `out_proj.weight.shape = (d_model, d_model = n_heads * d_head)` (i.e. the first dimension is the output space of the attention heads, and the second dimension is the input space for each of the `z`-vectors, concatenated).
* We can rearrange this along the second dimension, then index into it to get a tensor of shape `(d_model, d_head)` corresponding to a particular head.
* Then we can take this matrix `W_O`, and calculate `W_O @ z` - this will give us the attention head's output.

</details>

In [47]:
def calculate_fn_vector(
    dataset: AntonymDataset,
    head_list: List[Tuple[int, int]],
) -> Float[Tensor, "d_model"]:
    '''
    Returns a tensor of shape (layers, heads), containing the CIE for each head.

    Inputs:
        dataset: AntonymDataset
            the dataset of clean prompts from which we'll extract the function vector (we'll also create a
            corrupted version of this dataset for interventions)
        head_list: List[Tuple[int, int]]
            list of attention heads we're calculating the function vector from
    '''
    # Turn head_list into a dict of {layer: list_of_heads}
    head_dict = {}
    for layer, head in head_list:
        head_dict[layer] = head_dict.get(layer, []) + [head]

    z_dict = {}

    with model.generate(max_new_tokens=1, pad_token_id=tokenizer.eos_token_id, output_scores=True, return_dict_in_generate=True) as generator:

        # output_list = []

        with generator.invoke(dataset.prompts) as invoker:
            for layer, head_list in head_dict.items():

                # Get the output projection layer
                out_proj = model.transformer.h[layer].attn.out_proj
                
                # Get the hidden states, and the mix of value vectors (which we'll call z)
                hidden_states = out_proj.input[0][:, -1]
                z = hidden_states.reshape(len(dataset), N_HEADS, D_HEAD).mean(dim=0)

                # For each head, compute the function vector, and add it to the list
                for head in head_list:
                    z_dict[(layer, head)] = z[head].save()

        # fn_vector = t.stack(output_list).sum(dim=0).save()

    attn_head_output_dict = {}
    
    for (layer, head), z in z_dict.items():
        W_O = model.local_model.transformer.h[layer].attn.out_proj.weight.data
        output = W_O.reshape(D_MODEL, N_HEADS, D_HEAD)[:, head] @ z.value
        attn_head_output_dict[(layer, head)] = output

    fn_vector = t.stack(list(attn_head_output_dict.values())).sum(dim=0)
    assert fn_vector.shape == (D_MODEL,)

    return fn_vector

In [None]:
tests.test_calculate_fn_vector(calculate_fn_vector)

## Multi-token generation

We're now going to replicate some of the results in Table 3, in the paper:

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/tab3.png" width="700">

This will involve doing something we haven't done before - **intervening on multi-token prompt generation**.

Most of the interpretability exercises in this chapter have just consisted of running single forward passes, rather than autoregressive text generation. But we're trying something different here: we're adding the function vector to the final sequence position at each forward pass during text generation, and seeing if we can get the model to output a sentence with a different meaning.

The results of Table 3 came from adding the function vector to the residual stream at the final sequence position of the original prompt, **and the final sequence position for each subsequent generation.** The reason we do this is to guide the model's behaviour over time. Our hypothesis is that the function vector induces "next-token antonym behaviour" (because it was calculated by averaging attention head outputs at the sequence position before the model made its antonym prediction in the ICL prompts).

### Using `nnsight` for multi-token generation

Previously, our context managers have looked like:

```python
with model.generate(max_new_tokens=1) as generator:
    with generator.invoke(prompt) as invoker:

        # Do stuff to the model's internals
```

But for multi-token generation, our context mnagers will look like:

```python
with model.generate(max_new_tokens=max_new_tokens) as generator:
    with generator.invoke(prompt) as invoker:

        for n in range(max_new_tokens):
            # Do stuff to the model's internals, on the n-th forward pass
            invoker.next()
```

The line `invoker.next()` denotes that the following interventions should be applied to the subsequent generations.

Mostly, the stuff you'll be used to from single-token generation generalizes to th multi-token case. The object `generator.output` is still a tensor which contains the model's token ID completions (or we can return scores instead, exactly like we did before). Using `.save()` still saves proxies outside the context managers (although make sure that you don't use the same variable names over different generations, otherwise you'll overwrite them - it's easier to store your saved proxies in e.g. a list or dict).

A couple more notes:

* By default the `generate` method will generate tokens greedily, i.e. always taking the maximum-probability token at each step. For now, we don't need to worry about changing this behaviour.
* Transformer models perform **key-value caching** to speed up text generation. This means that the time taken to generate $n$ tokens is ***much*** less than $n$ times longer than generating a single token. See [this blog post](https://kipp.ly/transformer-inference-arithmetic/) on transformer inference arithmetic for more.

### Exercise - intervene with function vector, in multi-token generation

```c
Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵⚪

You should spend up to 10-20 minutes on this exercise.
```

You should now fill in the function `intervene_with_fn_vector` below. This will take a function vector (calculated from the function you wrote above), as well as a few other arguments (see docstring), and return the model's string completion on the given prompt template. 

We hope to observe results qualitatively like the ones in Table 3, i.e. having the model define a particular word as its antonym.

In [50]:
def intervene_with_fn_vector(
    word: str,
    layer: int,
    fn_vector: Float[Tensor, "d_model"],
    prompt_template = 'The word "{x}" means',
    n_tokens: int = 5,
) -> Tuple[str, str]:
    '''
    Intervenes with a function vector, by adding it at the last sequence position of a generated prompt.

    Returns: the full completion (including original prompt) for no-intervention / intervention case respectively.
    '''

    prompt = prompt_template.format(x=word)
    
    with model.generate(max_new_tokens=n_tokens, pad_token_id=tokenizer.eos_token_id) as generator:

        # No intervention
        with generator.invoke(prompt) as invoker:
            pass

        # Intervention
        with generator.invoke(prompt) as invoker:
            for i in range(n_tokens):
                hidden_states = model.transformer.h[layer].output[0]
                hidden_states[:, -1] += fn_vector
                # TODO - add exercises about invoker.next()!
                invoker.next()

    output = generator.output

    completions = tokenizer.batch_decode(output)

    return tuple(completions)

To test your function, run this code. You should find that the first completion seems normal, but the second completion defines a word as its antonym. **You've just successfully induced an OOD behavioural change in a 6b-parameter model!**

In [57]:
# Define our dataset, and the attention heads we'll use
dataset = AntonymDataset(WORD_PAIRS, size=20, n_prepended=5)
head_list = [(8, 0), (8, 1), (9, 14), (11, 0), (12, 10), (13, 12), (13, 13), (14, 9), (15, 5), (16, 14)]

# Extract the function vector
fn_vector = calculate_fn_vector(dataset, head_list)

# Define a word we'll use in our prompt, and check it's OOD
word = "expensive"
assert all(word not in pair for pair in WORD_PAIRS)

# Intervene with the function vector
completion, completion_intervention = intervene_with_fn_vector(
    word = word,
    layer = 9,
    fn_vector = fn_vector,
    prompt_template = 'The word "{x}" means',
    n_tokens = 10,
)
print(repr(completion))
print(repr(completion_intervention))

'The word "expensive" means different things to different people. For some, it'
'The word "expensive" means "cheap" in this case.\n\n'


In [None]:
# Note to self - won't have ppl replicate results like Table 2, because this just seems kinda boring compared to the stuff we've done so far, and it doesn't introduce any new conceptual ideas. We've already done (zero shot, h-vector) and (shuffled, fn-vector). We'll leave as a bonus exercise replicating these results with other prompt templates and other models.

# TODO - change the examples in the `generate_dataset` function, so they have a prepended space

### Exercise - generalize results to another task (optional)

```c
Difficulty: 🔴🔴🔴🔴⚪
Importance: 🔵🔵🔵⚪⚪

You should spend up to 10-15 minutes on this exercise.
```

In this exercise, you get to pick a task different to the antonyms task, and see if the results still hold up (for the same set of attention heads).

We'll leave this exercise fairly open-ended, without any code templates for you to fill in. However, if you'd like some guidance you can use the dropdown below.

<details>
<summary>Guidance for exercise</summary>

Whatever your task, you'll want to generate a new set of words. You can repurpose the `generate_dataset` function from the antonyms task, by supplying a different prompt and initial set of examples (this will require generating & using an OpenAI api key, if you haven't already), or you can just find an appropriate dataset online.

When you define the `AntonymDataset`, you might want to use `bidirectional=False`, if your task isn't symmetric. The antonym task is symmetric, but others (e.g. the Country-Capitals task) are not.

You'll need to supply a new prompt template for the `intervene_with_fn_vector` function, but otherwise most of your code should stay the same.

</details>


In [60]:
def generate_dataset_capitals(N: int):

    t0 = time.time()

    examples = "Portugal: Lisbon, Ireland: Dublin, Chile: Santiago, Japan: Tokyo, "

    response = openai.ChatCompletion.create(
        model="gpt-4",
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": f"Give me {N} examples of Country-Capital pairs."},
            {"role": "assistant", "content": f"Sure! Here are {N} Country-Capital pairs: {examples}"},
        ]
    )
    response_text: str = examples + response["choices"][0]["message"]["content"]

    word_pairs = [word_pair.split(": ") for word_pair in response_text.strip(".\n").split(", ")]

    print(f"Finished in {time.time()-t0:.2f} seconds.")

    return word_pairs


CC_WORD_PAIRS = generate_dataset_capitals(100)

Finished in 35.82 seconds.


In [70]:
# Remove (Netherlands, Amsterdam) from the pairs, so it can be a holdout
country = "Netherlands"
_CC_WORD_PAIRS = [pair for pair in CC_WORD_PAIRS if pair[0] != country]

# Define our dataset, and the attention heads we'll use
dataset = AntonymDataset(_CC_WORD_PAIRS, size=20, n_prepended=5, bidirectional=False)
head_list = [(8, 0), (8, 1), (9, 14), (11, 0), (12, 10), (13, 12), (13, 13), (14, 9), (15, 5), (16, 14)]

# Extract the function vector
fn_vector = calculate_fn_vector(dataset, head_list)

# Intervene with the function vector
completion, completion_intervention = intervene_with_fn_vector(
    word = country,
    layer = 9,
    fn_vector = fn_vector,
    prompt_template = 'When you think of {x},',
    n_tokens = 20,
)
print(repr(completion))
print(repr(completion_intervention))

'When you think of Netherlands, you probably think of tulips, windmills, and cheese. But the Netherlands is also home'
'When you think of Netherlands, you probably think of tulips, windmills, and cheese. But Amsterdam is a lot more'


# 3️⃣ Bonus

There are two other interesting results from the paper, although neither of them are as important as the ones we've covered so far. If you have time, you can try to reproduce these results yourself.

### 3.2 - The Decoded Vocabulary of Function Vectors

In this section, the authors find the top words in the decoded vocabulary of the function vector (i.e. the words whose unembedding vectors have the highest dot product with the function vector), and show that these words seem conceptually related to the task. For example:

* For the antonyms task, the top words evoke the idea of antonyms, e.g. `" negate"`, `" counterpart"`, `" lesser"`.
* For the country-capitals task, the top words are actually the names of capitals, e.g. `" Moscow"`, `" Paris"`, `" Madrid"`.

Can you replicate these results, both with the antonyms task and with the task you chose in the previous section?

An interesting extension - what happens if you take a task like the Country-Capitals task (which is inherently asymmetric), and get your function vector from the symmetric version of the task (i.e. the one where each of your question-answer pairs might be flipped around)? Do you still get the same behavioural results, and how (if at all) do the decoded vocabulary results change?

<details>
<summary>My results for this (spoiler!)</summary>

In the Country-Capitals task, I found:

* The bidirectional task does still work to induce behavioural changes, although slightly less effectively than for the original task.
* The top decoded vocabulary items are a mix of country names and capital names, but mostly capitals.

</details>

### 3.3 - Vector Algebra on Function Vectors

In this section, the authors investigate whether function vectors can be composed. For instance, if we have three separate ICL tasks which in some sense compose to make a fourth task, can we add together the three function vectors of the first tasks, and use this as the function vector of the fourth task?

The authors test this on a variety of different tasks. They find that it's effective on some tasks (e.g. Country-Capitals, where it outperforms function vectors), but generally isn't as effective as function vectors. Do you get these same results?

In [None]:
# Code to calculate decoded vocabulary:
logits = model.local_model.lm_head(fn_vector)
max_logits = logits.topk(20).indices.tolist()
tokens = model.tokenizer.batch_decode(max_logits)
print(tokens)