Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features rot90 #583

Merged
merged 21 commits into from
Jul 1, 2020
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
- [#573](https://github.com/helmholtz-analytics/heat/pull/573) Bugfix: matmul fixes: early out for 2 vectors, remainders not added if inner block is 1 for split 10 case
- [#577](https://github.com/helmholtz-analytics/heat/pull/577) Add ndim property in dndarray
- [#580](https://github.com/helmholtz-analytics/heat/pull/580) New feature: fliplr()
- [#583](https://github.com/helmholtz-analytics/heat/pull/583) New feature: rot90()


# v0.4.0
Expand Down
82 changes: 82 additions & 0 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from . import stride_tricks
from . import tiling
from . import types
from . import linalg
mtar marked this conversation as resolved.
Show resolved Hide resolved


__all__ = [
Expand All @@ -23,6 +24,7 @@
"hstack",
"reshape",
"resplit",
"rot90",
"sort",
"squeeze",
"unique",
Expand Down Expand Up @@ -936,6 +938,86 @@ def reshape_argsort_counts_displs(
return factories.array(data, dtype=a.dtype, is_split=axis, device=a.device, comm=a.comm)


def rot90(m, k=1, axes=(0, 1)):
Markus-Goetz marked this conversation as resolved.
Show resolved Hide resolved
"""
Rotate an array by 90 degrees in the plane specified by axes.
Rotation direction is from the first towards the second axis.
Parameters
mtar marked this conversation as resolved.
Show resolved Hide resolved
----------
m : DNDarray
Array of two or more dimensions.
k : integer
Number of times the array is rotated by 90 degrees.
axes: (2,) int list or tuple
The array is rotated in the plane defined by the axes.
Axes must be different.

Returns
Markus-Goetz marked this conversation as resolved.
Show resolved Hide resolved
-------
y : DNDarray
A rotated view of `m`.

Notes
-----
rot90(m, k=1, axes=(1,0)) is the reverse of rot90(m, k=1, axes=(0,1))
rot90(m, k=1, axes=(1,0)) is equivalent to rot90(m, k=-1, axes=(0,1))

May change the split axis on distributed tensors

Examples
--------
>>> m = ht.array([[1,2],[3,4]], dtype=ht.int)
>>> m
tensor([[1, 2],
[3, 4]], dtype=torch.int32)
>>> ht.rot90(m)
tensor([[2, 4],
[1, 3]], dtype=torch.int32)
>>> ht.rot90(m, 2)
tensor([[4, 3],
[2, 1]], dtype=torch.int32)
>>> m = ht.arange(8).reshape((2,2,2))
>>> ht.rot90(m, 1, (1,2))
tensor([[[1, 3],
[0, 2]],
[[5, 7],
[4, 6]]], dtype=torch.int32)
"""
axes = tuple(axes)
if len(axes) != 2:
raise ValueError("len(axes) must be 2.")

if not isinstance(m, dndarray.DNDarray):
raise TypeError("expected m to be a ht.DNDarray, but was {}".format(type(m)))

if axes[0] == axes[1] or np.absolute(axes[0] - axes[1]) == m.ndim:
raise ValueError("Axes must be different.")

if axes[0] >= m.ndim or axes[0] < -m.ndim or axes[1] >= m.ndim or axes[1] < -m.ndim:
raise ValueError("Axes={} out of range for array of ndim={}.".format(axes, m.ndim))

if m.split is None:
return factories.array(
torch.rot90(m._DNDarray__array, k, axes), dtype=m.dtype, device=m.device, comm=m.comm
)

k %= 4
mtar marked this conversation as resolved.
Show resolved Hide resolved

if k == 0:
return m[:]
Markus-Goetz marked this conversation as resolved.
Show resolved Hide resolved
if k == 2:
return flip(flip(m, axes[0]), axes[1])
Markus-Goetz marked this conversation as resolved.
Show resolved Hide resolved

axes_list = np.arange(0, m.ndim).tolist()
(axes_list[axes[0]], axes_list[axes[1]]) = (axes_list[axes[1]], axes_list[axes[0]])

if k == 1:
return linalg.transpose(flip(m, axes[1]), axes_list)
else:
# k == 3
return flip(linalg.transpose(m, axes_list), axes[1])


def sort(a, axis=None, descending=False, out=None):
"""
Sorts the elements of the DNDarray a along the given dimension (by default in ascending order) by their value.
Expand Down
45 changes: 45 additions & 0 deletions heat/core/tests/test_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,51 @@ def test_reshape(self):
with self.assertRaises(TypeError):
ht.reshape(ht.zeros((4, 3)), "(5, 7)")

def test_rot90(self):
size = ht.MPI_WORLD.size
m = ht.arange(size ** 3, dtype=ht.int).reshape((size, size, size))

self.assertTrue(ht.equal(ht.rot90(m, 0), m))
self.assertTrue(ht.equal(ht.rot90(m, 4), m))
self.assertTrue(ht.equal(ht.rot90(ht.rot90(m, 1), 1, (1, 0)), m))

a = ht.resplit(m, 0)

self.assertTrue(ht.equal(ht.rot90(a, 0), a))
self.assertTrue(ht.equal(ht.rot90(a), ht.resplit(ht.rot90(m), 1)))
self.assertTrue(ht.equal(ht.rot90(a, 2), ht.resplit(ht.rot90(m, 2), 0)))
self.assertTrue(ht.equal(ht.rot90(a, 3, (1, 2)), ht.resplit(ht.rot90(m, 3, (1, 2)), 0)))

m = ht.arange(size ** 3, dtype=ht.float).reshape((size, size, size))
a = ht.resplit(m, 1)

self.assertTrue(ht.equal(ht.rot90(a, 0), a))
self.assertTrue(ht.equal(ht.rot90(a), ht.resplit(ht.rot90(m), 0)))
self.assertTrue(ht.equal(ht.rot90(a, 2), ht.resplit(ht.rot90(m, 2), 1)))
self.assertTrue(ht.equal(ht.rot90(a, 3, (1, 2)), ht.resplit(ht.rot90(m, 3, (1, 2)), 2)))

a = ht.resplit(m, 2)

self.assertTrue(ht.equal(ht.rot90(a, 0), a))
self.assertTrue(ht.equal(ht.rot90(a), ht.resplit(ht.rot90(m), 2)))
self.assertTrue(ht.equal(ht.rot90(a, 2), ht.resplit(ht.rot90(m, 2), 2)))
self.assertTrue(ht.equal(ht.rot90(a, 3, (1, 2)), ht.resplit(ht.rot90(m, 3, (1, 2)), 1)))

with self.assertRaises(ValueError):
ht.rot90(ht.ones((2, 3)), 1, (0, 1, 2))
with self.assertRaises(TypeError):
ht.rot90(torch.tensor((2, 3)))
with self.assertRaises(ValueError):
ht.rot90(ht.zeros((2, 2)), 1, (0, 0))
with self.assertRaises(ValueError):
ht.rot90(ht.zeros((2, 2)), 1, (-3, 1))
with self.assertRaises(ValueError):
ht.rot90(ht.zeros((2, 2)), 1, (4, 1))
with self.assertRaises(ValueError):
ht.rot90(ht.zeros((2, 2)), 1, (0, -2))
with self.assertRaises(ValueError):
ht.rot90(ht.zeros((2, 2)), 1, (0, 3))

def test_sort(self):
size = ht.MPI_WORLD.size
rank = ht.MPI_WORLD.rank
Expand Down