In [1]:
# note: same thing for REPL
# note: we use this instead of magic because `black` will otherwise fail to format
#
# Enable autoreload to automatically reload modules when they change

from IPython import get_ipython

# do this so that formatter not messed up
ipython = get_ipython()
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

# Import commonly used libraries
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# graphics
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

# type annotation
import jaxtyping
from jaxtyping import Float32, Int64, jaxtyped
from typeguard import typechecked as typechecker

# more itertools
import more_itertools as mi

# itertools
import itertools
import collections

# tensor manipulation
from einops import rearrange, reduce, repeat

# automatically apply jaxtyping
# %load_ext jaxtyping
# %jaxtyping.typechecker typeguard.typechecked

In [2]:
# Enable jaxtyping based typechecking
import jaxtyping
import typeguard

# Inline comment: This magic command enables runtime type checking using jaxtyping and typeguard
# ipython.run_line_magic("load_ext", "jaxtyping")

# Inline comment: This sets the typecheck mode to 'jaxtyping', which allows for more precise tensor shape checking
# ipython.run_line_magic("jaxtyping.typechecker", "typeguard.typechecked")

In [3]:
import nnsight

In [4]:
import os

nnsight_api_key = os.environ["NNSIGHT_API_KEY"]

print(f"{len(nnsight_api_key)=}")

nnsight.CONFIG.set_default_api_key(nnsight_api_key)

len(nnsight_api_key)=32


In [5]:
model = nnsight.LanguageModel("EleutherAI/gpt-j-6b", device_map="auto")
tokenizer = model.tokenizer

N_HEADS = model.config.n_head
N_LAYERS = model.config.n_layer
D_MODEL = model.config.n_embd
D_HEAD = D_MODEL // N_HEADS

print(f"Number of heads: {N_HEADS}")
print(f"Number of layers: {N_LAYERS}")
print(f"Model dimension: {D_MODEL}")
print(f"Head dimension: {D_HEAD}\n")

print("Entire config: ", model.config)
# %%

# Calling tokenizer returns a dictionary, containing input ids & other data.
# If returned as a tensor, then by default it will have a batch dimension.
print(tokenizer("This must be Thursday", return_tensors="pt"))

# Decoding a list of integers, into a concatenated string.
print(tokenizer.decode([40, 1239, 714, 651, 262, 8181, 286, 48971, 12545, 13]))

# Using batch decode, on both 1D and 2D input.
print(tokenizer.batch_decode([4711, 2456, 481, 307, 6626, 510]))
print(tokenizer.batch_decode([[1212, 6827, 481, 307, 1978], [2396, 481, 428, 530]]))

# Split sentence into tokens (note we see the special Ġ character in place of prepended spaces).
print(tokenizer.tokenize("This sentence will be tokenized"))

config.json:   0%|          | 0.00/930 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.37M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/4.04k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/357 [00:00<?, ?B/s]

Number of heads: 16
Number of layers: 28
Model dimension: 4096
Head dimension: 256

Entire config:  GPTJConfig {
  "_name_or_path": "EleutherAI/gpt-j-6b",
  "activation_function": "gelu_new",
  "architectures": [
    "GPTJForCausalLM"
  ],
  "attn_pdrop": 0.0,
  "bos_token_id": 50256,
  "embd_pdrop": 0.0,
  "eos_token_id": 50256,
  "gradient_checkpointing": false,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gptj",
  "n_embd": 4096,
  "n_head": 16,
  "n_inner": null,
  "n_layer": 28,
  "n_positions": 2048,
  "resid_pdrop": 0.0,
  "rotary": true,
  "rotary_dim": 64,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50,
      "temperature": 1.0
    }
  },
  "tie_word_embeddings": false,
  "tokenizer_class": "GPT2Tokenizer",
  "



In [6]:
prompt = "The Eiffel Tower is in the city of"

with model.trace(remote=True) as runner:
    with runner.invoke(prompt) as invoker:

        # Save the model's hidden states
        hidden_states = model.transformer.h[-1].output[0].save()

        # Save the model's logit output
        logits = model.lm_head.output[0, -1].save()

        # If you've worked with TransformerLens (or even regular HF models) then you
        # might be used to getting logits directly from the model output, but here we
        # generally extract logits from the model internals just like any other
        # activation because this allows us to control exactly what we return.

# Get the model's logit output, and it's next token prediction
print(f"\nlogits.shape = {logits.value.shape} = (vocab_size,)")

predicted_token_id = logits.value.argmax().item()
print(f"Predicted token ID = {predicted_token_id}")
print(f"Predicted token = {tokenizer.decode(predicted_token_id)!r}")

