Skip to content

Commit

Permalink
Add add_ to BatchedTensor(Seq) (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo committed Apr 7, 2023
1 parent 53ebb2f commit 20fb201
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 30 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redcat"
version = "0.0.1a31"
version = "0.0.1a32"
description = "A library to manipulate batches of examples"
readme = "README.md"
authors = ["Thibaut Durand <durand.tibo+gh@gmail.com>"]
Expand Down
37 changes: 34 additions & 3 deletions src/redcat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,9 +671,9 @@ def long(self) -> TBatchedTensor:
"""
return self.__class__(self._data.long(), **self._get_kwargs())

###################################
# Arithmetical operations #
###################################
##################################################
# Mathematical | arithmetical operations #
##################################################

def add(
self,
Expand Down Expand Up @@ -713,6 +713,37 @@ def add(
"""
return torch.add(self, other, alpha=alpha)

@abstractmethod
def add_(
self,
other: BaseBatchedTensor | Tensor | int | float,
alpha: int | float = 1.0,
) -> None:
r"""Adds the input ``other``, scaled by ``alpha``, to the ``self``
batch.
Similar to ``self += alpha * other`` (in-place)
Args:
other (``BaseBatchedTensor`` or ``torch.Tensor`` or int or
float): Specifies the other value to add to the
current batch.
alpha (int or float, optional): Specifies the scale of the
batch to add. Default: ``1.0``
Example usage:
.. code-block:: python
>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.ones(2, 3))
>>> batch.add_(BatchedTensor(torch.full((2, 3), 2.0)))
>>> batch
tensor([[3., 3., 3.],
[3., 3., 3.]], batch_dim=0)
"""

def div(
self,
other: BaseBatchedTensor | torch.Tensor | int | float,
Expand Down
23 changes: 17 additions & 6 deletions src/redcat/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch import Tensor

from redcat.base import BaseBatchedTensor
from redcat.utils import check_batch_dims, get_batch_dims


class BatchedTensor(BaseBatchedTensor):
Expand Down Expand Up @@ -40,13 +41,11 @@ def __torch_function__(
args: tuple[Any, ...] = (),
kwargs: dict[str, Any] | None = None,
) -> BatchedTensor:
batch_dims = {a._batch_dim for a in args if hasattr(a, "_batch_dim")}
if len(batch_dims) > 1:
raise RuntimeError(
f"The batch dimensions do not match. Received multiple values: {batch_dims}"
)
kwargs = kwargs or {}
batch_dims = get_batch_dims(args, kwargs)
check_batch_dims(batch_dims)
args = [a._data if hasattr(a, "_data") else a for a in args]
return cls(func(*args, **(kwargs or {})), batch_dim=batch_dims.pop())
return cls(func(*args, **kwargs), batch_dim=batch_dims.pop())

@property
def batch_dim(self) -> int:
Expand Down Expand Up @@ -225,6 +224,18 @@ def equal(self, other: Any) -> bool:
return False
return self._data.equal(other.data)

##################################################
# Mathematical | arithmetical operations #
##################################################

def add_(
self,
other: BaseBatchedTensor | Tensor | int | float,
alpha: int | float = 1.0,
) -> None:
check_batch_dims(get_batch_dims((self, other), {}))
return self._data.add_(other, alpha=alpha)

def _get_kwargs(self) -> dict:
return {"batch_dim": self._batch_dim}

Expand Down
34 changes: 21 additions & 13 deletions src/redcat/tensor_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from typing import Any

import torch
from torch import Tensor

from redcat.base import BaseBatchedTensor
from redcat.utils import check_batch_dims, check_seq_dims, get_batch_dims, get_seq_dims


class BatchedTensorSeq(BaseBatchedTensor):
Expand Down Expand Up @@ -42,20 +44,13 @@ def __torch_function__(
args: tuple[Any, ...] = (),
kwargs: dict[str, Any] | None = None,
) -> BatchedTensorSeq:
batch_dims = {a._batch_dim for a in args if hasattr(a, "_batch_dim")}
if len(batch_dims) > 1:
raise RuntimeError(
f"The batch dimensions do not match. Received multiple values: {batch_dims}"
)
seq_dims = {a._seq_dim for a in args if hasattr(a, "_seq_dim")}
if len(seq_dims) > 1:
raise RuntimeError(
f"The sequence dimensions do not match. Received multiple values: {seq_dims}"
)
kwargs = kwargs or {}
batch_dims = get_batch_dims(args, kwargs)
check_batch_dims(batch_dims)
seq_dims = get_seq_dims(args, kwargs)
check_seq_dims(seq_dims)
args = [a._data if hasattr(a, "_data") else a for a in args]
return cls(
func(*args, **(kwargs or {})), batch_dim=batch_dims.pop(), seq_dim=seq_dims.pop()
)
return cls(func(*args, **kwargs), batch_dim=batch_dims.pop(), seq_dim=seq_dims.pop())

@property
def batch_dim(self) -> int:
Expand Down Expand Up @@ -295,6 +290,19 @@ def equal(self, other: Any) -> bool:
return False
return self._data.equal(other.data)

##################################################
# Mathematical | arithmetical operations #
##################################################

def add_(
self,
other: BaseBatchedTensor | Tensor | int | float,
alpha: int | float = 1.0,
) -> None:
check_batch_dims(get_batch_dims((self, other), {}))
check_seq_dims(get_seq_dims((self, other), {}))
return self._data.add_(other, alpha=alpha)

def _get_kwargs(self) -> dict:
return {"batch_dim": self._batch_dim, "seq_dim": self._seq_dim}

Expand Down
48 changes: 45 additions & 3 deletions tests/unit/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,9 +1003,9 @@ def test_batched_tensor_long_custom_batch_dim() -> None:
)


###################################
# Arithmetical operations #
###################################
##################################################
# Mathematical | arithmetical operations #
##################################################


@mark.parametrize(
Expand Down Expand Up @@ -1053,6 +1053,48 @@ def test_batched_tensor_add_incorrect_batch_dim() -> None:
BatchedTensor(torch.ones(2, 3)).add(BatchedTensor(torch.ones(2, 3), batch_dim=1))


@mark.parametrize(
"other",
(
BatchedTensor(torch.full((2, 3), 2.0)),
BatchedTensorSeq(torch.full((2, 3), 2.0)),
torch.full((2, 3), 2.0),
2,
2.0,
),
)
def test_batched_tensor_add_(
other: Union[BatchedTensorSeq, torch.Tensor, bool, int, float]
) -> None:
batch = BatchedTensor(torch.ones(2, 3))
batch.add_(other)
assert batch.equal(BatchedTensor(torch.full((2, 3), 3.0)))


def test_batched_tensor_add__alpha_2_float() -> None:
batch = BatchedTensor(torch.ones(2, 3))
batch.add_(BatchedTensor(torch.full((2, 3), 2.0)), alpha=2.0)
assert batch.equal(BatchedTensor(torch.full((2, 3), 5.0)))


def test_batched_tensor_add__alpha_2_long() -> None:
batch = BatchedTensor(torch.ones(2, 3, dtype=torch.long))
batch.add_(BatchedTensor(torch.ones(2, 3, dtype=torch.long).mul(2)), alpha=2)
assert batch.equal(BatchedTensor(torch.ones(2, 3, dtype=torch.long).mul(5)))


def test_batched_tensor_add__custom_dims() -> None:
batch = BatchedTensor(torch.ones(2, 3), batch_dim=1)
batch.add_(BatchedTensor(torch.full((2, 3), 2.0), batch_dim=1))
assert batch.equal(BatchedTensor(torch.full((2, 3), 3.0), batch_dim=1))


def test_batched_tensor_add__incorrect_batch_dim() -> None:
batch = BatchedTensor(torch.ones(2, 3))
with raises(RuntimeError, match=r"The batch dimensions do not match."):
batch.add_(BatchedTensor(torch.ones(2, 3), batch_dim=1))


@mark.parametrize(
"other",
(
Expand Down
56 changes: 52 additions & 4 deletions tests/unit/test_tensor_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,9 +1127,9 @@ def test_batched_tensor_seq_long_custom_dims() -> None:
)


###################################
# Arithmetical operations #
###################################
##################################################
# Mathematical | arithmetical operations #
##################################################


@mark.parametrize(
Expand Down Expand Up @@ -1188,6 +1188,54 @@ def test_batched_tensor_seq_add_incorrect_seq_dim() -> None:
)


@mark.parametrize(
"other",
(
BatchedTensorSeq(torch.full((2, 3), 2.0)),
BatchedTensor(torch.full((2, 3), 2.0)),
torch.full((2, 3), 2.0),
2,
2.0,
),
)
def test_batched_tensor_seq_add_(
other: Union[BatchedTensorSeq, torch.Tensor, bool, int, float]
) -> None:
batch = BatchedTensorSeq(torch.ones(2, 3))
batch.add_(other)
assert batch.equal(BatchedTensorSeq(torch.full((2, 3), 3.0)))


def test_batched_tensor_seq_add__alpha_2_float() -> None:
batch = BatchedTensorSeq(torch.ones(2, 3))
batch.add_(BatchedTensorSeq(torch.full((2, 3), 2.0)), alpha=2.0)
assert batch.equal(BatchedTensorSeq(torch.full((2, 3), 5.0)))


def test_batched_tensor_seq_add__alpha_2_long() -> None:
batch = BatchedTensorSeq(torch.ones(2, 3, dtype=torch.long))
batch.add_(BatchedTensorSeq(torch.ones(2, 3, dtype=torch.long).mul(2)), alpha=2)
assert batch.equal(BatchedTensorSeq(torch.ones(2, 3, dtype=torch.long).mul(5)))


def test_batched_tensor_seq_add__custom_dims() -> None:
batch = BatchedTensorSeq(torch.ones(2, 3), batch_dim=1, seq_dim=0)
batch.add_(BatchedTensorSeq(torch.full((2, 3), 2.0), batch_dim=1, seq_dim=0))
assert batch.equal(BatchedTensorSeq(torch.full((2, 3), 3.0), batch_dim=1, seq_dim=0))


def test_batched_tensor_seq_add__incorrect_batch_dim() -> None:
batch = BatchedTensorSeq(torch.ones(2, 3, 1))
with raises(RuntimeError, match=r"The batch dimensions do not match."):
batch.add_(BatchedTensorSeq(torch.ones(2, 3, 1), batch_dim=2))


def test_batched_tensor_seq_add__incorrect_seq_dim() -> None:
batch = BatchedTensorSeq(torch.ones(2, 3, 1))
with raises(RuntimeError, match=r"The sequence dimensions do not match."):
batch.add_(BatchedTensorSeq(torch.zeros(2, 1, 3), seq_dim=2))


@mark.parametrize(
"other",
(
Expand Down Expand Up @@ -1296,7 +1344,7 @@ def test_batched_tensor_seq_sub_alpha_2_float() -> None:
assert (
BatchedTensorSeq(torch.ones(2, 3))
.sub(BatchedTensorSeq(torch.full((2, 3), 2.0)), alpha=2.0)
.equal(BatchedTensorSeq(-torch.ones(2, 3).mul(3)))
.equal(BatchedTensorSeq(-torch.full((2, 3), 3.0)))
)


Expand Down

0 comments on commit 20fb201

Please sign in to comment.