# Advanced Optimization Techniques

This tutorial covers advanced optimization techniques for writing high-performance brain dynamics operators with `braintaichi`.

In [40]:
import numpy as np
import jax
import jax.numpy as jnp
import taichi as ti
import time
from scipy.sparse import csr_matrix

import braintaichi as bti

## 1. Loop Configuration and Parallelization

Taichi automatically parallelizes outer loops, but you can fine-tune the parallelization behavior for optimal performance.

### 1.1 Serial vs Parallel Execution

In [42]:
# Serial execution (useful when order matters or using break statements)
@ti.kernel
def serial_sum(
    arr: ti.types.ndarray(ndim=1),
    out: ti.types.ndarray(ndim=1)
):
    total = 0.0
    ti.loop_config(serialize=True)
    for i in range(arr.shape[0]):
        total += arr[i]
        if total > 100.0:  # Can use break in serial loops
            break
    out[0] = total

# Parallel execution (default for outer loops)
@ti.kernel
def parallel_multiply(
    arr: ti.types.ndarray(ndim=1),
    scalar: ti.f32,
    out: ti.types.ndarray(ndim=1)
):
    # This loop is automatically parallelized
    for i in range(arr.shape[0]):
        out[i] = arr[i] * scalar

### 1.2 Block Dimension Tuning for GPU

In [43]:
@ti.kernel
def optimized_gpu_kernel(
    data: ti.types.ndarray(ndim=1),
    out: ti.types.ndarray(ndim=1)
):
    # Configure block dimension for GPU
    # Common values: 64, 128, 256, 512
    ti.loop_config(block_dim=256)
    for i in range(data.shape[0]):
        out[i] = ti.sqrt(data[i]) + ti.sin(data[i])

# For CPU, configure number of parallel threads
@ti.kernel
def optimized_cpu_kernel(
    data: ti.types.ndarray(ndim=1),
    out: ti.types.ndarray(ndim=1)
):
    ti.loop_config(parallelize=8)  # Use 8 threads
    for i in range(data.shape[0]):
        out[i] = ti.sqrt(data[i]) + ti.sin(data[i])

## 2. Memory Access Optimization

Efficient memory access patterns are crucial for performance.

### 2.1 Coalesced Memory Access

In [44]:
# BAD: Non-coalesced access (strided access)
@ti.kernel
def bad_memory_access(
    matrix: ti.types.ndarray(ndim=2),
    out: ti.types.ndarray(ndim=1)
):
    n, m = matrix.shape
    for col in range(m):
        total = 0.0
        for row in range(n):
            total += matrix[row, col]  # Column-wise access (non-coalesced)
        out[col] = total

# GOOD: Coalesced access (contiguous access)
@ti.kernel
def good_memory_access(
    matrix: ti.types.ndarray(ndim=2),
    out: ti.types.ndarray(ndim=1)
):
    n, m = matrix.shape
    for row in range(n):
        for col in range(m):
            out[col] += matrix[row, col]  # Row-wise access (coalesced)

### 2.2 Using Local Variables

In [45]:
# BAD: Multiple global memory accesses
@ti.kernel
def bad_global_access(
    arr: ti.types.ndarray(ndim=1),
    out: ti.types.ndarray(ndim=1)
):
    for i in range(arr.shape[0]):
        out[i] = arr[i] * arr[i] + arr[i] * 2.0  # arr[i] accessed 3 times

# GOOD: Use local variable to cache value
@ti.kernel
def good_local_cache(
    arr: ti.types.ndarray(ndim=1),
    out: ti.types.ndarray(ndim=1)
):
    for i in range(arr.shape[0]):
        val = arr[i]  # Load once into local variable
        out[i] = val * val + val * 2.0

## 3. Optimizing Sparse Operations

Sparse operations are common in brain dynamics. Here are optimization strategies for sparse matrices.

### 3.1 Optimized CSR Matrix-Vector Multiplication