# Print the shape of the model's residual stream
print(f"\nresid.shape = {hidden_states.value.shape} = (batch_size, seq_len, d_model)")

2024-09-13 17:23:49,012 81d891ed-9659-4b95-93fb-134293e2cf1f - RECEIVED: Your job has been received and is waiting approval.
2024-09-13 17:23:49,090 81d891ed-9659-4b95-93fb-134293e2cf1f - APPROVED: Your job was approved and is waiting to be run.
2024-09-13 17:23:49,202 81d891ed-9659-4b95-93fb-134293e2cf1f - RUNNING: Your job has started running.
2024-09-13 17:23:49,537 81d891ed-9659-4b95-93fb-134293e2cf1f - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 184k/184k [00:00<00:00, 773kB/s]


logits.shape = torch.Size([50400]) = (vocab_size,)
Predicted token ID = 6342
Predicted token = ' Paris'

resid.shape = torch.Size([1, 10, 4096]) = (batch_size, seq_len, d_model)





In [7]:
from typing import TypeVar, TypeVarTuple, Any

T = TypeVar("T")
U = TypeVar("U")


def get_first_arg_from_input(proxy: T) -> U:

    # get back a tuple, (args, kwargs)
    proxy_inputs: tuple[tuple[TypeVarTuple], dict[str, Any]] = proxy.input

    proxy_inputs_args: tuple[TypeVarTuple] = proxy_inputs[0]

    return proxy_inputs_args[0]


with model.trace(remote=True) as runner:

    with runner.invoke(prompt) as invoker:

        # note: `input` returns (args, kwargs)

        attn_patterns_proxy = model.transformer.h[0].attn.attn_dropout.input[0][0]

        attn_patterns = attn_patterns_proxy.save()

2024-09-13 17:24:36,306 dc961223-64f3-4681-987f-70866ac69c7f - RECEIVED: Your job has been received and is waiting approval.
2024-09-13 17:24:36,326 dc961223-64f3-4681-987f-70866ac69c7f - APPROVED: Your job was approved and is waiting to be run.
2024-09-13 17:24:36,358 dc961223-64f3-4681-987f-70866ac69c7f - RUNNING: Your job has started running.
2024-09-13 17:24:36,478 dc961223-64f3-4681-987f-70866ac69c7f - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 1.44k/1.44k [00:00<00:00, 2.82MB/s]


In [9]:
import circuitsvis as cv

# Get string tokens (replacing special character for spaces)
str_tokens = model.tokenizer.tokenize(prompt)
str_tokens = [s.replace("Ġ", " ") for s in str_tokens]

# Attention patterns (squeeze out the batch dimension)
attn_patterns_value = attn_patterns.value.squeeze(0)

print("Layer 0 Head Attention Patterns:")
display(
    cv.attention.attention_patterns(
        tokens=str_tokens,
        attention=attn_patterns_value,
    )
)

Layer 0 Head Attention Patterns:


In [10]:
# However, one thing to be wary of is that GPT-J uses rotary embeddings, which makes
# the computation of attention patterns from keys and queries a bit harder than it
# would otherwise be

ANTONYM_PAIRS = [
    ("happy", "sad"),
    ("light", "dark"),
    ("hot", "cold"),
    ("big", "small"),
    ("fast", "slow"),
    ("hard", "soft"),
    ("rich", "poor"),
    ("full", "empty"),
    ("up", "down"),
    ("strong", "weak"),
    ("brave", "cowardly"),
    ("young", "old"),
    ("new", "old"),
    ("clean", "dirty"),
    ("near", "far"),
    ("sharp", "blunt"),
    ("quiet", "loud"),
    ("hard", "easy"),
    ("thick", "thin"),
    ("wet", "dry"),
    ("open", "closed"),
    ("happy", "sad"),
    ("love", "hate"),
    ("success", "failure"),
    ("yes", "no"),
    ("buy", "sell"),
    ("true", "false"),
    ("defend", "attack"),
    ("accept", "refuse"),
    ("included", "excluded"),
    ("acceptance", "rejection"),
    ("advance", "retreat"),
    ("gain", "loss"),
    ("believe", "doubt"),
    ("attract", "repel"),
    ("increase", "decrease"),
    ("win", "lose"),
    ("visible", "invisible"),
    ("active", "inactive"),
    ("complex", "simple"),
    ("ignore", "acknowledge"),
    ("encourage", "discourage"),
    ("assemble", "disperse"),
    ("mature", "immature"),
    ("gain", "lose"),
    ("new", "used"),
    ("cooperate", "compete"),
    ("begin", "end"),
    ("create", "destroy"),
    ("expand", "contract"),
    ("develop", "regress"),
    ("succeed", "fail"),
    ("connect", "disconnect"),
    ("expand", "shrink"),
    ("introduce", "withdraw"),
    ("safety", "danger"),
    ("satisfaction", "discontent"),
    ("freedom", "restriction"),
    ("strength", "weakness"),
    ("joy", "sorrow"),
    ("truth", "falsehood"),
    ("acceptance", "rejection"),
]

