# Replicate the Decoder-Only Transformers architecture with PyTorch + Lightning

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader

import lightning as L

## Setup device-agnostic code

In [2]:
device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
device

'mps'

## Create a tokenizer

In [3]:
# Create a mapping from vocabs to numbers to mimic the tokenizer
token_to_id = {
    "what": 0,
    "is": 1,
    "robot": 2,
    "awesome": 3,
    "<EOS>": 4,
}

# Create a mapping from numbers back to vocabs for interpreting the output from the transformer later
id_to_token = dict(map(reversed, token_to_id.items()))
id_to_token

{0: 'what', 1: 'is', 2: 'robot', 3: 'awesome', 4: '<EOS>'}

## Create a training dataset

**Note:** the transformer processes all tokens in parallel (only) during training (or when using teacher forcing)

In [4]:
inputs = torch.tensor([
    [
        token_to_id["what"],
        token_to_id["is"],
        token_to_id["robot"],
        token_to_id["<EOS>"],
        token_to_id["awesome"]
    ],
    
    [
        token_to_id["robot"],
        token_to_id["is"],
        token_to_id["what"],
        token_to_id["<EOS>"],
        token_to_id["awesome"],
    ]
])

labels = torch.tensor([
    [
        token_to_id["is"],
        token_to_id["robot"],
        token_to_id["<EOS>"],
        token_to_id["awesome"],
        token_to_id["<EOS>"]
    ],

    [
        token_to_id["is"],
        token_to_id["what"],
        token_to_id["<EOS>"],
        token_to_id["awesome"],
        token_to_id["<EOS>"],
    ]
])

## Create position encoding

We will use the formula used in the paper *Attention is all you need* position encoding is:
* PE_(pos, 2i) = sin(pos / 10000^(2i / d_model))
* PE_(pos, 2i+1) = cos(pos / 10000^(2i / d_model))

In [5]:
class PositionEncoding(nn.Module):

    def __init__(self, d_model=2, max_len=6):

        super().__init__()

        pe = torch.zeros(max_len, d_model)

        position = torch.arange(start=0, end=max_len, step=1).float().unsqueeze(1)

        embedding_index = torch.arange(start=0, end=d_model, step=2).float()

        div_term = torch.tensor(10000)**(embedding_index / d_model)

        pe[:, 0::2] = torch.sin(position / div_term)
        pe[:, 1::2] = torch.cos(position / div_term)

        self.register_buffer("pe", pe)
    
    def forward(self, word_embeddings):

        return word_embeddings + self.pe[:word_embeddings.size(0), :]

## Create masked Self-Attention layers

In [6]:
class Attention(nn.Module):

    def __init__(self, d_model=2):

        super().__init__()

        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

        self.row_dim = 0
        self.col_dim = 1

    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None, device=device):

        # Create the Q, K and V matrices
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)

        # Calculate the similarity score between the queries and values
        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

        # Scale the similarity score with the square root of d_model
        scaled_sims = sims / torch.tensor((k.size(self.col_dim))**0.5)

        # Mask the scaled similarity scores of the later tokens so that the earlier tokens can't cheat during training
        if mask is not None:
            # Move your mask to the target device because a manually created tensor lives in the cpu by default
            mask = mask.to(device)
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)  # -1e9 is an approximation of negative infinity

        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)

        # attention_scores are the contextualised embeddings
        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores

## Create a Decoder-only Transformer

In [7]:
class DecoderOnlyTransformer(L.LightningModule):

    def __init__(self, num_tokens, d_model, max_len):

        super().__init__()

        # Word Embeddings
        self.we = nn.Embedding(num_embeddings=num_tokens, embedding_dim=d_model)

        # Position Encodings
        self.pe = PositionEncoding(d_model=d_model, max_len=max_len)

        # Masked Self-Attention
        self.attention = Attention(d_model=d_model)

        # Fully Connected layer
        self.fc = nn.Linear(in_features=d_model, out_features=num_tokens)

        self.loss = nn.CrossEntropyLoss() # softmax is included

    def forward(self, token_ids):
        
        # print(f"The model is training on: {next(self.parameters()).device}")

        # Create word embeddings
        word_embeddings = self.we(token_ids)

        # Add position encodings to the word embeddings
        position_encoded = self.pe(word_embeddings)

        # Create a mask matrix for masking used in masked self attention
        mask = torch.tril(torch.ones(token_ids.size(dim=0),token_ids.size(dim=0))) == 0 # the shape of mask is: [seq_len, seq_len]

        # Masked Self-Attention
        self_attention_values = self.attention(position_encoded,
                                               position_encoded,
                                               position_encoded,
                                               mask=mask)

        # Add residual connections
        residual_connection_values = position_encoded + self_attention_values

        # Run the residual connections through a fully connected layer
        fc_layer_out = self.fc(residual_connection_values)

        return fc_layer_out
    
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=0.1)
    
    def training_step(self, batch, batch_idx):
        # input_tokens comes in a batch: [batch_size, seq_len], with batch_size=1
        input_tokens, labels = batch
        outputs = self.forward(input_tokens[0])
        loss = self.loss(outputs, labels[0])

        return loss

