# Differentiation

In this section, we are going to talk about how to make automatical differentiation on your variables in a function or a class object. In nowadays machine learning systems, computing and using gradients are common in various situations. So, we are going to understand 

- how to  calculate derivatives of arbitrary complex functions, 
- how to compute high-order gradients. 

In [1]:
import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')

## Preliminary

Every autograd function in BrainPy has several keywords. In the below, all examples are illustrated through [brainpy.math.grad()](../apis/auto/math/generated/brainpy.math.autograd.grad.rst). Other autograd functions have the same settings. 

### ``argnums`` and ``grad_vars``

The autograd functions in BrainPy can compute derivatives of *function arguments* (can be specified through `argnums`) or *non-argument variables* (can be specified through ``grad_vars``). For instance, in this linear readout model, 

In [8]:
class Linear(bp.Base):
    def __init__(self):
        super(Linear, self).__init__()
        self.w = bm.random.random((1, 10))
        self.b = bm.zeros(1)
    
    def update(self, x):
        r = bm.dot(self.w, x) + self.b
        return r.sum()
    
l = Linear()

If we try to take derivative of the argument "x" when calling the update function, we can set this through ``argnums``:

In [9]:
grad = bm.grad(l.update, argnums=0)

grad(bm.ones(10))

JaxArray(DeviceArray([0.0940454 , 0.24210012, 0.10360408, 0.2991985 , 0.22486198,
                      0.9399384 , 0.88925755, 0.84567535, 0.94881094, 0.7843926 ],            dtype=float32))

But, if you take care of the derivatives of parameters "self.w" and "self.b", we should label them with ``grad_vars``:  

In [10]:
grad = bm.grad(l.update, grad_vars=(l.w, l.b))

grad(bm.ones(10))

