Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Add SliceableDataset #454

Merged
merged 32 commits into from
Apr 17, 2018
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
505c8fb
add AnnotatedDatasetMixin
Hakuyume Oct 14, 2017
7ff1dbd
add test
Hakuyume Oct 14, 2017
b7bce0c
fix
Hakuyume Oct 14, 2017
4e31ea2
add docs
Hakuyume Oct 14, 2017
d4ba10b
add PickableDataset
Hakuyume Oct 15, 2017
209a9bc
add test
Hakuyume Oct 15, 2017
27254c7
remove AnnotatedImageDatasetMixin
Hakuyume Oct 15, 2017
8fbd490
remove unused import
Hakuyume Oct 15, 2017
64a6fa0
fix super for Python2
Hakuyume Oct 15, 2017
5cad770
fix name
Hakuyume Oct 17, 2017
8145725
copy sliceable datasets from chainer
Hakuyume Feb 20, 2018
95844ca
Merge branch 'annotated-dataset-mixin' of https://github.com/Hakuyume…
Hakuyume Feb 20, 2018
fd28f45
fix style
Hakuyume Feb 20, 2018
1e76edb
Merge branch 'master' into annotated-dataset-mixin
Hakuyume Feb 20, 2018
5c61e9f
Revert "fix style"
Hakuyume Feb 20, 2018
6df3b86
Merge branch 'use-braces-brackets' into annotated-dataset-mixin
Hakuyume Feb 20, 2018
a5c6410
fix style
Hakuyume Feb 20, 2018
7a252c2
Merge branch 'master' into annotated-dataset-mixin
Hakuyume Mar 1, 2018
893e7e9
add example code
Hakuyume Apr 2, 2018
0abae4d
add rst
Hakuyume Apr 2, 2018
7f9fcae
fix docs
Hakuyume Apr 2, 2018
9308ecc
update docs
Hakuyume Apr 2, 2018
126bcd6
fix docs
Hakuyume Apr 6, 2018
7a59b50
super
Hakuyume Apr 6, 2018
cf2837b
doc?
Hakuyume Apr 17, 2018
f5c6320
update tutorial
Hakuyume Apr 17, 2018
5ef9b2d
fix
Hakuyume Apr 17, 2018
bf50752
update
Hakuyume Apr 17, 2018
8303d46
update
Hakuyume Apr 17, 2018
ec1423d
add to index
Hakuyume Apr 17, 2018
608ba79
fix
Hakuyume Apr 17, 2018
10889e6
fix
Hakuyume Apr 17, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions chainercv/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pkg_resources

from chainercv import chainer_experimental # NOQA
from chainercv import datasets # NOQA
from chainercv import evaluations # NOQA
from chainercv import extensions # NOQA
Expand Down
1 change: 1 addition & 0 deletions chainercv/chainer_experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from chainercv.chainer_experimental import datasets # NOQA
1 change: 1 addition & 0 deletions chainercv/chainer_experimental/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from chainercv.chainer_experimental.datasets import sliceable # NOQA
6 changes: 6 additions & 0 deletions chainercv/chainer_experimental/datasets/sliceable/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from chainercv.chainer_experimental.datasets.sliceable.sliceable_dataset import SliceableDataset # NOQA

from chainercv.chainer_experimental.datasets.sliceable.concatenated_dataset import ConcatenatedDataset # NOQA
from chainercv.chainer_experimental.datasets.sliceable.getter_dataset import GetterDataset # NOQA
from chainercv.chainer_experimental.datasets.sliceable.transform_dataset import TransformDataset # NOQA
from chainercv.chainer_experimental.datasets.sliceable.tuple_dataset import TupleDataset # NOQA
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from chainercv.chainer_experimental.datasets.sliceable import SliceableDataset


class ConcatenatedDataset(SliceableDataset):
"""A sliceable version of :class:`chainer.datasets.ConcatenatedDataset`.

Hew is an example.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here


>>> dataset_a = TupleDataset([0, 1, 2], [0, 1, 4])
>>> dataset_b = TupleDataset([3, 4, 5], [9, 16, 25])
>>>
>>> dataset = ConcatenatedDataset(dataset_a, dataset_b)
>>> dataset.slice[:, 0][:] # [0, 1, 2, 3, 4, 5]

Args:
datasets: The underlying datasets.
Each dataset should inherit
:class:~chainer.datasets.sliceable.Sliceabledataset`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no period

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not chainer.datasets, but chainercv.*

and should have the same keys.
"""

