In [2]:
import jax
import jax.numpy as jnp
import math

import flax.linen as nn

In [3]:
class Dropout1d(nn.Module):

  dropout_rate: float = 0.0

  def __call__(self, x, deterministic=True):
    if (self.dropout_rate > 0.0) and not deterministic:
      drop = jax.random.bernoulli(
          self.make_rng('dropout'),
          1 - self.dropout_rate,
          (x.shape[0], 1, x.shape[-1]),
      )
      x = x * drop / (1 - self.dropout_rate)
    return x


def repeat_kv(x, n_rep):
  bs, slen, n_kv_heads, head_dim = x.shape
  if n_rep == 1:
    return x
  return jnp.tile(x[:, :, :, None, :], [1, 1, 1, n_rep, 1]).reshape(
      bs, slen, n_kv_heads * n_rep, head_dim
  )

class Attention(nn.Module):
    dim: int
    n_heads: int
    n_kv_heads: int | None = None
    dropout_rate: float = 0.0
    qkv_bias: bool = False

    def setup(self):
        self._n_kv_heads = self.n_heads if self.n_kv_heads is None else self.n_kv_heads
        assert self.n_heads % self._n_kv_heads == 0
        self.n_rep = self.n_heads // self._n_kv_heads
        self.head_dim = self.dim // self.n_heads
        self.wq = nn.Dense(self.n_heads * self.head_dim, use_bias=self.qkv_bias)
        self.wk = nn.Dense(self._n_kv_heads * self.head_dim, use_bias=self.qkv_bias)
        self.wv = nn.Dense(self._n_kv_heads * self.head_dim, use_bias=self.qkv_bias)
        self.wo = nn.Dense(self.dim, use_bias=False)
        self.attn_dropout = nn.Dropout(self.dropout_rate)
        # self.resid_dropout = nn.Dropout(self.dropout_rate)
        self.resid_dropout = Dropout1d(self.dropout_rate)

    def __call__(self, x, train=False):
        bsz, seqlen, _ = x.shape

        # QKV
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.reshape(bsz, seqlen, self.n_heads, self.head_dim)
        xk = xk.reshape(bsz, seqlen, self._n_kv_heads, self.head_dim)
        xv = xv.reshape(bsz, seqlen, self._n_kv_heads, self.head_dim)

        # grouped multiquery attention: expand out keys and values
        xk = repeat_kv(xk, self.n_rep)
        xv = repeat_kv(xv, self.n_rep)

        # make heads into a batch dimension
        xq = xq.swapaxes(1, 2)  # (bs, n_heads, seqlen, head_dim)
        xk = xk.swapaxes(1, 2)
        xv = xv.swapaxes(1, 2)

        scores = jnp.matmul(xq, xk.swapaxes(2, 3)) / math.sqrt(self.head_dim)
        scores = nn.softmax(scores, axis=-1)
        scores = self.attn_dropout(scores, deterministic=not train)
        output = jnp.matmul(scores, xv)  # (bs, n_heads, seqlen, head_dim)

        # restore time as batch dimension and concat heads
        output = output.swapaxes(1, 2).reshape(bsz, seqlen, -1)

        # final projection into the residual stream
        output = self.wo(output)
        output = self.resid_dropout(output, deterministic=not train)
        return output

In [4]:
import time
import numpy as np

