In [1]:
# try to get data set right
import numpy as np

In [2]:
k = 3  # number of digits

In [3]:
def _get_digits(num):
        # convert numbers into a reversed list of digits, such that index 0 refers to the ones digit
        digits = str(num)
        if len(digits) < k:
            padding = "0" * k 
            digits = padding[:k-len(digits)] + digits
        return list(digits)

In [4]:
digit = 2  # which digit to query (0-indexed from the right)
max_number = 10**k - 1
a1_int = np.random.randint(0, max_number)
a2_int = np.random.randint(0, max_number)
A_int = a1_int + a2_int

a1 = " ".join(_get_digits(a1_int))
a2 = " ".join(_get_digits(a2_int))
A = _get_digits(A_int)[::-1]

input_str = f"A = {a1} + {a2} , A {digit} = ?"
output = int(A[digit])
sent = f"Input: {input_str} Output: {output}"

In [5]:
sent

'Input: A = 4 4 5 + 4 0 4 , A 2 = ? Output: 8'

In [6]:
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.processors import TemplateProcessing
from transformers import PreTrainedTokenizerFast

# --- fixed vocab (lowercased to be robust if you later add Lowercase normalizer) ---
SPECIALS = ["[PAD]", "[UNK]", "[BOS]", "[EOS]"]
BASIC_TOKENS = ["Input", "Output", "A", ":", "=", "+", ",", "?"] + [str(d) for d in range(10)]
VOCAB_LIST = SPECIALS + BASIC_TOKENS
VOCAB = {tok: i for i, tok in enumerate(VOCAB_LIST)}

# ---- tokenizer model ----
tk = Tokenizer(WordLevel(vocab=VOCAB, unk_token="[UNK]"))

# Use Whitespace pre-tokenizer to split on whitespace
tk.pre_tokenizer = Whitespace()

# ---- post-processing (BOS/EOS) ----
tk.post_processor = TemplateProcessing(
    single="[BOS] $0 [EOS]",
    special_tokens=[("[BOS]", VOCAB["[BOS]"]), ("[EOS]", VOCAB["[EOS]"])],
)

# ---- wrap ----
tokenizer = PreTrainedTokenizerFast(
    tokenizer_object=tk,
    unk_token="[UNK]",
    pad_token="[PAD]",
    bos_token="[BOS]",
    eos_token="[EOS]",
)

In [7]:
x = "Input: A = 0 4 4 + 8 8 8 , A 1 = ? Output: 3"
enc = tokenizer(x, add_special_tokens=False, max_length=19, truncation=True)
print("input_ids:", enc["input_ids"])
print("tokens:", tokenizer.convert_ids_to_tokens(enc["input_ids"]))
print("decoded:", tokenizer.decode(enc["input_ids"], skip_special_tokens=True))

input_ids: [4, 7, 6, 8, 12, 16, 16, 9, 20, 20, 20, 10, 6, 13, 8, 11, 5, 7, 15]
tokens: ['Input', ':', 'A', '=', '0', '4', '4', '+', '8', '8', '8', ',', 'A', '1', '=', '?', 'Output', ':', '3']
decoded: Input : A = 0 4 4 + 8 8 8, A 1 =? Output : 3


In [8]:
len(enc["input_ids"])

19

In [9]:
tokenizer.vocab_size

22

In [10]:
from torch.utils.data import Dataset, DataLoader
import torch

class DigitAdditionDataset(Dataset):
    def __init__(self, num_samples, k, digit, tokenizer):
        self.num_samples = num_samples
        self.k = k
        self.digit = digit
        self.tokenizer = tokenizer
        
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        max_number = 10**self.k - 1
        a1_int = np.random.randint(0, max_number)
        a2_int = np.random.randint(0, max_number)
        A_int = a1_int + a2_int
        
        a1 = " ".join(_get_digits(a1_int))
        a2 = " ".join(_get_digits(a2_int))
        A = _get_digits(A_int)[::-1]
        
        input_str = f"A = {a1} + {a2} , A {self.digit} = ?"
        output = int(A[self.digit])
        sent = f"Input: {input_str} Output: {output}"
        
        # Tokenize
        encoding = self.tokenizer(sent, return_tensors='pt', padding=False, truncation=True)
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'label': output
        }

