-
-
Notifications
You must be signed in to change notification settings - Fork 29
/
k_fold.dart
31 lines (28 loc) · 1.16 KB
/
k_fold.dart
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import 'package:ml_algo/src/model_selection/data_splitter/splitter.dart';
import 'package:xrange/zrange.dart';
class KFoldSplitter implements Splitter {
KFoldSplitter(this._numberOfFolds) {
if (_numberOfFolds == 0 || _numberOfFolds == 1) {
throw RangeError(
'Number of folds must be greater than 1 and less than number of '
'samples');
}
}
final int _numberOfFolds;
@override
Iterable<Iterable<int>> split(int numOfObservations) sync* {
if (_numberOfFolds > numOfObservations) {
throw RangeError.range(_numberOfFolds, 0, numOfObservations, null,
'Number of folds must be less than number of samples!');
}
final remainder = numOfObservations % _numberOfFolds;
final foldSize = numOfObservations ~/ _numberOfFolds;
for (int i = 0, startIdx = 0, endIdx = 0; i < _numberOfFolds; i++) {
// if we reached last fold of size [foldSize], all the next folds up
// to the last fold will have size that is equal to [foldSize] + 1
endIdx = startIdx + foldSize + (i >= _numberOfFolds - remainder ? 1 : 0);
yield ZRange.closedOpen(startIdx, endIdx).values();
startIdx = endIdx;
}
}
}