def performance_test():
    # Test configuration
    batch_size = 256
    seq_len = 128
    dim = 12 * 64
    n_heads = 12
    n_kv_heads = 12  # For grouped query attention
    
    # Create model and test data
    model = Attention(dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, dropout_rate=0.1)
    key = jax.random.PRNGKey(42)
    x = jax.random.normal(key, (batch_size, seq_len, dim))
    
    # Initialize parameters
    params = model.init(key, x, train=True)
    
    # Create jitted function for training and inference
    @jax.jit
    def forward_train(params, x, key):
        return model.apply(params, x, train=True, rngs={'dropout': key})
    
    @jax.jit
    def forward_inference(params, x):
        return model.apply(params, x, train=False)
    
    print("=== Attention Performance Test ===")
    print(f"Config: batch_size={batch_size}, seq_len={seq_len}, dim={dim}")
    print(f"Heads: n_heads={n_heads}, n_kv_heads={n_kv_heads}")
    print()
    
    # Warmup runs
    print("Warming up JIT compilation...")
    warmup_runs = 5
    for i in range(warmup_runs):
        key, subkey = jax.random.split(key)
        _ = forward_train(params, x, subkey)
        _ = forward_inference(params, x)
        print(f"Warmup {i+1}/{warmup_runs} complete")
    
    # Ensure all computations are complete
    jax.block_until_ready(forward_train(params, x, key))
    jax.block_until_ready(forward_inference(params, x))
    print("Warmup complete!\n")
    
    # Performance measurement
    num_runs = 50
    
    # Training mode performance
    print("Measuring training performance...")
    train_times = []
    for i in range(num_runs):
        key, subkey = jax.random.split(key)
        start_time = time.perf_counter()
        result = forward_train(params, x, subkey)
        jax.block_until_ready(result)
        end_time = time.perf_counter()
        train_times.append((end_time - start_time) * 1000)  # Convert to ms
    
    # Inference mode performance
    print("Measuring inference performance...")
    inference_times = []
    for i in range(num_runs):
        start_time = time.perf_counter()
        result = forward_inference(params, x)
        jax.block_until_ready(result)
        end_time = time.perf_counter()
        inference_times.append((end_time - start_time) * 1000)  # Convert to ms
    
    # Calculate statistics
    train_mean = np.mean(train_times)
    train_std = np.std(train_times)
    train_min = np.min(train_times)
    train_max = np.max(train_times)
    
    inference_mean = np.mean(inference_times)
    inference_std = np.std(inference_times)
    inference_min = np.min(inference_times)
    inference_max = np.max(inference_times)
    
    # Report results
    print("\n=== Performance Results ===")
    print(f"Training mode ({num_runs} runs):")
    print(f"  Mean: {train_mean:.3f}ms ± {train_std:.3f}ms")
    print(f"  Range: {train_min:.3f}ms - {train_max:.3f}ms")
    print()
    print(f"Inference mode ({num_runs} runs):")
    print(f"  Mean: {inference_mean:.3f}ms ± {inference_std:.3f}ms")
    print(f"  Range: {inference_min:.3f}ms - {inference_max:.3f}ms")
    print()
    print(f"Training overhead: {((train_mean - inference_mean) / inference_mean * 100):.1f}%")
    
    # Calculate throughput
    total_params = sum(x.size for x in jax.tree_util.tree_leaves(params))
    tokens_per_sec_train = (batch_size * seq_len) / (train_mean / 1000)
    tokens_per_sec_inference = (batch_size * seq_len) / (inference_mean / 1000)
    
    print(f"\nModel info:")
    print(f"  Total parameters: {total_params:,}")
    print(f"  Tokens/sec (training): {tokens_per_sec_train:,.0f}")
    print(f"  Tokens/sec (inference): {tokens_per_sec_inference:,.0f}")

# Run the performance test
performance_test()

=== Attention Performance Test ===
Config: batch_size=256, seq_len=128, dim=768
Heads: n_heads=12, n_kv_heads=12

Warming up JIT compilation...
Warmup 1/5 complete
Warmup 2/5 complete
Warmup 3/5 complete
Warmup 4/5 complete
Warmup 5/5 complete
Warmup complete!

Measuring training performance...
Measuring inference performance...

=== Performance Results ===
Training mode (50 runs):
  Mean: 1.854ms ± 0.009ms
  Range: 1.832ms - 1.890ms

Inference mode (50 runs):
  Mean: 0.972ms ± 0.037ms
  Range: 0.942ms - 1.130ms

Training overhead: 90.7%

Model info:
  Total parameters: 2,359,296
  Tokens/sec (training): 17,672,133
  Tokens/sec (inference): 33,700,979


In [20]:
import jax.experimental.pallas.ops.tpu.flash_attention as flash_attention

