In [1]:
%load_ext autoreload
%autoreload 2

import os
from os.path import expanduser
home = expanduser("~")

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from typing import List, Union, Optional
from functools import partial
import copy

import itertools
import dataclasses
from IPython.display import HTML
import pandas as pd
from neel_plotly import line, imshow, scatter
from jaxtyping import Float
import plotly.io as pio

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import transformer_lens.patching as patching
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f48cb570e90>

In [4]:
from utils_sva import get_batched_dataset, clean_blocks_labels, paper_plot
from utils_sva import get_logit_diff
from safetensors import safe_open
from safetensors.torch import save_file

### Load Model

In [5]:
n_devices = torch.cuda.device_count()

model = HookedTransformer.from_pretrained(
    "gemma-2b",
    center_unembed=True,
    center_writing_weights=False,
    fold_ln=False,
    fold_value_biases=False,
    #n_devices=2
)
model.set_use_attn_result(True)
# Get the default device used
device: torch.device = utils.get_device()

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model gemma-2b into HookedTransformer


In [6]:
from utils_sva import read_files
# Get suitable list of verbs and nouns in singular and plural
#examples_valid_verbs_tuples = [('es', 'son'), ('tiene', 'tienen'), ('fue', 'fueron'), ('era', 'eran')]
examples_valid_verbs_tuples = [('tuvo', 'tuvieron')]
examples_valid_verbs_tuples_pred = [('fue', 'fueron'),('era', 'eran')]
examples_valid_nouns = [('cantante', 'cantantes'), ('ingeniero', 'ingenieros'), ('ministro', 'ministros'), ('piloto', 'pilotos')]
verb_list_tuples = list(set(examples_valid_verbs_tuples + read_files(model, "datasets/plausible_spa_singular_plural_past_verbs.txt")))
noun_list_tuples = list(set(examples_valid_nouns + read_files(model, "datasets/spa_singular_plural_nouns.txt")))

### Load Dataset

In [7]:
from datasets import load_dataset
import random
random.seed(10)
np.random.seed(10)
import gc
gc.collect()
torch.cuda.empty_cache()

dataset_type = 'both' # singular / plural / both
language = 'english' # english / spanish / both
num_samples = 200
batch_size = 10
start_at = 0

answers = []
src_list = []
base_list = []
src_label_list = []
base_label_list = []
ex_number_list = []
ex_lang_list = []

if language=='english' or language=='both':
    len_sv_num = 6 # sentences should have 6 tokens

    hf_dataset = load_dataset("aryaman/causalgym", split='train')
    hf_dataset = hf_dataset.filter(lambda example: example['task']=='agr_sv_num_subj-relc')#agr_sv_num_pp

    if dataset_type=='singular':
        dataset = hf_dataset.filter(lambda example: example["base_type"]=='singular')
    elif dataset_type=='plural':
        dataset = hf_dataset.filter(lambda example: example["base_type"]=='plural')
    else:
        dataset = hf_dataset

    match_counter = start_at
    i=start_at
    while match_counter<num_samples:
        for type_sentence in ['src','base']:
            for word in dataset[i][type_sentence]:
                if len(word.split())>1: # eliminate compound words like ' taxi driver'
                    break
        
        src = ''.join(dataset[i]['src']).replace('<|endoftext|>','')
        base = ''.join(dataset[i]['base']).replace('<|endoftext|>','')
        if len(src.split())==len_sv_num and len(base.split())==len_sv_num:
            src_list.append(src)
            base_list.append(base)
            src_label = dataset[i]['src_label']
            base_label = dataset[i]['base_label']
            src_label_list.append(src_label)
            base_label_list.append(base_label)
            answers.append((base_label, src_label))
            ex_lang_list.append('English')
            print(base.split()[1])
            print(f'{base} {base_label}\n{src} {src_label}')
            if base.split()[1].endswith('s'):
                # Plural
                ex_number_list.append('Plural')
            else:
                # Singular
                ex_number_list.append('Singular')
            match_counter += 1
        i += 1
        
        

if language=='spanish' or language=='both':
    # noun_list_sing = ['cantante', 'ingeniero', 'ministro', 'piloto']
    # noun_list_plural = ['cantantes', 'ingenieros', 'ministros', 'pilotos']
    noun_list_sing = [f' {noun_tuple[0]}' for noun_tuple in noun_list_tuples]
    noun_list_plural = [f' {noun_tuple[1]}' for noun_tuple in noun_list_tuples]

    verb_1_list_sing = [f' {verb_tuple[0]}' for verb_tuple in verb_list_tuples]
    verb_1_list_plural = [f' {verb_tuple[1]}' for verb_tuple in verb_list_tuples]

    verb_2_list_sing = [f' {verb_tuple[0]}' for verb_tuple in examples_valid_verbs_tuples_pred]
    verb_2_list_plural = [f' {verb_tuple[1]}' for verb_tuple in examples_valid_verbs_tuples_pred]
    print(verb_2_list_plural)

    permutations_list = [[i,j,k] for i in range(len(noun_list_sing)) for j in range(len(verb_1_list_sing)) for k in range(len(noun_list_sing)) if k!=i]
    permutations_array = np.array(permutations_list)
    np.random.shuffle(permutations_array)
    counter = 0

    for i,j,k in permutations_array:
        counter += 1
        sent_1 = f'Los{noun_list_plural[k]} que{verb_1_list_plural[j]} al{noun_list_sing[i]}'
        sent_2 = f'El{noun_list_sing[k]} que{verb_1_list_sing[j]} al{noun_list_sing[i]}'

        if dataset_type=='singular':
            rdm_num = 0
        elif dataset_type=='plural':
            rdm_num = 1
        else:
            rdm_num = int(round(random.uniform(0, 1), 0))

        # Avoid verb repetition
        verbs_indices = list(range(0,len(verb_2_list_sing)))
        already_verb = j
        available_verb_indices = verbs_indices[:already_verb] + verbs_indices[already_verb+1:]
        ver_2_idx = np.random.choice(available_verb_indices)
        
        if rdm_num== 0:
            src = sent_1
            base = sent_2
            src_label = verb_2_list_plural[ver_2_idx]
            base_label = verb_2_list_sing[ver_2_idx]
        else:
            src = sent_2
            base = sent_1
            src_label = verb_2_list_sing[ver_2_idx]
            base_label = verb_2_list_plural[ver_2_idx]
        if base_label[-1] == 'n':
            # Plural
            ex_number_list.append('Plural')
        else:
            # Singular
            ex_number_list.append('Singular')

        src_list.append(src)
        base_list.append(base)
        src_label_list.append(src_label)
        base_label_list.append(base_label)
        answers.append((base_label, src_label))
        ex_lang_list.append('Spanish')

        print(f'{base} {base_label}\n{src} {src_label}\n')
        
    
        if counter >=num_samples:
            break

