# Indirect Object Identification Circuit in Pythia

In [1]:
from IPython import get_ipython
from IPython.display import clear_output, display

ipython = get_ipython()
ipython.magic("load_ext autoreload")
ipython.magic("autoreload 2")

  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [131]:
import os
from typing import List, Optional, Union, Dict, Tuple

import torch
from torch import Tensor
import numpy as np
import einops
from fancy_einsum import einsum
import circuitsvis as cv

import transformer_lens.utils as utils

from transformer_lens import HookedTransformer
import transformer_lens.patching as patching

from transformers import AutoModelForCausalLM

from torch import Tensor
from jaxtyping import Float
import plotly.express as px

from functools import partial

from torchtyping import TensorType as TT

from path_patching_cm.path_patching import Node, IterNode, path_patch, act_patch
from path_patching_cm.ioi_dataset import IOIDataset, NAMES
from neel_plotly import imshow as imshow_n

from utils.visualization import imshow_p, plot_attention_heads, plot_attention

from utils.visualization_utils import (
    plot_attention_heads,
    scatter_attention_and_contribution,
    get_attn_head_patterns
)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [132]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def two_lines(tensor1, tensor2, renderer=None, **kwargs):
    px.line(y=[utils.to_numpy(tensor1), utils.to_numpy(tensor2)], **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [133]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f2c14b07520>

## Tools

In [6]:
class ComponentDict:
    def __init__(self, layer_heads, include_ln=False, include_mlps=False):
        self.components = {}
        hidden_size = 768  # Assuming a hidden size of 768
        num_heads = 12     # Assuming 12 heads per layer
        head_size = hidden_size // num_heads

        # Create a dictionary to store head indices for each layer
        layer_to_heads = {}
        for layer, head in layer_heads:
            if layer not in layer_to_heads:
                layer_to_heads[layer] = []
            layer_to_heads[layer].append(head)

        for layer, heads in layer_to_heads.items():
            for head in heads:
                # Calculate start and end indices for each specified head
                start_idx = head * head_size
                end_idx = start_idx + head_size

                # Store the slice information for each head's weights
                component_weight_key = f'gpt_neox.layers.{layer}.attention.query_key_value.weight'
                if component_weight_key not in self.components:
                    self.components[component_weight_key] = []
                self.components[component_weight_key].append((start_idx, end_idx))

                # Store the slice information for each head's biases
                component_bias_key = f'gpt_neox.layers.{layer}.attention.query_key_value.bias'
                if component_bias_key not in self.components:
                    self.components[component_bias_key] = []
                self.components[component_bias_key].append((start_idx, end_idx))

            # Add LayerNorm components if specified
            if include_ln:
                self.components[f'gpt_neox.layers.{layer}.input_layernorm'] = None
                self.components[f'gpt_neox.layers.{layer}.post_attention_layernorm'] = None

            # Add MLP components if specified
            if include_mlps:
                self.components[f'gpt_neox.layers.{layer}.mlp'] = None

    def get_component_specs(self):
        return self.components


In [7]:
def get_components_to_swap(source_model, component_dict, cache_dir):
    component_params = {}
    for name, param in source_model.named_parameters():
        comp_specs = component_dict.get_component_specs().get(name)
        if comp_specs is not None:
            # Handle multiple slices for both weights and biases
            if "bias" in name:
                # Bias is a 1D tensor
                slices = [param.detach().clone()[start:end] for start, end in comp_specs]
            else:
                # Weights are a 2D tensor
                slices = [param.detach().clone()[:, start:end] for start, end in comp_specs]
            concatenated_slices = torch.cat(slices, dim=-1)  # Concatenate on the last dimension
            component_params[name] = (concatenated_slices, comp_specs)
        elif comp_specs is None and name in component_dict.get_component_specs():
            # Handle non-sliced components
            component_params[name] = param.detach().clone()
    return component_params


In [8]:
def load_swapped_params(target_model, component_params):
    for name, param in target_model.named_parameters():
        if name in component_params:
            new_param_data, slice_info = component_params[name]
            if slice_info is not None:
                head_size = new_param_data.shape[-1] // len(slice_info)  # Adjust head size calculation
                for i, (start_idx, end_idx) in enumerate(slice_info):
                    if param.data.ndim == 2:
                        param.data[:, start_idx:end_idx] = new_param_data[:, i*head_size:(i+1)*head_size]
                    elif param.data.ndim == 1:
                        param.data[start_idx:end_idx] = new_param_data[i*head_size:(i+1)*head_size]
            elif slice_info is None:
                # For non-sliced components, replace the entire parameter
                param.data = new_param_data


In [9]:
from transformers import AutoModelForCausalLM, AutoTokenizer

def generate_text(prompt, model, model_name, max_length=50):
    # Load pre-trained model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Ensure model is in evaluation mode
    model.eval()

    # Encode the prompt text
    input_ids = tokenizer.encode(prompt, return_tensors='pt')

    # Generate text
    with torch.no_grad():  # Disable gradient calculations for efficiency
        output = model.generate(input_ids, max_length=max_length)

    # Decode and return the generated text
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    return generated_text

### Tests

In [10]:
def test_component_dict():
    # Example: Test for layer 9, heads 0, 1, 2
    component_dict = ComponentDict(layer_heads=[(9, 0), (9, 1), (9, 2)], include_ln=False, include_mlps=False)
    component_specs = component_dict.get_component_specs()

    # Check for the presence of correct keys and slices
    assert 'gpt_neox.layers.9.attention.query_key_value.weight' in component_specs
    assert 'gpt_neox.layers.9.attention.query_key_value.bias' in component_specs
    assert len(component_specs['gpt_neox.layers.9.attention.query_key_value.weight']) == 3  # 3 heads
    assert len(component_specs['gpt_neox.layers.9.attention.query_key_value.bias']) == 3  # 3 heads

    # Add more assertions as needed to test the slices

    print("Test ComponentDict: Passed")


In [11]:
def test_get_components_to_swap():
    source_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-70m")
    component_dict = ComponentDict(layer_heads=[(9, 0), (9, 1)], include_ln=False, include_mlps=False)
    component_params = get_components_to_swap(source_model, component_dict, cache_dir="./")

    hidden_size = source_model.config.hidden_size
    num_heads = source_model.config.num_attention_heads
    head_size = hidden_size // num_heads

    for name, (concatenated_param, slices) in component_params.items():
        if "query_key_value.weight" in name or "query_key_value.bias" in name:
            for i, (start, end) in enumerate(slices):
                expected_slice = source_model.get_parameter(name).data[:, start:end] if "weight" in name else source_model.get_parameter(name).data[start:end]
                actual_slice = concatenated_param[:, i*head_size:(i+1)*head_size] if concatenated_param.ndim == 2 else concatenated_param[i*head_size:(i+1)*head_size]
                assert torch.allclose(expected_slice, actual_slice), f"Slice mismatch in {name} for head {i}"

    print("Test get_components_to_swap: Passed")


In [12]:
def test_load_swapped_params():
    source_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-70m")
    target_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-70m")
    component_dict = ComponentDict(layer_heads=[(9, 0), (9, 1)], include_ln=False, include_mlps=False)
    

    hidden_size = target_model.config.hidden_size
    num_heads = target_model.config.num_attention_heads
    head_size = hidden_size // num_heads

    # Modify source model's specified components
    for name, param in source_model.named_parameters():
        if name in component_dict.get_component_specs():
            param.data.add_(0.123)  # Arbitrary modification for testing

    # Perform parameter swapping
    component_params = get_components_to_swap(source_model, component_dict, cache_dir="./")
    load_swapped_params(target_model, component_params)

    # Verify the parameters are correctly updated in the target model
    for name, (concatenated_param, slices) in component_params.items():
        if "query_key_value.weight" in name or "query_key_value.bias" in name:
            target_param = target_model.get_parameter(name)
            for i, (start, end) in enumerate(slices):
                target_slice = target_param.data[:, start:end] if target_param.ndim == 2 else target_param.data[start:end]
                source_slice = concatenated_param[:, i*head_size:(i+1)*head_size] if concatenated_param.ndim == 2 else concatenated_param[i*head_size:(i+1)*head_size]
                assert torch.allclose(target_slice, source_slice), f"Mismatch in parameters for {name} head {i}"

    print("Test load_swapped_params: Passed")


In [13]:
test_component_dict()
test_get_components_to_swap()
test_load_swapped_params()


Test ComponentDict: Passed


config.json:   0%|          | 0.00/567 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/166M [00:00<?, ?B/s]

Test get_components_to_swap: Passed
Test load_swapped_params: Passed


## Model Setup

In [170]:
model_name = "EleutherAI/pythia-160m"
revision_source = "step2000"
revision_target = "step143000"
cache_dir = "model_cache"

In [171]:
source_model = AutoModelForCausalLM.from_pretrained(
    model_name, revision=revision_source, cache_dir=cache_dir
)

target_model = AutoModelForCausalLM.from_pretrained(
    model_name, revision=revision_target, cache_dir=cache_dir
)

In [73]:
whole_circuit = [(4, 6), (4, 11), (6, 6), (7, 2), (7, 9), (8, 9), (8, 2), (8, 10)]
nmh = [(8, 2), (8, 10)]
top_3 = nmh + [(8, 9)]
all_s2i = [(6, 6), (7, 2), (7, 9), (8, 9)]
pre_s2i = [(6, 6), (7, 2), (7, 9)]
special_s2i = [(8, 9)]
idh = [(4, 6), (4, 11)]

component_dict = ComponentDict(layer_heads=all_s2i, include_ln=False, include_mlps=False)

component_params = get_components_to_swap(source_model, component_dict, cache_dir)
load_swapped_params(target_model, component_params)

In [295]:
model = HookedTransformer.from_pretrained(
    "EleutherAI/pythia-160m",
    #hf_model=target_model,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=False,
)
model.set_use_hook_mlp_in(True)

model.safetensors:   0%|          | 0.00/375M [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer


## Data Setup

### Complex Dataset

In [296]:
def _logits_to_ave_logit_diff(logits: Float[Tensor, "batch seq d_vocab"], ioi_dataset: IOIDataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), ioi_dataset.word_idx["end"], ioi_dataset.io_tokenIDs]
    s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), ioi_dataset.word_idx["end"], ioi_dataset.s_tokenIDs]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

