In [1]:
%load_ext autoreload
%autoreload 2

import os
from os.path import expanduser
home = expanduser("~")

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from typing import List, Union, Optional
from functools import partial
import copy

import itertools
import dataclasses
from IPython.display import HTML
import pandas as pd
from neel_plotly import line, imshow, scatter
from jaxtyping import Float
import plotly.io as pio
from PIL import ImageColor
from utils_sva import html_colors

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import transformer_lens.patching as patching
from utils_sva import residual_stack_to_logit_diff
torch.set_grad_enabled(False)

# Create folder to save plots
images_dir = f"{home}/circuits_languages/images"
if not os.path.exists(images_dir):
    os.makedirs(images_dir)

In [3]:
from utils_sva import clean_blocks_labels, paper_plot
from utils_sva import get_logit_diff
from load_dataset import load_sva_dataset, get_batched_dataset
spa_color = 'rgb' + str(ImageColor.getcolor(html_colors['green_drawio'], "RGB"))
eng_color = 'rgb' + str(ImageColor.getcolor(html_colors['brown_D3'], "RGB"))

### Load Model

In [4]:
n_devices = torch.cuda.device_count()

model = HookedTransformer.from_pretrained(
    "gemma-2b",
    center_unembed=True,
    center_writing_weights=False,
    fold_ln=False,
    fold_value_biases=False,
    #n_devices=2
)
model.set_use_attn_result(True)
# Get the default device used
device: torch.device = utils.get_device()

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model gemma-2b into HookedTransformer


### Direct Logit Attribution

In [7]:
dataset_type = 'both' # singular / plural / both
language = 'both' # english / spanish / both
num_samples = 100
batch_size = 10
start_at = 0
dataset = load_sva_dataset(model, language, dataset_type, num_samples)
batched_dataset = get_batched_dataset(model, dataset, batch_size=batch_size)

batches_base_tokens = batched_dataset['batches_base_tokens']
batches_src_tokens = batched_dataset['batches_src_tokens']
batches_answer_token_indices = batched_dataset['batches_answer_token_indices']
batches = len(batches_src_tokens)

In [27]:
# DLA on accumulated residuals
import gc

base_logit_diff_list = []
src_logit_diff_list = []

batches_base_logits = []

list_logit_lens_logit_diffs= []
for batch in range(batches):
    # Get clean tokens and answer indices from batches
    base_tokens = batches_base_tokens[batch]
    answer_token_indices = batches_answer_token_indices[batch]

    base_logits, base_cache = model.run_with_cache(base_tokens)
    answer_token_indices = answer_token_indices.to(base_logits.device)
    original_average_logit_diff = get_logit_diff(base_logits, answer_token_indices)
    answer_residual_directions = model.tokens_to_residual_directions(answer_token_indices)

    # Difference of unembedding vectors
    logit_diff_directions = (
        answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
    )
    accumulated_residual, labels = base_cache.accumulated_resid(
        layer=-1, incl_mid=True, pos_slice=-1, return_labels=True
    )
    logit_lens_logit_diffs = residual_stack_to_logit_diff(model, accumulated_residual, base_cache, logit_diff_directions, mean=False)
    list_logit_lens_logit_diffs.append(logit_lens_logit_diffs)
    

    torch.cuda.empty_cache()
    gc.collect()
cat_logit_lens_logit_diffs = torch.cat(list_logit_lens_logit_diffs, 0)

In [31]:
fig = line(
        cat_logit_lens_logit_diffs.mean(0),
        x=np.arange(model.cfg.n_layers * 2 + 1) / 2,
        hover_name=labels,
        title="Logit Difference From Accumulate Residual Stream",
        return_fig=True,
        #labels={"x": "Layer", "y": "Logit Difference"},
    )
fig = paper_plot(fig)
fig.show()

In [63]:
import gc

from load_dataset import load_sva_dataset

