# Variables

In BrainPy, the [JIT compilation](../apis/auto/math/generated/brainpy.math.jit.jit.rst) for class objects relies on [Variable](../apis/auto/math/generated/brainpy.math.jaxarray.Variable.rst). In this section, we are going to understand:

- what is ``Variable``?
- the subtypes of ``Variable``?

In [1]:
import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')

## Variable

``brainpy.math.Variable`` is a pointer refers to a [tensor](./tensors.ipynb). It stores the value of the tensor. The concrete data in a Variable can be changed. If a tensor is labeled as a Variable, it means that it is a dynamical variable, and its data can be changed. 

During the JIT compilation, the tensors which are not marked as Variable will be compiled as static data. The change of the tensor will not work, or cause an error. 

- **Create a Variable**

Passing a tensor into the ``brainpy.math.Variable`` creates a Variable, for example:

In [2]:
b1 = bm.random.random(5)
b1

JaxArray(DeviceArray([0.78549385, 0.7267256 , 0.25824118, 0.95101726, 0.56174445],            dtype=float32))

In [3]:
b2 = bm.Variable(b1)
b2

Variable(DeviceArray([0.78549385, 0.7267256 , 0.25824118, 0.95101726, 0.56174445],            dtype=float32))

- **Access the value in a Variable**

The concrete data of a Variable can be obtained through ``.value``.

In [4]:
b2.value

DeviceArray([0.78549385, 0.7267256 , 0.25824118, 0.95101726, 0.56174445],            dtype=float32)

In [5]:
(b2.value == b1).all()

DeviceArray(True, dtype=bool)

- **Supported operations on a Variable**

A Variable support almost all the operations for a [tensor](./tensors.ipynb). Actually, ``brainpy.math.Variable`` is a subclass of ``brainpy.math.ndarray``. 

In [6]:
isinstance(b2, bm.ndarray)

True

In [7]:
isinstance(b2, bm.JaxArray)

True

In [8]:
# `bp.math.ndarray` is an alias for `bp.math.JaxArray` in 'jax' backend

bm.ndarray is bm.JaxArray

True

```{note}
After performing any operation on a Variable, the resulting value will be a JaxArray (``brainpy.math.ndarray`` is an alias for ``brainpy.math.JaxArray``). This means that the Variable can only be used to refer to a value. 
```

In [9]:
b2 + 1.

JaxArray(DeviceArray([1.7854939, 1.7267256, 1.2582412, 1.9510173, 1.5617445], dtype=float32))

In [10]:
b2 ** 2

JaxArray(DeviceArray([0.6170006 , 0.52813005, 0.06668851, 0.90443385, 0.31555682],            dtype=float32))

In [11]:
bm.floor(b2)

JaxArray(DeviceArray([0., 0., 0., 0., 0.], dtype=float32))

- **Subtypes of Variable**

``brainpy.math.Variable`` has several subtypes, including ``brainpy.math.TrainVar`` and ``brainpy.math.Parameter``. Subtypes can also be customized and extended by the user. We are going to talk about this.

## TrainVar

``brainpy.math.TrainVar`` is a trainable variable (a subclass of ``brainpy.math.Variable``). Usually, the trainable variables are meant to require their gradients and compute the corresponding update values. However, users can also use TrainVar for other purpose. 

In [12]:
b = bm.random.rand(4)

b

JaxArray(DeviceArray([0.9473914 , 0.27128887, 0.01305449, 0.79503417], dtype=float32))

In [13]:
bm.TrainVar(b)

TrainVar(DeviceArray([0.9473914 , 0.27128887, 0.01305449, 0.79503417], dtype=float32))

## Parameter

``brainpy.math.Parameter`` is to label a dynamically changed parameter. It is also a subclass of ``brainpy.math.Variable``. The advantage of using Parameter rather than Variable is that it can be easily retrieved by the ``Collector.subsets`` method (please see [Base class](./base.ipynb)).

In [14]:
b = bm.random.rand(1)

b

JaxArray(DeviceArray([0.85893416], dtype=float32))

In [15]:
bm.Parameter(b)

Parameter(DeviceArray([0.85893416], dtype=float32))

## RandomState

``brainpy.math.random.RandomState`` is also a subclass of ``brainpy.math.Variable``. This is because the RandomState must store the dynamically changed **key** information (see [JAX random number designs](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers)). Every time after a RandomState performs a random sampling, the "key" will change. Therefore, it is worthy to label a RandomState as the Variable. 

In [16]:
state = bm.random.RandomState(seed=1234)

state

RandomState(DeviceArray([   0, 1234], dtype=uint32))

In [17]:
# perform a "random" sampling 
state.random(1)

state  # the value changed

RandomState(DeviceArray([2113592192, 1902136347], dtype=uint32))

In [18]:
# perform a "sample" sampling 
state.sample(1)

state  # the value changed too

RandomState(DeviceArray([1076515368, 3893328283], dtype=uint32))

Every instance of RandomState can create a new seed from the current seed with ``.split_key()``. 

In [19]:
state.split_key()

DeviceArray([3028232624,  826525938], dtype=uint32)

It can also create multiple seeds from the current seed with ``.split_keys(n)``. This is used internally by [pmap](../apis/auto/math/generated/brainpy.math.parallels.pmap.rst) and [vmap](../apis/auto/math/generated/brainpy.math.parallels.vmap.rst) to ensure that random numbers are different in parallel threads. 

In [20]:
state.split_keys(2)

DeviceArray([[4198471980, 1111166693],
             [1457783592, 2493283834]], dtype=uint32)

In [21]:
state.split_keys(5)

DeviceArray([[3244149147, 2659778815],
             [2548793527, 3057026599],
             [ 874320145, 4142002431],
             [3368470122, 3462971882],
             [1756854521, 1662729797]], dtype=uint32)

There is a default RandomState in ``brainpy.math.random`` module: `DEFAULT`. 

In [22]:
bm.random.DEFAULT

RandomState(DeviceArray([1597671135, 2052649380], dtype=uint32))

The inherent random methods like ``randint()``, ``rand()``, ``shuffle()``, etc. are using this DEFAULT state. If you try to change the default RandomState, please use ``seed()`` method. 

In [23]:
bm.random.seed(654321)

bm.random.DEFAULT

RandomState(DeviceArray([     0, 654321], dtype=uint32))