In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm
import os
import wandb

import components

In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device("cpu")

device

# Download dataset

### download and inspect dataset

In [None]:
from datasets import load_dataset, DatasetDict

In [None]:
# Load the WMT14 dataset for German-English translation
dataset = load_dataset('wmt14', 'de-en')

In [None]:
dataset

In [None]:
dataset['train'][4]

In [None]:
# select a very small segment for experimentation
# Take a small subset for experimentation
small_train_dataset = dataset['train'].select(range(20))
small_val_dataset = dataset['validation'].select(range(5))

In [None]:
small_train_dataset

### Tokenization

In [None]:
# as we are following the original `Attention is all you need paper` we will use Byte-Pair Encoding
from tokenizers import ByteLevelBPETokenizer

In [None]:
# Load the trained tokenizer
tokenizer = ByteLevelBPETokenizer(
    "bpe_tokenizer/vocab.json",
    "bpe_tokenizer/merges.txt"
)

In [None]:
# Test the tokenizer
print(tokenizer.encode("Das ist ein Beispiel.").ids)

print([tokenizer.id_to_token(token) for token in tokenizer.encode("Das ist ein Beispiel").ids])
# Should return something like ['<s>', 'Das', 'ist', 'ein', 'Beispiel', '</s>']

print(tokenizer.token_to_id("</s>"))
# Should return a valid token ID for '</s>'

print(tokenizer.decode(tokenizer.encode("Das ist ein Beispiel.").ids))


In [None]:
PAD_TOKEN_ID = tokenizer.token_to_id('<pad>')
BOS_TOKEN_ID = tokenizer.token_to_id('<s>')
EOS_TOKEN_ID = tokenizer.token_to_id('</s>')

In [None]:
print(BOS_TOKEN_ID, PAD_TOKEN_ID, EOS_TOKEN_ID)

In [None]:
# Create a pytorch dataset class
from torch.utils.data import Dataset, DataLoader

class TranslationDataset(Dataset):
    def __init__(self, dataset, tokenizer, bos_token_id: int = BOS_TOKEN_ID, eos_token_id: int = EOS_TOKEN_ID ,pad_token_id:int = PAD_TOKEN_ID, max_length: int = 512):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.bos = bos_token_id
        self.eos = eos_token_id
        self.pad_token_id = pad_token_id
        

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        src_sentence = self.dataset[idx]['translation']['de']
        tgt_sentence = self.dataset[idx]['translation']['en']

        # tokenize the source and target
        src_tokens = self.tokenizer.encode(src_sentence).ids
        tgt_tokens = self.tokenizer.encode(tgt_sentence).ids

        # pad and truncate
        src_tokens = torch.tensor(self.pad_and_truncate(self.add_special_tokens(src_tokens)))
        tgt_tokens = torch.tensor(self.pad_and_truncate(self.add_special_tokens(tgt_tokens)))

        # # create attention masks
        # src_mask = (src_tokens != self.pad_token_id).int()
        # tgt_mask = (src_tokens != self.pad_token_id).int()

        # # create look ahead mask
        # look_ahead_mask = self.create_causal_mask(len(tgt_tokens))


        return {
            'src_sentence': src_sentence, 
            'tgt_sentence': tgt_sentence, 
            'src_tokens': src_tokens,
            'tgt_tokens': tgt_tokens,
            # 'src_mask': src_mask,
            # 'tgt_mask': tgt_mask,
            # 'look_ahead_mask': look_ahead_mask,
            # 'combined_mask': tgt_mask & look_ahead_mask
        }

    def pad_and_truncate(self, tokens):
        if len(tokens) < self.max_length:
            tokens = tokens + [self.pad_token_id] * (self.max_length - len(tokens))
        else:
            tokens = tokens[:self.max_length]
        
        return tokens
    
    def add_special_tokens(self, tokens):
        return [self.bos] + tokens + [self.eos]

    def create_causal_mask(self, size):
        # create an lower triangular matrix for the purposes of look ahead masking
        return torch.tril(torch.ones(size, size)).type(torch.uint8)

In [None]:
small_translation_ds = TranslationDataset(small_train_dataset, tokenizer=tokenizer, pad_token_id=PAD_TOKEN_ID, max_length=30)
small_translation_ds

In [None]:
small_translation_ds[0]

In [None]:
# collate function for handling masks

def create_causal_mask(size):
    """
    Creates a causal mask (look-ahead mask) that prevents attending to future tokens.
    size: Length of the sequence.
    """
    attn_shape = (1, size, size)
    return torch.tril(torch.ones(attn_shape)).type(torch.uint8)  # Shape: (1, seq_length, seq_length)

