In [2]:
import torch
from torch.nn.functional import scaled_dot_product_attention

In [None]:
import torch
import math
from torch.nn.functional import scaled_dot_product_attention

def create_test_data_for_extend():
    """
    Create sample data to test _run_sdpa_forward_extend function.
    
    Scenario: 3 sequences being extended with different configurations
    """
    
    # Model configuration
    num_heads = 8
    head_size = 64
    max_total_tokens = 1000
    max_num_reqs = 100
    max_context_len = 512
    
    # Batch configuration
    num_seqs = 3
    seq_lens = torch.tensor([10, 15, 8])                # Total tokens after extension
    extend_prefix_lens = torch.tensor([7, 12, 5])       # Already cached tokens
    extend_seq_lens = torch.tensor([3, 3, 3])           # New tokens being added
    req_pool_indices = torch.tensor([42, 17, 89])       # Memory pool locations
    
    # Calculate total new tokens
    num_tokens = extend_seq_lens.sum().item()  # 3 + 3 + 3 = 9
    
    print(f"=== TEST DATA CONFIGURATION ===")
    print(f"num_seqs: {num_seqs}")
    print(f"seq_lens: {seq_lens.tolist()}")
    print(f"extend_prefix_lens: {extend_prefix_lens.tolist()}")
    print(f"extend_seq_lens: {extend_seq_lens.tolist()}")
    print(f"total num_tokens: {num_tokens}")
    print(f"req_pool_indices: {req_pool_indices.tolist()}")
    print()
    
    # Create query tensor [num_tokens, num_heads, head_size]
    query = torch.randn(num_tokens, num_heads, head_size)
    
    # Mark query tokens with identifiable values for debugging
    for i in range(num_tokens):
        query[i, :, 0] = i + 100  # First dimension has token ID (100, 101, 102, ...)
    
    print(f"Query tensor shape: {query.shape}")
    print(f"Query token markers (first head, first dim): {query[:, 0, 0].tolist()}")
    print()
    
    # Create output tensor (same shape as query)
    output = torch.zeros_like(query)
    
    # Create global KV cache [max_total_tokens, num_heads, head_size]
    k_cache = torch.randn(max_total_tokens, num_heads, head_size)
    v_cache = torch.randn(max_total_tokens, num_heads, head_size)
    
    # Mark cache entries with identifiable values
    for i in range(max_total_tokens):
        k_cache[i, :, 0] = i + 1000  # Key cache markers (1000, 1001, 1002, ...)
        v_cache[i, :, 0] = i + 2000  # Value cache markers (2000, 2001, 2002, ...)
    
    # Create req_to_token mapping [max_num_reqs, max_context_len]
    req_to_token = torch.zeros(max_num_reqs, max_context_len, dtype=torch.long)
    
    # Set up token mappings for our test sequences
    # Sequence 0: pool index 42, uses tokens 100-109 in cache
    req_to_token[42, :10] = torch.arange(100, 110)
    
    # Sequence 1: pool index 17, uses tokens 200-214 in cache  
    req_to_token[17, :15] = torch.arange(200, 215)
    
    # Sequence 2: pool index 89, uses tokens 300-307 in cache
    req_to_token[89, :8] = torch.arange(300, 308)
    
    print(f"=== TOKEN MAPPINGS ===")
    print(f"Seq 0 (pool {req_pool_indices[0]}): cache tokens {req_to_token[42, :10].tolist()}")
    print(f"Seq 1 (pool {req_pool_indices[1]}): cache tokens {req_to_token[17, :15].tolist()}")
    print(f"Seq 2 (pool {req_pool_indices[2]}): cache tokens {req_to_token[89, :8].tolist()}")
    print()
    
    # Set scaling factor
    scaling = 1.0 / math.sqrt(head_size)
    
    return {
        'query': query,
        'output': output,
        'k_cache': k_cache,
        'v_cache': v_cache,
        'req_to_token': req_to_token,
        'req_pool_indices': req_pool_indices,
        'seq_lens': seq_lens,
        'extend_prefix_lens': extend_prefix_lens,
        'extend_seq_lens': extend_seq_lens,
        'scaling': scaling,
        'enable_gqa': False,
        'causal': True
    }