In [46]:
# Standard implementation
@ti.kernel
def csr_matvec_standard(
    values: ti.types.ndarray(ndim=1),
    indices: ti.types.ndarray(ndim=1),
    indptr: ti.types.ndarray(ndim=1),
    vector: ti.types.ndarray(ndim=1),
    out: ti.types.ndarray(ndim=1)
):
    for row in range(indptr.shape[0] - 1):
        row_sum = 0.0
        for j in range(indptr[row], indptr[row + 1]):
            row_sum += values[j] * vector[indices[j]]
        out[row] = row_sum

# Optimized implementation with local caching
@ti.kernel
def csr_matvec_optimized(
    values: ti.types.ndarray(ndim=1),
    indices: ti.types.ndarray(ndim=1),
    indptr: ti.types.ndarray(ndim=1),
    vector: ti.types.ndarray(ndim=1),
    out: ti.types.ndarray(ndim=1)
):
    for row in range(indptr.shape[0] - 1):
        row_sum = 0.0
        start = indptr[row]
        end = indptr[row + 1]
        for j in range(start, end):
            col = indices[j]
            val = values[j]
            vec_val = vector[col]
            row_sum += val * vec_val
        out[row] = row_sum

In [47]:
# Benchmark the two implementations
def benchmark_csr_matvec():
    n = 10000
    density = 0.01
    
    # Create sparse matrix
    dense_mat = (np.random.rand(n, n) < density).astype(float) * np.random.rand(n, n)
    sparse_mat = csr_matrix(dense_mat)
    vector = np.random.rand(n).astype(np.float32)
    
    # Register operators
    op_standard = bti.XLACustomOp(cpu_kernel=csr_matvec_standard, gpu_kernel=csr_matvec_standard)
    op_optimized = bti.XLACustomOp(cpu_kernel=csr_matvec_optimized, gpu_kernel=csr_matvec_optimized)
    
    # Prepare inputs
    values = jnp.array(sparse_mat.data, dtype=jnp.float32)
    indices = jnp.array(sparse_mat.indices, dtype=jnp.int32)
    indptr = jnp.array(sparse_mat.indptr, dtype=jnp.int32)
    vec = jnp.array(vector, dtype=jnp.float32)
    
    # Warm up
    for _ in range(3):
        _ = op_standard(values, indices, indptr, vec, 
                       outs=[jax.ShapeDtypeStruct((n,), dtype=jnp.float32)])
        _ = op_optimized(values, indices, indptr, vec,
                        outs=[jax.ShapeDtypeStruct((n,), dtype=jnp.float32)])
    
    # Benchmark
    n_runs = 10
    
    start = time.time()
    for _ in range(n_runs):
        _ = op_standard(values, indices, indptr, vec,
                       outs=[jax.ShapeDtypeStruct((n,), dtype=jnp.float32)])
    time_standard = (time.time() - start) / n_runs
    
    start = time.time()
    for _ in range(n_runs):
        _ = op_optimized(values, indices, indptr, vec,
                        outs=[jax.ShapeDtypeStruct((n,), dtype=jnp.float32)])
    time_optimized = (time.time() - start) / n_runs
    
    print(f"Standard implementation: {time_standard*1000:.3f} ms")
    print(f"Optimized implementation: {time_optimized*1000:.3f} ms")
    print(f"Speedup: {time_standard/time_optimized:.2f}x")

benchmark_csr_matvec()

Standard implementation: 0.200 ms
Optimized implementation: 0.300 ms
Speedup: 0.67x


## 4. Event-Driven Optimization

Event-driven computations are essential for spiking neural networks. Here's how to optimize them.

