diff --git a/heat/core/operations.py b/heat/core/operations.py index 64dfed92da..c422c5a1fb 100644 --- a/heat/core/operations.py +++ b/heat/core/operations.py @@ -11,6 +11,7 @@ __all__ = [] + def __binary_op(operation, t1, t2): """ Generic wrapper for element-wise binary operations of two operands (either can be tensor or scalar). @@ -99,7 +100,7 @@ def __binary_op(operation, t1, t2): warnings.warn('Broadcasting requires transferring data of second operator between MPI ranks!') if t2.comm.rank > 0: t2._DNDarray__array = torch.zeros(t2.shape, dtype=t2.dtype.torch_type()) - t2.comm.Bcast(t2) + t2.comm.Bcast(t2) else: raise TypeError('Only tensors and numeric scalars are supported, but input was {}'.format(type(t2))) @@ -110,7 +111,6 @@ def __binary_op(operation, t1, t2): else: raise NotImplementedError('Not implemented for non scalar') - promoted_type = types.promote_types(t1.dtype, t2.dtype).torch_type() if t1.split is not None: if len(t1.lshape) > t1.split and t1.lshape[t1.split] == 0: @@ -118,7 +118,7 @@ def __binary_op(operation, t1, t2): else: result = operation(t1._DNDarray__array.type(promoted_type), t2._DNDarray__array.type(promoted_type)) elif t2.split is not None: - + if len(t2.lshape) > t2.split and t2.lshape[t2.split] == 0: result = t2._DNDarray__array.type(promoted_type) else: diff --git a/heat/core/statistics.py b/heat/core/statistics.py index 7f25232f3d..8d2db876c1 100644 --- a/heat/core/statistics.py +++ b/heat/core/statistics.py @@ -8,6 +8,7 @@ from . import operations from . import dndarray from . import types +from . import stride_tricks __all__ = [ @@ -16,6 +17,7 @@ 'max', 'mean', 'min', + 'minimum', 'std', 'var' ] @@ -524,6 +526,7 @@ def min(x, axis=None, out=None, keepdim=None): [ 7.], [10.]]) """ + def local_min(*args, **kwargs): result = torch.min(*args, **kwargs) if isinstance(result, tuple): @@ -533,6 +536,141 @@ def local_min(*args, **kwargs): return operations.__reduce_op(x, local_min, MPI.MIN, axis=axis, out=out, keepdim=keepdim) +def minimum(x1, x2, out=None, **kwargs): + ''' + Compares two tensors and returns a new tensor containing the element-wise minima. + If one of the elements being compared is a NaN, then that element is returned. TODO: Check this: If both elements are NaNs then the first is returned. + The latter distinction is important for complex NaNs, which are defined as at least one of the real or imaginary parts being a NaN. The net effect is that NaNs are propagated. + + Parameters: + ----------- + + x1, x2 : ht.DNDarray + The tensors containing the elements to be compared. They must have the same shape, or shapes that can be broadcast to a single shape. + For broadcasting semantics, see: https://pytorch.org/docs/stable/notes/broadcasting.html + + out : ht.DNDarray or None, optional + A location into which the result is stored. If provided, it must have a shape that the inputs broadcast to. + If not provided or None, a freshly-allocated tensor is returned. + + Returns: + -------- + + minimum: ht.DNDarray + Element-wise minimum of the two input tensors. + + Examples: + --------- + >>> import heat as ht + >>> import torch + >>> torch.manual_seed(1) + + + >>> a = ht.random.randn(3,4) + >>> a + tensor([[-0.1955, -0.9656, 0.4224, 0.2673], + [-0.4212, -0.5107, -1.5727, -0.1232], + [ 3.5870, -1.8313, 1.5987, -1.2770]]) + + >>> b = ht.random.randn(3,4) + >>> b + tensor([[ 0.8310, -0.2477, -0.8029, 0.2366], + [ 0.2857, 0.6898, -0.6331, 0.8795], + [-0.6842, 0.4533, 0.2912, -0.8317]]) + + >>> ht.minimum(a,b) + tensor([[-0.1955, -0.9656, -0.8029, 0.2366], + [-0.4212, -0.5107, -1.5727, -0.1232], + [-0.6842, -1.8313, 0.2912, -1.2770]]) + + >>> c = ht.random.randn(1,4) + >>> c + tensor([[-1.6428, 0.9803, -0.0421, -0.8206]]) + + >>> ht.minimum(a,c) + tensor([[-1.6428, -0.9656, -0.0421, -0.8206], + [-1.6428, -0.5107, -1.5727, -0.8206], + [-1.6428, -1.8313, -0.0421, -1.2770]]) + + >>> b.__setitem__((0,1), ht.nan) + >>> b + tensor([[ 0.8310, nan, -0.8029, 0.2366], + [ 0.2857, 0.6898, -0.6331, 0.8795], + [-0.6842, 0.4533, 0.2912, -0.8317]]) + >>> ht.minimum(a,b) + tensor([[-0.1955, nan, -0.8029, 0.2366], + [-0.4212, -0.5107, -1.5727, -0.1232], + [-0.6842, -1.8313, 0.2912, -1.2770]]) + + >>> d = ht.random.randn(3,4,5) + >>> ht.minimum(a,d) + ValueError: operands could not be broadcast, input shapes (3, 4) (3, 4, 5) + ''' + # perform sanitation + if not isinstance(x1, dndarray.DNDarray) or not isinstance(x2, dndarray.DNDarray): + raise TypeError('expected x1 and x2 to be a ht.DNDarray, but were {}, {} '.format(type(x1), type(x2))) + if out is not None and not isinstance(out, dndarray.DNDarray): + raise TypeError('expected out to be None or an ht.DNDarray, but was {}'.format(type(out))) + + # apply split semantics + if x1.split is not None or x2.split is not None: + if x1.split == None: + x1.resplit(x2.split) + if x2.split == None: + x2.resplit(x1.split) + if x1.split != x2.split: + if np.prod(x1.gshape) < np.prod(x2.gshape): + x1.resplit(x2.split) + if np.prod(x2.gshape) < np.prod(x1.gshape): + x2.resplit(x1.split) + else: + if x1.split < x2.split: + x2.resplit(x1.split) + else: + x1.resplit(x2.split) + split = x1.split + else: + split = None + + # locally: apply torch.min(x1, x2) + output_lshape = stride_tricks.broadcast_shape(x1.lshape, x2.lshape) + lresult = factories.empty(output_lshape) + lresult._DNDarray__array = torch.min(x1._DNDarray__array, x2._DNDarray__array) + lresult._DNDarray__dtype = types.promote_types(x1.dtype, x2.dtype) + lresult._DNDarray__split = split + if x1.split is not None or x2.split is not None: + if x1.comm.is_distributed(): # assuming x1.comm = x2.comm + output_gshape = stride_tricks.broadcast_shape(x1.gshape, x2.gshape) + result = factories.empty(output_gshape) + x1.comm.Allgather(lresult, result) + # TODO: adopt Allgatherv() as soon as it is fixed, Issue #233 + result._DNDarray__dtype = lresult._DNDarray__dtype + result._DNDarray__split = split + + if out is not None: + if out.shape != output_gshape: + raise ValueError('Expecting output buffer of shape {}, got {}'.format(output_gshape, out.shape)) + out._DNDarray__array = result._DNDarray__array + out._DNDarray__dtype = result._DNDarray__dtype + out._DNDarray__split = split + out._DNDarray__device = x1.device + out._DNDarray__comm = x1.comm + + return out + return result + + if out is not None: + if out.shape != output_lshape: + raise ValueError('Expecting output buffer of shape {}, got {}'.format(output_lshape, out.shape)) + out._DNDarray__array = lresult._DNDarray__array + out._DNDarray__dtype = lresult._DNDarray__dtype + out._DNDarray__split = split + out._DNDarray__device = x1.device + out._DNDarray__comm = x1.comm + + return lresult + + def mpi_argmax(a, b, _): lhs = torch.from_numpy(np.frombuffer(a, dtype=np.float64)) rhs = torch.from_numpy(np.frombuffer(b, dtype=np.float64)) @@ -554,7 +692,6 @@ def mpi_argmax(a, b, _): def mpi_argmin(a, b, _): lhs = torch.from_numpy(np.frombuffer(a, dtype=np.float64)) rhs = torch.from_numpy(np.frombuffer(b, dtype=np.float64)) - # extract the values and minimal indices from the buffers (first half are values, second are indices) values = torch.stack((lhs.chunk(2)[0], rhs.chunk(2)[0],), dim=1) indices = torch.stack((lhs.chunk(2)[1], rhs.chunk(2)[1],), dim=1) diff --git a/heat/core/tests/test_statistics.py b/heat/core/tests/test_statistics.py index a4a7b431e2..6bb2ac5b00 100644 --- a/heat/core/tests/test_statistics.py +++ b/heat/core/tests/test_statistics.py @@ -412,6 +412,103 @@ def test_min(self): with self.assertRaises(ValueError): ht.min(ht_array, axis=-4) + def test_minimum(self): + data1 = [ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + [10, 11, 12] + ] + data2 = [ + [0, 3, 2], + [5, 4, 7], + [6, 9, 8], + [9, 10, 11] + ] + + ht_array1 = ht.array(data1) + ht_array2 = ht.array(data2) + comparison1 = torch.tensor(data1) + comparison2 = torch.tensor(data2) + + # check minimum + minimum = ht.minimum(ht_array1, ht_array2) + + self.assertIsInstance(minimum, ht.DNDarray) + self.assertEqual(minimum.shape, (4, 3)) + self.assertEqual(minimum.lshape, (4, 3)) + self.assertEqual(minimum.split, None) + self.assertEqual(minimum.dtype, ht.int64) + self.assertEqual(minimum._DNDarray__array.dtype, torch.int64) + self.assertTrue((minimum._DNDarray__array == torch.min(comparison1, comparison2)).all()) + + # check minimum over float elements of split 3d tensors + # TODO: add check for uneven distribution of dimensions (see Issue #273) + size = ht.MPI_WORLD.size + torch.manual_seed(1) + random_volume_1 = ht.array(ht.random.randn(12, 3, 3), is_split=0) + random_volume_2 = ht.array(ht.random.randn(12, 1, 3), is_split=0) + minimum_volume = ht.minimum(random_volume_1, random_volume_2) + + self.assertIsInstance(minimum_volume, ht.DNDarray) + self.assertEqual(minimum_volume.shape, (size * 12, 3, 3)) + self.assertEqual(minimum_volume.lshape, (size * 12, 3, 3)) + self.assertEqual(minimum_volume.dtype, ht.float32) + self.assertEqual(minimum_volume._DNDarray__array.dtype, torch.float32) + self.assertEqual(minimum_volume.split, random_volume_1.split) + + # check minimum over float elements of split 3d tensors with different split axis + torch.manual_seed(1) + random_volume_1_splitdiff = ht.array(ht.random.randn(size*3, size*3, 4), split=0) + random_volume_2_splitdiff = ht.array(ht.random.randn(size*3, size*3, 4), split=1) + minimum_volume_splitdiff = ht.minimum(random_volume_1_splitdiff, random_volume_2_splitdiff) + self.assertIsInstance(minimum_volume_splitdiff, ht.DNDarray) + self.assertEqual(minimum_volume_splitdiff.shape, (size*3, size*3, 4)) + self.assertEqual(minimum_volume_splitdiff.lshape, (size*3, size*3, 4)) + self.assertEqual(minimum_volume_splitdiff.dtype, ht.float32) + self.assertEqual(minimum_volume_splitdiff._DNDarray__array.dtype, torch.float32) + self.assertEqual(minimum_volume_splitdiff.split, 0) + + random_volume_1_splitdiff = ht.array(ht.random.randn(size*3, size*3, 4), split=1) + random_volume_2_splitdiff = ht.array(ht.random.randn(size*3, size*3, 4), split=0) + minimum_volume_splitdiff = ht.minimum(random_volume_1_splitdiff, random_volume_2_splitdiff) + self.assertEqual(minimum_volume_splitdiff.split, 0) + + random_volume_1_splitNone = ht.array(ht.random.randn(size*3, size*3, 4), split=None) + random_volume_2_splitdiff = ht.array(ht.random.randn(size*3, size*3, 4), split=1) + minimum_volume_splitdiff = ht.minimum(random_volume_1_splitNone, random_volume_2_splitdiff) + self.assertEqual(minimum_volume_splitdiff.split, 1) + + random_volume_1_splitNone = ht.array(ht.random.randn(size*3, size*3, 4), split=0) + random_volume_2_splitdiff = ht.array(ht.random.randn(size*3, size*3, 4), split=None) + minimum_volume_splitdiff = ht.minimum(random_volume_1_splitNone, random_volume_2_splitdiff) + self.assertEqual(minimum_volume_splitdiff.split, 0) + + # check output buffer + out_shape = ht.stride_tricks.broadcast_shape(random_volume_1.gshape, random_volume_2.gshape) + output = ht.empty(out_shape) + ht.minimum(random_volume_1, random_volume_2, out=output) + self.assertIsInstance(output, ht.DNDarray) + self.assertEqual(output.shape, (ht.MPI_WORLD.size * 12, 3, 3)) + self.assertEqual(output.lshape, (ht.MPI_WORLD.size * 12, 3, 3)) + self.assertEqual(output.dtype, ht.float32) + self.assertEqual(output._DNDarray__array.dtype, torch.float32) + self.assertEqual(output.split, random_volume_1.split) + + # check exceptions + random_volume_3 = ht.array(ht.random.randn(4, 2, 3), split=0) + with self.assertRaises(ValueError): + ht.minimum(random_volume_1, random_volume_3) + random_volume_3 = torch.ones(12, 3, 3) + with self.assertRaises(TypeError): + ht.minimum(random_volume_1, random_volume_3) + output = torch.ones(12, 3, 3) + with self.assertRaises(TypeError): + ht.minimum(random_volume_1, random_volume_2, out=output) + output = ht.ones((12, 4, 3)) + with self.assertRaises(ValueError): + ht.minimum(random_volume_1, random_volume_2, out=output) + def test_std(self): # test raises x = ht.zeros((2, 3, 4))