def test_run_sdpa_forward_extend():
    """Test the _run_sdpa_forward_extend function with sample data."""
    
    # Get test data
    test_data = create_test_data_for_extend()
    
    # Extract parameters
    query = test_data['query']
    output = test_data['output']
    k_cache = test_data['k_cache']
    v_cache = test_data['v_cache']
    req_to_token = test_data['req_to_token']
    req_pool_indices = test_data['req_pool_indices']
    seq_lens = test_data['seq_lens']
    extend_prefix_lens = test_data['extend_prefix_lens']
    extend_seq_lens = test_data['extend_seq_lens']
    scaling = test_data['scaling']
    enable_gqa = test_data['enable_gqa']
    causal = test_data['causal']
    
    print("=== BEFORE PROCESSING ===")
    print(f"Output tensor (should be zeros): {output[:3, 0, 0].tolist()}")
    print()
    
    # Simulate the _run_sdpa_forward_extend function
    # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
    query_reshaped = query.movedim(0, query.dim() - 2)
    
    start_q, start_kv = 0, 0
    
    for seq_idx in range(seq_lens.shape[0]):
        print(f"=== PROCESSING SEQUENCE {seq_idx} ===")
        
        extend_seq_len_q = extend_seq_lens[seq_idx].item()
        prefill_seq_len_q = extend_prefix_lens[seq_idx].item()
        seq_len_kv = seq_lens[seq_idx].item()
        end_q = start_q + extend_seq_len_q
        end_kv = start_kv + seq_len_kv
        
        print(f"extend_seq_len_q: {extend_seq_len_q}")
        print(f"prefill_seq_len_q: {prefill_seq_len_q}")
        print(f"seq_len_kv: {seq_len_kv}")
        print(f"Query range: [{start_q}:{end_q}]")
        
        # Extract query for this sequence
        per_req_query = query_reshaped[:, start_q:end_q, :]
        print(f"per_req_query shape: {per_req_query.shape}")
        print(f"per_req_query markers: {per_req_query[0, :, 0].tolist()}")
        
        # Create redundant query tensor
        per_req_query_redundant = torch.zeros(
            (per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]),
            dtype=per_req_query.dtype,
            device=per_req_query.device,
        )
        
        # Place new queries at correct positions  
        per_req_query_redundant[:, prefill_seq_len_q:, :] = per_req_query
        print(f"per_req_query_redundant shape: {per_req_query_redundant.shape}")
        print(f"Redundant query markers: {per_req_query_redundant[0, :, 0].tolist()}")
        
        # Get key and value from cache
        req_pool_idx = req_pool_indices[seq_idx].item()
        per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
        print(f"req_pool_idx: {req_pool_idx}")
        print(f"per_req_tokens: {per_req_tokens.tolist()}")
        
        per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
        per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
        print(f"per_req_key shape: {per_req_key.shape}")
        print(f"Key markers: {per_req_key[0, :, 0].tolist()}")
        print(f"Value markers: {per_req_value[0, :, 0].tolist()}")
        
        # Run attention
        per_req_out_redundant = (
            scaled_dot_product_attention(
                per_req_query_redundant.unsqueeze(0),
                per_req_key.unsqueeze(0),
                per_req_value.unsqueeze(0),
                scale=scaling,
                is_causal=causal,
            )
            .squeeze(0)
            .movedim(query.dim() - 2, 0)
        )
        
        print(f"per_req_out_redundant shape: {per_req_out_redundant.shape}")
        print(f"Output markers before extraction: {per_req_out_redundant[:, 0, 0].tolist()}")
        
        # Extract relevant outputs
        relevant_output = per_req_out_redundant[prefill_seq_len_q:, :, :]
        output[start_q:end_q, :, :] = relevant_output
        
        print(f"Extracted output shape: {relevant_output.shape}")
        print(f"Placed in output[{start_q}:{end_q}]")
        print(f"Output after placement: {output[start_q:end_q, 0, 0].tolist()}")
        print()
        
        start_q, start_kv = end_q, end_kv
    
    print("=== FINAL RESULTS ===")
    print(f"Final output shape: {output.shape}")
    print(f"Output sample (first head, first dim): {output[:, 0, 0].tolist()}")
    print()
    
    # Verify output is non-zero (attention worked)
    output_norm = torch.norm(output)
    print(f"Output tensor norm: {output_norm:.4f}")
    
    if output_norm > 0:
        print("✅ Test PASSED - Output is non-zero, attention computation worked")
    else:
        print("❌ Test FAILED - Output is zero, something went wrong")
    
    return output

