Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[docs] Update docs for brainunit.math #4

Merged
merged 15 commits into from
Jun 11, 2024
73 changes: 40 additions & 33 deletions brainunit/_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def assert_allclose(actual, desired, rtol=4.5e8, atol=0, **kwds):
def assert_quantity(q, values, unit):
values = jnp.asarray(values)
if isinstance(q, Quantity):
assert have_same_unit(q.unit, unit), f"Dimension mismatch: ({get_unit(q)}) ({get_unit(unit)})"
assert have_same_unit(q.dim, unit), f"Dimension mismatch: ({get_unit(q)}) ({get_unit(unit)})"
if not jnp.allclose(q.value, values):
raise AssertionError(f"Values do not match: {q.value} != {values}")
elif isinstance(q, jnp.ndarray):
Expand Down Expand Up @@ -144,10 +144,10 @@ def test_get_dimensions():
Test various ways of getting/comparing the dimensions of a Array.
"""
q = 500 * ms
assert get_unit(q) is get_or_create_dimension(q.unit._dims)
assert get_unit(q) is q.unit
assert get_unit(q) is get_or_create_dimension(q.dim._dims)
assert get_unit(q) is q.dim
assert q.has_same_unit(3 * second)
dims = q.unit
dims = q.dim
assert_equal(dims.get_dimension("time"), 1.0)
assert_equal(dims.get_dimension("length"), 0)

Expand Down Expand Up @@ -201,47 +201,54 @@ def test_unary_operations():


def test_operations():
q1 = Quantity(5, dim=mV)
q2 = Quantity(10, dim=mV)
assert_quantity(q1 + q2, 15, mV)
assert_quantity(q1 - q2, -5, mV)
assert_quantity(q1 * q2, 50, mV * mV)
q1 = 5 * second
q2 = 10 * second
assert_quantity(q1 + q2, 15, second)
assert_quantity(q1 - q2, -5, second)
assert_quantity(q1 * q2, 50, second * second)
assert_quantity(q2 / q1, 2, DIMENSIONLESS)
assert_quantity(q2 // q1, 2, DIMENSIONLESS)
assert_quantity(q2 % q1, 0, mV)
assert_quantity(q2 % q1, 0, second)
assert_quantity(divmod(q2, q1)[0], 2, DIMENSIONLESS)
assert_quantity(divmod(q2, q1)[1], 0, mV)
assert_quantity(q1 ** 2, 25, mV ** 2)
assert_quantity(q1 << 1, 10, mV)
assert_quantity(q1 >> 1, 2, mV)
assert_quantity(round(q1, 0), 5, mV)
assert_quantity(divmod(q2, q1)[1], 0, second)
assert_quantity(q1 ** 2, 25, second ** 2)
assert_quantity(round(q1, 0), 5, second)

# matmul
q1 = Quantity([1, 2], dim=mV)
q2 = Quantity([3, 4], dim=mV)
assert_quantity(q1 @ q2, 11, mV ** 2)
q1 = [1, 2] * second
q2 = [3, 4] * second
assert_quantity(q1 @ q2, 11, second ** 2)
q1 = Quantity([1, 2], unit=second)
q2 = Quantity([3, 4], unit=second)
assert_quantity(q1 @ q2, 11, second ** 2)

# shift
q1 = Quantity(0b1100, dtype=jnp.int32, dim=DIMENSIONLESS)
assert_quantity(q1 << 1, 0b11000, second)
assert_quantity(q1 >> 1, 0b110, second)


def test_numpy_methods():
q = Quantity([[1, 2], [3, 4]], dim=mV)
q = [[1, 2], [3, 4]] * second
assert q.all()
assert q.any()
assert q.nonzero()[0].tolist() == [0, 0, 1, 1]
assert q.argmax() == 3
assert q.argmin() == 0
assert q.argsort(axis=None).tolist() == [0, 1, 2, 3]
assert_quantity(q.var(), 1.25, mV ** 2)
assert_quantity(q.round(), [[1, 2], [3, 4]], mV)
assert_quantity(q.std(), 1.11803398875, mV)
assert_quantity(q.sum(), 10, mV)
assert_quantity(q.trace(), 5, mV)
assert_quantity(q.cumsum(), [1, 3, 6, 10], mV)
assert_quantity(q.cumprod(), [1, 2, 6, 24], mV ** 4)
assert_quantity(q.diagonal(), [1, 4], mV)
assert_quantity(q.max(), 4, mV)
assert_quantity(q.mean(), 2.5, mV)
assert_quantity(q.min(), 1, mV)
assert_quantity(q.ptp(), 3, mV)
assert_quantity(q.ravel(), [1, 2, 3, 4], mV)
assert_quantity(q.var(), 1.25, second ** 2)
assert_quantity(q.round(), [[1, 2], [3, 4]], second)
assert_quantity(q.std(), 1.11803398875, second)
assert_quantity(q.sum(), 10, second)
assert_quantity(q.trace(), 5, second)
assert_quantity(q.cumsum(), [1, 3, 6, 10], second)
assert_quantity(q.cumprod(), [1, 2, 6, 24], second ** 4)
assert_quantity(q.diagonal(), [1, 4], second)
assert_quantity(q.max(), 4, second)
assert_quantity(q.mean(), 2.5, second)
assert_quantity(q.min(), 1, second)
assert_quantity(q.ptp(), 3, second)
assert_quantity(q.ravel(), [1, 2, 3, 4], second)


def test_shape_manipulation():
Expand Down Expand Up @@ -1596,7 +1603,7 @@ def test_constants():
import brainunit._unit_constants as constants

# Check that the expected names exist and have the correct dimensions
assert constants.avogadro_constant.dim == (1 / mole).unit
assert constants.avogadro_constant.dim == (1 / mole).dim
assert constants.boltzmann_constant.dim == (joule / kelvin).dim
assert constants.electric_constant.dim == (farad / meter).dim
assert constants.electron_mass.dim == kilogram.dim
Expand Down
65 changes: 60 additions & 5 deletions brainunit/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,67 @@
# limitations under the License.
# ==============================================================================

from ._compat_numpy import *
from ._compat_numpy import __all__ as _compat_numpy_all
# from ._compat_numpy import *
# from ._compat_numpy import __all__ as _compat_numpy_all
from ._others import *
from ._others import __all__ as _other_all
from ._compat_numpy_array_creation import *
from ._compat_numpy_array_creation import __all__ as _compat_array_creation_all
from ._compat_numpy_array_manipulation import *
from ._compat_numpy_array_manipulation import __all__ as _compat_array_manipulation_all
from ._compat_numpy_funcs_accept_unitless import *
from ._compat_numpy_funcs_accept_unitless import __all__ as _compat_funcs_accept_unitless_all
from ._compat_numpy_funcs_bit_operation import *
from ._compat_numpy_funcs_bit_operation import __all__ as _compat_funcs_bit_operation_all
from ._compat_numpy_funcs_change_unit import *
from ._compat_numpy_funcs_change_unit import __all__ as _compat_funcs_change_unit_all
from ._compat_numpy_funcs_indexing import *
from ._compat_numpy_funcs_indexing import __all__ as _compat_funcs_indexing_all
from ._compat_numpy_funcs_keep_unit import *
from ._compat_numpy_funcs_keep_unit import __all__ as _compat_funcs_keep_unit_all
from ._compat_numpy_funcs_logic import *
from ._compat_numpy_funcs_logic import __all__ as _compat_funcs_logic_all
from ._compat_numpy_funcs_match_unit import *
from ._compat_numpy_funcs_match_unit import __all__ as _compat_funcs_match_unit_all
from ._compat_numpy_funcs_remove_unit import *
from ._compat_numpy_funcs_remove_unit import __all__ as _compat_funcs_remove_unit_all
from ._compat_numpy_funcs_window import *
from ._compat_numpy_funcs_window import __all__ as _compat_funcs_window_all
from ._compat_numpy_get_attribute import *
from ._compat_numpy_get_attribute import __all__ as _compat_get_attribute_all
from ._compat_numpy_linear_algebra import *
from ._compat_numpy_linear_algebra import __all__ as _compat_linear_algebra_all
from ._compat_numpy_misc import *
from ._compat_numpy_misc import __all__ as _compat_misc_all

__all__ = _compat_numpy_all + _other_all

del _compat_numpy_all, _other_all
__all__ = _compat_array_creation_all + \
_compat_array_manipulation_all + \
_compat_funcs_change_unit_all + \
_compat_funcs_keep_unit_all + \
_compat_funcs_accept_unitless_all + \
_compat_funcs_match_unit_all + \
_compat_funcs_remove_unit_all + \
_compat_get_attribute_all + \
_compat_funcs_bit_operation_all + \
_compat_funcs_logic_all + \
_compat_funcs_indexing_all + \
_compat_funcs_window_all + \
_compat_linear_algebra_all + \
_compat_misc_all + _other_all + \
_other_all

del _compat_array_creation_all, \
_compat_array_manipulation_all, \
_compat_funcs_change_unit_all, \
_compat_funcs_keep_unit_all, \
_compat_funcs_accept_unitless_all, \
_compat_funcs_match_unit_all, \
_compat_funcs_remove_unit_all, \
_compat_get_attribute_all, \
_compat_funcs_bit_operation_all, \
_compat_funcs_logic_all, \
_compat_funcs_indexing_all, \
_compat_funcs_window_all, \
_compat_linear_algebra_all, \
_compat_misc_all, \
_other_all
Loading