ministers
The ministers that disguised the executive  have
The athlete that disguised the executive  has
authors
The authors that injured the secretary  are
The guard that injured the secretary  has
athlete
The athlete that disguised the executive  has
The ministers that disguised the executive  have
guard
The guard that injured the secretary  has
The authors that injured the secretary  are
ministers
The ministers that embarrassed the manager  are
The executive that embarrassed the manager  has
executive
The executive that embarrassed the manager  has
The ministers that embarrassed the manager  are
executive
The executive that injured the doctors  is
The executives that injured the doctors  have
farmer
The farmer that ignored the teachers  is
The authors that ignored the teachers  are
executives
The executives that injured the doctors  have
The executive that injured the doctors  is
authors
The authors that ignored the teachers  are
The farmer that ignored the teachers  is
managers
The

In [8]:
batches_src_tokens, batches_base_tokens, batches_answer_token_indices = get_batched_dataset(model,
                                                                                            base_list,
                                                                                            src_list,
                                                                                            answers,
                                                                                            batch_size=batch_size)

### Patching Experiments

In [9]:
# Create folder to save plots
images_dir = f"{home}/circuits_languages/images"
if not os.path.exists(images_dir):
    os.makedirs(images_dir)

In [22]:
# Only one batch (first)
batch = 0
batches = len(batches_src_tokens)

src_logit_diff_list = []
base_logit_diff_list = []
for batch in range(batches):
    base_tokens = batches_base_tokens[batch]
    src_tokens = batches_src_tokens[batch]
    answer_token_indices = batches_answer_token_indices[batch]

    base_logits, base_cache = model.run_with_cache(base_tokens)
    src_logits, corrupted_cache = model.run_with_cache(src_tokens)
    answer_token_indices = answer_token_indices.to(base_logits.device)
    base_logit_diff = get_logit_diff(base_logits, answer_token_indices, mean=False)
    base_logit_diff_list.append(base_logit_diff)
    print(f"Base logit diff batch mean: {base_logit_diff.mean().item():.4f}")

    src_logit_diff = get_logit_diff(src_logits, answer_token_indices, mean=False)
    src_logit_diff_list.append(src_logit_diff)
    print(f"Source logit diff batch mean: {src_logit_diff.mean().item():.4f}")

src_logit_diff = torch.cat(src_logit_diff_list,0).mean(0)
base_logit_diff = torch.cat(base_logit_diff_list,0).mean(0)

# Patching metric
CLEAN_BASELINE = base_logit_diff
CORRUPTED_BASELINE = src_logit_diff
def ioi_metric(logits, answer_token_indices=answer_token_indices):
    return (get_logit_diff(model, logits, answer_token_indices) - CORRUPTED_BASELINE) / (CLEAN_BASELINE  - CORRUPTED_BASELINE)


Base logit diff batch mean: 3.2576
Source logit diff batch mean: -3.2576
Base logit diff batch mean: 2.8009
Source logit diff batch mean: -2.9403
Base logit diff batch mean: 3.4275
Source logit diff batch mean: -3.2882
Base logit diff batch mean: 3.1141
Source logit diff batch mean: -3.3294
Base logit diff batch mean: 3.2108
Source logit diff batch mean: -2.7190
Base logit diff batch mean: 3.2316
Source logit diff batch mean: -3.5238
Base logit diff batch mean: 2.7914
Source logit diff batch mean: -2.7757
Base logit diff batch mean: 3.3525
Source logit diff batch mean: -2.3120
Base logit diff batch mean: 2.7866
Source logit diff batch mean: -3.8271
Base logit diff batch mean: 3.4785
Source logit diff batch mean: -2.9268
Base logit diff batch mean: 2.9302
Source logit diff batch mean: -3.4819
Base logit diff batch mean: 3.0944
Source logit diff batch mean: -2.8756
Base logit diff batch mean: 3.4824
Source logit diff batch mean: -3.7012
Base logit diff batch mean: 3.5261
Source logit dif

In [23]:
fig = px.bar([base_logit_diff.item(), src_logit_diff.item()])
fig.update(layout_coloraxis_showscale=False,layout_showlegend=False)
#fig.update_layout(legend_title_text='Subject Number')
fig.update_layout(
    xaxis_title="", yaxis_title="Logit Difference"
)
fig.update_layout(
    xaxis = dict(
        tickmode = 'array',
        tickvals = [0, 1],
        ticktext = ['Clean input', 'Corrupted input']
    ),
    font=dict(
        size=15,  # Set the font size here
    )
)
fig = paper_plot(fig, tickangle=0)
fig.update_layout(yaxis_range=[-3.4,3.4])
fig.show()
pio.write_image(fig, f'{images_dir}/{language}_logit_diffs.png',scale=5, width=550, height=350)


In [15]:
y_labels = [f'{str(layer)}' for layer in range(model.cfg.n_layers-1,-1,-1)]
patching_plot_sentence = model.to_str_tokens(model.to_tokens('The executives that embarrassed the manager'))
patching_plot_sentence

['<bos>', 'The', ' executives', ' that', ' embarrassed', ' the', ' manager']

