In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class Attention(nn.Module):
    def __init__(self, embed_dim: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert embed_dim % n_heads == 0, "d_model must be divisible by n_heads"
        self.d_model = embed_dim
        self.n_heads = n_heads
        # splitted embed_dim for each head
        self.d_k = embed_dim // n_heads
        # avoid overfitting
        self.dropout = nn.Dropout(dropout)

        self.q = nn.Linear(embed_dim, embed_dim)
        self.k = nn.Linear(embed_dim, embed_dim)
        self.v = nn.Linear(embed_dim, embed_dim)    

        self.o = nn.Linear(embed_dim, embed_dim)

        self.scale = math.sqrt(self.d_k)

    def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):
        # q, k, v shape: (batch, seq_len, d_model)
        batch_size = query.size(0)

        # q, k, v shape: (batch, seq_len, d_model)
        q = self.q(query)
        k = self.k(key)
        v = self.v(value)

        # sequence length could be computed automatically. so it is not a hyperparameter for Attention.

        # -1 should be the embed_dim
        # q, k, v shape: (batch, n_heads, seq_len, d_k)
        q = q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        k = k.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        v = v.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)

        # combine query and key info
        # scores shape: (batch, n_heads, seq_len, seq_len)
        scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale

        # mask is important to avoid using padding token
        # for encoder 如果不 mask，padding token 会干扰注意力分数，导致模型学到噪声
        # for decoder 如果不 mask，当前位置会“偷看”后面的真实词 → 信息泄露 → 模型学不到真正预测未来的能力

        # scores shape: (batch, n_heads, seq_len, seq_len) but the value is -1e9 when mask is 0, or first seq_len < second seq_len
        # Apply masks
        if attn_mask is not None:
            # attn_mask: usually (seq_q, seq_k) or broadcastable
            scores = scores + attn_mask   # usually -inf in future positions
        if key_padding_mask is not None:
            # key_padding_mask: (B, seq_k) → expand to (B, H, seq_q, seq_k)
            scores = scores.masked_fill(
                key_padding_mask.unsqueeze(1).unsqueeze(2) == 0,
                float('-inf')
            )

        # normalize the last dimension as embed_dim, after softmax, the value is between 0 and 1. and -1e9 will be almost 0
        scores = F.softmax(scores, dim=-1)
        # avoid overfitting
        scores = self.dropout(scores)

        # scores shape: (batch, n_heads, seq_len, seq_len)
        # v shape: (batch, n_heads, seq_len, d_k)
        context = torch.matmul(scores, v) # (batch, heads, seq, d_k) or scores @ v
        # transpose to (batch, seq, heads, d_k), then combine heads and d_k
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)  # Fixed: self.embed_dim -> self.d_model
        # output shape: (batch, seq_len, embed_dim)
        output = self.o(context)  # Fixed: self.W_o -> self.o
        return output
        
class PositionwiseFeedForward(nn.Module):
    """Position-wise Feed-Forward Network (two linear layers + activation)"""
    def __init__(self, embed_dim: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        # 把 d_model 维的向量 “扩展” 到更高维空间（通常 4 倍）
        self.linear1 = nn.Linear(embed_dim, d_ff)
        # 把扩展后的高维特征 “压缩” 回原来的 d_model 维度
        # 作用：提炼精华、输出更有价值的特征
        self.linear2 = nn.Linear(d_ff, embed_dim)
        # 作用：防止过拟合
        self.dropout = nn.Dropout(dropout)
        # 作用：增加非线性
        self.activation = nn.SELU()   # original paper used ReLU

    def forward(self, x):
        return self.linear2(self.dropout(self.activation(self.linear1(x))))

class DecoderLayer(nn.Module):
    """One Encoder Layer = Self-Attention + Feed-Forward + residual + norm"""
    def __init__(self, embed_dim: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = Attention(embed_dim, num_heads, dropout)
        self.cross_attn = Attention(embed_dim, num_heads, dropout)

        self.feed_forward = PositionwiseFeedForward(embed_dim, d_ff, dropout)
        
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.norm3 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_output, self_attn_mask=None, self_key_padding_mask=None,
                cross_attn_mask=None, cross_key_padding_mask=None):  # Fixed: mask should be passed as parameter, not stored
        # x shape: (batch, seq_len, embed_dim)
        # Self-Attention sub-layer
        self_attn_output = self.self_attn(x, x, x, None, None)  # Fixed: pass mask parameter
        
        # x keep the original info, and the attention output provide the new knowledge just learned
        x = self.norm1(x + self.dropout(self_attn_output))          # residual + norm (post-norm)

        cross_attn_output = self.cross_attn(x, encoder_output, encoder_output, cross_attn_mask, cross_key_padding_mask)
        x = self.norm2(x + self.dropout(cross_attn_output))
        
        # Feed-Forward sub-layer
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))            # residual + norm
        
        return x

