In [1]:
# Set up notebook.
try:
    import google.colab 
    IN_COLAB = True
except:
    IN_COLAB = False

if IN_COLAB:
    # Install packages
    %pip install transformer_lens
    %pip install einops
    %pip install jaxtyping
    import os, sys
else:
    pass

In [2]:
import os; os.environ['ACCELERATE_DISABLE_RICH'] = "1"
import sys
import json
import re
import math
from collections import defaultdict, Counter
from functools import partial
from typing import Tuple, List, Optional, Dict
from dataclasses import dataclass

import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from jaxtyping import Float, Int
from rich.table import Table
from rich import print as rprint
from pathlib import Path

import torch as t
from torch import Tensor
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, Dataset, DataLoader, random_split
import einops

device = t.device("cuda" if t.cuda.is_available() else "cpu")
print(device)
MAIN = __name__ == '__main__'

cpu


In [3]:
# Define GELU approximation function.
SQRT = np.sqrt(2/np.pi)
def geluprox(x):
    return t.tanh(SQRT*(x + 0.044715*x**3))

# Define softmax.
def softmax(z):
    return t.exp(z)/t.sum(t.exp(z))

In [4]:
# Load the data.
gemara = open('./gemara_english.txt', 'rb').read().decode(encoding='utf-8')
vocab  = sorted(set(gemara))
print ('{} unique Hebrew characters'.format(len(vocab)))

99 unique Hebrew characters


## Build transformer.

### Define architecture in config.

In [5]:
dm = 128

@dataclass
class Config:
    d_model: int = 768
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 16000
    init_range: float = 0.02
    n_ctx: int = 1024
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12

cfg = Config()
print(cfg)

Config(d_model=768, debug=True, layer_norm_eps=1e-05, d_vocab=16000, init_range=0.02, n_ctx=1024, d_head=64, d_mlp=3072, n_heads=12, n_layers=12)


In [6]:
def rand_float_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    random_input = t.randn(shape).to(device)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    if isinstance(output, tuple): output = output[0]
    print("Output shape:", output.shape, "\n")

def rand_int_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    random_input = t.randint(100, 1000, shape).to(device)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    if isinstance(output, tuple): output = output[0]
    print("Output shape:", output.shape, "\n")

### Define Layers.

In [7]:
class Embed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(t.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
        return self.W_E[tokens]

rand_int_test(Embed, [2, 4])

Input shape: torch.Size([2, 4])
Output shape: torch.Size([2, 4, 768]) 



In [8]:
class LayerNorm(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(t.ones(cfg.d_model))
        self.b = nn.Parameter(t.zeros(cfg.d_model))

    def forward(self, residual):
        mean = residual.mean(dim=-1, keepdim=True)
        vary = (residual.var(dim=-1, keepdim=True, unbiased=False) +
                            self.cfg.layer_norm_eps).sqrt()
        residual = (residual - mean)/vary
        return residual * self.w + self.b

rand_float_test(LayerNorm, [2, 4, 768])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768]) 



In [9]:
class PosEmbed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(t.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
        #batch, seq_len = tokens.shape
        shapey = tokens.shape
        batch = shapey[0]
        seq_len = shapey[1]
        #return einops.repeat(self.W_pos[:seq_len], "seq d_model -> batch seq d_model", batch=batch)
        return self.W_pos[:seq_len].unsqueeze(0).expand(batch, -1, -1)

rand_int_test(PosEmbed, [2, 4])

Input shape: torch.Size([2, 4])
Output shape: torch.Size([2, 4, 768]) 



In [10]:
class Attention(nn.Module):
    IGNORE: Float[Tensor, ""]

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_K = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_V = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_O = nn.Parameter(t.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        self.b_Q = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_K = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_V = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_O = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.register_buffer("IGNORE", t.tensor(-1e5, dtype=t.float32, device=device))

    def forward(
        self, x: Float[Tensor, "batch posn d_model"]
    ) -> Float[Tensor, "batch posn d_model"]: #normalized_resid_pre
        # x: residual stream, (batch, seq_len, d_model)
        # Compute key, query, and value tensors.
        # n heads, b batches, sequence length s, embedding e (d_model), head size h.
        K = t.einsum('nmh,bsm->bsnh', self.W_K, x) + self.b_K # (b, s, n, h)
        Q = t.einsum('nmh,bsm->bsnh', self.W_Q, x) + self.b_Q # (b, s, n, h)
        V = t.einsum('nmh,bsm->bsnh', self.W_V, x) + self.b_V # (b, s, n, h)
        # Compute attention scores.
        QKt = t.einsum('bsnh,btnh->bnst', Q, K) # (b, n, s, s)
        # Compute attention probabilities, with causal mask.
        A = self.apply_causal_mask(QKt/self.cfg.d_head**0.5).softmax(-1) # (b, n, s, s)
        # Sum over value vectors, weighted by A.
        z = t.einsum("bnst,btnh->bsnh", A, V) # (b, s, n, h)
        # Sum over parallel attention heads.
        attention = t.einsum("nhm,bsnh->bsm", self.W_O, z)
        return attention

    def apply_causal_mask(
        self, attn_scores: Float[Tensor, "batch n_heads query_pos key_pos"]
    ) -> Float[Tensor, "batch n_heads query_pos key_pos"]:
        '''
        Applies a causal mask to attention scores, and returns masked scores.
        '''
        onesy = t.ones(attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device)
        mask = t.triu(onesy, diagonal=1).bool()
        # Apply the mask to attention scores, then return the masked scores
        attn_scores.masked_fill_(mask, self.IGNORE)
        return attn_scores

rand_float_test(Attention, [2, 4, 768])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768]) 



