Skip to content

Commit

Permalink
Merge pull request #287 from markus-kreft/utils_from_pvnet
Browse files Browse the repository at this point in the history
Add NumpyBatch utils from PVNet
  • Loading branch information
dfulu committed Mar 26, 2024
2 parents a066adc + 53a2e6d commit 2e8b064
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 4 deletions.
3 changes: 2 additions & 1 deletion ocf_datapipes/batch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Datapipes for batching together data"""
from .batches import BatchKey, NumpyBatch, NWPBatchKey, NWPNumpyBatch, XarrayBatch
from .batches import BatchKey, NumpyBatch, NWPBatchKey, NWPNumpyBatch, TensorBatch, XarrayBatch
from .merge_numpy_examples_to_batch import (
MergeNumpyBatchIterDataPipe as MergeNumpyBatch,
)
Expand All @@ -12,3 +12,4 @@
)
from .merge_numpy_modalities import MergeNumpyModalitiesIterDataPipe as MergeNumpyModalities
from .merge_numpy_modalities import MergeNWPNumpyModalitiesIterDataPipe as MergeNWPNumpyModalities
from .utils import batch_to_tensor, copy_batch_to_device
3 changes: 3 additions & 0 deletions ocf_datapipes/batch/batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Union

import numpy as np
import torch
import xarray as xr


Expand Down Expand Up @@ -229,3 +230,5 @@ class NWPBatchKey(Enum):
NumpyBatch = dict[BatchKey, Union[np.ndarray, dict[str, NWPNumpyBatch]]]

XarrayBatch = dict[BatchKey, Union[xr.DataArray, xr.Dataset]]

TensorBatch = dict[BatchKey, Union[torch.Tensor, dict[str, dict[NWPBatchKey, torch.Tensor]]]]
75 changes: 75 additions & 0 deletions ocf_datapipes/batch/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Additional utils for working with batches"""
import numpy as np
import torch

from ocf_datapipes.batch import NumpyBatch, TensorBatch


def _copy_batch_to_device(batch: dict, device: torch.device) -> dict:
"""
Moves tensor leaves in a nested dict to a new device
Args:
batch: nested dict with tensors to move
device: Device to move tensors to
Returns:
A dict with tensors moved to new device
"""
batch_copy = {}

for k, v in batch.items():
if isinstance(v, dict):
# Recursion to reach the nested NWP
batch_copy[k] = _copy_batch_to_device(v, device)
elif isinstance(v, torch.Tensor):
batch_copy[k] = v.to(device)
else:
batch_copy[k] = v
return batch_copy


def copy_batch_to_device(batch: TensorBatch, device: torch.device) -> TensorBatch:
"""
Moves the tensors in a TensorBatch to a new device.
Args:
batch: TensorBatch with tensors to move
device: Device to move tensors to
Returns:
TensorBatch with tensors moved to new device
"""
return _copy_batch_to_device(batch, device)


def _batch_to_tensor(batch: dict) -> dict:
"""
Moves ndarrays in a nested dict to torch tensors
Args:
batch: nested dict with data in numpy arrays
Returns:
Nested dict with data in torch tensors
"""
for k, v in batch.items():
if isinstance(v, dict):
# Recursion to reach the nested NWP
batch[k] = _batch_to_tensor(v)
elif isinstance(v, np.ndarray) and np.issubdtype(v.dtype, np.number):
batch[k] = torch.as_tensor(v)
return batch


def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
"""
Moves data in a NumpyBatch to a TensorBatch
Args:
batch: NumpyBatch with data in numpy arrays
Returns:
TensorBatch with data in torch tensors
"""
return _batch_to_tensor(batch)
3 changes: 1 addition & 2 deletions ocf_datapipes/training/pvnet_site.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import xarray as xr
from torch.utils.data import IterDataPipe, functional_datapipe
from torch.utils.data.datapipes.iter import IterableWrapper
from ocf_datapipes.batch import BatchKey, NumpyBatch

from ocf_datapipes.batch import MergeNumpyModalities, MergeNWPNumpyModalities
from ocf_datapipes.batch import BatchKey, MergeNumpyModalities, MergeNWPNumpyModalities
from ocf_datapipes.training.common import (
DatapipeKeyForker,
_get_datapipes_dict,
Expand Down
1 change: 0 additions & 1 deletion ocf_datapipes/training/windnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def __init__(self, filenames: List[str], keys: List[str]):

def __iter__(self):
"""Iterate through each filename, loading it, uncombining it, and then yielding it"""
import numpy as np

while True:
for filename in self.filenames:
Expand Down
30 changes: 30 additions & 0 deletions tests/batch/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np
import torch

from ocf_datapipes.batch import BatchKey, NumpyBatch, TensorBatch
from ocf_datapipes.batch import copy_batch_to_device, batch_to_tensor


def _create_test_batch() -> NumpyBatch:
sample: NumpyBatch = {}
sample[BatchKey.satellite_actual] = np.full((12, 10, 24, 24), 0)
return sample


def test_batch_to_tensor() -> None:
batch: NumpyBatch = _create_test_batch()
tensor_batch = batch_to_tensor(batch)
assert isinstance(tensor_batch[BatchKey.satellite_actual], torch.Tensor)


def test_copy_batch_to_device() -> None:
batch = _create_test_batch()
tensor_batch = batch_to_tensor(batch)
device = torch.device("cpu")
batch_copy: TensorBatch = copy_batch_to_device(tensor_batch, device)
assert batch_copy[BatchKey.satellite_actual].device == device # type: ignore


if __name__ == "__main__":
test_batch_to_tensor()
test_copy_batch_to_device()

0 comments on commit 2e8b064

Please sign in to comment.