In [1]:
import os
os.environ['HF_HOME'] = '/workspace/huggingface'

from transformer_lens import HookedTransformer, ActivationCache, utils
import torch

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

import plotly.graph_objects as go
from plotly.offline import init_notebook_mode, iplot
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from functools import partial

init_notebook_mode(connected=True)

Device: cuda


In [2]:
!huggingface-cli login --token 

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /workspace/huggingface/token
Login successful


In [3]:
model = HookedTransformer.from_pretrained('gemma-2b', device=device)

model.eval()
model.set_use_attn_result(True)
model.set_use_attn_in(True)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)

config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/33.6k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [29]:
import os
import sys

current_dir = os.getcwd()
atp_dir = os.path.abspath(os.path.join(current_dir, '..', '..'))
sys.path.append(atp_dir)

from atp import Patching

patching = Patching(model, 'atp')

In [30]:
tasks = [
    {
        'name': 'Indirect Object Identification',
        'description': 'Find the correct token to predict when dealing with repeated objects in the sentence',
        'clean_prompt': 'John and Mary went to the store, then John gave a bottle of milk to',
        'corrupted_prompt': 'John and Mary went to the store, then Mary gave a bottle of milk to',
        'clean_answer': ' Mary',
        'corrupted_answer': ' John'
    },
    {
        'name': 'Numbers Addition',
        'description': 'Find the right outcome given by the addition of two numbers',
        'clean_prompt': '10 + 3 = 1',
        'corrupted_prompt': '10 + 6 = 1',
        'clean_answer': '3',
        'corrupted_answer': '6'
    },
]
    

In [31]:
task_id = 0
task = tasks[task_id]

In [32]:
model.generate(
    task['clean_prompt'], 
    max_new_tokens=1,
    temperature=0
)

  0%|          | 0/1 [00:00<?, ?it/s]

'John and Mary went to the store, then John gave a bottle of milk to Mary'

In [33]:
patching.patching(
    task['clean_prompt'], task['clean_answer'], 
    task['corrupted_prompt'], task['corrupted_answer'], 
    component='attn_all'
)

Clean logit difference: 3.917
Corrupted logit difference: -4.342
Patching...


OutOfMemoryError: CUDA out of memory. Tried to allocate 16.00 MiB (GPU 0; 23.68 GiB total capacity; 23.41 GiB already allocated; 11.25 MiB free; 23.41 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [28]:
fig = patching.plot()
fig.write_html(f"task_{task_id}.html")