# Combining, Defining, and Displaying Units

## Basic example

Units and quantities can be combined together using the regular Python numeric operators:

In [12]:
import brainunit as bu
volt = bu.meter2 * bu.kilogram / (bu.second3 * bu.ampere)
volt == bu.volt

True

## Defining units

Users are free to define new units, either fundamental or compound, using the `Unit.create` and  `Unit.create_scaled_unit` function:

##### Creating Basic Units
First, we create some basic units, such as meters (metre) and seconds (second):

In [None]:
from brainunit import Unit, get_or_create_dimension

# Creating a basic unit: metre
metre = Unit.create(get_or_create_dimension(m=1), "metre", "m")

# Creating a basic unit: second
second = Unit.create(get_or_create_dimension(s=1), "second", "s")

metre, second

(metre, second)

Here, `get_or_create_dimension(m=1)` creates a dimension object representing length (meters), and `Unit.create` uses this dimension to create a unit named "metre" with a display name "m".

##### Creating Compound Units
Next, we create a compound unit, such as volt(metre ^ 2 * kilogram / (second ^ 3 * ampere)):

In [None]:
volt = Unit.create(get_or_create_dimension(m=2, kg=1, s=-3, A=-1), "volt", "V")

volt

volt

In this example, we define the dimensions for the compound unit and create a new unit named "volt" with the specified dimensions.

##### Creating Scaled Units
Finally, we create a scaled version of a basic unit, such as kilometers (kilometre):

In [None]:
kilometre = Unit.create_scaled_unit(metre, "k")

kilometre

kmetre

In [None]:
1 * kilometre / (1 * metre)

1000.0

Here, `create_scaled_unit` creates a new unit named "kilometre" by scaling the base unit "metre" with a scale factor of "k" (kilo).

The scale factor determines the prefix used for the unit, allowing for easy conversion between different scales of the same unit.

## Displaying in JIT / grad / ... transformations

## Basic Display methods

Except directly using the `str` and `print` functions to display a `Quantity`, `brainunit` also provides `in_unit` and `in_best_unit` functions to display a `Quantity` in a specific unit or the best unit(the value is not too large or too small) respectively.

In [17]:
from brainunit import in_unit, in_best_unit
a = 3 * bu.volt

print(a) # print is same as `in_best_unit(a)`

in_unit(a, bu.mV), in_best_unit(a)

3. V


('3000. mV', '3. V')

## Displaying in JIT transformations

`brainunit` support use the display methods above in JIT transformations.

In [20]:
import numpy as np
import jax
import jax.numpy as jnp

@jax.jit
def f1(a):
    b = a * bu.siemens / bu.cm ** 2
    print(b)
    return b

val = np.random.rand(3)
r = f1(val)
bu.math.allclose(val * bu.siemens / bu.cm ** 2, r)

Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)> m^-4 kg^-1 s^3 A^2


Array(True, dtype=bool)

In [24]:
@jax.jit
def f2(a):
    b = a * bu.siemens / bu.cm ** 2
    print(in_unit(b, bu.siemens / bu.meter ** 2))
    return b

val = np.random.rand(3)
r = f2(val)
bu.math.allclose(val * bu.siemens / bu.cm ** 2, r)

Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)> S/(m^2)


Array(True, dtype=bool)

In [26]:
@jax.jit
def f3(a):
    b = a * bu.siemens / bu.cm ** 2
    print(in_best_unit(b))
    return b

val = np.random.rand(3)
r = f3(val)
bu.math.allclose(val * bu.siemens / bu.cm ** 2, r)

Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)> S/(m^2)


Array(True, dtype=bool)