# 训练人工神经网络

近年来，人工神经网络蓬勃发展，在神经科学研究中也扮演重要角色。brainstate作为面向脑动力学建模的高性能计算框架，也支持人工神经网络训练，便于神经动力学模型和人工神经网络的对接。

在这里，我们用一个简单的2层多层感知机（MLP）进行手写数字识别（MNIST）任务的例子，介绍如何用brainstate训练人工神经网络。

In [1]:
import jax.numpy as jnp
import numpy as np
from datasets import load_dataset

import brainstate
from braintools.metric import softmax_cross_entropy_with_integer_labels

In [2]:
brainstate.__version__

'0.1.0'

## 准备数据集

需要获取数据集内容，并将数据集包装成可迭代对象，按照batch大小进行自动采样、洗牌

In [3]:
dataset = load_dataset('mnist')
X_train = np.array(np.stack(dataset['train']['image']), dtype=np.uint8)
X_test = np.array(np.stack(dataset['test']['image']), dtype=np.uint8)
X_train = (X_train > 0).astype(jnp.float32)
X_test = (X_test > 0).astype(jnp.float32)
Y_train = np.array(dataset['train']['label'], dtype=np.int32)
Y_test = np.array(dataset['test']['label'], dtype=np.int32)

In [None]:
class Dataset:
    def __init__(self, X, Y, batch_size, shuffle=True):
        self.X = X
        self.Y = Y
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = np.arange(len(X))
        self.current_index = 0
        if self.shuffle:
            np.random.shuffle(self.indices)

    def __iter__(self):
        self.current_index = 0
        if self.shuffle:
            np.random.shuffle(self.indices)
        return self

    def __next__(self):
        # Check if all samples have been processed
        if self.current_index >= len(self.X):
            raise StopIteration

        # Define the start and end of the current batch
        start = self.current_index
        end = start + self.batch_size
        if end > len(self.X):
            end = len(self.X)
        
        # Update current index
        self.current_index = end

        # Select batch samples
        batch_indices = self.indices[start:end]
        batch_X = self.X[batch_indices]
        batch_Y = self.Y[batch_indices]

        # Ensure batch has consistent shape
        if batch_X.ndim == 1:
            batch_X = np.expand_dims(batch_X, axis=0)

        return batch_X, batch_Y

In [None]:
# Initialize training and testing datasets
batch_size = 32
train_dataset = Dataset(X_train, Y_train, batch_size, shuffle=True)
test_dataset = Dataset(X_test, Y_test, batch_size, shuffle=False)

## 人工神经网络的定义

brainstate在定义人工神经网络时，需要继承基类 ``brainstate.nn.Module``。定义时需要在类方法``__init__()``中定义网络中的层（注意要首先初始化基类``super().__init__()``）；需要在类方法``__call__()``中定义网络前向传播的方法。

brainstate也支持具体定义网络某一层执行的操作，也需要继承基类 ``brainstate.nn.Module``, 用法和定义网络时类似。

模型中所有需要改变的量都需封装在 ``State`` 对象中。其中，需要在训练中更新的模型参数需封装在``ParamState``（是``State`` 的子类）对象中。其他需要在训练中更新的量封装在``State``的另一个子类``ShortTermState``对象中。

In [None]:
# Define linear layer
class Linear(brainstate.nn.Module):
  def __init__(self, din: int, dout: int):
    super().__init__()
    self.w = brainstate.ParamState(brainstate.random.rand(din, dout))  # Initialize weight parameters
    self.b = brainstate.ParamState(jnp.zeros((dout,)))  # Initialize bias parameters

  def __call__(self, x):
    return x @ self.w.value + self.b.value    # Perform linear transformation

In [None]:
# Define a short-term state for counting times called
class Count(brainstate.ShortTermState):
  pass

In [None]:
# Define MLP model
class MLP(brainstate.nn.Module):
  def __init__(self, din, dhidden, dout):
    super().__init__()
    self.count = Count(jnp.array(0))    # Count how many times model is called
    self.linear1 = Linear(din, dhidden)          # brainstate有常规层的实现，可以直接写 self.linear1 = bst.nn.Linear(din, dhidden)
    self.linear2 = Linear(dhidden, dout)
    self.flatten = brainstate.nn.Flatten(start_axis=1)   # Flatten images to 1D
    self.relu = brainstate.nn.ReLU()   # ReLU activation function

  def __call__(self, x):
    self.count.value += 1   # Increment call count

    x = self.flatten(x)
    x = self.linear1(x)
    x = self.relu(x)      # 也兼容jax函数，可以直接写 x = jax.nn.relu(x)
    x = self.linear2(x)
    return x

