# Lab

In [1]:
%load_ext watermark
%watermark -v -p numpy,pandas,polars,omegaconf --conda

Python implementation: CPython
Python version       : 3.11.8
IPython version      : 8.22.2

numpy    : 1.26.4
pandas   : 2.2.1
polars   : 0.20.18
omegaconf: 2.3.0

conda environment: torch_p11



In [2]:
# Built-in library
from pathlib import Path
import re
import json
from typing import Any, Optional, Union
import logging
import warnings

# Standard imports
import numpy as np
import numpy.typing as npt
from pprint import pprint
import pandas as pd
import polars as pl
from rich.console import Console
from rich.theme import Theme

custom_theme = Theme(
    {
        "info": "#76FF7B",
        "warning": "#FBDDFE",
        "error": "#FF0000",
    }
)
console = Console(theme=custom_theme)

# Visualization
import matplotlib.pyplot as plt


# Pandas settings
pd.options.display.max_rows = 1_000
pd.options.display.max_columns = 1_000
pd.options.display.max_colwidth = 600

warnings.filterwarnings("ignore")


# Black code formatter (Optional)
%load_ext lab_black

# auto reload imports
%load_ext autoreload
%autoreload 2

In [3]:
import torch
from torch import nn, Tensor
import torch.nn.functional as F

In [4]:
seed: int = 123

GPT_CONFIG_124M: dict[str, Any] = {
    "vocab_size": 50_257,
    "context_length": 1_024,
    "emb_dim": 768,
    "n_heads": 12,  # Number of attention heads
    "n_layers": 12,
    "drop_rate": 0.1,  # Dropout rate
    "qkv_bias": False,
}

In [5]:
class SelfAttention(nn.Module):
    def __init__(self, in_feats: int, out_feats: int, qkv_bias: bool = False) -> None:
        super().__init__()

        # Size: (seq_len, emb_dim)
        self.query_weights = nn.Linear(in_feats, out_feats, bias=qkv_bias)
        self.key_weights = nn.Linear(in_feats, out_feats, bias=qkv_bias)
        self.value_weights = nn.Linear(in_feats, out_feats, bias=qkv_bias)

    def forward(self, x: Tensor) -> Tensor:
        # b_size, seq_len, emb_dim = x.shape
        # (b_size, emb_dim, seq_len) @ (seq_len, emb_dim) -> (b_size, emb_dim, emb_dim)
        query = self.query_weights(x)
        key = self.key_weights(x)
        value = self.value_weights(x)

        # Attention scores
        # (b_size, emb_dim, seq_len) @ (seq_len, emb_dim) -> (b_size, emb_dim, emb_dim)
        attn_scores: Tensor = torch.matmul(query, key.transpose(-1, -2))
        attn_weights: Tensor = F.softmax(attn_scores / key.shape[1] ** 0.5, dim=-1)
        # (seq_len, emb_dim) @ (b_size, emb_dim, emb_dim) -> (b_size, seq_len, emb_dim)
        context_vector: Tensor = torch.matmul(attn_weights, value)
        return context_vector

In [6]:
vocab_size: int = 27
embedding_dim: int = 16
context_size: int = 8
batch_size: int = 2

input_seq: Tensor = torch.rand(
    size=(batch_size, context_size, embedding_dim), dtype=torch.float32
)
self_attn: SelfAttention = SelfAttention(embedding_dim, embedding_dim)
context_vector: Tensor = self_attn(input_seq)
context_vector

