In [1]:
import torch
import torch.nn as nn
from tqdm import tqdm
import os
import wandb

import components
import utils

In [2]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device

device(type='mps')

# Download dataset

### download and inspect dataset

In [4]:
from datasets import load_dataset, DatasetDict

In [5]:
# Load the WMT14 dataset for German-English translation
dataset = load_dataset('wmt14', 'de-en')

In [6]:
dataset

DatasetDict({
    train: Dataset({
        features: ['translation'],
        num_rows: 4508785
    })
    validation: Dataset({
        features: ['translation'],
        num_rows: 3000
    })
    test: Dataset({
        features: ['translation'],
        num_rows: 3003
    })
})

In [7]:
dataset['train'][4]

{'translation': {'de': 'Heute möchte ich Sie bitten - das ist auch der Wunsch einiger Kolleginnen und Kollegen -, allen Opfern der Stürme, insbesondere in den verschiedenen Ländern der Europäischen Union, in einer Schweigeminute zu gedenken.',
  'en': "In the meantime, I should like to observe a minute' s silence, as a number of Members have requested, on behalf of all the victims concerned, particularly those of the terrible storms, in the various countries of the European Union."}}

In [8]:
# select a very small segment for experimentation
# Take a small subset for experimentation
small_train_dataset = dataset['train'].select(range(20))
small_val_dataset = dataset['validation'].select(range(5))

In [9]:
small_train_dataset

Dataset({
    features: ['translation'],
    num_rows: 20
})

### Tokenization

In [10]:
# as we are following the original `Attention is all you need paper` we will use Byte-Pair Encoding
from tokenizers import ByteLevelBPETokenizer

In [11]:
# Load the trained tokenizer
tokenizer = ByteLevelBPETokenizer(
    "bpe_tokenizer/vocab.json",
    "bpe_tokenizer/merges.txt"
)

In [12]:
# Test the tokenizer
print(tokenizer.encode("Das ist ein Beispiel.").ids)

print([tokenizer.id_to_token(token) for token in tokenizer.encode("Das ist ein Beispiel").ids])
# Should return something like ['<s>', 'Das', 'ist', 'ein', 'Beispiel', '</s>']

print(tokenizer.token_to_id("</s>"))
# Should return a valid token ID for '</s>'

print(tokenizer.decode(tokenizer.encode("Das ist ein Beispiel.").ids))


[789, 423, 328, 3010, 18]
['Das', 'Ġist', 'Ġein', 'ĠBeispiel']
2
Das ist ein Beispiel.


In [13]:
PAD_TOKEN_ID = tokenizer.token_to_id('<pad>')
BOS_TOKEN_ID = tokenizer.token_to_id('<s>')
EOS_TOKEN_ID = tokenizer.token_to_id('</s>')

In [14]:
# Create a pytorch dataset class
from torch.utils.data import Dataset, DataLoader

class TranslationDataset(Dataset):
    def __init__(self, dataset, tokenizer, bos_token_id: int = BOS_TOKEN_ID, eos_token_id: int = EOS_TOKEN_ID ,pad_token_id:int = PAD_TOKEN_ID, max_length: int = 512):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.bos = bos_token_id
        self.eos = eos_token_id
        self.pad_token_id = pad_token_id
        

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        src_sentence = self.dataset[idx]['translation']['de']
        tgt_sentence = self.dataset[idx]['translation']['en']

        # tokenize the source and target
        src_tokens = self.tokenizer.encode(src_sentence).ids
        tgt_tokens = self.tokenizer.encode(tgt_sentence).ids

        # pad and truncate
        src_tokens = torch.tensor(self.pad_and_truncate(self.add_special_tokens(src_tokens)))
        tgt_tokens = torch.tensor(self.pad_and_truncate(self.add_special_tokens(tgt_tokens)))

        # # create attention masks
        # src_mask = (src_tokens != self.pad_token_id).int()
        # tgt_mask = (src_tokens != self.pad_token_id).int()

        # # create look ahead mask
        # look_ahead_mask = self.create_causal_mask(len(tgt_tokens))


        return {
            'src_sentence': src_sentence, 
            'tgt_sentence': tgt_sentence, 
            'src_tokens': src_tokens,
            'tgt_tokens': tgt_tokens,
            # 'src_mask': src_mask,
            # 'tgt_mask': tgt_mask,
            # 'look_ahead_mask': look_ahead_mask,
            # 'combined_mask': tgt_mask & look_ahead_mask
        }

    def pad_and_truncate(self, tokens):
        if len(tokens) < self.max_length:
            tokens = tokens + [self.pad_token_id] * (self.max_length - len(tokens))
        else:
            tokens = tokens[:self.max_length]
        
        return tokens
    
    def add_special_tokens(self, tokens):
        return [self.bos] + tokens + [self.eos]

    def create_causal_mask(self, size):
        # create an lower triangular matrix for the purposes of look ahead masking
        return torch.tril(torch.ones(size, size)).type(torch.uint8)

In [15]:
small_translation_ds = TranslationDataset(small_train_dataset, tokenizer=tokenizer, pad_token_id=PAD_TOKEN_ID, max_length=30)
small_translation_ds

<__main__.TranslationDataset at 0x292ced610>

In [16]:
small_translation_ds[0]

{'src_sentence': 'Wiederaufnahme der Sitzungsperiode',
 'tgt_sentence': 'Resumption of the session',
 'src_tokens': tensor([    0, 23062, 17719,   319, 26699,     2,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1]),
 'tgt_tokens': tensor([    0,  8859, 27958,   304,   280,  9974,     2,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1])}

In [17]:
# collate function for handling masks

def create_causal_mask(size):
    """
    Creates a causal mask (look-ahead mask) that prevents attending to future tokens.
    size: Length of the sequence.
    """
    attn_shape = (1, size, size)
    return torch.tril(torch.ones(attn_shape)).type(torch.uint8)  # Shape: (1, seq_length, seq_length)

def create_std_mask(tgt, pad_token_id = PAD_TOKEN_ID):
    tgt_mask = (tgt != pad_token_id).unsqueeze(-2)
    tgt_mask = tgt_mask & create_causal_mask(tgt.size(-1))
    return tgt_mask
    
def collate_fn(batch, pad_token_id = PAD_TOKEN_ID):
    src_batch = torch.stack([item['src_tokens'] for item in batch])
    tgt_batch = torch.stack([item['tgt_tokens'] for item in batch])

    # create source masks
    src_mask = (src_batch != pad_token_id).unsqueeze(-2).int() # shape: (bs, seq_length, 1)
    tgt = tgt_batch[:, :-1]
    tgt_y = tgt_batch[:, 1:]
    tgt_mask = create_std_mask(tgt, pad_token_id=pad_token_id)

    return {
        'src_tokens': src_batch,
        'tgt_input': tgt, 
        'tgt_output': tgt_y,
        'src_mask': src_mask, 
        'tgt_mask': tgt_mask,
    }


In [18]:
small_dl = DataLoader(small_translation_ds, collate_fn=collate_fn, batch_size=4)

for batch in small_dl:
    print(f"Source tokens:", batch['src_tokens'].shape)
    print(f"Target tokens:", batch['tgt_input'].shape)
    print(f"Target output tokens:", batch['tgt_output'].shape)
    print(f"Source mask:", batch['src_mask'].shape)
    print(f"Target mask:", batch['tgt_mask'].shape)
    break

Source tokens: torch.Size([4, 30])
Target tokens: torch.Size([4, 29])
Target output tokens: torch.Size([4, 29])
Source mask: torch.Size([4, 1, 30])
Target mask: torch.Size([4, 29, 29])


# Creating each layer step by step

### Scaled Dot-Product Attention

In [19]:
import torch.nn.functional as F
import torch
import math

