### 1. Dot Product Attention (Stateless)

- **Code task:** Implement a basic function for dot product attention that takes queries, keys, and values as inputs and returns attention scores.

### 2. Scaled Dot Product Attention

- **Code task:** Modify the dot product attention function to include scaling by the square root of the key dimension for numerical stability.

### 3. Self-Attention (Trainable)

- **Code task:** Implement a trainable self-attention module that uses learnable weight matrices to transform queries, keys, and values.

### 4. Causal Self-Attention (Includes Mask Param)

- **Code task:** Extend the self-attention module to support causal masking, ensuring that information flow respects autoregressive constraints.

### 5. Multi-Head Attention

- **Code task:** Extend the causal self-attention module to support multiple heads, computes attention for each head, and combines the results.

### 6. Feedforward Network

- **Code task:** Implement a typical transformer feedforward network.

### 7. Trainable Transformer Block

- **Code task:** Construct a complete transformer block by combining the causal multi-head self-attention module with the feedforward network.

### 8. Residual Connections

- **Code task:** Incorporate residual connections into the transformer block to stabilize training and improve gradient flow.

### 9. Layer Normalization

- **Code task:** Add both layer normalization and RMS normalization to the transformer block, applying them in the appropriate sequence with residual connections. Ensure the implementations are modular with an initialization argument to select between the two (defaulting to RMS normalization), so either method can be easily selected and tested.

### 10. Transformer Model (Decoder-Only)

- **Code task:** Implement a decoder-only transformer model with the following components:
  - An embedding layer to process input tokens.
  - A learnable positional encoding layer.
  - A stack of transformer blocks (as defined in previous steps).
  - A final linear layer to project outputs to the vocabulary size for logits generation.

### 11. Dropout Support

- **Code task:** Add support for dropout in the following locations of the transformer model:
  - Attention scores after the softmax operation in the attention mechanism.
  - The output of the feedforward network in each transformer block.
  - The output of the embedding layer, including positional encodings.
  - The output of the stacked transformer blocks before the final linear layer.
  - Ensure that dropout is configurable with a specified rate and is only applied during training.


In [43]:
import torch


def dot_product_attention(
    queries: torch.Tensor,
    keys: torch.Tensor,
    values: torch.Tensor,
    mask: torch.Tensor | None = None,
) -> torch.Tensor:
    q_len, q_dim = queries.shape[-2:]
    k_len, k_dim = keys.shape[-2:]
    v_len, v_dim = values.shape[-2:]
    assert q_dim == k_dim, "Queries and keys must have the same embedding size"
    assert k_len == v_len, "Keys and values must be the same length"
    scores = queries @ keys.transpose(-2, -1)  # [...batch dimensions, q_len, k_len]
    scores /= k_dim**0.5
    if mask is not None:
        assert mask.shape[-2:] == (
            q_len,
            k_len,
        ), "mask's last two dimensions must be equal to the query and key lengths respectively"
        scores += mask * -1e9
    weights = torch.softmax(scores, dim=-1)
    return weights @ values  # [...batch dimensions, q_len, v_dim]


dot_product_attention(
    queries=torch.randn([2, 3]),
    keys=torch.randn([3, 3]),
    values=torch.randn([3, 4]),
    mask=torch.zeros([2, 3]),
).shape

torch.Size([2, 4])

In [35]:
class SelfAttention(torch.nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.wq = torch.nn.Linear(dim, dim, bias=False)
        self.wk = torch.nn.Linear(dim, dim, bias=False)
        self.wv = torch.nn.Linear(dim, dim, bias=False)
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert (
            x.shape[-1] == self.dim
        ), f"Input x's last/innermost dimension {x.shape[-1]} does not match this self-attention module's embedding dim {self.dim}"
        return dot_product_attention(self.wq(x), self.wk(x), self.wv(x))


attn = SelfAttention(4)
attn.forward(torch.randn([2, 3, 4])).shape

torch.Size([2, 3, 4])

In [65]:
class CausalSelfAttention(torch.nn.Module):
    def __init__(self, dim: int, num_heads: int = 1) -> None:
        super().__init__()
        assert (
            dim % num_heads == 0
        ), f"The number of heads {num_heads} must evenly divide the embedding dimension {dim}"
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.wq = torch.nn.Linear(dim, dim, bias=False)
        self.wk = torch.nn.Linear(dim, dim, bias=False)
        self.wv = torch.nn.Linear(dim, dim, bias=False)
        self.dim = dim

    def forward(
        self, x: torch.Tensor, mask: torch.Tensor | None = None
    ) -> torch.Tensor:
        *batch_dims, seq_len, dim = x.shape
        assert (
            dim == self.dim
        ), f"Input x's last/innermost dimension {dim} does not match this self-attention module's embedding dim {self.dim}"
        if mask is None:
            mask = torch.triu(
                torch.ones([seq_len, seq_len], dtype=x.dtype, device=x.device),
                diagonal=1,
            )
        split = lambda x: x.view(
            *batch_dims, seq_len, self.num_heads, self.head_dim
        ).transpose(-3, -2)
        return (
            dot_product_attention(
                queries=split(self.wq(x)),
                keys=split(self.wk(x)),
                values=split(self.wv(x)),
                mask=mask,
            )
            .transpose(-3, -2)
            .reshape(*batch_dims, seq_len, dim)
        )


attn = CausalSelfAttention(4, num_heads=2)
attn.forward(torch.randn([2, 3, 4])).shape

torch.Size([2, 3, 4])

In [68]:
class FeedForward(torch.nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.w_in = torch.nn.Linear(dim, dim * 4)
        self.w_out = torch.nn.Linear(dim * 4, dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.w_in(x)
        x = torch.nn.functional.gelu(x)
        return self.w_out(x)
    
ff = FeedForward(4)
ff.forward(torch.randn(2, 3, 4)).shape

torch.Size([2, 3, 4])