def create_std_mask(tgt, pad_token_id = PAD_TOKEN_ID):
    tgt_mask = (tgt != pad_token_id).unsqueeze(-2)
    tgt_mask = tgt_mask & create_causal_mask(tgt.size(-1))
    return tgt_mask
    
def collate_fn(batch, pad_token_id = PAD_TOKEN_ID):
    src_batch = torch.stack([item['src_tokens'] for item in batch])
    tgt_batch = torch.stack([item['tgt_tokens'] for item in batch])

    # create source masks
    src_mask = (src_batch != pad_token_id).unsqueeze(-2).int() # shape: (bs, seq_length, 1)
    tgt = tgt_batch[:, :-1]
    tgt_y = tgt_batch[:, 1:]
    tgt_mask = create_std_mask(tgt, pad_token_id=pad_token_id)

    return {
        'src_tokens': src_batch,
        'tgt_input': tgt, 
        'tgt_output': tgt_y,
        'src_mask': src_mask, 
        'tgt_mask': tgt_mask,
    }


In [None]:
small_dl = DataLoader(small_translation_ds, collate_fn=collate_fn, batch_size=4)

for batch in small_dl:
    print(f"Source tokens:", batch['src_tokens'].shape)
    print(f"Target tokens:", batch['tgt_input'].shape)
    print(f"Target output tokens:", batch['tgt_output'].shape)
    print(f"Source mask:", batch['src_mask'].shape)
    print(f"Target mask:", batch['tgt_mask'].shape)
    break

# Creating each layer step by step

### Scaled Dot-Product Attention

In [None]:
import torch.nn.functional as F
import torch
import math

def scaled_dpa(query, key, value, mask=None, verbose=False):
    """
    Implements scaled dot product attention.
    Args:
        query: (batch_size, seq_length, dim_k)
        key: (batch_size, seq_length, dim_k)
        value: (batch_size, seq_length, dim_v)
        mask: (batch_size, seq_length) or None
        verbose: Boolean default False
    Returns:
        attention_output: (batch_size, seq_length, dim_v)
        attention_weights: (batch_size, seq_length, seq_length)
    """

    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)# (bs, seq_length, seq_length)

    if verbose:
        print(f"Scores shape: {scores.shape}")
    
    # apply the mask if necessary
    if mask is not None:
        scores = scores.masked_fill(mask==0, float('-inf'))
    
    # apply softmax to get attention_weights
    attention_weights = F.softmax(scores, dim=-1) # (bs, seq_length, seq_length)

    if verbose:
        print(f"Attention weights shape: {attention_weights.shape}")
    
    output = torch.matmul(attention_weights, value)
    
    if verbose:
        print(f"Attention output shape: {output.shape}")
    
    return output, attention_weights

In [None]:
# Batch size = 1, Sequence length = 5, Embedding dimension = 4 (d_k)
batch_size = 3
seq_length = 5

# example scores
scores = torch.rand(batch_size, seq_length, seq_length)
print(scores)

# Optional mask
mask = torch.tensor([
    [1, 1, 1, 0, 0], 
    [1, 1, 0, 0, 0],
    [1, 1, 1, 1, 1],
])
print(mask)

mask = mask.unsqueeze(1)
print(mask.shape)

scores = scores.masked_fill(mask==0, float('-inf'))
print(scores)


In [None]:
# test scaled dpa
# Example of how to use scaled_dpa with random tensors

# Batch size = 1, Sequence length = 5, Embedding dimension = 4 (d_k)
batch_size = 3
seq_length = 5
embedding_dim = 4

# Random queries, keys, and values
query = torch.rand(batch_size, seq_length, embedding_dim).to(device)
key = torch.rand(batch_size, seq_length, embedding_dim).to(device)
value = torch.rand(batch_size, seq_length, embedding_dim).to(device)

print(f"Query shape: {query.shape}")

# Optional mask
mask = torch.tensor([
    [1, 1, 1, 0, 0], 
    [1, 1, 0, 0, 0],
    [1, 1, 1, 1, 1],
])
mask = mask.unsqueeze(1).to(device)

print(f"Mask shape: {mask.shape}")

# Test scaled_dpa
output, attention_weights = scaled_dpa(query, key, value, mask, verbose=True)

print("Attention Output:\n", output)
print("Attention Weights:\n", attention_weights)


In [None]:
# testing with mask
def create_padding_mask(seq):
    """
    Creates a padding mask (1 for valid tokens, 0 for padding tokens).
    seq: Tensor of shape (batch_size, seq_length)
    """
    return (seq != 0).unsqueeze(1).unsqueeze(2)  # Shape: (batch_size, 1, 1, seq_length)

def create_causal_mask(size):
    """
    Creates a causal mask (look-ahead mask) that prevents attending to future tokens.
    size: Length of the sequence.
    """
    return torch.tril(torch.ones(size, size)).type(torch.uint8)  # Shape: (seq_length, seq_length)

