<a href="https://colab.research.google.com/github/jha09pjha/jha09pjha/blob/main/attention_code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Tuple, Optional

# Model configuration parameters
hidden_size = 128
num_attention_heads = 16
num_key_value_heads = 4
head_dim = hidden_size // num_attention_heads
max_position_embeddings = 256
rope_theta = 10000.0
rms_norm_eps = 1e-5
attention_bias = False
attention_dropout = 0.0
use_qk_norm = True

# Sample input tensor shapes
batch_size = 2
sequence_length = 10
hidden_states = torch.randn(batch_size, sequence_length, hidden_size)
position_ids = torch.arange(0, sequence_length).unsqueeze(0).repeat(batch_size, 1)
# Causal attention mask
attention_mask = torch.triu(torch.ones(sequence_length, sequence_length) * -torch.inf, diagonal=1)
attention_mask = attention_mask.unsqueeze(0).unsqueeze(0)
attention_mask = attention_mask.expand(batch_size, 1, -1, -1)

# Print configuration and input shapes
print("Configuration:")
print(f"  hidden_size: {hidden_size}")
print(f"  num_attention_heads: {num_attention_heads}")
print(f"  num_key_value_heads: {num_key_value_heads}")
print(f"  head_dim: {head_dim}")

print("\nSample Input Shapes:")
print(f"  hidden_states: {hidden_states.shape}")
print(f"  position_ids: {position_ids.shape}")
print(f"  attention_mask: {attention_mask.shape}")

# Linear projections for Q, K, V, and Output
q_proj = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=attention_bias)
k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
o_proj = nn.Linear(num_attention_heads * head_dim, hidden_size, bias=attention_bias)

# Project input hidden states to Q, K, V
query_states = q_proj(hidden_states)
key_states = k_proj(hidden_states)
value_states = v_proj(hidden_states)

# Reshape and transpose Q, K, V for multi-head attention
query_states = query_states.view(batch_size, sequence_length, num_attention_heads, head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, sequence_length, num_key_value_heads, head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, sequence_length, num_key_value_heads, head_dim).transpose(1, 2)

# Print projected shapes
print("Projected Shapes:")
print(f"  query_states: {query_states.shape}")
print(f"  key_states: {key_states.shape}")
print(f"  value_states: {value_states.shape}")

# Calculate the number of key/value groups for Grouped Query Attention (GQA)
num_key_value_groups = num_attention_heads // num_key_value_heads
print(f"\nNum Key/Value Groups (Q heads per K/V head): {num_key_value_groups}")

# Function to calculate rotary embedding frequencies
def simple_rope_calculation(dim, max_seq_len, base=10000.0, device=None):
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))
    t = torch.arange(max_seq_len, device=device).type_as(inv_freq)
    freqs = new_func(inv_freq, t)
    emb = torch.cat((freqs, freqs), dim=-1)
    freqs_cos = emb.cos()
    freqs_sin = emb.sin()
    freqs_cis = torch.complex(freqs_cos, freqs_sin)
    return freqs_cis

def new_func(inv_freq, t):
    freqs = torch.outer(t, inv_freq)
    return freqs

