# Arrays

[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/brainpy/blob/master/docs/tutorial_math/array.ipynb)

@[Chaoming Wang](https://github.com/chaoming0625)
@[Xiaoyu Chen](mailto:c-xy17@tsinghua.org.cn)

```{note}
If you have the basic knowledge about [NumPy](https://numpy.org/) (the ``array`` here is the same as the ``ndarray`` in NumPy), you can skip this section.
```

In this section, we are going to understand:

- What is a ``array``?
- How to create a ``array``?
- What operations are supported for a ``array``?

In [1]:
import brainpy.math as bm

bm.set_platform('cpu')

## What is ``array``?

A array is a homogeneous multidimensional array. It is a table of elements (usually numbers), all of the same type, indexed by a tuple of non-negative integers. The dimensions of an array are called **axes**.

In the following picture, the 1D array (`[7, 2, 9, 10]`) only has one axis. There are 4 elements in this axis, so the shape of the array is `(4,)`. 

By contrast, the 2D array

```python
[[5.2, 3.0, 4.5], 
 [9.1, 0.1, 0.3]]
```
has 2 axes. The first axis has a length of 2 and the second has a length of 3. Therefore, the shape of the 2D array is `(2, 3)`. 

Similarly, the 3D array has 3 axes, with the dimensions `(4, 3, 2)` in each axis, respectively. 

<center><img src="../_static/numpy_arrays.png" width="600 px"></center>

A array has several important attributes:

- **.ndim**: the number of axes (dimensions) of the array.

- **.shape**: the dimensions of the array. This is a tuple of integers indicating the size of the array in each dimension. For a matrix with n rows and m columns, the shape will be `(n,m)`. The length of the shape tuple is therefore the number of axes, `ndim`.

- **.size**: the total number of elements of the array. This is equal to the product of the elements of shape.

- **.dtype**: an object describing the type of the elements in the array. One can create or specify dtypes using standard Python types.

In [2]:
a = bm.arange(15).reshape((3, 5))

a

JaxArray(DeviceArray([[ 0,  1,  2,  3,  4],
                      [ 5,  6,  7,  8,  9],
                      [10, 11, 12, 13, 14]], dtype=int32))

In [3]:
a.shape

(3, 5)

In [4]:
a.ndim

2

In [5]:
a.dtype

dtype('int32')

## How to create a ``array``?

There are several ways to create a array.

### 1. ``array()``, ``zeros()`` and ``ones()``

The basic method is to convert Python sequences into arrays by ``bm.array()``. For example：

In [6]:
bm.array([2, 3, 4])

JaxArray(DeviceArray([2, 3, 4], dtype=int32))

In [7]:
bm.array([(1.5, 2, 3), (4, 5, 6)])

JaxArray(DeviceArray([[1.5, 2. , 3. ],
                      [4. , 5. , 6. ]], dtype=float32))

Often, the elements of an array are originally unknown, but its size is known. Therefore, you can use placeholder functions to create arrays, like:

In [8]:
# "bm.zeros()" creates an array full of zeros

bm.zeros((3, 4))

JaxArray(DeviceArray([[0., 0., 0., 0.],
                      [0., 0., 0., 0.],
                      [0., 0., 0., 0.]], dtype=float32))

In [9]:
# "bm.ones()" creates an array full of ones

bm.ones((3, 4))

JaxArray(DeviceArray([[1., 1., 1., 1.],
                      [1., 1., 1., 1.],
                      [1., 1., 1., 1.]], dtype=float32))

### 2. ``linspace()`` and ``arange()``

Another two commonly used 1D array creation functions are ``bm.linspace()`` and ``bm.arange()``. 

``bm.arange()`` creates arrays with regularly incrementing values. It receives "start", "end", and "step" settings. 

In [10]:
# if only one argument "A" are provided, the function will  
# recognize the "start = 0", "end = A", and "step = 1" .

bm.arange(10)  

JaxArray(DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32))

In [11]:
# if two argument "A, B" are provided, the function will  
# recognize the "start = A", "end = B", and "step = 1" .

