# Quantity

Braincore generates standard names for units, combining the unit name (e.g. “siemens”) with a prefixes (e.g. “m”), and also generates squared and cubed versions by appending a number. For example, the units “msiemens”, “siemens2”, “usiemens3” are all predefined. You can import these units from the package `brianunit` – accordingly, an `from brainunit import *` will result in everything being imported.

We recommend importing only the units you need, to have a cleaner namespace. For example, `import brainunit as bu` and then using `bu.msiemens` instead of `msiemens`.

In [1]:
import brainunit as bu

You can generate a physical quantity by multiplying a scalar or ndarray with its physical unit:

In [None]:
tau = 20 * bu.ms
tau

20. * msecond

In [None]:
rates = [10, 20, 30] * bu.Hz
rates

ArrayImpl([10., 20., 30.], dtype=float32) * hertz

In [None]:
rates = [[10, 20, 30], [20, 30, 40]] * bu.Hz
rates

ArrayImpl([[10., 20., 30.],
           [20., 30., 40.]], dtype=float32) * hertz

Braincore will check the consistency of operations on units and raise an error for dimensionality mismatches:

In [5]:
try:
    tau += 1  # ms? second?
except Exception as e:
    print(e)

Cannot calculate ... += 1, units do not match (units are s and 1).


In [6]:
try:
    3 * bu.kgram + 3 * bu.amp
except Exception as e:
    print(e)

Cannot calculate 3.0 + 3.0, units do not match (units are kg and A).


## Creating Quantity Instances

Creating a Quantity object can be accomplished in several ways, categorized based on the type of input used. Here, we present the methods grouped by their input types and characteristics for better clarity.

In [7]:
import jax.numpy as jnp
import brainstate as bst
bst.environ.set(precision=64) # we recommend using 64-bit precision for better numerical stability

#### Scalar and Array Multiplication
- Multiplying a Scalar with a Unit

In [8]:
5 * bu.ms

5. * msecond

- Multiplying a Jax nunmpy value type with a Unit:

In [9]:
jnp.float64(5) * bu.ms

5. * msecond

- Multiplying a Jax numpy array with a Unit:

In [10]:
jnp.array([1, 2, 3]) * bu.ms

ArrayImpl([1., 2., 3.]) * msecond

- Multiplying a List with a Unit:

In [11]:
[1, 2, 3] * bu.ms

ArrayImpl([1., 2., 3.]) * msecond

#### Direct Quantity Creation

- Creating a Quantity Directly with a Value

In [12]:
bu.Quantity(5)

Quantity(5.)

- Creating a Quantity Directly with a Value and Unit

In [13]:
bu.Quantity(5, unit=bu.ms)

5. * msecond

- Creating a Quantity with a Jax numpy Array of Values and a Unit

In [14]:
bu.Quantity(jnp.array([1, 2, 3]), unit=bu.ms)

ArrayImpl([1., 2., 3.]) * msecond

- Creating a Quantity with a List of Values and a Unit

In [15]:
bu.Quantity([1, 2, 3], unit=bu.ms)

ArrayImpl([1., 2., 3.]) * msecond

- Creating a Quantity with a List of Quantities

In [16]:
bu.Quantity([500 * bu.ms, 1 * bu.second])

ArrayImpl([0.5, 1. ]) * second

- Using the with_units Method

In [17]:
bu.Quantity.with_units(jnp.array([0.5, 1]), second=1)

ArrayImpl([0.5, 1. ]) * second

#### Unitless Quantity
Quantities can be unitless, which means they have no units. If there is no unit provided, the quantity is assumed to be unitless. The following are examples of creating unitless quantities:

In [18]:
bu.Quantity([1, 2, 3])

Quantity(ArrayImpl([1., 2., 3.]))

In [19]:
bu.Quantity(jnp.array([1, 2, 3]))

Quantity(ArrayImpl([1., 2., 3.]))

In [20]:
bu.Quantity([])

Quantity(ArrayImpl([], dtype=float64))

#### Illegal Quantity Creation
The following are examples of illegal quantity creation:

In [21]:
try:
    bu.Quantity([500 * bu.ms, 1])
except Exception as e:
    print(e)

All elements must have the same unit


In [22]:
try:
    bu.Quantity(["some", "nonsense"])
except Exception as e:
    print(e)

Value 'some' with dtype <U4 is not a valid JAX array type. Only arrays of numeric types are supported by JAX.


In [23]:
try:
    bu.Quantity([500 * bu.ms, 1 * bu.volt])
except Exception as e:
    print(e)