In [16]:
# Patching residual streams
resid_pre_act_patch_results = patching.get_act_patch_resid_pre(model, src_tokens, base_cache, ioi_metric)
imshow(torch.flip(resid_pre_act_patch_results, dims=[0]), 
       yaxis="Layer", 
       xaxis="Position",
       x=[f"{tok} {i}" for i, tok in enumerate(patching_plot_sentence)],
       y=y_labels,
       title="resid_pre Activation Patching")

  0%|          | 0/126 [00:00<?, ?it/s]

In [None]:
# Patching attention heads all positions
attn_head_out_all_pos_act_patch_results = patching.get_act_patch_attn_head_out_all_pos(model, src_tokens, base_cache, ioi_metric)
fig_heads = imshow(torch.flip(attn_head_out_all_pos_act_patch_results, dims=[0]), 
       yaxis="Layer", 
       xaxis="Head",
       x=[f'{head}' for head in range(model.cfg.n_heads)],
       y=y_labels,
       title="attn_head_out Activation Patching (All Pos)",
       return_fig=True)
fig_heads.show()
pio.write_image(fig_heads, f'{images_dir}/patching_{language}_heads_all_pos.png',scale=5, width=400, height=500)

In [None]:
# Patching attention heads last_position
attn_head_out_act_patch_results = patching.get_act_patch_attn_head_out_by_pos(model, src_tokens, base_cache, ioi_metric)
fig_heads = imshow(torch.flip(attn_head_out_act_patch_results[:,-1], dims=[0]), 
       yaxis="Layer", 
       xaxis="Head",
       x=[f'{head}' for head in range(model.cfg.n_heads)],
       y=y_labels,
       title="Attn Head Output",
       return_fig=True)
fig_heads.show()
pio.write_image(fig_heads, f'{images_dir}/patching_{language}_heads_last_pos.png',scale=5, width=350, height=500)

In [None]:
every_block_result = patching.get_act_patch_block_every(model, src_tokens, base_cache, ioi_metric)
fig = imshow(torch.flip(every_block_result, dims=[1]), facet_col=0,
        y=y_labels,
        facet_labels=["Residual Stream", "Attn Output", "MLP Output"],
        title="Activation Patching Per Block", xaxis="Position", yaxis="Layer",
        zmax=1, zmin=-1, x= patching_plot_sentence,
        return_fig=True
        )
fig.update_xaxes(tickangle=45)
fig.show()
pio.write_image(fig, f'{images_dir}/patching_{language}_res_streams_attn_mlp.png',scale=5, width=800, height=500)

### Direct Logit Attribution

In [25]:
from utils_sva import residual_stack_to_logit_diff

In [27]:
# DLA on accumulated residuals
import gc

base_logit_diff_list = []
src_logit_diff_list = []

batches_base_logits = []

list_logit_lens_logit_diffs= []
for batch in range(batches):
    # Get clean tokens and answer indices from batches
    base_tokens = batches_base_tokens[batch]
    answer_token_indices = batches_answer_token_indices[batch]

    base_logits, base_cache = model.run_with_cache(base_tokens)
    answer_token_indices = answer_token_indices.to(base_logits.device)
    original_average_logit_diff = get_logit_diff(base_logits, answer_token_indices)
    answer_residual_directions = model.tokens_to_residual_directions(answer_token_indices)

    # Difference of unembedding vectors
    logit_diff_directions = (
        answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
    )
    accumulated_residual, labels = base_cache.accumulated_resid(
        layer=-1, incl_mid=True, pos_slice=-1, return_labels=True
    )
    logit_lens_logit_diffs = residual_stack_to_logit_diff(model, accumulated_residual, base_cache, logit_diff_directions, mean=False)
    list_logit_lens_logit_diffs.append(logit_lens_logit_diffs)
    

    torch.cuda.empty_cache()
    gc.collect()
cat_logit_lens_logit_diffs = torch.cat(list_logit_lens_logit_diffs, 0)

In [31]:
fig = line(
        cat_logit_lens_logit_diffs.mean(0),
        x=np.arange(model.cfg.n_layers * 2 + 1) / 2,
        hover_name=labels,
        title="Logit Difference From Accumulate Residual Stream",
        return_fig=True,
        #labels={"x": "Layer", "y": "Logit Difference"},
    )
fig = paper_plot(fig)
fig.show()

In [33]:
import gc

base_logit_diff_list = []
src_logit_diff_list = []

batches_base_logits = []

list_per_layer_logit_diffs= []
for batch in range(batches):
    # Get clean tokens and answer indices from batches
    base_tokens = batches_base_tokens[batch]
    answer_token_indices = batches_answer_token_indices[batch]

    base_logits, base_cache = model.run_with_cache(base_tokens)
    answer_token_indices = answer_token_indices.to(base_logits.device)
    original_average_logit_diff = get_logit_diff(base_logits, answer_token_indices)
    answer_residual_directions = model.tokens_to_residual_directions(answer_token_indices)

    # Difference of unembedding vectors
    logit_diff_directions = (
        answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
    )

    per_layer_residual, labels = base_cache.decompose_resid(
    layer=-1, pos_slice=-1, return_labels=True
    )
    per_layer_logit_diffs = residual_stack_to_logit_diff(model,per_layer_residual[:], base_cache,
                                                        logit_diff_directions, mean=False)


    list_per_layer_logit_diffs.append(per_layer_logit_diffs)
    

    torch.cuda.empty_cache()
    gc.collect()
cat_per_layer_logit_diffs = torch.cat(list_per_layer_logit_diffs, 0)

In [40]:
import plotly.graph_objs as go
def hex_to_rgba(h, alpha):
    '''
    converts color value in hex format to rgba format with alpha transparency
    '''
    return tuple([int(h.lstrip('#')[i:i+2], 16) for i in (0, 2, 4)] + [alpha])

