Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add logical_xor to BatchedTensor(Seq) #69

Merged
merged 1 commit into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redcat"
version = "0.0.1a62"
version = "0.0.1a63"
description = "A library to manipulate batches of examples"
readme = "README.md"
authors = ["Thibaut Durand <durand.tibo+gh@gmail.com>"]
Expand Down
61 changes: 61 additions & 0 deletions src/redcat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2222,3 +2222,64 @@ def logical_or_(self, other: BaseBatchedTensor | Tensor) -> None:
tensor([[ True, True, True, False],
[ True, True, True, True]], batch_dim=0)
"""

def logical_xor(self, other: BaseBatchedTensor | Tensor) -> TBatchedTensor:
r"""Computes the element-wise logical XOR.

Zeros are treated as ``False`` and non-zeros are treated as
``True``.

Args:
other (``BaseBatchedTensor`` or ``torch.Tensor``):
Specifies the batch to compute logical XOR with.

Returns:
``BaseBatchedTensor``: A batch containing the element-wise
logical XOR.

Example usage:

.. code-block:: python

>>> import torch
>>> from redcat import BatchedTensor
>>> batch1 = BatchedTensor(
... torch.tensor([[True, True, False, False], [True, False, True, False]])
... )
>>> batch2 = BatchedTensor(
... torch.tensor([[True, False, True, False], [True, True, True, True]])
... )
>>> batch1.logical_xor(batch2)
tensor([[False, True, True, False],
[False, True, False, True]], batch_dim=0)
"""
return torch.logical_xor(self, other)

@abstractmethod
def logical_xor_(self, other: BaseBatchedTensor | Tensor) -> None:
r"""Computes the element-wise logical XOR.

Zeros are treated as ``False`` and non-zeros are treated as
``True``.

Args:
other (``BaseBatchedTensor`` or ``torch.Tensor``):
Specifies the batch to compute logical XOR with.

Example usage:

.. code-block:: python

