# Decoder-Only Transformers (by batch)

This notebook demonstrates how to build a very simple decoder-only Transformer (which can takes batches instead of a single example in a forward pass) from scratch. Decoder-only architectures, such as those powering large language models (LLMs) like ChatGPT, focus exclusively on the generative component of the Transformer. By working through this notebook, you'll see how these models process context and generate output—providing a foundation for understanding how modern LLMs operate under the hood.

In [193]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader

import lightning as L

# Create training dataset

In [194]:
# Create a mapping from vocabs to numbers as nn.Embedding can only take integers
token_to_id = {"what": 0,
               "is": 1,
                "statquest": 2,
                "awesome": 3,
                "<EOS>": 4
              }

# Create a mapping from numbers back to vocabs to interpret the output from the transformer
id_to_token = dict(map(reversed, token_to_id.items()))


In [195]:
# Create the training dataset
# As the input is going to be word embeddings, we only need the corresponding numbers from the mapping
# The tokens used as inputs during training comes from 1. processing the prompt and 2. generating the output
inputs = torch.tensor([[token_to_id["what"],
                        token_to_id["is"],
                        token_to_id["statquest"],
                        token_to_id["<EOS>"],
                        token_to_id["awesome"]
                        ], 
                        
                        [token_to_id["statquest"],
                         token_to_id["is"],
                         token_to_id["what"],
                         token_to_id["<EOS>"],
                         token_to_id["awesome"]]])

labels = torch.tensor([[token_to_id["is"],
                         token_to_id["statquest"],
                         token_to_id["<EOS>"],
                         token_to_id["awesome"],
                         token_to_id["<EOS>"]], 
                         
                         [token_to_id["is"],
                          token_to_id["what"],
                          token_to_id["<EOS>"],
                          token_to_id["awesome"],
                          token_to_id["<EOS>"]]])

dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset, batch_size=2)

# Position Encoding

The formula for the (standard, used in the paper **Attention is all you need**) position encoding is:  
PE_(pos, 2i) = sin(pos / 10000^(2i / d_model))  
PE_(pos, 2i+1) = cos(pos / 10000^(2i / d_model))  


In [196]:
# First try
# class PositionEncoding(nn.Module):

#     def __init__(self, batch_size=2, d_model=2, max_len=6):

#         super().__init__()

#         # pe stands for position encoding
#         pe = torch.zeros(batch_size, max_len, d_model)

#         # position is now a batch of 2D column matrices of size [batch_size, max_len, 1], e.g. [[0.], [1.], [2.]] * batch_size
#         position = torch.arange(start=0, end=max_len, step=1).float().unsqueeze(1)
#         # Add a leading batch dimension 
#         position = position.unsqueeze(0).expand(batch_size, max_len, 1)

#         # Step is set to 2 because of "2i" in the formula, note that it is still a 1D tensor (no leading batch dimension is needed, it is still broadcastable)
#         embedding_index = torch.arange(start=0, end=d_model, step=2).float()
    
#         # div_term is a row matrix (1D) with the same size as embedding_index
#         div_term = torch.tensor(10000.)**(embedding_index / d_model)

#         # Note: calculating the sin and cos values in this way only works when d_model is an even number, if d_model is odd, there will be a shape mismatch
#         # E.g. [2, 3, 1] / [2] => [2, 3, 1] / [1, 1, 2] => results in a tensor of size [2, 3, 2]
#         pe[:, :, 0::2] = torch.sin(position / div_term)
#         pe[:, :, 1::2] = torch.cos(position / div_term)

#         self.register_buffer('pe', pe)
    

#     def forward(self, word_embeddings):

#         print(word_embeddings.shape)
#         print(self.pe[:, :word_embeddings.size(1), :].shape)
#         # Note: we need to return the position encodings across all the batches
#         return word_embeddings + self.pe[:, :word_embeddings.size(1), :]


