Skip to content

Commit

Permalink
Add BaseBatchedTensor (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo committed Apr 6, 2023
1 parent 23d4e61 commit 4b5192c
Show file tree
Hide file tree
Showing 5 changed files with 314 additions and 289 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redcat"
version = "0.0.1a5"
version = "0.0.1a6"
description = "A library to manipulate batches of examples"
readme = "README.md"
authors = ["Thibaut Durand <durand.tibo+gh@gmail.com>"]
Expand Down
213 changes: 213 additions & 0 deletions src/redcat/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
from __future__ import annotations

__all__ = ["BaseBatchedTensor"]

from abc import ABC, abstractmethod
from typing import Any, TypeVar

import torch
from torch import Tensor

# Workaround because Self is not available for python 3.9 and 3.10
# https://peps.python.org/pep-0673/
TBatchedTensor = TypeVar("TBatchedTensor", bound="BaseBatchedTensor")


class BaseBatchedTensor(ABC):
@abstractmethod
def _get_kwargs(self) -> dict:
pass

def __init__(self, data: Any, **kwargs) -> None:
super().__init__()
self._data = torch.as_tensor(data, **kwargs)

@property
def data(self) -> Tensor:
r"""``torch.Tensor``: The data in the batch."""
return self._data

@property
def device(self) -> torch.device:
r"""``torch.device``: The device where the batch data/tensor is."""
return self._data.device

#################################
# Conversion operations #
#################################

def contiguous(
self, memory_format: torch.memory_format = torch.contiguous_format
) -> TBatchedTensor:
r"""Creates a batch with a contiguous representation of the data.
Args:
memory_format (``torch.memory_format``, optional):
Specifies the desired memory format.
Default: ``torch.contiguous_format``
Returns:
``BatchedTensor``: A new batch with a contiguous
representation of the data.
Example usage:
.. code-block:: python
>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.ones(2, 3)).contiguous()
>>> batch.data.is_contiguous()
True
"""
return self.__class__(
self._data.contiguous(memory_format=memory_format), **self._get_kwargs()
)

def is_contiguous(self, memory_format: torch.memory_format = torch.contiguous_format) -> bool:
r"""Indicates if a batch as a contiguous representation of the data.
Args:
memory_format (``torch.memory_format``, optional):
Specifies the desired memory format.
Default: ``torch.contiguous_format``
Returns:
bool: ``True`` if the data are stored with a contiguous
tensor, otherwise ``False``.
Example usage:
.. code-block:: python
>>> import torch
>>> from redcat import BatchedTensor
>>> BatchedTensor(torch.ones(2, 3)).is_contiguous()
True
"""
return self._data.is_contiguous(memory_format=memory_format)

def to(self, *args, **kwargs) -> TBatchedTensor:
r"""Moves and/or casts the data.
Args:
*args: see https://pytorch.org/docs/stable/generated/torch.Tensor.to.html#torch-tensor-to
**kwargs: see https://pytorch.org/docs/stable/generated/torch.Tensor.to.html#torch-tensor-to
Returns:
``BatchedTensor``: A new batch with the data after dtype
and/or device conversion.
Example usage:
.. code-block:: python
>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.ones(2, 3))
>>> batch_cuda = batch.to(device=torch.device('cuda:0'))
>>> batch_bool = batch.to(dtype=torch.bool)
>>> batch_bool.data
tensor([[True, True, True],
[True, True, True]])
"""
return self.__class__(self._data.to(*args, **kwargs), **self._get_kwargs())

#################
# dtype #
#################

@property
def dtype(self) -> torch.dtype:
r"""``torch.dtype``: The data type."""
return self._data.dtype

def bool(self) -> TBatchedTensor:
r"""Converts the current batch to bool data type.
Returns:
``BaseBatchedTensor``: The current batch to bool data type.
Example usage:
.. code-block:: python
>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.ones(2, 3))
>>> batch.bool().dtype
torch.bool
"""
return self.__class__(self._data.bool(), **self._get_kwargs())

def double(self) -> TBatchedTensor:
r"""Converts the current batch to double data type.
Returns:
``BaseBatchedTensor``: The current batch to double data type.
Example usage:
.. code-block:: python
>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.ones(2, 3))
>>> batch.double().dtype
torch.float64
"""
return self.__class__(self._data.double(), **self._get_kwargs())

def float(self) -> TBatchedTensor:
r"""Converts the current batch to float data type.
Returns:
``BaseBatchedTensor``: The current batch to float data type.
Example usage:
.. code-block:: python
>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.ones(2, 3))
>>> batch.float().dtype
torch.float32
"""
return self.__class__(self._data.float(), **self._get_kwargs())

def int(self) -> TBatchedTensor:
r"""Converts the current batch to int data type.
Returns:
``BaseBatchedTensor``: The current batch to int data type.
Example usage:
.. code-block:: python
>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.ones(2, 3))
>>> batch.int().dtype
torch.int32
"""
return self.__class__(self._data.int(), **self._get_kwargs())

def long(self) -> TBatchedTensor:
r"""Converts the current batch to long data type.
Returns:
``BaseBatchedTensor``: The current batch to long data type.
Example usage:
.. code-block:: python
>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.ones(2, 3))
>>> batch.long().dtype
torch.int64
"""
return self.__class__(self._data.long(), **self._get_kwargs())
98 changes: 4 additions & 94 deletions src/redcat/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import torch
from torch import Tensor

from redcat.base import BaseBatchedTensor

class BatchedTensor:

class BatchedTensor(BaseBatchedTensor):
r"""Implements a batched tensor to easily manipulate a batch of examples.
Args:
Expand All @@ -23,8 +25,7 @@ class BatchedTensor:
"""

def __init__(self, data: Any, *, batch_dim: int = 0, **kwargs) -> None:
super().__init__()
self._data = torch.as_tensor(data, **kwargs)
super().__init__(data, **kwargs)
check_data_and_dim(self._data, batch_dim)
self._batch_dim = int(batch_dim)

Expand Down Expand Up @@ -59,97 +60,6 @@ def batch_size(self) -> int:
r"""int: The batch size."""
return self._data.shape[self._batch_dim]

@property
def data(self) -> Tensor:
r"""``torch.Tensor``: The data in the batch."""
return self._data

@property
def device(self) -> torch.device:
r"""``torch.device``: The device where the batch data/tensor is."""
return self._data.device

#################################
# Conversion operations #
#################################

def contiguous(
self, memory_format: torch.memory_format = torch.contiguous_format
) -> BatchedTensor:
r"""Creates a batch with a contiguous representation of the data.
Args:
memory_format (``torch.memory_format``, optional):
Specifies the desired memory format.
Default: ``torch.contiguous_format``
Returns:
``BatchedTensor``: A new batch with a contiguous
representation of the data.
Example usage:
.. code-block:: python
>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.ones(2, 3)).contiguous()
>>> batch.data.is_contiguous()
True
"""
return BatchedTensor(
data=self._data.contiguous(memory_format=memory_format), **self._get_kwargs()
)

def is_contiguous(self, memory_format: torch.memory_format = torch.contiguous_format) -> bool:
r"""Indicates if a batch as a contiguous representation of the data.
Args:
memory_format (``torch.memory_format``, optional):
Specifies the desired memory format.
Default: ``torch.contiguous_format``
Returns:
bool: ``True`` if the data are stored with a contiguous
tensor, otherwise ``False``.
Example usage:
.. code-block:: python
>>> import torch
>>> from redcat import BatchedTensor
>>> BatchedTensor(torch.ones(2, 3)).is_contiguous()
True
"""
return self._data.is_contiguous(memory_format=memory_format)

def to(self, *args, **kwargs) -> BatchedTensor:
r"""Moves and/or casts the data.
Args:
*args: see https://pytorch.org/docs/stable/generated/torch.Tensor.to.html#torch-tensor-to
**kwargs: see https://pytorch.org/docs/stable/generated/torch.Tensor.to.html#torch-tensor-to
Returns:
``BatchedTensor``: A new batch with the data after dtype
and/or device conversion.
Example usage:
.. code-block:: python
>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.ones(2, 3))
>>> batch_cuda = batch.to(device=torch.device('cuda:0'))
>>> batch_bool = batch.to(dtype=torch.bool)
>>> batch_bool.data
tensor([[True, True, True],
[True, True, True]])
"""
return BatchedTensor(data=self._data.to(*args, **kwargs), **self._get_kwargs())

#################################
# Comparison operations #
#################################
Expand Down
Loading

0 comments on commit 4b5192c

Please sign in to comment.