add_std = False
mean = cat_per_layer_logit_diffs.mean(0).cpu()
std = cat_per_layer_logit_diffs.std(0).cpu()
y_upper = list(mean + std)
y_lower = list(mean - std)            
hex_color = '#636efa'
component_labels_to_pretty_labels = [clean_blocks_labels(label) for label in labels]

fig = line(mean, x=component_labels_to_pretty_labels,
            title="Logit Difference From Each Layer",
            #color=color,
            xaxis="Block", yaxis="Logit Difference",
            return_fig=True)
if add_std:
    # Add standard deviations in graph
    fig.add_trace(
            go.Scatter(
                x = component_labels_to_pretty_labels+component_labels_to_pretty_labels[::-1],
                y = y_upper+y_lower[::-1],
                fill = 'toself',
                fillcolor = 'rgba' + str(hex_to_rgba(
                    h=hex_color,
                    alpha=0.25
                )),
                line = dict(
                    color = 'rgba(0,0,0,0)'
                ),
                hoverinfo = "skip",
                showlegend = False,
            )
        )
fig.update_layout(
    font=dict(
        size=15,  # Set the font size here
    )
)
fig = paper_plot(fig, tickangle=90)
fig.show()
pio.write_image(fig, f'{images_dir}/logit_diff_{language}_blocks.pdf',scale=6, width=800, height=350)

In [None]:
per_head_residual, labels = cache.stack_head_results(
    layer=-1, pos_slice=-1, return_labels=True
)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)
per_head_logit_diffs = einops.rearrange(
    per_head_logit_diffs,
    "(layer head_index) -> layer head_index",
    layer=model.cfg.n_layers,
    head_index=model.cfg.n_heads,
)
imshow(
    per_head_logit_diffs[:],
    title="Logit Difference From Each Head",
    xaxis="Head", 
    yaxis="Layer",
)

Tried to stack head results when they weren't cached. Computing head results now


In [41]:
layer_index = 13
hook_name = utils.get_act_name('post', layer_index)
neurons_contrib_logit_diff_list = []
#output_attention_head = cache[hook_name][:,-1,attn_head_index]
for batch in range(batches):
    # Get clean tokens and answer indices from batches
    base_tokens = batches_base_tokens[batch]
    answer_token_indices = batches_answer_token_indices[batch]
    base_logits, base_cache = model.run_with_cache(base_tokens)
    #answer_token_indices = answer_token_indices.to(base_logits.device)
    answer_residual_directions = model.tokens_to_residual_directions(answer_token_indices)

    # Difference of unembedding vectors
    logit_diff_directions = (
        answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
    )

    _, base_cache = model.run_with_cache(base_tokens, names_filter=[hook_name, "ln_final.hook_scale"])
    output_hook_pos = base_cache[hook_name][:,-1]
    #list_output_hook.append(output_hook_pos)
    act_w_out_rows = einsum(
        f"batch d_ffn, d_ffn d_model \
        -> batch d_ffn d_model",
        base_cache[hook_name][:,-1],
        model.W_out[layer_index])

    # Scaling by final LN scaling value (we need this to get exact logit diffs)
    scaled_act_w_out_rows = act_w_out_rows / base_cache["ln_final.hook_scale"][:,-1].unsqueeze(1)

    # Dot product with unembeding diff directions (first apply final LN weights)
    neurons_contrib_logit_diff = einsum(
            f"batch d_ffn d_model, batch d_model \
            -> batch d_ffn",
            scaled_act_w_out_rows * model.ln_final.w,
            logit_diff_directions)
    neurons_contrib_logit_diff_list.append(neurons_contrib_logit_diff)
    torch.cuda.empty_cache()

In [42]:
neurons_contrib_logit_diff = torch.cat(neurons_contrib_logit_diff_list, 0)
neurons_contrib_logit_diff.shape

In [45]:
fig = line(neurons_contrib_logit_diff.mean(0),
title=f"Logit Difference From Each Neuron in MLP {layer_index}",
xaxis="Neuron", yaxis="Logit Difference",return_fig=True)
fig.update_layout(
    font=dict(
        size=15,  # Set the font size here
    )
)
fig = paper_plot(fig)
fig.show()
pio.write_image(fig, f'{images_dir}/MLP{layer_index}_neurons.pdf',scale=6, width=800, height=350)

In [None]:
# layer_index = 13
# # Getting act * w_out rows -> [batch d_ffn d_model]
# act_w_out_rows = einsum(
#         f"batch d_ffn, d_ffn d_model \
#         -> batch d_ffn d_model",
#         cache[f'blocks.{layer_index}.mlp.hook_post'][:,-1],
#         model.W_out[layer_index])

# # Scaling by final LN scaling value (we need this to get exact logit diffs)
# scaled_act_w_out_rows = act_w_out_rows / cache["ln_final.hook_scale"][:,-1].unsqueeze(1)

# # Dot product with unembeding diff directions (first apply final LN weights)
# neurons_contrib_logit_diff = einsum(
#         f"batch d_ffn d_model, batch d_model \
#         -> batch d_ffn",
#         scaled_act_w_out_rows * model.ln_final.w,
#         logit_diff_directions)

# # Check added logit diffs of all neurons add up to original MLP contribution to the logit diff (mean across batch)
# print(neurons_contrib_logit_diff.sum(-1).mean())
# torch.cuda.empty_cache()

# line(neurons_contrib_logit_diff.mean(0),
# title=f"Logit Difference From Each Neuron in MLP {layer_index}",
# labels={"x": "Neuron", "y": "Logit Difference"},)

### Attention Maps

In [25]:
from circuitsvis.attention import attention_heads

