In [1]:
import os
os.environ['HF_HOME'] = '/workspace/huggingface'

from transformer_lens import HookedTransformer, ActivationCache, utils
import torch

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

import plotly.graph_objects as go
from plotly.offline import init_notebook_mode, iplot
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from functools import partial

init_notebook_mode(connected=True)

Device: cuda


In [2]:
model = HookedTransformer.from_pretrained('gemma-2b', device=device)

model.eval()
model.set_use_attn_result(True)
model.set_use_attn_in(True)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
import os
import sys

current_dir = os.getcwd()
atp_dir = os.path.abspath(os.path.join(current_dir, '..', '..'))
sys.path.append(atp_dir)

from atp import Patching

In [4]:
tasks = [
    {
        'name': 'Indirect Object Identification',
        'description': 'Find the correct token to predict when dealing with repeated objects in the sentence',
        'clean_prompt': 'John and Mary went to the store, then John gave a bottle of milk to',
        'corrupted_prompt': 'John and Mary went to the store, then Mary gave a bottle of milk to',
        'clean_answer': ' Mary',
        'corrupted_answer': ' John'
    },
    {
        'name': 'Numbers Addition',
        'description': 'Find the right outcome given by the addition of two numbers',
        'clean_prompt': '10 + 3 = 1',
        'corrupted_prompt': '10 + 6 = 1',
        'clean_answer': '3',
        'corrupted_answer': '6'
    },
    {
        'name': 'Geographic Knowledge',
        'description': 'Find the right country a city is located into.',
        'clean_prompt': 'Paris is in',
        'corrupted_prompt': 'Rome is in',
        'clean_answer': 'France',
        'corrupted_answer': 'Italy'
    },
    {
        'name': 'ICL - Interpolation',
        'description': 'Given the coordinates of two 2D points, find the y of a third point located in between.',
        'clean_prompt': '(-2, 1) (4, -2) (2, -1)\n(-1, 0) (3, 4) (1, 2)\n(1, 3) (5, 1) (3, ',
        'corrupted_prompt': '(-2, 1) (4, -2) (2, -1)\n(-1, 0) (3, 4) (1, 2)\n(1, 3) (5, 5) (3, ',
        'clean_answer': '2',
        'corrupted_answer': '4'
    },
    {
        'name': 'Question Answering',
        'description': 'Answering multiple choice questions.',
        'clean_prompt': 'Which one is red?\n\n(A) the heart\n(B) the sun\n(C) the sea\n\nAnswer: (',
        'corrupted_prompt': 'Which one is blue?\n\n(A) the heart\n(B) the sun\n(C) the sea\n\nAnswer: (',
        'clean_answer': 'C',
        'corrupted_answer': 'A'
    },
    {
        'name': 'Gender Agreement',
        'description': 'Predicting the right pronoun based on the gender.',
        'clean_prompt': 'John saw Marc at the park and waved at',
        'corrupted_prompt': 'John saw Mary at the park and waved at',
        'clean_answer': ' him',
        'corrupted_answer': ' her'
    },
    {
        'name': 'ICL - Contrary Identification',
        'description': 'Identify the contrary of a given word.',
        'clean_prompt': 'happy -> sad\nexciting -> boring\nfast ->',
        'corrupted_prompt': 'happy -> sad\nexciting -> boring\nslow ->',
        'clean_answer': 'slow',
        'corrupted_answer': 'fast'
    },
    {
        'name': 'ICL - Sentiment Analysis',
        'description': 'Identify the sentiment of sentences.',
        'clean_prompt': "This is amazing | Positive\nIt has been the worst thing I've ever seen | Negative\nIt was good overall | ",
        'corrupted_prompt': "This is amazing | Positive\nIt has been the worst thing I've ever seen | Negative\nIt was bad overall | ",
        'clean_answer': 'Positive',
        'corrupted_answer': 'Negative'
    },
]

In [11]:
#  Checks
for i, task in enumerate(tasks):
    print(f"TASK {i} - ", end='')
    assert len(model.to_str_tokens(task['clean_prompt'])) == len(model.to_str_tokens(task['corrupted_prompt']))
    assert model.to_single_token(task['clean_answer'])
    assert model.to_single_token(task['corrupted_answer'])

    if not os.path.exists(f"task-{i}"):
        os.makedirs(f"task-{i}")
        
    print("Passed")

TASK 0 - Passed
TASK 1 - Passed
TASK 2 - Passed
TASK 3 - Passed
TASK 4 - Passed
TASK 5 - Passed
TASK 6 - Passed
TASK 7 - Passed


In [28]:
task_id = 0
task = tasks[task_id]

In [30]:
clean_out = model.generate(
    task['clean_prompt'], 
    max_new_tokens=2,
    temperature=0
)

corrupted_out = model.generate(
    task['corrupted_prompt'], 
    max_new_tokens=2,
    temperature=0
)

print("Clean\n", clean_out, sep='')
print("\nCorrupted\n", corrupted_out, sep='')

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

Clean
John and Mary went to the store, then John gave a bottle of milk to Mary.

Corrupted
John and Mary went to the store, then Mary gave a bottle of milk to John.


In [31]:
method = 'atp'
component = 'attn_all'

patching = Patching(model, method)

patching.patching(
    task['clean_prompt'], task['clean_answer'], 
    task['corrupted_prompt'], task['corrupted_answer'], 
    component=component
)

Clean logit difference: 3.917
Corrupted logit difference: -4.342
Patching...


In [32]:
patching.patch[patching.patch.abs() < 0.2] = 0
fig = patching.plot()
fig.write_html(f"task-{task_id}/{method}_{component}.html")

In [33]:
heads = [(l, h) for h in range(8) for l in range(18) if patching.patch.max(-1).values[8*l + h] > 0]

In [34]:
with torch.no_grad():
    _, cache = model.run_with_cache(model.to_tokens(task['clean_prompt']))

In [37]:
import plotly.express as px
import numpy as np

if not os.path.exists(f"task-{task_id}/patterns"):
    os.makedirs(f"task-{task_id}/patterns")

labels = [f"{tok} ({i})" for i, tok in enumerate(model.to_str_tokens(task['clean_prompt']))]

for l, h in heads:
    data = cache[f'blocks.{l}.attn.hook_pattern'][0, h].cpu()
    
    fig = px.imshow(
        data,
        labels=dict(x="Keys", y="Queries", color="Attention Score"),
        x=labels,
        y=labels,
        title=f'Attention patter at head {h} of layer {l}',
        color_continuous_scale="Blues"
    )
    
    # Adjust the layout for better readability
    fig.update_xaxes(tickangle=35)
    fig.update_layout(coloraxis_colorbar=dict(title="Score"))
    
    # Show the plot
    fig.write_html(f"task-{task_id}/patterns/L{l}H{h}.html")

## SAEs

In [2]:
import torch
from transformer_lens import HookedTransformer
from sae_lens import SparseAutoencoder, ActivationsStore

torch.set_grad_enabled(False)
model = HookedTransformer.from_pretrained("gemma-2b")
sparse_autoencoder = SparseAutoencoder.from_pretrained(
  "gemma-2b-res-jb", # to see the list of available releases, go to: https://github.com/jbloomAus/SAELens/blob/main/sae_lens/pretrained_saes.yaml
  "blocks.12.hook_resid_post" # change this to another specific SAE ID in the release if desired. 
)
activation_store = ActivationsStore.from_config(model, sparse_autoencoder.cfg)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


ValueError: Release jbloom/Gemma-2b-Residual-Stream-SAEs not found in pretrained SAEs directory.