In [1]:
%load_ext autoreload
%autoreload 2

import os
from os.path import expanduser
home = expanduser("~")

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from fancy_einsum import einsum
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from typing import List, Union, Optional

from IPython.display import HTML
import pandas as pd
from neel_plotly import line, imshow, scatter
from jaxtyping import Float
import plotly.io as pio

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import transformer_lens.patching as patching
from utils_sva import residual_stack_to_logit_diff
torch.set_grad_enabled(False)

# Create folder to save plots
images_dir = f"{home}/circuits_languages/images"
if not os.path.exists(images_dir):
    os.makedirs(images_dir)

In [3]:
from utils_sva import clean_blocks_labels, paper_plot
from utils_sva import get_logit_diff, compute_act_patching
from load_dataset import load_sva_dataset, get_batched_dataset

### Load Model

In [4]:
n_devices = torch.cuda.device_count()

model = HookedTransformer.from_pretrained(
    "gemma-2b",
    center_unembed=True,
    center_writing_weights=False,
    fold_ln=False,
    fold_value_biases=False,
    #n_devices=2
)
model.set_use_attn_result(False)
# Get the default device used
device: torch.device = utils.get_device()

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model gemma-2b into HookedTransformer


In [5]:
utils.test_prompt('Los chefs que vieron al productor', ' van', model)

Tokenized prompt: ['<bos>', 'Los', ' chefs', ' que', ' vieron', ' al', ' productor']
Tokenized answer: [' van']


Top 0th token. Logit: 15.52 Prob: 21.16% Token: | de|
Top 1th token. Logit: 14.17 Prob:  5.46% Token: | y|
Top 2th token. Logit: 14.12 Prob:  5.22% Token: | del|
Top 3th token. Logit: 13.43 Prob:  2.61% Token: | |
Top 4th token. Logit: 13.04 Prob:  1.77% Token: |,|
Top 5th token. Logit: 12.91 Prob:  1.56% Token: | en|
Top 6th token. Logit: 12.69 Prob:  1.25% Token: | ejecutivo|
Top 7th token. Logit: 12.55 Prob:  1.08% Token: | David|
Top 8th token. Logit: 12.37 Prob:  0.91% Token: | que|
Top 9th token. Logit: 12.19 Prob:  0.76% Token: | Michael|


### Patching Experiments

In [8]:
dataset_type = 'both' # singular / plural / both
language = 'english' # english / spanish / both
num_samples = 300
batch_size = 30
start_at = 0
dataset = load_sva_dataset(model, language, dataset_type, num_samples)
batched_dataset = get_batched_dataset(model, dataset, batch_size=batch_size)

batches_base_tokens = batched_dataset['batches_base_tokens']
batches_src_tokens = batched_dataset['batches_src_tokens']
batches_answer_token_indices = batched_dataset['batches_answer_token_indices']
batches = len(batches_src_tokens)

151
149


In [9]:
# Compute full dataset results
src_logit_diff_list = []
base_logit_diff_list = []
for batch in range(batches):
    base_tokens = batches_base_tokens[batch]
    src_tokens = batches_src_tokens[batch]
    answer_token_indices = batches_answer_token_indices[batch]

    base_logits, base_cache = model.run_with_cache(base_tokens)
    src_logits, corrupted_cache = model.run_with_cache(src_tokens)
    answer_token_indices = answer_token_indices.to(base_logits.device)
    base_logit_diff = get_logit_diff(base_logits, answer_token_indices, mean=False)
    base_logit_diff_list.append(base_logit_diff)
    print(f"Base logit diff batch mean: {base_logit_diff.mean().item():.4f}")

    src_logit_diff = get_logit_diff(src_logits, answer_token_indices, mean=False)
    src_logit_diff_list.append(src_logit_diff)
    print(f"Source logit diff batch mean: {src_logit_diff.mean().item():.4f}")

full_base_logit_diff = torch.cat(base_logit_diff_list,0).mean(0)
full_src_logit_diff = torch.cat(src_logit_diff_list,0).mean(0)
# Patching metric
CLEAN_BASELINE = full_base_logit_diff
CORRUPTED_BASELINE = full_src_logit_diff
def ioi_metric(logits, answer_token_indices):
    answer_token_indices = answer_token_indices.to(logits.device)
    return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (CLEAN_BASELINE  - CORRUPTED_BASELINE)