class FlashAttention(nn.Module):
    dim: int
    n_heads: int
    n_kv_heads: int | None = None
    dropout_rate: float = 0.0
    qkv_bias: bool = False
    causal: bool = False  # Match original - no causal masking

    def setup(self):
        self._n_kv_heads = self.n_heads if self.n_kv_heads is None else self.n_kv_heads
        assert self.n_heads % self._n_kv_heads == 0
        self.n_rep = self.n_heads // self._n_kv_heads
        self.head_dim = self.dim // self.n_heads
        
        # Use same initialization as original
        self.wq = nn.Dense(self.n_heads * self.head_dim, use_bias=self.qkv_bias)
        self.wk = nn.Dense(self._n_kv_heads * self.head_dim, use_bias=self.qkv_bias)
        self.wv = nn.Dense(self._n_kv_heads * self.head_dim, use_bias=self.qkv_bias)
        self.wo = nn.Dense(self.dim, use_bias=False)
        
        # Match original dropout setup exactly
        self.attn_dropout = nn.Dropout(self.dropout_rate)
        self.resid_dropout = Dropout1d(self.dropout_rate)

    def __call__(self, x, attention_bias=None, train=False):
        bsz, seqlen, _ = x.shape

        # QKV projections - identical to original
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.reshape(bsz, seqlen, self.n_heads, self.head_dim)
        xk = xk.reshape(bsz, seqlen, self._n_kv_heads, self.head_dim)
        xv = xv.reshape(bsz, seqlen, self._n_kv_heads, self.head_dim)

        # Grouped multiquery attention: expand out keys and values - identical to original
        xk = repeat_kv(xk, self.n_rep)
        xv = repeat_kv(xv, self.n_rep)

        # make heads into a batch dimension - identical to original
        xq = xq.swapaxes(1, 2)  # (bs, n_heads, seqlen, head_dim)
        xk = xk.swapaxes(1, 2)
        xv = xv.swapaxes(1, 2)

        # Scale factor - identical to original
        sm_scale = 1.0 / math.sqrt(self.head_dim)

        # Create custom block sizes
        block_sizes = flash_attention.BlockSizes(
            block_b=64,
            block_q=128,
            block_k_major=128,
            block_k=128
        )
        
        output = flash_attention.flash_attention(
            q=xq,
            k=xk,
            v=xv,
            ab=attention_bias,
            causal=self.causal,
            sm_scale=sm_scale,
            block_sizes=block_sizes,
        )
                
        # restore time as batch dimension and concat heads - identical to original
        output = output.swapaxes(1, 2).reshape(bsz, seqlen, -1)

        # final projection into the residual stream - identical to original
        output = self.wo(output)
        output = self.resid_dropout(output, deterministic=not train)
        return output


def test_mathematical_equivalence():
    """Test that FlashAttention matches Attention exactly when using fallback."""
    print("=== Mathematical Equivalence Test ===")
    
    # Test configuration
    batch_size = 2
    seq_len = 64
    dim = 192
    n_heads = 6
    n_kv_heads = 6
    
    # Create models with same config
    regular_attn = Attention(dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, dropout_rate=0.0, qkv_bias=False)
    flash_attn = FlashAttention(dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, dropout_rate=0.0, qkv_bias=False, causal=False)
    
    # Test data
    key = jax.random.PRNGKey(42)
    x = jax.random.normal(key, (batch_size, seq_len, dim))
    
    # Initialize with same parameters by copying
    regular_params = regular_attn.init(key, x, train=False)
    flash_params = flash_attn.init(key, x, train=False)
    
    # Copy weights to ensure identical parameters
    flash_params = {
        'params': {
            'wq': regular_params['params']['wq'],
            'wk': regular_params['params']['wk'], 
            'wv': regular_params['params']['wv'],
            'wo': regular_params['params']['wo'],
        }
    }
    
    # Test inference (no dropout)
    regular_out = regular_attn.apply(regular_params, x, train=False)
    flash_out = flash_attn.apply(flash_params, x, train=False)
    
    # Check mathematical equivalence
    max_diff = jnp.max(jnp.abs(regular_out - flash_out))
    rel_diff = max_diff / jnp.max(jnp.abs(regular_out))
    
    print(f"Output shapes: Regular {regular_out.shape}, Flash {flash_out.shape}")
    print(f"Max absolute difference: {max_diff}")
    print(f"Max relative difference: {rel_diff}")
    print(f"Outputs are identical: {jnp.allclose(regular_out, flash_out, rtol=1e-6, atol=1e-6)}")
    
    # Test with different sequence lengths and grouped query attention
    print("\nTesting with grouped query attention...")
    regular_attn_gqa = Attention(dim=dim, n_heads=n_heads, n_kv_heads=3, dropout_rate=0.0, qkv_bias=False)
    flash_attn_gqa = FlashAttention(dim=dim, n_heads=n_heads, n_kv_heads=3, dropout_rate=0.0, qkv_bias=False, causal=False)
    
    regular_params_gqa = regular_attn_gqa.init(key, x, train=False)
    flash_params_gqa = {
        'params': {
            'wq': regular_params_gqa['params']['wq'],
            'wk': regular_params_gqa['params']['wk'], 
            'wv': regular_params_gqa['params']['wv'],
            'wo': regular_params_gqa['params']['wo'],
        }
    }
    
    regular_out_gqa = regular_attn_gqa.apply(regular_params_gqa, x, train=False)
    flash_out_gqa = flash_attn_gqa.apply(flash_params_gqa, x, train=False)
    
    max_diff_gqa = jnp.max(jnp.abs(regular_out_gqa - flash_out_gqa))
    print(f"GQA Max absolute difference: {max_diff_gqa}")
    print(f"GQA Outputs are identical: {jnp.allclose(regular_out_gqa, flash_out_gqa, rtol=1e-6, atol=1e-6)}")


