<a href="https://colab.research.google.com/github/nares10/transformer-implementation/blob/main/transformer_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [3]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        """
        Adds positional information to token embeddings.
        """
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Create constant 'pe' matrix with values dependent on
        # position and dimension
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                             (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)  # even indices
        pe[:, 1::2] = torch.cos(position * div_term)  # odd indices
        pe = pe.unsqueeze(0)  # Shape: [1, max_len, d_model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Tensor of shape [batch_size, seq_len, d_model]
        Returns:
            Tensor with positional encodings added.
        """
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

In [4]:
def scaled_dot_product_attention(q, k, v, mask=None):
    """
    Calculate the attention weights.
    q, k, v: shape [batch, heads, seq_len, d_k]
    mask: mask tensor to prevent attention to certain positions.
    """
    d_k = q.size(-1)
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    attention = F.softmax(scores, dim=-1)
    return torch.matmul(attention, v), attention

In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        """
        Multi-Head Attention mechanism.
        """
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_k = d_model // num_heads
        self.num_heads = num_heads

        # Define linear layers for queries, keys, and values.
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)
        # Perform linear projections and split into multiple heads
        q = self.q_linear(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
        k = self.k_linear(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
        v = self.v_linear(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)

        # Compute scaled dot-product attention for each head
        x, attn = scaled_dot_product_attention(q, k, v, mask)

        # Concatenate heads and put through final linear layer
        x = x.transpose(1,2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
        return self.out(x)

In [6]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=2048, dropout=0.1):
        """
        Two-layer feed-forward network.
        """
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

In [7]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff=2048, dropout=0.1):
        """
        Single Encoder Layer composed of multi-head self-attention and feed-forward network.
        """
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Self-attention sublayer with residual connection and layer normalization
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout1(attn_output))
        # Feed-forward sublayer with residual connection and layer normalization
        ff_output = self.ff(x)
        x = self.norm2(x + self.dropout2(ff_output))
        return x

In [8]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff=2048, dropout=0.1):
        """
        Single Decoder Layer which includes:
          - Self-attention for target sequence.
          - Encoder-decoder attention.
          - Feed-forward network.
        """
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.enc_dec_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        # Self-attention sublayer for target sequence
        self_attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout1(self_attn_output))
        # Encoder-decoder attention sublayer
        enc_dec_attn_output = self.enc_dec_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout2(enc_dec_attn_output))
        # Feed-forward sublayer
        ff_output = self.ff(x)
        x = self.norm3(x + self.dropout3(ff_output))
        return x

In [9]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, num_heads=8,
                 num_encoder_layers=6, num_decoder_layers=6, d_ff=2048, dropout=0.1, max_len=5000):
        """
        Transformer model combining encoder and decoder stacks.
        """
        super(Transformer, self).__init__()
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, dropout, max_len)

        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_encoder_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_decoder_layers)
        ])
        self.output_layer = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        """
        Args:
            src: Source tensor [batch_size, src_seq_len]
            tgt: Target tensor [batch_size, tgt_seq_len]
            src_mask, tgt_mask: Masks to control attention (optional)
        Returns:
            Output logits of shape [batch_size, tgt_seq_len, tgt_vocab_size]
        """
        # Embedding + positional encoding
        src = self.positional_encoding(self.src_embedding(src))
        tgt = self.positional_encoding(self.tgt_embedding(tgt))

        # Pass through the encoder stack
        for layer in self.encoder_layers:
            src = layer(src, src_mask)
        enc_output = src

        # Pass through the decoder stack
        for layer in self.decoder_layers:
            tgt = layer(tgt, enc_output, src_mask, tgt_mask)

        # Final linear layer for output probabilities
        output = self.output_layer(tgt)
        return output

In [10]:
batch_size = 2
src_seq_len = 10
tgt_seq_len = 10
src_vocab_size = 1000
tgt_vocab_size = 1000


In [11]:
model = Transformer(src_vocab_size, tgt_vocab_size)

In [12]:
src = torch.randint(0, src_vocab_size, (batch_size, src_seq_len))
tgt = torch.randint(0, tgt_vocab_size, (batch_size, tgt_seq_len))

In [13]:
output = model(src, tgt)

In [14]:
print("Output shape:", output.shape)  # Expected: [batch_size, tgt_seq_len, tgt_vocab_size]

Output shape: torch.Size([2, 10, 1000])
