# Encoder–Decoder Transformers

This notebook demonstrates how to build a very simple encoder–decoder Transformer from scratch, following the foundational ideas of the paper “Attention Is All You Need.” The encoder–decoder architecture underpins many modern machine translation systems, as well as diverse NLP tasks. By working through this notebook, you’ll see how the encoder processes the input to capture essential context, and how the decoder then generates the output—providing a solid grounding in how this classic Transformer design operates under the hood.

In [190]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader

import lightning as L

# Create training dataset


In [191]:
# Create two mappings (for encoder and decoder) from vocabs to numbers as nn.Embedding can only take integers
token_to_id_eng = {"<SOS>": 0,
                   "let's": 1,
                   "go": 2,
                   "love": 3,
                   "you": 4,
                   "<EOS>": 5}

token_to_id_spa = {"<SOS>": 0,
                   "ir" : 1,
                   "vamos": 2,
                   "te": 3,
                   "amo": 4,
                   "<EOS>": 5}
                  
# Create a mapping from numbers back to spanish vocabs in order to interpret the output from the transformer
id_to_token_spa = dict(map(reversed, token_to_id_spa.items()))

# This is actually not needed
id_to_token_eng = dict(map(reversed, token_to_id_eng.items()))


In [192]:
# Create the training pytorch dataset
# As the input is going to be word embeddings, we only need the corresponding numbers from the mapping

inputs = torch.tensor([[token_to_id_eng["let's"],
                        token_to_id_eng["go"],
                        token_to_id_eng["<EOS>"],
                        token_to_id_spa["<SOS>"],
                        token_to_id_spa["ir"],
                        token_to_id_spa["vamos"]],
                        
                       [token_to_id_eng["love"],
                        token_to_id_eng["you"],
                        token_to_id_eng["<EOS>"],
                        token_to_id_spa["<SOS>"],
                        token_to_id_spa["te"],
                        token_to_id_spa["amo"]]])

labels = torch.tensor([[token_to_id_spa["ir"],
                        token_to_id_spa["vamos"],
                        token_to_id_spa["<EOS>"]], 
                         
                       [token_to_id_spa["te"],
                        token_to_id_spa["amo"],
                        token_to_id_spa["<EOS>"]]])

dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset)

# Position Encoding

The formula for the (standard, used in the paper **Attention is all you need**) position encoding is:  
PE_(pos, 2i) = sin(pos / 10000^(2i / d_model))  
PE_(pos, 2i+1) = cos(pos / 10000^(2i / d_model))  


In [193]:
class PositionEncoding(nn.Module):

    def __init__(self, d_model=2, max_len=6):

        super().__init__()

        # pe stands for position encoding
        pe = torch.zeros(max_len, d_model)

        # position is a column matrix (2D) of size [max_len, 1], e.g. [[0.], [1.], [2.]]
        position = torch.arange(start=0, end=max_len, step=1).float().unsqueeze(1)

        # Step is set to 2 because of "2i" in the formula, note that it is a 1D tensor, e.g. [0., 2.] as each position can have multiple embedding values
        embedding_index = torch.arange(start=0, end=d_model, step=2).float()

        # div_term is a row matrix (1D) with the same size as embedding_index
        div_term = torch.tensor(10000.)**(embedding_index / d_model)

        # Note: calculating the sin and cos values in this way only works when d_model is an even number, if d_model is odd, there will be a shape mismatch
        pe[:, 0::2] = torch.sin(position / div_term)
        pe[:, 1::2] = torch.cos(position / div_term)

        self.register_buffer('pe', pe)
    

    def forward(self, word_embeddings):

        # Note: we might not need all the position encodings, as the number of tokens might not hit the maximum length (max_len)
        return word_embeddings + self.pe[:word_embeddings.size(0), :]
    


# Attention (including encoder/decoder (masked) self-attention and encoder-decoder attention)

In [194]:
class Attention(nn.Module):

    def __init__(self, d_model=2):

        super().__init__()

        # Create the weights associated with the query, key and value values
        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

        self.row_dim = 0
        self.col_dim = 1

    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):

        # Create the Q, K and V matrices
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)

        # Calculate the similarity score between the query values and key values
        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

        # Scale the similarity score with the square root of d_model
        scaled_sims = sims / torch.tensor((k.size(self.col_dim))**0.5)

        device = scaled_sims.device
        # Move your mask to mps:0, or mask would live in cpu by default
        
        # Mask the scaled similarity scores of the later tokens so that the earlier tokens can't cheat. Note: -1e9 is an approximation of negative infinity
        if mask is not None:
            mask = mask.to(device)
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)

        # Applying the softmax function to the scaled similarites determines the percentages of influence each token (in columns) should have on the others (in rows)
        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)

        # attention_scores are basically the contextualised embeddings
        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores


# Encoder

In [195]:
class Encoder(nn.Module):

    def __init__(self, num_tokens, d_model, max_len):

        super().__init__()

        # Word Embeddings
        self.we = nn.Embedding(num_embeddings=num_tokens, embedding_dim=d_model)

        # Position Encodings
        self.pe = PositionEncoding(d_model=d_model, max_len=max_len)

        # Encoder Self-Attention
        self.attention = Attention(d_model=d_model)

        # Fully Connected layer, it's commented out here for simplicity
        # self.fc = nn.Linear(in_features=d_model, out_features=d_model)

    # The size of token_ids just needs to be a 1D tensor (without batching)
    def forward(self, token_ids):

        # Create word embeddings
        word_embeddings = self.we(token_ids)

        # Add position encodings to the word embeddings
        position_encoded = self.pe(word_embeddings)

        # Self-Attention
        self_attention_values = self.attention(position_encoded,
                                               position_encoded,
                                               position_encoded,
                                               mask=None)

        # Add residual connections
        residual_connection_values = position_encoded + self_attention_values

        # Run the residual connections through a fully connected layer
        # fc_layer_out = self.fc(residual_connection_values)

        # Return the residual connections
        return residual_connection_values


# Decoder

In [196]:
class Decoder(nn.Module):

    # Actually, we can further categorise these parameters (encoder- and decoder-)
    def __init__(self, num_tokens, d_model, max_len):

        super().__init__()

        # Word Embeddings
        self.we = nn.Embedding(num_embeddings=num_tokens, embedding_dim=d_model)

        # Position Encodings
        self.pe = PositionEncoding(d_model=d_model, max_len=max_len)

        # Decoder Self-Attention
        self.self_attention = Attention(d_model=d_model)

        # Encoder-Decoder Attention
        # Note: it needs to be separated from the decoder self-attention as the weights are different
        self.encoder_decoder_attention = Attention(d_model=d_model)

    # The size of token_ids just needs to be a 1D tensor (without batching)
    def forward(self, token_ids, embeddings_for_k, embeddings_for_v):
        
        # Create word embeddings
        word_embeddings = self.we(token_ids)

        # Add position encodings to the word embeddings
        position_encoded = self.pe(word_embeddings)

        # Create a mask matrix for masking used in decoder self attention
        # The size of mask should be [decoder_seq_len, decoder_seq_len]
        mask = torch.tril(torch.ones(token_ids.size(dim=0),token_ids.size(dim=0))) == 0

        # Decoder Self-Attention
        self_attention_values = self.self_attention(position_encoded,
                                                    position_encoded,
                                                    position_encoded,
                                                    mask=mask)

        # Add residual connections
        residual_connection_values = position_encoded + self_attention_values

        # Encdoer-Decoder Attention
        encoder_decoder_attention_values = self.encoder_decoder_attention(position_encoded,
                                                                          embeddings_for_k,
                                                                          embeddings_for_v,
                                                                          mask=None)
        # Add residual connections
        residual_connection_values = encoder_decoder_attention_values + residual_connection_values

        return residual_connection_values


# Transformer

