Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve BatchedTensorSeq and BatchedTensor #7

Merged
merged 2 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ repos:
- id: python-check-blanket-noqa
- id: python-use-type-annotations
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.260
rev: v0.0.261
hooks:
- id: ruff
args: [--fix]
66 changes: 52 additions & 14 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redcat"
version = "0.0.1a3"
version = "0.0.1a4"
description = "A library to manipulate batches of examples"
readme = "README.md"
authors = ["Thibaut Durand <durand.tibo+gh@gmail.com>"]
Expand All @@ -17,6 +17,7 @@ packages = [
coola = "^0.0"
python = "^3.9"
torch = "^2.0"
numpy = "^1.24" # make optional

[tool.poetry.group.dev.dependencies]
black = "^22.10"
Expand Down
7 changes: 5 additions & 2 deletions src/redcat/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def contiguous(
True
"""
return BatchedTensor(
data=self._data.contiguous(memory_format=memory_format), batch_dim=self._batch_dim
data=self._data.contiguous(memory_format=memory_format), **self._get_kwargs()
)

def is_contiguous(self, memory_format: torch.memory_format = torch.contiguous_format) -> bool:
Expand Down Expand Up @@ -148,7 +148,7 @@ def to(self, *args, **kwargs) -> BatchedTensor:
tensor([[True, True, True],
[True, True, True]])
"""
return BatchedTensor(data=self._data.to(*args, **kwargs), batch_dim=self._batch_dim)
return BatchedTensor(data=self._data.to(*args, **kwargs), **self._get_kwargs())

#################################
# Comparison operations #
Expand Down Expand Up @@ -221,6 +221,9 @@ def add(
"""
return torch.add(self, other, alpha=alpha)

def _get_kwargs(self) -> dict:
return {"batch_dim": self._batch_dim}


def check_data_and_dim(data: Tensor, batch_dim: int) -> None:
r"""Checks if the tensor ``data`` and ``batch_dim`` are correct.
Expand Down
84 changes: 84 additions & 0 deletions src/redcat/tensor_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,87 @@ def seq_len(self) -> int:
r"""int: The sequence length."""
return self._data.shape[self._seq_dim]

#################################
# Conversion operations #
#################################

def contiguous(
self, memory_format: torch.memory_format = torch.contiguous_format
) -> BatchedTensorSeq:
r"""Creates a batch with a contiguous representation of the data.

Args:
memory_format (``torch.memory_format``, optional):
Specifies the desired memory format.
Default: ``torch.contiguous_format``

Returns:
``BatchedTensorSeq``: A new batch with a contiguous
representation of the data.

Example usage:

.. code-block:: python

>>> import torch
>>> from redcat import BatchedTensorSeq
>>> batch = BatchedTensorSeq(torch.ones(2, 3)).contiguous()
>>> batch.data.is_contiguous()
True
"""
return BatchedTensorSeq(
data=self._data.contiguous(memory_format=memory_format), **self._get_kwargs()
)

def is_contiguous(self, memory_format: torch.memory_format = torch.contiguous_format) -> bool:
r"""Indicates if a batch as a contiguous representation of the data.

Args:
memory_format (``torch.memory_format``, optional):
Specifies the desired memory format.
Default: ``torch.contiguous_format``

Returns:
bool: ``True`` if the data are stored with a contiguous
tensor, otherwise ``False``.

Example usage:

.. code-block:: python

>>> import torch
>>> from redcat import BatchedTensorSeq
>>> BatchedTensorSeq(torch.ones(2, 3)).is_contiguous()
True
"""
return self._data.is_contiguous(memory_format=memory_format)

def to(self, *args, **kwargs) -> BatchedTensorSeq:
r"""Moves and/or casts the data.

Args:
*args: see https://pytorch.org/docs/stable/generated/torch.Tensor.to.html#torch-tensor-to
**kwargs: see https://pytorch.org/docs/stable/generated/torch.Tensor.to.html#torch-tensor-to

Returns:
``BatchedTensorSeq``: A new batch with the data after dtype
and/or device conversion.

Example usage:

.. code-block:: python

>>> import torch
>>> from redcat import BatchedTensorSeq
>>> batch = BatchedTensorSeq(torch.ones(2, 3))
>>> batch_cuda = batch.to(device=torch.device('cuda:0'))
>>> batch_bool = batch.to(dtype=torch.bool)
>>> batch_bool.data
tensor([[True, True, True],
[True, True, True]])
"""
return BatchedTensorSeq(data=self._data.to(*args, **kwargs), **self._get_kwargs())

#################################
# Comparison operations #
#################################
Expand Down Expand Up @@ -92,6 +173,9 @@ def equal(self, other: Any) -> bool:
return False
return self._data.equal(other.data)

def _get_kwargs(self) -> dict:
return {"batch_dim": self._batch_dim, "seq_dim": self._seq_dim}


def check_data_and_dims(data: torch.Tensor, batch_dim: int, seq_dim: int) -> None:
r"""Checks if the tensor ``data``, ``batch_dim`` and ``seq_dim`` are
Expand Down
20 changes: 18 additions & 2 deletions tests/unit/test_tensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Union
from typing import Any, Union

import numpy as np
import torch
from pytest import mark, raises
from torch.overrides import is_tensor_like
Expand All @@ -13,8 +14,23 @@ def test_batched_tensor_is_tensor_like() -> None:
assert is_tensor_like(BatchedTensor(torch.ones(2, 3)))


@mark.parametrize(
"data",
(
torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float),
np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32),
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
((1.0, 2.0, 3.0), (4.0, 5.0, 6.0)),
),
)
def test_batched_tensor_init_data(data: Any) -> None:
assert BatchedTensor(data).data.equal(
torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float)
)


@mark.parametrize("batch_dim", (-1, 1, 2))
def test_batched_tensor_incorrect_batch_dim(batch_dim: int) -> None:
def test_batched_tensor_init_incorrect_batch_dim(batch_dim: int) -> None:
with raises(RuntimeError):
BatchedTensor(torch.ones(2), batch_dim=batch_dim)

Expand Down
Loading