def get_per_layer_logit_diff(batched_dataset, batches):

    batches_base_tokens = batched_dataset['batches_base_tokens']
    batches_src_tokens = batched_dataset['batches_src_tokens']
    batches_answer_token_indices = batched_dataset['batches_answer_token_indices']

    # base_logit_diff_list = []
    # src_logit_diff_list = []
    # batches_base_logits = []
    list_per_layer_logit_diffs= []
    for batch in range(batches):
        # Get clean tokens and answer indices from batches
        base_tokens = batches_base_tokens[batch]
        answer_token_indices = batches_answer_token_indices[batch]

        base_logits, base_cache = model.run_with_cache(base_tokens)
        answer_token_indices = answer_token_indices.to(base_logits.device)
        original_average_logit_diff = get_logit_diff(base_logits, answer_token_indices)
        answer_residual_directions = model.tokens_to_residual_directions(answer_token_indices)

        # Difference of unembedding vectors
        logit_diff_directions = (
            answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
        )

        per_layer_residual, labels = base_cache.decompose_resid(
        layer=-1, pos_slice=-1, return_labels=True
        )
        per_layer_logit_diffs = residual_stack_to_logit_diff(model,per_layer_residual[:], base_cache,
                                                            logit_diff_directions, mean=False)


        list_per_layer_logit_diffs.append(per_layer_logit_diffs)
        

        torch.cuda.empty_cache()
        gc.collect()
    cat_per_layer_logit_diffs = torch.cat(list_per_layer_logit_diffs, 0)
    return cat_per_layer_logit_diffs, labels

In [68]:
dataset_type = 'both' # singular / plural / both
num_samples = 300
batch_size = 10
lang_logit_diff = {}
for language in ['Spanish', 'English']:
    dataset = load_sva_dataset(model, language, dataset_type, num_samples)
    batched_dataset = get_batched_dataset(model, dataset, batch_size=batch_size)
    batch_lang_logit_diff, labels = get_per_layer_logit_diff(batched_dataset, len(batched_dataset['batches_src_tokens']))
    lang_logit_diff[language] = batch_lang_logit_diff.mean(0).tolist()

In [69]:
spa_color = 'rgb' + str(ImageColor.getcolor(html_colors['green_drawio'], "RGB"))
eng_color = 'rgb' + str(ImageColor.getcolor(html_colors['brown_D3'], "RGB"))

In [70]:
component_labels_to_pretty_labels = [clean_blocks_labels(label) for label in labels]

df = pd.DataFrame.from_dict(lang_logit_diff)
fig = px.line(df,color_discrete_map={"English": eng_color, "Spanish": spa_color},)
fig.update_layout(legend_title_text='Language',
                  xaxis_title="Block", yaxis_title="Logit Difference", 
                  xaxis = dict(
                                tickmode = 'array',
                                tickvals = np.arange(len(component_labels_to_pretty_labels)),
                                ticktext = component_labels_to_pretty_labels
                                ),
                font=dict(
                            size=15,  # Set the font size here
                        )
                )
fig = paper_plot(fig)
fig.show()
pio.write_image(fig, f'{images_dir}/logit_diff_both_blocks.pdf',scale=6, width=800, height=350)

In [None]:
# per_head_residual, labels = cache.stack_head_results(
#     layer=-1, pos_slice=-1, return_labels=True
# )
# per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)
# per_head_logit_diffs = einops.rearrange(
#     per_head_logit_diffs,
#     "(layer head_index) -> layer head_index",
#     layer=model.cfg.n_layers,
#     head_index=model.cfg.n_heads,
# )
# imshow(
#     per_head_logit_diffs[:],
#     title="Logit Difference From Each Head",
#     xaxis="Head", 
#     yaxis="Layer",
# )

Tried to stack head results when they weren't cached. Computing head results now


In [116]:
def contrib_logit_diff(model, batched_dataset, hook_name, layer_index):
    batches_base_tokens = batched_dataset['batches_base_tokens']
    batches_src_tokens = batched_dataset['batches_src_tokens']
    batches_answer_token_indices = batched_dataset['batches_answer_token_indices']
    
    batches_answer_token_indices
    neurons_contrib_logit_diff_list = []
    for batch in range(len(batches_base_tokens)):
        # Get clean tokens and answer indices from batches
        base_tokens = batches_base_tokens[batch]
        answer_token_indices = batches_answer_token_indices[batch]
        base_logits, base_cache = model.run_with_cache(base_tokens)
        #answer_token_indices = answer_token_indices.to(base_logits.device)
        answer_residual_directions = model.tokens_to_residual_directions(answer_token_indices)

        # Difference of unembedding vectors
        logit_diff_directions = (
            answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
        )

        _, base_cache = model.run_with_cache(base_tokens, names_filter=[hook_name, "ln_final.hook_scale"])
        output_hook_pos = base_cache[hook_name][:,-1]
        #list_output_hook.append(output_hook_pos)
        act_w_out_rows = einsum(
            f"batch d_ffn, d_ffn d_model \
            -> batch d_ffn d_model",
            base_cache[hook_name][:,-1],
            model.W_out[layer_index])

        # Scaling by final LN scaling value (we need this to get exact logit diffs)
        scaled_act_w_out_rows = act_w_out_rows / base_cache["ln_final.hook_scale"][:,-1].unsqueeze(1)

        # Dot product with unembeding diff directions (first apply final LN weights)
        neurons_contrib_logit_diff = einsum(
                f"batch d_ffn d_model, batch d_model \
                -> batch d_ffn",
                scaled_act_w_out_rows * model.ln_final.w,
                logit_diff_directions)
        neurons_contrib_logit_diff_list.append(neurons_contrib_logit_diff)
        torch.cuda.empty_cache()
    neurons_contrib_logit_diff = torch.cat(neurons_contrib_logit_diff_list, 0)

    return neurons_contrib_logit_diff