def scaled_dpa(query, key, value, mask=None, verbose=False):
    """
    Implements scaled dot product attention.
    Args:
        query: (batch_size, seq_length, dim_k)
        key: (batch_size, seq_length, dim_k)
        value: (batch_size, seq_length, dim_v)
        mask: (batch_size, seq_length) or None
        verbose: Boolean default False
    Returns:
        attention_output: (batch_size, seq_length, dim_v)
        attention_weights: (batch_size, seq_length, seq_length)
    """

    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)# (bs, seq_length, seq_length)

    if verbose:
        print(f"Scores shape: {scores.shape}")
    
    # apply the mask if necessary
    if mask is not None:
        scores = scores.masked_fill(mask==0, float('-inf'))
    
    # apply softmax to get attention_weights
    attention_weights = F.softmax(scores, dim=-1) # (bs, seq_length, seq_length)

    if verbose:
        print(f"Attention weights shape: {attention_weights.shape}")
    
    output = torch.matmul(attention_weights, value)
    
    if verbose:
        print(f"Attention output shape: {output.shape}")
    
    return output, attention_weights

In [20]:
# Batch size = 1, Sequence length = 5, Embedding dimension = 4 (d_k)
batch_size = 3
seq_length = 5

# example scores
scores = torch.rand(batch_size, seq_length, seq_length)
print(scores)

# Optional mask
mask = torch.tensor([
    [1, 1, 1, 0, 0], 
    [1, 1, 0, 0, 0],
    [1, 1, 1, 1, 1],
])
print(mask)

mask = mask.unsqueeze(1)
print(mask.shape)

scores = scores.masked_fill(mask==0, float('-inf'))
print(scores)


tensor([[[0.2889, 0.3645, 0.5073, 0.0601, 0.1955],
         [0.4960, 0.7317, 0.2633, 0.5181, 0.3663],
         [0.5461, 0.5656, 0.1713, 0.6215, 0.9109],
         [0.1746, 0.3636, 0.8319, 0.7131, 0.2682],
         [0.4274, 0.7795, 0.9693, 0.7363, 0.5564]],

        [[0.9173, 0.4745, 0.6934, 0.0855, 0.6556],
         [0.4114, 0.0378, 0.4886, 0.3944, 0.3892],
         [0.0564, 0.7735, 0.9761, 0.2099, 0.1457],
         [0.7007, 0.1584, 0.1256, 0.8583, 0.9621],
         [0.7026, 0.4526, 0.0107, 0.3048, 0.5285]],

        [[0.5153, 0.2527, 0.0817, 0.3672, 0.7689],
         [0.5652, 0.5058, 0.1088, 0.6928, 0.2403],
         [0.0115, 0.8514, 0.6130, 0.9758, 0.5974],
         [0.7339, 0.7190, 0.9975, 0.1526, 0.9539],
         [0.8057, 0.1836, 0.2223, 0.7273, 0.5501]]])
tensor([[1, 1, 1, 0, 0],
        [1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1]])
