# Quick Start Guide

This tutorial will guide you through the basics of using `braintaichi` to create high-performance brain dynamics operators.

## Installation

First, make sure you have installed `braintaichi`:

```bash
pip install braintaichi
```

Or install from source:

```bash
git clone https://github.com/chaoming0625/braintaichi.git
cd braintaichi
pip install -e .
```

## Import Required Libraries

Let's start by importing the necessary libraries:

In [None]:
import sys
import numpy as np
import jax
import jax.numpy as jnp
import taichi as ti
from scipy.sparse import csr_matrix

import braintaichi as bti

## Example 1: Simple Vector Addition

Let's start with a simple example - vector addition using Taichi kernel.

In [12]:
@ti.kernel
def vector_add(
    a: ti.types.ndarray(ndim=1),
    b: ti.types.ndarray(ndim=1),
    out: ti.types.ndarray(ndim=1)
):
    for i in range(a.shape[0]):
        out[i] = a[i] + b[i]

# Register the custom operator
vector_add_op = bti.XLACustomOp(
    cpu_kernel=vector_add,
    gpu_kernel=vector_add
)

In [13]:
# Test the operator
n = 10
a = jnp.arange(n, dtype=jnp.float32)
b = jnp.ones(n, dtype=jnp.float32)

result = vector_add_op(
    a, b,
    outs=[jax.ShapeDtypeStruct((n,), dtype=jnp.float32)]
)

print("Input a:", a)
print("Input b:", b)
print("Result:", result)

Input a: [0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]
Input b: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
Result: [Array([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.], dtype=float32)]


## Example 2: Sparse Matrix-Vector Multiplication

Brain networks are typically sparse. Let's implement a sparse matrix-vector multiplication operator using CSR (Compressed Sparse Row) format.

In [14]:
@ti.kernel
def csr_matvec(
    values: ti.types.ndarray(ndim=1),
    indices: ti.types.ndarray(ndim=1),
    indptr: ti.types.ndarray(ndim=1),
    vector: ti.types.ndarray(ndim=1),
    out: ti.types.ndarray(ndim=1)
):
    # Iterate over each row
    for row in range(indptr.shape[0] - 1):
        row_sum = 0.0
        # Iterate over non-zero elements in the row
        for j in range(indptr[row], indptr[row + 1]):
            col = indices[j]
            row_sum += values[j] * vector[col]
        out[row] = row_sum

# Register the operator
csr_matvec_op = bti.XLACustomOp(
    cpu_kernel=csr_matvec,
    gpu_kernel=csr_matvec
)

In [15]:
# Create a sparse matrix
n_rows, n_cols = 100, 100
density = 0.1
dense_matrix = (np.random.rand(n_rows, n_cols) < density).astype(float)
dense_matrix *= np.random.rand(n_rows, n_cols)

# Convert to CSR format
sparse_matrix = csr_matrix(dense_matrix)

# Create input vector
input_vector = np.random.rand(n_cols).astype(np.float32)

# Run the custom operator
result = csr_matvec_op(
    jnp.array(sparse_matrix.data, dtype=jnp.float32),
    jnp.array(sparse_matrix.indices, dtype=jnp.int32),
    jnp.array(sparse_matrix.indptr, dtype=jnp.int32),
    jnp.array(input_vector, dtype=jnp.float32),
    outs=[jax.ShapeDtypeStruct((n_rows,), dtype=jnp.float32)]
)

# Verify the result
expected = sparse_matrix @ input_vector
print("Custom operator result:", result[0][:5])
print("Expected result:", expected[:5])
print("Maximum difference:", np.max(np.abs(np.array(result[0]) - expected)))

Custom operator result: [1.8271513 2.4389558 1.7941285 2.318161  1.7279923]
Expected result: [1.82715139 2.43895573 1.79412849 2.31816102 1.72799216]
Maximum difference: 3.8372544519660323e-07


## Example 3: Event-Driven Computation

Brain dynamics are often event-driven. Let's implement an event-driven synaptic transmission operator.

In [16]:
@ti.kernel
def event_csr_matvec(
    values: ti.types.ndarray(ndim=1),
    indices: ti.types.ndarray(ndim=1),
    indptr: ti.types.ndarray(ndim=1),
    events: ti.types.ndarray(ndim=1),  # Boolean array indicating which neurons fired
    out: ti.types.ndarray(ndim=1)
):
    # Only process rows where events occurred
    ti.loop_config(serialize=True)
    for row in range(indptr.shape[0] - 1):
        if events[row]:  # Only process if neuron fired
            for j in range(indptr[row], indptr[row + 1]):
                col = indices[j]
                out[col] += values[j]

# Register the operator
event_csr_op = bti.XLACustomOp(
    cpu_kernel=event_csr_matvec,
    gpu_kernel=event_csr_matvec
)