In [197]:
class PositionEncoding(nn.Module):

    def __init__(self, d_model=2, max_len=6):

        super().__init__()

        # pe stands for position encoding
        pe = torch.zeros(max_len, d_model)

        # position is a column matrix (2D) of size [max_len, 1], e.g. [[0.], [1.], [2.]]
        position = torch.arange(start=0, end=max_len, step=1).float().unsqueeze(1)

        # Step is set to 2 because of "2i" in the formula, note that it is a 1D tensor, e.g. [0., 2.] as each position can have multiple embedding values
        embedding_index = torch.arange(start=0, end=d_model, step=2).float()

        # div_term is a row matrix (1D) with the same size as embedding_index
        div_term = torch.tensor(10000.)**(embedding_index / d_model)

        # Note: calculating the sin and cos values in this way only works when d_model is an even number, if d_model is odd, there will be a shape mismatch
        pe[:, 0::2] = torch.sin(position / div_term)
        pe[:, 1::2] = torch.cos(position / div_term)

        self.register_buffer('pe', pe)
    

    def forward(self, word_embeddings):

        # Note: the size of word_embeddings is [batch_size, input_len, d_model]
        # and the size of self.pe[:word_embeddings.size(0), :] is [input_length, d_model]
        # e.g. [2, 4, 2] + [4, 2] => [2, 4, 2] + [1, 4, 2] => [2, 4, 2]
        return word_embeddings + self.pe[:word_embeddings.size(1), :]
    


In [198]:
# Testing
# pe = PositionEncoding(max_len=6, d_model=2)
# we = nn.Embedding(num_embeddings=5, embedding_dim=2)

# model_input = torch.tensor([token_to_id["what"],
#                             token_to_id["is"],
#                             token_to_id["statquest"],
#                             token_to_id["<EOS>"]])

# model_input = model_input.repeat(2, 1)

# word_em = we(model_input)
# pos_en = pe(word_em)
# pos_en

# Masked Self-Attention

In [199]:
class Attention(nn.Module):

    def __init__(self, d_model=2):

        super().__init__()

        # Create the weights associated with the query, key and value values
        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

        self.batch_dim = 0
        self.row_dim = 1
        self.col_dim = 2

    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):

        # Create the Q, K and V matrices
        # Note: now the size of q, k and v is [batch_size, input_len, d_model] (for convenience, I will assume the batch_size = 2, seq_len = 4, d_model = 2)
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)

        # Calculate the similarity score between the query values and key values
        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

        # Scale the similarity score with the square root of d_model
        # Broadcasting steps: [2, 4, 4] / [] => [2, 4, 4] / [1, 1, 1] => [2, 4, 4]
        scaled_sims = sims / torch.tensor((k.size(self.col_dim))**0.5)
        print("the dimension of scaled_sims: ", scaled_sims.shape)

        device = scaled_sims.device
        
        # Mask the scaled similarity scores of the later tokens so that the earlier tokens can't cheat. Note: -1e9 is an approximation of negative infinity
        if mask is not None:
            # Move your mask to mps:0, or mask would live in cpu by default
            mask = mask.to(device)
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)
            print("the dimension of scaled_sims after masking: ", scaled_sims.shape)

        # Applying the softmax function to the scaled similarites determines the percentages of influence each token (in columns) should have on the others (in rows)
        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)
        print("the dimension of attention_percents: ", attention_percents.shape)

        # attention_scores are basically the contextualised embeddings
        # The dimensions of the matrix multiplication: [2, 4, 4] * [2, 4, 2] => [2, 4, 2]
        attention_scores = torch.matmul(attention_percents, v)
        print("the dimension of attention_score: ", attention_scores.shape)

        return attention_scores


# Decoder-only Transformer

In [262]:
class DecoderOnlyTransformer(L.LightningModule):

    def __init__(self, num_tokens, d_model, max_len, batch_size=2):

        super().__init__()

        # Word Embeddings
        self.we = nn.Embedding(num_embeddings=num_tokens, embedding_dim=d_model)

        # Position Encodings
        self.pe = PositionEncoding(d_model=d_model, max_len=max_len)

        # Masked Self-Attention
        self.attention = Attention(d_model=d_model)

        # Fully Connected layer
        self.fc = nn.Linear(in_features=d_model, out_features=num_tokens)

        # Calculate the loss with Cross Entropy; softmax is already included
        self.loss = nn.CrossEntropyLoss()

        # Specify batch size
        self.batch_size = batch_size

    # The size of token_ids is a 2D tensor [batch_size, seq_len], unlike nn.LSTM, which requires the size of the input tensor to be [seq_len, batch_size, input_size]
    def forward(self, token_ids):

        # Create word embeddings
        word_embeddings = self.we(token_ids)

        # Add position encodings to the word embeddings
        position_encoded = self.pe(word_embeddings)

        # Create a mask matrix for masking used in masked self attention
        # The size of mask should be [seq_len, seq_len]
        mask = torch.tril(torch.ones(token_ids.size(dim=1),token_ids.size(dim=1))) == 0
        # Add a leading batch dimension to mask
        mask = mask.unsqueeze(dim=0).expand(self.batch_size, token_ids.size(dim=1), token_ids.size(dim=1))
        print("the dimension of mask: ", mask.shape)

        # Masked Self-Attention
        self_attention_values = self.attention(position_encoded,
                                               position_encoded,
                                               position_encoded,
                                               mask=mask)

        # Add residual connections
        print("the dimension of position encoding: ", position_encoded.shape)
        residual_connection_values = position_encoded + self_attention_values

        # Run the residual connections through a fully connected layer
        # fc_layer_out has the same size as residual_connected_values, which is a 2D tensor of size [seq_len, num_tokens]
        fc_layer_out = self.fc(residual_connection_values)
        print("the dimension of the fc layer output: ", fc_layer_out.shape)

        # Return the fully connected layer
        return fc_layer_out
    
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=0.1)
    
    def training_step(self, batch, batch_idx):
        
        # input_tokens is a 2D tensor of size [batch_size, seq_len]
        input_tokens, labels = batch
        print(batch)
        print(input_tokens)
        # outputs is fc_layer_out, so they share the same size
        outputs = self.forward(input_tokens)
        # Cross Entropy loss will automatically apply softmax to the outputs
        # Sum the loss across the samples in the batch
        loss = torch.sum(self.loss(outputs, labels))

        return loss