def _ioi_metric_noising(
        logits: Float[Tensor, "batch seq d_vocab"],
        clean_logit_diff: float,
        corrupted_logit_diff: float,
        ioi_dataset: IOIDataset,
    ) -> float:
        '''
        We calibrate this so that the value is 0 when performance isn't harmed (i.e. same as IOI dataset),
        and -1 when performance has been destroyed (i.e. is same as ABC dataset).
        '''
        patched_logit_diff = _logits_to_ave_logit_diff(logits, ioi_dataset)
        return ((patched_logit_diff - clean_logit_diff) / (clean_logit_diff - corrupted_logit_diff)).item()


def generate_data_and_caches(N: int, verbose: bool = False, seed: int = 42):

    ioi_dataset = IOIDataset(
        prompt_type="mixed",
        N=N,
        tokenizer=model.tokenizer,
        prepend_bos=False,
        seed=seed,
        device=str(device)
    )

    abc_dataset = ioi_dataset.gen_flipped_prompts("ABB->ABA, BAB->BAA")

    model.reset_hooks(including_permanent=True)

    ioi_logits_original, ioi_cache = model.run_with_cache(ioi_dataset.toks)
    abc_logits_original, abc_cache = model.run_with_cache(abc_dataset.toks)

    ioi_average_logit_diff = _logits_to_ave_logit_diff(ioi_logits_original, ioi_dataset).item()
    abc_average_logit_diff = _logits_to_ave_logit_diff(abc_logits_original, ioi_dataset).item()

    if verbose:
        print(f"Average logit diff (IOI dataset): {ioi_average_logit_diff:.4f}")
        print(f"Average logit diff (ABC dataset): {abc_average_logit_diff:.4f}")

    ioi_metric_noising = partial(
        _ioi_metric_noising,
        clean_logit_diff=ioi_average_logit_diff,
        corrupted_logit_diff=abc_average_logit_diff,
        ioi_dataset=ioi_dataset,
    )

    return ioi_dataset, abc_dataset, ioi_cache, abc_cache, ioi_metric_noising



