Skip to content

Commit

Permalink
Implement set_default_magnitude
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jul 12, 2024
1 parent d915314 commit c25506c
Show file tree
Hide file tree
Showing 3 changed files with 302 additions and 10 deletions.
76 changes: 76 additions & 0 deletions brainunit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,16 @@

__version__ = "0.0.1.1"

import importlib

from . import math
from . import _base
from . import _unit_common
from . import _unit_constants
from . import _unit_shortcuts

from ._base import *
from ._base import _default_magnitude, _siprefixes
from ._base import __all__ as _base_all
from ._unit_common import *
from ._unit_common import __all__ as _common_all
Expand All @@ -27,3 +35,71 @@

__all__ = ['math'] + _common_all + _std_units_all + _constants_all + _base_all
del _common_all, _std_units_all, _constants_all, _base_all


def set_default_magnitude(
magnitude: int | dict[str, int],
# unit: Unit = None,
):
"""
Set the default magnitude for units.
Parameters
----------
magnitude : int | dict[str, int]
The magnitude to set. If an int is given, it will be set for all
dimensions. If a dict is given, it will be set for the specified
dimensions.
Examples
--------
>>> set_default_magnitude('n') # Sets the default magnitude to 'nano' (10e-9)
>>> set_default_magnitude(-9) # Alternatively, use an integer to represent the exponent of 10
>>> set_default_magnitude({'m': -3, 'kg': -9}) # Set the default magnitude for 'metre' to 'milli' and 'kilogram' to 'nano'
>>> set_default_magnitude({'m': 'm', 'kg': 'n'}) # Alternatively, use a string to represent the magnitude
>>> set_default_magnitude(-3, unit=volt) # Set the default magnitude for the 'volt' unit to 'milli'
"""
global _default_magnitude
if isinstance(magnitude, int):
# if isinstance(unit, Unit):
# for key, dim in zip(_default_magnitude.keys(), unit.dim._dims):
# _default_magnitude[key] = magnitude / abs(dim) if dim != 0 else 0
# else:
_default_magnitude.update((key, magnitude) for key in _default_magnitude)
elif isinstance(magnitude, str):
# if isinstance(unit, Unit):
# for key, dim in zip(_default_magnitude.keys(), unit.dim._dims):
# _default_magnitude[key] = _siprefixes[magnitude] / abs(dim) if dim != 0 else 0
# else:
_default_magnitude.update((key, _siprefixes[magnitude]) for key in _default_magnitude)
elif isinstance(magnitude, dict):
_default_magnitude.update((key, 0) for key in _default_magnitude)
for key, value in magnitude.items():
if isinstance(value, int):
_default_magnitude[key] = value
elif isinstance(value, str):
_default_magnitude[key] = _siprefixes[value]
else:
raise ValueError(f"Invalid magnitude value: {value}")
else:
raise ValueError(f"Invalid magnitude: {magnitude}")

global _unit_common
global _unit_constants
global _unit_shortcuts
# Reload modules
importlib.reload(_unit_common)
importlib.reload(_unit_constants)
importlib.reload(_unit_shortcuts)

from ._base import __all__ as _base_all
from ._unit_common import __all__ as _common_all
from ._unit_constants import __all__ as _constants_all
from ._unit_shortcuts import __all__ as _std_units_all
globals().update({k: getattr(_unit_common, k) for k in _common_all})
globals().update({k: getattr(_unit_constants, k) for k in _constants_all})
globals().update({k: getattr(_unit_shortcuts, k) for k in _std_units_all})

global __all__
__all__ = ['math'] + _common_all + _std_units_all + _constants_all + _base_all
del _common_all, _std_units_all, _constants_all, _base_all
115 changes: 106 additions & 9 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,25 @@
'check_units',
'is_scalar_type',
'fail_for_dimension_mismatch',
'assert_quantity'
'assert_quantity',
]

_all_slice = slice(None, None, None)
_unit_checking = True
_allow_python_scalar_value = False
_auto_register_unit = True

# for setting the default magnitude of the unit
_default_magnitude = {
"m": 0,
"kg": 0,
"s": 0,
"A": 0,
"K": 0,
"mol": 0,
"cd": 0,
}


@contextmanager
def turn_off_auto_unit_register():
Expand Down Expand Up @@ -584,6 +595,31 @@ def get_dim(obj) -> Dimension:
raise TypeError(f"Object of type {type(obj)} does not have dimensions")


def get_unit(obj) -> Unit:
"""
Return the unit of any object that has them.
Parameters
----------
obj : `object`
The object to check.
Returns
-------
unit : Unit
The physical unit of the `obj`.
"""
try:
return obj.unit
except AttributeError:
if isinstance(obj, (numbers.Number, jax.Array, np.number, np.ndarray)):
return Unit(1, name='1', dispname='1')
try:
return Quantity(obj).unit
except TypeError:
raise TypeError(f"Object of type {type(obj)} does not have a unit")