ANTONYM_PAIRS = [list(x) for x in ANTONYM_PAIRS]

In [11]:
class ICLSequence:
    """
    Class to store a single antonym sequence.

    Uses the default template "Q: {x}\nA: {y}" (with separate pairs split by "\n\n").
    """

    def __init__(self, word_pairs: list[tuple[str, str]]):
        self.word_pairs = word_pairs
        self.x, self.y = zip(*word_pairs)

    def __len__(self):
        return len(self.word_pairs)

    def __getitem__(self, idx: int):
        return self.word_pairs[idx]

    def prompt(self):
        """Returns the prompt, which contains all but the second element in the last word pair."""
        p = "\n\n".join([f"Q: {x}\nA: {y}" for x, y in self.word_pairs])
        return p[: -len(self.completion())]

    def completion(self):
        """Returns the second element in the last word pair (with padded space)."""
        return " " + self.y[-1]

    def __str__(self):
        """Prints a readable string representation of the prompt & completion (indep of template)."""
        return f"{', '.join([f'({x}, {y})' for x, y in self[:-1]])}, {self.x[-1]} ->".strip(
            ", "
        )


word_list = [["hot", "cold"], ["yes", "no"], ["in", "out"], ["up", "down"]]
seq = ICLSequence(word_list)

print("Tuple-representation of the sequence:")
print(seq)
print("\nActual prompt, which will be fed into the model:")
print(seq.prompt())

Tuple-representation of the sequence:
(hot, cold), (yes, no), (in, out), up ->

Actual prompt, which will be fed into the model:
Q: hot
A: cold

Q: yes
A: no

Q: in
A: out

Q: up
A:


In [12]:
import copy


class ICLDataset:
    """
    Dataset to create antonym pair prompts, in ICL task format. We use random seeds for consistency
    between the corrupted and clean datasets.

    Note:

        Note that the correct completions have a prepended space!!!

    Inputs:
        word_pairs:
            list of ICL task, e.g. [["old", "young"], ["top", "bottom"], ...] for the antonym task
        size:
            number of prompts to generate
        n_prepended:
            number of antonym pairs before the single-word ICL task
        bidirectional:
            if True, then we also consider the reversed antonym pairs
        corrupted:
            if True, then the second word in each pair is replaced with a random word
        seed:
            random seed, for consistency & reproducibility
    """

    def __init__(
        self,
        word_pairs: list[str],
        size: int,
        n_prepended: int,
        bidirectional: bool = True,
        seed: int = 0,
        corrupted: bool = False,
    ):
        assert n_prepended + 1 <= len(
            word_pairs
        ), "Not enough antonym pairs in dataset to create prompt."

        self.word_pairs = copy.deepcopy(word_pairs)
        self.word_list = [word for word_pair in word_pairs for word in word_pair]
        self.size = size
        self.n_prepended = n_prepended
        self.bidirectional = bidirectional
        self.corrupted = corrupted
        self.seed = seed

        self.seqs = []
        self.prompts = []
        self.completions = []

        # Generate the dataset (by choosing random antonym pairs, and constructing `ICLSequence` objects)
        for n in range(size):
            np.random.seed(seed + n)
            random_pairs = np.random.choice(
                len(self.word_pairs), n_prepended + 1, replace=False
            )
            random_orders = np.random.choice([1, -1], n_prepended + 1)
            if not (bidirectional):
                random_orders[:] = 1
            word_pairs = [
                self.word_pairs[pair][::order]
                for pair, order in zip(random_pairs, random_orders)
            ]
            if corrupted:
                for i in range(len(word_pairs) - 1):
                    word_pairs[i][1] = np.random.choice(self.word_list)
            seq = ICLSequence(word_pairs)

            self.seqs.append(seq)
            self.prompts.append(seq.prompt())
            self.completions.append(seq.completion())

    def create_corrupted_dataset(self):
        """Creates a corrupted version of the dataset (with same random seed)."""
        return ICLDataset(
            self.word_pairs,
            self.size,
            self.n_prepended,
            self.bidirectional,
            corrupted=True,
            seed=self.seed,
        )

    def __len__(self):
        return self.size

    def __getitem__(self, idx: int):
        return self.seqs[idx]

In [15]:
import rich
import rich.table

