Skip to content

Commit

Permalink
Fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Mar 27, 2024
1 parent f3559d2 commit 8e0d3b6
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 112 deletions.
88 changes: 56 additions & 32 deletions braincore/units/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,28 +1065,43 @@ class Array(object):
# )
# return subarr

def __init__(self, value, dtype: Any = None, unit=DIMENSIONLESS, copy=False):
def __init__(self, value, dtype=float, unit=DIMENSIONLESS, copy=False):
# This checks if the list is empty
try:
is_empty_list = not value
except:
is_empty_list = False

# array value
if isinstance(value, Array):
unit = value.unit
value = value._value
elif isinstance(value, (tuple, list, np.ndarray)):
# get the unit if it is an unit list
if len(value) > 0 and hasattr(value[0], 'unit'):
# make sure all elements have the same unit
units = [v.unit for v in value if hasattr(v, 'unit')]
self._unit = value.unit
self._value = jnp.array(value.value, dtype=dtype, copy=copy)
return
if is_empty_list:
self._unit = unit
self._value = value
return
elif isinstance(value, (tuple, list, np.ndarray, jnp.ndarray)):
# Existing logic to check for mixed types or process units
has_units = [hasattr(v, 'unit') for v in value]
if any(has_units) and not all(has_units):
raise TypeError("All elements must have the same unit or no unit at all")
if all(has_units):
units = [v.unit for v in value]
if not all(u == units[0] for u in units):
raise ValueError("All elements must have the same unit")
raise TypeError("All elements must have the same unit")
unit = units[0]
# get the value
value = [v.value if hasattr(v, 'value') else v for v in value]
# transform to jnp array
value = jnp.asarray(value, dtype=dtype, copy=copy)
value = [v.value for v in value]
# Transform to jnp array
value = jnp.array(value, dtype=dtype, copy=copy)
elif isinstance(value, (np.number, numbers.Number)):
value = jnp.array(value, dtype=dtype, copy=copy)
else:
raise TypeError(f"Invalid type for value: {type(value)}")
if dtype is not None:
value = jnp.asarray(value, dtype=dtype, copy=copy)
value = jnp.array(value, dtype=dtype, copy=copy)

self._value = value

# unit
self._unit = unit

Expand Down Expand Up @@ -1185,8 +1200,8 @@ def has_same_unit(self, other):
bool
Whether the two Arrays have the same unit dimensions
"""
other_unit = get_unit(other)
return (self.unit is other_unit) or (self.unit == other_unit)
other_unit = get_unit(other.unit)
return (get_unit(self.unit) is other_unit) or (get_unit(self.unit) == other_unit)

def in_unit(self, u, precision=None, python_code=False):
"""
Expand Down Expand Up @@ -1218,18 +1233,25 @@ def in_unit(self, u, precision=None, python_code=False):
>>> x.in_unit(mV, 3)
'25.123 mV'
"""
fail_for_dimension_mismatch(self.unit, u, 'Non-matching unit for method "in_unit"')

value = jnp.array(self.value / u, copy=False)
fail_for_dimension_mismatch(self, u, 'Non-matching unit for method "in_unit"')

value = np.asarray(self.value / u)
if value.shape == ():
s = jnp.array_str(jnp.array([value]), precision=precision)
s = np.array_str(jnp.array([value]), precision=precision)
s = s.replace("[", "").replace("]", "").strip()
else:
if python_code:
s = jnp.array_repr(value, precision=precision)
if value.size > 100:
if python_code:
s = np.array_repr(value, precision=precision)[:100]
s += "..."
else:
s = np.array_str(value, precision=precision)[:100]
s += "..."
else:
s = jnp.array_str(value, precision=precision)
if python_code:
s = np.array_repr(value, precision=precision)
else:
s = np.array_str(value, precision=precision)

if not u.is_dimensionless:
if isinstance(u, Unit):
Expand Down Expand Up @@ -1269,7 +1291,7 @@ def get_best_unit(self, *regs):
return r[self]
except KeyError:
pass
return Array(1, unit=self.dim)
return Array(1, unit=self.unit)
else:
return self.get_best_unit(
standard_unit_register, user_unit_register, additional_unit_register
Expand Down Expand Up @@ -3199,6 +3221,8 @@ def __init__(
# The full name of this unit
self._name = name
# The display name of this unit
if dispname is None:
dispname = name
self._dispname = dispname
# Whether this unit is a combination of other units
self.iscompound = iscompound
Expand Down Expand Up @@ -3231,7 +3255,7 @@ def create(unit, name, dispname, scale=0):
The new unit.
"""
name = str(name)
dispname = str(name)
dispname = str(dispname)

u = Unit(
10.0 ** scale,
Expand Down Expand Up @@ -3335,16 +3359,16 @@ def __mul__(self, other):
def __div__(self, other):
if isinstance(other, Unit):
if self.iscompound:
dispname = f"({self.dispname}"
name = f"({self.name}"
dispname = f"({self.dispname})"
name = f"({self.name})"
else:
dispname = self.dispname
name = self.name
dispname += "/"
name += " / "
if other.iscompound:
dispname += f"{other.dispname})"
name += f"{other.name})"
dispname += f"({other.dispname})"
name += f"({other.name})"
else:
dispname += other.dispname
name += other.name
Expand Down Expand Up @@ -3467,7 +3491,7 @@ def __getitem__(self, x):
unit where the deviations from that are the smallest. More precisely,
the unit that minimizes the sum of (log10(m)-1)**2 over all entries).
"""
matching = self.units_for_dimensions.get(x.dim, {})
matching = self.units_for_dimensions.get(x.unit, {})
if len(matching) == 0:
raise KeyError("Unit not found in registry.")

Expand Down Expand Up @@ -3550,7 +3574,7 @@ def get_unit(d):
]:
if 1.0 in unit_register.units_for_dimensions[d]:
return unit_register.units_for_dimensions[d][1.0]
return Unit(1.0, dim=d)
return Unit(1.0, unit=d)


def get_unit_for_display(d):
Expand Down

0 comments on commit 8e0d3b6

Please sign in to comment.