In [2]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer, ActivationCache, utils, patching
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import ipywidgets as widgets
from IPython.display import display, clear_output

pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import get_mlp_activations
import haystack_utils

from typing import Literal

%reload_ext autoreload
%autoreload 2

In [3]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

german_data = haystack_utils.load_json_data("data/german_europarl.json")[:200]
english_data = haystack_utils.load_json_data("data/english_europarl.json")[:200]


english_activations = {}
german_activations = {}
for layer in range(3, 4):
    english_activations[layer] = get_mlp_activations(english_data, layer, model, mean=False)
    german_activations[layer] = get_mlp_activations(german_data, layer, model, mean=False)

LAYER_TO_ABLATE = 3
NEURONS_TO_ABLATE = [669]
MEAN_ACTIVATION_ACTIVE = german_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()
MEAN_ACTIVATION_INACTIVE = english_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()

def deactivate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_INACTIVE
    return value
deactivate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', deactivate_neurons_hook)]

def activate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_ACTIVE
    return value
activate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', activate_neurons_hook)]

all_ignore, not_ignore = haystack_utils.get_weird_tokens(model, plot_norms=False)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

In [4]:
def get_attention_pattern_activations(cache: ActivationCache, layer: int, n_pos=-1):
    block_name = f'blocks.{layer}.attn.hook_pattern'
    original_activations = cache[block_name].mean(0)

    if n_pos == -1:
        n_pos = original_activations.shape[1]

    for i, pattern in enumerate(original_activations):
        haystack_utils.imshow(pattern[:n_pos, :n_pos].cpu(), title=f"L{layer}H{i} Attention patterns", width=700)


tokens = model.to_tokens(german_data[:20])
original_loss, original_cache = model.run_with_cache(tokens, return_type="loss")

get_attention_pattern_activations(original_cache, layer=0, n_pos=20)
get_attention_pattern_activations(original_cache, layer=1, n_pos=20)
get_attention_pattern_activations(original_cache, layer=2, n_pos=20)
# get_attention_pattern_activations(original_cache, layer=3, n_pos=20)
# get_attention_pattern_activations(original_cache, layer=4, n_pos=20)
# get_attention_pattern_activations(original_cache, layer=5, n_pos=20)
# get_attention_pattern_activations(original_cache, layer=6, n_pos=20)

## Trigram circuit

Previous token head could copy over skip-bigram information, or copy over what the previous token is so the current token can compute trigram information using a trigram table, perhaps learned in the embed, of previous tokens -> current token -> most common next token.

Test:
- If head has a positive direct effect on the output it's a skip bigram. 
- If not, generate a corrupted head cache and patch components into a clean run to see what's reading from it.

We think it's doing trigrams rather than skip bigrams so we could potentially save some time by testing for skip trigrams immediately.

Trigram: each token embedding has a table of possible previous token x implies next token z. L0H7 copies across previous token x and multiplies it with the trigram take to get next token z. The trigram table must be read in from the current position token embedding so it's not happening in the previous token head which only reads from the previous token. It either happens in a future head that only reads from the current token

Hypothesis: previous token head L0H7 is responsible for most of the trigram loss increase. Two options for how it's used:
- a global bigram table in the embedding or in a later component contains the global bigram completions that sometimes include completions which only make sense in German. The attention head copies this information over 
- an MLP layer contains German-specific bigram completions. The bigram completions are combined with the previous token and selected when the prompt is in German, most likely by direct reference to the German context neuron.

To look for the MLP layer, get a cache where L2H1 has been zero ablated or otherwise distorted, and then patch MLP components from this cache into a clean run and see which one damages the loss.

In [5]:
# Get top common german tokens excluding punctuation
token_counts = torch.zeros(model.cfg.d_vocab).cuda()
for example in tqdm(german_data):
    tokens = model.to_tokens(example)
    for token in tokens[0]:
        token_counts[token.item()] += 1

punctuation = ["\n", ".", ",", "!", "?", ";", ":", "-", "(", ")", "[", "]", "{", "}", "<", ">", "/", "\\", "\"", "'"]
leading_space_punctuation = [" " + char for char in punctuation]
punctuation_tokens = model.to_tokens(punctuation + leading_space_punctuation + [' –', " ", '  ', "<|endoftext|>"])[:, 1].flatten()
token_counts[punctuation_tokens] = 0
token_counts[all_ignore] = 0

top_counts, top_tokens = torch.topk(token_counts, 100)
print(model.to_str_tokens(top_tokens[:100]))