N = 70
ioi_dataset, abc_dataset, ioi_cache, abc_cache, ioi_metric_noising = generate_data_and_caches(N, verbose=True)

Average logit diff (IOI dataset): 4.1336
Average logit diff (ABC dataset): -4.0758


In [297]:
ioi_dataset.ioi_prompts[40]

{'[PLACE]': 'garden',
 '[OBJECT]': 'ring',
 'text': 'After Ruby and Joshua went to the garden, Ruby gave a ring to Joshua',
 'IO': 'Joshua',
 'S': 'Ruby',
 'TEMPLATE_IDX': 6}

In [298]:
abc_dataset.ioi_prompts[40]

{'[PLACE]': 'garden',
 '[OBJECT]': 'ring',
 'text': 'After Ruby and Joshua went to the garden, Joshua gave a ring to Joshua',
 'IO': 'Ruby',
 'S': 'Joshua',
 'TEMPLATE_IDX': 6}

In [299]:
clean_logits, clean_cache = model.run_with_cache(ioi_dataset.toks)
corrupted_logits, corrupted_cache = model.run_with_cache(abc_dataset.toks)

clean_logit_diff = _logits_to_ave_logit_diff(clean_logits, ioi_dataset)
print(f"Clean logit diff: {clean_logit_diff:.4f}")

corrupted_logit_diff = _logits_to_ave_logit_diff(corrupted_logits, ioi_dataset)
print(f"Corrupted logit diff: {corrupted_logit_diff:.4f}")

Clean logit diff: 4.1336
Corrupted logit diff: -4.0758


In [300]:
from utils.metrics import _logits_to_mean_accuracy

clean_logit_accuracy = _logits_to_mean_accuracy(clean_logits, ioi_dataset).item()
print(f"Clean logit accuracy: {clean_logit_accuracy:.4f}")

corrupted_logit_accuracy = _logits_to_mean_accuracy(corrupted_logits, ioi_dataset).item()
print(f"Corrupted logit accuracy: {corrupted_logit_accuracy:.4f}")

Clean logit accuracy: 0.9714
Corrupted logit accuracy: 0.0286


In [301]:
from utils.metrics import _logits_to_rank_0_rate

clean_logit_rank_0_rate = _logits_to_rank_0_rate(clean_logits, ioi_dataset)
print(f"Clean logit rank 0 rate: {clean_logit_rank_0_rate:.4f}")

corrupted_logit_rank_0_rate = _logits_to_rank_0_rate(corrupted_logits, ioi_dataset)
print(f"Corrupted logit rank 0 rate: {corrupted_logit_rank_0_rate:.4f}")

Clean logit rank 0 rate: 0.7000
Corrupted logit rank 0 rate: 0.0000


In [302]:
CLEAN_BASELINE = clean_logit_diff
CORRUPTED_BASELINE = corrupted_logit_diff

In [303]:
def logit_diff_denoising(
    logits: Float[Tensor, "batch seq d_vocab"],
    dataset: IOIDataset,
    flipped_logit_diff: float = corrupted_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
    return_tensor: bool = False,
) -> Float[Tensor, ""]:
    '''
    Linear function of logit diff, calibrated so that it equals 0 when performance is
    same as on flipped input, and 1 when performance is same as on clean input.
    '''
    patched_logit_diff = _logits_to_ave_logit_diff(logits, dataset)
    ld = ((patched_logit_diff - flipped_logit_diff) / (clean_logit_diff  - flipped_logit_diff))
    if return_tensor:
        return ld
    else:
        return ld.item()


def logit_diff_noising(
        logits: Float[Tensor, "batch seq d_vocab"],
        dataset: IOIDataset = ioi_dataset,
        clean_logit_diff: float = clean_logit_diff,
        corrupted_logit_diff: float = corrupted_logit_diff,
        return_tensor: bool = False,
    ) -> float:
        '''
        We calibrate this so that the value is 0 when performance isn't harmed (i.e. same as IOI dataset),
        and -1 when performance has been destroyed (i.e. is same as ABC dataset).
        '''
        patched_logit_diff = _logits_to_ave_logit_diff(logits, dataset)
        ld = ((patched_logit_diff - clean_logit_diff) / (clean_logit_diff - corrupted_logit_diff))

        if return_tensor:
            return ld
        else:
            return ld.item()

logit_diff_denoising_ioi = partial(logit_diff_denoising, dataset=ioi_dataset)
logit_diff_noising_ioi = partial(logit_diff_noising, dataset=ioi_dataset)

## Tool Setup

### Activation Patching

## Direct Logit Attribution

In [304]:
sio_answer_tensor = torch.cat(
    (torch.tensor(ioi_dataset.io_tokenIDs).unsqueeze(dim=1), torch.tensor(ioi_dataset.s_tokenIDs).unsqueeze(dim=1)), 
    dim=1
)

In [305]:
answer_residual_directions: Float[Tensor, "batch 2 d_model"] = model.tokens_to_residual_directions(sio_answer_tensor)
print("Answer residual directions shape:", answer_residual_directions.shape)