In [None]:
# Initialize model with input, hidden, and output layer sizes
model = MLP(din=28*28, dhidden=512, dout=10)

## 优化器设置
``brainstate.optim``提供了各种各样的优化器可供选择。

优化器实例化后，需要在``optimizer.register_trainable_weights()``中规定传入模型中计划使用此优化器更新的参数。

此处用``brainstate.nn.Module.states()``收集了该模型中所有网络节点和子节点的``State``，并且限制了所收集``State``的类型为``brainstate.ParamState``（在此例的模型中，``State``实例还有``Count``类，而这并非是需要此优化器更新的，因此需要加以类型限制）。

In [None]:
# Initialize optimizer and register model parameters
optimizer = brainstate.optim.SGD(lr = 1e-3)   # Initialize SGD optimizer with learning rate
optimizer.register_trainable_weights(model.states(brainstate.ParamState))   # Register parameters for optimization

## 模型训练
模型训练时，使用``brainstate.augment.grad``函数来计算梯度，需要传入损失函数和需计算梯度的参数``State``。

然后将梯度通过``update()``传给先前定义好的优化器，进行更新。

使用``brainstate.compile.jit``函数装饰单步训练函数，使之可以即时编译，提高计算效率和性能。

In [None]:
# Training step function
@brainstate.compile.jit
def train_step(batch):
  x, y = batch
  # Define loss function
  def loss_fn():
    return softmax_cross_entropy_with_integer_labels(model(x), y).mean()
  
  # Compute gradients of the loss with respect to model parameters
  grads = brainstate.augment.grad(loss_fn, model.states(brainstate.ParamState))()
  optimizer.update(grads)   # Update parameters using optimizer

## 模型测试
使用``brainstate.compile.jit``函数装饰单步测试函数，使之可以即时编译，提高计算效率和性能。

In [None]:
# Testing step function
@brainstate.compile.jit
def test_step(batch):
  x, y = batch
  y_pred = model(x)   # Perform forward pass
  loss = softmax_cross_entropy_with_integer_labels(y_pred, y).mean()   # Compute loss
  correct = (y_pred.argmax(1) == y).sum()   # Count correct predictions

  return {'loss': loss, 'correct': correct}

## 训练进行

In [None]:
# Execute training and testing
total_steps = 20
for epoch in range(10):
  for step, batch in enumerate(train_dataset):
    train_step(batch)   # Perform training step for each batch

  # Calculate test loss and accuracy
  test_loss, correct = 0, 0
  for step_, test_ in enumerate(test_dataset):
    logs = test_step(test_)
    test_loss += logs['loss']
    correct += logs['correct']
    test_loss += logs['loss']
  test_loss = test_loss / (step_ + 1)
  test_accuracy = correct / len(X_test)
  print(f"epoch: {epoch}, test loss: {test_loss}, test accuracy: {test_accuracy}")

print('times model called:', model.count.value)   # Output number of model calls

epoch: 0, test loss: 410.2366638183594, test accuracy: 0.24890001118183136
epoch: 1, test loss: 278.79864501953125, test accuracy: 0.6233000159263611
epoch: 2, test loss: 75.72823333740234, test accuracy: 0.7638000249862671
epoch: 3, test loss: 59.49712371826172, test accuracy: 0.7830000519752502
epoch: 4, test loss: 38.07597351074219, test accuracy: 0.8623000383377075
epoch: 5, test loss: 54.225074768066406, test accuracy: 0.8329000473022461
epoch: 6, test loss: 74.46405792236328, test accuracy: 0.7676000595092773
epoch: 7, test loss: 35.6864128112793, test accuracy: 0.867900013923645
epoch: 8, test loss: 140.0616912841797, test accuracy: 0.7529000639915466
epoch: 9, test loss: 42.05353927612305, test accuracy: 0.8574000597000122
times called: 21880
