## Attention and Transformer

Note: ensure Runtime GPU is selected!

Install the required packages

In [None]:
%%capture
%pip install flax wandb tensorboardX tiktoken

Define the imports

In [None]:
import os

import numpy as np
import matplotlib.pyplot as plt
import sklearn.datasets as skdata
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
import optax
import flax
from flax import linen as nn

### Preliminaries

**Embeddings**

Embeddings are a way to convert discrete tokens (like words or indices) into continuous vectors. Instead of representing a word as a huge, sparse one-hot vector (mostly zeros, one one), an embedding maps it into a smaller, dense vector space where similar tokens can have similar representations.

**Why use embeddings?**

- **Efficiency:** One-hot vectors are high-dimensional and sparse (e.g., a vocabulary of 100 words gives a 100-dimensional vector, with 99 zeros).  
- **Expressiveness:** Dense vectors (say, 256 dimensions) can capture more nuanced relationships between words (similar words can end up with similar vectors).  
- **Learning:** Embeddings are learnable parameters. Instead of manually designing features, the model learns the best way to represent tokens for the task at hand.

**How does it work?**

Imagine you have a vocabulary of 100 tokens. A one-hot vector for any token is a vector of length 100 with a single 1 at the token’s index and 0 everywhere else. An embedding layer is essentially a lookup table (a matrix) of shape `(vocab_size, embedding_dim)`. When you "look up" a token, you use its one-hot vector to select the corresponding row from this matrix.

$$
\text{embedding} = \mathbf{onehot} \times E
$$

Because $ \mathbf{onehot} $ has all zeros except a one at the token index, this effectively selects the row of $E $ corresponding to that token.

In [None]:
vocab_size = 10
embed_dim = 3

# Create a random embedding matrix of shape (10, 3)
# Will be learned during training!
embedding_matrix = np.random.randn(vocab_size, embed_dim)
print("Embedding Matrix:\n", embedding_matrix)

# Let's say our token index is 2
token_index = 2

# Create one-hot vector for token_index 2
one_hot = np.zeros(vocab_size)
one_hot[token_index] = 1
print("One-hot vector:\n", one_hot)

# Get embedding by dot product: (1,5) * (5,3) = (1,3)
embedding = one_hot.dot(embedding_matrix)
print("Resulting embedding:\n", embedding)

# Alternatively, simply index the embedding matrix:
embedding_lookup = embedding_matrix[token_index]
print("Embedding via lookup:\n", embedding_lookup)

**Multi-Head (Dot-Product) Attention**

See the original paper https://arxiv.org/abs/1706.03762

Why attention? Intuitively, processing the sequence word for word is not always optimal:

<img src="https://substackcdn.com/image/fetch/f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F2d632b81-5bd6-432a-a456-37f20788be20_1180x614.png" width="500">

Access all sequence elements at each time step!

At its core, attention works by:

1. **Creating Queries, Keys, and Values:**  
   For each token, we create three vectors:
   - **Query (Q):** Represents what information the token is looking for in the rest of the sequence.
   - **Key (K):** Represents the token’s characteristics that might be useful for other tokens.
   - **Value (V):** Contains the actual information of the token.
   
   In many cases, these vectors are created by applying different learned transformations to the original embedding.

2. **Calculating Attention Scores:**  
   For a given token, we want to figure out how much attention it should pay to every other token. We do this by comparing its query with the keys of all tokens:
   - **Dot Product:** We calculate a dot product between the query of $i$-th token and the key of each token. This measures their similarity.
   - **Scaling:** The dot product is scaled (divided by the square root of the dimension of the vectors) to prevent the numbers from getting too large, which helps stabilize learning.

3. **Applying Softmax:**  
   The raw attention scores are then passed through a softmax function. This converts the scores into probabilities (or weights) that add up to 1. The softmax emphasizes the tokens with higher similarity scores, so $i$-th token will "attend" more to those tokens.

