if ffn_results:
    print("\n" + "="*70)
    print("FFN BENCHMARK RESULTS")
    print("="*70)
    print(f"\n{'Seq Len':<10} {'PyTorch (ms)':<15} {'Custom (ms)':<15} {'Speedup':<10} {'Memory Saved':<15}")
    print("-"*70)
    
    for r in ffn_results:
        # Calculate memory saved
        mem_saved_mb = (r['seq_len'] * config.intermediate_size * 4) / (1024**2)
        print(f"{r['seq_len']:<10} {r['pytorch_mean']:<15.2f} {r['custom_mean']:<15.2f} {r['speedup']:<10.2f}x {mem_saved_mb:<15.1f} MB")
    
    # Calculate average speedup
    avg_speedup = np.mean([r['speedup'] for r in ffn_results])
    print(f"\nAverage speedup: {avg_speedup:.2f}x")
    
    # Save FFN results
    results_dir = Path('results')
    with open(results_dir / 'ffn_benchmark.json', 'w') as f:
        json.dump(ffn_results, f, indent=2)
    print(f"\nFFN results saved to: {results_dir / 'ffn_benchmark.json'}")
else:
    print("\nNo FFN results to display (GPU may not be available)")

## Setup

Import required modules and check GPU availability.

---

# PROMPT 4: Documentation Updates

*(Coming soon - will add for final documentation)*

---

## Summary

This notebook demonstrates:

1. **PROMPT 1**:
   - Llama-2-7B configuration loading
   - Dummy input generation for attention and FFN
   - Memory usage estimation

2. **PROMPT 2**:
   - Custom attention kernel loading
   - Numerical correctness validation
   - Performance benchmarking across sequence lengths
   - Memory savings analysis

3. **PROMPT 3**:
   - Custom FFN kernel loading
   - Numerical correctness validation
   - Performance benchmarking with memory savings

4. **PROMPT 4**: *(to be added)*

### Key Results

- **Attention**: O(N) memory vs O(N²) for standard attention, online softmax
- **FFN**: Eliminates intermediate tensor allocations, fused GELU activation
- Combined speedup from optimized kernels

### Next Steps

- Run this notebook on a GPU with CUDA to get actual benchmark numbers
- Compare with Flash Attention 2 baseline
- Add documentation updates (PROMPT 4)

In [None]:
# Add parent directory to path
import sys
import os
sys.path.insert(0, os.path.abspath('..'))
sys.path.insert(0, os.path.abspath('.'))

import torch
import torch.nn as nn
import json
from pathlib import Path

# Import our modules
from configs.llama2_7b import LLAMA2_7B, get_config
from models.utils import (
    create_dummy_attention_inputs,
    create_dummy_ffn_inputs,
    estimate_memory_usage,
    validate_attention_output,
    get_gpu_memory_info,
    print_gpu_info
)
from scripts.benchmark_harness import BenchmarkHarness

print("Imports complete!")

## GPU Information

Check CUDA availability and display GPU specs.

In [None]:
print("="*70)
print("GPU Information")
print("="*70)

if torch.cuda.is_available():
    print("\nCUDA is available!")
    print_gpu_info()
    
    # Set device
    device = torch.device('cuda')
    print(f"\nUsing device: {device}")
else:
    print("\nCUDA is NOT available. Please run this notebook on a GPU machine.")
    device = torch.device('cpu')

---

# PROMPT 1: LLM Infrastructure Setup

## Model Configuration

Load Llama-2-7B configuration for benchmarking.

In [None]:
# Load Llama-2-7B configuration
config = LLAMA2_7B

print("="*70)
print("Llama-2-7B Configuration")
print("="*70)
print(f"Hidden size:           {config.hidden_size}")
print(f"Num layers:            {config.num_hidden_layers}")
print(f"Num attention heads:   {config.num_attention_heads}")
print(f"Num KV heads:          {config.num_key_value_heads}")
print(f"Head dimension:        {config.head_dim}")
print(f"Intermediate size:     {config.intermediate_size}")
print(f"Max position embeddings: {config.max_position_embeddings}")
print(f"Vocab size:            {config.vocab_size}")