In [11]:
class MLP(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(t.empty((cfg.d_model, cfg.d_mlp)))
        self.W_out = nn.Parameter(t.empty((cfg.d_mlp, cfg.d_model)))
        self.b_in = nn.Parameter(t.zeros((cfg.d_mlp)))
        self.b_out = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        nn.init.normal_(self.W_out, std=self.cfg.init_range)

    def forward(
        self, x: Float[Tensor, "batch posn d_model"]
    ) -> Float[Tensor, "batch posn d_model"]:
        neural = t.einsum('mn,bsm->bsn', self.W_in, x) + self.b_in
        neural = geluprox(neural)
        return t.einsum('nm,bsn->bsm', self.W_out, neural) + self.b_out

rand_float_test(MLP, [2, 4, 768])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768]) 



In [12]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(
        self, x: Float[Tensor, "batch position d_model"]
    ) -> Float[Tensor, "batch position d_model"]:
        y = self.attn(self.ln1(x)) + x
        z = self.mlp(self.ln2(y)) + y
        return z

rand_float_test(TransformerBlock, [2, 4, 768])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768]) 



In [13]:
class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(t.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(t.zeros((cfg.d_vocab), requires_grad=False))

    def forward(
        self, x: Float[Tensor, "batch position d_model"]
    ) -> Float[Tensor, "batch position d_vocab"]:
        return t.einsum('mv,bsm->bsv', self.W_U, x) + self.b_U

rand_float_test(Unembed, [2, 4, 768])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 16000]) 



