In [1]:
import torch as t
import torch.nn as nn
import einops
from jaxtyping import Int, Float
from typing import List, Optional, Tuple
from transformer_lens import HookedTransformer, utils, HookedTransformerConfig
from InductionHeadMethods import InductionTask as IT
from InductionHeadMethods import InductionAttribution as IA
from torch import Tensor
from transformer_lens.utils import to_numpy
from plotly.express import imshow

t.set_grad_enabled(False)
device = t.device("cuda" if t.cuda.is_available() else "cpu")

In [2]:
model: HookedTransformer = HookedTransformer.from_pretrained("gpt2-small")



Loaded pretrained model gpt2-small into HookedTransformer


In [15]:
seq_len = 16
tokens = IT.generate_repeated_tokens(model, seq_len)[0]
logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)

In [18]:
results = IA.get_results(cache, model)
tokens = tokens[seq_len:]
logit_attr = IA.logit_attribution(cache["embed"][seq_len:], [result[seq_len:] for result in results], model.W_U, tokens)

In [20]:
def convert_tokens_to_string(model, tokens, batch_index=0):
    if len(tokens.shape) == 2:
        tokens = tokens[batch_index]
    return [f"|{model.tokenizer.decode(tok)}|_{c}" for (c, tok) in enumerate(tokens)]


tokens = tokens.squeeze()
y_labels = convert_tokens_to_string(model, tokens[:-1])
x_labels = ["Direct"] + [f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)]
title = ""

imshow(
        to_numpy(logit_attr), 
        x=x_labels, y=y_labels, 
        labels={"x": "Term", "y": "Position", "color": "logit"}, title=title if title else None, 
        height=18*len(y_labels), width=24*len(x_labels)
)