In [119]:
layer_index = 16
hook_name = utils.get_act_name('post', layer_index)

dataset_type = 'both' # singular / plural / both
num_samples = 300
batch_size = 10
lang_logit_diff = {}
for language in ['Spanish', 'English']:
    dataset = load_sva_dataset(model, language, dataset_type, num_samples)
    batched_dataset = get_batched_dataset(model, dataset, batch_size=batch_size)
    neurons_contrib_logit_diff = contrib_logit_diff(model, batched_dataset, hook_name, layer_index)
    lang_logit_diff[language] = neurons_contrib_logit_diff.mean(0).tolist()
#output_attention_head = cache[hook_name][:,-1,attn_head_index]


141
159
151
149


In [120]:
import plotly.graph_objects as go

df = pd.DataFrame.from_dict(lang_logit_diff)
fig = px.line(df['Spanish'],color_discrete_map={"Spanish": spa_color, "English": eng_color},)
fig.update_layout(legend_title_text='Language',
                  xaxis_title="Neuron", yaxis_title="Logit Difference",
                  font=dict(
                      size=15,  # Set the font size here
                  ) 
                )
fig.add_trace(go.Scatter(x=np.arange(len(df['English'])), y=df['English'], mode='lines', name='English', line=dict(color=eng_color)))

#fig.update_traces(opacity=.4)

fig = paper_plot(fig)
fig.show()
# Update
#fig.update_layout(legend_traceorder="reversed")

pio.write_image(fig, f'{images_dir}/MLP{layer_index}_both_neurons.pdf',scale=6, width=800, height=350)

In [None]:
# layer_index = 13
# # Getting act * w_out rows -> [batch d_ffn d_model]
# act_w_out_rows = einsum(
#         f"batch d_ffn, d_ffn d_model \
#         -> batch d_ffn d_model",
#         cache[f'blocks.{layer_index}.mlp.hook_post'][:,-1],
#         model.W_out[layer_index])

# # Scaling by final LN scaling value (we need this to get exact logit diffs)
# scaled_act_w_out_rows = act_w_out_rows / cache["ln_final.hook_scale"][:,-1].unsqueeze(1)

# # Dot product with unembeding diff directions (first apply final LN weights)
# neurons_contrib_logit_diff = einsum(
#         f"batch d_ffn d_model, batch d_model \
#         -> batch d_ffn",
#         scaled_act_w_out_rows * model.ln_final.w,
#         logit_diff_directions)

# # Check added logit diffs of all neurons add up to original MLP contribution to the logit diff (mean across batch)
# print(neurons_contrib_logit_diff.sum(-1).mean())
# torch.cuda.empty_cache()

# line(neurons_contrib_logit_diff.mean(0),
# title=f"Logit Difference From Each Neuron in MLP {layer_index}",
# labels={"x": "Neuron", "y": "Logit Difference"},)

### Attention Maps

In [25]:
from circuitsvis.attention import attention_heads

