# Transformer From Scratch – Module Tests

This notebook walks through unit‑testing each custom module in the repository and finally runs a tiny end‑to‑end forward pass with the full `TransformerEncoderDecoder`.


In [1]:
import torch

from modules.blocks.decoder_block import TransformerDecoderBlock
from modules.blocks.encoder_block import TransformerEncoderBlock
from modules.core.embeddings import TokenEmbedding, PositionalEmbedding
from modules.core.multi_head_attention import MultiHeadedAttention
from modules.core.scaled_dot_product_attention import ScaledDotProductAttention
from modules.transformer import TransformerEncoderDecoder

  cpu = _conversion_method_template(device=torch.device("cpu"))


In [2]:
# ---- Token & Positional Embeddings ----
batch, seq_len, d_model = 2, 5, 32
vocab_size = 100
tok_emb = TokenEmbedding(vocab_size, d_model)
pos_emb = PositionalEmbedding(max_seq_length=50, embedding_dim=d_model)

x = torch.randint(0, vocab_size, (batch, seq_len))
tok = tok_emb(x)
pos = pos_emb(seq_len).unsqueeze(0)
print('Token embedding:', tok.shape)
print('Positional embedding:', pos.shape)
print('Sum:', (tok + pos).shape)

Token embedding: torch.Size([2, 5, 32])
Positional embedding: torch.Size([1, 5, 32])
Sum: torch.Size([2, 5, 32])


In [3]:
# ---- Scaled Dot‑Product Attention ----
attn = ScaledDotProductAttention(embedding_dim=d_model)
out = attn(tok)  # (B, L, d)
print('Scaled attention output:', out.shape)

Scaled attention output: torch.Size([2, 5, 32])


In [4]:
# ---- Multi‑Head Attention ----
heads = 4
mha = MultiHeadedAttention(embedding_dim=d_model, num_heads=heads)
out_mha = mha(tok, tok, tok)
print('Multi‑head output:', out_mha.shape)

Multi‑head output: torch.Size([2, 5, 32])


In [5]:
# ---- Encoder Block ----
ff_dim = 64
enc = TransformerEncoderBlock(d_model, heads, ff_dim)
enc_out = enc(tok + pos)
print('Encoder block output:', enc_out.shape)

Encoder block output: torch.Size([2, 5, 32])


In [6]:
# ---- Decoder Block ----
dec = TransformerDecoderBlock(d_model, heads, ff_dim)
dec_in = tok + pos  # pretend previous target embeddings
dec_out = dec(dec_in, enc_out, enc_out)
print('Decoder block output:', dec_out.shape)

Decoder block output: torch.Size([2, 5, 32])


In [7]:
# ---- Full Transformer Encoder‑Decoder ----
model = TransformerEncoderDecoder(
    src_vocab_size=vocab_size,
    tgt_vocab_size=vocab_size,
    max_seq_length=50,
    embedding_dim=d_model,
    num_encoder_layers=2,
    num_decoder_layers=2,
    num_heads=heads,
    feed_forward_dim=ff_dim
)

src = torch.randint(0, vocab_size, (batch, seq_len))
tgt = torch.randint(0, vocab_size, (batch, seq_len))
logits = model(src, tgt)
print('Transformer output logits:', logits.shape)

Transformer output logits: torch.Size([2, 5, 100])
