In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class Attention(nn.Module):
    def __init__(self, embed_dim: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert embed_dim % n_heads == 0, "d_model must be divisible by n_heads"
        self.d_model = embed_dim
        self.n_heads = n_heads
        # splitted embed_dim for each head
        self.d_k = embed_dim // n_heads
        # avoid overfitting
        self.dropout = nn.Dropout(dropout)

        self.q = nn.Linear(embed_dim, embed_dim)
        self.k = nn.Linear(embed_dim, embed_dim)
        self.v = nn.Linear(embed_dim, embed_dim)    

        self.o = nn.Linear(embed_dim, embed_dim)

        self.scale = math.sqrt(self.d_k)

    def forward(self, query, key, value, mask=None):
        # q, k, v shape: (batch, seq_len, d_model)
        batch_size = query.size(0)

        # q, k, v shape: (batch, seq_len, d_model)
        q = self.q(query)
        k = self.k(key)
        v = self.v(value)

        # -1 should be the embed_dim
        # q, k, v shape: (batch, n_heads, seq_len, d_k)
        q = q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        k = k.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        v = v.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)

        # combine query and key info
        # scores shape: (batch, n_heads, seq_len, seq_len)
        scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale

        # mask is important to avoid using padding token
        # for encoder 如果不 mask，padding token 会干扰注意力分数，导致模型学到噪声
        # for decoder 如果不 mask，当前位置会“偷看”后面的真实词 → 信息泄露 → 模型学不到真正预测未来的能力

        # scores shape: (batch, n_heads, seq_len, seq_len) but the value is -1e9 when mask is 0, or first seq_len < second seq_len
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # normalize the last dimension as embed_dim, after softmax, the value is between 0 and 1. and -1e9 will be almost 0
        scores = F.softmax(scores, dim=-1)
        # avoid overfitting
        scores = self.dropout(scores)

        # scores shape: (batch, n_heads, seq_len, seq_len)
        # v shape: (batch, n_heads, seq_len, d_k)
        context = torch.matmul(scores, v) # (batch, heads, seq, d_k) or scores @ v
        # transpose to (batch, seq, heads, d_k), then combine heads and d_k
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)  # Fixed: self.embed_dim -> self.d_model
        # output shape: (batch, seq_len, embed_dim)
        output = self.o(context)  # Fixed: self.W_o -> self.o
        return output
        
class PositionwiseFeedForward(nn.Module):
    """Position-wise Feed-Forward Network (two linear layers + activation)"""
    def __init__(self, embed_dim: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        # 把 d_model 维的向量 “扩展” 到更高维空间（通常 4 倍）
        self.linear1 = nn.Linear(embed_dim, d_ff)
        # 把扩展后的高维特征 “压缩” 回原来的 d_model 维度
        # 作用：提炼精华、输出更有价值的特征
        self.linear2 = nn.Linear(d_ff, embed_dim)
        # 作用：防止过拟合
        self.dropout = nn.Dropout(dropout)
        # 作用：增加非线性
        self.activation = nn.SELU()   # original paper used ReLU

    def forward(self, x):
        return self.linear2(self.dropout(self.activation(self.linear1(x))))

class EncoderLayer(nn.Module):
    """One Encoder Layer = Self-Attention + Feed-Forward + residual + norm"""
    def __init__(self, embed_dim: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = Attention(embed_dim, num_heads, dropout)
        self.feed_forward = PositionwiseFeedForward(embed_dim, d_ff, dropout)
        
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):  # Fixed: mask should be passed as parameter, not stored
        # x shape: (batch, seq_len, embed_dim)
        # Self-Attention sub-layer
        attn_output = self.self_attn(x, x, x, mask)  # Fixed: pass mask parameter
        
        # x keep the original info, and the attention output provide the new knowledge just learned
        x = self.norm1(x + self.dropout(attn_output))          # residual + norm (post-norm)
        
        # Feed-Forward sub-layer
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))            # residual + norm
        
        return x

