In [None]:
#############
## IMPORTS ##
#############
import torch
import torch.nn as nn
import torch.nn.functional as F
import requests, math



In [3]:
#####################
## IMPORT THE DATA ##
######################

# Instead of using tenorflow to import the data set, we will be pulling it
# directly from Karpath's GitHub repository. He is the one who created the 
# data set and made it famous in his RNN blog post.
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
text = requests.get(url).text

In [4]:
#################
## TOKENIZAION ##
#################
# Here we are performing the tokenization. For simplicity, we will use a
# word tokenization (i.e. each unique word is a token). This means that the
# model will learn more quickly but will have a limited vocabulary. In practice,
# this data set, and other natural language data sets, are often tokenized using
# sub-word or character tokenization to allow for a larger vocabulary and better
# handling of rare words.

# Split the text pulled from the GitHub repo into words using the whitespace to
# separate the words.
words = text.split()

# Generate a list of the unique words (vocabulary) and create mappings from
# word to index and index to word.
vocab = sorted(set(words))
# Word to index mapping
stoi = {w: i for i, w in enumerate(vocab)}    
# Index to word mapping 
itos = {i: w for w, i in stoi.items()}    
# Size of the vocabulary      
vocab_size = len(vocab)

# Print some statistics about the tokenization
print("Total words:", len(words))
print("Vocab size:", vocab_size)
print("Sample tokens:", words[:20])

# Convert the entire text into a tensor of word indices (note that we need to
# use tensor here to feed the data into a PyTorch model later).
data = torch.tensor([stoi[w] for w in words], dtype=torch.long)

# Train/validation split. Here we are using 90% of the data for training and 10%
# for validation. This can be adjusted as needed.
split = int(0.9 * len(data))
train_data = data[:split]
val_data   = data[split:]

print("Train words:", len(train_data))
print("Val words:", len(val_data))

# Function to generate a batch of data for training/validation. The batch will
# consist of input sequences of length block_size and the corresponding target
# sequences (which are the input sequences shifted by one word). Block size and
# batch size are hyperparameters that can be adjusted.
def get_batch(split, block_size, batch_size, device):
    """
    Inputs:    
        split: "train" or "val" to indicate which data split to use
        block_size: Length of each input sequence
        batch_size: Number of sequences in the batch
        device: Device to place the tensors on (e.g., "cpu" or "cuda")
    Outputs:
        x: Input tensor of shape (batch_size, block_size)
        y: Target tensor of shape (batch_size, block_size)
    Formats the train/validation data into batches.
    """ 
    # select the appropriate data split
    source = train_data if split == "train" else val_data
    # pick random starting word positions
    ids = torch.randint(0, len(source) - block_size - 1, (batch_size,))
    # construct the input and target sequences
    x = torch.stack([source[i:i+block_size] for i in ids]).to(device)
    y = torch.stack([source[i+1:i+block_size+1] for i in ids]).to(device)
    return x, y

Total words: 202651
Vocab size: 25670
Sample tokens: ['First', 'Citizen:', 'Before', 'we', 'proceed', 'any', 'further,', 'hear', 'me', 'speak.', 'All:', 'Speak,', 'speak.', 'First', 'Citizen:', 'You', 'are', 'all', 'resolved', 'rather']
Train words: 182385
Val words: 20266


In [None]:
########################
## ATTENTION VARIANTS ##
########################

