In [1]:
import os
os.environ['HF_HOME'] = '/workspace/huggingface'

from transformer_lens import HookedTransformer, ActivationCache, utils
import torch

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

import plotly.graph_objects as go
from plotly.offline import init_notebook_mode, iplot
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import load_dataset
from functools import partial
tqdm.pandas()

init_notebook_mode(connected=True)


Device: cuda


In [2]:
model = HookedTransformer.from_pretrained('gpt2')

model.eval()
model.set_use_attn_result(True)
model.set_use_attn_in(True)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)


`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.



Loaded pretrained model gpt2 into HookedTransformer


In [23]:
from huggingface_hub import snapshot_download

REPO_ID = "jbloom/GPT2-Small-SAEs-Reformatted"
path = snapshot_download(repo_id=REPO_ID)

Fetching 41 files:   0%|          | 0/41 [00:00<?, ?it/s]

In [24]:
from sae_lens import LMSparseAutoencoderSessionloader
from tqdm import tqdm
import os

saes = []

for l in tqdm(range(model.cfg.n_layers)):
    _, sae_group, activation_store = LMSparseAutoencoderSessionloader.load_pretrained_sae(
        path = os.path.join(path, f"blocks.{l}.hook_resid_pre"), device=device
    )
    sae_group.eval()
    saes.append(sae_group[f'blocks.{l}.hook_resid_pre'])


`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.



Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda



The repository for Skylion007/openwebtext contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/Skylion007/openwebtext
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.

  8%|▊         | 1/12 [00:15<02:49, 15.43s/it]

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda


 17%|█▋        | 2/12 [00:26<02:09, 12.96s/it]

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda


 25%|██▌       | 3/12 [00:39<01:55, 12.85s/it]

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda


 33%|███▎      | 4/12 [00:46<01:25, 10.63s/it]

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda


 42%|████▏     | 5/12 [01:03<01:30, 12.91s/it]

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda


 50%|█████     | 6/12 [01:13<01:11, 11.90s/it]

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda


 58%|█████▊    | 7/12 [01:29<01:05, 13.17s/it]

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda


 67%|██████▋   | 8/12 [01:35<00:43, 10.88s/it]

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda


 75%|███████▌  | 9/12 [01:42<00:28,  9.62s/it]

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda


 83%|████████▎ | 10/12 [01:48<00:17,  8.51s/it]

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda


 92%|█████████▏| 11/12 [01:54<00:07,  7.75s/it]

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda


100%|██████████| 12/12 [04:21<00:00, 21.78s/it]


In [4]:
data = load_dataset('eriktks/conll2003')


The repository for eriktks/conll2003 contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/eriktks/conll2003
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.



Downloading builder script:   0%|          | 0.00/9.57k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/12.3k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/983k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/14041 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3250 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3453 [00:00<?, ? examples/s]

In [22]:
for i, x in tqdm(enumerate(data['train'])):
    prompt = ' '.join(x['tokens'])
    
    tokens = x['tokens']
    model_tokens = model.to_str_tokens(prompt)[1:]

    # Map tokens to model tokens
    mapping = []
    tok_id = 0
    for j, tok in enumerate(model_tokens):
        if j != 0:
            if tok[0] == ' ':
                tok_id += 1

        mapping.append(tok_id)

    # Token reconstruction test
    reconstructed_tokens = []
    prev_id = 0
    token_str = ''
    for tok, id_ in zip(model_tokens, mapping):
        if id_ == prev_id:
            token_str += tok
        else:
            reconstructed_tokens.append(token_str.strip())
            token_str = tok
            prev_id = id_

    reconstructed_tokens.append(token_str.strip())
            
    assert tokens == reconstructed_tokens, i

14041it [00:13, 1068.40it/s]


In [25]:
def reconstr_hook(x, hook, sae):
    sae_out, f_act, *_ = sae(x)
    
    # Function to capture the gradient
    def capture_grad(grad):
        sae_grad_cache[hook.name] = grad.clone()

    # Register the hook to capture the gradient
    if f_act.requires_grad:
        f_act.register_hook(capture_grad)
    
    sae_cache[hook.name] = f_act.detach()
    return sae_out

In [33]:
sae_cache = {}
sae_grad_cache = {}

for i, x in tqdm(enumerate(data['train'])):
    prompt = ' '.join(x['tokens'])
    
    tokens = x['tokens']
    model_tokens = model.to_str_tokens(prompt)[1:]

    # Map tokens to model tokens
    mapping = []
    tok_id = 0
    for j, tok in enumerate(model_tokens):
        if j != 0:
            if tok[0] == ' ':
                tok_id += 1

        mapping.append(tok_id)

    # Running model cache
    out = model.run_with_hooks(
            model.to_tokens(prompt),
            fwd_hooks=[
                (
                    utils.get_act_name('resid_pre', l),
                    partial(reconstr_hook, sae=saes[l]),
                )
            for l in range(model.cfg.n_layers)]
        )[0]
    break

0it [00:00, ?it/s]


In [35]:
sae_cache['blocks.0.hook_resid_pre'].shape

torch.Size([1, 10, 24576])

In [14]:
# NER CLF {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}

AttributeError: 'DatasetDict' object has no attribute 'info'

In [37]:
x['ner_tags']

[3, 0, 7, 0, 0, 0, 7, 0, 0]

In [38]:
mapping

[0, 1, 2, 3, 4, 5, 6, 7, 8]