bm.arange(2, 10, dtype=float)

JaxArray(DeviceArray([2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32))

In [12]:
# if three argument "A, B, C" are provided, the function will  
# recognize the "start = A", "end = B", and "step = C" .

bm.arange(2, 3, 0.1)

JaxArray(DeviceArray([2.       , 2.1      , 2.1999998, 2.2999997, 2.3999996,
                      2.4999995, 2.5999994, 2.6999993, 2.7999992, 2.8999991],            dtype=float32))

Due to the finite floating-point precision, it is generally impossible to predict the number of elements obtained by ``bm.arange()``. For this reason, it is usually better to use the function ``bm.linspace()`` that receives  "start", "end", and "num" settings.

In [13]:
bm.linspace(2, 3, 10)

JaxArray(DeviceArray([2.       , 2.1111112, 2.2222223, 2.3333333, 2.4444447,
                      2.5555556, 2.6666665, 2.777778 , 2.8888888, 3.       ],            dtype=float32))

### 3. Random sampling

``brainpy.math`` module provides convenient [random sampling functions](../apis/math/comparison.html#random-sampling). This module contains some simple random data generation methods, some permutation and distribution functions, and random generator functions. Here I just give several examples. 

- ``brainpy.math.random.rand(d0, d1, ..., dn)``

This function of random module is used to generate random numbers or values in a given shape.

In [14]:
bm.random.rand(5, 2)

JaxArray(DeviceArray([[0.07216692, 0.2877245 ],
                      [0.29642737, 0.43941212],
                      [0.9228879 , 0.14709306],
                      [0.8345591 , 0.8134085 ],
                      [0.6776234 , 0.42747045]], dtype=float32))

- ``brainpy.math.random.randn(d0, d1, ..., dn)``

This function of random module returns a sample from the "standard normal" distribution.

In [15]:
bm.random.randn(5, 2)

JaxArray(DeviceArray([[-1.2456564 ,  0.93986976],
                      [ 1.2722825 , -0.5604058 ],
                      [ 0.42648995,  1.2291526 ],
                      [ 0.7283678 ,  0.12521066],
                      [ 0.29101485,  2.0382316 ]], dtype=float32))

- ``brainpy.math.random.randint(low, high[, size, dtype])``

This function of random module is used to generate random integers from inclusive(low) to exclusive(high).

In [16]:
bm.random.randint(0, 3, size=10)  

JaxArray(DeviceArray([0, 0, 0, 1, 2, 0, 1, 1, 1, 1], dtype=int32))

- ``brainpy.math.random.random([size])``

This function of random module is used to generate random floats number in the half-open interval [0.0, 1.0).

In [17]:
bm.random.random((3, 2))

JaxArray(DeviceArray([[0.30104673, 0.08174968],
                      [0.46729672, 0.30544508],
                      [0.37684703, 0.7211865 ]], dtype=float32))

``brainpy.math`` module also provides permutation functions.

- ``brainpy.math.random.shuffle()``

This function is used to modify a sequence in-place by shuffling its contents.

In [18]:
bm.random.shuffle( bm.arange(10) )  

JaxArray(DeviceArray([5, 9, 4, 0, 3, 2, 6, 8, 7, 1], dtype=int32))

- ``brainpy.math.random.permutation()``

This function permute a sequence randomly or return a permuted range.

In [19]:
bm.random.permutation( bm.arange(10) )  

JaxArray(DeviceArray([1, 4, 8, 5, 2, 7, 3, 9, 0, 6], dtype=int32))

``brainpy.math`` module also provides functions to sample distributions.

- ``beta(a, b[, size])``

This function is used to draw samples from a Beta distribution.

In [20]:
bm.random.beta(2, 3, 10) 

JaxArray(DeviceArray([0.5404635 , 0.5812938 , 0.22380598, 0.20873289, 0.5460086 ,
                      0.11106081, 0.69154614, 0.28446677, 0.24081689, 0.6313805 ],            dtype=float32))

- ``exponential([scale, size])``

This function is used to draw sample from an exponential distribution.

In [21]:
bm.random.exponential(1, 10) 

JaxArray(DeviceArray([0.828246  , 0.21370666, 0.32322478, 1.7631028 , 2.4889555 ,
                      1.133313  , 0.37169302, 0.32336047, 0.34985116, 0.2314292 ],            dtype=float32))

More sampling methods please see [random sampling functions](../apis/math/comparison.html#random-sampling).

### 4. Other methods

Moreover, there are many other methods we can use to create arrays, including:

- Conversion from other Python structures (i.e. lists and tuples)
- Intrinsic NumPy array creation functions (e.g. ``arange``, ``ones``, ``zeros``, etc.)
- Use of special library functions (e.g., ``random``)
- Replicating, joining, or mutating existing arrays
- Reading arrays from disk, either from standard or custom formats
- Creating arrays from raw bytes through the use of strings or buffers

For details of these methods, please see [NumPy tutorial: Array creation](https://numpy.org/doc/stable/user/basics.creation.html). Most of these methods are supported in BrainPy. 

## Supported operations on ``array``

All the operations in BrainPy are based on arrays. Therefore it is necessary to know what operations supported in each array object.

### Basic operations

Arithmetic operators on arrays apply element-wise. Let's take "+", "-", "\*", and "/" as examples.

We first create two arrays:

In [22]:
data = bm.array([1, 2])

data

JaxArray(DeviceArray([1, 2], dtype=int32))

In [23]:
ones = bm.ones(2)

ones

JaxArray(DeviceArray([1., 1.], dtype=float32))

![](../_static/array_dataones.png)

In [24]:
data + ones

JaxArray(DeviceArray([2., 3.], dtype=float32))

![](../_static/array_plus_ones.png)

In [25]:
data - ones

JaxArray(DeviceArray([0., 1.], dtype=float32))

In [26]:
data * data

JaxArray(DeviceArray([1, 4], dtype=int32))

In [27]:
data / data

JaxArray(DeviceArray([1., 1.], dtype=float32))

![](../_static/array_sub_mult_divide.png)

Aggregation functions can also be performed on arrays, like:

- ``.min()``: to get the minimum element;
- ``.max()``: to get the maximum element;
- ``.sum()``: to get the summation;
- ``.mean()``: to get the average;
- ``.prod()``: to get the result of multiplying the elements together;
- ``.std()``: to get the standard deviation.

In [28]:
data = bm.array([1, 2, 3])

data.max()

DeviceArray(3, dtype=int32)

In [29]:
data.min()

DeviceArray(1, dtype=int32)

In [30]:
data.sum()

DeviceArray(6, dtype=int32)

![](../_static/arraye_aggregation.png)

It's very common to specify the aggregation along a row or column. For example, you can find the maximum value within each column by specifying ``axis=0``.

In [31]:
a = bm.array([[1, 2],
              [5, 3],
              [4, 6]])

In [32]:
a.max(axis=0)

JaxArray(DeviceArray([5, 6], dtype=int32))

In [33]:
a.max(axis=1)

JaxArray(DeviceArray([2, 5, 6], dtype=int32))

![](../_static/array_matrix_aggregation_row.png)

### Broadcasting

array operations are usually done on pairs of arrays on an element-by-element basis. In the simplest case, the two arrays must have exactly the same shape, as in the following example:

In [34]:
a = bm.array([1.0, 2.0, 3.0])
b = bm.array([2.0, 2.0, 2.0])

a * b

JaxArray(DeviceArray([2., 4., 6.], dtype=float32))

However, the **broadcasting** rule may be relaxed when the shapes of the arrays meet certain constraints. The simplest broadcasting example occurs when a array and a scalar value are combined in an operation:

In [35]:
a = bm.array([1, 2])
b = 1.6

a * b

JaxArray(DeviceArray([1.6, 3.2], dtype=float32, weak_type=True))

![](../_static/multiply_broadcasting.png)

Similarly, broadcasting can be applied to matrices:

In [36]:
data = bm.array([[1, 2],
                 [3, 4],
                 [5, 6]])

ones_row = bm.array([[1, 1]])

data + ones_row

JaxArray(DeviceArray([[2, 3],
                      [4, 5],
                      [6, 7]], dtype=int32))

![](../_static/matrix_broadcasting.png)

Under certain constraints, the smaller array can be "broadcast" across the larger array so that they have compatible shapes. Broadcasting provides a means of vectorizing array operations so that looping occurs in C instead of Python. It does this without making needless copies of data and usually leads to efficient algorithm implementations.

Generally, the dimensions of two arrays are compatible when

- they are equal,

- one of them is 1, or

- one of them has fewer dimensions

If these conditions are not met, an error will occur. 

For example, according to the broadcasting rules, the following two shapes are compatible:

```bash
Image  (3d array): 256 x 256 x 3
Scale  (1d array):             3
Result (3d array): 256 x 256 x 3
```

In [37]:
image = bm.random.random((256, 256, 3))
scale = bm.random.random(3)

_ = image + scale 
_ = image - scale 
_ = image * scale 
_ = image / scale 

These shapes are also compatible:

```bash
A      (4d array):  8 x 1 x 6 x 1
B      (3d array):      7 x 1 x 5
Result (4d array):  8 x 7 x 6 x 5
```

In [38]:
A = bm.random.random((8, 1, 6, 1))
B = bm.random.random((7, 1, 5))

_ = A + B 
_ = A - B 
_ = A * B 
_ = A / B 

However, in these examples, the shapes cannot be broadcast:

```bash
A      (1d array):  3
B      (1d array):  4 # trailing dimensions do not match

A      (2d array):      2 x 1
B      (3d array):  8 x 4 x 3 # second from last dimensions mismatched
```

In [39]:
A = bm.random.random((3,))
B = bm.random.random((4,))

try:
    _ = A + B
except Exception as e:
    print(e)

add got incompatible shapes for broadcasting: (3,), (4,).


In [40]:
A = bm.random.random((2, 1))
B = bm.random.random((8, 4, 3))

try:
    _ = A + B
except Exception as e:
    print(e)

Incompatible shapes for broadcasting: ((1, 2, 1), (8, 4, 3))


For more details about broadcasting, please see [NumPy documentation: broadcasting](https://numpy.org/doc/stable/user/basics.broadcasting.html).

### Indexing, Slicing and Iterating

arrays can be indexed, sliced, and iterated over, much like lists and other Python sequences. For examples:

In [41]:
a = bm.arange(10) ** 3

a

JaxArray(DeviceArray([  0,   1,   8,  27,  64, 125, 216, 343, 512, 729], dtype=int32))

In [42]:
a[2]

DeviceArray(8, dtype=int32)

In [43]:
a[2:5]

DeviceArray([ 8, 27, 64], dtype=int32)

In [44]:
# from start to position 6, exclusive, set every 2nd element to 1000,
# equivalent to a[0:6:2] = 1000

a[:6:2] = 1000

a

JaxArray(DeviceArray([1000,    1, 1000,   27, 1000,  125,  216,  343,  512,  729], dtype=int32))

In [45]:
a[::-1]  # reversed a

DeviceArray([ 729,  512,  343,  216,  125, 1000,   27, 1000,    1, 1000], dtype=int32)

In [46]:
for i in a:  # iterate a
    print(i**(1 / 3.))

10.000001
1.0
10.000001
3.0
10.000001
5.0000005
6.0000005
7.0000005
8.000001
9.000001


For multi-dimensional arrays, these indices should be given in a tuple separated by commas. For example:

In [47]:
b = bm.arange(20).reshape((5, 4))

In [48]:
b[2, 3]

DeviceArray(11, dtype=int32)

In [49]:
b[0:5, 1]  # each row in the second column of b

DeviceArray([ 1,  5,  9, 13, 17], dtype=int32)

In [50]:
b[:, 1]    # equivalent to the previous example

DeviceArray([ 1,  5,  9, 13, 17], dtype=int32)

In [51]:
b[1:3, :]  # each column in the second and third row of b

DeviceArray([[ 4,  5,  6,  7],
             [ 8,  9, 10, 11]], dtype=int32)

When fewer indices are provided than the number of axes, the missing indices are considered complete slices:

In [52]:
b[-1]   # the last row. Equivalent to b[-1, :]

DeviceArray([16, 17, 18, 19], dtype=int32)

You can also write this using dots as ``b[i, ...]``. The dots ``(...)`` represent as many colons as needed to produce a complete indexing tuple.  For example, if x is an array with 5 axes, then

- ``x[1, 2, ...]`` is equivalent to ``x[1, 2, :, :, :]``,

- ``x[..., 3]`` to ``x[:, :, :, :, 3]`` and

- ``x[4, ..., 5, :]`` to ``x[4, :, :, 5, :]``.

In [53]:
c = bm.arange(48).reshape((6, 4, 2))

In [54]:
c[1, ...]  # same as c[1, :, :] or c[1]

DeviceArray([[ 8,  9],
             [10, 11],
             [12, 13],
             [14, 15]], dtype=int32)

In [55]:
c[..., 1]  # same as c[:, :, 1]

DeviceArray([[ 1,  3,  5,  7],
             [ 9, 11, 13, 15],
             [17, 19, 21, 23],
             [25, 27, 29, 31],
             [33, 35, 37, 39],
             [41, 43, 45, 47]], dtype=int32)

Iterating over multidimensional arrays is done with respect to the first axis:

In [56]:
for row in b:
    print(row)

[0 1 2 3]
[4 5 6 7]
[ 8  9 10 11]
[12 13 14 15]
[16 17 18 19]


For more methods or advanced indexing and index tricks, please see [NumPy tutorial: Indexing](https://numpy.org/doc/stable/user/basics.indexing.html).

### Mathematical functions

arrays support many other functions, including

- [mathematical functions](https://numpy.org/doc/stable/reference/routines.math.html)
- [logical functions](https://numpy.org/doc/stable/reference/routines.logic.html)
- [statistics functions](https://numpy.org/doc/stable/reference/routines.statistics.html)
- [window functions](https://numpy.org/doc/stable/reference/routines.window.html)
- [sorting, searching, and counting functions](https://numpy.org/doc/stable/reference/routines.sort.html)
- and others.

Most of these functions can be found in ``brainpy.math`` module. Let's take a look at trigonometric, hyperbolic, and rounding functions.

In [57]:
d = bm.linspace(0, 1, 10)

d

JaxArray(DeviceArray([0.        , 0.11111111, 0.22222222, 0.33333334, 0.44444445,
                      0.5555556 , 0.6666667 , 0.7777778 , 0.8888889 , 1.        ],            dtype=float32))

In [58]:
# trigonometric functions

bm.sin(d)

JaxArray(DeviceArray([0.        , 0.11088263, 0.22039774, 0.32719472, 0.42995638,
                      0.5274154 , 0.6183698 , 0.7016979 , 0.7763719 , 0.84147096],            dtype=float32))

In [59]:
bm.arcsin(d)

JaxArray(DeviceArray([0.        , 0.11134101, 0.2240931 , 0.3398369 , 0.460554  ,
                      0.58903104, 0.7297277 , 0.8911225 , 1.0949141 , 1.5707964 ],            dtype=float32))

In [60]:
# hyperbolic functions

bm.sinh(d)

JaxArray(DeviceArray([0.        , 0.11133985, 0.22405571, 0.33954054, 0.45922154,
                      0.58457786, 0.7171585 , 0.8586021 , 1.0106566 , 1.1752012 ],            dtype=float32))

In [61]:
# rounding functions

bm.round(d)

JaxArray(DeviceArray([0., 0., 0., 0., 0., 1., 1., 1., 1., 1.], dtype=float32))

In [62]:
# sum function

bm.sum(d)

DeviceArray(5., dtype=float32)