Skip to content

Commit

Permalink
Implement airy functions (except the ai_zeros and bi_zeros functi…
Browse files Browse the repository at this point in the history
…ons) (#3195)
  • Loading branch information
shantam-8 committed Jul 28, 2022
1 parent 9d7fbd3 commit d4ef9fc
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 11 deletions.
12 changes: 12 additions & 0 deletions docs/source/reference/tensor/special.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
Special Functions
=================

Airy functions
--------------

.. autosummary::
:toctree: generated/
:nosignatures:

mars.tensor.special.airy
mars.tensor.special.airye
mars.tensor.special.itairy


Information Theory functions
----------------------------

Expand Down
4 changes: 4 additions & 0 deletions mars/lib/sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,10 @@ def reciprocal(x, **kw):
elliprg = partial(call_sparse, "elliprg")
elliprj = partial(call_sparse, "elliprj")

airy = partial(_call_unary, "airy")
airye = partial(_call_unary, "airye")
itairy = partial(_call_unary, "itairy")


def equal(a, b, **_):
try:
Expand Down
4 changes: 4 additions & 0 deletions mars/lib/sparse/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,10 @@ def _scipy_binary(self, func_name, other):

hyp0f1 = partialmethod(_scipy_binary, "hyp0f1")

airy = partialmethod(_scipy_unary, "airy")
airye = partialmethod(_scipy_unary, "airye")
itairy = partialmethod(_scipy_unary, "itairy")

def __eq__(self, other):
try:
naked_other = naked(other)
Expand Down
8 changes: 8 additions & 0 deletions mars/tensor/special/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,14 @@
elliprj,
TensorElliprj,
)
from .airy import (
airy,
TensorAiry,
airye,
TensorAirye,
itairy,
TensorItairy,
)
except ImportError: # pragma: no cover
pass

Expand Down
57 changes: 57 additions & 0 deletions mars/tensor/special/airy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import scipy.special as spspecial

from ..utils import infer_dtype, implement_scipy
from .core import TensorTupleOp, _register_special_op


@_register_special_op
class TensorAiry(TensorTupleOp):
_func_name = "airy"
_n_outputs = 4


@implement_scipy(spspecial.airy)
@infer_dtype(spspecial.airy, multi_outputs=True)
def airy(z, out=None, **kwargs):
op = TensorAiry(**kwargs)
return op(z, out=out)


@_register_special_op
class TensorAirye(TensorTupleOp):
_func_name = "airye"
_n_outputs = 4


@implement_scipy(spspecial.airye)
@infer_dtype(spspecial.airye, multi_outputs=True)
def airye(z, out=None, **kwargs):
op = TensorAirye(**kwargs)
return op(z, out=out)


@_register_special_op
class TensorItairy(TensorTupleOp):
_func_name = "itairy"
_n_outputs = 4


@implement_scipy(spspecial.itairy)
@infer_dtype(spspecial.itairy, multi_outputs=True)
def itairy(x, out=None, **kwargs):
op = TensorItairy(**kwargs)
return op(x, out=out)
20 changes: 14 additions & 6 deletions mars/tensor/special/tests/test_special.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@
TensorEllipe,
TensorEllipeinc,
)
from ..airy import (
TensorAiry,
TensorAirye,
TensorItairy,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -135,14 +140,17 @@ def test_unary_operand_out(func, tensor_cls):


@pytest.mark.parametrize(
"func,tensor_cls",
"func,tensor_cls,n_outputs",
[
("fresnel", TensorFresnel),
("modfresnelp", TensorModFresnelP),
("modfresnelm", TensorModFresnelM),
("fresnel", TensorFresnel, 2),
("modfresnelp", TensorModFresnelP, 2),
("modfresnelm", TensorModFresnelM, 2),
("airy", TensorAiry, 4),
("airye", TensorAirye, 4),
("itairy", TensorItairy, 4),
],
)
def test_unary_tuple_operand(func, tensor_cls):
def test_unary_tuple_operand(func, tensor_cls, n_outputs):
sp_func = getattr(spsecial, func)
mt_func = getattr(mt_special, func)

Expand All @@ -167,7 +175,7 @@ def test_unary_tuple_operand(func, tensor_cls):
with pytest.raises(TypeError):
r = mt_func(t, mismatch_size_tuple)

out = ExecutableTuple([t, t])
out = ExecutableTuple([t] * n_outputs)
r_out = mt_func(t, out=out)

assert isinstance(out, ExecutableTuple)
Expand Down
6 changes: 1 addition & 5 deletions mars/tensor/special/tests/test_special_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,11 +307,7 @@ def test_quintuple_execution(setup, func):

@pytest.mark.parametrize(
"func",
[
"fresnel",
"modfresnelp",
"modfresnelm",
],
["fresnel", "modfresnelp", "modfresnelm", "airy", "airye", "itairy"],
)
def test_unary_tuple_execution(setup, func):
sp_func = getattr(spspecial, func)
Expand Down

0 comments on commit d4ef9fc

Please sign in to comment.