# Building a Decoder only Transformer 

We will be building a decoder only transformer, which is foundation for ChatGPT.

For this example, all we want is for the transformer to respond to two different prompts. 
* What is StatQuest? --> Answer: Awesome! 
* StatQuest is what? --> Answer: Awesome!

In [1]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F 

from torch.optim import Adam 
from torch.utils.data import TensorDataset, DataLoader

import lightning as L 

These dictionaries will make it easy to format the input to the Transformer and interpret the output from the Transformer.

In [2]:
token_to_id = {'what':0,
               'is':1,
               'statquest':2,
               'awesome':3,
               '<EOS>':4}

id_to_token = dict(map(reversed, token_to_id.items()))

Convert the prompts and responses into a dataset. 
![](https://github.com/kcsanjeeb/AI_Codes/blob/main/public/transformer-1.png?raw=true)


# Input and Labels

In [3]:
# Input and Label for 
# First Prompt: What is StatQuest? --> Awesome 
# Second Prompt: StatQuest is What? --> Awesome                 )

# Input tensor (batch_size=2, sequence_length=5)
inputs = torch.tensor([
    [token_to_id["what"], token_to_id["is"], token_to_id["statquest"], token_to_id["<EOS>"], token_to_id["awesome"]],
    [token_to_id["statquest"], token_to_id["is"], token_to_id["what"], token_to_id["<EOS>"], token_to_id["awesome"]]
])

# Labels tensor (should match inputs shape)
labels = torch.tensor([
    [token_to_id["is"], token_to_id["statquest"], token_to_id["<EOS>"], token_to_id["awesome"], token_to_id["<EOS>"]],
    [token_to_id["is"], token_to_id["what"], token_to_id["<EOS>"], token_to_id["awesome"], token_to_id["<EOS>"]]
])

dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset)

###  Positional Encoding

#### The Original Equation
$$ PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) $$ 
$$ PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) $$ 

Where:
* `pos` is the position in the sequence
* `i` is the dimension index
* `d_model` (Dimention of the model), is the number of word embeddings values per token 

The alternating sin/cos pattern allows the model to learn relative positions through linear transformations.

* `max_len` : Maximum number of tokens our Transformer can process (input & output combined)

In [4]:
class PositionEncoding(nn.Module):
    def __init__(self, d_model=2, max_len=6):
        super().__init__()
        pe = torch.zeros(max_len, d_model)  
        position = torch.arange(start=0,end=max_len,step=1).float().unsqueeze(1)
        embedding_index = torch.arange(start=0, end=d_model, step=2).float()
        div_term = 1/torch.tensor(10000.0)**(embedding_index / d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
        
    def forward(self,word_embeddings):
        return word_embeddings + self.pe[:word_embeddings.size(0), :]

* `pe = torch.zeros(max_len, d_model)` : We start by creating a matrix that we will call `pe`, for position encodings, that is full of 0's. pe will have max_len rows and d_model columns.
* `position = torch.arange(start=0,end=max_len,step=1).float().unsqueeze(1)`:
    * Now, we create a column matrix, position, that represents the positions `pos` for each token
    * Using torch.arange to create a seq of numbers between 0 and max_len  
    * .float() ensures that numbers are float 
    * .unsqueeze(1) turns the seq of number into a column matrix 
* `embedding_index = torch.arange(start=0, end=d_model, step=2).float()` :
    * Now, we create a row matrix, embedding_position, that represents the index `i` times `2`, for each word embedding.
* `div_term = 1/torch.tensor(10000.0)**(embedding_index / d_model)` : Devisor like formula
* `pe[:, 0::2] = torch.sin(position * div_term)` : Takes the first column as value of sin() function
* `pe[:, 1::2] = torch.cos(position * div_term)` : Takes the second column as value of cos() function
* `self.register_buffer('pe', pe)` : Ensure that pe gets moved to GPU if we use one. 


### Masked Self Attention 
* We need to calculate the query, key, and values for each token. 

* `self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)` : 
    * in_features define how many rows are in the weight matrix, so we set it to d_model
    * out_features define the number of columns in the weight matrix, so we set it to d_model as well 
* W_q untrained weights needed to calculate **query**
* W_k weights needed to calculate the **keys**
* W_v weights needed to calculate the **values**

Encoder decoder transformers have something called **Encoder decoder attention**, where the **keys** and **values** are calculated from the encoded tokens in the **Encoder** , and the **queries** are calculated from the encoded tokens from the **decoder**.  

#### Attention Equation
$$ \text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^\top}{\sqrt{d_k}}\right)\mathbf{V} $$ 

In [5]:
class Attention(nn.Module):
    def __init__(self,d_model=2):
        super().__init__()
        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)    
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)  
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)  
        
        self.row_dim = 0 
        self.col_dim = 1
    
    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):
        
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)
        
        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))
        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)
        
        if mask is not None: 
            scaled_sims = scaled_sims.masked_fill(mask=mask , value=-1e9)
        
        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)
        
        attention_scores = torch.matmul(attention_percents , v)
        
        return attention_scores 