def __init__(self, *datasets):
if len(datasets) == 0:
raise ValueError('At least one dataset is required')
self._datasets = datasets
self._keys = datasets[0].keys
for dataset in datasets[1:]:
if not dataset.keys == self._keys:
raise ValueError('All datasets should have the same keys')

def __len__(self):
return sum(len(dataset) for dataset in self._datasets)

@property
def keys(self):
return self._keys

def get_example_by_keys(self, index, key_indices):
if index < 0:
raise IndexError
for dataset in self._datasets:
if index < len(dataset):
return dataset.get_example_by_keys(index, key_indices)
index -= len(dataset)
raise IndexError
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from chainercv.chainer_experimental.datasets.sliceable import SliceableDataset


def _as_tuple(t):
if isinstance(t, tuple):
return t
else:
return t,


class GetterDataset(SliceableDataset):
"""A sliceable dataset class that defined by getters.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is defined with


This ia a dataset class with getters.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ia --> This is a

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding a comment that lets users know about the tutorial?

Please refer to the tutorial for more detailed explanation.


Here is an example.

>>> class SliceableLabeledImageDataset(GetterDataset):
>>> def __init__(self, pairs, root='.'):
>>> super().__init__()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python 2 would not work.

>>> with open(pairs) as f:
>>> self._pairs = [l.split() for l in f]
>>> self._root = root
>>>
>>> self.add_getter('image', self.get_image)
>>> self.add_getter('label', self.get_label)
>>>
>>> def __len__(self):
>>> return len(self._pairs)
>>>
>>> def get_image(self, i):
>>> path, _ = self._pairs[i]
>>> return read_image(os.path.join(self._root, path))
>>>
>>> def get_label(self, i):
>>> _, label = self._pairs[i]
>>> return np.int32(label)
>>
>>> dataset = SliceableLabeledImageDataset('list.txt')
>>>
>>> # get a subset with label = 0, 1, 2
>>> # no images are loaded
>>> indices = [i for i, label in
>>> enumerate(dataset.slice[:, 'label']) if label in {0, 1, 2}]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indentation is wiered

>>> dataset_012 = dataset.slice[indices]
"""

def __init__(self):
self._keys = []
self._getters = []

def __len__(self):
raise NotImplementedError

@property
def keys(self):
return tuple(key for key, _, _ in self._keys)

def add_getter(self, keys, getter):
"""Register a getter function

Args:
keys (int or string or tuple of strings): The number or name(s) of
data that the getter function returns.
getter (callable): A getter function that takes an index and
returns data of the corresponding example.
"""
self._getters.append(getter)
if isinstance(keys, int):
if keys == 1:
keys = None
else:
keys = (None,) * keys
if isinstance(keys, tuple):
for key_index, key in enumerate(keys):
self._keys.append((key, len(self._getters) - 1, key_index))
else:
self._keys.append((keys, len(self._getters) - 1, None))

def get_example_by_keys(self, index, key_indices):
example = []
cache = {}
for key_index in key_indices:
_, getter_index, key_index = self._keys[key_index]
if getter_index not in cache:
cache[getter_index] = self._getters[getter_index](index)
if key_index is None:
example.append(cache[getter_index])
else:
example.append(cache[getter_index][key_index])
return tuple(example)
137 changes: 137 additions & 0 deletions chainercv/chainer_experimental/datasets/sliceable/sliceable_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import six

import chainer


def _as_tuple(t):
if isinstance(t, tuple):
return t
else:
return t,


class SliceableDataset(chainer.dataset.DatasetMixin):
"""An abstract dataset class that supports slicing.

This ia a dataset class that supports slicing.
A dataset class inheriting this class should implement
three methods: :meth:`__len__`, :meth:`keys`, and
:meth:`get_example_by_keys`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we recommend users to use GetterDataset?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean the relationship between SlicableDataset and GetterDataset should be clear.

From users perspective,
they would first come and read SlicableDataset. Since this is not intended to be directly touched by users, we should guide them to GetterDataset.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice suggestion. I agree with you.

"""

def __len__(self):
raise NotImplementedError