def have_same_unit(obj1, obj2) -> bool:
"""Test if two values have the same dimensions.
Expand Down Expand Up @@ -969,12 +1005,15 @@ def check_units_and_collect_values(lst):

def _get_dim(dim: Dimension, unit: 'Unit'):
if dim != DIMENSIONLESS and unit is not None:
raise ValueError("Cannot specify both a dimension and a unit")
return None, dim
if dim == DIMENSIONLESS:
if unit is None:
return None, DIMENSIONLESS
else:
return unit.value, unit.dim
try:
return unit.value, unit.dim
except:
return None, unit.dim
else:
return None, dim

Expand All @@ -987,7 +1026,7 @@ class Quantity(object):
"""
__module__ = "brainunit"
__slots__ = ('_value', '_dim')
__slots__ = ('_value', '_dim', '_unit')
_value: Union[jax.Array, numbers.Number]
_dim: Dimension
__array_priority__ = 1000
Expand All @@ -1005,6 +1044,8 @@ def __init__(
if isinstance(value, numbers.Number):
self._dim = dim
self._value = (value if scale is None else (value * scale))
if dim is not DIMENSIONLESS:
self._unit = unit
return

if isinstance(value, (list, tuple)):
Expand All @@ -1023,6 +1064,7 @@ def __init__(
# array value
if isinstance(value, Quantity):
self._dim = value.dim
self._unit = Unit(1, name='1', dispname='1') if unit is None else unit
self._value = jnp.array(value.value, dtype=dtype)
return

Expand All @@ -1044,6 +1086,12 @@ def __init__(
# dimension
self._dim = dim

# unit
if unit is None:
self._unit = Unit(1, name='1', dispname='1')
else:
self._unit = unit

@property
def value(self) -> jax.Array | numbers.Number:
# return the value
Expand Down Expand Up @@ -1101,7 +1149,8 @@ def dim(self, *args):

@property
def unit(self) -> 'Unit':
return Unit(1., self.dim, register=False)
return self._unit
# return Unit(1., self.dim, register=False)

@unit.setter
def unit(self, *args):
Expand Down Expand Up @@ -1214,14 +1263,14 @@ def get_best_unit(self, *regs) -> 'Quantity':
The best unit for this `Array`.
"""
if self.is_unitless:
return Unit(1)
return Unit(1, name='1', dispname='1')
if len(regs):
for r in regs:
try:
return r[self]
except KeyError:
pass
return Quantity(1, dim=self.dim)
return self.unit
else:
return self.get_best_unit(standard_unit_register, user_unit_register, additional_unit_register)

Expand Down Expand Up @@ -1556,6 +1605,7 @@ def _binary_operation(
"""
other = _to_quantity(other)
other_dim = None
other_unit = None

if fail_for_mismatch:
if inplace:
Expand All @@ -1568,9 +1618,13 @@ def _binary_operation(
if other_dim is None:
other_dim = get_dim(other)

if other_unit is None:
other_unit = get_unit(other)

new_dim = unit_operation(self.dim, other_dim)
new_unit = unit_operation(self.unit, other_unit)
result = value_operation(self.value, other.value)
r = Quantity(result, dim=new_dim)
r = Quantity(result, dim=new_dim, unit=new_unit)
if inplace:
self.update_value(r.value)
return self
Expand Down Expand Up @@ -2688,11 +2742,18 @@ def __init__(
# Whether this unit is a combination of other units
self.iscompound = iscompound

super().__init__(value, dtype=dtype, dim=dim)
if dim == DIMENSIONLESS:
super().__init__(value, dtype=dtype, dim=dim)
else:
super().__init__(value, dtype=dtype, dim=dim, unit=self)

if _auto_register_unit and register:
register_new_unit(self)

@property
def unit(self) -> 'Unit':
return self

@staticmethod
def create(unit: Dimension, name: str, dispname: str, scale: int = 0):
"""
Expand All @@ -2718,6 +2779,8 @@ def create(unit: Dimension, name: str, dispname: str, scale: int = 0):
name = str(name)
dispname = str(dispname)

scale = calculate_scale(unit=unit, scale=scale)

u = Unit(
10.0 ** scale,
dim=unit,
Expand Down Expand Up @@ -3274,3 +3337,37 @@ def new_f(*args, **kwds):
return new_f

return do_check_units


def calculate_scale(
unit: Dimension,
scale: int = 0
) -> int:
"""
Calculate the scale for a unit.
Parameters
----------
unit : Dimension
The unit to determine the scale for.
Returns
-------
scale : int
The scale for the unit.
Examples
--------
>>> set_default_magnitude({'m': -3, 'kg': -9})
>>> calculate_scale(get_or_create_dimension(m=1))
-3
>>> calculate_scale(get_or_create_dimension(kg=1))
-9
>>> calculate_scale(get_or_create_dimension(m=1, kg=1))
-12
>>> calculate_scale(get_or_create_dimension(m=2, kg=-1))
3
"""
for dim, magnitude in zip(unit._dims, _default_magnitude.values()):
scale -= dim * magnitude
return scale
Loading

0 comments on commit c25506c

Please sign in to comment.