Skip to content
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
86 lines (67 sloc) 2.85 KB
import numpy
import six
class DatasetMixin(object):
"""Default implementation of dataset indexing.
DatasetMixin provides the :meth:`__getitem__` operator. The default
implementation uses :meth:`get_example` to extract each example, and
combines the results into a list. This mixin makes it easy to implement a
new dataset that does not support efficient slicing.
Dataset implementation using DatasetMixin still has to provide the
:meth:`__len__` operator explicitly.
def __getitem__(self, index):
"""Returns an example or a sequence of examples.
It implements the standard Python indexing and one-dimensional integer
array indexing. It uses the :meth:`get_example` method by default, but
it may be overridden by the implementation to, for example, improve the
slicing performance.
index (int, slice, list or numpy.ndarray): An index of an example
or indexes of examples.
If index is int, returns an example created by `get_example`.
If index is either slice or one-dimensional list or numpy.ndarray,
returns a list of examples created by `get_example`.
.. admonition:: Example
>>> import numpy
>>> from chainer import dataset
>>> class SimpleDataset(dataset.DatasetMixin):
... def __init__(self, values):
... self.values = values
... def __len__(self):
... return len(self.values)
... def get_example(self, i):
... return self.values[i]
>>> ds = SimpleDataset([0, 1, 2, 3, 4, 5])
>>> ds[1] # Access by int
>>> ds[1:3] # Access by slice
[1, 2]
>>> ds[[4, 0]] # Access by one-dimensional integer list
[4, 0]
>>> index = numpy.arange(3)
>>> ds[index] # Access by one-dimensional integer numpy.ndarray
[0, 1, 2]
if isinstance(index, slice):
current, stop, step = index.indices(len(self))
return [self.get_example(i) for i in
six.moves.range(current, stop, step)]
elif isinstance(index, list) or isinstance(index, numpy.ndarray):
return [self.get_example(i) for i in index]
return self.get_example(index)
def __len__(self):
"""Returns the number of data points."""
raise NotImplementedError
def get_example(self, i):
"""Returns the i-th example.
Implementations should override it. It should raise :class:`IndexError`
if the index is invalid.
i (int): The index of the example.
The i-th example.
raise NotImplementedError
You can’t perform that action at this time.