In [4]:
%load_ext autoreload
%autoreload 2

import os
from os.path import expanduser
home = expanduser("~")

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
import plotly.graph_objects as go
from torch.utils.data import DataLoader

from typing import List, Union, Optional
from functools import partial
import copy
import gc
import itertools
import dataclasses
from IPython.display import HTML
import pandas as pd
from neel_plotly import line, imshow, scatter
from jaxtyping import Float
import plotly.io as pio
from PIL import ImageColor
from utils_sva import html_colors

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import transformer_lens.patching as patching
from utils_sva import residual_stack_to_logit_diff
torch.set_grad_enabled(False)

# Create folder to save plots
images_dir = f"{home}/circuits_languages/images"
if not os.path.exists(images_dir):
    os.makedirs(images_dir)

from collections import defaultdict
from utils_sva import clean_blocks_labels, paper_plot
from utils_sva import get_logit_diff
from load_dataset import load_sva_dataset, get_batched_dataset
spa_color = 'rgb' + str(ImageColor.getcolor(html_colors['green_drawio'], "RGB"))
eng_color = 'rgb' + str(ImageColor.getcolor(html_colors['brown_D3'], "RGB"))

def flatten(xss):
    return [x for xs in xss for x in xs]

In [6]:
from utils_sva import clean_blocks_labels, paper_plot
from utils_sva import get_logit_diff, compute_act_patching
from load_dataset import load_sva_dataset, get_batched_dataset

### Load Model

In [7]:
n_devices = torch.cuda.device_count()

model = HookedTransformer.from_pretrained(
    "gemma-2b",
    center_unembed=True,
    center_writing_weights=False,
    fold_ln=False,
    fold_value_biases=False,
    #n_devices=2
)
model.set_use_attn_result(False)
# Get the default device used
device: torch.device = utils.get_device()

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 128.00 MiB. GPU 

### Composition Attention head and MLP neurons

In [5]:
model.set_use_attn_result(True)

In [6]:
def contrib_logit_diff(model, batched_dataset, hook_name, layer_index):
    batches_base_tokens = batched_dataset['batches_base_tokens']
    batches_src_tokens = batched_dataset['batches_src_tokens']
    batches_answer_token_indices = batched_dataset['batches_answer_token_indices']
    
    batches_answer_token_indices
    neurons_contrib_logit_diff_list = []
    for batch in range(len(batches_base_tokens)):
        # Get clean tokens and answer indices from batches
        base_tokens = batches_base_tokens[batch]
        answer_token_indices = batches_answer_token_indices[batch]
        base_logits, base_cache = model.run_with_cache(base_tokens)
        #answer_token_indices = answer_token_indices.to(base_logits.device)
        answer_residual_directions = model.tokens_to_residual_directions(answer_token_indices)

        # Difference of unembedding vectors
        logit_diff_directions = (
            answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
        )

        _, base_cache = model.run_with_cache(base_tokens, names_filter=[hook_name, "ln_final.hook_scale"])
        output_hook_pos = base_cache[hook_name][:,-1]
        #list_output_hook.append(output_hook_pos)
        act_w_out_rows = einsum(
            f"batch d_ffn, d_ffn d_model \
            -> batch d_ffn d_model",
            base_cache[hook_name][:,-1],
            model.W_out[layer_index])

        # Scaling by final LN scaling value (we need this to get exact logit diffs)
        scaled_act_w_out_rows = act_w_out_rows / base_cache["ln_final.hook_scale"][:,-1].unsqueeze(1)

        # Dot product with unembeding diff directions (first apply final LN weights)
        neurons_contrib_logit_diff = einsum(
                f"batch d_ffn d_model, batch d_model \
                -> batch d_ffn",
                scaled_act_w_out_rows * model.ln_final.w,
                logit_diff_directions)
        neurons_contrib_logit_diff_list.append(neurons_contrib_logit_diff)
        torch.cuda.empty_cache()
    neurons_contrib_logit_diff = torch.cat(neurons_contrib_logit_diff_list, 0)

    return neurons_contrib_logit_diff

In [56]:
layer_index = 16
hook_name = utils.get_act_name('post', layer_index)

dataset_type = 'both' # singular / plural / both
num_samples = 300
batch_size = 10
lang_logit_diff = {}
for language in ['Spanish', 'English']:
    dataset = load_sva_dataset(model, language, dataset_type, num_samples)
    batched_dataset = get_batched_dataset(model, dataset, batch_size=batch_size)
    neurons_contrib_logit_diff = contrib_logit_diff(model, batched_dataset, hook_name, layer_index)
    lang_logit_diff[language] = neurons_contrib_logit_diff.mean(0).tolist()
#output_attention_head = cache[hook_name][:,-1,attn_head_index]


137
163
151
149


