Skip to content

Commit

Permalink
Start to implement BatchList (#238)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo committed May 14, 2023
1 parent 12f78dd commit aea1f61
Show file tree
Hide file tree
Showing 5 changed files with 474 additions and 10 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redcat"
version = "0.0.1a194"
version = "0.0.1a195"
description = "A library to manipulate batches of examples"
readme = "README.md"
authors = ["Thibaut Durand <durand.tibo+gh@gmail.com>"]
Expand Down
8 changes: 7 additions & 1 deletion src/redcat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
__all__ = ["BaseBatch", "BatchedTensor", "BatchedTensorSeq"]
__all__ = [
"BaseBatch",
"BatchList",
"BatchedTensor",
"BatchedTensorSeq",
]

from redcat import comparators # noqa: F401
from redcat.base import BaseBatch
from redcat.list import BatchList
from redcat.tensor import BatchedTensor
from redcat.tensorseq import BatchedTensorSeq
16 changes: 8 additions & 8 deletions src/redcat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,11 @@ def data(self) -> T:
###############################

@abstractmethod
def clone(self, *args, **kwargs) -> TBatch:
def clone(self) -> TBatch:
r"""Creates a copy of the current batch.
Args:
*args: See the documentation of ``torch.Tensor.clone``
**kwargs: See the documentation of ``torch.Tensor.clone``
Returns:
``BaseBatchedTensor``: A copy of the current batch.
``BaseBatch``: A copy of the current batch.
Example usage:
Expand All @@ -56,6 +52,10 @@ def clone(self, *args, **kwargs) -> TBatch:
[1., 1., 1.]], batch_dim=0)
"""

#################################
# Comparison operations #
#################################

@abstractmethod
def allclose(
self, other: Any, rtol: float = 1e-5, atol: float = 1e-8, equal_nan: bool = False
Expand Down Expand Up @@ -123,7 +123,7 @@ def permute_along_batch(self, permutation: Sequence[int] | Tensor) -> TBatch:
input should be compatible with the shape of the data.
Returns:
``BaseBatchedTensor``: A new batch with permuted data.
``BaseBatch``: A new batch with permuted data.
Example usage:
Expand Down Expand Up @@ -175,7 +175,7 @@ def shuffle_along_batch(self, generator: torch.Generator | None = None) -> TBatc
Default: ``None``
Returns:
``BaseBatchedTensor``: A new batch with shuffled data.
``BaseBatch``: A new batch with shuffled data.
Example usage:
Expand Down
Loading

0 comments on commit aea1f61

Please sign in to comment.