def get_random_selection(tensor, n=12):
    # Hacky replacement for np.random.choice
    return tensor[torch.randperm(len(tensor))[:n]]

def generate_random_prompts(end_string, n=50, length=12):
    # Generate a batch of random prompts ending with a specific ngram
    end_tokens = model.to_tokens(end_string).flatten()[1:]
    prompts = []
    for i in range(n):
        prompt = get_random_selection(top_tokens[:max(50, length)], n=length).cuda()
        prompt = torch.cat([prompt, end_tokens])
        prompts.append(prompt)
    prompts = torch.stack(prompts)
    return prompts

prompts = generate_random_prompts(" Vorschlägen", n=100, length=20)

  0%|          | 0/200 [00:00<?, ?it/s]

[' der', 'en', ' die', ' und', 'ung', 'ä', ' in', ' den', ' des', ' zu', 'ch', 'n', 'st', 're', 'z', ' von', ' für', 'äsident', ' Pr', 'ischen', 't', 'ü', 'icht', 'gen', ' ist', ' auf', ' dass', 'ge', 'ig', ' im', 'in', ' über', 'g', ' das', 'te', ' er', 'men', ' w', 'es', ' an', 'ß', ' wir', ' eine', 'f', ' W', 'hen', 'w', ' Europ', ' ich', 'ungen', 'ren', 'le', ' dem', 'ten', ' ein', 'e', ' Z', ' Ver', 'der', ' B', ' mit', ' dies', 'h', ' nicht', 'ungs', 's', ' G', ' z', 'it', ' Herr', ' es', 'l', ' S', 'ich', 'lich', ' An', 'heit', 'ie', ' Er', ' zur', ' V', ' ver', 'u', 'hr', 'chaft', 'Der', ' Ich', ' Ab', ' haben', 'i', 'ant', 'chte', ' mö', 'er', ' K', 'igen', ' Ber', 'ür', ' Fra', 'em']


In [6]:
# see if zero ablating the clean prev token head increases loss
# can't mean ablate the pattern because it always looks at previous tokens
# could mean ablate the output to it passes forwards gibberish as the previous token
common_tokens = haystack_utils.get_common_tokens(german_data, model, all_ignore, k=50)
random_prompts = haystack_utils.generate_random_prompts(" Vorschlägen", model, common_tokens, 100, length=20)[:, :-4]
model.set_use_attn_result(True)

l0_heads = 'blocks.0.attn.hook_result'
_, original_cache = model.run_with_cache(random_prompts, return_type="loss", loss_per_token=True)
original_loss, _ = model.run_with_cache(prompts, return_type="loss", loss_per_token=True)

mean_l0h7_out = original_cache[l0_heads][:, :, 7, :].mean((0, 1))
def ablate_l2h7_hook(value, hook):
    value[:, -2, 7, :] = mean_l0h7_out
    return value

with model.hooks([(l0_heads, ablate_l2h7_hook)]):
    ablated_loss, ablated_cache = model.run_with_cache(prompts, return_type="loss", loss_per_token=True)

print((ablated_loss[:, -1] - original_loss[:, -1]).mean())

  0%|          | 0/200 [00:00<?, ?it/s]

tensor(2.9198, device='cuda:0')


In [7]:
original_loss, original_cache = model.run_with_cache(prompts, return_type="loss", loss_per_token=True)
with model.hooks([(l0_heads, ablate_l2h7_hook)]):
    ablated_loss, ablated_cache = model.run_with_cache(prompts, return_type="loss", loss_per_token=True)

def deactivate_components_hook(value, hook):
    if "blocks." in hook.name and hook.layer() != 0 and ("attn" in hook.name or "mlp" in hook.name):
        value = ablated_cache[hook.name]
        print(hook.name)
    return value   
deactivate_component_hooks = [(lambda pattern : True, deactivate_components_hook)]

with model.hooks(fwd_hooks=deactivate_component_hooks):
    activated_direct_effect_loss, _ = model.run_with_cache(prompts, return_type="loss", loss_per_token=True)

print((activated_direct_effect_loss - ablated_loss).mean(0)[-1])

# Not a skip bigram