class Encoder(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int, n_heads: int, d_ff: int, dropout: float = 0.1, n_layers: int = 6):  # Fixed: removed mask from __init__
        super().__init__()
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.vocab_size = vocab_size
        self.n_layers = n_layers
        self.norm = nn.LayerNorm(embed_dim)
        # 所以 d_ff 通常 远大于 d_model，形成“先扩展、再压缩”的瓶颈结构（bottleneck），这能显著增加模型的非线性表达能力。
        self.d_ff = d_ff

        self.layers = nn.ModuleList([EncoderLayer(embed_dim, n_heads, d_ff, dropout) for _ in range(n_layers)])

    # x: [batch_size, seq_len, embed_dim]
    def forward(self, x, mask=None):  # Fixed: add mask parameter
        """
        x: (batch, seq_len, d_model)  — usually after embedding + pos encoding
        mask: (batch, seq_len, seq_len) or (batch, 1, seq_len, seq_len)  — optional padding mask
        """
        for layer in self.layers:
            x = layer(x, mask)  # Fixed: pass mask to each layer
        return self.norm(x)
        


# Encoder Implementation Review & Improvements

## Issues Found:

1. **Bug in Attention class**: Uses `self.embed_dim` but it's actually `self.d_model`
2. **Mask handling**: Mask is passed to EncoderLayer but not properly forwarded to Attention
3. **Missing components**: No input embedding or positional encoding layers
4. **Inconsistent naming**: Mix of `embed_dim` and `d_model`
5. **Mask parameter**: Encoder accepts mask but doesn't use it

## Improvements:

1. Fix the bug in Attention forward method
2. Properly handle mask propagation through layers
3. Add input embedding layer
4. Add positional encoding
5. Standardize naming conventions
6. Add better documentation

In [1]:
# ============================================================================
# IMPROVED ENCODER IMPLEMENTATION
# ============================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class Attention(nn.Module):
    """
    Multi-Head Self-Attention Mechanism
    
    Improvements:
    - Fixed bug: self.embed_dim -> self.d_model
    - Better variable naming consistency
    - Improved documentation
    """
    def __init__(self, embed_dim: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert embed_dim % n_heads == 0, "embed_dim must be divisible by n_heads"
        
        self.d_model = embed_dim  # Standard naming: d_model
        self.n_heads = n_heads
        self.d_k = embed_dim // n_heads  # Dimension per head
        self.dropout = nn.Dropout(dropout)

        # Query, Key, Value projections
        self.q = nn.Linear(embed_dim, embed_dim, bias=False)
        self.k = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v = nn.Linear(embed_dim, embed_dim, bias=False)
        
        # Output projection
        self.o = nn.Linear(embed_dim, embed_dim, bias=False)  # Fixed: was self.W_o

        self.scale = math.sqrt(self.d_k)

    def forward(self, query, key, value, mask=None):
        """
        Args:
            query: (batch, seq_len, d_model)
            key: (batch, seq_len, d_model)
            value: (batch, seq_len, d_model)
            mask: (batch, 1, seq_len, seq_len) or (batch, seq_len, seq_len)
        
        Returns:
            output: (batch, seq_len, d_model)
        """
        batch_size, seq_len = query.size(0), query.size(1)

        # Linear projections
        q = self.q(query)  # (batch, seq_len, d_model)
        k = self.k(key)
        v = self.v(value)

        # Reshape and transpose for multi-head attention
        # (batch, seq_len, d_model) -> (batch, n_heads, seq_len, d_k)
        q = q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)

        # Compute attention scores
        # (batch, n_heads, seq_len, seq_len)
        scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale

        # Apply mask (if provided)
        if mask is not None:
            # Ensure mask has correct shape for broadcasting
            if mask.dim() == 3:
                mask = mask.unsqueeze(1)  # (batch, 1, seq_len, seq_len)
            scores = scores.masked_fill(mask == 0, float("-inf"))

        # Softmax normalization
        scores = F.softmax(scores, dim=-1)
        scores = self.dropout(scores)

        # Apply attention to values
        # (batch, n_heads, seq_len, d_k)
        context = torch.matmul(scores, v)
        
        # Concatenate heads: (batch, n_heads, seq_len, d_k) -> (batch, seq_len, d_model)
        context = context.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.d_model
        )
        
        # Output projection
        output = self.o(context)  # Fixed: was self.W_o
        return output