#############################################
## SCALED DOT-PRODUCT ATTENTION (STANDARD) ##
##############################################
class ScaledDotProductAttention(nn.Module):
    """
        Standard scaled dot-product attention with causal masking. 
    """
    def __init__(self, n_embd, n_head):
        """
        Inputs:
            n_embd: Embedding dimension (length of input vectors)
            n_head: Number of attention heads
        Returns:
            None
        Initializes the components of the attention mechanism with one 
        or more heads
        """
        # Initialize the nn.Module superclass
        super().__init__()
        # Define the number of heads and the dimension per head. The dimension
        # per head is the embedding dimension divided by the number of heads and 
        # must be an integer.
        assert n_embd % n_head == 0
        self.n_head = n_head
        self.d = n_embd // n_head
        # Define the linear layers for query, key, and value projections
        self.qkv = nn.Linear(n_embd, 3*n_embd, bias=False)
        # Define the output linear layer
        self.out = nn.Linear(n_embd, n_embd)

    def forward(self, x):
        """
        Inputs:        
            x: Input tensor of shape (B, T, C) where
               B = batch size
               T = sequence length (number of tokens)
               C = embedding dimension
        Returns:
            Output tensor of shape (B, T, C) after applying attention
        Performs the forward pass of the attention mechanism.
        """
        # Get the batch size (B), sequence length (T), and embedding dimension (C)
        B, T, C = x.shape
        # Send the input through the query, key, and value linear layer and reshapes
        # into multiple heads. Then permute the dimensions to get the shape
        # (3, B, n_head, T, d) where d = C / n_head. Permute reorders the dimensions.
        qkv = self.qkv(x).reshape(B, T, 3, self.n_head, self.d).permute(2,0,3,1,4)
        # Split the qkv tensor into separate query, key, and value tensors
        q, k, v = qkv[0], qkv[1], qkv[2]
        # Compute the raw attention scores by performing the dot product between
        # the query and key tensors, scaled by the square root of the dimension
        # per head. The result has shape (B, n_head, T, T).
        att = (q @ k.transpose(-2,-1)) / math.sqrt(self.d)
        # Create a causal mask to ensure that each position can only attend to
        # previous positions (including itself). The mask has shape (T, T). Causal
        # masking is important for autoregressive models to prevent information
        # leakage from future tokens.
        mask = torch.tril(torch.ones(T, T, device=x.device)) == 1
        # Apply the causal mask to the attention scores, setting masked positions
        # to negative infinity. This ensures that after applying softmax, these
        # positions will have zero attention weight.
        att = att.masked_fill(~mask, float('-inf'))
        # Apply the softmax function to the attention scores to obtain the
        # attention weights. The softmax is applied along the last dimension
        # (the sequence length dimension).
        att = att.softmax(dim=-1)
        # Compute the output by performing the weighted sum of the value tensor
        # using the attention weights. The result has shape (B, n_head, T, d).
        out = att @ v
        # Reshape and permute the output tensor back to shape (B, T, C).
        out = out.transpose(1,2).reshape(B,T,C)
        # Send the output through the final linear layer and return.
        return self.out(out)

