Skip to content

Commit

Permalink
Add cumprod of BatchedTensor(Seq) (#415)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo committed Aug 16, 2023
1 parent 8093560 commit db0d52a
Show file tree
Hide file tree
Showing 5 changed files with 384 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redcat"
version = "0.0.8"
version = "0.0.9a0"
description = "A library to manipulate batches of examples"
readme = "README.md"
authors = ["Thibaut Durand <durand.tibo+gh@gmail.com>"]
Expand Down
102 changes: 102 additions & 0 deletions src/redcat/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1405,6 +1405,108 @@ def argsort_along_batch(self, *args, **kwargs) -> TBatchedTensor:
"""
return self.argsort(self._batch_dim, *args, **kwargs)

def cumprod(self, dim: int, *args, **kwargs) -> TBatchedTensor:
r"""Computes the cumulative product of elements of the current
batch in a given dimension.
Args:
----
dim (int): Specifies the dimension of the cumulative sum.
*args: See the documentation of ``torch.Tensor.cumprod``
**kwargs: See the documentation of ``torch.Tensor.cumprod``
Returns:
-------
``BatchedTensor``: A batch with the cumulative product of
elements of the current batch in a given dimension.
Example usage:
.. code-block:: pycon
>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.arange(10).view(2, 5))
>>> batch.cumprod(dim=0)
tensor([[ 0, 1, 2, 3, 4],
[ 0, 6, 14, 24, 36]], batch_dim=0)
"""
return torch.cumprod(self, dim, *args, **kwargs)

def cumprod_(self, dim: int, *args, **kwargs) -> None:
r"""Computes the cumulative product of elements of the current
batch in a given dimension.
Args:
----
dim (int): Specifies the dimension of the cumulative product.
*args: See the documentation of ``torch.Tensor.cumprod``
**kwargs: See the documentation of ``torch.Tensor.cumprod``
Example usage:
.. code-block:: pycon
>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.arange(10).view(2, 5))
>>> batch.cumprod_(dim=0)
>>> batch
tensor([[ 0, 1, 2, 3, 4],
[ 0, 6, 14, 24, 36]], batch_dim=0)
"""
self._data.cumprod_(dim, *args, **kwargs)

def cumprod_along_batch(self, *args, **kwargs) -> TBatchedTensor:
r"""Computes the cumulative product of elements of the current
batch in the batch dimension.
Args:
----
*args: See the documentation of ``torch.Tensor.cumprod``
**kwargs: See the documentation of ``torch.Tensor.cumprod``
Returns:
-------
``BatchedTensor``: A batch with the cumulative product of
elements of the current batch in the batch dimension.
Example usage:
.. code-block:: pycon
>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.arange(10).view(2, 5))
>>> batch.cumprod_along_batch()
tensor([[ 0, 1, 2, 3, 4],
[ 0, 6, 14, 24, 36]], batch_dim=0)
"""
return self.cumprod(self._batch_dim, *args, **kwargs)

def cumprod_along_batch_(self, *args, **kwargs) -> None:
r"""Computes the cumulative product of elements of the current
batch in the batch dimension.
Args:
----
*args: See the documentation of ``torch.Tensor.cumprod``
**kwargs: See the documentation of ``torch.Tensor.cumprod``
Example usage:
.. code-block:: pycon
>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.arange(10).view(2, 5))
>>> batch.cumprod_along_batch_()
>>> batch
tensor([[ 0, 1, 2, 3, 4],
[ 0, 6, 14, 24, 36]], batch_dim=0)
"""
self.cumprod_(self._batch_dim, *args, **kwargs)

def cumsum(self, dim: int, **kwargs) -> TBatchedTensor:
r"""Computes the cumulative sum of elements of the current batch
in a given dimension.
Expand Down
50 changes: 50 additions & 0 deletions src/redcat/tensorseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,56 @@ def argsort_along_seq(self, *args, **kwargs) -> BatchedTensorSeq:
"""
return self.argsort(self._seq_dim, *args, **kwargs)

def cumprod_along_seq(self, *args, **kwargs) -> BatchedTensorSeq:
r"""Computes the cumulative product of elements of the current
batch in the sequence dimension.
Args:
----
*args: See the documentation of ``torch.Tensor.cumprod``
**kwargs: See the documentation of ``torch.Tensor.cumprod``
Returns:
-------
``BatchedTensorSeq``: A batch with the cumulative sum of
elements of the current batch in the sequence dimension.
Example usage:
.. code-block:: pycon
>>> import torch
>>> from redcat import BatchedTensorSeq
>>> batch = BatchedTensorSeq(torch.arange(10).view(2, 5)).cumprod_along_seq()
>>> batch
tensor([[ 0, 0, 0, 0, 0],
[ 5, 30, 210, 1680, 15120]], batch_dim=0, seq_dim=1)
"""
return self.cumprod(self._seq_dim, *args, **kwargs)

def cumprod_along_seq_(self, *args, **kwargs) -> None:
r"""Computes the cumulative product of elements of the current
batch in the sequence dimension.
Args:
----
*args: See the documentation of ``torch.Tensor.cumprod``
**kwargs: See the documentation of ``torch.Tensor.cumprod``
Example usage:
.. code-block:: pycon
>>> import torch
>>> from redcat import BatchedTensorSeq
>>> batch = BatchedTensorSeq(torch.arange(10).view(2, 5))
>>> batch.cumprod_along_seq_()
>>> batch
tensor([[ 0, 0, 0, 0, 0],
[ 5, 30, 210, 1680, 15120]], batch_dim=0, seq_dim=1)
"""
self.cumprod_(self._seq_dim, *args, **kwargs)

def cumsum_along_seq(self, **kwargs) -> BatchedTensorSeq:
r"""Computes the cumulative sum of elements of the current batch
in the sequence dimension.
Expand Down
88 changes: 88 additions & 0 deletions tests/unit/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1611,6 +1611,94 @@ def test_batched_tensor_argsort_along_batch_custom_dims() -> None:
)