# Function to apply rotary embeddings to query and key states
def apply_rotary_emb_torch(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:

    freqs_cis = freqs_cis.to(xq.device)

    freqs_cis = freqs_cis[position_ids]

    freqs_cis = freqs_cis[:, None, :, :]

    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    freqs_cis_broadcast = freqs_cis[..., :xq_.shape[-1]]

    rotated_xq = xq_ * freqs_cis_broadcast
    rotated_xk = xk_ * freqs_cis_broadcast

    xq_out = torch.view_as_real(rotated_xq).flatten(3)
    xk_out = torch.view_as_real(rotated_xk).flatten(3)

    return xq_out.type_as(xq), xk_out.type_as(xk)

# Calculate and apply RoPE
freqs_cis = simple_rope_calculation(head_dim, max_position_embeddings, base=rope_theta, device=hidden_states.device)
print(f"Calculated freqs_cis shape: {freqs_cis.shape}")

query_states_rope, key_states_rope = apply_rotary_emb_torch(query_states, key_states, freqs_cis)

print("\nShapes after RoPE:")
print(f"  query_states_rope: {query_states_rope.shape}")
print(f"  key_states_rope: {key_states_rope.shape}")

# Simple L2 Normalization class
class SimpleL2Norm(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps

    def forward(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

# Apply QK Normalization if enabled
if use_qk_norm:
    qk_norm = SimpleL2Norm()
    query_states_final = qk_norm(query_states_rope)
    key_states_final = qk_norm(key_states_rope)
    print("\nApplied QK Norm")
else:
    query_states_final = query_states_rope
    key_states_final = key_states_rope
    print("\nSkipped QK Norm")

print("\nShapes before attention score calculation:")
print(f"  query_states_final: {query_states_final.shape}")
print(f"  key_states_final: {key_states_final.shape}")

# Function to repeat key/value states for GQA
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

# Repeat key and value states for GQA
key_states_repeated = repeat_kv(key_states_final, num_key_value_groups)
value_states_repeated = repeat_kv(value_states, num_key_value_groups)

print("\nShapes after repeating K/V for GQA:")
print(f"  key_states_repeated: {key_states_repeated.shape}")
print(f"  value_states_repeated: {value_states_repeated.shape}")

# Calculate attention weights (dot product of Q and K)
attn_weights = torch.matmul(query_states_final, key_states_repeated.transpose(2, 3))

# Apply scaling factor
scaling_factor = 1.0 / math.sqrt(head_dim)
attn_weights = attn_weights * scaling_factor

# Apply attention mask
if attention_mask is not None:
    print(f"\nApplying attention mask with shape: {attention_mask.shape}")
    causal_mask = attention_mask[:, :, :, :key_states_repeated.shape[-2]]
    attn_weights = attn_weights + causal_mask
else:
     print("\nNo attention mask applied.")

# Apply softmax to get attention probabilities
attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query_states.dtype)

# Calculate attention output (weighted sum of V)
attn_output = torch.matmul(attn_weights, value_states_repeated)

print("\nAttention Calculation Shapes:")
print(f"  attn_weights (raw scores): {attn_weights.shape}")
print(f"  attn_weights (after softmax): {attn_weights.shape}")
print(f"  attn_output: {attn_output.shape}")

# Reshape attention output back to original dimensions
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, sequence_length, hidden_size)

# Apply output projection
final_attn_output = o_proj(attn_output)

print("\nFinal Output Shapes:")
print(f"  attn_output (reshaped): {attn_output.shape}")
print(f"  final_attn_output: {final_attn_output.shape}")

# Simple Feed-Forward Network class
class SimpleFeedForward(nn.Module):
    def __init__(self, hidden_size, intermediate_size):
        super().__init__()
        # Linear layer for gate projection
        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        # Linear layer for up projection
        self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        # Linear layer for down projection
        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)

    def forward(self, x):
        # Apply SiLU activation to gate projection, element-wise multiply with up projection, and then apply down projection
        return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))

