## Transformer

### Architecture

In [85]:
import torch
import torch.nn as nn
import math
import numpy as np

#### `SelfAttention`

In [61]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * self.heads == self.embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(embed_size, embed_size)
        self.keys = nn.Linear(embed_size, embed_size)
        self.queries = nn.Linear(embed_size, embed_size)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0] # Batch size
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
        # Shapes before splitting:
        # values:  (N, value_len, embed_size)
        # keys:    (N, key_len  , embed_size)
        # queries: (N, query_len, embed_size)

        # Split the embedding into self.heads different pieces
        values = self.values(values).reshape(N, value_len, self.heads, self.head_dim)
        keys = self.keys(keys).reshape(N, key_len, self.heads, self.head_dim)
        queries = self.queries(query).reshape(N, query_len, self.heads, self.head_dim)
        # Shapes after splitting: 
        # values:  (N, value_len, heads, head_dim)
        # keys:    (N, key_len  , heads, head_dim)
        # queries: (N, query_len, heads, head_dim)

        # Compute the dot product between queries and keys for each head, 
        # and divide by sqrt of head_dim for numerical stability
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) / math.sqrt(self.head_dim)
        # Shape of energy: (N, heads, query_len, key_len)

        # Apply mask
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
        # Shape of mask: (N, 1, 1, key_len)
        # 0 in key_len dimension means that the respective element in energy is set to -1e20
        # Mask will be broadcasted to (N, heads, query_len, key_len) by PyTorch automatically

        # Compute the attention weights for each head using the softmax function
        attention = torch.softmax(energy, dim=-1)
        # Shape of attention: (N, heads, query_len, key_len)

        # Multiply the attention weights with the values for each head and then concatenate
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.embed_size
        )
        # Shape of out: (N, query_len, embed_size)

        out = self.fc_out(out)
        return out

#### `TransformerBlock`: 

`SelfAttention` -> layerNorm -> Feed-Forward -> layerNorm

In [62]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        # Compute self-attention
        attention = self.attention(value, key, query, mask)
        # Shape of attention: (N, query_len, embed_size)

        # Add skip connection, run through normalization and dropout
        x = self.dropout(self.norm1(attention + query))
        # Shape of x: (N, query_len, embed_size)

        # Feed-forward network
        forward = self.feed_forward(x)
        # Shape of forward: (N, query_len, embed_size)

        # Add skip connection, run through normalization and dropout
        out = self.dropout(self.norm2(forward + x))
        # Shape of out: (N, query_len, embed_size)

        return out

#### Encoder: `num_layers` of `TransformerBlock` 

In [73]:
class Encoder(nn.Module):
    def __init__(
        self,
        src_vocab_size, # Size of the source vocabulary
        num_layers, # Number of TransformerBlocks
        max_length, # Maximum length of the sentence
        embed_size,
        heads,
        forward_expansion,
        dropout,
        device,
    ):
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    forward_expansion,
                    dropout,
                )
                for _ in range(num_layers)
            ]
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_length = x.shape

        # Positions is the index of the word in the sentence (0, 1, 2, ..., seq_length)
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)

        # Add word embeddings and position embeddings
        out = self.dropout(
            (self.word_embedding(x) + self.position_embedding(positions))
        )
        # Shape of out: (N, seq_length, embed_size)

        # In the Encoder the query, key, value are all the same
        for layer in self.layers:
            out = layer(out, out, out, mask)
            # Shape of out: (N, seq_length, embed_size)

        return out

#### `DecoderBlock`

Masked `SelfAttention` -> layerNorm -> `TransformerBlock`

In [74]:
class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout):
        super(DecoderBlock, self).__init__()
        self.norm = nn.LayerNorm(embed_size)
        self.attention = SelfAttention(embed_size, heads=heads)
        self.transformer_block = TransformerBlock(
            embed_size, heads, forward_expansion, dropout
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_value, encoder_key, src_mask, trg_mask):
        # Self attention on the target sentence
        attention = self.attention(x, x, x, trg_mask)

        # Add skip connection, run through normalization and dropout
        query = self.dropout(self.norm(attention + x))

        # Transformer block with encoder's values and keys
        out = self.transformer_block(encoder_value, encoder_key, query, src_mask)
        
        return out

#### Decoder: `num_layers` of `DecoderBlock`

In [75]:
class Decoder(nn.Module):
    def __init__(
        self,
        trg_vocab_size, # Size of the target vocabulary
        num_layers, # Number of DecoderBlocks
        max_length, # Maximum length of the sentence
        embed_size,
        heads,
        forward_expansion,
        dropout,
        device,
    ):
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                DecoderBlock(
                    embed_size,
                    heads,
                    forward_expansion,
                    dropout)
                for _ in range(num_layers)
            ]
        )
        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, trg_mask):
        N, seq_length = x.shape

        # Positions is the index of the word in the sentence (0, 1, 2, ..., seq_length)
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)

        # Add word embeddings and position embeddings
        x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))

        # In the Decoder the key and value are the encoder's output,
        # and the query is the output of the previous DecoderBlock
        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)

        out = self.fc_out(x)

        return out

