Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features rot90 #583

Merged
merged 21 commits into from
Jul 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
- [#577](https://github.com/helmholtz-analytics/heat/pull/577) Add ndim property in dndarray
- [#578](https://github.com/helmholtz-analytics/heat/pull/578) Bugfix: Bad variable in reshape
- [#580](https://github.com/helmholtz-analytics/heat/pull/580) New feature: fliplr()
- [#581](https://github.com/helmholtz-analytics/heat/pull/581) New feature: DNDarray.tolist()
- [#593](https://github.com/helmholtz-analytics/heat/pull/593) New feature: arctan2()
- [#581](https://github.com/helmholtz-analytics/heat/pull/581) New Feature: DNDarray.tolist()
- [#583](https://github.com/helmholtz-analytics/heat/pull/583) New feature: rot90()
- [#593](https://github.com/helmholtz-analytics/heat/pull/593) New feature arctan2()
- [#594](https://github.com/helmholtz-analytics/heat/pull/594) New feature: Advanced indexing
- [#594](https://github.com/helmholtz-analytics/heat/pull/594) Bugfix: getitem and setitem memory consumption heavily reduced
- [#596](https://github.com/helmholtz-analytics/heat/pull/596) New feature: outer()
Expand Down
100 changes: 100 additions & 0 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from . import constants
from . import dndarray
from . import factories
from . import linalg
from . import stride_tricks
from . import tiling
from . import types
Expand All @@ -25,6 +26,7 @@
"hstack",
"reshape",
"resplit",
"rot90",
"shape",
"sort",
"squeeze",
Expand Down Expand Up @@ -940,6 +942,104 @@ def reshape_argsort_counts_displs(
return factories.array(data, dtype=a.dtype, is_split=axis, device=a.device, comm=a.comm)


def rot90(m, k=1, axes=(0, 1)):
Markus-Goetz marked this conversation as resolved.
Show resolved Hide resolved
"""
Rotate an array by 90 degrees in the plane specified by axes.
Rotation direction is from the first towards the second axis.

Parameters
mtar marked this conversation as resolved.
Show resolved Hide resolved
----------
m : DNDarray
Array of two or more dimensions.
k : integer
Number of times the array is rotated by 90 degrees.
axes: (2,) int list or tuple
The array is rotated in the plane defined by the axes.
Axes must be different.

Returns
Markus-Goetz marked this conversation as resolved.
Show resolved Hide resolved
-------
DNDarray

Notes
-----
rot90(m, k=1, axes=(1,0)) is the reverse of rot90(m, k=1, axes=(0,1))
rot90(m, k=1, axes=(1,0)) is equivalent to rot90(m, k=-1, axes=(0,1))

May change the split axis on distributed tensors

Raises
------
TypeError
If first parameter is not a :class:DNDarray.
TypeError
If parameter ``k`` is not castable to integer.
ValueError
If ``len(axis)!=2``.
ValueError
If the axes are the same.
ValueError
If axes are out of range.

Examples
--------
>>> m = ht.array([[1,2],[3,4]], dtype=ht.int)
>>> m
tensor([[1, 2],
[3, 4]], dtype=torch.int32)
>>> ht.rot90(m)
tensor([[2, 4],
[1, 3]], dtype=torch.int32)
>>> ht.rot90(m, 2)
tensor([[4, 3],
[2, 1]], dtype=torch.int32)
>>> m = ht.arange(8).reshape((2,2,2))
>>> ht.rot90(m, 1, (1,2))
tensor([[[1, 3],
[0, 2]],
[[5, 7],
[4, 6]]], dtype=torch.int32)
"""
axes = tuple(axes)
if len(axes) != 2:
raise ValueError("len(axes) must be 2.")

if not isinstance(m, dndarray.DNDarray):
raise TypeError("expected m to be a ht.DNDarray, but was {}".format(type(m)))

if axes[0] == axes[1] or np.absolute(axes[0] - axes[1]) == m.ndim:
raise ValueError("Axes must be different.")

if axes[0] >= m.ndim or axes[0] < -m.ndim or axes[1] >= m.ndim or axes[1] < -m.ndim:
raise ValueError("Axes={} out of range for array of ndim={}.".format(axes, m.ndim))

if m.split is None:
return factories.array(
torch.rot90(m._DNDarray__array, k, axes), dtype=m.dtype, device=m.device, comm=m.comm
)

try:
k = int(k)
except (TypeError, ValueError):
raise TypeError("Unknown type, must be castable to integer")

k %= 4
mtar marked this conversation as resolved.
Show resolved Hide resolved

if k == 0:
return m.copy()
if k == 2:
return flip(flip(m, axes[0]), axes[1])
Markus-Goetz marked this conversation as resolved.
Show resolved Hide resolved

axes_list = np.arange(0, m.ndim).tolist()
(axes_list[axes[0]], axes_list[axes[1]]) = (axes_list[axes[1]], axes_list[axes[0]])

if k == 1:
return linalg.transpose(flip(m, axes[1]), axes_list)
else:
# k == 3
return flip(linalg.transpose(m, axes_list), axes[1])


def shape(a):
"""
Returns the shape of a DNDarray `a`.
Expand Down
47 changes: 47 additions & 0 deletions heat/core/tests/test_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,53 @@ def test_reshape(self):
with self.assertRaises(TypeError):
ht.reshape(ht.zeros((4, 3)), "(5, 7)")

def test_rot90(self):
size = ht.MPI_WORLD.size
m = ht.arange(size ** 3, dtype=ht.int).reshape((size, size, size))

self.assertTrue(ht.equal(ht.rot90(m, 0), m))
self.assertTrue(ht.equal(ht.rot90(m, 4), m))
self.assertTrue(ht.equal(ht.rot90(ht.rot90(m, 1), 1, (1, 0)), m))

a = ht.resplit(m, 0)

self.assertTrue(ht.equal(ht.rot90(a, 0), a))
self.assertTrue(ht.equal(ht.rot90(a), ht.resplit(ht.rot90(m), 1)))
self.assertTrue(ht.equal(ht.rot90(a, 2), ht.resplit(ht.rot90(m, 2), 0)))
self.assertTrue(ht.equal(ht.rot90(a, 3, (1, 2)), ht.resplit(ht.rot90(m, 3, (1, 2)), 0)))

m = ht.arange(size ** 3, dtype=ht.float).reshape((size, size, size))
a = ht.resplit(m, 1)

self.assertTrue(ht.equal(ht.rot90(a, 0), a))
self.assertTrue(ht.equal(ht.rot90(a), ht.resplit(ht.rot90(m), 0)))
self.assertTrue(ht.equal(ht.rot90(a, 2), ht.resplit(ht.rot90(m, 2), 1)))
self.assertTrue(ht.equal(ht.rot90(a, 3, (1, 2)), ht.resplit(ht.rot90(m, 3, (1, 2)), 2)))

a = ht.resplit(m, 2)

self.assertTrue(ht.equal(ht.rot90(a, 0), a))
self.assertTrue(ht.equal(ht.rot90(a), ht.resplit(ht.rot90(m), 2)))
self.assertTrue(ht.equal(ht.rot90(a, 2), ht.resplit(ht.rot90(m, 2), 2)))
self.assertTrue(ht.equal(ht.rot90(a, 3, (1, 2)), ht.resplit(ht.rot90(m, 3, (1, 2)), 1)))

with self.assertRaises(ValueError):
ht.rot90(ht.ones((2, 3)), 1, (0, 1, 2))
with self.assertRaises(TypeError):
ht.rot90(torch.tensor((2, 3)))
with self.assertRaises(ValueError):
ht.rot90(ht.zeros((2, 2)), 1, (0, 0))
with self.assertRaises(ValueError):
ht.rot90(ht.zeros((2, 2)), 1, (-3, 1))
with self.assertRaises(ValueError):
ht.rot90(ht.zeros((2, 2)), 1, (4, 1))
with self.assertRaises(ValueError):
ht.rot90(ht.zeros((2, 2)), 1, (0, -2))
with self.assertRaises(ValueError):
ht.rot90(ht.zeros((2, 2)), 1, (0, 3))
with self.assertRaises(TypeError):
ht.rot90(ht.zeros((2, 3)), "k", (0, 1))

def test_shape(self):
x = ht.random.randn(3, 4, 5, split=2)
self.assertEqual(ht.shape(x), (3, 4, 5))
Expand Down