Skip to content

Commit

Permalink
Clean BatchedTensorSeq implementation (#174)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo committed May 6, 2023
1 parent cf8f9ce commit 021a5c3
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 122 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redcat"
version = "0.0.1a147"
version = "0.0.1a148"
description = "A library to manipulate batches of examples"
readme = "README.md"
authors = ["Thibaut Durand <durand.tibo+gh@gmail.com>"]
Expand Down
10 changes: 5 additions & 5 deletions src/redcat/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def new_full(
shape[self._batch_dim] = batch_size
kwargs["dtype"] = kwargs.get("dtype", self.dtype)
kwargs["device"] = kwargs.get("device", self.device)
return BatchedTensor(
return self.__class__(
torch.full(size=shape, fill_value=fill_value, **kwargs), **self._get_kwargs()
)

Expand Down Expand Up @@ -163,7 +163,7 @@ def new_ones(
shape[self._batch_dim] = batch_size
kwargs["dtype"] = kwargs.get("dtype", self.dtype)
kwargs["device"] = kwargs.get("device", self.device)
return BatchedTensor(torch.ones(*shape, **kwargs), **self._get_kwargs())
return self.__class__(torch.ones(*shape, **kwargs), **self._get_kwargs())

def new_zeros(
self,
Expand Down Expand Up @@ -209,7 +209,7 @@ def new_zeros(
shape[self._batch_dim] = batch_size
kwargs["dtype"] = kwargs.get("dtype", self.dtype)
kwargs["device"] = kwargs.get("device", self.device)
return BatchedTensor(torch.zeros(*shape, **kwargs), **self._get_kwargs())
return self.__class__(torch.zeros(*shape, **kwargs), **self._get_kwargs())

#################################
# Comparison operations #
Expand All @@ -218,7 +218,7 @@ def new_zeros(
def allclose(
self, other: Any, rtol: float = 1e-5, atol: float = 1e-8, equal_nan: bool = False
) -> bool:
if not isinstance(other, BatchedTensor):
if not isinstance(other, self.__class__):
return False
if self._batch_dim != other.batch_dim:
return False
Expand All @@ -227,7 +227,7 @@ def allclose(
return self._data.allclose(other.data, rtol=rtol, atol=atol, equal_nan=equal_nan)

def equal(self, other: Any) -> bool:
if not isinstance(other, BatchedTensor):
if not isinstance(other, self.__class__):
return False
if self._batch_dim != other.batch_dim:
return False
Expand Down
132 changes: 16 additions & 116 deletions src/redcat/tensorseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@
compute_batch_seq_permutation,
get_batch_dims,
get_seq_dims,
permute_along_dim,
)

HANDLED_FUNCTIONS = {}


class BatchedTensorSeq(BaseBatchedTensor):
class BatchedTensorSeq(BatchedTensor):
r"""Implements a batched tensor to easily manipulate a batch of
sequences.
Expand All @@ -43,9 +42,8 @@ class BatchedTensorSeq(BaseBatchedTensor):
"""

def __init__(self, data: Any, *, batch_dim: int = 0, seq_dim: int = 1, **kwargs) -> None:
super().__init__(data, **kwargs)
super().__init__(data, batch_dim=batch_dim, **kwargs)
check_data_and_dims(self._data, batch_dim, seq_dim)
self._batch_dim = int(batch_dim)
self._seq_dim = int(seq_dim)

def __repr__(self) -> str:
Expand All @@ -70,15 +68,6 @@ def __torch_function__(
args = [a._data if hasattr(a, "_data") else a for a in args]
return cls(func(*args, **kwargs), batch_dim=batch_dims.pop(), seq_dim=seq_dims.pop())

@property
def batch_dim(self) -> int:
r"""int: The batch dimension in the ``torch.Tensor`` object."""
return self._batch_dim

@property
def batch_size(self) -> int:
return self._data.shape[self._batch_dim]

@property
def seq_dim(self) -> int:
r"""int: The sequence dimension in the ``torch.Tensor`` object."""
Expand Down Expand Up @@ -148,7 +137,7 @@ def new_full(
shape[self._seq_dim] = seq_len
kwargs["dtype"] = kwargs.get("dtype", self.dtype)
kwargs["device"] = kwargs.get("device", self.device)
return BatchedTensorSeq(
return self.__class__(
torch.full(size=shape, fill_value=fill_value, **kwargs), **self._get_kwargs()
)

Expand Down Expand Up @@ -205,7 +194,7 @@ def new_ones(
shape[self._seq_dim] = seq_len
kwargs["dtype"] = kwargs.get("dtype", self.dtype)
kwargs["device"] = kwargs.get("device", self.device)
return BatchedTensorSeq(torch.ones(*shape, **kwargs), **self._get_kwargs())
return self.__class__(torch.ones(*shape, **kwargs), **self._get_kwargs())

def new_zeros(
self,
Expand Down Expand Up @@ -260,7 +249,7 @@ def new_zeros(
shape[self._seq_dim] = seq_len
kwargs["dtype"] = kwargs.get("dtype", self.dtype)
kwargs["device"] = kwargs.get("device", self.device)
return BatchedTensorSeq(torch.zeros(*shape, **kwargs), **self._get_kwargs())
return self.__class__(torch.zeros(*shape, **kwargs), **self._get_kwargs())

@classmethod
def from_seq_batch(cls, data: Any, **kwargs) -> BatchedTensorSeq:
Expand Down Expand Up @@ -317,48 +306,37 @@ def add_(
other: BaseBatchedTensor | Tensor | int | float,
alpha: int | float = 1.0,
) -> None:
check_batch_dims(get_batch_dims((self, other)))
check_seq_dims(get_seq_dims((self, other)))
self._data.add_(other, alpha=alpha)
super().add_(other, alpha=alpha)

def div_(
self,
other: BaseBatchedTensor | Tensor | int | float,
rounding_mode: str | None = None,
) -> None:
check_batch_dims(get_batch_dims((self, other)))
check_seq_dims(get_seq_dims((self, other)))
self._data.div_(other, rounding_mode=rounding_mode)
super().div_(other, rounding_mode=rounding_mode)

def fmod_(self, divisor: BaseBatchedTensor | Tensor | int | float) -> None:
check_batch_dims(get_batch_dims((self, divisor)))
check_seq_dims(get_seq_dims((self, divisor)))
self._data.fmod_(divisor)
super().fmod_(divisor)

def mul_(self, other: BaseBatchedTensor | Tensor | int | float) -> None:
check_batch_dims(get_batch_dims((self, other)))
check_seq_dims(get_seq_dims((self, other)))
self._data.mul_(other)
super().mul_(other)

def sub_(
self,
other: BaseBatchedTensor | Tensor | int | float,
alpha: int | float = 1.0,
) -> None:
check_batch_dims(get_batch_dims((self, other)))
check_seq_dims(get_seq_dims((self, other)))
self._data.sub_(other, alpha=alpha)
super().sub_(other, alpha=alpha)

###########################################################
# Mathematical | advanced arithmetical operations #
###########################################################

def cumsum_along_batch(self, **kwargs) -> BatchedTensorSeq:
return self.cumsum(self._batch_dim, **kwargs)

def cumsum_along_batch_(self) -> None:
self.cumsum_(self._batch_dim)

def cumsum_along_seq(self, **kwargs) -> BatchedTensorSeq:
r"""Computes the cumulative sum of elements of the current batch
in the sequence dimension.
Expand Down Expand Up @@ -401,12 +379,6 @@ def cumsum_along_seq_(self) -> None:
"""
self.cumsum_(self._seq_dim)

def logcumsumexp_along_batch(self) -> BatchedTensorSeq:
return self.logcumsumexp(self._batch_dim)

def logcumsumexp_along_batch_(self) -> None:
self.logcumsumexp_(self._batch_dim)

def logcumsumexp_along_seq(self) -> BatchedTensorSeq:
r"""Computes the logarithm of the cumulative summation of the
exponentiation of elements of the current batch in the sequence
Expand Down Expand Up @@ -448,25 +420,6 @@ def logcumsumexp_along_seq_(self) -> None:
"""
self.logcumsumexp_(self._seq_dim)

def permute_along_batch(self, permutation: Sequence[int] | Tensor) -> BatchedTensorSeq:
return self.permute_along_dim(permutation, dim=self._batch_dim)

def permute_along_batch_(self, permutation: Sequence[int] | Tensor) -> None:
self.permute_along_dim_(permutation, dim=self._batch_dim)

def permute_along_dim(self, permutation: Sequence[int] | Tensor, dim: int) -> BatchedTensorSeq:
if not torch.is_tensor(permutation):
permutation = torch.tensor(permutation)
return self.__class__(
permute_along_dim(tensor=self._data, permutation=permutation, dim=dim),
**self._get_kwargs(),
)

def permute_along_dim_(self, permutation: Sequence[int] | Tensor, dim: int) -> None:
if not torch.is_tensor(permutation):
permutation = torch.tensor(permutation)
self._data = permute_along_dim(tensor=self._data, permutation=permutation, dim=dim)

def permute_along_seq(self, permutation: Sequence[int] | Tensor) -> BatchedTensorSeq:
r"""Permutes the data along the sequence dimension.
Expand Down Expand Up @@ -560,13 +513,6 @@ def shuffle_along_seq_(self, generator: torch.Generator | None = None) -> None:
"""
self.permute_along_seq_(torch.randperm(self.seq_len, generator=generator))

def sort_along_batch(
self,
descending: bool = False,
stable: bool = False,
) -> tuple[BatchedTensorSeq, BatchedTensorSeq]:
return self.sort(dim=self._batch_dim, descending=descending, stable=stable)

def sort_along_seq(
self, descending: bool = False, stable: bool = False
) -> tuple[BatchedTensorSeq, BatchedTensorSeq]:
Expand Down Expand Up @@ -608,28 +554,24 @@ def sort_along_seq(
################################################

def pow_(self, exponent: int | float | BaseBatchedTensor) -> None:
check_batch_dims(get_batch_dims((self, exponent)))
check_seq_dims(get_seq_dims((self, exponent)))
self._data.pow_(exponent)
super().pow_(exponent)

#############################################
# Mathematical | logical operations #
#############################################

def logical_and_(self, other: BaseBatchedTensor | Tensor) -> None:
check_batch_dims(get_batch_dims((self, other)))
check_seq_dims(get_seq_dims((self, other)))
self._data.logical_and_(other)
super().logical_and_(other)

def logical_or_(self, other: BaseBatchedTensor | Tensor) -> None:
check_batch_dims(get_batch_dims((self, other)))
check_seq_dims(get_seq_dims((self, other)))
self._data.logical_or_(other)
super().logical_or_(other)

def logical_xor_(self, other: BaseBatchedTensor | Tensor) -> None:
check_batch_dims(get_batch_dims((self, other)))
check_seq_dims(get_seq_dims((self, other)))
self._data.logical_xor_(other)
super().logical_xor_(other)

################################
# Reduction operations #
Expand Down Expand Up @@ -858,7 +800,7 @@ def align_as(self, other: BatchedTensorSeq) -> BatchedTensorSeq:
tensor([[0, 2, 4, 6, 8],
[1, 3, 5, 7, 9]], batch_dim=1, seq_dim=0)
"""
if not isinstance(other, BatchedTensorSeq):
if not isinstance(other, self.__class__):
raise TypeError(
f"Incorrect type {type(other)}. No implementation available to `align_as` "
f"{type(self)} with {type(other)}"
Expand Down Expand Up @@ -896,7 +838,7 @@ def align_to_batch_seq(self) -> BatchedTensorSeq:
[1, 3, 5, 7, 9]], batch_dim=0, seq_dim=1)
"""
return self.__class__(
align_to_batch_seq(tensor=self._data, batch_dim=self._batch_dim, seq_dim=self._seq_dim),
align_to_batch_seq(tensor=self._data, **self._get_kwargs()),
batch_dim=0,
seq_dim=1,
)
Expand Down Expand Up @@ -927,16 +869,6 @@ def align_to_seq_batch(self) -> BatchedTensorSeq:
seq_dim=0,
)

def cat_along_batch(
self, tensors: BaseBatchedTensor | Tensor | Iterable[BaseBatchedTensor | Tensor]
) -> BatchedTensor:
return self.cat(tensors, dim=self._batch_dim)

def cat_along_batch_(
self, tensors: BaseBatchedTensor | Tensor | Iterable[BaseBatchedTensor | Tensor]
) -> None:
self.cat_(tensors, dim=self._batch_dim)

def cat_along_seq(
self, tensors: BaseBatchedTensor | Tensor | Iterable[BaseBatchedTensor | Tensor]
) -> BatchedTensorSeq:
Expand Down Expand Up @@ -1025,9 +957,6 @@ def cat_along_seq_(
"""
self.cat_(tensors, dim=self._seq_dim)

def chunk_along_batch(self, chunks: int) -> tuple[BaseBatchedTensor, ...]:
return self.chunk(chunks, self._batch_dim)

def chunk_along_seq(self, chunks: int) -> tuple[BaseBatchedTensor, ...]:
r"""Splits the batch into chunks along the sequence dimension.
Expand All @@ -1051,14 +980,6 @@ def chunk_along_seq(self, chunks: int) -> tuple[BaseBatchedTensor, ...]:
"""
return self.chunk(chunks, self._seq_dim)

def index_select(self, dim: int, index: torch.Tensor | Sequence[int]) -> BatchedTensorSeq:
if not torch.is_tensor(index):
index = torch.tensor(index)
return self.__class__(self._data.index_select(dim, index), **self._get_kwargs())

def index_select_along_batch(self, index: torch.Tensor | Sequence[int]) -> BatchedTensorSeq:
return self.index_select(self._batch_dim, index)

def index_select_along_seq(self, index: torch.Tensor | Sequence[int]) -> BatchedTensorSeq:
r"""Slices the batch along the sequence dimension at the given
indices.
Expand Down Expand Up @@ -1120,9 +1041,6 @@ def repeat_along_seq(self, repeats: int) -> BatchedTensorSeq:
sizes[self._seq_dim] = repeats
return self.__class__(self._data.repeat(*sizes), **self._get_kwargs())

def select_along_batch(self, index: int) -> Tensor:
return self.select(self._batch_dim, index)

def select_along_seq(self, index: int) -> BatchedTensor:
r"""Slices the batch along the sequence dimension at the given
index.
Expand All @@ -1148,11 +1066,6 @@ def select_along_seq(self, index: int) -> BatchedTensor:
batch_dim=self._batch_dim if self._seq_dim > self._batch_dim else self._batch_dim - 1,
)

def slice_along_batch(
self, start: int = 0, stop: int | None = None, step: int = 1
) -> BatchedTensorSeq:
return self.slice_along_dim(self._batch_dim, start, stop, step)

def slice_along_seq(
self, start: int = 0, stop: int | None = None, step: int = 1
) -> BatchedTensorSeq:
Expand Down Expand Up @@ -1189,21 +1102,11 @@ def slice_along_seq(
"""
return self.slice_along_dim(self._seq_dim, start, stop, step)

def split_along_batch(
self, split_size_or_sections: int | Sequence[int]
) -> tuple[BatchedTensorSeq, ...]:
return self.split(split_size_or_sections, dim=self._batch_dim)

def split_along_seq(
self, split_size_or_sections: int | Sequence[int]
) -> tuple[BatchedTensorSeq, ...]:
return self.split(split_size_or_sections, dim=self._seq_dim)

def take_along_batch(
self, indices: BaseBatch[Tensor | Sequence] | Tensor | Sequence
) -> BatchedTensorSeq:
return self.take_along_dim(indices, dim=self._batch_dim)

def take_along_seq(self, indices: BaseBatch | Tensor | Sequence) -> BatchedTensorSeq:
r"""Takes values along the sequence dimension.
Expand Down Expand Up @@ -1238,9 +1141,6 @@ def unsqueeze(self, dim: int) -> BatchedTensorSeq:
seq_dim=self._seq_dim + 1 if self._seq_dim >= dim and dim >= 0 else self._seq_dim,
)

def view(self, *shape: tuple[int, ...]) -> Tensor:
return self._data.view(*shape)

def view_as(self, other: BaseBatchedTensor | Tensor) -> BatchedTensorSeq:
check_batch_dims(get_batch_dims((self, other)))
check_seq_dims(get_seq_dims((self, other)))
Expand Down

0 comments on commit 021a5c3

Please sign in to comment.