## minbpe

In [15]:
from minbpe.minbpe.base import Tokenizer, get_stats, merge


class BasicTokenizer(Tokenizer):
    def __init__(self):
        super().__init__()

    def train(self, text, vocab_size, verbose=False):
        assert vocab_size >= 256, "vocab_size is less than 256"
        num_merges = vocab_size - 256

        ids = list(text.encode("utf-8"))

        merges = {}
        vocab = {idx: bytes([idx]) for idx in range(256)}
        for i in range(num_merges):
            stats = get_stats(ids)
            pair = max(stats, key=stats.get)
            idx = 256 + i
            ids = merge(ids, pair, idx)
            merges[pair] = idx
            vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
            if verbose:
                print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")

        self.merges = merges
        self.vocab = vocab

    def encode(self, text):
        ids = list(text.encode("utf-8"))
        while len(ids) >= 2:
            stats = get_stats(ids)
            pair = min(stats, key=lambda x: stats.get(x, float("inf")))
            if pair not in self.merges:
                break
            idx = self.merges[pair]
            ids = merge(ids, pair, idx)
        return ids
    
    def decode(self, ids):
        text_bytes = b"".join(self.vocab[idx] for idx in ids)
        text = text_bytes.decode("utf-8", errors="replace")
        return text


In [11]:
with open("minbpe/tests/taylorswift.txt", "r") as f:
    text = f.read()

In [16]:
basic_tokenizer = BasicTokenizer()
basic_tokenizer.train(text, 300, verbose=True)

merge 1/44: (101, 32) -> 256 (b'e ') had 2981 occurrences
merge 2/44: (44, 32) -> 257 (b', ') had 2961 occurrences
merge 3/44: (100, 32) -> 258 (b'd ') had 2617 occurrences
merge 4/44: (46, 32) -> 259 (b'. ') had 2560 occurrences
merge 5/44: (114, 32) -> 260 (b'r ') had 2428 occurrences
merge 6/44: (50, 48) -> 261 (b'20') had 2365 occurrences
merge 7/44: (115, 32) -> 262 (b's ') had 2053 occurrences
merge 8/44: (105, 110) -> 263 (b'in') had 2006 occurrences
merge 9/44: (111, 110) -> 264 (b'on') had 1815 occurrences
merge 10/44: (114, 105) -> 265 (b'ri') had 1805 occurrences
merge 11/44: (116, 32) -> 266 (b't ') had 1802 occurrences
merge 12/44: (116, 104) -> 267 (b'th') had 1737 occurrences
merge 13/44: (101, 258) -> 268 (b'ed ') had 1736 occurrences
merge 14/44: (257, 261) -> 269 (b', 20') had 1705 occurrences
merge 15/44: (97, 110) -> 270 (b'an') had 1487 occurrences
merge 16/44: (97, 114) -> 271 (b'ar') had 1360 occurrences
merge 17/44: (101, 260) -> 272 (b'er ') had 1356 occurrence

In [24]:
ids = basic_tokenizer.encode("hello world!!!? (안녕하세요!) lol123 😉")
ids

[104,
 101,
 108,
 108,
 111,
 32,
 119,
 111,
 114,
 108,
 100,
 33,
 33,
 33,
 63,
 32,
 40,
 236,
 149,
 136,
 235,
 133,
 149,
 237,
 149,
 152,
 236,
 132,
 184,
 236,
 154,
 148,
 33,
 41,
 32,
 108,
 111,
 108,
 49,
 50,
 51,
 32,
 240,
 159,
 152,
 137]

In [25]:
basic_tokenizer.decode(ids)

'hello world!!!? (안녕하세요!) lol123 😉'

In [4]:
import regex as re
from minbpe.minbpe.base import Tokenizer, get_stats, merge

GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""


