Skip to content

Commit

Permalink
Fix tensor frexp (#3259)
Browse files Browse the repository at this point in the history
* Fix tensor frexp

* Test for Ray DAG

* Fix

Co-authored-by: 刘宝 <po.lb@antgroup.com>
  • Loading branch information
fyrestone and 刘宝 committed Sep 19, 2022
1 parent c69eea9 commit 8e0d0e6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 12 deletions.
15 changes: 4 additions & 11 deletions mars/tensor/arithmetic/frexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,10 @@ def execute(cls, ctx, op):
where = None
kw["order"] = op.order

try:
args = [input]
if out1 is not None:
args.append(out1)
if out2 is not None:
args.append(out2)
mantissa, exponent = xp.frexp(*args, **kw)
except TypeError:
if where is None:
raise
mantissa, exponent = xp.frexp(input)
# The out1 out2 are immutable because they are got from
# the shared memory.
mantissa, exponent = xp.frexp(input)
if where is not None:
mantissa, exponent = (
xp.where(where, mantissa, out1),
xp.where(where, exponent, out2),
Expand Down
14 changes: 13 additions & 1 deletion mars/tensor/arithmetic/tests/test_arithmetic_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ....session import execute, fetch
from ....tests.core import require_cupy
from ....utils import ignore_warning
from ...datasource import ones, tensor, zeros
from ...datasource import ones, tensor, zeros, arange
from .. import (
add,
cos,
Expand Down Expand Up @@ -349,6 +349,7 @@ def test_arctan2_execution(setup):
np.testing.assert_equal(result, np.arctan2(0, raw2.A))


@pytest.mark.ray_dag
def test_frexp_execution(setup):
data1 = np.random.RandomState(0).randint(0, 100, (5, 9, 6))

Expand Down Expand Up @@ -381,6 +382,17 @@ def test_frexp_execution(setup):
expected = sum(np.frexp(data1.toarray()))
np.testing.assert_equal(res.toarray(), expected)

x = np.arange(9)
a = np.zeros(9)
b = np.zeros(9)
mx = arange(9)
ma = zeros(9)
mb = zeros(9)
res = frexp(mx, ma, mb, where=mx > 5).execute()
expected = np.frexp(x, a, b, where=x > 5)
np.testing.assert_equal(res[0], expected[0])
np.testing.assert_equal(res[1], expected[1])


def test_frexp_order_execution(setup):
data1 = np.random.RandomState(0).random((5, 9))
Expand Down

0 comments on commit 8e0d0e6

Please sign in to comment.