# BrainTaichi Introduction

This tutorial provides a comprehensive guide on how to develop custom operators using `BrainTaichi`.

## Kernel Registration Interface

Brain dynamics is sparse and event-driven, however, proprietary operators for brain dynamics are not well abstracted and summarized. As a result, we are often faced with the need to customize operators. In this tutorial, we will explore how to customize brain dynamics operators using `BrainTaichi`.

Start by importing the relevant Python package.

In [3]:
import braintaichi as bti
import brainstate as bst

import jax
import jax.numpy as jnp
import pytest
import platform

import taichi as ti

### Basic Structure of Custom Operators
`Taichi` uses Python functions and decorators to define custom operators. Here is a basic structure of a custom operator:

In [4]:
@ti.kernel
def my_kernel(arg1: ti.types.ndarray(), arg2: ti.types.ndarray()):
    # Internal logic of the operator
    ...

The `@ti.kernel` decorator tells Taichi that this is a function that requires special compilation.

### Defining Helper Functions
When defining complex custom operators, you can use the @ti.func decorator to define helper functions. These functions can be called inside the kernel function:

In [5]:
@ti.func
def helper_func(x: ti.f32) -> ti.f32:
    # Auxiliary computation
    return x * 2

@ti.kernel
def my_kernel(arg: ti.types.ndarray()):
    for i in ti.ndrange(arg.shape[0]):
        arg[i] *= helper_func(arg[i])

### Example: Custom Event Processing Operator
The following example demonstrates how to customize an event processing operator:

In [6]:
@ti.func
def get_weight(weight: ti.types.ndarray()) -> ti.f32:
    return weight[None]

@ti.func
def update_output(out: ti.types.ndarray(), index: ti.i32, weight_val: ti.f32):
    out[index] += weight_val

@ti.kernel
def event_ell_cpu(indices: ti.types.ndarray(),
                  vector: ti.types.ndarray(),
                  weight: ti.types.ndarray(),
                  out: ti.types.ndarray()):
    weight_val = get_weight(weight)
    num_rows, num_cols = indices.shape
    ti.loop_config(serialize=True)
    for i in range(num_rows):
        if vector[i]:
            for j in range(num_cols):
                update_output(out, indices[i, j], weight_val)

In the declaration of parameters, the last few parameters need to be output parameters so that `Taichi` can compile correctly. This operator `event_ell_cpu` receives indices, vectors, weights, and output arrays, and updates the output arrays according to the provided logic.

### Registering and Using Custom Operators
After defining a custom operator, it can be registered into a specific framework and used where needed. 
`BrainTaichi` provides a simple and flexible interface for registering custom operators -- `XLACustomOp`. When registering, you can specify `cpu_kernel` and `gpu_kernel`, so the operator can run on different devices. Specify the outs parameter when calling, using `jax.ShapeDtypeStruct` to define the shape and data type of the output.

Note: Maintain the order of the operator’s declared parameters consistent with the order when calling.

In [None]:
# Taichi operator registration
prim = bti.XLACustomOp(cpu_kernel=event_ell_cpu, gpu_kernel=event_ell_gpu)

# Using the operator
def test_taichi_op():
    # Create input data
    # ...

    # Call the custom operator
    out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])

    # Output the result
    print(out)

## Basic Taichi Concepts

Taichi is a domain-specific language (DSL) designed to simplify the development of high-performance visual computing and physics simulation algorithms, particularly for computer graphics researchers. Here are some of the basic concepts of Taichi based on the provided introduction:

### Embedded in Python
Taichi is embedded within Python, allowing developers to leverage the simplicity and flexibility of Python while benefiting from the performance of native GPU or CPU instructions. This means that if you are familiar with Python, you can quickly start using Taichi without learning a completely new language.

### Just-in-Time (JIT) Compilation
Taichi uses modern JIT compilation frameworks like LLVM and SPIR-V to translate Python code into native GPU or CPU instructions. This approach ensures that the code runs efficiently both during development and at runtime.

### Imperative Programming Paradigm
Unlike many other DSLs that focus on specific computing patterns, Taichi adopts an imperative programming paradigm. This provides greater flexibility and allows developers to write complex computations in a single kernel, which Taichi refers to as a "mega-kernel."

### Optimizations
Taichi employs various compiler optimizations such as common subexpression elimination, dead code elimination, and control flow graph analysis. These optimizations are backend-neutral, thanks to Taichi's own intermediate representation (IR) layer.

### Community and Backend Support
Taichi has a strong and dedicated community that has contributed to the development of various backends, including Vulkan, OpenGL, and DirectX. This wide range of backend support enhances Taichi's portability and usability across different platforms.

## Kernel Optimization in Taichi

`Taichi` kernels automatically parallelize for-loops in the outermost scope. Our compiler sets the settings automatically to best explore the target architecture. Nonetheless, for Ninjas seeking the final few percent of speed, we provide several APIs to allow developers to fine-tune their programs. Specifying a proper block_dim is key.

You can use `ti.loop_config` to set the loop directives for the next for loop. Available directives are:

- **parallelize**: Sets the number of threads to use on CPU
- **block_dim**: Sets the number of threads in a block on GPU
- **serialize**: If you set **serialize** to True, the for loop will run serially, and you can write break statements inside it (Only applies on range/ndrange fors). Equals to setting **parallelize** to 1.

In [None]:
@ti.kernel
def break_in_serial_for() -> ti.i32:
    a = 0
    ti.loop_config(serialize=True)
    for i in range(100):  # This loop runs serially
        a += i
        if i == 10:
            break
    return a

break_in_serial_for()  # returns 55
n = 128
val = ti.field(ti.i32, shape=n)
@ti.kernel
def fill():
    ti.loop_config(parallelize=8, block_dim=16)
    # If the kernel is run on the CPU backend, 8 threads will be used to run it
    # If the kernel is run on the CUDA backend, each block will have 16 threads.
    for i in range(n):
        val[i] = i