def visualize_attention_patterns(
    type_pattern: str,
    heads: Union[List[int], int, Float[torch.Tensor, "heads"]],
    local_cache: ActivationCache,
    local_tokens: torch.Tensor,
    title: Optional[str] = "",
    max_width: Optional[int] = 700,
    html: Optional[bool] = True
):
    # If a single head is given, convert to a list
    if isinstance(heads, int):
        heads = [heads]

    # Create the plotting data
    labels: List[str] = []
    patterns: List[Float[torch.Tensor, "dest_pos src_pos"]] = []

    # Assume we have a single batch item
    #batch_index = 0
    print('heads', heads)

    for head in heads:
        # Set the label
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        labels.append(f"L{layer}H{head_index}")

        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        if type_pattern == 'attn_weights':
            updated_pattern = local_cache["attn", layer][:, head_index]
            print(updated_pattern.shape)
        else:
            # We compute attention heads weighted value vectors a_{i,j} x_j W_V
            pattern = local_cache[f'blocks.{layer}.attn.hook_pattern']
            v = local_cache[f'blocks.{layer}.attn.hook_v']
            weighted_values = einsum(
                        "batch key_pos head_index d_head, \
                        batch head_index query_pos key_pos -> \
                        batch query_pos key_pos head_index d_head",
                        v,
                        pattern,
                    )# [batch, query_pos, key_pos, head_index, d_head]

            if type_pattern == 'value_weighted':
                # Value-weighted norms
                raw_inter_token_attribution = torch.norm(weighted_values, dim=-1, p=2)
                # weighted_values_norm -> [batch query_pos key_pos head_index]

            elif type_pattern == 'output_value_weighted' or type_pattern == 'distance_based':
                # We decompose attention heads further by computing a_{i,j} x_j W_OV
                output_weighted_values = einsum(
                        "batch query_pos key_pos head_index d_head, \
                            head_index d_head d_model -> \
                            batch query_pos key_pos head_index d_model",
                        weighted_values,
                        model.W_O[layer],
                    )

                # Check sum decomposition is equivalent to cached values
                output_heads = output_weighted_values.sum(2)
                output_attention = output_heads.sum(-2) + model.b_O[layer]
                assert torch.dist(output_attention, local_cache[f'blocks.{layer}.hook_attn_out']).item() < 1e-3 * local_cache[f'blocks.{layer}.hook_attn_out'].numel()

                if type_pattern == 'output_value_weighted':
                    # Output-value-weighted norms
                    # weighted_values_norm -> [batch query_pos key_pos head_index]
                    raw_inter_token_attribution = torch.norm(output_weighted_values, dim=-1, p=1)

                elif type_pattern == 'distance_based':
                    # Distance-based
                    EPS = 1e-5
                    # distance -> [batch query_pos key_pos head_index]
                    distance = -F.pairwise_distance(output_weighted_values, output_heads.unsqueeze(2),p=2)
                    # head_output_norm -> [batch query_pos head_index]
                    head_output_norm = torch.norm(output_heads, p=2, dim=-1)
                    raw_inter_token_attribution = (distance + head_output_norm.unsqueeze(2)).clip(min=EPS)
                    

            # Normalize over key_pos
            inter_token_attribution = raw_inter_token_attribution / raw_inter_token_attribution.sum(dim=-2,keepdim=True)
            updated_pattern = inter_token_attribution[:, :, :, head_index]
            
        patterns.append(updated_pattern)

    # Combine the patterns into a single tensor
    patterns: Float[torch.Tensor, "batch head_index dest_pos src_pos"] = torch.stack(
        patterns, dim=1
    )

    if html:
        # Convert the tokens to strings (for the axis labels)
        str_tokens = model.to_str_tokens(local_tokens)
        # Circuitsvis Plot (note we get the code version so we can concatenate with the title)
        plot = attention_heads(
            attention=patterns, tokens=str_tokens, attention_head_names=labels
        ).show_code()
        # Display the title
        title_html = f"<h2>{title}</h2><br/>"
        # Return the visualisation as raw code
        return f"<div style='max-width: {str(max_width)}px;'>{title_html + plot}</div>"
    else:
        return patterns


In [12]:
# Get clean tokens and answer indices from batches
model.set_use_attn_result(True)

batch = 0
base_tokens = batches_base_tokens[batch]
answer_token_indices = batches_answer_token_indices[batch]

base_logits, base_cache = model.run_with_cache(base_tokens)

In [13]:
head_name_to_number = {}

