In [21]:
# class BPETokenizer:
    
#     def __init__(self, vocab_size=1000, special_tokens=["[PAD]", "[UNK]", "[CLS]", "[SEP]"]):
#         self.vocab_size = vocab_size
#         self.special_tokens = special_tokens
#         self.bpe_vocab = {}
#         self.merges = []
#         self.unk_token = "[UNK]"

#     def get_vocab(self):
#         return self.bpe_vocab

#     def build_vocab(self, corpus):
#         """Build the BPE vocabulary based on the corpus."""
#         # Split words into characters, with space being a delimiter for words
#         token_freqs = defaultdict(int)
#         for sentence in corpus:
#             words = sentence.strip().split()
#             for word in words:
#                 # Add spaces between characters and a word boundary symbol </w>
#                 word = " ".join(list(word)) + " </w>"
#                 token_freqs[word] += 1

#         # Build vocabulary by merging the most frequent pairs
#         for _ in range(self.vocab_size):
#             pairs = self.get_stats(token_freqs)
#             if not pairs:
#                 break

#             # Get the most frequent pair
#             best_pair = max(pairs, key=pairs.get)
#             self.merges.append(best_pair)

#             # Merge the best pair
#             token_freqs = self.merge_vocab(best_pair, token_freqs)

#         # Create final vocab mapping
#         self.bpe_vocab = {word: idx for idx, word in enumerate(self.special_tokens + list(token_freqs.keys()))}

#     def get_stats(self, token_freqs):
#         """Count frequency of token pairs in the vocabulary."""
#         pairs = defaultdict(int)
#         for word, freq in token_freqs.items():
#             tokens = word.split()
#             for i in range(len(tokens) - 1):
#                 pairs[(tokens[i], tokens[i + 1])] += freq
#         return pairs

#     def merge_vocab(self, pair, token_freqs):
#         """Merge the most frequent pair in the vocabulary."""
#         bigram = re.escape(" ".join(pair))
#         p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
#         new_vocab = {}
#         for word in token_freqs:
#             new_word = p.sub("".join(pair), word)
#             new_vocab[new_word] = token_freqs[word]
#         return new_vocab

#     def tokenize(self, text):
#         """Tokenize a sentence using the learned BPE vocabulary."""
#         words = text.strip().split()
#         tokens = []
#         for word in words:
#             word = " ".join(list(word)) + " </w>"
#             tokens.extend(self.encode_word(word))
#         return tokens

#     def encode_word(self, word):
#         """Encode a single word using the learned BPE merges."""
#         tokens = word.split()
#         for merge in self.merges:
#             while " ".join(merge) in " ".join(tokens):
#                 i = tokens.index(merge[0])
#                 if i + 1 < len(tokens) and tokens[i + 1] == merge[1]:
#                     tokens[i:i + 2] = ["".join(merge)]
#         return tokens

#     def convert_tokens_to_ids(self, tokens):
#         """Convert tokens to IDs using the BPE vocabulary."""
#         return [self.bpe_vocab.get(token, self.bpe_vocab[self.unk_token]) for token in tokens]

#     def convert_ids_to_tokens(self, ids):
#         """Convert token IDs back to tokens."""
#         inv_vocab = {idx: token for token, idx in self.bpe_vocab.items()}
#         return [inv_vocab.get(i, self.unk_token) for i in ids]

#     def encode(self, text, add_special_tokens=True):
#         """Convert text to token IDs, optionally adding special tokens."""
#         tokens = self.tokenize(text)
#         token_ids = self.convert_tokens_to_ids(tokens)
        
#         if add_special_tokens:
#             token_ids = [self.bpe_vocab["[CLS]"]] + token_ids + [self.bpe_vocab["[SEP]"]]
        
#         return token_ids

#     def decode(self, token_ids, skip_special_tokens=True):
#         """Convert token IDs back into text."""
#         tokens = self.convert_ids_to_tokens(token_ids)
#         if skip_special_tokens:
#             tokens = [t for t in tokens if t not in self.special_tokens]
#         return " ".join(tokens).replace(" </w>", "").replace(" ", "")


# with open("../data/shakespeare/input.txt", "r") as file:
#     corpus = file.readlines()

# bpe_tokenizer = BPETokenizer(vocab_size=vocab_size)
# bpe_tokenizer.build_vocab(corpus)

In [22]:
import os

import sentencepiece as spm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import numpy as np

In [23]:
# Hyperparameters
vocab_size = 10770   # Vocabulary size (e.g., from tokenizer)
embed_dim = 512      # Embedding dimension
num_heads = 8        # Number of attention heads
ff_hidden_dim = 2048 # Feedforward hidden dimension
num_layers = 6       # Number of transformer decoder layers
max_seq_len = 128    # Maximum sequence length
num_epochs = 10      # Number of training epochs

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('mps')

filepath = "../data/shakespeare/input.txt"

