Skip to content

Commit

Permalink
Wrap batch utility functions to enforce typing
Browse files Browse the repository at this point in the history
  • Loading branch information
markus-kreft committed Mar 23, 2024
1 parent 9f9d2de commit 43af752
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 15 deletions.
47 changes: 38 additions & 9 deletions ocf_datapipes/batch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import numpy as np
import torch

from ocf_datapipes.batch import NumpyBatch, TensorBatch

def copy_batch_to_device(batch: dict, device: torch.device) -> dict:

def _copy_batch_to_device(batch: dict, device: torch.device) -> dict:
"""
Moves a dict-batch of tensors to new device.
Moves tensor leaves in a nested dict to a new device
Args:
batch: dict with tensors to move
batch: nested dict with tensors to move
device: Device to move tensors to
Returns:
Expand All @@ -19,28 +21,55 @@ def copy_batch_to_device(batch: dict, device: torch.device) -> dict:
for k, v in batch.items():
if isinstance(v, dict):
# Recursion to reach the nested NWP
batch_copy[k] = copy_batch_to_device(v, device)
batch_copy[k] = _copy_batch_to_device(v, device)
elif isinstance(v, torch.Tensor):
batch_copy[k] = v.to(device)
else:
batch_copy[k] = v
return batch_copy


def batch_to_tensor(batch: dict) -> dict:
def copy_batch_to_device(batch: TensorBatch, device: torch.device) -> TensorBatch:
"""
Moves the tensors in a TensorBatch to a new device.
Args:
batch: TensorBatch with tensors to move
device: Device to move tensors to
Returns:
TensorBatch with tensors moved to new device
"""
return _copy_batch_to_device(batch, device)


def _batch_to_tensor(batch: dict) -> dict:
"""
Moves numpy batch to a tensor
Moves ndarrays in a nested dict to torch tensors
Args:
batch: dict-like batch with data in numpy arrays
batch: nested dict with data in numpy arrays
Returns:
A batch with data in torch tensors
Nested dict with data in torch tensors
"""
for k, v in batch.items():
if isinstance(v, dict):
# Recursion to reach the nested NWP
batch[k] = batch_to_tensor(v)
batch[k] = _batch_to_tensor(v)
elif isinstance(v, np.ndarray) and np.issubdtype(v.dtype, np.number):
batch[k] = torch.as_tensor(v)
return batch


def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
"""
Moves data in a NumpyBatch to a TensorBatch
Args:
batch: NumpyBatch with data in numpy arrays
Returns:
TensorBatch with data in torch tensors
"""
return _batch_to_tensor(batch)
17 changes: 11 additions & 6 deletions tests/batch/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import torch

from ocf_datapipes.batch import BatchKey, NumpyBatch
from ocf_datapipes.batch import BatchKey, NumpyBatch, TensorBatch
from ocf_datapipes.batch import copy_batch_to_device, batch_to_tensor


Expand All @@ -11,15 +11,20 @@ def _create_test_batch() -> NumpyBatch:
return sample


def test_batch_to_tensor():
batch = _create_test_batch()
def test_batch_to_tensor() -> None:
batch: NumpyBatch = _create_test_batch()
tensor_batch = batch_to_tensor(batch)
assert isinstance(tensor_batch[BatchKey.satellite_actual], torch.Tensor)


def test_copy_batch_to_device():
def test_copy_batch_to_device() -> None:
batch = _create_test_batch()
tensor_batch = batch_to_tensor(batch)
device = torch.device("cpu")
batch_copy = copy_batch_to_device(tensor_batch, device)
assert batch_copy[BatchKey.satellite_actual].device == device
batch_copy: TensorBatch = copy_batch_to_device(tensor_batch, device)
assert batch_copy[BatchKey.satellite_actual].device == device # type: ignore


if __name__ == "__main__":
test_batch_to_tensor()
test_copy_batch_to_device()

0 comments on commit 43af752

Please sign in to comment.