>>> import torch
>>> from redcat import BatchedTensor
>>> batch1 = BatchedTensor(
... torch.tensor([[True, True, False, False], [True, False, True, False]])
... )
>>> batch2 = BatchedTensor(
... torch.tensor([[True, False, True, False], [True, True, True, True]])
... )
>>> batch1.logical_xor_(batch2)
>>> batch1
tensor([[False, True, True, False],
[False, True, False, True]], batch_dim=0)
"""
4 changes: 4 additions & 0 deletions src/redcat/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,10 @@ def logical_or_(self, other: BaseBatchedTensor | Tensor) -> None:
check_batch_dims(get_batch_dims((self, other), {}))
self._data.logical_or_(other)

def logical_xor_(self, other: BaseBatchedTensor | Tensor) -> None:
check_batch_dims(get_batch_dims((self, other), {}))
self._data.logical_xor_(other)

def _get_kwargs(self) -> dict:
return {"batch_dim": self._batch_dim}

Expand Down
5 changes: 5 additions & 0 deletions src/redcat/tensor_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,11 @@ def logical_or_(self, other: BaseBatchedTensor | Tensor) -> None:
check_seq_dims(get_seq_dims((self, other), {}))
self._data.logical_or_(other)

def logical_xor_(self, other: BaseBatchedTensor | Tensor) -> None:
check_batch_dims(get_batch_dims((self, other), {}))
check_seq_dims(get_seq_dims((self, other), {}))
self._data.logical_xor_(other)

def _get_kwargs(self) -> dict:
return {"batch_dim": self._batch_dim, "seq_dim": self._seq_dim}

Expand Down
130 changes: 130 additions & 0 deletions tests/unit/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3258,6 +3258,136 @@ def test_batched_tensor_logical_or__incorrect_batch_dim() -> None:
)


@mark.parametrize(
"other",
(
BatchedTensor(
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool)
),
BatchedTensorSeq(
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool)
),
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool),
BatchedTensor(torch.tensor([[1, 0, 1, 0], [1, 1, 1, 1]], dtype=torch.float)),
BatchedTensorSeq(torch.tensor([[1, 0, 1, 0], [1, 1, 1, 1]], dtype=torch.float)),
torch.tensor([[1, 0, 1, 0], [1, 1, 1, 1]], dtype=torch.float),
),
)
@mark.parametrize("dtype", (torch.bool, torch.float, torch.long))
def test_batched_tensor_logical_xor(other: BaseBatchedTensor | Tensor, dtype: torch.dtype) -> None:
assert (
BatchedTensor(
torch.tensor([[True, True, False, False], [True, False, True, False]], dtype=dtype)
)
.logical_xor(other)
.equal(
BatchedTensor(
torch.tensor(
[[False, True, True, False], [False, True, False, True]], dtype=torch.bool
)
)
)
)


def test_batched_tensor_logical_xor_custom_dims() -> None:
assert (
BatchedTensor(
torch.tensor(
[[True, True, False, False], [True, False, True, False]], dtype=torch.bool
),
batch_dim=1,
)
.logical_xor(
BatchedTensor(
torch.tensor(
[[True, False, True, False], [True, True, True, True]], dtype=torch.bool
),
batch_dim=1,
)
)
.equal(
BatchedTensor(
torch.tensor(
[[False, True, True, False], [False, True, False, True]], dtype=torch.bool
),
batch_dim=1,
)
)
)


def test_batched_tensor_logical_xor_incorrect_batch_dim() -> None:
batch = BatchedTensor(torch.zeros(2, 3, dtype=torch.bool))
with raises(RuntimeError, match=r"The batch dimensions do not match."):
batch.logical_xor(
BatchedTensor(
torch.zeros(2, 3, dtype=torch.bool),
batch_dim=1,
)
)


@mark.parametrize(
"other",
(
BatchedTensor(
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool)
),
BatchedTensorSeq(
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool)
),
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool),
BatchedTensor(torch.tensor([[1, 0, 1, 0], [1, 1, 1, 1]], dtype=torch.float)),
BatchedTensorSeq(torch.tensor([[1, 0, 1, 0], [1, 1, 1, 1]], dtype=torch.float)),
torch.tensor([[1, 0, 1, 0], [1, 1, 1, 1]], dtype=torch.float),
),
)
@mark.parametrize("dtype", (torch.bool, torch.float, torch.long))
def test_batched_tensor_logical_xor_(other: BaseBatchedTensor | Tensor, dtype: torch.dtype) -> None:
batch = BatchedTensor(
torch.tensor([[True, True, False, False], [True, False, True, False]], dtype=dtype)
)
batch.logical_xor_(other)
assert batch.equal(
BatchedTensor(
torch.tensor([[False, True, True, False], [False, True, False, True]], dtype=torch.bool)
)
)


def test_batched_tensor_logical_xor__custom_dims() -> None:
batch = BatchedTensor(
torch.tensor([[True, True, False, False], [True, False, True, False]], dtype=torch.bool),
batch_dim=1,
)
batch.logical_xor_(
BatchedTensor(
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool),
batch_dim=1,
)
)
assert batch.equal(
BatchedTensor(
torch.tensor(
[[False, True, True, False], [False, True, False, True]], dtype=torch.bool
),
batch_dim=1,
)
)


def test_batched_tensor_logical_xor__incorrect_batch_dim() -> None:
batch = BatchedTensor(torch.zeros(2, 3, dtype=torch.bool))
with raises(RuntimeError, match=r"The batch dimensions do not match."):
batch.logical_xor_(
BatchedTensor(
torch.zeros(2, 3, dtype=torch.bool),
batch_dim=1,
)
)


########################################
# Tests for check_data_and_dim #
########################################
Expand Down
152 changes: 152 additions & 0 deletions tests/unit/test_tensor_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3639,6 +3639,158 @@ def test_batched_tensor_seq_logical_or__incorrect_seq_dim() -> None:
batch.logical_or_(BatchedTensorSeq(torch.zeros(2, 3, 1, dtype=torch.bool), seq_dim=2))


@mark.parametrize(
"other",
(
BatchedTensorSeq(
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool)
),
BatchedTensor(
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool)
),
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool),
BatchedTensorSeq(torch.tensor([[1, 0, 1, 0], [1, 1, 1, 1]], dtype=torch.float)),
BatchedTensor(torch.tensor([[1, 0, 1, 0], [1, 1, 1, 1]], dtype=torch.float)),
torch.tensor([[1, 0, 1, 0], [1, 1, 1, 1]], dtype=torch.float),
),
)
@mark.parametrize("dtype", (torch.bool, torch.float, torch.long))
def test_batched_tensor_seq_logical_xor(
other: BaseBatchedTensor | Tensor, dtype: torch.dtype
) -> None:
assert (
BatchedTensorSeq(
torch.tensor([[True, True, False, False], [True, False, True, False]], dtype=dtype)
)
.logical_xor(other)
.equal(
BatchedTensorSeq(
torch.tensor(
[[False, True, True, False], [False, True, False, True]], dtype=torch.bool
)
)
)
)


def test_batched_tensor_seq_logical_xor_custom_dims() -> None:
assert (
BatchedTensorSeq(
torch.tensor(
[[True, True, False, False], [True, False, True, False]], dtype=torch.bool
),
batch_dim=1,
seq_dim=0,
)
.logical_xor(
BatchedTensorSeq(
torch.tensor(
[[True, False, True, False], [True, True, True, True]], dtype=torch.bool
),
batch_dim=1,
seq_dim=0,
)
)
.equal(
BatchedTensorSeq(
torch.tensor(
[[False, True, True, False], [False, True, False, True]], dtype=torch.bool
),
batch_dim=1,
seq_dim=0,
)
)
)


def test_batched_tensor_seq_logical_xor_incorrect_batch_dim() -> None:
batch = BatchedTensorSeq(torch.zeros(2, 3, 1, dtype=torch.bool))
with raises(RuntimeError, match=r"The batch dimensions do not match."):
batch.logical_xor(
BatchedTensorSeq(
torch.zeros(2, 3, 1, dtype=torch.bool),
batch_dim=2,
)
)


def test_batched_tensor_seq_logical_xor_incorrect_seq_dim() -> None:
batch = BatchedTensorSeq(torch.ones(2, 3, 1, dtype=torch.bool))
with raises(RuntimeError, match=r"The sequence dimensions do not match."):
batch.logical_xor(BatchedTensorSeq(torch.zeros(2, 3, 1, dtype=torch.bool), seq_dim=2))


@mark.parametrize(
"other",
(
BatchedTensorSeq(
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool)
),
BatchedTensor(
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool)
),
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool),
BatchedTensorSeq(torch.tensor([[1, 0, 1, 0], [1, 1, 1, 1]], dtype=torch.float)),
BatchedTensor(torch.tensor([[1, 0, 1, 0], [1, 1, 1, 1]], dtype=torch.float)),
torch.tensor([[1, 0, 1, 0], [1, 1, 1, 1]], dtype=torch.float),
),
)
@mark.parametrize("dtype", (torch.bool, torch.float, torch.long))
def test_batched_tensor_seq_logical_xor_(
other: BaseBatchedTensor | Tensor, dtype: torch.dtype
) -> None:
batch = BatchedTensorSeq(
torch.tensor([[True, True, False, False], [True, False, True, False]], dtype=dtype)
)
batch.logical_xor_(other)
assert batch.equal(
BatchedTensorSeq(
torch.tensor([[False, True, True, False], [False, True, False, True]], dtype=torch.bool)
)
)


def test_batched_tensor_seq_logical_xor__custom_dims() -> None:
batch = BatchedTensorSeq(
torch.tensor([[True, True, False, False], [True, False, True, False]], dtype=torch.bool),
batch_dim=1,
seq_dim=0,
)
batch.logical_xor_(
BatchedTensorSeq(
torch.tensor([[True, False, True, False], [True, True, True, True]], dtype=torch.bool),
batch_dim=1,
seq_dim=0,
)
)
assert batch.equal(
BatchedTensorSeq(
torch.tensor(
[[False, True, True, False], [False, True, False, True]], dtype=torch.bool
),
batch_dim=1,
seq_dim=0,
)
)


def test_batched_tensor_seq_logical_xor__incorrect_batch_dim() -> None:
batch = BatchedTensorSeq(torch.zeros(2, 3, 1, dtype=torch.bool))
with raises(RuntimeError, match=r"The batch dimensions do not match."):
batch.logical_xor_(
BatchedTensorSeq(
torch.zeros(2, 3, 1, dtype=torch.bool),
batch_dim=2,
)
)


def test_batched_tensor_seq_logical_xor__incorrect_seq_dim() -> None:
batch = BatchedTensorSeq(torch.ones(2, 3, 1, dtype=torch.bool))
with raises(RuntimeError, match=r"The sequence dimensions do not match."):
batch.logical_xor_(BatchedTensorSeq(torch.zeros(2, 3, 1, dtype=torch.bool), seq_dim=2))


#########################################
# Tests for check_data_and_dims #
#########################################
Expand Down