# Training Artificial Neural Networks

In recent years, artificial neural networks have developed rapidly and play an important role in neuroscience research. As a high-performance computational framework for brain dynamics modeling, brainstate also supports the training of artificial neural networks, facilitating the integration of neural dynamics models with artificial neural networks.

Here, we will introduce how to train an artificial neural network using brainstate, with an example of a simple 2-layer multilayer perceptron (MLP) for handwritten digit recognition (MNIST).

In [1]:
import jax.numpy as jnp
import numpy as np
from datasets import load_dataset

import braintools 
import brainstate 
from braintools.metric import softmax_cross_entropy_with_integer_labels

In [2]:
brainstate.__version__

'0.2.2'

## Preparing the Dataset

First, we need to obtain the dataset and wrap it into an iterable object that automatically samples and shuffles the data according to the batch size.

In [3]:
dataset = load_dataset('mnist')
X_train = np.array(np.stack(dataset['train']['image']), dtype=np.uint8)
X_test = np.array(np.stack(dataset['test']['image']), dtype=np.uint8)
X_train = (X_train > 0).astype(jnp.float32)
X_test = (X_test > 0).astype(jnp.float32)
Y_train = np.array(dataset['train']['label'], dtype=np.int32)
Y_test = np.array(dataset['test']['label'], dtype=np.int32)

In [4]:
class Dataset:
    def __init__(self, X, Y, batch_size, shuffle=True):
        self.X = X
        self.Y = Y
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = np.arange(len(X))
        self.current_index = 0
        if self.shuffle:
            np.random.shuffle(self.indices)

    def __iter__(self):
        self.current_index = 0
        if self.shuffle:
            np.random.shuffle(self.indices)
        return self

    def __next__(self):
        # Check if all samples have been processed
        if self.current_index >= len(self.X):
            raise StopIteration

        # Define the start and end of the current batch
        start = self.current_index
        end = start + self.batch_size
        if end > len(self.X):
            end = len(self.X)
        
        # Update current index
        self.current_index = end

        # Select batch samples
        batch_indices = self.indices[start:end]
        batch_X = self.X[batch_indices]
        batch_Y = self.Y[batch_indices]

        # Ensure batch has consistent shape
        if batch_X.ndim == 1:
            batch_X = np.expand_dims(batch_X, axis=0)

        return batch_X, batch_Y

In [5]:
# Initialize training and testing datasets
batch_size = 32
train_dataset = Dataset(X_train, Y_train, batch_size, shuffle=True)
test_dataset = Dataset(X_test, Y_test, batch_size, shuffle=False)

## Defining the Artificial Neural Network

When defining an artificial neural network in brainstate, you need to inherit the base class ``brainstate.nn.Module``. In the class method ``__init__()``, define the layers in the network (make sure to initialize the base class first using ``super().__init__()``). In the class method ``__call__()``, define the forward pass method of the network.

brainstate also supports defining operations for individual layers in the network. For these custom layers, you need to inherit from the base class ``brainstate.nn.Module``, similar to defining a network.

All quantities that need to change in the model should be encapsulated in the ``State`` object. Parameters that need to be updated during training should be encapsulated in a subclass of ``State`` called ``ParamState``. Other quantities that need to be updated during training are encapsulated in another subclass of ``State`` called ``ShortTermState``.

In [6]:
# Define linear layer
class Linear(brainstate.nn.Module):
  def __init__(self, din: int, dout: int):
    super().__init__()
    self.w = brainstate.ParamState(brainstate.random.rand(din, dout))  # Initialize weight parameters
    self.b = brainstate.ParamState(jnp.zeros((dout,)))  # Initialize bias parameters

  def __call__(self, x):
    return x @ self.w.value + self.b.value    # Perform linear transformation

In [7]:
# Define a short-term state for counting times called
class Count(brainstate.ShortTermState):
  pass

In [8]:
# Define MLP model
class MLP(brainstate.nn.Module):
  def __init__(self, din, dhidden, dout):
    super().__init__()
    self.count = Count(jnp.array(0))    # Count how many times model is called
    self.linear1 = Linear(din, dhidden)          # brainstate有常规层的实现，可以直接写 self.linear1 = brainstate.nn.Linear(din, dhidden)
    self.linear2 = Linear(dhidden, dout)
    self.flatten = brainstate.nn.Flatten(start_axis=1)   # Flatten images to 1D
    self.relu = brainstate.nn.ReLU()   # ReLU activation function

  def __call__(self, x):
    self.count.value += 1   # Increment call count

    x = self.flatten(x)
    x = self.linear1(x)
    x = self.relu(x)      # 也兼容jax函数，可以直接写 x = jax.nn.relu(x)
    x = self.linear2(x)
    return x

In [9]:
# Initialize model with input, hidden, and output layer sizes
model = MLP(din=28*28, dhidden=512, dout=10)

## Optimizer Setup

``braintools.optim`` provides various optimizers to choose from.

After instantiating the optimizer, you need to specify which parameters the optimizer should update by calling ``optimizer.register_trainable_weights()``.