correct_residual_directions, incorrect_residual_directions = answer_residual_directions.unbind(dim=1)
logit_diff_directions: Float[Tensor, "batch d_model"] = correct_residual_directions - incorrect_residual_directions
print(f"Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([70, 2, 768])
Logit difference directions shape: torch.Size([70, 768])


In [307]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream: Float[Tensor, "batch seq d_model"] = clean_cache["resid_post", -1]
print(f"Final residual stream shape: {final_residual_stream.shape}")
scaled_residual_stream = clean_cache.apply_ln_to_stack(final_residual_stream, layer=-1)
print(f"Scaled residual stream shape: {scaled_residual_stream.shape}")
scaled_final_token_residual_stream: Float[Tensor, "batch d_model"] = scaled_residual_stream[torch.arange(final_residual_stream.size(0)), ioi_dataset.word_idx["end"]]
print(f"Final token residual stream shape: {scaled_final_token_residual_stream.shape}")

# Apply LayerNorm scaling (to just the final sequence position)
# pos_slice is the subset of the positions we take - here the final token of each prompt
#scaled_final_token_residual_stream = torch.zeros_like(final_token_residual_stream)
#for i in range(final_token_residual_stream.shape[0]):
#    scaled_final_token_residual_stream[i] = clean_cache.apply_ln_to_stack(final_token_residual_stream, layer=-1, pos_slice=ioi_dataset.word_idx["end"][i].item(), batch_slice=i)
#scaled_final_token_residual_stream = clean_cache.apply_ln_to_stack(final_token_residual_stream, layer=-1, pos_slice=pos_slice) #ioi_dataset.word_idx["end"].unsqueeze(0))
#print(scaled_final_token_residual_stream.shape)

average_logit_diff = einops.einsum(
    scaled_final_token_residual_stream, logit_diff_directions,
    "batch d_model, batch d_model ->"
) / 70

print(f"Calculated average logit diff: {average_logit_diff:.10f}")
print(f"Original logit difference:     {clean_logit_diff:.10f}")

#torch.testing.assert_close(average_logit_diff, clean_logit_diff)

Final residual stream shape: torch.Size([70, 21, 768])
Scaled residual stream shape: torch.Size([70, 21, 768])
Final token residual stream shape: torch.Size([70, 768])
Calculated average logit diff: 4.0873727798
Original logit difference:     4.1336488724


In [308]:
from transformer_lens import ActivationCache

def residual_stack_to_logit_diff(
    residual_stack: Float[Tensor, "... batch d_model"],
    cache: ActivationCache,
    logit_diff_directions: Float[Tensor, "batch d_model"] = logit_diff_directions,
) -> Float[Tensor, "..."]:
    '''
    Gets the avg logit difference between the correct and incorrect answer for a given
    stack of components in the residual stream.
    '''
    # SOLUTION
    batch_size = residual_stack.size(-2)
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer=-1, pos_slice=-1)
    return einops.einsum(
        scaled_residual_stack, logit_diff_directions,
        "... batch d_model, batch d_model -> ..."
    ) / batch_size


# Test function by checking that it gives the same result as the original logit difference
# t.testing.assert_close(
#     residual_stack_to_logit_diff(final_token_residual_stream, cache),
#     original_average_logit_diff
# )

### Logit Lens

In [325]:
accumulated_residual, labels = clean_cache.accumulated_resid(layer=-1, incl_mid=False, return_labels=True)
accumulated_residual_final_token = accumulated_residual[:, torch.arange(accumulated_residual.size(1)), ioi_dataset.word_idx["end"]]
print(f"Shape of accumulated residual: {accumulated_residual_final_token.shape}")
# accumulated_residual has shape (component, batch, d_model)

logit_lens_logit_diffs: Float[Tensor, "component"] = residual_stack_to_logit_diff(accumulated_residual_final_token, clean_cache)
line(logit_lens_logit_diffs, x=np.arange(model.cfg.n_layers+1), hover_name=labels, title="Logit Difference From Accumulated Residual Stream")

Shape of accumulated residual: torch.Size([13, 70, 768])


### Layer Attribution

In [330]:
per_layer_residual, labels = clean_cache.decompose_resid(layer=-1, return_labels=True)
per_layer_residual_final_token = per_layer_residual[:, torch.arange(per_layer_residual.size(1)), ioi_dataset.word_idx["end"]]
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual_final_token, clean_cache)

line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

### Head Attribution

In [339]:
per_head_residual, labels = clean_cache.stack_head_results(layer=-1, return_labels=True)
print(f"Shape of per head residual: {per_head_residual.shape}")
per_head_residual_final_token = per_head_residual[:, torch.arange(per_head_residual.size(1)), ioi_dataset.word_idx["end"]]
print(f"Shape of per head residual: {per_head_residual_final_token.shape}")
per_head_residual_final_token = einops.rearrange(
    per_head_residual_final_token,
    "(layer head) ... -> layer head ...",
    layer=model.cfg.n_layers
)
print(f"Shape of per head residual: {per_head_residual_final_token.shape}")
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual_final_token, clean_cache)

imshow(per_head_logit_diffs, xaxis="Head", yaxis="Layer", title="Logit Difference From Each Head")

Shape of per head residual: torch.Size([144, 70, 21, 768])
Shape of per head residual: torch.Size([144, 70, 768])
Shape of per head residual: torch.Size([12, 12, 70, 768])


In [340]:
plot_attention_heads(per_head_logit_diffs/clean_logit_diff, top_n=15, range_x=[0, 1])

Total logit diff contribution above threshold: 0.77


In [341]:
nmh_candidates = [(8, 10), (10, 7)]

In [352]:
import pandas as pd
def scatter_attention_and_contribution(
    model,
    head,
    prompts,
    end_positions,
    io_positions,
    s_positions,
    answer_residual_directions,
    return_vals=False,
    return_fig=False,
):

    df = []

    layer, head_idx = head
    # Get the attention output to the residual stream for the head
    _, cache = model.run_with_cache(prompts)
    per_head_residual, _ = cache.stack_head_results(
        layer=-1, return_labels=True
    )
    scaled_residual_stack = cache.apply_ln_to_stack(
        per_head_residual, layer=-1
    )
    print(scaled_residual_stack.shape)
    scaled_residual_stack_final_token = scaled_residual_stack[:, torch.arange(scaled_residual_stack.size(1)), end_positions]
    head_resid = scaled_residual_stack_final_token[layer * model.cfg.n_heads + head_idx]

    # Loop over each prompt
    for i in range(len(answer_residual_directions)):
        # Get attention values
        tokens, attn, names = get_attn_head_patterns(model, prompts[i], [head])

        # For IO
        # Get the attention contribution in the residual directions
        dot = einsum(
            "d_model, d_model -> ", head_resid[i], answer_residual_directions[i][0]
        )

        # Get the attention probability to the IO answer
        prob = attn[0, 14, io_positions[i]]
        df.append([prob, dot, "IO", prompts[i]])

        # For S
        # Get the attention contribution in the residual directions
        dot = einsum(
            "d_model, d_model -> ", head_resid[i], answer_residual_directions[i][1]
        )
        # Get the attention probability to the S answer
        prob = attn[0, 14, s_positions[i]]
        df.append([prob, dot, "S", prompts[i]])

    # Plot the results
    viz_df = pd.DataFrame(
        df, columns=[f"Attn Prob on Name", f"Dot w Name Embed", "Name Type", "text"]
    )
    fig = px.scatter(
        viz_df,
        x=f"Attn Prob on Name",
        y=f"Dot w Name Embed",
        color="Name Type",
        hover_data=["text"],
        color_discrete_sequence=["rgb(114,255,100)", "rgb(201,165,247)"],
        title=f"How Strong {layer}.{head_idx} Writes in the Name Embed Direction Relative to Attn Prob",
    )

    if return_vals:
        return viz_df
    if return_fig:
        return fig
    else:
        fig.show()