In [24]:
class CausalSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(CausalSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert (
            self.head_dim * num_heads == embed_dim
        ), "Embedding dimension must be divisible by number of heads"

        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        self.scale = self.head_dim ** -0.5  # scaling factor for query

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.size()

        # Compute Q, K, V
        Q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale

        # Causal mask (upper triangular, prevents attending to future tokens)
        causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).to(x.device)
        causal_mask = causal_mask.masked_fill(causal_mask == 1, float('-inf'))
        attn_scores = attn_scores + causal_mask

        # Attention probabilities
        attn_probs = F.softmax(attn_scores, dim=-1)

        # Weighted sum of values
        attn_output = torch.matmul(attn_probs, V)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)

        # Final output projection
        return self.out_proj(attn_output)


class TransformerDecoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_hidden_dim, dropout=0.1):
        super(TransformerDecoderBlock, self).__init__()
        self.self_attn = CausalSelfAttention(embed_dim, num_heads)
        self.layernorm1 = nn.LayerNorm(embed_dim)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ff_hidden_dim),
            nn.ReLU(),
            nn.Linear(ff_hidden_dim, embed_dim),
        )
        self.layernorm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Self-attention + residual connection
        attn_output = self.self_attn(x)
        x = x + self.dropout(attn_output)
        x = self.layernorm1(x)

        # Feedforward + residual connection
        ffn_output = self.ffn(x)
        x = x + self.dropout(ffn_output)
        x = self.layernorm2(x)

        return x


class CausalTransformerDecoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, ff_hidden_dim, num_layers, max_seq_len):
        super(CausalTransformerDecoder, self).__init__()
        self.embed_tokens = nn.Embedding(vocab_size, embed_dim)
        self.position_embedding = nn.Embedding(max_seq_len, embed_dim)
        self.layers = nn.ModuleList([
            TransformerDecoderBlock(embed_dim, num_heads, ff_hidden_dim)
        for _ in range(num_layers)])
        self.output_proj = nn.Linear(embed_dim, vocab_size)

        self.max_seq_len = max_seq_len

    def forward(self, input_ids):
        batch_size, seq_len = input_ids.size()

        # Token + position embeddings
        positions = torch.arange(0, seq_len, device=input_ids.device).unsqueeze(0)
        x = self.embed_tokens(input_ids) + self.position_embedding(positions)

        # Pass through layers
        for layer in self.layers:
            x = layer(x)

        # Final linear projection to vocab size
        logits = self.output_proj(x)

        return logits

In [25]:
class TextDataset(Dataset):

    def __init__(self, filepath, seq_len):
        self.tokens = []
        with open(filepath, "r") as file:
            corpus = "".join(file.readlines())
        model_prefix = os.path.join(os.path.dirname(filepath), "sentencepiece")
        spm.SentencePieceTrainer.train(
            input=filepath,
            model_prefix=model_prefix,
            vocab_size=vocab_size,
        )
        self.sp = spm.SentencePieceProcessor(model_file=model_prefix + ".model")
        tokens = self.sp.EncodeAsIds(corpus)
        for idx in range(0, len(tokens) - seq_len + 1):
            self.tokens.append(tokens[idx:idx + seq_len])

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        input_ids = self.tokens[idx]
        labels = input_ids[1:]
        labels.append(self.sp.pad_id())
        return torch.tensor(input_ids), \
               torch.tensor(labels)

dataset = TextDataset(filepath=filepath, seq_len=max_seq_len)
dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)