blocks.1.attn.hook_q
blocks.1.attn.hook_k
blocks.1.attn.hook_v
blocks.1.attn.hook_rot_q
blocks.1.attn.hook_rot_k
blocks.1.attn.hook_attn_scores
blocks.1.attn.hook_pattern
blocks.1.attn.hook_z
blocks.1.attn.hook_result
blocks.1.hook_attn_out
blocks.1.mlp.hook_pre
blocks.1.mlp.hook_post
blocks.1.hook_mlp_out
blocks.2.attn.hook_q
blocks.2.attn.hook_k
blocks.2.attn.hook_v
blocks.2.attn.hook_rot_q
blocks.2.attn.hook_rot_k
blocks.2.attn.hook_attn_scores
blocks.2.attn.hook_pattern
blocks.2.attn.hook_z
blocks.2.attn.hook_result
blocks.2.hook_attn_out
blocks.2.mlp.hook_pre
blocks.2.mlp.hook_post
blocks.2.hook_mlp_out
blocks.3.attn.hook_q
blocks.3.attn.hook_k
blocks.3.attn.hook_v
blocks.3.attn.hook_rot_q
blocks.3.attn.hook_rot_k
blocks.3.attn.hook_attn_scores
blocks.3.attn.hook_pattern
blocks.3.attn.hook_z
blocks.3.attn.hook_result
blocks.3.hook_attn_out
blocks.3.mlp.hook_pre
blocks.3.mlp.hook_post
blocks.3.hook_mlp_out
blocks.4.attn.hook_q
blocks.4.attn.hook_k
blocks.4.attn.hook_v
blocks.4.attn

In [8]:
original_loss, original_cache = model.run_with_cache(prompts, return_type="loss", loss_per_token=True)
with model.hooks([(l0_heads, ablate_l2h7_hook)]):
    ablated_loss, ablated_cache = model.run_with_cache(prompts, return_type="loss", loss_per_token=True)

def deactivate_components_hook(value, hook):
    value = ablated_cache[hook.name]
    return value  

def activate_components_hook(value, hook):
    value = original_cache[hook.name]
    return value

components = []
for component_prefix in ["mlp.hook_post", "hook_attn_out"]:
    for block in range(6):
        components.append(f"blocks.{block}.{component_prefix}")

print("Total effects of deactivating component reads from L0H7")
total_effects = []
for component in components:
    deactivate_component_hooks = [(component, deactivate_components_hook)]
    with model.hooks(fwd_hooks=deactivate_component_hooks):
        activated_direct_effect_loss, _ = model.run_with_cache(prompts, return_type="loss", loss_per_token=True)
        total_effects.append((activated_direct_effect_loss - original_loss).mean(0)[-1])

for component, total_effect in zip(components, total_effects):
    if total_effect > 0.5:
        print(component, total_effect.item())

print("Direct effects of deactivating component reads from L0H7")
direct_effects = []
for component in components:
    deactivate_component_hooks = [(component, deactivate_components_hook)]
    activate_other_components_hooks = [(other_component, activate_components_hook) for other_component in components if other_component != component]
    with model.hooks(fwd_hooks=deactivate_component_hooks+activate_other_components_hooks):
        activated_direct_effect_loss, _ = model.run_with_cache(prompts, return_type="loss", loss_per_token=True)
        direct_effects.append((activated_direct_effect_loss - original_loss).mean(0)[-1])

for component, direct_effect in zip(components, direct_effects):
    if direct_effect > 0.5:
        print(component, direct_effect.item())

print("Total effects of deactivating component reads through L0H7 -> MLP1")
def ablate_mlp_hook(value, hook):
    value = ablated_cache[hook.name]
    return value
with model.hooks([('blocks.1.mlp.hook_post', ablate_mlp_hook)]):
    ablated_loss, ablated_cache = model.run_with_cache(prompts, return_type="loss", loss_per_token=True)

def deactivate_components_hook(value, hook):
    value = ablated_cache[hook.name]
    return value  

total_effects = []
for component in components:
    deactivate_component_hooks = [(component, deactivate_components_hook)]
    with model.hooks(fwd_hooks=deactivate_component_hooks):
        activated_direct_effect_loss, _ = model.run_with_cache(prompts, return_type="loss", loss_per_token=True)
        total_effects.append((activated_direct_effect_loss - original_loss).mean(0)[-1])

for component, total_effect in zip(components, total_effects):
    if total_effect > 0.5:
        print(component, total_effect.item())

print("Total effects of deactivating component reads through L0H7 -> MLP1 and more -> MLP2")
with model.hooks([('blocks.2.mlp.hook_post', ablate_mlp_hook)]):
    ablated_loss, ablated_cache = model.run_with_cache(prompts, return_type="loss", loss_per_token=True)

def deactivate_components_hook(value, hook):
    value = ablated_cache[hook.name]
    return value  

