<a href="https://colab.research.google.com/github/million-in/MLA/blob/main/Multi_Head_Latent_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook is an implementation of Multi head Latent Attention. A solution to dealing with KV cache.

This implementation follows just one single attention block for simplicty. and we use GPT-3 configuration. D_MODEL, D_K.

Multi head Latent Attention is mainly projecting hidden state of tokens into a row rank dimensional space or, a latent representation. wich is compressed. And then having to learners that will build Key and Value outputs for the tokens from the compressed matrix.

MLA simply comes down two three steps,

1. Low rank key-value joint compression into a latent representation.
2. Applying decouple RoPe
3. Backward projection to a higher dimension space for key and values from the latent representation.


In [1]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from flax.core import freeze, unfreeze

In [2]:
#  Define and initialize the embedding layer
class EmbeddingLayer(nn.Module):
    num_embeddings: int
    features: int

    def setup(self):
        self.embedding = nn.Embed(num_embeddings=self.num_embeddings, features=self.features)

    def __call__(self, x):
        return self.embedding(x)

# Define model
D_MODEL = 12288  # Embedding dimension
NUM_HEADS = 96   # Number of attention heads
D_K = D_MODEL // NUM_HEADS  # Head dimension

model = EmbeddingLayer(num_embeddings=50257, features=D_MODEL)

# Initialize the model
key = jax.random.PRNGKey(0)
params = model.init(key, jnp.array([0]))  # Initialize with a dummy input

# Print initialized embedding weights
print("Embedding layer initialized:")
print(params)

Embedding layer initialized:
{'params': {'embedding': {'embedding': Array([[ 6.6024400e-03,  1.5081277e-02, -1.1323430e-02, ...,
        -4.9526279e-05,  4.3605505e-03, -2.2229243e-03],
       [-8.1163282e-03,  6.7466344e-03, -3.4407864e-03, ...,
        -9.2547555e-03,  8.8060035e-05,  3.5618185e-03],
       [-7.9459362e-03, -8.6927693e-03,  9.2485854e-03, ...,
        -8.3842678e-03,  1.8406912e-03, -5.2417803e-04],
       ...,
       [-4.3475558e-03,  4.9987296e-03,  6.4304932e-03, ...,
         2.5832008e-03,  2.6981889e-03, -3.2927021e-03],
       [ 7.2161956e-03,  1.2845425e-02, -9.2969723e-03, ...,
        -4.9834945e-03, -2.9959911e-03,  9.5217573e-03],
       [-1.1894217e-02,  5.3528803e-03,  1.3001672e-02, ...,
         5.9268093e-03, -2.5876894e-04, -4.3124328e-03]], dtype=float32)}}}


In [3]:

# Test the model with a dummy set of tokens
dummy_tokens = jnp.arange(1024) % 50257  # Generate a sequence of dummy tokens
embedded_tokens = model.apply(params, dummy_tokens)

print("Test output (embedded tokens shape):", embedded_tokens.shape)
print("First embedded vector:", embedded_tokens[0])
print("Last embedded vector:", embedded_tokens[-1])


Test output (embedded tokens shape): (1024, 12288)
First embedded vector: [ 6.6024400e-03  1.5081277e-02 -1.1323430e-02 ... -4.9526279e-05
  4.3605505e-03 -2.2229243e-03]
Last embedded vector: [-0.00201211 -0.02040401  0.01133402 ...  0.00343121 -0.00958229
 -0.01465543]


In [4]:
# Compute Layer Normalization with learnable parameters
def layer_norm(x, gamma, beta, epsilon=1e-5):
    mean = jnp.mean(x, axis=-1, keepdims=True)
    variance = jnp.var(x, axis=-1, keepdims=True)
    normalized_x = (x - mean) / jnp.sqrt(variance + epsilon)
    return normalized_x * gamma + beta

# Initialize learnable parameters
gamma = jnp.ones((1, D_MODEL))  # Scale parameter
beta = jnp.zeros((1, D_MODEL))  # Shift parameter

normalized_embeddings = layer_norm(embedded_tokens, gamma, beta)