(JaxArray(DeviceArray([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)),
 JaxArray(DeviceArray([1.], dtype=float32)))

If we pay attention on the derivatives of both argument "x" and parameters "self.w" and "self.b", ``argnums`` and ``grad_vars`` can be used together. In this time, the gradient function will return gradients with the format of ``(var_grads, arg_grads)``, where ``arg_grads`` refers to the gradients of "argnums", and ``var_grads`` refers to the gradients of "grad_vars". 

In [12]:
grad = bm.grad(l.update, grad_vars=(l.w, l.b), argnums=0)

var_grads, arg_grads = grad(bm.ones(10))

In [13]:
var_grads

(JaxArray(DeviceArray([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)),
 JaxArray(DeviceArray([1.], dtype=float32)))

In [14]:
arg_grads

JaxArray(DeviceArray([0.0940454 , 0.24210012, 0.10360408, 0.2991985 , 0.22486198,
                      0.9399384 , 0.88925755, 0.84567535, 0.94881094, 0.7843926 ],            dtype=float32))

### ``return_value``

As you see, autograd function return a function which computes gradients with respect to the function returned value. However, sometimes, we take care of the value that function returns, not just gradients. Therefore, you can set ``return_value=True`` in autograd functions.   

In [15]:
grad = bm.grad(l.update, argnums=0, return_value=True)

gradient, value = grad(bm.ones(10))

In [16]:
gradient

JaxArray(DeviceArray([0.0940454 , 0.24210012, 0.10360408, 0.2991985 , 0.22486198,
                      0.9399384 , 0.88925755, 0.84567535, 0.94881094, 0.7843926 ],            dtype=float32))

In [17]:
value

DeviceArray(5.3718853, dtype=float32)

### ``has_aux``

In some situations, we are interested in the intermediate values in a function. In this time, ``has_aux=True`` can help you. The constrain is that you must return values with the format of ``(loss, aux_data)``. For instance, 

In [18]:
class LinearAux(bp.Base):
    def __init__(self):
        super(LinearAux, self).__init__()
        self.w = bm.random.random((1, 10))
        self.b = bm.zeros(1)
    
    def update(self, x):
        dot = bm.dot(self.w, x)
        r = (dot + self.b).sum()
        return r, (r, dot)  # here the aux data is a tuple, includes the loss and the dot value.
                            # however, aux can be arbitrary complex.
    
l2 = LinearAux()

In [19]:
grad = bm.grad(l2.update, argnums=0, has_aux=True)

gradient, aux = grad(bm.ones(10))

In [20]:
gradient

JaxArray(DeviceArray([0.7740855 , 0.6669129 , 0.74336326, 0.7743118 , 0.08353662,
                      0.1557033 , 0.27870536, 0.3860656 , 0.14068758, 0.46460104],            dtype=float32))

In [21]:
aux

(DeviceArray(4.4679728, dtype=float32),
 JaxArray(DeviceArray([4.4679728], dtype=float32)))

When multiple keywords (``argnums``, ``grad_vars``, ``has_aux`` or``return_value``) are set simulatenously, the return format of the gradient function can be inspected through the corresponding API documentationm. For the above API, please see [brainpy.math.grad()](../apis/auto/math/generated/brainpy.math.autograd.grad.rst).

## ``brainpy.math.grad()``

[brainpy.math.grad()](../apis/auto/math/generated/brainpy.math.autograd.grad.rst) takes a function/object ($f : \mathbb{R}^n \to \mathbb{R}$) and returns a new function ($\partial f(x) \to \mathbb{R}^n$) which computes the gradient of the original function/object. It's worthy to note that ``brainpy.math.grad()`` only support scalar value return. 

### Pure functions

For pure function, the gradient is taken with respect to the first argument: 

In [2]:
def f(a, b):
    return a * 2 + b

grad_f1 = bm.grad(f)

In [3]:
grad_f1(2., 1.)

DeviceArray(2., dtype=float32)

However, this can be controlled via the `argnums` argument.

In [4]:
grad_f2 = bm.grad(f, argnums=(0, 1))

grad_f2(2., 1.)

(DeviceArray(2., dtype=float32), DeviceArray(1., dtype=float32))

### Class objects

For a class object or a class bound function, the gradient is taken with respect to the provided ``grad_vars`` and ``argnums`` setting: 

In [22]:
class F(bp.Base):
    def __init__(self):
        super(F, self).__init__()
        self.a = bm.TrainVar(bm.ones(1))
        self.b = bm.TrainVar(bm.ones(1))

    def __call__(self, c):
        ab = self.a * self.b
        ab2 = ab * 2
        vv = ab2 + c
        return vv.mean()
    
f = F()

The ``grad_vars`` can be a JaxArray, or a list/tuple/dict of JaxArray. 

In [23]:
bm.grad(f, grad_vars=f.train_vars())(10.)

{'F0.a': TrainVar(DeviceArray([2.], dtype=float32)),
 'F0.b': TrainVar(DeviceArray([2.], dtype=float32))}

In [24]:
bm.grad(f, grad_vars=[f.a, f.b])(10.)

[TrainVar(DeviceArray([2.], dtype=float32)),
 TrainVar(DeviceArray([2.], dtype=float32))]

If there are values dynamically changed in the gradient function, you can provide them in the ``dyn_vars`` argument. 

In [8]:
class F2(bp.Base):
    def __init__(self):
        super(F2, self).__init__()
        self.a = bm.TrainVar(bm.ones(1))
        self.b = bm.TrainVar(bm.ones(1))

    def __call__(self, c):
        ab = self.a * self.b
        ab = ab * 2
        self.a.value = ab
        return (ab + c).mean()

In [9]:
f2 = F2()
bm.grad(f2, dyn_vars=[f2.a], grad_vars=f2.b)(10.)

TrainVar(DeviceArray([2.], dtype=float32))

Also, if you are interested with the gradient of the input value, please use ``argnums`` argument. For this situation, calling the gradient function will return ``(grads_of_grad_vars, grads_of_args)``. 

In [14]:
class F3(bp.Base):
    def __init__(self):
        super(F3, self).__init__()
        self.a = bm.TrainVar(bm.ones(1))
        self.b = bm.TrainVar(bm.ones(1))

    def __call__(self, c, d):
        ab = self.a * self.b
        ab = ab * 2
        return (ab + c * d).mean()

In [16]:
f3 = F3()
grads_of_gv, grad_of_args = bm.grad(f3, grad_vars=[f3.a, f3.b], argnums=0)(10., 3.)

print("grads_of_gv :", grads_of_gv)
print("grad_of_args :", grad_of_args)

grads_of_gv : [TrainVar(DeviceArray([2.], dtype=float32)), TrainVar(DeviceArray([2.], dtype=float32))]
grads_of_args : 3.0


In [19]:
f3 = F3()
grads_of_gv, grad_of_args = bm.grad(f3, grad_vars=[f3.a, f3.b], argnums=(0, 1))(10., 3.)

print("grads_of_gv :", grads_of_gv)
print("grad_of_args :", grad_of_args)

grads_of_gv : [TrainVar(DeviceArray([2.], dtype=float32)), TrainVar(DeviceArray([2.], dtype=float32))]
grad_of_arg0 : 3.0
grad_of_arg1 : 10.0


Actually, we recommend you to provide any dynamically changed variables (no matter them are updated in the gradient function) in the ``dyn_vars`` argument. 

### Auxiliary data

Usually, we want to get the value of the loss, or, we want to return some intermediate variables during the gradient computation. For these situation, users can set ``has_aux=True`` to return auxiliary data, and set ``return_value=True`` to return loss value. 

In [11]:
# return loss

grad, loss = bm.grad(f, grad_vars=f.a, return_value=True)(10.)

print('grad: ', grad)
print('loss: ', loss)

grad:  TrainVar(DeviceArray([2.], dtype=float32))
loss:  12.0


In [21]:
class F4(bp.Base):
    def __init__(self):
        super(F4, self).__init__()
        self.a = bm.TrainVar(bm.ones(1))
        self.b = bm.TrainVar(bm.ones(1))

    def __call__(self, c):
        ab = self.a * self.b
        ab2 = ab * 2
        loss = (ab + c).mean()
        return loss, (ab, ab2)
    

f4 = F4()
    
# return intermediate values
grad, aux_data = bm.grad(f4, grad_vars=f4.a, has_aux=True)(10.)

print('grad: ', grad)
print('aux_data: ', aux_data)

grad:  TrainVar(DeviceArray([1.], dtype=float32))
aux_data:  (JaxArray(DeviceArray([1.], dtype=float32)), JaxArray(DeviceArray([2.], dtype=float32)))


```note
Any function wants to compute gradients through ``brainpy.math.grad()`` must return a scalar value. Otherwise an error will raise. 
```

In [23]:
try:
    bm.grad(lambda x: x)(bm.zeros(2))
except Exception as e:
    print(type(e), e)

<class 'TypeError'> Gradient only defined for scalar-output functions. Output was [0. 0.].


In [24]:
# this is right

bm.grad(lambda x: x.mean())(bm.zeros(2))

JaxArray(DeviceArray([0.5, 0.5], dtype=float32))

## ``brainpy.math.vector_grad()``

If you want to take gradients for a vector-output values, please use [brainpy.math.vector_grad()](../apis/auto/math/generated/brainpy.math.autograd.vector_grad.rst)  function. For example, 

In [8]:
def f(a, b): 
    return bm.sin(b) * a

Gradients for vectors.

In [None]:
# vectors

a = bm.arange(5.)
b = bm.random.random(5)

In [9]:
bm.vector_grad(f)(a, b)

JaxArray(DeviceArray([0.829829  , 0.3382971 , 0.13563846, 0.5101524 , 0.28861028],            dtype=float32))

In [10]:
bm.vector_grad(f, argnums=(0, 1))(a, b)

(JaxArray(DeviceArray([0.829829  , 0.3382971 , 0.13563846, 0.5101524 , 0.28861028],            dtype=float32)),
 JaxArray(DeviceArray([0.       , 0.9410394, 1.9815168, 2.580252 , 3.8297865], dtype=float32)))

Gradients for matrices.

In [11]:
# matrix

a = bm.arange(6.).reshape((2, 3))
b = bm.random.random((2, 3))

In [12]:
bm.vector_grad(f, argnums=1)(a, b)

JaxArray(DeviceArray([[0.       , 0.6934817, 1.9375703],
                      [2.142562 , 2.5830717, 4.9865813]], dtype=float32))

In [13]:
bm.vector_grad(f, argnums=(0, 1))(a, b)

(JaxArray(DeviceArray([[0.09120136, 0.72047424, 0.24790175],
                       [0.6999546 , 0.7635338 , 0.07321358]], dtype=float32)),
 JaxArray(DeviceArray([[0.       , 0.6934817, 1.9375703],
                       [2.142562 , 2.5830717, 4.9865813]], dtype=float32)))

Similar to [brainpy.math.grad()](../apis/auto/math/generated/brainpy.math.autograd.grad.rst) , ``brainpy.math.vector_grad()`` also supports take derivatives of variables in a class object. Here we show an simple example. 

In [23]:
class Test(bp.Base):
  def __init__(self):
    super(Test, self).__init__()
    self.x = bm.ones(5)
    self.y = bm.ones(5)

  def __call__(self):
    return self.x ** 2 + self.y ** 3 + 10

t = Test()

In [24]:
bm.vector_grad(t, grad_vars=t.x)()

JaxArray(DeviceArray([2., 2., 2., 2., 2.], dtype=float32))

In [25]:
bm.vector_grad(t, grad_vars=(t.x, ))()

(JaxArray(DeviceArray([2., 2., 2., 2., 2.], dtype=float32)),)

In [26]:
bm.vector_grad(t, grad_vars=(t.x, t.y))()

(JaxArray(DeviceArray([2., 2., 2., 2., 2.], dtype=float32)),
 JaxArray(DeviceArray([3., 3., 3., 3., 3.], dtype=float32)))

Other supports like ``return_value`` or ``has_aux`` in [brainpy.math.vector_grad()](../apis/auto/math/generated/brainpy.math.autograd.vector_grad.rst)  are the same with [brainpy.math.grad()](../apis/auto/math/generated/brainpy.math.autograd.grad.rst) .

## ``brainpy.math.jacobian()``

Another way to take gradients of a vector-output value is using [brainpy.math.jacobian()](../apis/auto/math/generated/brainpy.math.autograd.jacobian.rst). ``brainpy.math.jacobian()`` aims to automatically compute the Jacobian matrices $\partial f(x) \in \mathbb{R}^{m \times n}$ by the given function $f : \mathbb{R}^n \to \mathbb{R}^m$ at the given point of $x \in \mathbb{R}^n$. Here, we will not go to the details of the implementation and usage of the ``brainpy.math.jacobian()`` function. Instead, we only show two examples will deliveried on the pure function and class function. 

Given the following function, 

In [34]:
import jax.numpy as jnp

def f1(x, y):
    a = 4 * x[1] ** 2 - 2 * x[2]
    r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], a, x[2] * jnp.sin(x[0])])
    return r, a

