In [1]:
import torch
from sae_lens import SAE
import transformer_lens as tl
import yaml
import sae_lens
from typing import List

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x1127fb340>

In [2]:
model = tl.HookedTransformer.from_pretrained("gpt2-small")

Loaded pretrained model gpt2-small into HookedTransformer


In [3]:
device = model.cfg.device

In [4]:
html_prompt = """<html>
<head>
    <title>My Website</title>
</head>
<body>
    <h>Welcome to My Website</h1>
    <p>This is a paragraph.</p>
    <div class="content">
        <h2>Here is a heading</h2>
        <p>This is another paragraph.</p>
    </div>
</html>
"""
html_prompt

'<html>\n<head>\n    <title>My Website</title>\n</head>\n<body>\n    <h>Welcome to My Website</h1>\n    <p>This is a paragraph.</p>\n    <div class="content">\n        <h2>Here is a heading</h2>\n        <p>This is another paragraph.</p>\n    </div>\n</html>\n'

In [26]:
html_tokens = model.tokenizer(html_prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(device)
bos_token_id = model.tokenizer.bos_token_id
if bos_token_id is not None:
    html_tokens = torch.cat([torch.tensor([[bos_token_id]]).to(device), html_tokens], dim=1)
model.to_str_tokens(html_tokens[0])

['<|endoftext|>',
 '<',
 'html',
 '>',
 '\n',
 '<',
 'head',
 '>',
 '\n',
 ' ',
 ' ',
 ' ',
 ' <',
 'title',
 '>',
 'My',
 ' Website',
 '</',
 'title',
 '>',
 '\n',
 '</',
 'head',
 '>',
 '\n',
 '<',
 'body',
 '>',
 '\n',
 ' ',
 ' ',
 ' ',
 ' <',
 'h',
 '>',
 'Welcome',
 ' to',
 ' My',
 ' Website',
 '</',
 'h',
 '1',
 '>',
 '\n',
 ' ',
 ' ',
 ' ',
 ' <',
 'p',
 '>',
 'This',
 ' is',
 ' a',
 ' paragraph',
 '.</',
 'p',
 '>',
 '\n',
 ' ',
 ' ',
 ' ',
 ' <',
 'div',
 ' class',
 '="',
 'content',
 '">',
 '\n',
 ' ',
 ' ',
 ' ',
 ' ',
 ' ',
 ' ',
 ' ',
 ' <',
 'h',
 '2',
 '>',
 'Here',
 ' is',
 ' a',
 ' heading',
 '</',
 'h',
 '2',
 '>',
 '\n',
 ' ',
 ' ',
 ' ',
 ' ',
 ' ',
 ' ',
 ' ',
 ' <',
 'p',
 '>',
 'This',
 ' is',
 ' another',
 ' paragraph',
 '.</',
 'p',
 '>',
 '\n',
 ' ',
 ' ',
 ' ',
 ' </',
 'div',
 '>',
 '\n',
 '</',
 'html',
 '>',
 '\n']

In [29]:
str_tokens = model.to_str_tokens(html_tokens[0])
tokens_we_care_about = [i for i in range(len(str_tokens)) if '<' in str_tokens[i] and "<|endoftext|>" != str_tokens[i]]
tokens_we_care_about

[1, 5, 12, 17, 21, 25, 32, 39, 47, 54, 61, 75, 83, 95, 102, 109, 113]

In [30]:
def get_cache_fwd_and_bwd(model, tokens, metric, layers):
    model.reset_hooks()
    cache = {}

    def forward_cache_hook(act, hook):
        cache[hook.name] = act.detach()

    for layer in layers:
        model.add_hook(f"blocks.{layer}.hook_resid_post", forward_cache_hook, "fwd")

    grad_cache = {}

    def backward_cache_hook(act, hook):
        grad_cache[hook.name] = act.detach()

    for layer in layers:
        model.add_hook(f"blocks.{layer}.hook_resid_post", backward_cache_hook, "bwd")
    torch.set_grad_enabled(True)
    logits = model(tokens.clone())
    value = metric(logits)
    value.backward()
    torch.set_grad_enabled(False)
    model.reset_hooks()
    return (
        value.item(),
        tl.ActivationCache(cache, model),
        tl.ActivationCache(grad_cache, model),
    )

In [31]:
def metric(logits: torch.Tensor, idxs_we_care_about: List[int] = tokens_we_care_about, labels: torch.Tensor = html_tokens[:, 1:]):
    """Next token prediction"""
    logits_to_use = logits[:, idxs_we_care_about]
    labels_to_use = labels[:, idxs_we_care_about]
    log_probs = torch.log_softmax(logits_to_use, dim=-1)
    log_probs_at_labels = log_probs.gather(dim=-1, index=labels_to_use[..., None])
    return log_probs_at_labels.mean()

In [32]:
loss, fwd_cache, bwd_cache = get_cache_fwd_and_bwd(
    model, html_tokens, metric, list(range(model.cfg.n_layers))
)
loss

-1.0637294054031372

In [33]:
tokens_we_care_about

[1, 5, 12, 17, 21, 25, 32, 39, 47, 54, 61, 75, 83, 95, 102, 109, 113]

In [34]:
import neel_utils as nutils

logits = model(html_tokens)

In [35]:
nutils.show_df(nutils.create_vocab_df(logits[0, tokens_we_care_about[0]], make_probs=True).head(15))

Unnamed: 0,token,logit,log_prob,prob
79,p,12.082162,-2.573916,0.076236
7146,div,11.820868,-2.83521,0.058706
9600,img,11.78215,-2.873928,0.056477
28112,!--,11.644079,-3.011999,0.049193
1671,br,11.336342,-3.319736,0.036162
64,a,11.318685,-3.337394,0.035529
12626,span,11.107184,-3.548894,0.028756
20635,·Earlier,11.064755,-3.591323,0.027562
11576,strong,10.947535,-3.708544,0.024513
6494,html,10.894014,-3.762064,0.023236


### Load SAEs

In [None]:
with open("/usr/local/lib/python3.10/dist-packages/sae_lens/pretrained_saes.yaml", "r") as file:
    pretrained_saes = yaml.safe_load(file)
print(pretrained_saes.keys())


RELEASE = "gpt2-small-res-jb"

saes = {}
for layer in range(model.cfg.n_layers):
    saes[layer] = sae_lens.SAE.from_pretrained(
        release=RELEASE,
        sae_id=f"blocks.{layer}.hook_resid_pre",
        device="cuda",
    )[0].to(torch.float32)