dataset = ICLDataset(ANTONYM_PAIRS, size=10, n_prepended=2, corrupted=False)

table = rich.table.Table("Prompt", "Correct completion")
for seq, completion in zip(dataset.seqs, dataset.completions):
    table.add_row(str(seq), repr(completion))

rich.print(table)

In [17]:
dataset = ICLDataset(ANTONYM_PAIRS, size=10, n_prepended=2, corrupted=True)

table = rich.table.Table("Prompt", "Correct completion")
for seq, completions in zip(dataset.seqs, dataset.completions):
    table.add_row(str(seq), repr(completions))

rich.print(table)

In [18]:
def calculate_h(
    model: nnsight.LanguageModel,
    dataset: ICLDataset,
    layer: int = -1,
) -> tuple[str, torch.Tensor]:
    """
    Averages over the model's hidden representations on each of the prompts in
    `dataset` at layer `layer`, to produce a single vector `h`.

    Inputs:
        model: LanguageModel
            the transformer you're doing this computation with
        dataset: ICLDataset
            the dataset whose prompts `dataset.prompts` you're extracting the
            activations from (at the last seq pos)
        layer: int
            the layer you're extracting activations from

    Returns:
        completions: str]
            list of model completion strings (i.e. the strings the model predicts to
            follow the last token)
        h: Tensor
            average hidden state tensor at final sequence position, of shape (d_model,)

    """
    with model.trace(remote=True) as runner:
        with runner.invoke(dataset.prompts) as invoker:

            h = model.transformer.h[layer].output[0][:, -1].mean(dim=0).save()

            logits = model.lm_head.output[:, -1]
            token_ids = logits.argmax(dim=-1).save()

    completions = model.tokenizer.batch_decode(token_ids.value)

    return completions, h.value

In [19]:
def display_model_completions_on_antonyms(
    model: nnsight.LanguageModel,
    dataset: ICLDataset,
    completions: str,
    num_to_display: int = 20,
) -> None:
    table = rich.table.Table(
        "Prompt (tuple representation)",
        "Model's completion\n(green=correct)",
        "Correct completion",
        title="Model's antonym completions",
    )

    for i in range(min(len(completions), num_to_display)):

        # Get model's completion, and correct completion
        completion = completions[i]
        correct_completion = dataset.completions[i]
        correct_completion_first_token = model.tokenizer.tokenize(correct_completion)[
            0
        ].replace("Ġ", " ")
        seq = dataset.seqs[i]

        # Color code the completion based on whether it's correct
        is_correct = completion == correct_completion_first_token
        completion = (
            f"[b green]{repr(completion)}[/]" if is_correct else repr(completion)
        )

        table.add_row(str(seq), completion, repr(correct_completion))

    rich.print(table)

In [20]:
# Get uncorrupted dataset
dataset = ICLDataset(ANTONYM_PAIRS, size=20, n_prepended=2)

# Getting it from layer 12, as in the description in section 2.1 of paper
model_completions, h = calculate_h(model, dataset, layer=12)

# Displaying the output
display_model_completions_on_antonyms(model, dataset, model_completions)

2024-09-13 17:55:04,071 8c9fd27b-cff0-4c02-bf74-f711ed01253b - RECEIVED: Your job has been received and is waiting approval.
2024-09-13 17:55:04,130 8c9fd27b-cff0-4c02-bf74-f711ed01253b - APPROVED: Your job was approved and is waiting to be run.
2024-09-13 17:55:04,199 8c9fd27b-cff0-4c02-bf74-f711ed01253b - RUNNING: Your job has started running.
2024-09-13 17:55:04,417 8c9fd27b-cff0-4c02-bf74-f711ed01253b - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 9.82k/9.82k [00:00<00:00, 157kB/s]


