In [1]:
%%capture
# %%capture prevents this cell from printing a ton of STDERR stuff to the screen

## First, check to see if lightning is installed, if not, install it.
import pip
try:
  __import__("lightning")
except ImportError:
  pip.main(['install', "lightning"])

In [2]:
import torch ## torch let's us create tensors and also provides helper functions
import torch.nn as nn ## torch.nn gives us nn.Module, nn.Embedding() and nn.Linear()
import torch.nn.functional as F # This gives us the softmax() and argmax()
from torch.optim import Adam # This is the optimizer we will use

import lightning as L # Lightning makes it easier to write, optimize and scale our code
from torch.utils.data import TensorDataset, DataLoader # We'll store our data in DataLoaders

----

# The input and output vocabularies and data

In [3]:
## first, a dictionary for the input vocabulary
input_vocab = {'<SOS>': 0, ## <SOS> = start of sequence.
               'lets': 1,
               'to': 2,
               'go': 3}

## Now a dictionary for the output vocabulary
output_vocab = {'<SOS>': 0,
                'ir': 1,
                'vamos': 2,
                'y': 3,
                '<EOS>': 4}


inputs = torch.tensor([[1, 3],
                       [2, 3]])

## Here are the spanish translations encoded using the output vocabulary
labels = torch.tensor([[2],
                      [1]])

dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset)

Now that we have created the input and output datasets and the **Dataloader** to train the model, let's start building it.

<a id="position"></a>
# Position Encoding

In [4]:
class PositionEncoding(nn.Module):
    def __init__(self, d_model=2, max_len=3):

        super().__init__()

        pe = torch.zeros(max_len, d_model)

        position = torch.arange(start=0, end=max_len, step=1).float().unsqueeze(1)


        div_term = 1/torch.tensor(10000.0)**(torch.arange(start=0, end=d_model, step=2).float() / d_model)

        pe[:, 0::2] = torch.sin(position * div_term) 
        pe[:, 1::2] = torch.cos(position * div_term) 

        self.register_buffer('pe', pe) 

    def forward(self, x):
 
        return x + self.pe[:x.size(0), :] 

<a id="attention"></a>
# Attention

In [5]:
class Attention(nn.Module): 
    def __init__(self, d_model=2):


        super().__init__()

        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

        self.row_dim = 0
        self.col_dim = 1

    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):
   
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)

        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))
        scaled_sims = sims / torch.tensor(q.size(self.col_dim)**0.5)

        if mask is not None:
 
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9) # I've also seen -1e20 and -9e15 used in masking

        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)

        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores

<a id="encoder"></a>
# The Encoder Class

In [6]:
class Encoder(nn.Module):

    def __init__(self, num_tokens=4, d_model=2, max_len=3):

        super().__init__()


        L.seed_everything(seed=42)

        self.we = nn.Embedding(num_embeddings=num_tokens,
                               embedding_dim=d_model)

        self.pe = PositionEncoding(d_model=d_model,
                                   max_len=max_len)

        self.self_attention = Attention(d_model=d_model)


    def forward(self, token_ids):

        word_embeddings = self.we(token_ids)

        position_encoded = self.pe(word_embeddings)

        self_attention_values = self.self_attention(position_encoded,
                                                    position_encoded,
                                                    position_encoded)

        output_values = position_encoded + self_attention_values

        return output_values

<a id="decoder"></a>
# The Decoder Class

In [16]:
class Decoder(nn.Module):
    def __init__(self, num_tokens=4, d_model=2, max_len=3):

        super().__init__()

        L.seed_everything(seed=43)

        self.we = nn.Embedding(num_embeddings=num_tokens,
                               embedding_dim=d_model)

        self.pe = PositionEncoding(d_model=d_model,
                                   max_len=max_len)

        self.self_attention = Attention(d_model=d_model)

        self.enc_dec_attention = Attention(d_model=d_model)

        self.fc_layer = nn.Linear(in_features=d_model, out_features=num_tokens)

        self.row_dim = 0
        self.col_dim = 1


    def forward(self, token_ids, encoder_values):
     word_embeddings = self.we(token_ids)
     position_encoded = self.pe(word_embeddings)

     # Create mask and move it to the same device as `position_encoded`
     mask = torch.tril(torch.ones((token_ids.size(self.row_dim), token_ids.size(self.row_dim)))).to(position_encoded.device)
     mask = mask == 0  # Invert the mask for masked_fill

     self_attention_values = self.self_attention(position_encoded,
                                                position_encoded,
                                                position_encoded,
                                                mask=mask)

     residual_connection_values = position_encoded + self_attention_values

     enc_dec_attention_values = self.enc_dec_attention(residual_connection_values,
                                                      encoder_values,
                                                      encoder_values)

     residual_connection_values = enc_dec_attention_values + residual_connection_values

     fc_layer_output = self.fc_layer(residual_connection_values)

     return fc_layer_output