print("First normalized vector:", normalized_embeddings[0])
print("Last normalized vector:", normalized_embeddings[-1])

First normalized vector: [ 0.7045637   1.5899293  -1.1672673  ...  0.00996106  0.47046414
 -0.21698657]
Last normalized vector: [-0.20973752 -2.1317196   1.1849535  ...  0.3590983  -1.0008328
 -1.5309834 ]


Now here we apply the first phase of our MLA. Projecting the hidden states (here normalized embeddings are acting as hidden states since this is the first layer) into a row rank dimension space. The dimension as you can see is 4 times smaller than the dimension of the model ie REDUCED_D_MODEL = D_MODEL // 4.

We do this by introducing a new downward projection matrix W(c) as parameters. to form C




In [5]:
# now here we implement mla starting from only cell 4 of layer norm
# Project normalized output to a lower-rank latent space for compression

def project_to_low_rank(x, projection_matrix):
    """Projects the input to a lower-rank latent space.

    Args:
      x: The input tensor.
      projection_matrix: The matrix used for projection.

    Returns:
      The projected tensor.
    """
    return jnp.matmul(x, projection_matrix)

# Define reduced dimension
REDUCED_D_MODEL = D_MODEL // 4

# Initialize the projection matrix
projection_key = jax.random.PRNGKey(1)
projection_matrix = jax.random.normal(projection_key, (D_MODEL, REDUCED_D_MODEL)) * jnp.sqrt(2 / (D_MODEL + REDUCED_D_MODEL)) #Glorot initialization


# Project the normalized embeddings
projected_embeddings = project_to_low_rank(normalized_embeddings, projection_matrix)

# Print the shape of the output
print("Shape of projected embeddings:", projected_embeddings.shape)
print("First projected vector:", projected_embeddings[0])
print("Last projected vector:", projected_embeddings[-1])

Shape of projected embeddings: (1024, 3072)
First projected vector: [ 2.8534365  -0.5653596   0.91363555 ... -0.28176805 -1.0861151
 -0.5680051 ]
Last projected vector: [-0.99326026 -0.58047986  0.37042296 ...  2.9131997  -0.23583116
 -1.5297287 ]


Now we Apply RoPe on half of the latent representation. This is because we definetly can not apply RoPe on the whole latent vector because that would encode only position information, while we will also use this to compute value

In [6]:
# Apply RoPE to a portion of the projected embeddings

def apply_rope(x, freqs_cis):
    """Applies Rotary Positional Embeddings (RoPE) to the input.

    Args:
        x: The input tensor to apply RoPE to.  Shape: (seq_len, dim)
        freqs_cis: Precomputed complex exponentials. Shape: (seq_len, dim // 2)
                   This contains the values  exp(i * m * theta_t)  for
                   positions m and frequencies 1 / (10000^(2t/d)).
    Returns:
      The input tensor with RoPE applied.
    """
    x_complex = x.astype(jnp.complex64)
    x_rotated = x_complex * freqs_cis
    return x_rotated.real.astype(x.dtype)  # Return real part, same dtype as input

