Skip to content

Commit

Permalink
Add permute_along_dim (#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo committed Apr 18, 2023
1 parent f9bbc33 commit cacd651
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 0 deletions.
52 changes: 52 additions & 0 deletions src/redcat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"get_available_devices",
"get_batch_dims",
"get_seq_dims",
"permute_along_dim",
"swap2",
]

Expand Down Expand Up @@ -246,6 +247,57 @@ def get_seq_dims(args: tuple[Any, ...], kwargs: dict[str, Any]) -> set[int]:
return dims


def permute_along_dim(tensor: Tensor, permutation: Tensor, dim: int = 0) -> Tensor:
r"""Permutes the values of a tensor along a given dimension.
Args:
----
tensor (``torch.Tensor``): Specifies the tensor to permute.
permutation (``torch.Tensor`` of type long and shape
``(dimension,)``): Specifies the permutation to use on the
tensor. The dimension of this tensor should be compatible
with the shape of the tensor to permute.
dim (int, optional): Specifies the dimension used to permute the
tensor. Default: ``0``
Returns:
-------
``torch.Tensor``: The permuted tensor.
Example usage:
.. code-block:: python
>>> from redcat.utils import permute_along_dim
>>> permute_along_dim(tensor=torch.arange(4), permutation=torch.tensor([0, 2, 1, 3]))
tensor([0, 2, 1, 3])
>>> permute_along_dim(
... tensor=torch.arange(20).view(4, 5),
... permutation=torch.tensor([0, 2, 1, 3]),
... )
tensor([[0, 1, 2, 3, 4], [10, 11, 12, 13, 14], [5, 6, 7, 8, 9], [15, 16, 17, 18, 19]])
>>> permute_along_dim(
... tensor=torch.arange(20).view(4, 5),
... permutation=torch.tensor([0, 4, 2, 1, 3]),
... dim=1,
... )
tensor([[ 0, 4, 2, 1, 3],
[ 5, 9, 7, 6, 8],
[10, 14, 12, 11, 13],
[15, 19, 17, 16, 18]])
>>> permute_along_dim(
... tensor=torch.arange(20).view(2, 2, 5),
... permutation=torch.tensor([0, 4, 2, 1, 3]),
... dim=2,
... )
tensor([[[ 0, 4, 2, 1, 3],
[ 5, 9, 7, 6, 8]],
[[10, 14, 12, 11, 13],
[15, 19, 17, 16, 18]]])
"""
return tensor.transpose(0, dim)[permutation].transpose(0, dim).contiguous()


@overload
def swap2(sequence: Tensor, index0: int, index1: int) -> Tensor:
r"""``swap2`` for a ``torch.Tensor``."""
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
get_available_devices,
get_batch_dims,
get_seq_dims,
permute_along_dim,
swap2,
)

Expand Down Expand Up @@ -241,6 +242,43 @@ def test_get_seq_dims_empty() -> None:
assert get_seq_dims(tuple(), dict()) == set()


#######################################
# Tests for permute_along_dim #
#######################################


def test_permute_along_dim_1d() -> None:
assert permute_along_dim(tensor=torch.arange(4), permutation=torch.tensor([0, 2, 1, 3])).equal(
torch.tensor([0, 2, 1, 3])
)


def test_permute_along_dim_2d_dim_0() -> None:
assert permute_along_dim(
tensor=torch.arange(20).view(4, 5), permutation=torch.tensor([0, 2, 1, 3])
).equal(
torch.tensor([[0, 1, 2, 3, 4], [10, 11, 12, 13, 14], [5, 6, 7, 8, 9], [15, 16, 17, 18, 19]])
)


def test_permute_along_dim_2d_dim_1() -> None:
assert permute_along_dim(
tensor=torch.arange(20).view(4, 5), permutation=torch.tensor([0, 4, 2, 1, 3]), dim=1
).equal(
torch.tensor([[0, 4, 2, 1, 3], [5, 9, 7, 6, 8], [10, 14, 12, 11, 13], [15, 19, 17, 16, 18]])
)


def test_permute_along_dim_3d_dim_2() -> None:
assert permute_along_dim(
tensor=torch.arange(20).view(2, 2, 5), permutation=torch.tensor([0, 4, 2, 1, 3]), dim=2
).equal(
torch.tensor(
[[[0, 4, 2, 1, 3], [5, 9, 7, 6, 8]], [[10, 14, 12, 11, 13], [15, 19, 17, 16, 18]]]
)
)


###########################
# Tests for swap2 #
###########################
Expand Down

0 comments on commit cacd651

Please sign in to comment.