From ca8a106c367744b441a4bf6abfbcd79dd9a5f69b Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Mon, 21 Sep 2020 11:06:11 +0200 Subject: [PATCH 01/27] Sanitizing parameters --- heat/core/manipulations.py | 72 +++++++++++++++++++++++++++++++++++++- 1 file changed, 71 insertions(+), 1 deletion(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 32f382669b..4ff7f29a8a 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -13,7 +13,6 @@ from . import types from . import _operations - __all__ = [ "column_stack", "concatenate", @@ -31,6 +30,7 @@ "row_stack", "shape", "sort", + "split", "squeeze", "stack", "topk", @@ -1424,6 +1424,76 @@ def sort(a, axis=None, descending=False, out=None): return tensor, return_indices +def split(ary, indices_or_sections, axis=0): + """ + Split a DNDarray into multiple sub-DNDarrays as views into ary. + + Parameters + ---------- + ary : DNDarray + DNDArray to be divided into sub-DNDarrays. + indices_or_sections : int or 1-dimensional array_like (i.e. DNDarray, list or tuple) + If indices_or_sections is an integer, N, the DNDarray will be divided into N equal DNDarrays along axis. + If such a split is not possible, an error is raised. + If indices_or_sections is a 1-D DNDarray of sorted integers, the entries indicate where along axis + the array is split. + axis : int, optional + The axis along which to split, default is 0. + + Returns + ------- + sub-arrays : list of DNDarrays + A list of sub-DNDarrays as views into ary. + + Raises + ------ + ValueError + If indices_or_sections is given as integer, but a split does not result in equal division. + + Examples #TODO + -------- + >>> x = ht.array(12).reshape((2,2,3)) + + """ + # sanitize ary + if not isinstance(ary, dndarray.DNDarray): + raise TypeError("Expected ary to be a DNDarray, but was {}".format(type(ary))) + + # sanitize axis + if not isinstance(axis, int): + raise TypeError("Expected axis to be an integer, but was {}".format(type(axis))) + if axis < 0 or axis > len(ary.gshape) - 1: + raise ValueError( + "Invalid input for axis. Valid range is between 0 and {}, but was {}".format( + len(ary.gshape) - 1, axis + ) + ) + + # sanitize indices_or_sections + if isinstance(indices_or_sections, int): + if ary.gshape[axis] % indices_or_sections != 0: + raise ValueError( + "DNDarray with shape {} can't be divided equally into {} chunks along axis {}".format( + ary.gshape, indices_or_sections, axis + ) + ) + elif isinstance(indices_or_sections, (list, tuple, dndarray.DNDarray)): + if isinstance(indices_or_sections, (list, tuple)): + indices_or_sections = factories.array(indices_or_sections) + if len(indices_or_sections.gshape) != 1: + raise ValueError( + "Expected indices_or_sections to be 1-dimensional, but was {}-dimensional instead.".format( + len(indices_or_sections.gshape) - 1 + ) + ) + else: + raise TypeError( + "Expected indices_or_sections to be array_like (DNDarray, list or tuple), but was {}".format( + type(indices_or_sections) + ) + ) + + def squeeze(x, axis=None): """ Remove single-dimensional entries from the shape of a tensor. From be6ab9913803d941bf1d8fcca673ad0265a64628 Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Mon, 21 Sep 2020 13:16:19 +0200 Subject: [PATCH 02/27] First approach to function --- heat/core/manipulations.py | 11 ++++++++++- heat/core/tests/test_manipulations.py | 3 +++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 4ff7f29a8a..ef1fb206d8 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -1442,7 +1442,7 @@ def split(ary, indices_or_sections, axis=0): Returns ------- - sub-arrays : list of DNDarrays + sub_arrays : list of DNDarrays A list of sub-DNDarrays as views into ary. Raises @@ -1477,6 +1477,8 @@ def split(ary, indices_or_sections, axis=0): ary.gshape, indices_or_sections, axis ) ) + # to adapt torch syntax + indices_or_sections -= 1 elif isinstance(indices_or_sections, (list, tuple, dndarray.DNDarray)): if isinstance(indices_or_sections, (list, tuple)): indices_or_sections = factories.array(indices_or_sections) @@ -1493,6 +1495,13 @@ def split(ary, indices_or_sections, axis=0): ) ) + if ary.split is None: + sub_arrays_t = torch.split(ary, indices_or_sections, axis) + + return factories.array( + sub_arrays_t, dtype=ary.dtype, is_split=ary.split, device=ary.device, comm=ary.comm + ) + def squeeze(x, axis=None): """ diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 5243a583be..6c60a9dcd2 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -1353,6 +1353,9 @@ def test_sort(self): ).all() ) + def test_split(self): + pass + def test_resplit(self): if ht.MPI_WORLD.size > 1: # resplitting with same axis, should leave everything unchanged From 7dc205a6a87d062f8be4c853c7c30ef7c1350180 Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Mon, 21 Sep 2020 14:13:08 +0200 Subject: [PATCH 03/27] indices_or_sections = int, undistributed case --- heat/core/manipulations.py | 30 +++++++++++++++++++++------ heat/core/tests/test_manipulations.py | 25 +++++++++++++++++++++- 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index ef1fb206d8..f63d1ef119 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -1477,8 +1477,9 @@ def split(ary, indices_or_sections, axis=0): ary.gshape, indices_or_sections, axis ) ) - # to adapt torch syntax - indices_or_sections -= 1 + # np to torch mapping - calculate size of resulting data chunks + indices_or_sections_t = ary.gshape[axis] // indices_or_sections + elif isinstance(indices_or_sections, (list, tuple, dndarray.DNDarray)): if isinstance(indices_or_sections, (list, tuple)): indices_or_sections = factories.array(indices_or_sections) @@ -1496,11 +1497,28 @@ def split(ary, indices_or_sections, axis=0): ) if ary.split is None: - sub_arrays_t = torch.split(ary, indices_or_sections, axis) + if isinstance(indices_or_sections, int): + sub_arrays_t = torch.split(ary._DNDarray__array, indices_or_sections_t, axis) + else: + if ( + indices_or_sections.split is None + ): # TODO map np syntax to torch (calculate chunk sizes) + pass + sub_arrays_t = torch.split( + ary._DNDarray__array, indices_or_sections._DNDarray__array, axis + ) - return factories.array( - sub_arrays_t, dtype=ary.dtype, is_split=ary.split, device=ary.device, comm=ary.comm - ) + sub_arrays_ht = [ + factories.array( + sub_DNDarray, dtype=ary.dtype, is_split=ary.split, device=ary.device, comm=ary.comm + ) + for sub_DNDarray in sub_arrays_t + ] + + for sub_DNDarray in sub_arrays_ht: + sub_DNDarray.balance_() + + return sub_arrays_ht def squeeze(x, axis=None): diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 6c60a9dcd2..995e57e3d7 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -1354,7 +1354,30 @@ def test_sort(self): ) def test_split(self): - pass + # ==================================== + # UNDISTRIBUTED CASE + # ==================================== + data_ht = ht.arange(24).reshape((2, 3, 4)) + data_np = data_ht.numpy() + + # indices_or_sections = int + result = ht.split(data_ht, 2) + comparison = np.split(data_np, 2) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = tuple + # result = ht.split(data_ht, ) + + # indices_or_sections = list + + # indices_or_sections = undistributed DNDarray + + # indices_or_sections = distributed DNDarray def test_resplit(self): if ht.MPI_WORLD.size > 1: From d273872c276d08f54f6fa216130f332a499950c2 Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Mon, 21 Sep 2020 15:42:58 +0200 Subject: [PATCH 04/27] Undistributed indices_or_sections (mapping np -> torch), undistributed case --- heat/core/manipulations.py | 32 ++++++++++++++++++++++----- heat/core/tests/test_manipulations.py | 9 +++++++- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index f63d1ef119..eb2f7d7bec 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -4,9 +4,11 @@ from .communication import MPI +from . import arithmetics from . import constants from . import dndarray from . import factories +from . import indexing from . import linalg from . import stride_tricks from . import tiling @@ -1500,13 +1502,31 @@ def split(ary, indices_or_sections, axis=0): if isinstance(indices_or_sections, int): sub_arrays_t = torch.split(ary._DNDarray__array, indices_or_sections_t, axis) else: - if ( - indices_or_sections.split is None - ): # TODO map np syntax to torch (calculate chunk sizes) + if indices_or_sections.split is None: + # np to torch mapping + + # 1. replace all values out of range with gshape[axis] to generate size 0 + indices_or_sections_t = indexing.where( + indices_or_sections <= ary.gshape[axis], indices_or_sections, ary.gshape[axis] + ) + + # 2. add first and last value to DNDarray + additional_zero = factories.array([0]) + additional_length = factories.array([ary.gshape[axis]]) + indices_or_sections_t = concatenate( + [additional_zero, indices_or_sections_t, additional_length] + ) + + # 3. calculate the 1-st discrete difference therefore corresponding chunk sizes + indices_or_sections_t = arithmetics.diff(indices_or_sections_t) + + # 4. transform the result into a list (torch requirement) + indices_or_sections_t = [int(ele) for ele in indices_or_sections_t] + # indices_or_sections distributed + else: # TODO pass - sub_arrays_t = torch.split( - ary._DNDarray__array, indices_or_sections._DNDarray__array, axis - ) + + sub_arrays_t = torch.split(ary._DNDarray__array, indices_or_sections_t, axis) sub_arrays_ht = [ factories.array( diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 995e57e3d7..53ed954e00 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -1371,7 +1371,14 @@ def test_split(self): self.assert_array_equal(result[i], comparison[i]) # indices_or_sections = tuple - # result = ht.split(data_ht, ) + result = ht.split(data_ht, [0, 1]) + comparison = np.split(data_np, [0, 1]) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) # indices_or_sections = list From 16f465efbcf92bfa1c8b64071cc27d76e0bef37f Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Mon, 21 Sep 2020 16:01:11 +0200 Subject: [PATCH 05/27] Additional tests --- heat/core/tests/test_manipulations.py | 61 ++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 53ed954e00..44b44875be 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -1357,6 +1357,8 @@ def test_split(self): # ==================================== # UNDISTRIBUTED CASE # ==================================== + # axis = 0 + # ==================================== data_ht = ht.arange(24).reshape((2, 3, 4)) data_np = data_ht.numpy() @@ -1371,8 +1373,8 @@ def test_split(self): self.assert_array_equal(result[i], comparison[i]) # indices_or_sections = tuple - result = ht.split(data_ht, [0, 1]) - comparison = np.split(data_np, [0, 1]) + result = ht.split(data_ht, (0, 1)) + comparison = np.split(data_np, (0, 1)) self.assertTrue(len(result) == len(comparison)) @@ -1381,11 +1383,66 @@ def test_split(self): self.assert_array_equal(result[i], comparison[i]) # indices_or_sections = list + result = ht.split(data_ht, [0, 1]) + comparison = np.split(data_np, [0, 1]) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) # indices_or_sections = undistributed DNDarray + result = ht.split(data_ht, ht.array([0, 1])) + comparison = np.split(data_np, np.array([0, 1])) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) # indices_or_sections = distributed DNDarray + # ==================================== + # axis != 0 (2 in this case) + # ==================================== + # indices_or_sections = int + result = ht.split(data_ht, 2, 2) + comparison = np.split(data_np, 2, 2) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = tuple + result = ht.split(data_ht, (0, 1)) + comparison = np.split(data_np, (0, 1)) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # exceptions + with self.assertRaises(TypeError): + ht.split([1, 2, 3, 4], 2) + with self.assertRaises(TypeError): + ht.split(data_ht, "2") + with self.assertRaises(TypeError): + ht.split(data_ht, 2, "0") + with self.assertRaises(ValueError): + ht.split(data_ht, 2, -1) + with self.assertRaises(ValueError): + ht.split(data_ht, 2, 3) + with self.assertRaises(ValueError): + ht.split(data_ht, 5) + with self.assertRaises(ValueError): + ht.split(data_ht, [[0, 1]]) + def test_resplit(self): if ht.MPI_WORLD.size > 1: # resplitting with same axis, should leave everything unchanged From ed75e8ba2d50d412ab411041d028790f96aeb928 Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Tue, 22 Sep 2020 09:42:51 +0200 Subject: [PATCH 06/27] indices_or sections distributed, undistributed case --- heat/core/manipulations.py | 46 ++++++++++++++++----------- heat/core/tests/test_manipulations.py | 8 +++++ 2 files changed, 35 insertions(+), 19 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index eb2f7d7bec..5d3be27086 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -1502,29 +1502,37 @@ def split(ary, indices_or_sections, axis=0): if isinstance(indices_or_sections, int): sub_arrays_t = torch.split(ary._DNDarray__array, indices_or_sections_t, axis) else: - if indices_or_sections.split is None: - # np to torch mapping - - # 1. replace all values out of range with gshape[axis] to generate size 0 - indices_or_sections_t = indexing.where( - indices_or_sections <= ary.gshape[axis], indices_or_sections, ary.gshape[axis] + if indices_or_sections.split is not None: + warnings.warn( + "`indices_or_sections` might not be distributed (along axis {}) if `ary` is not distributed.\n" + "`indices_or_sections` will be copied with new split axis None.".format( + indices_or_sections.split + ) ) + indices_or_sections = resplit(indices_or_sections, None) - # 2. add first and last value to DNDarray - additional_zero = factories.array([0]) - additional_length = factories.array([ary.gshape[axis]]) - indices_or_sections_t = concatenate( - [additional_zero, indices_or_sections_t, additional_length] - ) + # np to torch mapping - # 3. calculate the 1-st discrete difference therefore corresponding chunk sizes - indices_or_sections_t = arithmetics.diff(indices_or_sections_t) + # 1. replace all values out of range with gshape[axis] to generate size 0 + indices_or_sections_t = indexing.where( + indices_or_sections <= ary.gshape[axis], indices_or_sections, ary.gshape[axis] + ) + + # 2. add first and last value to DNDarray + # 3. calculate the 1-st discrete difference therefore corresponding chunk sizes + indices_or_sections_t = arithmetics.diff( + indices_or_sections_t, prepend=0, append=ary.gshape[axis] + ) + indices_or_sections_t = factories.array( + indices_or_sections_t, + dtype=types.int64, + is_split=indices_or_sections_t.split, + comm=indices_or_sections_t.comm, + device=indices_or_sections_t.device, + ) - # 4. transform the result into a list (torch requirement) - indices_or_sections_t = [int(ele) for ele in indices_or_sections_t] - # indices_or_sections distributed - else: # TODO - pass + # 4. transform the result into a list (torch requirement) + indices_or_sections_t = indices_or_sections_t.tolist() sub_arrays_t = torch.split(ary._DNDarray__array, indices_or_sections_t, axis) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 44b44875be..d91c9f983a 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -1403,6 +1403,14 @@ def test_split(self): self.assert_array_equal(result[i], comparison[i]) # indices_or_sections = distributed DNDarray + result = ht.split(data_ht, ht.array([0, 1], split=0)) + comparison = np.split(data_np, np.array([0, 1])) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) # ==================================== # axis != 0 (2 in this case) From f7adbc085c3ca1e4798df10f0b93571b894e5618 Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Tue, 22 Sep 2020 12:55:36 +0200 Subject: [PATCH 07/27] Distributed case, axis != a.split, int axis == a.split --- heat/core/manipulations.py | 43 +++++++++++++++++++++++---- heat/core/tests/test_manipulations.py | 24 +++++++++++++++ 2 files changed, 61 insertions(+), 6 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 5d3be27086..7920a1fc7f 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -10,6 +10,7 @@ from . import factories from . import indexing from . import linalg +from . import sanitation from . import stride_tricks from . import tiling from . import types @@ -1454,19 +1455,34 @@ def split(ary, indices_or_sections, axis=0): Examples #TODO -------- - >>> x = ht.array(12).reshape((2,2,3)) + >>> x = ht.array(12).reshape((4,3)) + >>> ht.split(x, 2) + [ DNDarray([[0, 1, 2], + [3, 4, 5]]), + DNDarray([[ 6, 7, 8], + [ 9, 10, 11]]) + ] + >>> ht.split(x, [2, 3, 5]) + [ DNDarray([[0, 1, 2], + [3, 4, 5]]), + DNDarray([[6, 7, 8]] + DNDarray([[ 9, 10, 11]]), + DNDarray([]) + ] + + + """ # sanitize ary - if not isinstance(ary, dndarray.DNDarray): - raise TypeError("Expected ary to be a DNDarray, but was {}".format(type(ary))) + sanitation.sanitize_input(ary) # sanitize axis if not isinstance(axis, int): - raise TypeError("Expected axis to be an integer, but was {}".format(type(axis))) + raise TypeError("Expected `axis` to be an integer, but was {}".format(type(axis))) if axis < 0 or axis > len(ary.gshape) - 1: raise ValueError( - "Invalid input for axis. Valid range is between 0 and {}, but was {}".format( + "Invalid input for `axis`. Valid range is between 0 and {}, but was {}".format( len(ary.gshape) - 1, axis ) ) @@ -1493,11 +1509,14 @@ def split(ary, indices_or_sections, axis=0): ) else: raise TypeError( - "Expected indices_or_sections to be array_like (DNDarray, list or tuple), but was {}".format( + "Expected `indices_or_sections` to be array_like (DNDarray, list or tuple), but was {}".format( type(indices_or_sections) ) ) + # start of actual algorithm + + # undistributed case if ary.split is None: if isinstance(indices_or_sections, int): sub_arrays_t = torch.split(ary._DNDarray__array, indices_or_sections_t, axis) @@ -1535,6 +1554,18 @@ def split(ary, indices_or_sections, axis=0): indices_or_sections_t = indices_or_sections_t.tolist() sub_arrays_t = torch.split(ary._DNDarray__array, indices_or_sections_t, axis) + # distributed case + else: + if ary.split == axis: + raise ValueError( + "Split can only be applied to undistributed tensors if `ary.split` == `axis`.\n" + "Split axis {} is not allowed for `ary`".format(ary.split) + ) + else: + if isinstance(indices_or_sections, int): + sub_arrays_t = torch.split(ary._DNDarray__array, indices_or_sections_t, axis) + else: + pass sub_arrays_ht = [ factories.array( diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index d91c9f983a..185727298a 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -1451,6 +1451,30 @@ def test_split(self): with self.assertRaises(ValueError): ht.split(data_ht, [[0, 1]]) + # ==================================== + # DISTRIBUTED CASE + # ==================================== + # axis == ary.split + # ==================================== + data_ht = ht.arange(120, split=0).reshape((4, 5, 6)) + data_np = data_ht.numpy() + + with self.assertRaises(ValueError): + ht.split(data_ht, 2, 0) + + # ==================================== + # axis != ary.split + # ==================================== + # indices_or_sections = int + result = ht.split(data_ht, 2, 2) + comparison = np.split(data_np, 2, 2) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + def test_resplit(self): if ht.MPI_WORLD.size > 1: # resplitting with same axis, should leave everything unchanged From 2cbad85f74ba765f360296b824e217050e970aaf Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Tue, 22 Sep 2020 13:23:26 +0200 Subject: [PATCH 08/27] Implementation for all cases --- heat/core/manipulations.py | 27 +++++++---------------- heat/core/tests/test_manipulations.py | 31 +++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 7920a1fc7f..fa633c9a78 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -1435,13 +1435,14 @@ def split(ary, indices_or_sections, axis=0): ---------- ary : DNDarray DNDArray to be divided into sub-DNDarrays. - indices_or_sections : int or 1-dimensional array_like (i.e. DNDarray, list or tuple) + indices_or_sections : int or 1-dimensional array_like (i.e. undistributed DNDarray, list or tuple) If indices_or_sections is an integer, N, the DNDarray will be divided into N equal DNDarrays along axis. If such a split is not possible, an error is raised. If indices_or_sections is a 1-D DNDarray of sorted integers, the entries indicate where along axis the array is split. axis : int, optional The axis along which to split, default is 0. + axis is not allowed to equal ary.split if ary is distributed. Returns ------- @@ -1469,10 +1470,6 @@ def split(ary, indices_or_sections, axis=0): DNDarray([[ 9, 10, 11]]), DNDarray([]) ] - - - - """ # sanitize ary sanitation.sanitize_input(ary) @@ -1516,8 +1513,12 @@ def split(ary, indices_or_sections, axis=0): # start of actual algorithm - # undistributed case - if ary.split is None: + if ary.split == axis: + raise ValueError( + "Split can only be applied to undistributed tensors if `ary.split` == `axis`.\n" + "Split axis {} is not allowed for `ary` in this case.".format(ary.split) + ) + else: if isinstance(indices_or_sections, int): sub_arrays_t = torch.split(ary._DNDarray__array, indices_or_sections_t, axis) else: @@ -1554,18 +1555,6 @@ def split(ary, indices_or_sections, axis=0): indices_or_sections_t = indices_or_sections_t.tolist() sub_arrays_t = torch.split(ary._DNDarray__array, indices_or_sections_t, axis) - # distributed case - else: - if ary.split == axis: - raise ValueError( - "Split can only be applied to undistributed tensors if `ary.split` == `axis`.\n" - "Split axis {} is not allowed for `ary`".format(ary.split) - ) - else: - if isinstance(indices_or_sections, int): - sub_arrays_t = torch.split(ary._DNDarray__array, indices_or_sections_t, axis) - else: - pass sub_arrays_ht = [ factories.array( diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 185727298a..b0db44d3a2 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -1475,6 +1475,37 @@ def test_split(self): self.assertIsInstance(result[i], ht.DNDarray) self.assert_array_equal(result[i], comparison[i]) + # indices_or_sections = list + result = ht.split(data_ht, [3, 4, 6], 2) + comparison = np.split(data_np, [3, 4, 6], 2) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = undistributed DNDarray + result = ht.split(data_ht, ht.array([3, 4, 6]), 2) + comparison = np.split(data_np, np.array([3, 4, 6]), 2) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = distributed DNDarray + indices = ht.array([3, 4, 6], split=0) + result = ht.split(data_ht, indices, 2) + comparison = np.split(data_np, np.array([3, 4, 6]), 2) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + def test_resplit(self): if ht.MPI_WORLD.size > 1: # resplitting with same axis, should leave everything unchanged From 7093e3bf90807d444dd4612b55e3d297134729d2 Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Wed, 23 Sep 2020 09:32:37 +0200 Subject: [PATCH 09/27] ary.split == axis, ary.comm.size == indices --- heat/core/manipulations.py | 35 +++++++++++++++++++++++---- heat/core/tests/test_manipulations.py | 13 +++++++++- 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index fa633c9a78..1cb7ae1c1a 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -1513,11 +1513,36 @@ def split(ary, indices_or_sections, axis=0): # start of actual algorithm - if ary.split == axis: - raise ValueError( - "Split can only be applied to undistributed tensors if `ary.split` == `axis`.\n" - "Split axis {} is not allowed for `ary` in this case.".format(ary.split) - ) + if ary.split == axis and ary.split is not None: + # CASE 0 number of processes == indices_or_selections -> split already done due to distribution + if isinstance(indices_or_sections, int) and ary.comm.size == indices_or_sections: + new_lshape = list(ary.lshape) + new_lshape[axis] = 0 + sub_arrays_t = [ + torch.empty(new_lshape) if i != ary.comm.rank else ary._DNDarray__array + for i in range(indices_or_sections) + ] + + # CASE 1 number of processes > tensor-chunk size -> reorder chunks correctly + elif isinstance(indices_or_sections, int) and ary.comm.size > indices_or_sections: + # no data + if ary.lshape[axis] == 0: + sub_arrays_t = [torch.empty(ary.lshape) for i in range(indices_or_sections_t)] + # already correctly split + elif indices_or_sections_t == ary.lshape[axis]: + sub_arrays_t = [ + torch.empty(ary.lshape) if i != ary.comm.rank else ary._DNDarray__array + for i in range(indices_or_sections) + ] + # chunks too small + else: + pass + + else: + raise ValueError( + "Split can only be applied to undistributed tensors if `ary.split` == `axis`.\n" + "Split axis {} is not allowed for `ary` in this case.".format(ary.split) + ) else: if isinstance(indices_or_sections, int): sub_arrays_t = torch.split(ary._DNDarray__array, indices_or_sections_t, axis) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index b0db44d3a2..5f31fc60ad 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -1459,8 +1459,19 @@ def test_split(self): data_ht = ht.arange(120, split=0).reshape((4, 5, 6)) data_np = data_ht.numpy() + if data_ht.comm.size == 2: # TODO generalize + result = ht.split(data_ht, 2) + comparison = np.split(data_np, 2) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assertTrue((ht.array(comparison[i]) == result[i]).all()) + # self.assert_array_equal(result[i], comparison[i]) + with self.assertRaises(ValueError): - ht.split(data_ht, 2, 0) + ht.split(data_ht, [0, 2], 0) # ==================================== # axis != ary.split From b949e26206bd2d0c15679bb4f2bd0bf7a19b5570 Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Wed, 23 Sep 2020 11:22:59 +0200 Subject: [PATCH 10/27] ary.split == axis, ary.comm.size >= indices --- heat/core/manipulations.py | 31 ++++++++++++++++++--------- heat/core/tests/test_manipulations.py | 9 ++++++++ 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index b472fd3966..3c72c2cbfc 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -1858,7 +1858,7 @@ def split(ary, indices_or_sections, axis=0): # start of actual algorithm - if ary.split == axis and ary.split is not None: + if ary.split == axis and ary.split is not None and ary.comm.size > 1: # CASE 0 number of processes == indices_or_selections -> split already done due to distribution if isinstance(indices_or_sections, int) and ary.comm.size == indices_or_sections: new_lshape = list(ary.lshape) @@ -1868,20 +1868,31 @@ def split(ary, indices_or_sections, axis=0): for i in range(indices_or_sections) ] - # CASE 1 number of processes > tensor-chunk size -> reorder chunks correctly + # CASE 1 number of processes > tensor-chunk size -> reorder (and split) chunks correctly elif isinstance(indices_or_sections, int) and ary.comm.size > indices_or_sections: # no data if ary.lshape[axis] == 0: sub_arrays_t = [torch.empty(ary.lshape) for i in range(indices_or_sections_t)] - # already correctly split - elif indices_or_sections_t == ary.lshape[axis]: - sub_arrays_t = [ - torch.empty(ary.lshape) if i != ary.comm.rank else ary._DNDarray__array - for i in range(indices_or_sections) - ] - # chunks too small + # calculate mapped list else: - pass + offset, local_shape, slices = ary.comm.chunk(ary.gshape, axis) + idx_block = offset // indices_or_sections_t + left_data_block = indices_or_sections_t - (offset % indices_or_sections_t) + + # put all available data on process on concerned block + if left_data_block >= ary.lshape[axis]: + new_lshape = list(ary.lshape) + new_lshape[axis] = 0 + sub_arrays_t = [ + torch.empty(new_lshape) if i != idx_block else ary._DNDarray__array + for i in range(indices_or_sections) + ] + else: + new_indices = torch.zeros(indices_or_sections, dtype=int) + new_indices[idx_block] = left_data_block + new_indices[idx_block + 1] = ary.gshape[axis] - left_data_block + + sub_arrays_t = torch.split(ary._DNDarray__array, new_indices.tolist(), axis) else: raise ValueError( diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 7eed4c9de7..98060298f1 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -1796,6 +1796,15 @@ def test_split(self): self.assertIsInstance(result[i], ht.DNDarray) self.assertTrue((ht.array(comparison[i]) == result[i]).all()) # self.assert_array_equal(result[i], comparison[i]) + if data_ht.comm.size > 2: # TODO generalize + result = ht.split(data_ht, 2) + comparison = np.split(data_np, 2) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assertTrue((ht.array(comparison[i]) == result[i]).all()) with self.assertRaises(ValueError): ht.split(data_ht, [0, 2], 0) From a5f49f8e32e43ea8b249e00e6dbc8e58f6dd148e Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Wed, 23 Sep 2020 13:17:24 +0200 Subject: [PATCH 11/27] ary.split == axis, indices = int --- heat/core/manipulations.py | 79 +++++++++++++++++---------- heat/core/tests/test_manipulations.py | 29 ++++------ 2 files changed, 60 insertions(+), 48 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 3c72c2cbfc..bf0c292b0a 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -1859,41 +1859,62 @@ def split(ary, indices_or_sections, axis=0): # start of actual algorithm if ary.split == axis and ary.split is not None and ary.comm.size > 1: - # CASE 0 number of processes == indices_or_selections -> split already done due to distribution - if isinstance(indices_or_sections, int) and ary.comm.size == indices_or_sections: - new_lshape = list(ary.lshape) - new_lshape[axis] = 0 - sub_arrays_t = [ - torch.empty(new_lshape) if i != ary.comm.rank else ary._DNDarray__array - for i in range(indices_or_sections) - ] - - # CASE 1 number of processes > tensor-chunk size -> reorder (and split) chunks correctly - elif isinstance(indices_or_sections, int) and ary.comm.size > indices_or_sections: - # no data - if ary.lshape[axis] == 0: - sub_arrays_t = [torch.empty(ary.lshape) for i in range(indices_or_sections_t)] - # calculate mapped list - else: + + if isinstance(indices_or_sections, int): + # CASE 0 number of processes == indices_or_selections -> split already done due to distribution + if ary.comm.size == indices_or_sections: + new_lshape = list(ary.lshape) + new_lshape[axis] = 0 + sub_arrays_t = [ + torch.empty(new_lshape) if i != ary.comm.rank else ary._DNDarray__array + for i in range(indices_or_sections) + ] + + # CASE 1 number of processes > tensor-chunk size -> reorder (and split) chunks correctly + elif ary.comm.size > indices_or_sections: + # no data + if ary.lshape[axis] == 0: + sub_arrays_t = [torch.empty(ary.lshape) for i in range(indices_or_sections)] + # calculate mapped list + else: + offset, local_shape, slices = ary.comm.chunk(ary.gshape, axis) + idx_block = offset // indices_or_sections_t + left_data_block = indices_or_sections_t - (offset % indices_or_sections_t) + + # put all available data on process on concerned block + if left_data_block >= ary.lshape[axis]: + new_lshape = list(ary.lshape) + new_lshape[axis] = 0 + sub_arrays_t = [ + torch.empty(new_lshape) if i != idx_block else ary._DNDarray__array + for i in range(indices_or_sections) + ] + else: + new_indices = torch.zeros(indices_or_sections, dtype=int) + new_indices[idx_block] = left_data_block + new_indices[idx_block + 1] = ary.gshape[axis] - left_data_block + + sub_arrays_t = torch.split(ary._DNDarray__array, new_indices.tolist(), axis) + # CASE 2 number of processes < tensor-chunk size -> reorder (and split) chunks correctly + elif ary.comm.size < indices_or_sections: offset, local_shape, slices = ary.comm.chunk(ary.gshape, axis) idx_block = offset // indices_or_sections_t left_data_block = indices_or_sections_t - (offset % indices_or_sections_t) + left_data_process = ary.lshape[axis] - # put all available data on process on concerned block - if left_data_block >= ary.lshape[axis]: - new_lshape = list(ary.lshape) - new_lshape[axis] = 0 - sub_arrays_t = [ - torch.empty(new_lshape) if i != idx_block else ary._DNDarray__array - for i in range(indices_or_sections) - ] - else: - new_indices = torch.zeros(indices_or_sections, dtype=int) - new_indices[idx_block] = left_data_block - new_indices[idx_block + 1] = ary.gshape[axis] - left_data_block + new_indices = torch.zeros(indices_or_sections, dtype=int) - sub_arrays_t = torch.split(ary._DNDarray__array, new_indices.tolist(), axis) + for i in range(idx_block, indices_or_sections): + if left_data_block >= left_data_process: + new_indices[idx_block] = left_data_process + break + else: + new_indices[idx_block] = left_data_block + left_data_process -= left_data_block + idx_block += 1 + left_data_block = indices_or_sections_t + sub_arrays_t = torch.split(ary._DNDarray__array, new_indices.tolist(), axis) else: raise ValueError( "Split can only be applied to undistributed tensors if `ary.split` == `axis`.\n" diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 98060298f1..2d75c79d47 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -1786,28 +1786,19 @@ def test_split(self): data_ht = ht.arange(120, split=0).reshape((4, 5, 6)) data_np = data_ht.numpy() - if data_ht.comm.size == 2: # TODO generalize - result = ht.split(data_ht, 2) - comparison = np.split(data_np, 2) - - self.assertTrue(len(result) == len(comparison)) - - for i in range(len(result)): - self.assertIsInstance(result[i], ht.DNDarray) - self.assertTrue((ht.array(comparison[i]) == result[i]).all()) - # self.assert_array_equal(result[i], comparison[i]) - if data_ht.comm.size > 2: # TODO generalize - result = ht.split(data_ht, 2) - comparison = np.split(data_np, 2) + # indices = int + result = ht.split(data_ht, 2) + comparison = np.split(data_np, 2) - self.assertTrue(len(result) == len(comparison)) + self.assertTrue(len(result) == len(comparison)) - for i in range(len(result)): - self.assertIsInstance(result[i], ht.DNDarray) - self.assertTrue((ht.array(comparison[i]) == result[i]).all()) + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assertTrue((ht.array(comparison[i]) == result[i]).all()) - with self.assertRaises(ValueError): - ht.split(data_ht, [0, 2], 0) + if data_ht.comm.size > 1: + with self.assertRaises(ValueError): + ht.split(data_ht, [0, 2], 0) # ==================================== # axis != ary.split From 5a363d93fb1e61714e772679cf47ca95c0102729 Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Wed, 23 Sep 2020 13:36:58 +0200 Subject: [PATCH 12/27] United cases 1 & 2, replaced for loop --- heat/core/manipulations.py | 47 +++++++++++++++----------------------- 1 file changed, 18 insertions(+), 29 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index bf0c292b0a..3b4396e1b5 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -1870,51 +1870,40 @@ def split(ary, indices_or_sections, axis=0): for i in range(indices_or_sections) ] - # CASE 1 number of processes > tensor-chunk size -> reorder (and split) chunks correctly - elif ary.comm.size > indices_or_sections: + # # CASE 1 number of processes > tensor-chunk size -> reorder (and split) chunks correctly + # # CASE 2 number of processes < tensor-chunk size -> reorder (and split) chunks correctly + else: # no data if ary.lshape[axis] == 0: sub_arrays_t = [torch.empty(ary.lshape) for i in range(indices_or_sections)] - # calculate mapped list else: offset, local_shape, slices = ary.comm.chunk(ary.gshape, axis) idx_block = offset // indices_or_sections_t left_data_block = indices_or_sections_t - (offset % indices_or_sections_t) + left_data_process = ary.lshape[axis] - # put all available data on process on concerned block - if left_data_block >= ary.lshape[axis]: - new_lshape = list(ary.lshape) - new_lshape[axis] = 0 - sub_arrays_t = [ - torch.empty(new_lshape) if i != idx_block else ary._DNDarray__array - for i in range(indices_or_sections) - ] - else: - new_indices = torch.zeros(indices_or_sections, dtype=int) - new_indices[idx_block] = left_data_block - new_indices[idx_block + 1] = ary.gshape[axis] - left_data_block - - sub_arrays_t = torch.split(ary._DNDarray__array, new_indices.tolist(), axis) - # CASE 2 number of processes < tensor-chunk size -> reorder (and split) chunks correctly - elif ary.comm.size < indices_or_sections: - offset, local_shape, slices = ary.comm.chunk(ary.gshape, axis) - idx_block = offset // indices_or_sections_t - left_data_block = indices_or_sections_t - (offset % indices_or_sections_t) - left_data_process = ary.lshape[axis] + new_indices = torch.zeros(indices_or_sections, dtype=int) - new_indices = torch.zeros(indices_or_sections, dtype=int) - - for i in range(idx_block, indices_or_sections): if left_data_block >= left_data_process: new_indices[idx_block] = left_data_process - break else: new_indices[idx_block] = left_data_block left_data_process -= left_data_block idx_block += 1 - left_data_block = indices_or_sections_t - sub_arrays_t = torch.split(ary._DNDarray__array, new_indices.tolist(), axis) + # calculate blocks which can be filled completely + left_blocks_to_fill = left_data_process // indices_or_sections_t + new_indices[ + idx_block : (left_blocks_to_fill + idx_block) + ] = indices_or_sections_t + + # assign residuate to following process + new_indices[left_blocks_to_fill + idx_block] = ( + left_data_process % indices_or_sections_t + ) + + sub_arrays_t = torch.split(ary._DNDarray__array, new_indices.tolist(), axis) + # indices or sections == DNDarray else: raise ValueError( "Split can only be applied to undistributed tensors if `ary.split` == `axis`.\n" From 730734067a35debded10df529807130bdf1e4bf9 Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Wed, 23 Sep 2020 15:36:50 +0200 Subject: [PATCH 13/27] ary.split == axis, indices array_like --- heat/core/manipulations.py | 63 ++++++++++++++++++++++++--- heat/core/tests/test_manipulations.py | 42 ++++++++++++++++-- 2 files changed, 97 insertions(+), 8 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 3b4396e1b5..88d22c15f6 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -1860,7 +1860,9 @@ def split(ary, indices_or_sections, axis=0): if ary.split == axis and ary.split is not None and ary.comm.size > 1: - if isinstance(indices_or_sections, int): + if isinstance( + indices_or_sections, int + ): # TODO eventually use x.comm.counts_displs (easier solution) # CASE 0 number of processes == indices_or_selections -> split already done due to distribution if ary.comm.size == indices_or_sections: new_lshape = list(ary.lshape) @@ -1897,7 +1899,7 @@ def split(ary, indices_or_sections, axis=0): idx_block : (left_blocks_to_fill + idx_block) ] = indices_or_sections_t - # assign residuate to following process + # assign residual to following process new_indices[left_blocks_to_fill + idx_block] = ( left_data_process % indices_or_sections_t ) @@ -1905,10 +1907,61 @@ def split(ary, indices_or_sections, axis=0): sub_arrays_t = torch.split(ary._DNDarray__array, new_indices.tolist(), axis) # indices or sections == DNDarray else: - raise ValueError( - "Split can only be applied to undistributed tensors if `ary.split` == `axis`.\n" - "Split axis {} is not allowed for `ary` in this case.".format(ary.split) + if indices_or_sections.split is not None: + warnings.warn( + "`indices_or_sections` might not be distributed (along axis {}) if `ary` is not distributed.\n" + "`indices_or_sections` will be copied with new split axis None.".format( + indices_or_sections.split + ) + ) + indices_or_sections = resplit(indices_or_sections, None) + + offset, local_shape, slices = ary.comm.chunk(ary.gshape, axis) + + # np to torch mapping + + # 1. replace all values out of range with gshape[axis] to generate size 0 + indices_or_sections_t = indexing.where( + indices_or_sections <= ary.gshape[axis], indices_or_sections, ary.gshape[axis] + ) + + # 2. add first and last value to DNDarray + # 3. calculate the 1-st discrete difference therefore corresponding chunk sizes + indices_or_sections_t = arithmetics.diff( + indices_or_sections_t, prepend=0, append=ary.gshape[axis] + ) + indices_or_sections_t = factories.array( + indices_or_sections_t, + dtype=types.int64, + is_split=indices_or_sections_t.split, + comm=indices_or_sections_t.comm, + device=indices_or_sections_t.device, ) + + # 4. transform the result into a list (torch requirement) + indices_or_sections_t = indices_or_sections_t.tolist() + + left_data_process = ary.lshape[axis] + + # 5. subtract already split data on a different process + for i in range(len(indices_or_sections_t)): + if offset != 0 and offset - indices_or_sections_t[i] >= 0: + offset -= indices_or_sections_t[i] + indices_or_sections_t[i] = 0 + else: + if offset != 0: + indices_or_sections_t[i] -= offset + offset = 0 + if left_data_process - indices_or_sections_t[i] >= 0: + left_data_process -= indices_or_sections_t[i] + else: + indices_or_sections_t[i] = left_data_process + indices_or_sections_t[i + 1 :] = [0] * ( + len(indices_or_sections_t) - (i + 1) + ) + break + + sub_arrays_t = torch.split(ary._DNDarray__array, indices_or_sections_t, axis) else: if isinstance(indices_or_sections, int): sub_arrays_t = torch.split(ary._DNDarray__array, indices_or_sections_t, axis) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 2d75c79d47..4f0e5f967a 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -1796,9 +1796,45 @@ def test_split(self): self.assertIsInstance(result[i], ht.DNDarray) self.assertTrue((ht.array(comparison[i]) == result[i]).all()) - if data_ht.comm.size > 1: - with self.assertRaises(ValueError): - ht.split(data_ht, [0, 2], 0) + # indices_or_sections = tuple + result = ht.split(data_ht, (1, 3, 5)) + comparison = np.split(data_np, (1, 3, 5)) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = list + result = ht.split(data_ht, [1, 3, 5]) + comparison = np.split(data_np, [1, 3, 5]) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = undistributed DNDarray + result = ht.split(data_ht, ht.array([1, 3, 5])) + comparison = np.split(data_np, np.array([1, 3, 5])) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = distributed DNDarray + result = ht.split(data_ht, ht.array([1, 3, 5], split=0)) + comparison = np.split(data_np, np.array([1, 3, 5])) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) # ==================================== # axis != ary.split From fcd48347ab0a93f5f508b8c20f2c530621d4853f Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Thu, 24 Sep 2020 14:57:17 +0200 Subject: [PATCH 14/27] Expanded Docstrings --- heat/core/manipulations.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 88d22c15f6..640186bb42 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -1799,7 +1799,7 @@ def split(ary, indices_or_sections, axis=0): ValueError If indices_or_sections is given as integer, but a split does not result in equal division. - Examples #TODO + Examples -------- >>> x = ht.array(12).reshape((4,3)) >>> ht.split(x, 2) @@ -1811,10 +1811,25 @@ def split(ary, indices_or_sections, axis=0): >>> ht.split(x, [2, 3, 5]) [ DNDarray([[0, 1, 2], [3, 4, 5]]), - DNDarray([[6, 7, 8]] + DNDarray([[6, 7, 8]] DNDarray([[ 9, 10, 11]]), DNDarray([]) ] + >>> ht.split(x, [1, 2], 1) + [ DNDarray([[0], + [3], + [6], + [9]]), + DNDarray([[ 1], + [ 4], + [ 7], + [10]], + DNDarray([[ 2], + [ 5], + [ 8], + [11]]) + ] + """ # sanitize ary sanitation.sanitize_input(ary) @@ -1860,9 +1875,7 @@ def split(ary, indices_or_sections, axis=0): if ary.split == axis and ary.split is not None and ary.comm.size > 1: - if isinstance( - indices_or_sections, int - ): # TODO eventually use x.comm.counts_displs (easier solution) + if isinstance(indices_or_sections, int): # CASE 0 number of processes == indices_or_selections -> split already done due to distribution if ary.comm.size == indices_or_sections: new_lshape = list(ary.lshape) @@ -1943,6 +1956,7 @@ def split(ary, indices_or_sections, axis=0): left_data_process = ary.lshape[axis] + # TODO eventually restructure code and add if case here instead # 5. subtract already split data on a different process for i in range(len(indices_or_sections_t)): if offset != 0 and offset - indices_or_sections_t[i] >= 0: From ffbbc1964ac67dbce2109d37e978ac20516f5e53 Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Thu, 24 Sep 2020 16:55:02 +0200 Subject: [PATCH 15/27] Changed algorithm using where expression --- heat/core/manipulations.py | 34 +++++++++------------------------- 1 file changed, 9 insertions(+), 25 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 640186bb42..ddd89687ef 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -1930,18 +1930,23 @@ def split(ary, indices_or_sections, axis=0): indices_or_sections = resplit(indices_or_sections, None) offset, local_shape, slices = ary.comm.chunk(ary.gshape, axis) + slice_axis = slices[axis] - # np to torch mapping + # reduce information to the (chunk) relevant + indices_or_sections_t = indexing.where( + indices_or_sections <= slice_axis.start, slice_axis.start, indices_or_sections + ) - # 1. replace all values out of range with gshape[axis] to generate size 0 indices_or_sections_t = indexing.where( - indices_or_sections <= ary.gshape[axis], indices_or_sections, ary.gshape[axis] + indices_or_sections_t >= slice_axis.stop, slice_axis.stop, indices_or_sections_t ) + # np to torch mapping + # 2. add first and last value to DNDarray # 3. calculate the 1-st discrete difference therefore corresponding chunk sizes indices_or_sections_t = arithmetics.diff( - indices_or_sections_t, prepend=0, append=ary.gshape[axis] + indices_or_sections_t, prepend=slice_axis.start, append=slice_axis.stop ) indices_or_sections_t = factories.array( indices_or_sections_t, @@ -1954,27 +1959,6 @@ def split(ary, indices_or_sections, axis=0): # 4. transform the result into a list (torch requirement) indices_or_sections_t = indices_or_sections_t.tolist() - left_data_process = ary.lshape[axis] - - # TODO eventually restructure code and add if case here instead - # 5. subtract already split data on a different process - for i in range(len(indices_or_sections_t)): - if offset != 0 and offset - indices_or_sections_t[i] >= 0: - offset -= indices_or_sections_t[i] - indices_or_sections_t[i] = 0 - else: - if offset != 0: - indices_or_sections_t[i] -= offset - offset = 0 - if left_data_process - indices_or_sections_t[i] >= 0: - left_data_process -= indices_or_sections_t[i] - else: - indices_or_sections_t[i] = left_data_process - indices_or_sections_t[i + 1 :] = [0] * ( - len(indices_or_sections_t) - (i + 1) - ) - break - sub_arrays_t = torch.split(ary._DNDarray__array, indices_or_sections_t, axis) else: if isinstance(indices_or_sections, int): From 1c65be16227b61ae938e17714f5101b3476b5311 Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Fri, 25 Sep 2020 08:38:37 +0200 Subject: [PATCH 16/27] Additional test case, first draft hsplit, vsplit, dsplit --- heat/core/manipulations.py | 89 ++++++++++++++++++++++++++- heat/core/tests/test_manipulations.py | 17 +++++ 2 files changed, 105 insertions(+), 1 deletion(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index ddd89687ef..5c30052295 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -21,11 +21,13 @@ "concatenate", "diag", "diagonal", + "dsplit", "expand_dims", "flatten", "flip", "fliplr", "flipud", + "hsplit", "hstack", "pad", "reshape", @@ -39,6 +41,7 @@ "stack", "topk", "unique", + "vsplit", "vstack", ] @@ -643,6 +646,77 @@ def diagonal(a, offset=0, dim1=0, dim2=1): return factories.array(result, dtype=a.dtype, is_split=split, device=a.device, comm=a.comm) +def dsplit(ary, indices_or_sections): + """ + Split array into multiple sub-arrays along the 3rd axis (depth). + + Please refer to the split documentation. dsplit is equivalent to split with axis=2, + the array is always split along the third axis provided the array dimension is greater than or equal to 3. + + Parameters + ---------- + ary : DNDarray + DNDArray to be divided into sub-DNDarrays. + indices_or_sections : int or 1-dimensional array_like (i.e. undistributed DNDarray, list or tuple) + If indices_or_sections is an integer, N, the DNDarray will be divided into N equal DNDarrays along the 3rd axis. + If such a split is not possible, an error is raised. + If indices_or_sections is a 1-D DNDarray of sorted integers, the entries indicate where along the 3rd axis + the array is split. + If an index exceeds the dimension of the array along the 3rd axis, an empty sub-array is returned correspondingly. + + Returns + ------- + sub_arrays : list of DNDarrays + A list of sub-DNDarrays as views into ary. + + Raises + ------ + ValueError + If indices_or_sections is given as integer, but a split does not result in equal division. + + Examples + -------- + >>> x = ht.array(24).reshape((2, 3, 4)) + >>> ht.dsplit(x, 2) + [ + DNDarray([[[ 0, 1], + [ 4, 5], + [ 8, 9]], + + [[12, 13], + [16, 17], + [20, 21]]]), + DNDarray([[[ 2, 3], + [ 6, 7], + [10, 11]], + + [[14, 15], + [18, 19], + [22, 23]]]) + ] + >>> ht.dsplit(x, [1, 4]) + [ + DNDarray([[[ 0], + [ 4], + [ 8]], + + [[12], + [16], + [20]]]), + DNDarray([[[ 1, 2, 3], + [ 5, 6, 7], + [ 9, 10, 11]], + + [[13, 14, 15], + [17, 18, 19], + [21, 22, 23]]]), + DNDarray([]) + ] + + """ + return split(ary, indices_or_sections, 2) + + def expand_dims(a, axis): """ Expand the shape of an array. @@ -865,6 +939,10 @@ def flipud(a): return flip(a, 0) +def hsplit(ary, indices_or_sections): # TODO + pass + + def hstack(tup): """ Stack arrays in sequence horizontally (column wise). @@ -1096,7 +1174,7 @@ def pad(array, pad_width, mode="constant", constant_values=0): if len(pad) // 2 > len(array.shape): raise ValueError( f"Not enough dimensions to pad.\n" - f"Padding a {len(array.shape)}-dimensional tensor for {len(pad)//2}" + f"Padding a {len(array.shape)}-dimensional tensor for {len(pad) // 2}" f" dimensions is not possible." ) @@ -1785,6 +1863,11 @@ def split(ary, indices_or_sections, axis=0): If such a split is not possible, an error is raised. If indices_or_sections is a 1-D DNDarray of sorted integers, the entries indicate where along axis the array is split. + For example, indices_or_sections = [2, 3] would, for axis = 0, result in + - ary[:2] + - ary[2:3] + - ary[3:] + If an index exceeds the dimension of the array along axis, an empty sub-array is returned correspondingly. axis : int, optional The axis along which to split, default is 0. axis is not allowed to equal ary.split if ary is distributed. @@ -2510,6 +2593,10 @@ def unique(a, sorted=False, return_inverse=False, axis=None): return return_value +def vsplit(ary, indices_or_sections): # TODO + pass + + def resplit(arr, axis=None): """ Out-of-place redistribution of the content of the tensor. Allows to "unsplit" (i.e. gather) all values from all diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 4f0e5f967a..c32660815f 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -738,6 +738,9 @@ def test_diagonal(self): numpy_args={"axis1": 0, "axis2": 1}, ) + def test_dsplit(self): + pass # TODO + def test_expand_dims(self): # vector data a = ht.arange(10) @@ -1792,6 +1795,20 @@ def test_split(self): self.assertTrue(len(result) == len(comparison)) + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assertTrue((ht.array(comparison[i]) == result[i]).all()) + + # larger example + data_ht_large = ht.arange(160, split=0).reshape((8, 5, 4)) + data_np_large = data_ht_large.numpy() + + # indices = int + result = ht.split(data_ht_large, 2) + comparison = np.split(data_np_large, 2) + + self.assertTrue(len(result) == len(comparison)) + for i in range(len(result)): self.assertIsInstance(result[i], ht.DNDarray) self.assertTrue((ht.array(comparison[i]) == result[i]).all()) From 6286635ad5769edde9163e10db652ab543573753 Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Fri, 25 Sep 2020 09:14:42 +0200 Subject: [PATCH 17/27] dsplit, hsplit, vsplit implemented + docstrings --- heat/core/manipulations.py | 135 +++++++++++++++++++++++++++++++++++-- 1 file changed, 131 insertions(+), 4 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 5c30052295..bab5388a3e 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -939,8 +939,70 @@ def flipud(a): return flip(a, 0) -def hsplit(ary, indices_or_sections): # TODO - pass +def hsplit(ary, indices_or_sections): + """ + Split array into multiple sub-arrays along the 2nd axis (horizontally/column-wise). + + Please refer to the split documentation. hsplit is nearly equivalent to split with axis=1, + the array is always split along the second axis though, in contrary to split, regardless of the array dimension. + + Parameters + ---------- + ary : DNDarray + DNDArray to be divided into sub-DNDarrays. + indices_or_sections : int or 1-dimensional array_like (i.e. undistributed DNDarray, list or tuple) + If indices_or_sections is an integer, N, the DNDarray will be divided into N equal DNDarrays along the 2nd axis. + If such a split is not possible, an error is raised. + If indices_or_sections is a 1-D DNDarray of sorted integers, the entries indicate where along the 2nd axis + the array is split. + If an index exceeds the dimension of the array along the 2nd axis, an empty sub-array is returned correspondingly. + + Returns + ------- + sub_arrays : list of DNDarrays + A list of sub-DNDarrays as views into ary. + + Raises + ------ + ValueError + If indices_or_sections is given as integer, but a split does not result in equal division. + + Examples + -------- + >>> x = ht.arange(24).reshape((2, 4, 3)) + >>> ht.hsplit(x, 2) + [ + DNDarray([[[ 0, 1, 2], + [ 3, 4, 5]], + + [[12, 13, 14], + [15, 16, 17]]]), + DNDarray([[[ 6, 7, 8], + [ 9, 10, 11]], + + [[18, 19, 20], + [21, 22, 23]]]) + ] + + >>> ht.hsplit(x, [1, 3]) + [ + DNDarray([[[ 0, 1, 2]], + + [[12, 13, 14]]]), + DNDarray([[[ 3, 4, 5], + [ 6, 7, 8]], + + [[15, 16, 17], + [18, 19, 20]]]), + DNDarray([[[ 9, 10, 11]], + + [[21, 22, 23]]])] + """ + sanitation.sanitize_input(ary) + + if len(ary.lshape) < 2: + ary = ary.reshape(ary, (1, ary.lshape[0])) + return split(ary, indices_or_sections, 1) def hstack(tup): @@ -2593,8 +2655,73 @@ def unique(a, sorted=False, return_inverse=False, axis=None): return return_value -def vsplit(ary, indices_or_sections): # TODO - pass +def vsplit(ary, indices_or_sections): + """ + Split array into multiple sub-arrays along the 1st axis (vertically/row-wise). + + Please refer to the split documentation. hsplit is equivalent to split with axis=0, + the array is always split along the first axis regardless of the array dimension. + + Parameters + ---------- + ary : DNDarray + DNDArray to be divided into sub-DNDarrays. + indices_or_sections : int or 1-dimensional array_like (i.e. undistributed DNDarray, list or tuple) + If indices_or_sections is an integer, N, the DNDarray will be divided into N equal DNDarrays along the 1st axis. + If such a split is not possible, an error is raised. + If indices_or_sections is a 1-D DNDarray of sorted integers, the entries indicate where along the 1st axis + the array is split. + If an index exceeds the dimension of the array along the 1st axis, an empty sub-array is returned correspondingly. + + Returns + ------- + sub_arrays : list of DNDarrays + A list of sub-DNDarrays as views into ary. + + Raises + ------ + ValueError + If indices_or_sections is given as integer, but a split does not result in equal division. + + Examples + -------- + >>> x = ht.arange(24).reshape((4, 3, 2)) + >>> ht.vsplit(x, 2) + [ + DNDarray([[[ 0, 1], + [ 2, 3], + [ 4, 5]], + + [[ 6, 7], + [ 8, 9], + [10, 11]]]), + DNDarray([[[12, 13], + [14, 15], + [16, 17]], + + [[18, 19], + [20, 21], + [22, 23]]]) + ] + + >>> ht.vsplit(x, [1, 3]) + [ + DNDarray([[[0, 1], + [2, 3], + [4, 5]]]), + DNDarray([[[ 6, 7], + [ 8, 9], + [10, 11]], + + [[12, 13], + [14, 15], + [16, 17]]]), + DNDarray([[[18, 19], + [20, 21], + [22, 23]]])] + + """ + return split(ary, indices_or_sections, 0) def resplit(arr, axis=None): From a98c32997828a6c5a7be8db76867bb19a7b4a2a4 Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Fri, 25 Sep 2020 09:45:03 +0200 Subject: [PATCH 18/27] Tests for hsplit, vsplit, dsplit + flatten in hsplit --- heat/core/manipulations.py | 9 +- heat/core/tests/test_manipulations.py | 218 +++++++++++++++++++++++++- 2 files changed, 224 insertions(+), 3 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index bab5388a3e..ec78f9ba69 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -1001,8 +1001,13 @@ def hsplit(ary, indices_or_sections): sanitation.sanitize_input(ary) if len(ary.lshape) < 2: - ary = ary.reshape(ary, (1, ary.lshape[0])) - return split(ary, indices_or_sections, 1) + ary = reshape(ary, (1, ary.lshape[0])) + result = split(ary, indices_or_sections, 1) + result = [flatten(sub_array) for sub_array in result] + else: + result = split(ary, indices_or_sections, 1) + + return result def hstack(tup): diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index c32660815f..77cc2263a1 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -739,7 +739,59 @@ def test_diagonal(self): ) def test_dsplit(self): - pass # TODO + # for further testing, see test_split + data_ht = ht.arange(24).reshape((2, 3, 4)) + data_np = data_ht.numpy() + + # indices_or_sections = int + result = ht.dsplit(data_ht, 2) + comparison = np.dsplit(data_np, 2) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = tuple + result = ht.dsplit(data_ht, (0, 1)) + comparison = np.dsplit(data_np, (0, 1)) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = list + result = ht.dsplit(data_ht, [0, 1]) + comparison = np.dsplit(data_np, [0, 1]) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = undistributed DNDarray + result = ht.dsplit(data_ht, ht.array([0, 1])) + comparison = np.dsplit(data_np, np.array([0, 1])) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = distributed DNDarray + result = ht.dsplit(data_ht, ht.array([0, 1], split=0)) + comparison = np.dsplit(data_np, np.array([0, 1])) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) def test_expand_dims(self): # vector data @@ -967,6 +1019,115 @@ def test_flipud(self): ) self.assertTrue(ht.equal(ht.resplit(ht.flipud(c), 0), r_c)) + def test_hsplit(self): + # for further testing, see test_split + # 1-dimensional array (as forbidden in split) + data_ht = ht.arange(24) + data_np = data_ht.numpy() + + # indices_or_sections = int + result = ht.hsplit(data_ht, 2) + comparison = np.hsplit(data_np, 2) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = tuple + result = ht.hsplit(data_ht, (0, 1)) + comparison = np.hsplit(data_np, (0, 1)) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = list + result = ht.hsplit(data_ht, [0, 1]) + comparison = np.hsplit(data_np, [0, 1]) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = undistributed DNDarray + result = ht.hsplit(data_ht, ht.array([0, 1])) + comparison = np.hsplit(data_np, np.array([0, 1])) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = distributed DNDarray + result = ht.hsplit(data_ht, ht.array([0, 1], split=0)) + comparison = np.hsplit(data_np, np.array([0, 1])) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + data_ht = ht.arange(24).reshape((2, 4, 3)) + data_np = data_ht.numpy() + + # indices_or_sections = int + result = ht.hsplit(data_ht, 2) + comparison = np.hsplit(data_np, 2) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = tuple + result = ht.hsplit(data_ht, (0, 1)) + comparison = np.hsplit(data_np, (0, 1)) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = list + result = ht.hsplit(data_ht, [0, 1]) + comparison = np.hsplit(data_np, [0, 1]) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = undistributed DNDarray + result = ht.hsplit(data_ht, ht.array([0, 1])) + comparison = np.hsplit(data_np, np.array([0, 1])) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = distributed DNDarray + result = ht.hsplit(data_ht, ht.array([0, 1], split=0)) + comparison = np.hsplit(data_np, np.array([0, 1])) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + def test_hstack(self): # cases to test: # MM=================================== @@ -2351,6 +2512,61 @@ def test_unique(self): res, inv = ht.unique(data_split_zero, return_inverse=True, sorted=True) self.assertTrue(torch.equal(inv, exp_inv.to(dtype=inv.dtype))) + def test_vsplit(self): + # for further testing, see test_split + data_ht = ht.arange(24).reshape((4, 3, 2)) + data_np = data_ht.numpy() + + # indices_or_sections = int + result = ht.vsplit(data_ht, 2) + comparison = np.vsplit(data_np, 2) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = tuple + result = ht.vsplit(data_ht, (0, 1)) + comparison = np.vsplit(data_np, (0, 1)) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = list + result = ht.vsplit(data_ht, [0, 1]) + comparison = np.vsplit(data_np, [0, 1]) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = undistributed DNDarray + result = ht.vsplit(data_ht, ht.array([0, 1])) + comparison = np.vsplit(data_np, np.array([0, 1])) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = distributed DNDarray + result = ht.vsplit(data_ht, ht.array([0, 1], split=0)) + comparison = np.vsplit(data_np, np.array([0, 1])) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + def test_vstack(self): # cases to test: # MM=================================== From bd7af0b923af73bf149f6acd40bfbff99c0367b0 Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Fri, 25 Sep 2020 09:58:24 +0200 Subject: [PATCH 19/27] Correction of docstring example --- heat/core/manipulations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index ec78f9ba69..256c2e487a 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -1951,7 +1951,7 @@ def split(ary, indices_or_sections, axis=0): Examples -------- - >>> x = ht.array(12).reshape((4,3)) + >>> x = ht.arange(12).reshape((4,3)) >>> ht.split(x, 2) [ DNDarray([[0, 1, 2], [3, 4, 5]]), From 057d430fcadbe1561f876513187c581c97887f5b Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Fri, 25 Sep 2020 10:08:40 +0200 Subject: [PATCH 20/27] Added warning to docstring --- heat/core/manipulations.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 256c2e487a..7c3376bcf6 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -1920,6 +1920,9 @@ def sort(a, axis=None, descending=False, out=None): def split(ary, indices_or_sections, axis=0): """ Split a DNDarray into multiple sub-DNDarrays as views into ary. + ! Warning ! + Though it is possible to distribute `ary`, this function has nothing to do with the split + parameter of a DNDarray. Parameters ---------- From 439540aaf926bfdb5a68bbb7a82fcc5b4bec9702 Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Fri, 25 Sep 2020 13:24:10 +0200 Subject: [PATCH 21/27] Clarifying modifications --- heat/core/manipulations.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 7c3376bcf6..98f3269645 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2029,7 +2029,7 @@ def split(ary, indices_or_sections, axis=0): if ary.split == axis and ary.split is not None and ary.comm.size > 1: if isinstance(indices_or_sections, int): - # CASE 0 number of processes == indices_or_selections -> split already done due to distribution + # CASE 1 number of processes == indices_or_selections -> split already done due to distribution if ary.comm.size == indices_or_sections: new_lshape = list(ary.lshape) new_lshape[axis] = 0 @@ -2038,35 +2038,34 @@ def split(ary, indices_or_sections, axis=0): for i in range(indices_or_sections) ] - # # CASE 1 number of processes > tensor-chunk size -> reorder (and split) chunks correctly - # # CASE 2 number of processes < tensor-chunk size -> reorder (and split) chunks correctly + # # CASE 2 number of processes != indices_or_selections -> reorder (and split) chunks correctly else: # no data if ary.lshape[axis] == 0: sub_arrays_t = [torch.empty(ary.lshape) for i in range(indices_or_sections)] else: offset, local_shape, slices = ary.comm.chunk(ary.gshape, axis) - idx_block = offset // indices_or_sections_t - left_data_block = indices_or_sections_t - (offset % indices_or_sections_t) + idx_frst_chunk_affctd = offset // indices_or_sections_t + left_data_chunk = indices_or_sections_t - (offset % indices_or_sections_t) left_data_process = ary.lshape[axis] new_indices = torch.zeros(indices_or_sections, dtype=int) - if left_data_block >= left_data_process: - new_indices[idx_block] = left_data_process + if left_data_chunk >= left_data_process: + new_indices[idx_frst_chunk_affctd] = left_data_process else: - new_indices[idx_block] = left_data_block - left_data_process -= left_data_block - idx_block += 1 + new_indices[idx_frst_chunk_affctd] = left_data_chunk + left_data_process -= left_data_chunk + idx_frst_chunk_affctd += 1 - # calculate blocks which can be filled completely - left_blocks_to_fill = left_data_process // indices_or_sections_t + # calculate chunks which can be filled completely + left_chunks_to_fill = left_data_process // indices_or_sections_t new_indices[ - idx_block : (left_blocks_to_fill + idx_block) + idx_frst_chunk_affctd : (left_chunks_to_fill + idx_frst_chunk_affctd) ] = indices_or_sections_t # assign residual to following process - new_indices[left_blocks_to_fill + idx_block] = ( + new_indices[left_chunks_to_fill + idx_frst_chunk_affctd] = ( left_data_process % indices_or_sections_t ) From dd22292a65afeaa2e940a5aa640702b11b997db7 Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Fri, 25 Sep 2020 13:43:58 +0200 Subject: [PATCH 22/27] Added PR to changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1ae128ec28..e8782ec493 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ # Pending additions +- [#677](https://github.com/helmholtz-analytics/heat/pull/677) New feature: split, vsplit, dsplit, hsplit ## New features ### Manipulations ### Statistical Functions From 5806e693f74bd18d165f698325e06936d6e5653b Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Mon, 26 Oct 2020 16:29:28 +0100 Subject: [PATCH 23/27] Added reference to split in docs --- heat/core/manipulations.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index a32dc9d2fd..f05cd350cd 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -664,6 +664,10 @@ def dsplit(ary, indices_or_sections): ValueError If indices_or_sections is given as integer, but a split does not result in equal division. + See Also + ------ + :function:`split ` + Examples -------- >>> x = ht.array(24).reshape((2, 3, 4)) @@ -949,6 +953,10 @@ def hsplit(ary, indices_or_sections): ValueError If indices_or_sections is given as integer, but a split does not result in equal division. + See Also + ------ + :function:`split ` + Examples -------- >>> x = ht.arange(24).reshape((2, 4, 3)) @@ -2670,6 +2678,10 @@ def vsplit(ary, indices_or_sections): ValueError If indices_or_sections is given as integer, but a split does not result in equal division. + See Also + ------ + :function:`split ` + Examples -------- >>> x = ht.arange(24).reshape((4, 3, 2)) From ddb9bb7ba6f25bbee52e6d701731c430a713a6b6 Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Mon, 26 Oct 2020 16:45:13 +0100 Subject: [PATCH 24/27] Restructured documentation (Notes section) --- heat/core/manipulations.py | 48 +++++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 21 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index f05cd350cd..7f27db9563 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -640,9 +640,6 @@ def dsplit(ary, indices_or_sections): """ Split array into multiple sub-arrays along the 3rd axis (depth). - Please refer to the split documentation. dsplit is equivalent to split with axis=2, - the array is always split along the third axis provided the array dimension is greater than or equal to 3. - Parameters ---------- ary : DNDarray @@ -659,6 +656,11 @@ def dsplit(ary, indices_or_sections): sub_arrays : list of DNDarrays A list of sub-DNDarrays as views into ary. + Notes + ----- + Please refer to the split documentation. dsplit is equivalent to split with axis=2, + the array is always split along the third axis provided the array dimension is greater than or equal to 3. + Raises ------ ValueError @@ -676,14 +678,12 @@ def dsplit(ary, indices_or_sections): DNDarray([[[ 0, 1], [ 4, 5], [ 8, 9]], - [[12, 13], [16, 17], [20, 21]]]), DNDarray([[[ 2, 3], [ 6, 7], [10, 11]], - [[14, 15], [18, 19], [22, 23]]]) @@ -693,14 +693,12 @@ def dsplit(ary, indices_or_sections): DNDarray([[[ 0], [ 4], [ 8]], - [[12], [16], [20]]]), DNDarray([[[ 1, 2, 3], [ 5, 6, 7], [ 9, 10, 11]], - [[13, 14, 15], [17, 18, 19], [21, 22, 23]]]), @@ -929,9 +927,6 @@ def hsplit(ary, indices_or_sections): """ Split array into multiple sub-arrays along the 2nd axis (horizontally/column-wise). - Please refer to the split documentation. hsplit is nearly equivalent to split with axis=1, - the array is always split along the second axis though, in contrary to split, regardless of the array dimension. - Parameters ---------- ary : DNDarray @@ -948,13 +943,18 @@ def hsplit(ary, indices_or_sections): sub_arrays : list of DNDarrays A list of sub-DNDarrays as views into ary. + Notes + ----- + Please refer to the split documentation. hsplit is nearly equivalent to split with axis=1, + the array is always split along the second axis though, in contrary to split, regardless of the array dimension. + Raises ------ ValueError If indices_or_sections is given as integer, but a split does not result in equal division. See Also - ------ + -------- :function:`split ` Examples @@ -1908,9 +1908,6 @@ def sort(a, axis=None, descending=False, out=None): def split(ary, indices_or_sections, axis=0): """ Split a DNDarray into multiple sub-DNDarrays as views into ary. - ! Warning ! - Though it is possible to distribute `ary`, this function has nothing to do with the split - parameter of a DNDarray. Parameters ---------- @@ -1935,11 +1932,21 @@ def split(ary, indices_or_sections, axis=0): sub_arrays : list of DNDarrays A list of sub-DNDarrays as views into ary. + Warnings + -------- + Though it is possible to distribute `ary`, this function has nothing to do with the split + parameter of a DNDarray. + Raises ------ ValueError If indices_or_sections is given as integer, but a split does not result in equal division. + See Also + -------- + :function:`dsplit `, :function:`hsplit `, + :function:`vsplit ` + Examples -------- >>> x = ht.arange(12).reshape((4,3)) @@ -2654,9 +2661,6 @@ def vsplit(ary, indices_or_sections): """ Split array into multiple sub-arrays along the 1st axis (vertically/row-wise). - Please refer to the split documentation. hsplit is equivalent to split with axis=0, - the array is always split along the first axis regardless of the array dimension. - Parameters ---------- ary : DNDarray @@ -2673,13 +2677,18 @@ def vsplit(ary, indices_or_sections): sub_arrays : list of DNDarrays A list of sub-DNDarrays as views into ary. + Notes + ----- + Please refer to the split documentation. hsplit is equivalent to split with axis=0, + the array is always split along the first axis regardless of the array dimension. + Raises ------ ValueError If indices_or_sections is given as integer, but a split does not result in equal division. See Also - ------ + -------- :function:`split ` Examples @@ -2690,14 +2699,12 @@ def vsplit(ary, indices_or_sections): DNDarray([[[ 0, 1], [ 2, 3], [ 4, 5]], - [[ 6, 7], [ 8, 9], [10, 11]]]), DNDarray([[[12, 13], [14, 15], [16, 17]], - [[18, 19], [20, 21], [22, 23]]]) @@ -2711,7 +2718,6 @@ def vsplit(ary, indices_or_sections): DNDarray([[[ 6, 7], [ 8, 9], [10, 11]], - [[12, 13], [14, 15], [16, 17]]]), From 1fcd1e169db3baeeddb206b2b66426427cff3c00 Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Tue, 27 Oct 2020 08:30:33 +0100 Subject: [PATCH 25/27] Moved PR to section --- CHANGELOG.md | 3 ++- heat/core/manipulations.py | 9 ++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e90553c0c..ab12f0693c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,10 @@ # Pending additions -- [#677](https://github.com/helmholtz-analytics/heat/pull/677) New feature: split, vsplit, dsplit, hsplit + ## New features - [#680](https://github.com/helmholtz-analytics/heat/pull/680) New property: larray - [#683](https://github.com/helmholtz-analytics/heat/pull/683) New properties: nbytes, gnbytes, lnbytes ### Manipulations +- [#677](https://github.com/helmholtz-analytics/heat/pull/677) split, vsplit, dsplit, hsplit ### Statistical Functions ### Linear Algebra ### ... diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 7f27db9563..989e02a053 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -668,7 +668,7 @@ def dsplit(ary, indices_or_sections): See Also ------ - :function:`split ` + :function:`split` Examples -------- @@ -955,7 +955,7 @@ def hsplit(ary, indices_or_sections): See Also -------- - :function:`split ` + :function:`split` Examples -------- @@ -1944,8 +1944,7 @@ def split(ary, indices_or_sections, axis=0): See Also -------- - :function:`dsplit `, :function:`hsplit `, - :function:`vsplit ` + :function:`dsplit`, :function:`hsplit`, :function:`vsplit` Examples -------- @@ -2689,7 +2688,7 @@ def vsplit(ary, indices_or_sections): See Also -------- - :function:`split ` + :function:`split` Examples -------- From fd58cc834a0a1b97cf15d61f84e251602fe9bda4 Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Tue, 27 Oct 2020 14:26:37 +0100 Subject: [PATCH 26/27] Adapted requested changes & changed description of 'views' to 'copies' --- heat/core/manipulations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index eb843f6a80..16a0733ff1 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2166,7 +2166,7 @@ def sort(a, axis=None, descending=False, out=None): def split(ary, indices_or_sections, axis=0): """ - Split a DNDarray into multiple sub-DNDarrays as views into ary. + Split a DNDarray into multiple sub-DNDarrays as copies of parts of ary. Parameters ---------- @@ -2279,7 +2279,7 @@ def split(ary, indices_or_sections, axis=0): # start of actual algorithm - if ary.split == axis and ary.split is not None and ary.comm.size > 1: + if ary.split == axis and ary.is_distributed(): if isinstance(indices_or_sections, int): # CASE 1 number of processes == indices_or_selections -> split already done due to distribution From 454cdee6f12da69fd180a261197d3410df04f50b Mon Sep 17 00:00:00 2001 From: Lena Blind Date: Tue, 17 Nov 2020 13:31:14 +0100 Subject: [PATCH 27/27] Added docstrings & updated to current master --- heat/core/manipulations.py | 65 ++++++++++++++++++++------------------ 1 file changed, 34 insertions(+), 31 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 3abc6e8d57..96848c277f 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -632,23 +632,24 @@ def diagonal(a, offset=0, dim1=0, dim2=1): def dsplit(ary, indices_or_sections): """ - Split array into multiple sub-arrays along the 3rd axis (depth). + Split array into multiple sub-DNDarrays along the 3rd axis (depth). + Note that this function returns copies and not views into `ary`. Parameters ---------- ary : DNDarray DNDArray to be divided into sub-DNDarrays. indices_or_sections : int or 1-dimensional array_like (i.e. undistributed DNDarray, list or tuple) - If indices_or_sections is an integer, N, the DNDarray will be divided into N equal DNDarrays along the 3rd axis. + If `indices_or_sections` is an integer, N, the DNDarray will be divided into N equal DNDarrays along the 3rd axis. If such a split is not possible, an error is raised. - If indices_or_sections is a 1-D DNDarray of sorted integers, the entries indicate where along the 3rd axis + If `indices_or_sections` is a 1-D DNDarray of sorted integers, the entries indicate where along the 3rd axis the array is split. - If an index exceeds the dimension of the array along the 3rd axis, an empty sub-array is returned correspondingly. + If an index exceeds the dimension of the array along the 3rd axis, an empty sub-DNDarray is returned correspondingly. Returns ------- sub_arrays : list of DNDarrays - A list of sub-DNDarrays as views into ary. + A list of sub-DNDarrays as copies of parts of `ary`. Notes ----- @@ -658,7 +659,7 @@ def dsplit(ary, indices_or_sections): Raises ------ ValueError - If indices_or_sections is given as integer, but a split does not result in equal division. + If `indices_or_sections` is given as integer, but a split does not result in equal division. See Also ------ @@ -919,23 +920,24 @@ def flipud(a): def hsplit(ary, indices_or_sections): """ - Split array into multiple sub-arrays along the 2nd axis (horizontally/column-wise). + Split array into multiple sub-DNDarrays along the 2nd axis (horizontally/column-wise). + Note that this function returns copies and not views into `ary`. Parameters ---------- ary : DNDarray DNDArray to be divided into sub-DNDarrays. indices_or_sections : int or 1-dimensional array_like (i.e. undistributed DNDarray, list or tuple) - If indices_or_sections is an integer, N, the DNDarray will be divided into N equal DNDarrays along the 2nd axis. + If `indices_or_sections` is an integer, N, the DNDarray will be divided into N equal DNDarrays along the 2nd axis. If such a split is not possible, an error is raised. - If indices_or_sections is a 1-D DNDarray of sorted integers, the entries indicate where along the 2nd axis + If `indices_or_sections` is a 1-D DNDarray of sorted integers, the entries indicate where along the 2nd axis the array is split. - If an index exceeds the dimension of the array along the 2nd axis, an empty sub-array is returned correspondingly. + If an index exceeds the dimension of the array along the 2nd axis, an empty sub-DNDarray is returned correspondingly. Returns ------- sub_arrays : list of DNDarrays - A list of sub-DNDarrays as views into ary. + A list of sub-DNDarrays as copies of parts of `ary` Notes ----- @@ -945,7 +947,7 @@ def hsplit(ary, indices_or_sections): Raises ------ ValueError - If indices_or_sections is given as integer, but a split does not result in equal division. + If `indices_or_sections` is given as integer, but a split does not result in equal division. See Also -------- @@ -982,7 +984,7 @@ def hsplit(ary, indices_or_sections): [[21, 22, 23]]])] """ - sanitation.sanitize_input(ary) + sanitation.sanitize_in(ary) if len(ary.lshape) < 2: ary = reshape(ary, (1, ary.lshape[0])) @@ -2159,30 +2161,30 @@ def sort(a, axis=None, descending=False, out=None): def split(ary, indices_or_sections, axis=0): """ - Split a DNDarray into multiple sub-DNDarrays as copies of parts of ary. + Split a DNDarray into multiple sub-DNDarrays as copies of parts of `ary`. Parameters ---------- ary : DNDarray DNDArray to be divided into sub-DNDarrays. indices_or_sections : int or 1-dimensional array_like (i.e. undistributed DNDarray, list or tuple) - If indices_or_sections is an integer, N, the DNDarray will be divided into N equal DNDarrays along axis. + If `indices_or_sections` is an integer, N, the DNDarray will be divided into N equal DNDarrays along axis. If such a split is not possible, an error is raised. - If indices_or_sections is a 1-D DNDarray of sorted integers, the entries indicate where along axis + If `indices_or_sections` is a 1-D DNDarray of sorted integers, the entries indicate where along axis the array is split. - For example, indices_or_sections = [2, 3] would, for axis = 0, result in - - ary[:2] - - ary[2:3] - - ary[3:] + For example, `indices_or_sections = [2, 3]` would, for `axis = 0`, result in + - `ary[:2]` + - `ary[2:3]` + - `ary[3:]` If an index exceeds the dimension of the array along axis, an empty sub-array is returned correspondingly. axis : int, optional The axis along which to split, default is 0. - axis is not allowed to equal ary.split if ary is distributed. + `axis` is not allowed to equal `ary.split` if `ary` is distributed. Returns ------- sub_arrays : list of DNDarrays - A list of sub-DNDarrays as views into ary. + A list of sub-DNDarrays as copies of parts of `ary`. Warnings -------- @@ -2192,7 +2194,7 @@ def split(ary, indices_or_sections, axis=0): Raises ------ ValueError - If indices_or_sections is given as integer, but a split does not result in equal division. + If `indices_or_sections` is given as integer, but a split does not result in equal division. See Also -------- @@ -2231,7 +2233,7 @@ def split(ary, indices_or_sections, axis=0): """ # sanitize ary - sanitation.sanitize_input(ary) + sanitation.sanitize_in(ary) # sanitize axis if not isinstance(axis, int): @@ -2893,33 +2895,34 @@ def unique(a, sorted=False, return_inverse=False, axis=None): def vsplit(ary, indices_or_sections): """ - Split array into multiple sub-arrays along the 1st axis (vertically/row-wise). + Split array into multiple sub-DNDNarrays along the 1st axis (vertically/row-wise). + Note that this function returns copies and not views into `ary`. Parameters ---------- ary : DNDarray DNDArray to be divided into sub-DNDarrays. indices_or_sections : int or 1-dimensional array_like (i.e. undistributed DNDarray, list or tuple) - If indices_or_sections is an integer, N, the DNDarray will be divided into N equal DNDarrays along the 1st axis. + If `indices_or_sections` is an integer, N, the DNDarray will be divided into N equal DNDarrays along the 1st axis. If such a split is not possible, an error is raised. - If indices_or_sections is a 1-D DNDarray of sorted integers, the entries indicate where along the 1st axis + If `indices_or_sections` is a 1-D DNDarray of sorted integers, the entries indicate where along the 1st axis the array is split. - If an index exceeds the dimension of the array along the 1st axis, an empty sub-array is returned correspondingly. + If an index exceeds the dimension of the array along the 1st axis, an empty sub-DNDarray is returned correspondingly. Returns ------- sub_arrays : list of DNDarrays - A list of sub-DNDarrays as views into ary. + A list of sub-DNDarrays as copies of parts of `ary`. Notes ----- - Please refer to the split documentation. hsplit is equivalent to split with axis=0, + Please refer to the split documentation. hsplit is equivalent to split with `axis=0`, the array is always split along the first axis regardless of the array dimension. Raises ------ ValueError - If indices_or_sections is given as integer, but a split does not result in equal division. + If `indices_or_sections` is given as integer, but a split does not result in equal division. See Also --------