def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
    """Precomputes the complex exponentials for RoPE.

    Args:
        dim: The dimension of the embeddings (must be even for RoPE).
        seq_len: The maximum sequence length.
        theta: The base for the geometric progression.

    Returns:
        freqs_cis: Complex exponentials, shape (seq_len, dim // 2).
    """
    freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2)[: (dim // 2)].astype(jnp.float32) / dim))
    t = jnp.arange(seq_len)  # type: ignore
    freqs = jnp.outer(t, freqs)  # type: ignore
    freqs_cis = jnp.exp(1j * freqs)
    return freqs_cis


# 1. Split the projected embeddings
split_index = REDUCED_D_MODEL // 2
content_embeddings = projected_embeddings[:, :split_index]
positional_embeddings = projected_embeddings[:, split_index:]

# 2. Precompute freqs_cis for RoPE
seq_len = projected_embeddings.shape[0]  # Get the sequence length
freqs_cis = precompute_freqs_cis(positional_embeddings.shape[1] * 2, seq_len)  # *2 to account the split
freqs_cis = freqs_cis[:,:positional_embeddings.shape[1]] #new line added

# 3. Apply RoPE to the positional embeddings
roped_positional_embeddings = apply_rope(positional_embeddings, freqs_cis)

# 4. Combine the content and RoPE-transformed positional embeddings
combined_embeddings = jnp.concatenate([content_embeddings, roped_positional_embeddings], axis=-1)

print("Shape of content embeddings:", content_embeddings.shape)
print("Shape of original positional embeddings:", positional_embeddings.shape)
print("Shape of RoPE'd positional embeddings:", roped_positional_embeddings.shape)
print("Shape of combined embeddings:", combined_embeddings.shape)
print("First combined vector:", combined_embeddings[0])
print("Last combined vector:", combined_embeddings[-1])

Shape of content embeddings: (1024, 1536)
Shape of original positional embeddings: (1024, 1536)
Shape of RoPE'd positional embeddings: (1024, 1536)
Shape of combined embeddings: (1024, 3072)
First combined vector: [ 2.8534365  -0.5653596   0.91363555 ... -0.28176805 -1.0861151
 -0.5680051 ]
Last combined vector: [-0.99326026 -0.58047986  0.37042296 ...  2.8974118  -0.23456831
 -1.5216347 ]


In [7]:
# Initialize Multi-Head Attention Matrices

def initialize_attention_head(key, d_model, reduced_d_model, d_k):
    """Initializes the projection matrices for a single attention head.

    Args:
        key: A JAX PRNGKey for random number generation.
        d_model: The input dimension of the model.
        reduced_d_model: The dimension of reduced space (from previous cell)
        d_k: The dimension of the key, query, and value projections.

    Returns:
        A dictionary containing the initialized query, key, and value
        projection matrices for the head.
    """
    k_q, k_k, k_v = jax.random.split(key, 3)  # Split key for each matrix

    # Xavier/Glorot initialization for each matrix
    query_projection = jax.random.normal(k_q, (d_model, d_k)) * jnp.sqrt(2 / (d_model + d_k))
    key_projection = jax.random.normal(k_k, (reduced_d_model, d_k)) * jnp.sqrt(2 / (reduced_d_model + d_k))
    value_projection = jax.random.normal(k_v, (reduced_d_model, d_k)) * jnp.sqrt(2 / (reduced_d_model + d_k))

    return {
        "query": query_projection,
        "key": key_projection,
        "value": value_projection,
    }


# Initialize all attention heads
attention_heads = []
main_key = jax.random.PRNGKey(2)  # New key for the attention heads

for i in range(NUM_HEADS):
    head_key = jax.random.fold_in(main_key, i)  # Generate a unique key for each head
    head_params = initialize_attention_head(head_key, D_MODEL, REDUCED_D_MODEL, D_K)
    attention_heads.append(head_params)

# Print the shapes of the matrices for a single head (since they are all the same)
print(f"Head (Shapes are the same for all heads):")
print(f"  Query Projection Shape: {attention_heads[0]['query'].shape}")
print(f"  Key Projection Shape: {attention_heads[0]['key'].shape}")
print(f"  Value Projection Shape: {attention_heads[0]['value'].shape}")
print("-" * 20)

Head (Shapes are the same for all heads):
  Query Projection Shape: (12288, 128)
  Key Projection Shape: (3072, 128)
  Value Projection Shape: (3072, 128)
--------------------


Now since we have a compresedd representation in a low dimensiona space and we then have applied a decouple RoPe to include positional encodings. The question is how do we calculate attention.

Here we apply the last step of decompression. First we project the normalized input into query projection matrix and apply rope. yeap for query we definetly dont change anything. we compute it as normal.

Then for value and Key we introduce new learner W(k) and W(v). where these matrices learn to form a key and value vector for each token from the latent vector C.

Remember how we split C to apply decouple loss if you want you can calcuate value vectors from the part of C without just positional embedings as we dont need them in Value vectors.

Now thats MLA in full now whats next is just computing attention on the constructed Keys and Value vectors.

In [8]:
# Compute Query, Key, and Value Vectors with RoPE (Using Provided RoPE Function - Query Only)

def compute_attention_inputs(query_input, key_value_input, attention_heads):
    """Computes the query, key, and value vectors for all attention heads.

    Args:
        query_input: The input for query projection (normalized embeddings).
        key_value_input: The input for key and value projection (combined embeddings).
        attention_heads: A list of dictionaries, each containing the projection
                         matrices for an attention head.

    Returns:
        A tuple containing lists of query, key, and value vectors for all heads.
    """
    all_queries = []
    all_keys = []
    all_values = []

    for head in attention_heads:
        # Query projection
        query = jnp.matmul(query_input, head['query'])

        # Key and Value projections
        key = jnp.matmul(key_value_input, head['key'])
        value = jnp.matmul(key_value_input, head['value'])

        all_queries.append(query)
        all_keys.append(key)
        all_values.append(value)

    return all_queries, all_keys, all_values

def apply_rope(x, print_transforms=False):
    original_shape = x.shape
    x = x.reshape(1, x.shape[0], 1, x.shape[1])
    batch_size, seq_len, num_heads, d_k = x.shape


    theta = 10000.0 ** (-jnp.arange(0, d_k, 2) / d_k)
    pos = jnp.arange(seq_len)[:, None]

    sin = jnp.sin(pos * theta)
    cos = jnp.cos(pos * theta)


    sin = sin[None, :, None, :]
    cos = cos[None, :, None, :]

    x_even, x_odd = x[..., ::2], x[..., 1::2]

    if print_transforms:
        print("\nRoPE Transformations:")
        print(f"Theta shape (frequencies): {theta.shape}")
        print(f"First few theta values: {theta[:5]}")


        example_cos = cos[0, 0, 0]
        example_sin = sin[0, 0, 0]
        example_even = x_even[0, 0, 0]
        example_odd = x_odd[0, 0, 0]


        print(f"\nFirst position rotation values:")
        print(f"cos values: {example_cos[:5]}")
        print(f"sin values: {example_sin[:5]}")
        print(f"\nExample rotation for first vector pair:")
        print(f"Original values (even, odd): ({example_even[0]}, {example_odd[0]})")
        print(f"Rotated values: ({(example_even[0] * example_cos[0] - example_odd[0] * example_sin[0]).item()}, "
              f"{(example_even[0] * example_sin[0] + example_odd[0] * example_cos[0]).item()})")

    x_rotated = jnp.stack([
        x_even * cos - x_odd * sin,
        x_even * sin + x_odd * cos
    ], axis=-1)

    return x_rotated.reshape(original_shape)  # Reshape back to original 2D




# Compute the query, key, and value vectors (NO RoPE applied here)
all_queries, all_keys, all_values = compute_attention_inputs(
    normalized_embeddings, combined_embeddings, attention_heads
)

# Apply RoPE *only* to queries after projection
all_queries_roped = [apply_rope(q) for q in all_queries]


# Print shapes for verification (taking the first head as an example)
print("Shapes for the first attention head:")
print(f"  Query Shape (after RoPE): {all_queries_roped[0].shape}")
print(f"  Key Shape: {all_keys[0].shape}")  # No RoPE
print(f"  Value Shape: {all_values[0].shape}") # No RoPE
print("-" * 20)



Shapes for the first attention head:
  Query Shape (after RoPE): (1024, 128)
  Key Shape: (1024, 128)
  Value Shape: (1024, 128)
--------------------


In [9]:
# Cell 9: Calculate Attention Weights and Apply to Values

def calculate_attention(queries, keys, values, d_k):
    """Calculates the scaled dot-product attention.

    Args:
        queries: A list of query vectors for all heads.
        keys: A list of key vectors for all heads.
        values: A list of value vectors for all heads.
        d_k: The dimension of the key/query vectors (per head).

    Returns:
        A list of attention-weighted value vectors for all heads.
    """
    attention_outputs = []
    for q, k, v in zip(queries, keys, values):
        # Calculate attention weights.  (seq_len_q, d_k) @ (d_k, seq_len_k) -> (seq_len_q, seq_len_k)
        attention_scores = jnp.matmul(q, k.transpose()) / jnp.sqrt(d_k)
        attention_weights = jax.nn.softmax(attention_scores, axis=-1)  # Softmax over the last axis (seq_len_k)

        # Apply attention weights to the values. (seq_len_q, seq_len_k) @ (seq_len_k, d_v) -> (seq_len_q, d_v)
        attention_output = jnp.matmul(attention_weights, v)
        attention_outputs.append(attention_output)
    return attention_outputs


# Calculate the attention outputs using the RoPE-transformed queries, and original keys/values
attention_outputs = calculate_attention(all_queries_roped, all_keys, all_values, D_K)

# Print the shape of the output for the first head
print("Shape of attention output for the first head:", attention_outputs[0].shape)
# Print shapes for all heads
print("Shapes of all attention outputs", [o.shape for o in attention_outputs])

Shape of attention output for the first head: (1024, 128)
Shapes of all attention outputs [(1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128), (1024, 128),

In [10]:
# Cell 10: Concatenate Attention Outputs (Debugging - No Projection)

def concatenate_attention_outputs(attention_outputs):
    """Concatenates the attention outputs from all heads.

    Args:
        attention_outputs: A list of attention-weighted value vectors for all heads.

    Returns:
        The concatenated attention outputs.
    """

    # Concatenate the outputs along the head dimension.
    # Each attention_output has shape (seq_len, d_k)
    # After concatenation: (seq_len, num_heads * d_k)
    concatenated_outputs = jnp.concatenate(attention_outputs, axis=-1)
    return concatenated_outputs




# Concatenate the attention outputs
concatenated_attention = concatenate_attention_outputs(attention_outputs)

print("Shape of concatenated attention outputs:", concatenated_attention.shape)

Shape of concatenated attention outputs: (1024, 12288)


In [11]:
# Cell 11: Initialize Final Projection Matrix (Corrected Variable Name)

def initialize_projection_matrix(d_model, key):
    """Initializes the final linear projection matrix.

    Args:
        d_model: The model's embedding dimension.
        key: A JAX PRNGKey for random number generation.

    Returns:
        The initialized projection matrix.
    """
    final_projection = jax.random.normal(key, (d_model, d_model)) * jnp.sqrt(2 / (d_model + d_model))
    return final_projection



# Initialize the projection matrix
projection_key = jax.random.PRNGKey(4)  # Use a new key
final_projection_matrix = initialize_projection_matrix(D_MODEL, projection_key)

print("Shape of the final projection matrix:", final_projection_matrix.shape)

Shape of the final projection matrix: (12288, 12288)


In [12]:
# Cell 12: Apply Final Projection

def apply_final_projection(concatenated_outputs, projection_matrix):
    """Applies the final linear projection to the concatenated attention outputs.

    Args:
        concatenated_outputs: The concatenated attention outputs from all heads.
        projection_matrix: The final linear projection matrix.

    Returns:
        The output of the multi-head attention block.
    """
    output = jnp.matmul(concatenated_outputs, projection_matrix)
    return output

# Apply the final projection
final_output = apply_final_projection(concatenated_attention, final_projection_matrix)

print("Shape of the final output:", final_output.shape)
print("First vector of the final output", final_output[0])
print("Last vector of the final output", final_output[-1])

Shape of the final output: (1024, 12288)
First vector of the final output [-0.36862063  0.31341684  0.11350989 ...  0.12969272 -0.08512364
 -0.00763207]
Last vector of the final output [ 0.2255859  -0.12203582  0.31405017 ...  0.31826627  0.03736952
 -0.49951586]


Now eventually MLA reduces KV cache cause first you going to cache only one latent vector representation which is eventually compressed and smaller than normal KV cache.

Yes MLA introduce new parameters in the system but this is great because now you have new parameters that are going to learn new behavoirs. and these 2 new matrices Wc and C. I dont count Wk and WV. since traditionaly also you would have K, V matrices to project the hidden states into the the D_k dimensions.

For more math behind MLA check out deep seek paper https://arxiv.org/pdf/2405.04434