In [78]:
# Test the model
model = DecoderOnlyTransformer(num_tokens=len(token_to_id), d_model=2, max_len=6).to(device) # move the model to mps

# Create a prompt
model_input = inputs[0].to(device)

input_length = model_input.size(dim=0)

predictions = model(model_input) # the logits (scores) for each vocab in the tokenizer
print(f"Model predictions (logits) before training:\n {predictions}")
print()
print(f"Model predictions (probabilities) before training:\n {torch.softmax(predictions, dim=1)}")

predictions_vocab = list(map(lambda x: id_to_token[x.item()], torch.argmax(torch.softmax(predictions, dim=1), dim=1).cpu()))
print(f"\nModel predictions (vocab) before training:\n {predictions_vocab}")

Model predictions (logits) before training:
 tensor([[-0.3656, -0.2974,  0.1898,  0.5190, -0.1002],
        [-0.8329, -0.4573,  0.0647,  0.4218, -0.2724],
        [ 0.2757,  0.4283, -0.2465,  0.1536, -0.4414],
        [-0.5160, -0.1763, -0.0577,  0.3177, -0.3525],
        [ 0.0559, -0.4981,  0.7168,  0.9466,  0.4487]], device='mps:0',
       grad_fn=<LinearBackward0>)

Model predictions (probabilities) before training:
 tensor([[0.1326, 0.1420, 0.2311, 0.3213, 0.1730],
        [0.0983, 0.1432, 0.2413, 0.3449, 0.1723],
        [0.2421, 0.2820, 0.1436, 0.2142, 0.1182],
        [0.1340, 0.1881, 0.2118, 0.3083, 0.1577],
        [0.1346, 0.0774, 0.2607, 0.3280, 0.1994]], device='mps:0',
       grad_fn=<SoftmaxBackward0>)

Model predictions (vocab) before training:
 ['awesome', 'awesome', 'is', 'awesome', 'awesome']


## Train the model with `Lightning.Trainer.fit()`

In [79]:
### Create dataloader
from torch.utils.data import TensorDataset, DataLoader

dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset)
len(dataloader)

2

In [None]:
# Create an instance of trainer
trainer = L.Trainer(max_epochs=30,
                    accelerator="auto")
trainer.fit(model,
            train_dataloaders=dataloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | we        | Embedding        | 10     | train
1 | pe        | PositionEncoding | 0      | train
2 | attention | Attention        | 12     | train
3 | fc        | Linear           | 15     | train
4 | loss      | CrossEntropyLoss | 0      | train
-------------------------------------------------------
37        Trainable params
0         Non-trainable params
37        Total params
0.000     Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


## Test the model after training with both training samples

In [86]:
model.to(device) # the model will be moved to cpu after calling trainer.fit() by default
model.device

device(type='mps')

In [94]:
# Convert the token ids in the training samples to vocabs
inputs_vocab = []
for input in inputs:
    input_vocab = list(map(lambda x: id_to_token[x.item()], input))
    inputs_vocab.append(input_vocab)

print(inputs_vocab)

[['what', 'is', 'robot', '<EOS>', 'awesome'], ['robot', 'is', 'what', '<EOS>', 'awesome']]


In [104]:
# Test the model after training with both training samples
preds_vocab=[]

for input in inputs:
    # Make a prediction with the model
    pred = torch.argmax(model(input.to(device)), dim=1)
    # Convert the token ids back to vocabs
    preds_vocab.append(list(map(lambda x: id_to_token[x.item()], pred)))

# print(preds_vocab)

for i, input_pred in enumerate(zip(inputs_vocab, preds_vocab)):
    print(f"Training sample {i}:\n")
    
    input, pred = input_pred
    for j in range(len(input)):
        print(f"\tInput: {input[j]} | Pred: {pred[j]}")
    print()

Training sample 0:

	Input: what | Pred: is
	Input: is | Pred: robot
	Input: robot | Pred: <EOS>
	Input: <EOS> | Pred: awesome
	Input: awesome | Pred: <EOS>

Training sample 1:

	Input: robot | Pred: is
	Input: is | Pred: what
	Input: what | Pred: <EOS>
	Input: <EOS> | Pred: awesome
	Input: awesome | Pred: <EOS>