total_effects = []
for component in components:
    deactivate_component_hooks = [(component, deactivate_components_hook)]
    with model.hooks(fwd_hooks=deactivate_component_hooks):
        activated_direct_effect_loss, _ = model.run_with_cache(prompts, return_type="loss", loss_per_token=True)
        total_effects.append((activated_direct_effect_loss - original_loss).mean(0)[-1])

for component, total_effect in zip(components, total_effects):
    if total_effect > 0.5:
        print(component, total_effect.item())

Total effects of deactivating component reads from L0H7
blocks.1.mlp.hook_post 2.918806791305542
blocks.2.mlp.hook_post 0.8668676614761353
blocks.5.mlp.hook_post 1.6885840892791748
blocks.0.hook_attn_out 2.9198434352874756
Direct effects of deactivating component reads from L0H7
blocks.1.mlp.hook_post 1.445307731628418
blocks.2.mlp.hook_post 1.8252601623535156
blocks.5.mlp.hook_post 1.6885840892791748
Total effects of deactivating component reads through L0H7 -> MLP1
blocks.1.mlp.hook_post 2.918806791305542
blocks.2.mlp.hook_post 1.889991283416748
blocks.5.mlp.hook_post 0.6976528763771057
Total effects of deactivating component reads through L0H7 -> MLP1 and more -> MLP2
blocks.2.mlp.hook_post 1.889991283416748
blocks.3.mlp.hook_post 0.5435132384300232


In [9]:
from collections import deque

original_loss, original_cache = model.run_with_cache(prompts, return_type="loss", loss_per_token=True)
with model.hooks([(l0_heads, ablate_l2h7_hook)]):
    ablated_loss, ablated_cache = model.run_with_cache(prompts, return_type="loss", loss_per_token=True)

def deactivate_components_hook(value, hook):
    value = ablated_cache[hook.name]
    return value  

def activate_components_hook(value, hook):
    value = original_cache[hook.name]
    return value

components = []
for component_prefix in ["mlp.hook_post", "attn.hook_result"]:
    for block in range(6):
        components.append(f"blocks.{block}.{component_prefix}")

class Node():
    def __init__(self, layer: int, component: Literal["mlp", "attn"], index: None | int | list[int], children: list["Node"] = []):
        self.layer = layer
        self.component = component
        self.index = index # neuron or head index
        self.name = f"blocks.{layer}.attn.hook_result" if component == "attn" else f"blocks.{layer}.mlp.hook_post"
        self.children = children

def activate_path(path: list[Node], ablated_cache, original_cache):
    '''First node in path patched in from ablated cache. Future nodes patched in from previous nodes.
    If in path, calculate as normal starting from ablated cache. Else read from original_cache until layer
    with final node in path. Then continue without hooks.'''    
    def get_activate_components_hooks():
        def activate_components_hook(value, hook):
            value = original_cache[hook.name]
            return value
            
        node_names = [node.name for node in path]
        other_components = []
        for component in components:
            if component not in node_names and int(component.split(".")[1]) <= path[-1].layer:
                other_components.append(component)
        return [(other_component, activate_components_hook) for other_component in other_components]

    def get_deactivate_components_hooks():
        hooks = []
        for node in path:
            if node.component == 'mlp':
                component = f'blocks.{node.layer}.mlp.hook_post'
            else:
                component = f'blocks.{node.layer}.attn.hook_result'

            def deactivate_components_hook(value, hook):
                if node.index is None:
                    value = ablated_cache[hook.name]
                elif "mlp" in hook.name:
                    value[:, :, node.index] = ablated_cache[hook.name][:, :, node.index]

                    mask = torch.ones_like(value, dtype=bool)
                    mask[:, :, node.index, :] = False
                    value[mask] = original_cache[hook.name][mask]
                else:
                    value[:, :, node.index, :] = ablated_cache[hook.name][:, :, node.index, :]
                    mask = torch.ones_like(value, dtype=bool)
                    mask[:, :, node.index, :] = False
                    value[mask] = original_cache[hook.name][mask]
                return value  
            hooks.append((component, deactivate_components_hook))
        return hooks
    
    with model.hooks(fwd_hooks=get_deactivate_components_hooks()+get_activate_components_hooks()):
        patched_loss, patched_cache = model.run_with_cache(prompts, return_type="loss", loss_per_token=True)
    return patched_loss, patched_cache

paths = [
    [Node(0, "attn", 7)],
    [Node(0, "attn", 7), Node(1, "mlp", None), Node(2, "mlp", None)]
]
paths[1][0].children = [paths[1][1]]
paths[1][1].children = [paths[1][2]]

for path in paths:
    patched_loss, patched_cache = activate_path(path, ablated_cache, original_cache)
    diff = (patched_loss - original_loss).mean(0)[-1]
    print(diff.item())