# Test scaled_dpa with padding and causal masks

# Batch size = 1, Sequence length = 5, Embedding dimension = 4 (d_k)
batch_size = 1
seq_length = 5
embedding_dim = 4

# Random queries, keys, and values
query = torch.rand(batch_size, seq_length, embedding_dim).to(device)
key = torch.rand(batch_size, seq_length, embedding_dim).to(device)
value = torch.rand(batch_size, seq_length, embedding_dim).to(device)

# Create a random sequence with padding (0 represents padding token)
src_tokens = torch.tensor([[1, 2, 3, 0, 0]]).to(device)  # Example with 2 padding tokens

# Create a padding mask
padding_mask = create_padding_mask(src_tokens).to(device)  # Shape: (batch_size, 1, 1, seq_length)

# Create a causal mask (look-ahead mask)
causal_mask = create_causal_mask(seq_length).to(device)  # Shape: (seq_length, seq_length)

# Combine the masks (for testing both padding and causal masking together)
combined_mask = padding_mask & causal_mask.unsqueeze(0).to(device)

# Test scaled_dpa with the mask
output, attention_weights = scaled_dpa(query, key, value, combined_mask, verbose=True)

print("Attention Output:\n", output)
print("Attention Weights:\n", attention_weights)
print("Padding Mask:\n", padding_mask)
print("Causal Mask:\n", causal_mask)
print("Combined Mask:\n", combined_mask)



### Multi-head attention

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads: int, d_model: int, dropout=0.1, verbose=False):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads."
        self.num_heads = num_heads
        self.d_model = d_model
        self.d_k = d_model // num_heads
        self.verbose = verbose

        if self.verbose:
            print(f"Num heads: {num_heads}")
            print(f"Embedding dimension: {d_model}")
            print(f"per head dimension: {self.d_k}")
    
        # linear layers to project the inputs to query, key, and value
        self.query_linear = nn.Linear(d_model, d_model)
        self.key_linear = nn.Linear(d_model, d_model)
        self.value_linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(p=dropout) 
        self.output_linear = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):
        # query shape is bs, seq_length, d_model
        # key shape is bs, seq_length, d_model
        # value shape is bs, d_model, d_model
        batch_size = query.size(0)
        seq_length = query.size(1)

        if mask is not None:
            mask = mask.unsqueeze(1) # Same mask applied to all heads. 

        if self.verbose and mask is not None:
            print(f"Mask shape (after unsqueezing at 1): {mask.shape}")

        # apply linear layers
        query = self.query_linear(query)   # shape bs, seq_length, d_model
        key = self.key_linear(key) #shape: bs, seq_length, d_model
        value = self.value_linear(value) # shape: bs, d_model, d_model

        if self.verbose:
            print(f"Query shape: {query.shape}")
            print(f"Key shape: {key.shape}")
            print(f"Value shape: {value.shape}")
        
        # reshape and split into multiple heads
        query = query.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) # (bs, num_heads, seq_length, d_k)
        key = key.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) # (bs, num_heads, seq_length, d_k)
        value = value.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) #(bs, num_heads, seq_length, d_k)

        if self.verbose:
            print(f"Shapes after projections for query, key, value...")
            print(f"{query.shape}, {key.shape}, {value.shape}")

        attn_output, attn_weights = scaled_dpa(query, key, value, mask, verbose = self.verbose)

        # we've separated the query key and value into separate heads and then computed the scaled dot-product attention for each head.
        # Now we must put them back together. 
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)

        if self.verbose:
            print(f"Attention output shape after concat: {attn_output.shape}")

        # apply the final linear layer transformation
        output = self.output_linear(attn_output)
        if self.verbose:
            print(f"Output shape: {output.shape}")

        return output, attn_weights

In [None]:
# visualizing how the transpsoe works
batch_size = 1
seq_length = 4
d_model = 8
num_heads = 2
d_k = d_model // num_heads
query = torch.arange(1, seq_length*d_model + 1).view(batch_size, seq_length, d_model).to(device)
print(query.shape)
print(f"unchanged query: {query}")

query = query.view(batch_size, -1, num_heads, d_k)
print(f"prior to transpose query: {query}")

query = query.transpose(1, 2)
print(f"transposed query: {query}")
print(query.shape)

In [None]:
query = torch.arange(1, seq_length*d_model + 1).view(batch_size, seq_length, d_model)

query = query.view(batch_size, num_heads, -1, d_k)
print(f"Directly reshaping query: {query}")

In [None]:
# Test MultiHeadAttention with random inputs

# Define parameters
num_heads = 8
d_model = 64
seq_length = 5
batch_size = 1