#####################
## LOCAL ATTENTION ##
#####################
class LocalAttention(nn.Module):
    """
    Simple causal *local* attention with φ(x)=ELU(x)+1 feature map.
    Pedagogical sliding-window implementation (not the most optimized).
    """
    def __init__(self, n_embd, n_head, window_size=64):
        """
        Inputs:
            n_embd: Embedding dimension (length of input vectors)
            n_head: Number of attention heads
            window_size: Size of the local attention window
        Returns:
            None
        Initializes the components of the local attention mechanism.
        """
        # Initialize the nn.Module superclass
        super().__init__()
        # Define the number of heads and the dimension per head. The dimension
        # per head is the embedding dimension divided by the number of heads and 
        # must be an integer.
        assert n_embd % n_head == 0
        self.n_head = n_head
        self.d = n_embd // n_head
        # Define the local attention window size
        self.window_size = int(window_size)
        # Define the linear layers for query, key, and value projections
        self.Wq = nn.Linear(n_embd, n_embd, bias=False)
        self.Wk = nn.Linear(n_embd, n_embd, bias=False)
        self.Wv = nn.Linear(n_embd, n_embd, bias=False)
        # Define the output linear layer
        self.out = nn.Linear(n_embd, n_embd)
        # Define the scaling factor for stability. This is a common practice
        # in attention mechanisms to prevent numerical instability.
        self.scale = 1.0 / math.sqrt(self.d)  # standard scaling for stability

    # Define the feature map phi(x) used in the attention mechanism. staticmethod
    # decorator indicates that this method does not depend on the instance
    # (self) and can be called on the class itself.
    @staticmethod
    def phi(x):
        """
        Inputs:
            x: Input tensor
        Returns:
            Transformed tensor after applying the feature map
        Applies the feature map φ(x) = ELU(x) + 1."""
        return F.elu(x, alpha=1.0) + 1.0

    def forward(self, x):
        """
        Inputs:
            x: (B, T, C)
        Returns:
          y: (B, T, C)
        Performs the forward pass of the local attention mechanism.
        """
        # Get the batch size (B), sequence length (T), and embedding dimension (C)
        B, T, C = x.shape
        # Get the number of heads (H) and dimension per head (Dh)
        H, Dh = self.n_head, self.d

        # project and split into heads: (B,H,T,Dh)
        q = self.Wq(x).reshape(B, T, H, Dh).transpose(1, 2)
        k = self.Wk(x).reshape(B, T, H, Dh).transpose(1, 2)
        v = self.Wv(x).reshape(B, T, H, Dh).transpose(1, 2)

        # apply feature map
        q_phi = self.phi(q)  # (B,H,T,Dh)
        k_phi = self.phi(k)  # (B,H,T,Dh)

        # slide over time steps
        outs = []
        for t in range(T):
            # causal window: [max(0, t-window+1) .. t]
            j0 = max(0, t - self.window_size + 1)
            j1 = t + 1

            # slice keys/values in window: (B,H,W,Dh)
            k_win = k_phi[:, :, j0:j1, :]      # φ(K)
            v_win = v[:,  :, j0:j1, :]      # V (no φ on values)
            q_t   = q_phi[:, :, t, :]          # (B,H,Dh)

            # similarities: (B,H,W) using φ(q)·φ(k)
            # add scale factor to stabilize softmax (common with dot-product attention)
            sim = torch.einsum("bhd,bhwd->bhw", q_t, k_win) * self.scale

            # softmax over window positions
            a = F.softmax(sim, dim=-1)      # (B,H,W)

            # weighted sum: (B,H,Dh)
            y_t = torch.einsum("bhw,bhwd->bhd", a, v_win)

            outs.append(y_t)

        # stack back to (B,T,C)
        y = torch.stack(outs, dim=2).transpose(1, 2).reshape(B, T, C)
        return self.out(y)

