# Decoder-Only Transformers

This notebook demonstrates how to build a very simple decoder-only Transformer from scratch. Decoder-only architectures, such as those powering large language models (LLMs) like ChatGPT, focus exclusively on the generative component of the Transformer. By working through this notebook, you'll see how these models process context and generate output—providing a foundation for understanding how modern LLMs operate under the hood.

In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader

import lightning as L

# Create training dataset

In [30]:
# Create a mapping from vocabs to numbers as nn.Embedding can only take integers
token_to_id = {"what": 0,
               "is": 1,
                "statquest": 2,
                "awesome": 3,
                "<EOS>": 4
              }

# Create a mapping from numbers back to vocabs to interpret the output from the transformer
id_to_token = dict(map(reversed, token_to_id.items()))


In [None]:
# Create the training dataset
# As the input is going to be word embeddings, we only need the corresponding numbers from the mapping
# The tokens used as inputs during training comes from 1. processing the prompt and 2. generating the output
inputs = torch.tensor([[token_to_id["what"],
                        token_to_id["is"],
                        token_to_id["statquest"],
                        token_to_id["<EOS>"],
                        token_to_id["awesome"]
                        ], 
                        
                        [token_to_id["statquest"],
                         token_to_id["is"],
                         token_to_id["what"],
                         token_to_id["<EOS>"],
                         token_to_id["awesome"]]])

labels = torch.tensor([[token_to_id["is"],
                        token_to_id["statquest"],
                        token_to_id["<EOS>"],
                        token_to_id["awesome"],
                        token_to_id["<EOS>"]], 
                         
                       [token_to_id["is"],
                        token_to_id["what"],
                        token_to_id["<EOS>"],
                        token_to_id["awesome"],
                        token_to_id["<EOS>"]]])

dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset)

# Position Encoding

The formula for the (standard, used in the paper **Attention is all you need**) position encoding is:  
PE_(pos, 2i) = sin(pos / 10000^(2i / d_model))  
PE_(pos, 2i+1) = cos(pos / 10000^(2i / d_model))  


In [32]:
class PositionEncoding(nn.Module):

    def __init__(self, d_model=2, max_len=6):

        super().__init__()

        # pe stands for position encoding
        pe = torch.zeros(max_len, d_model)

        # position is a column matrix (2D) of size [max_len, 1], e.g. [[0.], [1.], [2.]]
        position = torch.arange(start=0, end=max_len, step=1).float().unsqueeze(1)

        # Step is set to 2 because of "2i" in the formula, note that it is a 1D tensor, e.g. [0., 2.] as each position can have multiple embedding values
        embedding_index = torch.arange(start=0, end=d_model, step=2).float()

        # div_term is a row matrix (1D) with the same size as embedding_index
        div_term = torch.tensor(10000.)**(embedding_index / d_model)

        # Note: calculating the sin and cos values in this way only works when d_model is an even number, if d_model is odd, there will be a shape mismatch
        pe[:, 0::2] = torch.sin(position / div_term)
        pe[:, 1::2] = torch.cos(position / div_term)

        self.register_buffer('pe', pe)
    

    def forward(self, word_embeddings):

        # Note: we might not need all the position encodings, as the number of tokens might not hit the maximum length (max_len)
        return word_embeddings + self.pe[:word_embeddings.size(0), :]
    


# Masked Self-Attention

In [33]:
class Attention(nn.Module):

    def __init__(self, d_model=2):

        super().__init__()

        # Create the weights associated with the query, key and value values
        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

        self.row_dim = 0
        self.col_dim = 1

    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):

        # Create the Q, K and V matrices
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)

        # Calculate the similarity score between the query values and key values
        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

        # Scale the similarity score with the square root of d_model
        scaled_sims = sims / torch.tensor((k.size(self.col_dim))**0.5)

        device = scaled_sims.device
        # Move your mask to mps:0, or mask would live in cpu by default
        mask = mask.to(device)

        # Mask the scaled similarity scores of the later tokens so that the earlier tokens can't cheat. Note: -1e9 is an approximation of negative infinity
        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)

        # Applying the softmax function to the scaled similarites determines the percentages of influence each token (in columns) should have on the others (in rows)
        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)

        # attention_scores are basically the contextualised embeddings
        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores


# Decoder-only Transformer

In [34]:
class DecoderOnlyTransformer(L.LightningModule):

    def __init__(self, num_tokens, d_model, max_len):

        super().__init__()

        # Word Embeddings
        self.we = nn.Embedding(num_embeddings=num_tokens, embedding_dim=d_model)

        # Position Encodings
        self.pe = PositionEncoding(d_model=d_model, max_len=max_len)

        # Masked Self-Attention
        self.attention = Attention(d_model=d_model)

        # Fully Connected layer
        self.fc = nn.Linear(in_features=d_model, out_features=num_tokens)

        # Calculate the loss with Cross Entropy; softmax is already included
        self.loss = nn.CrossEntropyLoss()

    # The size of token_ids just needs to be a 1D tensor (without batching), unlike nn.LSTM, which requires the size of the input tensor to be [seq_len, batch_size, input_size]
    def forward(self, token_ids):

        # Create word embeddings
        word_embeddings = self.we(token_ids)

        # Add position encodings to the word embeddings
        position_encoded = self.pe(word_embeddings)

        # Create a mask matrix for masking used in masked self attention
        # The size of mask should be [seq_len, seq_len]
        mask = torch.tril(torch.ones(token_ids.size(dim=0),token_ids.size(dim=0))) == 0

        # Masked Self-Attention
        self_attention_values = self.attention(position_encoded,
                                               position_encoded,
                                               position_encoded,
                                               mask=mask)

        # Add residual connections
        residual_connection_values = position_encoded + self_attention_values

        # Run the residual connections through a fully connected layer
        # fc_layer_out has the same size as residual_connected_values, which is a 2D tensor of size [seq_len, num_tokens]
        fc_layer_out = self.fc(residual_connection_values)

        # Return the fully connected layer
        return fc_layer_out
    
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=0.1)
    
    def training_step(self, batch, batch_idx):
        
        # input_tokens is a 2D tensor of size [batch_size, seq_len]
        input_tokens, labels = batch
        # print(batch, '\n')
        # print(input_tokens[0])
        # outputs is fc_layer_out, so they share the same size
        outputs = self.forward(input_tokens[0])
        # Cross Entropy loss will automatically apply softmax to the outputs
        loss = self.loss(outputs, labels[0])

        return loss


In [35]:
# Before we train the model, let's see what the model outputs for fun
model = DecoderOnlyTransformer(num_tokens=len(token_to_id), d_model=2, max_len=6)

# Create a prompt
model_input = torch.tensor([token_to_id["what"],
                            token_to_id["is"],
                            token_to_id["statquest"],
                            token_to_id["<EOS>"]])

input_length = model_input.size(dim=0)

# predictions is the raw score output by the fulled connected layer
predictions = model(model_input)
# predictions is a 2D tensor of size [seq_len (or input_length), num_tokens (or vocab_size)] orginally, however, we are only interested in the prediction of the last token in the prompt
# So now predicted_id is a 1D tensor of size [1]
predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
# Used to store the coming predicted_id's during inference
predicted_ids = predicted_id


In [36]:
# Generate output
max_len = 6

for i in range(input_length, max_len):
    # Check if the predicted_id is actually the one for <EOS>, if yes, break the loop
    # This condition is actually a bit sloppy, it is comparing a 1D tensor against an integer, but it works, and it returns a 1D tensor like tensor([True])
    if (predicted_id == token_to_id["<EOS>"]):
        break

    # if not, continue generating the next token
    # But first, include the newly generated token into the input first so the model has full context
    model_input = torch.cat((model_input, predicted_id))
    predictions = model(model_input)
    predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
    predicted_ids = torch.cat((predicted_id, predicted_ids))