# Random inputs for query, key, and value
query = torch.rand(batch_size, seq_length, d_model).to(device)
key = torch.rand(batch_size, seq_length, d_model).to(device)
value = torch.rand(batch_size, seq_length, d_model).to(device)

# No mask for now (can add later)
mask = None

# Create MultiHeadAttention object
multihead_attn = MultiHeadAttention(num_heads=num_heads, d_model=d_model, verbose=True).to(device)

# Pass the inputs through multi-head attention
output, attention_weights = multihead_attn(query, key, value, mask)

print("Multi-Head Attention Output:\n", output)
print("Attention Weights:\n", attention_weights.shape)


In [None]:
# Test MultiHeadAttention with a padding mask and causal mask

# Define parameters
num_heads = 2
d_model = 8
seq_length = 4
batch_size = 1

# Random inputs for query, key, and value
query = torch.rand(batch_size, seq_length, d_model).to(device)
key = torch.rand(batch_size, seq_length, d_model).to(device)
value = torch.rand(batch_size, seq_length, d_model).to(device)

# Create a random sequence with padding (0 represents padding token)
src_tokens = torch.tensor([[1, 2, 3, 0]]).to(device)  # Example with 1 padding token

# Create a padding mask
padding_mask = create_padding_mask(src_tokens).to(device)  # Shape: (batch_size, 1, 1, seq_length)

# Create a causal mask (look-ahead mask)
causal_mask = create_causal_mask(seq_length).to(device)  # Shape: (seq_length, seq_length)

# Combine the masks (bitwise AND to use both padding and causal masks)
combined_mask = padding_mask & causal_mask.unsqueeze(0)
combined_mask.to(device)

# Create MultiHeadAttention object
multihead_attn = MultiHeadAttention(num_heads=num_heads, d_model=d_model, verbose=True).to(device)

# Pass the inputs through multi-head attention with a mask
output, attention_weights = multihead_attn(query, key, value, combined_mask)

print("\nMulti-Head Attention Output:\n", output)
print("Attention Weights:\n", attention_weights)
print("Padding Mask:\n", padding_mask)
print("Causal Mask:\n", causal_mask)
print("Combined Mask:\n", combined_mask)

### Encoder layer

Now we implement the Encoder layer

In [None]:
class PositionwiseFFN(nn.Module):
    def __init__(self, d_ff: int, d_model: int, dropout: float = 0.1):
        super(PositionwiseFFN, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear2(self.dropout(torch.relu(self.linear1(x))))
    
class EncoderLayer(nn.Module):
    def __init__(self, num_heads: int, d_model: int, d_ff: int, dropout: float = 0.1, verbose: bool = False):
        super(EncoderLayer, self).__init__()
        self.mha = MultiHeadAttention(num_heads=num_heads, d_model=d_model, dropout=dropout, verbose=verbose)
        self.ffn = PositionwiseFFN(d_ff=d_ff, d_model=d_model, dropout=dropout)
        self.layernorm1 = nn.LayerNorm(d_model)
        self.layernorm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.verbose = verbose

    def forward(self, x, mask=None):
        if self.verbose:
            print(f"Input to Encoder Layer: {x.shape}")
        
        # Multi-head attention with residual connection and layer normalization
        attn_output, _ = self.mha(x, x, x, mask)
        if self.verbose:
            print(f"attn_output shape: {attn_output.shape}")
        out1 = self.layernorm1(x + self.dropout(attn_output))

        # Feedforward with residual connection and layer normalization
        ffn_output = self.ffn(out1)
        out2 = self.layernorm2(out1 + self.dropout(ffn_output))  # Fixed: add out1, not x

        if self.verbose:
            print(f"Output from Encoder Layer: {out2.shape}")
        
        return out2


In [None]:
# Test EncoderLayer with random inputs

# Define parameters
num_heads = 2
d_model = 8
d_ff = 16
seq_length = 4
batch_size = 1

# Random input sequence
x = torch.rand(batch_size, seq_length, d_model).to(device)

# Create a random padding mask (e.g., if needed)
padding_mask = create_padding_mask(torch.tensor([[1, 2, 3, 0]])).to(device)  # Example with padding
print(f"Padding mask: {padding_mask.int()}")

# Create EncoderLayer object
encoder_layer = EncoderLayer(num_heads=num_heads, d_model=d_model, d_ff=d_ff, verbose=True).to(device)

# Pass the input through the encoder layer
output = encoder_layer(x, mask=padding_mask)

print("\nOutput from Encoder Layer:\n", output)


In [None]:
# Create a padding mask (0 indicates padding)
src_tokens = torch.tensor([[1, 2, 3, 0]]).to(device)  # Example sequence with padding
padding_mask = create_padding_mask(src_tokens).to(device)

# Test EncoderLayer with padding mask
encoder_layer = EncoderLayer(num_heads=2, d_model=8, d_ff=16, dropout=0.1, verbose=True).to(device)
x = torch.rand(1, 4, 8).to(device)  # Random input sequence

# Pass through the encoder layer with the mask
output = encoder_layer(x, mask=padding_mask)
print("Output from encoder layer with padding mask:", output)


### Decoder layer

decoder layer implementation

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, num_heads: int, d_model: int, d_ff: int, dropout: float = 0.1, verbose=False):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(num_heads=num_heads, d_model=d_model, dropout=dropout, verbose=verbose)
        self.src_attn = MultiHeadAttention(num_heads=num_heads, d_model=d_model, dropout=dropout, verbose=verbose)
        self.ffn = PositionwiseFFN(d_ff=d_ff, d_model=d_model, dropout=dropout)
        self.layernorms = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(3)])
        self.dropout = nn.Dropout(dropout)
        self.verbose = verbose

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):

        if self.verbose:
            print(f"Input shape x: {x.shape}")
            print(f"Encoder output shape: {enc_output.shape}\n")
        # masked self-attention over the target (with look-ahead mask)

        if self.verbose:
            print(f"Passing through self-attention")
        self_attn_output, _ = self.self_attn(x, x, x, tgt_mask)
        x = self.layernorms[0](x + self.dropout(self_attn_output))

        if self.verbose:
            print(f"\nPassing Through encoder-decoder attention")
        # encoder-decoder attention over the encoder output (attend to source)
        enc_dec_attn_output, _ = self.src_attn(x, enc_output, enc_output, src_mask)
        x = self.layernorms[1](x + self.dropout(enc_dec_attn_output))

        if self.verbose:
            print(f"\nFinal feedforward of layer")
        # feedforward with residual connection and layer normalization
        ffn_output = self.ffn(x)
        x = self.layernorms[2](x + self.dropout(ffn_output))
        
        if self.verbose:
            print(f"\nOutput shape: {x.shape}")

        return x



