Skip to content

Commit

Permalink
Add logical_xor to BatchedTensor(Seq) (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo committed Apr 17, 2023
1 parent dbf6046 commit 9b8fd11
Show file tree
Hide file tree
Showing 6 changed files with 353 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redcat"
version = "0.0.1a62"
version = "0.0.1a63"
description = "A library to manipulate batches of examples"
readme = "README.md"
authors = ["Thibaut Durand <durand.tibo+gh@gmail.com>"]
Expand Down
61 changes: 61 additions & 0 deletions src/redcat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2222,3 +2222,64 @@ def logical_or_(self, other: BaseBatchedTensor | Tensor) -> None:
tensor([[ True, True, True, False],
[ True, True, True, True]], batch_dim=0)
"""

def logical_xor(self, other: BaseBatchedTensor | Tensor) -> TBatchedTensor:
r"""Computes the element-wise logical XOR.
Zeros are treated as ``False`` and non-zeros are treated as
``True``.
Args:
other (``BaseBatchedTensor`` or ``torch.Tensor``):
Specifies the batch to compute logical XOR with.
Returns:
``BaseBatchedTensor``: A batch containing the element-wise
logical XOR.
Example usage:
.. code-block:: python
>>> import torch
>>> from redcat import BatchedTensor
>>> batch1 = BatchedTensor(
... torch.tensor([[True, True, False, False], [True, False, True, False]])
... )
>>> batch2 = BatchedTensor(
... torch.tensor([[True, False, True, False], [True, True, True, True]])
... )
>>> batch1.logical_xor(batch2)
tensor([[False, True, True, False],
[False, True, False, True]], batch_dim=0)
"""
return torch.logical_xor(self, other)

@abstractmethod
def logical_xor_(self, other: BaseBatchedTensor | Tensor) -> None:
r"""Computes the element-wise logical XOR.
Zeros are treated as ``False`` and non-zeros are treated as
``True``.
Args:
other (``BaseBatchedTensor`` or ``torch.Tensor``):
Specifies the batch to compute logical XOR with.
Example usage:
.. code-block:: python
>>> import torch
>>> from redcat import BatchedTensor
>>> batch1 = BatchedTensor(
... torch.tensor([[True, True, False, False], [True, False, True, False]])
... )
>>> batch2 = BatchedTensor(
... torch.tensor([[True, False, True, False], [True, True, True, True]])
... )
>>> batch1.logical_xor_(batch2)
>>> batch1
tensor([[False, True, True, False],
[False, True, False, True]], batch_dim=0)
"""
4 changes: 4 additions & 0 deletions src/redcat/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,10 @@ def logical_or_(self, other: BaseBatchedTensor | Tensor) -> None:
check_batch_dims(get_batch_dims((self, other), {}))
self._data.logical_or_(other)

def logical_xor_(self, other: BaseBatchedTensor | Tensor) -> None:
check_batch_dims(get_batch_dims((self, other), {}))
self._data.logical_xor_(other)

def _get_kwargs(self) -> dict:
return {"batch_dim": self._batch_dim}

Expand Down
5 changes: 5 additions & 0 deletions src/redcat/tensor_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,11 @@ def logical_or_(self, other: BaseBatchedTensor | Tensor) -> None:
check_seq_dims(get_seq_dims((self, other), {}))
self._data.logical_or_(other)

def logical_xor_(self, other: BaseBatchedTensor | Tensor) -> None:
check_batch_dims(get_batch_dims((self, other), {}))
check_seq_dims(get_seq_dims((self, other), {}))
self._data.logical_xor_(other)

def _get_kwargs(self) -> dict:
return {"batch_dim": self._batch_dim, "seq_dim": self._seq_dim}

Expand Down
130 changes: 130 additions & 0 deletions tests/unit/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3258,6 +3258,136 @@ def test_batched_tensor_logical_or__incorrect_batch_dim() -> None:
)


@mark.parametrize(
"other",
(
BatchedTensor(
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool)
),
BatchedTensorSeq(
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool)
),
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool),
BatchedTensor(torch.tensor([[1, 0, 1, 0], [1, 1, 1, 1]], dtype=torch.float)),
BatchedTensorSeq(torch.tensor([[1, 0, 1, 0], [1, 1, 1, 1]], dtype=torch.float)),
torch.tensor([[1, 0, 1, 0], [1, 1, 1, 1]], dtype=torch.float),
),
)
@mark.parametrize("dtype", (torch.bool, torch.float, torch.long))
def test_batched_tensor_logical_xor(other: BaseBatchedTensor | Tensor, dtype: torch.dtype) -> None:
assert (
BatchedTensor(
torch.tensor([[True, True, False, False], [True, False, True, False]], dtype=dtype)
)
.logical_xor(other)
.equal(
BatchedTensor(
torch.tensor(
[[False, True, True, False], [False, True, False, True]], dtype=torch.bool
)
)
)
)


def test_batched_tensor_logical_xor_custom_dims() -> None:
assert (
BatchedTensor(
torch.tensor(
[[True, True, False, False], [True, False, True, False]], dtype=torch.bool
),
batch_dim=1,
)
.logical_xor(
BatchedTensor(
torch.tensor(
[[True, False, True, False], [True, True, True, True]], dtype=torch.bool
),
batch_dim=1,
)
)
.equal(
BatchedTensor(
torch.tensor(
[[False, True, True, False], [False, True, False, True]], dtype=torch.bool
),
batch_dim=1,
)
)
)


def test_batched_tensor_logical_xor_incorrect_batch_dim() -> None:
batch = BatchedTensor(torch.zeros(2, 3, dtype=torch.bool))
with raises(RuntimeError, match=r"The batch dimensions do not match."):
batch.logical_xor(
BatchedTensor(
torch.zeros(2, 3, dtype=torch.bool),
batch_dim=1,
)
)


@mark.parametrize(
"other",
(
BatchedTensor(
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool)
),
BatchedTensorSeq(
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool)
),
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool),
BatchedTensor(torch.tensor([[1, 0, 1, 0], [1, 1, 1, 1]], dtype=torch.float)),
BatchedTensorSeq(torch.tensor([[1, 0, 1, 0], [1, 1, 1, 1]], dtype=torch.float)),
torch.tensor([[1, 0, 1, 0], [1, 1, 1, 1]], dtype=torch.float),
),
)
@mark.parametrize("dtype", (torch.bool, torch.float, torch.long))
def test_batched_tensor_logical_xor_(other: BaseBatchedTensor | Tensor, dtype: torch.dtype) -> None:
batch = BatchedTensor(
torch.tensor([[True, True, False, False], [True, False, True, False]], dtype=dtype)
)
batch.logical_xor_(other)
assert batch.equal(
BatchedTensor(
torch.tensor([[False, True, True, False], [False, True, False, True]], dtype=torch.bool)
)
)


def test_batched_tensor_logical_xor__custom_dims() -> None:
batch = BatchedTensor(
torch.tensor([[True, True, False, False], [True, False, True, False]], dtype=torch.bool),
batch_dim=1,
)
batch.logical_xor_(
BatchedTensor(
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool),
batch_dim=1,
)
)
assert batch.equal(
BatchedTensor(
torch.tensor(
[[False, True, True, False], [False, True, False, True]], dtype=torch.bool
),
batch_dim=1,
)
)


def test_batched_tensor_logical_xor__incorrect_batch_dim() -> None:
batch = BatchedTensor(torch.zeros(2, 3, dtype=torch.bool))
with raises(RuntimeError, match=r"The batch dimensions do not match."):
batch.logical_xor_(
BatchedTensor(
torch.zeros(2, 3, dtype=torch.bool),
batch_dim=1,
)
)


########################################
# Tests for check_data_and_dim #
########################################
Expand Down
152 changes: 152 additions & 0 deletions tests/unit/test_tensor_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3639,6 +3639,158 @@ def test_batched_tensor_seq_logical_or__incorrect_seq_dim() -> None:
batch.logical_or_(BatchedTensorSeq(torch.zeros(2, 3, 1, dtype=torch.bool), seq_dim=2))


@mark.parametrize(
"other",
(
BatchedTensorSeq(
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool)
),
BatchedTensor(
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool)
),
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool),
BatchedTensorSeq(torch.tensor([[1, 0, 1, 0], [1, 1, 1, 1]], dtype=torch.float)),
BatchedTensor(torch.tensor([[1, 0, 1, 0], [1, 1, 1, 1]], dtype=torch.float)),
torch.tensor([[1, 0, 1, 0], [1, 1, 1, 1]], dtype=torch.float),
),
)
@mark.parametrize("dtype", (torch.bool, torch.float, torch.long))
def test_batched_tensor_seq_logical_xor(
other: BaseBatchedTensor | Tensor, dtype: torch.dtype
) -> None:
assert (
BatchedTensorSeq(
torch.tensor([[True, True, False, False], [True, False, True, False]], dtype=dtype)
)
.logical_xor(other)
.equal(
BatchedTensorSeq(
torch.tensor(
[[False, True, True, False], [False, True, False, True]], dtype=torch.bool
)
)
)
)


def test_batched_tensor_seq_logical_xor_custom_dims() -> None:
assert (
BatchedTensorSeq(
torch.tensor(
[[True, True, False, False], [True, False, True, False]], dtype=torch.bool
),
batch_dim=1,
seq_dim=0,
)
.logical_xor(
BatchedTensorSeq(
torch.tensor(
[[True, False, True, False], [True, True, True, True]], dtype=torch.bool
),
batch_dim=1,
seq_dim=0,
)
)
.equal(
BatchedTensorSeq(
torch.tensor(
[[False, True, True, False], [False, True, False, True]], dtype=torch.bool
),
batch_dim=1,
seq_dim=0,
)
)
)


def test_batched_tensor_seq_logical_xor_incorrect_batch_dim() -> None:
batch = BatchedTensorSeq(torch.zeros(2, 3, 1, dtype=torch.bool))
with raises(RuntimeError, match=r"The batch dimensions do not match."):
batch.logical_xor(
BatchedTensorSeq(
torch.zeros(2, 3, 1, dtype=torch.bool),
batch_dim=2,
)
)


def test_batched_tensor_seq_logical_xor_incorrect_seq_dim() -> None:
batch = BatchedTensorSeq(torch.ones(2, 3, 1, dtype=torch.bool))
with raises(RuntimeError, match=r"The sequence dimensions do not match."):
batch.logical_xor(BatchedTensorSeq(torch.zeros(2, 3, 1, dtype=torch.bool), seq_dim=2))


@mark.parametrize(
"other",
(
BatchedTensorSeq(
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool)
),
BatchedTensor(
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool)
),
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool),
BatchedTensorSeq(torch.tensor([[1, 0, 1, 0], [1, 1, 1, 1]], dtype=torch.float)),
BatchedTensor(torch.tensor([[1, 0, 1, 0], [1, 1, 1, 1]], dtype=torch.float)),
torch.tensor([[1, 0, 1, 0], [1, 1, 1, 1]], dtype=torch.float),
),
)
@mark.parametrize("dtype", (torch.bool, torch.float, torch.long))
def test_batched_tensor_seq_logical_xor_(
other: BaseBatchedTensor | Tensor, dtype: torch.dtype
) -> None:
batch = BatchedTensorSeq(
torch.tensor([[True, True, False, False], [True, False, True, False]], dtype=dtype)
)
batch.logical_xor_(other)
assert batch.equal(
BatchedTensorSeq(
torch.tensor([[False, True, True, False], [False, True, False, True]], dtype=torch.bool)
)
)


def test_batched_tensor_seq_logical_xor__custom_dims() -> None:
batch = BatchedTensorSeq(
torch.tensor([[True, True, False, False], [True, False, True, False]], dtype=torch.bool),
batch_dim=1,
seq_dim=0,
)
batch.logical_xor_(
BatchedTensorSeq(
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool),
batch_dim=1,
seq_dim=0,
)
)
assert batch.equal(
BatchedTensorSeq(
torch.tensor(
[[False, True, True, False], [False, True, False, True]], dtype=torch.bool
),
batch_dim=1,
seq_dim=0,
)
)


def test_batched_tensor_seq_logical_xor__incorrect_batch_dim() -> None:
batch = BatchedTensorSeq(torch.zeros(2, 3, 1, dtype=torch.bool))
with raises(RuntimeError, match=r"The batch dimensions do not match."):
batch.logical_xor_(
BatchedTensorSeq(
torch.zeros(2, 3, 1, dtype=torch.bool),
batch_dim=2,
)
)


def test_batched_tensor_seq_logical_xor__incorrect_seq_dim() -> None:
batch = BatchedTensorSeq(torch.ones(2, 3, 1, dtype=torch.bool))
with raises(RuntimeError, match=r"The sequence dimensions do not match."):
batch.logical_xor_(BatchedTensorSeq(torch.zeros(2, 3, 1, dtype=torch.bool), seq_dim=2))


#########################################
# Tests for check_data_and_dims #
#########################################
Expand Down

0 comments on commit 9b8fd11

Please sign in to comment.