# JAX/TPU Optimization Benchmarks

This notebook demonstrates the performance improvements from the JAX/TPU optimizations applied to CayleyPy.

In [None]:
import time
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Callable

try:
    import jax
    import jax.numpy as jnp
    JAX_AVAILABLE = True
except ImportError:
    JAX_AVAILABLE = False
    jax = None
    jnp = None
    print("JAX not available - using numpy for tests")

In [None]:
# Import CayleyPy components
if JAX_AVAILABLE:
    from cayleypy.jax_tensor_ops import (
        unique_with_indices, isin_via_searchsorted, sort_with_indices,
        batch_matmul, vectorized_element_wise_equal, batch_isin_via_searchsorted,
        distributed_batch_matmul, memory_efficient_unique, optimized_chunked_operation
    )
    
    from cayleypy.jax_hasher import (
        JAXStateHasher, OptimizedJAXStateHasher, vectorized_hash_states,
        distributed_hash_states, memory_efficient_hash_large_batch
    )
    
    print(f"JAX version: {jax.__version__}")
    print(f"Available devices: {jax.devices()}")
    print(f"Default backend: {jax.default_backend()}")

## Helper Functions

In [None]:
def time_function(func, *args, **kwargs):
    """Time a function execution."""
    start_time = time.time()
    result = func(*args, **kwargs)
    end_time = time.time()
    return end_time - start_time, result


def run_comparison(name, funcs, *args, **kwargs):
    """Run and compare multiple implementations of the same function."""
    print(f"\n=== Testing {name} ===")
    results = {}
    times = {}
    
    baseline_impl = next(iter(funcs.keys()))
    
    for impl_name, func in funcs.items():
        time_taken, result = time_function(func, *args, **kwargs)
        results[impl_name] = result
        times[impl_name] = time_taken
        print(f"{impl_name}: {time_taken:.6f} seconds")
    
    # Calculate speedups
    baseline_time = times[baseline_impl]
    speedups = {}
    for impl_name, time_taken in times.items():
        if impl_name != baseline_impl:
            speedup = baseline_time / time_taken
            speedups[impl_name] = speedup
            print(f"{impl_name} speedup: {speedup:.2f}x")
    
    # Verify results match
    baseline_result = results[baseline_impl]
    for impl_name, result in results.items():
        if impl_name != baseline_impl:
            if JAX_AVAILABLE:
                try:
                    match = jnp.array_equal(result, baseline_result)
                    print(f"{impl_name} results match baseline: {match}")
                except:
                    print(f"{impl_name} results could not be compared")
            else:
                try:
                    match = np.array_equal(result, baseline_result)
                    print(f"{impl_name} results match baseline: {match}")
                except:
                    print(f"{impl_name} results could not be compared")
    
    return times, speedups

## Tensor Operations Benchmarks

In [None]:
if JAX_AVAILABLE:
    # Test unique_with_indices
    array = jnp.array([3, 1, 2, 1, 3, 2, 4, 5, 6, 7, 8, 9, 10] * 100)  # Make it larger
    unique_times, unique_speedups = run_comparison("unique_with_indices", {
        "jnp.unique": lambda x: jnp.unique(x, return_inverse=True, return_counts=True),
        "optimized": lambda x: unique_with_indices(x, True, True)
    }, array)

In [None]:
if JAX_AVAILABLE:
    # Test isin with different sizes
    sizes = [100, 1000, 10000]
    isin_times = {"standard": {}, "optimized": {}}
    
    for size in sizes:
        elements = jnp.arange(size)
        test_elements = jnp.arange(0, size, 10)  # Every 10th element
        
        print(f"\n=== Testing isin with size {size} ===")
        
        # Standard implementation (only for small sizes)
        if size <= 1000:
            start_time = time.time()
            _ = jnp.array([e in test_elements for e in elements])
            end_time = time.time()
            isin_times["standard"][size] = end_time - start_time
            print(f"standard: {isin_times['standard'][size]:.6f} seconds")
        
        # Optimized implementation
        start_time = time.time()
        _ = isin_via_searchsorted(elements, test_elements)
        end_time = time.time()
        isin_times["optimized"][size] = end_time - start_time
        print(f"optimized: {isin_times['optimized'][size]:.6f} seconds")
        
        # Calculate speedup if standard is available
        if size in isin_times["standard"]:
            speedup = isin_times["standard"][size] / isin_times["optimized"][size]
            print(f"Speedup: {speedup:.2f}x")

In [None]:
if JAX_AVAILABLE:
    # Plot isin results
    plt.figure(figsize=(10, 6))
    
    standard_sizes = sorted(isin_times["standard"].keys())
    standard_times = [isin_times["standard"][size] for size in standard_sizes]
    plt.plot(standard_sizes, standard_times, marker='o', label="standard")
    
    optimized_sizes = sorted(isin_times["optimized"].keys())
    optimized_times = [isin_times["optimized"][size] for size in optimized_sizes]
    plt.plot(optimized_sizes, optimized_times, marker='o', label="optimized")
    
    plt.title("isin Performance Comparison")
    plt.xlabel("Array Size")
    plt.ylabel("Time (seconds)")
    plt.xscale('log')
    plt.yscale('log')
    plt.grid(True, which="both", ls="--", alpha=0.3)
    plt.legend()
    plt.show()

