# Math Operations with Quantity

[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chaobrain/brainunit/blob/master/docs/physical_units/math_operations_with_quantity.ipynb)
[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/chaobrain/brainunit/blob/master/docs/physical_units/math_operations_with_quantity.ipynb)

In [1]:
import brainunit as u
import jax.numpy as jnp

Like Numpy and Jax numpy, arithmetic operators on arrays apply elementwise.

In [2]:
a = [20, 30, 40, 50] * u.mV
b = jnp.arange(4) * u.mV
b

ArrayImpl([0, 1, 2, 3]) * mvolt

## Addition and Subtraction

Addition and subtraction of quantities need to have the same units and keep the units in the result.

In [3]:
c = a - b
c

ArrayImpl([20, 29, 38, 47]) * mvolt

In [4]:
c + b

ArrayImpl([20, 30, 40, 50]) * mvolt

## Multiplication and Division

Multiplication and division of quantities multiply and divide the values and add and subtract the dimensions of the units.

In [5]:
A = jnp.array([[1, 2], [3, 4]]) * u.mV
B = jnp.array([[5, 6], [7, 8]]) * u.mV

A, B

(ArrayImpl([[1, 2],
            [3, 4]]) * mvolt,
 ArrayImpl([[5, 6],
            [7, 8]]) * mvolt)

In [6]:
A * B # element-wise multiplication

ArrayImpl([[ 5, 12],
           [21, 32]]) * mvolt2

In [7]:
A @ B # matrix multiplication

ArrayImpl([[19, 22],
           [43, 50]]) * mvolt2

In [8]:
A.dot(B) # matrix multiplication

ArrayImpl([[19, 22],
           [43, 50]]) * mvolt2

In [9]:
A / 2 # divide by a scalar

ArrayImpl([[0.5, 1. ],
           [1.5, 2. ]], dtype=float32) * mvolt

if the unit of result is unitless, the unit is removed and returned as jax.Array

In [10]:
A / (2 * u.mV) # divide by a quantity, return jax array

Array([[0.5, 1. ],
       [1.5, 2. ]], dtype=float32)

In [11]:
A / (2 * u.mA) # divide by a quantity, return quantity

ArrayImpl([[0.5, 1. ],
           [1.5, 2. ]], dtype=float32) * ohm

## Power

The power operator raises the value of the quantity to the power of the scalar, and multiplies the unit by the scalar.

In [12]:
A

ArrayImpl([[1, 2],
           [3, 4]]) * mvolt

In [13]:
A ** 2 # element-wise power

ArrayImpl([[ 1,  4],
           [ 9, 16]]) * mvolt2

## Built-in Functions

Brainunit provides a number of built-in functions in `Quantity` class to perform operations on quantities. These functions are:
- unary operations
    - positive(+)
    - negative(-)
    - absolute(abs)
    - invert(~)
- logical operations
    - all
    - any
- shape operations
    - reshape
    - resize
    - squeeze
    - unsqueeze
    - spilt
    - swapaxes
    - transpose
    - ravel
    - take
    - repeat
    - diagonal
    - trace
- mathematical functions
    - nonzero
    - argmax
    - argmin
    - argsort
    - var
    - round
    - std
    - sum
    - cumsum
    - cumprod
    - max
    - mean
    - min
    - ptp
    - clip
    - conj
    - dot
    - fill
    - item
    - prod
    - clamp
    - sort