In [357]:
scatter_attention_and_contribution(model, (8, 10), ioi_dataset.toks, ioi_dataset.word_idx['end'], ioi_dataset.word_idx['IO'], ioi_dataset.word_idx['S2'], answer_residual_directions)

Tried to stack head results when they weren't cached. Computing head results now
torch.Size([144, 70, 21, 768])


In [None]:
top_k = 2
top_heads = torch.topk(-per_head_logit_diffs.flatten(), k=top_k).indices.cpu().numpy()
heads = [(head // model.cfg.n_heads, head % model.cfg.n_heads) for head in top_heads]
tokens, attn, names = get_attn_head_patterns(model, prompts[0], heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

In [None]:
scatter_attention_and_contribution(model, (8, 9), prompts, io_positions, s_positions, answer_residual_directions)

Tried to stack head results when they weren't cached. Computing head results now


## Activation Patching for Model Component Importance

### Attention Heads

In [111]:
results = act_patch(
    model=model,
    orig_input=abc_dataset.toks,
    new_cache=clean_cache,
    patching_nodes=IterNode("z"), # iterating over all heads' output in all layers
    patching_metric=logit_diff_denoising_ioi,
    verbose=True,
)

  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=12, head=12)


In [112]:
imshow_p(
    results['z'] * 100,
    title="Patching output of attention heads (corrupted -> clean)",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=600,
    margin={"r": 100, "l": 100}
)

### Head Output by Component

In [113]:
# iterating over all heads' output in all layers
results = act_patch(
    model=model,
    orig_input=abc_dataset.toks,
    new_cache=clean_cache,
    patching_nodes=IterNode(["z", "q", "k", "v", "pattern"]),
    patching_metric=logit_diff_denoising_ioi,
    verbose=True,
)

  0%|          | 0/720 [00:00<?, ?it/s]

results['z'].shape = (layer=12, head=12)
results['q'].shape = (layer=12, head=12)
results['k'].shape = (layer=12, head=12)
results['v'].shape = (layer=12, head=12)
results['pattern'].shape = (layer=12, head=12)


In [114]:
assert results.keys() == {"z", "q", "k", "v", "pattern"}
#assert all([r.shape == (12, 12) for r in results.values()])

imshow_p(
    torch.stack(tuple(results.values())) * 100,
    facet_col=0,
    facet_labels=["Output", "Query", "Key", "Value", "Pattern"],
    title="Patching output of attention heads (corrupted -> clean)",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=1500,
    margin={"r": 100, "l": 100}
)

### Residual Stream & Layer Outputs

In [19]:
results = act_patch(
    model=model,
    orig_input=abc_dataset.toks,
    new_cache=clean_cache,
    patching_nodes=IterNode(["resid_pre", "attn_out", "mlp_out"], seq_pos="each"),
    patching_metric=logit_diff_denoising_ioi,
    verbose=True,
)

  0%|          | 0/756 [00:00<?, ?it/s]

results['resid_pre'].shape = (seq_pos=21, layer=12)
results['attn_out'].shape = (seq_pos=21, layer=12)
results['mlp_out'].shape = (seq_pos=21, layer=12)


In [20]:
assert results.keys() == {"resid_pre", "attn_out", "mlp_out"}
labels = [f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(ioi_dataset.toks[0]))]
imshow_p(
    torch.stack([r.T for r in results.values()]) * 100, # we transpose so layer is on the y-axis
    facet_col=0,
    facet_labels=["resid_pre", "attn_out", "mlp_out"],
    title="Patching at resid stream & layer outputs (corrupted -> clean)",
    labels={"x": "Sequence position", "y": "Layer", "color": "Logit diff variation"},
    x=labels,
    xaxis_tickangle=45,
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=1400,
    height=600,
    margin={"r": 100, "l": 100}
)

## Circuit Sketching

### First Level

#### Heads Influencing Logit Diff Directly

In [115]:
model.cfg.use_attn_in = True

In [116]:
path_patch_resid_post = path_patch(
    model,
    orig_input=ioi_dataset.toks,
    new_input=abc_dataset.toks,
    sender_nodes=IterNode('z'), # This means iterate over all heads in all layers
    receiver_nodes=Node('resid_post', 11), # This is resid_post at layer 11
    patching_metric=logit_diff_noising_ioi,
    verbose=True
)

  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=12, head=12)


In [117]:
imshow_p(
    path_patch_resid_post['z'] * 100,
    title="Direct effect on logit diff (patch from head output -> final resid)",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=600,
    margin={"r": 100, "l": 100}
)

In [118]:
plot_attention_heads(-path_patch_resid_post['z'].cuda(), top_n=10, range_x=[0, 1.0])

Total logit diff contribution above threshold: 1.03


