In [1]:
%load_ext autoreload
%autoreload 2

import os
from os.path import expanduser
home = expanduser("~")

In [2]:
import torch
import numpy as np
from pathlib import Path
import plotly.express as px

from typing import List, Union, Optional
from jaxtyping import Float
from IPython.display import HTML
from neel_plotly import line, imshow, scatter

import plotly.io as pio
import transformer_lens.utils as utils
from transformer_lens import HookedTransformer
torch.set_grad_enabled(False)

# Create folder to save plots
images_dir = f"{home}/circuits_languages/images"
if not os.path.exists(images_dir):
    os.makedirs(images_dir)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from utils_sva import clean_blocks_labels, paper_plot
from utils_sva import get_logit_diff, compute_act_patching
from load_dataset import load_sva_dataset, get_batched_dataset, create_sva_datasets

### Load Model

In [4]:
n_devices = torch.cuda.device_count()
model_alias = "gemma-2-2b"
model = HookedTransformer.from_pretrained(
    model_alias,
    center_unembed=True,# if not model_alias.startswith("gemma-2-") else False,
    center_writing_weights=False,
    fold_ln=False,
    fold_value_biases=False,
)
model.set_use_attn_result(False)
# Get the default device used
device: torch.device = utils.get_device()

create_sva_datasets(model)




Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00,  4.40it/s]


Loaded pretrained model gemma-2-2b into HookedTransformer
Dataset saved to ./datasets/final_datasets/english_train_sva_dataset.json
Dataset saved to ./datasets/final_datasets/english_validation_sva_dataset.json
Dataset saved to ./datasets/final_datasets/english_test_sva_dataset.json
Dataset saved to ./datasets/final_datasets/spanish_train_sva_dataset.json
Dataset saved to ./datasets/final_datasets/spanish_validation_sva_dataset.json
Dataset saved to ./datasets/final_datasets/spanish_test_sva_dataset.json


In [8]:
model.cfg.output_logits_soft_cap

30.0

In [9]:
utils.test_prompt('Los chefs que vieron al productor', ' van', model)

Tokenized prompt: ['<bos>', 'Los', ' chefs', ' que', ' vieron', ' al', ' productor']
Tokenized answer: [' van']


Top 0th token. Logit: 23.75 Prob: 33.07% Token: | de|
Top 1th token. Logit: 22.34 Prob:  8.06% Token: | y|
Top 2th token. Logit: 22.03 Prob:  5.93% Token: | ejecutivo|
Top 3th token. Logit: 21.99 Prob:  5.71% Token: | |
Top 4th token. Logit: 21.86 Prob:  5.01% Token: | gastron|
Top 5th token. Logit: 21.31 Prob:  2.87% Token: | del|
Top 6th token. Logit: 21.19 Prob:  2.57% Token: |,|
Top 7th token. Logit: 20.77 Prob:  1.68% Token: | en|
Top 8th token. Logit: 20.63 Prob:  1.47% Token: | se|
Top 9th token. Logit: 19.89 Prob:  0.70% Token: | estadounidense|


### Patching Experiments

In [20]:
subject_number = 'both' # singular / plural / both
language = 'spanish' # english / spanish
batch_size = 32
num_samples = 128
dataset = load_sva_dataset(model, language, subject_number, split='train', num_samples=num_samples)
batched_dataset = get_batched_dataset(model, dataset, batch_size=batch_size)

batches_base_tokens = batched_dataset['batches_base_tokens']
batches_src_tokens = batched_dataset['batches_src_tokens']
batches_answer_token_indices = batched_dataset['batches_answer_token_indices']
batches = len(batches_src_tokens)

52
48


In [21]:
# Compute full dataset results
src_logit_diff_list = []
base_logit_diff_list = []
for batch in range(batches):
    base_tokens = batches_base_tokens[batch]
    print(base_tokens.shape)
    src_tokens = batches_src_tokens[batch]
    answer_token_indices = batches_answer_token_indices[batch]

    base_logits, base_cache = model.run_with_cache(base_tokens)
    src_logits, corrupted_cache = model.run_with_cache(src_tokens)
    answer_token_indices = answer_token_indices.to(base_logits.device)
    base_logit_diff = get_logit_diff(base_logits, answer_token_indices, mean=False)
    base_logit_diff_list.append(base_logit_diff)
    print(f"Base logit diff batch mean: {base_logit_diff.mean().item():.4f}")

    src_logit_diff = get_logit_diff(src_logits, answer_token_indices, mean=False)
    src_logit_diff_list.append(src_logit_diff)
    print(f"Source logit diff batch mean: {src_logit_diff.mean().item():.4f}")

full_base_logit_diff = torch.cat(base_logit_diff_list,0).mean(0)
full_src_logit_diff = torch.cat(src_logit_diff_list,0).mean(0)
# Patching metric
CLEAN_BASELINE = full_base_logit_diff
CORRUPTED_BASELINE = full_src_logit_diff
def ioi_metric(logits, answer_token_indices):
    answer_token_indices = answer_token_indices.to(logits.device)
    return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (CLEAN_BASELINE  - CORRUPTED_BASELINE)


