# Setup

To set up an Anaconda environment for implementing the Transformer model in PyTorch, follow these steps:

---

### **1. Create a New Conda Environment**
Open a terminal and run:
```bash
conda create --name attention-is-all-you-need python=3.12
```

---

### **2. Activate the Environment**
```bash
conda activate attention-is-all-you-need
```

---

### **3. Install PyTorch**
For GPU (CUDA):
```bash
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
```
For CPU (if you donâ€™t have a compatible GPU):
```bash
conda install pytorch torchvision torchaudio cpuonly -c pytorch
```
Check if PyTorch is installed correctly:
```python
python -c "import torch; print(torch.__version__)"
```

---

### **4. Install Essential Libraries**
```bash
pip install numpy pandas matplotlib tqdm
```
- `numpy`: Tensor operations
- `pandas`: Data handling (optional, useful for datasets)
- `matplotlib`: Visualization
- `tqdm`: Progress bars for training

---

### **5. Install NLP Libraries (If Needed)**
```bash
pip install transformers datasets tokenizers sentencepiece
```
- `transformers`: Pretrained models from Hugging Face (optional)
- `datasets`: NLP datasets from Hugging Face
- `tokenizers`: Efficient tokenization
- `sentencepiece`: Subword tokenization (used in original Transformer)

---

### **6. Install Jupyter Notebook (Optional)**
If you want to develop in Jupyter:
```bash
conda install jupyter
```
Then start Jupyter:
```bash
jupyter notebook
```

---

### **7. Verify Everything**
Run the following to ensure your environment is properly set up:
```python
import torch
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
```

---

### **8. Save the Environment (Optional)**
To export your environment for reproducibility:
```bash
conda env export > environment.yml
```
To recreate it later:
```bash
conda env create -f environment.yml
```

---

# Start

In [1]:
import torch
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

PyTorch version: 2.5.1
CUDA available: False


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from typing import Optional, Tuple