## Test Dummy Input Generation

In [None]:
# Test attention input creation
seq_len = 512

print(f"\nTesting dummy input generation for seq_len={seq_len}...\n")

Q, K, V = create_dummy_attention_inputs(
    batch_size=1,
    seq_len=seq_len,
    num_heads=config.num_attention_heads,
    head_dim=config.head_dim,
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

print("Attention inputs created:")
print(f"  Q shape: {Q.shape}")
print(f"  K shape: {K.shape}")
print(f"  V shape: {V.shape}")
print(f"  Device:  {Q.device}")
print(f"  Q mean:  {Q.mean().item():.4f}")
print(f"  Q std:   {Q.std().item():.4f}")

# Test FFN input creation
x, w1, w2 = create_dummy_ffn_inputs(
    batch_size=1,
    seq_len=seq_len,
    hidden_dim=config.hidden_size,
    intermediate_dim=config.intermediate_size,
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

print("\nFFN inputs created:")
print(f"  Input shape: {x.shape}")
print(f"  W1 shape:    {w1.shape}")
print(f"  W2 shape:    {w2.shape}")

## Memory Estimation

In [None]:
# Estimate memory usage across sequence lengths
print("\n="*70)
print("Memory Usage Estimates (per layer, float32)")
print("="*70)
print(f"\n{'Seq Len':<10} {'Attention (MB)':<18} {'Attn Scores (MB)':<18} {'FFN (MB)':<12} {'Total (MB)':<12}")
print("-"*80)

for seq_len in [256, 512, 1024, 2048, 4096]:
    mem = estimate_memory_usage(
        seq_len=seq_len,
        hidden_dim=config.hidden_size,
        num_heads=config.num_attention_heads,
        intermediate_dim=config.intermediate_size,
    )
    print(f"{seq_len:<10} {mem['attention_mb']:<18.2f} {mem['attention_scores_mb']:<18.2f} {mem['ffn_mb']:<12.2f} {mem['total_mb']:<12.2f}")

print("\nNote: Custom attention kernel avoids storing O(N²) attention matrix!")

---

# PROMPT 2: Attention Kernel Benchmarking

## Load Custom Attention Kernel

In [None]:
# Import custom attention
from models.custom_attention import CustomMultiHeadAttention, create_pytorch_baseline_attention

print("="*70)
print("Loading Attention Kernels")
print("="*70)

try:
    custom_attn = CustomMultiHeadAttention(
        hidden_size=config.hidden_size,
        num_heads=config.num_attention_heads,
    ).cuda().eval()
    print("\nCustom attention kernel: LOADED")
    print(f"  Using custom kernel: {custom_attn.use_custom_kernel}")
except Exception as e:
    print(f"\nCustom attention kernel error: {e}")
    custom_attn = None

try:
    pytorch_attn = create_pytorch_baseline_attention(
        hidden_size=config.hidden_size,
        num_heads=config.num_attention_heads,
    ).cuda().eval()
    print("\nPyTorch baseline: LOADED")
except Exception as e:
    print(f"\nPyTorch baseline error: {e}")
    pytorch_attn = None

## Numerical Correctness Validation

In [None]:
# Validate correctness on a smaller test case
print("\n="*70)
print("Numerical Correctness Test")
print("="*70)

if custom_attn and pytorch_attn:
    test_seq_len = 256
    hidden_states = torch.randn(
        1, test_seq_len, config.hidden_size,
        dtype=torch.float32, device='cuda'
    )
    
    with torch.no_grad():
        custom_output = custom_attn(hidden_states)
        pytorch_output, _ = pytorch_attn(hidden_states, hidden_states, hidden_states)
    
    is_close, max_error, mean_error = validate_attention_output(
        custom_output, pytorch_output, rtol=1e-3, atol=1e-4
    )
    
    print(f"\nSequence length: {test_seq_len}")
    print(f"Output shape: {custom_output.shape}")
    print(f"Max error:  {max_error:.2e}")
    print(f"Mean error: {mean_error:.2e}")
    print(f"\nResult: {'PASS - Outputs match within tolerance' if is_close else 'FAIL - Outputs differ'}")
else:
    print("\nSkipping: models not loaded")

## Performance Benchmark

Benchmark across multiple sequence lengths.

In [None]:
import numpy as np

def benchmark_single_seq_len(seq_len, warmup=10, iters=50):
    """Benchmark attention at a single sequence length"""
    if not custom_attn or not pytorch_attn:
        return None
    
    print(f"\nBenchmarking seq_len={seq_len}...")
    
    hidden_states = torch.randn(
        1, seq_len, config.hidden_size,
        dtype=torch.float32, device='cuda'
    )
    
    # Warmup
    for _ in range(warmup):
        with torch.no_grad():
            _ = custom_attn(hidden_states)
            _ = pytorch_attn(hidden_states, hidden_states, hidden_states)
    torch.cuda.synchronize()
    
    # Benchmark PyTorch
    pytorch_times = []
    for _ in range(iters):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        with torch.no_grad():
            _ = pytorch_attn(hidden_states, hidden_states, hidden_states)
        end.record()
        torch.cuda.synchronize()
        pytorch_times.append(start.elapsed_time(end))
    
    # Benchmark Custom
    custom_times = []
    for _ in range(iters):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        with torch.no_grad():
            _ = custom_attn(hidden_states)
        end.record()
        torch.cuda.synchronize()
        custom_times.append(start.elapsed_time(end))
    
    return {
        'seq_len': seq_len,
        'pytorch_mean': np.mean(pytorch_times),
        'pytorch_std': np.std(pytorch_times),
        'custom_mean': np.mean(custom_times),
        'custom_std': np.std(custom_times),
        'speedup': np.mean(pytorch_times) / np.mean(custom_times),
    }

# Run benchmarks
print("\n" + "="*70)
print("Attention Performance Benchmark")
print("="*70)

seq_lengths = [512, 1024, 2048, 4096]
results = []

for seq_len in seq_lengths:
    try:
        result = benchmark_single_seq_len(seq_len)
        if result:
            results.append(result)
    except RuntimeError as e:
        print(f"  Error: {e}")
        continue

## Results Summary

In [None]:
if results:
    print("\n" + "="*70)
    print("BENCHMARK RESULTS")
    print("="*70)
    print(f"\n{'Seq Len':<10} {'PyTorch (ms)':<15} {'Custom (ms)':<15} {'Speedup':<10}")
    print("-"*60)
    
    for r in results:
        print(f"{r['seq_len']:<10} {r['pytorch_mean']:<15.2f} {r['custom_mean']:<15.2f} {r['speedup']:<10.2f}x")
    
    # Calculate average speedup
    avg_speedup = np.mean([r['speedup'] for r in results])
    print(f"\nAverage speedup: {avg_speedup:.2f}x")
    
    # Memory savings
    print("\n" + "="*70)
    print("Memory Savings (Attention Matrix Avoided)")
    print("="*70)
    print(f"\n{'Seq Len':<10} {'Attention Matrix (MB)':<25} {'Saved':<10}")
    print("-"*50)
    
    for seq_len in [512, 1024, 2048, 4096]:
        attn_matrix_mb = (config.num_attention_heads * seq_len * seq_len * 4) / (1024**2)
        saved_gb = attn_matrix_mb / 1024
        print(f"{seq_len:<10} {attn_matrix_mb:<25.2f} ~{saved_gb:<10.2f} GB")
else:
    print("\nNo results to display (GPU may not be available)")

## Save Results

In [None]:
# Save benchmark results
if results:
    results_dir = Path('results')
    results_dir.mkdir(exist_ok=True)
    
    # Save detailed results
    with open(results_dir / 'attention_benchmark.json', 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"\nResults saved to: {results_dir / 'attention_benchmark.json'}")
    
    # Display saved file
    print("\nSaved results:")
    print(json.dumps(results, indent=2))

# Import custom FFN
from models.custom_ffn import CustomFFN, PyTorchFFN, count_ffn_operations

print("="*70)
print("Loading FFN Kernels")
print("="*70)

try:
    custom_ffn = CustomFFN(
        hidden_size=config.hidden_size,
        intermediate_size=config.intermediate_size,
    ).cuda().eval()
    print("\nCustom FFN kernel: LOADED")
    print(f"  Using custom kernel: {custom_ffn.use_custom_kernel}")
except Exception as e:
    print(f"\nCustom FFN kernel error: {e}")
    custom_ffn = None

try:
    pytorch_ffn = PyTorchFFN(
        hidden_size=config.hidden_size,
        intermediate_size=config.intermediate_size,
    ).cuda().eval()
    print("\nPyTorch baseline: LOADED")
except Exception as e:
    print(f"\nPyTorch baseline error: {e}")
    pytorch_ffn = None

# Validate FFN correctness
print("\n" + "="*70)
print("FFN Numerical Correctness Test")
print("="*70)

if custom_ffn and pytorch_ffn:
    test_seq_len = 256
    hidden_states = torch.randn(
        1, test_seq_len, config.hidden_size,
        dtype=torch.float32, device='cuda'
    )
    
    with torch.no_grad():
        custom_output = custom_ffn(hidden_states)
        pytorch_output = pytorch_ffn(hidden_states)
    
    is_close, max_error, mean_error = validate_attention_output(
        custom_output, pytorch_output, rtol=1e-3, atol=1e-4
    )
    
    print(f"\nSequence length: {test_seq_len}")
    print(f"Output shape: {custom_output.shape}")
    print(f"Max error:  {max_error:.2e}")
    print(f"Mean error: {mean_error:.2e}")
    print(f"\nResult: {'PASS - Outputs match within tolerance' if is_close else 'FAIL - Outputs differ'}")
else:
    print("\nSkipping: models not loaded")

def benchmark_ffn_single_seq_len(seq_len, warmup=10, iters=50):
    """Benchmark FFN at a single sequence length"""
    if not custom_ffn or not pytorch_ffn:
        return None
    
    print(f"\nBenchmarking FFN seq_len={seq_len}...")
    
    hidden_states = torch.randn(
        1, seq_len, config.hidden_size,
        dtype=torch.float32, device='cuda'
    )
    
    # Warmup
    for _ in range(warmup):
        with torch.no_grad():
            _ = custom_ffn(hidden_states)
            _ = pytorch_ffn(hidden_states)
    torch.cuda.synchronize()
    
    # Benchmark PyTorch
    pytorch_times = []
    for _ in range(iters):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        with torch.no_grad():
            _ = pytorch_ffn(hidden_states)
        end.record()
        torch.cuda.synchronize()
        pytorch_times.append(start.elapsed_time(end))
    
    # Benchmark Custom
    custom_times = []
    for _ in range(iters):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        with torch.no_grad():
            _ = custom_ffn(hidden_states)
        end.record()
        torch.cuda.synchronize()
        custom_times.append(start.elapsed_time(end))
    
    return {
        'seq_len': seq_len,
        'pytorch_mean': np.mean(pytorch_times),
        'pytorch_std': np.std(pytorch_times),
        'custom_mean': np.mean(custom_times),
        'custom_std': np.std(custom_times),
        'speedup': np.mean(pytorch_times) / np.mean(custom_times),
    }

# Run FFN benchmarks
print("\n" + "="*70)
print("FFN Performance Benchmark")
print("="*70)

ffn_seq_lengths = [512, 1024, 2048, 4096]
ffn_results = []

for seq_len in ffn_seq_lengths:
    try:
        result = benchmark_ffn_single_seq_len(seq_len)
        if result:
            ffn_results.append(result)
    except RuntimeError as e:
        print(f"  Error: {e}")
        continue