<a id="transformer"></a>
# The Transformer Class

In [12]:
class Transformer(L.LightningModule):

    def __init__(self, input_size, output_size, d_model=2, max_len=3):

        super().__init__()

        self.encoder = Encoder(num_tokens=len(input_vocab), d_model=d_model, max_len=max_len)
        self.decoder = Decoder(num_tokens=len(output_vocab), d_model=d_model, max_len=max_len)

        self.loss = nn.CrossEntropyLoss()


    def forward(self, inputs, labels):

        encoder_values = self.encoder(inputs)
        output_presoftmax = self.decoder(labels, encoder_values)

        return(output_presoftmax)


    def configure_optimizers(self):

        return Adam(self.parameters(), lr=0.1)


    def training_step(self, batch, batch_idx):
      input_i, label_i = batch  # collect input

      # Ensure tensors are on the same device
      input_tokens = torch.cat((torch.tensor([0], device=input_i.device), input_i[0]))
      teacher_forcing = torch.cat((torch.tensor([0], device=label_i.device), label_i[0]))
      expected_output = torch.cat((label_i[0], torch.tensor([4], device=label_i.device)))

      # Forward pass
      output_i = self.forward(input_tokens, teacher_forcing)
      loss = self.loss(output_i, expected_output)

      return loss


In [17]:
## First, a reminder of our input and output vocabularies...
# input_vocab = {'<SOS>': 0, # Start
#                'lets': 1,
#                'to': 2,
#                'go': 3}

# output_vocab = {'<SOS>': 0, # Start
#                 'ir': 1,
#                 'vamos': 2,
#                 'y': 3,
#                 '<EOS>': 4} # End
max_length = 3

## Create a tranformer object...
transformer = Transformer(len(input_vocab), len(output_vocab), d_model=2, max_len=max_length)

## Encode the user input...
encoder_values = transformer.encoder(torch.tensor([0, 1, 3])) 

# Initialize predicted_ids with <SOS> token
predicted_ids = torch.tensor([0]) 

for i in range(max_length):

    prediction = transformer.decoder(predicted_ids, encoder_values)

    predicted_id = torch.tensor([torch.argmax(prediction[-1,:])])
    ## add the predicted token id to the list of predicted ids.
    predicted_ids = torch.cat((predicted_ids, predicted_id))

    if (predicted_id == 4): ## if the prediction is <EOS> then we are done.
        break

print("\npredicted_ids:", predicted_ids)


predicted_ids: tensor([0, 2, 0, 1])


And, without training, the transformer predicts **\<SOS> vamos \<SOS> ir**, but we wanted it to predict **\<SOS> vamos \<EOS>** So, since the transformer didn't correctly translate the English phrases into Spanish, we'll have to train it.

<a id="train"></a>
# Train the Transformer

In [18]:
transformer = Transformer(len(input_vocab), len(output_vocab), d_model=2, max_len=3)

In [19]:
trainer = L.Trainer(max_epochs=30)
trainer.fit(transformer, train_dataloaders=dataloader)

Training: |          | 0/? [00:00<?, ?it/s]

<a id="use"></a>
# Use the Trained Transformer

In [20]:
## First, a reminder of our input and output vocabularies...
# input_vocab = {'<SOS>': 0, # Start
#                'lets': 1,
#                'to': 2,
#                'go': 3}

# output_vocab = {'<SOS>': 0, # Start
#                 'ir': 1,
#                 'vamos': 2,
#                 'y': 3,
#                 '<EOS>': 4} # End

max_length = 3
row_dim = 0
col_dim = 1

encoder_values = transformer.encoder(torch.tensor([0, 1, 3])) 

predicted_ids = torch.tensor([0]) 

for i in range(max_length):

    prediction = transformer.decoder(predicted_ids, encoder_values)

    predicted_id = torch.tensor([torch.argmax(prediction[-1,:])])
    
    predicted_ids = torch.cat((predicted_ids, predicted_id))

    if (predicted_id == 4): 
        break

print("\npredicted_ids:", predicted_ids)


predicted_ids: tensor([0, 2, 4])


And the output is **\<SOS> vamos \<EOS>**, which is exactly what we want.