class DigitAdditionDatasetAllDigits(Dataset):
    def __init__(self, num_samples, k, tokenizer, pair_base: int = 0):
        self.num_samples = num_samples
        assert num_samples % k == 0, "num_samples must be multiple of k"
        self.k = k
        self.tokenizer = tokenizer
        self.pair_base = pair_base # evaluation offset to avoid overlap between train and eval
        
    def __len__(self):
        return self.num_samples # Each sample generates k examples
        
    def __getitem__(self, idx):
        # Determine which pair and which digit
        pair_idx = idx // (self.k) + self.pair_base
        digit_pos = idx % (self.k)
        
        # Use pair_idx as seed for reproducibility within epoch
        rng = np.random.RandomState(pair_idx)
        max_number = 10**self.k - 1
        a1_int = rng.randint(0, max_number)
        a2_int = rng.randint(0, max_number)
        A_int = a1_int + a2_int
        
        a1 = " ".join(_get_digits(a1_int))
        a2 = " ".join(_get_digits(a2_int))
        A = _get_digits(A_int)[::-1]
        
        input_str = f"A = {a1} + {a2} , A {digit_pos} = ?"
        output = int(A[digit_pos])
        sent = f"Input: {input_str} Output: {output}"
        
        # Tokenize
        encoding = self.tokenizer(sent, return_tensors='pt', padding=False, add_special_tokens=False)
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'label': output
        }
# Create dataset and dataloader
#train_dataset = DigitAdditionDataset(num_samples=10000, k=k, digit=digit, tokenizer=tokenizer)
#train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [None]:
#tokenizer.save_pretrained("./addition_tokenizer")

('./addition_tokenizer/tokenizer_config.json',
 './addition_tokenizer/special_tokens_map.json',
 './addition_tokenizer/tokenizer.json')

In [11]:
train_dataset = DigitAdditionDatasetAllDigits(num_samples=32*3*2000, k=3, tokenizer=tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False)
count = 0
for batch in train_dataloader:
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    labels = batch['label']
    print(input_ids.shape, attention_mask.shape, labels.shape)
    count += 1
    if count >= 2:
        break

torch.Size([32, 19]) torch.Size([32, 19]) torch.Size([32])
torch.Size([32, 19]) torch.Size([32, 19]) torch.Size([32])


In [12]:
tokenizer.decode(input_ids[0]), tokenizer.decode(input_ids[1]), tokenizer.decode(input_ids[2])

('Input : A = 2 6 5 + 1 2 5, A 2 =? Output : 3',
 'Input : A = 9 2 1 + 7 0 3, A 0 =? Output : 4',
 'Input : A = 9 2 1 + 7 0 3, A 1 =? Output : 2')

In [13]:
import torch, torch.nn as nn, torch.nn.functional as F
from typing import Optional, Any
from dataclasses import dataclass

# helper
ACT2FN = {
    "relu": F.relu,
    "gelu": F.gelu,
    "silu": F.silu,
    "swish": F.silu,
}


@dataclass
class AttentionConfig:
    D: int = 768
    layer_idx: Optional[int] = None
    n_heads: int = 4
    causal: bool = True
    device: str = "cuda"


