# test

In [1]:
import torch
import torch.nn as nn

In [2]:
# ========== Assignment Hyperparameters ==========

BATCH_SIZE = 32
SEQ_LEN = 1024
VOCAB_SIZE = 1000
EMBED_SIZE = 512
HIDDEN_SIZE = 768
NUM_HEADS = 8
ENCODER_LAYERS = 2
DECODER_LAYERS = 3

In [3]:
# ========= Embedding & Positional Encoding =======
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=4096):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(10000.0)) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)].to(x.device)

embedding = nn.Embedding(VOCAB_SIZE, EMBED_SIZE)
pos_encoding = PositionalEncoding(EMBED_SIZE, SEQ_LEN)

# ========== Encoder Block ==========
class EncoderBlock(nn.Module):
    def __init__(self, emb_dim, hidden_dim, num_heads):
        super().__init__()
        self.mha = nn.MultiheadAttention(emb_dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(emb_dim)
        self.ffn = nn.Sequential(
            nn.Linear(emb_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, emb_dim)
        )
        self.norm2 = nn.LayerNorm(emb_dim)

    def forward(self, x, layer_num=0):
        # Multi-Head Attention (self-attention)
        attn_out, _ = self.mha(x, x, x)
        print(f"Encoder Layer {layer_num}: after MHA:", attn_out.shape)
        x = self.norm1(x + attn_out)
        print(f"Encoder Layer {layer_num}: after Add+Norm1:", x.shape)
        # Feed Forward
        ffn_out = self.ffn(x)
        print(f"Encoder Layer {layer_num}: after FFN:", ffn_out.shape)
        x = self.norm2(x + ffn_out)
        print(f"Encoder Layer {layer_num}: after Add+Norm2:", x.shape)
        return x

    def parameter_count(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

# ========== Decoder Block ==========
class DecoderBlock(nn.Module):
    def __init__(self, emb_dim, hidden_dim, num_heads):
        super().__init__()
        self.mha1 = nn.MultiheadAttention(emb_dim, num_heads, batch_first=True)  # Masked Self-Attention
        self.norm1 = nn.LayerNorm(emb_dim)
        self.mha2 = nn.MultiheadAttention(emb_dim, num_heads, batch_first=True)  # Encoder-Decoder Attention
        self.norm2 = nn.LayerNorm(emb_dim)
        self.ffn = nn.Sequential(
            nn.Linear(emb_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, emb_dim)
        )
        self.norm3 = nn.LayerNorm(emb_dim)

    def forward(self, x, enc_out, layer_num=0):
        # Masked Self-Attention (for real generation, you must use masking here)
        attn1_out, _ = self.mha1(x, x, x)
        print(f"Decoder Layer {layer_num}: after Masked MHA:", attn1_out.shape)
        x = self.norm1(x + attn1_out)
        print(f"Decoder Layer {layer_num}: after Add+Norm1:", x.shape)
        # Encoder-Decoder Attention
        attn2_out, _ = self.mha2(x, enc_out, enc_out)
        print(f"Decoder Layer {layer_num}: after Enc-Dec MHA:", attn2_out.shape)
        x = self.norm2(x + attn2_out)
        print(f"Decoder Layer {layer_num}: after Add+Norm2:", x.shape)
        # Feed Forward
        ffn_out = self.ffn(x)
        print(f"Decoder Layer {layer_num}: after FFN:", ffn_out.shape)
        x = self.norm3(x + ffn_out)
        print(f"Decoder Layer {layer_num}: after Add+Norm3:", x.shape)
        return x

    def parameter_count(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

# ========== Encoder Stack ==========
class EncoderStack(nn.Module):
    def __init__(self, num_layers, emb_dim, hidden_dim, num_heads):
        super().__init__()
        self.layers = nn.ModuleList([
            EncoderBlock(emb_dim, hidden_dim, num_heads) for _ in range(num_layers)
        ])

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            print(f"\n--- Encoder Block {i+1} ---")
            x = layer(x, layer_num=i+1)
        return x

    def parameter_count(self):
        return sum(block.parameter_count() for block in self.layers)

# ========== Decoder Stack ==========
class DecoderStack(nn.Module):
    def __init__(self, num_layers, emb_dim, hidden_dim, num_heads):
        super().__init__()
        self.layers = nn.ModuleList([
            DecoderBlock(emb_dim, hidden_dim, num_heads) for _ in range(num_layers)
        ])

    def forward(self, x, enc_out):
        for i, layer in enumerate(self.layers):
            print(f"\n--- Decoder Block {i+1} ---")
            x = layer(x, enc_out, layer_num=i+1)
        return x

    def parameter_count(self):
        return sum(block.parameter_count() for block in self.layers)



In [4]:
# ========== Demo: Forward Pass & Parameter Reporting ==========
# Fake input data (like token indices)
x = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN))

# Embedding and positional encoding
out = embedding(x)
print("After embedding:", out.shape)
out = pos_encoding(out)
print("After positional encoding:", out.shape)

# Encoder stack
encoder = EncoderStack(ENCODER_LAYERS, EMBED_SIZE, HIDDEN_SIZE, NUM_HEADS)
print(f"\n=== EncoderStack: Total parameters: {encoder.parameter_count():,}")
enc_out = encoder(out)

# Decoder stack (uses a copy of input for simplicity)
dec_in = embedding(x)
dec_in = pos_encoding(dec_in)
decoder = DecoderStack(DECODER_LAYERS, EMBED_SIZE, HIDDEN_SIZE, NUM_HEADS)
print(f"\n=== DecoderStack: Total parameters: {decoder.parameter_count():,}")
dec_out = decoder(dec_in, enc_out)

After embedding: torch.Size([32, 1024, 512])
After positional encoding: torch.Size([32, 1024, 512])

=== EncoderStack: Total parameters: 3,680,768

--- Encoder Block 1 ---
Encoder Layer 1: after MHA: torch.Size([32, 1024, 512])
Encoder Layer 1: after Add+Norm1: torch.Size([32, 1024, 512])
Encoder Layer 1: after FFN: torch.Size([32, 1024, 512])
Encoder Layer 1: after Add+Norm2: torch.Size([32, 1024, 512])

--- Encoder Block 2 ---
Encoder Layer 2: after MHA: torch.Size([32, 1024, 512])
Encoder Layer 2: after Add+Norm1: torch.Size([32, 1024, 512])
Encoder Layer 2: after FFN: torch.Size([32, 1024, 512])
Encoder Layer 2: after Add+Norm2: torch.Size([32, 1024, 512])

=== DecoderStack: Total parameters: 8,676,096

--- Decoder Block 1 ---
Decoder Layer 1: after Masked MHA: torch.Size([32, 1024, 512])
Decoder Layer 1: after Add+Norm1: torch.Size([32, 1024, 512])
Decoder Layer 1: after Enc-Dec MHA: torch.Size([32, 1024, 512])
Decoder Layer 1: after Add+Norm2: torch.Size([32, 1024, 512])
Decoder 