In [21]:
def intervene_with_h(
    model: nnsight.LanguageModel,
    zero_shot_dataset: ICLDataset,
    h: torch.Tensor,
    layer: int,
) -> tuple[list[str], list[str]]:
    """
    Extracts the vector `h` using previously defined function, and intervenes by adding `h` to the
    residual stream of a set of generated zero-shot prompts.

    Inputs:
        model: LanguageModel
            the transformer you're doing this computation with
        zero_shot_dataset: ICLDataset
            the dataset of zero-shot prompts which we'll intervene on, using the `h`-vector
        h: Tensor
            the `h`-vector we'll be adding to the residual stream
        layer: int
            the layer we'll be extracting the `h`-vector from

    Returns:
        completions_zero_shot: list[str]
            list of string completions for the zero-shot prompts, without intervention
        completions_intervention: list[str]
            list of string completions for the zero-shot prompts, with h-intervention
    """
    with model.trace(remote=True) as runner:

        # First, run a forward pass where we don't intervene, just save token id completions
        with runner.invoke(zero_shot_dataset.prompts) as invoker:
            token_completions_zero_shot = (
                model.lm_head.output[:, -1].argmax(dim=-1).save()
            )

        # Next, run a forward pass on the zero-shot prompts where we do intervene
        with runner.invoke(zero_shot_dataset.prompts) as invoker:
            # Add the h-vector to the residual stream, at the last sequence position
            hidden_states = model.transformer.h[layer].output[0]
            hidden_states[:, -1] += h
            # Also save completions
            token_completions_intervention = (
                model.lm_head.output[:, -1].argmax(dim=-1).save()
            )

    # Decode to get the string tokens
    completions_zero_shot = model.tokenizer.batch_decode(
        token_completions_zero_shot.value
    )
    completions_intervention = model.tokenizer.batch_decode(
        token_completions_intervention.value
    )

    return completions_zero_shot, completions_intervention

In [22]:
# Note, it's very important that we set a different random seed for the zero shot
# dataset, otherwise we'll be intervening on examples which were actually in the
# dataset we used to compute h!

layer = 12
dataset = ICLDataset(ANTONYM_PAIRS, size=20, n_prepended=3, seed=0)
zero_shot_dataset = ICLDataset(ANTONYM_PAIRS, size=20, n_prepended=0, seed=1)

# Run previous function to get h-vector
h = calculate_h(model, dataset, layer=layer)[1]

# Run new function to intervene with h-vector
completions_zero_shot, completions_intervention = intervene_with_h(
    model,
    zero_shot_dataset,
    h,
    layer=layer,
)

print("\nZero-shot completions: ", completions_zero_shot)
print("Completions with intervention: ", completions_intervention)

2024-09-13 17:55:39,657 54e99311-d381-40d1-9a3b-5134666a2277 - RECEIVED: Your job has been received and is waiting approval.
2024-09-13 17:55:39,678 54e99311-d381-40d1-9a3b-5134666a2277 - APPROVED: Your job was approved and is waiting to be run.
2024-09-13 17:55:39,706 54e99311-d381-40d1-9a3b-5134666a2277 - RUNNING: Your job has started running.
2024-09-13 17:55:39,922 54e99311-d381-40d1-9a3b-5134666a2277 - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 9.82k/9.82k [00:00<00:00, 145kB/s]
2024-09-13 17:55:41,182 bb78d644-d2cb-4ee7-83e3-cd3af8d6c585 - RECEIVED: Your job has been received and is waiting approval.
2024-09-13 17:55:41,215 bb78d644-d2cb-4ee7-83e3-cd3af8d6c585 - APPROVED: Your job was approved and is waiting to be run.
2024-09-13 17:55:41,238 bb78d644-d2cb-4ee7-83e3-cd3af8d6c585 - RUNNING: Your job has started running.
2024-09-13 17:55:41,373 bb78d644-d2cb-4ee7-83e3-cd3af8d6c585 - COMPLETED: Your job has been completed.
Downloading result: 100%|█


Zero-shot completions:  [' sad', ' win', ' destroy', ' slow', ' 1', ' gain', ' blunt', ' light', ' rep', ' open', ' strong', ' accept', ' yes', ' rich', ' hate', ' lose', ' quiet', ' encourage', ' thick', ' fast']
Completions with intervention:  [' happy', ' win', ' destroy', ' fast', ' disperse', ' loss', ' blunt', ' dark', ' rep', ' closed', ' weak', ' accept', ' safety', ' poor', ' love', ' lose', ' quiet', ' discourage', ' thin', ' slow']





In [25]:
def display_model_completions_on_h_intervention(
    dataset: ICLDataset,
    completions: list[str],
    completions_intervention: list[str],
    num_to_display: int = 20,
) -> None:
    table = rich.table.Table(
        "Prompt",
        "Model's completion\n(no intervention)",
        "Model's completion\n(intervention)",
        "Correct completion",
        title="Model's antonym completions",
    )

    for i in range(min(len(completions), num_to_display)):

        completion_ni = completions[i]
        completion_i = completions_intervention[i]
        correct_completion = dataset.completions[i]
        correct_completion_first_token = tokenizer.tokenize(correct_completion)[
            0
        ].replace("Ġ", " ")
        seq = dataset.seqs[i]

        # Color code the completion based on whether it's correct
        is_correct = completion_i == correct_completion_first_token
        completion_i = (
            f"[b green]{repr(completion_i)}[/]" if is_correct else repr(completion_i)
        )

        table.add_row(
            str(seq), repr(completion_ni), completion_i, repr(correct_completion)
        )

    rich.print(table)