In [119]:
top_k = 3
DISPLAY_IDX = 0
top_heads = torch.topk(-path_patch_resid_post['z'].flatten(), k=top_k).indices.cpu().numpy()
heads = [(head // model.cfg.n_heads, head % model.cfg.n_heads) for head in top_heads]
tokens, attn, names = get_attn_head_patterns(model, model.to_string((ioi_dataset.toks[DISPLAY_IDX][:ioi_dataset.word_idx["end"][DISPLAY_IDX]+1])), heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

In [121]:
DE_NMH = [(8, 10), (8, 2)]
DE_S2I = [(7, 6)]
DE_PUNC_CONJ = [(10, 1)]

In [36]:
logit_diff_denoising_ioi_t = partial(logit_diff_denoising, dataset=ioi_dataset, return_tensor=True)
logit_diff_noising_ioi_t = partial(logit_diff_noising, dataset=ioi_dataset, return_tensor=True)

In [37]:
ioi_metric = logit_diff_denoising_ioi
attn_head_pattern_all_pos_act_patch_results = patching.get_act_patch_attn_head_pattern_all_pos(model, abc_dataset.toks, ioi_cache, logit_diff_denoising_ioi_t)
# imshow(attn_head_pattern_all_pos_act_patch_results, 
#        yaxis="Layer", 
#        xaxis="Head", 
#        title="IOI Metric for 'attn_head_pattern' Activation Patching (All Pos)")

  0%|          | 0/144 [00:00<?, ?it/s]

In [38]:
attn_head_out_all_pos_act_patch_results = patching.get_act_patch_attn_head_out_all_pos(model, abc_dataset.toks, ioi_cache, logit_diff_denoising_ioi_t)
# imshow(attn_head_out_all_pos_act_patch_results, 
#        yaxis="Layer", 
#        xaxis="Head", 
#        title="IOI Metric for 'attn_head_out' Activation Patching (All Pos)")

  0%|          | 0/144 [00:00<?, ?it/s]

In [40]:
from utils.visualization_utils import l_scatter
head_labels = [f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)]
l_scatter(
    x=utils.to_numpy(attn_head_pattern_all_pos_act_patch_results.flatten()), 
    y=utils.to_numpy(attn_head_out_all_pos_act_patch_results.flatten()), 
    hover_name = head_labels,
    xaxis="Attention Patch",
    yaxis="Output Patch",
    title="Scatter plot of output patching vs attention patching")

#### NMH Knockout

##### All Heads

In [92]:
heads_to_ablate = DE_NMH

print(f"Heads to ablate: {heads_to_ablate}")
def ablate_top_head_hook(z: TT["batch", "pos", "head_index", "d_head"], hook, head_idx=0):
    z[:, :, head_idx, :] = 0
    return z
# Adds a hook into global model state
for layer, head in heads_to_ablate:
    ablate_head_hook = partial(ablate_top_head_hook, head_idx=head)
    model.blocks[layer].attn.hook_z.add_hook(ablate_head_hook)
# Runs the model, temporarily adds caching hooks and then removes *all* hooks after running, including the ablation hook.
ablated_logits, ablated_cache = model.run_with_cache(ioi_dataset.toks)
print(f"Original IOI Metric: {_logits_to_ave_logit_diff(clean_logits, ioi_dataset).item():.4f}")
print(f"Post ablation IOI Metric: {_logits_to_ave_logit_diff(ablated_logits, ioi_dataset).item()}")

Heads to ablate: [(8, 10)]
Original IOI Metric: 4.1336
Post ablation IOI Metric: 2.110731363296509


In [93]:
per_head_ablated_residual, labels = ablated_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_ablated_logit_diffs = residual_stack_to_logit_diff(per_head_ablated_residual, logit_diff_directions, prompts, ablated_cache)
per_head_ablated_logit_diffs = per_head_ablated_logit_diffs.reshape(model.cfg.n_layers, model.cfg.n_heads)
imshow_n(per_head_ablated_logit_diffs, labels={"x":"Head", "y":"Layer"}, zmin=-1.5, zmax=1.5, title="Post-Ablation Direct Logit Attribution of Heads")
l_scatter(y=per_head_logit_diffs.flatten(), x=per_head_ablated_logit_diffs.flatten(), hover_name=head_labels, range_x=(-3, 3), range_y=(-3, 3), xaxis="Ablated", yaxis="Original", title="Original vs Post-Ablation Direct Logit Attribution of Heads")

Tried to stack head results when they weren't cached. Computing head results now


NameError: name 'logit_diff_directions' is not defined

In [126]:
exclusions = [(6, 6), (7, 9), (8, 9)] + [(9, 1), (9, 5)]
delta = per_head_ablated_logit_diffs - per_head_logit_diffs
for layer, head in exclusions:
    per_head_ablated_logit_diffs[layer, head] = 0

plot_attention_heads(
    per_head_ablated_logit_diffs/clean_logit_diff, 
    title="Logit Diff Contribution From Backup Heads", 
    top_n=15, 
    range_x=[0, 0.5]
)

Total logit diff contribution above threshold: 0.37


##### Individual Heads

In [73]:
# Get indices of all heads where the ablation had a positive effect
delta = per_head_ablated_logit_diffs - per_head_logit_diffs
backup_nmh_candidates = np.argwhere(delta.cpu().detach().numpy() > 0.05)
backup_nmh_candidates = [tuple(h) for h in backup_nmh_candidates]
backup_nmh_candidates = [h for h in backup_nmh_candidates if h not in exclusions]
print(f"Backup NMH Candidates: {backup_nmh_candidates}")
for l, h in backup_nmh_candidates:
    for layer, head in heads_to_ablate:
        ablate_head_hook = partial(ablate_top_head_hook, head_idx=head)
        model.blocks[layer].attn.hook_z.add_hook(ablate_head_hook)
    scatter_attention_and_contribution(model, (l, h), prompts, io_positions, s_positions, answer_residual_directions)

Backup NMH Candidates: [(9, 6)]
Tried to stack head results when they weren't cached. Computing head results now


RuntimeError: The expanded size of the tensor (1) must match the existing size (8) at non-singleton dimension 0.  Target sizes: [1, 15, 64].  Tensor sizes: [8, 15, 64]

In [76]:
attn_head_pattern_all_pos_act_patch_results['z'].shape

torch.Size([12, 12])

In [77]:
top_k = 5
top_heads = torch.topk(-attn_head_pattern_all_pos_act_patch_results['z'].flatten(), k=top_k).indices.cpu().numpy()
heads = [(head // model.cfg.n_heads, head % model.cfg.n_heads) for head in top_heads]
tokens, attn, names = get_attn_head_patterns(model, prompts, heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

In [78]:
# V-weighted version
plot_attention(
    model, 
    prompts[0],
    nmh_candidates,
    clean_cache,
    weighted=True)

### Contributors to NMHs

#### Attention Out by Position

In [122]:
results = path_patch(
    model,
    orig_input=ioi_dataset.toks,
    new_input=abc_dataset.toks,
    sender_nodes=IterNode(node_names=["attn_out"], seq_pos="each"),
    receiver_nodes=[Node("q", layer, head=head) for layer, head in DE_S2I],
    patching_metric=logit_diff_noising_ioi,
    verbose=True,
)
results = einops.rearrange(results['attn_out'], "seq layer -> layer seq")

  0%|          | 0/252 [00:00<?, ?it/s]

results['attn_out'].shape = (seq_pos=21, layer=12)


In [123]:
imshow_n(
        results * 100,
        title=f"Direct effect on DE Heads' values",
        xaxis="Pos", 
        x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(ioi_dataset.toks[0]))],
        y=[f"Layer {layer}" for layer in range(model.cfg.n_layers)],
        width=1500,
        height=600,
    )

#####

#### Attention Out by Head

In [126]:
results = path_patch(
    model,
    orig_input=ioi_dataset.toks,
    new_input=abc_dataset.toks,
    sender_nodes=IterNode("z"),
    receiver_nodes=[Node("v", layer, head=head) for layer, head in DE_S2I],
    patching_metric=logit_diff_noising_ioi,
    verbose=True,
)

  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=12, head=12)


In [127]:
imshow_p(
        results["z"][:10] * 100,
        title=f"Direct effect on NMH' queries",
        labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
        coloraxis=dict(colorbar_ticksuffix = "%"),
        border=True,
        width=700,
        margin={"r": 100, "l": 100}
    )

In [128]:
plot_attention_heads(-results['z'].cuda(), top_n=10, range_x=[0, 1.0])

Total logit diff contribution above threshold: 0.22


In [47]:
IDX = 55
model.to_string((ioi_dataset.toks[IDX][:ioi_dataset.word_idx["end"][IDX]+1]))

'After Peter and Brian went to the hospital, Peter gave a snack to'

In [48]:
ioi_dataset.word_idx["end"].shape

torch.Size([70])

In [129]:
top_k = 2
DISPLAY_IDX = 45
top_heads = torch.topk(-results['z'].flatten(), k=top_k).indices.cpu().numpy()
heads = [(head // model.cfg.n_heads, head % model.cfg.n_heads) for head in top_heads]
tokens, attn, names = get_attn_head_patterns(model, model.to_string((ioi_dataset.toks[DISPLAY_IDX][:ioi_dataset.word_idx["end"][DISPLAY_IDX]+1])), heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

### Second Level

#### Attention Pattern for Second-Level Heads

In [50]:
second_level_positive_heads = [(6, 6), (7, 2), (7, 9)]
IE_S2I = second_level_positive_heads

tokens, attn, names = get_attn_head_patterns(model, model.to_string((ioi_dataset.toks[DISPLAY_IDX][:ioi_dataset.word_idx["end"][DISPLAY_IDX]+1])), second_level_positive_heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

#second_level_negative_heads = [(7, 8), (8, 10)]
#visualize_attention_patterns(torch.tensor([l*12+h for l, h in second_level_negative_heads]), title=f"Top Negative Second Level IOI Metric Heads")

In [66]:
head_labels = [f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)]
l_scatter(
    x=utils.to_numpy(attn_head_v_all_pos_act_patch_results.flatten()), 
    y=utils.to_numpy(attn_head_out_all_pos_act_patch_results.flatten()), 
    xaxis="Value Patch",
    yaxis="Output Patch",
    #caxis="Layer",
    hover_name = head_labels,
    color=einops.repeat(np.arange(model.cfg.n_layers), "layer -> (layer head)", head=model.cfg.n_heads),
    range_x=(-1.5, 1.5),
    range_y=(-1.5, 1.5),
    title="Scatter plot of output patching vs value patching")

In [51]:
s2i_candidates = [(6, 6), (7, 2), (7, 9), (8, 9)]
#s2i_candidates = [(8, 9)]

#### S2I Knockout

##### All Heads

In [69]:
heads_to_ablate = s2i_candidates

print(f"Heads to ablate: {heads_to_ablate}")
def ablate_top_head_hook(z: TT["batch", "pos", "head_index", "d_head"], hook, head_idx=0):
    z[:, -1, head_idx, :] = 0
    return z
# Adds a hook into global model state
for layer, head in heads_to_ablate:
    ablate_head_hook = partial(ablate_top_head_hook, head_idx=head)
    model.blocks[layer].attn.hook_z.add_hook(ablate_head_hook)
# Runs the model, temporarily adds caching hooks and then removes *all* hooks after running, including the ablation hook.
ablated_logits, ablated_cache = model.run_with_cache(clean_tokens)
print(f"Original IOI Metric: {ioi_metric(clean_logits).item():.4f}")
print(f"Post ablation IOI Metric: {ioi_metric(ablated_logits).item()}")
#print(f"Direct Logit Attribution of top name mover head: {per_head_logit_diffs.flatten()[top_name_mover].item()}")
#print(f"Naive prediction of post ablation logit diff: {original_average_logit_diff - per_head_logit_diffs.flatten()[top_name_mover].item()}")

Heads to ablate: [(6, 6), (7, 2), (7, 9), (8, 9)]
Original IOI Metric: 1.0000
Post ablation IOI Metric: 0.5408056974411011


In [70]:
per_head_ablated_residual, labels = ablated_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_ablated_logit_diffs = residual_stack_to_logit_diff(per_head_ablated_residual, ablated_cache)
per_head_ablated_logit_diffs = per_head_ablated_logit_diffs.reshape(model.cfg.n_layers, model.cfg.n_heads)
imshow(per_head_ablated_logit_diffs, labels={"x":"Head", "y":"Layer"})
l_scatter(y=per_head_logit_diffs.flatten(), x=per_head_ablated_logit_diffs.flatten(), hover_name=head_labels, range_x=(-3, 3), range_y=(-3, 3), xaxis="Ablated", yaxis="Original", title="Original vs Post-Ablation Direct Logit Attribution of Heads")

Tried to stack head results when they weren't cached. Computing head results now


#### Path Patching for S2-Inhibition Candidates

In [52]:
receiver_heads = second_level_positive_heads

results = path_patch(
    model,
    orig_input=ioi_dataset.toks,
    new_input=abc_dataset.toks,
    sender_nodes=IterNode("z"),
    receiver_nodes=[Node("v", layer, head=head) for layer, head in s2i_candidates],
    patching_metric=logit_diff_noising_ioi,
    verbose=True,
)

  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=12, head=12)


In [53]:
imshow_p(
        results["z"][:10] * 100,
        title=f"Direct effect on S2Is' values",
        labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
        coloraxis=dict(colorbar_ticksuffix = "%"),
        border=True,
        width=700,
        margin={"r": 100, "l": 100}
    )

In [54]:
plot_attention_heads(-results['z'].cuda(), top_n=10, range_x=[0, 1.0])

Total logit diff contribution above threshold: 0.37


### Third Level

#### Attention Patterns for Third-Level Heads

We have a mix of induction heads and duplicate token heads here, as well as two heads that focus on S2 at S2.

In [58]:
third_level_positive_heads = [(4, 6), (4, 11)]
DISPLAY_IDX = 0
tokens, attn, names = get_attn_head_patterns(model, model.to_string((ioi_dataset.toks[DISPLAY_IDX][:ioi_dataset.word_idx["end"][DISPLAY_IDX]+1])), third_level_positive_heads)
#tokens, attn, names = get_attn_head_patterns(model, model.to_string((ioi_dataset.toks[DISPLAY_IDX][:ioi_dataset.word_idx["end"][DISPLAY_IDX]+1])), second_level_positive_heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

## Save the Circuit

In [63]:
# define circuit
from collections import namedtuple
import pickle

CircuitComponent = namedtuple(
    "CircuitComponent", ["heads", "position", "receiver_type"]
)

circuit = {
    "name-movers": CircuitComponent(
        DE_NMH, -1, "hook_q"
    ),
    "s2-inhibition": CircuitComponent(s2i_candidates, ioi_dataset.word_idx["S2"], "hook_v"),
    # "duplicate-name": CircuitComponent([(7, 15), (9, 1)], 10, 'head_v'),
    # "induction": CircuitComponent([], 10, 'head_v')
}

# Specify a filename for saving the circuit dictionary
circuit_filename = 'results/circuits/pythia_160m_circuit.pkl'

# Save the circuit dictionary using pickle
with open(circuit_filename, 'wb') as f:
    pickle.dump(circuit, f)

In [None]:

# save 410m circuit
circuit = {
    "name-movers": CircuitComponent(
        [(17, 10), (17, 6), (17, 11), (18, 0), (18, 8), (18, 13), (18, 14)],
        -1,
        "hook_q",
    ),
    "s2-inhibition": CircuitComponent(
        [(11, 4), (13, 1), (13, 5), (16, 0)], 10, "hook_v"
    ),
    # "duplicate-name": CircuitComponent([], 10, 'head_v'),
    # "induction": CircuitComponent([], 10, 'head_v')
}

In [None]:
# save 1.4b circuit
circuit = {
    "name-movers": CircuitComponent(
        [(12, 15), (13, 1), (13, 6), (15, 15), (16, 13), (17, 7)], -1, "hook_q"
    ),
    "s2-inhibition": CircuitComponent([(10, 7)], 10, "hook_v"),
    # "duplicate-name": CircuitComponent([(7, 15), (9, 1)], 10, 'head_v'),
    # "induction": CircuitComponent([], 10, 'head_v')
}

In [None]:
# save 2.8b circuit
circuit = {
    "name-movers": CircuitComponent(
        [(12, 15), (13, 1), (13, 6), (15, 15), (16, 13), (17, 7)], -1, "hook_q"
    ),
    "s2-inhibition": CircuitComponent([(10, 7)], 10, "hook_v"),
    # "duplicate-name": CircuitComponent([(7, 15), (9, 1)], 10, 'head_v'),
    # "induction": CircuitComponent([], 10, 'head_v')
}

In [6]:
circuit_file = "pythia_1-4b_circuit.pkl"
circuit_root = "results/circuits/"
with open(circuit_root + circuit_file, 'rb') as f:
    circuit = pickle.load(f)

In [15]:
res = torch.load(f"results/pythia-1.4b-no-dropout/value_perf.pt")

In [20]:
res[18, 12]

tensor([-4.5751, -4.5505], device='cuda:0')

In [12]:
model = HookedTransformer.from_pretrained(
    "EleutherAI/pythia-1.4b",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=False,
)
model.set_use_hook_mlp_in(True)

model.safetensors:   0%|          | 0.00/2.93G [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model EleutherAI/pythia-1.4b into HookedTransformer


In [17]:
model.cfg

HookedTransformerConfig:
{'act_fn': 'gelu',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 128,
 'd_mlp': 8192,
 'd_model': 2048,
 'd_vocab': 50304,
 'd_vocab_out': 50304,
 'default_prepend_bos': True,
 'device': device(type='cuda'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.017677669529663688,
 'model_name': 'pythia-1.4b',
 'n_ctx': 2048,
 'n_devices': 1,
 'n_heads': 16,
 'n_key_value_heads': None,
 'n_layers': 24,
 'n_params': 1207959552,
 'normalization_type': 'LNPre',
 'original_architecture': 'GPTNeoXForCausalLM',
 'parallel_attn_mlp': True,
 'positional_embedding_type': 'rotary',
 'post_embedding_ln': False,
 'rotary_dim': 32,
 'scale_attn_by_inverse_layer_idx': False,
 'seed': None,
 'tokenizer_name': 'EleutherAI/pythia-1.4b',