def visualize_attention_pattern_correct():
    """
    Properly visualize attention patterns without tensor dimension issues.
    
    The key insight: we need to avoid zero queries that cause NaN when combined with causal masking.
    """
    
    print("=== CORRECTED ATTENTION PATTERN VISUALIZATION ===")
    
    # Configuration
    extend_prefix_len = 2  # 2 cached tokens
    extend_seq_len = 2     # 2 new tokens  
    seq_len_kv = 4         # Total tokens in sequence
    
    # Create meaningful queries (non-zero for positions that will be used)
    # Shape: [seq_len, head_dim] where we only care about the new token positions
    query_redundant = torch.tensor([
        [0.0],  # Position 0: cached (will be masked anyway)
        [0.0],  # Position 1: cached (will be masked anyway)  
        [1.0],  # Position 2: new token with query
        [2.0],  # Position 3: new token with query
    ])
    
    # Keys and values
    k_cache = torch.tensor([
        [1.0],  # Key for position 0
        [1.0],  # Key for position 1
        [1.0],  # Key for position 2  
        [1.0],  # Key for position 3
    ])
    
    v_cache = torch.tensor([
        [10.0],  # Value for position 0
        [20.0],  # Value for position 1
        [30.0],  # Value for position 2
        [40.0],  # Value for position 3
    ])
    
    print(f"Query (redundant): {query_redundant.squeeze().tolist()}")
    print(f"Key cache: {k_cache.squeeze().tolist()}")  
    print(f"Value cache: {v_cache.squeeze().tolist()}")
    print(f"Extend prefix length: {extend_prefix_len}")
    print(f"New token positions: {extend_prefix_len} onwards")
    print()
    
    # Compute raw attention scores: Q @ K^T
    scores = torch.matmul(query_redundant, k_cache.transpose(-2, -1))  # [4, 4]
    print(f"Raw scores shape: {scores.shape}")
    print("Raw scores (Q @ K^T):")
    for i, row in enumerate(scores.tolist()):
        print(f"  Position {i}: {row}")
    print()
    
    # Create causal mask: upper triangular with -inf
    causal_mask = torch.triu(torch.ones(4, 4), diagonal=1) * float('-inf')
    print("Causal mask:")
    for i, row in enumerate(causal_mask.tolist()):
        print(f"  Position {i}: {[f'{x:.0f}' if x != float('-inf') else '-∞' for x in row]}")
    print()
    
    # Apply causal mask
    scores_masked = scores + causal_mask
    print("Scores after causal masking:")
    for i, row in enumerate(scores_masked.tolist()):
        formatted_row = []
        for x in row:
            if x == float('-inf'):
                formatted_row.append('-∞')
            elif math.isnan(x):
                formatted_row.append('NaN')  
            else:
                formatted_row.append(f'{x:.1f}')
        print(f"  Position {i}: {formatted_row}")
    print()
    
    # Apply softmax to get attention weights
    attn_weights = torch.softmax(scores_masked, dim=-1)
    print("Attention weights (after softmax):")
    for i, row in enumerate(attn_weights.tolist()):
        formatted_row = []
        for x in row:
            if math.isnan(x):
                formatted_row.append('NaN')
            else:
                formatted_row.append(f'{x:.3f}')
        print(f"  Position {i}: {formatted_row}")
    print()
    
    # Compute final output: attention_weights @ V
    output = torch.matmul(attn_weights, v_cache)  # [4, 1]
    print(f"Final output shape: {output.shape}")
    print("Final output (attention @ V):")
    for i, val in enumerate(output.squeeze().tolist()):
        if math.isnan(val):
            print(f"  Position {i}: NaN")
        else:
            print(f"  Position {i}: {val:.2f}")
    print()
    
    # Extract only the outputs for new tokens (what the extend function would return)
    new_token_outputs = output[extend_prefix_len:]
    print(f"New token outputs (positions {extend_prefix_len}+): {new_token_outputs.squeeze().tolist()}")
    
    print("\n=== EXPLANATION ===")
    print("1. Positions 0,1: Cached tokens (prefix)")
    print("2. Positions 2,3: New tokens being processed")
    print("3. Causal mask ensures:")
    print("   - Position 0: Only attends to itself")
    print("   - Position 1: Attends to positions 0,1") 
    print("   - Position 2: Attends to positions 0,1,2")
    print("   - Position 3: Attends to positions 0,1,2,3")
    print("4. Only outputs for positions 2,3 are used (new tokens)")
    
    # Show that this matches the extend function behavior
    print("\n=== EXTEND FUNCTION SIMULATION ===")
    print("This is exactly what _run_sdpa_forward_extend does:")
    print("1. Creates redundant query with new tokens at correct positions")
    print("2. Uses full KV cache for the sequence") 
    print("3. Applies causal attention")
    print("4. Extracts outputs only for new token positions")

