Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Mar 11, 2024
1 parent cb640d3 commit 3ba124b
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 74 deletions.
80 changes: 41 additions & 39 deletions brainpy/_src/math/units/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"have_same_dimensions",
"in_unit",
"in_best_unit",
"UnitArray",
"Quantity",
"Unit",
"register_new_unit",
"check_units",
Expand Down Expand Up @@ -300,7 +300,7 @@ def wrap_function_keep_dimensions(func):
"""

def f(x, *args, **kwds): # pylint: disable=C0111
return UnitArray(func(np.array(x, copy=False), *args, **kwds), dim=x.dim)
return Quantity(func(np.array(x, copy=False), *args, **kwds), dim=x.dim)

f._arg_units = [None]
f._return_unit = lambda u: u
Expand All @@ -325,7 +325,7 @@ def wrap_function_change_dimensions(func, change_dim_func):

def f(x, *args, **kwds): # pylint: disable=C0111
ar = np.array(x, copy=False)
return UnitArray(func(ar, *args, **kwds), dim=change_dim_func(ar, x.dim))
return Quantity(func(ar, *args, **kwds), dim=change_dim_func(ar, x.dim))

f._arg_units = [None]
f._return_unit = change_dim_func
Expand Down Expand Up @@ -781,7 +781,7 @@ def get_dimensions(obj):
] or isinstance(obj, (numbers.Number, np.number, np.ndarray)):
return DIMENSIONLESS
try:
return UnitArray(obj).dim
return Quantity(obj).dim
except TypeError:
raise TypeError(f"Object of type {type(obj)} does not have dimensions")

Expand Down Expand Up @@ -938,7 +938,7 @@ def quantity_with_dimensions(floatval, dims):
Returns
-------
q : `UnitArray`
q : `Quantity`
A quantity with the given dimensions.
Examples
Expand All @@ -951,10 +951,10 @@ def quantity_with_dimensions(floatval, dims):
--------
get_or_create_dimensions
"""
return UnitArray(floatval, get_or_create_dimension(dims._dims))
return Quantity(floatval, get_or_create_dimension(dims._dims))


class UnitArray(np.ndarray):
class Quantity(np.ndarray):
"""
A number with an associated physical dimension. In most cases, it is not
necessary to create a Quantity object by hand, instead use multiplication
Expand Down Expand Up @@ -1024,7 +1024,9 @@ class UnitArray(np.ndarray):
in_best_unit
"""

__slots__ = ["dim"]
__slots__ = ["dim", 'value', 'unit']
# value: jax.Array, np.ndarray, or number, custom type, pytree
# unit: Unit, 1, None

__array_priority__ = 1000

Expand Down Expand Up @@ -1061,7 +1063,7 @@ def __new__(cls, arr, dim=None, dtype=None, copy=False, force_quantity=False):
if not isinstance(arr, (np.ndarray, np.number, numbers.Number)):
# check whether it is an iterable containing Quantity objects
try:
is_quantity = [isinstance(x, UnitArray) for x in _flatten(arr)]
is_quantity = [isinstance(x, Quantity) for x in _flatten(arr)]
except TypeError:
# Not iterable
is_quantity = [False]
Expand Down Expand Up @@ -1211,12 +1213,12 @@ def __array_wrap__(self, array, context=None):
# a 1 * volt ** 2 quantitiy instead of volt ** 2. But this should
# rarely be an issue. The alternative leads to more confusing
# behaviour: np.float64(3) * mV would result in a dimensionless float64
result = array.view(UnitArray)
result = array.view(Quantity)
result.dim = dim
return result

def __deepcopy__(self, memo):
return UnitArray(self, copy=True)
return Quantity(self, copy=True)

