Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add div_ to BatchedTensor(Seq) #36

Merged
merged 1 commit into from
Apr 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redcat"
version = "0.0.1a32"
version = "0.0.1a33"
description = "A library to manipulate batches of examples"
readme = "README.md"
authors = ["Thibaut Durand <durand.tibo+gh@gmail.com>"]
Expand Down
36 changes: 35 additions & 1 deletion src/redcat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def div(
) -> TBatchedTensor:
r"""Divides the ``self`` batch by the input ``other`.

Similar to ``out = self / other`` (in-place)
Similar to ``out = self / other``

Args:
other (``BaseBatchedTensor`` or ``torch.Tensor`` or int or
Expand Down Expand Up @@ -785,6 +785,40 @@ def div(
"""
return torch.div(self, other, rounding_mode=rounding_mode)

@abstractmethod
def div_(
self,
other: BaseBatchedTensor | torch.Tensor | int | float,
rounding_mode: str | None = None,
) -> None:
r"""Divides the ``self`` batch by the input ``other`.

Similar to ``self /= other`` (in-place)

Args:
other (``BaseBatchedTensor`` or ``torch.Tensor`` or int or
float): Specifies the dividend.
rounding_mode (str or ``None``, optional): Specifies the
type of rounding applied to the result.
- ``None``: true division.
- ``"trunc"``: rounds the results of the division
towards zero.
- ``"floor"``: floor division.
Default: ``None``

Example usage:

.. code-block:: python

>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.ones(2, 3))
>>> batch.div_(BatchedTensor(torch.full((2, 3), 2.0)))
>>> batch
tensor([[0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000]], batch_dim=0)
"""