In [None]:
def create_causal_mask(seq_length):
    """
    Creates a causal mask (look-ahead mask) that prevents attending to future tokens.
    size: Length of the sequence.
    """
    return torch.tril(torch.ones(seq_length, seq_length)).type(torch.uint8)  # Shape: (seq_length, seq_length)

In [None]:
# Random input sequence for target (decoder input)
tgt = torch.rand(1, 4, 8).to(device)  # (batch_size=1, seq_length=4, d_model=8)

# Random encoder output (assuming same dimensions for simplicity)
enc_output = torch.rand(1, 4, 8).to(device)

# Create masks
tgt_mask = create_causal_mask(seq_length=4).unsqueeze(0).to(device)  # Causal mask for target
src_mask = create_padding_mask(torch.tensor([[1, 2, 3, 0]])).to(device)  # Padding mask for source

# Initialize the decoder layer
decoder_layer = DecoderLayer(num_heads=2, d_model=8, d_ff=16, dropout=0.1, verbose=True).to(device)

# Pass through the decoder layer
output = decoder_layer(tgt, enc_output, src_mask=src_mask, tgt_mask=tgt_mask)
print("Output from decoder layer:", output)
    

### Positional Encoding

In [None]:
import math

class PositionalEncoding(nn.Module):
    "Implement the PE function."

    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1), :].requires_grad_(False)
        return self.dropout(x)

### Encoder

In [None]:
class Encoder(nn.Module):
    def __init__(self, num_blocks: int, num_heads: int, d_model: int, d_ff: int, dropout: float = 0.1, verbose: bool = False):
        super(Encoder, self).__init__()
        self.num_blocks = num_blocks
        self.verbose = verbose

        # encoder layers
        self.encoder_blocks = nn.ModuleList([
            EncoderLayer(num_heads=num_heads, d_model=d_model, d_ff=d_ff, dropout=dropout, verbose=verbose) for _ in range(num_blocks)
        ])

        # final layer normalization layer
        self.layernorm = nn.LayerNorm(d_model)

    def forward(self, x, src_mask = None):

        if self.verbose:
            print(f"Input of shape: {x.shape}")
        
        for i, block in enumerate(self.encoder_blocks):
            if self.verbose:
                print(f"\n------------ Passing Through Encoder block {i + 1} ----------------")
            
            x = block(x, mask=src_mask)

        # apply final layer normalization
        x = self.layernorm(x)

        if self.verbose:
            print(f"\nFinal output shape is: {x.shape}")

        return x

In [None]:
# Random input sequence (batch_size=1, seq_length=4, d_model=8)
src = torch.rand(1, 4, 8)

