In [1]:
%load_ext autoreload
%autoreload 2

import os
from os.path import expanduser
home = expanduser("~")

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from typing import List, Union, Optional
from functools import partial
import copy

import itertools
import dataclasses
from IPython.display import HTML
import pandas as pd
from neel_plotly import line, imshow, scatter
from jaxtyping import Float
import plotly.io as pio

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import transformer_lens.patching as patching
from utils_sva import residual_stack_to_logit_diff
torch.set_grad_enabled(False)

# Create folder to save plots
images_dir = f"{home}/circuits_languages/images"
if not os.path.exists(images_dir):
    os.makedirs(images_dir)

In [7]:
from utils_sva import clean_blocks_labels, paper_plot
from utils_sva import get_logit_diff, compute_act_patching
from load_dataset import load_sva_dataset, get_batched_dataset

### Load Model

In [4]:
n_devices = torch.cuda.device_count()

model = HookedTransformer.from_pretrained(
    "gemma-2b",
    center_unembed=True,
    center_writing_weights=False,
    fold_ln=False,
    fold_value_biases=False,
    #n_devices=2
)
model.set_use_attn_result(False)
# Get the default device used
device: torch.device = utils.get_device()

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model gemma-2b into HookedTransformer


### Composition Attention head and MLP neurons

In [None]:
model.set_use_attn_result(True)

dataset_type = 'both' # singular / plural / both
language = 'spanish' # english / spanish / both
num_samples = 300
batch_size = 10
start_at = 0
dataset = load_sva_dataset(model, language, dataset_type, num_samples)
batched_dataset = get_batched_dataset(model, dataset, batch_size=batch_size)

batches_base_tokens = batched_dataset['batches_base_tokens']
batches_src_tokens = batched_dataset['batches_src_tokens']
batches_answer_token_indices = batched_dataset['batches_answer_token_indices']
batches = len(batches_src_tokens)

In [7]:
# Attention head
images_dir = f"{home}/circuits_languages/images"
batches = len(batches_base_tokens)
attn_layer_index = 13
attn_head_index = 7
list_output_hook = []
#hook_name = f'blocks.{attn_layer_index}.attn.hook_result'
hook_name = utils.get_act_name('result', attn_layer_index)
#hook_name = utils.get_act_name('resid_pre', attn_layer_index)
#output_attention_head = cache[hook_name][:,-1,attn_head_index]
for batch in range(batches):
    # Get clean tokens and answer indices from batches
    base_tokens = batches_base_tokens[batch]
    base_logits, base_cache = model.run_with_cache(base_tokens, names_filter=[hook_name])
    output_hook_pos = base_cache[hook_name][:,-1]
    list_output_hook.append(output_hook_pos)
#key_attention_head = cache[f'blocks.{attn_layer_index}.attn.hook_q'][:,-1,attn_head_index]
cat_output_hook = torch.cat(list_output_hook, 0)


In [8]:
output_attention_head = cat_output_hook[:,attn_head_index]

In [9]:
# MLP neuron
MLP_layer = 13
MLP_neuron = 2069

# Dot product with neuron's W_in column (input weights)
dot_prod = einsum(
        f"batch d_model, d_model \
        -> batch ",
        output_attention_head,
        model.W_in[MLP_layer][:,MLP_neuron])

In [10]:
def flatten(xss):
    return [x for xs in xss for x in xs]
flatten_batches_base_tokens = flatten(batched_dataset['batches_base_tokens'])
flatten_ex_number_list = flatten(batched_dataset['batches_ex_number_list'])


In [11]:
from collections import defaultdict
per_number_dict = defaultdict(list)

for batch_element in range(dot_prod.shape[0]):
    print(f'{model.to_string(flatten_batches_base_tokens[batch_element])}, dot product: {dot_prod[batch_element]:.2}')
    if flatten_ex_number_list[batch_element] == 'Singular':
        if len(per_number_dict['Singular']) == 149:
            # Workaraound to get lists of the same length
            continue
        else:
            per_number_dict['Singular'].append(dot_prod[batch_element].item())
        
    elif flatten_ex_number_list[batch_element] == 'Plural':
        per_number_dict['Plural'].append(dot_prod[batch_element].item())

<bos>The ministers that disguised the executive, dot product: 2.5
<bos>The authors that injured the secretary, dot product: 2.2
<bos>The athlete that disguised the executive, dot product: -1.2
<bos>The guard that injured the secretary, dot product: -0.98
<bos>The ministers that embarrassed the manager, dot product: 2.8
<bos>The executive that embarrassed the manager, dot product: -0.95
<bos>The executive that injured the doctors, dot product: -0.49
<bos>The farmer that ignored the teachers, dot product: -1.2
<bos>The executives that injured the doctors, dot product: 3.5
<bos>The authors that ignored the teachers, dot product: 3.6
<bos>The managers that admired the author, dot product: 2.5
<bos>The actors that ignored the author, dot product: 3.0
<bos>The executive that admired the author, dot product: -0.93
<bos>The secretary that ignored the author, dot product: -0.99
<bos>The consultant that ignored the actor, dot product: -0.9
<bos>The authors that ignored the actor, dot product: 3.