#######################
## LINEAR ATTENTION ##
#######################
class LinearAttention(nn.Module):
    """
    Simple causal linear attention with φ(x)=ELU(x)+1 feature map.
    Pedagogical prefix-scan implementation (not the most optimized).
    """
    def __init__(self, n_embd, n_head):
        """
        Inputs:
            n_embd: Embedding dimension (length of input vectors)
            n_head: Number of attention heads
        Returns:
            None
        Initializes the components of the linear attention mechanism.
        """
        # Initialize the nn.Module superclass
        super().__init__()
        # Define the number of heads and the dimension per head. The dimension
        # per head is the embedding dimension divided by the number of heads and 
        # must be an integer.
        assert n_embd % n_head == 0
        self.n_head = n_head
        self.d = n_embd // n_head
        # Define the linear layers for query, key, and value projections
        self.Wq = nn.Linear(n_embd, n_embd, bias=False)
        self.Wk = nn.Linear(n_embd, n_embd, bias=False)
        self.Wv = nn.Linear(n_embd, n_embd, bias=False)
        # Define the output linear layer
        self.out = nn.Linear(n_embd, n_embd)
        # Small epsilon value for numerical stability in division
        self.eps = 1e-6

    # See the local attention class for explanation of staticmethod decorator
    @staticmethod
    def phi(x):
        """
        Inputs:
            x: Input tensor
        Returns:
            Transformed tensor after applying the feature map
        Applies the feature map φ(x) = ELU(x) + 1."""
        return F.elu(x, alpha=1.0) + 1.0

    def forward(self, x):
        """
        Inputs:
            x: (B, T, C)
        Returns:
          y: (B, T, C)
        Performs the forward pass of the linear attention mechanism.
        """
        # Get the batch size (B), sequence length (T), and embedding dimension (C)
        B, T, C = x.shape
        # Get the number of heads (H) and dimension per head (Dh)
        H, Dh = self.n_head, self.d

        # project and split into heads: (B,H,T,Dh)
        q = self.Wq(x).reshape(B, T, H, Dh).transpose(1, 2)   # (B,H,T,Dh)
        k = self.Wk(x).reshape(B, T, H, Dh).transpose(1, 2)
        v = self.Wv(x).reshape(B, T, H, Dh).transpose(1, 2)

        # apply feature map
        q_phi = self.phi(q)
        k_phi = self.phi(k)

        # prefix accumulators
        K_accum  = torch.zeros(B, H, Dh, device=x.device, dtype=x.dtype)
        KV_accum = torch.zeros(B, H, Dh, Dh, device=x.device, dtype=x.dtype)

        # slide over time steps
        outs = []
        for t in range(T):
            # keys/values at time t
            kt = k_phi[:, :, t, :]                                    # (B,H,Dh)
            vt = v[:, :, t, :]                                     # (B,H,Dh)
            # accumulate prefix sums
            KV_accum = KV_accum + torch.einsum("bhd,bhe->bhde", kt, vt)
            K_accum  = K_accum  + kt
            # query
            qt = q_phi[:, :, t, :]                                    # (B,H,Dh)
            # compute output at time t
            num = torch.einsum("bhd,bhde->bhe", qt, KV_accum)      # (B,H,Dh)
            # compute denominator
            den = torch.einsum("bhd,bhd->bh", qt, K_accum).unsqueeze(-1) + self.eps
            # final output at time t
            yt = num / den
            outs.append(yt)
        # stack back to (B,T,C)
        y = torch.stack(outs, dim=2).transpose(1, 2).reshape(B, T, C)  # (B,T,C)
        return self.out(y)

# ============================
# Multi-Query Attention (MQA)
# ============================
class MultiQueryAttention(nn.Module):
    """
    Multi-Query Attention: many query heads, shared K and V across heads.
    Cuts memory bandwidth for K/V compared with full MHA.
    """
    def __init__(self, n_embd, n_head):
        """
        Inputs:
            n_embd: Embedding dimension (length of input vectors)
            n_head: Number of attention heads
        Returns:
            None
        Initializes the components of the multi-query attention mechanism.
        """
        # Initialize the nn.Module superclass
        super().__init__()
        # Define the number of heads and the dimension per head. The dimension
        # per head is the embedding dimension divided by the number of heads and 
        # must be an integer.
        assert n_embd % n_head == 0
        self.n_head = n_head
        self.d = n_embd // n_head
        # Define the linear layers for query, key, and value projections
        self.Wq = nn.Linear(n_embd, n_embd, bias=False)  # (B,T,C) -> (B,T,H*Dh)
        self.Wk = nn.Linear(n_embd, self.d, bias=False)  # shared K: (B,T,Dh)
        self.Wv = nn.Linear(n_embd, self.d, bias=False)  # shared V: (B,T,Dh)
        # Define the output linear layer
        self.out = nn.Linear(n_embd, n_embd)

    def forward(self, x):
        """
        Inputs:
            x: (B, T, C)
        Returns:
          y: (B, T, C)
        Performs the forward pass of the multi-query attention mechanism.
        """
        # Get the batch size (B), sequence length (T), and embedding dimension (C)
        B, T, C = x.shape
        # Get the number of heads (H) and dimension per head (Dh)
        H, Dh = self.n_head, self.d
        # project and split into heads: (B,H,T,Dh)
        q = self.Wq(x).reshape(B, T, H, Dh).transpose(1, 2)      # (B,H,T,Dh)

        # shared K/V (expand over heads)
        k = self.Wk(x).unsqueeze(1).expand(B, H, T, Dh).contiguous()
        v = self.Wv(x).unsqueeze(1).expand(B, H, T, Dh).contiguous()

        # PyTorch's fused scaled dot-product attention (Flash when available)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)  # (B,H,T,Dh)
        y = y.transpose(1, 2).reshape(B, T, C)
        return self.out(y)


