In [None]:
from torch.nn.attention.flex_attention import (
    flex_attention,
)
import torch
flex_attention = torch.compile(flex_attention, dynamic=False)

# torch._dynamo.config.cache_size_limit = 192
# torch._dynamo.config.accumulated_cache_size_limit = 192

In [None]:
output = flex_attention(query, key, value, block_mask=self.block_mask)

In [73]:
import torch
import torch.nn as nn
import torch.nn.functional as F


from torch.nn.attention.flex_attention import (
    create_block_mask,
)

# from torch.nn.attention import and_masks, or_masks

def build_attention_mask_vectorized_simplified(mask):
    """
    Simplified vectorized implementation for masks with pre-labeled chunks.
    Args:
        mask (torch.Tensor): Shape (B, S), where 0 = padding/invalid, >0 = chunk labels
    Returns:
        torch.Tensor: Attention mask of shape (B, 1, S, S)
    """
    B, S = mask.shape
    device = mask.device

    # Expand mask to (B, S, S) for pairwise comparison
    chunk_ids = mask.unsqueeze(-1)  # (B, S, 1)
    same_chunk = chunk_ids == chunk_ids.transpose(-2, -1)  # (B, S, S)

    # Mask out invalid positions (where mask == 0)
    valid_token = (mask != 0).unsqueeze(-1)  # (B, S, 1)
    attention_mask = (same_chunk & valid_token).to(torch.float32)  # (B, S, S)

    # Add head dimension for multi-head attention compatibility
    return attention_mask.unsqueeze(1)  # (B, 1, S, S)


# Naive PyTorch Attention
class NaiveAttention(nn.Module):
    def __init__(self, d_model):
        super(NaiveAttention, self).__init__()
        self.d_model = d_model
        self.scale = 1.0 / (d_model ** 0.5)

    def forward(self, Q, K, V, mask=None):
        # Q, K, V: [batch_size, seq_len, d_model]
        scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale  # [batch_size, H, seq_len, seq_len]
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)  # [batch_size, seq_len, d_model]
        
        return output

# Flex Attention (Custom Attention Mechanism)
class FlexAttention(nn.Module):
    def __init__(self, d_model):
        super(FlexAttention, self).__init__()
        self.d_model = d_model
        self.scale = 1.0 / (d_model ** 0.5)

    def forward(self, Q, K, V, mask=None):
        # Q, K, V: [batch_size, seq_len, d_model]
        output = flex_attention(Q, K, V, block_mask=mask)
        return output

# Create a custom attention mask
def create_custom_mask(seq_len, device='cpu'):
    mask = torch.ones(seq_len, seq_len, device=device)
    mask = torch.tril(mask)  # Lower triangular mask (causal mask)
    return mask


prefix_length = 4
def prefix_mask(b, h, q_idx, kv_idx):
    return kv_idx <= 4