In [17]:
# Create test data
n_neurons = 1000
density = 0.1

# Create sparse connectivity matrix
connectivity = (np.random.rand(n_neurons, n_neurons) < density).astype(float)
connectivity *= np.random.rand(n_neurons, n_neurons) * 0.5  # Synaptic weights
sparse_conn = csr_matrix(connectivity)

# Generate random spike events (10% of neurons fire)
events = np.random.rand(n_neurons) < 0.1

# Run the event-driven operator
result = event_csr_op(
    jnp.array(sparse_conn.data, dtype=jnp.float32),
    jnp.array(sparse_conn.indices, dtype=jnp.int32),
    jnp.array(sparse_conn.indptr, dtype=jnp.int32),
    jnp.array(events, dtype=jnp.bool_),
    outs=[jax.ShapeDtypeStruct((n_neurons,), dtype=jnp.float32)]
)

print(f"Number of neurons that fired: {events.sum()}")
print(f"Synaptic input statistics:")
print(f"  Mean: {np.mean(result[0]):.4f}")
print(f"  Max: {np.max(result[0]):.4f}")
print(f"  Non-zero entries: {np.sum(np.array(result[0]) > 0)}")

Number of neurons that fired: 98
Synaptic input statistics:
  Mean: 2.4380
  Max: 5.0876
  Non-zero entries: 1000


## Example 4: Using Built-in Operators

`braintaichi` provides many pre-implemented operators for common brain dynamics operations. Let's explore some of them.

In [18]:
# Check available operators in braintaichi
print("Available modules in braintaichi:")
print([attr for attr in dir(bti) if not attr.startswith('_')])

Available modules in braintaichi:
['XLACustomOp', 'coo_to_csr', 'coomv', 'cpu_ops', 'csr_to_coo', 'csr_to_dense', 'csrmm', 'csrmv', 'defjvp', 'event_csrmm', 'event_csrmv', 'get_homo_weight_matrix', 'get_normal_weight_matrix', 'get_uniform_weight_matrix', 'jitc_event_mv_prob_homo', 'jitc_event_mv_prob_normal', 'jitc_event_mv_prob_uniform', 'jitc_mv_prob_homo', 'jitc_mv_prob_normal', 'jitc_mv_prob_uniform', 'rand', 'register_general_batching']


## Performance Tips

Here are some key tips for optimizing your custom operators:

1. **Parallelize outer loops**: Taichi automatically parallelizes the outermost for-loops
2. **Use `ti.loop_config(serialize=True)`**: When you need sequential execution or want to use break statements
3. **Choose appropriate data types**: Use `ti.f32` for single precision, `ti.f64` for double precision
4. **Avoid Python objects inside kernels**: Use Taichi native types only
5. **Batch operations**: Process multiple operations together to reduce overhead

## Integration with JAX

One of the powerful features of `braintaichi` is seamless integration with JAX, enabling automatic differentiation and JIT compilation.

In [19]:
# Example: Using braintaichi operator in a JAX JIT-compiled function
@jax.jit
def neural_network_step(weights_data, weights_indices, weights_indptr, inputs):
    """Simulate one step of a spiking neural network"""
    n = inputs.shape[0]
    # Apply synaptic weights
    synaptic_input = csr_matvec_op(
        weights_data,
        weights_indices,
        weights_indptr,
        inputs,
        outs=[jax.ShapeDtypeStruct((n,), dtype=jnp.float32)]
    )
    return synaptic_input

# Test the JIT-compiled function
n = 100
sparse_mat = csr_matrix((np.random.rand(n, n) < 0.1).astype(float) * np.random.rand(n, n))
inputs = jnp.array(np.random.rand(n), dtype=jnp.float32)

result = neural_network_step(
    jnp.array(sparse_mat.data, dtype=jnp.float32),
    jnp.array(sparse_mat.indices, dtype=jnp.int32),
    jnp.array(sparse_mat.indptr, dtype=jnp.int32),
    inputs
)

print("Neural network output:", result[0][:5])

Neural network output: [4.365637  1.9171182 1.9304688 2.2530348 2.110351 ]


## Next Steps

Now that you've learned the basics, you can:

1. Read the **braintaichi_intro.ipynb** for detailed kernel registration interfaces
2. Explore the **complete_example.ipynb** for more complex use cases
3. Check out the **advanced_optimization.ipynb** for performance optimization techniques
4. Visit the [API Documentation](https://braintaichi.readthedocs.io/) for detailed reference

For more examples, check the source code:
- [Event operators](https://github.com/chaoming0625/braintaichi/tree/main/braintaichi/_eventop)
- [Sparse operators](https://github.com/chaoming0625/braintaichi/tree/main/braintaichi/_sparseop)
- [JIT connection operators](https://github.com/chaoming0625/braintaichi/tree/main/braintaichi/_jitconnop)