In [3]:
import torch
import torch.nn as nn

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, d_model: int):
        """
        Initializes the embedding layer.

        Args:
            vocab_size (int): Number of unique tokens in the vocabulary.
            d_model (int): Dimension of the embedding vectors.
        """
        super().__init__()
        
        # TODO: Define the embedding layer that maps token indices to dense vectors.
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)  

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for token embedding.

        Args:
            x (torch.Tensor): Tensor of shape (batch_size, seq_len) containing token indices.

        Returns:
            torch.Tensor: Tensor of shape (batch_size, seq_len, d_model) containing embedded representations.
        """
        # TODO: Implement the lookup operation using the embedding layer.
        embedded = self.embedding(x)  

        return embedded


In [4]:
def run_tests():
    # Test Parameters
    vocab_size = 100
    d_model = 16
    batch_size = 4
    seq_len = 10

    # Create a sample input tensor
    test_input = torch.randint(0, vocab_size, (batch_size, seq_len))

    # Initialize TokenEmbedding
    embedding_layer = TokenEmbedding(vocab_size, d_model)

    # Test 1: Check Output Shape
    output = embedding_layer(test_input)
    assert output.shape == (batch_size, seq_len, d_model), f"Unexpected shape: {output.shape}"
    
    # Test 2: Ensure Output is a Tensor of Correct Type
    assert isinstance(output, torch.Tensor), "Output is not a tensor"
    assert output.dtype == torch.float32, f"Unexpected dtype: {output.dtype}"
    
    # Test 3: Check if the Same Token Index Maps to the Same Embedding
    index = torch.tensor([[5]])
    embedding_1 = embedding_layer(index)
    embedding_2 = embedding_layer(index)
    assert torch.allclose(embedding_1, embedding_2), "Embeddings should be identical for the same index"
    
    # Test 4: Check if Different Indices Give Different Embeddings
    index1 = torch.tensor([[5]])
    index2 = torch.tensor([[8]])
    embedding_1 = embedding_layer(index1)
    embedding_2 = embedding_layer(index2)
    assert not torch.allclose(embedding_1, embedding_2), "Different indices should have different embeddings"
    
    # Test 5: Check if Gradients are Computed
    loss = output.sum()
    loss.backward()
    assert embedding_layer.embedding.weight.grad is not None, "Gradients should not be None"
    assert embedding_layer.embedding.weight.grad.shape == (vocab_size, d_model), "Gradient shape mismatch"
    
    print("âœ… All tests passed successfully!")

# Run all tests
run_tests()


âœ… All tests passed successfully!


In [5]:
embedding_layer = TokenEmbedding(vocab_size=10, d_model=3)
embedding_layer(torch.tensor(5))

tensor([ 0.7030,  0.0881, -0.1770], grad_fn=<EmbeddingBackward0>)

In [6]:
import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        """
        Initializes positional encoding.

        Args:
            d_model (int): Dimension of the embedding vectors.
            max_len (int): Maximum sequence length.
        """
        super().__init__()

        # TODO: Create an empty tensor to hold positional encodings of shape (max_len, d_model)
        pe = torch.zeros(size=(max_len, d_model))

        # TODO: Create a position index tensor of shape (max_len, 1)
        positions = torch.arange(max_len).unsqueeze(1)  # Replace with the correct initialization

        # TODO: Compute the denominator term for the sine/cosine functions
        div_term = 10**4**(2*positions/d_model)  # Replace with the correct computation

        # TODO: Compute sin and cos positional encodings
        # Apply sine to even indices and cosine to odd indices
        # Hint: Use slicing `self.pe[:, 0::2] = ...` for even indices
        #       Use slicing `self.pe[:, 1::2] = ...` for odd indices
        pe[:, 0::2] = torch.sin(positions/div_term)
        pe[:, 1::2] = torch.cos(positions/div_term)

        # TODO: Register `self.pe` as a buffer so it doesn't update during training
        # Hint: Use `self.register_buffer("pe", self.pe)`
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Adds positional encoding to the input embeddings.

        Args:
            x (torch.Tensor): Tensor of shape (batch_size, seq_len, d_model) containing input embeddings.

        Returns:
            torch.Tensor: Tensor of shape (batch_size, seq_len, d_model) with positional encodings added.
        """
        # TODO: Retrieve only the necessary positions up to the input sequence length
        # Hint: Slice `self.pe` correctly based on `x.size(1)`
        pe_slice = self.pe[:x.size(1),:].unsqueeze(0)

        # TODO: Add positional encodings to the input embeddings
        # Hint: Ensure the positional encodings are on the same device as `x`
        pe_slice.to(x.device)

        return x + pe_slice  # Replace with the final tensor with positional encoding added

In [7]:
def run_positional_encoding_tests():
    d_model = 16
    seq_len = 10
    batch_size = 4

    test_input = torch.zeros((batch_size, seq_len, d_model))  # Placeholder embeddings
    pos_encoding = PositionalEncoding(d_model=d_model)

    # Test 1: Check Output Shape
    output = pos_encoding(test_input)
    assert output.shape == (batch_size, seq_len, d_model), f"Unexpected shape: {output.shape}"
    
    # Test 2: Ensure Output is a Tensor of Correct Type
    assert isinstance(output, torch.Tensor), "Output is not a tensor"
    assert output.dtype == torch.float32, f"Unexpected dtype: {output.dtype}"
    
    # Test 3: Check if Positional Encoding is Being Added
    assert not torch.allclose(test_input, output), "Positional encoding is not being added!"
    
    # Test 4: Check Device Compatibility
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    test_input = test_input.to(device)
    pos_encoding = pos_encoding.to(device)
    output = pos_encoding(test_input)
    assert output.device == test_input.device, f"Device mismatch: {output.device} vs {test_input.device}"
    
    # Test 5: Check if Encodings are Deterministic
    output1 = pos_encoding(test_input)
    output2 = pos_encoding(test_input)
    assert torch.allclose(output1, output2), "Positional encoding should be deterministic!"
    
    print("âœ… All positional encoding tests passed successfully!")

