# Program Augmentation


The `BrainState` framework provides a powerful program functionality augmentation mechanism based on the [`pygraph` syntax](./pygraph-en.ipynb), allowing users to add additional features to the basic computational model, such as automatic differentiation, batching, and multi-device parallelism. This tutorial will detail how to utilize these enhancement features to optimize and extend your model. It is recommended to read the [`pygraph` syntax](./pygraph-en.ipynb) tutorial prior to this one.

## 1. Automatic Differentiation

Automatic differentiation is one of the most fundamental and important features in deep learning. The `BrainState` framework is built upon JAX’s automatic differentiation system, providing a [simple and intuitive API](../apis/augment.rst) for gradient computation.

### 1.1 Automatic Differentiation Syntax

The automatic differentiation interface provided by `brainstate` requires the user to specify the `State` collection for which gradients are needed. The basic syntax is as follows:

```python
gradients = brainstate.augment.grad(loss_fn, states)
```

Here, `loss_fn` represents the loss function, and `states` is the collection of parameters for which gradients are to be computed. The `grad` function returns a new function that accepts the same inputs as `loss_fn`, but the return value consists of the gradients of each parameter in `states`. The `grad` function is designed for scalar loss functions, but it can be replaced with other forms of differentiation functions for different types of loss functions. Currently, the supported automatic differentiation interfaces include:

- `brainstate.augment.grad`: Automatic differentiation for scalar loss functions using reverse-mode automatic differentiation.
- `brainstate.augment.vector_grad`: Automatic differentiation for vector loss functions using reverse-mode automatic differentiation.
- `brainstate.augment.jacrev`: Jacobian matrix for scalar functions using reverse-mode automatic differentiation.
- `brainstate.augment.jacfwd`: Jacobian matrix for scalar functions using forward-mode automatic differentiation.
- `brainstate.augment.jacobian`: Jacobian matrix for scalar functions, equivalent to `brainstate.augment.jacrev`.
- `brainstate.augment.hessian`: Hessian matrix for scalar functions using reverse-mode automatic differentiation.
- For more detailed information, please refer to the [API documentation](../apis/augment.rst).

The automatic differentiation interfaces provided by `brainstate` support returning the loss function value (`return_value=True`) and also support returning auxiliary data (`has_aux=True`).

When `return_value=True`, the return value is a tuple where the first element is the gradient and the second element is the loss function value.

```python
gradients, loss = brainstate.augment.grad(loss_fn, states)
```

When `has_aux=True`, the return value is a tuple where the first element is the gradient and the second element is auxiliary data. In this case, `loss_fn` must return a tuple where the first element is the loss function value and the second element is the auxiliary data.

```python
def loss_fn(*args):
    ...
    return loss, aux

gradients, aux = brainstate.augment.grad(loss_fn, states, has_aux=True)
```

When both `return_value=True` and `has_aux=True` are set to true, the return value is a tuple where the first element is the gradient, the second element is the loss function value, and the third element is the auxiliary data.

```python
def loss_fn(*args):
    ...
    return loss, aux

gradients, loss, aux = brainstate.augment.grad(loss_fn, states, return_value=True, has_aux=True)
```

### 1.2 Basic Gradient Calculation

The functions provided by `brainstate`, such as `grad` and `vector_grad`, support first-order gradient calculations. Below is a simple example:

In [1]:
import jax.numpy as jnp

import brainunit as u
import brainstate as bst

In [2]:
# Create a simple linear layer model
model = bst.nn.Linear(2, 3)

# Prepare input data
x = jnp.ones((1, 2))
y = jnp.ones((1, 3))

# Define the loss function
def loss_fn(x, y):
    return jnp.mean((y - model(x)) ** 2)

# Retrieve model parameters
weights = model.states()

# Compute gradients
grads = bst.augment.grad(loss_fn, weights)(x, y)

# Print gradient information
print("Gradients:", grads)


Gradients: {('weight',): {'bias': Array([ 0.5462675, -1.7040147, -1.0706702], dtype=float32), 'weight': Array([[ 0.5462675, -1.7040147, -1.0706702],
       [ 0.5462675, -1.7040147, -1.0706702]], dtype=float32)}}


In this example, a simple linear layer model is created using the `bst.nn.Linear` class, which takes two input features and produces three output features. Input data `x` and target output `y` are prepared as arrays of ones with shapes corresponding to the model's input and output dimensions.

The loss function, defined as the mean squared error between the model's predictions and the target outputs, computes the average of the squared differences.

The model parameters (weights) are retrieved using the `model.states()` method. Gradients of the loss function with respect to the model parameters are calculated using the automatic differentiation feature provided by `bst.augment.grad`. The computed gradients are then printed to the console.

### 1.3 Higher-Order Gradient Computation

The `BrainState` framework supports the computation of higher-order derivatives, which can be very useful in certain optimization tasks:

In [3]:
# Compute the Hessian
hessian = bst.augment.hessian(loss_fn, weights)(x, y)

# Compute the Jacobian matrix
jacobian = bst.augment.jacobian(loss_fn, weights)(x, y)

In this example, the framework allows for the calculation of second-order derivatives by using the `bst.augment.hessian` function, which computes the Hessian matrix of the loss function with respect to the model parameters. Additionally, the Jacobian matrix can be calculated using the `bst.augment.jacobian` function, providing insight into how changes in model parameters affect the output of the loss function. Both of these higher-order derivative computations are essential for advanced optimization techniques and can enhance the model's performance in complex tasks.

### 1.4 Gradient Transformation and the Chain Rule

You can combine multiple gradient computation operations:



In [4]:
# Combine multiple gradient operations
def composite_grad(fn, params):
    grad_fn = bst.augment.grad(fn, params)
    return lambda *args: bst.augment.grad(grad_fn, params)(*args)

In this example, a function `composite_grad` is defined to facilitate the combination of multiple gradient operations. The function first computes the gradient of a given function `fn` with respect to the specified parameters using `bst.augment.grad`. It then returns a new lambda function that, when called with arguments, applies the chain rule by computing the gradient of the previously obtained gradient function. This approach allows for the creation of more complex gradient computations, enabling users to apply higher-level operations efficiently while adhering to the principles of automatic differentiation.

## 2. Batching Augmentation

Batching is a key technique in deep learning that enhances computational efficiency. By processing multiple samples simultaneously, it improves hardware utilization and reduces computational overhead. The `brainstate` framework supports batching for [``pygraph`` models](./pygraph-en.ipynb), allowing users to implement batching through a simple API.

Compared to `jax.vmap`, `brainstate.augment.vmap` introduces two additional parameters, ``in_states`` and ``out_states``, which specify which ``State`` objects in the model should undergo batching. Both ``in_states`` and ``out_states`` are dictionaries, where the keys represent the batching dimensions, and the values are the ``State`` objects to be batched (which can be any combination of ``State`` objects in a PyTree).

Similar to `jax.vmap`, the `brainstate.augment.vmap` function also accepts the `in_axes` and `out_axes` parameters, which specify which dimensions of non-``State`` parameters should be batched. The usage of `in_axes` and `out_axes` is the same as in `jax.vmap`.

Below are some simple examples:

In [5]:
batch_size = 4

# 创建批处理数据和模型
x_batch = jnp.ones((batch_size, 3)) * u.mA
model = bst.nn.LIF(3)
model.init_state(batch_size)


# 对每一个批次计算损失
@bst.augment.vmap(
    in_states=model.states()  # 所有State都进行批处理
)
def eval(x):
    with bst.environ.context(dt=0.1 * u.ms):
        return model(x)

eval(x_batch)

Array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)

``brainstate.augment.vmap`` automatically detects any ``State`` objects that do not have the correct batching dimensions set and issues a warning.

In [6]:
class Foo(bst.nn.Module):
    def __init__(self):
        super().__init__()
        self.a = bst.ParamState(jnp.arange(4))
        self.b = bst.ShortTermState(jnp.arange(4))

    def __call__(self):
        self.b.value = self.a.value * self.b.value


foo = Foo()
r = bst.augment.vmap(foo, in_states=foo.a)()

BatchAxisError: The value of State ShortTermState(
  value=Traced<ShapedArray(int32[4])>with<BatchTrace> with
    val = Array([[0, 0, 0, 0],
         [0, 1, 2, 3],
         [0, 2, 4, 6],
         [0, 3, 6, 9]], dtype=int32)
    batch_dim = 0
) is batched, but it is not in the out_states.

In the example above, we define a simple model `Foo`, where only `a` is batched, but `b` is not. `brainstate` will automatically detect this issue and raise an error. The correct approach is to batch `b` as well, for example, by setting it in `out_states`.

In [7]:
foo = Foo()
r = bst.augment.vmap(foo, in_states=foo.a, out_states=foo.b)()

foo.b.value

Array([[0, 0, 0, 0],
       [0, 1, 2, 3],
       [0, 2, 4, 6],
       [0, 3, 6, 9]], dtype=int32)

## 3. Multi-Device Parallel Computation

In addition to batching, `brainstate` also supports multi-device parallel computation. We can use the `brainstate.augment.pmap` function to convert a model into one that supports parallel computation across multiple devices.

The usage of the `brainstate.augment.pmap` augmentation function is fundamentally similar to that of the `brainstate.augment.vmap` function. However, the `pmap` function transforms the model to enable parallel computation across multiple devices, while the `vmap` function facilitates parallel computation across different threads on a single device.

This capability is particularly valuable in leveraging the computational power of multiple devices, such as GPUs or TPUs, to accelerate training and inference processes in deep learning applications. By utilizing `pmap`, users can effectively distribute their workloads and achieve significant improvements in performance and efficiency.

## 4. Combining Augmentation Transformations

In practical applications, we often need to combine multiple transformations for program augmentation. The various program augmentation functions and compilation functions in `brainstate` can be used together seamlessly. Below is a simple example:

In [8]:
batch_size = 5
xs = bst.random.rand(batch_size, 3)
ys = bst.random.rand(batch_size, 4)

net = bst.nn.Linear(3, 4)

@bst.augment.vmap
def batch_run(x):
    return net(x)


def loss_fn(x, y):
    return jnp.mean((y - batch_run(x)) ** 2)


weights = net.states(bst.ParamState)
opt = bst.optim.Adam(1e-3)
opt.register_trainable_weights(weights)


@bst.compile.jit
def batch_train(xs, ys):
    grads, l = bst.augment.grad(loss_fn, weights, return_value=True)(xs, ys)
    opt.update(grads)
    return l


l = batch_train(xs, ys)
l

Array(1.2410938, dtype=float32)

In this example, we define a loss function `loss_fn` that calculates the mean squared error between the model's predictions, obtained by invoking `batch_run(net, x)`, and the target values `y`. The model's trainable parameters (weights) are registered with the Adam optimizer.

The `batch_train` function is compiled with Just-In-Time (JIT) compilation using `bst.compile.jit`. Within this function, we compute the gradients of the loss function with respect to the model parameters using `bst.augment.grad`, specifying that we want both gradients and the loss value returned.

After calculating the gradients, we update the model parameters with the optimizer's `update` method. This approach illustrates how multiple functional enhancements, such as batching, gradient computation, and JIT compilation, can be effectively combined to streamline the training process. The result, `l`, contains the computed loss value, demonstrating the integration of these functionalities in a cohesive workflow.

## 5. Performance Optimization Recommendations

When utilizing program augmentation transformations, the following points should be considered to achieve optimal performance:

1. **Appropriate Batch Size**: Choose a suitable batch size based on device memory and computational capacity.
2. **Gradient Accumulation**: When the batch size is limited, consider implementing gradient accumulation.
3. **Cached Compilation**: Reuse compiled functions to reduce compilation overhead.
4. **Memory Management**: Use `jax.device_get()` to release device memory in a timely manner.


## 6. Debugging Techniques

Debugging is an important topic when utilizing functional enhancements. The `brainstate` framework fully supports the debugging tools provided in JAX, such as the `print` function in `jax.debug`. Below is a simple example:

```python
# Debugging with jax.debug.print
@bst.compile.jit
def batch_train(xs, ys):
    grads, l = bst.augment.grad(loss_fn, weights, return_value=True)(xs, ys)
    jax.debug.print("Gradients: {g}", g=grads)
    opt.update(grads)
    return l
```

For detailed usage, users can refer to the [JAX Debugging Documentation](https://jax.readthedocs.io/en/latest/debugging/index.html).


In this example, the `batch_train` function utilizes `jax.debug.print` to output the computed gradients during the training process. This can be particularly useful for monitoring the training dynamics and diagnosing issues related to gradient computations. Leveraging debugging tools like this can enhance the development process and facilitate the identification of errors or unexpected behaviors in the model's training workflow.

## Conclusion

The program augmentations based on `pygraph` are one of the core features of the `BrainState` framework. By effectively utilizing these augmentations, one can significantly improve the inference, training efficiency, and overall performance of models. This tutorial has covered the primary program augmentation features and their usage methods, providing a foundation for further exploration and application of these capabilities. 