### Full finished code, for reference

You may want to refer directly to the git repo instead though.

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [9]:
torch.manual_seed(1337)

with open('kinyas_kayra_clean.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [None]:
import json
import random
from collections import Counter
import re
from typing import List, Dict, Tuple, Optional, Set

class ByteLevelBPE:
    def __init__(self, merges: Optional[List[Tuple[str, str]]] = None, 
                 token_to_id: Optional[Dict[str, int]] = None, 
                 special_tokens: Optional[List[str]] = None):
        """Byte-level BPE tokenizer for Turkish text with UTF-8 support.
        
        Args:
            merges: List of merge rules as tuples
            token_to_id: Vocabulary mapping tokens to ids
            special_tokens: Special tokens to add to vocabulary
        """
        self.merges = merges or []
        self.merges_set = set(self.merges)
        self.token_to_id = token_to_id or {}
        self.id_to_token = {v: k for k, v in self.token_to_id.items()}
        self.vocab = None

        # Special tokens with default ones
        self.special_tokens = special_tokens or ['<pad>', '<unk>', '<sos>', '<eos>']
        self._initialize_special_tokens()
        
    def _initialize_special_tokens(self) -> None:
        """Initialize special tokens in vocabulary."""
        for tok in self.special_tokens:
            if tok not in self.token_to_id:
                self.token_to_id[tok] = len(self.token_to_id)
        self.id_to_token.update({v: k for k, v in self.token_to_id.items()})
        self.special_token_ids = {tok: self.token_to_id[tok] for tok in self.special_tokens}
        self.special_tokens_set = set(self.special_tokens)

    def train(self, text: str, num_merges: int = 5000, verbose: bool = True) -> None:
        """Train tokenizer on given text.
        
        Args:
            text: Training text corpus
            num_merges: Number of merge operations to perform
            verbose: Whether to print progress
        """
        self.vocab = self._get_vocab(text)
        self.merges = []
        
        for i in range(num_merges):
            pairs = self._get_stats(self.vocab)
            if not pairs:
                break
                
            best = max(pairs, key=pairs.get)
            self.vocab = self._merge_vocab(best, self.vocab)
            self.merges.append(best)
            
            if verbose and (i % 500 == 0 or i == num_merges - 1):
                print(f"Merge {i + 1}: {best} (frequency: {pairs[best]})")
                
        self.merges_set = set(self.merges)
        self._build_token_vocab()

    def _get_vocab(self, text: str) -> Counter:
        """Initialize vocabulary from text by splitting into UTF-8 bytes."""
        vocab = Counter()
        words = re.findall(r'\S+|\s+', text)
        
        for word in words:
            word_bytes = list(word.encode('utf-8'))
            word_bytes_str = [f"{b:03d}" for b in word_bytes]
            tokenized = ' '.join(word_bytes_str + ['</w>'])
            vocab[tokenized] += 1
            
        return vocab

    def _get_stats(self, vocab: Counter) -> Counter:
        """Get frequency statistics for possible merges."""
        pairs = Counter()
        
        for word, freq in vocab.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                if symbols[i + 1] == '</w>':
                    continue
                pairs[(symbols[i], symbols[i + 1])] += freq
                
        return pairs

    def _merge_vocab(self, pair: Tuple[str, str], vocab_in: Counter) -> Counter:
        """Merge the given pair in the vocabulary."""
        vocab_out = Counter()
        replacement = pair[0] + pair[1]
        
        for word, freq in vocab_in.items():
            symbols = word.split()
            new_symbols = []
            i = 0
            
            while i < len(symbols):
                if (i < len(symbols) - 1 and 
                    (symbols[i], symbols[i + 1]) == pair and
                    len(symbols[i]) == 3 and  # Ensure proper byte tokens
                    len(symbols[i + 1]) == 3):
                    new_symbols.append(replacement)
                    i += 2
                else:
                    new_symbols.append(symbols[i])
                    i += 1
                    
            new_word = ' '.join(new_symbols)
            vocab_out[new_word] = freq
            
        return vocab_out

    def _build_token_vocab(self) -> None:
        """Build the final vocabulary from base bytes and merges."""
        # Reset with special tokens
        self.token_to_id = {tok: i for i, tok in enumerate(self.special_tokens)}
        
        # Add base byte tokens (000-255)
        for i in range(256):
            tok = f"{i:03d}"
            if tok not in self.token_to_id:
                self.token_to_id[tok] = len(self.token_to_id)
        
        # Add merged tokens from vocabulary
        for word in self.vocab.keys():
            for token in word.split():
                if token != '</w>' and token not in self.token_to_id:
                    # Validate token format (multiples of 3 digits)
                    if len(token) % 3 == 0 and all(c.isdigit() for c in token):
                        self.token_to_id[token] = len(self.token_to_id)
        
        # Add all possible merges
        for a, b in self.merges:
            merged = a + b
            if merged not in self.token_to_id:
                self.token_to_id[merged] = len(self.token_to_id)
        
        self.id_to_token = {v: k for k, v in self.token_to_id.items()}
        self.special_token_ids = {tok: self.token_to_id[tok] for tok in self.special_tokens}

    def encode(self, text: str, dropout: float = 0.0, 
               add_special_tokens: bool = True) -> List[int]:
        """Encode text into token IDs.
        
        Args:
            text: Input text to encode
            dropout: Merge operation dropout probability
            add_special_tokens: Whether to add special tokens
            
        Returns:
            List of token IDs
        """
        words = re.findall(r'\S+|\s+', text)
        encoded = []
        
        if add_special_tokens:
            encoded.append(self.special_token_ids['<sos>'])
        
        for word in words:
            word_bytes = [f"{b:03d}" for b in word.encode('utf-8')] + ['</w>']
            
            # Apply BPE merges
            while len(word_bytes) > 1:
                pairs = self._get_word_pairs(word_bytes)
                mergeable = [p for p in pairs if p in self.merges_set]
                
                if not mergeable:
                    break
                    
                best = self._select_best_merge(pairs, dropout)
                if not best:
                    break
                    
                word_bytes = self._apply_merge(word_bytes, best)
            
            # Add to encoded tokens
            for token in word_bytes:
                if token == '</w>':
                    continue
                encoded.append(self.token_to_id.get(token, self.special_token_ids['<unk>']))
        
        if add_special_tokens:
            encoded.append(self.special_token_ids['<eos>'])
            
        return encoded

    def _get_word_pairs(self, word_bytes: List[str]) -> List[Tuple[str, str]]:
        """Get all possible adjacent pairs in word."""
        return [(word_bytes[i], word_bytes[i + 1]) for i in range(len(word_bytes) - 1)]

    def _select_best_merge(self, pairs: List[Tuple[str, str]], 
                          dropout: float) -> Optional[Tuple[str, str]]:
        """Select the best merge according to merge list with optional dropout."""
        for merge in self.merges:
            if merge in pairs:
                if dropout > 0 and random.random() < dropout:
                    continue  # Skip this merge due to dropout
                return merge
        return None

    def _apply_merge(self, word_bytes: List[str], 
                    best: Tuple[str, str]) -> List[str]:
        """Apply the merge operation to the word bytes."""
        new_word = []
        i = 0
        while i < len(word_bytes):
            if (i < len(word_bytes) - 1 and 
                (word_bytes[i], word_bytes[i + 1]) == best):
                new_word.append(word_bytes[i] + word_bytes[i + 1])
                i += 2
            else:
                new_word.append(word_bytes[i])
                i += 1
        return new_word

    def decode(self, token_ids: List[int], 
               skip_special_tokens: bool = True) -> str:
        """Decode token IDs back to text.
        
        Args:
            token_ids: List of token IDs to decode
            skip_special_tokens: Whether to skip special tokens
            
        Returns:
            Decoded text string
        """
        byte_seq = []
        
        for token_id in token_ids:
            token = self.id_to_token.get(token_id)
            if token is None or (skip_special_tokens and token in self.special_tokens_set):
                continue
                
            # Split token into 3-digit byte parts
            parts = []
            i = 0
            while i < len(token):
                part = token[i:i+3]
                if len(part) == 3 and part.isdigit():
                    parts.append(part)
                    i += 3
                else:
                    # Handle incomplete parts (shouldn't happen with proper tokens)
                    i += 1
            
            # Convert to bytes
            for part in parts:
                try:
                    byte_val = int(part)
                    if 0 <= byte_val <= 255:
                        byte_seq.append(byte_val)
                except ValueError:
                    continue
        
        # Decode with error handling
        try:
            return bytes(byte_seq).decode('utf-8', errors='strict')
        except UnicodeDecodeError:
            # Fallback to replace invalid sequences
            return bytes(byte_seq).decode('utf-8', errors='replace')

    def save_vocab(self, file_prefix: str) -> None:
        """Save vocabulary and merges to files.
        
        Args:
            file_prefix: Prefix for vocab and merges files
        """
        with open(f"{file_prefix}_merges.json", 'w', encoding='utf-8') as f:
            json.dump(self.merges, f, ensure_ascii=False, indent=2)
            
        with open(f"{file_prefix}_vocab.json", 'w', encoding='utf-8') as f:
            json.dump(self.token_to_id, f, ensure_ascii=False, indent=2)

    def load_vocab(self, file_prefix: str) -> None:
        """Load vocabulary and merges from files.
        
        Args:
            file_prefix: Prefix for vocab and merges files
        """
        with open(f"{file_prefix}_merges.json", 'r', encoding='utf-8') as f:
            self.merges = [tuple(merge) for merge in json.load(f)]
            self.merges_set = set(self.merges)
            
        with open(f"{file_prefix}_vocab.json", 'r', encoding='utf-8') as f:
            self.token_to_id = json.load(f)
            
        self.id_to_token = {v: k for k, v in self.token_to_id.items()}
        self._initialize_special_tokens()

    def get_vocab_size(self) -> int:
        """Get the size of the vocabulary."""
        return len(self.token_to_id)

    def tokenize(self, text: str) -> List[str]:
        """Tokenize text into subword tokens (for inspection)."""
        token_ids = self.encode(text, add_special_tokens=False)
        return [self.id_to_token[token_id] for token_id in token_ids]

    def inspect_tokenization(self, text: str) -> None:
        """Print detailed tokenization information for debugging."""
        print(f"Original text: {repr(text)}")
        tokens = self.tokenize(text)
        print("Tokens:", tokens)
        print("Token IDs:", self.encode(text, add_special_tokens=False))
        decoded = self.decode(self.encode(text))
        print("Decoded:", repr(decoded))
        print("Byte sequence:", [ord(c) for c in decoded])

tokenizer = ByteLevelBPE()
tokenizer.train(text, num_merges=1000)
tokenizer.save_vocab('turkish_bpe')

# here are all the unique characters that occur in this text
#chars = sorted(list(set(text)))
#vocab_size = len(chars)
# create a mapping from characters to integers
#stoi = { ch:i for i,ch in enumerate(chars) }
#itos = { i:ch for i,ch in enumerate(chars) }
#encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
#decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

Merge 1: ('196', '177') (frequency: 45258)
Merge 501: ('196177', '110') (frequency: 9525)
Merge 1000: ('196177', '110') (frequency: 9525)


In [6]:
tokenizer.inspect_tokenization("şçöğü")

# Metni encode/decode etme
encoded = tokenizer.encode("merhaba dünya")
decoded = tokenizer.decode(encoded)
print(decoded)  # "merhaba dünya"

# Vocab kaydetme/yükleme
#tokenizer.save_vocab("turkish_bpe")
tokenizer.load_vocab("bpe_tokenizer")

Original text: 'şçöğü'
Tokens: ['197159', '195167', '195', '182', '196159', '195188']
Token IDs: [269, 268, 199, 186, 266, 261]
Decoded: 'şçöğü'
Byte sequence: [351, 231, 246, 287, 252]
merhaba dünya


In [None]:
test_cases = [
    "İstanbul'da şehir içi ulaşım çok karmaşık",
    "Pijamalı hasta yağız şoföre çabucak güvendi",
    "Fahiş fiyatlarla mücadele ederken güğümsü renkler içinde"
]

for _text in test_cases:
    print(f"\nTesting: {_text}")
    tokenizer.inspect_tokenization(_text)
    encoded = tokenizer.encode(_text)
    decoded = tokenizer.decode(encoded)
    print("Success!" if decoded == _text else "Failed!")


Testing: İstanbul'da şehir içi ulaşım çok karmaşık
Original text: "İstanbul'da şehir içi ulaşım çok karmaşık"
Tokens: ['196', '176', '115', '116', '097110', '098', '117', '108', '039', '100', '097', '032', '197159', '101', '104', '105', '114', '032', '105', '195167', '105', '032', '117', '108', '097', '197159', '196177', '109', '032', '195167', '111', '107', '032', '107', '097114', '109', '097', '197159', '196177', '107']
Token IDs: [200, 180, 119, 120, 263, 102, 121, 112, 43, 104, 101, 36, 269, 105, 108, 109, 118, 36, 109, 268, 109, 36, 121, 112, 101, 269, 265, 113, 36, 268, 115, 111, 36, 111, 260, 113, 101, 269, 265, 111]
Decoded: "İstanbul'da şehir içi ulaşım çok karmaşık"
Byte sequence: [304, 115, 116, 97, 110, 98, 117, 108, 39, 100, 97, 32, 351, 101, 104, 105, 114, 32, 105, 231, 105, 32, 117, 108, 97, 351, 305, 109, 32, 231, 111, 107, 32, 107, 97, 114, 109, 97, 351, 305, 107]
Success!

Testing: Pijamalı hasta yağız şoföre çabucak güvendi
Original text: 'Pijamalı hasta yağız şoför

In [10]:
from tqdm import tqdm

def encode_text_with_bpe_ids(bpe_obj, text):
    tokens = []
    words = re.findall(r'\S+|\s+', text)
    for word in tqdm(words, desc="Encoding with BPE"):
        tokens.extend(bpe_obj.encode(word, add_special_tokens=False))
    return tokens

tokens = encode_text_with_bpe_ids(tokenizer, text)
data = torch.tensor(tokens, dtype=torch.long)

n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

print(f"Total token: {len(data)}")
print(f"Train data size: {len(train_data)}")
print(f"Val data size: {len(val_data)}")

# Train and test splits
#data = torch.tensor(encode(text), dtype=torch.long)
#n = int(0.9*len(data)) # first 90% will be train, rest val
#train_data = data[:n]
#val_data = data[n:]

Encoding with BPE: 100%|██████████| 286013/286013 [00:03<00:00, 93714.66it/s]

Total token: 1001484
Train data size: 901335
Val data size: 100149





In [11]:
# data loading
def get_batch(split):
    data_split = train_data if split == 'train' else val_data
    ix = torch.randint(len(data_split) - block_size, (batch_size,))
    x = torch.stack([data_split[i:i+block_size] for i in ix])
    y = torch.stack([data_split[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

In [12]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [13]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

In [14]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [15]:
class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

In [16]:
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


In [None]:
# hyperparameters
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = 256 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 500
learning_rate = 3e-4

grad_clip = 1.0
best_val_loss = float('inf')
patience_counter = 0
patience = 3

eval_iters = 200
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.5
# ------------

vocab_size = len(tokenizer.token_to_id)
vocab_size

273

In [20]:
# super simple bigram model
class BigramLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :]  # (B, vocab_size)

            logits = logits / temperature

            if top_k is not None:
                v, _ = torch.topk(logits, top_k)
                min_v = v[:, -1].unsqueeze(1)
                logits = torch.where(logits < min_v, torch.full_like(logits, -float('Inf')), logits)

            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [21]:
model = BigramLanguageModel()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-2)

11.243793 M parameters


In [None]:
import torch
import math
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(log_dir="runs/bpe_transformer")

def get_lr(it, warmup_iters=500, max_lr=1e-3, total_iters=5000):
    if it < warmup_iters:
        return max_lr * it / warmup_iters
    elif it > total_iters:
        return 0.0
    else:
        decay_ratio = (it - warmup_iters) / (total_iters - warmup_iters)
        return max_lr * 0.5 * (1.0 + math.cos(math.pi * decay_ratio))

for iter in range(max_iters):
    # Learning rate scheduler
    lr = get_lr(iter)
    for g in optimizer.param_groups:
        g['lr'] = lr

    # Değerlendirme ve log
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        train_loss = losses['train']
        val_loss = losses['val']

        print(f"Step {iter}: Train {train_loss:.4f}, Val {val_loss:.4f}, LR {lr:.6f}")
        writer.add_scalar("Loss/train", train_loss, iter)
        writer.add_scalar("Loss/val", val_loss, iter)
        writer.add_scalar("Learning Rate", lr, iter)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), 'best_model.pt')
            print("The new model is better than the old model. The best model has been updated.")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping.")
                break

    # Eğitim adımı
    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()


#for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
#    if iter % eval_interval == 0 or iter == max_iters - 1:
#        losses = estimate_loss()
#        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
#    xb, yb = get_batch('train')

    # evaluate the loss
#    logits, loss = model(xb, yb)
#    optimizer.zero_grad(set_to_none=True)
#    loss.backward()
#    optimizer.step()


In [None]:
prompt = "kinyas kayra"
prompt_tokens = tokenizer.encode(prompt)  # split yapma
context = torch.tensor(prompt_tokens, dtype=torch.long, device=device).unsqueeze(0)

generated_ids = model.generate(context, max_new_tokens=100, temperature=0.7, top_k=50)[0].tolist()
print("Generated text:")
print(tokenizer.decode(generated_ids))


# generate from the model
#context = torch.zeros((1, 1), dtype=torch.long, device=device)
#print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))

KeyError: '032'

In [None]:
print('032' in bpe.token_to_id) 

False
