<a href="https://colab.research.google.com/github/ochaudha/sample/blob/main/TransformerEncoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import math
from torch.utils.data import Dataset, DataLoader

# --- Positional Encoding Class ---
class PositionalEncoding(nn.Module):
    """
    Implements the positional encoding for the Transformer model.
    It adds sinusoidal positional encodings to the input embeddings
    to inject information about the relative or absolute position of tokens
    in the sequence.
    """
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe) # Register 'pe' as a buffer, not a parameter

    def forward(self, x):
        """
        Args:
            x: Input tensor (batch_size, sequence_length, d_model)
        Returns:
            Tensor with positional encoding added.
        """
        # Add positional encoding to the input.
        # self.pe is truncated to the input sequence length.
        x = x + self.pe[:x.size(0), :]
        return x

# --- Transformer Encoder Layer Class ---
class TransformerEncoderLayer(nn.Module):
    """
    A single layer of the Transformer Encoder.
    It consists of a multi-head self-attention mechanism,
    followed by a position-wise feed-forward network.
    Each sub-layer also includes a residual connection and layer normalization.
    """
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        # Multi-head self-attention module
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=False)
        # Position-wise feed-forward network
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        # Layer normalization components
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        # Dropout layers for residual connections
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        # Activation function for the feed-forward network
        self.activation = nn.ReLU()

    def forward(self, src, src_mask=None, src_key_padding_mask=None, **kwargs):
        """
        Args:
            src: The input sequence to the encoder layer (sequence_length, batch_size, d_model).
            src_mask: An optional mask for the src sequence (optional).
            src_key_padding_mask: An optional mask for src keys indicating which elements
                                  to ignore in the attention computation (batch_size, sequence_length).
        Returns:
            The output of the Transformer Encoder Layer (sequence_length, batch_size, d_model).
        """
        # Multi-head self-attention
        # src, src, src are used for Q, K, V respectively in self-attention
        attn_output, _ = self.self_attn(src, src, src, attn_mask=src_mask,
                                        key_padding_mask=src_key_padding_mask)

        # First residual connection and layer normalization
        src = src + self.dropout1(attn_output)
        src = self.norm1(src)

        # Position-wise feed-forward network
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))

        # Second residual connection and layer normalization
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

# --- Transformer Encoder Model Class ---
class TransformerEncoder(nn.Module):
    """
    The full Transformer Encoder model.
    It comprises an embedding layer, positional encoding,
    and a stack of TransformerEncoderLayers.
    """
    def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward, dropout):
        super(TransformerEncoder, self).__init__()

        # Token embedding layer
        self.embedding = nn.Embedding(vocab_size, d_model)
        # Positional encoding layer
        self.pos_encoder = PositionalEncoding(d_model)
        # Single encoder layer instance to be replicated by nn.TransformerEncoder
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
        # Stack of TransformerEncoderLayers
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers)
        self.d_model = d_model

        # Linear layer for the final output, mapping d_model to vocab_size
        self.output_layer = nn.Linear(d_model, vocab_size)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        """
        Args:
            src: Input tensor of token indices (sequence_length, batch_size).
            src_mask: An optional mask for the src sequence (optional).
            src_key_padding_mask: An optional mask for src keys indicating which elements
                                  to ignore in the attention computation (batch_size, sequence_length).
        Returns:
            The output logits for each token in the vocabulary for each position
            (sequence_length, batch_size, vocab_size).
        """
        # 1. Embed tokens and scale by sqrt(d_model) as per the paper
        src = self.embedding(src) * math.sqrt(self.d_model)
        # 2. Add positional encodings
        src = self.pos_encoder(src)
        # 3. Pass through the Transformer Encoder stack
        output = self.transformer_encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
        # 4. Apply a linear layer to get logits over the vocabulary
        output = self.output_layer(output)
        return output

# --- Dummy Data Setup ---
# Define the vocabulary
vocab = ["<pad>", "hello", "world", "transformer", "encoder", "pytorch", "<sos>", "<eos>"]
# Create mappings from word to index and index to word
word_to_idx = {word: idx for idx, word in enumerate(vocab)}
idx_to_word = {idx: word for idx, word in enumerate(vocab)}
vocab_size = len(vocab)
pad_idx = word_to_idx["<pad>"] # Index for padding token

# Example dummy sentences (sequences)
dummy_data = [
    ["<sos>", "hello", "world", "<eos>"],
    ["<sos>", "transformer", "encoder", "pytorch", "<eos>"],
    ["<sos>", "hello", "encoder", "<eos>"]
]

# Determine the maximum sequence length in the dummy data
max_len = max(len(seq) for seq in dummy_data)