# Run all tests
run_positional_encoding_tests()


âœ… All positional encoding tests passed successfully!


In [8]:
def scratchboard(max_len, d_model):
    pe = torch.zeros(size=(max_len, d_model))
    positions = torch.arange(max_len).unsqueeze(1)
    div_term = 10**4**(2*positions/d_model)
    print(div_term)

scratchboard(13, 10)

tensor([[1.0000e+01],
        [2.0869e+01],
        [5.5094e+01],
        [1.9833e+02],
        [1.0751e+03],
        [1.0000e+04],
        [1.8968e+05],
        [9.2131e+06],
        [1.5473e+09],
        [1.3358e+12],
        [1.0000e+16],
        [1.2946e+21],
        [7.2048e+27]])


In [9]:
import torch
import torch.nn as nn
from typing import Optional, Tuple

class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k: int):
        """
        Initializes scaled dot-product attention.

        Args:
            d_k (int): Dimension of the key vectors (used for scaling).
        """
        super().__init__()

        # TODO: Store d_k for scaling attention scores
        self.d_k = d_k  # Replace with correct initialization

    def forward(
        self, 
        query: torch.Tensor, 
        key: torch.Tensor, 
        value: torch.Tensor, 
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the scaled dot-product attention.

        Args:
            query (torch.Tensor): Shape (batch_size, num_heads, seq_len, d_k)
            key (torch.Tensor): Shape (batch_size, num_heads, seq_len, d_k)
            value (torch.Tensor): Shape (batch_size, num_heads, seq_len, d_v)
            mask (Optional[torch.Tensor]): Shape (batch_size, 1, seq_len, seq_len) 
                                           (mask for padding or future tokens in decoder)

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: 
                - Attention output of shape (batch_size, num_heads, seq_len, d_v)
                - Attention weights of shape (batch_size, num_heads, seq_len, seq_len)
        """
        # TODO: Compute attention scores as QK^T / sqrt(d_k)
        attention_scores = torch.matmul(query, torch.transpose(key, -2, -1)) / math.sqrt(self.d_k)   # Replace with correct computation

        # TODO: Apply mask (if provided) by setting masked positions to a very low value
        # Hint: Use `float('-inf')` for masked positions before applying softmax
        if mask is not None:
            attention_scores.masked_fill_(mask.bool(), float('-inf'))

        # TODO: Compute attention weights using softmax
        attention_weights = torch.softmax(attention_scores, dim=-1)  # Replace with correct computation

        # TODO: Multiply attention weights by value matrix to get the final output
        output = torch.matmul(attention_weights, value)  # Replace with correct computation

        return output, attention_weights  # Return attention output and weights


In [10]:
import torch

def test_scaled_dot_product_attention():
    """
    Tests the ScaledDotProductAttention module for correctness.
    """
    batch_size = 2
    num_heads = 4
    seq_len = 5
    d_k = 8
    d_v = 8

    # Initialize test input tensors
    query = torch.randn(batch_size, num_heads, seq_len, d_k)
    key = torch.randn(batch_size, num_heads, seq_len, d_k)
    value = torch.randn(batch_size, num_heads, seq_len, d_v)

    # Initialize attention module
    attention = ScaledDotProductAttention(d_k)

    # Run forward pass without a mask
    output, attention_weights = attention(query, key, value, mask=None)

    # Test 1: Check output shape
    assert output.shape == (batch_size, num_heads, seq_len, d_v), \
        f"Unexpected output shape: {output.shape}"
    print("âœ… Output shape test passed!")

    # Test 2: Check attention weights shape
    assert attention_weights.shape == (batch_size, num_heads, seq_len, seq_len), \
        f"Unexpected attention weights shape: {attention_weights.shape}"
    print("âœ… Attention weights shape test passed!")

    # Test 3: Check softmax behavior (sum of attention weights should be ~1 per row)
    sum_weights = attention_weights.sum(dim=-1)  # Sum over last dim
    assert torch.allclose(sum_weights, torch.ones_like(sum_weights), atol=1e-5), \
        "Softmax output does not sum to 1"
    print("âœ… Softmax test passed!")

    # Test 4: Check deterministic output for same input
    output_2, attention_weights_2 = attention(query, key, value, mask=None)
    assert torch.allclose(output, output_2), "Output should be deterministic!"
    assert torch.allclose(attention_weights, attention_weights_2), "Attention weights should be deterministic!"
    print("âœ… Deterministic output test passed!")

    # Test 5: Apply a mask and check if it works
    mask = torch.zeros(batch_size, 1, seq_len, seq_len)
    mask[:, :, :, -1] = float('-inf')  # Mask the last token

    output_masked, attention_weights_masked = attention(query, key, value, mask=mask)

    # The last column of attention weights should be very small (close to 0)
    assert torch.all(attention_weights_masked[:, :, :, -1] < 1e-3), \
        "Masking is not applied correctly!"
    print("âœ… Masking test passed!")

    print("ðŸŽ‰ All ScaledDotProductAttention tests passed!")

# Run the test
test_scaled_dot_product_attention()

âœ… Output shape test passed!
âœ… Attention weights shape test passed!
âœ… Softmax test passed!
âœ… Deterministic output test passed!
âœ… Masking test passed!
ðŸŽ‰ All ScaledDotProductAttention tests passed!


In [11]:
import torch
import torch.nn as nn
from typing import Optional, Tuple

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        """
        Initializes multi-head attention.

        Args:
            d_model (int): Dimension of the model (input and output size).
            num_heads (int): Number of attention heads.
        """
        super().__init__()

        # TODO: Ensure d_model is divisible by num_heads
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # Dimension per head

        # TODO: Define linear transformations for query, key, and value
        self.W_q = nn.Linear(d_model, d_model)  # Replace with nn.Linear
        self.W_k = nn.Linear(d_model, d_model)  # Replace with nn.Linear
        self.W_v = nn.Linear(d_model, d_model)  # Replace with nn.Linear

        # TODO: Define output projection layer
        self.W_o = nn.Linear(d_model, d_model)  # Replace with nn.Linear

        # TODO: Define the scaled dot-product attention module
        self.attention = ScaledDotProductAttention(self.d_k)  # Replace with ScaledDotProductAttention(self.d_k)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Computes multi-head attention.

        Args:
            query (torch.Tensor): Shape (batch_size, seq_len, d_model)
            key (torch.Tensor): Shape (batch_size, seq_len, d_model)
            value (torch.Tensor): Shape (batch_size, seq_len, d_model)
            mask (Optional[torch.Tensor]): Shape (batch_size, 1, seq_len, seq_len)

        Returns:
            torch.Tensor: Shape (batch_size, seq_len, d_model) - Multi-head attention output.
        """
        # TODO: Apply linear transformations to query, key, and value
        Q = self.W_q(query)  # Replace with correct transformation
        K = self.W_k(key)  # Replace with correct transformation
        V = self.W_v(value)  # Replace with correct transformation

        # TODO: Reshape Q, K, V for multi-head attention
        # Hint: Use `.view()` and `.transpose()` to shape into (batch_size, num_heads, seq_len, d_k)
        batch_size, seq_len, _ = query.shape
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k)

        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)

        # TODO: Apply scaled dot-product attention
        output, attention_weights = self.attention(Q, K, V, mask)  # Replace with correct computation

        # TODO: Concatenate the heads back and apply final linear transformation
        # Current shape: (batch_size, num_heads, seq_len, d_k)
        # We first swap num_heads and seq_len
        output = output.transpose(1, 2)  # (batch_size, seq_len, num_heads, d_k)
        output = output.contiguous().view(batch_size, seq_len, self.d_model)

        output = self.W_o(output)  # Replace with correct transformation

        return output


