# Replicate the Decoder-Only Transformers architecture with PyTorch + Lightning

In [176]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader

import lightning as L

## Setup device-agnostic code

In [177]:
device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
device

'mps'

## Create a tokenizer

In [178]:
# Create a mapping from vocabs to numbers to mimic the tokenizer
token_to_id = {
    "<CLS>": 0, # mimic the special class token in BERT
    "<EOS>": 1, # <SEP> in BERT
    "machine": 2,
    "learning": 3,
    "i": 4,
    "hate": 5,
    "is": 6,
    "fun": 7
}

# Create a mapping from numbers back to vocabs for interpreting the output from the transformer later
id_to_token = dict(map(reversed, token_to_id.items()))
id_to_token

{0: '<CLS>',
 1: '<EOS>',
 2: 'machine',
 3: 'learning',
 4: 'i',
 5: 'hate',
 6: 'is',
 7: 'fun'}

## Create a training dataset

In [179]:
inputs = torch.tensor([
    [
        token_to_id["<CLS>"],
        token_to_id["machine"],
        token_to_id["learning"],
        token_to_id["is"],
        token_to_id["fun"],
        token_to_id["<EOS>"]
    ],
    
    [
        token_to_id["<CLS>"],
        token_to_id["i"],
        token_to_id["hate"],
        token_to_id["machine"],
        token_to_id["learning"],
        token_to_id["<EOS>"]
    ]
])

labels = torch.tensor([1, 0])
inputs, labels

(tensor([[0, 2, 3, 6, 7, 1],
         [0, 4, 5, 2, 3, 1]]),
 tensor([1, 0]))

## Create position encoding

We will use the formula used in the paper *Attention is all you need* position encoding is:
* PE_(pos, 2i) = sin(pos / 10000^(2i / d_model))
* PE_(pos, 2i+1) = cos(pos / 10000^(2i / d_model))

In [180]:
class PositionEncoding(nn.Module):

    def __init__(self, d_model=2, max_len=6):

        super().__init__()

        pe = torch.zeros(max_len, d_model)

        position = torch.arange(start=0, end=max_len, step=1).float().unsqueeze(1)

        embedding_index = torch.arange(start=0, end=d_model, step=2).float()

        div_term = torch.tensor(10000)**(embedding_index / d_model)

        pe[:, 0::2] = torch.sin(position / div_term)
        pe[:, 1::2] = torch.cos(position / div_term)

        self.register_buffer("pe", pe)
    
    def forward(self, word_embeddings):

        return word_embeddings + self.pe[:word_embeddings.size(0), :]

## Create Self-Attention layers

In [181]:
class Attention(nn.Module):

    def __init__(self, d_model=2):

        super().__init__()

        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

        self.row_dim = 0
        self.col_dim = 1

    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None, device=device):

        # Create the Q, K and V matrices
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)

        # Calculate the similarity score between the queries and values
        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

        # Scale the similarity score with the square root of d_model
        scaled_sims = sims / torch.tensor((k.size(self.col_dim))**0.5)

        # Mask the scaled similarity scores of the later tokens so that the earlier tokens can't cheat during training
        if mask is not None:
            mask = mask.to(device) # move your mask to the target device because a manually created tensor lives in the cpu by default
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)  # -1e9 is an approximation of negative infinity

        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)

        # attention_scores are the contextualised embeddings
        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores

## Create a Encoder-only Transformer

In [182]:
class EncoderOnlyTransformer(L.LightningModule):

    def __init__(self, num_tokens, d_model, max_len):

        super().__init__()

        # Word Embeddings
        self.we = nn.Embedding(num_embeddings=num_tokens, embedding_dim=d_model)

        # Position Encodings
        self.pe = PositionEncoding(d_model=d_model, max_len=max_len)

        # Self-Attention
        self.attention = Attention(d_model=d_model)

        # Classifier head
        self.cls = nn.Linear(in_features=d_model, out_features=2)

        # Calculate the loss with Cross Entropy; softmax is already included
        self.loss = nn.CrossEntropyLoss()

    def forward(self, token_ids):

        # Create word embeddings
        word_embeddings = self.we(token_ids)

        # Add position encodings to the word embeddings
        position_encoded = self.pe(word_embeddings)

        # Create Self-Attention layers
        self_attention_values = self.attention(position_encoded,
                                               position_encoded,
                                               position_encoded,
                                               mask=None) # no mask is needed
        
        # Add residual connections
        residual_connection_values = position_encoded + self_attention_values
    
        # Pass the class token to the MLP
        fc_layer_out = self.cls(residual_connection_values[0].unsqueeze(dim=0))
    
        return fc_layer_out

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=0.1)
    
    def training_step(self, batch, batch_idx):
        input_tokens, labels = batch
        # print(f"input shape: {input_tokens.shape}, labels shape: {labels.shape}")
        outputs = self.forward(input_tokens[0])
        loss = self.loss(outputs, labels)

        return loss

In [183]:
# Test the model
model = EncoderOnlyTransformer(num_tokens=len(token_to_id), d_model=2, max_len=6).to(device) # move the model to mps

predictions = model(inputs[0].to(device)) # the logits (scores) for each category
print(f"Model predictions (logits) before training:\n {predictions}")
print()
print(f"Model predictions (probabilities) before training:\n {torch.softmax(predictions, dim=1)}")

Model predictions (logits) before training:
 tensor([[-0.0701, -0.9977]], device='mps:0', grad_fn=<LinearBackward0>)

Model predictions (probabilities) before training:
 tensor([[0.7166, 0.2834]], device='mps:0', grad_fn=<SoftmaxBackward0>)


## Train the model with `Lightning.Trainer.fit()`

In [184]:
### Create dataloader
from torch.utils.data import TensorDataset, DataLoader

dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset)
len(dataloader)

2

In [185]:
# Create an instance of trainer
trainer = L.Trainer(max_epochs=30,
                    accelerator="auto")
trainer.fit(model,
            train_dataloaders=dataloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | we        | Embedding        | 16     | train
1 | pe        | PositionEncoding | 0      | train
2 | attention | Attention        | 12     | train
3 | cls       | Linear           | 6      | train
4 | loss      | CrossEntropyLoss | 0      | train
-------------------------------------------------------
34        Trainable params
0         Non-trainable params
34        Total params
0.000     Total estimated model params size (MB)
/Users/edison/Git/pytorch-lightning-deep-learning/myenv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performan

Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


## Test the model after training with both training samples

In [186]:
model.to(device) # the model will be moved to cpu after calling trainer.fit() by default
model.device

device(type='mps')

In [187]:
# Convert the token ids in the training samples to vocabs
inputs_vocab = []
for input in inputs:
    input_vocab = list(map(lambda x: id_to_token[x.item()], input))
    inputs_vocab.append(input_vocab)

print(inputs_vocab)

[['<CLS>', 'machine', 'learning', 'is', 'fun', '<EOS>'], ['<CLS>', 'i', 'hate', 'machine', 'learning', '<EOS>']]


In [188]:
# Test the model after training with both training samples
for input in inputs:
    # Make a prediction with the model
    pred = torch.argmax(model(input.to(device)), dim=1)

for i, input in enumerate(inputs_vocab):
    print(f"Training sample {i}: {input} | Prediction: {torch.argmax(model(inputs[i].to(device)), dim=1).item()}")    

Training sample 0: ['<CLS>', 'machine', 'learning', 'is', 'fun', '<EOS>'] | Prediction: 1
Training sample 1: ['<CLS>', 'i', 'hate', 'machine', 'learning', '<EOS>'] | Prediction: 0