# Create a padding mask for the source sequence
src_mask = create_padding_mask(torch.tensor([[1, 2, 3, 0]]))

# Initialize the encoder with 2 blocks for testing
encoder = Encoder(num_blocks=2, num_heads=2, d_model=8, d_ff=16, dropout=0.1, max_length=30, verbose=True)

# Pass through the encoder
output = encoder(src, src_mask=src_mask)
print("Final output from encoder:", output)


In [None]:
# test with actual examples
batch_size = 4
small_dl = DataLoader(small_translation_ds, batch_size = batch_size, shuffle=True, collate_fn=collate_fn)

for batch in small_dl:
    print(batch.keys())
    src_tokens = batch['src_tokens'].to(device)  # The tokenized source sentences
    tgt_input = batch['tgt_input'].to(device)  # The tokenized target sentences
    tgt_output = batch['tgt_output'].to(device)
    src_mask = batch['src_mask'].to(device)
    tgt_mask = batch['tgt_mask'].to(device)

    print(f"Source tokens: {src_tokens.shape}")
    print(f"tgt_input: {tgt_input.shape}")
    print(f"tgt_output: {tgt_output.shape}")
    print(f"src_mask: {src_mask.shape}")
    print(f"tgt_mask: {tgt_mask.shape}")

    break  # Just getting the first batch for demonstration

encoder = Encoder(num_blocks=6, num_heads=8, d_model=512, d_ff=2048, verbose=True).to(device)

embedding = nn.Embedding(tokenizer.get_vocab_size(), 512).to(device)
pos_encoder = PositionalEncoding(512, dropout=0.1, max_len=512).to(device)

for batch in small_dl:
    src_tokens = batch['src_tokens'].to(device)
    src_mask = batch['src_mask'].to(device)

    print(f"\nSource token shape: {src_tokens.shape}")

    src_embed = embedding(src_tokens)
    src_embed = pos_encoder(src_embed)
    encoder_output = encoder(src_embed, src_mask)
    break


### Decoder

In [None]:
class Decoder(nn.Module):
    def __init__(self, num_blocks: int, num_heads: int, d_model: int, d_ff: int, dropout: float = 0.1, verbose: bool = True):
        super(Decoder, self).__init__()
        self.decoder_blocks = nn.ModuleList([
            DecoderLayer(num_heads=num_heads, d_model=d_model, d_ff=d_ff, dropout=dropout, verbose=verbose) for _ in range(num_blocks)
        ])
        self.layernorm = nn.LayerNorm(d_model)
        self.verbose = verbose

    def forward(self, tgt, enc_output, src_mask=None, tgt_mask=None):
        for i, block in enumerate(self.decoder_blocks):
            if self.verbose:
                print(f"\n------------- Passing Through Decoder Block {i+1} ----------------")
            tgt = block(tgt, enc_output, src_mask, tgt_mask)

        return self.layernorm(tgt)

In [None]:
import torch
import torch.nn as nn

# Mock data for testing
batch_size = 4
seq_length = 5
d_model = 512
num_heads = 8
num_blocks = 6
d_ff = 2048

# Random embedded target tokens (already embedded, just mock data)
tgt_embed = torch.rand(batch_size, seq_length, d_model)  # (batch_size, seq_length, d_model)

# Random encoder output (to simulate the output from the encoder)
enc_output = torch.rand(batch_size, seq_length, d_model)  # (batch_size, seq_length, d_model)

# Create padding mask (mock data, assume no padding tokens for simplicity)
src_mask = torch.ones(batch_size, 1, seq_length)  # Shape: (batch_size, 1, seq_length)
print(src_mask.shape)

tgt_mask = torch.tril(torch.ones(batch_size, seq_length, seq_length))
print(tgt_mask.shape)


# Initialize the decoder without embedding
decoder = Decoder(num_blocks=num_blocks, num_heads=num_heads, d_model=d_model, d_ff=d_ff, dropout=0.1, verbose=True)

# Pass the mock data through the decoder
decoder_output = decoder(tgt_embed, enc_output, src_mask=src_mask, tgt_mask=tgt_mask)

print("Decoder output shape:", decoder_output.shape)


# Translator Class

Now we are going to instantiate an encoder and decoder class and string them together to confirm that everything works together. Than we will abstract and create an Encoder-Decoder sequence to sequence model. 

In [None]:
class Generator(nn.Module):
    """Define the linear + softmax step for generating token probabilities.
        Layer projects vector on to vocab space and then applys a log_softmax. 
    """
    def __init__(self, d_model, vocab_size):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        return torch.log_softmax(self.proj(x), dim=-1)


In [None]:
# parameters
d_model = 512
d_ff = 2048
dropout = 0.1
num_blocks = 6
num_heads = 8
max_len = 30

