diff --git a/CHANGELOG.md b/CHANGELOG.md index 9794564245..b029cd8241 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ - [#683](https://github.com/helmholtz-analytics/heat/pull/683) New properties: nbytes, gnbytes, lnbytes - [#687](https://github.com/helmholtz-analytics/heat/pull/687) New DNDarray property: balanced ### Manipulations +- [#677](https://github.com/helmholtz-analytics/heat/pull/677) split, vsplit, dsplit, hsplit ### Statistical Functions - [#679](https://github.com/helmholtz-analytics/heat/pull/679) New feature: ``histc()`` and ``histogram()`` ### Linear Algebra diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 0926a24779..96848c277f 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -4,9 +4,11 @@ from .communication import MPI +from . import arithmetics from . import constants from . import dndarray from . import factories +from . import indexing from . import linalg from . import sanitation from . import stride_tricks @@ -19,11 +21,13 @@ "concatenate", "diag", "diagonal", + "dsplit", "expand_dims", "flatten", "flip", "fliplr", "flipud", + "hsplit", "hstack", "pad", "repeat", @@ -33,10 +37,12 @@ "row_stack", "shape", "sort", + "split", "squeeze", "stack", "topk", "unique", + "vsplit", "vstack", ] @@ -624,6 +630,80 @@ def diagonal(a, offset=0, dim1=0, dim2=1): return factories.array(result, dtype=a.dtype, is_split=split, device=a.device, comm=a.comm) +def dsplit(ary, indices_or_sections): + """ + Split array into multiple sub-DNDarrays along the 3rd axis (depth). + Note that this function returns copies and not views into `ary`. + + Parameters + ---------- + ary : DNDarray + DNDArray to be divided into sub-DNDarrays. + indices_or_sections : int or 1-dimensional array_like (i.e. undistributed DNDarray, list or tuple) + If `indices_or_sections` is an integer, N, the DNDarray will be divided into N equal DNDarrays along the 3rd axis. + If such a split is not possible, an error is raised. + If `indices_or_sections` is a 1-D DNDarray of sorted integers, the entries indicate where along the 3rd axis + the array is split. + If an index exceeds the dimension of the array along the 3rd axis, an empty sub-DNDarray is returned correspondingly. + + Returns + ------- + sub_arrays : list of DNDarrays + A list of sub-DNDarrays as copies of parts of `ary`. + + Notes + ----- + Please refer to the split documentation. dsplit is equivalent to split with axis=2, + the array is always split along the third axis provided the array dimension is greater than or equal to 3. + + Raises + ------ + ValueError + If `indices_or_sections` is given as integer, but a split does not result in equal division. + + See Also + ------ + :function:`split` + + Examples + -------- + >>> x = ht.array(24).reshape((2, 3, 4)) + >>> ht.dsplit(x, 2) + [ + DNDarray([[[ 0, 1], + [ 4, 5], + [ 8, 9]], + [[12, 13], + [16, 17], + [20, 21]]]), + DNDarray([[[ 2, 3], + [ 6, 7], + [10, 11]], + [[14, 15], + [18, 19], + [22, 23]]]) + ] + >>> ht.dsplit(x, [1, 4]) + [ + DNDarray([[[ 0], + [ 4], + [ 8]], + [[12], + [16], + [20]]]), + DNDarray([[[ 1, 2, 3], + [ 5, 6, 7], + [ 9, 10, 11]], + [[13, 14, 15], + [17, 18, 19], + [21, 22, 23]]]), + DNDarray([]) + ] + + """ + return split(ary, indices_or_sections, 2) + + def expand_dims(a, axis): """ Expand the shape of an array. @@ -838,6 +918,84 @@ def flipud(a): return flip(a, 0) +def hsplit(ary, indices_or_sections): + """ + Split array into multiple sub-DNDarrays along the 2nd axis (horizontally/column-wise). + Note that this function returns copies and not views into `ary`. + + Parameters + ---------- + ary : DNDarray + DNDArray to be divided into sub-DNDarrays. + indices_or_sections : int or 1-dimensional array_like (i.e. undistributed DNDarray, list or tuple) + If `indices_or_sections` is an integer, N, the DNDarray will be divided into N equal DNDarrays along the 2nd axis. + If such a split is not possible, an error is raised. + If `indices_or_sections` is a 1-D DNDarray of sorted integers, the entries indicate where along the 2nd axis + the array is split. + If an index exceeds the dimension of the array along the 2nd axis, an empty sub-DNDarray is returned correspondingly. + + Returns + ------- + sub_arrays : list of DNDarrays + A list of sub-DNDarrays as copies of parts of `ary` + + Notes + ----- + Please refer to the split documentation. hsplit is nearly equivalent to split with axis=1, + the array is always split along the second axis though, in contrary to split, regardless of the array dimension. + + Raises + ------ + ValueError + If `indices_or_sections` is given as integer, but a split does not result in equal division. + + See Also + -------- + :function:`split` + + Examples + -------- + >>> x = ht.arange(24).reshape((2, 4, 3)) + >>> ht.hsplit(x, 2) + [ + DNDarray([[[ 0, 1, 2], + [ 3, 4, 5]], + + [[12, 13, 14], + [15, 16, 17]]]), + DNDarray([[[ 6, 7, 8], + [ 9, 10, 11]], + + [[18, 19, 20], + [21, 22, 23]]]) + ] + + >>> ht.hsplit(x, [1, 3]) + [ + DNDarray([[[ 0, 1, 2]], + + [[12, 13, 14]]]), + DNDarray([[[ 3, 4, 5], + [ 6, 7, 8]], + + [[15, 16, 17], + [18, 19, 20]]]), + DNDarray([[[ 9, 10, 11]], + + [[21, 22, 23]]])] + """ + sanitation.sanitize_in(ary) + + if len(ary.lshape) < 2: + ary = reshape(ary, (1, ary.lshape[0])) + result = split(ary, indices_or_sections, 1) + result = [flatten(sub_array) for sub_array in result] + else: + result = split(ary, indices_or_sections, 1) + + return result + + def hstack(tup): """ Stack arrays in sequence horizontally (column wise). @@ -1069,7 +1227,7 @@ def pad(array, pad_width, mode="constant", constant_values=0): if len(pad) // 2 > len(array.shape): raise ValueError( f"Not enough dimensions to pad.\n" - f"Padding a {len(array.shape)}-dimensional tensor for {len(pad)//2}" + f"Padding a {len(array.shape)}-dimensional tensor for {len(pad) // 2}" f" dimensions is not possible." ) @@ -2001,6 +2159,258 @@ def sort(a, axis=None, descending=False, out=None): return tensor, return_indices +def split(ary, indices_or_sections, axis=0): + """ + Split a DNDarray into multiple sub-DNDarrays as copies of parts of `ary`. + + Parameters + ---------- + ary : DNDarray + DNDArray to be divided into sub-DNDarrays. + indices_or_sections : int or 1-dimensional array_like (i.e. undistributed DNDarray, list or tuple) + If `indices_or_sections` is an integer, N, the DNDarray will be divided into N equal DNDarrays along axis. + If such a split is not possible, an error is raised. + If `indices_or_sections` is a 1-D DNDarray of sorted integers, the entries indicate where along axis + the array is split. + For example, `indices_or_sections = [2, 3]` would, for `axis = 0`, result in + - `ary[:2]` + - `ary[2:3]` + - `ary[3:]` + If an index exceeds the dimension of the array along axis, an empty sub-array is returned correspondingly. + axis : int, optional + The axis along which to split, default is 0. + `axis` is not allowed to equal `ary.split` if `ary` is distributed. + + Returns + ------- + sub_arrays : list of DNDarrays + A list of sub-DNDarrays as copies of parts of `ary`. + + Warnings + -------- + Though it is possible to distribute `ary`, this function has nothing to do with the split + parameter of a DNDarray. + + Raises + ------ + ValueError + If `indices_or_sections` is given as integer, but a split does not result in equal division. + + See Also + -------- + :function:`dsplit`, :function:`hsplit`, :function:`vsplit` + + Examples + -------- + >>> x = ht.arange(12).reshape((4,3)) + >>> ht.split(x, 2) + [ DNDarray([[0, 1, 2], + [3, 4, 5]]), + DNDarray([[ 6, 7, 8], + [ 9, 10, 11]]) + ] + >>> ht.split(x, [2, 3, 5]) + [ DNDarray([[0, 1, 2], + [3, 4, 5]]), + DNDarray([[6, 7, 8]] + DNDarray([[ 9, 10, 11]]), + DNDarray([]) + ] + >>> ht.split(x, [1, 2], 1) + [ DNDarray([[0], + [3], + [6], + [9]]), + DNDarray([[ 1], + [ 4], + [ 7], + [10]], + DNDarray([[ 2], + [ 5], + [ 8], + [11]]) + ] + + """ + # sanitize ary + sanitation.sanitize_in(ary) + + # sanitize axis + if not isinstance(axis, int): + raise TypeError("Expected `axis` to be an integer, but was {}".format(type(axis))) + if axis < 0 or axis > len(ary.gshape) - 1: + raise ValueError( + "Invalid input for `axis`. Valid range is between 0 and {}, but was {}".format( + len(ary.gshape) - 1, axis + ) + ) + + # sanitize indices_or_sections + if isinstance(indices_or_sections, int): + if ary.gshape[axis] % indices_or_sections != 0: + raise ValueError( + "DNDarray with shape {} can't be divided equally into {} chunks along axis {}".format( + ary.gshape, indices_or_sections, axis + ) + ) + # np to torch mapping - calculate size of resulting data chunks + indices_or_sections_t = ary.gshape[axis] // indices_or_sections + + elif isinstance(indices_or_sections, (list, tuple, dndarray.DNDarray)): + if isinstance(indices_or_sections, (list, tuple)): + indices_or_sections = factories.array(indices_or_sections) + if len(indices_or_sections.gshape) != 1: + raise ValueError( + "Expected indices_or_sections to be 1-dimensional, but was {}-dimensional instead.".format( + len(indices_or_sections.gshape) - 1 + ) + ) + else: + raise TypeError( + "Expected `indices_or_sections` to be array_like (DNDarray, list or tuple), but was {}".format( + type(indices_or_sections) + ) + ) + + # start of actual algorithm + + if ary.split == axis and ary.is_distributed(): + + if isinstance(indices_or_sections, int): + # CASE 1 number of processes == indices_or_selections -> split already done due to distribution + if ary.comm.size == indices_or_sections: + new_lshape = list(ary.lshape) + new_lshape[axis] = 0 + sub_arrays_t = [ + torch.empty(new_lshape) if i != ary.comm.rank else ary._DNDarray__array + for i in range(indices_or_sections) + ] + + # # CASE 2 number of processes != indices_or_selections -> reorder (and split) chunks correctly + else: + # no data + if ary.lshape[axis] == 0: + sub_arrays_t = [torch.empty(ary.lshape) for i in range(indices_or_sections)] + else: + offset, local_shape, slices = ary.comm.chunk(ary.gshape, axis) + idx_frst_chunk_affctd = offset // indices_or_sections_t + left_data_chunk = indices_or_sections_t - (offset % indices_or_sections_t) + left_data_process = ary.lshape[axis] + + new_indices = torch.zeros(indices_or_sections, dtype=int) + + if left_data_chunk >= left_data_process: + new_indices[idx_frst_chunk_affctd] = left_data_process + else: + new_indices[idx_frst_chunk_affctd] = left_data_chunk + left_data_process -= left_data_chunk + idx_frst_chunk_affctd += 1 + + # calculate chunks which can be filled completely + left_chunks_to_fill = left_data_process // indices_or_sections_t + new_indices[ + idx_frst_chunk_affctd : (left_chunks_to_fill + idx_frst_chunk_affctd) + ] = indices_or_sections_t + + # assign residual to following process + new_indices[left_chunks_to_fill + idx_frst_chunk_affctd] = ( + left_data_process % indices_or_sections_t + ) + + sub_arrays_t = torch.split(ary._DNDarray__array, new_indices.tolist(), axis) + # indices or sections == DNDarray + else: + if indices_or_sections.split is not None: + warnings.warn( + "`indices_or_sections` might not be distributed (along axis {}) if `ary` is not distributed.\n" + "`indices_or_sections` will be copied with new split axis None.".format( + indices_or_sections.split + ) + ) + indices_or_sections = resplit(indices_or_sections, None) + + offset, local_shape, slices = ary.comm.chunk(ary.gshape, axis) + slice_axis = slices[axis] + + # reduce information to the (chunk) relevant + indices_or_sections_t = indexing.where( + indices_or_sections <= slice_axis.start, slice_axis.start, indices_or_sections + ) + + indices_or_sections_t = indexing.where( + indices_or_sections_t >= slice_axis.stop, slice_axis.stop, indices_or_sections_t + ) + + # np to torch mapping + + # 2. add first and last value to DNDarray + # 3. calculate the 1-st discrete difference therefore corresponding chunk sizes + indices_or_sections_t = arithmetics.diff( + indices_or_sections_t, prepend=slice_axis.start, append=slice_axis.stop + ) + indices_or_sections_t = factories.array( + indices_or_sections_t, + dtype=types.int64, + is_split=indices_or_sections_t.split, + comm=indices_or_sections_t.comm, + device=indices_or_sections_t.device, + ) + + # 4. transform the result into a list (torch requirement) + indices_or_sections_t = indices_or_sections_t.tolist() + + sub_arrays_t = torch.split(ary._DNDarray__array, indices_or_sections_t, axis) + else: + if isinstance(indices_or_sections, int): + sub_arrays_t = torch.split(ary._DNDarray__array, indices_or_sections_t, axis) + else: + if indices_or_sections.split is not None: + warnings.warn( + "`indices_or_sections` might not be distributed (along axis {}) if `ary` is not distributed.\n" + "`indices_or_sections` will be copied with new split axis None.".format( + indices_or_sections.split + ) + ) + indices_or_sections = resplit(indices_or_sections, None) + + # np to torch mapping + + # 1. replace all values out of range with gshape[axis] to generate size 0 + indices_or_sections_t = indexing.where( + indices_or_sections <= ary.gshape[axis], indices_or_sections, ary.gshape[axis] + ) + + # 2. add first and last value to DNDarray + # 3. calculate the 1-st discrete difference therefore corresponding chunk sizes + indices_or_sections_t = arithmetics.diff( + indices_or_sections_t, prepend=0, append=ary.gshape[axis] + ) + indices_or_sections_t = factories.array( + indices_or_sections_t, + dtype=types.int64, + is_split=indices_or_sections_t.split, + comm=indices_or_sections_t.comm, + device=indices_or_sections_t.device, + ) + + # 4. transform the result into a list (torch requirement) + indices_or_sections_t = indices_or_sections_t.tolist() + + sub_arrays_t = torch.split(ary._DNDarray__array, indices_or_sections_t, axis) + + sub_arrays_ht = [ + factories.array( + sub_DNDarray, dtype=ary.dtype, is_split=ary.split, device=ary.device, comm=ary.comm + ) + for sub_DNDarray in sub_arrays_t + ] + + for sub_DNDarray in sub_arrays_ht: + sub_DNDarray.balance_() + + return sub_arrays_ht + + def squeeze(x, axis=None): """ Remove single-dimensional entries from the shape of a tensor. @@ -2483,6 +2893,79 @@ def unique(a, sorted=False, return_inverse=False, axis=None): return return_value +def vsplit(ary, indices_or_sections): + """ + Split array into multiple sub-DNDNarrays along the 1st axis (vertically/row-wise). + Note that this function returns copies and not views into `ary`. + + Parameters + ---------- + ary : DNDarray + DNDArray to be divided into sub-DNDarrays. + indices_or_sections : int or 1-dimensional array_like (i.e. undistributed DNDarray, list or tuple) + If `indices_or_sections` is an integer, N, the DNDarray will be divided into N equal DNDarrays along the 1st axis. + If such a split is not possible, an error is raised. + If `indices_or_sections` is a 1-D DNDarray of sorted integers, the entries indicate where along the 1st axis + the array is split. + If an index exceeds the dimension of the array along the 1st axis, an empty sub-DNDarray is returned correspondingly. + + Returns + ------- + sub_arrays : list of DNDarrays + A list of sub-DNDarrays as copies of parts of `ary`. + + Notes + ----- + Please refer to the split documentation. hsplit is equivalent to split with `axis=0`, + the array is always split along the first axis regardless of the array dimension. + + Raises + ------ + ValueError + If `indices_or_sections` is given as integer, but a split does not result in equal division. + + See Also + -------- + :function:`split` + + Examples + -------- + >>> x = ht.arange(24).reshape((4, 3, 2)) + >>> ht.vsplit(x, 2) + [ + DNDarray([[[ 0, 1], + [ 2, 3], + [ 4, 5]], + [[ 6, 7], + [ 8, 9], + [10, 11]]]), + DNDarray([[[12, 13], + [14, 15], + [16, 17]], + [[18, 19], + [20, 21], + [22, 23]]]) + ] + + >>> ht.vsplit(x, [1, 3]) + [ + DNDarray([[[0, 1], + [2, 3], + [4, 5]]]), + DNDarray([[[ 6, 7], + [ 8, 9], + [10, 11]], + [[12, 13], + [14, 15], + [16, 17]]]), + DNDarray([[[18, 19], + [20, 21], + [22, 23]]])] + + """ + return split(ary, indices_or_sections, 0) + + def resplit(arr, axis=None): """ Out-of-place redistribution of the content of the tensor. Allows to "unsplit" (i.e. gather) all values from all diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index d16f297ed3..ed7882b6d5 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -725,6 +725,61 @@ def test_diagonal(self): numpy_args={"axis1": 0, "axis2": 1}, ) + def test_dsplit(self): + # for further testing, see test_split + data_ht = ht.arange(24).reshape((2, 3, 4)) + data_np = data_ht.numpy() + + # indices_or_sections = int + result = ht.dsplit(data_ht, 2) + comparison = np.dsplit(data_np, 2) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = tuple + result = ht.dsplit(data_ht, (0, 1)) + comparison = np.dsplit(data_np, (0, 1)) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = list + result = ht.dsplit(data_ht, [0, 1]) + comparison = np.dsplit(data_np, [0, 1]) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = undistributed DNDarray + result = ht.dsplit(data_ht, ht.array([0, 1])) + comparison = np.dsplit(data_np, np.array([0, 1])) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = distributed DNDarray + result = ht.dsplit(data_ht, ht.array([0, 1], split=0)) + comparison = np.dsplit(data_np, np.array([0, 1])) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + def test_expand_dims(self): # vector data a = ht.arange(10) @@ -951,6 +1006,115 @@ def test_flipud(self): ) self.assertTrue(ht.equal(ht.resplit(ht.flipud(c), 0), r_c)) + def test_hsplit(self): + # for further testing, see test_split + # 1-dimensional array (as forbidden in split) + data_ht = ht.arange(24) + data_np = data_ht.numpy() + + # indices_or_sections = int + result = ht.hsplit(data_ht, 2) + comparison = np.hsplit(data_np, 2) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = tuple + result = ht.hsplit(data_ht, (0, 1)) + comparison = np.hsplit(data_np, (0, 1)) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = list + result = ht.hsplit(data_ht, [0, 1]) + comparison = np.hsplit(data_np, [0, 1]) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = undistributed DNDarray + result = ht.hsplit(data_ht, ht.array([0, 1])) + comparison = np.hsplit(data_np, np.array([0, 1])) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = distributed DNDarray + result = ht.hsplit(data_ht, ht.array([0, 1], split=0)) + comparison = np.hsplit(data_np, np.array([0, 1])) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + data_ht = ht.arange(24).reshape((2, 4, 3)) + data_np = data_ht.numpy() + + # indices_or_sections = int + result = ht.hsplit(data_ht, 2) + comparison = np.hsplit(data_np, 2) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = tuple + result = ht.hsplit(data_ht, (0, 1)) + comparison = np.hsplit(data_np, (0, 1)) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = list + result = ht.hsplit(data_ht, [0, 1]) + comparison = np.hsplit(data_np, [0, 1]) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = undistributed DNDarray + result = ht.hsplit(data_ht, ht.array([0, 1])) + comparison = np.hsplit(data_np, np.array([0, 1])) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = distributed DNDarray + result = ht.hsplit(data_ht, ht.array([0, 1], split=0)) + comparison = np.hsplit(data_np, np.array([0, 1])) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + def test_hstack(self): # cases to test: # MM=================================== @@ -2167,6 +2331,220 @@ def test_sort(self): if rank == i: self.assertTrue(torch.lt(result.larray[idx], result.larray[idx + 1]).all()) + def test_split(self): + # ==================================== + # UNDISTRIBUTED CASE + # ==================================== + # axis = 0 + # ==================================== + data_ht = ht.arange(24).reshape((2, 3, 4)) + data_np = data_ht.numpy() + + # indices_or_sections = int + result = ht.split(data_ht, 2) + comparison = np.split(data_np, 2) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = tuple + result = ht.split(data_ht, (0, 1)) + comparison = np.split(data_np, (0, 1)) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = list + result = ht.split(data_ht, [0, 1]) + comparison = np.split(data_np, [0, 1]) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = undistributed DNDarray + result = ht.split(data_ht, ht.array([0, 1])) + comparison = np.split(data_np, np.array([0, 1])) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = distributed DNDarray + result = ht.split(data_ht, ht.array([0, 1], split=0)) + comparison = np.split(data_np, np.array([0, 1])) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # ==================================== + # axis != 0 (2 in this case) + # ==================================== + # indices_or_sections = int + result = ht.split(data_ht, 2, 2) + comparison = np.split(data_np, 2, 2) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = tuple + result = ht.split(data_ht, (0, 1)) + comparison = np.split(data_np, (0, 1)) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # exceptions + with self.assertRaises(TypeError): + ht.split([1, 2, 3, 4], 2) + with self.assertRaises(TypeError): + ht.split(data_ht, "2") + with self.assertRaises(TypeError): + ht.split(data_ht, 2, "0") + with self.assertRaises(ValueError): + ht.split(data_ht, 2, -1) + with self.assertRaises(ValueError): + ht.split(data_ht, 2, 3) + with self.assertRaises(ValueError): + ht.split(data_ht, 5) + with self.assertRaises(ValueError): + ht.split(data_ht, [[0, 1]]) + + # ==================================== + # DISTRIBUTED CASE + # ==================================== + # axis == ary.split + # ==================================== + data_ht = ht.arange(120, split=0).reshape((4, 5, 6)) + data_np = data_ht.numpy() + + # indices = int + result = ht.split(data_ht, 2) + comparison = np.split(data_np, 2) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assertTrue((ht.array(comparison[i]) == result[i]).all()) + + # larger example + data_ht_large = ht.arange(160, split=0).reshape((8, 5, 4)) + data_np_large = data_ht_large.numpy() + + # indices = int + result = ht.split(data_ht_large, 2) + comparison = np.split(data_np_large, 2) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assertTrue((ht.array(comparison[i]) == result[i]).all()) + + # indices_or_sections = tuple + result = ht.split(data_ht, (1, 3, 5)) + comparison = np.split(data_np, (1, 3, 5)) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = list + result = ht.split(data_ht, [1, 3, 5]) + comparison = np.split(data_np, [1, 3, 5]) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = undistributed DNDarray + result = ht.split(data_ht, ht.array([1, 3, 5])) + comparison = np.split(data_np, np.array([1, 3, 5])) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = distributed DNDarray + result = ht.split(data_ht, ht.array([1, 3, 5], split=0)) + comparison = np.split(data_np, np.array([1, 3, 5])) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # ==================================== + # axis != ary.split + # ==================================== + # indices_or_sections = int + result = ht.split(data_ht, 2, 2) + comparison = np.split(data_np, 2, 2) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = list + result = ht.split(data_ht, [3, 4, 6], 2) + comparison = np.split(data_np, [3, 4, 6], 2) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = undistributed DNDarray + result = ht.split(data_ht, ht.array([3, 4, 6]), 2) + comparison = np.split(data_np, np.array([3, 4, 6]), 2) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = distributed DNDarray + indices = ht.array([3, 4, 6], split=0) + result = ht.split(data_ht, indices, 2) + comparison = np.split(data_np, np.array([3, 4, 6]), 2) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + def test_resplit(self): if ht.MPI_WORLD.size > 1: # resplitting with same axis, should leave everything unchanged @@ -2600,6 +2978,61 @@ def test_unique(self): res, inv = ht.unique(data_split_zero, return_inverse=True, sorted=True) self.assertTrue(torch.equal(inv, exp_inv.to(dtype=inv.dtype))) + def test_vsplit(self): + # for further testing, see test_split + data_ht = ht.arange(24).reshape((4, 3, 2)) + data_np = data_ht.numpy() + + # indices_or_sections = int + result = ht.vsplit(data_ht, 2) + comparison = np.vsplit(data_np, 2) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = tuple + result = ht.vsplit(data_ht, (0, 1)) + comparison = np.vsplit(data_np, (0, 1)) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = list + result = ht.vsplit(data_ht, [0, 1]) + comparison = np.vsplit(data_np, [0, 1]) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = undistributed DNDarray + result = ht.vsplit(data_ht, ht.array([0, 1])) + comparison = np.vsplit(data_np, np.array([0, 1])) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + + # indices_or_sections = distributed DNDarray + result = ht.vsplit(data_ht, ht.array([0, 1], split=0)) + comparison = np.vsplit(data_np, np.array([0, 1])) + + self.assertTrue(len(result) == len(comparison)) + + for i in range(len(result)): + self.assertIsInstance(result[i], ht.DNDarray) + self.assert_array_equal(result[i], comparison[i]) + def test_vstack(self): # cases to test: # MM===================================