Skip to content

Commit

Permalink
Add argsort of BatchedTensor(Seq) (#385)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo committed Aug 12, 2023
1 parent c6ae613 commit 6d75df7
Show file tree
Hide file tree
Showing 5 changed files with 322 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redcat"
version = "0.0.8a55"
version = "0.0.8a56"
description = "A library to manipulate batches of examples"
readme = "README.md"
authors = ["Thibaut Durand <durand.tibo+gh@gmail.com>"]
Expand Down
80 changes: 80 additions & 0 deletions src/redcat/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,6 +1350,84 @@ def sub_(
# Mathematical | advanced arithmetical operations #
###########################################################

def argsort(
self,
dim: int = -1,
descending: bool = False,
stable: bool = False,
) -> TBatchedTensor:
r"""Returns the indices that sort the batch along a given
dimension in monotonic order by value.
Args:
----
dim (int, optional): Specifies the dimension to sort along.
Default: ``-1``
descending (bool, optional): Controls the sorting order.
If ``True``, the elements are sorted in descending
order by value. Default: ``False``
stable (bool, optional): Makes the sorting routine stable,
which guarantees that the order of equivalent elements
is preserved. Default: ``False``
Returns:
-------
``BatchedTensor``: The indices that sort the batch along
the given dimension.
Example usage:
.. code-block:: pycon
>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.arange(10).view(2, 5))
>>> batch.argsort(descending=True)
tensor([[4, 3, 2, 1, 0],
[4, 3, 2, 1, 0]], batch_dim=0)
"""
return self._create_new_batch(
self._data.argsort(dim=dim, descending=descending, stable=stable)
)

def argsort_along_batch(
self,
descending: bool = False,
stable: bool = False,
) -> TBatchedTensor:
r"""Sorts the elements of the batch along the batch dimension in
monotonic order by value.
Args:
----
descending (bool, optional): Controls the sorting order.
If ``True``, the elements are sorted in descending
order by value. Default: ``False``
stable (bool, optional): Makes the sorting routine stable,
which guarantees that the order of equivalent elements
is preserved. Default: ``False``
Returns:
-------
``BatchedTensor``: The indices that sort the batch along
the batch dimension.
Example usage:
.. code-block:: pycon
>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.arange(10).view(5, 2))
>>> batch.argsort_along_batch(descending=True)
tensor([[4, 4],
[3, 3],
[2, 2],
[1, 1],
[0, 0]], batch_dim=0)
"""
return self.argsort(dim=self._batch_dim, descending=descending, stable=stable)

def cumsum(self, dim: int, **kwargs) -> TBatchedTensor:
r"""Computes the cumulative sum of elements of the current batch
in a given dimension.
Expand Down Expand Up @@ -1685,6 +1763,8 @@ def sort(
Args:
----
dim (int, optional): Specifies the dimension to sort along.
Default: ``-1``
descending (bool, optional): Controls the sorting order.
If ``True``, the elements are sorted in descending
order by value. Default: ``False``
Expand Down
35 changes: 35 additions & 0 deletions src/redcat/tensorseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,41 @@ def sub_(
# Mathematical | advanced arithmetical operations #
###########################################################

def argsort_along_seq(
self,
descending: bool = False,
stable: bool = False,
) -> BatchedTensorSeq:
r"""Sorts the elements of the batch along the sequence dimension
in monotonic order by value.
Args:
----
descending (bool, optional): Controls the sorting order.
If ``True``, the elements are sorted in descending
order by value. Default: ``False``
stable (bool, optional): Makes the sorting routine stable,
which guarantees that the order of equivalent elements
is preserved. Default: ``False``
Returns:
-------
``BatchedTensor``: The indices that sort the batch along
the sequence dimension.
Example usage:
.. code-block:: pycon
>>> import torch
>>> from redcat import BatchedTensorSeq
>>> batch = BatchedTensorSeq(torch.arange(10).view(2, 5))
>>> batch.argsort_along_seq(descending=True)
tensor([[4, 3, 2, 1, 0],
[4, 3, 2, 1, 0]], batch_dim=0, seq_dim=1)
"""
return self.argsort(dim=self._seq_dim, descending=descending, stable=stable)

def cumsum_along_seq(self, **kwargs) -> BatchedTensorSeq:
r"""Computes the cumulative sum of elements of the current batch
in the sequence dimension.
Expand Down
76 changes: 76 additions & 0 deletions tests/unit/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1536,6 +1536,82 @@ def test_batched_tensor_sub__incorrect_batch_dim() -> None:
###########################################################


def test_batched_tensor_argsort_descending_false() -> None:
assert objects_are_equal(
BatchedTensor(torch.tensor([[4, 1, 2, 5, 3], [9, 7, 5, 6, 8]])).argsort(descending=False),
BatchedTensor(torch.tensor([[1, 2, 4, 0, 3], [2, 3, 1, 4, 0]])),
)


def test_batched_tensor_argsort_descending_true() -> None:
assert objects_are_equal(
BatchedTensor(torch.tensor([[4, 1, 2, 5, 3], [9, 7, 5, 6, 8]])).argsort(descending=True),
BatchedTensor(torch.tensor([[3, 0, 4, 2, 1], [0, 4, 1, 3, 2]])),
)


def test_batched_tensor_argsort_dim_0() -> None:
assert objects_are_equal(
BatchedTensor(torch.tensor([[4, 9], [1, 7], [2, 5], [5, 6], [3, 8]])).argsort(dim=0),
BatchedTensor(torch.tensor([[1, 2], [2, 3], [4, 1], [0, 4], [3, 0]])),
)


def test_batched_tensor_argsort_dim_1() -> None:
assert objects_are_equal(
BatchedTensor(
torch.tensor(
[
[[0, 1], [-2, 3], [-4, 5], [-6, 7], [-8, 9]],
[[10, -11], [12, -13], [14, -15], [16, -17], [18, -19]],
]
)
).argsort(dim=1),
BatchedTensor(
torch.tensor(
[
[[4, 0], [3, 1], [2, 2], [1, 3], [0, 4]],
[[0, 4], [1, 3], [2, 2], [3, 1], [4, 0]],
]
)
),
)


def test_batched_tensor_argsort_custom_dims() -> None:
assert objects_are_equal(
BatchedTensor(torch.tensor([[4, 9], [1, 7], [2, 5], [5, 6], [3, 8]]), batch_dim=1).argsort(
dim=0
),
BatchedTensor(torch.tensor([[1, 2], [2, 3], [4, 1], [0, 4], [3, 0]]), batch_dim=1),
)


def test_batched_tensor_argsort_along_batch_descending_false() -> None:
assert objects_are_equal(
BatchedTensor(torch.tensor([[4, 9], [1, 7], [2, 5], [5, 6], [3, 8]])).argsort_along_batch(),
BatchedTensor(torch.tensor([[1, 2], [2, 3], [4, 1], [0, 4], [3, 0]])),
)


def test_batched_tensor_argsort_along_batch_descending_true() -> None:
assert objects_are_equal(
BatchedTensor(torch.tensor([[4, 9], [1, 7], [2, 5], [5, 6], [3, 8]])).argsort_along_batch(
descending=True
),
BatchedTensor(torch.tensor([[3, 0], [0, 4], [4, 1], [2, 3], [1, 2]])),
)


def test_batched_tensor_argsort_along_batch_custom_dims() -> None:
assert objects_are_equal(
BatchedTensor(
torch.tensor([[4, 1, 2, 5, 3], [9, 7, 5, 6, 8]]), batch_dim=1
).argsort_along_batch(),
BatchedTensor(torch.tensor([[1, 2, 4, 0, 3], [2, 3, 1, 4, 0]]), batch_dim=1),
)


def test_batched_tensor_cumsum() -> None:
assert (
BatchedTensor(torch.arange(10).view(2, 5))
Expand Down
130 changes: 130 additions & 0 deletions tests/unit/test_tensorseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1762,6 +1762,136 @@ def test_batched_tensor_seq_sub__incorrect_seq_dim() -> None:
###########################################################


def test_batched_tensor_seq_argsort_descending_false() -> None:
assert objects_are_equal(
BatchedTensorSeq(torch.tensor([[4, 1, 2, 5, 3], [9, 7, 5, 6, 8]])).argsort(
descending=False
),
BatchedTensorSeq(torch.tensor([[1, 2, 4, 0, 3], [2, 3, 1, 4, 0]])),
)


def test_batched_tensor_seq_argsort_descending_true() -> None:
assert objects_are_equal(
BatchedTensorSeq(torch.tensor([[4, 1, 2, 5, 3], [9, 7, 5, 6, 8]])).argsort(descending=True),
BatchedTensorSeq(torch.tensor([[3, 0, 4, 2, 1], [0, 4, 1, 3, 2]])),
)


def test_batched_tensor_seq_argsort_dim_0() -> None:
assert objects_are_equal(
BatchedTensorSeq(torch.tensor([[4, 9], [1, 7], [2, 5], [5, 6], [3, 8]])).argsort(dim=0),
BatchedTensorSeq(torch.tensor([[1, 2], [2, 3], [4, 1], [0, 4], [3, 0]])),
)


def test_batched_tensor_seq_argsort_dim_1() -> None:
assert objects_are_equal(
BatchedTensorSeq(
torch.tensor(
[
[[0, 1], [-2, 3], [-4, 5], [-6, 7], [-8, 9]],
[[10, -11], [12, -13], [14, -15], [16, -17], [18, -19]],
]
)
).argsort(dim=1),
BatchedTensorSeq(
torch.tensor(
[
[[4, 0], [3, 1], [2, 2], [1, 3], [0, 4]],
[[0, 4], [1, 3], [2, 2], [3, 1], [4, 0]],
]
)
),
)


def test_batched_tensor_seq_argsort_custom_dims() -> None:
assert objects_are_equal(
BatchedTensorSeq(
torch.tensor([[4, 9], [1, 7], [2, 5], [5, 6], [3, 8]]), seq_dim=0, batch_dim=1
).argsort(dim=0),
BatchedTensorSeq(
torch.tensor([[1, 2], [2, 3], [4, 1], [0, 4], [3, 0]]), seq_dim=0, batch_dim=1
),
)


def test_batched_tensor_seq_argsort_along_batch_descending_false() -> None:
assert objects_are_equal(
BatchedTensorSeq(
torch.tensor([[4, 9], [1, 7], [2, 5], [5, 6], [3, 8]])
).argsort_along_batch(),
BatchedTensorSeq(torch.tensor([[1, 2], [2, 3], [4, 1], [0, 4], [3, 0]])),
)


def test_batched_tensor_seq_argsort_along_batch_descending_true() -> None:
assert objects_are_equal(
BatchedTensorSeq(
torch.tensor([[4, 9], [1, 7], [2, 5], [5, 6], [3, 8]])
).argsort_along_batch(descending=True),
BatchedTensorSeq(torch.tensor([[3, 0], [0, 4], [4, 1], [2, 3], [1, 2]])),
)


def test_batched_tensor_seq_argsort_along_batch_custom_dims() -> None:
assert objects_are_equal(
BatchedTensorSeq(
torch.tensor([[4, 1, 2, 5, 3], [9, 7, 5, 6, 8]]), seq_dim=0, batch_dim=1
).argsort_along_batch(),
BatchedTensorSeq(torch.tensor([[1, 2, 4, 0, 3], [2, 3, 1, 4, 0]]), seq_dim=0, batch_dim=1),
)


def test_batched_tensor_seq_argsort_along_seq_descending_false() -> None:
assert objects_are_equal(
BatchedTensorSeq(torch.tensor([[4, 1, 2, 5, 3], [9, 7, 5, 6, 8]])).argsort_along_seq(),
BatchedTensorSeq(torch.tensor([[1, 2, 4, 0, 3], [2, 3, 1, 4, 0]])),
)


def test_batched_tensor_seq_argsort_along_seq_descending_true() -> None:
assert objects_are_equal(
BatchedTensorSeq(torch.tensor([[4, 1, 2, 5, 3], [9, 7, 5, 6, 8]])).argsort_along_seq(
descending=True
),
BatchedTensorSeq(torch.tensor([[3, 0, 4, 2, 1], [0, 4, 1, 3, 2]])),
)


def test_batched_tensor_seq_argsort_along_seq_dim_3() -> None:
assert objects_are_equal(
BatchedTensorSeq(
torch.tensor(
[
[[0, 1], [-2, 3], [-4, 5], [-6, 7], [-8, 9]],
[[10, -11], [12, -13], [14, -15], [16, -17], [18, -19]],
]
)
).argsort_along_seq(),
BatchedTensorSeq(
torch.tensor(
[
[[4, 0], [3, 1], [2, 2], [1, 3], [0, 4]],
[[0, 4], [1, 3], [2, 2], [3, 1], [4, 0]],
]
)
),
)


def test_batched_tensor_seq_argsort_along_seq_custom_dims() -> None:
assert objects_are_equal(
BatchedTensorSeq(
torch.tensor([[4, 9], [1, 7], [2, 5], [5, 6], [3, 8]]), seq_dim=0, batch_dim=1
).argsort_along_seq(),
BatchedTensorSeq(
torch.tensor([[1, 2], [2, 3], [4, 1], [0, 4], [3, 0]]), seq_dim=0, batch_dim=1
),
)


def test_batched_tensor_seq_cumsum_dim_0() -> None:
assert (
BatchedTensorSeq(torch.arange(10).view(2, 5))
Expand Down

0 comments on commit 6d75df7

Please sign in to comment.