Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sqrt to BatchedTensor(Seq) #48

Merged
merged 1 commit into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redcat"
version = "0.0.1a43"
version = "0.0.1a44"
description = "A library to manipulate batches of examples"
readme = "README.md"
authors = ["Thibaut Durand <durand.tibo+gh@gmail.com>"]
Expand Down
43 changes: 41 additions & 2 deletions src/redcat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,7 +1324,7 @@ def pow(self, exponent: int | float | BaseBatchedTensor) -> TBatchedTensor:
>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]))
>>> batch.pow(2).data
>>> batch.pow(2)
tensor([[ 0., 1., 4.],
[ 9., 16., 25.]], batch_dim=0)
"""
Expand All @@ -1350,7 +1350,46 @@ def pow_(self, exponent: int | float | BaseBatchedTensor) -> None:
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]))
>>> batch.pow_(2)
>>> batch.data
>>> batch
tensor([[ 0., 1., 4.],
[ 9., 16., 25.]], batch_dim=0)
"""

def sqrt(self) -> TBatchedTensor:
r"""Computes the square-root of each element.

Return:
``BaseBatchedTensor``: A batch with the square-root of
each element.

Example usage:

.. code-block:: python

>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.tensor([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]]))
>>> batch.sqrt()
tensor([[0., 1., 2.],
[3., 4., 5.]], batch_dim=0)
"""
return torch.sqrt(self)

def sqrt_(self) -> None:
r"""Computes the square-root of each element.

In-place version of ``sqrt()``.

Example usage:

.. code-block:: python

>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.tensor([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]]))
>>> batch.sqrt_()
>>> batch
tensor([[0., 1., 2.],
[3., 4., 5.]], batch_dim=0)
"""
self._data.sqrt_()
46 changes: 46 additions & 0 deletions tests/unit/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1931,6 +1931,52 @@ def test_batched_tensor_pow__incorrect_batch_dim() -> None:
BatchedTensor(torch.ones(2, 3, 1)).pow_(BatchedTensor(torch.ones(2, 3, 1), batch_dim=2))


def test_batched_tensor_sqrt() -> None:
assert (
BatchedTensor(torch.tensor([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]], dtype=torch.float))
.sqrt()
.equal(BatchedTensor(torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], dtype=torch.float)))
)


def test_batched_tensor_sqrt_custom_dims() -> None:
assert (
BatchedTensor(
torch.tensor([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]], dtype=torch.float),
batch_dim=1,
)
.sqrt()
.equal(
BatchedTensor(
torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], dtype=torch.float),
batch_dim=1,
)
)
)


def test_batched_tensor_sqrt_() -> None:
batch = BatchedTensor(torch.tensor([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]], dtype=torch.float))
batch.sqrt_()
assert batch.equal(
BatchedTensor(torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], dtype=torch.float))
)


def test_batched_tensor_sqrt__custom_dims() -> None:
batch = BatchedTensor(
torch.tensor([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]], dtype=torch.float),
batch_dim=1,
)
batch.sqrt_()
assert batch.equal(
BatchedTensor(
torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], dtype=torch.float),
batch_dim=1,
)
)


########################################
# Tests for check_data_and_dim #
########################################
Expand Down
52 changes: 52 additions & 0 deletions tests/unit/test_tensor_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2192,6 +2192,58 @@ def test_batched_tensor_seq_pow__incorrect_seq_dim() -> None:
)


def test_batched_tensor_seq_sqrt() -> None:
assert (
BatchedTensorSeq(torch.tensor([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]], dtype=torch.float))
.sqrt()
.equal(
BatchedTensorSeq(torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], dtype=torch.float))
)
)


def test_batched_tensor_seq_sqrt_custom_dims() -> None:
assert (
BatchedTensorSeq(
torch.tensor([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]], dtype=torch.float),
batch_dim=1,
seq_dim=0,
)
.sqrt()
.equal(
BatchedTensorSeq(
torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], dtype=torch.float),
batch_dim=1,
seq_dim=0,
)
)
)


def test_batched_tensor_seq_sqrt_() -> None:
batch = BatchedTensorSeq(torch.tensor([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]], dtype=torch.float))
batch.sqrt_()
assert batch.equal(
BatchedTensorSeq(torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], dtype=torch.float))
)


def test_batched_tensor_seq_sqrt__custom_dims() -> None:
batch = BatchedTensorSeq(
torch.tensor([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]], dtype=torch.float),
batch_dim=1,
seq_dim=0,
)
batch.sqrt_()
assert batch.equal(
BatchedTensorSeq(
torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], dtype=torch.float),
batch_dim=1,
seq_dim=0,
)
)


#########################################
# Tests for check_data_and_dims #
#########################################
Expand Down