batch_size = 4

small_dl = DataLoader(small_translation_ds, batch_size = batch_size, shuffle=True, collate_fn=collate_fn)

embedding = nn.Embedding(tokenizer.get_vocab_size(), d_model)
pos_encoder = PositionalEncoding(d_model=d_model, dropout=dropout, max_len=max_len)

encoder = Encoder(num_blocks=num_blocks, num_heads=num_heads, d_model=d_model, d_ff=d_ff, dropout=dropout, verbose=True)

decoder = Decoder(num_blocks=6, num_heads=8, d_model=d_model, d_ff=d_ff, dropout=dropout, verbose=True)

generator = Generator(d_model=d_model, vocab_size=tokenizer.get_vocab_size())

for batch in small_dl:
    print(f"Source tokens:", batch['src_tokens'].shape)
    print(f"Target input tokens:", batch['tgt_input'].shape)
    print(f"Target output tokens:", batch['tgt_output'].shape)
    print(f"Source mask:", batch['src_mask'].shape)
    print(f"Target mask:", batch['tgt_mask'].shape)

    src_tokens = batch['src_tokens']
    src_mask = batch['src_mask']
    tgt_input = batch['tgt_input']
    tgt_output = batch['tgt_output']
    src_mask = batch['src_mask']
    tgt_mask = batch['tgt_mask']

    src_embed = embedding(src_tokens)
    src_embed = pos_encoder(src_embed)
    encoder_output = encoder(src_embed, src_mask)

    print(f"Encoder output: {encoder_output.shape}")

    tgt_embed = embedding(tgt_input)
    tgt_embed = pos_encoder(tgt_embed)
    dec_output = decoder(tgt=tgt_embed, enc_output = encoder_output, src_mask=src_mask, tgt_mask=tgt_mask)

    output = generator(dec_output)
    predicted_tokens = torch.argmax(output, dim=-1)

    print(output)
    print(tgt_output)

    print(output.shape)
    print(predicted_tokens.shape)
    print(tgt_output.shape)
    break


In [None]:
# now we abstract the above into a EncoderDecoder class. 
class EncoderDecoder(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, generator: Generator, embedding: nn.Embedding, pos_encoder: PositionalEncoding, verbose: bool = False):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.generator = generator
        self.embedding = embedding
        self.pos_encoder = pos_encoder
        self.verbose = verbose

    def forward(self, src_tokens, tgt_input, src_mask, tgt_mask):
        # Encoder
        src_embed = self.embedding(src_tokens)
        src_embed = self.pos_encoder(src_embed)
        encoder_output = self.encoder(src_embed, src_mask)

        # Decoder
        tgt_embed = self.embedding(tgt_input)
        tgt_embed = self.pos_encoder(tgt_embed)
        dec_output = self.decoder(tgt=tgt_embed, enc_output=encoder_output, src_mask=src_mask, tgt_mask = tgt_mask)

        output_log_probs = self.generator(dec_output)

        return output_log_probs



# Training

In [None]:
# create dataset
train_ds = TranslationDataset(dataset['train'].shuffle().select(range(20000)), tokenizer=tokenizer, bos_token_id=BOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID, pad_token_id=PAD_TOKEN_ID)
val_ds = TranslationDataset(dataset['validation'], tokenizer=tokenizer, bos_token_id=BOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID, pad_token_id=PAD_TOKEN_ID)

In [None]:
# create dataloaders
batch_size = 16
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_dl = DataLoader(val_ds, batch_size=batch_size, collate_fn=collate_fn)

In [None]:
for batch in train_dl:
    print(batch.keys())
    break

In [None]:
# initiate models

# parameters
d_model = 512
d_ff = 2048
dropout = 0.1
num_blocks = 6
num_heads = 8
max_len = 512
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

embedding = nn.Embedding(tokenizer.get_vocab_size(), d_model).to(device)
pos_encoder = PositionalEncoding(d_model=d_model, dropout=dropout, max_len=max_len).to(device)
encoder = Encoder(num_blocks=num_blocks, num_heads=num_heads, d_model=d_model, d_ff=d_ff, dropout=dropout, verbose=False).to(device)
decoder = Decoder(num_blocks=6, num_heads=8, d_model=d_model, d_ff=d_ff, dropout=dropout, verbose=False).to(device)
generator = Generator(d_model=d_model, vocab_size=tokenizer.get_vocab_size()).to(device)

model = EncoderDecoder(encoder, decoder, generator, embedding, pos_encoder, verbose=False).to(device)

In [None]:
# Create learning rate scheduler, following `Attention is All You Need` for now. 
# lr = d_model ** (-0.5) * min(step_num ** (-0.5), step_num * warmup_steps ** (-1.5))
warmup_steps = 4000