# Simplified Llama-like Attention module with Feed-Forward Network
class SimplifiedLlama4Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config['hidden_size']
        self.num_attention_heads = config['num_attention_heads']
        self.num_key_value_heads = config['num_key_value_heads']
        self.head_dim = self.hidden_size // self.num_attention_heads
        self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
        self.max_position_embeddings = config['max_position_embeddings']
        self.rope_theta = config['rope_theta']
        self.attention_bias = config['attention_bias']
        self.use_qk_norm = config['use_qk_norm']
        self.intermediate_size = config.get('intermediate_size', 4 * self.hidden_size)

        if (self.head_dim * self.num_attention_heads) != self.hidden_size:
            raise ValueError("hidden_size must be divisible by num_attention_heads")

        # Linear projections
        self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=self.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.attention_bias)
        self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=self.attention_bias)

        # Rotary embeddings
        self.freqs_cis = simple_rope_calculation(self.head_dim, self.max_position_embeddings, base=self.rope_theta)

        # Optional QK Normalization
        if self.use_qk_norm:
             self.qk_norm = SimpleL2Norm()

        # Feed-Forward Network layer
        self.feed_forward = SimpleFeedForward(self.hidden_size, self.intermediate_size)

    def forward(self, hidden_states, attention_mask, position_ids):
        batch_size, sequence_length, _ = hidden_states.shape

        # Project input hidden states
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # Reshape and transpose for multi-head attention
        query_states = query_states.view(batch_size, sequence_length, self.num_attention_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(batch_size, sequence_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(batch_size, sequence_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        # Apply rotary embeddings
        current_freqs_cis = self.freqs_cis.to(hidden_states.device)
        query_states_rope, key_states_rope = apply_rotary_emb_torch(query_states, key_states, current_freqs_cis)

        # Apply QK Normalization
        if self.use_qk_norm:
             query_states_final = self.qk_norm(query_states_rope)
             key_states_final = self.qk_norm(key_states_rope)
        else:
            query_states_final = query_states_rope
            key_states_final = key_states_rope

        # Repeat key/value states for GQA
        key_states_repeated = repeat_kv(key_states_final, self.num_key_value_groups)
        value_states_repeated = repeat_kv(value_states, self.num_key_value_groups)

        # Calculate attention weights
        attn_weights = torch.matmul(query_states_final, key_states_repeated.transpose(2, 3))
        scaling_factor = 1.0 / math.sqrt(self.head_dim)
        attn_weights = attn_weights * scaling_factor

        # Apply attention mask
        if attention_mask is not None:
            causal_mask = attention_mask[:, :, :, :key_states_repeated.shape[-2]]
            attn_weights = attn_weights + causal_mask

        attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query_states.dtype)

        attn_output = torch.matmul(attn_weights, value_states_repeated)

        # Reshape attention output
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, sequence_length, self.hidden_size)

        # Apply output projection
        final_attn_output = self.o_proj(attn_output)

        # Pass the output through the feed-forward network
        output_with_ffn = self.feed_forward(final_attn_output)

        return output_with_ffn, attn_weights

# Configuration dictionary
config_dict = {
    'hidden_size': hidden_size,
    'num_attention_heads': num_attention_heads,
    'num_key_value_heads': num_key_value_heads,
    'max_position_embeddings': max_position_embeddings,
    'rope_theta': rope_theta,
    'attention_bias': attention_bias,
    'use_qk_norm': use_qk_norm,
    'intermediate_size': 4 * hidden_size
}

# Instantiate and run the simplified attention module with FFN
simplified_attn_module = SimplifiedLlama4Attention(config_dict)

final_output_simplified, final_weights_simplified = simplified_attn_module(hidden_states, attention_mask, position_ids)

# Print output shapes
print("\nOutput shape from simplified module:", final_output_simplified.shape)
print("Attention weights shape from simplified module:", final_weights_simplified.shape)

Configuration:
  hidden_size: 128
  num_attention_heads: 16
  num_key_value_heads: 4
  head_dim: 8

Sample Input Shapes:
  hidden_states: torch.Size([2, 10, 128])
  position_ids: torch.Size([2, 10])
  attention_mask: torch.Size([2, 1, 10, 10])
Projected Shapes:
  query_states: torch.Size([2, 16, 10, 8])
  key_states: torch.Size([2, 4, 10, 8])
  value_states: torch.Size([2, 4, 10, 8])

Num Key/Value Groups (Q heads per K/V head): 4
Calculated freqs_cis shape: torch.Size([256, 8])

Shapes after RoPE:
  query_states_rope: torch.Size([2, 16, 10, 8])
  key_states_rope: torch.Size([2, 4, 10, 8])

Applied QK Norm

Shapes before attention score calculation:
  query_states_final: torch.Size([2, 16, 10, 8])
  key_states_final: torch.Size([2, 4, 10, 8])

Shapes after repeating K/V for GQA:
  key_states_repeated: torch.Size([2, 16, 10, 8])
  value_states_repeated: torch.Size([2, 16, 10, 8])

Applying attention mask with shape: torch.Size([2, 1, 10, 10])

Attention Calculation Shapes:
  attn_weight

In [None]:
class SimpleFeedForward(nn.Module):
    """
    A simple feed-forward network with two linear layers and a SiLU activation.
    This is a common component in transformer models.
    """
    def __init__(self, hidden_size, intermediate_size):
        super().__init__()
        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
        self.act_fn = F.silu

    def forward(self, x):
        # First linear transformation and SiLU activation
        gate = self.act_fn(self.gate_proj(x))
        # Second linear transformation
        up = self.up_proj(x)
        # Element-wise multiplication of the two branches
        intermediate = gate * up
        # Final linear transformation
        down = self.down_proj(intermediate)
        return down