class PositionwiseFeedForward(nn.Module):
    """
    Position-wise Feed-Forward Network
    
    Improvements:
    - Better activation function (GELU is more common than SELU)
    - Clearer comments
    """
    def __init__(self, embed_dim: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(embed_dim, d_ff)
        self.linear2 = nn.Linear(d_ff, embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.GELU()  # GELU is more common than SELU in modern transformers

    def forward(self, x):
        """
        Args:
            x: (batch, seq_len, d_model)
        Returns:
            output: (batch, seq_len, d_model)
        """
        return self.linear2(self.dropout(self.activation(self.linear1(x))))


class EncoderLayer(nn.Module):
    """
    Single Encoder Layer
    
    Improvements:
    - Proper mask handling (pass mask to attention)
    - Pre-norm vs Post-norm option
    - Better residual connection handling
    """
    def __init__(
        self, 
        embed_dim: int, 
        num_heads: int, 
        d_ff: int, 
        dropout: float = 0.1,
        pre_norm: bool = True  # Pre-norm is more stable than post-norm
    ):
        super().__init__()
        self.self_attn = Attention(embed_dim, num_heads, dropout)
        self.feed_forward = PositionwiseFeedForward(embed_dim, d_ff, dropout)
        
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.pre_norm = pre_norm

    def forward(self, x, mask=None):
        """
        Args:
            x: (batch, seq_len, d_model)
            mask: (batch, seq_len, seq_len) or None
        Returns:
            x: (batch, seq_len, d_model)
        """
        if self.pre_norm:
            # Pre-norm: normalize before attention/FFN
            # Self-Attention sub-layer
            attn_output = self.self_attn(
                self.norm1(x), self.norm1(x), self.norm1(x), mask
            )
            x = x + self.dropout1(attn_output)
            
            # Feed-Forward sub-layer
            ff_output = self.feed_forward(self.norm2(x))
            x = x + self.dropout2(ff_output)
        else:
            # Post-norm: normalize after attention/FFN (original Transformer)
            attn_output = self.self_attn(x, x, x, mask)
            x = self.norm1(x + self.dropout1(attn_output))
            
            ff_output = self.feed_forward(x)
            x = self.norm2(x + self.dropout2(ff_output))
        
        return x


class PositionalEncoding(nn.Module):
    """
    Sinusoidal Positional Encoding
    
    Improvement: Added positional encoding class
    """
    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                            (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: (batch, seq_len, d_model)
        Returns:
            x: (batch, seq_len, d_model)
        """
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


class Encoder(nn.Module):
    """
    Transformer Encoder
    
    Improvements:
    - Added input embedding layer
    - Added positional encoding
    - Proper mask handling
    - Better parameter organization
    """
    def __init__(
        self, 
        vocab_size: int, 
        embed_dim: int, 
        n_heads: int, 
        d_ff: int, 
        n_layers: int = 6,
        max_seq_len: int = 5000,
        dropout: float = 0.1,
        pre_norm: bool = True
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.vocab_size = vocab_size
        self.n_layers = n_layers
        
        # Input embedding
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        # Positional encoding
        self.pos_encoding = PositionalEncoding(embed_dim, max_seq_len, dropout)
        
        # Encoder layers
        self.layers = nn.ModuleList([
            EncoderLayer(embed_dim, n_heads, d_ff, dropout, pre_norm) 
            for _ in range(n_layers)
        ])
        
        # Final layer norm
        self.norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        Args:
            x: (batch, seq_len) - token indices
            mask: (batch, seq_len, seq_len) or (batch, 1, seq_len, seq_len) - padding mask
        Returns:
            output: (batch, seq_len, d_model)
        """
        # Input embedding
        x = self.embedding(x) * math.sqrt(self.embed_dim)  # Scale embeddings
        
        # Add positional encoding
        x = self.pos_encoding(x)
        
        # Pass through encoder layers
        for layer in self.layers:
            x = layer(x, mask)
        
        # Final normalization
        return self.norm(x)
    
    def create_padding_mask(self, x, pad_idx=0):
        """
        Create padding mask from input tokens
        
        Args:
            x: (batch, seq_len) - token indices
            pad_idx: padding token index (default: 0)
        Returns:
            mask: (batch, 1, seq_len, seq_len) - 1 for valid positions, 0 for padding
        """
        mask = (x != pad_idx).unsqueeze(1).unsqueeze(2)  # (batch, 1, 1, seq_len)
        mask = mask.expand(x.size(0), 1, x.size(1), x.size(1))  # (batch, 1, seq_len, seq_len)
        return mask


# ============================================================================
# COMPARISON: Original vs Improved
# ============================================================================

print("=" * 70)
print("Key Improvements Summary")
print("=" * 70)
print("""
1. ✅ Fixed Bug: self.embed_dim -> self.d_model in Attention forward
2. ✅ Fixed Bug: self.W_o -> self.o in Attention forward  
3. ✅ Added Input Embedding layer
4. ✅ Added Positional Encoding layer
5. ✅ Proper mask handling (pass mask through layers)
6. ✅ Pre-norm option (more stable than post-norm)
7. ✅ Better documentation and type hints
8. ✅ Helper method: create_padding_mask()
9. ✅ Consistent naming (d_model throughout)
10. ✅ Better activation (GELU instead of SELU)
""")

Key Improvements Summary

1. ✅ Fixed Bug: self.embed_dim -> self.d_model in Attention forward
2. ✅ Fixed Bug: self.W_o -> self.o in Attention forward  
3. ✅ Added Input Embedding layer
4. ✅ Added Positional Encoding layer
5. ✅ Proper mask handling (pass mask through layers)
6. ✅ Pre-norm option (more stable than post-norm)
7. ✅ Better documentation and type hints
8. ✅ Helper method: create_padding_mask()
9. ✅ Consistent naming (d_model throughout)
10. ✅ Better activation (GELU instead of SELU)



In [2]:
# ============================================================================
# Test the Improved Encoder
# ============================================================================

# Test parameters
vocab_size = 1000
d_model = 512
n_heads = 8
d_ff = 2048
n_layers = 6
batch_size = 2
seq_len = 10

# Create improved encoder
encoder = Encoder(
    vocab_size=vocab_size,
    embed_dim=d_model,
    n_heads=n_heads,
    d_ff=d_ff,
    n_layers=n_layers,
    dropout=0.1,
    pre_norm=True
)

# Create dummy input (token indices)
input_tokens = torch.randint(0, vocab_size, (batch_size, seq_len))
print(f"Input tokens shape: {input_tokens.shape}")

# Create padding mask (simulate padding at end)
padding_mask = encoder.create_padding_mask(input_tokens)
print(f"Padding mask shape: {padding_mask.shape}")

# Forward pass
with torch.no_grad():
    output = encoder(input_tokens, mask=padding_mask)
    print(f"\nOutput shape: {output.shape}")
    print(f"Expected: ({batch_size}, {seq_len}, {d_model})")
    print(f"✅ Shape matches: {output.shape == (batch_size, seq_len, d_model)}")

# Test without mask
output_no_mask = encoder(input_tokens)
print(f"\nOutput without mask shape: {output_no_mask.shape}")
print(f"✅ Works without mask: {output_no_mask.shape == (batch_size, seq_len, d_model)}")

Input tokens shape: torch.Size([2, 10])
Padding mask shape: torch.Size([2, 1, 10, 10])

Output shape: torch.Size([2, 10, 512])
Expected: (2, 10, 512)
✅ Shape matches: True

Output without mask shape: torch.Size([2, 10, 512])
✅ Works without mask: True


In [None]:
# ============================================================================
# Detailed Review Comments
# ============================================================================

print("=" * 70)
print("DETAILED REVIEW COMMENTS")
print("=" * 70)

print("""
ORIGINAL CODE ISSUES:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

1. ❌ CRITICAL BUG in Attention.forward():
   - Line uses: self.embed_dim but should be self.d_model
   - Line uses: self.W_o but should be self.o
   - This will cause AttributeError at runtime!

2. ❌ Mask not properly handled:
   - EncoderLayer accepts mask parameter but doesn't pass it to Attention
   - Encoder accepts mask but never uses it
   - Mask shape handling could be improved

3. ❌ Missing components:
   - No input embedding layer (expects pre-embedded input)
   - No positional encoding (critical for transformers!)
   - No way to create padding masks

4. ⚠️  Inconsistent naming:
   - Mix of embed_dim and d_model
   - Standard convention is d_model

5. ⚠️  Post-norm vs Pre-norm:
   - Original uses post-norm (less stable)
   - Pre-norm is more stable and commonly used now

6. ⚠️  Activation function:
   - Uses SELU (less common)
   - GELU is more standard in modern transformers

IMPROVEMENTS MADE:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

✅ Fixed all bugs
✅ Added input embedding
✅ Added positional encoding  
✅ Proper mask handling
✅ Pre-norm option (with post-norm fallback)
✅ Better documentation
✅ Helper methods for mask creation
✅ Consistent naming conventions
✅ More standard activation (GELU)

USAGE EXAMPLE:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

encoder = Encoder(
    vocab_size=10000,
    embed_dim=512,
    n_heads=8,
    d_ff=2048,
    n_layers=6,
    dropout=0.1
)

# Input: token indices (not embeddings!)
tokens = torch.randint(0, 10000, (batch_size, seq_len))

# Create mask
mask = encoder.create_padding_mask(tokens)

# Forward pass
output = encoder(tokens, mask=mask)
""")