torch.Size([3, 1, 5])
tensor([[[0.2889, 0.3645, 0.5073,   -inf,   -inf],
         [0.4960, 0.7317, 0.2633,   -inf,   -inf],
         [0.5461, 0.5656, 0.1

In [21]:
# test scaled dpa
# Example of how to use scaled_dpa with random tensors

# Batch size = 1, Sequence length = 5, Embedding dimension = 4 (d_k)
batch_size = 3
seq_length = 5
embedding_dim = 4

# Random queries, keys, and values
query = torch.rand(batch_size, seq_length, embedding_dim).to(device)
key = torch.rand(batch_size, seq_length, embedding_dim).to(device)
value = torch.rand(batch_size, seq_length, embedding_dim).to(device)

print(f"Query shape: {query.shape}")

# Optional mask
mask = torch.tensor([
    [1, 1, 1, 0, 0], 
    [1, 1, 0, 0, 0],
    [1, 1, 1, 1, 1],
])
mask = mask.unsqueeze(1).to(device)

print(f"Mask shape: {mask.shape}")

# Test scaled_dpa
output, attention_weights = scaled_dpa(query, key, value, mask, verbose=True)

print("Attention Output:\n", output)
print("Attention Weights:\n", attention_weights)


Query shape: torch.Size([3, 5, 4])
Mask shape: torch.Size([3, 1, 5])
Scores shape: torch.Size([3, 5, 5])
Attention weights shape: torch.Size([3, 5, 5])
Attention output shape: torch.Size([3, 5, 4])
Attention Output:
 tensor([[[0.2269, 0.3303, 0.5288, 0.5487],
         [0.2299, 0.3100, 0.5195, 0.5397],
         [0.2408, 0.2945, 0.5300, 0.5479],
         [0.2357, 0.3024, 0.5256, 0.5445],
         [0.2238, 0.3349, 0.5260, 0.5466]],

        [[0.4412, 0.2081, 0.4441, 0.5420],
         [0.4452, 0.2067, 0.4396, 0.5500],
         [0.4453, 0.2067, 0.4395, 0.5501],
         [0.4298, 0.2120, 0.4568, 0.5192],
         [0.4321, 0.2112, 0.4542, 0.5239]],

        [[0.3537, 0.3539, 0.4510, 0.4363],
         [0.3498, 0.3251, 0.4349, 0.3921],
         [0.3582, 0.3571, 0.4370, 0.4307],
         [0.3551, 0.3546, 0.4511, 0.4361],
         [0.3521, 0.3384, 0.4405, 0.4132]]], device='mps:0')
Attention Weights:
 tensor([[[0.4430, 0.3145, 0.2425, 0.0000, 0.0000],
         [0.4025, 0.3140, 0.2835, 0.0000, 0.0

In [22]:
# testing with mask
def create_padding_mask(seq):
    """
    Creates a padding mask (1 for valid tokens, 0 for padding tokens).
    seq: Tensor of shape (batch_size, seq_length)
    """
    return (seq != 0).unsqueeze(1).unsqueeze(2)  # Shape: (batch_size, 1, 1, seq_length)

def create_causal_mask(size):
    """
    Creates a causal mask (look-ahead mask) that prevents attending to future tokens.
    size: Length of the sequence.
    """
    return torch.tril(torch.ones(size, size)).type(torch.uint8)  # Shape: (seq_length, seq_length)

# Test scaled_dpa with padding and causal masks

# Batch size = 1, Sequence length = 5, Embedding dimension = 4 (d_k)
batch_size = 1
seq_length = 5
embedding_dim = 4

# Random queries, keys, and values
query = torch.rand(batch_size, seq_length, embedding_dim).to(device)
key = torch.rand(batch_size, seq_length, embedding_dim).to(device)
value = torch.rand(batch_size, seq_length, embedding_dim).to(device)

# Create a random sequence with padding (0 represents padding token)
src_tokens = torch.tensor([[1, 2, 3, 0, 0]]).to(device)  # Example with 2 padding tokens

# Create a padding mask
padding_mask = create_padding_mask(src_tokens).to(device)  # Shape: (batch_size, 1, 1, seq_length)

# Create a causal mask (look-ahead mask)
causal_mask = create_causal_mask(seq_length).to(device)  # Shape: (seq_length, seq_length)

# Combine the masks (for testing both padding and causal masking together)
combined_mask = padding_mask & causal_mask.unsqueeze(0).to(device)

# Test scaled_dpa with the mask
output, attention_weights = scaled_dpa(query, key, value, combined_mask, verbose=True)

print("Attention Output:\n", output)
print("Attention Weights:\n", attention_weights)
print("Padding Mask:\n", padding_mask)
print("Causal Mask:\n", causal_mask)
print("Combined Mask:\n", combined_mask)



Scores shape: torch.Size([1, 5, 5])
Attention weights shape: torch.Size([1, 1, 5, 5])
Attention output shape: torch.Size([1, 1, 5, 4])
Attention Output:
 tensor([[[[0.0599, 0.9591, 0.5854, 0.6830],
          [0.0922, 0.6035, 0.3744, 0.4882],
          [0.2372, 0.7045, 0.3788, 0.3680],
          [0.2349, 0.7237, 0.3904, 0.3791],
          [0.2515, 0.7268, 0.3861, 0.3616]]]], device='mps:0')
Attention Weights:
 tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.4658, 0.5342, 0.0000, 0.0000, 0.0000],
          [0.3558, 0.3557, 0.2885, 0.0000, 0.0000],
          [0.3855, 0.3269, 0.2875, 0.0000, 0.0000],
          [0.3615, 0.3194, 0.3191, 0.0000, 0.0000]]]], device='mps:0')
Padding Mask:
 tensor([[[[ True,  True,  True, False, False]]]], device='mps:0')
Causal Mask:
 tensor([[1, 0, 0, 0, 0],
        [1, 1, 0, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1]], device='mps:0', dtype=torch.uint8)
Combined Mask:
 tensor([[[[1, 0, 0, 0, 0],
         

### Multi-head attention

In [23]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads: int, d_model: int, dropout=0.1, verbose=False):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads."
        self.num_heads = num_heads
        self.d_model = d_model
        self.d_k = d_model // num_heads
        self.verbose = verbose

        if self.verbose:
            print(f"Num heads: {num_heads}")
            print(f"Embedding dimension: {d_model}")
            print(f"per head dimension: {self.d_k}")
    
        # linear layers to project the inputs to query, key, and value
        self.query_linear = nn.Linear(d_model, d_model)
        self.key_linear = nn.Linear(d_model, d_model)
        self.value_linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(p=dropout) 
        self.output_linear = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):
        # query shape is bs, seq_length, d_model
        # key shape is bs, seq_length, d_model
        # value shape is bs, d_model, d_model
        batch_size = query.size(0)
        seq_length = query.size(1)

        if mask is not None:
            mask = mask.unsqueeze(1) # Same mask applied to all heads. 

        if self.verbose and mask is not None:
            print(f"Mask shape (after unsqueezing at 1): {mask.shape}")

        # apply linear layers
        query = self.query_linear(query)   # shape bs, seq_length, d_model
        key = self.key_linear(key) #shape: bs, seq_length, d_model
        value = self.value_linear(value) # shape: bs, d_model, d_model

        if self.verbose:
            print(f"Query shape: {query.shape}")
            print(f"Key shape: {key.shape}")
            print(f"Value shape: {value.shape}")
        
        # reshape and split into multiple heads
        query = query.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) # (bs, num_heads, seq_length, d_k)
        key = key.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) # (bs, num_heads, seq_length, d_k)
        value = value.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) #(bs, num_heads, seq_length, d_k)

        if self.verbose:
            print(f"Shapes after projections for query, key, value...")
            print(f"{query.shape}, {key.shape}, {value.shape}")

        attn_output, attn_weights = scaled_dpa(query, key, value, mask, verbose = self.verbose)

        # we've separated the query key and value into separate heads and then computed the scaled dot-product attention for each head.
        # Now we must put them back together. 
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)

        if self.verbose:
            print(f"Attention output shape after concat: {attn_output.shape}")

        # apply the final linear layer transformation
        output = self.output_linear(attn_output)
        if self.verbose:
            print(f"Output shape: {output.shape}")

        return output, attn_weights

In [24]:
# visualizing how the transpsoe works
batch_size = 1
seq_length = 4
d_model = 8
num_heads = 2
d_k = d_model // num_heads
query = torch.arange(1, seq_length*d_model + 1).view(batch_size, seq_length, d_model).to(device)
print(query.shape)
print(f"unchanged query: {query}")

query = query.view(batch_size, -1, num_heads, d_k)
print(f"prior to transpose query: {query}")

query = query.transpose(1, 2)
print(f"transposed query: {query}")
print(query.shape)

torch.Size([1, 4, 8])
unchanged query: tensor([[[ 1,  2,  3,  4,  5,  6,  7,  8],
         [ 9, 10, 11, 12, 13, 14, 15, 16],
         [17, 18, 19, 20, 21, 22, 23, 24],
         [25, 26, 27, 28, 29, 30, 31, 32]]], device='mps:0')
prior to transpose query: tensor([[[[ 1,  2,  3,  4],
          [ 5,  6,  7,  8]],

         [[ 9, 10, 11, 12],
          [13, 14, 15, 16]],

         [[17, 18, 19, 20],
          [21, 22, 23, 24]],

         [[25, 26, 27, 28],
          [29, 30, 31, 32]]]], device='mps:0')
transposed query: tensor([[[[ 1,  2,  3,  4],
          [ 9, 10, 11, 12],
          [17, 18, 19, 20],
          [25, 26, 27, 28]],

         [[ 5,  6,  7,  8],
          [13, 14, 15, 16],
          [21, 22, 23, 24],
          [29, 30, 31, 32]]]], device='mps:0')
torch.Size([1, 2, 4, 4])


In [25]:
query = torch.arange(1, seq_length*d_model + 1).view(batch_size, seq_length, d_model)

query = query.view(batch_size, num_heads, -1, d_k)
print(f"Directly reshaping query: {query}")

Directly reshaping query: tensor([[[[ 1,  2,  3,  4],
          [ 5,  6,  7,  8],
          [ 9, 10, 11, 12],
          [13, 14, 15, 16]],

         [[17, 18, 19, 20],
          [21, 22, 23, 24],
          [25, 26, 27, 28],
          [29, 30, 31, 32]]]])


In [26]:
# Test MultiHeadAttention with random inputs

# Define parameters
num_heads = 8
d_model = 64
seq_length = 5
batch_size = 1

# Random inputs for query, key, and value
query = torch.rand(batch_size, seq_length, d_model).to(device)
key = torch.rand(batch_size, seq_length, d_model).to(device)
value = torch.rand(batch_size, seq_length, d_model).to(device)

# No mask for now (can add later)
mask = None

# Create MultiHeadAttention object
multihead_attn = MultiHeadAttention(num_heads=num_heads, d_model=d_model, verbose=True).to(device)

# Pass the inputs through multi-head attention
output, attention_weights = multihead_attn(query, key, value, mask)

print("Multi-Head Attention Output:\n", output)
print("Attention Weights:\n", attention_weights.shape)


Num heads: 8
Embedding dimension: 64
per head dimension: 8
Query shape: torch.Size([1, 5, 64])
Key shape: torch.Size([1, 5, 64])
Value shape: torch.Size([1, 5, 64])
Shapes after projections for query, key, value...
torch.Size([1, 8, 5, 8]), torch.Size([1, 8, 5, 8]), torch.Size([1, 8, 5, 8])
Scores shape: torch.Size([1, 8, 5, 5])
Attention weights shape: torch.Size([1, 8, 5, 5])
Attention output shape: torch.Size([1, 8, 5, 8])
Attention output shape after concat: torch.Size([1, 5, 64])
Output shape: torch.Size([1, 5, 64])
Multi-Head Attention Output:
 tensor([[[ 0.2293, -0.1196, -0.4825,  0.0844, -0.1668,  0.0588, -0.0671,
           0.1175,  0.3363, -0.0163,  0.1162,  0.1029, -0.0682,  0.2925,
           0.0719,  0.0330,  0.1565, -0.2045, -0.2230, -0.0716,  0.1627,
          -0.0956, -0.0742,  0.1855,  0.0179,  0.1026,  0.1620, -0.0013,
           0.2692, -0.2667,  0.1745,  0.0689,  0.1884,  0.3444,  0.2481,
          -0.0640,  0.1578, -0.0885, -0.0780, -0.1707, -0.4239, -0.3066,
     

In [27]:
# Test MultiHeadAttention with a padding mask and causal mask

# Define parameters
num_heads = 2
d_model = 8
seq_length = 4
batch_size = 1

# Random inputs for query, key, and value
query = torch.rand(batch_size, seq_length, d_model).to(device)
key = torch.rand(batch_size, seq_length, d_model).to(device)
value = torch.rand(batch_size, seq_length, d_model).to(device)

# Create a random sequence with padding (0 represents padding token)
src_tokens = torch.tensor([[1, 2, 3, 0]]).to(device)  # Example with 1 padding token

# Create a padding mask
padding_mask = create_padding_mask(src_tokens).to(device)  # Shape: (batch_size, 1, 1, seq_length)

# Create a causal mask (look-ahead mask)
causal_mask = create_causal_mask(seq_length).to(device)  # Shape: (seq_length, seq_length)

# Combine the masks (bitwise AND to use both padding and causal masks)
combined_mask = padding_mask & causal_mask.unsqueeze(0)
combined_mask.to(device)

# Create MultiHeadAttention object
multihead_attn = MultiHeadAttention(num_heads=num_heads, d_model=d_model, verbose=True).to(device)

# Pass the inputs through multi-head attention with a mask
output, attention_weights = multihead_attn(query, key, value, combined_mask)

print("\nMulti-Head Attention Output:\n", output)
print("Attention Weights:\n", attention_weights)
print("Padding Mask:\n", padding_mask)
print("Causal Mask:\n", causal_mask)
print("Combined Mask:\n", combined_mask)

Num heads: 2
Embedding dimension: 8
per head dimension: 4
Mask shape (after unsqueezing at 1): torch.Size([1, 1, 1, 4, 4])
Query shape: torch.Size([1, 4, 8])
Key shape: torch.Size([1, 4, 8])
Value shape: torch.Size([1, 4, 8])
Shapes after projections for query, key, value...
torch.Size([1, 2, 4, 4]), torch.Size([1, 2, 4, 4]), torch.Size([1, 2, 4, 4])
Scores shape: torch.Size([1, 2, 4, 4])
Attention weights shape: torch.Size([1, 1, 2, 4, 4])
Attention output shape: torch.Size([1, 1, 2, 4, 4])
Attention output shape after concat: torch.Size([1, 4, 8])
Output shape: torch.Size([1, 4, 8])

Multi-Head Attention Output:
 tensor([[[ 0.2070, -0.3633,  0.5263,  0.6589,  0.3929, -0.1892,  0.0745,
          -0.3641],
         [ 0.1263, -0.3127,  0.4872,  0.6978,  0.3566, -0.2125,  0.0865,
          -0.2827],
         [ 0.3505,  0.0285,  0.2862,  0.4353,  0.0862, -0.2340, -0.0065,
           0.0970],
         [ 0.3877,  0.0840,  0.3126,  0.3966,  0.0394, -0.1514,  0.0640,
           0.1300]]], dev

### Encoder layer

Now we implement the Encoder layer

In [28]:
class PositionwiseFFN(nn.Module):
    def __init__(self, d_ff: int, d_model: int, dropout: float = 0.1):
        super(PositionwiseFFN, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear2(self.dropout(torch.relu(self.linear1(x))))
    
class EncoderLayer(nn.Module):
    def __init__(self, num_heads: int, d_model: int, d_ff: int, dropout: float = 0.1, verbose: bool = False):
        super(EncoderLayer, self).__init__()
        self.mha = MultiHeadAttention(num_heads=num_heads, d_model=d_model, dropout=dropout, verbose=verbose)
        self.ffn = PositionwiseFFN(d_ff=d_ff, d_model=d_model, dropout=dropout)
        self.layernorm1 = nn.LayerNorm(d_model)
        self.layernorm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.verbose = verbose

    def forward(self, x, mask=None):
        if self.verbose:
            print(f"Input to Encoder Layer: {x.shape}")
        
        # Multi-head attention with residual connection and layer normalization
        attn_output, _ = self.mha(x, x, x, mask)
        if self.verbose:
            print(f"attn_output shape: {attn_output.shape}")
        out1 = self.layernorm1(x + self.dropout(attn_output))

        # Feedforward with residual connection and layer normalization
        ffn_output = self.ffn(out1)
        out2 = self.layernorm2(out1 + self.dropout(ffn_output))  # Fixed: add out1, not x

        if self.verbose:
            print(f"Output from Encoder Layer: {out2.shape}")
        
        return out2


In [29]:
# Test EncoderLayer with random inputs

# Define parameters
num_heads = 2
d_model = 8
d_ff = 16
seq_length = 4
batch_size = 1

# Random input sequence
x = torch.rand(batch_size, seq_length, d_model).to(device)

# Create a random padding mask (e.g., if needed)
padding_mask = create_padding_mask(torch.tensor([[1, 2, 3, 0]])).to(device)  # Example with padding
print(f"Padding mask: {padding_mask.int()}")

# Create EncoderLayer object
encoder_layer = EncoderLayer(num_heads=num_heads, d_model=d_model, d_ff=d_ff, verbose=True).to(device)

# Pass the input through the encoder layer
output = encoder_layer(x, mask=padding_mask)

print("\nOutput from Encoder Layer:\n", output)


Padding mask: tensor([[[[1, 1, 1, 0]]]], device='mps:0', dtype=torch.int32)
Num heads: 2
Embedding dimension: 8
per head dimension: 4
Input to Encoder Layer: torch.Size([1, 4, 8])
Mask shape (after unsqueezing at 1): torch.Size([1, 1, 1, 1, 4])
Query shape: torch.Size([1, 4, 8])
Key shape: torch.Size([1, 4, 8])
Value shape: torch.Size([1, 4, 8])
Shapes after projections for query, key, value...
torch.Size([1, 2, 4, 4]), torch.Size([1, 2, 4, 4]), torch.Size([1, 2, 4, 4])
Scores shape: torch.Size([1, 2, 4, 4])
Attention weights shape: torch.Size([1, 1, 2, 4, 4])
Attention output shape: torch.Size([1, 1, 2, 4, 4])
Attention output shape after concat: torch.Size([1, 4, 8])
Output shape: torch.Size([1, 4, 8])
attn_output shape: torch.Size([1, 4, 8])
Output from Encoder Layer: torch.Size([1, 4, 8])

Output from Encoder Layer:
 tensor([[[-0.7111,  1.5372, -1.4903,  0.1783,  0.3779, -1.3133,  0.6901,
           0.7312],
         [-1.7991,  1.2210, -0.5939, -0.6746,  0.7649,  0.0175, -0.2782,
 

In [30]:
# Create a padding mask (0 indicates padding)
src_tokens = torch.tensor([[1, 2, 3, 0]]).to(device)  # Example sequence with padding
padding_mask = create_padding_mask(src_tokens).to(device)

# Test EncoderLayer with padding mask
encoder_layer = EncoderLayer(num_heads=2, d_model=8, d_ff=16, dropout=0.1, verbose=True).to(device)
x = torch.rand(1, 4, 8).to(device)  # Random input sequence

# Pass through the encoder layer with the mask
output = encoder_layer(x, mask=padding_mask)
print("Output from encoder layer with padding mask:", output)


Num heads: 2
Embedding dimension: 8
per head dimension: 4
Input to Encoder Layer: torch.Size([1, 4, 8])
Mask shape (after unsqueezing at 1): torch.Size([1, 1, 1, 1, 4])
Query shape: torch.Size([1, 4, 8])
Key shape: torch.Size([1, 4, 8])
Value shape: torch.Size([1, 4, 8])
Shapes after projections for query, key, value...
torch.Size([1, 2, 4, 4]), torch.Size([1, 2, 4, 4]), torch.Size([1, 2, 4, 4])
Scores shape: torch.Size([1, 2, 4, 4])
Attention weights shape: torch.Size([1, 1, 2, 4, 4])
Attention output shape: torch.Size([1, 1, 2, 4, 4])
Attention output shape after concat: torch.Size([1, 4, 8])
Output shape: torch.Size([1, 4, 8])
attn_output shape: torch.Size([1, 4, 8])
Output from Encoder Layer: torch.Size([1, 4, 8])
Output from encoder layer with padding mask: tensor([[[-0.7593, -0.3763, -0.1341, -1.4512,  0.6835,  0.4469, -0.4748,
           2.0653],
         [-0.5853, -0.7003,  0.2369, -1.6870,  0.4599, -0.1346,  0.4533,
           1.9570],
         [-0.6862,  0.3053,  0.3419, -2.3

### Decoder layer

decoder layer implementation

In [31]:
class DecoderLayer(nn.Module):
    def __init__(self, num_heads: int, d_model: int, d_ff: int, dropout: float = 0.1, verbose=False):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(num_heads=num_heads, d_model=d_model, dropout=dropout, verbose=verbose)
        self.src_attn = MultiHeadAttention(num_heads=num_heads, d_model=d_model, dropout=dropout, verbose=verbose)
        self.ffn = PositionwiseFFN(d_ff=d_ff, d_model=d_model, dropout=dropout)
        self.layernorms = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(3)])
        self.dropout = nn.Dropout(dropout)
        self.verbose = verbose

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):

        if self.verbose:
            print(f"Input shape x: {x.shape}")
            print(f"Encoder output shape: {enc_output.shape}\n")
        # masked self-attention over the target (with look-ahead mask)

        if self.verbose:
            print(f"Passing through self-attention")
        self_attn_output, _ = self.self_attn(x, x, x, tgt_mask)
        x = self.layernorms[0](x + self.dropout(self_attn_output))

        if self.verbose:
            print(f"\nPassing Through encoder-decoder attention")
        # encoder-decoder attention over the encoder output (attend to source)
        enc_dec_attn_output, _ = self.src_attn(x, enc_output, enc_output, src_mask)
        x = self.layernorms[1](x + self.dropout(enc_dec_attn_output))

        if self.verbose:
            print(f"\nFinal feedforward of layer")
        # feedforward with residual connection and layer normalization
        ffn_output = self.ffn(x)
        x = self.layernorms[2](x + self.dropout(ffn_output))
        
        if self.verbose:
            print(f"\nOutput shape: {x.shape}")

        return x



In [32]:
def create_causal_mask(seq_length):
    """
    Creates a causal mask (look-ahead mask) that prevents attending to future tokens.
    size: Length of the sequence.
    """
    return torch.tril(torch.ones(seq_length, seq_length)).type(torch.uint8)  # Shape: (seq_length, seq_length)

In [33]:
# Random input sequence for target (decoder input)
tgt = torch.rand(1, 4, 8).to(device)  # (batch_size=1, seq_length=4, d_model=8)

# Random encoder output (assuming same dimensions for simplicity)
enc_output = torch.rand(1, 4, 8).to(device)

# Create masks
tgt_mask = create_causal_mask(seq_length=4).unsqueeze(0).to(device)  # Causal mask for target
src_mask = create_padding_mask(torch.tensor([[1, 2, 3, 0]])).to(device)  # Padding mask for source

# Initialize the decoder layer
decoder_layer = DecoderLayer(num_heads=2, d_model=8, d_ff=16, dropout=0.1, verbose=True).to(device)

# Pass through the decoder layer
output = decoder_layer(tgt, enc_output, src_mask=src_mask, tgt_mask=tgt_mask)
print("Output from decoder layer:", output)
    

Num heads: 2
Embedding dimension: 8
per head dimension: 4
Num heads: 2
Embedding dimension: 8
per head dimension: 4
Input shape x: torch.Size([1, 4, 8])
Encoder output shape: torch.Size([1, 4, 8])

Passing through self-attention
Mask shape (after unsqueezing at 1): torch.Size([1, 1, 4, 4])
Query shape: torch.Size([1, 4, 8])
Key shape: torch.Size([1, 4, 8])
Value shape: torch.Size([1, 4, 8])
Shapes after projections for query, key, value...
torch.Size([1, 2, 4, 4]), torch.Size([1, 2, 4, 4]), torch.Size([1, 2, 4, 4])
Scores shape: torch.Size([1, 2, 4, 4])
Attention weights shape: torch.Size([1, 2, 4, 4])
Attention output shape: torch.Size([1, 2, 4, 4])
Attention output shape after concat: torch.Size([1, 4, 8])
Output shape: torch.Size([1, 4, 8])

Passing Through encoder-decoder attention
Mask shape (after unsqueezing at 1): torch.Size([1, 1, 1, 1, 4])
Query shape: torch.Size([1, 4, 8])
Key shape: torch.Size([1, 4, 8])
Value shape: torch.Size([1, 4, 8])
Shapes after projections for query,

### Positional Encoding

In [34]:
import math

class PositionalEncoding(nn.Module):
    "Implement the PE function."

    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1), :].requires_grad_(False)
        return self.dropout(x)

### Encoder

In [35]:
class Encoder(nn.Module):
    def __init__(self, num_blocks: int, num_heads: int, d_model: int, d_ff: int, dropout: float = 0.1, verbose: bool = False):
        super(Encoder, self).__init__()
        self.num_blocks = num_blocks
        self.verbose = verbose

        # encoder layers
        self.encoder_blocks = nn.ModuleList([
            EncoderLayer(num_heads=num_heads, d_model=d_model, d_ff=d_ff, dropout=dropout, verbose=verbose) for _ in range(num_blocks)
        ])

        # final layer normalization layer
        self.layernorm = nn.LayerNorm(d_model)

    def forward(self, x, src_mask = None):

        if self.verbose:
            print(f"Input of shape: {x.shape}")
        
        for i, block in enumerate(self.encoder_blocks):
            if self.verbose:
                print(f"\n------------ Passing Through Encoder block {i + 1} ----------------")
            
            x = block(x, mask=src_mask)

        # apply final layer normalization
        x = self.layernorm(x)

        if self.verbose:
            print(f"\nFinal output shape is: {x.shape}")

        return x

In [36]:
# Random input sequence (batch_size=1, seq_length=4, d_model=8)
src = torch.rand(1, 4, 8)

# Create a padding mask for the source sequence
src_mask = create_padding_mask(torch.tensor([[1, 2, 3, 0]]))

# Initialize the encoder with 2 blocks for testing
encoder = Encoder(num_blocks=2, num_heads=2, d_model=8, d_ff=16, dropout=0.1, max_length=30, verbose=True)

# Pass through the encoder
output = encoder(src, src_mask=src_mask)
print("Final output from encoder:", output)


TypeError: Encoder.__init__() got an unexpected keyword argument 'max_length'

In [37]:
# test with actual examples
batch_size = 4
small_dl = DataLoader(small_translation_ds, batch_size = batch_size, shuffle=True, collate_fn=collate_fn)

for batch in small_dl:
    print(batch.keys())
    src_tokens = batch['src_tokens'].to(device)  # The tokenized source sentences
    tgt_input = batch['tgt_input'].to(device)  # The tokenized target sentences
    tgt_output = batch['tgt_output'].to(device)
    src_mask = batch['src_mask'].to(device)
    tgt_mask = batch['tgt_mask'].to(device)

    print(f"Source tokens: {src_tokens.shape}")
    print(f"tgt_input: {tgt_input.shape}")
    print(f"tgt_output: {tgt_output.shape}")
    print(f"src_mask: {src_mask.shape}")
    print(f"tgt_mask: {tgt_mask.shape}")

    break  # Just getting the first batch for demonstration

encoder = Encoder(num_blocks=6, num_heads=8, d_model=512, d_ff=2048, verbose=True).to(device)

embedding = nn.Embedding(tokenizer.get_vocab_size(), 512).to(device)
pos_encoder = PositionalEncoding(512, dropout=0.1, max_len=512).to(device)

for batch in small_dl:
    src_tokens = batch['src_tokens'].to(device)
    src_mask = batch['src_mask'].to(device)

    print(f"\nSource token shape: {src_tokens.shape}")

    src_embed = embedding(src_tokens)
    src_embed = pos_encoder(src_embed)
    encoder_output = encoder(src_embed, src_mask)
    break


dict_keys(['src_tokens', 'tgt_input', 'tgt_output', 'src_mask', 'tgt_mask'])
Source tokens: torch.Size([4, 30])
tgt_input: torch.Size([4, 29])
tgt_output: torch.Size([4, 29])
src_mask: torch.Size([4, 1, 30])
tgt_mask: torch.Size([4, 29, 29])
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64

Source token shape: torch.Size([4, 30])
Input of shape: torch.Size([4, 30, 512])

------------ Passing Through Encoder block 1 ----------------
Input to Encoder Layer: torch.Size([4, 30, 512])
Mask shape (after unsqueezing at 1): torch.Size([4, 1, 1, 30])
Query shape: torch.Size([4, 30, 512])
Key shape: torch.Size([4, 30, 512])
Value shape: torch.Size([4, 30, 512])
Shapes after projection

### Decoder

In [40]:
class Decoder(nn.Module):
    def __init__(self, num_blocks: int, num_heads: int, d_model: int, d_ff: int, dropout: float = 0.1, verbose: bool = True):
        super(Decoder, self).__init__()
        self.decoder_blocks = nn.ModuleList([
            DecoderLayer(num_heads=num_heads, d_model=d_model, d_ff=d_ff, dropout=dropout, verbose=verbose) for _ in range(num_blocks)
        ])
        self.layernorm = nn.LayerNorm(d_model)
        self.verbose = verbose

    def forward(self, tgt, enc_output, src_mask=None, tgt_mask=None):
        for i, block in enumerate(self.decoder_blocks):
            if self.verbose:
                print(f"\n------------- Passing Through Decoder Block {i+1} ----------------")
            tgt = block(tgt, enc_output, src_mask, tgt_mask)

        return self.layernorm(tgt)

In [41]:
import torch
import torch.nn as nn

# Mock data for testing
batch_size = 4
seq_length = 5
d_model = 512
num_heads = 8
num_blocks = 6
d_ff = 2048

# Random embedded target tokens (already embedded, just mock data)
tgt_embed = torch.rand(batch_size, seq_length, d_model)  # (batch_size, seq_length, d_model)

# Random encoder output (to simulate the output from the encoder)
enc_output = torch.rand(batch_size, seq_length, d_model)  # (batch_size, seq_length, d_model)

# Create padding mask (mock data, assume no padding tokens for simplicity)
src_mask = torch.ones(batch_size, 1, seq_length)  # Shape: (batch_size, 1, seq_length)
print(src_mask.shape)

tgt_mask = torch.tril(torch.ones(batch_size, seq_length, seq_length))
print(tgt_mask.shape)


# Initialize the decoder without embedding
decoder = Decoder(num_blocks=num_blocks, num_heads=num_heads, d_model=d_model, d_ff=d_ff, dropout=0.1, verbose=True)

# Pass the mock data through the decoder
decoder_output = decoder(tgt_embed, enc_output, src_mask=src_mask, tgt_mask=tgt_mask)

print("Decoder output shape:", decoder_output.shape)


torch.Size([4, 1, 5])
torch.Size([4, 5, 5])
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64

------------- Passing Through Decoder Block 1 ----------------
Input shape x: torch.Size([4, 5, 512])
Encoder output shape: torch.Size([4, 5, 512])

Passing through self-attention
Mask shape (after unsqueezing at 1): torch.

# Translator Class

Now we are going to instantiate an encoder and decoder class and string them together to confirm that everything works together. Than we will abstract and create an Encoder-Decoder sequence to sequence model. 

In [42]:
class Generator(nn.Module):
    """Define the linear + softmax step for generating token probabilities.
        Layer projects vector on to vocab space and then applys a log_softmax. 
    """
    def __init__(self, d_model, vocab_size):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        return torch.log_softmax(self.proj(x), dim=-1)


In [43]:
# parameters
d_model = 512
d_ff = 2048
dropout = 0.1
num_blocks = 6
num_heads = 8
max_len = 30

batch_size = 4

small_dl = DataLoader(small_translation_ds, batch_size = batch_size, shuffle=True, collate_fn=collate_fn)

embedding = nn.Embedding(tokenizer.get_vocab_size(), d_model)
pos_encoder = PositionalEncoding(d_model=d_model, dropout=dropout, max_len=max_len)

encoder = Encoder(num_blocks=num_blocks, num_heads=num_heads, d_model=d_model, d_ff=d_ff, dropout=dropout, verbose=True)

decoder = Decoder(num_blocks=6, num_heads=8, d_model=d_model, d_ff=d_ff, dropout=dropout, verbose=True)

generator = Generator(d_model=d_model, vocab_size=tokenizer.get_vocab_size())

for batch in small_dl:
    print(f"Source tokens:", batch['src_tokens'].shape)
    print(f"Target input tokens:", batch['tgt_input'].shape)
    print(f"Target output tokens:", batch['tgt_output'].shape)
    print(f"Source mask:", batch['src_mask'].shape)
    print(f"Target mask:", batch['tgt_mask'].shape)

    src_tokens = batch['src_tokens']
    src_mask = batch['src_mask']
    tgt_input = batch['tgt_input']
    tgt_output = batch['tgt_output']
    src_mask = batch['src_mask']
    tgt_mask = batch['tgt_mask']

    src_embed = embedding(src_tokens)
    src_embed = pos_encoder(src_embed)
    encoder_output = encoder(src_embed, src_mask)

    print(f"Encoder output: {encoder_output.shape}")

    tgt_embed = embedding(tgt_input)
    tgt_embed = pos_encoder(tgt_embed)
    dec_output = decoder(tgt=tgt_embed, enc_output = encoder_output, src_mask=src_mask, tgt_mask=tgt_mask)

    output = generator(dec_output)
    predicted_tokens = torch.argmax(output, dim=-1)

    print(output)
    print(tgt_output)

    print(output.shape)
    print(predicted_tokens.shape)
    print(tgt_output.shape)
    break


Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding dimension: 512
per head dimension: 64
Num heads: 8
Embedding d

In [44]:
# now we abstract the above into a EncoderDecoder class. 
class EncoderDecoder(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, generator: Generator, embedding: nn.Embedding, pos_encoder: PositionalEncoding, verbose: bool = False):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.generator = generator
        self.embedding = embedding
        self.pos_encoder = pos_encoder
        self.verbose = verbose

    def forward(self, src_tokens, tgt_input, src_mask, tgt_mask):
        # Encoder
        src_embed = self.embedding(src_tokens)
        src_embed = self.pos_encoder(src_embed)
        encoder_output = self.encoder(src_embed, src_mask)

        # Decoder
        tgt_embed = self.embedding(tgt_input)
        tgt_embed = self.pos_encoder(tgt_embed)
        dec_output = self.decoder(tgt=tgt_embed, enc_output=encoder_output, src_mask=src_mask, tgt_mask = tgt_mask)

        output_log_probs = self.generator(dec_output)

        return output_log_probs



# Training

In [45]:
# create dataset
train_ds = TranslationDataset(dataset['train'].shuffle().select(range(20000)), tokenizer=tokenizer, bos_token_id=BOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID, pad_token_id=PAD_TOKEN_ID)
val_ds = TranslationDataset(dataset['validation'], tokenizer=tokenizer, bos_token_id=BOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID, pad_token_id=PAD_TOKEN_ID)

In [46]:
# create dataloaders
batch_size = 16
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_dl = DataLoader(val_ds, batch_size=batch_size, collate_fn=collate_fn)

In [47]:
for batch in train_dl:
    print(batch.keys())
    break

dict_keys(['src_tokens', 'tgt_input', 'tgt_output', 'src_mask', 'tgt_mask'])


In [48]:
# initiate models

# parameters
d_model = 512
d_ff = 2048
dropout = 0.1
num_blocks = 6
num_heads = 8
max_len = 512
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

embedding = nn.Embedding(tokenizer.get_vocab_size(), d_model).to(device)
pos_encoder = PositionalEncoding(d_model=d_model, dropout=dropout, max_len=max_len).to(device)
encoder = Encoder(num_blocks=num_blocks, num_heads=num_heads, d_model=d_model, d_ff=d_ff, dropout=dropout, verbose=False).to(device)
decoder = Decoder(num_blocks=6, num_heads=8, d_model=d_model, d_ff=d_ff, dropout=dropout, verbose=False).to(device)
generator = Generator(d_model=d_model, vocab_size=tokenizer.get_vocab_size()).to(device)

model = EncoderDecoder(encoder, decoder, generator, embedding, pos_encoder, verbose=False).to(device)

In [49]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)

EncoderDecoder(
  (encoder): Encoder(
    (encoder_blocks): ModuleList(
      (0-5): 6 x EncoderLayer(
        (mha): MultiHeadAttention(
          (query_linear): Linear(in_features=512, out_features=512, bias=True)
          (key_linear): Linear(in_features=512, out_features=512, bias=True)
          (value_linear): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (output_linear): Linear(in_features=512, out_features=512, bias=True)
        )
        (ffn): PositionwiseFFN(
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (layernorm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (layernorm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (layernorm): LayerNorm((

In [50]:
device

device(type='mps')

In [51]:
# Create learning rate scheduler, following `Attention is All You Need` for now. 
# lr = d_model ** (-0.5) * min(step_num ** (-0.5), step_num * warmup_steps ** (-1.5))
warmup_steps = 4000

def get_lr(step_num):
    return d_model ** -0.5 * min(step_num ** -0.5, step_num * warmup_steps ** -1.5)



In [52]:
from tqdm import tqdm
from tqdm.notebook import tqdm

# training loop
# optimizer and criterion
learning_rate = 1e-4
optimizer = torch.optim.AdamW(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
criterion = nn.NLLLoss()

num_epochs = 5

for epoch in range(num_epochs):

    model.train()
    total_loss = 0
    step_num = 0

    for batch in tqdm(train_dl):
        step_num += 1

        for param_group in optimizer.param_groups:
            param_group['lr'] = get_lr(step_num)

        src_tokens = batch['src_tokens'].to(device)
        tgt_input = batch['tgt_input'].to(device)
        tgt_output = batch['tgt_output'].to(device)
        src_mask = batch['src_mask'].to(device)
        tgt_mask = batch['tgt_mask'].to(device)

        # print(src_tokens.device)
        # print(src_mask.device)

        # zero the gradients
        optimizer.zero_grad()

        output_logits = model(src_tokens, tgt_input, src_mask, tgt_mask)

        loss = criterion(output_logits.view(-1, output_logits.size(-1)), tgt_output.view(-1))
        
        loss.backward()

        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_dl)
    print(f"Epoch {epoch + 1}/{num_epochs}, Average training loss: {avg_loss: .4f}")

    # validation
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for batch in tqdm(val_dl):
            src_tokens = batch['src_tokens'].to(device)
            tgt_input = batch['tgt_input'].to(device)
            tgt_output = batch['tgt_output'].to(device)
            src_mask = batch['src_mask'].to(device)
            tgt_mask = batch['tgt_mask'].to(device)

            output = model(src_tokens, tgt_input, src_mask, tgt_mask)

            loss = criterion(output.view(-1, output.size(-1)), tgt_output.view(-1))
            total_val_loss += loss.item()
    
    avg_val_loss = total_val_loss / len(val_dl)
    print(F"Epoch {epoch + 1}/{num_epochs}, Average validation loss: {avg_val_loss:.4f}")

  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 1/5, Average training loss:  0.6738


  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 1/5, Average validation loss: 0.3290


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 2/5, Average training loss:  0.3582


  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 2/5, Average validation loss: 0.3121


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 3/5, Average training loss:  0.3347


  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 3/5, Average validation loss: 0.3031


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 4/5, Average training loss:  0.3166


  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 4/5, Average validation loss: 0.2938


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 5/5, Average training loss:  0.3002


  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 5/5, Average validation loss: 0.2897


In [53]:
model

EncoderDecoder(
  (encoder): Encoder(
    (encoder_blocks): ModuleList(
      (0-5): 6 x EncoderLayer(
        (mha): MultiHeadAttention(
          (query_linear): Linear(in_features=512, out_features=512, bias=True)
          (key_linear): Linear(in_features=512, out_features=512, bias=True)
          (value_linear): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (output_linear): Linear(in_features=512, out_features=512, bias=True)
        )
        (ffn): PositionwiseFFN(
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (layernorm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (layernorm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (layernorm): LayerNorm((

# Inference

In [54]:
# check if inference works
num_examples = 10
examples = []
for i in range(num_examples):
    examples.append(val_ds[i])
examples

[{'src_sentence': 'Eine republikanische Strategie, um der Wiederwahl von Obama entgegenzutreten',
  'tgt_sentence': 'A Republican strategy to counter the re-election of Obama',
  'src_tokens': tensor([    0,  2530,  2878, 12244,  8708,  4789,    16,   577,   319,  4755,
           3815,   408, 11741,  7738, 27237,     2,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
         

In [55]:
model.to(device)

EncoderDecoder(
  (encoder): Encoder(
    (encoder_blocks): ModuleList(
      (0-5): 6 x EncoderLayer(
        (mha): MultiHeadAttention(
          (query_linear): Linear(in_features=512, out_features=512, bias=True)
          (key_linear): Linear(in_features=512, out_features=512, bias=True)
          (value_linear): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (output_linear): Linear(in_features=512, out_features=512, bias=True)
        )
        (ffn): PositionwiseFFN(
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (layernorm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (layernorm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (layernorm): LayerNorm((

In [56]:
def greedy_decoding(model: EncoderDecoder, src_tokens: torch.Tensor, tokenizer: ByteLevelBPETokenizer = tokenizer, pad_token_id: int = PAD_TOKEN_ID, max_len=100):
    model.eval()

    # embed the source tokens and create the src_mask
    src_tokens = src_tokens.unsqueeze(0).to(device)
    src_mask = (src_tokens != pad_token_id).unsqueeze(0).to(device) # shape: (1, seq_length)

    src_embed = model.embedding(src_tokens)
    src_embed = model.pos_encoder(src_embed)

    # store encoder hidden states for the src_tokens
    encoder_output = model.encoder(src_embed, src_mask)

    # initizlie target sentence with BOS token
    tgt_tokens = torch.tensor([BOS_TOKEN_ID], dtype=torch.long).to(device)

    # Autoregressive loop to generate sentence
    for _ in range(max_len):
        # create target mask
        tgt_seq_len = tgt_tokens.size(0)
        tgt_mask = torch.tril(torch.ones(1, tgt_seq_len, tgt_seq_len)).to(device)
        #print(f"target mask shape: {tgt_mask.shape}")

        #print(f"Tokens at beginning: {tgt_tokens}")
        tgt_embed = model.embedding(tgt_tokens).unsqueeze(0)
        #print(f"Token embeddings shape: {tgt_embed.shape}")
        tgt_embed = model.pos_encoder(tgt_embed)

        output_logits = model.decoder(tgt_embed, encoder_output, src_mask, tgt_mask)
        output_log_probs = model.generator(output_logits)
        # print(output_log_probs.shape)
        # print(output_log_probs)
        next_token = torch.argmax(output_log_probs[:, -1, :], dim=-1)
        # print(next_token.shape)
        #print(f"Next token: {next_token.item()}")
        #print(tgt_tokens.shape)
        # append next
        tgt_tokens = torch.cat([tgt_tokens, next_token])

        if next_token.item() == EOS_TOKEN_ID or tgt_tokens.size(0) >= 50:
            break

        #print(f"Resulting Target Tokens: {tgt_tokens}")

    print(f"Source sentence: {tokenizer.decode([num for num in src_tokens.squeeze(0).tolist() if num != pad_token_id], skip_special_tokens=True)}")
    print(f"Translation: {tokenizer.decode(tgt_tokens.tolist())}")

    

In [None]:
def top_k_sampling(logits, k=10):
    values, indices = torch.topk(logits, k)

In [None]:
tokenizer.decode(src_tokens.tolist(), skip_special_tokens=True)

In [None]:
tokenizer.id_to_token(BOS_TOKEN_ID)

In [57]:
import numpy as np
random_integers = np.random.choice(300, 10, replace=False).tolist()
examples = [val_ds[i]['src_tokens'] for i in random_integers]
for example in examples:
    greedy_decoding(model, example)

Source sentence: <s>Die Konservierung von Fleisch durch Räuchern, Trocknen oder Salzen kann zur Bildung von Karzinogenen führen.</s>
Translation: <s>The town of the town of the town can be used to the same time, or to the use of the other hand.</s>
Source sentence: <s>Deshalb empfehle ich den Test ab einem Alter von 50 Jahren bzw. 40 Jahren, wenn man einen direkten Verwandten hat, der bereits an Prostatakrebs erkrankt war.</s>
Translation: <s>I therefore believe that the fact that a few years of the years of the years, if the last two years have been made a few years of the last year.</s>
Source sentence: <s>Man schlägt ihnen eine aktive Überwachung vor und bietet ihnen bei Fortschreiten der Krankheit eine Behandlung an.</s>
Translation: <s>You can enjoy a high-quality and a long-term and a long-term way.</s>
Source sentence: <s>"Die Durchführung eines Früherkennungstest führt nicht zu Krebs."</s>
Translation: <s>The fact is to make a few words of a few words of the most important issu

In [76]:
src_tokens = val_ds[170]['src_tokens'].unsqueeze(0).to(device)
src_sentence = val_ds[170]['src_sentence']
src_mask = (src_tokens != PAD_TOKEN_ID).unsqueeze(0).to(device)
src_mask.shape, src_tokens.shape

(torch.Size([1, 1, 512]), torch.Size([1, 512]))

In [77]:
random_integers

[93, 60, 66, 102, 238, 45, 41, 294, 178, 297]

In [78]:
src_embed = model.embedding(src_tokens)
src_embed = model.pos_encoder(src_embed)
encoder_output = model.encoder(src_embed, src_mask)

In [79]:
tgt_tokens = torch.tensor([BOS_TOKEN_ID], dtype=torch.long).to(device)
# tgt_mask = torch.ones(1, 1).to(device)

for _ in range(50):

    # create target mask
    tgt_seq_len = tgt_tokens.size(0)
    tgt_mask = torch.tril(torch.ones(1, tgt_seq_len, tgt_seq_len)).to(device)
    print(f"target mask shape: {tgt_mask.shape}")

    print(f"Tokens at beginning: {tgt_tokens}")
    tgt_embed = model.embedding(tgt_tokens).unsqueeze(0)
    print(print(f"Token embeddings shape: {tgt_embed.shape}"))
    tgt_embed = model.pos_encoder(tgt_embed)

    output_logits = model.decoder(tgt_embed, encoder_output, src_mask, tgt_mask)
    output_log_probs = model.generator(output_logits)
    # print(output_log_probs.shape)
    # print(output_log_probs)
    next_token = torch.argmax(output_log_probs[:, -1, :], dim=-1)
    # print(next_token.shape)
    print(f"Next token: {next_token.item()}")
    print(tgt_tokens.shape)
    # append next
    tgt_tokens = torch.cat([tgt_tokens, next_token])

    if next_token.item() == EOS_TOKEN_ID or tgt_tokens.size(0) >= 50:
        break

    print(f"Resulting Target Tokens: {tgt_tokens}")
    

target mask shape: torch.Size([1, 1, 1])
Tokens at beginning: tensor([0], device='mps:0')
Token embeddings shape: torch.Size([1, 1, 512])
None
Next token: 45
torch.Size([1])
Resulting Target Tokens: tensor([ 0, 45], device='mps:0')
target mask shape: torch.Size([1, 2, 2])
Tokens at beginning: tensor([ 0, 45], device='mps:0')
Token embeddings shape: torch.Size([1, 2, 512])
None
Next token: 520
torch.Size([2])
Resulting Target Tokens: tensor([  0,  45, 520], device='mps:0')
target mask shape: torch.Size([1, 3, 3])
Tokens at beginning: tensor([  0,  45, 520], device='mps:0')
Token embeddings shape: torch.Size([1, 3, 512])
None
Next token: 956
torch.Size([3])
Resulting Target Tokens: tensor([  0,  45, 520, 956], device='mps:0')
target mask shape: torch.Size([1, 4, 4])
Tokens at beginning: tensor([  0,  45, 520, 956], device='mps:0')
Token embeddings shape: torch.Size([1, 4, 512])
None
Next token: 7014
torch.Size([4])
Resulting Target Tokens: tensor([   0,   45,  520,  956, 7014], device='m

In [80]:
tgt_tokens.tolist()

[0,
 45,
 520,
 956,
 7014,
 393,
 435,
 326,
 503,
 264,
 1353,
 4839,
 393,
 938,
 520,
 264,
 1362,
 3069,
 427,
 264,
 1362,
 3069,
 304,
 280,
 1641,
 978,
 18,
 2]

In [81]:
print(src_sentence)
tokenizer.decode(tgt_tokens.tolist())

"Ich habe keine Wünsche mehr im Leben", sagt sie, bevor sie akzeptiert, dass man ihr eine Maske aufsetzt, die ihr beim Atmen hilft.


'<s>I have no doubt that it is not a good thing that they have a great deal with a great deal of the same time.</s>'

In [83]:
examples

[tensor([    0,   567, 18680,   875,   408, 13262,   745,   370,   314,   412,
           547,    16, 28152,   472,   653,  4206,  1582,   836,   665,  6417,
           408,  4518, 11877,  2683,   261,  2660,    18,     2,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,  

In [84]:
# collect encodings for the src sentences
encodings = []
for src_tokens in examples:
    src_tokens = src_tokens.unsqueeze(0).to(device)
    src_mask = (src_tokens != PAD_TOKEN_ID).unsqueeze(0).to(device)
    
    with torch.no_grad():
        src_embed = model.embedding(src_tokens)
        src_embed = model.pos_encoder(src_embed)
        encoder_output = model.encoder(src_embed, src_mask)

    encodings.append(encoder_output)

encoder_diffs = []
for i in range(9):
    for j in range(i+1, 10):
        diff = torch.norm(encodings[i] - encodings[j])
        encoder_diffs.append(diff)
        print(f"Difference between example {i} and example {j}: {diff}")
        


Difference between example 0 and example 1: 225.04786682128906
Difference between example 0 and example 2: 214.9784393310547
Difference between example 0 and example 3: 222.360107421875
Difference between example 0 and example 4: 259.3667907714844
Difference between example 0 and example 5: 191.9576416015625
Difference between example 0 and example 6: 214.24774169921875
Difference between example 0 and example 7: 224.41213989257812
Difference between example 0 and example 8: 185.67291259765625
Difference between example 0 and example 9: 227.0283966064453
Difference between example 1 and example 2: 251.41867065429688
Difference between example 1 and example 3: 259.7427062988281
Difference between example 1 and example 4: 282.7121276855469
Difference between example 1 and example 5: 205.73074340820312
Difference between example 1 and example 6: 205.10382080078125
Difference between example 1 and example 7: 259.2537841796875
Difference between example 1 and example 8: 231.93775939941406
D

In [85]:
max(encoder_diffs) - min(encoder_diffs)


tensor(110.6548, device='mps:0')

In [None]:
train_ds[98]