In [None]:
if JAX_AVAILABLE:
    # Test batch operations
    elements_batch = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
    test_elements = jnp.array([2, 5, 8, 11])
    batch_times, batch_speedups = run_comparison("batch_isin", {
        "loop": lambda x, y: jnp.stack([isin_via_searchsorted(row, y) for row in x]),
        "vectorized": lambda x, y: batch_isin_via_searchsorted(x, y)
    }, elements_batch, test_elements)

In [None]:
if JAX_AVAILABLE:
    # Test matrix multiplication with different sizes
    sizes = [10, 100, 500, 1000]
    matmul_times = {"standard": {}, "optimized": {}}
    
    for size in sizes:
        a = jnp.ones((size, size))
        b = jnp.ones((size, size))
        
        print(f"\n=== Testing matmul with size {size}x{size} ===")
        
        # Standard implementation
        start_time = time.time()
        _ = jnp.matmul(a, b)
        end_time = time.time()
        matmul_times["standard"][size] = end_time - start_time
        print(f"standard: {matmul_times['standard'][size]:.6f} seconds")
        
        # Optimized implementation
        start_time = time.time()
        _ = batch_matmul(a, b)
        end_time = time.time()
        matmul_times["optimized"][size] = end_time - start_time
        print(f"optimized: {matmul_times['optimized'][size]:.6f} seconds")
        
        # Calculate speedup
        speedup = matmul_times["standard"][size] / matmul_times["optimized"][size]
        print(f"Speedup: {speedup:.2f}x")

In [None]:
if JAX_AVAILABLE:
    # Plot matmul results
    plt.figure(figsize=(10, 6))
    
    standard_sizes = sorted(matmul_times["standard"].keys())
    standard_times = [matmul_times["standard"][size] for size in standard_sizes]
    plt.plot(standard_sizes, standard_times, marker='o', label="standard")
    
    optimized_sizes = sorted(matmul_times["optimized"].keys())
    optimized_times = [matmul_times["optimized"][size] for size in optimized_sizes]
    plt.plot(optimized_sizes, optimized_times, marker='o', label="optimized")
    
    plt.title("Matrix Multiplication Performance Comparison")
    plt.xlabel("Matrix Size")
    plt.ylabel("Time (seconds)")
    plt.grid(True, which="both", ls="--", alpha=0.3)
    plt.legend()
    plt.show()

## Hash Function Benchmarks

In [None]:
if JAX_AVAILABLE:
    # Test hash functions with different batch sizes
    sizes = [100, 1000, 10000, 100000]
    state_size = 10
    hash_times = {"standard": {}, "vectorized": {}, "optimized": {}}
    
    # Create hashers
    standard_hasher = JAXStateHasher(state_size=state_size, random_seed=42)
    optimized_hasher = OptimizedJAXStateHasher(state_size=state_size, random_seed=42)
    
    for size in sizes:
        states = jnp.ones((size, state_size), dtype=jnp.int32)
        states = states.at[:, 0].set(jnp.arange(size))  # Make states unique
        
        print(f"\n=== Testing hash_states with batch size {size} ===")
        
        # Standard implementation
        start_time = time.time()
        standard_result = standard_hasher.hash_states(states)
        end_time = time.time()
        hash_times["standard"][size] = end_time - start_time
        print(f"standard: {hash_times['standard'][size]:.6f} seconds")
        
        # Vectorized implementation
        start_time = time.time()
        vectorized_result = vectorized_hash_states(states, standard_hasher)
        end_time = time.time()
        hash_times["vectorized"][size] = end_time - start_time
        print(f"vectorized: {hash_times['vectorized'][size]:.6f} seconds")
        
        # Optimized implementation
        start_time = time.time()
        optimized_result = optimized_hasher.hash_states_optimized(states)
        end_time = time.time()
        hash_times["optimized"][size] = end_time - start_time
        print(f"optimized: {hash_times['optimized'][size]:.6f} seconds")
        
        # Calculate speedups
        vectorized_speedup = hash_times["standard"][size] / hash_times["vectorized"][size]
        optimized_speedup = hash_times["standard"][size] / hash_times["optimized"][size]
        print(f"Vectorized speedup: {vectorized_speedup:.2f}x")
        print(f"Optimized speedup: {optimized_speedup:.2f}x")
        
        # Verify results match
        vectorized_match = jnp.array_equal(standard_result, vectorized_result)
        optimized_match = jnp.array_equal(standard_result, optimized_result)
        print(f"Vectorized results match: {vectorized_match}")
        print(f"Optimized results match: {optimized_match}")

