Skip to content

Commit

Permalink
fix conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jun 27, 2024
2 parents 30a55a6 + c39110d commit 419c104
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 28 deletions.
5 changes: 4 additions & 1 deletion brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,10 @@ def get_dim(obj) -> Dimension:
The physical dimensions of the `obj`.
"""
try:
return obj.dim
try:
return obj.dim.dim
except:
return obj.dim
except AttributeError:
# The following is not very pretty, but it will avoid the costly
# isinstance check for the common types
Expand Down
5 changes: 4 additions & 1 deletion brainunit/math/_fun_accept_unitless.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@
'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan',
'tanh', 'deg2rad', 'rad2deg', 'degrees', 'radians', 'angle',
'percentile', 'nanpercentile', 'quantile', 'nanquantile',
'corrcoef', 'correlate', 'cov',

# math funcs only accept unitless (unary) can return Quantity
'round', 'around', 'round_', 'rint',
'floor', 'ceil', 'trunc', 'fix', 'modf', 'frexp',
'corrcoef', 'correlate', 'cov',


# math funcs only accept unitless (binary)
'hypot', 'arctan2', 'logaddexp', 'logaddexp2',
Expand Down
2 changes: 1 addition & 1 deletion brainunit/math/_fun_remove_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import jax.numpy as jnp
from jax import Array

from .._base import Quantity, DIMENSIONLESS, fail_for_dimension_mismatch
from .._base import Quantity, fail_for_dimension_mismatch, DIMENSIONLESS
from .._misc import set_module_as

__all__ = [
Expand Down
51 changes: 26 additions & 25 deletions brainunit/math/_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def test_round(self):
self.assertTrue(jnp.all(result == jnp.round(array)))

q = bu.Quantity([1.123, 2.567, 3.891], bu.second)
result_q = bu.math.round(q)
result_q = bu.math.round(q, unit_to_scale=bu.second)
expected_q = jnp.round(jnp.array([1.123, 2.567, 3.891])) * bu.second
assert_quantity(result_q, expected_q.value, bu.second)

Expand All @@ -383,7 +383,7 @@ def test_rint(self):
self.assertTrue(jnp.all(result == jnp.rint(array)))

q = bu.Quantity([1.5, 2.3, 3.8], bu.second)
result_q = bu.math.rint(q)
result_q = bu.math.rint(q, unit_to_scale=bu.second)
expected_q = jnp.rint(jnp.array([1.5, 2.3, 3.8])) * bu.second
assert_quantity(result_q, expected_q.value, bu.second)

Expand All @@ -393,7 +393,7 @@ def test_floor(self):
self.assertTrue(jnp.all(result == jnp.floor(array)))

q = bu.Quantity([1.5, 2.3, 3.8], bu.second)
result_q = bu.math.floor(q)
result_q = bu.math.floor(q, unit_to_scale=bu.second)
expected_q = jnp.floor(jnp.array([1.5, 2.3, 3.8]))
assert_quantity(result_q, expected_q, bu.second)

Expand Down Expand Up @@ -713,6 +713,16 @@ def test_gcd(self):
expected_q = jnp.gcd(jnp.array([4, 5, 6]), jnp.array([2, 3, 4])) * bu.second
assert_quantity(result_q, expected_q.value, bu.second)

def test_remainder(self):
result = bu.math.remainder(jnp.array([5, 7]), jnp.array([2, 3]))
self.assertTrue(jnp.all(result == jnp.remainder(jnp.array([5, 7]), jnp.array([2, 3]))))

q1 = [5, 7] * bu.second
q2 = [2, 3] * bu.second
result_q = bu.math.remainder(q1, q2)
expected_q = jnp.remainder(jnp.array([5, 7]), jnp.array([2, 3])) * bu.second
assert_quantity(result_q, expected_q.value, bu.second)


class TestMathFuncsKeepUnitUnary2(unittest.TestCase):

Expand Down Expand Up @@ -946,15 +956,6 @@ def test_divmod(self):
expected = jnp.divmod(jnp.array([5, 6]), jnp.array([2, 3]))
self.assertTrue(jnp.all(result[0] == expected[0]) and jnp.all(result[1] == expected[1]))

def test_remainder(self):
result = bu.math.remainder(jnp.array([5, 7]), jnp.array([2, 3]))
self.assertTrue(jnp.all(result == jnp.remainder(jnp.array([5, 7]), jnp.array([2, 3]))))

q1 = [5, 7] * (bu.second ** 2)
q2 = [2, 3] * bu.second
result_q = bu.math.remainder(q1, q2)
expected_q = jnp.remainder(jnp.array([5, 7]), jnp.array([2, 3])) * bu.second
assert_quantity(result_q, expected_q.value, bu.second)

def test_convolve(self):
result = bu.math.convolve(jnp.array([1, 2, 3]), jnp.array([4, 5, 6]))
Expand Down Expand Up @@ -1625,7 +1626,7 @@ def test_append(self):
self.assertTrue(jnp.all(result == jnp.append(array, 3)))

q = [0, 1, 2] * bu.second
result_q = bu.math.append(q, 3)
result_q = bu.math.append(q, 3 * bu.second)
expected_q = jnp.append(jnp.array([0, 1, 2]), 3)
assert_quantity(result_q, expected_q, bu.second)

Expand Down Expand Up @@ -1901,7 +1902,7 @@ def test_searchsorted(self):
self.assertTrue(result == jnp.searchsorted(array, 2))

q = [0, 1, 2] * bu.second
result_q = bu.math.searchsorted(q, 2)
result_q = bu.math.searchsorted(q, 2 * bu.second)
expected_q = jnp.searchsorted(jnp.array([0, 1, 2]), 2)
assert result_q == expected_q

Expand Down Expand Up @@ -1932,15 +1933,15 @@ def test_bitwise_not(self):
result = bu.math.bitwise_not(jnp.array([0b1100]))
self.assertTrue(jnp.all(result == jnp.bitwise_not(jnp.array([0b1100]))))

with pytest.raises(ValueError):
with pytest.raises(AssertionError):
q = [0b1100] * bu.second
result_q = bu.math.bitwise_not(q)

def test_invert(self):
result = bu.math.invert(jnp.array([0b1100]))
self.assertTrue(jnp.all(result == jnp.invert(jnp.array([0b1100]))))

with pytest.raises(ValueError):
with pytest.raises(AssertionError):
q = [0b1100] * bu.second
result_q = bu.math.invert(q)

Expand All @@ -1951,7 +1952,7 @@ def test_bitwise_and(self):
result = bu.math.bitwise_and(jnp.array([0b1100]), jnp.array([0b1010]))
self.assertTrue(jnp.all(result == jnp.bitwise_and(jnp.array([0b1100]), jnp.array([0b1010]))))

with pytest.raises(ValueError):
with pytest.raises(AssertionError):
q1 = [0b1100] * bu.second
q2 = [0b1010] * bu.second
result_q = bu.math.bitwise_and(q1, q2)
Expand All @@ -1960,7 +1961,7 @@ def test_bitwise_or(self):
result = bu.math.bitwise_or(jnp.array([0b1100]), jnp.array([0b1010]))
self.assertTrue(jnp.all(result == jnp.bitwise_or(jnp.array([0b1100]), jnp.array([0b1010]))))

with pytest.raises(ValueError):
with pytest.raises(AssertionError):
q1 = [0b1100] * bu.second
q2 = [0b1010] * bu.second
result_q = bu.math.bitwise_or(q1, q2)
Expand All @@ -1969,7 +1970,7 @@ def test_bitwise_xor(self):
result = bu.math.bitwise_xor(jnp.array([0b1100]), jnp.array([0b1010]))
self.assertTrue(jnp.all(result == jnp.bitwise_xor(jnp.array([0b1100]), jnp.array([0b1010]))))

with pytest.raises(ValueError):
with pytest.raises(AssertionError):
q1 = [0b1100] * bu.second
q2 = [0b1010] * bu.second
result_q = bu.math.bitwise_xor(q1, q2)
Expand All @@ -1978,15 +1979,15 @@ def test_left_shift(self):
result = bu.math.left_shift(jnp.array([0b1100]), 2)
self.assertTrue(jnp.all(result == jnp.left_shift(jnp.array([0b1100]), 2)))

with pytest.raises(ValueError):
with pytest.raises(AssertionError):
q = [0b1100] * bu.second
result_q = bu.math.left_shift(q, 2)

def test_right_shift(self):
result = bu.math.right_shift(jnp.array([0b1100]), 2)
self.assertTrue(jnp.all(result == jnp.right_shift(jnp.array([0b1100]), 2)))

with pytest.raises(ValueError):
with pytest.raises(AssertionError):
q = [0b1100] * bu.second
result_q = bu.math.right_shift(q, 2)

Expand All @@ -1996,23 +1997,23 @@ def test_all(self):
result = bu.math.all(jnp.array([True, True, True]))
self.assertTrue(result == jnp.all(jnp.array([True, True, True])))

with pytest.raises(ValueError):
with pytest.raises(AssertionError):
q = [True, True, True] * bu.second
result_q = bu.math.all(q)

def test_any(self):
result = bu.math.any(jnp.array([False, True, False]))
self.assertTrue(result == jnp.any(jnp.array([False, True, False])))

with pytest.raises(ValueError):
with pytest.raises(AssertionError):
q = [False, True, False] * bu.second
result_q = bu.math.any(q)

def test_logical_not(self):
result = bu.math.logical_not(jnp.array([True, False]))
self.assertTrue(jnp.all(result == jnp.logical_not(jnp.array([True, False]))))

with pytest.raises(ValueError):
with pytest.raises(AssertionError):
q = [True, False] * bu.second
result_q = bu.math.logical_not(q)

Expand Down Expand Up @@ -2153,7 +2154,7 @@ def test_where(self):
self.assertTrue(jnp.all(result == jnp.where(array > 2, array, 0)))

q = [1, 2, 3, 4, 5] * bu.second
result_q = bu.math.where(q > 2 * bu.second, q, 0)
result_q = bu.math.where(q > 2 * bu.second, q.to_value(bu.second), 0)
expected_q = jnp.where(jnp.array([1, 2, 3, 4, 5]) > 2, jnp.array([1, 2, 3, 4, 5]), 0)
assert_quantity(result_q, expected_q, bu.second)

Expand Down

0 comments on commit 419c104

Please sign in to comment.