In [12]:
def test_multi_head_attention():
    """
    Tests the MultiHeadAttention module for correctness.
    """
    batch_size = 2
    seq_len = 5
    d_model = 16
    num_heads = 4

    # Initialize test input tensors
    query = torch.randn(batch_size, seq_len, d_model)
    key = torch.randn(batch_size, seq_len, d_model)
    value = torch.randn(batch_size, seq_len, d_model)

    # Initialize multi-head attention module
    mha = MultiHeadAttention(d_model, num_heads)

    # Run forward pass without a mask
    output = mha(query, key, value, mask=None)

    # Test 1: Check output shape
    assert output.shape == (batch_size, seq_len, d_model), \
        f"Unexpected output shape: {output.shape}"
    print("âœ… Output shape test passed!")

    # Test 2: Ensure output is a tensor
    assert isinstance(output, torch.Tensor), "Output is not a tensor"
    print("âœ… Tensor type test passed!")

    # Test 3: Check deterministic output for same input
    output_2 = mha(query, key, value, mask=None)
    assert torch.allclose(output, output_2), "Output should be deterministic!"
    print("âœ… Deterministic output test passed!")

    # Test 4: Apply a mask and check if masking works
    mask = torch.zeros(batch_size, 1, seq_len, seq_len)
    mask[:, :, :, -1] = float('-inf')  # Mask the last token

    output_masked = mha(query, key, value, mask=mask)

    # Ensure output is still the correct shape
    assert output_masked.shape == (batch_size, seq_len, d_model), \
        "Masked output shape mismatch"
    print("âœ… Masking test passed!")

    print("ðŸŽ‰ All MultiHeadAttention tests passed!")