In [6]:
class DecoderOnlyTransformer(L.LightningModule):
    def __init__(self, num_tokens=4, d_model=2, max_len=6):
        super().__init__()
        self.we = nn.Embedding(num_embeddings=num_tokens, embedding_dim=d_model)
        self.pe = PositionEncoding(d_model=d_model,max_len=max_len)
        self.self_attention = Attention(d_model=d_model)
        self.fc_layer = nn.Linear(in_features=d_model, out_features=num_tokens)
        self.loss = nn.CrossEntropyLoss()
    
    def forward(self, token_ids):
        word_embeddings = self.we(token_ids)
        position_encoded = self.pe(word_embeddings)
        
        # Create mask on the same device as the model
        mask = torch.tril(torch.ones((token_ids.size(0), token_ids.size(0)), 
                                   device=self.device))  # Add device parameter
        mask = mask == 0
        
        self_attention_values = self.self_attention(position_encoded, 
                                                    position_encoded,
                                                    position_encoded,
                                                    mask=mask)

        residual_connection_values = position_encoded + self_attention_values
        fc_layer_output = self.fc_layer(residual_connection_values)
        return fc_layer_output
    
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=0.1)
    
    def training_step(self, batch, batch_idx):
        input_tokens, labels = batch 
        output = self.forward(input_tokens[0])
        loss = self.loss(output, labels[0])
        return loss

* `num_tokens`: Number of tokens in the vocabulary. 
* `d_model`: The number of tokens we want to represent each token 
* `max_len`: The maximum length of input plus output 

In [7]:
model = DecoderOnlyTransformer(num_tokens=len(token_to_id), d_model=2, max_len=6 )
model_input =  torch.tensor([token_to_id["what"],
                            token_to_id["is"],
                            token_to_id["statquest"],
                            token_to_id["<EOS>"]])
input_length = model_input.size(dim=0)

predictions = model(model_input)
predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
predicted_ids = predicted_id

max_length = 6 
for i in range(input_length, max_length):
    if (predicted_id == token_to_id["<EOS>"]):
        break
    model_input = torch.cat((model_input, predicted_id))
    
    predictions = model(model_input)
    predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
    predicted_ids = torch.cat((predicted_ids, predicted_id))

print("Predicted Tokens :\n")
for id in predicted_ids:
    print("\t", id_to_token[id.item()])

Predicted Tokens :

	 statquest
	 statquest
	 awesome


In [11]:
trainer = L.Trainer(max_epochs=30)
trainer.fit(model, train_dataloaders=dataloader)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name           | Type             | Params | Mode 
------------------------------------------------------------
0 | we             | Embedding        | 10     | train
1 | pe             | PositionEncoding | 0      | train
2 | self_attention | Attention        | 12     | train
3 | fc_layer       | Linear           | 15     | train
4 | loss           | CrossEntropyLoss | 0      | train
------------------------------------------------------------
37        Trainable params
0         Non-trainable params
37        Total params
0.000     Total estimated model params size (MB)
8         Modules in train mode
0         Modules in eval mode
/Users/sanjeeb/Desktop/Harbin Institute of Technology/Artificial Intelligence/AI_Codes/.venv/lib/python3.1

Epoch 29: 100%|██████████| 2/2 [00:00<00:00, 183.86it/s, v_num=0]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 29: 100%|██████████| 2/2 [00:00<00:00, 112.56it/s, v_num=0]


# Using the trained Model 

In [12]:
model_input =  torch.tensor([token_to_id["what"],
                            token_to_id["is"],
                            token_to_id["statquest"],
                            token_to_id["<EOS>"]])
input_length = model_input.size(dim=0)

predictions = model(model_input)
predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
predicted_ids = predicted_id

max_length = 6 
for i in range(input_length, max_length):
    if (predicted_id == token_to_id["<EOS>"]):
        break
    model_input = torch.cat((model_input, predicted_id))
    
    predictions = model(model_input)
    predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
    predicted_ids = torch.cat((predicted_ids, predicted_id))

print("Predicted Tokens :\n")
for id in predicted_ids:
    print("\t", id_to_token[id.item()])

Predicted Tokens :

	 awesome
	 <EOS>


In [13]:
model_input =  torch.tensor([token_to_id["statquest"],
                            token_to_id["is"],
                            token_to_id["what"],
                            token_to_id["<EOS>"]])
input_length = model_input.size(dim=0)

predictions = model(model_input)
predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
predicted_ids = predicted_id

max_length = 6 
for i in range(input_length, max_length):
    if (predicted_id == token_to_id["<EOS>"]):
        break
    model_input = torch.cat((model_input, predicted_id))
    
    predictions = model(model_input)
    predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
    predicted_ids = torch.cat((predicted_ids, predicted_id))

print("Predicted Tokens :\n")
for id in predicted_ids:
    print("\t", id_to_token[id.item()])

Predicted Tokens :

	 awesome
	 <EOS>