In this case, we use ``brainstate.nn.Module.states()`` to collect all the ``State`` objects of the network nodes and their sub-nodes in the model. We restrict the types of ``State`` collected to ``brainstate.ParamState`` (in this model, ``State`` instances may also have other types like ``Count``, which do not need to be updated by the optimizer, so we apply type restrictions).

In [10]:
# Initialize optimizer and register model parameters
optimizer = braintools.optim.SGD(lr = 1e-3)   # Initialize SGD optimizer with learning rate
optimizer.register_trainable_weights(model.states(brainstate.ParamState))   # Register parameters for optimization

  optimizer = braintools.optim.SGD(lr = 1e-3)   # Initialize SGD optimizer with learning rate


SGD(
  momentum=0.0,
  nesterov=False,
  param_states=<braintools.optim.UniqueStateManager object at 0x0000023F61D79DF0>,
  weight_decay=0.0,
  grad_clip_norm=None,
  grad_clip_value=None,
  step_count=OptimState(
    value=ShapedArray(int32[], weak_type=True)
  ),
  param_groups=[
    {
      'params': {
        ('linear1', 'b'): ParamState(
          value=ShapedArray(float32[512])
        ),
        ('linear1', 'w'): ParamState(
          value=ShapedArray(float32[784,512])
        ),
        ('linear2', 'b'): ParamState(
          value=ShapedArray(float32[10])
        ),
        ('linear2', 'w'): ParamState(
          value=ShapedArray(float32[512,10])
        )
      },
      'lr': OptimState(
        value=ShapedArray(float32[], weak_type=True)
      ),
      'weight_decay': 0.0
    }
  ],
  param_groups_opt_states=[],
  _schedulers=[],
  _lr_scheduler=<braintools.optim.ConstantLR object at 0x0000023F61D7A180>,
  _base_lr=0.001,
  _current_lr=OptimState(...),
  tx=GradientTransf

## Model Training

During model training, use the ``brainstate.transform.grad`` function to calculate gradients. This function requires the loss function and the parameters (``State``) for which gradients should be computed.

Then, the gradients are passed to the previously defined optimizer via ``update()`` for the update.

To improve computational efficiency and performance, use the ``brainstate.transform.jit`` function to decorate the training step function, enabling just-in-time compilation.

In [11]:
# Training step function
@brainstate.transform.jit
def train_step(batch):
  x, y = batch
  # Define loss function
  def loss_fn():
    return softmax_cross_entropy_with_integer_labels(model(x), y).mean()
  
  # Compute gradients of the loss with respect to model parameters
  grads = brainstate.transform.grad(loss_fn, model.states(brainstate.ParamState))()
  optimizer.update(grads)   # Update parameters using optimizer

  @brainstate.transform.jit


## Model Testing

Similarly, use the ``brainstate.transform.jit`` function to decorate the testing step function, allowing for just-in-time compilation to improve computational efficiency and performance.

In [12]:
# Testing step function
@brainstate.transform.jit
def test_step(batch):
  x, y = batch
  y_pred = model(x)   # Perform forward pass
  loss = softmax_cross_entropy_with_integer_labels(y_pred, y).mean()   # Compute loss
  correct = (y_pred.argmax(1) == y).sum()   # Count correct predictions

  return {'loss': loss, 'correct': correct}

## Training Process

This completes the setup and the process for training an artificial neural network with brainstate.

In [13]:
# Execute training and testing
total_steps = 20
for epoch in range(10):
  for step, batch in enumerate(train_dataset):
    train_step(batch)   # Perform training step for each batch

  # Calculate test loss and accuracy
  test_loss, correct = 0, 0
  for step_, test_ in enumerate(test_dataset):
    logs = test_step(test_)
    test_loss += logs['loss']
    correct += logs['correct']
    test_loss += logs['loss']
  test_loss = test_loss / (step_ + 1)
  test_accuracy = correct / len(X_test)
  print(f"epoch: {epoch}, test loss: {test_loss}, test accuracy: {test_accuracy}")

print('times model called:', model.count.value)   # Output number of model calls

  grads = brainstate.transform.grad(loss_fn, model.states(brainstate.ParamState))()


epoch: 0, test loss: 402.7491455078125, test accuracy: 0.48350000381469727
epoch: 1, test loss: 286.73333740234375, test accuracy: 0.4884999990463257
epoch: 2, test loss: 113.37454986572266, test accuracy: 0.7487999796867371
epoch: 3, test loss: 51.642398834228516, test accuracy: 0.8188999891281128
epoch: 4, test loss: 102.10010528564453, test accuracy: 0.7190999984741211
epoch: 5, test loss: 47.38042068481445, test accuracy: 0.8242999911308289
epoch: 6, test loss: 68.87149047851562, test accuracy: 0.782800018787384
epoch: 7, test loss: 43.74293518066406, test accuracy: 0.8547999858856201
epoch: 8, test loss: 72.86088562011719, test accuracy: 0.7950999736785889
epoch: 9, test loss: 42.85074996948242, test accuracy: 0.8532999753952026
times model called: 21880