# ==============================================================================
# Quantity-specific functions (not existing in ndarray)
Expand All @@ -1237,17 +1239,17 @@ def with_dimensions(value, *args, **keywords):
Returns
-------
q : `UnitArray`
q : `Quantity`
A `Quantity` object with the given dim
Examples
--------
All of these define an equivalent `Quantity` object:
>>> from brainpy.math.units import *
>>> UnitArray.with_dimensions(2, get_or_create_dimension(length=1))
>>> Quantity.with_dimensions(2, get_or_create_dimension(length=1))
2. * metre
>>> UnitArray.with_dimensions(2, length=1)
>>> Quantity.with_dimensions(2, length=1)
2. * metre
>>> 2 * metre
2. * metre
Expand All @@ -1256,7 +1258,7 @@ def with_dimensions(value, *args, **keywords):
dimensions = args[0]
else:
dimensions = get_or_create_dimension(*args, **keywords)
return UnitArray(value, dim=dimensions)
return Quantity(value, dim=dimensions)

### ATTRIBUTES ###
is_dimensionless = property(
Expand Down Expand Up @@ -1378,7 +1380,7 @@ def get_best_unit(self, *regs):
Returns
-------
u : `UnitArray` or `Unit`
u : `Quantity` or `Unit`
The best-fitting unit for the quantity `x`.
"""
if self.is_dimensionless:
Expand All @@ -1389,7 +1391,7 @@ def get_best_unit(self, *regs):
return r[self]
except KeyError:
pass
return UnitArray(1, self.dim)
return Quantity(1, self.dim)
else:
return self.get_best_unit(
standard_unit_register, user_unit_register, additional_unit_register
Expand Down Expand Up @@ -1447,11 +1449,11 @@ def __getitem__(self, key):
"""Overwritten to assure that single elements (i.e., indexed with a
single integer or a tuple of integers) retain their unit.
"""
return UnitArray(np.ndarray.__getitem__(self, key), self.dim)
return Quantity(np.ndarray.__getitem__(self, key), self.dim)

def item(self, *args):
"""Overwritten to assure that the returned element retains its unit."""
return UnitArray(np.ndarray.item(self, *args), self.dim)
return Quantity(np.ndarray.item(self, *args), self.dim)

def __setitem__(self, key, value):
fail_for_dimension_mismatch(self, value, "Inconsistent units in assignment")
Expand Down Expand Up @@ -1522,7 +1524,7 @@ def _binary_operation(

if inplace:
if self.shape == ():
self_value = UnitArray(self, copy=True)
self_value = Quantity(self, copy=True)
else:
self_value = self
operation(self_value, other)
Expand All @@ -1533,7 +1535,7 @@ def _binary_operation(
self_arr = np.array(self, copy=False)
other_arr = np.array(other, copy=False)
result = operation(self_arr, other_arr)
return UnitArray(result, newdims)
return Quantity(result, newdims)

def __mul__(self, other):
return self._binary_operation(other, operator.mul, operator.mul)
Expand Down Expand Up @@ -1601,12 +1603,12 @@ def __rsub__(self, other):
# We allow operations with 0 even for dimension mismatches, e.g.
# 0 - 3*mV is allowed. In this case, the 0 is not represented by a
# Quantity object so we cannot simply call Quantity.__sub__
if (not isinstance(other, UnitArray) or other.dim is DIMENSIONLESS) and np.all(
if (not isinstance(other, Quantity) or other.dim is DIMENSIONLESS) and np.all(
other == 0
):
return self.__neg__()
else:
return UnitArray(other, copy=False, force_quantity=True).__sub__(self)
return Quantity(other, copy=False, force_quantity=True).__sub__(self)

def __isub__(self, other):
return self._binary_operation(
Expand All @@ -1631,15 +1633,15 @@ def __pow__(self, other):
exponent=other,
)
other = np.array(other, copy=False)
return UnitArray(np.array(self, copy=False) ** other, self.dim ** other)
return Quantity(np.array(self, copy=False) ** other, self.dim ** other)
else:
return NotImplemented

def __rpow__(self, other):
if self.is_dimensionless:
if isinstance(other, np.ndarray) or isinstance(other, np.ndarray):
new_array = np.array(other, copy=False) ** np.array(self, copy=False)
return UnitArray(new_array, DIMENSIONLESS)
return Quantity(new_array, DIMENSIONLESS)
else:
return NotImplemented
else:
Expand Down Expand Up @@ -1671,21 +1673,21 @@ def __ipow__(self, other):
return NotImplemented

def __neg__(self):
return UnitArray(-np.array(self, copy=False), self.dim)
return Quantity(-np.array(self, copy=False), self.dim)

def __pos__(self):
return self

def __abs__(self):
return UnitArray(abs(np.array(self, copy=False)), self.dim)
return Quantity(abs(np.array(self, copy=False)), self.dim)

def tolist(self):
"""
Convert the array into a list.
Returns
-------
l : list of `UnitArray`
l : list of `Quantity`
A (possibly nested) list equivalent to the original array.
"""

Expand All @@ -1696,15 +1698,15 @@ def replace_with_quantity(seq, dim):
"""
# No recursion needed for single values
if not isinstance(seq, list):
return UnitArray(seq, dim)
return Quantity(seq, dim)

def top_replace(s):
"""
Recursivley descend into the list.
"""
for i in s:
if not isinstance(i, list):
yield UnitArray(i, dim)
yield Quantity(i, dim)
else:
yield type(i)(top_replace(i))

Expand Down Expand Up @@ -1834,7 +1836,7 @@ def put(self, indices, values, *args, **kwds): # pylint: disable=C0111
def clip(self, a_min, a_max, *args, **kwds): # pylint: disable=C0111
fail_for_dimension_mismatch(self, a_min, "clip")
fail_for_dimension_mismatch(self, a_max, "clip")
return UnitArray(
return Quantity(
np.clip(
np.array(self, copy=False),
np.array(a_min, copy=False),
Expand All @@ -1849,7 +1851,7 @@ def clip(self, a_min, a_max, *args, **kwds): # pylint: disable=C0111
clip._do_not_run_doctests = True

def dot(self, other, **kwds): # pylint: disable=C0111
return UnitArray(
return Quantity(
np.array(self).dot(np.array(other), **kwds),
self.dim * get_dimensions(other),
)
Expand Down Expand Up @@ -1879,7 +1881,7 @@ def prod(self, *args, **kwds): # pylint: disable=C0111
# identical
if dim_exponent.size > 1:
dim_exponent = dim_exponent[0]
return UnitArray(np.array(prod_result, copy=False), self.dim ** dim_exponent)
return Quantity(np.array(prod_result, copy=False), self.dim ** dim_exponent)

prod.__doc__ = np.ndarray.prod.__doc__
prod._do_not_run_doctests = True
Expand All @@ -1890,16 +1892,16 @@ def cumprod(self, *args, **kwds): # pylint: disable=C0111
"cumprod over array elements on quantities "
"with dimensions is not possible."
)
return UnitArray(np.array(self, copy=False).cumprod(*args, **kwds))
return Quantity(np.array(self, copy=False).cumprod(*args, **kwds))

cumprod.__doc__ = np.ndarray.cumprod.__doc__
cumprod._do_not_run_doctests = True


UnitArray.__module__ = "brainpy.math.units"
Quantity.__module__ = "brainpy.math.units"


class Unit(UnitArray):
class Unit(Quantity):
r"""
A physical unit.
Expand Down Expand Up @@ -2287,7 +2289,7 @@ def __eq__(self, other):
if isinstance(other, Unit):
return other.dim is self.dim and other.scale == self.scale
else:
return UnitArray.__eq__(self, other)
return Quantity.__eq__(self, other)

def __neq__(self, other):
return not self.__eq__(other)
Expand Down Expand Up @@ -2562,13 +2564,13 @@ def new_f(*args, **kwds):
arg_names = f.__code__.co_varnames[0: f.__code__.co_argcount]
for n, v in zip(arg_names, args[0: f.__code__.co_argcount]):
if (
not isinstance(v, (UnitArray, str, bool))
not isinstance(v, (Quantity, str, bool))
and v is not None
and n in au
):
try:
# allow e.g. to pass a Python list of values
v = UnitArray(v)
v = Quantity(v)
except TypeError:
if have_same_dimensions(au[n], 1):
raise TypeError(
Expand Down

0 comments on commit 3ba124b

Please sign in to comment.