In [2]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer, ActivationCache
from typing import List
from tqdm.auto import tqdm
import plotly.graph_objects as go

# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

import haystack_utils
from haystack_utils import load_txt_data, get_mlp_activations, line, imshow

%reload_ext autoreload
%autoreload 2

In [3]:
model = HookedTransformer.from_pretrained("pythia-70m-v0", fold_ln=True, device=device)

Using pad_token, but it is not set yet.


Loaded pretrained model pythia-70m-v0 into HookedTransformer


In [4]:
kde_french = load_txt_data("kde4_french.txt")
kde_english = load_txt_data("kde4_english.txt")

kde4_french.txt: Loaded 1007 examples with 505 to 5345 characters each.
kde4_english.txt: Loaded 1007 examples with 501 to 5295 characters each.


In [5]:
french_activations = get_mlp_activations(kde_french, 3, model, num_prompts=100, mean=True)
english_activations = get_mlp_activations(kde_english, 3, model, num_prompts=100, mean=True)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

## Attention Heads - L4 and L5

Plot the difference in attention patterns on French data with and without L3N609 disabled

In [None]:
def get_head_activations(
    prompts: List[str],
    layer: int,
    model: HookedTransformer,
    num_prompts: int = -1,
    context_crop_start=10,
    context_crop_end=400,
    mean=True
):
    """
    Runs the model through a list of prompts and stores the head patterns for a given layer. 
    The mean calculation is somewhat complex because each head has a different attention matrix size.
    We pad the patterns with zeros to the maximum context length so they can be stacked. 
    We saves a running count of attention head size to allow for mean pooling. To use the counts,
    we convert it to a matrix where [0, 0] is the first position count, [0, 1], [1, 0], and [1, 1] all 
    have the second position count, and so on. Then we can do elementwise division over the summed
    attention pattern to get the mean.

    A different implementation is done for MLPs in Haystack_cleaned using a mask.
    """
    max_ctx = context_crop_end - context_crop_start
    position_counts = torch.zeros(max_ctx).cuda()
    patterns = []
    pattern_label = f'blocks.{layer}.attn.hook_pattern'
    if num_prompts == -1:
        num_prompts = len(prompts)
    for i in tqdm(range(num_prompts)):
        tokens = model.to_tokens(prompts[i])
        _, cache = model.run_with_cache(tokens)
        # cache[pattern_label].shape == [batch head query_pos key_pos]
        pattern = cache[pattern_label][:, :, context_crop_start:context_crop_end, context_crop_start:context_crop_end]
        patterns.append(pattern)
        position_counts[:pattern.shape[2]] += 1
    
    padded_patterns = []
    for pattern in patterns:
        pad = (0, max_ctx - pattern.shape[-1], 0, max_ctx - pattern.shape[-1])
        padded_patterns.append(torch.nn.functional.pad(pattern, pad))
    
    patterns = torch.concat(padded_patterns, dim=0)
    if mean:
        scaling_matrix = torch.zeros(max_ctx, max_ctx).cuda()
        for row in range(max_ctx):
            for col in range(max_ctx):
                scaling_matrix[row, col] = position_counts[max(row, col)]

        print(scaling_matrix.shape)
        print(patterns.shape)
        return patterns.sum(dim=0) / scaling_matrix # [num_prompts head max_pos_len max_pos_len] 
    return patterns

In [6]:
neurons = [609]

def ablate_neuron_hook(value, hook):
    value[:, :, neurons] = english_activations[neurons]
    return value


def get_attention_pattern_activations(original_cache: ActivationCache, ablated_cache: ActivationCache, layer: int, n_pos=-1):
    block_name = f'blocks.{layer}.attn.hook_pattern'
    original_activations = original_cache[block_name]
    ablated_activations = ablated_cache[block_name]

    difference = original_activations.mean((0)) - ablated_activations.mean((0))

    if n_pos == -1:
        n_pos = difference.shape[1]

    for i, pattern in enumerate(difference):
        imshow(pattern[:n_pos, :n_pos].cpu(), title=f"L{layer}H{i} Difference between attention patterns with and without French neuron set to mean \"non-French data\" value <br> French dataset <br> Blue means the activation is more prevalent when French neuron enabled")

## Plot attention patterns

In [7]:
tokens = model.to_tokens(kde_french[:5])
original_loss, original_cache = model.run_with_cache(tokens, return_type="loss")
with model.hooks(fwd_hooks=[(f'blocks.3.mlp.hook_post', ablate_neuron_hook)]):
    ablated_loss, ablated_cache = model.run_with_cache(tokens, return_type="loss")

get_attention_pattern_activations(original_cache, ablated_cache, layer=4, n_pos=200)

In [8]:
get_attention_pattern_activations(original_cache, ablated_cache, layer=5, n_pos=200)

