diff --git a/CHANGELOG.md b/CHANGELOG.md index ef8d82c372..d4718a944f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,9 @@ - [#577](https://github.com/helmholtz-analytics/heat/pull/577) Add ndim property in dndarray - [#578](https://github.com/helmholtz-analytics/heat/pull/578) Bugfix: Bad variable in reshape - [#580](https://github.com/helmholtz-analytics/heat/pull/580) New feature: fliplr() -- [#581](https://github.com/helmholtz-analytics/heat/pull/581) New feature: DNDarray.tolist() -- [#593](https://github.com/helmholtz-analytics/heat/pull/593) New feature: arctan2() +- [#581](https://github.com/helmholtz-analytics/heat/pull/581) New Feature: DNDarray.tolist() +- [#583](https://github.com/helmholtz-analytics/heat/pull/583) New feature: rot90() +- [#593](https://github.com/helmholtz-analytics/heat/pull/593) New feature arctan2() - [#594](https://github.com/helmholtz-analytics/heat/pull/594) New feature: Advanced indexing - [#594](https://github.com/helmholtz-analytics/heat/pull/594) Bugfix: getitem and setitem memory consumption heavily reduced - [#596](https://github.com/helmholtz-analytics/heat/pull/596) New feature: outer() diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 90cfa73716..175c11afc3 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -7,6 +7,7 @@ from . import constants from . import dndarray from . import factories +from . import linalg from . import stride_tricks from . import tiling from . import types @@ -25,6 +26,7 @@ "hstack", "reshape", "resplit", + "rot90", "shape", "sort", "squeeze", @@ -940,6 +942,104 @@ def reshape_argsort_counts_displs( return factories.array(data, dtype=a.dtype, is_split=axis, device=a.device, comm=a.comm) +def rot90(m, k=1, axes=(0, 1)): + """ + Rotate an array by 90 degrees in the plane specified by axes. + Rotation direction is from the first towards the second axis. + + Parameters + ---------- + m : DNDarray + Array of two or more dimensions. + k : integer + Number of times the array is rotated by 90 degrees. + axes: (2,) int list or tuple + The array is rotated in the plane defined by the axes. + Axes must be different. + + Returns + ------- + DNDarray + + Notes + ----- + rot90(m, k=1, axes=(1,0)) is the reverse of rot90(m, k=1, axes=(0,1)) + rot90(m, k=1, axes=(1,0)) is equivalent to rot90(m, k=-1, axes=(0,1)) + + May change the split axis on distributed tensors + + Raises + ------ + TypeError + If first parameter is not a :class:DNDarray. + TypeError + If parameter ``k`` is not castable to integer. + ValueError + If ``len(axis)!=2``. + ValueError + If the axes are the same. + ValueError + If axes are out of range. + + Examples + -------- + >>> m = ht.array([[1,2],[3,4]], dtype=ht.int) + >>> m + tensor([[1, 2], + [3, 4]], dtype=torch.int32) + >>> ht.rot90(m) + tensor([[2, 4], + [1, 3]], dtype=torch.int32) + >>> ht.rot90(m, 2) + tensor([[4, 3], + [2, 1]], dtype=torch.int32) + >>> m = ht.arange(8).reshape((2,2,2)) + >>> ht.rot90(m, 1, (1,2)) + tensor([[[1, 3], + [0, 2]], + [[5, 7], + [4, 6]]], dtype=torch.int32) + """ + axes = tuple(axes) + if len(axes) != 2: + raise ValueError("len(axes) must be 2.") + + if not isinstance(m, dndarray.DNDarray): + raise TypeError("expected m to be a ht.DNDarray, but was {}".format(type(m))) + + if axes[0] == axes[1] or np.absolute(axes[0] - axes[1]) == m.ndim: + raise ValueError("Axes must be different.") + + if axes[0] >= m.ndim or axes[0] < -m.ndim or axes[1] >= m.ndim or axes[1] < -m.ndim: + raise ValueError("Axes={} out of range for array of ndim={}.".format(axes, m.ndim)) + + if m.split is None: + return factories.array( + torch.rot90(m._DNDarray__array, k, axes), dtype=m.dtype, device=m.device, comm=m.comm + ) + + try: + k = int(k) + except (TypeError, ValueError): + raise TypeError("Unknown type, must be castable to integer") + + k %= 4 + + if k == 0: + return m.copy() + if k == 2: + return flip(flip(m, axes[0]), axes[1]) + + axes_list = np.arange(0, m.ndim).tolist() + (axes_list[axes[0]], axes_list[axes[1]]) = (axes_list[axes[1]], axes_list[axes[0]]) + + if k == 1: + return linalg.transpose(flip(m, axes[1]), axes_list) + else: + # k == 3 + return flip(linalg.transpose(m, axes_list), axes[1]) + + def shape(a): """ Returns the shape of a DNDarray `a`. diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index ad736e4380..500310362d 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -1069,6 +1069,53 @@ def test_reshape(self): with self.assertRaises(TypeError): ht.reshape(ht.zeros((4, 3)), "(5, 7)") + def test_rot90(self): + size = ht.MPI_WORLD.size + m = ht.arange(size ** 3, dtype=ht.int).reshape((size, size, size)) + + self.assertTrue(ht.equal(ht.rot90(m, 0), m)) + self.assertTrue(ht.equal(ht.rot90(m, 4), m)) + self.assertTrue(ht.equal(ht.rot90(ht.rot90(m, 1), 1, (1, 0)), m)) + + a = ht.resplit(m, 0) + + self.assertTrue(ht.equal(ht.rot90(a, 0), a)) + self.assertTrue(ht.equal(ht.rot90(a), ht.resplit(ht.rot90(m), 1))) + self.assertTrue(ht.equal(ht.rot90(a, 2), ht.resplit(ht.rot90(m, 2), 0))) + self.assertTrue(ht.equal(ht.rot90(a, 3, (1, 2)), ht.resplit(ht.rot90(m, 3, (1, 2)), 0))) + + m = ht.arange(size ** 3, dtype=ht.float).reshape((size, size, size)) + a = ht.resplit(m, 1) + + self.assertTrue(ht.equal(ht.rot90(a, 0), a)) + self.assertTrue(ht.equal(ht.rot90(a), ht.resplit(ht.rot90(m), 0))) + self.assertTrue(ht.equal(ht.rot90(a, 2), ht.resplit(ht.rot90(m, 2), 1))) + self.assertTrue(ht.equal(ht.rot90(a, 3, (1, 2)), ht.resplit(ht.rot90(m, 3, (1, 2)), 2))) + + a = ht.resplit(m, 2) + + self.assertTrue(ht.equal(ht.rot90(a, 0), a)) + self.assertTrue(ht.equal(ht.rot90(a), ht.resplit(ht.rot90(m), 2))) + self.assertTrue(ht.equal(ht.rot90(a, 2), ht.resplit(ht.rot90(m, 2), 2))) + self.assertTrue(ht.equal(ht.rot90(a, 3, (1, 2)), ht.resplit(ht.rot90(m, 3, (1, 2)), 1))) + + with self.assertRaises(ValueError): + ht.rot90(ht.ones((2, 3)), 1, (0, 1, 2)) + with self.assertRaises(TypeError): + ht.rot90(torch.tensor((2, 3))) + with self.assertRaises(ValueError): + ht.rot90(ht.zeros((2, 2)), 1, (0, 0)) + with self.assertRaises(ValueError): + ht.rot90(ht.zeros((2, 2)), 1, (-3, 1)) + with self.assertRaises(ValueError): + ht.rot90(ht.zeros((2, 2)), 1, (4, 1)) + with self.assertRaises(ValueError): + ht.rot90(ht.zeros((2, 2)), 1, (0, -2)) + with self.assertRaises(ValueError): + ht.rot90(ht.zeros((2, 2)), 1, (0, 3)) + with self.assertRaises(TypeError): + ht.rot90(ht.zeros((2, 3)), "k", (0, 1)) + def test_shape(self): x = ht.random.randn(3, 4, 5, split=2) self.assertEqual(ht.shape(x), (3, 4, 5))