In [303]:
# Test if the forward pass is working with batch_size=2
model = DecoderOnlyTransformer(num_tokens=len(token_to_id), d_model=2, max_len=6, batch_size=2)

model_input = torch.tensor([token_to_id["what"],
                            token_to_id["is"],
                            token_to_id["statquest"],
                            token_to_id["<EOS>"]])

# Put one more (duplicated) sample in the batch
# Note: -1 in the expand function means keep that dimension unchanged
model_input = model_input.unsqueeze(0).expand(2, -1)

# The dimension is changed to 1 now as the leading dimension refers to the batch_size
input_length = model_input.size(dim=1)
print("the input_length is: ", input_length)

# predictions is the raw score output by the fulled connected layer
predictions = model(model_input)
# predictions is now a 3D tensor of size [batcg_size, seq_len (or input_length), num_tokens (or vocab_size)]
print(predictions)

# So now predicted_id is a 1D tensor of size [2] after argmax ([2, 5] => [2])
predicted_id = torch.argmax(predictions[:,-1,:], dim=-1)
print(predicted_id)


the input_length is:  4
the dimension of mask:  torch.Size([2, 4, 4])
the dimension of scaled_sims:  torch.Size([2, 4, 4])
the dimension of scaled_sims after masking:  torch.Size([2, 4, 4])
the dimension of attention_percents:  torch.Size([2, 4, 4])
the dimension of attention_percents:  torch.Size([2, 4, 2])
the dimension of position encoding:  torch.Size([2, 4, 2])
the dimension of the fc layer output:  torch.Size([2, 4, 5])
tensor([[[-2.0690,  1.2700,  0.7473,  1.8396, -0.7087],
         [-1.1143,  1.0249,  0.9409,  0.5370, -0.4872],
         [-0.5126,  0.8604,  0.9683, -0.2167, -0.3221],
         [ 0.9594,  0.3338, -0.1442, -1.2195,  0.4005]],

        [[-2.0690,  1.2700,  0.7473,  1.8396, -0.7087],
         [-1.1143,  1.0249,  0.9409,  0.5370, -0.4872],
         [-0.5126,  0.8604,  0.9683, -0.2167, -0.3221],
         [ 0.9594,  0.3338, -0.1442, -1.2195,  0.4005]]],
       grad_fn=<ViewBackward0>)
tensor([0, 0])


