diff --git a/src/sparsezoo/utils/data.py b/src/sparsezoo/utils/data.py index fc424fbd..c1fe2056 100644 --- a/src/sparsezoo/utils/data.py +++ b/src/sparsezoo/utils/data.py @@ -19,7 +19,7 @@ import logging import math from collections import OrderedDict -from typing import Dict, Iterable, Iterator, List, Tuple, Union +from typing import Dict, Generator, Iterable, Iterator, List, Tuple, Union import numpy @@ -28,10 +28,102 @@ __all__ = ["Dataset", "RandomDataset", "DataLoader"] - _LOGGER = logging.getLogger(__name__) +# A utility class to load data in batches for fixed number of iterations + + +class _BatchLoader: + __slots__ = [ + "_data", + "_batch_size", + "_was_wrapped_originally", + "_iterations", + "_batch_buffer", + "_batch_template", + "_batches_created", + ] + + def __init__( + self, + data: Iterable[Union[numpy.ndarray, List[numpy.ndarray]]], + batch_size: int, + iterations: int, + ): + self._data = data + self._was_wrapped_originally = type(self._data[0]) is list + if not self._was_wrapped_originally: + self._data = [self._data] + self._batch_size = batch_size + self._iterations = iterations + if batch_size <= 0 or iterations <= 0: + raise ValueError( + f"Both batch size and number of iterations should be positive, " + f"supplied values (batch_size, iterations):{(batch_size, iterations)}" + ) + + self._batch_buffer = [] + self._batch_template = self._init_batch_template() + self._batches_created = 0 + + def __iter__(self) -> Generator[List[numpy.ndarray], None, None]: + yield from self._multi_input_batch_generator() + + @property + def _buffer_is_full(self) -> bool: + return len(self._batch_buffer) == self._batch_size + + @property + def _all_batches_loaded(self) -> bool: + return self._batches_created >= self._iterations + + def _multi_input_batch_generator( + self, + ) -> Generator[List[numpy.ndarray], None, None]: + # A generator for with each element of the form + # [[(batch_size, features_a), (batch_size, features_b), ...]] + while not self._all_batches_loaded: + yield from self._batch_generator(source=self._data) + + def _batch_generator(self, source) -> Generator[List[numpy.ndarray], None, None]: + # batches from source + for sample in source: + self._batch_buffer.append(sample) + if self._buffer_is_full: + _batch = self._make_batch() + yield _batch + self._batch_buffer = [] + self._batches_created += 1 + if self._all_batches_loaded: + break + + def _init_batch_template( + self, + ) -> Iterable[Union[List[numpy.ndarray], numpy.ndarray]]: + # A placeholder for batches + return [ + numpy.ascontiguousarray( + numpy.zeros((self._batch_size, *_input.shape), dtype=_input.dtype) + ) + for _input in self._data[0] + ] + + def _make_batch(self) -> Iterable[Union[numpy.ndarray, List[numpy.ndarray]]]: + # Copy contents of buffer to batch placeholder + # and return A list of numpy array(s) representing the batch + + batch = [ + numpy.stack([sample[idx] for sample in self._batch_buffer], out=template) + for idx, template in enumerate(self._batch_template) + ] + + if not self._was_wrapped_originally: + # unwrap outer list + batch = batch[0] + return batch + + class Dataset(Iterable): """ A numpy dataset implementation @@ -76,6 +168,22 @@ def data(self) -> List[Union[numpy.ndarray, Dict[str, numpy.ndarray]]]: """ return self._data + def iter_batches( + self, batch_size: int, iterations: int + ) -> Generator[List[numpy.ndarray], None, None]: + """ + A function to iterate over data in batches + + :param batch_size: non-negative integer representing the size of each + :param iterations: non-negative integer representing + the number of batches to return + :returns: A generator for batches, each batch is enclosed in a list + Each batch is of the form [(batch_size, *feature_shape)] + """ + return _BatchLoader( + data=self.data, batch_size=batch_size, iterations=iterations + ) + class RandomDataset(Dataset): """ diff --git a/tests/sparsezoo/utils.py b/tests/sparsezoo/helpers.py similarity index 100% rename from tests/sparsezoo/utils.py rename to tests/sparsezoo/helpers.py diff --git a/tests/sparsezoo/models/classification/test_efficientnet.py b/tests/sparsezoo/models/classification/test_efficientnet.py index 4e8b73f8..29eae4c5 100644 --- a/tests/sparsezoo/models/classification/test_efficientnet.py +++ b/tests/sparsezoo/models/classification/test_efficientnet.py @@ -15,7 +15,7 @@ import pytest from sparsezoo.models.classification import efficientnet_b0, efficientnet_b4 -from tests.sparsezoo.utils import model_constructor +from tests.sparsezoo.helpers import model_constructor @pytest.mark.parametrize( diff --git a/tests/sparsezoo/models/classification/test_inception.py b/tests/sparsezoo/models/classification/test_inception.py index 208992dd..e4d4a6df 100644 --- a/tests/sparsezoo/models/classification/test_inception.py +++ b/tests/sparsezoo/models/classification/test_inception.py @@ -15,7 +15,7 @@ import pytest from sparsezoo.models.classification import inception_v3 -from tests.sparsezoo.utils import model_constructor +from tests.sparsezoo.helpers import model_constructor @pytest.mark.parametrize( diff --git a/tests/sparsezoo/models/classification/test_mobilenet.py b/tests/sparsezoo/models/classification/test_mobilenet.py index 0c3ae7b9..bab2427b 100644 --- a/tests/sparsezoo/models/classification/test_mobilenet.py +++ b/tests/sparsezoo/models/classification/test_mobilenet.py @@ -14,7 +14,7 @@ import pytest from sparsezoo.models.classification import mobilenet_v1, mobilenet_v2 -from tests.sparsezoo.utils import model_constructor +from tests.sparsezoo.helpers import model_constructor @pytest.mark.parametrize( diff --git a/tests/sparsezoo/models/classification/test_resnet.py b/tests/sparsezoo/models/classification/test_resnet.py index 8d50b02e..184355c2 100644 --- a/tests/sparsezoo/models/classification/test_resnet.py +++ b/tests/sparsezoo/models/classification/test_resnet.py @@ -23,7 +23,7 @@ resnet_101_2x, resnet_152, ) -from tests.sparsezoo.utils import model_constructor +from tests.sparsezoo.helpers import model_constructor @pytest.mark.parametrize( diff --git a/tests/sparsezoo/models/classification/test_vgg.py b/tests/sparsezoo/models/classification/test_vgg.py index 4910a1e6..b34948cc 100644 --- a/tests/sparsezoo/models/classification/test_vgg.py +++ b/tests/sparsezoo/models/classification/test_vgg.py @@ -24,7 +24,7 @@ vgg_19, vgg_19bn, ) -from tests.sparsezoo.utils import model_constructor +from tests.sparsezoo.helpers import model_constructor @pytest.mark.parametrize( diff --git a/tests/sparsezoo/models/detection/test_ssd.py b/tests/sparsezoo/models/detection/test_ssd.py index 8f8111e8..cf44b96c 100644 --- a/tests/sparsezoo/models/detection/test_ssd.py +++ b/tests/sparsezoo/models/detection/test_ssd.py @@ -15,7 +15,7 @@ import pytest from sparsezoo.models.detection import ssd_resnet50_300 -from tests.sparsezoo.utils import model_constructor +from tests.sparsezoo.helpers import model_constructor @pytest.mark.parametrize( diff --git a/tests/sparsezoo/models/detection/test_yolo.py b/tests/sparsezoo/models/detection/test_yolo.py index 31029151..dbd0c4df 100644 --- a/tests/sparsezoo/models/detection/test_yolo.py +++ b/tests/sparsezoo/models/detection/test_yolo.py @@ -15,7 +15,7 @@ import pytest from sparsezoo.models.detection import yolo_v3 -from tests.sparsezoo.utils import model_constructor +from tests.sparsezoo.helpers import model_constructor @pytest.mark.parametrize( diff --git a/tests/sparsezoo/models/test_zoo.py b/tests/sparsezoo/models/test_zoo.py index 9364eaee..855b58e0 100644 --- a/tests/sparsezoo/models/test_zoo.py +++ b/tests/sparsezoo/models/test_zoo.py @@ -19,7 +19,7 @@ from sparsezoo import Zoo from sparsezoo.utils import CACHE_DIR -from tests.sparsezoo.utils import validate_downloaded_model +from tests.sparsezoo.helpers import validate_downloaded_model @pytest.mark.parametrize( diff --git a/tests/sparsezoo/models/test_zoo_extensive.py b/tests/sparsezoo/models/test_zoo_extensive.py index bd9f8b98..5a1dd9d3 100644 --- a/tests/sparsezoo/models/test_zoo_extensive.py +++ b/tests/sparsezoo/models/test_zoo_extensive.py @@ -17,7 +17,7 @@ import pytest from sparsezoo.models import Zoo -from tests.sparsezoo.utils import download_and_verify +from tests.sparsezoo.helpers import download_and_verify def _get_models(domain, sub_domain) -> List[str]: diff --git a/tests/sparsezoo/utils/__init__.py b/tests/sparsezoo/utils/__init__.py new file mode 100644 index 00000000..0c44f887 --- /dev/null +++ b/tests/sparsezoo/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sparsezoo/utils/test_data.py b/tests/sparsezoo/utils/test_data.py new file mode 100644 index 00000000..6f0f0ec0 --- /dev/null +++ b/tests/sparsezoo/utils/test_data.py @@ -0,0 +1,214 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Iterable + +import numpy +import numpy as np +import pytest + +from sparsezoo.utils import Dataset + + +@pytest.fixture +def dummy_dataset(): + return Dataset(data=[np.random.rand(100, 10)], name="dummy-dataset") + + +@pytest.fixture +def single_input_dataset(): + data = [np.random.rand(3, 2), np.random.rand(3, 2), np.random.rand(3, 2)] + return Dataset(data=data, name="single-input-test-dataset") + + +@pytest.fixture +def multi_input_dataset(): + data = [ + [np.random.rand(1, 2), np.random.rand(1, 3)], + [np.random.rand(1, 2), np.random.rand(1, 3)], + ] + return Dataset(data=data, name="multi-input-test-dataset") + + +@pytest.fixture +def both_datasets(single_input_dataset, multi_input_dataset): + return [single_input_dataset, multi_input_dataset] + + +def test_has_iter_batches(dummy_dataset): + assert hasattr(dummy_dataset, "iter_batches") + + +@pytest.mark.parametrize( + "batch_size", + [ + 1, + 2, + 10, + ], +) +@pytest.mark.parametrize( + "iterations", + [ + 1, + 2, + 4, + 10, + ], +) +def test_iter_batches_returns_iterable(both_datasets, batch_size, iterations): + for dataset in both_datasets: + loader = dataset.iter_batches(batch_size=batch_size, iterations=iterations) + assert isinstance(loader, Iterable) + + +@pytest.mark.parametrize( + "batch_size", + [ + 1, + 2, + 10, + ], +) +@pytest.mark.parametrize( + "iterations", + [ + 1, + 2, + 4, + 10, + ], +) +def test_batch_is_in_list(multi_input_dataset, batch_size, iterations): + loader = multi_input_dataset.iter_batches( + batch_size=batch_size, iterations=iterations + ) + for batch in loader: + assert isinstance(batch, list) + + +@pytest.mark.parametrize( + "batch_size", + [ + 1, + 2, + 10, + ], +) +@pytest.mark.parametrize( + "iterations", + [ + 1, + 2, + 4, + 10, + ], +) +def test_batch_not_in_list_for_single_input( + single_input_dataset, batch_size, iterations +): + loader = single_input_dataset.iter_batches( + batch_size=batch_size, iterations=iterations + ) + for batch in loader: + assert not isinstance(batch, list) and isinstance(batch, numpy.ndarray) + + +@pytest.mark.parametrize( + "batch_size", + [ + 1, + 2, + 10, + ], +) +@pytest.mark.parametrize( + "iterations", + [ + 1, + 2, + 4, + 10, + ], +) +def test_iter_batches_single_input_batch_shape( + single_input_dataset, batch_size, iterations +): + loader = single_input_dataset.iter_batches( + batch_size=batch_size, iterations=iterations + ) + + _data_dimensions = single_input_dataset.data[0].shape + print("data dimensions", _data_dimensions) + expected_batch_shape = (batch_size, *_data_dimensions) + print(expected_batch_shape) + for batch in loader: + assert batch.shape == expected_batch_shape + + +@pytest.mark.parametrize( + "batch_size", + [ + 1, + 2, + 10, + ], +) +@pytest.mark.parametrize( + "iterations", + [ + 1, + 2, + 4, + 10, + ], +) +def test_iter_batches_number_of_iterations(both_datasets, batch_size, iterations): + for dataset in both_datasets: + loader = dataset.iter_batches(batch_size=batch_size, iterations=iterations) + for iteration, batch in enumerate(loader): + pass + assert iteration + 1 == iterations + + +@pytest.mark.parametrize( + "batch_size", + [ + 1, + 2, + 3, + ], +) +@pytest.mark.parametrize( + "iterations", + [ + 1, + 2, + 3, + ], +) +def test_iter_batches_multi_input_batch_shape( + multi_input_dataset, batch_size, iterations +): + expected_batch_dimensions = [ + (batch_size, *multi_input.shape) for multi_input in multi_input_dataset.data[0] + ] + loader = multi_input_dataset.iter_batches( + batch_size=batch_size, iterations=iterations + ) + + for batch in loader: + assert all( + expected_batch_dimensions[idx] == multi_input.shape + for idx, multi_input in enumerate(batch) + )