class RegexTokenizer(Tokenizer):
    def __init__(self):
        super().__init__()
        self.compiled_pattern = re.compile(GPT4_SPLIT_PATTERN)

    def train(self, text, vocab_size, verbose=False):
        assert vocab_size >= 256, "vocab_size is less than 256"
        num_merges = vocab_size - 256

        # split text into chunks
        text_chunks = re.findall(self.compiled_pattern, text)
        ids = [list(chunk.encode("utf-8")) for chunk in text_chunks] 

        merges = {}
        vocab = {idx: bytes([idx]) for idx in range(256)}
        for i in range(num_merges):
            stats = {}
            for chunk_ids in ids:
                get_stats(chunk_ids, stats)
            pair = max(stats, key=stats.get)
            idx = 256 + i
            ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids]
            merges[pair] = idx
            vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
            if verbose:
                print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")

        self.merges = merges
        self.vocab = vocab

    def encode(self, text):
        text_chunks = re.findall(self.compiled_pattern, text)
        ids = []
        for chunk in text_chunks:
            chunk_ids = list(chunk.encode("utf-8")) 
            while len(chunk_ids) >= 2:
                stats = get_stats(chunk_ids)
                pair = min(stats, key=lambda x: stats.get(x, float("inf")))
                if pair not in self.merges:
                    break
                idx = self.merges[pair]
                chunk_ids = merge(chunk_ids, pair, idx)
            ids.extend(chunk_ids)
        return ids
    
    def decode(self, ids):
        text_bytes = b"".join(self.vocab[idx] for idx in ids)
        text = text_bytes.decode("utf-8", errors="replace")
        return text


In [5]:
with open("tests/taylorswift.txt", "r") as f:
    text = f.read()

In [12]:
regex_tokenizer = RegexTokenizer()
regex_tokenizer.train(text, 300, verbose=True)

merge 1/44: (101, 114) -> 256 (b'er') had 2359 occurrences
merge 2/44: (50, 48) -> 257 (b'20') had 2187 occurrences
merge 3/44: (111, 114) -> 258 (b'or') had 2076 occurrences
merge 4/44: (105, 110) -> 259 (b'in') had 2006 occurrences
merge 5/44: (101, 100) -> 260 (b'ed') had 1876 occurrences
merge 6/44: (32, 116) -> 261 (b' t') had 1824 occurrences
merge 7/44: (111, 110) -> 262 (b'on') had 1815 occurrences
merge 8/44: (104, 101) -> 263 (b'he') had 1772 occurrences
merge 9/44: (32, 83) -> 264 (b' S') had 1633 occurrences
merge 10/44: (97, 114) -> 265 (b'ar') had 1519 occurrences
merge 11/44: (97, 110) -> 266 (b'an') had 1487 occurrences
merge 12/44: (32, 65) -> 267 (b' A') had 1335 occurrences
merge 13/44: (261, 263) -> 268 (b' the') had 1169 occurrences
merge 14/44: (97, 108) -> 269 (b'al') had 1164 occurrences
merge 15/44: (114, 105) -> 270 (b'ri') had 1156 occurrences
merge 16/44: (118, 260) -> 271 (b'ved') had 1104 occurrences
merge 17/44: (115, 116) -> 272 (b'st') had 1089 occurren

In [22]:
ids = regex_tokenizer.encode("hello world!!!? (안녕하세요!) lol123 😉")
ids

[263,
 108,
 108,
 111,
 32,
 119,
 111,
 114,
 108,
 100,
 33,
 33,
 33,
 63,
 293,
 236,
 149,
 136,
 235,
 133,
 149,
 237,
 149,
 152,
 236,
 132,
 184,
 236,
 154,
 148,
 33,
 41,
 32,
 108,
 111,
 108,
 49,
 50,
 51,
 32,
 240,
 159,
 152,
 137]

In [23]:
regex_tokenizer.decode(ids)

'hello world!!!? (안녕하세요!) lol123 😉'

## GPT2

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from dataclasses import dataclass

# vanilla, non-DDP run
ddp_rank = 0
ddp_local_rank = 0
ddp_world_size = 1
master_process = True
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"
print(f"using device: {device}")


class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embed, 4 * config.n_embed)
        self.gelu = nn.GELU(approximate="tanh")
        self.c_proj = nn.Linear(4 * config.n_embed, config.n_embed)
        self.c_proj.NANOGPT_SCALE_INIT = 1

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x


class CasualSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embed % config.n_head == 0
        self.n_head = config.n_head
        self.n_embed = config.n_embed
        self.c_attn = nn.Linear(config.n_embed, 3 * config.n_embed)  # qkv projections
        self.c_proj = nn.Linear(config.n_embed, config.n_embed)
        self.c_proj.NANOGPT_SCALE_INIT = 1

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embed, dim=2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)

        y = F.scaled_dot_product_attention(q, k, v, is_casual=True)

        y = y.transpose(1, 2).contiguous().view(B, T, C)  #  (B, T, C)
        y = self.c_proj(y)
        return y


class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embed)
        self.attn = CasualSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embed)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


@dataclass
class GPTConfig:
    block_size: int = 1024  # max sequence length
    vocab_size: int = 50257  # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
    n_layer: int = 12  # number of layers
    n_head: int = 12  # number of heads
    n_embd: int = 768  # embedding dimension


class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(
            dict(
                wte=nn.Embedding(config.vocab_size, config.n_embed),
                wpe=nn.Embedding(config.block_size, config.n_embed),
                h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
                ln_f=nn.LayerNorm(config.n_embed),
            )
        )
        self.lm_head = nn.Linear(config.n_embed, config.vocab_size, bias=False)

        # weight sharing scheme
        self.transformer.wte.weight = self.lm_head.weight

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            std = 0.02
            if hasattr(module, "NANOGPT_SCALE_INIT"):
                std *= (2 * self.config.n_layer) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        assert T <= self.config.block_size

        pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
        pos_emb = self.transformer.wpe(pos)

        tok_emb = self.transformer.wte(idx)

        x = tok_emb + pos_emb
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.transformer.lm_head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

    @classmethod
    def from_pretrained(cls, model_type):
        """Loads pretrained model weights from huggingface."""
        pass

    def configure_optimizers(self, weight_decay, learning_rate, device_type):
        # start with all of the candidate parameters (that require grad)
        param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {"params": decay_params, "weight_decay": weight_decay},
            {"params": nodecay_params, "weight_decay": 0.0},
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        if master_process:
            print(
                f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters"
            )
            print(
                f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters"
            )
        # Create AdamW optimizer and use the fused version if it is available
        fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == "cuda"
        if master_process:
            print(f"using fused AdamW: {use_fused}")
        optimizer = torch.optim.AdamW(
            optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused
        )
        return optimizer




## Llama-3

[
    "tok_embeddings.weight",
    "layers.0.attention.wq.weight",
    "layers.0.attention.wk.weight",
    "layers.0.attention.wv.weight",
    "layers.0.attention.wo.weight",
    "layers.0.feed_forward.w1.weight",
    "layers.0.feed_forward.w3.weight",
    "layers.0.feed_forward.w2.weight",
    "layers.0.attention_norm.weight",
    "layers.0.ffn_norm.weight",
    "layers.1.attention.wq.weight",
    "layers.1.attention.wk.weight",
    "layers.1.attention.wv.weight",
    "layers.1.attention.wo.weight",
    "layers.1.feed_forward.w1.weight",
    "layers.1.feed_forward.w3.weight",
    "layers.1.feed_forward.w2.weight",
    "layers.1.attention_norm.weight",
    "layers.1.ffn_norm.weight",
    "layers.2.attention.wq.weight"
]

{'dim': 4096,
 'n_layers': 32,
 'n_heads': 32,
 'n_kv_heads': 8,
 'vocab_size': 128256,
 'multiple_of': 1024,
 'ffn_dim_multiplier': 1.3,
 'norm_eps': 1e-05,
 'rope_theta': 500000.0}
 

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from dataclasses import dataclass


class FeedForward(nn.Module):

    def __init__(self, config):
        super().__init__()
        hidden_dim = 4 * config.n_embd
        hidden_dim = int(2 * hidden_dim / 3)  # Applying your specific transformation
        if config.ffn_dim_multiplier is not None:
            hidden_dim = int(config.ffn_dim_multiplier * hidden_dim)

        self.w1 = nn.Linear(config.n_embed, hidden_dim, bias=False)
        self.w3 = nn.Linear(config.n_embed, hidden_dim, bias=False)
        self.silu = nn.SiLU()
        self.w2 = nn.Linear(hidden_dim, config.n_embed, bias=False)

    def forward(self, x):
        swish = self.w1(self.silu(x))
        x_v = self.w3(x)
        x = swish * x_v
        x = self.w2(x)
        return x
    
