In [1]:
import torch
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7074f7766f20>

In [2]:
from sae_lens import SAE
import transformer_lens as tl

model = tl.HookedTransformer.from_pretrained("gpt2-small")

Loaded pretrained model gpt2-small into HookedTransformer


In [3]:
html_prompt = """<html>
<head>
    <title>My Website</title>
</head>
<body>
    <h>Welcome to My Website</h1>
    <p>This is a paragraph.</p>
    <div class="content">
        <h2>Here is a heading</h2>
        <p>This is another paragraph.</p>
    </div>
</html>
"""
html_prompt

'<html>\n<head>\n    <title>My Website</title>\n</head>\n<body>\n    <h>Welcome to My Website</h1>\n    <p>This is a paragraph.</p>\n    <div class="content">\n        <h2>Here is a heading</h2>\n        <p>This is another paragraph.</p>\n    </div>\n</html>\n'

In [4]:
import yaml
import sae_lens

with open("/usr/local/lib/python3.10/dist-packages/sae_lens/pretrained_saes.yaml", "r") as file:
    pretrained_saes = yaml.safe_load(file)
print(pretrained_saes.keys())


RELEASE = "gpt2-small-res-jb"

saes = {}
for layer in range(model.cfg.n_layers):
    saes[layer] = sae_lens.SAE.from_pretrained(
        release=RELEASE,
        sae_id=f"blocks.{layer}.hook_resid_pre",
        device="cuda",
    )[0].to(torch.float32)

dict_keys(['gpt2-small-res-jb', 'gpt2-small-hook-z-kk', 'gpt2-small-mlp-tm', 'gpt2-small-res-jb-feature-splitting', 'gpt2-small-resid-post-v5-32k', 'gpt2-small-resid-post-v5-128k', 'gemma-2b-res-jb', 'gemma-2b-it-res-jb', 'mistral-7b-res-wg', 'gpt2-small-resid-mid-v5-32k', 'gpt2-small-resid-mid-v5-128k', 'gpt2-small-mlp-out-v5-32k', 'gpt2-small-mlp-out-v5-128k', 'gpt2-small-attn-out-v5-32k', 'gpt2-small-attn-out-v5-128k', 'gemma-scope-2b-pt-res', 'gemma-scope-2b-pt-res-canonical', 'gemma-scope-2b-pt-mlp', 'gemma-scope-2b-pt-mlp-canonical', 'gemma-scope-2b-pt-att', 'gemma-scope-2b-pt-att-canonical', 'gemma-scope-9b-pt-res', 'gemma-scope-9b-pt-res-canonical', 'gemma-scope-9b-pt-att', 'gemma-scope-9b-pt-att-canonical', 'gemma-scope-9b-pt-mlp', 'gemma-scope-9b-pt-mlp-canonical', 'gemma-scope-9b-it-res', 'gemma-scope-9b-it-res-canonical', 'gemma-scope-27b-pt-res', 'gemma-scope-27b-pt-res-canonical', 'pythia-70m-deduped-res-sm', 'pythia-70m-deduped-mlp-sm', 'pythia-70m-deduped-att-sm', 'gpt2

This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [5]:
html_tokens = model.tokenizer(html_prompt, return_tensors="pt").input_ids.to("cuda")
print(model.to_str_tokens(html_tokens[0]))

['<', 'html', '>', '\n', '<', 'head', '>', '\n', ' ', ' ', ' ', ' <', 'title', '>', 'My', ' Website', '</', 'title', '>', '\n', '</', 'head', '>', '\n', '<', 'body', '>', '\n', ' ', ' ', ' ', ' <', 'h', '>', 'Welcome', ' to', ' My', ' Website', '</', 'h', '1', '>', '\n', ' ', ' ', ' ', ' <', 'p', '>', 'This', ' is', ' a', ' paragraph', '.</', 'p', '>', '\n', ' ', ' ', ' ', ' <', 'div', ' class', '="', 'content', '">', '\n', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' <', 'h', '2', '>', 'Here', ' is', ' a', ' heading', '</', 'h', '2', '>', '\n', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' <', 'p', '>', 'This', ' is', ' another', ' paragraph', '.</', 'p', '>', '\n', ' ', ' ', ' ', ' </', 'div', '>', '\n', '</', 'html', '>', '\n']


In [6]:
from transformer_lens import (
    HookedTransformerConfig,
    HookedTransformer,
    FactoredMatrix,
    ActivationCache,
)

def get_cache_fwd_and_bwd(model, tokens, metric, layers):
    model.reset_hooks()
    cache = {}

    def forward_cache_hook(act, hook):
        cache[hook.name] = act.detach()

    for layer in layers:
        model.add_hook(f"blocks.{layer}.hook_resid_post", forward_cache_hook, "fwd")

    grad_cache = {}

    def backward_cache_hook(act, hook):
        grad_cache[hook.name] = act.detach()

    for layer in layers:
        model.add_hook(f"blocks.{layer}.hook_resid_post", backward_cache_hook, "bwd")
    torch.set_grad_enabled(True)
    logits = model(tokens.clone())
    value = metric(logits)
    value.backward()
    torch.set_grad_enabled(False)
    model.reset_hooks()
    return (
        value.item(),
        ActivationCache(cache, model),
        ActivationCache(grad_cache, model),
    )

In [7]:
model(html_tokens).shape

torch.Size([1, 116, 50257])

In [8]:
def metric(logits):
    """Next token prediction"""
    # cross entropy loss for next token prediction
    ce_loss = torch.nn.CrossEntropyLoss()
    loss = ce_loss(logits.view(-1, logits.size(-1)), html_tokens.view(-1))
    return loss

In [11]:
loss, fwd_cache, bwd_cache = get_cache_fwd_and_bwd(
    model, html_tokens, metric, list(range(model.cfg.n_layers))
)
loss

8.581077575683594

In [10]:
# nutils.show_df(nutils.create_vocab_df(logits[0, 55]).head(20))

8.581077575683594