In [304]:
trainer = L.Trainer(max_epochs=30)
trainer.fit(model, train_dataloaders=dataloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | we        | Embedding        | 10     | train
1 | pe        | PositionEncoding | 0      | train
2 | attention | Attention        | 12     | train
3 | fc        | Linear           | 15     | train
4 | loss      | CrossEntropyLoss | 0      | train
-------------------------------------------------------
37        Trainable params
0         Non-trainable params
37        Total params
0.000     Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

[tensor([[0, 1, 2, 4, 3],
        [2, 1, 0, 4, 3]], device='mps:0'), tensor([[1, 2, 4, 3, 4],
        [1, 0, 4, 3, 4]], device='mps:0')]
tensor([[0, 1, 2, 4, 3],
        [2, 1, 0, 4, 3]], device='mps:0')
the dimension of mask:  torch.Size([2, 5, 5])
the dimension of scaled_sims:  torch.Size([2, 5, 5])
the dimension of scaled_sims after masking:  torch.Size([2, 5, 5])
the dimension of attention_percents:  torch.Size([2, 5, 5])
the dimension of attention_percents:  torch.Size([2, 5, 2])
the dimension of position encoding:  torch.Size([2, 5, 2])
the dimension of the fc layer output:  torch.Size([2, 5, 5])
[tensor([[0, 1, 2, 4, 3],
        [2, 1, 0, 4, 3]], device='mps:0'), tensor([[1, 2, 4, 3, 4],
        [1, 0, 4, 3, 4]], device='mps:0')]
tensor([[0, 1, 2, 4, 3],
        [2, 1, 0, 4, 3]], device='mps:0')
the dimension of mask:  torch.Size([2, 5, 5])
the dimension of scaled_sims:  torch.Size([2, 5, 5])
the dimension of scaled_sims after masking:  torch.Size([2, 5, 5])
the dimension of att

`Trainer.fit` stopped: `max_epochs=60` reached.


[tensor([[0, 1, 2, 4, 3],
        [2, 1, 0, 4, 3]], device='mps:0'), tensor([[1, 2, 4, 3, 4],
        [1, 0, 4, 3, 4]], device='mps:0')]
tensor([[0, 1, 2, 4, 3],
        [2, 1, 0, 4, 3]], device='mps:0')
the dimension of mask:  torch.Size([2, 5, 5])
the dimension of scaled_sims:  torch.Size([2, 5, 5])
the dimension of scaled_sims after masking:  torch.Size([2, 5, 5])
the dimension of attention_percents:  torch.Size([2, 5, 5])
the dimension of attention_percents:  torch.Size([2, 5, 2])
the dimension of position encoding:  torch.Size([2, 5, 2])
the dimension of the fc layer output:  torch.Size([2, 5, 5])
[tensor([[0, 1, 2, 4, 3],
        [2, 1, 0, 4, 3]], device='mps:0'), tensor([[1, 2, 4, 3, 4],
        [1, 0, 4, 3, 4]], device='mps:0')]
tensor([[0, 1, 2, 4, 3],
        [2, 1, 0, 4, 3]], device='mps:0')
the dimension of mask:  torch.Size([2, 5, 5])
the dimension of scaled_sims:  torch.Size([2, 5, 5])
the dimension of scaled_sims after masking:  torch.Size([2, 5, 5])
the dimension of att

In [305]:
model_input = torch.tensor([token_to_id["what"],
                            token_to_id["is"],
                            token_to_id["statquest"],
                            token_to_id["<EOS>"]])

# Put one more (duplicated) sample in the batch
# Note: -1 in the expand function means keep that dimension unchanged
model_input = model_input.unsqueeze(0).expand(2,-1)

# The dimension is changed to 1 now as the leading dimension refers to the batch_size
input_length = model_input.size(dim=1)
print("the input_length is: ", input_length)

# predictions is the raw score output by the fulled connected layer
predictions = model(model_input)
# predictions is now a 3D tensor of size [batch_size, seq_len (or input_length), num_tokens (or vocab_size)]
print(predictions)

# So now predicted_id is a 1D tensor of size [2] after argmax ([2, 5] => [2])
predicted_id = torch.argmax(predictions[:,-1,:], dim=-1)
print(predicted_id)


the input_length is:  4
the dimension of mask:  torch.Size([2, 4, 4])
the dimension of scaled_sims:  torch.Size([2, 4, 4])
the dimension of scaled_sims after masking:  torch.Size([2, 4, 4])
the dimension of attention_percents:  torch.Size([2, 4, 4])
the dimension of attention_percents:  torch.Size([2, 4, 2])
the dimension of position encoding:  torch.Size([2, 4, 2])
the dimension of the fc layer output:  torch.Size([2, 4, 5])
tensor([[[ -5.0525,   4.2033,  -4.8357,   0.4521,  -5.2534],
         [ 10.5997,   9.8727,  -1.4766, -11.3479,   1.2655],
         [ -3.9753,  17.5173, -13.3411,  -5.6371, -15.5727],
         [-24.4906, -19.2599,   2.0944,  21.8119,   0.3338]],

        [[ -5.0525,   4.2033,  -4.8357,   0.4521,  -5.2534],
         [ 10.5997,   9.8727,  -1.4766, -11.3479,   1.2655],
         [ -3.9753,  17.5173, -13.3411,  -5.6371, -15.5727],
         [-24.4906, -19.2599,   2.0944,  21.8119,   0.3338]]],
       grad_fn=<ViewBackward0>)
tensor([3, 3])