In [57]:
df = pd.DataFrame.from_dict(lang_logit_diff)
fig = px.line(df['Spanish'],color_discrete_map={"Spanish": spa_color, "English": eng_color},)
fig.update_layout(legend_title_text='Language',
                  xaxis_title="Neuron", yaxis_title="Logit Difference",
                  font=dict(
                      size=15,  # Set the font size here
                  ) 
                )
fig.add_trace(go.Scatter(x=np.arange(len(df['English'])), y=df['English'], mode='lines', name='English', line=dict(color=eng_color)))

#fig.update_traces(opacity=.4)

fig = paper_plot(fig)
fig.show()
# Update
#fig.update_layout(legend_traceorder="reversed")

#pio.write_image(fig, f'{images_dir}/MLP{layer_index}_both_neurons.pdf',scale=6, width=800, height=350)

### Composition Attention head and MLP neurons

In [24]:
def get_composition(batched_dataset, attn_layer_index, attn_head_index, MLP_layer, MLP_neuron):
    batches_base_tokens = batched_dataset['batches_base_tokens']

    list_output_hook = []
    #hook_name = f'blocks.{attn_layer_index}.attn.hook_result'
    hook_name = utils.get_act_name('result', attn_layer_index)
    #hook_name = utils.get_act_name('resid_pre', attn_layer_index)
    #output_attention_head = cache[hook_name][:,-1,attn_head_index]
    for batch in range(len(batches_base_tokens)):
        # Get clean tokens and answer indices from batches
        base_tokens = batches_base_tokens[batch]
        base_logits, base_cache = model.run_with_cache(base_tokens, names_filter=[hook_name])
        output_hook_pos = base_cache[hook_name][:,-1]
        list_output_hook.append(output_hook_pos)
    #key_attention_head = cache[f'blocks.{attn_layer_index}.attn.hook_q'][:,-1,attn_head_index]
    cat_output_hook = torch.cat(list_output_hook, 0)
    output_attention_head = cat_output_hook[:,attn_head_index]

    # Dot product with neuron's W_in column (input weights)
    dot_prod = einsum(
            f"batch d_model, d_model \
            -> batch ",
            output_attention_head,
            model.W_in[MLP_layer][:,MLP_neuron])
    return dot_prod



In [58]:
lang_per_number_dict = {}

# Attention head
attn_layer_index = 13
attn_head_index = 7
# MLP neuron
MLP_layer = 16#13
MLP_neuron = 971#2069


images_dir = f"{home}/circuits_languages/images"
dataset_type = 'both' # singular / plural / both
num_samples = 300
batch_size = 10
lang_dot_prod = {}
for language in ['Spanish', 'English']:
    dataset = load_sva_dataset(model, language, dataset_type, num_samples)
    batched_dataset = get_batched_dataset(model, dataset, batch_size=batch_size)
    dot_prod = get_composition(batched_dataset, attn_layer_index, attn_head_index, MLP_layer, MLP_neuron)
    
    flatten_batches_base_tokens = flatten(batched_dataset['batches_base_tokens'])
    flatten_ex_number_list = flatten(batched_dataset['batches_ex_number_list'])
    per_number_dict = defaultdict(list)
    for batch_element in range(dot_prod.shape[0]):
        #print(f'{model.to_string(flatten_batches_base_tokens[batch_element])}, dot product: {dot_prod[batch_element]:.2}')
        if flatten_ex_number_list[batch_element] == 'Singular':
            if len(per_number_dict['Singular']) == 149:
                # Workaraound to get lists of the same length
                continue
            else:
                per_number_dict['Singular'].append(dot_prod[batch_element].item())
            
        elif flatten_ex_number_list[batch_element] == 'Plural':
            per_number_dict['Plural'].append(dot_prod[batch_element].item())

    lang_per_number_dict[language] = per_number_dict


139
161
151
149


In [59]:
length_lists = []
for language in ['Spanish', 'English']:
    for number in ['Singular', 'Plural']:
            length_lists.append(len(lang_per_number_dict[language][number]))
min_len = np.array(length_lists).min()
for language in ['Spanish', 'English']:
    for number in ['Singular', 'Plural']:
            lang_per_number_dict[language][number] = lang_per_number_dict[language][number][:min_len]

In [60]:
import plotly.graph_objects as go

from plotly.subplots import make_subplots

fig = make_subplots(rows=1, cols=1)

df = pd.DataFrame.from_dict(lang_per_number_dict['English'])
#fig = px.box(df, title=f'',color='Plural')
fig.add_trace(go.Box(y=lang_per_number_dict['English']['Plural'],
                    x=['Plural']*len(lang_per_number_dict['English']['Plural']),
                    marker_color=eng_color,
                    showlegend=True,
                    name='English'))
                    
fig.add_trace(go.Box(y=lang_per_number_dict['English']['Singular'],
                    x=['Singular']*len(lang_per_number_dict['English']['Plural']),
                    marker_color=eng_color,
                    showlegend=False))