In [26]:
display_model_completions_on_h_intervention(
    zero_shot_dataset,
    completions_zero_shot,
    completions_intervention,
)

In [27]:
def calculate_h_and_intervene(
    model: nnsight.LanguageModel,
    dataset: ICLDataset,
    zero_shot_dataset: ICLDataset,
    layer: int,
) -> tuple[list[str], list[str]]:
    """
    Extracts the vector `h`, intervenes by adding `h` to the residual stream of a set of generated zero-shot prompts,
    all within the same forward pass. Returns the completions from this intervention.

    Inputs:
        model: LanguageModel
            the model we're using to generate completions
        dataset: ICLDataset
            the dataset of clean prompts from which we'll extract the `h`-vector
        zero_shot_dataset: ICLDataset
            the dataset of zero-shot prompts which we'll intervene on, using the `h`-vector
        layer: int
            the layer we'll be extracting the `h`-vector from

    Returns:
        completions_zero_shot: list[str]
            list of string completions for the zero-shot prompts, without intervention
        completions_intervention: list[str]
            list of string completions for the zero-shot prompts, with h-intervention
    """
    with model.trace(remote=True) as runner:

        # Run on the clean prompts, to get the h-vector
        with runner.invoke(dataset.prompts) as invoker:
            # Define h (we don't need to save it, cause we don't need it outside `runner:`)
            hidden_states = model.transformer.h[layer].output[0]
            h = hidden_states[:, -1].mean(dim=0)

        # First, run a forward pass where we don't intervene
        with runner.invoke(zero_shot_dataset.prompts) as invoker:
            token_completions_zero_shot = (
                model.lm_head.output[:, -1].argmax(dim=-1).save()
            )

        # Next, run a forward pass on the zero-shot prompts where we do intervene
        with runner.invoke(zero_shot_dataset.prompts) as invoker:
            # Add the h-vector to the residual stream, at the last sequence position
            hidden_states = model.transformer.h[layer].output[0]
            hidden_states[:, -1] += h
            # Also save completions
            token_completions_intervention = (
                model.lm_head.output[:, -1].argmax(dim=-1).save()
            )

    # Decode to get the string tokens
    completions_zero_shot = model.tokenizer.batch_decode(
        token_completions_zero_shot.value
    )
    completions_intervention = model.tokenizer.batch_decode(
        token_completions_intervention.value
    )

    return completions_zero_shot, completions_intervention

In [28]:
dataset = ICLDataset(ANTONYM_PAIRS, size=20, n_prepended=3, seed=0)
zero_shot_dataset = ICLDataset(ANTONYM_PAIRS, size=20, n_prepended=0, seed=1)

completions_zero_shot, completions_intervention = calculate_h_and_intervene(
    model,
    dataset,
    zero_shot_dataset,
    layer=layer,
)

display_model_completions_on_h_intervention(
    zero_shot_dataset,
    completions_zero_shot,
    completions_intervention,
)

2024-09-13 17:56:39,912 6cc48d41-5503-4a89-a7c1-a5efb4374c6e - RECEIVED: Your job has been received and is waiting approval.
2024-09-13 17:56:39,944 6cc48d41-5503-4a89-a7c1-a5efb4374c6e - APPROVED: Your job was approved and is waiting to be run.
2024-09-13 17:56:39,974 6cc48d41-5503-4a89-a7c1-a5efb4374c6e - RUNNING: Your job has started running.
2024-09-13 17:56:40,422 6cc48d41-5503-4a89-a7c1-a5efb4374c6e - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 1.75k/1.75k [00:00<00:00, 4.64MB/s]