def get_lr(step_num):
    return d_model ** -0.5 * min(step_num ** -0.5, step_num * warmup_steps ** -1.5)



In [None]:
from tqdm import tqdm
from tqdm.notebook import tqdm

# training loop
# optimizer and criterion
learning_rate = 1e-4
optimizer = torch.optim.AdamW(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
criterion = nn.NLLLoss()

num_epochs = 5
step_num = 0

for epoch in range(num_epochs):

    model.train()
    total_loss = 0


    for batch in tqdm(train_dl):
        step_num += 1

        for param_group in optimizer.param_groups:
            param_group['lr'] = get_lr(step_num)

        src_tokens = batch['src_tokens'].to(device)
        tgt_input = batch['tgt_input'].to(device)
        tgt_output = batch['tgt_output'].to(device)
        src_mask = batch['src_mask'].to(device)
        tgt_mask = batch['tgt_mask'].to(device)

        # print(src_tokens.device)
        # print(src_mask.device)

        # zero the gradients
        optimizer.zero_grad()

        output_logits = model(src_tokens, tgt_input, src_mask, tgt_mask)

        loss = criterion(output_logits.view(-1, output_logits.size(-1)), tgt_output.view(-1))
        
        loss.backward()

        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_dl)
    print(f"Epoch {epoch + 1}/{num_epochs}, Average training loss: {avg_loss: .4f}")

    # validation
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for batch in tqdm(val_dl):
            src_tokens = batch['src_tokens'].to(device)
            tgt_input = batch['tgt_input'].to(device)
            tgt_output = batch['tgt_output'].to(device)
            src_mask = batch['src_mask'].to(device)
            tgt_mask = batch['tgt_mask'].to(device)

            output = model(src_tokens, tgt_input, src_mask, tgt_mask)

            loss = criterion(output.view(-1, output.size(-1)), tgt_output.view(-1))
            total_val_loss += loss.item()
    
    avg_val_loss = total_val_loss / len(val_dl)
    print(F"Epoch {epoch + 1}/{num_epochs}, Average validation loss: {avg_val_loss:.4f}")

In [None]:
model

# Inference

In [None]:
# check if inference works
num_examples = 10
examples = []
for i in range(num_examples):
    examples.append(val_ds[i])
examples

In [None]:
model.to(device)

In [None]:
def greedy_decoding(model: EncoderDecoder, src_tokens: torch.Tensor, tokenizer: ByteLevelBPETokenizer = tokenizer, pad_token_id: int = PAD_TOKEN_ID, max_len=100):
    model.eval()

    # embed the source tokens and create the src_mask
    src_tokens = src_tokens.unsqueeze(0).to(device)
    src_mask = (src_tokens != pad_token_id).unsqueeze(0).to(device) # shape: (1, seq_length)

    src_embed = model.embedding(src_tokens)
    src_embed = model.pos_encoder(src_embed)

    # store encoder hidden states for the src_tokens
    encoder_output = model.encoder(src_embed, src_mask)

    # initizlie target sentence with BOS token
    tgt_tokens = torch.tensor([BOS_TOKEN_ID], dtype=torch.long).to(device)

    # Autoregressive loop to generate sentence
    for _ in range(max_len):
        # create target mask
        tgt_seq_len = tgt_tokens.size(0)
        tgt_mask = torch.tril(torch.ones(1, tgt_seq_len, tgt_seq_len)).to(device)
        #print(f"target mask shape: {tgt_mask.shape}")

        #print(f"Tokens at beginning: {tgt_tokens}")
        tgt_embed = model.embedding(tgt_tokens).unsqueeze(0)
        #print(f"Token embeddings shape: {tgt_embed.shape}")
        tgt_embed = model.pos_encoder(tgt_embed)

        output_logits = model.decoder(tgt_embed, encoder_output, src_mask, tgt_mask)
        output_log_probs = model.generator(output_logits)
        # print(output_log_probs.shape)
        # print(output_log_probs)
        next_token = torch.argmax(output_log_probs[:, -1, :], dim=-1)
        # print(next_token.shape)
        #print(f"Next token: {next_token.item()}")
        #print(tgt_tokens.shape)
        # append next
        tgt_tokens = torch.cat([tgt_tokens, next_token])

        if next_token.item() == EOS_TOKEN_ID or tgt_tokens.size(0) >= 50:
            break

        #print(f"Resulting Target Tokens: {tgt_tokens}")

    print(f"Source sentence: {tokenizer.decode([num for num in src_tokens.squeeze(0).tolist() if num != pad_token_id], skip_special_tokens=True)}")
    print(f"Translation: {tokenizer.decode(tgt_tokens.tolist())}")

    