# https://medium.com/@vi.ai_/exploring-and-building-the-llama-3-architecture-a-deep-dive-into-components-coding-and-43d4097cfbbb

class CasualSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.n_embed = config.n_embed
        self.seq_len = config.block_size
        self.rope_theta = config.rope_theta

    def _precomputed_theta_pos_frequencies(self, device):
        # As written in the paper, the dimension of embedding must be even
        assert self.n_embed % 2 == 0, "n_embed must be even"
        # Built the theta parameters
        # According to the formula theta_i = 10000 ^ (-2(i-1)/dim) for i = [1,2,3,..dim/2]
        theta = 1.0 / (self.rope_theta ** (torch.arange(0, self.n_embed, 2) / self.n_embed)).to(device)  # Shape : (head_dim / 2)
        # Construct the positions (the "m" parameter)
        m = torch.arange(self.seq_len, device=device)  # shape: (seq_len)
        # multiply each theta by each position using the outer product
        freq = torch.outer(m, theta).float()  # (seq_len, head_dim / 2)
        # we can computer complex numbers in the polar form c = R * exp(i * m * theta), where R = 1 as follow
        freq_complex = torch.polar(torch.ones_like(freq), freq)  # (seq_len, head_dim / 2)
        return freq_complex

    def _apply_rotary_embeddings(self, x, freq_complex):
        # TODO: batch?
        q_per_token_split_into_pairs = x.view(x.shape[0], -1, 2)
        q_per_token_as_complex = torch.view_as_complex(q_per_token_split_into_pairs)
        q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex * freq_complex[:len(x)])
        q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)
        return q_per_token_rotated

    def forward(self, x):
        freq_complex = self._precomputed_theta_pos_frequencies(x.device)


class RMSNorm(nn.Module):

    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        # The gamma parameter
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x: torch.Tensor):
        # (B, seq_len, dim) -> (B, seq_len, 1)
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor):
        # dim : (B, seq_len, dim) -> (B, seq_len, dim)
        return self.weight * self._norm(x.float()).type_as(x)


class Block(nn.Module):
    """Encoder block"""

    def __init__(self, config):
        super().__init__()
        self.rmsnorm_1 = RMSNorm(config.n_embed, eps=config.norm_eps)
        self.attn = CasualSelfAttention(config)
        self.rmsnorm_2 = RMSNorm(config.n_embed, eps=config.norm_eps)
        self.ff = FeedForward(config)

    def forward(self, x):
        x = x + self.attn(self.rmsnorm_1(x))
        x = x + self.ff(self.rmsnorm_2(x))
        return x



@dataclass
class Llama3Config:
    block_size: int = 4096  # max sequence length
    vocab_size: int = 128256  # number of tokens: 128,000 BPE merges + 256 bytes tokens
    n_layer: int = 32  # number of layers
    n_head: int = 32  # number of heads
    n_embd: int = 4096  # embedding dimension = dim
    multiple_of: int = 1024
    ffn_dim_multiplier: float = 1.3
    norm_eps: float = 1e-5
    rope_theta: float = 500000.0


class Llama3(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(
            dict(
                wte=nn.Embedding(config.vocab_size, config.n_embed),
                wpe=nn.Embedding(config.block_size, config.n_embed),
                h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
                ln_f=nn.LayerNorm(config.n_embed),
            )
        )
        self.lm_head = nn.Linear(config.n_embed, config.vocab_size, bias=False)

        # weight sharing scheme
        self.transformer.wte.weight = self.lm_head.weight

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            std = 0.02
            if hasattr(module, "NANOGPT_SCALE_INIT"):
                std *= (2 * self.config.n_layer) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        assert T <= self.config.block_size

        pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
        pos_emb = self.transformer.wpe(pos)

        tok_emb = self.transformer.wte(idx)

        x = tok_emb + pos_emb
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.transformer.lm_head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss