# Program Augmentation


The `BrainState` framework provides a powerful program functionality augmentation mechanism based on the [`pygraph` syntax](./pygraph-en.ipynb), allowing users to add additional features to the basic computational model, such as automatic differentiation, batching, and multi-device parallelism. This tutorial will detail how to utilize these enhancement features to optimize and extend your model. It is recommended to read the [`pygraph` syntax](./pygraph-en.ipynb) tutorial prior to this one.

## 1. Automatic Differentiation

Automatic differentiation is one of the most fundamental and important features in deep learning. The `BrainState` framework is built upon JAX’s automatic differentiation system, providing a [simple and intuitive API](../apis/augment.rst) for gradient computation.

### 1.1 Automatic Differentiation Syntax

The automatic differentiation interface provided by `brainstate` requires the user to specify the `State` collection for which gradients are needed. The basic syntax is as follows:

```python
gradients = brainstate.augment.grad(loss_fn, states)
```

Here, `loss_fn` represents the loss function, and `states` is the collection of parameters for which gradients are to be computed. The `grad` function returns a new function that accepts the same inputs as `loss_fn`, but the return value consists of the gradients of each parameter in `states`. The `grad` function is designed for scalar loss functions, but it can be replaced with other forms of differentiation functions for different types of loss functions. Currently, the supported automatic differentiation interfaces include:

- `brainstate.augment.grad`: Automatic differentiation for scalar loss functions using reverse-mode automatic differentiation.
- `brainstate.augment.vector_grad`: Automatic differentiation for vector loss functions using reverse-mode automatic differentiation.
- `brainstate.augment.jacrev`: Jacobian matrix for scalar functions using reverse-mode automatic differentiation.
- `brainstate.augment.jacfwd`: Jacobian matrix for scalar functions using forward-mode automatic differentiation.
- `brainstate.augment.jacobian`: Jacobian matrix for scalar functions, equivalent to `brainstate.augment.jacrev`.
- `brainstate.augment.hessian`: Hessian matrix for scalar functions using reverse-mode automatic differentiation.
- For more detailed information, please refer to the [API documentation](../apis/augment.rst).

The automatic differentiation interfaces provided by `brainstate` support returning the loss function value (`return_value=True`) and also support returning auxiliary data (`has_aux=True`).

When `return_value=True`, the return value is a tuple where the first element is the gradient and the second element is the loss function value.

```python
gradients, loss = brainstate.augment.grad(loss_fn, states)
```

When `has_aux=True`, the return value is a tuple where the first element is the gradient and the second element is auxiliary data. In this case, `loss_fn` must return a tuple where the first element is the loss function value and the second element is the auxiliary data.

```python
def loss_fn(*args):
    ...
    return loss, aux

gradients, aux = brainstate.augment.grad(loss_fn, states, has_aux=True)
```

When both `return_value=True` and `has_aux=True` are set to true, the return value is a tuple where the first element is the gradient, the second element is the loss function value, and the third element is the auxiliary data.

```python
def loss_fn(*args):
    ...
    return loss, aux

gradients, loss, aux = brainstate.augment.grad(loss_fn, states, return_value=True, has_aux=True)
```

### 1.2 Basic Gradient Calculation

The functions provided by `brainstate`, such as `grad` and `vector_grad`, support first-order gradient calculations. Below is a simple example:

In [1]:
import jax.numpy as jnp

import brainstate as bst

In [2]:
# Create a simple linear layer model
model = bst.nn.Linear(2, 3)

# Prepare input data
x = jnp.ones((1, 2))
y = jnp.ones((1, 3))

# Define the loss function
def loss_fn(x, y):
    return jnp.mean((y - model(x)) ** 2)

# Retrieve model parameters
weights = model.states()

# Compute gradients
grads = bst.augment.grad(loss_fn, weights)(x, y)

# Print gradient information
print("Gradients:", grads)


Gradients: {('weight',): {'bias': Array([-0.89415467, -1.5608642 , -1.6746773 ], dtype=float32), 'weight': Array([[-0.89415467, -1.5608642 , -1.6746773 ],
       [-0.89415467, -1.5608642 , -1.6746773 ]], dtype=float32)}}


In this example, a simple linear layer model is created using the `bst.nn.Linear` class, which takes two input features and produces three output features. Input data `x` and target output `y` are prepared as arrays of ones with shapes corresponding to the model's input and output dimensions.

The loss function, defined as the mean squared error between the model's predictions and the target outputs, computes the average of the squared differences.

The model parameters (weights) are retrieved using the `model.states()` method. Gradients of the loss function with respect to the model parameters are calculated using the automatic differentiation feature provided by `bst.augment.grad`. The computed gradients are then printed to the console.

### 1.3 Higher-Order Gradient Computation

The `BrainState` framework supports the computation of higher-order derivatives, which can be very useful in certain optimization tasks:

In [3]:
# Compute the Hessian
hessian = bst.augment.hessian(loss_fn, weights)(x, y)

# Compute the Jacobian matrix
jacobian = bst.augment.jacobian(loss_fn, weights)(x, y)

In this example, the framework allows for the calculation of second-order derivatives by using the `bst.augment.hessian` function, which computes the Hessian matrix of the loss function with respect to the model parameters. Additionally, the Jacobian matrix can be calculated using the `bst.augment.jacobian` function, providing insight into how changes in model parameters affect the output of the loss function. Both of these higher-order derivative computations are essential for advanced optimization techniques and can enhance the model's performance in complex tasks.

### 1.4 Gradient Transformation and the Chain Rule

You can combine multiple gradient computation operations:



In [4]:
# Combine multiple gradient operations
def composite_grad(fn, params):
    grad_fn = bst.augment.grad(fn, params)
    return lambda *args: bst.augment.grad(grad_fn, params)(*args)

In this example, a function `composite_grad` is defined to facilitate the combination of multiple gradient operations. The function first computes the gradient of a given function `fn` with respect to the specified parameters using `bst.augment.grad`. It then returns a new lambda function that, when called with arguments, applies the chain rule by computing the gradient of the previously obtained gradient function. This approach allows for the creation of more complex gradient computations, enabling users to apply higher-level operations efficiently while adhering to the principles of automatic differentiation.

## 2. Batching Augmentation

Batching is a key technique in deep learning that enhances computational efficiency. By processing multiple samples simultaneously, it improves hardware utilization and reduces computational overhead. The `brainstate` framework supports batching for [``pygraph`` models](./pygraph-en.ipynb), allowing users to implement batching through a simple API.

### 2.1 Creating Batching Models

We can transform a model to support batching using the `brainstate.augment.vmap` function. Below is a simple example:

In [5]:
@bst.augment.vmap
def create_linear(key):
    bst.random.set_key(key)
    return bst.nn.Linear(2, 3)

batch_size = 5

# Create a batched linear layer model
linears = create_linear(bst.random.split_key(batch_size))

linears

Linear(
  in_size=(2,),
  out_size=(3,),
  w_mask=None,
  weight=ParamState(
    value={'bias': Array([[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]], dtype=float32), 'weight': Array([[[ 0.38410592, -1.3700708 ,  1.0667006 ],
            [ 1.5859954 , -0.24795905, -1.3676361 ]],
    
           [[-1.3518977 , -0.8566778 , -1.5469979 ],
            [ 0.87259007,  1.465411  ,  0.07184158]],
    
           [[ 0.8410974 , -1.5966035 ,  0.4221514 ],
            [-0.1764095 , -1.3065816 ,  0.64682233]],
    
           [[-0.51042235,  1.0864646 ,  0.5021799 ],
            [-0.337543  , -0.894522  ,  1.131219  ]],
    
           [[ 0.5159668 ,  0.62981915,  1.1093888 ],
            [ 1.084107  , -0.3580393 ,  0.30819336]]], dtype=float32)}
  )
)

In this example, the `create_linear` function is defined to create linear layer models. By applying the `@bst.augment.vmap` decorator, this function is vectorized to enable batching. The `bst.random.split_key(batch_size)` function is used to generate a batch of random keys, each corresponding to a separate instance of the linear layer model. As a result, the `linears` variable contains a collection of linear layers that can process batches of input data, facilitating efficient training and inference in deep learning applications.

### 2.2 Using Batched Models

We can perform batching on the model's state variables along a specified dimension. Below is a simple example:


In [6]:
# Create batched data
x_batch = jnp.ones((batch_size, 2))
y_batch = jnp.ones((batch_size, 3))

# Compute loss for each batch
@bst.augment.vmap(in_axes=(0, 0))
def eval(model, x):
    return model(x)

# Batched version of the loss function
def batch_loss_fn(x_batch, y_batch):
    predictions = eval(linears, x_batch)
    return jnp.mean((y_batch - predictions) ** 2)

# Compute batched gradients
weights = linears.states(bst.ParamState)
batch_grads = bst.augment.grad(batch_loss_fn, weights)(x_batch, y_batch)

batch_grads

{('weight',): {'bias': Array([[ 0.12934685, -0.34907067, -0.17345807],
         [-0.19724102, -0.05216891, -0.33002084],
         [-0.04470828, -0.5204247 ,  0.0091965 ],
         [-0.2463954 , -0.10774099,  0.0844532 ],
         [ 0.08000985, -0.09709602,  0.05567762]], dtype=float32),
  'weight': Array([[[ 0.12934685, -0.34907067, -0.17345807],
          [ 0.12934685, -0.34907067, -0.17345807]],
  
         [[-0.19724102, -0.05216891, -0.33002084],
          [-0.19724102, -0.05216891, -0.33002084]],
  
         [[-0.04470828, -0.5204247 ,  0.0091965 ],
          [-0.04470828, -0.5204247 ,  0.0091965 ]],
  
         [[-0.2463954 , -0.10774099,  0.0844532 ],
          [-0.2463954 , -0.10774099,  0.0844532 ]],
  
         [[ 0.08000985, -0.09709602,  0.05567762],
          [ 0.08000985, -0.09709602,  0.05567762]]], dtype=float32)}}

In this example, we first create batched input data `x_batch` and target output `y_batch`. The `eval` function, decorated with `@bst.augment.vmap`, computes the model's predictions for each input in the batch. The `batch_loss_fn` function calculates the mean squared error loss between the predictions and the actual outputs for the entire batch.

To compute the gradients for the batched loss function, we retrieve the model parameters (weights) using `linears.states(bst.ParamState)` and then apply `bst.augment.grad` to the `batch_loss_fn`. This returns the gradients with respect to the model parameters, enabling efficient optimization in a batched context. This approach enhances computational performance and allows for the effective training of models on larger datasets.

In [7]:
class Foo(bst.nn.Module):
    def __init__(self):
        super().__init__()
        self.a = bst.ParamState(jnp.arange(4))
        self.b = bst.ShortTermState(jnp.arange(4))

    def __call__(self):
        self.b.value = self.a.value * self.b.value


@bst.augment.vmap
def mul(foo):
    foo()


foo = Foo()
mul(foo)

foo.b.value

Array([0, 1, 4, 9], dtype=int32)

In the example above, we defined a simple model called `Foo`, which includes a `ParamState` and a `ShortTermState`. We utilized the `vmap` function to transform this model into one that supports batching. In the `mul` function, we invoked this batched model. In this scenario, the `vmap` function automatically recognizes that the values of `a` and `b` are batched, and it seamlessly retrieves these batched values for processing.

This automatic handling of batched inputs simplifies the implementation of the model and allows users to focus on higher-level logic without worrying about the underlying mechanics of batching. By leveraging `vmap`, users can efficiently operate on collections of data, thus enhancing the model's scalability and performance.

### 3. Specifying Batching for States

During the modeling process, we often encounter situations where we need to batch certain `State` variables while leaving others unbatched. The `brainstate` framework provides the `brainstate.augment.StateAxes` class to specify which `State` variables require batching. `StateAxes` can be used to set the `in_axes` and `out_axes` parameters of `vmap`. Below is a simple example.

In [8]:
import brainunit as u

class LIFNet(bst.nn.Module):
    def __init__(self, nin, nout):
        super().__init__()
        self.i2r = bst.nn.Linear(nin, nout)
        self.lif = bst.nn.LIF(nout)

    def update(self, x):
        r = self.i2r(x)
        return self.lif(r * u.mA)

In this example, we define a simple LIF neuron model that includes a linear layer and an LIF neuron.


In [9]:
n_in = 2
n_out = 3
batch_size = 5

net = LIFNet(n_in, n_out)

@bst.augment.vmap(out_axes=bst.augment.StateAxes({'new': 0, ...: None}))
def init_net(key):
    bst.random.set_key(key)

    # Initialize a batch of model state variables
    with bst.catch_new_states('new'):
        bst.nn.init_all_states(net)

    # Return a batch of the model
    return net

net = init_net(bst.random.split_key(batch_size))

In [10]:
net.lif.V.value

ArrayImpl([[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]], dtype=float32) * mvolt

In this example, we use the `StateAxes` class to specify which `State` variables need batching. We first catch all newly initialized `State` variables using `brainstate.catch_new_states`, assigning them a common tag (`tag='new'`). In the `vmap` function, we utilize `StateAxes` to indicate that all `State` variables with `tag='new'` should be batched. This allows us to perform batching operations on the newly initialized `State` variables.

In [11]:
@bst.augment.vmap(in_axes=(bst.augment.StateAxes({'new': 0, ...: None}), 0))
def batch_run(model, x):
    with bst.environ.context(dt=0.1 * u.ms):
        o = model.update(x)
    return o

xs = bst.random.rand(batch_size, 2) < 0.5

r = batch_run(net, xs)

In [12]:
net.lif.V.value

ArrayImpl([[0.0107702 , 0.03013004, 0.05423805],
           [0.0107702 , 0.03013004, 0.05423805],
           [0.0107702 , 0.03013004, 0.05423805],
           [0.01397169, 0.0071168 , 0.04299236],
           [0.01397169, 0.0071168 , 0.04299236]], dtype=float32) * mvolt

In this example, we use the `vmap` function to invoke the previously batched model. In the `batch_run` function, we again specify with `StateAxes` that all `State` variables with `tag='new'` require batching, while other `State` instances do not need batching.

Through the above examples, we see that even though our model definition is single-batch, we can flexibly apply the `vmap` function to achieve batching during model invocation. This allows us to avoid concerns about batching in the model definition, enabling us to decide on batching at the point of model execution.

## 3. Multi-Device Parallel Computation

In addition to batching, `brainstate` also supports multi-device parallel computation. We can use the `brainstate.augment.pmap` function to convert a model into one that supports parallel computation across multiple devices.

The usage of the `brainstate.augment.pmap` augmentation function is fundamentally similar to that of the `brainstate.augment.vmap` function. However, the `pmap` function transforms the model to enable parallel computation across multiple devices, while the `vmap` function facilitates parallel computation across different threads on a single device.

This capability is particularly valuable in leveraging the computational power of multiple devices, such as GPUs or TPUs, to accelerate training and inference processes in deep learning applications. By utilizing `pmap`, users can effectively distribute their workloads and achieve significant improvements in performance and efficiency.

## 4. Combining Augmentation Transformations

In practical applications, we often need to combine multiple transformations for program augmentation. The various program augmentation functions and compilation functions in `brainstate` can be used together seamlessly. Below is a simple example:

In [13]:
ys = bst.random.rand(batch_size, 3)

def loss_fn(x, y):
    return jnp.mean((y - batch_run(net, x)) ** 2)


weights = net.states(bst.ParamState)
opt = bst.optim.Adam(1e-3)
opt.register_trainable_weights(weights)


@bst.compile.jit
def batch_train(xs, ys):
    grads, l = bst.augment.grad(loss_fn, weights, return_value=True)(xs, ys)
    opt.update(grads)
    return l
    
    
l = batch_train(xs, ys)
l

Array(0.32209805, dtype=float32)

In this example, we define a loss function `loss_fn` that calculates the mean squared error between the model's predictions, obtained by invoking `batch_run(net, x)`, and the target values `y`. The model's trainable parameters (weights) are registered with the Adam optimizer.

The `batch_train` function is compiled with Just-In-Time (JIT) compilation using `bst.compile.jit`. Within this function, we compute the gradients of the loss function with respect to the model parameters using `bst.augment.grad`, specifying that we want both gradients and the loss value returned.

After calculating the gradients, we update the model parameters with the optimizer's `update` method. This approach illustrates how multiple functional enhancements, such as batching, gradient computation, and JIT compilation, can be effectively combined to streamline the training process. The result, `l`, contains the computed loss value, demonstrating the integration of these functionalities in a cohesive workflow.

## 5. Performance Optimization Recommendations

When utilizing program augmentation transformations, the following points should be considered to achieve optimal performance:

1. **Appropriate Batch Size**: Choose a suitable batch size based on device memory and computational capacity.
2. **Gradient Accumulation**: When the batch size is limited, consider implementing gradient accumulation.
3. **Cached Compilation**: Reuse compiled functions to reduce compilation overhead.
4. **Memory Management**: Use `jax.device_get()` to release device memory in a timely manner.


## 6. Debugging Techniques

Debugging is an important topic when utilizing functional enhancements. The `brainstate` framework fully supports the debugging tools provided in JAX, such as the `print` function in `jax.debug`. Below is a simple example:

```python
# Debugging with jax.debug.print
@bst.compile.jit
def batch_train(xs, ys):
    grads, l = bst.augment.grad(loss_fn, weights, return_value=True)(xs, ys)
    jax.debug.print("Gradients: {g}", g=grads)
    opt.update(grads)
    return l
```

For detailed usage, users can refer to the [JAX Debugging Documentation](https://jax.readthedocs.io/en/latest/debugging/index.html).


In this example, the `batch_train` function utilizes `jax.debug.print` to output the computed gradients during the training process. This can be particularly useful for monitoring the training dynamics and diagnosing issues related to gradient computations. Leveraging debugging tools like this can enhance the development process and facilitate the identification of errors or unexpected behaviors in the model's training workflow.

## Conclusion

The program augmentations based on `pygraph` are one of the core features of the `BrainState` framework. By effectively utilizing these augmentations, one can significantly improve the inference, training efficiency, and overall performance of models. This tutorial has covered the primary program augmentation features and their usage methods, providing a foundation for further exploration and application of these capabilities. 