def visualize_attention_patterns(
    type_pattern: str,
    heads: Union[List[int], int, Float[torch.Tensor, "heads"]],
    local_cache: ActivationCache,
    local_tokens: torch.Tensor,
    title: Optional[str] = "",
    max_width: Optional[int] = 700,
    html: Optional[bool] = True
):
    # If a single head is given, convert to a list
    if isinstance(heads, int):
        heads = [heads]

    # Create the plotting data
    labels: List[str] = []
    patterns: List[Float[torch.Tensor, "dest_pos src_pos"]] = []

    # Assume we have a single batch item
    #batch_index = 0
    print('heads', heads)

    for head in heads:
        # Set the label
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        labels.append(f"L{layer}H{head_index}")

        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        if type_pattern == 'attn_weights':
            updated_pattern = local_cache["attn", layer][:, head_index]
            print(updated_pattern.shape)
        else:
            # We compute attention heads weighted value vectors a_{i,j} x_j W_V
            pattern = local_cache[f'blocks.{layer}.attn.hook_pattern']
            v = local_cache[f'blocks.{layer}.attn.hook_v']
            weighted_values = einsum(
                        "batch key_pos head_index d_head, \
                        batch head_index query_pos key_pos -> \
                        batch query_pos key_pos head_index d_head",
                        v,
                        pattern,
                    )# [batch, query_pos, key_pos, head_index, d_head]

            if type_pattern == 'value_weighted':
                # Value-weighted norms
                raw_inter_token_attribution = torch.norm(weighted_values, dim=-1, p=2)
                # weighted_values_norm -> [batch query_pos key_pos head_index]

            elif type_pattern == 'output_value_weighted' or type_pattern == 'distance_based':
                # We decompose attention heads further by computing a_{i,j} x_j W_OV
                output_weighted_values = einsum(
                        "batch query_pos key_pos head_index d_head, \
                            head_index d_head d_model -> \
                            batch query_pos key_pos head_index d_model",
                        weighted_values,
                        model.W_O[layer],
                    )

                # Check sum decomposition is equivalent to cached values
                output_heads = output_weighted_values.sum(2)
                output_attention = output_heads.sum(-2) + model.b_O[layer]
                assert torch.dist(output_attention, local_cache[f'blocks.{layer}.hook_attn_out']).item() < 1e-3 * local_cache[f'blocks.{layer}.hook_attn_out'].numel()

                if type_pattern == 'output_value_weighted':
                    # Output-value-weighted norms
                    # weighted_values_norm -> [batch query_pos key_pos head_index]
                    raw_inter_token_attribution = torch.norm(output_weighted_values, dim=-1, p=1)

                elif type_pattern == 'distance_based':
                    # Distance-based
                    EPS = 1e-5
                    # distance -> [batch query_pos key_pos head_index]
                    distance = -F.pairwise_distance(output_weighted_values, output_heads.unsqueeze(2),p=2)
                    # head_output_norm -> [batch query_pos head_index]
                    head_output_norm = torch.norm(output_heads, p=2, dim=-1)
                    raw_inter_token_attribution = (distance + head_output_norm.unsqueeze(2)).clip(min=EPS)
                    

            # Normalize over key_pos
            inter_token_attribution = raw_inter_token_attribution / raw_inter_token_attribution.sum(dim=-2,keepdim=True)
            updated_pattern = inter_token_attribution[:, :, :, head_index]
            
        print('before',updated_pattern.shape)
        patterns.append(updated_pattern)
        print(len(patterns))

    # Combine the patterns into a single tensor
    patterns: Float[torch.Tensor, "batch head_index dest_pos src_pos"] = torch.stack(
        patterns, dim=1
    )
    #patterns = torch.stack(patterns, 1)
    print('after',patterns.shape)

    if html:
        # Convert the tokens to strings (for the axis labels)
        str_tokens = model.to_str_tokens(local_tokens)
        # Circuitsvis Plot (note we get the code version so we can concatenate with the title)
        plot = attention_heads(
            attention=patterns, tokens=str_tokens, attention_head_names=labels
        ).show_code()
        # Display the title
        title_html = f"<h2>{title}</h2><br/>"
        # Return the visualisation as raw code
        return f"<div style='max-width: {str(max_width)}px;'>{title_html + plot}</div>"
    else:
        return patterns


In [19]:
head_name_to_number = {}

