In [1]:
%load_ext autoreload
%autoreload 2


# HookedTransformer

* [TransformerLens - Tutorial - Trains HookedTransformer from Scratch](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/No_Position_Experiment.ipynb)

```python
import transformers

class HookedTransformer:
    cfg: HookedTransformerConfig
    
    # note: unlike EasyTransformer, HookedTransformer
    #       does not allow passing in a tokenizer other than by name,
    #       it needs to be set here
    #
    # note: it's probably easier to just operate on tokens outside of the model,
    #       that'll also make it clearer where tokenizer is used
    def set_tokenizer(self, tokenizer: transformers.PreTrainedTokenizer) -> None:
        self.tokenizer = tokenizer
```

In [2]:
import transformer_lens

In [3]:
device = transformer_lens.utils.get_device()

print(f'Using device: {device}')

Using device: mps


In [None]:
cfg = transformer_lens.HookedTransformerConfig(
    n_layers=2,
    d_model=64,
    d_head=64,
    n_heads=1,
    d_mlp=256,
    d_vocab=300,
    n_ctx=50,
    act_fn="relu",
    normalization_type="LN",
    device=device,
)
model =transformer_lens.HookedTransformer(cfg)

In [None]:
# make a naive tokenizer that just maps each unique char as a token
import dataclasses

# note: just used for visualization
from gpt_from_scratch import tokenizer_utils

TokenInt = int
Byte = int

@dataclasses.dataclass(frozen=True)
class NaiveTokenizer:
    """
    Naive tokenizer that maps each unique byte to a unique token
    """

    # note: store both for fast lookup during both encoding and decoding
    byte_to_token_dict: dict[Byte, TokenInt]
    token_to_byte_dict: dict[TokenInt, Byte]

    @classmethod
    def from_text(cls, text: str) -> 'NaiveTokenizer':

        byte_to_token_dict: dict[Byte, TokenInt] = {}
        token_to_byte_dict: dict[TokenInt, Byte] = {}
        
        text_as_bytes = text.encode('utf-8')

        unique_bytes = set(text_as_bytes)

        for index, unique_byte in enumerate(unique_bytes):
            byte_to_token_dict[unique_byte] = index
            token_to_byte_dict[index] = unique_byte

        return cls(
            byte_to_token_dict=byte_to_token_dict,
            token_to_byte_dict=token_to_byte_dict,
        )

    def decode(self, encoded_bytes: list[int]) -> str:

        return ''.join(self.token_to_byte_dict[token] for token in encoded_bytes)

    def encode(self, text: str) -> list[int]:

        text_as_bytes = text.encode('utf-8')

        return [self.byte_to_token_dict[byte] for byte in text_as_bytes]

    def decode(self, encoded_bytes: list[TokenInt]) -> str:

        # given ids (list of integers), return Python string
        vocab_byte_string = b"".join(
            self.token_to_byte_dict[encoded_byte] for encoded_byte in encoded_bytes
        )

        # replace with special marker (�) for any bytes that can't be decoded
        # UTF-8 requires special start tokens for multi-byte
        # standard practice is to use `errors="replace"`
        text = vocab_byte_string.decode("utf-8", errors="replace")

        return text

    # for compatibility with tiktoken
    def decode_single_token_bytes(self, encoded_byte: TokenInt) -> bytes:

        return self.token_to_byte_dict[encoded_byte]