def mul(self, other: BaseBatchedTensor | torch.Tensor | int | float) -> TBatchedTensor:
r"""Multiplies the ``self`` batch by the input ``other`.

Expand Down
8 changes: 8 additions & 0 deletions src/redcat/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,14 @@ def add_(
check_batch_dims(get_batch_dims((self, other), {}))
return self._data.add_(other, alpha=alpha)

def div_(
self,
other: BaseBatchedTensor | torch.Tensor | int | float,
rounding_mode: str | None = None,
) -> None:
check_batch_dims(get_batch_dims((self, other), {}))
return self._data.div_(other, rounding_mode=rounding_mode)

def _get_kwargs(self) -> dict:
return {"batch_dim": self._batch_dim}

Expand Down
9 changes: 9 additions & 0 deletions src/redcat/tensor_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,15 @@ def add_(
check_seq_dims(get_seq_dims((self, other), {}))
return self._data.add_(other, alpha=alpha)

def div_(
self,
other: BaseBatchedTensor | torch.Tensor | int | float,
rounding_mode: str | None = None,
) -> None:
check_batch_dims(get_batch_dims((self, other), {}))
check_seq_dims(get_seq_dims((self, other), {}))
return self._data.div_(other, rounding_mode=rounding_mode)

def _get_kwargs(self) -> dict:
return {"batch_dim": self._batch_dim, "seq_dim": self._seq_dim}

Expand Down
40 changes: 38 additions & 2 deletions tests/unit/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,7 @@ def test_batched_tensor_le_custom_batch_dim() -> None:
5.0,
),
)
def test_batched_tensor_lt(other: Union[BatchedTensorSeq, torch.Tensor, bool, int, float]) -> None:
def test_batched_tensor_lt(other: Union[BaseBatchedTensor, torch.Tensor, bool, int, float]) -> None:
assert (
BatchedTensor(torch.arange(10).view(2, 5))
.lt(other)
Expand Down Expand Up @@ -1064,7 +1064,7 @@ def test_batched_tensor_add_incorrect_batch_dim() -> None:
),
)
def test_batched_tensor_add_(
other: Union[BatchedTensorSeq, torch.Tensor, bool, int, float]
other: Union[BaseBatchedTensor, torch.Tensor, bool, int, float]
) -> None:
batch = BatchedTensor(torch.ones(2, 3))
batch.add_(other)
Expand Down Expand Up @@ -1134,6 +1134,42 @@ def test_batched_tensor_div_incorrect_batch_dim() -> None:
BatchedTensor(torch.ones(2, 3)).div(BatchedTensor(torch.ones(2, 3), batch_dim=1))


@mark.parametrize(
"other",
(
BatchedTensor(torch.ones(2, 3).mul(2)),
BatchedTensorSeq(torch.ones(2, 3).mul(2)),
torch.ones(2, 3).mul(2),
2,
2.0,
),
)
def test_batched_tensor_div_(
other: Union[BaseBatchedTensor, torch.Tensor, bool, int, float]
) -> None:
batch = BatchedTensor(torch.ones(2, 3))
batch.div_(other)
assert batch.equal(BatchedTensor(torch.ones(2, 3).mul(0.5)))


def test_batched_tensor_div__rounding_mode_floor() -> None:
batch = BatchedTensor(torch.ones(2, 3))
batch.div_(BatchedTensor(torch.ones(2, 3).mul(2)), rounding_mode="floor")
assert batch.equal(BatchedTensor(torch.zeros(2, 3)))


def test_batched_tensor_div__custom_dims() -> None:
batch = BatchedTensor(torch.ones(2, 3), batch_dim=1)
batch.div_(BatchedTensor(torch.ones(2, 3).mul(2), batch_dim=1))
assert batch.equal(BatchedTensor(torch.ones(2, 3).mul(0.5), batch_dim=1))


def test_batched_tensor_div__incorrect_batch_dim() -> None:
batch = BatchedTensor(torch.ones(2, 3))
with raises(RuntimeError, match=r"The batch dimensions do not match."):
batch.div_(BatchedTensor(torch.ones(2, 3), batch_dim=1))


@mark.parametrize(
"other",
(
Expand Down
44 changes: 43 additions & 1 deletion tests/unit/test_tensor_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,7 +1233,7 @@ def test_batched_tensor_seq_add__incorrect_batch_dim() -> None:
def test_batched_tensor_seq_add__incorrect_seq_dim() -> None:
batch = BatchedTensorSeq(torch.ones(2, 3, 1))
with raises(RuntimeError, match=r"The sequence dimensions do not match."):
batch.add_(BatchedTensorSeq(torch.zeros(2, 1, 3), seq_dim=2))
batch.add_(BatchedTensorSeq(torch.zeros(2, 3, 1), seq_dim=2))


@mark.parametrize(
Expand Down Expand Up @@ -1284,6 +1284,48 @@ def test_batched_tensor_seq_div_incorrect_seq_dim() -> None:
BatchedTensorSeq(torch.ones(2, 3, 1)).div(BatchedTensorSeq(torch.zeros(2, 1, 3), seq_dim=2))


@mark.parametrize(
"other",
(
BatchedTensorSeq(torch.ones(2, 3).mul(2)),
BatchedTensor(torch.ones(2, 3).mul(2)),
torch.ones(2, 3).mul(2),
2,
2.0,
),
)
def test_batched_tensor_seq_div_(
other: Union[BaseBatchedTensor, torch.Tensor, bool, int, float]
) -> None:
batch = BatchedTensorSeq(torch.ones(2, 3))
batch.div_(other)
assert batch.equal(BatchedTensorSeq(torch.ones(2, 3).mul(0.5)))


def test_batched_tensor_seq_div__rounding_mode_floor() -> None:
batch = BatchedTensorSeq(torch.ones(2, 3))
batch.div_(BatchedTensorSeq(torch.ones(2, 3).mul(2)), rounding_mode="floor")
assert batch.equal(BatchedTensorSeq(torch.zeros(2, 3)))


def test_batched_tensor_seq_div__custom_dims() -> None:
batch = BatchedTensorSeq(torch.ones(2, 3), batch_dim=1, seq_dim=0)
batch.div_(BatchedTensorSeq(torch.ones(2, 3).mul(2), batch_dim=1, seq_dim=0))
assert batch.equal(BatchedTensorSeq(torch.ones(2, 3).mul(0.5), batch_dim=1, seq_dim=0))


def test_batched_tensor_seq_div__incorrect_batch_dim() -> None:
batch = BatchedTensorSeq(torch.ones(2, 3, 1))
with raises(RuntimeError, match=r"The batch dimensions do not match."):
batch.div_(BatchedTensorSeq(torch.ones(2, 3, 1), batch_dim=2))


def test_batched_tensor_seq_div__incorrect_seq_dim() -> None:
batch = BatchedTensorSeq(torch.ones(2, 3, 1))
with raises(RuntimeError, match=r"The sequence dimensions do not match."):
batch.div_(BatchedTensorSeq(torch.ones(2, 3, 1), seq_dim=2))


@mark.parametrize(
"other",
(
Expand Down