All elements must have the same unit


## Converting to Different Units

TODO

## Attributes of a Quantity

The important attributes of a `Quantity` object are:
- `value`: the numerical value of the quantity
- `dim`: the dimension of the unit of the quantity
- `ndim`: the number of dimensions of quantity's value
- `shape`: the shape of the quantity's value
- `size`: the size of the quantity's value
- `dtype`: the dtype of the quantity's value

### An example

In [24]:
rates = [[10, 20, 30], [20, 30, 40]] * bu.Hz
rates

ArrayImpl([[10., 20., 30.],
           [20., 30., 40.]]) * hertz

In [25]:
rates.value

Array([[10., 20., 30.],
       [20., 30., 40.]], dtype=float64, weak_type=True)

In [26]:
rates.dim

second ** -1

In [27]:
rates.ndim, rates.shape, rates.size, rates.dtype

(2, (2, 3), 6, dtype('float64'))

## Arithmetic Functions

Like Numpy and Jax numpy, arithmetic operators on arrays apply elementwise.

In [28]:
a = [20, 30, 40, 50] * bu.mV
b = jnp.arange(4) * bu.mV
b

ArrayImpl([0., 1., 2., 3.]) * mvolt

### Addition and Subtraction

Addition and subtraction of quantities need to have the same units and keep the units in the result.

In [29]:
c = a - b
c

ArrayImpl([20., 29., 38., 47.]) * mvolt

In [30]:
c + b

ArrayImpl([20., 30., 40., 50.]) * mvolt

### Multiplication and Division

Multiplication and division of quantities multiply and divide the values and add and subtract the dimensions of the units.

In [31]:
A = jnp.array([[1, 2], [3, 4]]) * bu.mV
B = jnp.array([[5, 6], [7, 8]]) * bu.mV

A, B

(ArrayImpl([[1., 2.],
            [3., 4.]]) * mvolt,
 ArrayImpl([[5., 6.],
            [7., 8.]]) * mvolt)

In [32]:
A * B # element-wise multiplication

ArrayImpl([[ 5., 12.],
           [21., 32.]]) * mvolt2

In [33]:
A @ B # matrix multiplication

ArrayImpl([[19., 22.],
           [43., 50.]]) * mvolt2

In [34]:
A.dot(B) # matrix multiplication

ArrayImpl([[19., 22.],
           [43., 50.]]) * mvolt2

In [35]:
A / 2 # divide by a scalar

ArrayImpl([[0.5, 1. ],
           [1.5, 2. ]]) * mvolt

if the unit of result is unitless, the unit is removed and returned as jax.Array

In [38]:
A / (2 * bu.mV) # divide by a quantity, return jax array

Array([[0.5, 1. ],
       [1.5, 2. ]], dtype=float64, weak_type=True)

In [40]:
A / (2 * bu.mA) # divide by a quantity, return quantity

ArrayImpl([[0.5, 1. ],
           [1.5, 2. ]]) * ohm

### Power

The power operator raises the value of the quantity to the power of the scalar, and multiplies the unit by the scalar.

In [48]:
A

ArrayImpl([[1., 2.],
           [3., 4.]]) * mvolt

In [49]:
A ** 2 # element-wise power

ArrayImpl([[ 1.,  4.],
           [ 9., 16.]]) * mvolt2

## Built-in Functions

Braincore provides a number of built-in functions in `Quantity` class to perform operations on quantities. These functions are:
- unary operations
    - positive(+)
    - negative(-)
    - absolute(abs)
    - invert(~)
- logical operations
    - all
    - any
- shape operations
    - reshape
    - resize
    - squeeze
    - unsqueeze
    - spilt
    - swapaxes
    - transpose
    - ravel
    - take
    - repeat
    - diagonal
    - trace
- mathematical functions
    - nonzero
    - argmax
    - argmin
    - argsort
    - var
    - round
    - std
    - sum
    - cumsum
    - cumprod
    - max
    - mean
    - min
    - ptp
    - clip
    - conj
    - dot
    - fill
    - item
    - prod
    - clamp
    - sort

