Skip to content

Commit

Permalink
Merge pull request #200 from mottodora/fix-indexing
Browse files Browse the repository at this point in the history
Fix indexing behavior in NumpyTupleDataset
  • Loading branch information
corochann committed Jun 27, 2018
2 parents b524e4f + 7a70195 commit e5077f5
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 5 deletions.
5 changes: 1 addition & 4 deletions chainer_chemistry/dataset/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,7 @@ def _extract_feature(self, data_index, j):
self.dataset_length))
data_index = numpy.argwhere(data_index).ravel()

if len(data_index) == 1:
return self.extract_feature(data_index[0], j)
else:
res = [self.extract_feature(i, j) for i in data_index]
res = [self.extract_feature(i, j) for i in data_index]
else:
return self.extract_feature(data_index, j)
try:
Expand Down
2 changes: 1 addition & 1 deletion chainer_chemistry/datasets/numpy_tuple_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, *datasets):

def __getitem__(self, index):
batches = [dataset[index] for dataset in self._datasets]
if isinstance(index, slice):
if isinstance(index, (slice, list, numpy.ndarray)):
length = len(batches[0])
return [tuple([batch[i] for batch in batches])
for i in six.moves.range(length)]
Expand Down
8 changes: 8 additions & 0 deletions tests/dataset_tests/test_numpy_tuple_feature_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ def test_extract_feature_by_slice(self, indexer, data, slice_index, j):
indexer.extract_feature_by_slice(slice_index, j),
data[j][slice_index])

@pytest.mark.parametrize('ndarray_index', [numpy.asarray([0, 1]),
numpy.asarray([1])])
@pytest.mark.parametrize('j', [0, 1])
def test_extract_feature_by_ndarray(self, indexer, data, ndarray_index, j):
numpy.testing.assert_array_equal(
indexer.extract_feature_by_slice(ndarray_index, j),
data[j][ndarray_index])


if __name__ == '__main__':
pytest.main([__file__, '-v'])
41 changes: 41 additions & 0 deletions tests/datasets_tests/test_numpy_tuple_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ def data():
return a, b, c


@pytest.fixture
def long_data():
a = numpy.array([1, 2, 3, 4])
b = numpy.array([4, 5, 6, 7])
c = numpy.array([[6, 7, 8], [8, 9, 10], [11, 12, 13], [14, 15, 16]])
return a, b, c


class TestNumpyTupleDataset(object):

def test_len(self, data):
Expand Down Expand Up @@ -47,6 +55,39 @@ def test_get_item_slice_index(self, data, index):
for a, e in six.moves.zip(tuple_a, tuple_e):
numpy.testing.assert_array_equal(a, e)

@pytest.mark.parametrize('index', [numpy.asarray([2, 0]),
numpy.asarray([1])])
def test_get_item_ndarray_index(self, long_data, index):
dataset = NumpyTupleDataset(*long_data)
actual = dataset[index]

batches = [d[index] for d in long_data]
length = len(batches[0])
expect = [tuple([batch[i] for batch in batches])
for i in six.moves.range(length)]

assert len(actual) == len(expect)
for tuple_a, tuple_e in six.moves.zip(actual, expect):
assert len(tuple_a) == len(tuple_e)
for a, e in six.moves.zip(tuple_a, tuple_e):
numpy.testing.assert_array_equal(a, e)

@pytest.mark.parametrize('index', [[2, 0], [1]])
def test_get_item_list_index(self, long_data, index):
dataset = NumpyTupleDataset(*long_data)
actual = dataset[index]

batches = [d[index] for d in long_data]
length = len(batches[0])
expect = [tuple([batch[i] for batch in batches])
for i in six.moves.range(length)]

assert len(actual) == len(expect)
for tuple_a, tuple_e in six.moves.zip(actual, expect):
assert len(tuple_a) == len(tuple_e)
for a, e in six.moves.zip(tuple_a, tuple_e):
numpy.testing.assert_array_equal(a, e)

def test_invalid_datasets(self):
a = numpy.array([1, 2])
b = numpy.array([1, 2, 3])
Expand Down

0 comments on commit e5077f5

Please sign in to comment.