2.9198434352874756
2.4992523193359375


In [10]:
from collections import deque

original_loss, original_cache = model.run_with_cache(prompts, return_type="loss", loss_per_token=True)
with model.hooks([(l0_heads, ablate_l2h7_hook)]):
    ablated_loss, ablated_cache = model.run_with_cache(prompts, return_type="loss", loss_per_token=True)

def deactivate_components_hook(value, hook):
    value = ablated_cache[hook.name]
    return value  

def activate_components_hook(value, hook):
    value = original_cache[hook.name]
    return value

components = []
for component_prefix in ["mlp.hook_post", "attn.hook_result"]:
    for block in range(6):
        components.append(f"blocks.{block}.{component_prefix}")

class Node():
    def __init__(self, layer: int, component: Literal["mlp", "attn"], index: None | int | list[int], children: list["Node"] = []):
        self.layer = layer
        self.component = component
        self.index = index # neuron or head index
        self.name = f"blocks.{layer}.attn.hook_result" if component == "attn" else f"blocks.{layer}.mlp.hook_post"
        self.children = children

def get_node_list(root):
    if not root: return []
    queue = deque([root])
    result = []
    
    while queue:
        for _ in range(len(queue)):
            node = queue.popleft()
            result.append(node)
            queue += node.children
    return result

def get_activate_components_hooks(root, original_cache):
    def activate_components_hook(value, hook):
        value = original_cache[hook.name]
        return value
        
    node_list = get_node_list(root)
    node_names = [node.name for node in node_list]
    other_components = []
    for component in components:
        if component not in node_names and int(component.split(".")[1]) <= node_list[-1].layer:
            other_components.append(component)
    return [(other_component, activate_components_hook) for other_component in other_components]

def activate_tree(root: Node, ablated_cache, original_cache):
    '''First node in path patched in from ablated cache. Future nodes patched in from previous nodes.
    If in path, calculate from previous cache starting from ablated cache. Else read from original_cache until layer
    with final node in path. Then continue without hooks.
    
    Maximum efficiency would be to proceed in layer order, using a DFS that terminates when it reaches a layer but saves
    each node in the layer to a queue. Use a current layer iterator

    Each node should have a cache attached
    '''
    def get_deactivate_components_hooks(current_layer_nodes, patched_cache):
        hooks = []
        for node in current_layer_nodes:
            def deactivate_components_hook(value, hook):
                if node.index is None:
                    value = patched_cache[hook.name]
                elif "mlp" in hook.name:
                    value[:, :, node.index] = patched_cache[hook.name][:, :, node.index]

                    mask = torch.ones_like(value, dtype=bool)
                    mask[:, :, node.index, :] = False
                    value[mask] = original_cache[hook.name][mask]
                else:
                    value[:, :, node.index, :] = patched_cache[hook.name][:, :, node.index, :]
                    mask = torch.ones_like(value, dtype=bool)
                    mask[:, :, node.index, :] = False
                    value[mask] = original_cache[hook.name][mask]
                return value  
            hooks.append((node.name, deactivate_components_hook))
            return hooks
    
    def patch_nodes(root):
        if not root: return []
        queue = deque([root])
        current_layer = root.layer
        patched_cache = ablated_cache
        activate_components_hooks = get_activate_components_hooks(root, original_cache)
        current_layer_nodes = []
        
        while queue:
            for _ in range(len(queue)):
                if queue[0].layer == current_layer:
                    node = queue.popleft()
                    current_layer_nodes.append(node)
                    queue += node.children
            
            # todo clean run, patched in each parent node of children nodes 
            half_clean_cache = patched_cache
            with model.hooks(activate_components_hooks+get_deactivate_components_hooks(current_layer_nodes, patched_cache)):
                patched_loss, patched_cache = model.run_with_cache(prompts, return_type="loss", loss_per_token=True)
            current_layer += 1
        return patched_loss, patched_cache
    
    return patch_nodes(root)

mlp2 = Node(2, "mlp", None)
mlp1 = Node(1, "mlp", None, [mlp2])
attn0 = Node(0, "attn", 7, [mlp1])

patched_loss, patched_cache = activate_tree(attn0, ablated_cache, original_cache)
diff = (patched_loss - original_loss).mean(0)[-1]
print(diff.item())

patched_loss, patched_cache = activate_tree(mlp2, ablated_cache, original_cache)
diff = (patched_loss - original_loss).mean(0)[-1]
print(diff.item())

AttributeError: 'str' object has no attribute 'component'