# Print the output
print("Predicted Tokens:\n")
for id in predicted_ids:
    # Don't forget id is a 0D tensor actually, i.e. tensor(0) because it is an output from torch.argmax
    print("\t", id_to_token[id.item()])

Predicted Tokens:

	 awesome
	 awesome
	 awesome


In [None]:
# That means we need to train the model...
trainer = L.Trainer(max_epochs=30)
trainer.fit(model, train_dataloaders=dataloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | we        | Embedding        | 10     | train
1 | pe        | PositionEncoding | 0      | train
2 | attention | Attention        | 12     | train
3 | fc        | Linear           | 15     | train
4 | loss      | CrossEntropyLoss | 0      | train
-------------------------------------------------------
37        Trainable params
0         Non-trainable params
37        Total params
0.000     Total estimated model params size (MB)
/Users/edison/Git/pytorch-playground/myenv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/ed

Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


In [38]:
# Run the same code after training
# Test the model with the first prompt
model_input = torch.tensor([token_to_id["what"],
                            token_to_id["is"],
                            token_to_id["statquest"],
                            token_to_id["<EOS>"]])

input_length = model_input.size(dim=0)

# predictions is the raw score output by the fulled connected layer
predictions = model(model_input)
# predictions is a 2D tensor of size [seq_len (or input_length), num_tokens (or vocab_size)] orginally, however, we are only interested in the prediction of the last token in the prompt
# So now predicted_id is a 1D tensor of size [1]
predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
# Used to store the coming predicted_id's during inference
predicted_ids = predicted_id


In [39]:
# Generate output
for i in range(input_length, max_len):
    # Check if the predicted_id is actually the one for <EOS>, if yes, break the loop
    # This condition is actually a bit sloppy, it is comparing a 1D tensor against an integer, but it works, and it returns a 1D tensor like tensor([True])
    if (predicted_id == token_to_id["<EOS>"]):
        break

    # if not, continue generating the next token
    # But first, include the newly generated token into the input first so the model has full context
    model_input = torch.cat((model_input, predicted_id))
    predictions = model(model_input)
    predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
    predicted_ids = torch.cat((predicted_ids, predicted_id))

# Print the output
print("Predicted Tokens:\n")
for id in predicted_ids:
    # Don't forget id is a 0D tensor actually, i.e. tensor(0) because it is an output from torch.argmax
    print("\t", id_to_token[id.item()])

Predicted Tokens:

	 awesome
	 <EOS>


In [40]:
# Test the model with the second prompt
model_input = torch.tensor([token_to_id["statquest"],
                            token_to_id["is"],
                            token_to_id["what"],
                            token_to_id["<EOS>"]])

input_length = model_input.size(dim=0)

# predictions is the raw score output by the fulled connected layer
predictions = model(model_input)
# predictions is a 2D tensor of size [seq_len (or input_length), num_tokens (or vocab_size)] orginally, however, we are only interested in the prediction of the last token in the prompt
# So now predicted_id is a 1D tensor of size [1] for the concatenation of predicted_id and predicted_ids (1D) later
predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
# Used to store the coming predicted_id's during inference
predicted_ids = predicted_id


In [41]:
# Generate output
for i in range(input_length, max_len):
    # Check if the predicted_id is actually the one for <EOS>, if yes, break the loop
    # This condition is actually a bit sloppy, it is comparing a 1D tensor against an integer, but it works, and it returns a 1D tensor like tensor([True])
    if (predicted_id == token_to_id["<EOS>"]):
        break

    # if not, continue generating the next token
    # But first, include the newly generated token into the input first so the model has full context
    model_input = torch.cat((model_input, predicted_id))
    predictions = model(model_input)
    predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
    predicted_ids = torch.cat((predicted_ids, predicted_id))

# Print the output
print("Predicted Tokens:\n")
for id in predicted_ids:
    # Don't forget id is a 0D tensor actually, i.e. tensor(0) because it is an output from torch.argmax
    print("\t", id_to_token[id.item()])

Predicted Tokens:

	 awesome
	 <EOS>
