Skip to content

Commit

Permalink
Add some conversion operations (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo authored Apr 5, 2023
1 parent e7ec83e commit e298edf
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 24 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redcat"
version = "0.0.1a1"
version = "0.0.1a2"
description = "A library to manipulate batches of examples"
readme = "README.md"
authors = ["Thibaut Durand <durand.tibo+gh@gmail.com>"]
Expand Down
92 changes: 83 additions & 9 deletions src/redcat/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import torch
from torch import Tensor

from redcat.utils import IndexType


class BatchedTensor:
r"""Implements a batched tensor to easily manipulate a batch of examples.
Expand Down Expand Up @@ -66,15 +64,91 @@ def data(self) -> Tensor:
r"""``torch.Tensor``: The data in the batch."""
return self._data

###############################
# Indexing operations #
###############################
@property
def device(self) -> torch.device:
r"""``torch.device``: The device where the batch data/tensor is."""
return self._data.device

#################################
# Conversion operations #
#################################

def contiguous(
self, memory_format: torch.memory_format = torch.contiguous_format
) -> BatchedTensor:
r"""Creates a batch with a contiguous representation of the data.
Args:
memory_format (``torch.memory_format``, optional):
Specifies the desired memory format.
Default: ``torch.contiguous_format``
def __getitem__(self, index: IndexType) -> Tensor:
return self._data[index]
Returns:
``BatchedTensor``: A new batch with a contiguous
representation of the data.
Example usage:
def __setitem__(self, index: IndexType, value: Tensor | int | float) -> None:
self._data[index] = value
.. code-block:: python
>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.ones(2, 3)).contiguous()
>>> batch.data.is_contiguous()
True
"""
return BatchedTensor(
data=self._data.contiguous(memory_format=memory_format), batch_dim=self._batch_dim
)

def is_contiguous(self, memory_format: torch.memory_format = torch.contiguous_format) -> bool:
r"""Indicates if a batch as a contiguous representation of the data.
Args:
memory_format (``torch.memory_format``, optional):
Specifies the desired memory format.
Default: ``torch.contiguous_format``
Returns:
bool: ``True`` if the data are stored with a contiguous
tensor, otherwise ``False``.
Example usage:
.. code-block:: python
>>> import torch
>>> from redcat import BatchedTensor
>>> BatchedTensor(torch.ones(2, 3)).is_contiguous()
True
"""
return self._data.is_contiguous(memory_format=memory_format)

def to(self, *args, **kwargs) -> BatchedTensor:
r"""Moves and/or casts the data.
Args:
*args: see https://pytorch.org/docs/stable/generated/torch.Tensor.to.html#torch-tensor-to
**kwargs: see https://pytorch.org/docs/stable/generated/torch.Tensor.to.html#torch-tensor-to
Returns:
``BatchedTensor``: A new batch with the data after dtype
and/or device conversion.
Example usage:
.. code-block:: python
>>> import torch
>>> from redcat import BatchedTensor
>>> batch = BatchedTensor(torch.ones(2, 3))
>>> batch_cuda = batch.to(device=torch.device('cuda:0'))
>>> batch_bool = batch.to(dtype=torch.bool)
>>> batch_bool.data
tensor([[True, True, True],
[True, True, True]])
"""
return BatchedTensor(data=self._data.to(*args, **kwargs), batch_dim=self._batch_dim)

#################################
# Comparison operations #
Expand Down
23 changes: 22 additions & 1 deletion src/redcat/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__all__ = ["DeviceType", "IndexType"]
__all__ = ["DeviceType", "IndexType", "get_available_devices"]

from collections.abc import Sequence
from typing import Union
Expand All @@ -8,3 +8,24 @@

DeviceType = Union[torch.device, str, int]
IndexType = Union[None, int, slice, str, Tensor, Sequence]



def get_available_devices() -> tuple[str, ...]:
r"""Gets the available PyTorch devices on the machine.
Returns
-------
tuple: The available devices.
Example usage:
.. code-block:: python
>>> from redcat.utils import get_available_devices
>>> get_available_devices()
('cpu', 'cuda:0')
"""
if torch.cuda.is_available():
return ("cpu", "cuda:0")
return ("cpu",)
57 changes: 44 additions & 13 deletions tests/unit/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from redcat import BatchedTensor
from redcat.tensor import check_data_and_dim
from redcat.utils import get_available_devices


def test_batched_tensor_is_tensor_like() -> None:
Expand All @@ -31,25 +32,55 @@ def test_batched_tensor_repr() -> None:
assert repr(BatchedTensor(torch.arange(3))) == "tensor([0, 1, 2], batch_dim=0)"


###############################
# Indexing operations #
###############################
@mark.parametrize("device", get_available_devices())
def test_batched_tensor_device(device: str)->None:
device = torch.device(device)
assert BatchedTensor(torch.ones(2, 3, device=device)).device == device


def test_batched_tensor_getitem() -> None:
assert BatchedTensor(torch.arange(10).view(2, 5))[1, 1:3].equal(torch.tensor([6, 7]))
#################################
# Conversion operations #
#################################


def test_batched_tensor_contiguous() -> None:
batch = BatchedTensor(torch.ones(3, 2).transpose(0, 1))
assert not batch.is_contiguous()
cont = batch.contiguous()
assert cont.equal(BatchedTensor(torch.ones(2, 3)))
assert cont.is_contiguous()


def test_batched_tensor_contiguous_custom_dim() -> None:
batch = BatchedTensor(torch.ones(3, 2).transpose(0, 1), batch_dim=1)
assert not batch.is_contiguous()
cont = batch.contiguous()
assert cont.equal(BatchedTensor(torch.ones(2, 3), batch_dim=1))
assert cont.is_contiguous()


def test_batched_tensor_contiguous_memory_format() -> None:
batch = BatchedTensor(torch.ones(2, 3, 4, 5))
assert not batch.data.is_contiguous(memory_format=torch.channels_last)
cont = batch.contiguous(memory_format=torch.channels_last)
assert cont.equal(BatchedTensor(torch.ones(2, 3, 4, 5)))
assert cont.is_contiguous(memory_format=torch.channels_last)

def test_batched_tensor_setitem_tensor() -> None:
batch = BatchedTensor(torch.arange(10).view(2, 5))
batch[1, 1:3] = torch.tensor([16, 17])
assert batch.equal(BatchedTensor(torch.tensor([[0, 1, 2, 3, 4], [5, 16, 17, 8, 9]])))

def test_batched_tensor_to() -> None:
assert (
BatchedTensor(torch.ones(2, 3))
.to(dtype=torch.bool)
.equal(BatchedTensor(torch.ones(2, 3, dtype=torch.bool)))
)


def test_batched_tensor_setitem_int() -> None:
batch = BatchedTensor(torch.arange(10).view(2, 5))
batch[1, 1:3] = 2
assert batch.equal(BatchedTensor(torch.tensor([[0, 1, 2, 3, 4], [5, 2, 2, 8, 9]])))
def test_batched_tensor_to_custom_dim() -> None:
assert (
BatchedTensor(torch.ones(2, 3), batch_dim=1)
.to(dtype=torch.bool)
.equal(BatchedTensor(torch.ones(2, 3, dtype=torch.bool), batch_dim=1))
)


#################################
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from unittest.mock import patch

from redcat.utils import get_available_devices

###########################################
# Tests for get_available_devices #
###########################################


@patch("torch.cuda.is_available", lambda *args, **kwargs: False)
def test_get_available_devices_cpu() -> None:
assert get_available_devices() == ("cpu",)


@patch("torch.cuda.is_available", lambda *args, **kwargs: True)
@patch("torch.cuda.device_count", lambda *args, **kwargs: 1)
def test_get_available_devices_cpu_and_gpu() -> None:
assert get_available_devices() == ("cpu", "cuda:0")

0 comments on commit e298edf

Please sign in to comment.