In [48]:
# Method 1: Direct event checking (simple but may have branch divergence)
@ti.kernel
def event_driven_v1(
    indices: ti.types.ndarray(ndim=1),
    indptr: ti.types.ndarray(ndim=1),
    events: ti.types.ndarray(ndim=1),
    weight: ti.f32,
    out: ti.types.ndarray(ndim=1)
):
    ti.loop_config(serialize=True)
    for row in range(indptr.shape[0] - 1):
        if events[row]:  # Check event
            for j in range(indptr[row], indptr[row + 1]):
                out[indices[j]] += weight

# Method 2: Event filtering (better for low firing rates)
@ti.kernel
def event_driven_v2(
    indices: ti.types.ndarray(ndim=1),
    indptr: ti.types.ndarray(ndim=1),
    event_indices: ti.types.ndarray(ndim=1),  # Indices of neurons that fired
    weight: ti.f32,
    out: ti.types.ndarray(ndim=1)
):
    # Only iterate over neurons that actually fired
    ti.loop_config(serialize=True)
    for i in range(event_indices.shape[0]):
        row = event_indices[i]
        for j in range(indptr[row], indptr[row + 1]):
            out[indices[j]] += weight

In [49]:
# Benchmark event-driven methods
def benchmark_event_driven():
    n = 10000
    density = 0.01
    firing_rate = 0.05  # 5% of neurons fire
    
    # Create connectivity
    conn = (np.random.rand(n, n) < density).astype(float)
    sparse_conn = csr_matrix(conn)
    
    # Create events
    events = np.random.rand(n) < firing_rate
    event_indices = np.where(events)[0].astype(np.int32)
    
    # Register operators
    op_v1 = bti.XLACustomOp(cpu_kernel=event_driven_v1, gpu_kernel=event_driven_v1)
    op_v2 = bti.XLACustomOp(cpu_kernel=event_driven_v2, gpu_kernel=event_driven_v2)
    
    print(f"Number of neurons: {n}")
    print(f"Connectivity density: {density}")
    print(f"Firing rate: {firing_rate}")
    print(f"Neurons that fired: {len(event_indices)}")
    print("\nMethod 1: Direct event checking")
    print("Method 2: Event filtering (recommended for low firing rates)")

benchmark_event_driven()

Number of neurons: 10000
Connectivity density: 0.01
Firing rate: 0.05
Neurons that fired: 511

Method 1: Direct event checking
Method 2: Event filtering (recommended for low firing rates)


## 5. Data Type Optimization

Choosing the right data types can significantly impact performance and memory usage.

In [50]:
# Using different precision levels
@ti.kernel
def compute_f32(
    arr: ti.types.ndarray(ndim=1),
    out: ti.types.ndarray(ndim=1)
):
    for i in range(arr.shape[0]):
        val = ti.cast(arr[i], ti.f32)  # Explicitly cast to float32
        out[i] = ti.sqrt(val) + ti.sin(val)

@ti.kernel
def compute_f64(
    arr: ti.types.ndarray(ndim=1),
    out: ti.types.ndarray(ndim=1)
):
    for i in range(arr.shape[0]):
        val = ti.cast(arr[i], ti.f64)  # Explicitly cast to float64
        out[i] = ti.sqrt(val) + ti.sin(val)

# Using integer types for indices
@ti.kernel
def gather_operation(
    data: ti.types.ndarray(ndim=1),
    indices: ti.types.ndarray(ndim=1),
    out: ti.types.ndarray(ndim=1)
):
    for i in range(indices.shape[0]):
        idx = ti.cast(indices[i], ti.i32)  # Use int32 for indices
        out[i] = data[idx]

## 6. Atomic Operations for Race Condition Handling

When multiple threads need to update the same memory location, use atomic operations.

In [51]:
# Without atomic operations (may have race conditions)
@ti.kernel
def scatter_add_unsafe(
    values: ti.types.ndarray(ndim=1),
    indices: ti.types.ndarray(ndim=1),
    out: ti.types.ndarray(ndim=1)
):
    for i in range(values.shape[0]):
        idx = indices[i]
        out[idx] += values[i]  # Race condition if multiple threads write to same idx