@property
def keys(self):
"""Return names of all keys

Returns:
string or tuple of strings
"""
raise NotImplementedError

def get_example_by_keys(self, index, key_indices):
"""Return data of an example by keys

Args:
index (int): An index of an example.
key_indices (tuple of ints): A tuple of indices of requested keys.

Returns:
tuple of data
"""
raise NotImplementedError

def get_example(self, index):
if isinstance(self.keys, tuple):
return self.get_example_by_keys(
index, tuple(range(len(self.keys))))
else:
return self.get_example_by_keys(index, (0,))[0]

@property
def slice(self):
return SliceHelper(self)

def __iter__(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is not necessary for the sliceable functionality. I added this for convenience.

In the case of calculating statistics of labels, we can write

collections.Counter(dataset[:, 'label'])

Without __iter__, we have to write

collections.Counter(dataset[:, 'label'][:])  # temporary list creation (not good for both speed and memory)
# or 
collections.Counter(dataset[:, 'label'][i] for i in range(len(dataset)))  # lengthy 

return (self.get_example(i) for i in six.moves.range(len(self)))


class SliceHelper(object):
"""A helper class for :class:`SliceableDataset`."""

def __init__(self, dataset):
self._dataset = dataset

def __getitem__(self, args):
if isinstance(args, tuple):
index, keys = args
else:
index = args
keys = self._dataset.keys

if isinstance(keys, (list, tuple)):
return_tuple = True
else:
keys, return_tuple = (keys,), False

# convert name to index
key_indices = []
for key in keys:
if isinstance(key, int):
key_index = key
if key_index >= len(self._dataset.keys):
raise IndexError('Invalid index of key')
if key_index < 0:
key_index += len(self._dataset.keys)
else:
try:
key_index = _as_tuple(self._dataset.keys).index(key)
except ValueError:
raise KeyError('{} does not exists'.format(key))
key_indices.append(key_index)

return SlicedDataset(
self._dataset, index,
tuple(key_indices) if return_tuple else key_indices[0])


class SlicedDataset(SliceableDataset):
"""A sliced view for :class:`SliceableDataset`."""

def __init__(self, dataset, index, key_indices):
self._dataset = dataset
self._index = index
self._key_indices = key_indices

def __len__(self):
if isinstance(self._index, slice):
start, end, step = self._index.indices(len(self._dataset))
return len(range(start, end, step))
else:
return len(self._index)

@property
def keys(self):
keys = _as_tuple(self._dataset.keys)
if isinstance(self._key_indices, tuple):
return tuple(keys[key_index] for key_index in self._key_indices)
else:
return keys[self._key_indices]

def get_example_by_keys(self, index, key_indices):
if isinstance(key_indices, tuple):
key_indices = tuple(
_as_tuple(self._key_indices)[key_index]
for key_index in key_indices)
else:
key_indices = _as_tuple(self._key_indices)[key_indices]

if isinstance(self._index, slice):
start, _, step = self._index.indices(len(self._dataset))
return self._dataset.get_example_by_keys(
start + index * step, key_indices)
else:
return self._dataset.get_example_by_keys(
self._index[index], key_indices)
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from chainercv.chainer_experimental.datasets.sliceable import GetterDataset


class TransformDataset(GetterDataset):
"""A sliceable version of :class:`chainer.datasets.TransformDataset`.

Note that it reuqires :obj:`keys` to determine the names of returned
values.

Hew is an example.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here


>>> def transfrom(in_data):
>>> img, bbox, label = in_data
>>> ...
>>> return new_img, new_label
>>>
>>> dataset = TramsformDataset(dataset, ('img', 'label'), transform)
>>> dataset.keys # ('img', 'label')

Args:
dataset: The underlying dataset.
This dataset should have :meth:`__len__` and :meth:`__getitem__`.
keys (int or string or tuple of strings): The number or name(s) of
data that the transform function returns.
transform (callable): A function that is called to transform values
returned by the underlying dataset's :meth:`__getitem__`.
"""

def __init__(self, dataset, keys, transform):
super(TransformDataset, self).__init__()
self._dataset = dataset
if isinstance(keys, int):
if keys == 1:
keys = None
else:
keys = (None,) * keys
self.add_getter(keys, lambda index: transform(dataset[index]))

def __len__(self):
return len(self._dataset)
Loading