# Run the test
test_multi_head_attention()

âœ… Output shape test passed!
âœ… Tensor type test passed!
âœ… Deterministic output test passed!
âœ… Masking test passed!
ðŸŽ‰ All MultiHeadAttention tests passed!


In [13]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch_size, seq_len, d_model)

        Returns:
            (batch_size, seq_len, d_model) - Transformed representations.
        """
        output = self.fc1(x)
        output = self.relu(output)
        output = self.fc2(output)

        return output

In [14]:
def test_positionwise_feedforward():
    """
    Tests the PositionwiseFeedForward module.
    """
    batch_size = 2
    seq_len = 5
    d_model = 16
    d_ff = 32  # Expanded dimension

    # Initialize test input tensor (random)
    x = torch.randn(batch_size, seq_len, d_model)

    # Initialize the feed-forward module
    ffn = PositionwiseFeedForward(d_model, d_ff)

    # Run forward pass
    output = ffn(x)

    # Test 1: Check output shape
    assert output.shape == (batch_size, seq_len, d_model), \
        f"Unexpected output shape: {output.shape}"
    print("âœ… Output shape test passed!")

    # Test 2: Ensure output is a tensor
    assert isinstance(output, torch.Tensor), "Output is not a tensor"
    print("âœ… Tensor type test passed!")

    # Test 3: Ensure ReLU activation is applied
    hidden_layer_output = ffn.fc1(x)  # Get pre-ReLU values
    assert torch.all((hidden_layer_output > 0) == (ffn.relu(hidden_layer_output) > 0)), \
        "ReLU activation is not applied correctly"
    print("âœ… ReLU activation test passed!")

    # Test 4: Check deterministic output for same input
    output_2 = ffn(x)
    assert torch.allclose(output, output_2), "Output should be deterministic!"
    print("âœ… Deterministic output test passed!")

    # Test 5: Check gradients (ensuring backpropagation works)
    output.sum().backward()  # Compute gradients
    assert ffn.fc1.weight.grad is not None, "Gradients are not computed for fc1!"
    assert ffn.fc2.weight.grad is not None, "Gradients are not computed for fc2!"
    print("âœ… Gradient computation test passed!")

    print("ðŸŽ‰ All PositionwiseFeedForward tests passed!")

# Run the test
test_positionwise_feedforward()


âœ… Output shape test passed!
âœ… Tensor type test passed!
âœ… ReLU activation test passed!
âœ… Deterministic output test passed!
âœ… Gradient computation test passed!
ðŸŽ‰ All PositionwiseFeedForward tests passed!


In [15]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float):
        super().__init__()

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Args:
            x: (batch_size, seq_len, d_model) - Input to encoder layer.
            mask: (batch_size, 1, seq_len, seq_len) - Optional attention mask.

        Returns:
            (batch_size, seq_len, d_model) - Encoder layer output.
        """
        pass

