In [14]:
import sys
import os

sys.path.append(os.path.dirname(os.getcwd()))

import plotly.io as pio
import circuitsvis as cv

import torch
import einops


import transformer_lens.utils as utils
from transformer_lens import HookedTransformer

from vis import imshow, scatter, line

torch.set_grad_enabled(False)
pio.renderers.default = "notebook_connected"

In [15]:
device = utils.get_device()
gpt2 = HookedTransformer.from_pretrained('gpt2', device=device)


`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.



Loaded pretrained model gpt2 into HookedTransformer


In [16]:
text = "Kami and Luka went to store. Luka bought a candy and hand it to"
tokens = gpt2.to_tokens(text)
logit, cache = gpt2.run_with_cache(tokens, remove_batch_dim=True)

In [17]:
aname = utils.get_act_name('pattern', 8)
pattern = cache[aname]
pattern.shape

torch.Size([12, 19, 19])

In [18]:
cv.attention.attention_patterns(tokens=gpt2.to_str_tokens(text), attention=pattern)

In [19]:
size = (10, 50)
tokens = torch.randint(1000, 10000, size=size)
tokens = einops.repeat(tokens, 'b t -> b (2 t)')
logit = gpt2(tokens)

In [20]:
loss = gpt2.loss_fn(logit, tokens, per_token=True)
loss_per_tok = einops.reduce(loss, 'b t -> t', 'mean')
loss_per_tok.shape

torch.Size([99])

In [21]:
line(loss_per_tok, xaxis='position', yaxis='loss')

In [22]:
buffer = torch.zeros((gpt2.cfg.n_layers, gpt2.cfg.n_heads), device=device)
buffer.shape

torch.Size([12, 12])

In [23]:
# pattern is (b h t t)
def save_attn(pattern, hook):
    # find attention weight for -seq_len + 1 possition
    attn = pattern.diagonal(offset=-50 + 1, dim1=-2, dim2=-1)  # (b h t)
    attn = einops.reduce(attn, "b h t -> h", "mean")
    buffer[hook.layer(), :] = attn


hook_filter = lambda name: name.endswith("pattern")
gpt2.run_with_hooks(tokens, fwd_hooks=[(hook_filter, save_attn)], return_type=None)

In [24]:
imshow(buffer, xaxis='head', yaxis='layer')

In [25]:
text = "i love you cami you are the best thing happened to me "
tokens = gpt2.to_tokens(text, prepend_bos=False)
tokens = einops.repeat(tokens, "b t -> b (2 t)")


def vis_attn(pattern, hook):
    display(  # need this because it's in function
        cv.attention.attention_patterns(
            tokens=gpt2.to_str_tokens(tokens), attention=pattern[0, 5][None, ...]
        )
    )


gpt2.run_with_hooks(
    tokens, return_type=None, fwd_hooks=[(utils.get_act_name("pattern", 5), vis_attn)]
)