Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug/568 get halo behaviour if no data on process #631

Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
- [#620](https://github.com/helmholtz-analytics/heat/pull/620) New feature: KNN
- [#624](https://github.com/helmholtz-analytics/heat/pull/624) Bugfix: distributed median() indexing and casting
- [#629](https://github.com/helmholtz-analytics/heat/pull/629) New features: `asin`, `acos`, `atan`, `atan2`
- [#631](https://github.com/helmholtz-analytics/heat/pull/631) Bugfix: get_halo behaviour when rank has no data.
- [#634](https://github.com/helmholtz-analytics/heat/pull/634) New features: `kmedians`, `kmedoids`, `manhattan`
- [#633](https://github.com/helmholtz-analytics/heat/pull/633) Documentation: updated contributing.md
- [#635](https://github.com/helmholtz-analytics/heat/pull/635) `DNDarray.__getitem__` balances and resplits the given key to None if the key is a DNDarray
Expand Down
41 changes: 29 additions & 12 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,21 +312,37 @@ def get_halo(self, halo_size):
halo_size : int
Size of the halo.
"""
if not self.is_balanced():
raise RuntimeError(
"halo cannot be created for unbalanced tensors, running the .balance_() function is recommended"
)
if not isinstance(halo_size, int):
raise TypeError(
"halo_size needs to be of Python type integer, {} given)".format(type(halo_size))
"halo_size needs to be of Python type integer, {} given".format(type(halo_size))
)
if halo_size < 0:
raise ValueError(
"halo_size needs to be a positive Python integer, {} given)".format(type(halo_size))
"halo_size needs to be a positive Python integer, {} given".format(type(halo_size))
)

if self.comm.is_distributed() and self.split is not None:
min_chunksize = self.shape[self.split] // self.comm.size
if halo_size > min_chunksize:
# gather lshapes
lshape_map = self.create_lshape_map()
rank = self.comm.rank
size = self.comm.size
next_rank = rank + 1
prev_rank = rank - 1
last_rank = size - 1

# if local shape is zero and its the last process
if self.lshape[self.split] == 0:
return # if process has no data we ignore it

if halo_size > self.lshape[self.split]:
# if on at least one process the halo_size is larger than the local size throw ValueError
raise ValueError(
"halo_size {} needs to smaller than chunck-size {} )".format(
halo_size, min_chunksize
"halo_size {} needs to be smaller than chunck-size {} )".format(
halo_size, self.lshape[self.split]
)
)

Expand All @@ -338,19 +354,20 @@ def get_halo(self, halo_size):

req_list = list()

if self.comm.rank != self.comm.size - 1:
self.comm.Isend(a_next, self.comm.rank + 1)
# only exchange data with next process if it has data
if rank != last_rank and (lshape_map[next_rank, self.split] > 0):
self.comm.Isend(a_next, next_rank)
res_prev = torch.zeros(
a_prev.size(), dtype=a_prev.dtype, device=self.device.torch_device
)
req_list.append(self.comm.Irecv(res_prev, source=self.comm.rank + 1))
req_list.append(self.comm.Irecv(res_prev, source=next_rank))

if self.comm.rank != 0:
self.comm.Isend(a_prev, self.comm.rank - 1)
if rank != 0:
self.comm.Isend(a_prev, prev_rank)
res_next = torch.zeros(
a_next.size(), dtype=a_next.dtype, device=self.device.torch_device
)
req_list.append(self.comm.Irecv(res_next, source=self.comm.rank - 1))
req_list.append(self.comm.Irecv(res_next, source=prev_rank))

for req in req_list:
req.wait()
Expand Down
43 changes: 43 additions & 0 deletions heat/core/tests/test_dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,49 @@ def test_gethalo(self):
# exception for too large halos
with self.assertRaises(ValueError):
data.get_halo(4)
# exception on non balanced tensor
with self.assertRaises(RuntimeError):
if data.comm.rank == 1:
data._DNDarray__array = torch.empty(0)
data.get_halo(1)

# test no data on process
data_np = np.arange(2 * 12).reshape(2, 12)
data = ht.array(data_np, split=0)
data.get_halo(1)

data_with_halos = data.array_with_halos

if data.comm.rank == 0:
self.assertTrue(data.halo_prev is None)
self.assertTrue(data.halo_next is not None)
self.assertEqual(data_with_halos.shape, (2, 12))
if data.comm.rank == 1:
self.assertTrue(data.halo_prev is not None)
self.assertTrue(data.halo_next is None)
self.assertEqual(data_with_halos.shape, (2, 12))
if data.comm.rank == 2:
self.assertTrue(data.halo_prev is None)
self.assertTrue(data.halo_next is None)
self.assertEqual(data_with_halos.shape, (0, 12))

data = data.reshape((12, 2), axis=1)
data.get_halo(1)

data_with_halos = data.array_with_halos

if data.comm.rank == 0:
self.assertTrue(data.halo_prev is None)
self.assertTrue(data.halo_next is not None)
self.assertEqual(data_with_halos.shape, (12, 2))
if data.comm.rank == 1:
self.assertTrue(data.halo_prev is not None)
self.assertTrue(data.halo_next is None)
self.assertEqual(data_with_halos.shape, (12, 2))
if data.comm.rank == 2:
self.assertTrue(data.halo_prev is None)
self.assertTrue(data.halo_next is None)
self.assertEqual(data_with_halos.shape, (12, 0))

def test_astype(self):
data = ht.float32([[1, 2, 3], [4, 5, 6]])
Expand Down
4 changes: 2 additions & 2 deletions heat/core/tests/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,8 +993,8 @@ def test_minimum(self):
ht.minimum(random_volume_1, random_volume_2, out=output)

def test_percentile(self):
# test local, distributed, split/axis combination, TODO no data on process, Issue #568
x_np = np.arange(10 * 10 * 10).reshape(10, 10, 10)
# test local, distributed, split/axis combination, no data on process
x_np = np.arange(3 * 10 * 10).reshape(3, 10, 10)
x_ht = ht.array(x_np)
x_ht_split0 = ht.array(x_np, split=0)
x_ht_split1 = ht.array(x_np, split=1)
Expand Down