Skip to content

Commit

Permalink
Merge pull request #735 from helmholtz-analytics/bug/729-dtype-relati…
Browse files Browse the repository at this point in the history
…onal-functions

relational functions cast to boolean
  • Loading branch information
coquelin77 authored Mar 11, 2021
2 parents 1febcfc + f504eed commit 34c356b
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 23 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

## Bug fixes
- [#709](https://github.com/helmholtz-analytics/heat/pull/709) Set the encoding for README.md in setup.py explicitly.
- [#735](https://github.com/helmholtz-analytics/heat/pull/735) Set return type to bool in relational functions.

# v0.5.2

Expand Down
92 changes: 86 additions & 6 deletions heat/core/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from .communication import MPI
from . import _operations
from . import dndarray
from . import types

__all__ = ["eq", "equal", "ge", "gt", "le", "lt", "ne"]

Expand Down Expand Up @@ -37,7 +39,20 @@ def eq(t1, t2):
tensor([[0, 1],
[0, 0]])
"""
return _operations.__binary_op(torch.eq, t1, t2)
res = _operations.__binary_op(torch.eq, t1, t2)

if res.dtype != types.bool:
res = dndarray.DNDarray(
res.larray.type(torch.bool),
res.gshape,
types.bool,
res.split,
res.device,
res.comm,
res.balanced,
)

return res


def equal(t1, t2):
Expand Down Expand Up @@ -111,7 +126,20 @@ def ge(t1, t2):
tensor([[0, 1],
[1, 1]], dtype=torch.uint8)
"""
return _operations.__binary_op(torch.ge, t1, t2)
res = _operations.__binary_op(torch.ge, t1, t2)

if res.dtype != types.bool:
res = dndarray.DNDarray(
res.larray.type(torch.bool),
res.gshape,
types.bool,
res.split,
res.device,
res.comm,
res.balanced,
)

return res


def gt(t1, t2):
Expand Down Expand Up @@ -147,7 +175,20 @@ def gt(t1, t2):
tensor([[0, 0],
[1, 1]], dtype=torch.uint8)
"""
return _operations.__binary_op(torch.gt, t1, t2)
res = _operations.__binary_op(torch.gt, t1, t2)

if res.dtype != types.bool:
res = dndarray.DNDarray(
res.larray.type(torch.bool),
res.gshape,
types.bool,
res.split,
res.device,
res.comm,
res.balanced,
)

return res


def le(t1, t2):
Expand Down Expand Up @@ -182,7 +223,20 @@ def le(t1, t2):
tensor([[1, 1],
[0, 0]], dtype=torch.uint8)
"""
return _operations.__binary_op(torch.le, t1, t2)
res = _operations.__binary_op(torch.le, t1, t2)

if res.dtype != types.bool:
res = dndarray.DNDarray(
res.larray.type(torch.bool),
res.gshape,
types.bool,
res.split,
res.device,
res.comm,
res.balanced,
)

return res


def lt(t1, t2):
Expand Down Expand Up @@ -217,7 +271,20 @@ def lt(t1, t2):
tensor([[1, 0],
[0, 0]], dtype=torch.uint8)
"""
return _operations.__binary_op(torch.lt, t1, t2)
res = _operations.__binary_op(torch.lt, t1, t2)

if res.dtype != types.bool:
res = dndarray.DNDarray(
res.larray.type(torch.bool),
res.gshape,
types.bool,
res.split,
res.device,
res.comm,
res.balanced,
)

return res


def ne(t1, t2):
Expand Down Expand Up @@ -251,4 +318,17 @@ def ne(t1, t2):
tensor([[1, 0],
[1, 1]])
"""
return _operations.__binary_op(torch.ne, t1, t2)
res = _operations.__binary_op(torch.ne, t1, t2)

if res.dtype != types.bool:
res = dndarray.DNDarray(
res.larray.type(torch.bool),
res.gshape,
types.bool,
res.split,
res.device,
res.comm,
res.balanced,
)

return res
44 changes: 28 additions & 16 deletions heat/core/tests/test_relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@ def setUpClass(cls):
cls.errorneous_type = (2, 2)

def test_eq(self):
result = ht.uint8([[0, 1], [0, 0]])
result = ht.array([[False, True], [False, False]])

self.assertTrue(ht.equal(ht.eq(self.a_scalar, self.a_scalar), ht.uint8([1])))
self.assertTrue(ht.equal(ht.eq(self.a_scalar, self.a_scalar), ht.array([True])))
self.assertTrue(ht.equal(ht.eq(self.a_tensor, self.a_scalar), result))
self.assertTrue(ht.equal(ht.eq(self.a_scalar, self.a_tensor), result))
self.assertTrue(ht.equal(ht.eq(self.a_tensor, self.another_tensor), result))
self.assertTrue(ht.equal(ht.eq(self.a_tensor, self.a_vector), result))
self.assertTrue(ht.equal(ht.eq(self.a_tensor, self.an_int_scalar), result))
self.assertTrue(ht.equal(ht.eq(self.a_split_tensor, self.a_tensor), result))

self.assertEqual(ht.eq(self.a_split_tensor, self.a_tensor).dtype, ht.bool)

with self.assertRaises(ValueError):
ht.eq(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
Expand All @@ -44,17 +46,19 @@ def test_equal(self):
self.assertFalse(ht.equal(self.another_tensor, self.a_scalar))

def test_ge(self):
result = ht.uint8([[0, 1], [1, 1]])
commutated_result = ht.uint8([[1, 1], [0, 0]])
result = ht.uint8([[False, True], [True, True]])
commutated_result = ht.array([[True, True], [False, False]])

self.assertTrue(ht.equal(ht.ge(self.a_scalar, self.a_scalar), ht.uint8([1])))
self.assertTrue(ht.equal(ht.ge(self.a_scalar, self.a_scalar), ht.array([True])))
self.assertTrue(ht.equal(ht.ge(self.a_tensor, self.a_scalar), result))
self.assertTrue(ht.equal(ht.ge(self.a_scalar, self.a_tensor), commutated_result))
self.assertTrue(ht.equal(ht.ge(self.a_tensor, self.another_tensor), result))
self.assertTrue(ht.equal(ht.ge(self.a_tensor, self.a_vector), result))
self.assertTrue(ht.equal(ht.ge(self.a_tensor, self.an_int_scalar), result))
self.assertTrue(ht.equal(ht.ge(self.a_split_tensor, self.a_tensor), commutated_result))

self.assertEqual(ht.ge(self.a_split_tensor, self.a_tensor).dtype, ht.bool)

with self.assertRaises(ValueError):
ht.ge(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
Expand All @@ -63,17 +67,19 @@ def test_ge(self):
ht.ge("self.a_tensor", "s")

def test_gt(self):
result = ht.uint8([[0, 0], [1, 1]])
commutated_result = ht.uint8([[1, 0], [0, 0]])
result = ht.array([[False, False], [True, True]])
commutated_result = ht.array([[True, False], [False, False]])

self.assertTrue(ht.equal(ht.gt(self.a_scalar, self.a_scalar), ht.uint8([0])))
self.assertTrue(ht.equal(ht.gt(self.a_scalar, self.a_scalar), ht.array([False])))
self.assertTrue(ht.equal(ht.gt(self.a_tensor, self.a_scalar), result))
self.assertTrue(ht.equal(ht.gt(self.a_scalar, self.a_tensor), commutated_result))
self.assertTrue(ht.equal(ht.gt(self.a_tensor, self.another_tensor), result))
self.assertTrue(ht.equal(ht.gt(self.a_tensor, self.a_vector), result))
self.assertTrue(ht.equal(ht.gt(self.a_tensor, self.an_int_scalar), result))
self.assertTrue(ht.equal(ht.gt(self.a_split_tensor, self.a_tensor), commutated_result))

self.assertEqual(ht.gt(self.a_split_tensor, self.a_tensor).dtype, ht.bool)

with self.assertRaises(ValueError):
ht.gt(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
Expand All @@ -82,17 +88,19 @@ def test_gt(self):
ht.gt("self.a_tensor", "s")

def test_le(self):
result = ht.uint8([[1, 1], [0, 0]])
commutated_result = ht.uint8([[0, 1], [1, 1]])
result = ht.array([[True, True], [False, False]])
commutated_result = ht.array([[False, True], [True, True]])

self.assertTrue(ht.equal(ht.le(self.a_scalar, self.a_scalar), ht.uint8([1])))
self.assertTrue(ht.equal(ht.le(self.a_scalar, self.a_scalar), ht.array([True])))
self.assertTrue(ht.equal(ht.le(self.a_tensor, self.a_scalar), result))
self.assertTrue(ht.equal(ht.le(self.a_scalar, self.a_tensor), commutated_result))
self.assertTrue(ht.equal(ht.le(self.a_tensor, self.another_tensor), result))
self.assertTrue(ht.equal(ht.le(self.a_tensor, self.a_vector), result))
self.assertTrue(ht.equal(ht.le(self.a_tensor, self.an_int_scalar), result))
self.assertTrue(ht.equal(ht.le(self.a_split_tensor, self.a_tensor), commutated_result))

self.assertEqual(ht.le(self.a_split_tensor, self.a_tensor).dtype, ht.bool)

with self.assertRaises(ValueError):
ht.le(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
Expand All @@ -101,17 +109,19 @@ def test_le(self):
ht.le("self.a_tensor", "s")

def test_lt(self):
result = ht.uint8([[1, 0], [0, 0]])
commutated_result = ht.uint8([[0, 0], [1, 1]])
result = ht.array([[True, False], [False, False]])
commutated_result = ht.array([[False, False], [True, True]])

self.assertTrue(ht.equal(ht.lt(self.a_scalar, self.a_scalar), ht.uint8([0])))
self.assertTrue(ht.equal(ht.lt(self.a_scalar, self.a_scalar), ht.array([False])))
self.assertTrue(ht.equal(ht.lt(self.a_tensor, self.a_scalar), result))
self.assertTrue(ht.equal(ht.lt(self.a_scalar, self.a_tensor), commutated_result))
self.assertTrue(ht.equal(ht.lt(self.a_tensor, self.another_tensor), result))
self.assertTrue(ht.equal(ht.lt(self.a_tensor, self.a_vector), result))
self.assertTrue(ht.equal(ht.lt(self.a_tensor, self.an_int_scalar), result))
self.assertTrue(ht.equal(ht.lt(self.a_split_tensor, self.a_tensor), commutated_result))

self.assertEqual(ht.lt(self.a_split_tensor, self.a_tensor).dtype, ht.bool)

with self.assertRaises(ValueError):
ht.lt(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
Expand All @@ -120,9 +130,9 @@ def test_lt(self):
ht.lt("self.a_tensor", "s")

def test_ne(self):
result = ht.uint8([[1, 0], [1, 1]])
result = ht.array([[True, False], [True, True]])

# self.assertTrue(ht.equal(ht.ne(self.a_scalar, self.a_scalar), ht.uint8([0])))
# self.assertTrue(ht.equal(ht.ne(self.a_scalar, self.a_scalar), ht.array([False])))
# self.assertTrue(ht.equal(ht.ne(self.a_tensor, self.a_scalar), result))
# self.assertTrue(ht.equal(ht.ne(self.a_scalar, self.a_tensor), result))
# self.assertTrue(ht.equal(ht.ne(self.a_tensor, self.another_tensor), result))
Expand All @@ -131,6 +141,8 @@ def test_ne(self):
self.assertTrue(ht.equal(ht.ne(self.a_split_tensor, self.a_tensor), result))
self.assertTrue(ht.equal(self.a_split_tensor != self.a_tensor, result))

self.assertEqual(ht.ne(self.a_split_tensor, self.a_tensor).dtype, ht.bool)

with self.assertRaises(ValueError):
ht.ne(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
install_requires=[
"mpi4py>=3.0.0",
"numpy>=1.13.0",
"torch==1.7.0",
"torch>=1.7.0",
"scipy>=0.14.0",
"pillow>=6.0.0",
"torchvision>=0.5.0",
Expand Down

0 comments on commit 34c356b

Please sign in to comment.