/
dataset_mixin.py
86 lines (67 loc) · 2.85 KB
/
dataset_mixin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import numpy
import six
class DatasetMixin(object):
"""Default implementation of dataset indexing.
DatasetMixin provides the :meth:`__getitem__` operator. The default
implementation uses :meth:`get_example` to extract each example, and
combines the results into a list. This mixin makes it easy to implement a
new dataset that does not support efficient slicing.
Dataset implementation using DatasetMixin still has to provide the
:meth:`__len__` operator explicitly.
"""
def __getitem__(self, index):
"""Returns an example or a sequence of examples.
It implements the standard Python indexing and one-dimensional integer
array indexing. It uses the :meth:`get_example` method by default, but
it may be overridden by the implementation to, for example, improve the
slicing performance.
Args:
index (int, slice, list or numpy.ndarray): An index of an example
or indexes of examples.
Returns:
If index is int, returns an example created by `get_example`.
If index is either slice or one-dimensional list or numpy.ndarray,
returns a list of examples created by `get_example`.
.. admonition:: Example
>>> import numpy
>>> from chainer import dataset
>>> class SimpleDataset(dataset.DatasetMixin):
... def __init__(self, values):
... self.values = values
... def __len__(self):
... return len(self.values)
... def get_example(self, i):
... return self.values[i]
...
>>> ds = SimpleDataset([0, 1, 2, 3, 4, 5])
>>> ds[1] # Access by int
1
>>> ds[1:3] # Access by slice
[1, 2]
>>> ds[[4, 0]] # Access by one-dimensional integer list
[4, 0]
>>> index = numpy.arange(3)
>>> ds[index] # Access by one-dimensional integer numpy.ndarray
[0, 1, 2]
"""
if isinstance(index, slice):
current, stop, step = index.indices(len(self))
return [self.get_example(i) for i in
six.moves.range(current, stop, step)]
elif isinstance(index, list) or isinstance(index, numpy.ndarray):
return [self.get_example(i) for i in index]
else:
return self.get_example(index)
def __len__(self):
"""Returns the number of data points."""
raise NotImplementedError
def get_example(self, i):
"""Returns the i-th example.
Implementations should override it. It should raise :class:`IndexError`
if the index is invalid.
Args:
i (int): The index of the example.
Returns:
The i-th example.
"""
raise NotImplementedError