head_names_list = [f"L{layer}H{head}" for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
for i, head in enumerate(head_names_list):
    head_name_to_number[head] = i

heads = ['L13H7', 'L17H4']
heads_number = [head_name_to_number[head] for head in heads]

torch.Size([10, 7, 256000])

In [26]:
top_k = 10
type_pattern = 'attn_weights' # attn_weights / value_weighted / output_value_weighted / distance_based

viz_attn = visualize_attention_patterns(
    type_pattern,
    heads_number,
    base_cache,
    base_tokens[0], # TODO: change this to general tokens
    f"Top {top_k} Positive Logit Attribution Heads",
    html=False
)

heads [111, 140]
torch.Size([10, 7, 7])
before torch.Size([10, 7, 7])
1
torch.Size([10, 7, 7])
before torch.Size([10, 7, 7])
2
after torch.Size([10, 2, 7, 7])


In [None]:
x=[f"{tok} {i}" for i, tok in enumerate(patching_plot_sentence)],
y=[f"{tok} {i}" for i, tok in enumerate(patching_plot_sentence)],

In [30]:
imshow(
    viz_attn[:,0].mean(0),
    title="",
    xaxis="", 
    yaxis="",
)

### Composition Attention head and MLP neurons

In [5]:
model.set_use_attn_result(True)

In [6]:
def get_composition(batched_dataset, attn_layer_index, attn_head_index, MLP_layer, MLP_neuron):
    batches_base_tokens = batched_dataset['batches_base_tokens']

    list_output_hook = []
    #hook_name = f'blocks.{attn_layer_index}.attn.hook_result'
    hook_name = utils.get_act_name('result', attn_layer_index)
    #hook_name = utils.get_act_name('resid_pre', attn_layer_index)
    #output_attention_head = cache[hook_name][:,-1,attn_head_index]
    for batch in range(len(batches_base_tokens)):
        # Get clean tokens and answer indices from batches
        base_tokens = batches_base_tokens[batch]
        base_logits, base_cache = model.run_with_cache(base_tokens, names_filter=[hook_name])
        output_hook_pos = base_cache[hook_name][:,-1]
        list_output_hook.append(output_hook_pos)
    #key_attention_head = cache[f'blocks.{attn_layer_index}.attn.hook_q'][:,-1,attn_head_index]
    cat_output_hook = torch.cat(list_output_hook, 0)
    output_attention_head = cat_output_hook[:,attn_head_index]

    # Dot product with neuron's W_in column (input weights)
    dot_prod = einsum(
            f"batch d_model, d_model \
            -> batch ",
            output_attention_head,
            model.W_in[MLP_layer][:,MLP_neuron])
    return dot_prod



In [70]:
from collections import defaultdict

def flatten(xss):
    return [x for xs in xss for x in xs]


lang_per_number_dict = {}

# Attention head
attn_layer_index = 13
attn_head_index = 7
# MLP neuron
MLP_layer = 13
MLP_neuron = 2069


images_dir = f"{home}/circuits_languages/images"
dataset_type = 'both' # singular / plural / both
num_samples = 100
batch_size = 10
lang_dot_prod = {}
for language in ['Spanish', 'English']:
    dataset = load_sva_dataset(model, language, dataset_type, num_samples)
    batched_dataset = get_batched_dataset(model, dataset, batch_size=batch_size)
    dot_prod = get_composition(batched_dataset, attn_layer_index, attn_head_index, MLP_layer, MLP_neuron)
    
    flatten_batches_base_tokens = flatten(batched_dataset['batches_base_tokens'])
    flatten_ex_number_list = flatten(batched_dataset['batches_ex_number_list'])
    per_number_dict = defaultdict(list)
    for batch_element in range(dot_prod.shape[0]):
        print(f'{model.to_string(flatten_batches_base_tokens[batch_element])}, dot product: {dot_prod[batch_element]:.2}')
        if flatten_ex_number_list[batch_element] == 'Singular':
            if len(per_number_dict['Singular']) == 149:
                # Workaraound to get lists of the same length
                continue
            else:
                per_number_dict['Singular'].append(dot_prod[batch_element].item())
            
        elif flatten_ex_number_list[batch_element] == 'Plural':
            per_number_dict['Plural'].append(dot_prod[batch_element].item())

    lang_per_number_dict[language] = per_number_dict





48
52
<bos>El detective que tuvo al investigador, dot product: -1.4
<bos>El detective que puso al ministro, dot product: -1.3
<bos>Los diputados que encontraron al director, dot product: 3.8
<bos>El vendedor que dijo al policía, dot product: -1.4
<bos>El piloto que vio al director, dot product: -1.9
<bos>El abogado que dio al científico, dot product: -1.6
<bos>El agente que dio al detective, dot product: -1.3
<bos>El policía que pasó al piloto, dot product: -1.6
<bos>El diseñador que dijo al corredor, dot product: -0.98
<bos>El rey que vio al chef, dot product: -1.9
<bos>Los periodistas que tuvieron al vendedor, dot product: 4.4
<bos>El escritor que dio al corredor, dot product: -0.49
<bos>El técnico que tuvo al rey, dot product: -1.3
<bos>Los ministros que pusieron al maestro, dot product: 2.5
<bos>Los escritores que dijeron al piloto, dot product: 2.0
<bos>El técnico que puso al vendedor, dot product: -1.7
<bos>Los productores que encontraron al sacerdote, dot product: 3.7
<bos>El ag

In [71]:
length_lists = []
for language in ['Spanish', 'English']:
    for number in ['Singular', 'Plural']:
            length_lists.append(len(lang_per_number_dict[language][number]))
min_len = np.array(length_lists).min()
for language in ['Spanish', 'English']:
    for number in ['Singular', 'Plural']:
            lang_per_number_dict[language][number] = lang_per_number_dict[language][number][:min_len]

In [11]:
from collections import defaultdict
per_number_dict = defaultdict(list)

for batch_element in range(dot_prod.shape[0]):
    print(f'{model.to_string(flatten_batches_base_tokens[batch_element])}, dot product: {dot_prod[batch_element]:.2}')
    if flatten_ex_number_list[batch_element] == 'Singular':
        if len(per_number_dict['Singular']) == 149:
            # Workaraound to get lists of the same length
            continue
        else:
            per_number_dict['Singular'].append(dot_prod[batch_element].item())
        
    elif flatten_ex_number_list[batch_element] == 'Plural':
        per_number_dict['Plural'].append(dot_prod[batch_element].item())

<bos>The ministers that disguised the executive, dot product: 2.5
<bos>The authors that injured the secretary, dot product: 2.2
<bos>The athlete that disguised the executive, dot product: -1.2
<bos>The guard that injured the secretary, dot product: -0.98
<bos>The ministers that embarrassed the manager, dot product: 2.8
<bos>The executive that embarrassed the manager, dot product: -0.95
<bos>The executive that injured the doctors, dot product: -0.49
<bos>The farmer that ignored the teachers, dot product: -1.2
<bos>The executives that injured the doctors, dot product: 3.5
<bos>The authors that ignored the teachers, dot product: 3.6
<bos>The managers that admired the author, dot product: 2.5
<bos>The actors that ignored the author, dot product: 3.0
<bos>The executive that admired the author, dot product: -0.93
<bos>The secretary that ignored the author, dot product: -0.99
<bos>The consultant that ignored the actor, dot product: -0.9
<bos>The authors that ignored the actor, dot product: 3.

In [109]:
import plotly.express as px
import plotly.graph_objects as go

from plotly.subplots import make_subplots

fig = make_subplots(rows=1, cols=1)

df = pd.DataFrame.from_dict(lang_per_number_dict['English'])
#fig = px.box(df, title=f'',color='Plural')
fig.add_trace(go.Box(y=lang_per_number_dict['English']['Plural'],
                    x=['Plural']*len(lang_per_number_dict['English']['Plural']),
                    marker_color=eng_color,
                    showlegend=True,
                    name='English'))
                    
fig.add_trace(go.Box(y=lang_per_number_dict['English']['Singular'],
                    x=['Singular']*len(lang_per_number_dict['English']['Plural']),
                    marker_color=eng_color,
                    showlegend=False))

fig.add_trace(go.Box(y=lang_per_number_dict['Spanish']['Plural'],
                    x=['Plural']*len(lang_per_number_dict['Spanish']['Plural']),
                    marker_color=spa_color,
                    showlegend=False))
fig.add_trace(go.Box(y=lang_per_number_dict['Spanish']['Singular'],
                    x=['Singular']*len(lang_per_number_dict['Spanish']['Singular']),
                    marker_color=spa_color,
                    name='Spanish'))

fig.update_layout(
    xaxis_title="Subject Number", yaxis_title="",
    legend_title_text='Language',
    font=dict(
        size=15,  # Set the font size here
    ),
    boxmode='group'
)
fig.update_layout(legend_traceorder="reversed")

fig = paper_plot(fig, tickangle=0)
fig.show()
pio.write_image(fig, f'{images_dir}/both_MLP{MLP_layer}_neuron{MLP_neuron}_act_subj_num.pdf',scale=5, width=550, height=350)


In [121]:
layer_index = 16
head_index = 4

# Compute final LN normalized attention head output, attention block
# MLP output, or specific neuron

# normalized_attn_hook_result = cache.apply_ln_to_stack(
#     cache[f'blocks.{layer_index}.attn.hook_result'][:,-1,head_index], layer=-1, pos_slice=-1
# )

# normalized_attn_block = cache.apply_ln_to_stack(
#     cache['blocks.0.hook_attn_out'][:,-1,:], layer=-1, pos_slice=-1
# )

# # [50, 7, 2048]
# normalized_mlp_out = cache.apply_ln_to_stack(
#     cache[f'blocks.{layer_index}.hook_mlp_out'][:,-1], layer=-1, pos_slice=-1
# )


neuron_idx = 7540#2069
neuron = model.W_out[layer_index][neuron_idx].unsqueeze(0)

# Project component output into the vocabulary space
# Select model component output!
tokens_component = einsum(
        f"batch d_model, d_model vocab_size \
        -> batch vocab_size",
        neuron * model.ln_final.w,
        model.unembed.W_U)
torch.cuda.empty_cache()

In [122]:
from collections import defaultdict

# Get top/bottom tokens
top_k = 500
largest = True # True: promoted tokens / False: suppressed tokens
ranking_dict = defaultdict(list)
top_k_tokens_component = torch.topk(tokens_component, top_k, dim=-1, largest=largest).indices#.cpu().tolist()
for batch_element in range(top_k_tokens_component.shape[0]):
        top_k_str_tokens_component = model.tokenizer.convert_ids_to_tokens(top_k_tokens_component[batch_element])
        if largest == False:
                print(f'Suppresed tokens: {top_k_str_tokens_component}')
        else:
                print(f'Top promoted tokens: {top_k_str_tokens_component}')
        # rank_corr_pred = (torch.sort(tokens_component, descending=largest).indices[batch_element] == answer_token_indices[batch_element][0].item()).nonzero(as_tuple=True)[0].item()
        # rank_wrong_pred = (torch.sort(tokens_component, descending=largest).indices[batch_element] == answer_token_indices[batch_element][1].item()).nonzero(as_tuple=True)[0].item()
        
        # if largest == False:
        #         # We show indices starting from the last
        #         rank_corr_pred*=-1
        #         rank_wrong_pred*=-1
        # ranking_dict['corr_pred'].append(rank_corr_pred)
        # ranking_dict['wrong_pred'].append(rank_wrong_pred)
        # print(f'Ranking correct prediction: {rank_corr_pred}')
        # print(f'Ranking wrong prediction: {rank_wrong_pred}\n')


Top promoted tokens: ['▁son', 'Son', '▁Son', 'son', '▁SON', 'SON', '▁daughter', 'idyl', '▁sons', '▁sun', 'ebenarnya', 'eseorang', 'jectures', 'daughter', 'idist', 'illaries', 'Jäh', 'daughters', '▁sú', '▁sunt', 'orance', 'Jährige', 'unks', '▁grandson', 'chitis', 'idorm', '▁Sun', 'han', '▁Daughter', '▁***!', 'ierzch', '▁Sunt', 'undial', 'irvana', 'ocarcinoma', 'épendance', '▁han', 'elashes', 'loroethene', '▁SUN', 'umerable', '▁yPos', 'setViewportView', 'haustible', 'ophosph', 'sons', 'incón', 'épendant', 'oole', '▁so', 'iddhar', '▁zodiaco', 'vellous', 'ğaz', '▁للمعارف', 'Sun', '▁су', 'daction', '▁minY', 'oriasis', '▁daughters', 'wdriver', 'conium', '▁Gland', 'sun', '▁kõik', 'ôles', 'enumii', 'soni', 'Daughter', 'TIMORE', '▁Sons', 'lillah', '▁granddaughter', 'ellation', 'igroup', 'lalom', 'omatous', '▁sū', 'itaries', '▁rám', 'namic', 'tiously', 'chst', 'razine', 'HAN', 'raltar', 'ervations', 'outons', 'ceptre', '▁inimes', 'eterminate', '▁lanka', 'tplatz', '▁ceramica', 'ulir', '▁nanti', '

In [None]:
df = pd.DataFrame.from_dict(ranking_dict)
import plotly.express as px
fig = px.box(df, title=f'Ranking correct and wrong verb forms L{layer_index}H{head_index} Spanish')
fig.show()