In [34]:
def calculate_h_and_intervene_logprobs(
    model: nnsight.LanguageModel,
    dataset: ICLDataset,
    zero_shot_dataset: ICLDataset,
    layer: int,
) -> tuple[list[float], list[float]]:
    """
    Extracts the vector `h`, intervenes by adding `h` to the residual stream of a set of generated zero-shot prompts,
    all within the same forward pass. Returns the logprobs on correct tokens from this intervention.

    Inputs:
        model: LanguageModel
            the model we're using to generate completions
        dataset: ICLDataset
            the dataset of clean prompts from which we'll extract the `h`-vector
        zero_shot_dataset: ICLDataset
            the dataset of zero-shot prompts which we'll intervene on, using the `h`-vector
        layer: int
            the layer we'll be extracting the `h`-vector from

    Returns:
        correct_logprobs: list[float]
            list of correct-token logprobs for the zero-shot prompts, without intervention
        correct_logprobs_intervention: list[float]
            list of correct-token logprobs for the zero-shot prompts, with h-intervention
    """
    # Get correct completions from `dataset`, to be used for indexing into the logprobs
    correct_completion_ids = [
        toks[0] for toks in tokenizer(zero_shot_dataset.completions)["input_ids"]
    ]

    with model.trace(remote=True) as runner:

        # Run on the clean prompts, to get the h-vector
        with runner.invoke(dataset.prompts) as invoker:
            # Define h (we don't need to save it, cause we don't need it outside `runner:`)
            hidden_states = model.transformer.h[layer].output[0]
            h = hidden_states[:, -1].mean(dim=0)

        # First, run a forward pass where we don't intervene
        with runner.invoke(zero_shot_dataset.prompts) as invoker:
            # We save correct-token logprobs, not all logits - this means less for us to download!
            logprobs = model.lm_head.output[:, -1].log_softmax(dim=-1)
            correct_logprobs_zero_shot = logprobs[
                torch.arange(len(zero_shot_dataset)), correct_completion_ids
            ].save()

        # Next, run a forward pass on the zero-shot prompts where we do intervene
        with runner.invoke(zero_shot_dataset.prompts) as invoker:
            # Add the h-vector to the residual stream, at the last sequence position
            hidden_states = model.transformer.h[layer].output[0]
            hidden_states[:, -1] += h
            # We save correct-token logprobs, not all logits - this means less for us to download!
            logprobs = model.lm_head.output[:, -1].log_softmax(dim=-1)
            correct_logprobs_intervention = logprobs[
                torch.arange(len(zero_shot_dataset)), correct_completion_ids
            ].save()

    return (
        correct_logprobs_zero_shot.value.tolist(),
        correct_logprobs_intervention.value.tolist(),
    )

In [35]:
def display_model_logprobs_on_h_intervention(
    dataset: ICLDataset,
    correct_logprobs_zero_shot: list[float],
    correct_logprobs_intervention: list[float],
    num_to_display: int = 20,
) -> None:
    table = rich.table.Table(
        "Zero-shot prompt",
        "Model's logprob\n(no intervention)",
        "Model's logprob\n(intervention)",
        "Change in logprob",
        title="Model's antonym logprobs, with zero-shot h-intervention\n(green = intervention improves accuracy)",
    )

    for i in range(min(len(correct_logprobs_zero_shot), num_to_display)):

        logprob_ni = correct_logprobs_zero_shot[i]
        logprob_i = correct_logprobs_intervention[i]
        delta_logprob = logprob_i - logprob_ni
        zero_shot_prompt = f"{dataset[i].x[0]:>8} -> {dataset[i].y[0]}"

        # Color code the logprob based on whether it's increased with this intervention
        is_improvement = delta_logprob >= 0
        delta_logprob = (
            f"[b green]{delta_logprob:+.2f}[/]"
            if is_improvement
            else f"{delta_logprob:+.2f}"
        )

        table.add_row(
            zero_shot_prompt, f"{logprob_ni:.2f}", f"{logprob_i:.2f}", delta_logprob
        )

    rich.print(table)


dataset = ICLDataset(ANTONYM_PAIRS, size=20, n_prepended=3, seed=0)
zero_shot_dataset = ICLDataset(ANTONYM_PAIRS, size=20, n_prepended=0, seed=1)

correct_logprobs_zero_shot, correct_logprobs_intervention = (
    calculate_h_and_intervene_logprobs(model, dataset, zero_shot_dataset, layer=layer)
)

display_model_logprobs_on_h_intervention(
    zero_shot_dataset, correct_logprobs_zero_shot, correct_logprobs_intervention
)

2024-09-13 17:58:48,485 4f031cd1-3ead-46ff-9f36-a4b85a93cc4e - RECEIVED: Your job has been received and is waiting approval.
2024-09-13 17:58:48,516 4f031cd1-3ead-46ff-9f36-a4b85a93cc4e - APPROVED: Your job was approved and is waiting to be run.
2024-09-13 17:58:48,548 4f031cd1-3ead-46ff-9f36-a4b85a93cc4e - RUNNING: Your job has started running.
2024-09-13 17:58:49,014 4f031cd1-3ead-46ff-9f36-a4b85a93cc4e - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 1.62k/1.62k [00:00<00:00, 3.71MB/s]


In [37]:
import einops


