Skip to content

Commit

Permalink
Support scipy special function with tuple output (#3139)
Browse files Browse the repository at this point in the history
  • Loading branch information
RandomY-2 committed Jun 16, 2022
1 parent 3e3f4fd commit a8cb40e
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 5 deletions.
3 changes: 2 additions & 1 deletion docs/source/reference/tensor/special.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Bessel functions
mars.tensor.special.hankel2e


Error function
Error functions and fresnel integrals
--------------

.. autosummary::
Expand All @@ -48,6 +48,7 @@ Error function
mars.tensor.special.erfi
mars.tensor.special.erfinv
mars.tensor.special.erfcinv
mars.tensor.special.fresnel


Ellipsoidal harmonics
Expand Down
2 changes: 2 additions & 0 deletions mars/tensor/special/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
TensorErfinv,
erfcinv,
TensorErfcinv,
fresnel,
TensorFresnel,
)
from .gamma_funcs import (
gamma,
Expand Down
59 changes: 59 additions & 0 deletions mars/tensor/special/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import scipy.special as spspecial

from ...core import ExecutableTuple
from ... import opcodes
from ..datasource import tensor as astensor
from ..arithmetic.core import TensorUnaryOp, TensorBinOp, TensorMultiOp
from ..array_utils import (
np,
Expand Down Expand Up @@ -112,3 +115,59 @@ def execute(cls, ctx, op):
if ret.dtype != op.dtype:
ret = ret.astype(op.dtype)
ctx[op.outputs[0].key] = ret


class TensorTupleOp(TensorSpecialUnaryOp):
@property
def output_limit(self):
return self._n_outputs

def __call__(self, x, out=None):
x = astensor(x)

if out is not None:
if not isinstance(out, ExecutableTuple):
raise TypeError(
f"out should be ExecutableTuple object, got {type(out)} instead"
)
if len(out) != self._n_outputs:
raise TypeError(
f"out should be an ExecutableTuple object with {self._n_outputs} elements, got {len(out)} instead"
)

func = getattr(spspecial, self._func_name)
res = func(np.ones(x.shape, dtype=x.dtype))
res_tensors = self.new_tensors(
[x],
kws=[
{
"side": f"{self._func_name}[{i}]",
"dtype": output.dtype,
"shape": output.shape,
}
for i, output in enumerate(res)
],
)

if out is None:
return ExecutableTuple(res_tensors)

for res_tensor, out_tensor in zip(res_tensors, out):
out_tensor.data = res_tensor.data
return out

@classmethod
def execute(cls, ctx, op):
inputs, device_id, xp = as_same_device(
[ctx[c.key] for c in op.inputs], device=op.device, ret_extra=True
)

with device(device_id):
with np.errstate(**op.err):
if op.is_gpu():
ret = cls._execute_gpu(op, xp, inputs[0])
else:
ret = cls._execute_cpu(op, xp, inputs[0])

for output, ret_element in zip(op.outputs, ret):
ctx[output.key] = ret_element
19 changes: 18 additions & 1 deletion mars/tensor/special/err_fresnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@

from ..arithmetic.utils import arithmetic_operand
from ..utils import infer_dtype, implement_scipy
from .core import TensorSpecialUnaryOp, _register_special_op
from .core import (
TensorSpecialUnaryOp,
TensorTupleOp,
_register_special_op,
)


@_register_special_op
Expand Down Expand Up @@ -55,6 +59,12 @@ class TensorErfcinv(TensorSpecialUnaryOp):
_func_name = "erfcinv"


@_register_special_op
class TensorFresnel(TensorTupleOp):
_func_name = "fresnel"
_n_outputs = 2


@implement_scipy(spspecial.erf)
@infer_dtype(spspecial.erf)
def erf(x, out=None, where=None, **kwargs):
Expand Down Expand Up @@ -140,3 +150,10 @@ def erfinv(x, out=None, where=None, **kwargs):
def erfcinv(x, out=None, where=None, **kwargs):
op = TensorErfcinv(**kwargs)
return op(x, out=out, where=where)


@implement_scipy(spspecial.fresnel)
@infer_dtype(spspecial.fresnel, multi_outputs=True)
def fresnel(x, out=None, **kwargs):
op = TensorFresnel(**kwargs)
return op(x, out=out)
47 changes: 46 additions & 1 deletion mars/tensor/special/tests/test_special.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@
ellipkinc as scipy_ellipkinc,
ellipe as scipy_ellipe,
ellipeinc as scipy_ellipeinc,
fresnel as scipy_fresnel,
betainc as scipy_betainc,
)

from ....lib.version import parse as parse_version
from ....core import tile
from ....core import tile, ExecutableTuple
from ... import tensor
from ..err_fresnel import (
erf,
Expand All @@ -47,6 +48,8 @@
TensorErfinv,
erfcinv,
TensorErfcinv,
fresnel,
TensorFresnel,
)
from ..gamma_funcs import (
gammaln,
Expand Down Expand Up @@ -276,6 +279,48 @@ def test_erfcinv():
assert c.shape == c.inputs[0].shape


def test_fresnel():
raw = np.random.rand(10, 8, 5)
t = tensor(raw, chunk_size=3)

r = fresnel(t)
expect = scipy_fresnel(raw)

assert isinstance(r, ExecutableTuple)
assert len(r) == 2

for i in range(len(r)):
assert r[i].shape == expect[i].shape
assert r[i].dtype == expect[i].dtype
assert isinstance(r[i].op, TensorFresnel)

non_tuple_out = tensor(raw, chunk_size=3)
with pytest.raises(TypeError):
r = fresnel(t, non_tuple_out)

mismatch_size_tuple = ExecutableTuple([t])
with pytest.raises(TypeError):
r = fresnel(t, mismatch_size_tuple)

out = ExecutableTuple([t, t])
r_out = fresnel(t, out=out)

assert isinstance(out, ExecutableTuple)
assert isinstance(r_out, ExecutableTuple)

assert len(out) == 2
assert len(r_out) == 2

for r_output, expected_output, out_output in zip(r, expect, out):
assert r_output.shape == expected_output.shape
assert r_output.dtype == expected_output.dtype
assert isinstance(r_output.op, TensorFresnel)

assert out_output.shape == expected_output.shape
assert out_output.dtype == expected_output.dtype
assert isinstance(out_output.op, TensorFresnel)


def test_beta_inc():
raw1 = np.random.rand(4, 3, 2)
raw2 = np.random.rand(4, 3, 2)
Expand Down
22 changes: 22 additions & 0 deletions mars/tensor/special/tests/test_special_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,25 @@ def test_quintuple_execution(setup, func):

expected = sp_func(raw1.toarray(), raw2, raw3, raw4, raw5)
np.testing.assert_array_equal(result.toarray(), expected)


@pytest.mark.parametrize(
"func",
[
"fresnel",
],
)
def test_unary_tuple_execution(setup, func):
sp_func = getattr(spspecial, func)
mt_func = getattr(mt_special, func)

raw = np.random.rand(10, 8, 6)
a = tensor(raw, chunk_size=3)

r = mt_func(a)

result = r.execute().fetch()
expected = sp_func(raw)

for actual_output, expected_output in zip(result, expected):
np.testing.assert_array_equal(actual_output, expected_output)
7 changes: 5 additions & 2 deletions mars/tensor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def call(*tensors, **kw):
return inner


def infer_dtype(np_func, empty=True, reverse=False, check=True):
def infer_dtype(np_func, multi_outputs=False, empty=True, reverse=False, check=True):
def make_arg(arg):
if empty:
return np.empty((1,) * max(1, arg.ndim), dtype=arg.dtype)
Expand Down Expand Up @@ -267,7 +267,10 @@ def h(*tensors, **kw):
# that implements __tensor_ufunc__
try:
with np.errstate(all="ignore"):
dtype = np_func(*args, **np_kw).dtype
if multi_outputs:
dtype = np_func(*args, **np_kw)[0].dtype
else:
dtype = np_func(*args, **np_kw).dtype
except: # noqa: E722
dtype = None

Expand Down

0 comments on commit a8cb40e

Please sign in to comment.