# With atomic operations (safe)
@ti.kernel
def scatter_add_safe(
    values: ti.types.ndarray(ndim=1),
    indices: ti.types.ndarray(ndim=1),
    out: ti.types.ndarray(ndim=1)
):
    for i in range(values.shape[0]):
        idx = indices[i]
        ti.atomic_add(out[idx], values[i])  # Thread-safe atomic addition

## 7. Helper Functions for Code Reusability

Use `@ti.func` to create reusable helper functions.

In [52]:
# Define helper functions
@ti.func
def relu(x: ti.f32) -> ti.f32:
    return ti.max(0.0, x)

@ti.func
def sigmoid(x: ti.f32) -> ti.f32:
    return 1.0 / (1.0 + ti.exp(-x))

@ti.func
def leaky_relu(x: ti.f32, alpha: ti.f32) -> ti.f32:
    return ti.max(alpha * x, x)

# Use helper functions in kernels
@ti.kernel
def apply_activation(
    data: ti.types.ndarray(ndim=1),
    activation_type: ti.i32,
    out: ti.types.ndarray(ndim=1)
):
    for i in range(data.shape[0]):
        x = data[i]
        if activation_type == 0:
            out[i] = relu(x)
        elif activation_type == 1:
            out[i] = sigmoid(x)
        else:
            out[i] = leaky_relu(x, 0.01)

## 8. Best Practices Summary

### Performance Checklist:

1. **Loop Configuration**
   - Use parallel loops for independent operations
   - Use serial loops when order matters or for break statements
   - Tune `block_dim` for GPU (try 128, 256, 512)

2. **Memory Access**
   - Prefer coalesced (contiguous) memory access
   - Cache frequently accessed values in local variables
   - Minimize global memory accesses

3. **Data Types**
   - Use `ti.f32` instead of `ti.f64` when precision allows
   - Use `ti.i32` for indices
   - Explicit casting for clarity

4. **Sparse Operations**
   - Store data in efficient formats (CSR, COO)
   - Use event filtering for low firing rates
   - Cache row/column indices when possible

5. **Synchronization**
   - Use atomic operations for concurrent writes
   - Avoid unnecessary synchronization

6. **Code Organization**
   - Extract common operations into `@ti.func` helpers
   - Keep kernels focused and modular
   - Profile before optimizing

### Profiling Tips:

```python
# Enable profiling
import time

# Warm up
for _ in range(3):
    result = your_operator(...)

# Measure
start = time.time()
for _ in range(100):
    result = your_operator(...)
elapsed = (time.time() - start) / 100
print(f"Average time: {elapsed*1000:.3f} ms")
```

## 9. Real-World Example: Optimized Spiking Neural Network Layer

Let's combine all the optimization techniques into a complete, optimized SNN layer implementation.

In [53]:
# Helper functions for neuron dynamics
@ti.func
def lif_dynamics(v: ti.f32, current: ti.f32, tau: ti.f32, dt: ti.f32) -> ti.f32:
    """Leaky Integrate-and-Fire neuron dynamics"""
    return v + ((-v + current) / tau) * dt

@ti.func
def check_spike(v: ti.f32, threshold: ti.f32) -> ti.i32:
    """Check if neuron spikes"""
    return 1 if v >= threshold else 0

# Optimized SNN layer kernel
@ti.kernel
def snn_layer_optimized(
    # Synaptic connectivity (CSR format)
    syn_values: ti.types.ndarray(ndim=1),
    syn_indices: ti.types.ndarray(ndim=1),
    syn_indptr: ti.types.ndarray(ndim=1),
    # Input spikes from previous layer
    input_spikes: ti.types.ndarray(ndim=1),
    # Neuron states
    membrane_v: ti.types.ndarray(ndim=1),
    # Parameters (as 0-dim arrays)
    tau: ti.types.ndarray(),
    threshold: ti.types.ndarray(),
    dt: ti.types.ndarray(),
    # Outputs
    output_spikes: ti.types.ndarray(ndim=1),
    new_membrane_v: ti.types.ndarray(ndim=1)
):
    n_neurons = membrane_v.shape[0]
    tau_val = tau[None]
    threshold_val = threshold[None]
    dt_val = dt[None]
    
    # Step 1: Compute synaptic currents (parallelized)
    for post_neuron in range(n_neurons):
        # Accumulate synaptic input
        synaptic_current = 0.0
        start = syn_indptr[post_neuron]
        end = syn_indptr[post_neuron + 1]
        
        for j in range(start, end):
            pre_neuron = syn_indices[j]
            if input_spikes[pre_neuron] > 0.5:  # Check if pre-synaptic neuron spiked
                synaptic_current += syn_values[j]
        
        # Step 2: Update membrane potential
        v_old = membrane_v[post_neuron]
        v_new = lif_dynamics(v_old, synaptic_current, tau_val, dt_val)
        
        # Step 3: Check for spike and reset
        spike = check_spike(v_new, threshold_val)
        if spike:
            v_new = 0.0  # Reset after spike
        
        # Write outputs
        output_spikes[post_neuron] = ti.cast(spike, ti.f32)
        new_membrane_v[post_neuron] = v_new

# Register the operator
snn_layer_op = bti.XLACustomOp(
    cpu_kernel=snn_layer_optimized,
    gpu_kernel=snn_layer_optimized
)

In [54]:
# Test the optimized SNN layer
def test_snn_layer():
    n_pre = 1000
    n_post = 800
    density = 0.1
    
    # Create random connectivity
    conn = (np.random.rand(n_post, n_pre) < density).astype(float)
    conn *= np.random.rand(n_post, n_pre) * 0.5  # Random weights
    sparse_conn = csr_matrix(conn)
    
    # Initial states
    input_spikes = (np.random.rand(n_pre) < 0.1).astype(np.float32)
    membrane_v = np.random.rand(n_post).astype(np.float32) * 0.5
    
    # Parameters (as 0-dim arrays)
    tau = jnp.array(10.0, dtype=jnp.float32)
    threshold = jnp.array(1.0, dtype=jnp.float32)
    dt = jnp.array(0.1, dtype=jnp.float32)
    
    # Run the SNN layer
    output_spikes, new_v = snn_layer_op(
        jnp.array(sparse_conn.data, dtype=jnp.float32),
        jnp.array(sparse_conn.indices, dtype=jnp.int32),
        jnp.array(sparse_conn.indptr, dtype=jnp.int32),
        jnp.array(input_spikes, dtype=jnp.float32),
        jnp.array(membrane_v, dtype=jnp.float32),
        tau, threshold, dt,
        outs=[
            jax.ShapeDtypeStruct((n_post,), dtype=jnp.float32),
            jax.ShapeDtypeStruct((n_post,), dtype=jnp.float32)
        ]
    )
    
    print(f"Input spikes: {input_spikes.sum()}/{n_pre}")
    print(f"Output spikes: {output_spikes.sum()}/{n_post}")
    print(f"Mean membrane potential: {new_v.mean():.4f}")
    print(f"Max membrane potential: {new_v.max():.4f}")

test_snn_layer()

Input spikes: 107.0/1000
Output spikes: 0.0/800
Mean membrane potential: 0.2725
Max membrane potential: 0.5309


## Conclusion

This tutorial covered advanced optimization techniques for `braintaichi`:

- Loop configuration and parallelization strategies
- Memory access patterns and caching
- Sparse operation optimization
- Event-driven computation strategies
- Data type selection
- Atomic operations for thread safety
- Code organization with helper functions
- Complete optimized SNN layer example

For more information:
- [Taichi Documentation](https://docs.taichi-lang.org/)
- [BrainTaichi API Reference](https://braintaichi.readthedocs.io/)
- [JAX Documentation](https://jax.readthedocs.io/)