head_names_list = [f"L{layer}H{head}" for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
for i, head in enumerate(head_names_list):
    head_name_to_number[head] = i

heads = ['L13H7', 'L17H4']
heads_number = [head_name_to_number[head] for head in heads]

In [12]:
# for batch in range(batches):
# # Get clean tokens and answer indices from batches
batch = 0
base_tokens = batches_base_tokens[batch]
answer_token_indices = batches_answer_token_indices[batch]

base_logits, base_cache = model.run_with_cache(base_tokens)

In [13]:
base_logits.shape

torch.Size([10, 7, 256000])

In [26]:
top_k = 10
type_pattern = 'attn_weights' # attn_weights / value_weighted / output_value_weighted / distance_based

tokens = base_tokens
# top_positive_logit_attr_heads = torch.topk(
#     per_head_logit_diffs.flatten(), k=top_k
# ).indices

viz_attn = visualize_attention_patterns(
    type_pattern,
    heads_number,
    base_cache,
    tokens[0],
    f"Top {top_k} Positive Logit Attribution Heads",
    html=False
)

heads [111, 140]
torch.Size([10, 7, 7])
before torch.Size([10, 7, 7])
1
torch.Size([10, 7, 7])
before torch.Size([10, 7, 7])
2
after torch.Size([10, 2, 7, 7])


In [None]:
x=[f"{tok} {i}" for i, tok in enumerate(patching_plot_sentence)],
y=[f"{tok} {i}" for i, tok in enumerate(patching_plot_sentence)],

In [30]:
imshow(
    viz_attn[:,0].mean(0),
    title="",
    xaxis="", 
    yaxis="",
)

NameError: name 'viz_attn' is not defined

In [88]:
HTML(positive_html)

### Composition Attention head and MLP neurons

In [48]:
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

def run_pca(X, n_components):
    # Standardize data before applying PCA
    scaler = StandardScaler()
    X = scaler.fit_transform(X)
    pca = PCA(n_components=n_components)
    X_embedded = pca.fit_transform(X)
    return X_embedded, pca, scaler

def compute_diff_means(dataset_rep, class_list, class_1, class_2):
    # Get representations for each class
    indices_1 = [i for i, x in enumerate(class_list) if x == class_1]
    indices_2 = [i for i, x in enumerate(class_list) if x == class_2]
    rep_class_1 = dataset_rep[indices_1,:]
    rep_class_2 = dataset_rep[indices_2,:]
    # Average across batch dimension
    mean_class_1 = rep_class_1.mean(0)
    mean_class_2 = rep_class_2.mean(0)

    return mean_class_1 - mean_class_2

In [116]:
# Attention head
attn_layer_index = 13
attn_head_index = 7
list_output_hook = []
#hook_name = f'blocks.{attn_layer_index}.attn.hook_result'
hook_name = utils.get_act_name('result', attn_layer_index)
#hook_name = utils.get_act_name('resid_pre', attn_layer_index)
#output_attention_head = cache[hook_name][:,-1,attn_head_index]
for batch in range(batches):
    # Get clean tokens and answer indices from batches
    base_tokens = batches_base_tokens[batch]
    base_logits, base_cache = model.run_with_cache(base_tokens, names_filter=[hook_name])
    output_hook_pos = base_cache[hook_name][:,-1]
    list_output_hook.append(output_hook_pos)
#key_attention_head = cache[f'blocks.{attn_layer_index}.attn.hook_q'][:,-1,attn_head_index]
cat_output_hook = torch.cat(list_output_hook, 0)


In [117]:
ex_number_list_int = [0 if number=='Singular' else 1 for number in ex_number_list]
class_1 = 'Singular'
class_2 = 'Plural'

In [118]:
cat_output_hook.shape

torch.Size([200, 8, 2048])

In [57]:
method = 'diff_means'

if method == 'diff_means':
    directions = compute_diff_means(cat_output_hook, ex_number_list, 'Singular', 'Plural')
elif method == 'pca':
    pca_dirs_heads_list = []
    if 'resid' in hook_name:
        pca_matrix, pca, scaler = run_pca(cat_output_hook.cpu().numpy(), n_components=2)
        directions = torch.tensor(pca.components_).contiguous()
    else:
        for head in range(model.cfg.n_heads):
            cat_output_hook_head = cat_output_hook[:,head]
            pca_matrix, pca, scaler = run_pca(cat_output_hook_head.cpu().numpy(), n_components=2)
            pca_dirs_head = torch.tensor(pca.components_).contiguous()
            print(pca.explained_variance_ratio_.sum())
            pca_dirs_heads_list.append(pca_dirs_head.unsqueeze(1))

        directions = torch.cat(pca_dirs_heads_list, 1)



In [58]:
# Save diff in means directions
tensors_dir = f"{home}/mats/sva_tensors"
if not os.path.exists(tensors_dir):
    os.makedirs(tensors_dir)
tensors_dict = {}
if method == 'pca':
    tensors_dict[f"directions"] = directions
elif method == 'diff_means':
    tensors_dict[f"direction"] = directions
save_file(tensors_dict, f"{tensors_dir}/{model.cfg.model_name}_{dataset_type}_{language}_{hook_name}_{method}_singular_plural.safetensors")

In [121]:
cat_output_hook_head = cat_output_hook[:,attn_head_index]
pca_matrix, pca, scaler = run_pca(cat_output_hook_head.cpu().numpy(), n_components=2)

In [58]:
# MLP neuron
MLP_layer = 13
MLP_neuron = 2069

w_in_neuron_weights = model.W_in[MLP_layer][:,MLP_neuron]
w_in_neuron_weights = w_in_neuron_weights.cpu().detach() - scaler.mean_ / np.sqrt(scaler.var_)

w_in_neuron_pca_proj = pca.transform(w_in_neuron_weights.unsqueeze(0).cpu().numpy())


In [63]:
# pca_matrix = np.vstack([pca_matrix, w_in_neuron_pca_proj])
# number_proj = number + [2]

In [122]:
from PIL import ImageColor

html_colors = {
    'darkgreen' : '#138808',
    'green_drawio' : '#82B366',
    'dark_green_drawio' : '#557543',
    'dark_red_drawio' : '#990000',
    'blue_drawio' : '#6C8EBF',
    'orange_drawio' : '#D79B00',
    'red_drawio' : '#FF9999',
    'grey_drawio' : '#303030'}

def to_group(number_val, lang_val, alpha):
    if number_val=='Singular' and lang_val=='Spanish':
        return ('Spanish Singular', 'rgba' + str(tuple(list(ImageColor.getcolor(html_colors['dark_green_drawio'], "RGB")) + [alpha])))
    
    elif number_val=='Plural' and lang_val=='Spanish':
        return ('Spanish Plural', 'rgba' + str(tuple(list(ImageColor.getcolor(html_colors['green_drawio'], "RGB")) + [alpha])))
    
    elif number_val=='Singular' and lang_val=='English':
        return ('English Singular', 'rgba' + str(tuple(list(ImageColor.getcolor(html_colors['dark_red_drawio'], "RGB")) + [alpha]))
    )
    elif number_val=='Plural' and lang_val=='English':
        return ('English Plural', 'rgba' + str(tuple(list(ImageColor.getcolor(html_colors['red_drawio'], "RGB")) + [alpha])))
    else:
        print('ERROR!')


In [123]:
alpha = 1
rgba_list = []
labels = []
for number_val, lang_val in zip(ex_number_list, ex_lang_list):
    rgba_list.append(to_group(number_val, lang_val, alpha)[1])
    labels.append(to_group(number_val, lang_val, alpha)[0])

In [120]:
# df = pd.DataFrame(np.concatenate([pca_matrix, np.expand_dims(np.array(rgba_list),axis=-1), np.expand_dims(np.array(labels),axis=-1)], axis=-1), columns=['PC1', 'PC2', 'color', 'label'])
# df

In [124]:
fig = scatter(pca_matrix[:,0], pca_matrix[:,1],
                color=ex_number_list, return_fig=True,
                #symbol=labels,
                xaxis="PC1", yaxis="PC2",)
fig.update(layout_coloraxis_showscale=False,layout_showlegend=True)
fig.update_layout(legend_title_text='Subject Number')
fig.update_traces(marker=dict(size=10))
fig = paper_plot(fig)
fig.show()

In [38]:
pio.write_image(fig, f'images/{language}_pca_L{attn_layer_index}H{attn_head_index}.pdf',scale=10, width=500, height=350)

In [60]:
tensors_dict = {}
with safe_open(f"{tensors_dir}/{model.cfg.model_name}_{dataset_type}_{language}_pca_L{attn_layer_index}H{attn_head_index}.safetensors", framework="pt", device=0) as f:
        for k in f.keys():
            tensors_dict[k] = f.get_tensor(k)
tensors_dict['PC'][0].shape

torch.Size([2048])

In [None]:
from safetensors import safe_open
tensors_dir = f"{home}/sva_tensors"
if not os.path.exists(tensors_dir):
    os.makedirs(tensors_dir)

hook_names = [utils.get_act_name('resid_pre',layer) for layer in range(model.cfg.n_layers)]
tensors_dict = {}
for hook_name in hook_names:
    cached_activations = cache[hook_name].contiguous()
    
    tensors_dict[f"{hook_name}"] = cached_activations
    

save_file(tensors_dict, f"{tensors_dir}/{model.cfg.model_name}_{dataset_type}_{language}.safetensors")

In [13]:


tensors_language_dict = {}
for language in ['english', 'spanish']:
    tensors_dict = {}
    with safe_open(f"{tensors_dir}/{model.cfg.model_name}_{dataset_type}_{language}.safetensors", framework="pt", device=0) as f:
        for k in f.keys():
            tensors_dict[k] = f.get_tensor(k)
    tensors_language_dict[language] = tensors_dict

In [19]:
layer = 9

tensors_english = tensors_language_dict['english'][utils.get_act_name('resid_pre',layer)]
tensors_spanish = tensors_language_dict['spanish'][utils.get_act_name('resid_pre',layer)]

In [89]:
from plotly_utils import scatter


In [None]:
# MLP neuron
MLP_layer = 13
MLP_neuron = 2069

w_in_neuron_weights = model.W_in[MLP_layer][:,MLP_neuron]
w_in_neuron_pca_proj = pca.transform(w_in_neuron_weights.unsqueeze(0).cpu().numpy())


In [None]:
output = F.cosine_similarity(w_in_neuron_weights.unsqueeze(0), output_attention_head)
output

tensor([ 0.2965, -0.1273, -0.1302,  0.2883, -0.1387, -0.1468,  0.2199, -0.1080,
         0.2397, -0.1037, -0.1163,  0.2514, -0.1226,  0.2502, -0.1355, -0.0943,
        -0.1305,  0.3085, -0.0981, -0.1355, -0.1496, -0.0867, -0.1072,  0.2390,
        -0.0756, -0.1068,  0.2407,  0.2361, -0.1332,  0.2980, -0.0903, -0.1207,
         0.3070,  0.2489, -0.1237, -0.1544, -0.0814, -0.1183, -0.1197,  0.2313,
         0.2215, -0.1219,  0.2387, -0.1190, -0.1393, -0.0879, -0.1244,  0.2290,
         0.2167,  0.2612], device='cuda:0')

In [None]:
pca_matrix = np.vstack([pca_matrix, w_in_neuron_pca_proj])
number.append(2)

In [25]:
sing_idx = np.where(np.array(number)==0)[0]
sing_idx

array([  2,   3,   5,   6,   7,  12,  13,  14,  16,  20,  21,  23,  26,
        27,  28,  32,  33,  34,  37,  40,  41,  42,  46,  47,  48,  49,
        52,  53,  56,  60,  61,  62,  63,  66,  69,  71,  72,  74,  75,
        78,  79,  82,  85,  86,  87,  90,  91,  94,  97,  98,  99, 102,
       105, 106, 107, 111, 112, 114, 117, 118, 119, 122, 123, 126, 127,
       131, 132, 136, 137, 139, 140, 142, 145, 146, 149, 150, 153, 154,
       155, 159, 160, 162, 163, 166, 168, 169, 172, 175, 176, 177, 180,
       183, 184, 185, 188, 189, 192, 193, 197, 198, 200, 201, 204, 207,
       209, 212, 213, 215, 216, 220, 221, 222, 223, 228, 229, 230, 231,
       236, 237, 240, 241, 242, 243, 246, 247, 250, 253, 255, 256, 259,
       260, 262, 263, 267, 268, 270, 271, 274, 277, 280, 281, 283, 285,
       288, 289, 290, 291, 294, 296, 298, 299, 301, 302, 304, 305, 307,
       309, 310, 312, 314, 315, 318, 319, 323, 324, 326, 328, 329, 332,
       333, 335, 339, 341, 342, 343, 344, 348, 350, 353, 355, 35

In [26]:
np.where(np.take(pca_matrix[:,1],sing_idx)>200.0)

(array([162]),)

In [31]:
utils.test_prompt('The secretary that liked the secretaries', ' was', model)

Tokenized prompt: ['<bos>', 'The', ' secretary', ' that', ' liked', ' the', ' secretaries']
Tokenized answer: [' was']


Top 0th token. Logit: 19.28 Prob: 13.78% Token: |.|
Top 1th token. Logit: 18.91 Prob:  9.55% Token: |,|
Top 2th token. Logit: 18.73 Prob:  7.98% Token: | that|
Top 3th token. Logit: 18.21 Prob:  4.74% Token: | and|
Top 4th token. Logit: 18.03 Prob:  3.95% Token: | was|
Top 5th token. Logit: 18.01 Prob:  3.86% Token: |

|
Top 6th token. Logit: 17.66 Prob:  2.74% Token: | is|
Top 7th token. Logit: 17.64 Prob:  2.67% Token: |
|
Top 8th token. Logit: 17.51 Prob:  2.35% Token: | of|
Top 9th token. Logit: 17.26 Prob:  1.83% Token: | job|


In [None]:
# Attention head
attn_layer_index = 13
attn_head_index = 7
# MLP neuron
MLP_layer = 13
MLP_neuron = 2069

# Output attention head
output_attention_head = cache[f'blocks.{attn_layer_index}.attn.hook_result'][:,-1,attn_head_index]

# Dot product with neuron's W_in column (input weights)
dot_prod = einsum(
        f"batch d_model, d_model \
        -> batch ",
        output_attention_head,
        model.W_in[MLP_layer][:,MLP_neuron])

for batch_element in range(len(base_list)):
    print(f'{base_list[batch_element]}, dot product: {dot_prod[batch_element]:.2}')

The ministers that disguised the executive, dot product: 2.5
The authors that injured the secretary, dot product: 2.2
The athlete that disguised the executive, dot product: -1.2
The guard that injured the secretary, dot product: -0.98
The ministers that embarrassed the manager, dot product: 2.9
The executive that embarrassed the manager, dot product: -0.95
The executive that injured the doctors, dot product: -0.5
The farmer that ignored the teachers, dot product: -1.2
The executives that injured the doctors, dot product: 3.5
The authors that ignored the teachers, dot product: 3.5
The managers that admired the author, dot product: 2.5
The actors that ignored the author, dot product: 3.0
The executive that admired the author, dot product: -0.93
The secretary that ignored the author, dot product: -0.99
The consultant that ignored the actor, dot product: -0.9
The authors that ignored the actor, dot product: 3.6
The manager that ignored the athlete, dot product: -1.1
The athletes that ignor

In [None]:
layer_index = 17
head_index = 4

# Compute final LN normalized attention head output, attention block
# MLP output, or specific neuron

normalized_attn_hook_result = cache.apply_ln_to_stack(
    cache[f'blocks.{layer_index}.attn.hook_result'][:,-1,head_index], layer=-1, pos_slice=-1
)

normalized_attn_block = cache.apply_ln_to_stack(
    cache['blocks.0.hook_attn_out'][:,-1,:], layer=-1, pos_slice=-1
)

# [50, 7, 2048]
normalized_mlp_out = cache.apply_ln_to_stack(
    cache[f'blocks.{layer_index}.hook_mlp_out'][:,-1], layer=-1, pos_slice=-1
)

neuron_idx = 7540
neuron = model.W_out[layer_index][neuron_idx].unsqueeze(0)

# Project component output into the vocabulary space
# Select model component output!
tokens_component = einsum(
        f"batch d_model, d_model vocab_size \
        -> batch vocab_size",
        normalized_attn_hook_result * model.ln_final.w,
        model.unembed.W_U)
torch.cuda.empty_cache()

In [None]:
from collections import defaultdict

# Get top/bottom tokens
top_k = 500
largest = True # True: promoted tokens / False: suppressed tokens
ranking_dict = defaultdict(list)
top_k_tokens_component = torch.topk(tokens_component, top_k, dim=-1, largest=largest).indices#.cpu().tolist()
for batch_element in range(top_k_tokens_component.shape[0]):
        top_k_str_tokens_component = model.tokenizer.convert_ids_to_tokens(top_k_tokens_component[batch_element])
        if largest == False:
                print(f'Suppresed tokens: {top_k_str_tokens_component}')
        else:
                print(f'Top promoted tokens: {top_k_str_tokens_component}')
        rank_corr_pred = (torch.sort(tokens_component, descending=largest).indices[batch_element] == answer_token_indices[batch_element][0].item()).nonzero(as_tuple=True)[0].item()
        rank_wrong_pred = (torch.sort(tokens_component, descending=largest).indices[batch_element] == answer_token_indices[batch_element][1].item()).nonzero(as_tuple=True)[0].item()
        
        if largest == False:
                # We show indices starting from the last
                rank_corr_pred*=-1
                rank_wrong_pred*=-1
        ranking_dict['corr_pred'].append(rank_corr_pred)
        ranking_dict['wrong_pred'].append(rank_wrong_pred)
        print(f'Ranking correct prediction: {rank_corr_pred}')
        print(f'Ranking wrong prediction: {rank_wrong_pred}\n')


Top promoted tokens: ['<bos>', '▁were', '▁are', '▁aren', 'are', 'were', '▁Walkover', '▁weren', '▁Paglinawan', '▁came', '▁come', '<strong>', '▁Are', '▁sind', '▁грудня', '▁SEDS', '\n', '▁LCCN', '▁Topf', 'LEncoder', 'Από', '▁will', '▁CURLOPT', '▁ARE', 'Asimismo', '▁▁▁▁▁', 'PhysRevLett', '▁as', 'mathbf', '▁▁', '▁sono', '▁took', '\n\n', '▁січня', 'mathrm', 'quad', '▁▁▁▁', 'Για', '▁Baillargeon', 'EndProject', '.', '\xad', '▁▁▁', '▁هستند', '▁BoxFit', '▁лютого', '▁NKC', '▁اطلع', 'Ventajas', '▁Seeder', '▁CascadeType', '▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁кӀ', '▁않', '▁декабря', '▁▁▁▁▁▁▁▁▁', '▁були', '▁include', '▁▁▁▁▁▁', 'PhysRev', '▁Assignee', '▁have', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', ',', 'end', '▁', 'come', 'CompleteListener', 'siveness', '▁and', '▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', 'клопе', '੍ਹ', '▁(', '▁березня', 'databind', '▁января', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', 'Slf', '▁▁▁▁▁▁▁▁', '▁Weise', 'DateTimeField', 'made', 'strophy', 'will', '▁|', 'ClassNotFound', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', 'Οι', '▁Samstag'

In [None]:
df = pd.DataFrame.from_dict(ranking_dict)
import plotly.express as px
fig = px.box(df, title=f'Ranking correct and wrong verb forms L{layer_index}H{head_index} Spanish')
fig.show()