For more details on these functions, refer to the [documentation](https://braincore.readthedocs.io/en/latest/braincore/apis/generated/brainunit.Quantity.html).

## Indexing, Slicing and Iterating

One-dimensional Quantity can be indexed, sliced and iterated over, much like lists and other Python sequences.

In [41]:
a = jnp.arange(10) ** 3 * bu.mV
a

ArrayImpl([  0.,   1.,   8.,  27.,  64., 125., 216., 343., 512., 729.]) * mvolt

In [42]:
a[2]

8. * mvolt

In [43]:
a[2:5]

ArrayImpl([ 8., 27., 64.]) * mvolt

Only same dimension Quantity can be set to a slice of a Quantity.

In [44]:
# equivalent to a[0:6:2] = 1000;
# from start to position 6, exclusive, set every 2nd element to 1000
a[:6:2] = 1000 * bu.mV
a

ArrayImpl([1000.,    1., 1000.,   27., 1000.,  125.,  216.,  343.,  512.,
            729.]) * mvolt

In [45]:
a[::-1] # reversed a

ArrayImpl([ 729.,  512.,  343.,  216.,  125., 1000.,   27., 1000.,    1.,
           1000.]) * mvolt

In [46]:
for i in a:
    print(i**(1 / 3.))

1. m^0.6666666666666666 kg^0.3333333333333333 s^-1.0 A^-0.3333333333333333
0.1 m^0.6666666666666666 kg^0.3333333333333333 s^-1.0 A^-0.3333333333333333
1. m^0.6666666666666666 kg^0.3333333333333333 s^-1.0 A^-0.3333333333333333
0.3 m^0.6666666666666666 kg^0.3333333333333333 s^-1.0 A^-0.3333333333333333
1. m^0.6666666666666666 kg^0.3333333333333333 s^-1.0 A^-0.3333333333333333
0.5 m^0.6666666666666666 kg^0.3333333333333333 s^-1.0 A^-0.3333333333333333
0.6 m^0.6666666666666666 kg^0.3333333333333333 s^-1.0 A^-0.3333333333333333
0.7 m^0.6666666666666666 kg^0.3333333333333333 s^-1.0 A^-0.3333333333333333
0.8 m^0.6666666666666666 kg^0.3333333333333333 s^-1.0 A^-0.3333333333333333
0.9 m^0.6666666666666666 kg^0.3333333333333333 s^-1.0 A^-0.3333333333333333


Multidimensional Quantity can have one index per axis. These indices are given in a tuple separated by commas:

In [50]:
def f(x, y):
    return 10 * x + y
b = jnp.fromfunction(f, (5, 4), dtype=jnp.int32) * bu.mV
b

ArrayImpl([[ 0.,  1.,  2.,  3.],
           [10., 11., 12., 13.],
           [20., 21., 22., 23.],
           [30., 31., 32., 33.],
           [40., 41., 42., 43.]]) * mvolt

In [51]:
b[2, 3]

23. * mvolt

In [52]:
b[0:5, 1]  # each row in the second column of b

ArrayImpl([ 1., 11., 21., 31., 41.]) * mvolt

In [53]:
b[:, 1]  # equivalent to the previous example

ArrayImpl([ 1., 11., 21., 31., 41.]) * mvolt

In [54]:
b[1:3, :]  # each column in the second and third row of b

ArrayImpl([[10., 11., 12., 13.],
           [20., 21., 22., 23.]]) * mvolt

When fewer indices are provided than the number of axes, the missing indices are considered complete slices:

In [55]:
b[-1]

ArrayImpl([40., 41., 42., 43.]) * mvolt

The expression within brackets in b[i] is treated as an i followed by as many instances of : as needed to represent the remaining axes. NumPy also allows you to write this using dots as b[i, ...].

The dots (...) represent as many colons as needed to produce a complete indexing tuple. For example, if x is a Quantity with 5 axes, then
- x[1, 2, ...] is equivalent to x[1, 2, :, :, :],
- x[..., 3] to x[:, :, :, :, 3] and
- x[4, ..., 5, :] to x[4, :, :, 5, :].

In [56]:
c = jnp.array([[[0, 1, 2], [10, 12, 13]], [[100, 101, 102], [110, 112, 113]]]) * bu.mV # a 3D array (two stacked 2D arrays)
c.shape

(2, 2, 3)

In [57]:
c[1, ...] # same as c[1, :, :] or c[1]

ArrayImpl([[100., 101., 102.],
           [110., 112., 113.]]) * mvolt

In [58]:
c[..., 2] # same as c[:, :, 2]

ArrayImpl([[  2.,  13.],
           [102., 113.]]) * mvolt

Iterating over multidimensional Quantity is done with respect to the first axis:

In [59]:
for row in b:
    print(row)

[0. 1. 2. 3.] mV
[10. 11. 12. 13.] mV
[20. 21. 22. 23.] mV
[30. 31. 32. 33.] mV
[40. 41. 42. 43.] mV