class Decoder(nn.Module):
    """
    Decoder is a stack of DecoderLayer

    key_padding_mask
    如果你用 Hugging Face → 直接用 tokenizer 返回的 attention_mask（1=保留，0=padding），这是 90% 项目的最常见方式
    不加 padding_mask 会导致模型“关注”无意义的 padding，严重影响性能。
    padding的默认idx是0，所以用 mask == 0 就可以。它会广播的
    (batch_size, seq_len)

    attn_mask
    它是一个上三角矩阵，用于限制注意力只关注当前位置之前的词。
    在 Decoder 中，它用于限制 Decoder 只“看到”它之前的位置，避免“偷看”未来的信息。
    在 Encoder 中，它用于限制 Encoder 只“看到”它之前的位置，避免“偷看”未来的信息。

    encoder_output
    是来自Encoder的输出，用于Cross-Attention。 它是Encoder最后的输出，不是其中的每一层。
    它和Decoder的形状是一样的，都是 (batch_size, seq_len, embed_dim)

    src_seq = 64
    tgt_seq = 50
    所有序列都要 padding 到 batch 中最长的长度
    basically, 4096 is enough for most cases. 

    encoder input & decoder input
    它们往往是不同的，根据不同的训练任务。比如翻译，二个完全不同，而且要右移一个token。
    对于GPT Llama没有encoder，只有decoder。


    """
    def __init__(self, vocab_size: int, embed_dim: int, n_heads: int, d_ff: int, dropout: float = 0.1, n_layers: int = 6):  # Fixed: removed mask from __init__
        super().__init__()
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.vocab_size = vocab_size
        self.n_layers = n_layers
        self.norm = nn.LayerNorm(embed_dim)
        # 所以 d_ff 通常 远大于 d_model，形成“先扩展、再压缩”的瓶颈结构（bottleneck），这能显著增加模型的非线性表达能力。
        self.d_ff = d_ff

        self.layers = nn.ModuleList([DecoderLayer(embed_dim, n_heads, d_ff, dropout) for _ in range(n_layers)])

    # x: [batch_size, seq_len, embed_dim]
    def forward(self, x, encoder_output, mask=None):  # Fixed: add mask parameter
        """
        x: (batch, seq_len, d_model)  — usually after embedding + pos encoding
        mask: (batch, seq_len, seq_len) or (batch, 1, seq_len, seq_len)  — optional padding mask
        """
        for layer in self.layers:
            x = layer(x, encoder_output, encoder_output)  # Fixed: pass mask to each layer
        return self.norm(x)
        

# Test parameters
vocab_size = 1000
d_model = 512
n_heads = 8
d_ff = 2048
n_layers = 6
batch_size = 2
src_seq_len = 16
tgt_seq_len = 24

# Create improved encoder
decoder = Decoder(
    vocab_size=vocab_size,
    embed_dim=d_model,
    n_heads=n_heads,
    d_ff=d_ff,
    n_layers=n_layers,
    dropout=0.1
)

decoder_input = torch.randn(batch_size, tgt_seq_len, d_model)
encoder_output = torch.randn(batch_size, src_seq_len, d_model)

# Create dummy input (token indices)
# input_tokens = torch.randint(0, vocab_size, (batch_size, seq_len))
# print(f"Input tokens shape: {input_tokens.shape}")

# Create padding mask (simulate padding at end)
padding_mask = None

# Forward pass
with torch.no_grad():
    output = decoder(decoder_input, encoder_output)
    print(f"\nOutput shape: {output.shape}")


Output shape: torch.Size([2, 24, 512])