def test_batched_tensor_cumprod_dim_0() -> None:
assert (
BatchedTensor(torch.arange(10).view(2, 5))
.cumprod(dim=0)
.equal(BatchedTensor(torch.tensor([[0, 1, 2, 3, 4], [0, 6, 14, 24, 36]])))
)


def test_batched_tensor_cumprod_dim_1() -> None:
assert (
BatchedTensor(torch.arange(10).view(2, 5))
.cumprod(dim=1)
.equal(BatchedTensor(torch.tensor([[0, 0, 0, 0, 0], [5, 30, 210, 1680, 15120]])))
)


def test_batched_tensor_cumprod_custom_dims() -> None:
assert (
BatchedTensor(torch.arange(10).view(5, 2), batch_dim=1)
.cumprod(dim=1)
.equal(
BatchedTensor(torch.tensor([[0, 0], [2, 6], [4, 20], [6, 42], [8, 72]]), batch_dim=1)
)
)


def test_batched_tensor_cumprod_dtype() -> None:
assert (
BatchedTensor(torch.arange(10).view(2, 5))
.cumprod(dim=0, dtype=torch.int)
.equal(BatchedTensor(torch.tensor([[0, 1, 2, 3, 4], [0, 6, 14, 24, 36]], dtype=torch.int)))
)


def test_batched_tensor_cumprod_() -> None:
batch = BatchedTensor(torch.arange(10).view(2, 5))
batch.cumprod_(dim=0)
assert batch.equal(BatchedTensor(torch.tensor([[0, 1, 2, 3, 4], [0, 6, 14, 24, 36]])))


def test_batched_tensor_cumprod__custom_dims() -> None:
batch = BatchedTensor(torch.arange(10).view(5, 2), batch_dim=1)
batch.cumprod_(dim=1)
assert batch.equal(
BatchedTensor(torch.tensor([[0, 0], [2, 6], [4, 20], [6, 42], [8, 72]]), batch_dim=1)
)


def test_batched_tensor_cumprod_along_batch() -> None:
assert (
BatchedTensor(torch.arange(10).view(2, 5))
.cumprod_along_batch()
.equal(BatchedTensor(torch.tensor([[0, 1, 2, 3, 4], [0, 6, 14, 24, 36]])))
)


def test_batched_tensor_cumprod_along_batch_custom_dims() -> None:
assert (
BatchedTensor(torch.arange(10).view(5, 2), batch_dim=1)
.cumprod_along_batch()
.equal(
BatchedTensor(torch.tensor([[0, 0], [2, 6], [4, 20], [6, 42], [8, 72]]), batch_dim=1)
)
)


def test_batched_tensor_cumprod_along_batch_dtype() -> None:
assert (
BatchedTensor(torch.arange(10).view(2, 5))
.cumprod_along_batch(dtype=torch.int)
.equal(BatchedTensor(torch.tensor([[0, 1, 2, 3, 4], [0, 6, 14, 24, 36]], dtype=torch.int)))
)


def test_batched_tensor_cumprod_along_batch_() -> None:
batch = BatchedTensor(torch.arange(10).view(2, 5))
batch.cumprod_along_batch_()
assert batch.equal(BatchedTensor(torch.tensor([[0, 1, 2, 3, 4], [0, 6, 14, 24, 36]])))


def test_batched_tensor_cumprod_along_batch__custom_dims() -> None:
batch = BatchedTensor(torch.arange(10).view(5, 2), batch_dim=1)
batch.cumprod_along_batch_()
assert batch.equal(
BatchedTensor(torch.tensor([[0, 0], [2, 6], [4, 20], [6, 42], [8, 72]]), batch_dim=1)
)


def test_batched_tensor_cumsum() -> None:
assert (
BatchedTensor(torch.arange(10).view(2, 5))
Expand Down
Loading

0 comments on commit db0d52a

Please sign in to comment.