In [205]:
class Transformer(L.LightningModule):

    # Note: these parameters can actually be further categorised as encoder's and decoder's, but we keep it simple here
    def __init__(self, num_tokens, d_model, max_len):

        super().__init__()

        # Encoder
        self.encoder = Encoder(num_tokens=num_tokens, d_model=d_model, max_len=max_len)

        # Decoder
        self.decoder = Decoder(num_tokens=num_tokens, d_model=d_model, max_len=max_len)

        # Fully Connected layer
        self.fc = nn.Linear(in_features=d_model, out_features=num_tokens)

        # Calculate the loss with Cross Entropy; softmax is already included
        self.loss = nn.CrossEntropyLoss()

        # For the purpose of token generation
        self.max_len = max_len

    def forward(self, token_ids):

        # Note: token_ids here include both english and spanish token ids
        # So we need to find where the english sentence ends first, note that [0] is needed or it will return a tuple instead
        end_idx = torch.where(token_ids == token_to_id_spa["<SOS>"])[0].item()
        
        # Get the contextualised embeddings from the encoder
        contextualised_embeddings = self.encoder(token_ids[:end_idx])

        # Get the residual connection values from the decoder for further processing
        residual_connection_values = self.decoder(token_ids[end_idx:],
                                                  contextualised_embeddings,
                                                  contextualised_embeddings)
        
        # Run the residual connections through a fully connected layer
        fc_layer_out = self.fc(residual_connection_values)

        # Return the fully connected layer
        return fc_layer_out
    
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=0.1)
    
    def training_step(self, batch, batch_idx):
        
        # input_tokens is a 2D tensor of size [batch_size, seq_len]
        input_tokens, labels = batch
        # print(batch, '\n')
        # print(input_tokens[0])
        # outputs is fc_layer_out, so they share the same size
        outputs = self.forward(input_tokens[0])
        # Cross Entropy loss will automatically apply softmax to the outputs
        loss = self.loss(outputs, labels[0])

        return loss
    
    # Inference
    def generate(self, src_token_ids):

        # Encode the source language sentence
        contextualised_embeddings = self.encoder(src_token_ids)

        # Remind the decoder to start decoding by feeding it a <SOS> token
        # Note: target_token_ids needs to be 1D
        target_token_ids = torch.tensor([token_to_id_spa["<SOS>"]])

        # Start decoding
        residual_connection_values = self.decoder(target_token_ids,
                                                  contextualised_embeddings,
                                                  contextualised_embeddings)
        
        # Run the residual connections through a fully connected layer
        fc_layer_out = self.fc(residual_connection_values)

        # Run the output through a argmax layer (softmax is not needed anymore as no derivatives are required here) and then apply argmax to it to get the prediction
        # Note: we need to make prediction_id a 1D tensor for the concatenation later
        prediction_id = torch.tensor([torch.argmax(fc_layer_out)])
        # prediction_ids = prediction_id

        # Now we can start predicting the next token recursively
        for i in range(self.max_len-1):

            # Include the newly generated token into the input first so the decoder has full context
            target_token_ids = torch.cat((target_token_ids, prediction_id))

            # Check if the newly prediction_id is actually the one for <EOS>, if so, break the loop
            if (prediction_id == token_to_id_spa["<EOS>"]):
                break
            
            # Start decoding
            residual_connection_values = self.decoder(target_token_ids,
                                                      contextualised_embeddings,
                                                      contextualised_embeddings)
            
            # Run the residual connections through a fully connected layer
            # Note we don't to predict the first token again from <SOS>, so we just to pass the last attention output values (token) to the fc layer
            fc_layer_out = self.fc(residual_connection_values[-1:])
            prediction_id = torch.tensor([torch.argmax(fc_layer_out)])

        
        # Print the output
        print("Predicted Tokens:\n")
        for id in target_token_ids:
            print("\t", id_to_token_spa[id.item()])


# Training (teacher forcing)

In [206]:
model = Transformer(num_tokens=len(token_to_id_eng), d_model=2, max_len=5)

In [207]:
trainer = L.Trainer(max_epochs=30)
trainer.fit(model, train_dataloaders=dataloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name    | Type             | Params | Mode 
-----------------------------------------------------
0 | encoder | Encoder          | 24     | train
1 | decoder | Decoder          | 36     | train
2 | fc      | Linear           | 18     | train
3 | loss    | CrossEntropyLoss | 0      | train
-----------------------------------------------------
78        Trainable params
0         Non-trainable params
78        Total params
0.000     Total estimated model params size (MB)
/Users/edison/Git/pytorch-playground/myenv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/edison/Git/pytorch-playground/myenv/lib/python3.11/site-packages/l

Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


# Inference (autoregressive)

In [208]:
# Test case 1
test_input = torch.tensor([token_to_id_eng["let's"],
                           token_to_id_eng["go"]])

model.generate(test_input)

Predicted Tokens:

	 <SOS>
	 ir
	 vamos
	 <EOS>


In [209]:
# Test case 2
test_input = torch.tensor([token_to_id_eng["love"],
                           token_to_id_eng["you"]])

model.generate(test_input)

Predicted Tokens:

	 <SOS>
	 te
	 amo
	 <EOS>


In [210]:
# Test case 3
test_input = torch.tensor([token_to_id_eng["go"],
                           token_to_id_eng["love"],
                           token_to_id_eng["you"]])

model.generate(test_input)

Predicted Tokens:

	 <SOS>
	 amo
	 te
	 vamos
	 <EOS>