For more details on these functions, refer to the [documentation](https://brainunit.readthedocs.io/en/latest/apis/generated/brainunit.Quantity.html).

## Indexing, Slicing and Iterating

One-dimensional Quantity can be indexed, sliced and iterated over, much like lists and other Python sequences.

In [14]:
a = jnp.arange(10) ** 3 * u.mV
a

ArrayImpl([  0,   1,   8,  27,  64, 125, 216, 343, 512, 729]) * mvolt

In [15]:
a[2]

8 * mvolt

In [16]:
a[2:5]

ArrayImpl([ 8, 27, 64]) * mvolt

Only same dimension Quantity can be set to a slice of a Quantity.

In [17]:
# equivalent to a[0:6:2] = 1000;
# from start to position 6, exclusive, set every 2nd element to 1000
a[:6:2] = 1000 * u.mV
a

ArrayImpl([1000,    1, 1000,   27, 1000,  125,  216,  343,  512,  729]) * mvolt

In [18]:
a[::-1] # reversed a

ArrayImpl([ 729,  512,  343,  216,  125, 1000,   27, 1000,    1, 1000]) * mvolt

In [19]:
for i in a:
    print(i**(1 / 3.))

10.00000095 * mvolt ** 0.3333333333333333
1. * mvolt ** 0.3333333333333333
10.00000095 * mvolt ** 0.3333333333333333
3. * mvolt ** 0.3333333333333333
10.00000095 * mvolt ** 0.3333333333333333
5.00000048 * mvolt ** 0.3333333333333333
6.00000048 * mvolt ** 0.3333333333333333
7.00000048 * mvolt ** 0.3333333333333333
8.00000095 * mvolt ** 0.3333333333333333
9.00000095 * mvolt ** 0.3333333333333333


Multidimensional Quantity can have one index per axis. These indices are given in a tuple separated by commas:

In [20]:
def f(x, y):
    return 10 * x + y
b = jnp.fromfunction(f, (5, 4), dtype=jnp.int32) * u.mV
b

ArrayImpl([[ 0,  1,  2,  3],
           [10, 11, 12, 13],
           [20, 21, 22, 23],
           [30, 31, 32, 33],
           [40, 41, 42, 43]]) * mvolt

In [21]:
b[2, 3]

23 * mvolt

In [22]:
b[0:5, 1]  # each row in the second column of b

ArrayImpl([ 1, 11, 21, 31, 41]) * mvolt

In [23]:
b[:, 1]  # equivalent to the previous example

ArrayImpl([ 1, 11, 21, 31, 41]) * mvolt

In [24]:
b[1:3, :]  # each column in the second and third row of b

ArrayImpl([[10, 11, 12, 13],
           [20, 21, 22, 23]]) * mvolt

When fewer indices are provided than the number of axes, the missing indices are considered complete slices:

In [25]:
b[-1]

ArrayImpl([40, 41, 42, 43]) * mvolt

The expression within brackets in b[i] is treated as an i followed by as many instances of : as needed to represent the remaining axes. NumPy also allows you to write this using dots as b[i, ...].

The dots (...) represent as many colons as needed to produce a complete indexing tuple. For example, if x is a Quantity with 5 axes, then
- x[1, 2, ...] is equivalent to x[1, 2, :, :, :],
- x[..., 3] to x[:, :, :, :, 3] and
- x[4, ..., 5, :] to x[4, :, :, 5, :].

In [26]:
c = jnp.array([[[0, 1, 2], [10, 12, 13]], [[100, 101, 102], [110, 112, 113]]]) * u.mV # a 3D array (two stacked 2D arrays)
c.shape

(2, 2, 3)

In [27]:
c[1, ...] # same as c[1, :, :] or c[1]

ArrayImpl([[100, 101, 102],
           [110, 112, 113]]) * mvolt

In [28]:
c[..., 2] # same as c[:, :, 2]

ArrayImpl([[  2,  13],
           [102, 113]]) * mvolt

Iterating over multidimensional Quantity is done with respect to the first axis:

In [29]:
for row in b:
    print(row)

ArrayImpl([0, 1, 2, 3]) * mvolt
ArrayImpl([10, 11, 12, 13]) * mvolt
ArrayImpl([20, 21, 22, 23]) * mvolt
ArrayImpl([30, 31, 32, 33]) * mvolt
ArrayImpl([40, 41, 42, 43]) * mvolt


## Operating on Subsets

`.at` method can be used to operate on a subset of the Quantity. The following are examples of operating on subsets of a Quantity:

In [30]:
q = jnp.arange(5.0) * u.mV
q

ArrayImpl([0., 1., 2., 3., 4.], dtype=float32) * mvolt

In [31]:
q.at[2].add(10 * u.mV)

ArrayImpl([ 0.,  1., 12.,  3.,  4.], dtype=float32) * mvolt

In [32]:
q.at[10].add(10 * u.mV)  # out-of-bounds indices are ignored

ArrayImpl([0., 1., 2., 3., 4.], dtype=float32) * mvolt

In [33]:
q.at[20].add(10 * u.mV, mode='clip') # out-of-bounds indices are clipped

ArrayImpl([ 0.,  1.,  2.,  3., 14.], dtype=float32) * mvolt

In [34]:
q.at[2].get()

2. * mvolt

In [35]:
q.at[20].get()  # out-of-bounds indices clipped

4. * mvolt

In [36]:
q.at[20].get(mode='fill')  # out-of-bounds indices filled with NaN

nan * mvolt

Brainunit will check the consistency of operations on units and raise an error for dimensionality mismatches:

In [37]:
try:
    q.at[2].add(10)
except Exception as e:
    print(e)

Cannot convert to a unit with different dimensions. (units are Unit(10.0^0) and mV).


Brainunit also allows customized fill values for the `at` method:

In [38]:
q.at[20].get(mode='fill', fill_value=-1 * u.mV)  # custom fill value

-1. * mvolt

In [39]:
try:
    q.at[20].get(mode='fill', fill_value=-1)
except Exception as e:
    print(e)

Cannot convert to a unit with different dimensions. (units are Unit(10.0^0) and mV).