4. **Aggregating Values:**  
   Finally, we use these attention weights to create a new representation for token $i$:
   - Multiply each token's value vector by its corresponding attention weight.
   - Sum up these weighted vectors.  
     
   This weighted sum is the attention output for token $i$—it’s a blend of information from all tokens, focused according to the attention weights.


<img src="https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F6e1dcdeb-e096-4ff9-bdf9-3338e4efa4b4_1916x1048.png" width="800">



Images: https://magazine.sebastianraschka.com/p/understanding-and-coding-self-attention

In [None]:
import numpy as np

# ----- Step 1: Define a simple embedding matrix for 5 tokens -----
embedding_matrix = np.array([
    [0.1, 0.2, 0.3, 0.4],   # Token 0
    [0.5, 0.6, 0.7, 0.8],   # Token 1
    [0.9, 1.0, 1.1, 1.2],   # Token 2
    [1.3, 1.4, 1.5, 1.6],   # Token 3
    [1.7, 1.8, 1.9, 2.0]    # Token 4
])

# Assume our sequence consists of tokens: [1, 2, 3]
sequence_ids = [1, 2, 3]

embeddings = embedding_matrix[sequence_ids]  # Shape: (num_tokens, embed_dim)

# For simplicity, we use the embeddings directly as Q, K, and V
Q = embeddings.copy()  # Queries for all tokens
K = embeddings.copy()  # Keys for all tokens
V = embeddings.copy()  # Values for all tokens

# Let's focus on computing the attention for token 2 (which is at index 1 in our sequence)
target_index = 1  # This corresponds to token 2

# ----- Step 2: Compute the scaled dot product scores -----
embed_dim = embeddings.shape[1]
num_tokens = embeddings.shape[0]
scores = np.zeros(num_tokens)

# Compute dot product between Q[target_index] and every key K[j]
for j in range(num_tokens):
    dot_product = 0.0
    for i in range(embed_dim):
        dot_product += Q[target_index, i] * K[j, i]

    # Scale the score by the square root of the embedding dimension for stability
    scores[j] = dot_product / np.sqrt(embed_dim)

print("Raw attention scores for token 2:", scores)

# ----- Step 3: Apply Softmax to get attention weights -----
def softmax(x):
    # Subtracting the max for numerical stability
    exp_x = np.exp(x - np.max(x))
    return exp_x / np.sum(exp_x)

attn_weights = softmax(scores)
print("Attention weights for token 2:", attn_weights)

# ----- Step 4: Compute the attention output for token 2 -----
# This is the weighted sum of the value vectors from all tokens
attention_output = np.zeros(embed_dim)
for j in range(num_tokens):
    for i in range(embed_dim):
        attention_output[i] += attn_weights[j] * V[j, i]

print("Attention output for token 2:", attention_output)

### Mini Transformer in Flax

In [None]:
# Mask
jnp.tril(jnp.ones((10, 10)))

<img src="https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Feb61939e-e9ae-416f-8b75-9ed808be0782_1456x1392.png" width="600">

Image: https://magazine.sebastianraschka.com/p/building-a-gpt-style-llm-classifier




