In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader

import lightning as L

# Create training dataset

In [2]:
# Create a mapping from vocabs to numbers as nn.Embedding can only take integers
token_to_id = {"what": 0,
               "is": 1,
                "statquest": 2,
                "awesome": 3,
                "<EOS>": 4
              }

# Create a mapping from numbers back to vocabs to interpret the output from the transformer
id_to_token = dict(map(reversed, token_to_id.items()))


In [3]:
# Create the training dataset
# As the input is going to be word embeddings, we only need the corresponding numbers from the mapping
# The tokens used as inputs during training comes from 1. processing the prompt and 2. generating the output
inputs = torch.tensor([[token_to_id["what"],
                        token_to_id["is"],
                        token_to_id["statquest"],
                        token_to_id["<EOS>"],
                        token_to_id["awesome"]
                        ], 
                        
                        [token_to_id["statquest"],
                         token_to_id["is"],
                         token_to_id["what"],
                         token_to_id["<EOS>"],
                         token_to_id["awesome"]]])

labels = torch.tensor([[token_to_id["is"],
                         token_to_id["statquest"],
                         token_to_id["<EOS>"],
                         token_to_id["awesome"],
                         token_to_id["<EOS>"]], 
                         
                         [token_to_id["is"],
                          token_to_id["what"],
                          token_to_id["<EOS>"],
                          token_to_id["awesome"],
                          token_to_id["<EOS>"]]])

dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset)

# Position Encoding

The formula for the (standard, used in the paper **Attention is all you need**) position encoding is:  
PE_(pos, 2i) = sin(pos / 10000^(2i / d_model))  
PE_(pos, 2i+1) = cos(pos / 10000^(2i / d_model))  


In [4]:
class PositionEncoding(nn.Module):

    def __init__(self, d_model=2, max_len=6):

        super().__init__()

        # pe stands for position encoding
        pe = torch.zeros(max_len, d_model)

        # position is a column matrix (2D) of size [max_len, 1], e.g. [[0.], [1.], [2.]]
        position = torch.arange(start=0, end=max_len, step=1).float().unsqueeze(1)

        # Step is set to 2 because of "2i" in the formula, note that it is a 1D tensor, e.g. [0., 2.] as each position can have multiple embedding values
        embedding_index = torch.arange(start=0, end=d_model, step=2).float()

        # div_term is a row matrix (1D) with the same size as embedding_index
        div_term = torch.tensor(10000.)**(embedding_index / d_model)

        # Note: calculating the sin and cos values in this way only works when d_model is an even number, if d_model is odd, there will be a shape mismatch
        pe[:, 0::2] = torch.sin(position / div_term)
        pe[:, 1::2] = torch.cos(position / div_term)

        self.register_buffer(pe)
    

    def forward(self, word_embeddings):

        # Note: we might not need all the position encodings, as the number of tokens might not hit the maximum length (max_len)
        return word_embeddings + self.pe[:word_embeddings.size(0), :]
    


# Masked Self-Attention

In [5]:
class Attention(nn.Module):

    def __init__(self, d_model=2):

        super().__init__()

        # Create the weights associated with the query, key and value values
        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

        self.row_dim = 0
        self.col_dim = 1

    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):

        # Create the Q, K and V matrices
        q = self.W_q(encodings_for_q)
        k = self.W_q(encodings_for_k)
        v = self.W_q(encodings_for_v)

        # Calculate the similarity score between the query values and key values
        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

        # Scale the similarity score with the square root of d_model
        scaled_sims = sims / torch.tensor((k.size(self.col_dim))**0.5)

        # Mask the scaled similarity scores of the later tokens so that the earlier tokens can't cheat. Note: -1e9 is an approximation of negative infinity
        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)

        # Applying the softmax function to the scaled similarites determines the percentages of influence each token (in columns) should have on the others (in rows)
        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)

        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores


# Decoder-only Transformer

In [None]:
class DecoderOnlyTransformer(L.LightningModule):

    def __init__(self, num_tokens, d_model, max_len):

        super().__init__()

        # Word Embeddings
        self.we = nn.Embedding(num_embeddings=num_tokens, embedding_dim=d_model)

        # Position Encodings
        self.pe = PositionEncoding(d_model=d_model, max_len=max_len)

        # Masked Self-Attention
        self.attention = Attention(d_model=d_model)

        # Fully Connected layer
        self.fc = nn.Linear(in_features=d_model, out_features=num_tokens)

        # Calculate the loss with Cross Entropy; softmax is already included
        self.loss = nn.CrossEntropyLoss()

    def forward(self, token_ids):

        # Create word embeddings
        word_embeddings = self.we(token_ids)

        # Add position encodings to the word embeddings
        position_encoded = self.pe(word_embeddings)

        # Create a mask matrix for masking used in masked self attention
        mask = torch.ones()
        # Masked Self-Attention

        # Run the self attention values through a fully connected layer

        # Return the fully connected layer