In [None]:
import os

from pathlib import Path

import kagglehub

import sentencepiece as spm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import numpy as np

In [None]:
# Hyperparameters
vocab_size = 1000    # Vocabulary size
embed_dim = 512      # Embedding dimension
num_heads = 8        # Number of attention heads
ff_hidden_dim = 2048 # Feedforward hidden dimension
num_layers = 6       # Number of transformer decoder layers
max_seq_len = 128    # Maximum sequence length
num_epochs = 10      # Number of training epochs
learning_rate = 3e-4 # Learning rate (increased for from-scratch training)

device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
input_directory = kagglehub.dataset_download("mruanova/shakespeare")
input_filepath = os.path.join(input_directory, "shakespeare.txt")
output_filepath = os.path.join("results", "shakespeare-transformer")

Path(output_filepath).mkdir(parents=True, exist_ok=True)

In [None]:
class CausalSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(CausalSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert (
            self.head_dim * num_heads == embed_dim
        ), "Embedding dimension must be divisible by number of heads"

        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        self.scale = self.head_dim ** -0.5  # scaling factor for query

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.size()

        # Compute Q, K, V
        Q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale

        # Causal mask (upper triangular, prevents attending to future tokens)
        causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).to(x.device)
        causal_mask = causal_mask.masked_fill(causal_mask == 1, float('-inf'))
        attn_scores = attn_scores + causal_mask

        # Attention probabilities
        attn_probs = F.softmax(attn_scores, dim=-1)

        # Weighted sum of values
        attn_output = torch.matmul(attn_probs, V)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)

        # Final output projection
        return self.out_proj(attn_output)


class TransformerDecoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_hidden_dim, dropout=0.1):
        super(TransformerDecoderBlock, self).__init__()
        self.self_attn = CausalSelfAttention(embed_dim, num_heads)
        self.layernorm1 = nn.LayerNorm(embed_dim)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ff_hidden_dim),
            nn.ReLU(),
            nn.Linear(ff_hidden_dim, embed_dim),
        )
        self.layernorm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Pre-LN: LayerNorm before self-attention + residual connection
        attn_output = self.self_attn(self.layernorm1(x))
        x = x + self.dropout(attn_output)

        # Pre-LN: LayerNorm before feedforward + residual connection
        ffn_output = self.ffn(self.layernorm2(x))
        x = x + self.dropout(ffn_output)

        return x


class CausalTransformerDecoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, ff_hidden_dim, num_layers, max_seq_len, dropout=0.1):
        super(CausalTransformerDecoder, self).__init__()
        self.embed_tokens = nn.Embedding(vocab_size, embed_dim)
        self.position_embedding = nn.Embedding(max_seq_len, embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.layers = nn.ModuleList([
            TransformerDecoderBlock(embed_dim, num_heads, ff_hidden_dim, dropout)
        for _ in range(num_layers)])
        self.layernorm = nn.LayerNorm(embed_dim)  # Final layer norm
        self.output_proj = nn.Linear(embed_dim, vocab_size)

        self.max_seq_len = max_seq_len

    def forward(self, input_ids):
        batch_size, seq_len = input_ids.size()

        # Token + position embeddings
        positions = torch.arange(0, seq_len, device=input_ids.device).unsqueeze(0)
        x = self.embed_tokens(input_ids) + self.position_embedding(positions)
        x = self.dropout(x)

        # Pass through layers
        for layer in self.layers:
            x = layer(x)

        # Final layer norm
        x = self.layernorm(x)

        # Final linear projection to vocab size
        logits = self.output_proj(x)

        return logits

In [None]:
class TextDataset(Dataset):

    def __init__(self, input_filepath, output_filepath, seq_len):
        self.tokens = []
        with open(input_filepath, "r") as file:
            corpus = "".join(file.readlines())
        model_prefix = os.path.join(output_filepath, "sentencepiece")
        spm.SentencePieceTrainer.train(
            input=input_filepath,
            model_prefix=model_prefix,
            vocab_size=vocab_size,
        )
        self.sp = spm.SentencePieceProcessor(model_file=model_prefix + ".model")
        tokens = self.sp.EncodeAsIds(corpus)
        for idx in range(0, len(tokens) - seq_len):
            self.tokens.append(tokens[idx:idx + seq_len + 1])

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        sequence = self.tokens[idx]
        input_ids = sequence[:-1]
        labels = sequence[1:]
        return torch.tensor(input_ids), torch.tensor(labels)

dataset = TextDataset(input_filepath=input_filepath, output_filepath=output_filepath, seq_len=max_seq_len)
# Use num_workers=0 for Jupyter notebooks to avoid multiprocessing issues
dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=0)

In [None]:
model = CausalTransformerDecoder(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_heads=num_heads,
    ff_hidden_dim=ff_hidden_dim,
    num_layers=num_layers,
    max_seq_len=max_seq_len,
)
model = model.to(device)
criterion = nn.CrossEntropyLoss(ignore_index=dataset.sp.pad_id())
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for batch_idx, (input_ids, labels) in enumerate(dataloader):
        input_ids, labels = input_ids.to(device), labels.to(device)

        optimizer.zero_grad()

        # Forward pass
        logits = model(input_ids)

        # Reshape logits and labels for the loss function
        logits = logits.view(-1, vocab_size)  # Shape: (batch_size * seq_len, vocab_size)
        labels = labels.view(-1)  # Shape: (batch_size * seq_len)

        # Compute loss
        loss = criterion(logits, labels)
        total_loss += loss.item()

        # Backward pass
        loss.backward()
        optimizer.step()

        if batch_idx % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Progress [{batch_idx/len(dataloader)*100:.4f}%], Loss: {loss.item():.4f}')

    avg_loss = total_loss / len(dataloader)  # Average loss for the epoch
    print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')

In [None]:
model_output_path = os.path.join(output_filepath, 'model_weights.pth')
torch.save(model.state_dict(), model_output_path)

In [None]:
def generate_text(model, sp, prompt, max_length=100, temperature=1.0, top_k=50):
    """
    Generate text from the model given a prompt.
    
    Args:
        model: Trained CausalTransformerDecoder
        sp: SentencePiece processor
        prompt: Input text string to start generation
        max_length: Maximum number of tokens to generate
        temperature: Sampling temperature (higher = more random)
        top_k: Number of top tokens to sample from (0 = no filtering)
    """
    model.eval()
    
    # Encode the prompt
    input_ids = sp.EncodeAsIds(prompt)
    input_ids = torch.tensor([input_ids]).to(device)
    
    generated_ids = input_ids[0].tolist()
    
    with torch.no_grad():
        for _ in range(max_length):
            # Limit context to max_seq_len
            context = input_ids[:, -max_seq_len:]
            
            # Get logits
            logits = model(context)
            
            # Get logits for the last token
            logits = logits[0, -1, :] / temperature
            
            # Apply top-k filtering
            if top_k > 0:
                top_k_values, top_k_indices = torch.topk(logits, top_k)
                logits = torch.full_like(logits, float('-inf'))
                logits[top_k_indices] = top_k_values
            
            # Sample from the distribution
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Append to generated sequence
            generated_ids.append(next_token.item())
            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
            
            # Stop if we generate EOS token
            if next_token.item() == sp.eos_id():
                break
    
    # Decode the generated sequence
    generated_text = sp.DecodeIds(generated_ids)
    return generated_text


# Test text generation
prompt = "To be or not to be"
print(f"Prompt: {prompt}")
print(f"\nGenerated text:")
generated = generate_text(model, dataset.sp, prompt, max_length=100, temperature=0.8, top_k=50)
print(generated)