In [6]:
##########################
## LARGE LANGUAGE MODEL ##
###########################

##############################
## SINGLE TRANSFORMER BLOCK ##
###############################
class TransformerBlock(nn.Module):
    # Initialization method for the Transformer block.
    def __init__(self, n_embd, n_head, attn_class):
        """
        Inputs:        
            n_embd: Length of the embedding vector
            n_head: Number of attention heads
            attn_class: Attention class to use (must be the name of one of the above classes)
        Returns:
            None.
        Initializes a Transformer block with layer normalization, attention, and feed-forward network.
        """ 
        # Initialize the nn.Module superclass
        super().__init__()
        # Layer normalization   
        self.ln1 = nn.LayerNorm(n_embd)
        # Attention mechanism
        self.attn = attn_class(n_embd, n_head)
        # Second layer normalization
        self.ln2 = nn.LayerNorm(n_embd)
        # Feed-forward network
        self.ff = nn.Sequential(
            nn.Linear(n_embd, 4*n_embd),
            nn.GELU(),
            nn.Linear(4*n_embd, n_embd)
        )

    def forward(self, x):
        """
        Inputs:
            x: Input tensor of shape (B, T, C) where
               B = batch size
               T = sequence length (number of tokens)
               C = embedding dimension
        Returns:
            Output tensor of shape (B, T, C) after applying the Transformer block
        Performs the forward pass of the Transformer block.
        """
        # Apply layer normalization, attention, and residual connection
        x = x + self.attn(self.ln1(x))
        # Apply layer normalization, feed-forward network, and residual connection
        x = x + self.ff(self.ln2(x))
        # Return the output tensor
        return x

class LLM(nn.Module):
    """
    Creates an LLM model using multiple Transformer blocks. Allows for different attention mechanisms.
    """
    def __init__(self, vocab_size, block_size, n_layer=4, n_embd=256, n_head=4, 
                 attn_class=ScaledDotProductAttention):
        """
        Inputs:
            vocab_size: Size of the vocabulary (number of unique tokens)
            block_size: Length of each input sequence
            n_layer: Number of Transformer blocks
            n_embd: Length of the embedding vector
            n_head: Number of attention heads
            attn_class: Attention class to use (must be the name of one of the above classes)
        Returns:
            None
        Initializes the LLM model with token and positional embeddings, multiple Transformer blocks,
        layer normalization, and a final linear layer for output.
        """
        # Initialize the nn.Module superclass
        super().__init__()
        # Store the block size (length of input sequences)
        self.block_size = block_size
        # Token embedding layer
        self.token = nn.Embedding(vocab_size, n_embd)
        # Positional embedding layer
        self.pos   = nn.Embedding(block_size, n_embd)
        # Stack of Transformer blocks, number of block is n_layer
        self.blocks = nn.Sequential(*[TransformerBlock(n_embd, n_head, attn_class) for _ in range(n_layer)])
        # Final layer normalization
        self.ln_f = nn.LayerNorm(n_embd)
        # Final linear layer 
        self.head = nn.Linear(n_embd, vocab_size)

    def forward(self, x, targets=None):
        """
        Inputs:
            x: Input tensor of shape (B, T) where
                 B = batch size
                 T = sequence length (number of tokens)
            targets: Optional target tensor of shape (B, T) for computing loss
        Returns:
            If targets is None, returns logits tensor of shape (B, T, vocab_size).
            If targets is provided, returns a tuple (logits, loss) where loss is the
            cross-entropy loss between logits and targets.
        Performs the forward pass of the LLM model.
        """
        # Get the batch size (B) and sequence length (T)
        B, T = x.shape
        # Create position indices for the sequence length
        pos = torch.arange(0, T, device=x.device)
        # Combine token and positional embeddings
        x = self.token(x) + self.pos(pos)[None,:,:]
        # Pass through the stack of Transformer blocks
        x = self.blocks(x)
        # Apply final layer normalization
        x = self.ln_f(x)
        # Compute logits using the final linear layer
        logits = self.head(x)
        # If targets are not provided, return logits
        if targets is None:
            return logits
        # If targets are provided, compute cross-entropy loss and return logits and loss
        loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
        return logits, loss