def simple_working_example():
    """A simple example that definitely works without any issues."""
    
    print("=== SIMPLE WORKING EXAMPLE ===")
    
    # Use PyTorch's SDPA directly (like the real function does)
    batch_size = 1
    num_heads = 2
    seq_len = 4  
    head_dim = 8
    
    # Create test tensors with proper dimensions
    query = torch.randn(batch_size, num_heads, seq_len, head_dim)
    key = torch.randn(batch_size, num_heads, seq_len, head_dim)  
    value = torch.randn(batch_size, num_heads, seq_len, head_dim)
    
    print(f"Query shape: {query.shape}")
    print(f"Key shape: {key.shape}")
    print(f"Value shape: {value.shape}")
    
    # Run attention with causal masking
    output = scaled_dot_product_attention(
        query, key, value,
        is_causal=True,
        scale=1.0 / math.sqrt(head_dim)
    )
    
    print(f"Output shape: {output.shape}")
    print(f"Output norm: {torch.norm(output):.4f}")
    
    # Simulate extracting new tokens (last 2 positions)
    extend_prefix_len = 2
    new_token_outputs = output[:, :, extend_prefix_len:, :]
    print(f"New token outputs shape: {new_token_outputs.shape}")
    print(f"New token outputs norm: {torch.norm(new_token_outputs):.4f}")
    
    print("✅ Simple example completed successfully!")
    print("This demonstrates that SDPA works correctly with proper tensor shapes.")

if __name__ == "__main__":
    # Run the simple working example first
    simple_working_example()
    print("\n" + "="*60 + "\n")
    
    # Run the corrected visualization
    visualize_attention_pattern_correct()
    print("\n" + "="*60 + "\n")
    
    # Run the main test
    test_run_sdpa_forward_extend()


In [13]:
testdata = create_test_data_for_extend()

=== TEST DATA CONFIGURATION ===
num_seqs: 3
seq_lens: [10, 15, 8]
extend_prefix_lens: [7, 12, 5]
extend_seq_lens: [3, 3, 3]
total num_tokens: 9
req_pool_indices: [42, 17, 89]

Query tensor shape: torch.Size([9, 8, 64])
Query token markers (first head, first dim): [100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0]

=== TOKEN MAPPINGS ===
Seq 0 (pool 42): cache tokens [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]
Seq 1 (pool 17): cache tokens [200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214]
Seq 2 (pool 89): cache tokens [300, 301, 302, 303, 304, 305, 306, 307]



In [15]:
q = testdata['query']
q[0, 1, :]  # Access the first element in the first head and first dimension

tensor([ 1.0000e+02, -1.6576e+00, -2.0029e-02, -4.2802e-01,  1.3332e+00,
         9.9924e-01, -8.8784e-01, -4.4433e-01, -5.4830e-01,  5.9202e-01,
         8.4495e-01, -8.9862e-01, -2.5020e-02, -7.3002e-01, -2.1434e+00,
         6.3140e-01,  5.9199e-01,  2.7241e+00,  1.6477e+00,  8.3054e-01,
         1.6440e+00, -1.1733e+00,  6.0557e-01,  1.8598e-01, -1.8381e+00,
         5.8024e-01,  4.4517e-01, -2.8377e-01, -1.4314e+00,  1.1170e-02,
         1.3062e+00, -1.5280e+00, -7.7959e-01, -9.0410e-01,  1.5533e+00,
         6.5476e-01, -1.0017e+00, -3.6100e-01,  1.6596e+00, -6.9675e-01,
        -6.5668e-02,  8.1829e-01, -1.2362e+00, -8.3845e-01,  3.2537e-01,
        -1.8887e-01,  1.4897e-01, -9.9384e-01,  7.2069e-01,  1.8491e+00,
        -9.2455e-01,  6.9416e-01, -1.2503e+00, -3.9238e-01,  1.7303e+00,
        -2.0750e+00, -1.0172e+00, -4.9965e-01, -7.6738e-01,  1.5230e+00,
         6.1120e-01,  4.1065e-01, -7.2143e-01,  1.1637e+00])