class HookPoint(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return x


class Attention(nn.Module):  # BSD -> BSD
    def __init__(self, layer_idx: int, config: AttentionConfig):
        super().__init__()
        self.layer_idx = layer_idx
        self.D = config.D
        self.n_heads = config.n_heads
        assert self.D % self.n_heads == 0
        self.head_dim = self.D // self.n_heads
        self.Wq = nn.Linear(self.D, self.D, bias=False)
        self.Wk = nn.Linear(self.D, self.D, bias=False)
        self.Wv = nn.Linear(self.D, self.D, bias=False)
        self.causal = config.causal
        #self.Wo = nn.Linear(self.D, self.D, bias=False)
        #self.Wo.weight.data.zero_()  # initialize to zero for stability
        self.W_O = nn.Parameter(torch.zeros(self.n_heads, self.head_dim, self.D))
        self.device = config.device
        # Hook points
        self.hook_attn_pattern = HookPoint()
        self.hook_attn_output_per_head = HookPoint()

    def forward(
        self, x: torch.Tensor, kv_cache: Optional[Any] = None
    ) -> torch.Tensor:  # input is [B, S, D]
        B, S, D = x.shape

        # Make each QKV [B, S, D] --> [B, nh, S, hd]
        Q, K, V = self.Wq(x), self.Wk(x), self.Wv(x)  # all [B, S, D]

        Q = Q.view(B, S, self.n_heads, self.head_dim).transpose(1, 2)  # [B, nh, S, hd]
        K = K.view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
        V = V.view(B, S, self.n_heads, self.head_dim).transpose(1, 2)

        # update kv cache
        layer_idx = self.layer_idx
        if kv_cache is not None and layer_idx is not None:
            # its preallocated, just write to the memory of the cache using state of current_length
            kv_cache.update(layer_idx, K, V)
            K = kv_cache.keys[layer_idx][:, :, : kv_cache.current_length, :]
            V = kv_cache.values[layer_idx][:, :, : kv_cache.current_length, :]

        # [B, nh, S, hd] @ [B, nh, hd, S] -> [B, nh, S, S]
        scale = torch.sqrt(
            torch.tensor(self.head_dim, dtype=Q.dtype, device=self.device)
        )
        logits = (Q @ K.transpose(-2, -1)) / scale
        if self.causal:
            mask = torch.triu(torch.ones_like(logits), diagonal=1).bool()
            logits_masked = logits.masked_fill(mask, float("-inf"))
        else:
            logits_masked = logits

        A = F.softmax(logits_masked, dim=-1)  # [B, nh, S, S]
        # Hook attention pattern: [B, nh, S, S]
        A = self.hook_attn_pattern(A)

        preout = torch.einsum(
            "bnxy,bnyd->bnxd", A, V
        )  # [B, nh, S, hd]

        # Rearrange W_O from [D, D] to [nh, hd, D]
        #W_O = self.Wo.weight.T.view(self.n_heads, self.head_dim, self.D)
        attn_output_per_head = torch.einsum(
            "bnxd,ndh->bnxh", preout, self.W_O
        )  # [B, nh, S, D]
        # Reorder to [B, S, nh, D] and hook
        attn_output_per_head_seq = attn_output_per_head.transpose(1, 2)
        attn_output_per_head_seq = self.hook_attn_output_per_head(attn_output_per_head_seq)
        # Sum across heads -> [B, S, D]
        attn_out = attn_output_per_head_seq.sum(dim=2)
        return attn_out  # [B, S, D]


@dataclass
class MLPConfig:
    D: int
    hidden_multiplier: int = 4
    act: str = "gelu"
    device: Optional[torch.device] = None


# most important fact about MLP: it operates on each token independently, ie. D --> D
class MLP(nn.Module):
    def __init__(self, config: MLPConfig):
        super().__init__()
        self.D = config.D
        self.device = (
            config.device
            if config.device is not None
            else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        )
        self.up_proj = nn.Linear(self.D, self.D * config.hidden_multiplier, bias=False)
        self.down_proj = nn.Linear(
            self.D * config.hidden_multiplier, self.D, bias=False
        )
        self.down_proj.weight.data.zero_()  # initialize to zero for stability
        self.act = ACT2FN[config.act]
        # Hook point at MLP mid activation
        self.hook_mlp_mid = HookPoint()

    def forward(
        self, x: torch.Tensor
    ) -> torch.Tensor:  # BSD -> BSD automatically on last dim
        mid = self.act(self.up_proj(x))
        mid = self.hook_mlp_mid(mid)  # [B, S, D*mult]
        return self.down_proj(mid)


@dataclass
class LNConfig:
    D: int
    eps: float = 1e-9
    device: Optional[torch.device] = None


class LN(nn.Module):
    def __init__(self, config: LNConfig):
        super().__init__()
        self.D = config.D
        self.eps = config.eps
        self.device = (
            config.device
            if config.device is not None
            else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        )
        self.mean_scale = nn.Parameter(torch.zeros(self.D))
        self.std_scale = nn.Parameter(torch.ones(self.D))

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # x is [B, S, D]
        mean = x.mean(dim=-1, keepdim=True)  # [B, S, 1]
        std = (x.var(dim=-1, keepdim=True) + self.eps) ** 0.5  # [B, S, 1]
        x_norm = (x - mean) / (std)
        return x_norm * self.std_scale + self.mean_scale


@dataclass
class TransformerLayerConfig:
    D: int = 768
    n_heads: int = 4
    device: Optional[torch.device] = None


class TransformerLayer(nn.Module):
    def __init__(self, layer_idx: int, config: TransformerLayerConfig):
        super().__init__()
        self.D = config.D
        self.layer_idx = layer_idx
        self.device = (
            config.device
            if config.device is not None
            else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        )

        attn_config = AttentionConfig(
            D=self.D, n_heads=config.n_heads, device=self.device
        )
        mlp_config = MLPConfig(D=self.D, device=self.device)
        ln_config = LNConfig(D=self.D, device=self.device)

        self.attn = Attention(self.layer_idx, attn_config)
        self.mlp = MLP(mlp_config)
        self.ln1 = LN(ln_config)
        self.ln2 = LN(ln_config)
        # Residual stream hook points
        self.hook_resid_pre = HookPoint()
        self.hook_resid_mid = HookPoint()
        self.hook_resid_post = HookPoint()

    def forward(
        self, x: torch.Tensor, kv_cache: Optional[Any] = None, return_attn: bool = False
    ) -> torch.Tensor:  # x is BSD
        x = self.hook_resid_pre(x)
        ln1_out = self.ln1(x)
        attn_out = self.attn(ln1_out, kv_cache=kv_cache)
        x = x + attn_out
        x = self.hook_resid_mid(x)
        ln2_out = self.ln2(x)
        mlp_out = self.mlp(ln2_out)
        x = x + mlp_out
        x = self.hook_resid_post(x)
        if return_attn:
            return x, attn_out
        return x


@dataclass
class PositionalEmbeddingConfig:
    max_seq_len: int
    D: int
    device: Optional[torch.device] = None


class PositionalEmbedding(nn.Module):
    def __init__(self, config: PositionalEmbeddingConfig):
        super().__init__()
        self.device = (
            config.device
            if config.device is not None
            else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        )
        self.pos_embedding = nn.Parameter(torch.randn(config.max_seq_len, config.D))

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # x is [B, S, D]
        B, S, D = x.shape
        return x + self.pos_embedding[:S]  # Broadcasting handles batch dimension


@dataclass
class EmbeddingLayerConfig:
    vocab_size: int
    D: int
    device: Optional[torch.device] = None


class EmbeddingLayer(nn.Module):
    def __init__(self, config: EmbeddingLayerConfig):
        super().__init__()
        self.device = (
            config.device
            if config.device is not None
            else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        )
        self.embedding = nn.Parameter(torch.randn(config.vocab_size, config.D))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.embedding[x]


@dataclass
class UnembeddingLayerConfig:
    vocab_size: int
    D: int
    device: Optional[torch.device] = None


class UnembeddingLayer(nn.Module):
    def __init__(self, config: UnembeddingLayerConfig):
        super().__init__()
        self.device = (
            config.device
            if config.device is not None
            else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        )
        self.V = config.vocab_size
        self.unembedding = nn.Linear(config.D, self.V, bias=False)
        self.unembedding.weight.data.zero_() # initialize to zero for stability

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # x is [B, S, D]
        return self.unembedding(x)


@dataclass
class TransformerConfig:
    hidden_dim: int = 768
    depth: int = 2
    n_heads: int = 4
    vocab_size: int = 50257
    max_seq_len: int = 128
    device: Optional[torch.device] = None


class Transformer(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.depth = config.depth
        self.hidden_dim = config.hidden_dim
        self.vocab_size = config.vocab_size

        emb_config = EmbeddingLayerConfig(
            vocab_size=config.vocab_size, D=config.hidden_dim, device=config.device
        )
        pos_emb_config = PositionalEmbeddingConfig(
            max_seq_len=config.max_seq_len, D=config.hidden_dim, device=config.device
        )
        unemb_config = UnembeddingLayerConfig(
            vocab_size=config.vocab_size, D=config.hidden_dim, device=config.device
        )

        self.emb = EmbeddingLayer(emb_config)
        self.pos_emb = PositionalEmbedding(pos_emb_config)

        self.ln_final = LN(LNConfig(D=config.hidden_dim, device=config.device))
        self.unemb = UnembeddingLayer(unemb_config)

        layer_config = TransformerLayerConfig(
            D=config.hidden_dim, n_heads=config.n_heads, device=config.device
        )
        self.layers = nn.ModuleList(
            [TransformerLayer(idx, layer_config) for idx in range(config.depth)]
        )
        for i, layer in enumerate(self.layers):
            layer.attn.layer_idx = i

        self.device = config.device

    def forward(
        self, x: torch.Tensor, kv_cache: Optional[Any] = None, return_attn: bool = False
    ) -> torch.Tensor:
        x = self.emb(x)
        if kv_cache is not None:
            # When decoding, only add positional embeddings for the new tokens.
            pos_offset = kv_cache.current_length
            pos_emb = self.pos_emb.pos_embedding[
                pos_offset : pos_offset + x.size(1)
            ].unsqueeze(0)
            x = x + pos_emb
        else:
            x = self.pos_emb(x)

        all_attn = []
        for _, layer in enumerate(self.layers):
            if return_attn:
                x, attn = layer(x, kv_cache=kv_cache, return_attn=True)
                all_attn.append(attn)
            else:
                x = layer(x, kv_cache=kv_cache)

        x = self.ln_final(x)
        logits = self.unemb(x)
        if return_attn:
            return logits, torch.stack(all_attn, dim=0)
        return logits

In [14]:
cfg = TransformerConfig(
    hidden_dim=128*6,
    depth=2,
    n_heads=6,
    vocab_size=tokenizer.vocab_size,
    max_seq_len=19,
    device="cpu",
)

model = Transformer(cfg)

In [15]:
model.eval()
with torch.no_grad():
    logits = model(input_ids)
logits.shape  # should be [B, S, Vocab_size]

torch.Size([32, 19, 22])

In [16]:
input_ids.shape

torch.Size([32, 19])

In [17]:
labels

tensor([3, 4, 2, 6, 0, 1, 5, 4, 1, 5, 3, 6, 9, 3, 1, 2, 9, 1, 9, 4, 6, 8, 3, 7,
        5, 7, 2, 2, 9, 2, 8, 6])

In [18]:
# Get logits for the last token position
last_token_logits = logits[:, -2, :]  # [B, vocab_size]

loss = F.cross_entropy(last_token_logits, input_ids[:, -1])

In [19]:
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
logits = model(input_ids)
loss = F.cross_entropy(logits[:, :-1, :].contiguous().view(-1, logits.size(-1)), 
                       input_ids[:, 1:].contiguous().view(-1))
loss.backward()
optimizer.step()