In [None]:
class NanoLM(nn.Module):
    vocab_size: int
    num_layers: int = 6
    num_heads: int = 8
    head_size: int = 32
    dropout_rate: float = 0.2
    embed_size: int = 256
    block_size: int = 64

    @nn.compact
    def __call__(self, x, training: bool = True):
        # x: (8, 10) -- batch of 8 sequences, each of length 10 (token indices)

        seq_len = x.shape[1] # seq_len = 10

        # from index in the vocab to a dense vector
        # Output: (8, 10, 256)
        x = nn.Embed(self.vocab_size, self.embed_size)(x)

        # positional embedding
        # Output: (10, 256), then broadcast to (8, 10, 256) for addition
        x = x + nn.Embed(self.block_size, self.embed_size)(jnp.arange(seq_len))

        # for N layers
        for _ in range(self.num_layers):
            # Pre-layer normalization:
            # nn.LayerNorm()(x) normalizes each token's 256-d vector
            # Shape remains (8, 10, 256)
            x_norm = nn.LayerNorm()(x)

            # Self-Attention Block with Residual Connection
            # - Input: x_norm (8, 10, 256)
            # - Internally, each head projects to a 32-d space
            #   Thus, for 8 heads: total dimension = 8 * 32 = 256
            # - The attention mechanism outputs a tensor of shape (8, 10, 256)
            # - The causal mask ensures each position only attends to previous ones
            # !!! output matches the initial x, the process repeats num_layers times
            attn_out = nn.MultiHeadDotProductAttention(
                num_heads=self.num_heads,
                qkv_features=self.head_size,
                out_features=self.head_size * self.num_heads,
                dropout_rate=self.dropout_rate,
            )(
                x_norm,  # queries: (8, 10, 256)
                x_norm,  # keys:    (8, 10, 256)
                mask=jnp.tril(jnp.ones((x.shape[-2], x.shape[-2]))),  # (10, 10)
                deterministic=not training,
            )

            # Residual connection: add attention output to original x
            # Shape remains (8, 10, 256)
            x = x + attn_out

            # Feedforward (MLP) Block with Residual Connection
            # extra "x +" helps with gradient flow and retains information from earlier layers
            x = x + nn.Sequential([
                nn.Dense(4 * self.embed_size), # Expand: (8, 10, 256) -> (8, 10, 1024)
                nn.relu, # Activation: (8, 10, 1024)
                nn.Dropout(self.dropout_rate, deterministic=not training), # (8, 10, 1024)
                nn.Dense(self.embed_size), # Project: (8, 10, 1024) -> (8, 10, 256)
            ])(nn.LayerNorm()(x))

        # Final Layer Normalization:
        # Normalizes final representation at each token
        # Shape: (8, 10, 256)
        x = nn.LayerNorm()(x)

        # Output Projection:
        # Projects each 256-d token representation to logits over the vocabulary
        # Dense layer: (8, 10, 256) -> (8, 10, vocab_size) = (8, 10, 100)
        return nn.Dense(self.vocab_size)(x)

Initialize the model and run inference

In [None]:
# Model initialization
key = jax.random.PRNGKey(1337)
mini_transformer = NanoLM(vocab_size=100)

# Example input: batch of token sequences (batch_size=8, seq_len=10)
x = jnp.ones((8, 10), dtype=jnp.int32)

# Initialize parameters
params = mini_transformer.init(key, x)

# Forward pass
y = mini_transformer.apply(params, x, False)

# Predict the next token at every (!) position
# aka "teacher forcing" - helps the model learn the structure of the language by maximizing the likelihood of the entire sequence
# if we only produced an output for the end token,
# we'd lose valuable learning signals from every intermediate step
y.shape

In [None]:
def count_params(params):
    leaves = jax.tree_util.tree_leaves(params)
    return sum(x.size for x in leaves)

n_params = count_params(params)
print(f"Total number of parameters: {n_params:,}")

In [None]:
def params_size_in_bytes(params):
    leaves = jax.tree_util.tree_leaves(params)
    total_bytes = sum([x.size * x.dtype.itemsize for x in leaves])
    return total_bytes

size_bytes = params_size_in_bytes(params)
print("Total parameters size: {:.2f} MB".format(size_bytes / (1024 ** 2)))

Makes sense, since with 3,426,468 parameters and assuming 32-bit (4-byte) floats, the total parameter memory is roughly:

$$
3,426,468 \times 4 \text{ bytes} \approx 13,705,872 \text{ bytes} \approx 13.07 \text{ MB}
$$

**Resume**:

<img src="https://pbs.twimg.com/media/GCnZNRraAAE9HAx?format=png&name=small" width="600">

Source: https://x.com/srush_nlp/status/1741161984928920027