In [35]:
_x = bm.array([1., 2., 3.])
_y = bm.array([10., 5.])
    
grads, vec, aux = bm.jacobian(f1, return_value=True, has_aux=True)(_x, _y)

In [36]:
grads

JaxArray(DeviceArray([[10.        ,  0.        ,  0.        ],
                      [ 0.        ,  0.        , 25.        ],
                      [ 0.        , 16.        , -2.        ],
                      [ 1.6209068 ,  0.        ,  0.84147096]], dtype=float32))

In [37]:
vec

DeviceArray([10.       , 75.       , 10.       ,  2.5244129], dtype=float32)

In [38]:
aux

DeviceArray(10., dtype=float32)

Given the following class objects,

In [39]:
class Test(bp.Base):
  def __init__(self):
    super(Test, self).__init__()
    self.x = bm.array([1., 2., 3.])

  def __call__(self, y):
    a = self.x[0] * y[0]
    b = 5 * self.x[2] * y[1]
    c = 4 * self.x[1] ** 2 - 2 * self.x[2]
    d = self.x[2] * jnp.sin(self.x[0])
    r = jnp.asarray([a, b, c, d])
    return r, (c, d)

In [41]:
t = Test()
f_grad = bm.jacobian(t, grad_vars=t.x, argnums=0, has_aux=True, return_value=True)

(var_grads, arg_grads), value, aux = f_grad(_y)

In [43]:
var_grads

JaxArray(DeviceArray([[10.        ,  0.        ,  0.        ],
                      [ 0.        ,  0.        , 25.        ],
                      [ 0.        , 16.        , -2.        ],
                      [ 1.6209068 ,  0.        ,  0.84147096]], dtype=float32))

In [44]:
arg_grads

JaxArray(DeviceArray([[ 1.,  0.],
                      [ 0., 15.],
                      [ 0.,  0.],
                      [ 0.,  0.]], dtype=float32))

In [45]:
value

DeviceArray([10.       , 75.       , 10.       ,  2.5244129], dtype=float32)

In [46]:
aux

(DeviceArray(10., dtype=float32), DeviceArray(2.5244129, dtype=float32))

More automatical differentation APIs please see our [API documentation](../apis/auto/math/autograd.rst).