def calculate_fn_vectors_and_intervene(
    model: nnsight.LanguageModel,
    dataset: ICLDataset,
    layers: list[int] | None = None,
) -> jaxtyping.Float[torch.Tensor, "layers heads"]:
    """
    Returns a tensor of shape (layers, heads), containing the CIE for each head.

    Inputs:
        model: LanguageModel
            the transformer you're doing this computation with
        dataset: ICLDataset
            the dataset of clean prompts from which we'll extract the function vector (we'll also create a
            corrupted version of this dataset for interventions)

        layers: list[int] | None
            the layers which this function will calculate the score for (if None, we assume all layers)
    """
    layers = range(model.config.n_layer) if (layers is None) else layers
    heads = range(model.config.n_head)

    # Get corrupted dataset
    corrupted_dataset = dataset.create_corrupted_dataset()
    N = len(dataset)

    # Get correct token ids, so we can get correct token logprobs
    correct_completion_ids = [
        toks[0] for toks in tokenizer(dataset.completions)["input_ids"]
    ]

    with model.trace(remote=True) as runner:

        # Run a forward pass on clean prompts, where we store attention head outputs
        z_dict = {}
        with runner.invoke(dataset.prompts) as invoker:

            for layer in layers:

                # Get hidden states, reshape to get head dimension, store the mean tensor
                z = model.transformer.h[layer].attn.out_proj.input[0][0][:, -1]
                z_reshaped = z.reshape(N, N_HEADS, D_HEAD).mean(dim=0)

                for head in heads:
                    z_dict[(layer, head)] = z_reshaped[head]

            # Get correct token logprobs
            logits_clean = model.lm_head.output[:, -1]

        # Run a forward pass on corrupted prompts, where we don't intervene or store activations (just so we can
        # get the correct-token logprobs to compare with our intervention)
        with runner.invoke(corrupted_dataset.prompts) as invoker:

            logits = model.lm_head.output[:, -1]

            correct_logprobs_corrupted = logits.log_softmax(dim=-1)[
                torch.arange(N), correct_completion_ids
            ].save()

        # For each head, run a forward pass on corrupted prompts (here we need multiple different forward passes,
        # because we're doing different interventions each time)
        correct_logprobs_dict = {}
        for layer in layers:
            for head in heads:
                with runner.invoke(corrupted_dataset.prompts) as invoker:

                    # Get hidden states, reshape to get head dimension, then set it to the a-vector
                    z = model.transformer.h[layer].attn.out_proj.input[0][0][:, -1]
                    z.reshape(N, N_HEADS, D_HEAD)[:, head] = z_dict[(layer, head)]

                    # Get logprobs at the end, which we'll compare with our corrupted logprobs
                    logits = model.lm_head.output[:, -1]
                    correct_logprobs_dict[(layer, head)] = logits.log_softmax(dim=-1)[
                        torch.arange(N), correct_completion_ids
                    ].save()

    # Get difference between intervention logprobs and corrupted logprobs, and take mean over batch dim
    all_correct_logprobs_intervention = einops.rearrange(
        torch.stack([v.value for v in correct_logprobs_dict.values()]),
        "(layers heads) batch -> layers heads batch",
        layers=len(layers),
    )

    # shape [layers heads batch]
    logprobs_diff = all_correct_logprobs_intervention - correct_logprobs_corrupted.value

    # Return mean effect of intervention, over the batch dimension
    return logprobs_diff.mean(dim=-1)

In [39]:
from gpt_from_scratch import plotly_utils

import time

# Get the best available PyTorch device
device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

print(
    f"Using device: {device}"
)  # Inline comment: Print the selected device for confirmation

dataset = ICLDataset(ANTONYM_PAIRS, size=4, n_prepended=2)


def batch_process_layers(n_layers, batch_size):
    for i in range(0, n_layers, batch_size):
        yield range(n_layers)[i : i + batch_size]


results = torch.empty((0, N_HEADS), device=device)

# If this fails to run, reduce the batch size so the fwd passes are split up more
for layers in batch_process_layers(N_LAYERS, batch_size=4):

    if layers[0] == 12:
        break

    print(f"Computing layers in {layers} ...")
    t0 = time.time()
    results = torch.concat(
        [results, calculate_fn_vectors_and_intervene(model, dataset, layers).to(device)]
    )
    print(f"... finished in {time.time()-t0:.2f} seconds.\n")


plotly_utils.imshow(
    results.T,
    title="Average indirect effect of function-vector intervention on antonym task",
    width=1000,
    height=600,
    labels={"x": "Layer", "y": "Head"},
    aspect="equal",
)

Using device: mps
Computing layers in range(0, 4) ...


2024-09-13 18:32:03,132 5077ebdc-650c-4f85-96ae-f2dd43118faf - RECEIVED: Your job has been received and is waiting approval.
2024-09-13 18:32:03,527 5077ebdc-650c-4f85-96ae-f2dd43118faf - APPROVED: Your job was approved and is waiting to be run.


KeyboardInterrupt: 