def compare_attention_implementations():
    """Compare regular attention vs flash attention performance."""
    # Test configuration  
    batch_size = 256
    seq_len = 128
    dim = 12 * 64
    n_heads = 12
    n_kv_heads = 12
    
    print("=== Attention Implementation Comparison ===")
    print(f"Config: batch_size={batch_size}, seq_len={seq_len}, dim={dim}")
    print(f"Heads: n_heads={n_heads}, n_kv_heads={n_kv_heads}")
    print()
    
    # Create models with identical config
    regular_attn = Attention(dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, dropout_rate=0.0)
    flash_attn = FlashAttention(dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, dropout_rate=0.0, causal=False)
    
    # Test data
    key = jax.random.PRNGKey(42)
    x = jax.random.normal(key, (batch_size, seq_len, dim))
    
    # Initialize parameters
    regular_params = regular_attn.init(key, x, train=False)
    flash_params = flash_attn.init(key, x, train=False)
    
    # Create jitted functions
    @jax.jit
    def regular_forward(params, x):
        return regular_attn.apply(params, x, train=False)
    
    @jax.jit 
    def flash_forward(params, x):
        return flash_attn.apply(params, x, train=False)
    
    # Warmup
    print("Warming up...")
    for _ in range(100):
        _ = regular_forward(regular_params, x)
        _ = flash_forward(flash_params, x)
    
    # Performance comparison
    import time
    num_runs = 1000
    
    # Regular attention timing
    regular_times = []
    for _ in range(num_runs):
        start = time.perf_counter()
        result = regular_forward(regular_params, x)
        jax.block_until_ready(result)
        regular_times.append((time.perf_counter() - start) * 1000)
    
    # Flash attention timing
    flash_times = []
    for _ in range(num_runs):
        start = time.perf_counter()
        result = flash_forward(flash_params, x)
        jax.block_until_ready(result)
        flash_times.append((time.perf_counter() - start) * 1000)
    
    # Results
    regular_mean = jnp.mean(jnp.array(regular_times))
    flash_mean = jnp.mean(jnp.array(flash_times))
    speedup = regular_mean / flash_mean
    
    print(f"Regular Attention: {regular_mean:.3f}ms ± {jnp.std(jnp.array(regular_times)):.3f}ms")
    print(f"Flash Attention: {flash_mean:.3f}ms ± {jnp.std(jnp.array(flash_times)):.3f}ms")
    print(f"Speedup: {speedup:.2f}x")

# Test mathematical equivalence first
# test_mathematical_equivalence()

print("\n" + "="*50 + "\n")

# Then compare performance
try:
    compare_attention_implementations()
except Exception as e:
    print(f"Comparison failed: {e}")
    print("Flash attention may not be available on this platform.")



=== Attention Implementation Comparison ===
Config: batch_size=256, seq_len=128, dim=768
Heads: n_heads=12, n_kv_heads=12

Warming up...
Regular Attention: 1.027ms ± 1.825ms
Flash Attention: 2.944ms ± 0.011ms
Speedup: 0.35x


