From 636c797815b6f480f739533aa00aa1af34d5260e Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Thu, 16 Sep 2021 14:28:48 +0200 Subject: [PATCH 1/7] add split method for communicator --- heat/core/communication.py | 13 +++++++++++++ heat/core/tests/test_communication.py | 13 +++++++++++++ 2 files changed, 26 insertions(+) diff --git a/heat/core/communication.py b/heat/core/communication.py index 6d65b0c810..4443ab6bbc 100644 --- a/heat/core/communication.py +++ b/heat/core/communication.py @@ -436,6 +436,19 @@ def alltoall_recvbuffer( return self.as_mpi_memory(obj), (recvcount, recvdispls), recvtypes + def Split(self, color: int = 0, key: int = 0): + """ + Split communicator by color and key. + + Parameters + ---------- + color : int, optional + Determines the new communicator for a process. + key: int, optional + Ordering within the new communicator. + """ + return MPICommunication(self.handle.Split(color, key)) + def Irecv( self, buf: Union[DNDarray, torch.Tensor, Any], diff --git a/heat/core/tests/test_communication.py b/heat/core/tests/test_communication.py index 90f7808e34..af3a06f889 100644 --- a/heat/core/tests/test_communication.py +++ b/heat/core/tests/test_communication.py @@ -196,6 +196,19 @@ def test_default_comm(self): with self.assertRaises(TypeError): ht.use_comm("1") + def test_split(self): + a = ht.zeros((4, 5), split=0) + + color = a.comm.rank % 2 + newcomm = a.comm.Split(color, key=a.comm.rank) + + self.assertIsInstance(newcomm, ht.MPICommunication) + if ht.MPI_WORLD.size == 1: + self.assertTrue(newcomm.size == a.comm.size) + else: + self.assertTrue(newcomm.size < a.comm.size) + self.assertIsNot(newcomm, a.comm) + def test_allgather(self): # contiguous data data = ht.ones((1, 7)) From feab51a71c5607ba498f9113033c903bfaccb87e Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Thu, 16 Sep 2021 14:30:25 +0200 Subject: [PATCH 2/7] fine tuning single element binary_op --- heat/core/_operations.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/heat/core/_operations.py b/heat/core/_operations.py index 8a8e8ea835..e4ce2b8768 100644 --- a/heat/core/_operations.py +++ b/heat/core/_operations.py @@ -9,7 +9,6 @@ from .communication import MPI, MPI_WORLD from . import factories -from . import devices from . import stride_tricks from . import sanitation from . import statistics @@ -117,22 +116,26 @@ def __binary_op( # warnings.warn( # "Broadcasting requires transferring data of first operator between MPI ranks!" # ) - if t1.comm.rank > 0: + color = 0 if t1.comm.rank < t2.shape[t1.split] else 1 + newcomm = t1.comm.Split(color, t1.comm.rank) + if t1.comm.rank > 0 and color == 0: t1.larray = torch.zeros( t1.shape, dtype=t1.dtype.torch_type(), device=t1.device.torch_device ) - t1.comm.Bcast(t1) + newcomm.Bcast(t1) if t2.split is not None: if t2.shape[t2.split] == 1 and t2.comm.is_distributed(): # warnings.warn( # "Broadcasting requires transferring data of second operator between MPI ranks!" # ) - if t2.comm.rank > 0: + color = 0 if t2.comm.rank < t1.shape[t2.split] else 1 + newcomm = t2.comm.Split(color, t2.comm.rank) + if t2.comm.rank > 0 and color == 0: t2.larray = torch.zeros( t2.shape, dtype=t2.dtype.torch_type(), device=t2.device.torch_device ) - t2.comm.Bcast(t2) + newcomm.Bcast(t2) else: raise TypeError( From 6a4dc6a3f818090d64bf182acbc7e297718b15e7 Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Thu, 16 Sep 2021 15:31:41 +0200 Subject: [PATCH 3/7] update changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e7f45d8100..86600707b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,9 +3,13 @@ - [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension - [#846](https://github.com/helmholtz-analytics/heat/pull/846) Fixed an issue in `_reduce_op` when axis and keepdim were set. - [#846](https://github.com/helmholtz-analytics/heat/pull/846) Fixed an issue in `min`, `max` where DNDarrays with empty processes can't be computed. +- [#868](https://github.com/helmholtz-analytics/heat/pull/868) Fixed an issue in `__binary_op` where data was falsely distributed if a DNDarray has single element. ## Feature Additions +### Communication +- [#868](https://github.com/helmholtz-analytics/heat/pull/868) New `MPICommunication` method `Split` + ### DNDarray - [#856](https://github.com/helmholtz-analytics/heat/pull/856) New `DNDarray` method `__torch_proxy__` From b306f23dcd29a6116e71240ea59df64afe07d3f9 Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Fri, 17 Sep 2021 10:16:14 +0200 Subject: [PATCH 4/7] add specific test --- heat/core/_operations.py | 1 - heat/core/tests/test_arithmetics.py | 10 ++++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/heat/core/_operations.py b/heat/core/_operations.py index e4ce2b8768..22ee5ed523 100644 --- a/heat/core/_operations.py +++ b/heat/core/_operations.py @@ -110,7 +110,6 @@ def __binary_op( output_device = t1.device output_comm = t1.comm - # ToDo: Fine tuning in case of comm.size>t1.shape[t1.split]. Send torch tensors only to ranks, that will hold data. if t1.split is not None: if t1.shape[t1.split] == 1 and t1.comm.is_distributed(): # warnings.warn( diff --git a/heat/core/tests/test_arithmetics.py b/heat/core/tests/test_arithmetics.py index fa8a86ca2d..b601dcb41f 100644 --- a/heat/core/tests/test_arithmetics.py +++ b/heat/core/tests/test_arithmetics.py @@ -34,6 +34,16 @@ def test_add(self): self.assertTrue(ht.equal(ht.add(self.a_tensor, self.an_int_scalar), result)) self.assertTrue(ht.equal(ht.add(self.a_split_tensor, self.a_tensor), result)) + # Single element split + a = ht.array([1], split=0) + b = ht.array([1, 2], split=0) + c = ht.add(a, b) + self.assertTrue(ht.equal(c, ht.array([2, 3]))) + if c.comm.rank < 2: + self.assertEqual(c.larray.size()[0], 1) + else: + self.assertEqual(c.larray.size()[0], 0) + with self.assertRaises(ValueError): ht.add(self.a_tensor, self.another_vector) with self.assertRaises(TypeError): From c3881cadb86c5a870f18e93d92ef00f2fd65e46f Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Fri, 17 Sep 2021 11:43:59 +0200 Subject: [PATCH 5/7] fix test --- heat/core/tests/test_arithmetics.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/heat/core/tests/test_arithmetics.py b/heat/core/tests/test_arithmetics.py index b601dcb41f..1203b57512 100644 --- a/heat/core/tests/test_arithmetics.py +++ b/heat/core/tests/test_arithmetics.py @@ -39,10 +39,11 @@ def test_add(self): b = ht.array([1, 2], split=0) c = ht.add(a, b) self.assertTrue(ht.equal(c, ht.array([2, 3]))) - if c.comm.rank < 2: - self.assertEqual(c.larray.size()[0], 1) - else: - self.assertEqual(c.larray.size()[0], 0) + if c.comm.size > 1: + if c.comm.rank < 2: + self.assertEqual(c.larray.size()[0], 1) + else: + self.assertEqual(c.larray.size()[0], 0) with self.assertRaises(ValueError): ht.add(self.a_tensor, self.another_vector) From ce431cae0d226945d36adea9d298c8bbcb34c629 Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Thu, 23 Sep 2021 11:51:48 +0200 Subject: [PATCH 6/7] add return type --- heat/core/communication.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/communication.py b/heat/core/communication.py index 4443ab6bbc..6c3662cf5e 100644 --- a/heat/core/communication.py +++ b/heat/core/communication.py @@ -436,7 +436,7 @@ def alltoall_recvbuffer( return self.as_mpi_memory(obj), (recvcount, recvdispls), recvtypes - def Split(self, color: int = 0, key: int = 0): + def Split(self, color: int = 0, key: int = 0) -> MPICommunication: """ Split communicator by color and key. From 9016f42d5801f408648b01cfc6461b4a60d35eb5 Mon Sep 17 00:00:00 2001 From: Michael Tarnawa Date: Thu, 23 Sep 2021 14:04:57 +0200 Subject: [PATCH 7/7] add Free() method for MPICommunicator --- heat/core/_operations.py | 2 ++ heat/core/communication.py | 6 ++++++ heat/core/tests/test_communication.py | 2 ++ 3 files changed, 10 insertions(+) diff --git a/heat/core/_operations.py b/heat/core/_operations.py index 22ee5ed523..1a34d1a84f 100644 --- a/heat/core/_operations.py +++ b/heat/core/_operations.py @@ -122,6 +122,7 @@ def __binary_op( t1.shape, dtype=t1.dtype.torch_type(), device=t1.device.torch_device ) newcomm.Bcast(t1) + newcomm.Free() if t2.split is not None: if t2.shape[t2.split] == 1 and t2.comm.is_distributed(): @@ -135,6 +136,7 @@ def __binary_op( t2.shape, dtype=t2.dtype.torch_type(), device=t2.device.torch_device ) newcomm.Bcast(t2) + newcomm.Free() else: raise TypeError( diff --git a/heat/core/communication.py b/heat/core/communication.py index 6c3662cf5e..388949bcd8 100644 --- a/heat/core/communication.py +++ b/heat/core/communication.py @@ -436,6 +436,12 @@ def alltoall_recvbuffer( return self.as_mpi_memory(obj), (recvcount, recvdispls), recvtypes + def Free(self) -> None: + """ + Free a communicator. + """ + self.handle.Free() + def Split(self, color: int = 0, key: int = 0) -> MPICommunication: """ Split communicator by color and key. diff --git a/heat/core/tests/test_communication.py b/heat/core/tests/test_communication.py index af3a06f889..1410eaf9cc 100644 --- a/heat/core/tests/test_communication.py +++ b/heat/core/tests/test_communication.py @@ -209,6 +209,8 @@ def test_split(self): self.assertTrue(newcomm.size < a.comm.size) self.assertIsNot(newcomm, a.comm) + newcomm.Free() + def test_allgather(self): # contiguous data data = ht.ones((1, 7))