In [31]:
_run_sdpa_forward_extend(
    query=testdata['query'],
    output=testdata['output'],
    k_cache=testdata['k_cache'],
    v_cache=testdata['v_cache'],
    req_to_token=testdata['req_to_token'],
    req_pool_indices=testdata['req_pool_indices'],
    seq_lens=testdata['seq_lens'],
    extend_prefix_lens=testdata['extend_prefix_lens'],
    extend_seq_lens=testdata['extend_seq_lens'],
    scaling=testdata['scaling'],
    enable_gqa=testdata['enable_gqa'],
    causal=testdata['causal'],
)

start_q: 0, end_q: 3
per_req_query shape: torch.Size([8, 3, 64])
per_req_query_redudant shape: torch.Size([8, 10, 64])
per_req_query_redudant after assignment shape: torch.Size([8, 10, 64])
start_q: 3, end_q: 6
per_req_query shape: torch.Size([8, 3, 64])
per_req_query_redudant shape: torch.Size([8, 15, 64])
per_req_query_redudant after assignment shape: torch.Size([8, 15, 64])
start_q: 6, end_q: 9
per_req_query shape: torch.Size([8, 3, 64])
per_req_query_redudant shape: torch.Size([8, 8, 64])
per_req_query_redudant after assignment shape: torch.Size([8, 8, 64])


tensor([[[ 2.1070e+03, -1.7710e+00, -1.1017e+00,  ..., -4.2823e-01,
           1.4145e-01, -2.0578e+00],
         [ 2.1070e+03,  1.5543e-01,  1.3160e-01,  ...,  4.8716e-01,
           9.3702e-01,  1.3232e-01],
         [ 2.1070e+03,  2.2299e-01,  1.4251e-02,  ..., -3.9241e-01,
          -9.8241e-01,  5.5761e-02],
         ...,
         [ 2.1070e+03,  8.1138e-01, -1.0179e+00,  ...,  9.7894e-01,
          -7.9975e-01, -8.9252e-01],
         [ 2.1070e+03,  1.9310e+00, -6.5271e-01,  ...,  5.2144e-01,
          -3.0845e-01, -3.9937e-01],
         [ 2.1070e+03, -8.9899e-01, -2.8282e-02,  ...,  2.2493e+00,
           4.6407e-01, -1.1943e+00]],

        [[ 2.1080e+03,  1.1631e+00,  9.3131e-01,  ...,  9.6877e-01,
          -5.1081e-01, -1.7506e+00],
         [ 2.1080e+03, -5.8837e-01, -5.5931e-01,  ..., -4.2122e-01,
          -1.2030e-01,  8.3382e-01],
         [ 2.1080e+03,  3.8540e-01, -5.2342e-02,  ...,  2.2130e+00,
          -5.2050e-01, -5.6921e-01],
         ...,
         [ 2.1080e+03,  1

In [21]:
query = testdata['query']
print(query.shape) 
query = query.movedim(0, query.dim() - 2)  
print(query.shape)  




torch.Size([9, 8, 64])
torch.Size([8, 9, 64])


In [26]:
seq_lens= testdata['seq_lens']
print(seq_lens)
print(seq_lens.shape)
print(seq_lens.shape[0])  # Number of sequences


seq_len_kv = seq_lens[0]  # Example for the first sequence
print(seq_len_kv)  # Should print the length of the first sequence

tensor([10, 15,  8])
torch.Size([3])
3
tensor(10)


In [27]:
extend_seq_lens = testdata['extend_seq_lens']
print(extend_seq_lens)
print(extend_seq_lens.shape)
print(extend_seq_lens[0]) # Number of sequences


extend_seq_len_q = extend_seq_lens[0]  # Example for the first sequence
start_q = 0
end_q = start_q + extend_seq_len_q
print(end_q)  # Should print the length of the first sequence

tensor([3, 3, 3])
torch.Size([3])
tensor(3)
tensor(3)
