In [6]:
# !pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
import torch
import torch.nn as nn 
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader

# !pip install lightning
import lightning as L


In [14]:
# Test input
token_to_id = {
    "what"  : 0,
    "are"   : 1,
    "llm"   : 2, # little language model
    "awesome": 3,
    "<EOS>" : 4, 
}

id_to_token = dict(map(reversed, token_to_id.items()))

In [15]:
# Process The Prompt
# "what are llm"
# "llm are what"
# Respond -> "awesome"

# Tokens come from both processing a prompt and generating outputs
inputs = torch.tensor([
    #What are llm? -> awsome
    [token_to_id["what"],token_to_id["are"],token_to_id["llm"],token_to_id["<EOS>"],token_to_id["awesome"]],
    #LLM are what? -> awsome
    [token_to_id["llm"],token_to_id["are"],token_to_id["what"],token_to_id["<EOS>"],token_to_id["awesome"]]
])

labels = torch.tensor([
    #What [are llm? -> awsome]
    [token_to_id["are"],token_to_id["llm"],token_to_id["<EOS>"],token_to_id["awesome"],token_to_id["<EOS>"]],
    #LLM [are what? -> awsome]
    [token_to_id["are"],token_to_id["what"],token_to_id["<EOS>"],token_to_id["awesome"],token_to_id["<EOS>"]],
])

dataset    = TensorDataset(inputs,labels)
dataloader = DataLoader(dataset)

In [16]:
class PositionEncoding(nn.Module):
    # d_model = Dimension of model is number of word embedding values per token
    # max_len = maximum number of tokens our Transformer can process
    def __init__(self, d_model=2, max_len=6):
        super().__init__()                # overwrite the init inherited from nn.Module
        pe = torch.zeros(max_len,d_model) # create a lookup tabel of position encoding values, and intialize them to zero
        
        # Create a sequence of numbers for each position that a token can have in the input/output
        position = torch.arange(start=0, end=max_len, step=1).float().unsqueeze(1) # unsqueeze converts to column matrix
        embedding_index = torch.arange(start=0,end=d_model, step=2).float()        # i * 2

        div_term = 1/torch.tensor(10000.0)**(embedding_index / d_model) 
        pe[:, 0::2] = torch.sin(position*div_term)    # First Column has values from the SIN Function
        pe[:, 1::2] = torch.cos(position* div_term)   # Second Column has values from the COS function

        self.register_buffer("pe", pe) # Move PE to GPU 
    
    def forward(self, word_embeddings):
        # Add position encoding values to word embedding values
        return word_embeddings + self.pe[:word_embeddings.size(0), :]
    



In [17]:
# Attention 
# Create a class which does the matrix multiplication of quereys, keys, and values
class Attention(nn.Module):
    def __init__(self, d_model=2):
        super().__init__()
        self.d_model = d_model

        # Initialize weights for each querey (q), key (k) and  value (v)
        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

        self.row_dim = 0
        self.col_dim = 1

    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask = None):

        # Create quereym key and values using the encodings associated with each token
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)

        # Compute attention scores using the equation (q * k^T)/sqrt(d_model)
        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

        # Scale attention scores  - more for large models than this one
        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)

        # Mask out parts of the model we dont want to pay attention to
        # These are things like future tokens, <PAD> tokens
        # This is done by rplacing these values with very large negative numbers 
        # This causes the SoftMax Function to return probabilities for these values of zero
        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e10) 
        
        # Apply the softmax to determine what percent of each tokens value to use in the final attention value
        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)

        # Calculate score by percents
        attention_scores = torch.matmul(attention_percents, v)
        
        return attention_scores



In [18]:
class DecoderOnlyTransformer(L.LightningModule):
    
    def __init__(self, num_tokens=4, d_model=2, max_len=6):
        
        super().__init__()
        
        # set the seed so that you can get the same results for test
        L.seed_everything(seed=42)

        self.we = nn.Embedding(num_embeddings=num_tokens, 
                               embedding_dim=d_model)     
        
        self.pe = PositionEncoding(d_model=d_model, 
                                   max_len=max_len)

        self.self_attention = Attention(d_model=d_model)

        self.fc_layer = nn.Linear(in_features=d_model, out_features=num_tokens)
        
        self.loss = nn.CrossEntropyLoss()
        
        
    def forward(self, token_ids):
                
        word_embeddings = self.we(token_ids)        
        position_encoded = self.pe(word_embeddings)
        
        mask = torch.tril(torch.ones((token_ids.size(dim=0), token_ids.size(dim=0)), device=self.device))

        mask = mask == 0
        
        self_attention_values = self.self_attention(position_encoded, 
                                                    position_encoded, 
                                                    position_encoded, 
                                                    mask=mask)
   
        residual_connection_values = position_encoded + self_attention_values
        
        fc_layer_output = self.fc_layer(residual_connection_values)
        
        return fc_layer_output
    
    
    def configure_optimizers(self): 
        return Adam(self.parameters(), lr=0.1)
    
    def training_step(self, batch, batch_idx): 

        input_tokens, labels = batch 
        output = self.forward(input_tokens[0])
        loss = self.loss(output, labels[0])
                    
        return loss