fig.add_trace(go.Box(y=lang_per_number_dict['Spanish']['Plural'],
                    x=['Plural']*len(lang_per_number_dict['Spanish']['Plural']),
                    marker_color=spa_color,
                    showlegend=False))
fig.add_trace(go.Box(y=lang_per_number_dict['Spanish']['Singular'],
                    x=['Singular']*len(lang_per_number_dict['Spanish']['Singular']),
                    marker_color=spa_color,
                    name='Spanish'))

fig.update_layout(
    xaxis_title="Subject Number", yaxis_title="",
    title=f"MLP {MLP_layer} Neuron {MLP_neuron}",
    legend_title_text='Language',
    font=dict(
        size=15,  # Set the font size here
    ),
    boxmode='group'
)
fig.update_layout(legend_traceorder="reversed")

fig = paper_plot(fig, tickangle=0)
fig.show()
pio.write_image(fig, f'{images_dir}/both_MLP{MLP_layer}_neuron{MLP_neuron}_act_subj_num.png',scale=5, width=550, height=350)


In [118]:
layer_index = 13
head_index = 4

dataset_type = 'singular' # singular / plural / both
num_samples = 100
batch_size = 10
lang_dot_prod = {}
language = 'English'
dataset = load_sva_dataset(model, language, dataset_type, num_samples)
batched_dataset = get_batched_dataset(model, dataset, batch_size=batch_size)

# Compute final LN normalized attention head output, attention block
# MLP output, or specific neuron

# normalized_attn_hook_result = cache.apply_ln_to_stack(
#     cache[f'blocks.{layer_index}.attn.hook_result'][:,-1,head_index], layer=-1, pos_slice=-1
# )

# normalized_attn_block = cache.apply_ln_to_stack(
#     cache['blocks.0.hook_attn_out'][:,-1,:], layer=-1, pos_slice=-1
# )

# # [50, 7, 2048]
# normalized_mlp_out = cache.apply_ln_to_stack(
#     cache[f'blocks.{layer_index}.hook_mlp_out'][:,-1], layer=-1, pos_slice=-1
# )


neuron_idx = 2069#2069
neuron = model.W_out[layer_index][neuron_idx].unsqueeze(0)

# Project component output into the vocabulary space
# Select model component output!
tokens_component = einsum(
        f"batch d_model, d_model vocab_size \
        -> batch vocab_size",
        neuron * model.ln_final.w,
        model.unembed.W_U)
torch.cuda.empty_cache()

100
0


In [3]:
# Get top/bottom tokens
top_k = 500
largest = True # True: promoted tokens / False: suppressed tokens
ranking_dict = defaultdict(list)
top_k_tokens_component = torch.topk(tokens_component, top_k, dim=-1, largest=largest).indices#.cpu().tolist()
for batch_element in range(top_k_tokens_component.shape[0]):
        #print(f'{model.to_string(batched_dataset['base_tokens'][batch_element])}')
        answer_token_indices = batched_dataset['batches_answer_token_indices']
        top_k_str_tokens_component = model.tokenizer.convert_ids_to_tokens(top_k_tokens_component[batch_element])
        if largest == False:
                print(f'Suppresed tokens: {top_k_str_tokens_component}')
        else:
                print(f'Top promoted tokens: {top_k_str_tokens_component}')
        print(torch.sort(tokens_component, descending=largest).indices[batch_element])
        rank_corr_pred = (torch.sort(tokens_component, descending=largest).indices[batch_element].unsqueeze(0) == answer_token_indices[batch_element][:,0].unsqueeze(1)).nonzero(as_tuple=True)[1]
        rank_wrong_pred = (torch.sort(tokens_component, descending=largest).indices[batch_element].unsqueeze(0) == answer_token_indices[batch_element][:,1].unsqueeze(1)).nonzero(as_tuple=True)[1]
        
        # if largest == False:
        #         # We show indices starting from the last
        #         rank_corr_pred*=-1
        #         rank_wrong_pred*=-1
        # ranking_dict['corr_pred'].append(rank_corr_pred)
        # ranking_dict['wrong_pred'].append(rank_wrong_pred)
        print(f'Ranking correct prediction: {rank_corr_pred}')
        print(f'Ranking wrong prediction: {rank_wrong_pred}\n')

NameError: name 'defaultdict' is not defined

In [120]:
model.to_str_tokens(answer_token_indices[batch_element][:,0])

[' has', ' has', ' has', ' is', ' is', ' was', ' was', ' was', ' has', ' was']

In [None]:
# df = pd.DataFrame.from_dict(ranking_dict)
# fig = px.box(df, title=f'Ranking correct and wrong verb forms L{layer_index}H{head_index} Spanish')
# fig.show()