Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add argsort of BatchedTensor(Seq) #385

Merged
merged 1 commit into from
Aug 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redcat"
version = "0.0.8a55"
version = "0.0.8a56"
description = "A library to manipulate batches of examples"
readme = "README.md"
authors = ["Thibaut Durand <durand.tibo+gh@gmail.com>"]
Expand Down
80 changes: 80 additions & 0 deletions src/redcat/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,6 +1350,84 @@ def sub_(
# Mathematical | advanced arithmetical operations #
###########################################################

def argsort(
self,
dim: int = -1,
descending: bool = False,
stable: bool = False,
) -> TBatchedTensor:
r"""Returns the indices that sort the batch along a given
dimension in monotonic order by value.

Args:
----
dim (int, optional): Specifies the dimension to sort along.
Default: ``-1``
descending (bool, optional): Controls the sorting order.
If ``True``, the elements are sorted in descending
order by value. Default: ``False``
stable (bool, optional): Makes the sorting routine stable,
which guarantees that the order of equivalent elements
is preserved. Default: ``False``

Returns:
-------
``BatchedTensor``: The indices that sort the batch along
the given dimension.

Example usage:

.. code-block:: pycon

>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.arange(10).view(2, 5))
>>> batch.argsort(descending=True)
tensor([[4, 3, 2, 1, 0],
[4, 3, 2, 1, 0]], batch_dim=0)
"""
return self._create_new_batch(
self._data.argsort(dim=dim, descending=descending, stable=stable)
)

def argsort_along_batch(
self,
descending: bool = False,
stable: bool = False,
) -> TBatchedTensor:
r"""Sorts the elements of the batch along the batch dimension in
monotonic order by value.

Args:
----
descending (bool, optional): Controls the sorting order.
If ``True``, the elements are sorted in descending
order by value. Default: ``False``
stable (bool, optional): Makes the sorting routine stable,
which guarantees that the order of equivalent elements
is preserved. Default: ``False``

Returns:
-------
``BatchedTensor``: The indices that sort the batch along
the batch dimension.

Example usage:

.. code-block:: pycon

>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.arange(10).view(5, 2))
>>> batch.argsort_along_batch(descending=True)
tensor([[4, 4],
[3, 3],
[2, 2],
[1, 1],
[0, 0]], batch_dim=0)
"""
return self.argsort(dim=self._batch_dim, descending=descending, stable=stable)

def cumsum(self, dim: int, **kwargs) -> TBatchedTensor:
r"""Computes the cumulative sum of elements of the current batch
in a given dimension.
Expand Down Expand Up @@ -1685,6 +1763,8 @@ def sort(

Args:
----
dim (int, optional): Specifies the dimension to sort along.
Default: ``-1``
descending (bool, optional): Controls the sorting order.
If ``True``, the elements are sorted in descending
order by value. Default: ``False``
Expand Down
35 changes: 35 additions & 0 deletions src/redcat/tensorseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,41 @@ def sub_(
# Mathematical | advanced arithmetical operations #
###########################################################

def argsort_along_seq(
self,
descending: bool = False,
stable: bool = False,
) -> BatchedTensorSeq:
r"""Sorts the elements of the batch along the sequence dimension
in monotonic order by value.

Args:
----
descending (bool, optional): Controls the sorting order.
If ``True``, the elements are sorted in descending
order by value. Default: ``False``
stable (bool, optional): Makes the sorting routine stable,
which guarantees that the order of equivalent elements
is preserved. Default: ``False``

Returns:
-------
``BatchedTensor``: The indices that sort the batch along
the sequence dimension.

Example usage:

.. code-block:: pycon

>>> import torch
>>> from redcat import BatchedTensorSeq
>>> batch = BatchedTensorSeq(torch.arange(10).view(2, 5))
>>> batch.argsort_along_seq(descending=True)
tensor([[4, 3, 2, 1, 0],
[4, 3, 2, 1, 0]], batch_dim=0, seq_dim=1)
"""
return self.argsort(dim=self._seq_dim, descending=descending, stable=stable)

def cumsum_along_seq(self, **kwargs) -> BatchedTensorSeq:
r"""Computes the cumulative sum of elements of the current batch
in the sequence dimension.
Expand Down
76 changes: 76 additions & 0 deletions tests/unit/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1536,6 +1536,82 @@ def test_batched_tensor_sub__incorrect_batch_dim() -> None:
###########################################################


def test_batched_tensor_argsort_descending_false() -> None:
assert objects_are_equal(
BatchedTensor(torch.tensor([[4, 1, 2, 5, 3], [9, 7, 5, 6, 8]])).argsort(descending=False),
BatchedTensor(torch.tensor([[1, 2, 4, 0, 3], [2, 3, 1, 4, 0]])),
)


def test_batched_tensor_argsort_descending_true() -> None:
assert objects_are_equal(
BatchedTensor(torch.tensor([[4, 1, 2, 5, 3], [9, 7, 5, 6, 8]])).argsort(descending=True),
BatchedTensor(torch.tensor([[3, 0, 4, 2, 1], [0, 4, 1, 3, 2]])),
)


def test_batched_tensor_argsort_dim_0() -> None:
assert objects_are_equal(
BatchedTensor(torch.tensor([[4, 9], [1, 7], [2, 5], [5, 6], [3, 8]])).argsort(dim=0),
BatchedTensor(torch.tensor([[1, 2], [2, 3], [4, 1], [0, 4], [3, 0]])),
)


def test_batched_tensor_argsort_dim_1() -> None:
assert objects_are_equal(
BatchedTensor(
torch.tensor(
[
[[0, 1], [-2, 3], [-4, 5], [-6, 7], [-8, 9]],
[[10, -11], [12, -13], [14, -15], [16, -17], [18, -19]],
]
)
).argsort(dim=1),
BatchedTensor(
torch.tensor(
[
[[4, 0], [3, 1], [2, 2], [1, 3], [0, 4]],
[[0, 4], [1, 3], [2, 2], [3, 1], [4, 0]],
]
)
),
)


def test_batched_tensor_argsort_custom_dims() -> None:
assert objects_are_equal(
BatchedTensor(torch.tensor([[4, 9], [1, 7], [2, 5], [5, 6], [3, 8]]), batch_dim=1).argsort(
dim=0
),
BatchedTensor(torch.tensor([[1, 2], [2, 3], [4, 1], [0, 4], [3, 0]]), batch_dim=1),
)


def test_batched_tensor_argsort_along_batch_descending_false() -> None:
assert objects_are_equal(
BatchedTensor(torch.tensor([[4, 9], [1, 7], [2, 5], [5, 6], [3, 8]])).argsort_along_batch(),
BatchedTensor(torch.tensor([[1, 2], [2, 3], [4, 1], [0, 4], [3, 0]])),
)


def test_batched_tensor_argsort_along_batch_descending_true() -> None:
assert objects_are_equal(
BatchedTensor(torch.tensor([[4, 9], [1, 7], [2, 5], [5, 6], [3, 8]])).argsort_along_batch(
descending=True
),
BatchedTensor(torch.tensor([[3, 0], [0, 4], [4, 1], [2, 3], [1, 2]])),
)


def test_batched_tensor_argsort_along_batch_custom_dims() -> None:
assert objects_are_equal(
BatchedTensor(
torch.tensor([[4, 1, 2, 5, 3], [9, 7, 5, 6, 8]]), batch_dim=1
).argsort_along_batch(),
BatchedTensor(torch.tensor([[1, 2, 4, 0, 3], [2, 3, 1, 4, 0]]), batch_dim=1),
)


def test_batched_tensor_cumsum() -> None:
assert (
BatchedTensor(torch.arange(10).view(2, 5))
Expand Down
130 changes: 130 additions & 0 deletions tests/unit/test_tensorseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1762,6 +1762,136 @@ def test_batched_tensor_seq_sub__incorrect_seq_dim() -> None:
###########################################################


def test_batched_tensor_seq_argsort_descending_false() -> None:
assert objects_are_equal(
BatchedTensorSeq(torch.tensor([[4, 1, 2, 5, 3], [9, 7, 5, 6, 8]])).argsort(
descending=False
),
BatchedTensorSeq(torch.tensor([[1, 2, 4, 0, 3], [2, 3, 1, 4, 0]])),
)


def test_batched_tensor_seq_argsort_descending_true() -> None:
assert objects_are_equal(
BatchedTensorSeq(torch.tensor([[4, 1, 2, 5, 3], [9, 7, 5, 6, 8]])).argsort(descending=True),
BatchedTensorSeq(torch.tensor([[3, 0, 4, 2, 1], [0, 4, 1, 3, 2]])),
)


def test_batched_tensor_seq_argsort_dim_0() -> None:
assert objects_are_equal(
BatchedTensorSeq(torch.tensor([[4, 9], [1, 7], [2, 5], [5, 6], [3, 8]])).argsort(dim=0),
BatchedTensorSeq(torch.tensor([[1, 2], [2, 3], [4, 1], [0, 4], [3, 0]])),
)


def test_batched_tensor_seq_argsort_dim_1() -> None:
assert objects_are_equal(
BatchedTensorSeq(
torch.tensor(
[
[[0, 1], [-2, 3], [-4, 5], [-6, 7], [-8, 9]],
[[10, -11], [12, -13], [14, -15], [16, -17], [18, -19]],
]
)
).argsort(dim=1),
BatchedTensorSeq(
torch.tensor(
[
[[4, 0], [3, 1], [2, 2], [1, 3], [0, 4]],
[[0, 4], [1, 3], [2, 2], [3, 1], [4, 0]],
]
)
),
)


def test_batched_tensor_seq_argsort_custom_dims() -> None:
assert objects_are_equal(
BatchedTensorSeq(
torch.tensor([[4, 9], [1, 7], [2, 5], [5, 6], [3, 8]]), seq_dim=0, batch_dim=1
).argsort(dim=0),
BatchedTensorSeq(
torch.tensor([[1, 2], [2, 3], [4, 1], [0, 4], [3, 0]]), seq_dim=0, batch_dim=1
),
)


def test_batched_tensor_seq_argsort_along_batch_descending_false() -> None:
assert objects_are_equal(
BatchedTensorSeq(
torch.tensor([[4, 9], [1, 7], [2, 5], [5, 6], [3, 8]])
).argsort_along_batch(),
BatchedTensorSeq(torch.tensor([[1, 2], [2, 3], [4, 1], [0, 4], [3, 0]])),
)


def test_batched_tensor_seq_argsort_along_batch_descending_true() -> None:
assert objects_are_equal(
BatchedTensorSeq(
torch.tensor([[4, 9], [1, 7], [2, 5], [5, 6], [3, 8]])
).argsort_along_batch(descending=True),
BatchedTensorSeq(torch.tensor([[3, 0], [0, 4], [4, 1], [2, 3], [1, 2]])),
)


def test_batched_tensor_seq_argsort_along_batch_custom_dims() -> None:
assert objects_are_equal(
BatchedTensorSeq(
torch.tensor([[4, 1, 2, 5, 3], [9, 7, 5, 6, 8]]), seq_dim=0, batch_dim=1
).argsort_along_batch(),
BatchedTensorSeq(torch.tensor([[1, 2, 4, 0, 3], [2, 3, 1, 4, 0]]), seq_dim=0, batch_dim=1),
)


def test_batched_tensor_seq_argsort_along_seq_descending_false() -> None:
assert objects_are_equal(
BatchedTensorSeq(torch.tensor([[4, 1, 2, 5, 3], [9, 7, 5, 6, 8]])).argsort_along_seq(),
BatchedTensorSeq(torch.tensor([[1, 2, 4, 0, 3], [2, 3, 1, 4, 0]])),
)


def test_batched_tensor_seq_argsort_along_seq_descending_true() -> None:
assert objects_are_equal(
BatchedTensorSeq(torch.tensor([[4, 1, 2, 5, 3], [9, 7, 5, 6, 8]])).argsort_along_seq(
descending=True
),
BatchedTensorSeq(torch.tensor([[3, 0, 4, 2, 1], [0, 4, 1, 3, 2]])),
)


def test_batched_tensor_seq_argsort_along_seq_dim_3() -> None:
assert objects_are_equal(
BatchedTensorSeq(
torch.tensor(
[
[[0, 1], [-2, 3], [-4, 5], [-6, 7], [-8, 9]],
[[10, -11], [12, -13], [14, -15], [16, -17], [18, -19]],
]
)
).argsort_along_seq(),
BatchedTensorSeq(
torch.tensor(
[
[[4, 0], [3, 1], [2, 2], [1, 3], [0, 4]],
[[0, 4], [1, 3], [2, 2], [3, 1], [4, 0]],
]
)
),
)


def test_batched_tensor_seq_argsort_along_seq_custom_dims() -> None:
assert objects_are_equal(
BatchedTensorSeq(
torch.tensor([[4, 9], [1, 7], [2, 5], [5, 6], [3, 8]]), seq_dim=0, batch_dim=1
).argsort_along_seq(),
BatchedTensorSeq(
torch.tensor([[1, 2], [2, 3], [4, 1], [0, 4], [3, 0]]), seq_dim=0, batch_dim=1
),
)


def test_batched_tensor_seq_cumsum_dim_0() -> None:
assert (
BatchedTensorSeq(torch.arange(10).view(2, 5))
Expand Down