diff --git a/CHANGELOG.md b/CHANGELOG.md index 1948b5d7d1..a5d698b5bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,10 +7,11 @@ - [#577](https://github.com/helmholtz-analytics/heat/pull/577) Add ndim property in dndarray - [#578](https://github.com/helmholtz-analytics/heat/pull/578) Bugfix: Bad variable in reshape - [#580](https://github.com/helmholtz-analytics/heat/pull/580) New feature: fliplr() -- [#581](https://github.com/helmholtz-analytics/heat/pull/581) New Feature: DNDarray.tolist() -- [#593](https://github.com/helmholtz-analytics/heat/pull/593) New feature arctan2() +- [#581](https://github.com/helmholtz-analytics/heat/pull/581) New feature: DNDarray.tolist() +- [#593](https://github.com/helmholtz-analytics/heat/pull/593) New feature: arctan2() - [#594](https://github.com/helmholtz-analytics/heat/pull/594) New feature: Advanced indexing - [#594](https://github.com/helmholtz-analytics/heat/pull/594) Bugfix: getitem and setitem memory consumption heavily reduced +- [#596](https://github.com/helmholtz-analytics/heat/pull/596) New feature: outer() - [#600](https://github.com/helmholtz-analytics/heat/pull/600) New feature: shape() # v0.4.0 diff --git a/heat/core/linalg/basics.py b/heat/core/linalg/basics.py index f8538ee47b..1f2337220b 100644 --- a/heat/core/linalg/basics.py +++ b/heat/core/linalg/basics.py @@ -9,7 +9,7 @@ from .. import manipulations from .. import types -__all__ = ["dot", "matmul", "norm", "projection", "transpose", "tril", "triu"] +__all__ = ["dot", "matmul", "norm", "outer", "projection", "transpose", "tril", "triu"] def dot(a, b, out=None): @@ -827,6 +827,241 @@ def norm(a): return exponential.sqrt(d).item() +def outer(a, b, out=None, split=None): + """ + Compute the outer product of two 1-D DNDarrays. + + Given two vectors, :math:`a = (a_0, a_1, ..., a_N)` and :math:`b = (b_0, b_1, ..., b_M)`, the outer product is: + + .. math:: + :nowrap: + + \\begin{pmatrix} + a_0 \\cdot b_0 & a_0 \\cdot b_1 & . & . & a_0 \\cdot b_M \\ + a_1 \\cdot b_0 & a_1 \\cdot b_1 & . & . & a_1 \\cdot b_M \\ + . & . & . & . & . \\ + a_N \\cdot b_0 & a_N \\cdot b_1 & . & . & a_N \\cdot b_M + \\end{pmatrix} + + Parameters + ---------- + + a : DNDarray + 1-dimensional: :math: `N` + Will be flattened by default if more than 1-D. + + b : DNDarray + 1-dimensional: :math: `M` + Will be flattened by default if more than 1-D. + + out : DNDarray, optional + 2-dimensional: :math: `N \\times M` + A location where the result is stored + + split : int, optional + Split dimension of the resulting DNDarray. Can be 0, 1, or None. + This is only relevant if the calculations are memory-distributed, + in which case default is ``split=0`` (see Note). + + Note: parallel implementation of outer product, arrays are dense. + In the classical (dense) case, one DNDarray stays put, the other one is passed around the ranks in + ring communication. The slice-by-slice outer product is calculated locally (here via torch.einsum()). + N.B.: if ``b`` is sent around, the resulting outer product is split along the rows dimension (``split = 0``). + if ``a`` is sent around, the resulting outer product is split along the columns (``split = 1``). + So if ``split`` is not None, ``split`` defines which DNDarray stays put and which one is passed around. No + communication is needed beyond ring communication of one of the DNDarrays. + If ``split`` is None or unspecified, the result will be distributed along axis 0, i.e. by default ``b`` is + passed around, ``a`` stays put. + + Returns + ------- + + out(n, m): DNDarray + + out[i, j] = a[i] * b[j] + + Examples + -------- + >>> a = ht.arange(4) + >>> b = ht.arange(3) + >>> ht.outer(a, b) + (3 processes) + (0/3) tensor([[0, 0, 0], + [0, 1, 2], + [0, 2, 4], + [0, 3, 6]], dtype=torch.int32) + (1/3) tensor([[0, 0, 0], + [0, 1, 2], + [0, 2, 4], + [0, 3, 6]], dtype=torch.int32) + (2/3) tensor([[0, 0, 0], + [0, 1, 2], + [0, 2, 4], + [0, 3, 6]], dtype=torch.int32) + >>> a = ht.arange(4, split=0) + >>> b = ht.arange(3, split=0) + >>> ht.outer(a, b) + (0/3) tensor([[0, 0, 0], + [0, 1, 2]], dtype=torch.int32) + (1/3) tensor([[0, 2, 4]], dtype=torch.int32) + (2/3) tensor([[0, 3, 6]], dtype=torch.int32) + >>> ht.outer(a, b, split=1) + (0/3) tensor([[0], + [0], + [0], + [0]], dtype=torch.int32) + (1/3) tensor([[0], + [1], + [2], + [3]], dtype=torch.int32) + (2/3) tensor([[0], + [2], + [4], + [6]], dtype=torch.int32) + >>> a = ht.arange(5, dtype=ht.float32, split=0) + >>> b = ht.arange(4, dtype=ht.float64, split=0) + >>> out = ht.empty((5,4), dtype=ht.float64, split=1) + >>> ht.outer(a, b, split=1, out=out) + >>> out + (0/3) tensor([[0., 0.], + [0., 1.], + [0., 2.], + [0., 3.], + [0., 4.]], dtype=torch.float64) + (1/3) tensor([[0.], + [2.], + [4.], + [6.], + [8.]], dtype=torch.float64) + (2/3) tensor([[ 0.], + [ 3.], + [ 6.], + [ 9.], + [12.]], dtype=torch.float64) + """ + # sanitize input + if not isinstance(a, dndarray.DNDarray) or not isinstance(b, dndarray.DNDarray): + raise TypeError( + "a, b must be of type ht.DNDarray, but were {}, {}".format(type(a), type(b)) + ) + + # sanitize dimensions + # TODO move to sanitation module #468 + if a.ndim > 1: + a = manipulations.flatten(a) + if b.ndim > 1: + b = manipulations.flatten(b) + if a.ndim == 0 or b.ndim == 0: + raise RuntimeError( + "a, b must be 1-D DNDarrays, but were {}-D and {}-D".format(a.ndim, b.ndim) + ) + + outer_gshape = (a.gshape[0], b.gshape[0]) + t_a = a._DNDarray__array + t_b = b._DNDarray__array + t_outer_dtype = torch.promote_types(t_a.dtype, t_b.dtype) + t_a, t_b = t_a.type(t_outer_dtype), t_b.type(t_outer_dtype) + outer_dtype = types.canonical_heat_type(t_outer_dtype) + + if out is not None: + if not isinstance(out, dndarray.DNDarray): + raise TypeError("out must be of type ht.DNDarray, was {}".format(type(out))) + if out.dtype is not outer_dtype: + raise TypeError( + "Wrong datatype for out: expected {}, got {}".format(outer_dtype, out.dtype) + ) + if out.gshape != outer_gshape: + raise ValueError("out must have shape {}, got {}".format(outer_gshape, out.gshape)) + if out.split is not split: + raise ValueError( + "Split dimension mismatch for out: expected {}, got {}".format(split, out.split) + ) + + # distributed outer product, dense arrays (TODO: sparse, #384) + if a.comm.is_distributed() and split is not None or a.split is not None or b.split is not None: + # MPI coordinates + rank = a.comm.rank + size = a.comm.size + t_outer_slice = 2 * [slice(None, None, None)] + + if a.split is None: + a.resplit_(axis=0) + t_a = a._DNDarray__array.type(t_outer_dtype) + if b.split is None: + b.resplit_(axis=0) + t_b = b._DNDarray__array.type(t_outer_dtype) + if split is None: + # Split semantics: default out.split = a.split + split = a.split + if out is not None and out.split is None: + out.resplit_(axis=split) + + # calculate local slice of outer product + if split == 0: + lshape_map = b.create_lshape_map() + t_outer_shape = (a.lshape[0], b.gshape[0]) + _, _, local_slice = b.comm.chunk(b.gshape, b.split) + t_outer_slice[1] = local_slice[0] + elif split == 1: + lshape_map = a.create_lshape_map() + t_outer_shape = (a.gshape[0], b.lshape[0]) + _, _, local_slice = a.comm.chunk(a.gshape, a.split) + t_outer_slice[0] = local_slice[0] + t_outer = torch.zeros(t_outer_shape, dtype=t_outer_dtype, device=t_a.device) + if lshape_map[rank] != 0: + t_outer[t_outer_slice] = torch.einsum("i,j->ij", t_a, t_b) + + # Ring: fill in missing slices of outer product + # allocate memory for traveling data + if split == 0: + t_b_run = torch.empty(lshape_map[0], dtype=t_outer_dtype, device=t_a.device) + elif split == 1: + t_a_run = torch.empty(lshape_map[0], dtype=t_outer_dtype, device=t_b.device) + + for p in range(size - 1): + # prepare for sending + dest_rank = rank + 1 if rank != size - 1 else 0 + # prepare for receiving + origin_rank = rank - 1 if rank != 0 else size - 1 + actual_origin = origin_rank - p + if origin_rank < p: + actual_origin += size + # blocking send and recv + if split == 0: + b.comm.Send(t_b, dest_rank) + b.comm.Recv(t_b_run, origin_rank) + # buffer from actual_origin could be smaller than allocated buffer + t_b = t_b_run[: lshape_map[actual_origin]] + _, _, remote_slice = b.comm.chunk( + b.gshape, b.split, rank=actual_origin, w_size=size + ) + t_outer_slice[1] = remote_slice[0] + elif split == 1: + a.comm.Send(t_a, dest_rank) + a.comm.Recv(t_a_run, origin_rank) + # buffer from actual_origin could be smaller than allocated buffer + t_a = t_a_run[: lshape_map[actual_origin]] + _, _, remote_slice = a.comm.chunk( + a.gshape, a.split, rank=actual_origin, w_size=size + ) + t_outer_slice[0] = remote_slice[0] + t_outer[t_outer_slice] = torch.einsum("i,j->ij", t_a, t_b) + else: + # outer product, all local + t_outer = torch.einsum("i,j->ij", t_a, t_b) + split = None + + outer = dndarray.DNDarray( + t_outer, gshape=outer_gshape, dtype=outer_dtype, split=split, device=a.device, comm=a.comm + ) + + if out is not None: + out._DNDarray__array = outer._DNDarray__array + return out + + return outer + + def projection(a, b): """ Projection of vector a onto vector b diff --git a/heat/core/linalg/tests/test_basics.py b/heat/core/linalg/tests/test_basics.py index 7d32ea2c10..7bb43db169 100644 --- a/heat/core/linalg/tests/test_basics.py +++ b/heat/core/linalg/tests/test_basics.py @@ -479,6 +479,81 @@ def test_norm(self): c = np.arange(9) - 4 ht.linalg.norm(c) + def test_outer(self): + # test outer, a and b local, different dtypes + a = ht.arange(3, dtype=ht.int32) + b = ht.arange(8, dtype=ht.float32) + ht_outer = ht.outer(a, b, split=None) + np_outer = np.outer(a.numpy(), b.numpy()) + t_outer = torch.einsum("i,j->ij", a._DNDarray__array, b._DNDarray__array) + self.assertTrue((ht_outer.numpy() == np_outer).all()) + self.assertTrue(ht_outer._DNDarray__array.dtype is t_outer.dtype) + + # test outer, a and b distributed, no data on some ranks + a_split = ht.arange(3, dtype=ht.float32, split=0) + b_split = ht.arange(8, dtype=ht.float32, split=0) + ht_outer_split = ht.outer(a_split, b_split, split=None) + + # a and b split 0, outer split 1 + ht_outer_split = ht.outer(a_split, b_split, split=1) + self.assertTrue((ht_outer_split.numpy() == np_outer).all()) + self.assertTrue(ht_outer_split.split == 1) + + # a and b distributed, outer split unspecified + ht_outer_split = ht.outer(a_split, b_split, split=None) + self.assertTrue((ht_outer_split.numpy() == np_outer).all()) + self.assertTrue(ht_outer_split.split == 0) + + # a not distributed, outer.split = 1 + ht_outer_split = ht.outer(a, b_split, split=1) + self.assertTrue((ht_outer_split.numpy() == np_outer).all()) + self.assertTrue(ht_outer_split.split == 1) + + # b not distributed, outer.split = 0 + ht_outer_split = ht.outer(a_split, b, split=0) + self.assertTrue((ht_outer_split.numpy() == np_outer).all()) + self.assertTrue(ht_outer_split.split == 0) + + # a_split.ndim > 1 and a.split != 0 + a_split_3d = ht.random.randn(3, 3, 3, dtype=ht.float64, split=2) + ht_outer_split = ht.outer(a_split_3d, b_split) + np_outer_3d = np.outer(a_split_3d.numpy(), b_split.numpy()) + self.assertTrue((ht_outer_split.numpy() == np_outer_3d).all()) + self.assertTrue(ht_outer_split.split == 0) + + # write to out buffer + ht_out = ht.empty((a.gshape[0], b.gshape[0]), dtype=ht.float32) + ht.outer(a, b, out=ht_out) + self.assertTrue((ht_out.numpy() == np_outer).all()) + ht_out_split = ht.empty((a_split.gshape[0], b_split.gshape[0]), dtype=ht.float32, split=1) + ht.outer(a_split, b_split, out=ht_out_split, split=1) + self.assertTrue((ht_out_split.numpy() == np_outer).all()) + + # test exceptions + t_a = torch.arange(3) + with self.assertRaises(TypeError): + ht.outer(t_a, b) + np_b = np.arange(8) + with self.assertRaises(TypeError): + ht.outer(a, np_b) + a_0d = ht.array(2.3) + with self.assertRaises(RuntimeError): + ht.outer(a_0d, b) + t_out = torch.empty((a.gshape[0], b.gshape[0]), dtype=torch.float32) + with self.assertRaises(TypeError): + ht.outer(a, b, out=t_out) + ht_out_wrong_dtype = ht.empty((a.gshape[0], b.gshape[0]), dtype=ht.float64) + with self.assertRaises(TypeError): + ht.outer(a, b, out=ht_out_wrong_dtype) + ht_out_wrong_shape = ht.empty((7, b.gshape[0]), dtype=ht.float32) + with self.assertRaises(ValueError): + ht.outer(a, b, out=ht_out_wrong_shape) + ht_out_wrong_split = ht.empty( + (a_split.gshape[0], b_split.gshape[0]), dtype=ht.float32, split=1 + ) + with self.assertRaises(ValueError): + ht.outer(a_split, b_split, out=ht_out_wrong_split, split=0) + def test_projection(self): a = ht.arange(1, 4, dtype=ht.float32, split=None) e1 = ht.array([1, 0, 0], dtype=ht.float32, split=None)