diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index d8a356fac5..267787f455 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -37,7 +37,7 @@ These tools only profile the memory used by each process, not the entire functio - with `split=None` and `split not None` Python has an embedded profiler: https://docs.python.org/3.9/library/profile.html -Again, this will only provile the performance on each process. Printing the results with many processes +Again, this will only profile the performance on each process. Printing the results with many processes my be illegible. It may be easiest to save the output of each to a file. ---> diff --git a/CHANGELOG.md b/CHANGELOG.md index 6bd8162561..1478ae8712 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,8 +12,9 @@ - [#868](https://github.com/helmholtz-analytics/heat/pull/868) Fixed an issue in `__binary_op` where data was falsely distributed if a DNDarray has single element. ## Feature Additions -### Linear Algebra -- [#842](https://github.com/helmholtz-analytics/heat/pull/842) New feature: `vdot` + +### Arithmetics + - - [#887](https://github.com/helmholtz-analytics/heat/pull/887) Binary operations now support operands of equal shapes, equal `split` axes, but different distribution maps. ## Feature additions ### Communication @@ -25,6 +26,7 @@ # Feature additions ### Linear Algebra - [#840](https://github.com/helmholtz-analytics/heat/pull/840) New feature: `vecdot()` +- [#842](https://github.com/helmholtz-analytics/heat/pull/842) New feature: `vdot` - [#846](https://github.com/helmholtz-analytics/heat/pull/846) New features `norm`, `vector_norm`, `matrix_norm` - [#850](https://github.com/helmholtz-analytics/heat/pull/850) New Feature `cross` - [#877](https://github.com/helmholtz-analytics/heat/pull/877) New feature `det` diff --git a/heat/cluster/tests/test_kmedoids.py b/heat/cluster/tests/test_kmedoids.py index c3969b7655..523e4a8fee 100644 --- a/heat/cluster/tests/test_kmedoids.py +++ b/heat/cluster/tests/test_kmedoids.py @@ -50,7 +50,7 @@ def create_spherical_dataset( cluster4 = ht.stack((x - 2 * offset, y - 2 * offset, z - 2 * offset), axis=1) data = ht.concatenate((cluster1, cluster2, cluster3, cluster4), axis=0) - # Note: enhance when shuffel is available + # Note: enhance when shuffle is available return data def test_clusterer(self): diff --git a/heat/core/_operations.py b/heat/core/_operations.py index fe5f132393..9fbb53452e 100644 --- a/heat/core/_operations.py +++ b/heat/core/_operations.py @@ -51,147 +51,134 @@ def __binary_op( ------- result: ht.DNDarray A DNDarray containing the results of element-wise operation. + + Warning + ------- + If both operands are distributed, they must be distributed along the same dimension, i.e. `t1.split = t2.split`. + + MPI communication is necessary when both operands are distributed along the same dimension, but the distribution maps do not match. E.g.: + ``` + a = ht.ones(10000, split=0) + b = ht.zeros(10000, split=0) + c = a[:-1] + b[1:] + ``` + In such cases, one of the operands is redistributed OUT-OF-PLACE to match the distribution map of the other operand. + The operand determining the resulting distribution is chosen as follows: + 1) split is preferred to no split + 2) no (shape)-broadcasting in the split dimension if not necessary + 3) t1 is preferred to t2 """ + # Check inputs + if not np.isscalar(t1) and not isinstance(t1, DNDarray): + raise TypeError( + "Only DNDarrays and numeric scalars are supported, but input was {}".format(type(t1)) + ) + if not np.isscalar(t2) and not isinstance(t2, DNDarray): + raise TypeError( + "Only DNDarrays and numeric scalars are supported, but input was {}".format(type(t2)) + ) promoted_type = types.result_type(t1, t2).torch_type() - if np.isscalar(t1): + # Make inputs Dndarrays + if np.isscalar(t1) and np.isscalar(t2): try: - t1 = factories.array(t1, device=t2.device if isinstance(t2, DNDarray) else None) + t1 = factories.array(t1) + t2 = factories.array(t2) except (ValueError, TypeError): - raise TypeError("Data type not supported, input was {}".format(type(t1))) - - if np.isscalar(t2): - try: - t2 = factories.array(t2) - except (ValueError, TypeError): - raise TypeError( - "Only numeric scalars are supported, but input was {}".format(type(t2)) - ) - output_shape = (1,) - output_split = None - output_device = t2.device - output_comm = MPI_WORLD - elif isinstance(t2, DNDarray): - output_shape = t2.shape - output_split = t2.split - output_device = t2.device - output_comm = t2.comm - else: raise TypeError( - "Only tensors and numeric scalars are supported, but input was {}".format(type(t2)) - ) - - if t1.dtype != t2.dtype: - t1 = t1.astype(t2.dtype) - - elif isinstance(t1, DNDarray): - if np.isscalar(t2): - try: - t2 = factories.array(t2, device=t1.device) - output_shape = t1.shape - output_split = t1.split - output_device = t1.device - output_comm = t1.comm - except (ValueError, TypeError): - raise TypeError("Data type not supported, input was {}".format(type(t2))) - - elif isinstance(t2, DNDarray): - if t1.split is None: - t1 = factories.array( - t1, split=t2.split, copy=False, comm=t1.comm, device=t1.device, ndmin=-t2.ndim - ) - elif t2.split is None: - t2 = factories.array( - t2, split=t1.split, copy=False, comm=t2.comm, device=t2.device, ndmin=-t1.ndim - ) - elif t1.split != t2.split: - # It is NOT possible to perform binary operations on tensors with different splits, e.g. split=0 - # and split=1 - raise NotImplementedError("Not implemented for other splittings") - - output_shape = stride_tricks.broadcast_shape(t1.shape, t2.shape) - output_split = t1.split - output_device = t1.device - output_comm = t1.comm - - if t1.split is not None: - if t1.shape[t1.split] == 1 and t1.comm.is_distributed(): - # warnings.warn( - # "Broadcasting requires transferring data of first operator between MPI ranks!" - # ) - color = 0 if t1.comm.rank < t2.shape[t1.split] else 1 - newcomm = t1.comm.Split(color, t1.comm.rank) - if t1.comm.rank > 0 and color == 0: - t1.larray = torch.zeros( - t1.shape, dtype=t1.dtype.torch_type(), device=t1.device.torch_device - ) - newcomm.Bcast(t1) - newcomm.Free() - - if t2.split is not None: - if t2.shape[t2.split] == 1 and t2.comm.is_distributed(): - # warnings.warn( - # "Broadcasting requires transferring data of second operator between MPI ranks!" - # ) - color = 0 if t2.comm.rank < t1.shape[t2.split] else 1 - newcomm = t2.comm.Split(color, t2.comm.rank) - if t2.comm.rank > 0 and color == 0: - t2.larray = torch.zeros( - t2.shape, dtype=t2.dtype.torch_type(), device=t2.device.torch_device - ) - newcomm.Bcast(t2) - newcomm.Free() - - else: - raise TypeError( - "Only tensors and numeric scalars are supported, but input was {}".format(type(t2)) - ) - else: - raise NotImplementedError("Not implemented for non scalar") - - # sanitize output - if out is not None: - sanitation.sanitize_out(out, output_shape, output_split, output_device) - - # promoted_type = types.promote_types(t1.dtype, t2.dtype).torch_type() - if t1.split is not None: - if len(t1.lshape) > t1.split and t1.lshape[t1.split] == 0: - result = t1.larray.type(promoted_type) - else: - result = operation( - t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs + "Data type not supported, inputs were {} and {}".format(type(t1), type(t2)) ) + elif np.isscalar(t1) and isinstance(t2, DNDarray): + try: + t1 = factories.array(t1, device=t2.device, comm=t2.comm) + except (ValueError, TypeError): + raise TypeError("Data type not supported, input was {}".format(type(t1))) + elif isinstance(t1, DNDarray) and np.isscalar(t2): + try: + t2 = factories.array(t2, device=t1.device, comm=t1.comm) + except (ValueError, TypeError): + raise TypeError("Data type not supported, input was {}".format(type(t2))) + + # Make inputs have the same dimensionality + output_shape = stride_tricks.broadcast_shape(t1.shape, t2.shape) + # Broadcasting allows additional empty dimensions on the left side + # TODO simplify this once newaxis-indexing is supported to get rid of the loops + while len(t1.shape) < len(output_shape): + t1 = t1.expand_dims(axis=0) + while len(t2.shape) < len(output_shape): + t2 = t2.expand_dims(axis=0) + # t1 = t1[tuple([None] * (len(output_shape) - t1.ndim))] + # t2 = t2[tuple([None] * (len(output_shape) - t2.ndim))] + # print(t1.lshape, t2.lshape) + + def __get_out_params(target, other=None, map=None): + """ + Getter for the output parameters of a binary operation with target distribution. + If `other` is provided, its distribution will be matched to `target` or, if provided, + redistributed according to `map`. + + Parameters + ---------- + target : DNDarray + DNDarray determining the parameters + other : DNDarray + DNDarray to be adapted + map : Tensor + lshape_map `other` should be matched to. Defaults to `target.lshape_map` + + Returns + ------- + Tuple + split, device, comm, balanced, [other] + """ + if other is not None: + if out is None: + other = sanitation.sanitize_distribution(other, target=target, diff_map=map) + return target.split, target.device, target.comm, target.balanced, other + return target.split, target.device, target.comm, target.balanced + + if t1.split is not None and t1.shape[t1.split] == output_shape[t1.split]: # t1 is "dominant" + output_split, output_device, output_comm, output_balanced, t2 = __get_out_params(t1, t2) + elif t2.split is not None and t2.shape[t2.split] == output_shape[t2.split]: # t2 is "dominant" + output_split, output_device, output_comm, output_balanced, t1 = __get_out_params(t2, t1) + elif t1.split is not None: + # t1 is split but broadcast -> only on one rank; manipulate lshape_map s.t. this rank has 'full' data + lmap = t1.lshape_map + idx = lmap[:, t1.split].nonzero(as_tuple=True)[0] + lmap[idx.item(), t1.split] = output_shape[t1.split] + output_split, output_device, output_comm, output_balanced, t2 = __get_out_params( + t1, t2, map=lmap + ) elif t2.split is not None: - - if len(t2.lshape) > t2.split and t2.lshape[t2.split] == 0: - result = t2.larray.type(promoted_type) - else: - result = operation( - t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs - ) - else: - result = operation( - t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs + # t2 is split but broadcast -> only on one rank; manipulate lshape_map s.t. this rank has 'full' data + lmap = t2.lshape_map + idx = lmap[:, t2.split].nonzero(as_tuple=True)[0] + lmap[idx.item(), t2.split] = output_shape[t2.split] + output_split, output_device, output_comm, output_balanced, t1 = __get_out_params( + t2, other=t1, map=lmap ) - - if not isinstance(result, torch.Tensor): - result = torch.tensor(result, device=output_device.torch_device) + else: # both are not split + output_split, output_device, output_comm, output_balanced = __get_out_params(t1) if out is not None: - out_dtype = out.dtype - out.larray = result - out._DNDarray__comm = output_comm - out = out.astype(out_dtype) + sanitation.sanitize_out(out, output_shape, output_split, output_device, output_comm) + t1, t2 = sanitation.sanitize_distribution(t1, t2, target=out) + out.larray[:] = operation( + t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs + ) return out + # print(t1.lshape, t2.lshape) + + result = operation(t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs) return DNDarray( result, output_shape, types.heat_type_of(result), output_split, - output_device, - output_comm, - balanced=None, + device=output_device, + comm=output_comm, + balanced=output_balanced, ) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 2580d09090..539dc5e604 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -507,7 +507,7 @@ def balance_(self) -> DNDarray: [1/2] (7, 2) (2, 2) [2/2] (7, 2) (2, 2) """ - if self.is_balanced(): + if self.is_balanced(force_check=True): return self.redistribute_() @@ -582,7 +582,7 @@ def create_lshape_map(self, force_check: bool = False) -> torch.Tensor: result. Otherwise, create the lshape_map """ if not force_check and self.__lshape_map is not None: - return self.__lshape_map + return self.__lshape_map.clone() lshape_map = torch.zeros( (self.comm.size, self.ndim), dtype=torch.int, device=self.device.torch_device @@ -590,7 +590,7 @@ def create_lshape_map(self, force_check: bool = False) -> torch.Tensor: if not self.is_distributed: lshape_map[:] = torch.tensor(self.gshape, device=self.device.torch_device) return lshape_map - if self.is_balanced(): + if self.is_balanced(force_check=True): for i in range(self.comm.size): _, lshape, _ = self.comm.chunk(self.gshape, self.split, rank=i) lshape_map[i, :] = torch.tensor(lshape, device=self.device.torch_device) @@ -601,7 +601,7 @@ def create_lshape_map(self, force_check: bool = False) -> torch.Tensor: self.comm.Allreduce(MPI.IN_PLACE, lshape_map, MPI.SUM) self.__lshape_map = lshape_map - return lshape_map + return lshape_map.clone() def __float__(self) -> DNDarray: """ @@ -709,7 +709,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar """ this loop handles all other cases. DNDarrays which make it to here refer to advanced indexing slices, as do the torch tensors. Both DNDaarrys and torch.Tensors are cast into lists here by PyTorch. lists mean advanced indexing will be used""" - h = [slice(None, None, None)] * self.ndim + h = [slice(None, None, None)] * max(self.ndim, 1) if isinstance(key, DNDarray): key = manipulations.resplit(key) if key.larray.dtype in [torch.bool, torch.uint8]: @@ -751,11 +751,17 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar kend = key[ell_ind + 1 :] slices = [slice(None)] * (self.ndim - (len(kst) + len(kend))) key = kst + slices + kend + else: + key = key + [slice(None)] * (self.ndim - len(key)) - key = tuple(key) + self_proxy = self.__torch_proxy__() + for i in range(len(key)): + if self.__key_adds_dimension(key, i, self_proxy): + key[i] = slice(None) + return self.expand_dims(i)[tuple(key)] + key = tuple(key) # assess final global shape - self_proxy = self.__torch_proxy__() gout_full = list(self_proxy[key].shape) # calculate new split axis @@ -766,7 +772,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar new_split = 0 else: for i in range(len(key[: self.split + 1])): - if self.__is_key_singular(key, i, self_proxy): + if self.__key_is_singular(key, i, self_proxy): new_split = None if i == self.split else new_split - 1 key = tuple(key) @@ -841,15 +847,12 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar # standard slicing along the split axis, # adjust the slice start, stop, and step, then run it on the processes which have the requested data key = list(key) - key_start = key[self.split].start if key[self.split].start is not None else 0 - key_stop = ( - key[self.split].stop - if key[self.split].stop is not None - else self.gshape[self.split] + key[self.split] = stride_tricks.sanitize_slice(key[self.split], self.gshape[self.split]) + key_start, key_stop, key_step = ( + key[self.split].start, + key[self.split].stop, + key[self.split].step, ) - if key_stop < 0: - key_stop = self.gshape[self.split] + key[self.split].stop - key_step = key[self.split].step og_key_start = key_start st_pr = torch.where(key_start < chunk_ends)[0] st_pr = st_pr[0] if len(st_pr) > 0 else self.comm.size @@ -873,7 +876,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar lout[new_split] = 0 arr = torch.empty(lout, dtype=self.__array.dtype, device=self.__array.device) - elif self.__is_key_singular(key, self.split, self_proxy): + elif self.__key_is_singular(key, self.split, self_proxy): # getting one item along split axis: key = list(key) if isinstance(key[self.split], list): @@ -957,11 +960,17 @@ def is_distributed(self) -> bool: return self.split is not None and self.comm.is_distributed() @staticmethod - def __is_key_singular(key: any, axis: int, self_proxy: torch.Tensor) -> bool: + def __key_is_singular(key: any, axis: int, self_proxy: torch.Tensor) -> bool: # determine if the key gets a singular item - zeros = tuple([0] * (self_proxy.ndim - 1)) + zeros = (0,) * (self_proxy.ndim - 1) return self_proxy[(*zeros[:axis], key[axis], *zeros[axis:])].ndim == 0 + @staticmethod + def __key_adds_dimension(key: any, axis: int, self_proxy: torch.Tensor) -> bool: + # determine if the key adds a new dimension + zeros = (0,) * (self_proxy.ndim - 1) + return self_proxy[(*zeros[:axis], key[axis], *zeros[axis:])].ndim == 2 + def item(self): """ Returns the only element of a 1-element :class:`DNDarray`. diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 37666abbf5..54540485d0 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2891,7 +2891,7 @@ def stack( ValueError If the `DNDarray`s are of different shapes, or if they are split along different axes (`split` attribute). RuntimeError - If the `DNDarrays` reside of different devices, or if they are unevenly distributed across ranks (method `is_balanced()` returns `False`) + If the `DNDarrays` reside on different devices. Notes ----- @@ -2949,83 +2949,52 @@ def stack( [2/2] [18, 38], [2/2] [19, 39]]]) """ - # sanitation - sanitation.sanitize_sequence(arrays) - + arrays = sanitation.sanitize_sequence(arrays) if len(arrays) < 2: raise ValueError("stack expects a sequence of at least 2 DNDarrays") - for i, array in enumerate(arrays): - sanitation.sanitize_in(array) - - arrays_metadata = list( - [array.gshape, array.split, array.device, array.balanced] for array in arrays - ) - num_arrays = len(arrays) - # metadata must be identical for all arrays - if arrays_metadata.count(arrays_metadata[0]) != num_arrays: - shapes = list(array.gshape for array in arrays) - if shapes.count(shapes[0]) != num_arrays: - raise ValueError( - "All DNDarrays in sequence must have the same shape, got shapes {}".format(shapes) - ) - splits = list(array.split for array in arrays) - if splits.count(splits[0]) != num_arrays: - raise ValueError( - "All DNDarrays in sequence must have the same split axis, got splits {}" - "Check out the heat.resplit() documentation.".format(splits) - ) - devices = list(array.device for array in arrays) - if devices.count(devices[0]) != num_arrays: - raise RuntimeError( - "DNDarrays in sequence must reside on the same device, got devices {} {} {}".format( - devices, devices[0].device_id, devices[1].device_id - ) - ) - else: - array_shape, array_split, array_device, array_balanced = arrays_metadata[0][:4] - # extract torch tensors - t_arrays = list(array.larray for array in arrays) - # output dtype - t_dtypes = list(t_array.dtype for t_array in t_arrays) - t_array_dtype = t_dtypes[0] - if t_dtypes.count(t_dtypes[0]) != num_arrays: - for d in range(1, len(t_dtypes)): - t_array_dtype = ( - t_array_dtype - if t_array_dtype is t_dtypes[d] - else torch.promote_types(t_array_dtype, t_dtypes[d]) - ) - t_arrays = list(t_array.type(t_array_dtype) for t_array in t_arrays) - array_dtype = types.canonical_heat_type(t_array_dtype) + target = arrays[0] + try: + arrays = sanitation.sanitize_distribution( + *arrays, target=target + ) # also checks target again + except NotImplementedError as e: # transform split axis error to ValueError + raise ValueError(e) - # sanitize axis - axis = stride_tricks.sanitize_axis(array_shape + (num_arrays,), axis) + # extract torch tensors + t_arrays = list(array.larray for array in arrays) # output shape and split - stacked_shape = array_shape[:axis] + (num_arrays,) + array_shape[axis:] - if array_split is not None: - stacked_split = array_split + 1 if axis <= array_split else array_split + axis = stride_tricks.sanitize_axis(target.gshape + (len(arrays),), axis) + stacked_shape = target.gshape[:axis] + (len(arrays),) + target.gshape[axis:] + if target.split is not None: + stacked_split = target.split + 1 if axis <= target.split else target.split else: stacked_split = None # stack locally - t_stacked = torch.stack(t_arrays, dim=axis) + try: + t_stacked = torch.stack(t_arrays, dim=axis) + result_dtype = types.canonical_heat_type(t_stacked.dtype) + except Exception as e: + if "size" in e.args[0] or "shape" in e.args[0]: + raise ValueError(e) + raise e # return stacked DNDarrays if out is not None: - sanitation.sanitize_out(out, stacked_shape, stacked_split, array_device) + sanitation.sanitize_out(out, stacked_shape, stacked_split, target.device) out.larray = t_stacked.type(out.larray.dtype) return out stacked = DNDarray( t_stacked, gshape=stacked_shape, - dtype=array_dtype, + dtype=result_dtype, split=stacked_split, - device=array_device, - comm=arrays[0].comm, - balanced=array_balanced, + device=target.device, + comm=target.comm, + balanced=target.balanced, ) return stacked diff --git a/heat/core/relational.py b/heat/core/relational.py index 3fa1172da0..3b025ce56e 100644 --- a/heat/core/relational.py +++ b/heat/core/relational.py @@ -4,6 +4,7 @@ from __future__ import annotations import torch +import numpy as np from typing import Union @@ -12,6 +13,8 @@ from . import _operations from . import dndarray from . import types +from . import sanitation +from . import factories __all__ = [ "eq", @@ -98,14 +101,78 @@ def equal(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> boo >>> ht.equal(x, 3.0) False """ - result_tensor = _operations.__binary_op(torch.equal, x, y) - - if result_tensor.larray.numel() == 1: - result_value = result_tensor.larray.item() + if np.isscalar(x) and np.isscalar(y): + x = factories.array(x) + y = factories.array(y) + elif isinstance(x, DNDarray) and np.isscalar(y): + if x.gnumel == 1: + return equal(x.item(), y) + return False + # y = factories.full_like(x, fill_value=y) + elif np.isscalar(x) and isinstance(y, DNDarray): + if y.gnumel == 1: + return equal(x, y.item()) + return False + # x = factories.full_like(y, fill_value=x) + else: # elif isinstance(x, DNDarray) and isinstance(y, DNDarray): + if x.gnumel == 1: + return equal(x.item(), y) + elif y.gnumel == 1: + return equal(x, y.item()) + elif not x.comm == y.comm: + raise NotImplementedError("Not implemented for other comms") + elif not x.gshape == y.gshape: + return False + + if x.split is None and y.split is None: + pass + elif x.split is None and y.split is not None: + if y.is_balanced(force_check=False): + x = factories.array(x, split=y.split, copy=False, comm=x.comm, device=x.device) + else: + target_map = y.lshape_map + idx = [slice(None)] * x.ndim + idx[y.split] = slice( + target_map[: x.comm.rank, y.split].sum(), + target_map[: x.comm.rank + 1, y.split].sum(), + ) + x = factories.array( + x.larray[tuple(idx)], is_split=y.split, copy=False, comm=x.comm, device=x.device + ) + elif x.split is not None and y.split is None: + if x.is_balanced(force_check=False): + y = factories.array(y, split=x.split, copy=False, comm=y.comm, device=y.device) + else: + target_map = x.lshape_map + idx = [slice(None)] * y.ndim + idx[x.split] = slice( + target_map[: y.comm.rank, x.split].sum(), + target_map[: y.comm.rank + 1, x.split].sum(), + ) + y = factories.array( + y.larray[tuple(idx)], is_split=x.split, copy=False, comm=y.comm, device=y.device + ) + elif not x.split == y.split: + raise ValueError( + "DNDarrays must have the same split axes, found {} and {}".format(x.split, y.split) + ) + elif not (x.is_balanced(force_check=False) and y.is_balanced(force_check=False)): + x_lmap = x.lshape_map + y_lmap = y.lshape_map + if not torch.equal(x_lmap, y_lmap): + x = x.balance() + y = y.balance() + + result_type = types.result_type(x, y) + x = x.astype(result_type) + y = y.astype(result_type) + + if x.larray.numel() > 0: + result_value = torch.equal(x.larray, y.larray) else: result_value = True - return result_tensor.comm.allreduce(result_value, MPI.LAND) + return x.comm.allreduce(result_value, MPI.LAND) def ge(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarray: diff --git a/heat/core/sanitation.py b/heat/core/sanitation.py index 38c98a0fa8..f5550f8562 100644 --- a/heat/core/sanitation.py +++ b/heat/core/sanitation.py @@ -8,7 +8,7 @@ import warnings from typing import Any, Union, Sequence, List, Tuple -from .communication import MPI +from .communication import MPI, Communication from .dndarray import DNDarray from . import factories @@ -17,6 +17,7 @@ __all__ = [ + "sanitize_distribution", "sanitize_in", "sanitize_infinity", "sanitize_in_tensor", @@ -27,6 +28,134 @@ ] +def sanitize_distribution( + *args: DNDarray, target: DNDarray, diff_map: torch.Tensor = None +) -> Union[DNDarray, Tuple(DNDarray)]: + """ + Distribute every `arg` according to `target.lshape_map` or, if provided, `diff_map`. + After this sanitation, the lshapes are compatible along the split dimension. + `Args` can contain non-distributed DNDarrays, they will be split afterwards, if `target` is split. + + Parameters + ---------- + args : DNDarray + Dndarrays to be distributed + + target : DNDarray + Dndarray used to sanitize the metadata and to, if diff_map is not given, determine the resulting distribution. + + diff_map : torch.Tensor (optional) + Different lshape_map. Overwrites the distribution of the target array. + Used in cases when the target array does not correspond to the actually wanted distribution, + e.g. because it only contains a single element along the split axis and gets broadcast. + + Raises + ------ + TypeError + When an argument is not a ``DNDarray`` or ``None``. + ValueError + When the split-axes or sizes along the split-axis do not match. + + See Also + --------- + :func:`~heat.core.dndarray.create_lshape_map` + Function to create the lshape_map. + """ + out = [] + sanitize_in(target) + target_split = target.split + if diff_map is not None: + sanitize_in_tensor(diff_map) + target_map = diff_map + if target_split is not None: + tmap_split = target_map[:, target_split] + target_size = tmap_split.sum().item() + # Check if the diff_map is balanced + w_size = target_map.shape[0] + tmap_balanced = torch.full_like(tmap_split, fill_value=target_size // w_size) + remainder = target_size % w_size + tmap_balanced[:remainder] += 1 + target_balanced = torch.equal(tmap_balanced, tmap_split) + elif target_split is not None: + target_map = target.lshape_map + target_size = target.shape[target_split] + target_balanced = target.is_balanced(force_check=False) + + for arg in args: + sanitize_in(arg) + if not target.comm == arg.comm: + try: + raise NotImplementedError( + "Not implemented for other comms, found {} and {}".format( + target.comm.name, arg.comm.name + ) + ) + except Exception: + raise NotImplementedError("Not implemented for other comms") + elif target_split is None: + if arg.split is not None: + raise NotImplementedError( + "DNDarrays must have the same split axes, found {} and {}".format( + target_split, arg.split + ) + ) + else: + out.append(arg) + elif arg.shape[target_split] == 1 and target_size > 1: # broadcasting in split-dimension + out.append(arg.resplit(None)) + elif arg.shape[target_split] != target_size: + raise ValueError( + "Cannot distribute to match in split dimension, shapes are {} and {}".format( + target.shape, arg.shape + ) + ) + elif arg.split is None: # undistributed case + if target_balanced: + out.append( + factories.array( + arg, split=target_split, copy=False, comm=arg.comm, device=arg.device + ) + ) + else: + idx = [slice(None)] * arg.ndim + idx[target_split] = slice( + target_map[: arg.comm.rank, target_split].sum(), + target_map[: arg.comm.rank + 1, target_split].sum(), + ) + out.append( + factories.array( + arg.larray[tuple(idx)], + is_split=target_split, + copy=False, + comm=arg.comm, + device=arg.device, + ) + ) + elif arg.split != target_split: + raise NotImplementedError( + "DNDarrays must have the same split axes, found {} and {}".format( + target_split, arg.split + ) + ) + elif not ( + # False + target_balanced + and arg.is_balanced(force_check=False) + ): # Split axes are the same and atleast one is not balanced + current_map = arg.lshape_map + out_map = current_map.clone() + out_map[:, target_split] = target_map[:, target_split] + if not (current_map[:, target_split] == target_map[:, target_split]).all(): + out.append(arg.redistribute(lshape_map=current_map, target_map=out_map)) + else: + out.append(arg) + else: # both are balanced + out.append(arg) + if len(out) == 1: + return out[0] + return tuple(out) + + def sanitize_in(x: Any): """ Verify that input object is ``DNDarray``. @@ -127,7 +256,13 @@ def sanitize_lshape(array: DNDarray, tensor: torch.Tensor): ) -def sanitize_out(out: Any, output_shape: Tuple, output_split: int, output_device: str): +def sanitize_out( + out: Any, + output_shape: Tuple, + output_split: int, + output_device: str, + output_comm: Communication = None, +): """ Validate output buffer ``out``. @@ -145,6 +280,9 @@ def sanitize_out(out: Any, output_shape: Tuple, output_split: int, output_device output_device : Str "cpu" or "gpu" as per location of data + output_comm : Communication + Communication object of the result of the operation + Raises ------ TypeError @@ -155,18 +293,59 @@ def sanitize_out(out: Any, output_shape: Tuple, output_split: int, output_device if not isinstance(out, DNDarray): raise TypeError("expected `out` to be None or a DNDarray, but was {}".format(type(out))) - if out.gshape != output_shape: + out_proxy = out.__torch_proxy__() + out_proxy.names = [ + "split" if (out.split is not None and i == out.split) else "_{}".format(i) + for i in range(out_proxy.ndim) + ] + out_proxy = out_proxy.squeeze() + + check_proxy = torch.ones(1).expand(output_shape) + check_proxy.names = [ + "split" if (output_split is not None and i == output_split) else "_{}".format(i) + for i in range(check_proxy.ndim) + ] + check_proxy = check_proxy.squeeze() + + if out_proxy.shape != check_proxy.shape: raise ValueError( "Expecting output buffer of shape {}, got {}".format(output_shape, out.shape) ) - if out.split is not output_split: + count_split = int(out.split is not None) + int(output_split is not None) + if count_split == 1: raise ValueError( - "Split axis of output buffer is inconsistent with split semantics (see documentation)." + "Split axis of output buffer is inconsistent with split semantics for this operation." ) + elif count_split == 2: + if out.shape[out.split] > 1: # split axis is not squeezed out + if out_proxy.names.index("split") != check_proxy.names.index("split"): + raise ValueError( + "Split axis of output buffer is inconsistent with split semantics for this operation." + ) + else: # split axis is squeezed out + num_dim_before_split = len( + [name for name in out_proxy.names if int(name[1:]) < out.split] + ) + check_num_dim_before_split = len( + [name for name in check_proxy.names if int(name[1:]) < output_split] + ) + if num_dim_before_split != check_num_dim_before_split: + raise ValueError( + "Split axis of output buffer is inconsistent with split semantics for this operation." + ) if out.device is not output_device: raise ValueError( "Device mismatch: out is on {}, should be on {}".format(out.device, output_device) ) + if output_comm is not None and out.comm != output_comm: + try: + raise NotImplementedError( + "Not implemented for other comms, found {} and {}".format( + out.comm.name, output_comm.name + ) + ) + except Exception: + raise NotImplementedError("Not implemented for other comms") def sanitize_sequence( diff --git a/heat/core/stride_tricks.py b/heat/core/stride_tricks.py index f07af2418a..8d2ac968cf 100644 --- a/heat/core/stride_tricks.py +++ b/heat/core/stride_tricks.py @@ -1,9 +1,10 @@ """ -A collection of functions used for inferring or correction things before major computation +A collection of functions used for inferring or correcting things before major computation """ import itertools import numpy as np +import torch from typing import Tuple, Union @@ -41,17 +42,31 @@ def broadcast_shape(shape_a: Tuple[int, ...], shape_b: Tuple[int, ...]) -> Tuple "operands could not be broadcast, input shapes {} {}".format(shape_a, shape_b) ValueError: operands could not be broadcast, input shapes (2, 1) (8, 4, 3) """ - it = itertools.zip_longest(shape_a[::-1], shape_b[::-1], fillvalue=1) - resulting_shape = max(len(shape_a), len(shape_b)) * [None] - for i, (a, b) in enumerate(it): - if a == 1 or b == 1 or a == b: - resulting_shape[i] = max(a, b) - else: - raise ValueError( - "operands could not be broadcast, input shapes {} {}".format(shape_a, shape_b) - ) - - return tuple(resulting_shape[::-1]) + try: + resulting_shape = torch.broadcast_shapes(shape_a, shape_b) + except AttributeError: # torch < 1.8 + it = itertools.zip_longest(shape_a[::-1], shape_b[::-1], fillvalue=1) + resulting_shape = max(len(shape_a), len(shape_b)) * [None] + for i, (a, b) in enumerate(it): + if a == 0 and b == 1 or b == 0 and a == 1: + resulting_shape[i] = 0 + elif a == 1 or b == 1 or a == b: + resulting_shape[i] = max(a, b) + else: + raise ValueError( + "operands could not be broadcast, input shapes {} {}".format(shape_a, shape_b) + ) + return tuple(resulting_shape[::-1]) + except TypeError: + raise TypeError("operand 1 must be tuple of ints, not {}".format(type(shape_a))) + except NameError: + raise TypeError("operands must be tuples of ints, not {} and {}".format(shape_a, shape_b)) + except RuntimeError: + raise ValueError( + "operands could not be broadcast, input shapes {} {}".format(shape_a, shape_b) + ) + + return tuple(resulting_shape) def sanitize_axis( diff --git a/heat/core/tests/test_arithmetics.py b/heat/core/tests/test_arithmetics.py index 1203b57512..e57ce93901 100644 --- a/heat/core/tests/test_arithmetics.py +++ b/heat/core/tests/test_arithmetics.py @@ -21,9 +21,10 @@ def setUpClass(cls): cls.another_tensor = ht.array([[2.0, 2.0], [2.0, 2.0]]) cls.a_split_tensor = cls.another_tensor.copy().resplit_(0) - cls.errorneous_type = (2, 2) + cls.erroneous_type = (2, 2) def test_add(self): + # test basics result = ht.array([[3.0, 4.0], [5.0, 6.0]]) self.assertTrue(ht.equal(ht.add(self.a_scalar, self.a_scalar), ht.float32(4.0))) @@ -45,10 +46,41 @@ def test_add(self): else: self.assertEqual(c.larray.size()[0], 0) + # test with differently distributed DNDarrays + a = ht.ones(10, split=0) + b = ht.zeros(10, split=0) + c = a[:-1] + b[1:] + self.assertTrue((c == 1).all()) + self.assertTrue(c.lshape == a[:-1].lshape) + + c = a[1:-1] + b[1:-1] # test unbalanced + self.assertTrue((c == 1).all()) + self.assertTrue(c.lshape == a[1:-1].lshape) + + # test one unsplit + a = ht.ones(10, split=None) + b = ht.zeros(10, split=0) + c = a[:-1] + b[1:] + self.assertTrue((c == 1).all()) + self.assertEqual(c.lshape, b[1:].lshape) + c = b[:-1] + a[1:] + self.assertTrue((c == 1).all()) + self.assertEqual(c.lshape, b[:-1].lshape) + + # broadcast in split dimension + a = ht.ones((1, 10), split=0) + b = ht.zeros((2, 10), split=0) + c = a + b + self.assertTrue((c == 1).all()) + self.assertTrue(c.lshape == b.lshape) + c = b + a + self.assertTrue((c == 1).all()) + self.assertTrue(c.lshape == b.lshape) + with self.assertRaises(ValueError): ht.add(self.a_tensor, self.another_vector) with self.assertRaises(TypeError): - ht.add(self.a_tensor, self.errorneous_type) + ht.add(self.a_tensor, self.erroneous_type) with self.assertRaises(TypeError): ht.add("T", "s") @@ -76,7 +108,7 @@ def test_bitwise_and(self): with self.assertRaises(ValueError): ht.bitwise_and(an_int_vector, another_int_vector) with self.assertRaises(TypeError): - ht.bitwise_and(self.a_tensor, self.errorneous_type) + ht.bitwise_and(self.a_tensor, self.erroneous_type) with self.assertRaises(TypeError): ht.bitwise_and("T", "s") with self.assertRaises(TypeError): @@ -112,7 +144,7 @@ def test_bitwise_or(self): with self.assertRaises(ValueError): ht.bitwise_or(an_int_vector, another_int_vector) with self.assertRaises(TypeError): - ht.bitwise_or(self.a_tensor, self.errorneous_type) + ht.bitwise_or(self.a_tensor, self.erroneous_type) with self.assertRaises(TypeError): ht.bitwise_or("T", "s") with self.assertRaises(TypeError): @@ -148,7 +180,7 @@ def test_bitwise_xor(self): with self.assertRaises(ValueError): ht.bitwise_xor(an_int_vector, another_int_vector) with self.assertRaises(TypeError): - ht.bitwise_xor(self.a_tensor, self.errorneous_type) + ht.bitwise_xor(self.a_tensor, self.erroneous_type) with self.assertRaises(TypeError): ht.bitwise_xor("T", "s") with self.assertRaises(TypeError): @@ -257,17 +289,16 @@ def test_cumsum(self): def test_diff(self): ht_array = ht.random.rand(20, 20, 20, split=None) arb_slice = [0] * 3 - for dim in range(3): # loop over 3 dimensions + for dim in range(0, 3): # loop over 3 dimensions arb_slice[dim] = slice(None) + tup_arb = tuple(arb_slice) + np_array = ht_array[tup_arb].numpy() for ax in range(dim + 1): # loop over the possible axis values for sp in range(dim + 1): # loop over the possible split values + lp_array = ht.manipulations.resplit(ht_array[tup_arb], sp) # loop to 3 for the number of times to do the diff for nl in range(1, 4): # only generating the number once and then - tup_arb = tuple(arb_slice) - lp_array = ht.manipulations.resplit(ht_array[tup_arb], sp) - np_array = ht_array[tup_arb].numpy() - ht_diff = ht.diff(lp_array, n=nl, axis=ax) np_diff = ht.array(np.diff(np_array, n=nl, axis=ax)) @@ -280,10 +311,11 @@ def test_diff(self): ht_append = ht.ones( append_shape, dtype=lp_array.dtype, split=lp_array.split ) + ht_diff_pend = ht.diff(lp_array, n=nl, axis=ax, prepend=0, append=ht_append) + np_append = np.ones(append_shape, dtype=lp_array.larray.cpu().numpy().dtype) np_diff_pend = ht.array( - np.diff(np_array, n=nl, axis=ax, prepend=0, append=ht_append.numpy()), - dtype=ht_diff_pend.dtype, + np.diff(np_array, n=nl, axis=ax, prepend=0, append=np_append) ) self.assertTrue(ht.equal(ht_diff_pend, np_diff_pend)) self.assertEqual(ht_diff_pend.split, sp) @@ -333,7 +365,7 @@ def test_div(self): with self.assertRaises(ValueError): ht.div(self.a_tensor, self.another_vector) with self.assertRaises(TypeError): - ht.div(self.a_tensor, self.errorneous_type) + ht.div(self.a_tensor, self.erroneous_type) with self.assertRaises(TypeError): ht.div("T", "s") @@ -362,7 +394,7 @@ def test_fmod(self): with self.assertRaises(ValueError): ht.fmod(self.a_tensor, self.another_vector) with self.assertRaises(TypeError): - ht.fmod(self.a_tensor, self.errorneous_type) + ht.fmod(self.a_tensor, self.erroneous_type) with self.assertRaises(TypeError): ht.fmod("T", "s") @@ -419,7 +451,7 @@ def test_mul(self): with self.assertRaises(ValueError): ht.mul(self.a_tensor, self.another_vector) with self.assertRaises(TypeError): - ht.mul(self.a_tensor, self.errorneous_type) + ht.mul(self.a_tensor, self.erroneous_type) with self.assertRaises(TypeError): ht.mul("T", "s") @@ -464,7 +496,7 @@ def test_pow(self): with self.assertRaises(ValueError): ht.pow(self.a_tensor, self.another_vector) with self.assertRaises(TypeError): - ht.pow(self.a_tensor, self.errorneous_type) + ht.pow(self.a_tensor, self.erroneous_type) with self.assertRaises(TypeError): ht.pow("T", "s") @@ -601,7 +633,7 @@ def test_sub(self): with self.assertRaises(ValueError): ht.sub(self.a_tensor, self.another_vector) with self.assertRaises(TypeError): - ht.sub(self.a_tensor, self.errorneous_type) + ht.sub(self.a_tensor, self.erroneous_type) with self.assertRaises(TypeError): ht.sub("T", "s") diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 0bfb1dbfdb..8aa8a35920 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1413,6 +1413,19 @@ def test_setitem_getitem(self): setting = ht.zeros((8, 8), split=1) x[1:-1, 1:-1] = setting + for split in [None, 0, 1, 2]: + for new_dim in [0, 1, 2]: + for add in [np.newaxis, None]: + arr = ht.ones((4, 3, 2), split=split, dtype=ht.int32) + check = torch.ones((4, 3, 2), dtype=torch.int32) + idx = [slice(None), slice(None), slice(None)] + idx[new_dim] = add + idx = tuple(idx) + arr = arr[idx] + check = check[idx] + self.assertTrue(arr.shape == check.shape) + self.assertTrue(arr.lshape[new_dim] == 1) + def test_size_gnumel(self): a = ht.zeros((10, 10, 10), split=None) self.assertEqual(a.size, 10 * 10 * 10) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index ef74b39a4d..3815ae3e73 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3245,7 +3245,7 @@ def test_stack(self): with self.assertRaises(ValueError): ht.stack((ht_a_split, ht_b_wrong_split, ht_c_split)) with self.assertRaises(ValueError): - ht.stack((ht_a_split, ht_b, ht_c_split)) + ht.stack((ht_a_split, ht_b.resplit(1), ht_c_split)) out_wrong_type = torch.empty((3, 5, 4), dtype=torch.float32) with self.assertRaises(TypeError): ht.stack((ht_a_split, ht_b_split, ht_c_split), out=out_wrong_type) diff --git a/heat/core/tests/test_operations.py b/heat/core/tests/test_operations.py index cce5015b64..ac42ea967c 100644 --- a/heat/core/tests/test_operations.py +++ b/heat/core/tests/test_operations.py @@ -1,6 +1,7 @@ import torch import heat as ht +import numpy as np from .test_suites.basic_test import TestCase @@ -64,7 +65,7 @@ def test___binary_bit_op_broadcast(self): self.assertEqual(result.shape, (1, 2)) # broadcast with unequal dimensions and two splitted tensors - left_tensor = ht.ones((4, 1, 3, 1, 2), split=0, dtype=torch.uint8) + left_tensor = ht.ones((4, 1, 3, 1, 2), split=2, dtype=torch.uint8) right_tensor = ht.ones((1, 3, 1), split=0, dtype=torch.uint8) result = left_tensor & right_tensor self.assertEqual(result.shape, (4, 1, 3, 3, 2)) @@ -77,3 +78,26 @@ def test___binary_bit_op_broadcast(self): ht.bitwise_or( ht.ones((1, 2), dtype=ht.int32, split=0), ht.ones((1, 2), dtype=ht.int32, split=1) ) + + a = ht.ones((4, 4), split=None) + b = ht.zeros((4, 4), split=0) + self.assertTrue(ht.equal(a * b, b)) + self.assertTrue(ht.equal(b * a, b)) + self.assertTrue(ht.equal(a[0] * b[0], b[0])) + self.assertTrue(ht.equal(b[0] * a[0], b[0])) + self.assertTrue(ht.equal(a * b[0:1], b)) + self.assertTrue(ht.equal(b[0:1] * a, b)) + self.assertTrue(ht.equal(a[0:1] * b, b)) + self.assertTrue(ht.equal(b * a[0:1], b)) + + c = ht.array([1, 2, 3, 4], comm=ht.MPI_SELF) + with self.assertRaises(NotImplementedError): + b + c + with self.assertRaises(TypeError): + ht.minimum(a, np.float128(1)) + with self.assertRaises(TypeError): + ht.minimum(np.float128(1), a) + with self.assertRaises(NotImplementedError): + a.resplit(1) * b + with self.assertRaises(ValueError): + a[2:] * b diff --git a/heat/core/tests/test_relational.py b/heat/core/tests/test_relational.py index 654e41f9ce..e3f72522f8 100644 --- a/heat/core/tests/test_relational.py +++ b/heat/core/tests/test_relational.py @@ -41,9 +41,25 @@ def test_eq(self): def test_equal(self): self.assertTrue(ht.equal(self.a_tensor, self.a_tensor)) + self.assertFalse(ht.equal(self.a_tensor[1:], self.a_tensor)) + self.assertFalse(ht.equal(self.a_split_tensor[1:], self.a_tensor[1:])) + self.assertFalse(ht.equal(self.a_tensor[1:], self.a_split_tensor[1:])) self.assertFalse(ht.equal(self.a_tensor, self.another_tensor)) self.assertFalse(ht.equal(self.a_tensor, self.a_scalar)) + self.assertFalse(ht.equal(self.a_scalar, self.a_tensor)) + self.assertFalse(ht.equal(self.a_scalar, self.a_tensor[0, 0])) + self.assertFalse(ht.equal(self.a_tensor[0, 0], self.a_scalar)) self.assertFalse(ht.equal(self.another_tensor, self.a_scalar)) + self.assertTrue(ht.equal(self.split_ones_tensor[:, 0], self.split_ones_tensor[:, 1])) + self.assertTrue(ht.equal(self.split_ones_tensor[:, 1], self.split_ones_tensor[:, 0])) + self.assertFalse(ht.equal(self.a_tensor, self.a_split_tensor)) + self.assertFalse(ht.equal(self.a_split_tensor, self.a_tensor)) + + arr = ht.array([[1, 2], [3, 4]], comm=ht.MPI_SELF) + with self.assertRaises(NotImplementedError): + ht.equal(self.a_tensor, arr) + with self.assertRaises(ValueError): + ht.equal(self.a_split_tensor, self.a_split_tensor.resplit(1)) def test_ge(self): result = ht.uint8([[False, True], [True, True]]) diff --git a/heat/core/tests/test_statistics.py b/heat/core/tests/test_statistics.py index a26305590c..5ee00645f3 100644 --- a/heat/core/tests/test_statistics.py +++ b/heat/core/tests/test_statistics.py @@ -782,24 +782,24 @@ def test_maximum(self): random_volume_3 = ht.array([]) with self.assertRaises(ValueError): ht.maximum(random_volume_1, random_volume_3) - random_volume_3 = ht.random.randn(4, 2, 3, split=0) + random_volume_4 = ht.random.randn(4, 2, 3, split=0) with self.assertRaises(ValueError): - ht.maximum(random_volume_1, random_volume_3) - random_volume_3 = torch.ones(12, 3, 3, device=self.device.torch_device) + ht.maximum(random_volume_1, random_volume_4) + random_volume_5 = torch.ones(12, 3, 3, device=self.device.torch_device) with self.assertRaises(TypeError): - ht.maximum(random_volume_1, random_volume_3) - random_volume_3 = ht.random.randn(6, 3, 3, split=1) + ht.maximum(random_volume_1, random_volume_5) + random_volume_6 = ht.random.randn(6, 3, 3, split=1) with self.assertRaises(NotImplementedError): - ht.maximum(random_volume_1, random_volume_3) - output = torch.ones(12, 3, 3, device=self.device.torch_device) + ht.maximum(random_volume_1, random_volume_6) + output1 = torch.ones(12, 3, 3, device=self.device.torch_device) with self.assertRaises(TypeError): - ht.maximum(random_volume_1, random_volume_2, out=output) - output = ht.ones((12, 4, 3)) + ht.maximum(random_volume_1, random_volume_2, out=output1) + output2 = ht.ones((12, 4, 3)) with self.assertRaises(ValueError): - ht.maximum(random_volume_1, random_volume_2, out=output) - output = ht.ones((6, 3, 3), split=1) + ht.maximum(random_volume_1, random_volume_2, out=output2) + output3 = ht.ones((6, 3, 3), split=1) with self.assertRaises(ValueError): - ht.maximum(random_volume_1, random_volume_2, out=output) + ht.maximum(random_volume_1, random_volume_2, out=output3) def test_mean(self): array_0_len = 5 @@ -1053,7 +1053,7 @@ def test_minimum(self): with self.assertRaises(TypeError): ht.minimum(random_volume_1, random_volume_3) random_volume_3 = np.array(7.2) - with self.assertRaises(NotImplementedError): + with self.assertRaises(TypeError): ht.minimum(random_volume_3, random_volume_1) random_volume_3 = ht.random.randn(6, 3, 3, split=1) with self.assertRaises(NotImplementedError): @@ -1067,6 +1067,12 @@ def test_minimum(self): output = ht.ones((6, 3, 3), split=1) with self.assertRaises(ValueError): ht.minimum(random_volume_1, random_volume_2, out=output) + output = ht.ones((6, 3, 3), split=None, comm=ht.MPI_SELF) + with self.assertRaises(ValueError): + ht.minimum(random_volume_1, random_volume_2, out=output) + output = ht.ones((6, 3, 3), split=0, comm=ht.MPI_SELF) + with self.assertRaises(NotImplementedError): + ht.minimum(random_volume_1, random_volume_2, out=output) def test_percentile(self): # test local, distributed, split/axis combination, no data on process diff --git a/heat/core/types.py b/heat/core/types.py index dc7c72fa2c..8b1421c40f 100644 --- a/heat/core/types.py +++ b/heat/core/types.py @@ -1052,5 +1052,5 @@ def _init(self, dtype: Type[datatype]): return self -# tensor is imported at the very end to break circular dependency +# dndarray is imported at the very end to break circular dependency from . import dndarray