In [16]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float):
        super().__init__()

    def forward(self, x: torch.Tensor, memory: torch.Tensor, 
                src_mask: Optional[torch.Tensor] = None, 
                tgt_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Args:
            x: (batch_size, seq_len, d_model) - Input to decoder layer.
            memory: (batch_size, seq_len, d_model) - Encoder output.
            src_mask: (batch_size, 1, seq_len, seq_len) - Optional encoder mask.
            tgt_mask: (batch_size, 1, seq_len, seq_len) - Optional decoder mask.

        Returns:
            (batch_size, seq_len, d_model) - Decoder layer output.
        """
        pass


class TransformerEncoder(nn.Module):
    def __init__(self, num_layers: int, d_model: int, num_heads: int, d_ff: int, dropout: float):
        super().__init__()

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Args:
            x: (batch_size, seq_len, d_model) - Input sequence.
            mask: (batch_size, 1, seq_len, seq_len) - Optional mask.

        Returns:
            (batch_size, seq_len, d_model) - Encoder output.
        """
        pass


class TransformerDecoder(nn.Module):
    def __init__(self, num_layers: int, d_model: int, num_heads: int, d_ff: int, dropout: float):
        super().__init__()

    def forward(self, x: torch.Tensor, memory: torch.Tensor, 
                src_mask: Optional[torch.Tensor] = None, 
                tgt_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Args:
            x: (batch_size, seq_len, d_model) - Target sequence.
            memory: (batch_size, seq_len, d_model) - Encoder output.
            src_mask: (batch_size, 1, seq_len, seq_len) - Optional encoder mask.
            tgt_mask: (batch_size, 1, seq_len, seq_len) - Optional decoder mask.

        Returns:
            (batch_size, seq_len, d_model) - Decoder output.
        """
        pass


class Transformer(nn.Module):
    def __init__(self, vocab_size: int, d_model: int, num_layers: int, num_heads: int, 
                 d_ff: int, dropout: float):
        super().__init__()

    def forward(self, src: torch.Tensor, tgt: torch.Tensor, 
                src_mask: Optional[torch.Tensor] = None, 
                tgt_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Args:
            src: (batch_size, src_seq_len) - Source token indices.
            tgt: (batch_size, tgt_seq_len) - Target token indices.
            src_mask: (batch_size, 1, src_seq_len, src_seq_len) - Optional source mask.
            tgt_mask: (batch_size, 1, tgt_seq_len, tgt_seq_len) - Optional target mask.

        Returns:
            (batch_size, tgt_seq_len, vocab_size) - Token probabilities.
        """
        pass


class TransformerTrainer:
    def __init__(self, model: Transformer, learning_rate: float, weight_decay: float):
        """
        Initializes optimizer and loss function for training.
        """
        pass

    def train_step(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
        """
        Runs a single training step.

        Args:
            src: (batch_size, src_seq_len) - Source sequence.
            tgt: (batch_size, tgt_seq_len) - Target sequence.

        Returns:
            Loss value.
        """
        pass

    def evaluate(self, src: torch.Tensor, tgt: torch.Tensor) -> float:
        """
        Evaluates the model on a validation set.

        Args:
            src: (batch_size, src_seq_len) - Source sequence.
            tgt: (batch_size, tgt_seq_len) - Target sequence.

        Returns:
            BLEU score or another evaluation metric.
        """
        pass