Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cumprod of BatchedTensor(Seq) #415

Merged
merged 1 commit into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redcat"
version = "0.0.8"
version = "0.0.9a0"
description = "A library to manipulate batches of examples"
readme = "README.md"
authors = ["Thibaut Durand <durand.tibo+gh@gmail.com>"]
Expand Down
102 changes: 102 additions & 0 deletions src/redcat/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1405,6 +1405,108 @@ def argsort_along_batch(self, *args, **kwargs) -> TBatchedTensor:
"""
return self.argsort(self._batch_dim, *args, **kwargs)

def cumprod(self, dim: int, *args, **kwargs) -> TBatchedTensor:
r"""Computes the cumulative product of elements of the current
batch in a given dimension.

Args:
----
dim (int): Specifies the dimension of the cumulative sum.
*args: See the documentation of ``torch.Tensor.cumprod``
**kwargs: See the documentation of ``torch.Tensor.cumprod``

Returns:
-------
``BatchedTensor``: A batch with the cumulative product of
elements of the current batch in a given dimension.

Example usage:

.. code-block:: pycon

>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.arange(10).view(2, 5))
>>> batch.cumprod(dim=0)
tensor([[ 0, 1, 2, 3, 4],
[ 0, 6, 14, 24, 36]], batch_dim=0)
"""
return torch.cumprod(self, dim, *args, **kwargs)

def cumprod_(self, dim: int, *args, **kwargs) -> None:
r"""Computes the cumulative product of elements of the current
batch in a given dimension.

Args:
----
dim (int): Specifies the dimension of the cumulative product.
*args: See the documentation of ``torch.Tensor.cumprod``
**kwargs: See the documentation of ``torch.Tensor.cumprod``

Example usage:

.. code-block:: pycon

>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.arange(10).view(2, 5))
>>> batch.cumprod_(dim=0)
>>> batch
tensor([[ 0, 1, 2, 3, 4],
[ 0, 6, 14, 24, 36]], batch_dim=0)
"""
self._data.cumprod_(dim, *args, **kwargs)

def cumprod_along_batch(self, *args, **kwargs) -> TBatchedTensor:
r"""Computes the cumulative product of elements of the current
batch in the batch dimension.

Args:
----
*args: See the documentation of ``torch.Tensor.cumprod``
**kwargs: See the documentation of ``torch.Tensor.cumprod``

Returns:
-------
``BatchedTensor``: A batch with the cumulative product of
elements of the current batch in the batch dimension.

Example usage:

.. code-block:: pycon

>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.arange(10).view(2, 5))
>>> batch.cumprod_along_batch()
tensor([[ 0, 1, 2, 3, 4],
[ 0, 6, 14, 24, 36]], batch_dim=0)
"""
return self.cumprod(self._batch_dim, *args, **kwargs)

def cumprod_along_batch_(self, *args, **kwargs) -> None:
r"""Computes the cumulative product of elements of the current
batch in the batch dimension.

Args:
----
*args: See the documentation of ``torch.Tensor.cumprod``
**kwargs: See the documentation of ``torch.Tensor.cumprod``

Example usage:

.. code-block:: pycon

>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.arange(10).view(2, 5))
>>> batch.cumprod_along_batch_()
>>> batch
tensor([[ 0, 1, 2, 3, 4],
[ 0, 6, 14, 24, 36]], batch_dim=0)
"""
self.cumprod_(self._batch_dim, *args, **kwargs)

def cumsum(self, dim: int, **kwargs) -> TBatchedTensor:
r"""Computes the cumulative sum of elements of the current batch
in a given dimension.
Expand Down
50 changes: 50 additions & 0 deletions src/redcat/tensorseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,56 @@ def argsort_along_seq(self, *args, **kwargs) -> BatchedTensorSeq:
"""
return self.argsort(self._seq_dim, *args, **kwargs)

def cumprod_along_seq(self, *args, **kwargs) -> BatchedTensorSeq:
r"""Computes the cumulative product of elements of the current
batch in the sequence dimension.

Args:
----
*args: See the documentation of ``torch.Tensor.cumprod``
**kwargs: See the documentation of ``torch.Tensor.cumprod``

Returns:
-------
``BatchedTensorSeq``: A batch with the cumulative sum of
elements of the current batch in the sequence dimension.

Example usage:

.. code-block:: pycon

>>> import torch
>>> from redcat import BatchedTensorSeq
>>> batch = BatchedTensorSeq(torch.arange(10).view(2, 5)).cumprod_along_seq()
>>> batch
tensor([[ 0, 0, 0, 0, 0],
[ 5, 30, 210, 1680, 15120]], batch_dim=0, seq_dim=1)
"""
return self.cumprod(self._seq_dim, *args, **kwargs)

def cumprod_along_seq_(self, *args, **kwargs) -> None:
r"""Computes the cumulative product of elements of the current
batch in the sequence dimension.