In [12]:
len(per_number_dict['Plural'])

149

In [13]:
len(per_number_dict['Singular'])

149

In [20]:
df = pd.DataFrame.from_dict(per_number_dict)
import plotly.express as px
fig = px.box(df, title=f'')
fig.update_layout(
    xaxis_title="Subject Number", yaxis_title="",
    font=dict(
        size=15,  # Set the font size here
    )
)
fig = paper_plot(fig, tickangle=0)
fig.show()

In [21]:
pio.write_image(fig, f'{images_dir}/{language}_MLP{MLP_layer}_neuron{MLP_neuron}_act_subj_num.png',scale=5, width=550, height=350)

In [12]:
layer_index = 17
head_index = 4

# Compute final LN normalized attention head output, attention block
# MLP output, or specific neuron

# normalized_attn_hook_result = cache.apply_ln_to_stack(
#     cache[f'blocks.{layer_index}.attn.hook_result'][:,-1,head_index], layer=-1, pos_slice=-1
# )

# normalized_attn_block = cache.apply_ln_to_stack(
#     cache['blocks.0.hook_attn_out'][:,-1,:], layer=-1, pos_slice=-1
# )

# # [50, 7, 2048]
# normalized_mlp_out = cache.apply_ln_to_stack(
#     cache[f'blocks.{layer_index}.hook_mlp_out'][:,-1], layer=-1, pos_slice=-1
# )


neuron_idx = 1138
neuron = model.W_out[layer_index][neuron_idx].unsqueeze(0)

# Project component output into the vocabulary space
# Select model component output!
tokens_component = einsum(
        f"batch d_model, d_model vocab_size \
        -> batch vocab_size",
        neuron * model.ln_final.w,
        model.unembed.W_U)
torch.cuda.empty_cache()

In [13]:
from collections import defaultdict

# Get top/bottom tokens
top_k = 500
largest = True # True: promoted tokens / False: suppressed tokens
ranking_dict = defaultdict(list)
top_k_tokens_component = torch.topk(tokens_component, top_k, dim=-1, largest=largest).indices#.cpu().tolist()
for batch_element in range(top_k_tokens_component.shape[0]):
        top_k_str_tokens_component = model.tokenizer.convert_ids_to_tokens(top_k_tokens_component[batch_element])
        if largest == False:
                print(f'Suppresed tokens: {top_k_str_tokens_component}')
        else:
                print(f'Top promoted tokens: {top_k_str_tokens_component}')
        rank_corr_pred = (torch.sort(tokens_component, descending=largest).indices[batch_element] == answer_token_indices[batch_element][0].item()).nonzero(as_tuple=True)[0].item()
        rank_wrong_pred = (torch.sort(tokens_component, descending=largest).indices[batch_element] == answer_token_indices[batch_element][1].item()).nonzero(as_tuple=True)[0].item()
        
        if largest == False:
                # We show indices starting from the last
                rank_corr_pred*=-1
                rank_wrong_pred*=-1
        ranking_dict['corr_pred'].append(rank_corr_pred)
        ranking_dict['wrong_pred'].append(rank_wrong_pred)
        print(f'Ranking correct prediction: {rank_corr_pred}')
        print(f'Ranking wrong prediction: {rank_wrong_pred}\n')


Top promoted tokens: ['▁!...', 'tinyos', 'tshell', 'uffy', 'jarah', '▁Pexels', '▁pixabay', '▁tenda', 'GEBURTSDATUM', '▁Bhi', '▁Méri', '▁Dubuque', 'vergne', '▁uefa', 'Discografia', '▁effe', '▁thut', '▁Stavanger', 'rungsseite', '▁inverno', 'EditorBrowsable', 'Ẳ', '▁Pixabay', '▁Evansville', 'olkien', '▁greate', '▁grammi', '▁Thần', '▁liverpool', '▁peugeot', '!:)', 'Autoritní', 'ChatColor', '▁morrow', '▁tartalomajánló', '▁chelsea', '▁overcrow', '!!:', '▁autorytatywna', '▁Tradu', '▁lidl', '▁nct', '▁fays', '▁nabo', 'Портали', 'Karakteristik', '▁:)</', '▁fhe', "▁:')", '▁fuf', '▁whil', '▁racconta', 'alpin', '▁princi', 'uscany', '▁desir', '▁ktm', 'rivit', 'memoized', '▁:)))', '▁outono', '▁onsdag', 'umplimiento', '▁Roskov', '▁København', '▁venice', '▁bangkok', '▁herre', '▁waer', '▁shou', '!!</', 'ategorias', 'verifyException', '▁poff', '▁betweenstory', '▁Mulher', '▁tõ', '▁iStock', '▁Aggi', '▁Gurgaon', 'uxedo', '▁Pressed', '▁fign', '▁thar', 'bibnamefont', '▁CreateTagHelper', '▁Occidente', 'zegor',

In [None]:
df = pd.DataFrame.from_dict(ranking_dict)
import plotly.express as px
fig = px.box(df, title=f'Ranking correct and wrong verb forms L{layer_index}H{head_index} Spanish')
fig.show()