Base logit diff batch mean: 3.1620
Source logit diff batch mean: -3.1620
Base logit diff batch mean: 3.1855
Source logit diff batch mean: -3.1908
Base logit diff batch mean: 2.9768
Source logit diff batch mean: -2.9716
Base logit diff batch mean: 3.1677
Source logit diff batch mean: -3.0948
Base logit diff batch mean: 3.4246
Source logit diff batch mean: -3.4975
Base logit diff batch mean: 3.1949
Source logit diff batch mean: -3.1949
Base logit diff batch mean: 3.1293
Source logit diff batch mean: -3.1293
Base logit diff batch mean: 3.1453
Source logit diff batch mean: -3.1427
Base logit diff batch mean: 3.2583
Source logit diff batch mean: -3.2608
Base logit diff batch mean: 2.7124
Source logit diff batch mean: -2.8059


In [10]:
fig = px.bar([full_base_logit_diff.item(), full_src_logit_diff.item()])
fig.update(layout_coloraxis_showscale=False,layout_showlegend=False)
#fig.update_layout(legend_title_text='Subject Number')
fig.update_layout(
    xaxis_title="", yaxis_title="Logit Difference"
)
fig.update_layout(
    xaxis = dict(
        tickmode = 'array',
        tickvals = [0, 1],
        ticktext = ['Clean input', 'Corrupted input']
    ),
    font=dict(
        size=15,  # Set the font size here
    )
)
fig = paper_plot(fig, tickangle=0)
fig.update_layout(yaxis_range=[-3.4,3.4])
fig.show()
pio.write_image(fig, f'{images_dir}/{language}_logit_diffs.png',scale=5, width=550, height=350)


In [9]:
y_labels = [f'{str(layer)}' for layer in range(model.cfg.n_layers-1,-1,-1)]
eng_patching_plot_sentence = model.to_str_tokens(model.to_tokens('The executives that embarrassed the manager'))
spa_patching_plot_sentence = model.to_str_tokens(model.to_tokens('Los empleados que vieron al periodista'))
patching_plot_sentence = spa_patching_plot_sentence if language == 'spanish' else eng_patching_plot_sentence

In [10]:
# resid_streams
# heads_all_pos : attn heads all positions at the same time
# heads_last_pos: attn heads last position
# full: (resid streams, attn block outs and mlp outs)
patching_type = 'full'
total_resid_pre_act_patch_results = compute_act_patching(model,
                                                        ioi_metric,
                                                        patching_type,
                                                        batches_base_tokens,
                                                        batches_src_tokens,
                                                        batches_answer_token_indices,
                                                        len(batches_src_tokens))

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

In [11]:
# Patching residual streams
#resid_pre_act_patch_results = patching.get_act_patch_resid_pre(model, src_tokens, base_cache, ioi_metric)
if patching_type=='resid_streams':
       fig = imshow(torch.flip(total_resid_pre_act_patch_results, dims=[0]), 
                yaxis="Layer", 
                xaxis="Position",
                x=[f"{tok} {i}" for i, tok in enumerate(patching_plot_sentence)],
                y=y_labels,
                title="resid_pre Activation Patching",
                return_fig=True)

#every_block_result = patching.get_act_patch_block_every(model, src_tokens, base_cache, ioi_metric)
elif patching_type=='full':
        fig = imshow(torch.flip(total_resid_pre_act_patch_results, dims=[1]), facet_col=0,
                        y=y_labels,
                        facet_labels=["Residual Stream", "Attn Output", "MLP Output"],
                        title="Activation Patching Per Block", xaxis="Position", yaxis="Layer",
                        zmax=1, zmin=-1, x= patching_plot_sentence,
                        return_fig=True
                        )
        width = 800
        fig.update_xaxes(tickangle=45)
elif patching_type == 'heads_last_pos':
        fig = imshow(torch.flip(total_resid_pre_act_patch_results, dims=[0]), 
                        yaxis="Layer", 
                        xaxis="Head",
                        x=[f'{head}' for head in range(model.cfg.n_heads)],
                        y=y_labels,
                        title="Attn Head Output (Last Pos)",
                        return_fig=True)
        width=350
elif patching_type == 'heads_all_pos':
        fig = imshow(torch.flip(total_resid_pre_act_patch_results, dims=[0]), 
        yaxis="Layer", 
        xaxis="Head",
        x=[f'{head}' for head in range(model.cfg.n_heads)],
        y=y_labels,
        title="Attn Head Output (All Pos)",
        return_fig=True)
        width=350
        
fig.show()
pio.write_image(fig, f'{images_dir}/patching_{language}_{patching_type}.png',scale=5, width=width, height=500)