In [None]:
if JAX_AVAILABLE:
    # Plot hash function results
    plt.figure(figsize=(10, 6))
    
    for method, times in hash_times.items():
        sizes = sorted(times.keys())
        time_values = [times[size] for size in sizes]
        plt.plot(sizes, time_values, marker='o', label=method)
    
    plt.title("Hash Function Performance Comparison")
    plt.xlabel("Batch Size")
    plt.ylabel("Time (seconds)")
    plt.xscale('log')
    plt.yscale('log')
    plt.grid(True, which="both", ls="--", alpha=0.3)
    plt.legend()
    plt.show()

In [None]:
if JAX_AVAILABLE:
    # Plot speedups
    plt.figure(figsize=(10, 6))
    
    sizes = sorted(hash_times["standard"].keys())
    vectorized_speedups = [hash_times["standard"][size] / hash_times["vectorized"][size] for size in sizes]
    optimized_speedups = [hash_times["standard"][size] / hash_times["optimized"][size] for size in sizes]
    
    plt.plot(sizes, vectorized_speedups, marker='o', label="vectorized")
    plt.plot(sizes, optimized_speedups, marker='o', label="optimized")
    
    plt.title("Hash Function Speedup Comparison")
    plt.xlabel("Batch Size")
    plt.ylabel("Speedup (x)")
    plt.xscale('log')
    plt.grid(True, which="both", ls="--", alpha=0.3)
    plt.legend()
    plt.show()

## Memory Efficiency Tests

In [None]:
if JAX_AVAILABLE:
    # Test memory-efficient operations
    array_size = 10000
    chunk_size = 1000
    array = jnp.arange(array_size)
    
    def sum_operation(chunk):
        return jnp.sum(chunk, axis=0)
    
    chunked_times, chunked_speedups = run_comparison("chunked_operation", {
        "standard": lambda x: chunked_operation(x, sum_operation, chunk_size),
        "optimized": lambda x: optimized_chunked_operation(x, sum_operation, chunk_size, True)
    }, array)

In [None]:
if JAX_AVAILABLE:
    # Test memory-efficient unique
    array_size = 100000
    array = jnp.concatenate([jnp.arange(array_size // 10)] * 10)  # Repeated elements
    
    unique_times, unique_speedups = run_comparison("memory_efficient_unique", {
        "standard": jnp.unique,
        "memory_efficient": lambda x: memory_efficient_unique(x, 0.1)  # Small memory limit to force chunking
    }, array)

## Summary of Performance Improvements

In [None]:
if JAX_AVAILABLE:
    # Collect all speedups
    all_speedups = {
        "isin": {size: isin_times["standard"][size] / isin_times["optimized"][size] 
                for size in isin_times["standard"] if size in isin_times["optimized"]},
        "matmul": {size: matmul_times["standard"][size] / matmul_times["optimized"][size] 
                  for size in matmul_times["standard"] if size in matmul_times["optimized"]},
        "hash_vectorized": {size: hash_times["standard"][size] / hash_times["vectorized"][size] 
                          for size in hash_times["standard"] if size in hash_times["vectorized"]},
        "hash_optimized": {size: hash_times["standard"][size] / hash_times["optimized"][size] 
                         for size in hash_times["standard"] if size in hash_times["optimized"]}
    }
    
    # Print summary table
    print("\n=== Summary of Performance Improvements ===")
    print("Operation | Size | Speedup")
    print("-" * 40)
    
    for operation, speedups in all_speedups.items():
        for size, speedup in speedups.items():
            print(f"{operation} | {size} | {speedup:.2f}x")
    
    # Calculate average speedups
    avg_speedups = {}
    for operation, speedups in all_speedups.items():
        avg_speedups[operation] = sum(speedups.values()) / len(speedups)
    
    print("\n=== Average Speedups ===")
    for operation, avg_speedup in avg_speedups.items():
        print(f"{operation}: {avg_speedup:.2f}x")

In [None]:
if JAX_AVAILABLE:
    # Plot average speedups
    plt.figure(figsize=(10, 6))
    
    operations = list(avg_speedups.keys())
    speedup_values = [avg_speedups[op] for op in operations]
    
    plt.bar(operations, speedup_values)
    plt.axhline(y=1.0, color='r', linestyle='-', alpha=0.3)
    
    plt.title("Average Speedup by Operation")
    plt.xlabel("Operation")
    plt.ylabel("Average Speedup (x)")
    plt.grid(axis='y', alpha=0.3)
    
    # Add speedup values on top of bars
    for i, v in enumerate(speedup_values):
        plt.text(i, v + 0.1, f"{v:.2f}x", ha='center')
    
    plt.tight_layout()
    plt.show()

## Conclusion

The JAX/TPU optimizations have significantly improved the performance of CayleyPy's tensor operations and hash functions. Key improvements include:

1. **Vectorization**: Using `vmap` for batch operations provides substantial speedups
2. **JIT Compilation**: Proper use of `@jit` and static shapes improves compilation efficiency
3. **Memory Efficiency**: Chunked operations allow processing of larger datasets
4. **TPU Optimization**: Specialized implementations for TPU hardware

These optimizations make CayleyPy more efficient for large-scale graph processing on TPU hardware.