-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
sub_dataset.py
245 lines (194 loc) · 9.62 KB
/
sub_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import numpy
import six
from chainer.dataset import dataset_mixin
class SubDataset(dataset_mixin.DatasetMixin):
"""Subset of a base dataset.
SubDataset defines a subset of a given base dataset. The subset is defined
as an interval of indexes, optionally with a given permutation.
If ``order`` is given, then the ``i``-th example of this dataset is the
``order[start + i]``-th example of the base dataset, where ``i`` is a
non-negative integer. If ``order`` is not given, then the ``i``-th example
of this dataset is the ``start + i``-th example of the base dataset.
Negative indexing is also allowed: in this case, the term ``start + i`` is
replaced by ``finish + i``.
SubDataset is often used to split a dataset into training and validation
subsets. The training set is used for training, while the validation set is
used to track the generalization performance, i.e. how the learned model
works well on unseen data. We can tune hyperparameters (e.g. number of
hidden units, weight initializers, learning rate, etc.) by comparing the
validation performance. Note that we often use another set called test set
to measure the quality of the tuned hyperparameter, which can be made by
nesting multiple SubDatasets.
There are two ways to make training-validation splits. One is a single
split, where the dataset is split just into two subsets. It can be done by
:func:`split_dataset` or :func:`split_dataset_random`. The other one is a
:math:`k`-fold cross validation, in which the dataset is divided into
:math:`k` subsets, and :math:`k` different splits are generated using each
of the :math:`k` subsets as a validation set and the rest as a training
set. It can be done by :func:`get_cross_validation_datasets`.
Args:
dataset: Base dataset.
start (int): The first index in the interval.
finish (int): The next-to-the-last index in the interval.
order (sequence of ints): Permutation of indexes in the base dataset.
If this is ``None``, then the ascending order of indexes is used.
"""
def __init__(self, dataset, start, finish, order=None):
if start < 0 or finish > len(dataset):
raise ValueError('subset overruns the base dataset.')
self._dataset = dataset
self._start = start
self._finish = finish
self._size = finish - start
if order is not None and len(order) != len(dataset):
msg = ('order option must have the same length as the base '
'dataset: len(order) = {} while len(dataset) = {}'.format(
len(order), len(dataset)))
raise ValueError(msg)
self._order = order
def __len__(self):
return self._size
def get_example(self, i):
if i >= 0:
if i >= self._size:
raise IndexError('dataset index out of range')
index = self._start + i
else:
if i < -self._size:
raise IndexError('dataset index out of range')
index = self._finish + i
if self._order is not None:
index = self._order[index]
return self._dataset[index]
def split_dataset(dataset, split_at, order=None):
"""Splits a dataset into two subsets.
This function creates two instances of :class:`SubDataset`. These instances
do not share any examples, and they together cover all examples of the
original dataset.
Args:
dataset: Dataset to split.
split_at (int): Position at which the base dataset is split.
order (sequence of ints): Permutation of indexes in the base dataset.
See the documentation of :class:`SubDataset` for details.
Returns:
tuple: Two :class:`SubDataset` objects. The first subset represents the
examples of indexes ``order[:split_at]`` while the second subset
represents the examples of indexes ``order[split_at:]``.
"""
n_examples = len(dataset)
if not isinstance(split_at, (six.integer_types, numpy.integer)):
raise TypeError('split_at must be int, got {} instead'
.format(type(split_at)))
if split_at < 0:
raise ValueError('split_at must be non-negative')
if split_at > n_examples:
raise ValueError('split_at exceeds the dataset size')
subset1 = SubDataset(dataset, 0, split_at, order)
subset2 = SubDataset(dataset, split_at, n_examples, order)
return subset1, subset2
def split_dataset_random(dataset, first_size, seed=None):
"""Splits a dataset into two subsets randomly.
This function creates two instances of :class:`SubDataset`. These instances
do not share any examples, and they together cover all examples of the
original dataset. The split is automatically done randomly.
Args:
dataset: Dataset to split.
first_size (int): Size of the first subset.
seed (int): Seed the generator used for the permutation of indexes.
If an integer being convertible to 32 bit unsigned integers is
specified, it is guaranteed that each sample
in the given dataset always belongs to a specific subset.
If ``None``, the permutation is changed randomly.
Returns:
tuple: Two :class:`SubDataset` objects. The first subset contains
``first_size`` examples randomly chosen from the dataset without
replacement, and the second subset contains the rest of the
dataset.
"""
order = numpy.random.RandomState(seed).permutation(len(dataset))
return split_dataset(dataset, first_size, order)
def split_dataset_n(dataset, n, order=None):
"""Splits a dataset into ``n`` subsets.
Args:
dataset: Dataset to split.
n(int): The number of subsets.
order (sequence of ints): Permutation of indexes in the base dataset.
See the documentation of :class:`SubDataset` for details.
Returns:
list: List of ``n`` :class:`SubDataset` objects.
Each subset contains the examples of indexes
``order[i * (len(dataset) // n):(i + 1) * (len(dataset) // n)]``
.
"""
n_examples = len(dataset)
sub_size = n_examples // n
return [SubDataset(dataset, sub_size * i, sub_size * (i + 1), order)
for i in six.moves.range(n)]
def split_dataset_n_random(dataset, n, seed=None):
"""Splits a dataset into ``n`` subsets randomly.
Args:
dataset: Dataset to split.
n(int): The number of subsets.
seed (int): Seed the generator used for the permutation of indexes.
If an integer being convertible to 32 bit unsigned integers is
specified, it is guaranteed that each sample
in the given dataset always belongs to a specific subset.
If ``None``, the permutation is changed randomly.
Returns:
list: List of ``n`` :class:`SubDataset` objects.
Each subset contains ``len(dataset) // n`` examples randomly chosen
from the dataset without replacement.
"""
n_examples = len(dataset)
sub_size = n_examples // n
order = numpy.random.RandomState(seed).permutation(len(dataset))
return [SubDataset(dataset, sub_size * i, sub_size * (i + 1), order)
for i in six.moves.range(n)]
def get_cross_validation_datasets(dataset, n_fold, order=None):
"""Creates a set of training/test splits for cross validation.
This function generates ``n_fold`` splits of the given dataset. The first
part of each split corresponds to the training dataset, while the second
part to the test dataset. No pairs of test datasets share any examples, and
all test datasets together cover the whole base dataset. Each test dataset
contains almost same number of examples (the numbers may differ up to 1).
Args:
dataset: Dataset to split.
n_fold (int): Number of splits for cross validation.
order (sequence of ints): Order of indexes with which each split is
determined. If it is ``None``, then no permutation is used.
Returns:
list of tuples: List of dataset splits.
"""
if order is None:
order = numpy.arange(len(dataset))
else:
order = numpy.array(order) # copy
whole_size = len(dataset)
borders = [whole_size * i // n_fold for i in six.moves.range(n_fold + 1)]
test_sizes = [borders[i + 1] - borders[i] for i in six.moves.range(n_fold)]
splits = []
for test_size in reversed(test_sizes):
size = whole_size - test_size
splits.append(split_dataset(dataset, size, order))
new_order = numpy.empty_like(order)
new_order[:test_size] = order[-test_size:]
new_order[test_size:] = order[:-test_size]
order = new_order
return splits
def get_cross_validation_datasets_random(dataset, n_fold, seed=None):
"""Creates a set of training/test splits for cross validation randomly.
This function acts almost same as :func:`get_cross_validation_dataset`,
except automatically generating random permutation.
Args:
dataset: Dataset to split.
n_fold (int): Number of splits for cross validation.
seed (int): Seed the generator used for the permutation of indexes.
If an integer beging convertible to 32 bit unsigned integers is
specified, it is guaranteed that each sample
in the given dataset always belongs to a specific subset.
If ``None``, the permutation is changed randomly.
Returns:
list of tuples: List of dataset splits.
"""
order = numpy.random.RandomState(seed).permutation(len(dataset))
return get_cross_validation_datasets(dataset, n_fold, order)