diff --git a/chainer_chemistry/dataset/indexer.py b/chainer_chemistry/dataset/indexer.py index 84ef83e6..635f32f8 100644 --- a/chainer_chemistry/dataset/indexer.py +++ b/chainer_chemistry/dataset/indexer.py @@ -167,10 +167,7 @@ def _extract_feature(self, data_index, j): self.dataset_length)) data_index = numpy.argwhere(data_index).ravel() - if len(data_index) == 1: - return self.extract_feature(data_index[0], j) - else: - res = [self.extract_feature(i, j) for i in data_index] + res = [self.extract_feature(i, j) for i in data_index] else: return self.extract_feature(data_index, j) try: diff --git a/chainer_chemistry/datasets/numpy_tuple_dataset.py b/chainer_chemistry/datasets/numpy_tuple_dataset.py index 4fa86416..1018090b 100644 --- a/chainer_chemistry/datasets/numpy_tuple_dataset.py +++ b/chainer_chemistry/datasets/numpy_tuple_dataset.py @@ -35,7 +35,7 @@ def __init__(self, *datasets): def __getitem__(self, index): batches = [dataset[index] for dataset in self._datasets] - if isinstance(index, slice): + if isinstance(index, (slice, list, numpy.ndarray)): length = len(batches[0]) return [tuple([batch[i] for batch in batches]) for i in six.moves.range(length)] diff --git a/tests/dataset_tests/test_numpy_tuple_feature_indexer.py b/tests/dataset_tests/test_numpy_tuple_feature_indexer.py index dd6e1104..70634dec 100644 --- a/tests/dataset_tests/test_numpy_tuple_feature_indexer.py +++ b/tests/dataset_tests/test_numpy_tuple_feature_indexer.py @@ -33,6 +33,14 @@ def test_extract_feature_by_slice(self, indexer, data, slice_index, j): indexer.extract_feature_by_slice(slice_index, j), data[j][slice_index]) + @pytest.mark.parametrize('ndarray_index', [numpy.asarray([0, 1]), + numpy.asarray([1])]) + @pytest.mark.parametrize('j', [0, 1]) + def test_extract_feature_by_ndarray(self, indexer, data, ndarray_index, j): + numpy.testing.assert_array_equal( + indexer.extract_feature_by_slice(ndarray_index, j), + data[j][ndarray_index]) + if __name__ == '__main__': pytest.main([__file__, '-v']) diff --git a/tests/datasets_tests/test_numpy_tuple_dataset.py b/tests/datasets_tests/test_numpy_tuple_dataset.py index 4063fcbe..80166fab 100644 --- a/tests/datasets_tests/test_numpy_tuple_dataset.py +++ b/tests/datasets_tests/test_numpy_tuple_dataset.py @@ -16,6 +16,14 @@ def data(): return a, b, c +@pytest.fixture +def long_data(): + a = numpy.array([1, 2, 3, 4]) + b = numpy.array([4, 5, 6, 7]) + c = numpy.array([[6, 7, 8], [8, 9, 10], [11, 12, 13], [14, 15, 16]]) + return a, b, c + + class TestNumpyTupleDataset(object): def test_len(self, data): @@ -47,6 +55,39 @@ def test_get_item_slice_index(self, data, index): for a, e in six.moves.zip(tuple_a, tuple_e): numpy.testing.assert_array_equal(a, e) + @pytest.mark.parametrize('index', [numpy.asarray([2, 0]), + numpy.asarray([1])]) + def test_get_item_ndarray_index(self, long_data, index): + dataset = NumpyTupleDataset(*long_data) + actual = dataset[index] + + batches = [d[index] for d in long_data] + length = len(batches[0]) + expect = [tuple([batch[i] for batch in batches]) + for i in six.moves.range(length)] + + assert len(actual) == len(expect) + for tuple_a, tuple_e in six.moves.zip(actual, expect): + assert len(tuple_a) == len(tuple_e) + for a, e in six.moves.zip(tuple_a, tuple_e): + numpy.testing.assert_array_equal(a, e) + + @pytest.mark.parametrize('index', [[2, 0], [1]]) + def test_get_item_list_index(self, long_data, index): + dataset = NumpyTupleDataset(*long_data) + actual = dataset[index] + + batches = [d[index] for d in long_data] + length = len(batches[0]) + expect = [tuple([batch[i] for batch in batches]) + for i in six.moves.range(length)] + + assert len(actual) == len(expect) + for tuple_a, tuple_e in six.moves.zip(actual, expect): + assert len(tuple_a) == len(tuple_e) + for a, e in six.moves.zip(tuple_a, tuple_e): + numpy.testing.assert_array_equal(a, e) + def test_invalid_datasets(self): a = numpy.array([1, 2]) b = numpy.array([1, 2, 3])