diff --git a/CHANGELOG.md b/CHANGELOG.md index f4353be514..f9cd87950f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ - [#620](https://github.com/helmholtz-analytics/heat/pull/620) New feature: KNN - [#624](https://github.com/helmholtz-analytics/heat/pull/624) Bugfix: distributed median() indexing and casting - [#629](https://github.com/helmholtz-analytics/heat/pull/629) New features: `asin`, `acos`, `atan`, `atan2` +- [#631](https://github.com/helmholtz-analytics/heat/pull/631) Bugfix: get_halo behaviour when rank has no data. - [#634](https://github.com/helmholtz-analytics/heat/pull/634) New features: `kmedians`, `kmedoids`, `manhattan` - [#633](https://github.com/helmholtz-analytics/heat/pull/633) Documentation: updated contributing.md - [#635](https://github.com/helmholtz-analytics/heat/pull/635) `DNDarray.__getitem__` balances and resplits the given key to None if the key is a DNDarray diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 1e4457bc7b..1755f2403e 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -312,21 +312,37 @@ def get_halo(self, halo_size): halo_size : int Size of the halo. """ + if not self.is_balanced(): + raise RuntimeError( + "halo cannot be created for unbalanced tensors, running the .balance_() function is recommended" + ) if not isinstance(halo_size, int): raise TypeError( - "halo_size needs to be of Python type integer, {} given)".format(type(halo_size)) + "halo_size needs to be of Python type integer, {} given".format(type(halo_size)) ) if halo_size < 0: raise ValueError( - "halo_size needs to be a positive Python integer, {} given)".format(type(halo_size)) + "halo_size needs to be a positive Python integer, {} given".format(type(halo_size)) ) if self.comm.is_distributed() and self.split is not None: - min_chunksize = self.shape[self.split] // self.comm.size - if halo_size > min_chunksize: + # gather lshapes + lshape_map = self.create_lshape_map() + rank = self.comm.rank + size = self.comm.size + next_rank = rank + 1 + prev_rank = rank - 1 + last_rank = size - 1 + + # if local shape is zero and its the last process + if self.lshape[self.split] == 0: + return # if process has no data we ignore it + + if halo_size > self.lshape[self.split]: + # if on at least one process the halo_size is larger than the local size throw ValueError raise ValueError( - "halo_size {} needs to smaller than chunck-size {} )".format( - halo_size, min_chunksize + "halo_size {} needs to be smaller than chunck-size {} )".format( + halo_size, self.lshape[self.split] ) ) @@ -338,19 +354,20 @@ def get_halo(self, halo_size): req_list = list() - if self.comm.rank != self.comm.size - 1: - self.comm.Isend(a_next, self.comm.rank + 1) + # only exchange data with next process if it has data + if rank != last_rank and (lshape_map[next_rank, self.split] > 0): + self.comm.Isend(a_next, next_rank) res_prev = torch.zeros( a_prev.size(), dtype=a_prev.dtype, device=self.device.torch_device ) - req_list.append(self.comm.Irecv(res_prev, source=self.comm.rank + 1)) + req_list.append(self.comm.Irecv(res_prev, source=next_rank)) - if self.comm.rank != 0: - self.comm.Isend(a_prev, self.comm.rank - 1) + if rank != 0: + self.comm.Isend(a_prev, prev_rank) res_next = torch.zeros( a_next.size(), dtype=a_next.dtype, device=self.device.torch_device ) - req_list.append(self.comm.Irecv(res_next, source=self.comm.rank - 1)) + req_list.append(self.comm.Irecv(res_next, source=prev_rank)) for req in req_list: req.wait() diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index a261168e01..341a41c43b 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -126,6 +126,49 @@ def test_gethalo(self): # exception for too large halos with self.assertRaises(ValueError): data.get_halo(4) + # exception on non balanced tensor + with self.assertRaises(RuntimeError): + if data.comm.rank == 1: + data._DNDarray__array = torch.empty(0) + data.get_halo(1) + + # test no data on process + data_np = np.arange(2 * 12).reshape(2, 12) + data = ht.array(data_np, split=0) + data.get_halo(1) + + data_with_halos = data.array_with_halos + + if data.comm.rank == 0: + self.assertTrue(data.halo_prev is None) + self.assertTrue(data.halo_next is not None) + self.assertEqual(data_with_halos.shape, (2, 12)) + if data.comm.rank == 1: + self.assertTrue(data.halo_prev is not None) + self.assertTrue(data.halo_next is None) + self.assertEqual(data_with_halos.shape, (2, 12)) + if data.comm.rank == 2: + self.assertTrue(data.halo_prev is None) + self.assertTrue(data.halo_next is None) + self.assertEqual(data_with_halos.shape, (0, 12)) + + data = data.reshape((12, 2), axis=1) + data.get_halo(1) + + data_with_halos = data.array_with_halos + + if data.comm.rank == 0: + self.assertTrue(data.halo_prev is None) + self.assertTrue(data.halo_next is not None) + self.assertEqual(data_with_halos.shape, (12, 2)) + if data.comm.rank == 1: + self.assertTrue(data.halo_prev is not None) + self.assertTrue(data.halo_next is None) + self.assertEqual(data_with_halos.shape, (12, 2)) + if data.comm.rank == 2: + self.assertTrue(data.halo_prev is None) + self.assertTrue(data.halo_next is None) + self.assertEqual(data_with_halos.shape, (12, 0)) def test_astype(self): data = ht.float32([[1, 2, 3], [4, 5, 6]]) diff --git a/heat/core/tests/test_statistics.py b/heat/core/tests/test_statistics.py index b2bf70e348..0685344cec 100644 --- a/heat/core/tests/test_statistics.py +++ b/heat/core/tests/test_statistics.py @@ -993,8 +993,8 @@ def test_minimum(self): ht.minimum(random_volume_1, random_volume_2, out=output) def test_percentile(self): - # test local, distributed, split/axis combination, TODO no data on process, Issue #568 - x_np = np.arange(10 * 10 * 10).reshape(10, 10, 10) + # test local, distributed, split/axis combination, no data on process + x_np = np.arange(3 * 10 * 10).reshape(3, 10, 10) x_ht = ht.array(x_np) x_ht_split0 = ht.array(x_np, split=0) x_ht_split1 = ht.array(x_np, split=1)