tensor([[[-0.4171,  0.3059,  0.0703,  0.5011, -0.3046, -0.0724,  0.0533,
           0.2701,  0.3423, -0.1482,  0.1561, -0.1429, -0.1945, -0.3114,
           0.1051, -0.4797],
         [-0.4133,  0.3059,  0.0693,  0.4982, -0.3075, -0.0707,  0.0502,
           0.2698,  0.3414, -0.1518,  0.1533, -0.1413, -0.1940, -0.3147,
           0.1037, -0.4768],
         [-0.4148,  0.3054,  0.0728,  0.5028, -0.3058, -0.0762,  0.0506,
           0.2726,  0.3429, -0.1534,  0.1544, -0.1431, -0.1975, -0.3136,
           0.1112, -0.4802],
         [-0.4140,  0.3058,  0.0721,  0.5005, -0.3061, -0.0759,  0.0493,
           0.2709,  0.3421, -0.1531,  0.1538, -0.1436, -0.1974, -0.3152,
           0.1090, -0.4789],
         [-0.4137,  0.3057,  0.0704,  0.4968, -0.3051, -0.0744,  0.0506,
           0.2679,  0.3405, -0.1492,  0.1553, -0.1450, -0.1951, -0.3144,
           0.1035, -0.4777],
         [-0.4123,  0.3047,  0.0694,  0.4985, -0.3070, -0.0703,  0.0537,
           0.2725,  0.3402, -0.1508,  0.1551, -0.141

In [7]:
class CausalSelfAttention(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_size: int,
        dropout_pct: float = 0.0,
        qkv_bias: bool = False,
    ) -> None:
        super().__init__()

        # Size: (seq_len, emb_dim)
        self.query_weights = nn.Linear(d_model, d_model, bias=qkv_bias)
        self.key_weights = nn.Linear(d_model, d_model, bias=qkv_bias)
        self.value_weights = nn.Linear(d_model, d_model, bias=qkv_bias)

        self.register_buffer(
            "mask", torch.triu(torch.ones(context_size, context_size), diagonal=1)
        )
        self.dropout = nn.Dropout(p=dropout_pct)

    def forward(self, x: Tensor) -> Tensor:
        b_size, seq_len, emb_dim = x.shape
        # (b_size, emb_dim, seq_len) @ (seq_len, emb_dim) -> (b_size, emb_dim, emb_dim)
        query = self.query_weights(x)
        key = self.key_weights(x)
        value = self.value_weights(x)

        # Attention scores
        # (b_size, emb_dim, seq_len) @ (seq_len, emb_dim) -> (b_size, emb_dim, emb_dim)
        attn_scores: Tensor = torch.matmul(query, key.transpose(-1, -2))
        # Apply mask (inplace). The slicing ensures that the seq_len is consistent across the batch.
        attn_scores.masked_fill_(self.mask.bool()[:seq_len, :seq_len], -torch.inf)

        attn_weights: Tensor = F.softmax(attn_scores / key.shape[1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        # (seq_len, emb_dim) @ (b_size, emb_dim, emb_dim) -> (b_size, seq_len, emb_dim)
        context_vector: Tensor = torch.matmul(attn_weights, value)
        return context_vector

In [8]:
torch.manual_seed(seed)

input_seq: Tensor = torch.rand(
    size=(batch_size, context_size, embedding_dim), dtype=torch.float32
)
causal_self_attn: CausalSelfAttention = CausalSelfAttention(
    d_model=embedding_dim, context_size=context_size, dropout_pct=0.1
)
context_vector: Tensor = causal_self_attn(input_seq)
context_vector.shape

torch.Size([2, 8, 16])

In [9]:
class SequentialMultiHeadAttention(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_size: int,
        num_heads: int,
        dropout: float = 0.0,
        qkv_bias: bool = False,
    ):
        super().__init__()

        self.heads = nn.ModuleList(
            [
                CausalSelfAttention(d_model, context_size, dropout, qkv_bias)
                for _ in range(num_heads)
            ]
        )

    def forward(self, x: Tensor) -> Tensor:
        # Concat along the feature (emb) dimension
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [10]:
torch.manual_seed(seed)

multi_head_attn: SequentialMultiHeadAttention = SequentialMultiHeadAttention(
    d_model=embedding_dim,
    context_size=context_size,
    num_heads=3,
    dropout=0.1,
)
print(f"{input_seq.shape = }")
print(f"{multi_head_attn = }")
output: Tensor = multi_head_attn(input_seq)
print(f"{output.shape = }")

input_seq.shape = torch.Size([2, 8, 16])
multi_head_attn = SequentialMultiHeadAttention(
  (heads): ModuleList(
    (0-2): 3 x CausalSelfAttention(
      (query_weights): Linear(in_features=16, out_features=16, bias=False)
      (key_weights): Linear(in_features=16, out_features=16, bias=False)
      (value_weights): Linear(in_features=16, out_features=16, bias=False)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
)
output.shape = torch.Size([2, 8, 48])


<hr><br><br>

### Multi-head Attention

- Instead of relying on a `single attention mechanism`, `multi-head attention` uses multiple "`heads`" that work in parallel.
- Each `head` analyzes the input sequence from a `slightly different perspective`.
- These individual analyses are then `combined` (concatenated) to create a `richer understanding` of the relationships between elements in the sequence.

#### Here's a breakdown of the key points with clarification:

- **`Causal self-attention`**: This refers to a type of attention where an element in the sequence only attends to the elements that come before it in the sequence.

- **`Multiple heads in parallel`**: The core concept of `Multi-head Attention`. Instead of one attention mechanism, multiple "heads" analyze the data simultaneously.

- **`Input sequence split and processed`**: Each head gets a portion of the original input data (`d_model`) based on the number of heads (`num_heads`). This creates a lower dimension for each head (`head_dim`) for processing.

- **`Concatenation`**: After each head analyzes its portion of the data, the results are combined (concatenated) to create a richer representation that captures insights from all the heads.
  - E.g. 
    - With a `d_model` of 64 (original input has 64 features) and 4 heads, each head gets 16 dimensions (features) to process (64 / 4). 
    - These 4 heads analyze the data in `parallel`, and then their outputs are `combined` to create a `final representation` with potentially deeper understanding than a single head could achieve.

In [11]:
class MultiHeadAttention(nn.Module):
    """
    A Multi-Head Attention layer for use in neural network architectures.

    Args:
        d_model (int): The dimension of the input and output features.
        context_size (int): The size of the context window (neighborhood considered for attention).
        num_heads (int): The number of heads used in the Multi-Head Attention.
        dropout_pct (float, optional): The dropout probability for the attention weights. Defaults to 0.1.
        qkv_bias (bool, optional): Whether to add bias terms to the linear transformations for queries, keys,
        and values. Defaults to False.

    Raises:
        AssertionError: If `d_model` is not divisible by `num_heads`.

    Shapes:
        - Input: (batch_size, seq_len, d_model)
        - Output: (batch_size, seq_len, d_model)

    Note:
        B, T, C: (batch, seq_len, d_model)

    Example:
        >>> import torch
        >>> model = MultiHeadAttention(d_model=512, context_size=32, num_heads=8)
        >>> input_tensor = torch.randn(16, 100, 512)
        >>> output_tensor = model(input_tensor)
        >>> print(output_tensor.shape)
        torch.Size([16, 100, 512])
    """

    def __init__(
        self,
        d_model: int,
        context_size: int,
        num_heads: int,
        dropout_pct: float = 0.1,
        qkv_bias: bool = False,
    ):
        super().__init__()

        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads  # Dim of each head
        self.context_size = context_size
        self.dropout = nn.Dropout(dropout_pct)

        self.query_W = nn.Linear(d_model, d_model, bias=qkv_bias)
        self.key_W = nn.Linear(d_model, d_model, bias=qkv_bias)
        self.value_W = nn.Linear(d_model, d_model, bias=qkv_bias)
        self.out_proj = nn.Linear(d_model, d_model)

    def _split_heads(self, x: Tensor) -> Tensor:
        """Split the features at each head by reshaping and transposing them.

        Returns:
            torch.Tensor: Tensor of shape (batch_size, num_heads, seq_length, head_dim).
        """
        # B, T, C
        batch_size, seq_len, _ = x.size()

        # After transposing: (B, n_heads, T, h_dim)
        x_split: Tensor = x.view(
            batch_size, seq_len, self.num_heads, self.head_dim
        ).transpose(1, 2)
        return x_split

    def _concat_heads(self, x: Tensor) -> Tensor:
        """
        Concatenates the heads of the input tensor along the last dimension.

        Args:
            x (torch.Tensor): Input tensor of shape (B, n_heads, T, h_dim).

        Returns:
            torch.Tensor: Concatenated tensor of shape (B, T, n_heads * h_dim).
        """
        B, n_heads, T, h_dim = x.size()
        # After transposing: (B, T, n_heads * h_dim)
        # self.d_model = n_heads * h_dim
        x_concat: Tensor = x.transpose(1, 2).contiguous().view(B, T, (n_heads * h_dim))
        return x_concat

    def forward(self, x: Tensor) -> Tensor:
        # Compute the query, key and value features
        # (B, T, C) @ (C, C) -> (B, T, C)
        queries: Tensor = self.query_W(x)  # (B, T, C)
        keys: Tensor = self.key_W(x)  # (B, T, C)
        values: Tensor = self.value_W(x)  # (B, T, C)

        # Split the features
        # C = n_heads * h_dim
        # (B, T, C) -> (B, n_heads, T, h_dim)
        queries = self._split_heads(queries)
        keys = self._split_heads(keys)
        values = self._split_heads(values)

        # Calculate the attention
        # (B, n_heads, T, h_dim) @ (B, n_heads, h_dim, T) -> (B, n_heads, T, T)
        attn_scores: Tensor = queries @ keys.transpose(-1, -2)  # (B, n_heads, T, T)
        attn_weights: Tensor = F.softmax(attn_scores / keys.shape[-1], dim=-1)
        attn_weights = self.dropout(attn_weights)  # (B, n_heads, T, T)

        # (B, n_heads, T, T) @ (B, n_heads, T, h_dim) -> (B, n_heads, T, h_dim)
        context_vectors: Tensor = attn_weights @ values  # (B, n_heads, T, h_dim)
        # Concatenate the attention and the features
        context_vectors = self._concat_heads(context_vectors)  # (B, T, n_heads * h_dim)
        # (B, T, C) @ (C, C) -> (B, T, C)
        context_vectors = self.out_proj(context_vectors)  # (B, T, C)
        return context_vectors

In [12]:
# (B, T, D) @ (D, D) -> (B, T, D)
torch.manual_seed(seed)

multi_head_attn: MultiHeadAttention = MultiHeadAttention(
    d_model=embedding_dim,
    context_size=context_size,
    num_heads=2,
    dropout_pct=0.1,
)
print(f"{input_seq.shape = }")
print(f"{multi_head_attn = }")
output: Tensor = multi_head_attn(input_seq)
print(f"{output.shape = }")

output

input_seq.shape = torch.Size([2, 8, 16])
multi_head_attn = MultiHeadAttention(
  (dropout): Dropout(p=0.1, inplace=False)
  (query_W): Linear(in_features=16, out_features=16, bias=False)
  (key_W): Linear(in_features=16, out_features=16, bias=False)
  (value_W): Linear(in_features=16, out_features=16, bias=False)
  (out_proj): Linear(in_features=16, out_features=16, bias=True)
)
output.shape = torch.Size([2, 8, 16])


tensor([[[-0.3011, -0.4950, -0.2240, -0.1188,  0.0195,  0.1083,  0.2722,
           0.2290,  0.1635, -0.0806, -0.0211, -0.2224, -0.0863,  0.0350,
          -0.3421, -0.4037],
         [-0.2906, -0.5310, -0.2491, -0.1227,  0.0060,  0.1080,  0.2924,
           0.2448,  0.1576, -0.0843, -0.0560, -0.2239, -0.1180,  0.0230,
          -0.3593, -0.4231],
         [-0.2966, -0.4273, -0.1875, -0.1295,  0.0417,  0.1205,  0.2007,
           0.2339,  0.1815, -0.0691,  0.0205, -0.1844, -0.0510,  0.0703,
          -0.3123, -0.3564],
         [-0.2888, -0.5329, -0.2493, -0.1240,  0.0080,  0.1064,  0.2928,
           0.2439,  0.1564, -0.0811, -0.0586, -0.2241, -0.1182,  0.0230,
          -0.3588, -0.4239],
         [-0.3018, -0.4952, -0.2254, -0.1186,  0.0177,  0.1094,  0.2727,
           0.2299,  0.1639, -0.0833, -0.0211, -0.2217, -0.0873,  0.0346,
          -0.3431, -0.4041],
         [-0.2925, -0.4761, -0.1932, -0.1313,  0.0508,  0.1081,  0.2533,
           0.1996,  0.1579, -0.0480, -0.0030, -0.225