In [20]:
## First, create a model from DecoderOnlyTransformer()
model = DecoderOnlyTransformer(num_tokens=len(token_to_id), d_model=2, max_len=6)

## Now create the input for the transformer...
model_input = torch.tensor([token_to_id["what"], 
                            token_to_id["are"], 
                            token_to_id["llm"], 
                            token_to_id["<EOS>"]])
input_length = model_input.size(dim=0)

## Now get get predictions from the model
predictions = model(model_input) 
## NOTE: "predictions" is the output from the fully connected layer,
##      not a softmax() function. We could, if we wanted to,
##      Run "predictions" through a softmax() function, but 
##      since we're going to select the item with the largest value
##      we can just use argmax instead...
## ALSO NOTE: "predictions" is a matrix, with one row of predicted values
##      per input token. Since we only want the prediction from the
##      last row (the most recent prediction) we use reverse index for the
##      row, -1.
predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
## We'll store predicted_id in an array, predicted_ids, that
## we'll add to each time we predict a new output token.
predicted_ids = predicted_id

## Now use a loop to predict output tokens until we get an 
## <EOS> token.
max_length = 6
for i in range(input_length, max_length):
    if (predicted_id == token_to_id["<EOS>"]): # if the prediction is <EOS>, then we are done
        break
    
    model_input = torch.cat((model_input, predicted_id))
    
    predictions = model(model_input) 
    predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
    predicted_ids = torch.cat((predicted_ids, predicted_id))
        
## Now printout the predicted output phrase.
print("Predicted Tokens:\n") 
for id in predicted_ids: 
    print("\t", id_to_token[id.item()])

Seed set to 42


Predicted Tokens:

	 <EOS>


In [21]:


trainer = L.Trainer(max_epochs=30)
trainer.fit(model, train_dataloaders=dataloader)



GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type             | Params | Mode 
------------------------------------------------------------
0 | we             | Embedding        | 10     | train
1 | pe             | PositionEncoding | 0      | train
2 | self_attention | Attention        | 12     | train
3 | fc_layer       | Linear           | 15     | train
4 | loss           | CrossEntropyLoss | 0      | train
------------------------------------------------------------
37        Trainable param

Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


In [22]:
model_input = torch.tensor([token_to_id["what"], 
                            token_to_id["are"], 
                            token_to_id["llm"], 
                            token_to_id["<EOS>"]])
input_length = model_input.size(dim=0)

predictions = model(model_input) 
predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
predicted_ids = predicted_id

for i in range(input_length, max_length):
    if (predicted_id == token_to_id["<EOS>"]): # if the prediction is <EOS>, then we are done
        break
    
    model_input = torch.cat((model_input, predicted_id))
    
    predictions = model(model_input) 
    predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
    predicted_ids = torch.cat((predicted_ids, predicted_id))
        
print("Predicted Tokens:\n") 
for id in predicted_ids: 
    print("\t", id_to_token[id.item()])

Predicted Tokens:

	 awesome
	 <EOS>


In [23]:
## Now let's ask the other question...
model_input = torch.tensor([token_to_id["llm"], 
                            token_to_id["are"], 
                            token_to_id["what"], 
                            token_to_id["<EOS>"]])
input_length = model_input.size(dim=0)

predictions = model(model_input) 
predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
predicted_ids = predicted_id

for i in range(input_length, max_length):
    if (predicted_id == token_to_id["<EOS>"]): # if the prediction is <EOS>, then we are done
        break
    
    model_input = torch.cat((model_input, predicted_id))
    
    predictions = model(model_input) 
    predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
    predicted_ids = torch.cat((predicted_ids, predicted_id))
        
print("Predicted Tokens:\n") 
for id in predicted_ids: 
    print("\t", id_to_token[id.item()])

Predicted Tokens:

	 awesome
	 <EOS>