In [7]:
#$########################
## TRAIN THE LLM MODEL ##
#########################

# Check to see if a GPU is available and set the device accordingly. This only works
# if you have a compatible GPU and the necessary CUDA libraries installed. The GPU will
# significantly speed up the training process.
device = "cuda" if torch.cuda.is_available() else "cpu"

# Define block size and batch size for training. The block sie is the length of
# each input sequence and the batch size is the number of sequences in each batch. The
# values can be adjusted as needed.
block_size = 256
batch_size = 64

# Choose the attention mechanism to use in the model. You can switch between
# different attention classes defined above. Use the name exactly as defined.
attn_type = ScaledDotProductAttention   

# Define the model using the specified attention mechanism and move it to the
# selected device (CPU or GPU).
model = LLM(vocab_size, block_size, attn_class=attn_type).to(device)

# Define the optimizer for training. Here we are using the AdamW optimizer,
# which is a variant of the Adam optimizer that includes weight decay for
# regularization. The optimizer will update the model parameters during training.
optimizer = torch.optim.AdamW(model.parameters())

# Training loop. Steps is the number of training iterations. In each iteration,
# we get a batch of data, perform a forward pass through the model, compute the
# loss, perform backpropagation, and update the model parameters using the optimizer.
steps = 1000
print("Training...")

for step in range(steps):
    # Get a batch of training data
    xb, yb = get_batch("train", block_size, batch_size, device)
    # Perform a forward pass through the model to get logits and loss
    logits, loss = model(xb, yb)
    # Backpropagation and optimization step
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # Print the loss every 100 steps
    if step % 100 == 0:
        print(f"step {step} | loss {loss.item():.3f}")



Training...
step 0 | loss 10.318
step 100 | loss 6.917
step 200 | loss 5.677
step 300 | loss 4.692
step 400 | loss 3.690
step 500 | loss 2.796
step 600 | loss 1.860
step 700 | loss 1.211
step 800 | loss 0.778
step 900 | loss 0.510


In [8]:
############################
## TEST THE TRAINED MODEL ##
############################

def sample(model, start="ROMEO:", steps=100):
    """
    Inputs:    
        model: Trained LLM model
        start: Starting prompt string
        steps: Number of words to generate
    Returns:
        Generated text string after sampling from the model
    Generates text by sampling from the trained LLM model starting from the
    given prompt.
    """
    # Set the model to evaluation mode (testing mode)
    model.eval()

    # Tokenize start prompt into words
    start_words = start.split()
    start_ids = [stoi.get(w, 0) for w in start_words]

    # Build initial sequence tensor
    idx = torch.tensor([start_ids], device=device)

    # Generate words one at a time
    for i in range(steps):
        # Crop context to block_size so position embeddings never overflow
        idx_cond = idx[:, -model.block_size:]
        # Get the model's logits for the current sequence
        with torch.no_grad():
            logits = model(idx_cond)[:, -1, :]  # final word's logits
        # Greedy decoding: pick highest‑probability word. This can be changed as desired.
        next_id = torch.argmax(logits, dim=-1, keepdim=True)
        # Append to sequence
        idx = torch.cat([idx, next_id], dim=1)

    # Convert token IDs back to words, and join into a single string, and return in
    result_words = [itos[int(i)] for i in idx[0]]
    return " ".join(result_words)


# Generate and print sample text from the trained model
print(sample(model, "To be, or not to be,", steps=50))


To be, or not to be, The which is ta'en! First Gentleman: It is a King Edward. She would the benefit of all the king a himself, this is to be done, and that her to be true, that the lord of Paulina,--a piece of Paulina,--a piece of Paulina,--a piece of Paulina,--a piece the court? for