In [14]:
class RoboRav(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg       = cfg
        self.embed     = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks    = nn.ModuleList([TransformerBlock(cfg) for
                                        _ in range(cfg.n_layers)])
        self.ln_final  = LayerNorm(cfg)
        self.unembed   = Unembed(cfg)

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_vocab"]:
        residual = self.embed(tokens) + self.pos_embed(tokens)
        for block in self.blocks:
            residual = block(residual)
        logits = self.unembed(self.ln_final(residual))
        return logits

## Train on Gemara. Start by tokenizing text.

In [16]:
class TalmudTokenizer:
    def __init__(self, vocab_size: int = 16000):
        self.vocab_size = vocab_size
        self.vocab: Dict[str, int] = {"<PAD>": 0, "<UNK>": 1, "<BOS>": 2, "<EOS>": 3}
        self.inverse_vocab: Dict[int, str] = {v: k for k, v in self.vocab.items()}
        self.merges: Dict[Tuple[str, str], str] = {}
        self.space_prefix = 'Ġ'

    def _get_stats(self, vocab):
        pairs = defaultdict(int)
        for word, freq in vocab.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pairs[symbols[i], symbols[i + 1]] += freq
        return pairs

    def _merge_vocab(self, pair, v_in):
        v_out = {}
        bigram = re.escape(' '.join(pair))
        p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
        for word in v_in:
            w_out = p.sub(''.join(pair), word)
            v_out[w_out] = v_in[word]
        return v_out

    def train(self, text: str):
        print("Starting tokenizer training...")
        
        # Preprocess text to add space prefix, including for the first word
        words = [self.space_prefix + word for word in text.split()]
        
        # Also add non-prefixed versions of words to the vocabulary
        non_prefixed_words = text.split()
        
        # Initialize vocab with character tokens
        chars = set(''.join(words + non_prefixed_words))
        for char in chars:
            if char not in self.vocab:
                self.vocab[char] = len(self.vocab)
                self.inverse_vocab[len(self.vocab) - 1] = char

        print(f"Initial vocabulary size: {len(self.vocab)}")
        
        # Convert words to space-separated character sequences
        vocab = Counter(' '.join(word) for word in words)
        vocab.update(' '.join(word) for word in non_prefixed_words)
        
        num_merges = self.vocab_size - len(self.vocab)
        for i in range(num_merges):
            pairs = self._get_stats(vocab)
            if not pairs:
                print(f"No more pairs to merge after {i} iterations")
                break
            
            best = max(pairs, key=pairs.get)
            vocab = self._merge_vocab(best, vocab)
            self.merges[best] = ''.join(best)
            new_token = ''.join(best)
            
            if new_token not in self.vocab:
                self.vocab[new_token] = len(self.vocab)
                self.inverse_vocab[len(self.vocab) - 1] = new_token
            
            if len(self.vocab) >= self.vocab_size:
                print(f"Reached target vocabulary size after {i+1} iterations")
                break
            
            if i % 100 == 0:
                print(f"Completed {i} merges. Current vocab size: {len(self.vocab)}")

        print(f"Final vocabulary size: {len(self.vocab)}")
        print(f"Number of merges: {len(self.merges)}")

    def _tokenize_word(self, word: str) -> List[str]:
        if word in self.vocab:
            return [word]
        
        word = ' '.join(word)
        tokens = []
        while len(word) > 0:
            subword = word
            while len(subword) > 0:
                if subword in self.vocab:
                    tokens.append(subword)
                    word = word[len(subword):].lstrip()
                    break
                subword = subword[:-1]
            if len(subword) == 0:
                tokens.append(word[0])
                word = word[1:].lstrip()
        return tokens

    def tokenize(self, text: str) -> List[int]:
        words = text.split()
        tokens = []
        for i, word in enumerate(words):
            if i == 0 or word.startswith(self.space_prefix):
                tokens.extend(self._tokenize_word(word))
            else:
                tokens.extend(self._tokenize_word(self.space_prefix + word))
        return [self.vocab.get(token, self.vocab["<UNK>"]) for token in tokens]

    def decode(self, token_ids: List[int]) -> str:
        tokens = [self.inverse_vocab.get(id, "<UNK>") for id in token_ids]
        text = ''.join(tokens).replace(self.space_prefix, ' ')
        return text.strip()

    def save(self, path: str):
        os.makedirs(path, exist_ok=True)
        with open(os.path.join(path, 'vocab.json'), 'w', encoding='utf-8') as f:
            json.dump(self.vocab, f, ensure_ascii=False, indent=2)
        with open(os.path.join(path, 'merges.json'), 'w', encoding='utf-8') as f:
            json.dump({' '.join(k): v for k, v in self.merges.items()}, f, ensure_ascii=False, indent=2)
        with open(os.path.join(path, 'config.json'), 'w', encoding='utf-8') as f:
            json.dump({'vocab_size': self.vocab_size, 'space_prefix': self.space_prefix}, f, indent=2)

    @classmethod
    def load(cls, path: str):
        with open(os.path.join(path, 'config.json'), 'r', encoding='utf-8') as f:
            config = json.load(f)
        tokenizer = cls(vocab_size=config['vocab_size'])
        tokenizer.space_prefix = config['space_prefix']
        
        with open(os.path.join(path, 'vocab.json'), 'r', encoding='utf-8') as f:
            tokenizer.vocab = json.load(f)
        tokenizer.inverse_vocab = {int(v): k for k, v in tokenizer.vocab.items()}
        
        with open(os.path.join(path, 'merges.json'), 'r', encoding='utf-8') as f:
            merges = json.load(f)
            tokenizer.merges = {tuple(k.split()): v for k, v in merges.items()}
        
        return tokenizer

In [26]:
class SequenceDataset(Dataset):
    def __init__(self, tokens, sequence_length):
        self.tokens = tokens
        self.sequence_length = sequence_length

    def __len__(self):
        return len(self.tokens) - self.sequence_length

    def __getitem__(self, idx):
        chunk = self.tokens[idx:idx + self.sequence_length + 1]
        return t.tensor(chunk[:-1], dtype=t.long), t.tensor(chunk[1:], dtype=t.long)

def prepare_data_for_training(tokens, sequence_length, batch_size, val_split=0.1):
    dataset = SequenceDataset(tokens, sequence_length)
    
    # Split into train and validation sets
    val_size = int(val_split * len(dataset))
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    
    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader

In [27]:
import time

talmud_text = gemara
# Train and save the tokenizer
"""
tokenizer = TalmudTokenizer(vocab_size=16000)
tokenizer.train(talmud_text)
tokenizer.save("talmud_tokenizer")
"""

print("Tokenizer saved.")

# Load the saved tokenizer
loaded_tokenizer = TalmudTokenizer.load("talmud_tokenizer")
print("Tokenizer loaded.")

# Test the loaded tokenizer
test_sentence = "Rav Pappa said to Rabbi Akiva, from where do we learn about sandwiches?"
encoded = loaded_tokenizer.tokenize(test_sentence)
decoded = loaded_tokenizer.decode(encoded)

print(f"\nTest sentence: {test_sentence}")
print(f"Encoded: {encoded}")
print(f"Decoded: {decoded}")

# Verify that the loaded tokenizer produces the same results as the original
original_encoded = loaded_tokenizer.tokenize(test_sentence)
print(f"\nOriginal encoded: {original_encoded}")
print(f"Loaded tokenizer encoded: {encoded}")
print(f"Encodings match: {original_encoded == encoded}")

# Print some statistics
print(f"\nVocabulary size: {len(loaded_tokenizer.vocab)}")
print(f"Number of merges: {len(loaded_tokenizer.merges)}")

Tokenizer saved.
Tokenizer loaded.

Test sentence: Rav Pappa said to Rabbi Akiva, from where do we learn about sandwiches?
Encoded: [212, 2444, 286, 137, 260, 5519, 256, 470, 512, 571, 2264, 1163, 89, 75, 68, 98, 24, 15, 7, 83, 53, 86, 75, 74]
Decoded: Rav Pappa said to Rabbi Akiva, from where do we learn about sandwiches?

Original encoded: [212, 2444, 286, 137, 260, 5519, 256, 470, 512, 571, 2264, 1163, 89, 75, 68, 98, 24, 15, 7, 83, 53, 86, 75, 74]
Loaded tokenizer encoded: [212, 2444, 286, 137, 260, 5519, 256, 470, 512, 571, 2264, 1163, 89, 75, 68, 98, 24, 15, 7, 83, 53, 86, 75, 74]
Encodings match: True

Vocabulary size: 16000
Number of merges: 15899


In [28]:
talmudtokens = loaded_tokenizer.tokenize(talmud_text)

In [29]:
print(len(talmudtokens))

12687031


In [30]:
train_loader, test_loader = prepare_data_for_training(talmudtokens, 256, 512)

### Initialize model.

In [31]:
dm = 128
cli = Config(
    debug=False,
    d_model=dm,
    n_heads=4,
    d_head=8,
    d_mlp=4*dm,
    n_layers=2,
    n_ctx=256,
    d_vocab=16000
)
roborav = RoboRav(cli)
loss_fn = nn.CrossEntropyLoss()  # Replace with your loss function
optimizer = optim.Adam(roborav.parameters(), lr=0.001)

### Define training model.

In [32]:
def train(model, train_loader, val_loader, loss_fn, optimizer, num_epochs, device):
    train_losses = []
    val_losses = []
    
    for epoch in range(num_epochs):
        model.train()
        epoch_train_losses = []
        
        # Training loop
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            inputs, targets = batch
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            
            # Reshape outputs and targets
            batch_size, seq_length, vocab_size = outputs.shape
            outputs = outputs.view(-1, vocab_size)  # Reshape to [batch_size * seq_length, vocab_size]
            targets = targets.view(-1)  # Reshape to [batch_size * seq_length]
            
            loss = loss_fn(outputs, targets)
            loss.backward()
            optimizer.step()
            
            epoch_train_losses.append(loss.item())
        
        avg_train_loss = np.mean(epoch_train_losses)
        train_losses.append(avg_train_loss)
        
        # Validation loop
        model.eval()
        epoch_val_losses = []
        with t.no_grad():
            for batch in val_loader:
                inputs, targets = batch
                inputs, targets = inputs.to(device), targets.to(device)
                
                outputs = model(inputs)
                
                # Reshape outputs and targets
                batch_size, seq_length, vocab_size = outputs.shape
                outputs = outputs.view(-1, vocab_size)  # Reshape to [batch_size * seq_length, vocab_size]
                targets = targets.view(-1)  # Reshape to [batch_size * seq_length]
                
                loss = loss_fn(outputs, targets)
                epoch_val_losses.append(loss.item())
        
        avg_val_loss = np.mean(epoch_val_losses)
        val_losses.append(avg_val_loss)
        
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
        
        # Save the model after each epoch
        t.save(model.state_dict(), f"model_epoch_{epoch+1}.pth")
    
    # Plot and save the loss curve
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, num_epochs+1), train_losses, label='Train Loss')
    plt.plot(range(1, num_epochs+1), val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.savefig('loss_curve.png')
    plt.close()
    
    # Save the losses to a file
    np.savez('losses.npz', train_losses=train_losses, val_losses=val_losses)
    return train_losses, val_losses

In [1]:
#train(roborav, train_loader, test_loader, loss_fn, optimizer, 1, device)

### Sample from the trained transformer.