In [10]:
### Import
import brainpy as bp
import brainpy.math as bm
import numpy as np 
import tqdm as notebook_tqdm
import matplotlib.pyplot as plt
bm.set_platform('gpu')

In [11]:
class Example:
    def __init__(self):
        self.static = 0
        self.dyn = bm.Variable(bm.ones(1)) #let bm.variable 追蹤

    @bm.cls_jit  # JIT compiled function
    def update(self, inp):
        self.dyn.value = self.dyn * inp + self.static

一旦 JIT 編譯過了，compiled_update_graph 就不會再重新讀取 Python 屬性。
所以就算後面改變.static的值，update()讀到的仍然會是一開始=0的部分

In [12]:
example = Example()

In [None]:
example.update(1)
example.dyn
example.static = 100
print(example.static)

100


## A complex example: Training a network
With the simple understanding of how OO transformations work, we can train a neural network model using the these transformations .

In this training case, we want to teach the neural network to correctly classify a random array as two labels (True or False). That is, we have the training data:

In [None]:
num_in = 100
num_sample = 256
X = bm.random.rand(num_sample, num_in)
Y = (bm.random.rand(num_sample) < 0.5).astype(float)

In [None]:
class Linear(bp.BrainPyObject): #手刻了一個linear layer
    def __init__(self, n_in, n_out):
        super().__init__()
        self.num_in = n_in
        self.num_out = n_out
        init = bp.init.XavierNormal()
        self.W = bm.Variable(init((n_in, n_out)))
        self.b = bm.Variable(bm.zeros((1, n_out)))

    def __call__(self, x):
        return x @ self.W + self.b


net = bp.Sequential(Linear(num_in, 20),
                    bm.relu,
                    Linear(20, 2))
print(net)

In [None]:
class Trainer(object):
    def __init__(self, net):
        self.net = net
        self.grad = bm.grad(self.loss, grad_vars=net.vars(), return_value=True)
        self.optimizer = bp.optim.SGD(lr=1e-2, train_vars=net.vars())

    @bm.cls_jit(inline=True)
    def loss(self):
        # shuffle the data
        key = bm.random.split_key()
        x_data = bm.random.permutation(X, key=key)
        y_data = bm.random.permutation(Y, key=key)
        # prediction
        predictions = net(dict(), x_data)
        # loss
        l = bp.losses.cross_entropy_loss(predictions, y_data)
        return l

    @bm.cls_jit
    def train(self):
        grads, l = self.grad()
        self.optimizer.update(grads)
        return l