sentencepiece_trainer.cc(78) LOG(INFO) Starts training with : 
trainer_spec {
  input: ../data/shakespeare/input.txt
  input_format: 
  model_prefix: ../data/shakespeare/sentencepiece
  model_type: UNIGRAM
  vocab_size: 10770
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  seed_sentencepieces_file: 
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 0
  bos_id: 1
  eos_id: 2
  pad_id: -1
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s>
  pad_piece: <pad>
  unk_surface:  

In [26]:
model = CausalTransformerDecoder(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_heads=num_heads,
    ff_hidden_dim=ff_hidden_dim,
    num_layers=num_layers,
    max_seq_len=max_seq_len,
)
model = model.to(device)
criterion = nn.CrossEntropyLoss(ignore_index=dataset.sp.pad_id())
optimizer = optim.Adam(model.parameters(), lr=3e-5)

In [27]:
# for epoch in range(num_epochs):
#     model.train()  # Set the model to training mode
#     total_loss = 0  # Accumulate loss over the epoch

#     for batch_idx, (input_ids, labels) in enumerate(dataloader):
#         input_ids, labels = input_ids.to(device), labels.to(device)  # Move to device

#         optimizer.zero_grad()  # Clear gradients

#         # Forward pass
#         logits = model(input_ids)  # Shape: (batch_size, seq_len, vocab_size)

#         # Reshape logits and labels for the loss function
#         logits = logits.view(
#             -1, vocab_size
#         )  # Shape: (batch_size * seq_len, vocab_size)
#         labels = labels.view(-1)  # Shape: (batch_size * seq_len)

#         # Compute loss
#         loss = criterion(logits, labels)
#         print(labels)
#         total_loss += loss.item()

#         # Backward pass
#         loss.backward()
#         optimizer.step()  # Update parameters

#         if batch_idx % 10 == 0:  # Print every 10 batches
#             print(
#                 f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx/len(dataloader)*100:.4f}%], Loss: {loss.item():.4f}"
#             )

#     avg_loss = total_loss / len(dataloader)  # Average loss for the epoch
#     print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")

In [28]:
model

CausalTransformerDecoder(
  (embed_tokens): Embedding(10770, 512)
  (position_embedding): Embedding(128, 512)
  (layers): ModuleList(
    (0-5): 6 x TransformerDecoderBlock(
      (self_attn): CausalSelfAttention(
        (query): Linear(in_features=512, out_features=512, bias=True)
        (key): Linear(in_features=512, out_features=512, bias=True)
        (value): Linear(in_features=512, out_features=512, bias=True)
        (out_proj): Linear(in_features=512, out_features=512, bias=True)
      )
      (layernorm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (ffn): Sequential(
        (0): Linear(in_features=512, out_features=2048, bias=True)
        (1): ReLU()
        (2): Linear(in_features=2048, out_features=512, bias=True)
      )
      (layernorm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (output_proj): Linear(in_features=512, out_features=10770, bias=True)
)

In [29]:
model.load_state_dict(torch.load('../data/shakespeare/model_weights.pth', weights_only=True, map_location=torch.device('cpu')))
model.eval()

CausalTransformerDecoder(
  (embed_tokens): Embedding(10770, 512)
  (position_embedding): Embedding(128, 512)
  (layers): ModuleList(
    (0-5): 6 x TransformerDecoderBlock(
      (self_attn): CausalSelfAttention(
        (query): Linear(in_features=512, out_features=512, bias=True)
        (key): Linear(in_features=512, out_features=512, bias=True)
        (value): Linear(in_features=512, out_features=512, bias=True)
        (out_proj): Linear(in_features=512, out_features=512, bias=True)
      )
      (layernorm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (ffn): Sequential(
        (0): Linear(in_features=512, out_features=2048, bias=True)
        (1): ReLU()
        (2): Linear(in_features=2048, out_features=512, bias=True)
      )
      (layernorm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (output_proj): Linear(in_features=512, out_features=10770, bias=True)
)

In [30]:
sp = spm.SentencePieceProcessor(model_file="../data/shakespeare/sentencepiece.model")

In [31]:
import torch
import torch.nn.functional as F

def top_k_sampling(logits, k):
    """
    Perform top-k sampling on logits.
    
    Args:
        logits: Tensor of shape (vocab_size,)
        k: Number of top tokens to sample from.
    
    Returns:
        next_token: The sampled token index.
    """
    # Get top-k logits and their indices
    top_k_logits, top_k_indices = torch.topk(logits, k)

    # Convert logits to probabilities
    top_k_probs = F.softmax(top_k_logits, dim=-1)

    # Sample a token from the top-k probabilities
    next_token = torch.multinomial(top_k_probs, num_samples=1)

    # Map back to original vocabulary indices
    return top_k_indices[next_token.item()]

def sample_from_model(model, sp, max_length, start_sequence, k=10, temperature=1.0):
    """
    Generate a sequence from the trained transformer model using top-k sampling.
    
    Args:
        model: Trained transformer model.
        sp: SentencePiece tokenizer.
        max_length: Maximum sequence length.
        start_sequence: Initial sequence to start the generation.
        k: Number of top tokens to sample from (for top-k sampling).
        temperature: Controls randomness in sampling; higher values make the output more random.

    Returns:
        generated_sequence: List of token IDs (generated sequence).
    """
    model.eval()  # Set the model to evaluation mode
    generated_sequence = start_sequence[:]
    input_ids = torch.tensor(start_sequence, device=device).unsqueeze(0)  # Add batch dimension

    with torch.no_grad():
        for _ in range(max_length):
            # Forward pass through the model to get logits
            logits = model(input_ids)  # Shape: (batch_size, seq_len, vocab_size)
            
            # Take logits of the last token in the sequence
            logits = logits[:, -1, :] / temperature  # Shape: (batch_size, vocab_size)

            # Apply top-k sampling to select the next token
            next_token_id = top_k_sampling(logits.squeeze(), k)

            # Stop if the model outputs an end token (e.g., using the pad token as end token)
            if next_token_id == sp.pad_id():  # Assuming sp.pad_id() is the end token
                break

            # Append the new token to the generated sequence
            generated_sequence.append(next_token_id.cpu().item())

            # Update input for the next iteration
            input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0).unsqueeze(0)], dim=1)

    return generated_sequence

# Example usage:
# Assuming `model` is your trained CausalTransformerDecoder model and `sp` is your SentencePiece tokenizer
start_sequence = [sp.piece_to_id('[CLS]')]  # Replace with your actual start token

# Generate a sequence using top-k sampling
generated_sequence = sample_from_model(model, sp, max_length=50, start_sequence=start_sequence, k=10, temperature=1.0)

# Decode the token IDs back to text
generated_text = sp.decode(generated_sequence)
print(f'Generated Text: {generated_text}')

Generated Text:  ⁇ TON: And speaking it, he wistly look'd on me, And who should say, 'I would thou wert the man' That would divorce this terror from my heart;' Meaning the king at Pomfret. Come, let's go
