Skip to content

Commit

Permalink
Add div_ to BatchedTensor(Seq) (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo committed Apr 7, 2023
1 parent 20fb201 commit ba858c2
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redcat"
version = "0.0.1a32"
version = "0.0.1a33"
description = "A library to manipulate batches of examples"
readme = "README.md"
authors = ["Thibaut Durand <durand.tibo+gh@gmail.com>"]
Expand Down
36 changes: 35 additions & 1 deletion src/redcat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def div(
) -> TBatchedTensor:
r"""Divides the ``self`` batch by the input ``other`.
Similar to ``out = self / other`` (in-place)
Similar to ``out = self / other``
Args:
other (``BaseBatchedTensor`` or ``torch.Tensor`` or int or
Expand Down Expand Up @@ -785,6 +785,40 @@ def div(
"""
return torch.div(self, other, rounding_mode=rounding_mode)

@abstractmethod
def div_(
self,
other: BaseBatchedTensor | torch.Tensor | int | float,
rounding_mode: str | None = None,
) -> None:
r"""Divides the ``self`` batch by the input ``other`.
Similar to ``self /= other`` (in-place)
Args:
other (``BaseBatchedTensor`` or ``torch.Tensor`` or int or
float): Specifies the dividend.
rounding_mode (str or ``None``, optional): Specifies the
type of rounding applied to the result.
- ``None``: true division.
- ``"trunc"``: rounds the results of the division
towards zero.
- ``"floor"``: floor division.
Default: ``None``
Example usage:
.. code-block:: python
>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.ones(2, 3))
>>> batch.div_(BatchedTensor(torch.full((2, 3), 2.0)))
>>> batch
tensor([[0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000]], batch_dim=0)
"""

def mul(self, other: BaseBatchedTensor | torch.Tensor | int | float) -> TBatchedTensor:
r"""Multiplies the ``self`` batch by the input ``other`.
Expand Down
8 changes: 8 additions & 0 deletions src/redcat/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,14 @@ def add_(
check_batch_dims(get_batch_dims((self, other), {}))
return self._data.add_(other, alpha=alpha)

def div_(
self,
other: BaseBatchedTensor | torch.Tensor | int | float,
rounding_mode: str | None = None,
) -> None:
check_batch_dims(get_batch_dims((self, other), {}))
return self._data.div_(other, rounding_mode=rounding_mode)

def _get_kwargs(self) -> dict:
return {"batch_dim": self._batch_dim}

Expand Down
9 changes: 9 additions & 0 deletions src/redcat/tensor_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,15 @@ def add_(
check_seq_dims(get_seq_dims((self, other), {}))
return self._data.add_(other, alpha=alpha)

def div_(
self,
other: BaseBatchedTensor | torch.Tensor | int | float,
rounding_mode: str | None = None,
) -> None:
check_batch_dims(get_batch_dims((self, other), {}))
check_seq_dims(get_seq_dims((self, other), {}))
return self._data.div_(other, rounding_mode=rounding_mode)

def _get_kwargs(self) -> dict:
return {"batch_dim": self._batch_dim, "seq_dim": self._seq_dim}

Expand Down
40 changes: 38 additions & 2 deletions tests/unit/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,7 @@ def test_batched_tensor_le_custom_batch_dim() -> None:
5.0,
),
)
def test_batched_tensor_lt(other: Union[BatchedTensorSeq, torch.Tensor, bool, int, float]) -> None:
def test_batched_tensor_lt(other: Union[BaseBatchedTensor, torch.Tensor, bool, int, float]) -> None:
assert (
BatchedTensor(torch.arange(10).view(2, 5))
.lt(other)
Expand Down Expand Up @@ -1064,7 +1064,7 @@ def test_batched_tensor_add_incorrect_batch_dim() -> None:
),
)
def test_batched_tensor_add_(
other: Union[BatchedTensorSeq, torch.Tensor, bool, int, float]
other: Union[BaseBatchedTensor, torch.Tensor, bool, int, float]
) -> None:
batch = BatchedTensor(torch.ones(2, 3))
batch.add_(other)
Expand Down Expand Up @@ -1134,6 +1134,42 @@ def test_batched_tensor_div_incorrect_batch_dim() -> None:
BatchedTensor(torch.ones(2, 3)).div(BatchedTensor(torch.ones(2, 3), batch_dim=1))


@mark.parametrize(
"other",
(
BatchedTensor(torch.ones(2, 3).mul(2)),
BatchedTensorSeq(torch.ones(2, 3).mul(2)),
torch.ones(2, 3).mul(2),
2,
2.0,
),
)
def test_batched_tensor_div_(
other: Union[BaseBatchedTensor, torch.Tensor, bool, int, float]
) -> None:
batch = BatchedTensor(torch.ones(2, 3))
batch.div_(other)
assert batch.equal(BatchedTensor(torch.ones(2, 3).mul(0.5)))


def test_batched_tensor_div__rounding_mode_floor() -> None:
batch = BatchedTensor(torch.ones(2, 3))
batch.div_(BatchedTensor(torch.ones(2, 3).mul(2)), rounding_mode="floor")
assert batch.equal(BatchedTensor(torch.zeros(2, 3)))


def test_batched_tensor_div__custom_dims() -> None:
batch = BatchedTensor(torch.ones(2, 3), batch_dim=1)
batch.div_(BatchedTensor(torch.ones(2, 3).mul(2), batch_dim=1))
assert batch.equal(BatchedTensor(torch.ones(2, 3).mul(0.5), batch_dim=1))


def test_batched_tensor_div__incorrect_batch_dim() -> None:
batch = BatchedTensor(torch.ones(2, 3))
with raises(RuntimeError, match=r"The batch dimensions do not match."):
batch.div_(BatchedTensor(torch.ones(2, 3), batch_dim=1))


@mark.parametrize(
"other",
(
Expand Down
44 changes: 43 additions & 1 deletion tests/unit/test_tensor_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,7 +1233,7 @@ def test_batched_tensor_seq_add__incorrect_batch_dim() -> None:
def test_batched_tensor_seq_add__incorrect_seq_dim() -> None:
batch = BatchedTensorSeq(torch.ones(2, 3, 1))
with raises(RuntimeError, match=r"The sequence dimensions do not match."):
batch.add_(BatchedTensorSeq(torch.zeros(2, 1, 3), seq_dim=2))
batch.add_(BatchedTensorSeq(torch.zeros(2, 3, 1), seq_dim=2))


@mark.parametrize(
Expand Down Expand Up @@ -1284,6 +1284,48 @@ def test_batched_tensor_seq_div_incorrect_seq_dim() -> None:
BatchedTensorSeq(torch.ones(2, 3, 1)).div(BatchedTensorSeq(torch.zeros(2, 1, 3), seq_dim=2))


@mark.parametrize(
"other",
(
BatchedTensorSeq(torch.ones(2, 3).mul(2)),
BatchedTensor(torch.ones(2, 3).mul(2)),
torch.ones(2, 3).mul(2),
2,
2.0,
),
)
def test_batched_tensor_seq_div_(
other: Union[BaseBatchedTensor, torch.Tensor, bool, int, float]
) -> None:
batch = BatchedTensorSeq(torch.ones(2, 3))
batch.div_(other)
assert batch.equal(BatchedTensorSeq(torch.ones(2, 3).mul(0.5)))


def test_batched_tensor_seq_div__rounding_mode_floor() -> None:
batch = BatchedTensorSeq(torch.ones(2, 3))
batch.div_(BatchedTensorSeq(torch.ones(2, 3).mul(2)), rounding_mode="floor")
assert batch.equal(BatchedTensorSeq(torch.zeros(2, 3)))


def test_batched_tensor_seq_div__custom_dims() -> None:
batch = BatchedTensorSeq(torch.ones(2, 3), batch_dim=1, seq_dim=0)
batch.div_(BatchedTensorSeq(torch.ones(2, 3).mul(2), batch_dim=1, seq_dim=0))
assert batch.equal(BatchedTensorSeq(torch.ones(2, 3).mul(0.5), batch_dim=1, seq_dim=0))


def test_batched_tensor_seq_div__incorrect_batch_dim() -> None:
batch = BatchedTensorSeq(torch.ones(2, 3, 1))
with raises(RuntimeError, match=r"The batch dimensions do not match."):
batch.div_(BatchedTensorSeq(torch.ones(2, 3, 1), batch_dim=2))


def test_batched_tensor_seq_div__incorrect_seq_dim() -> None:
batch = BatchedTensorSeq(torch.ones(2, 3, 1))
with raises(RuntimeError, match=r"The sequence dimensions do not match."):
batch.div_(BatchedTensorSeq(torch.ones(2, 3, 1), seq_dim=2))


@mark.parametrize(
"other",
(
Expand Down

0 comments on commit ba858c2

Please sign in to comment.