#### `Transformer`

In [76]:
class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        trg_vocab_size,
        src_pad_idx, # Index of the padding token in the source vocabulary
        trg_pad_idx, # Index of the padding token in the target vocabulary
        num_layers,
        max_length,
        embed_size,
        heads,
        forward_expansion=4,
        dropout=0.0,
        device="cpu",
    ):

        super(Transformer, self).__init__()

        # Initialize the Encoder
        self.encoder = Encoder(
            src_vocab_size, # Size of the source vocabulary
            num_layers, # Number of TransformerBlocks
            max_length, # Maximum length of the sentence
            embed_size,
            heads,
            forward_expansion,
            dropout,
            device,
        )

        # Initialize the Decoder
        self.decoder = Decoder(
            trg_vocab_size, # Size of the target vocabulary
            num_layers, # Number of DecoderBlocks
            max_length, # Maximum length of the sentence
            embed_size,
            heads,
            forward_expansion,
            dropout,
            device,
        )

        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_src_mask(self, src):
        # Shape of src: (N, src_len)
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        # Shape of src_mask: (N, 1, 1, src_len)
        return src_mask.to(self.device)

    def make_trg_mask(self, trg):
        N, trg_len = trg.shape

        # Create a lower triangular matrix of ones with shape (trg_len, trg_len),
        # then expand it to (N, 1, trg_len, trg_len)
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        )

        return trg_mask.to(self.device)

    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask, trg_mask)
        return out

### Train the `Transformer`

In [144]:
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

#### Hyperparameters

In [None]:
# Hyperparameters
num_samples = 1000 # Number of samples in the dataset
max_length = 8 # Maximum length of the sequence including the <SOS> and <EOS> tokens
vocab_size = 99 + 3 # Numbers from 1 to 99 and three indices for padding, start of sequence, and end of sequence
sos_idx = 100 # Start of sequence index
eos_idx = 101 # End of sequence index
pad_idx = 0 # Padding index
num_layers = 2 # Number of Blocks in the Encoder and Decoder
embed_size = 32 # Embedding size for the tokens
heads = 2 # Number of heads in the Multi-Head Attention
forward_expansion = 4
dropout = 0.0
learning_rate = 0.001
batch_size = 32
num_epochs = 10

#### Data Generation

In [206]:
# Function to generate data for the three tasks: copy, reverse, and sort
def generate_data(num_samples, max_length, pad_idx, sos_idx, eos_idx, task):
    src_data = []
    trg_data = []
    
    for _ in range(num_samples):
        seq_length = np.random.randint(1, max_length - 1)  # Length of the random sequence
        sequence = torch.randint(1, 100, (seq_length,))  # Numbers from 1 to 99
        
        # Create the source sequence with <SOS> at the start and <EOS> at the end
        src_sequence = torch.cat([torch.tensor([sos_idx]), sequence, torch.tensor([eos_idx])])
        
        # Pad the source sequence to the maximum length
        padded_src_sequence = torch.cat([src_sequence, torch.full((max_length - len(src_sequence),), pad_idx)])
        
        if task == 'copy':
            # For copying task, target is the same as source
            padded_trg_sequence = padded_src_sequence.clone()
        
        elif task == 'reverse':
            # Reverse the sequence for the target and pad it to the maximum length
            reversed_sequence = sequence.flip(0)
            trg_sequence = torch.cat([torch.tensor([sos_idx]), reversed_sequence, torch.tensor([eos_idx])])
            padded_trg_sequence = torch.cat([trg_sequence, torch.full((max_length - len(trg_sequence),), pad_idx)])
        
        elif task == 'sort':
            # Sort the sequence for the target and pad it to the maximum length
            sorted_sequence = sequence.sort().values
            trg_sequence = torch.cat([torch.tensor([sos_idx]), sorted_sequence, torch.tensor([eos_idx])])
            padded_trg_sequence = torch.cat([trg_sequence, torch.full((max_length - len(trg_sequence),), pad_idx)])
        
        else:
            raise ValueError("Invalid task. Choose from 'copy', 'reverse', or 'sort'.")
        
        src_data.append(padded_src_sequence)
        trg_data.append(padded_trg_sequence)
    
    src_data = torch.stack(src_data)
    trg_data = torch.stack(trg_data)
    
    return src_data, trg_data

# Generate data for the copying task
src_data_copy, trg_data_copy = generate_data(num_samples, max_length, pad_idx, sos_idx, eos_idx, task='copy')

# Generate data for the reversing task
src_data_reverse, trg_data_reverse = generate_data(num_samples, max_length, pad_idx, sos_idx, eos_idx, task='reverse')

# Generate data for the sorting task
src_data_sort, trg_data_sort = generate_data(num_samples, max_length, pad_idx, sos_idx, eos_idx, task='sort')