# Tokenize and pad the dummy data
tokenized_data = []
for seq in dummy_data:
    token_ids = [word_to_idx[token] for token in seq]
    # Pad sequences to max_len
    token_ids += [pad_idx] * (max_len - len(token_ids))
    tokenized_data.append(token_ids)

# Convert tokenized data to a PyTorch tensor
# Transpose to get [sequence_length, batch_size] as expected by the Transformer
input_tensor = torch.tensor(tokenized_data, dtype=torch.long).transpose(0, 1)

# Create a key padding mask to ignore padding tokens during attention calculation
# A True value indicates that the corresponding key token will be ignored.
src_key_padding_mask = (input_tensor == pad_idx).transpose(0, 1) # [batch_size, sequence_length]

# --- Model Instantiation ---
# Define model hyperparameters
d_model = 64            # Dimension of the model (embedding size)
nhead = 4               # Number of attention heads
num_encoder_layers = 2  # Number of Transformer encoder layers
dim_feedforward = 128   # Dimension of the feed-forward network
dropout = 0.1           # Dropout rate

# Instantiate the Transformer Encoder model
model = TransformerEncoder(vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward, dropout)

# --- Training (Example with a dummy next-token prediction task) ---
# For a simple next-token prediction task, the target sequence is
# the input sequence shifted by one position, with the last token padded.
target_tensor = torch.roll(input_tensor, shifts=-1, dims=0)
target_tensor[-1, :] = pad_idx # Pad the last token of each sequence

# Define the loss function (CrossEntropyLoss is suitable for classification/prediction tasks)
# ignore_index=pad_idx ensures that padding tokens do not contribute to the loss.
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10
print("\n--- Starting Training ---")
for epoch in range(num_epochs):
    model.train() # Set the model to training mode
    optimizer.zero_grad() # Clear gradients from previous iteration

    # Forward pass: get model output (logits)
    output = model(input_tensor, src_key_padding_mask=src_key_padding_mask)

    # Reshape output and target for CrossEntropyLoss
    # output: [sequence_length * batch_size, vocab_size]
    # target: [sequence_length * batch_size]
    output = output.view(-1, vocab_size)
    target = target_tensor.view(-1)

    # Calculate the loss
    loss = criterion(output, target)

    # Backward pass: compute gradients
    loss.backward()
    # Optimizer step: update model parameters
    optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

print("Training finished.")

# --- Inference (Example) ---
model.eval() # Set the model to evaluation mode (disables dropout, etc.)
print("\n--- Starting Inference Example ---")
with torch.no_grad(): # Disable gradient calculations for inference
    # Take the first sequence from the dummy input for inference
    sample_input = input_tensor[:, 0].unsqueeze(1) # [sequence_length, 1]
    # Take the corresponding key padding mask for the first sequence
    sample_mask = src_key_padding_mask[0].unsqueeze(0) # [1, sequence_length]

    # Get the output representation from the model
    # The output here will be logits for each token in the vocabulary.
    # For a real application, you might use argmax to get the predicted token ID.
    output_logits = model(sample_input, src_key_padding_mask=sample_mask)

    print("\nInference Output Logits (for the first sequence):")
    print(f"Shape of output_logits: {output_logits.shape}")
    # Display the actual logits (for illustration)
    # print(output_logits)

    # To get the predicted word for each position in the sample sequence:
    predicted_token_ids = torch.argmax(output_logits, dim=-1).squeeze(1)
    predicted_words = [idx_to_word[idx.item()] for idx in predicted_token_ids]

    original_sequence = [idx_to_word[idx.item()] for idx in input_tensor[:, 0]]
    print(f"Original Sequence (first batch item): {original_sequence}")
    print(f"Predicted Tokens (for the first sequence): {predicted_words}")

    # You could also examine the output of the transformer_encoder directly if you
    # wanted the learned contextual embeddings before the final linear layer.
    # For that, you would need a separate forward method in the model or
    # to access the internal output of self.transformer_encoder.
    # For now, `output_logits` is the direct output from your current model setup.


--- Starting Training ---
Epoch [1/10], Loss: 2.0706
Epoch [2/10], Loss: 1.8092
Epoch [3/10], Loss: 1.5785
Epoch [4/10], Loss: 1.2701
Epoch [5/10], Loss: 1.1294
Epoch [6/10], Loss: 0.9963
Epoch [7/10], Loss: 0.8781
Epoch [8/10], Loss: 0.7436
Epoch [9/10], Loss: 0.6703
Epoch [10/10], Loss: 0.6024
Training finished.

--- Starting Inference Example ---

Inference Output Logits (for the first sequence):
Shape of output_logits: torch.Size([5, 1, 8])
Original Sequence (first batch item): ['<sos>', 'hello', 'world', '<eos>', '<pad>']
Predicted Tokens (for the first sequence): ['hello', 'world', '<eos>', 'world', 'world']
