# Including Attention, MLP, and Model Architecture Designs

Our current state of development:

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# Tiny toy corpus
corpus = [
    "hello , how are you ?",
    "hello , how is your day ?",
    "how are you ?",
    "how is your day ?",
]

vocab = {"<pad>": 0, 
         "<bos>": 1, 
         "<eos>": 2}

In [None]:
tokens = set()
for i in corpus:
    for word in i.split():
        tokens.add(word)
tokens

In [None]:
for ix,val in enumerate(tokens):
    vocab[val] = ix+3
vocab_size = len(vocab)

In [None]:
# get the reverse correspondence too:
id2token = {i: t for t, i in vocab.items()}

In [None]:
def simple_tokenize(text):
    return [vocab["<bos>"]] + [vocab[w] for w in text.split()] + [vocab["<eos>"]]

def detokenize(ids):
    return " ".join(id2token[i] for i in ids)

Our current model:

In [None]:
class TinyLM(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=2,
                dim_feedforward=64,
                batch_first=True,
            )
            for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.out_head = nn.Linear(d_model, vocab_size)

    def forward(self, input_ids):
        # input_ids: (batch, seq_len)
        x = self.embed(input_ids)  # (batch, seq_len, d_model)

        seq_len = x.size(1)
        # causal mask: (seq_len, seq_len)
        # mask[i, j] = -inf if j > i (can't attend to future)
        mask = torch.triu(
            torch.ones(seq_len, seq_len, device=x.device) * float("-inf"),
            diagonal=1
        )

        for layer in self.layers:
            x = layer(x, src_mask=mask)

        x = self.ln_f(x)
        logits = self.out_head(x)  # (batch, seq_len, vocab_size)
        return logits


Our current training loop

In [None]:
# model = TinyLM(vocab_size=vocab_size, d_model=32, n_layers=2)

# criterion = nn.CrossEntropyLoss()          # for token prediction
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# num_epochs = 200  # small corpus, will overfit fast

# for epoch in range(num_epochs):
#     model.train()
#     optimizer.zero_grad()

#     # forward
#     logits = model(input_batch)  # (batch, seq_len, vocab_size)

#     # reshape for CrossEntropyLoss: (batch * seq_len, vocab_size)
#     logits_flat = logits.view(-1, vocab_size)
#     targets_flat = target_batch.view(-1)   # (batch * seq_len,)

#     loss = criterion(logits_flat, targets_flat)

#     # backward
#     loss.backward()
#     optimizer.step()

#     if (epoch + 1) % 20 == 0:
#         print(f"Epoch {epoch+1}/{num_epochs} - loss: {loss.item():.4f}")


# .... Now .... 

The following code above is our "transformer block" that has the attention and MLP:
>            nn.TransformerEncoderLayer(
>                d_model=d_model,
>                nhead=2,
>                dim_feedforward=64,
>                batch_first=True,
>            )


How can we replace this with more illuminating code?

Let's first revise the TinyLM to accept a block that's more general:

In [None]:
class TinyLM(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([
            # REPLACING:
            # nn.TransformerEncoderLayer(
            #     d_model=d_model,
            #     nhead=2,
            #     dim_feedforward=64,
            #     batch_first=True,
            # )
            TinyAttnBlock(d_model, d_ff=64)   # our custom attention block
            for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.out_head = nn.Linear(d_model, vocab_size)

    def forward(self, input_ids):
        # input_ids: (batch, seq_len)
        x = self.embed(input_ids)  # (batch, seq_len, d_model)

        # ALSO REMOVE the MASK temporarily

        for layer in self.layers:
            x = layer(x)

        x = self.ln_f(x)
        logits = self.out_head(x)  # (batch, seq_len, vocab_size)
        return logits


We need to define the appropriate TinyAttnBlock:

In [None]:
class TinyAttnBlock(nn.Module):
    def __init__(self, d_model, d_ff=64):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = SelfAttention(d_model)  # SELF-ATTENTION
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(            # MLP
            nn.Linear(d_model, d_ff),       # 
            nn.ReLU(),                      #
            nn.Linear(d_ff, d_model),       #
        )                                   #

    def forward(self, x):
        # self-attention + residual
        x = x + self.attn(self.ln1(x))
        # feedforward + residual
        x = x + self.ff(self.ln2(x))
        return x

And we can be explicit about the SelfAttention:

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        # separate projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        self.out = nn.Linear(d_model, d_model)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        B, T, C = x.size()

        q = self.W_q(x)  # (B, T, C)
        k = self.W_k(x)  # (B, T, C)
        v = self.W_v(x)  # (B, T, C)

        # scaled dot-product attention
        scores = (q @ k.transpose(-2, -1)) / math.sqrt(C)  # (B, T, T)

        # causal mask: no attending to future positions
        mask = torch.triu(
            torch.ones(T, T, device=x.device) * float("-inf"),
            diagonal=1
        )
        scores = scores + mask  # broadcast over batch

        attn = F.softmax(scores, dim=-1)       # (B, T, T)
        y = attn @ v                           # (B, T, C)

        y = self.out(y)                        # (B, T, C)
        return y


In [None]:
model = TinyLM(vocab_size=vocab_size, d_model=32, n_layers=2)

In [None]:
print(model)

In [None]:
sum(p.numel() for p in model.parameters())

This is enough for us to make some key comments about different model architectures:
* **Encoder-only models (BERT-style)**
  * tokens can see each other in both directions; you usually only mask out padding
  * in our SelfAttention: no causal mask (don't block out the "future")
* **Decoder-only models (GPT-style)**
  * tokens can only see past and current positions -> causal mask + optional padding mask
  * in our SelfAttention: use the causal mask (`-inf` in upper triangle)
* **Encoder-decoder models (T5-style)**
  * Encoder blocks are the same as encoder-only model
  * Decoder blocks have:
     * Causal self-attention, just like the decoder-only model
     * Cross-attention (attention decoder pays to the encoder output)
       * Q comes from the decoder hidden states
       * K and V come from the encoder outputs
       * No causal mask (encoder tokens are not ordered relative to decoder tokens)

### Encoder self-attention

In [None]:
class EncoderSelfAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, x, attn_mask=None):
        # x: (B, T, C)
        B, T, C = x.size()

        q = self.W_q(x)  # (B, T, C)
        k = self.W_k(x)
        v = self.W_v(x)

        scores = (q @ k.transpose(-2, -1)) / math.sqrt(C)  # (B, T, T)

        # KEY DIFFERENCE
        # remove the causal mask
        
        # attn_mask: (B, 1, T) or (B, T, T), with 0 for keep, -inf for mask
        if attn_mask is not None:
            scores = scores + attn_mask

        attn = F.softmax(scores, dim=-1)  # (B, T, T)
        y = attn @ v                      # (B, T, C)
        y = self.out(y)
        return y


In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, d_model, d_ff=64):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.self_attn = EncoderSelfAttention(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )

    def forward(self, x, attn_mask=None):
        x = x + self.self_attn(self.ln1(x), attn_mask=attn_mask)
        x = x + self.ff(self.ln2(x))
        return x


## Decoder self-attention

In [None]:
class DecoderSelfAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, x, attn_mask=None):
        # x: (B, T, C)
        B, T, C = x.size()

        q = self.W_q(x)
        k = self.W_k(x)
        v = self.W_v(x)

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(C)  # (B, T, T)

        # KEY DIFFERENCE
        # include the causal mask
        # causal mask: prevent attending to j > i
        causal_mask = torch.triu(
            torch.ones(T, T, device=x.device) * float("-inf"),
            diagonal=1
        )  # (T, T)

        scores = scores + causal_mask  # broadcast to (B, T, T)

        # optional extra mask (e.g., padding), same idea as encoder
        if attn_mask is not None:
            scores = scores + attn_mask

        attn = F.softmax(scores, dim=-1)
        y = torch.matmul(attn, v)
        y = self.out(y)
        return y


## Decoder block with cross-attention included

In [None]:
class CrossAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.W_q = nn.Linear(d_model, d_model)  # from decoder state
        self.W_k = nn.Linear(d_model, d_model)  # from encoder output
        self.W_v = nn.Linear(d_model, d_model)  # from encoder output
        self.out = nn.Linear(d_model, d_model)

    def forward(self, x_dec, x_enc, enc_attn_mask=None):
        """
        x_dec: (B, T_dec, C)  decoder hidden states
        x_enc: (B, T_enc, C) encoder outputs
        enc_attn_mask: mask over encoder positions, shape (B, 1, T_enc) or (B, T_dec, T_enc)
        """
        B, T_dec, C = x_dec.size()
        _, T_enc, _ = x_enc.size()

        q = self.W_q(x_dec)  # (B, T_dec, C)
        k = self.W_k(x_enc)  # (B, T_enc, C)
        v = self.W_v(x_enc)  # (B, T_enc, C)

        # scores: (B, T_dec, T_enc)
        scores = (q @ k.transpose(-2, -1)) / math.sqrt(C)

        if enc_attn_mask is not None:
            scores = scores + enc_attn_mask  # mask out encoder pads

        attn = F.softmax(scores, dim=-1)        # (B, T_dec, T_enc)
        y = attn @ v                            # (B, T_dec, C)
        y = self.out(y)
        return y


In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model, d_ff=64):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.self_attn = DecoderSelfAttention(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.cross_attn = CrossAttention(d_model)
        self.ln3 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )

    def forward(self, x_dec, x_enc, self_mask=None, enc_mask=None):
        # causal self-attention over decoder tokens
        x_dec = x_dec + self.self_attn(self.ln1(x_dec), attn_mask=self_mask)

        # cross-attention to encoder outputs
        x_dec = x_dec + self.cross_attn(self.ln2(x_dec), x_enc, enc_attn_mask=enc_mask)

        # feedforward
        x_dec = x_dec + self.ff(self.ln3(x_dec))
        return x_dec


