In [1]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F

In [2]:
class InputEmbeddings(nn.Module):
    """
    Converts input token indices into dense vector representations

    Why its important: Models process number and not text. Embeddings map tokens to a continuous vector space where semantic similarity is reflected by distance and direction.
    How we build it: We use PyTorch's nn.Embedding layer and scale the outputs by the square root of the embedding dimension. This is mainly for training stability.
    """
    def __init__(self, embedding_size: int, vocab_size: int):
        super().__init__()
        self.embedding_size = embedding_size
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, embedding_size)

    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.embedding_size)

In [3]:
class PositionalEncoding(nn.Module):
    """
    Computes the relative or absolute position of tokens in the sequence using sinusoidal functions

    Why its important: The attention mechanism is permutation invariant. Without positional encoding, the model will treat the sequence as a bag of words.
    How we build it: We use sine and cosine functions of different frequencies to generate a unique encoding for each position. These are added to token embeddings.
    """
    def __init__(self, embedding_size: int, dropout: float, sequence_len: int):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Create a positional encoding matrix of shape (sequence_len, embedding_size)
        positional_encoding = torch.zeros(sequence_len, embedding_size)
        position = torch.arange(0, sequence_len, dtype=torch.float).unsqueeze(1) # (sequence_len, 1)
        div_term = torch.exp(torch.arange(0, embedding_size, 2).float() * (-math.log(10000.0) / embedding_size))

        # Apply sine to even indices and cosine to odd indices
        positional_encoding[:, 0::2] = torch.sin(position * div_term)
        positional_encoding[:, 1::2] = torch.cos(position * div_term)
        positional_encoding = positional_encoding.unsqueeze(0) # (1, sequence_len, embedding_size) for batch broadcasting

        self.register_buffer("positional_encoding", positional_encoding) # Not a model parameter but a part of the state

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch_size, sequence_len, embedding_size)
        """
        x = x + self.positional_encoding[:, :x.size(1)] # Add positional encoding
        return self.dropout(x)

In [4]:
class LayerNormalization(nn.Module):
    """
    Normalizes the inputs across the feature dimension for each data point in the batch independently

    Why its important: Reduces internal covariate shift and thus helps in stabilizing and accelerating training
    How we build it: Compute mean and std for each input across its feature dimension
    """
    def __init__(self, embedding_size: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(embedding_size)) # Learnable scale
        self.bias = nn.Parameter(torch.zeros(embedding_size)) # Learnable shift

    def forward(self, x):
        mean = x.mean(dim=-1, keepdims=True)
        std = x.std(dim=-1, keepdims=True, unbiased=False)
        normalised = (x - mean) / (std + self.eps)
        return self.alpha * normalised + self.bias

In [5]:
class MultiHeadAttention(nn.Module):
    """
    The multi-head attention mechanism allows the model to focus on different parts of the input sequence simultaneously.
    ANALOGY: Researching a topic (query) when you have multiple books (keys) with different content (values). Attention is like deciding which books are relevant and how much to read from each.
    """
    def __init__(self, embedding_size: int, n_heads: int, dropout: float):
        super().__init__()
        assert embedding_size % n_heads == 0, "Embedding size must be divisible by n_heads"
        self.embedding_size = embedding_size
        self.n_heads = n_heads
        self.dimensions_per_head = embedding_size // n_heads # Dimensions per head

        # Why separate projects? Each head learns different aspects
        self.w_q = nn.Linear(embedding_size, embedding_size) # Query projection
        self.w_k = nn.Linear(embedding_size, embedding_size) # Key projection
        self.w_v = nn.Linear(embedding_size, embedding_size) # Value projection
        self.w_o = nn.Linear(embedding_size, embedding_size) # Output projection

        self.dropout = nn.Dropout(dropout)

    # Attention mechanism: Core calculation
    @staticmethod
    def attention(query, key, value, mask=None, dropout=None):
        """
        Computes the scaled dot product attention
        """
        head_dimension = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(head_dimension)
        if mask is not None:
            scores.masked_fill_(mask == 0, -1e9)

        attention_weights = F.softmax(scores, dim=-1)
        if dropout is not None:
            attention_weights = dropout(attention_weights)

        return torch.matmul(attention_weights, value), attention_weights

    def forward(self, query, key, value, mask=None):
        """
        Query, Key, Value: Tensors of shape (batch_size, seq_len, embedding_size)
        mask: To prevent attention of certain positions
        """
        batch_size = query.size(0)

        # Linear projections and split into heads
        query = self.w_q(query).view(batch_size, -1, self.n_heads, self.dimensions_per_head).transpose(1, 2)
        key = self.w_k(key).view(batch_size, -1, self.n_heads, self.dimensions_per_head).transpose(1, 2)
        value = self.w_v(value).view(batch_size, -1, self.n_heads, self.dimensions_per_head).transpose(1, 2)

        # Apply attention
        x, self.attention_weights = MultiHeadAttention.attention(query, key, value, mask, self.dropout)

        # Concatenate heads and put through final linear layer
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.embedding_size)
        return self.w_o(x)

In [6]:
class PositionWiseFFN(nn.Module):
    """
    A simple feed-forward network applied to each position in the sequence independently

    Why its important: The self attention output is a linear combination of values. The FFN introduces non-linearity and allows for more complex transformations.
    How we build it: We use two linear layers with an expansion factor (4x) and a ReLU activation in between.
    """
    def __init__(self, embedding_size: int, hidden_size: int, dropout: float):
        super().__init__()
        self.linear1 = nn.Linear(embedding_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, embedding_size)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()

    def forward(self, x):
        return self.linear2(self.dropout(self.activation(self.linear1(x))))

In [7]:
class EncoderBlock(nn.Module):
    """
    Combines multi-head attention and the feed-forward network with residual connections and layer normalization.

    Why its important: Transforms input sequences into contextualized representations. Stacking blocks allows the model to build up increasingly abstract and complex representations of the input.
    How we build it: Wrap the MultiHeadAttention and PositionWiseFFN with residual connections and layer norm for stable training.
    """
    def __init__(self, embedding_size: int, n_head: int, hidden_size: int, dropout: float):
        super().__init__()
        self.self_attention = MultiHeadAttention(embedding_size, n_heads, dropout)
        self.feed_forward = PositionWiseFFN(embedding_size, hidden_size, dropout)
        self.norm1 = LayerNormalization(embedding_size)
        self.norm2 = LayerNormalization(embedding_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Self attention with residual connection and layer norm
        attended = self.self_attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attended))

        # Feed-forward with residual connection and layer norm
        feed_forward = self.feed_forward(x)
        x = self.norm2(x + self.dropout(feed_forward))
        return x

In [8]:
class DecoderBlock(nn.Module):
    """
    Similar to encoder but includes an additional cross-attention layer to attend to the encoders output.
    A single block of decoder comprising of self-attention, cross-attention (encoder-decoder attention) and a feed-forward layer.

    Why its important: It generates output sequence one token at a time, using both the previously generated tokens (self-attention) and the encoded input (cross-attention).
        For sequence-to-sequence tasks like translation, the decoder must condition its output on both target sequence so far and the entire input sequence.
    How we build it: Add a third sublayer for encoder-decoder attention on top of the two sublayers found in the encoder.
    """
    def __init__(self, embedding_size: int, n_heads: int, hidden_size: int, dropout: float):
        super().__init__()
        self.self_attention = MultiHeadAttention(embedding_size, n_heads, dropout)
        self.cross_attention = MultiHeadAttention(embedding_size, n_heads, dropout)
        self.feed_forward = PositionWiseFFN(embedding_size, n_heads, dropout)
        self.norm1 = LayerNormalization(embedding_size)
        self.norm2 = LayerNormalization(embedding_size)
        self.norm3 = LayerNormalization(embedding_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        # Masked sef-attention (prevents attending to future tokens)
        attended_self = self.self_attention(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attended_self))

        # Cross-attention to encoder output
        attended_cross = self.cross_attention(x, encoder_output, encoder_output, src_mask)
        x = self.norm2(x + self.dropout(attended_cross))

        # Feed-forward
        feed_forward = self.feed_forward(x)
        x = self.norm3(x + self.dropout(feed_forward))
        return x

In [24]:
class Transformer(nn.Module):
    """
    Combine embedding layers, stack of encoder blocks and a stack of decoder blocks

    Why its important: Integrate all components for end to end training.
    How we build it: Chain embedding layers, encoder (stack of N encoder blocks), decoder (stack of N decoder blocks) followed by final linear projection.
    """
    def __init__(self, src_vocab_size: int, tgt_vocab_size: int, src_max_len: int, tgt_max_len: int, embedding_size: int, n_heads: int, layers: int, hidden_size: int, dropout: float):
        super().__init__()
        self.encoder_embedding = InputEmbeddings(embedding_size, src_vocab_size)
        self.decoder_embedding = InputEmbeddings(embedding_size, tgt_vocab_size)
        self.src_pos_encoding = PositionalEncoding(embedding_size, dropout, src_max_len)
        self.tgt_pos_encoding = PositionalEncoding(embedding_size, dropout, tgt_max_len)

        self.encoder_blocks = nn.ModuleList([EncoderBlock(embedding_size, n_heads, hidden_size, dropout) for _ in range(layers)])
        self.decoder_blocks = nn.ModuleList([DecoderBlock(embedding_size, n_heads, hidden_size, dropout) for _ in range(layers)])

        self.final_linear = nn.Linear(embedding_size, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def encode(self, src, src_mask):
        src_embedded = self.dropout(self.src_pos_encoding(self.encoder_embedding(src)))
        for block in self.encoder_blocks:
            src_embedded = block(src_embedded, src_mask)
        return src_embedded

    def decode(self, tgt, encoder_output, src_mask, tgt_mask):
        tgt_embedded = self.dropout(self.tgt_pos_encoding(self.decoder_embedding(tgt)))
        for block in self.decoder_blocks:
            tgt_embedded = block(tgt_embedded, encoder_output, src_mask, tgt_mask)
        return tgt_embedded

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        encoder_output = self.encode(src, src_mask)
        decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
        return self.final_linear(decoder_output)

In [25]:
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader

dataset_name = 'Trelis/tiny-shakespeare'
dataset = load_dataset(dataset_name)

# Access the splits
train_dataset = dataset['train']
test_dataset = dataset['test']

# To see an example from the training set
print(train_dataset[0]['Text'])

# Combine all text from training dataset to build vocabulary
text_data = "".join([example['Text'] for example in train_dataset])

# Create a character level vocabulary
vocab = sorted(list(set(text_data)))
vocab_size = len(vocab)
char_to_idx = {chr: idx for idx, chr in enumerate(vocab)}
idx_to_char = {idx: chr for idx, chr in enumerate(vocab)}

# Encode the entire dataset
def encode_text(text):
    return torch.tensor([char_to_idx.get(ch, 0) for ch in text], dtype=torch.long)

# Encode train and test data
train_text = "".join([example['Text'] for example in train_dataset])
test_text = "".join([example['Text'] for example in test_dataset])

train_data = encode_text(train_text)
val_data = encode_text(test_text)

print(f"Training data length: {len(train_data)}")
print(f"Validation data length: {len(val_data)}")

def get_batch(split, batch_size, block_size):
    data = train_data if split == "train" else val_data
    # Ensure we don't go out of bounds
    max_start_idx = len(data) - block_size - 1
    if max_start_idx <= 0:
        raise ValueError(f"Data too short for block_size {block_size}")

    idx = torch.randint(0, max_start_idx, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in idx])
    y = torch.stack([data[i+1:i+block_size+1] for i in idx])
    return x, y

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [29]:
# Training hyperparameters
batch_size = 32
block_size = 128
embedding_size = 256
n_heads = 8
layers = 6
hidden_size = 1024
dropout = 0.1
learning_rate = 0.0003
epochs = 100
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Initialize model, loss and optimizer
model = Transformer(
    src_vocab_size=vocab_size,
    tgt_vocab_size=vocab_size,
    src_max_len=block_size,
    tgt_max_len=block_size,
    embedding_size=embedding_size,
    n_heads=n_heads,
    layers=layers,
    hidden_size=hidden_size,
    dropout=dropout
).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

Using device: cpu
Model parameters: 7,981,937


In [None]:
# Training loop
for epoch in range(epochs):
    model.train()
    x, y = get_batch('train', batch_size, block_size)
    x, y = x.to(DEVICE), y.to(DEVICE)

    # For language modeling, we use the same sequence for src and tgt
    # But we need to shift the target for teacher forcing
    logits = model(x, x[:, :-1])  # src: full sequence, tgt: sequence without last token

    # Reshape for loss calculation
    # logits: (batch_size, seq_len-1, vocab_size)
    # targets: (batch_size, seq_len-1)
    loss = criterion(logits.reshape(-1, vocab_size), y[:, 1:].reshape(-1))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 1 == 0:
        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")

Epoch 1/100, Loss: 4.3385
Epoch 2/100, Loss: 3.6365
Epoch 3/100, Loss: 3.5255
Epoch 4/100, Loss: 3.4619
Epoch 5/100, Loss: 3.3663
Epoch 6/100, Loss: 3.3703
Epoch 7/100, Loss: 3.3303
Epoch 8/100, Loss: 3.3616
Epoch 9/100, Loss: 3.3421
Epoch 10/100, Loss: 3.3060
Epoch 11/100, Loss: 3.3744
Epoch 12/100, Loss: 3.3474
Epoch 13/100, Loss: 3.3089
Epoch 14/100, Loss: 3.3364
Epoch 15/100, Loss: 3.3359
Epoch 16/100, Loss: 3.3264
Epoch 17/100, Loss: 3.3231
Epoch 18/100, Loss: 3.3204
Epoch 19/100, Loss: 3.2921
Epoch 20/100, Loss: 3.2990
Epoch 21/100, Loss: 3.2822
Epoch 22/100, Loss: 3.3046
Epoch 23/100, Loss: 3.2927
Epoch 24/100, Loss: 3.2764
Epoch 25/100, Loss: 3.3526
Epoch 26/100, Loss: 3.3020
Epoch 27/100, Loss: 3.2281
Epoch 28/100, Loss: 3.2731
Epoch 29/100, Loss: 3.2444
Epoch 30/100, Loss: 3.2395
Epoch 31/100, Loss: 3.2090
Epoch 32/100, Loss: 3.2058
Epoch 33/100, Loss: 3.2784
Epoch 34/100, Loss: 3.2741
Epoch 35/100, Loss: 3.1877
Epoch 36/100, Loss: 3.2211
Epoch 37/100, Loss: 3.2206


In [None]:
# Test the model
def generate_text(model, start_string, max_length=100):
    model.eval()
    input_ids = encode_text(start_string).unsqueeze(0).to(DEVICE)
    generated = input_ids

    with torch.no_grad():
        for _ in range(max_length):
            logits = model(generated, generated[:, :-1])
            next_token_logits = logits[:, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
            generated = torch.cat([generated, next_token], dim=1)

    return "".join([idx_to_char[idx.item()] for idx in generated[0]])

# Generate some text after training
test_start = "KING:"
generated_text = generate_text(model, test_start, max_length=50)
print(f"\nGenerated text starting with '{test_start}':")
print(generated_text)