Args:
----
*args: See the documentation of ``torch.Tensor.cumprod``
**kwargs: See the documentation of ``torch.Tensor.cumprod``

Example usage:

.. code-block:: pycon

>>> import torch
>>> from redcat import BatchedTensorSeq
>>> batch = BatchedTensorSeq(torch.arange(10).view(2, 5))
>>> batch.cumprod_along_seq_()
>>> batch
tensor([[ 0, 0, 0, 0, 0],
[ 5, 30, 210, 1680, 15120]], batch_dim=0, seq_dim=1)
"""
self.cumprod_(self._seq_dim, *args, **kwargs)

def cumsum_along_seq(self, **kwargs) -> BatchedTensorSeq:
r"""Computes the cumulative sum of elements of the current batch
in the sequence dimension.
Expand Down
88 changes: 88 additions & 0 deletions tests/unit/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1611,6 +1611,94 @@ def test_batched_tensor_argsort_along_batch_custom_dims() -> None:
)


def test_batched_tensor_cumprod_dim_0() -> None:
assert (
BatchedTensor(torch.arange(10).view(2, 5))
.cumprod(dim=0)
.equal(BatchedTensor(torch.tensor([[0, 1, 2, 3, 4], [0, 6, 14, 24, 36]])))
)


def test_batched_tensor_cumprod_dim_1() -> None:
assert (
BatchedTensor(torch.arange(10).view(2, 5))
.cumprod(dim=1)
.equal(BatchedTensor(torch.tensor([[0, 0, 0, 0, 0], [5, 30, 210, 1680, 15120]])))
)


def test_batched_tensor_cumprod_custom_dims() -> None:
assert (
BatchedTensor(torch.arange(10).view(5, 2), batch_dim=1)
.cumprod(dim=1)
.equal(
BatchedTensor(torch.tensor([[0, 0], [2, 6], [4, 20], [6, 42], [8, 72]]), batch_dim=1)
)
)


def test_batched_tensor_cumprod_dtype() -> None:
assert (
BatchedTensor(torch.arange(10).view(2, 5))
.cumprod(dim=0, dtype=torch.int)
.equal(BatchedTensor(torch.tensor([[0, 1, 2, 3, 4], [0, 6, 14, 24, 36]], dtype=torch.int)))
)


def test_batched_tensor_cumprod_() -> None:
batch = BatchedTensor(torch.arange(10).view(2, 5))
batch.cumprod_(dim=0)
assert batch.equal(BatchedTensor(torch.tensor([[0, 1, 2, 3, 4], [0, 6, 14, 24, 36]])))


def test_batched_tensor_cumprod__custom_dims() -> None:
batch = BatchedTensor(torch.arange(10).view(5, 2), batch_dim=1)
batch.cumprod_(dim=1)
assert batch.equal(
BatchedTensor(torch.tensor([[0, 0], [2, 6], [4, 20], [6, 42], [8, 72]]), batch_dim=1)
)


def test_batched_tensor_cumprod_along_batch() -> None:
assert (
BatchedTensor(torch.arange(10).view(2, 5))
.cumprod_along_batch()
.equal(BatchedTensor(torch.tensor([[0, 1, 2, 3, 4], [0, 6, 14, 24, 36]])))
)


def test_batched_tensor_cumprod_along_batch_custom_dims() -> None:
assert (
BatchedTensor(torch.arange(10).view(5, 2), batch_dim=1)
.cumprod_along_batch()
.equal(
BatchedTensor(torch.tensor([[0, 0], [2, 6], [4, 20], [6, 42], [8, 72]]), batch_dim=1)
)
)


def test_batched_tensor_cumprod_along_batch_dtype() -> None:
assert (
BatchedTensor(torch.arange(10).view(2, 5))
.cumprod_along_batch(dtype=torch.int)
.equal(BatchedTensor(torch.tensor([[0, 1, 2, 3, 4], [0, 6, 14, 24, 36]], dtype=torch.int)))
)


def test_batched_tensor_cumprod_along_batch_() -> None:
batch = BatchedTensor(torch.arange(10).view(2, 5))
batch.cumprod_along_batch_()
assert batch.equal(BatchedTensor(torch.tensor([[0, 1, 2, 3, 4], [0, 6, 14, 24, 36]])))


def test_batched_tensor_cumprod_along_batch__custom_dims() -> None:
batch = BatchedTensor(torch.arange(10).view(5, 2), batch_dim=1)
batch.cumprod_along_batch_()
assert batch.equal(
BatchedTensor(torch.tensor([[0, 0], [2, 6], [4, 20], [6, 42], [8, 72]]), batch_dim=1)
)


def test_batched_tensor_cumsum() -> None:
assert (
BatchedTensor(torch.arange(10).view(2, 5))
Expand Down
Loading
Loading