Skip to content

Commit

Permalink
Merge pull request #631 from helmholtz-analytics/bug/568-get_halo-beh…
Browse files Browse the repository at this point in the history
…aviour-if-no-data-on-process

Bug/568 get halo behaviour if no data on process
  • Loading branch information
ClaudiaComito committed Sep 22, 2020
2 parents 7cee8f0 + 5f063ad commit 6c490d7
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
- [#620](https://github.com/helmholtz-analytics/heat/pull/620) New feature: KNN
- [#624](https://github.com/helmholtz-analytics/heat/pull/624) Bugfix: distributed median() indexing and casting
- [#629](https://github.com/helmholtz-analytics/heat/pull/629) New features: `asin`, `acos`, `atan`, `atan2`
- [#631](https://github.com/helmholtz-analytics/heat/pull/631) Bugfix: get_halo behaviour when rank has no data.
- [#634](https://github.com/helmholtz-analytics/heat/pull/634) New features: `kmedians`, `kmedoids`, `manhattan`
- [#633](https://github.com/helmholtz-analytics/heat/pull/633) Documentation: updated contributing.md
- [#635](https://github.com/helmholtz-analytics/heat/pull/635) `DNDarray.__getitem__` balances and resplits the given key to None if the key is a DNDarray
Expand Down
41 changes: 29 additions & 12 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,21 +312,37 @@ def get_halo(self, halo_size):
halo_size : int
Size of the halo.
"""
if not self.is_balanced():
raise RuntimeError(
"halo cannot be created for unbalanced tensors, running the .balance_() function is recommended"
)
if not isinstance(halo_size, int):
raise TypeError(
"halo_size needs to be of Python type integer, {} given)".format(type(halo_size))
"halo_size needs to be of Python type integer, {} given".format(type(halo_size))
)
if halo_size < 0:
raise ValueError(
"halo_size needs to be a positive Python integer, {} given)".format(type(halo_size))
"halo_size needs to be a positive Python integer, {} given".format(type(halo_size))
)

if self.comm.is_distributed() and self.split is not None:
min_chunksize = self.shape[self.split] // self.comm.size
if halo_size > min_chunksize:
# gather lshapes
lshape_map = self.create_lshape_map()
rank = self.comm.rank
size = self.comm.size
next_rank = rank + 1
prev_rank = rank - 1
last_rank = size - 1

# if local shape is zero and its the last process
if self.lshape[self.split] == 0:
return # if process has no data we ignore it

if halo_size > self.lshape[self.split]:
# if on at least one process the halo_size is larger than the local size throw ValueError
raise ValueError(
"halo_size {} needs to smaller than chunck-size {} )".format(
halo_size, min_chunksize
"halo_size {} needs to be smaller than chunck-size {} )".format(
halo_size, self.lshape[self.split]
)
)

Expand All @@ -338,19 +354,20 @@ def get_halo(self, halo_size):

req_list = list()

if self.comm.rank != self.comm.size - 1:
self.comm.Isend(a_next, self.comm.rank + 1)
# only exchange data with next process if it has data
if rank != last_rank and (lshape_map[next_rank, self.split] > 0):
self.comm.Isend(a_next, next_rank)
res_prev = torch.zeros(
a_prev.size(), dtype=a_prev.dtype, device=self.device.torch_device
)
req_list.append(self.comm.Irecv(res_prev, source=self.comm.rank + 1))
req_list.append(self.comm.Irecv(res_prev, source=next_rank))

if self.comm.rank != 0:
self.comm.Isend(a_prev, self.comm.rank - 1)
if rank != 0:
self.comm.Isend(a_prev, prev_rank)
res_next = torch.zeros(
a_next.size(), dtype=a_next.dtype, device=self.device.torch_device
)
req_list.append(self.comm.Irecv(res_next, source=self.comm.rank - 1))
req_list.append(self.comm.Irecv(res_next, source=prev_rank))

for req in req_list:
req.wait()
Expand Down
43 changes: 43 additions & 0 deletions heat/core/tests/test_dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,49 @@ def test_gethalo(self):
# exception for too large halos
with self.assertRaises(ValueError):
data.get_halo(4)
# exception on non balanced tensor
with self.assertRaises(RuntimeError):
if data.comm.rank == 1:
data._DNDarray__array = torch.empty(0)
data.get_halo(1)

# test no data on process
data_np = np.arange(2 * 12).reshape(2, 12)
data = ht.array(data_np, split=0)
data.get_halo(1)

data_with_halos = data.array_with_halos

if data.comm.rank == 0:
self.assertTrue(data.halo_prev is None)
self.assertTrue(data.halo_next is not None)
self.assertEqual(data_with_halos.shape, (2, 12))
if data.comm.rank == 1:
self.assertTrue(data.halo_prev is not None)
self.assertTrue(data.halo_next is None)
self.assertEqual(data_with_halos.shape, (2, 12))
if data.comm.rank == 2:
self.assertTrue(data.halo_prev is None)
self.assertTrue(data.halo_next is None)
self.assertEqual(data_with_halos.shape, (0, 12))

data = data.reshape((12, 2), axis=1)
data.get_halo(1)

data_with_halos = data.array_with_halos

if data.comm.rank == 0:
self.assertTrue(data.halo_prev is None)
self.assertTrue(data.halo_next is not None)
self.assertEqual(data_with_halos.shape, (12, 2))
if data.comm.rank == 1:
self.assertTrue(data.halo_prev is not None)
self.assertTrue(data.halo_next is None)
self.assertEqual(data_with_halos.shape, (12, 2))
if data.comm.rank == 2:
self.assertTrue(data.halo_prev is None)
self.assertTrue(data.halo_next is None)
self.assertEqual(data_with_halos.shape, (12, 0))

def test_astype(self):
data = ht.float32([[1, 2, 3], [4, 5, 6]])
Expand Down
4 changes: 2 additions & 2 deletions heat/core/tests/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,8 +993,8 @@ def test_minimum(self):
ht.minimum(random_volume_1, random_volume_2, out=output)

def test_percentile(self):
# test local, distributed, split/axis combination, TODO no data on process, Issue #568
x_np = np.arange(10 * 10 * 10).reshape(10, 10, 10)
# test local, distributed, split/axis combination, no data on process
x_np = np.arange(3 * 10 * 10).reshape(3, 10, 10)
x_ht = ht.array(x_np)
x_ht_split0 = ht.array(x_np, split=0)
x_ht_split1 = ht.array(x_np, split=1)
Expand Down

0 comments on commit 6c490d7

Please sign in to comment.