torch.Size([32, 7])
Base logit diff batch mean: 2.1081
Source logit diff batch mean: -2.1049
torch.Size([32, 7])
Base logit diff batch mean: 2.5846
Source logit diff batch mean: -1.9166
torch.Size([32, 7])
Base logit diff batch mean: 1.7913
Source logit diff batch mean: -2.5777


In [22]:
fig = px.bar([full_base_logit_diff.item(), full_src_logit_diff.item()])
fig.update(layout_coloraxis_showscale=False,layout_showlegend=False)
#fig.update_layout(legend_title_text='Subject Number')
fig.update_layout(
    xaxis_title="", yaxis_title="Logit Difference"
)
fig.update_layout(
    xaxis = dict(
        tickmode = 'array',
        tickvals = [0, 1],
        ticktext = ['Clean input', 'Corrupted input']
    ),
    font=dict(
        size=15,  # Set the font size here
    )
)
fig = paper_plot(fig, tickangle=0)
fig.update_layout(yaxis_range=[-3.6,3.6])
fig.show()
os.makedirs(f'{images_dir}/logit_diffs', exist_ok=True)
pio.write_image(fig, f'{images_dir}/logit_diffs/{model_alias}_{language}_{subject_number}.png',scale=5, width=550, height=350)


In [23]:
y_labels = [f'{str(layer)}' for layer in range(model.cfg.n_layers-1,-1,-1)]
eng_patching_plot_sentence = model.to_str_tokens(model.to_tokens('The executives that embarrassed the manager'))
spa_patching_plot_sentence = model.to_str_tokens(model.to_tokens('Los empleados que vieron al periodista'))
patching_plot_sentence = spa_patching_plot_sentence if language == 'spanish' else eng_patching_plot_sentence

In [24]:
# resid_streams
# heads_all_pos : attn heads all positions at the same time
# heads_last_pos: attn heads last position
# full: (resid streams, attn block outs and mlp outs)
patching_type = 'resid_streams'
total_resid_pre_act_patch_results = compute_act_patching(model,
                                                        ioi_metric,
                                                        patching_type,
                                                        batches_base_tokens,
                                                        batches_src_tokens,
                                                        batches_answer_token_indices,
                                                        len(batches_src_tokens))

torch.Size([32, 7])


100%|██████████| 182/182 [00:21<00:00,  8.65it/s]


torch.Size([32, 7])


100%|██████████| 182/182 [00:21<00:00,  8.55it/s]


torch.Size([32, 7])


100%|██████████| 182/182 [00:21<00:00,  8.39it/s]


In [25]:
# Patching residual streams
#resid_pre_act_patch_results = patching.get_act_patch_resid_pre(model, src_tokens, base_cache, ioi_metric)
if patching_type=='resid_streams':
       fig = imshow(torch.flip(total_resid_pre_act_patch_results, dims=[0]), 
                yaxis="Layer", 
                xaxis="Position",
                x=[f"{tok} {i}" for i, tok in enumerate(patching_plot_sentence)],
                y=y_labels,
                title="resid_pre Activation Patching",
                return_fig=True)

#every_block_result = patching.get_act_patch_block_every(model, src_tokens, base_cache, ioi_metric)
elif patching_type=='full':
        fig = imshow(torch.flip(total_resid_pre_act_patch_results, dims=[1]), facet_col=0,
                        y=y_labels,
                        facet_labels=["Residual Stream", "Attn Output", "MLP Output"],
                        title="Activation Patching Per Block", xaxis="Position", yaxis="Layer",
                        zmax=1, zmin=-1, x= patching_plot_sentence,
                        return_fig=True
                        )
        width = 800
        fig.update_xaxes(tickangle=45)
elif patching_type == 'heads_last_pos':
        fig = imshow(torch.flip(total_resid_pre_act_patch_results, dims=[0]), 
                        yaxis="Layer", 
                        xaxis="Head",
                        x=[f'{head}' for head in range(model.cfg.n_heads)],
                        y=y_labels,
                        title="Attn Head Output (Last Pos)",
                        return_fig=True)
        width=350
elif patching_type == 'heads_all_pos':
        fig = imshow(torch.flip(total_resid_pre_act_patch_results, dims=[0]), 
        yaxis="Layer", 
        xaxis="Head",
        x=[f'{head}' for head in range(model.cfg.n_heads)],
        y=y_labels,
        title="Attn Head Output (All Pos)",
        return_fig=True)
        width=350
        
fig.show()
# Make the figure transparent
fig.update_layout(
    paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(0,0,0,0)'
)
os.makedirs(f'{images_dir}/patching', exist_ok=True)
#pio.write_image(fig, f'{images_dir}/patching/{model_alias}_{language}_{patching_type}.png',scale=5, width=width, height=500)