print("Copy Task - Source Data:")
print(src_data_copy[:3])
print("Copy Task - Target Data:")
print(trg_data_copy[:3])
print()
print("Reverse Task - Source Data:")
print(src_data_reverse[:3])
print("Reverse Task - Target Data:")
print(trg_data_reverse[:3])
print()
print("Sort Task - Source Data:")
print(src_data_sort[:3])
print("Sort Task - Target Data:")
print(trg_data_sort[:3])

Copy Task - Source Data:
tensor([[100,  87,  21, 101,   0,   0,   0,   0],
        [100,  21,  14,  99,  75,   8,  37, 101],
        [100,  53, 101,   0,   0,   0,   0,   0]])
Copy Task - Target Data:
tensor([[100,  87,  21, 101,   0,   0,   0,   0],
        [100,  21,  14,  99,  75,   8,  37, 101],
        [100,  53, 101,   0,   0,   0,   0,   0]])

Reverse Task - Source Data:
tensor([[100,  29,  89,  85,  50,  58,  72, 101],
        [100,  72,   6,  14, 101,   0,   0,   0],
        [100,  41,  50,  26,  63,  43, 101,   0]])
Reverse Task - Target Data:
tensor([[100,  72,  58,  50,  85,  89,  29, 101],
        [100,  14,   6,  72, 101,   0,   0,   0],
        [100,  43,  63,  26,  50,  41, 101,   0]])

Sort Task - Source Data:
tensor([[100,  92,  91,  79, 101,   0,   0,   0],
        [100,  35,  90,  40,  90,  86, 101,   0],
        [100,  24, 101,   0,   0,   0,   0,   0]])
Sort Task - Target Data:
tensor([[100,  79,  91,  92, 101,   0,   0,   0],
        [100,  35,  40,  86,  90,  90

#### Training Loop

In [207]:
# Create DataLoader
train_dataset = TensorDataset(src_data_copy, trg_data_copy)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Initialize the Transformer model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Transformer(
    src_vocab_size=vocab_size,
    trg_vocab_size=vocab_size,
    src_pad_idx=pad_idx,
    trg_pad_idx=pad_idx,
    num_layers=num_layers,
    max_length=max_length,
    embed_size=embed_size,
    heads=heads,
    forward_expansion=forward_expansion,
    dropout=dropout,
    device=device
).to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for src, trg in train_loader:
        src = src.to(device)
        trg = trg.to(device)

        # Forward pass
        output = model(src, trg[:, :-1])
        output = output.reshape(-1, output.shape[2])
        trg = trg[:, 1:].reshape(-1)

        # Compute loss
        loss = criterion(output, trg)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(train_loader)}')

print("Training complete.")

Epoch [1/10], Loss: 4.249279007315636
Epoch [2/10], Loss: 4.05519562959671
Epoch [3/10], Loss: 3.8383244201540947
Epoch [4/10], Loss: 3.419331967830658
Epoch [5/10], Loss: 2.7885498106479645
Epoch [6/10], Loss: 2.1006357446312904
Epoch [7/10], Loss: 1.4595107808709145
Epoch [8/10], Loss: 0.943718858063221
Epoch [9/10], Loss: 0.580320998094976
Epoch [10/10], Loss: 0.35647930298000574
Training complete.


### Evaluating the model

In [211]:
# Example input sequence
input_sequence = [5, 23, 45, 55]

# Prepare the input sequence: Add <SOS> and <EOS> tokens, pad the sequence to the maximum length
input_tensor = torch.tensor([sos_idx] + input_sequence + [eos_idx] + [pad_idx] * (max_length - len(input_sequence) - 2)).unsqueeze(0)  # Add batch dimension

# Move the input tensor to the appropriate device
input_tensor = input_tensor.to(device)
print("Input Tensor:", input_tensor)

Input Tensor: tensor([[100,   5,  23,  45,  55, 101,   0,   0]])


In [212]:
# Function to generate output from the model
def generate_output(model, input_tensor, max_length, pad_idx, sos_idx, eos_idx):
    model.eval()
    with torch.no_grad():
        # Create a target tensor filled with pad_idx
        trg_tensor = torch.full((1, max_length), pad_idx).to(device)
        
        # Set the first token of the target tensor to the start token
        trg_tensor[0, 0] = sos_idx
        
        for i in range(1, max_length):
            # Pass the input and target tensors through the model
            output = model(input_tensor, trg_tensor[:, :i])
            
            # Get the token with the highest probability
            next_token = output.argmax(2)[:, -1]
            
            # Set the next token in the target tensor
            trg_tensor[0, i] = next_token.item()
            
            # Stop if the next token is the end token
            if next_token.item() == eos_idx:
                break
        
        return trg_tensor

# Generate output
output_tensor = generate_output(model, input_tensor, max_length, pad_idx, sos_idx, eos_idx)

print("Output Tensor:", output_tensor)

Output Tensor: tensor([[100,   5,  23,  45,  55, 101,   0,   0]])