## Investigate

These heads all output into the unembed directly. We should see what their output values align with in the unembed.

Theory: the attention heads that attend to the BOS token when French neuron is enabled are language specific heads.
How to disprove: if we ablate these heads on French text and the loss increases significantly, the heads are relevant to French text.

BOS when French neuron enabled: L4H1, L4H2, L4H3
BOS when French neuron disabled: L4H0

L5H1 is some kind of previous tokens head that attends to different recent tokens in French vs. English. POS dependent?
L5H2 is one of a couple of heads with vertical French activation stripes.
L5H7 shows a clear pattern of more self-attention when L3N609 is enabled and more previous token attention when it's disabled"

In [9]:
french_patterns = get_head_activations(kde_french, 4, model, num_prompts=100, context_crop_start=0, context_crop_end=400, mean=True)

  0%|          | 0/100 [00:00<?, ?it/s]

torch.Size([400, 400])
torch.Size([100, 8, 400, 400])


In [10]:
print(french_patterns.shape)
print(french_patterns.max())

torch.Size([8, 400, 400])
tensor(1., device='cuda:0')


In [11]:
# Ablate French BOS heads and see if the loss increases much on French text

# 2% increase in loss from naive zero ablation
# 1.43% increase in loss from mean ablation, mean calculated over French data
model = model.cuda()

from collections import defaultdict

bos_french = defaultdict(list, {
    4: [1, 2, 3]
})
bos_non_french = defaultdict(list, {
  4: [0]
})


act_label_l4 = f'blocks.4.attn.hook_pattern'
def disable_head_hook(value, hook):
  pos = value.shape[2]
  # print(value[:, bos_french[4], :, :].shape)  # [batch head, query_pos, key_pos]
  value[:, bos_french[4], :, :] = french_patterns[bos_french[4], :pos, :pos].unsqueeze(dim=0) # mean head values on French text
  return value


def compare_loss_with_ablated(data, ablate_hook):
  average_original_loss = 0
  average_ablated_loss = 0

  for sample in data:
    tokens = model.to_tokens(sample)
    tokens = tokens.cuda()[:, :400]
    average_original_loss += model(tokens, return_type="loss")
    average_ablated_loss +=  model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(act_label_l4, ablate_hook)])

  average_original_loss /= len(data)
  average_ablated_loss /= len(data)

  print(f"Full model loss: {average_original_loss:.6f}")
  print(f"Ablated MLP layer loss: {average_ablated_loss:.6f}")
  print(f"% increase: {((average_ablated_loss - average_original_loss) / average_original_loss).item() * 100:.6f}")

# Loss increase from mean ablating BOS heads on French text
compare_loss_with_ablated(kde_french, disable_head_hook)

Moving model to device:  cuda


Full model loss: 3.591627
Ablated MLP layer loss: 3.642997
% increase: 1.430275


In [12]:
# Repeat the process with a synthetic attention pattern - 1 on BOS for everything
# 1.6% increase in loss

bos_pattern_l4 = torch.zeros((1, len(bos_french[4]), 400, 400)).cuda()
bos_pattern_l4[:, :, :, 0] = 1
def synthetic_bos_pattern_hook(value, hook):
  pos = value.shape[2]
  # print(value[:, bos_french[4], :, :].shape)  # [batch head, query_pos, key_pos]
  value[:, bos_french[4], :, :] = bos_pattern_l4[:, :, :pos, :pos]
  return value

# Loss increase from mean ablating BOS heads on French text
compare_loss_with_ablated(kde_french, synthetic_bos_pattern_hook)

Full model loss: 3.591627
Ablated MLP layer loss: 3.650235
% increase: 1.631784


Let's try ablating each BOS attention head with the synthetic head in turn, to see if there are any the synthetic head works well for.

Results: loss increases of 0.3%, 0.27%, 0.85% respectively.


In [13]:
bos_pattern = torch.zeros((1, 400, 400)).cuda()
bos_pattern[:, :, 0] = 1

for i in bos_french[4]:
    def single_bos_pattern_hook(value, hook):
        pos = value.shape[2]
        # print(value[:, bos_french[4], :, :].shape)  # [batch head, query_pos, key_pos]
        value[:, i, :, :] = bos_pattern[:, :pos, :pos]
        return value

    # Loss increase from mean ablating BOS heads on French text
    compare_loss_with_ablated(kde_french, single_bos_pattern_hook)

Full model loss: 3.591627
Ablated MLP layer loss: 3.602838
% increase: 0.312127
Full model loss: 3.591627
Ablated MLP layer loss: 3.601381
% increase: 0.271561
Full model loss: 3.591627
Ablated MLP layer loss: 3.622281
% increase: 0.853484
