Skip to content

Commit

Permalink
Merge branch 'master' into enhancement/721-warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
coquelin77 committed Mar 16, 2021
2 parents 237e484 + 13627f3 commit f298dda
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 0 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
- [#660](https://github.com/helmholtz-analytics/heat/pull/660) New feature: Data loader for H5 datasets which shuffles data in the background during training (`utils.data.partial_dataset.PartialH5Dataset`)
### Logical
- [#711](https://github.com/helmholtz-analytics/heat/pull/711) `isfinite()`, `isinf()`, `isnan()`
- [#743](https://github.com/helmholtz-analytics/heat/pull/743) `isneginf()`, `isposinf()`

### Types
- [#738](https://github.com/helmholtz-analytics/heat/pull/738) `iscomplex()`, `isreal()`


## Bug fixes
Expand Down
34 changes: 34 additions & 0 deletions heat/core/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
"isfinite",
"isinf",
"isnan",
"isneginf",
"isposinf",
"logical_and",
"logical_not",
"logical_or",
Expand Down Expand Up @@ -275,6 +277,38 @@ def isnan(x):
return _operations.__local_op(torch.isnan, x, None, no_cast=True)


def isneginf(x, out=None):
"""
Test if each element of `x` is negative infinite, return result as a bool array.
Parameters
----------
x : DNDarray
Examples
--------
>>> ht.isnan(ht.array([1, ht.inf, -ht.inf, ht.nan]))
DNDarray([False, False, True, False], dtype=ht.bool, device=cpu:0, split=None)
"""
return _operations.__local_op(torch.isneginf, x, out, no_cast=True)


def isposinf(x, out=None):
"""
Test if each element of `x` is positive infinite, return result as a bool array.
Parameters
----------
x : DNDarray
Examples
--------
>>> ht.isnan(ht.array([1, ht.inf, -ht.inf, ht.nan]))
DNDarray([False, True, False, False], dtype=ht.bool, device=cpu:0, split=None)
"""
return _operations.__local_op(torch.isposinf, x, out, no_cast=True)


def logical_and(t1, t2):
"""
Compute the truth value of t1 AND t2 element-wise.
Expand Down
68 changes: 68 additions & 0 deletions heat/core/tests/test_logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,74 @@ def test_isnan(self):
self.assertEqual(r.device, s.device)
self.assertTrue(ht.equal(r, s))

def test_isneginf(self):
a = ht.array([1, ht.inf, -ht.inf, ht.nan])
s = ht.array([False, False, True, False])
r = ht.isneginf(a)
self.assertEqual(r.shape, s.shape)
self.assertEqual(r.dtype, s.dtype)
self.assertEqual(r.device, s.device)
self.assertTrue(ht.equal(r, s))

a = ht.array([1, ht.inf, -ht.inf, ht.nan], split=0)
out = ht.empty(4, dtype=ht.bool, split=0)
s = ht.array([False, False, True, False], split=0)
ht.isneginf(a, out)
self.assertEqual(out.shape, s.shape)
self.assertEqual(out.dtype, s.dtype)
self.assertEqual(out.device, s.device)
self.assertTrue(ht.equal(r, s))

a = ht.ones((6, 6), dtype=ht.bool, split=0)
s = ht.zeros((6, 6), dtype=ht.bool, split=0)
r = ht.isneginf(a)
self.assertEqual(r.shape, s.shape)
self.assertEqual(r.dtype, s.dtype)
self.assertEqual(r.device, s.device)
self.assertTrue(ht.equal(r, s))

a = ht.ones((5, 5), dtype=ht.int, split=1)
s = ht.zeros((5, 5), dtype=ht.bool, split=1)
r = ht.isneginf(a)
self.assertEqual(r.shape, s.shape)
self.assertEqual(r.dtype, s.dtype)
self.assertEqual(r.device, s.device)
self.assertTrue(ht.equal(r, s))

def test_isposinf(self):
a = ht.array([1, ht.inf, -ht.inf, ht.nan])
s = ht.array([False, True, False, False])
r = ht.isposinf(a)
self.assertEqual(r.shape, s.shape)
self.assertEqual(r.dtype, s.dtype)
self.assertEqual(r.device, s.device)
self.assertTrue(ht.equal(r, s))

a = ht.array([1, ht.inf, -ht.inf, ht.nan], split=0)
out = ht.empty(4, dtype=ht.bool, split=0)
s = ht.array([False, True, False, False], split=0)
ht.isposinf(a, out)
self.assertEqual(out.shape, s.shape)
self.assertEqual(out.dtype, s.dtype)
self.assertEqual(out.device, s.device)
self.assertTrue(ht.equal(r, s))

a = ht.ones((6, 6), dtype=ht.bool, split=0)
s = ht.zeros((6, 6), dtype=ht.bool, split=0)
r = ht.isposinf(a)
self.assertEqual(r.shape, s.shape)
self.assertEqual(r.dtype, s.dtype)
self.assertEqual(r.device, s.device)
self.assertTrue(ht.equal(r, s))

a = ht.ones((5, 5), dtype=ht.int, split=1)
s = ht.zeros((5, 5), dtype=ht.bool, split=1)
r = ht.isposinf(a)
self.assertEqual(r.shape, s.shape)
self.assertEqual(r.dtype, s.dtype)
self.assertEqual(r.device, s.device)
self.assertTrue(ht.equal(r, s))

def test_logical_and(self):
first_tensor = ht.array([[True, True], [False, False]])
second_tensor = ht.array([[True, False], [True, False]])
Expand Down
66 changes: 66 additions & 0 deletions heat/core/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,72 @@ def test_complex128(self):

self.assertEqual(ht.complex128.char(), "c16")

def test_iscomplex(self):
a = ht.array([1, 1.2, 1 + 1j, 1 + 0j])
s = ht.array([False, False, True, False])
r = ht.iscomplex(a)
self.assertEqual(r.shape, s.shape)
self.assertEqual(r.dtype, s.dtype)
self.assertEqual(r.device, s.device)
self.assertTrue(ht.equal(r, s))

a = ht.array([1, 1.2, True], split=0)
s = ht.array([False, False, False], split=0)
r = ht.iscomplex(a)
self.assertEqual(r.shape, s.shape)
self.assertEqual(r.dtype, s.dtype)
self.assertEqual(r.device, s.device)
self.assertTrue(ht.equal(r, s))

a = ht.ones((6, 6), dtype=ht.bool, split=0)
s = ht.zeros((6, 6), dtype=ht.bool, split=0)
r = ht.iscomplex(a)
self.assertEqual(r.shape, s.shape)
self.assertEqual(r.dtype, s.dtype)
self.assertEqual(r.device, s.device)
self.assertTrue(ht.equal(r, s))

a = ht.full((5, 5), 1 + 1j, dtype=ht.int, split=1)
s = ht.ones((5, 5), dtype=ht.bool, split=1)
r = ht.iscomplex(a)
self.assertEqual(r.shape, s.shape)
self.assertEqual(r.dtype, s.dtype)
self.assertEqual(r.device, s.device)
self.assertTrue(ht.equal(r, s))

def test_isreal(self):
a = ht.array([1, 1.2, 1 + 1j, 1 + 0j])
s = ht.array([True, True, False, True])
r = ht.isreal(a)
self.assertEqual(r.shape, s.shape)
self.assertEqual(r.dtype, s.dtype)
self.assertEqual(r.device, s.device)
self.assertTrue(ht.equal(r, s))

a = ht.array([1, 1.2, True], split=0)
s = ht.array([True, True, True], split=0)
r = ht.isreal(a)
self.assertEqual(r.shape, s.shape)
self.assertEqual(r.dtype, s.dtype)
self.assertEqual(r.device, s.device)
self.assertTrue(ht.equal(r, s))

a = ht.ones((6, 6), dtype=ht.bool, split=0)
s = ht.ones((6, 6), dtype=ht.bool, split=0)
r = ht.isreal(a)
self.assertEqual(r.shape, s.shape)
self.assertEqual(r.dtype, s.dtype)
self.assertEqual(r.device, s.device)
self.assertTrue(ht.equal(r, s))

a = ht.full((5, 5), 1 + 1j, dtype=ht.int, split=1)
s = ht.zeros((5, 5), dtype=ht.bool, split=1)
r = ht.isreal(a)
self.assertEqual(r.shape, s.shape)
self.assertEqual(r.dtype, s.dtype)
self.assertEqual(r.device, s.device)
self.assertTrue(ht.equal(r, s))


class TestTypeConversion(TestCase):
def test_can_cast(self):
Expand Down
40 changes: 40 additions & 0 deletions heat/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from . import communication
from . import devices
from . import factories
from . import _operations
from . import sanitation


__all__ = [
Expand Down Expand Up @@ -55,6 +57,8 @@
"double",
"flexible",
"can_cast",
"iscomplex",
"isreal",
"issubdtype",
"promote_types",
"complex64",
Expand Down Expand Up @@ -607,6 +611,42 @@ def can_cast(from_, to, casting="intuitive"):
break


def iscomplex(x):
"""
Test element-wise if input is complex.
Parameters
----------
x : DNDarray
Examples
--------
>>> ht.iscomplex(ht.array([1+1j, 1]))
DNDarray([ True, False], dtype=ht.bool, device=cpu:0, split=None)
"""
sanitation.sanitize_in(x)

if issubclass(x.dtype, _complexfloating):
return x.imag != 0
else:
return factories.zeros(x.shape, bool, split=x.split, device=x.device, comm=x.comm)


def isreal(x):
"""
Test element-wise if input is real-valued.
Parameters
----------
x : DNDarray
Examples
--------
ht.iscomplex(ht.array([1+1j, 1]))
"""
return _operations.__local_op(torch.isreal, x, None, no_cast=True)


def issubdtype(arg1, arg2):
"""
Returns True if first argument is a typecode lower/equal in type hierarchy.
Expand Down

0 comments on commit f298dda

Please sign in to comment.