def causal(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx


def generate_block_prefix_causal_mask_mod(prompt_mask_4d):
    # prompt_mask_4d = build_attention_mask_vectorized_simplified(prompt_mask).numpy()
    def inner_causal(b, h, q_idx, kv_idx):
        return q_idx >= kv_idx
    # def judge_is_prompt(b, h, q_idx, kv_idx):
    #     return prompt_mask_4d[b][h][q_idx][kv_idx]
    def block_prefix_causal_mask_mod(b, h, q_idx, kv_idx):
        is_prompt = prompt_mask_4d[b, h, q_idx, kv_idx]
        return is_prompt | inner_causal(b, h, q_idx, kv_idx)
    
    return block_prefix_causal_mask_mod



# Test the attention mechanisms
def test_attention():
    batch_size = 2
    seq_len = 10
    d_model = 8
    heads = 4

    # Create random Q, K, V tensors
    Q = torch.randn(batch_size, heads, seq_len, d_model)
    K = torch.randn(batch_size, heads, seq_len, d_model)
    V = torch.randn(batch_size, heads, seq_len, d_model)
    # Example usage
    prompt_mask = torch.tensor([
        [1, 1, 1, 0, 0, 0, 2, 2, 0, 0],
        [1, 0, 2, 2, 0, 3, 3, 0, 0, 0]
    ], dtype=torch.long)
    # Create custom mask
    causal = create_custom_mask(seq_len, device=Q.device).expand(batch_size, heads, seq_len, seq_len)
    mask = build_attention_mask_vectorized_simplified(prompt_mask)
    # print(mask.shape)
    mask = mask.expand(batch_size, heads, seq_len, seq_len)
    # print(mask)
    prompt_mask_4d = (mask>0).to(torch.bool)
    # print(prompt_mask_4d[0][0][0][0])
    # print(causal(0, 0, 0, 0))
    mask_mod = generate_block_prefix_causal_mask_mod(prompt_mask_4d)
    block_mask = create_block_mask(mask_mod, None, None, seq_len, seq_len, device='cpu', _compile=False)

    # Initialize attention mechanisms
    naive_attention = NaiveAttention(d_model)
    flex_attention = FlexAttention(d_model)

    # Forward pass through both attention mechanisms
    # mask = mask.expand(batch_size, heads, seq_len, seq_len)
    # print(mask)
    naive_output = naive_attention(Q, K, V, mask + causal)
    flex_output  = flex_attention(Q, K, V, block_mask)
    # print(naive_output)
    # Compare the outputs
    print("Naive Attention Output Shape:", naive_output.shape)
    print("Flex Attention Output Shape:", flex_output.shape)
    # print("Naive Attention Weights Shape:", naive_weights.shape)
    # print("Flex Attention Weights Shape:", flex_weights.shape)

    # Compare the results
    print("\nDifference in Outputs:", torch.sum(naive_output - flex_output))
    # print("Difference in Weights:", torch.sum(naive_weights - flex_weights))

# Run the test
test_attention()

Naive Attention Output Shape: torch.Size([2, 4, 10, 8])
Flex Attention Output Shape: torch.Size([2, 4, 10, 8])

Difference in Outputs: tensor(0.)


In [16]:
import torch

def build_attention_mask_vectorized(mask):
    """
    Vectorized implementation to build attention mask from chunked input mask.
    Args:
        mask (torch.Tensor): Shape (B, S), where 1 = valid token, 0 = padding/invalid
    Returns:
        torch.Tensor: Attention mask of shape (B, 1, S, S)
    """
    B, S = mask.shape
    device = mask.device
    
    # 1. Identify chunk starts (0->1 transitions)
    padded_mask = torch.nn.functional.pad(mask, (1, 0), value=0)  # Pad left with 0
    diff = padded_mask[:, 1:] - padded_mask[:, :-1]  # Find transitions
    chunk_starts = (diff == 1)  # Marks start of new chunks

    # 2. Create chunk IDs using cumulative sum
    chunk_ids = torch.cumsum(chunk_starts, dim=1)  # (B, S)
    chunk_ids = chunk_ids * mask  # Zero out invalid positions

    # 3. Create attention mask (B, S, S)
    same_chunk = chunk_ids.unsqueeze(-1) == chunk_ids.unsqueeze(-2)  # (B, S, S)
    valid_token = chunk_ids.unsqueeze(-1) != 0  # (B, S, 1)
    attention_mask = (same_chunk & valid_token).to(torch.float32)  # Combine conditions

    # Add head dimension for multi-head attention compatibility
    return attention_mask.unsqueeze(1)  # (B, 1, S, S)

# Example usage
mask = torch.tensor([
    [1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0],
    [1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0]
], dtype=torch.float32)

attention_mask = build_attention_mask_vectorized(mask)
# print("Vectorized Attention Mask Shape:", attention_mask.shape)
print(attention_mask)

tensor([[[[1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
          [1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
          [1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],


        [[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 1., 1., 0., 0

In [17]:
def build_attention_mask_vectorized_simplified(mask):
    """
    Simplified vectorized implementation for masks with pre-labeled chunks.
    Args:
        mask (torch.Tensor): Shape (B, S), where 0 = padding/invalid, >0 = chunk labels
    Returns:
        torch.Tensor: Attention mask of shape (B, 1, S, S)
    """
    B, S = mask.shape
    device = mask.device

    # Expand mask to (B, S, S) for pairwise comparison
    chunk_ids = mask.unsqueeze(-1)  # (B, S, 1)
    same_chunk = chunk_ids == chunk_ids.transpose(-2, -1)  # (B, S, S)

    # Mask out invalid positions (where mask == 0)
    valid_token = (mask != 0).unsqueeze(-1)  # (B, S, 1)
    attention_mask = (same_chunk & valid_token).to(torch.float32)  # (B, S, S)

    # Add head dimension for multi-head attention compatibility
    return attention_mask.unsqueeze(1)  # (B, 1, S, S)

# Example usage
mask = torch.tensor([
    [1, 1, 1, 0, 0, 0, 2, 2, 0, 0, 0],
    [1, 0, 2, 2, 0, 3, 3, 0, 0, 0, 0]
], dtype=torch.float32)

attention_mask = build_attention_mask_vectorized(mask)
# print("Vectorized Attention Mask Shape:", attention_mask.shape)
print(attention_mask)

tensor([[[[1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
          [1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
          [1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],


        [[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 1., 1., 0., 0