In [10]:
def analyze_xla_compilation():
    """Analyze XLA compilation and padding for attention implementations."""
    print("=== XLA Compilation Analysis ===")
    
    # Test configuration
    batch_size = 8
    seq_len = 127  # Odd number to see padding effects
    dim = 768
    n_heads = 12
    n_kv_heads = 12
    
    print(f"Input config: batch_size={batch_size}, seq_len={seq_len}, dim={dim}")
    print(f"Heads: n_heads={n_heads}, n_kv_heads={n_kv_heads}")
    print()
    
    # Create models
    regular_attn = Attention(dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, dropout_rate=0.0)
    flash_attn = FlashAttention(dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, dropout_rate=0.0, causal=False)
    
    # Test data
    key = jax.random.PRNGKey(42)
    x = jax.random.normal(key, (batch_size, seq_len, dim))
    
    # Initialize parameters
    regular_params = regular_attn.init(key, x, train=False)
    flash_params = flash_attn.init(key, x, train=False)
    
    # Create functions for analysis
    def regular_forward(params, x):
        return regular_attn.apply(params, x, train=False)
    
    def flash_forward(params, x):
        return flash_attn.apply(params, x, train=False)
    
    # Compile and get HLO
    print("Compiling regular attention...")
    regular_compiled = jax.jit(regular_forward)
    regular_out = regular_compiled(regular_params, x)  # Trigger compilation
    
    print("Compiling flash attention...")
    flash_compiled = jax.jit(flash_forward)
    flash_out = flash_compiled(flash_params, x)  # Trigger compilation
    
    # Get HLO text representation
    try:
        # Get the compiled computation
        regular_hlo = regular_compiled.lower(regular_params, x).compile().as_text()
        flash_hlo = flash_compiled.lower(flash_params, x).compile().as_text()
        
        print("Regular Attention HLO Analysis:")
        print("=" * 40)
        
        # Look for padding operations
        regular_lines = regular_hlo.split('\n')
        padding_ops = [line for line in regular_lines if 'pad' in line.lower()]
        reshape_ops = [line for line in regular_lines if 'reshape' in line.lower()]
        dot_ops = [line for line in regular_lines if 'dot' in line.lower()]
        
        print(f"Padding operations found: {len(padding_ops)}")
        for i, op in enumerate(padding_ops[:5]):  # Show first 5
            print(f"  {i+1}: {op.strip()}")
        
        print(f"\nReshape operations found: {len(reshape_ops)}")
        for i, op in enumerate(reshape_ops[:5]):  # Show first 5
            print(f"  {i+1}: {op.strip()}")
            
        print(f"\nDot operations found: {len(dot_ops)}")
        for i, op in enumerate(dot_ops[:5]):  # Show first 5
            print(f"  {i+1}: {op.strip()}")
        
        print("\n" + "=" * 40)
        print("Flash Attention HLO Analysis:")
        print("=" * 40)
        
        flash_lines = flash_hlo.split('\n')
        flash_padding_ops = [line for line in flash_lines if 'pad' in line.lower()]
        flash_reshape_ops = [line for line in flash_lines if 'reshape' in line.lower()]
        flash_dot_ops = [line for line in flash_lines if 'dot' in line.lower()]
        flash_custom_ops = [line for line in flash_lines if 'custom' in line.lower()]
        
        print(f"Padding operations found: {len(flash_padding_ops)}")
        for i, op in enumerate(flash_padding_ops[:5]):
            print(f"  {i+1}: {op.strip()}")
        
        print(f"\nReshape operations found: {len(flash_reshape_ops)}")
        for i, op in enumerate(flash_reshape_ops[:5]):
            print(f"  {i+1}: {op.strip()}")
            
        print(f"\nDot operations found: {len(flash_dot_ops)}")
        for i, op in enumerate(flash_dot_ops[:5]):
            print(f"  {i+1}: {op.strip()}")
            
        print(f"\nCustom operations found: {len(flash_custom_ops)}")
        for i, op in enumerate(flash_custom_ops[:5]):
            print(f"  {i+1}: {op.strip()}")
        
    except Exception as e:
        print(f"Could not extract HLO: {e}")
    
    # Memory usage analysis
    print("\n" + "=" * 50)
    print("Memory Layout Analysis:")
    print("=" * 50)
    
    # Check tensor shapes at different stages
    def analyze_intermediate_shapes(params, x, model_name):
        print(f"\n{model_name} Intermediate Shapes:")
        
        if model_name == "Regular":
            model = regular_attn
        else:
            model = flash_attn
            
        # Manually trace through to see shapes
        bsz, seqlen, _ = x.shape
        print(f"  Input: {x.shape}")
        
        # Get projected Q, K, V shapes
        if model_name == "Regular":
            with jax.disable_jit():
                # Extract intermediate computations
                xq = model.wq.apply({'params': params['params']['wq']}, x)
                xk = model.wk.apply({'params': params['params']['wk']}, x)
                xv = model.wv.apply({'params': params['params']['wv']}, x)
                
                print(f"  Q projection: {xq.shape}")
                print(f"  K projection: {xk.shape}")
                print(f"  V projection: {xv.shape}")
                
                # After reshape
                xq_reshaped = xq.reshape(bsz, seqlen, n_heads, dim // n_heads)
                xk_reshaped = xk.reshape(bsz, seqlen, n_kv_heads, dim // n_heads)
                xv_reshaped = xv.reshape(bsz, seqlen, n_kv_heads, dim // n_heads)
                
                print(f"  Q reshaped: {xq_reshaped.shape}")
                print(f"  K reshaped: {xk_reshaped.shape}")
                print(f"  V reshaped: {xv_reshaped.shape}")
                
                # After transpose
                xq_transposed = xq_reshaped.swapaxes(1, 2)
                xk_transposed = xk_reshaped.swapaxes(1, 2)
                xv_transposed = xv_reshaped.swapaxes(1, 2)
                
                print(f"  Q transposed: {xq_transposed.shape}")
                print(f"  K transposed: {xk_transposed.shape}")
                print(f"  V transposed: {xv_transposed.shape}")
                
                # Attention scores
                scores = jnp.matmul(xq_transposed, xk_transposed.swapaxes(2, 3))
                print(f"  Attention scores: {scores.shape}")
                
                # Memory usage estimation
                total_memory = (
                    jnp.prod(jnp.array(xq.shape)) + 
                    jnp.prod(jnp.array(xk.shape)) + 
                    jnp.prod(jnp.array(xv.shape)) + 
                    jnp.prod(jnp.array(scores.shape))
                ) * 4  # Assuming float32
                
                print(f"  Estimated memory (MB): {total_memory / (1024 * 1024):.2f}")
    
    try:
        analyze_intermediate_shapes(regular_params, x, "Regular")
        analyze_intermediate_shapes(flash_params, x, "Flash")
    except Exception as e:
        print(f"Shape analysis failed: {e}")
    
    # Check for sequence length padding effects
    print("\n" + "=" * 50)
    print("Sequence Length Padding Test:")
    print("=" * 50)
    
    # Test with different sequence lengths to see padding effects
    test_seq_lens = [64, 127, 128, 129, 256, 511, 512, 513]
    
    for test_seq_len in test_seq_lens:
        test_x = jax.random.normal(key, (batch_size, test_seq_len, dim))
        
        try:
            # Time both implementations
            start = time.perf_counter()
            _ = regular_compiled(regular_params, test_x)
            jax.block_until_ready(_)
            regular_time = (time.perf_counter() - start) * 1000
            
            start = time.perf_counter()
            _ = flash_compiled(flash_params, test_x)
            jax.block_until_ready(_)
            flash_time = (time.perf_counter() - start) * 1000
            
            speedup = regular_time / flash_time if flash_time > 0 else float('inf')
            
            print(f"  seq_len={test_seq_len:3d}: Regular={regular_time:.2f}ms, Flash={flash_time:.2f}ms, Speedup={speedup:.2f}x")
            
        except Exception as e:
            print(f"  seq_len={test_seq_len:3d}: Error - {e}")

# Run XLA analysis
import time
analyze_xla_compilation()

=== XLA Compilation Analysis ===
Input config: batch_size=8, seq_len=127, dim=768
Heads: n_heads=12, n_kv_heads=12

Compiling regular attention...
Compiling flash attention...
Regular Attention HLO Analysis:
Padding operations found: 0

Reshape operations found: 0

Dot operations found: 13
  1: ROOT %convolution-base-dilated.2 = bf16[8,12,127,64]{2,3,1,0:T(8,128)(2,1)S(1)} convolution(%fusion.1, %fusion.33), window={size=8x12 stride=7x11 lhs_dilate=8x12}, dim_labels=01bf_01io->01bf, metadata={op_name="jit(regular_forward)/jit(main)/Attention/dot_general" source_file="/tmp/ipykernel_11883/206515483.py" source_line=65}
  2: %convolution-base-dilated.3 = f32[8,12,127,127]{2,3,1,0:T(8,128)S(1)} convolution(%fusion.31, %fusion.34), window={size=8x12 stride=7x11 lhs_dilate=8x12}, dim_labels=01bf_01io->01bf, metadata={op_name="jit(regular_forward)/jit(main)/Attention/dot_general" source_file="/tmp/ipykernel_11883/206515483.py" source_line=62}
  3: ROOT %convolution.4 = bf16[8,127,768]{1,2,0:T