## Encoder-decoder model

Use both of the above.  Here we skip showing only the attention blocks like the above, and instead show the full (mini) model.

In [None]:
class TinySeq2Seq(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers):
        super().__init__()
        self.src_embed = nn.Embedding(vocab_size, d_model)
        self.tgt_embed = nn.Embedding(vocab_size, d_model)

        self.encoder_layers = nn.ModuleList([
            EncoderBlock(d_model, d_ff=64)
            for _ in range(n_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            DecoderBlock(d_model, d_ff=64)
            for _ in range(n_layers)
        ])

        self.enc_ln_f = nn.LayerNorm(d_model)
        self.dec_ln_f = nn.LayerNorm(d_model)
        self.out_head = nn.Linear(d_model, vocab_size)

    # separate encode/decode helpers (nice for generation)
    def encode(self, src_ids):
        # src_ids: (B, T_src)
        x = self.src_embed(src_ids)  # (B, T_src, C)
        for layer in self.encoder_layers:
            x = layer(x)
        x = self.enc_ln_f(x)
        return x  # encoder outputs

    def decode(self, tgt_ids, enc_out):
        # tgt_ids: (B, T_tgt), enc_out: (B, T_src, C)
        x = self.tgt_embed(tgt_ids)  # (B, T_tgt, C)
        for layer in self.decoder_layers:
            x = layer(x, enc_out)
        x = self.dec_ln_f(x)
        logits = self.out_head(x)    # (B, T_tgt, vocab_size)
        return logits

    def forward(self, src_ids, tgt_ids):
        enc_out = self.encode(src_ids)
        logits = self.decode(tgt_ids, enc_out)
        return logits


In [None]:
model = TinySeq2Seq(vocab_size=vocab_size, d_model=32, n_layers=2)