Skip to content

Commit

Permalink
k fold splitter refactored
Browse files Browse the repository at this point in the history
  • Loading branch information
gyrdym committed Apr 22, 2019
1 parent f85b376 commit 397290f
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 45 deletions.
44 changes: 44 additions & 0 deletions benchmark/cross_validator.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// 10 sec
import 'dart:async';

import 'package:benchmark_harness/benchmark_harness.dart';
import 'package:ml_algo/ml_algo.dart';
import 'package:ml_linalg/matrix.dart';
import 'package:ml_linalg/vector.dart';

const observationsNum = 1000;
const featuresNum = 20;

class CrossValidatorBenchmark extends BenchmarkBase {
CrossValidatorBenchmark() : super('Cross validator benchmark');

Matrix features;
Matrix labels;
CrossValidator crossValidator;

static void main() {
CrossValidatorBenchmark().report();
}

@override
void run() {
crossValidator.evaluate((trainFeatures, trainLabels) =>
ParameterlessRegressor.knn(trainFeatures, trainLabels, k: 7),
features, labels, MetricType.mape);
}

@override
void setup() {
features = Matrix.fromRows(List.generate(observationsNum,
(i) => Vector.randomFilled(featuresNum)));
labels = Matrix.fromColumns([Vector.randomFilled(observationsNum)]);

crossValidator = CrossValidator.kFold(numberOfFolds: 5);
}

void tearDown() {}
}

Future main() async {
CrossValidatorBenchmark.main();
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@ class CrossValidatorImpl implements CrossValidator {
}

final allIndicesGroups = _splitter.split(observations.rowsNum);
// TODO get rid of length accessing
final scores = List<double>(allIndicesGroups.length);
int scoreCounter = 0;
var score = 0.0;
var folds = 0;

for (final testIndices in allIndicesGroups) {
final trainFeatures =
Expand Down Expand Up @@ -63,13 +62,14 @@ class CrossValidatorImpl implements CrossValidator {
Matrix.fromRows(trainLabels, dtype: dtype),
)..fit();

scores[scoreCounter++] = predictor.test(
score += predictor.test(
Matrix.fromRows(testFeatures, dtype: dtype),
Matrix.fromRows(testLabels, dtype: dtype),
metric
);
folds++;
}

return scores.fold<double>(0, (sum, value) => sum + value) / scores.length;
return score / folds;
}
}
41 changes: 11 additions & 30 deletions lib/src/model_selection/data_splitter/k_fold.dart
Original file line number Diff line number Diff line change
@@ -1,47 +1,28 @@
import 'package:ml_algo/src/model_selection/data_splitter/splitter.dart';
import 'package:xrange/zrange.dart';

class KFoldSplitter implements Splitter {
final int _numberOfFolds;

KFoldSplitter(this._numberOfFolds) {
if (_numberOfFolds == 0 || _numberOfFolds == 1) {
throw RangeError(
'Number of folds must be greater than 1 and less than number of samples');
}
}

final int _numberOfFolds;

@override
Iterable<Iterable<int>> split(int numberOfSamples) sync* {
if (_numberOfFolds > numberOfSamples) {
throw RangeError.range(_numberOfFolds, 0, numberOfSamples, null,
Iterable<Iterable<int>> split(int numOfObservations) sync* {
if (_numberOfFolds > numOfObservations) {
throw RangeError.range(_numberOfFolds, 0, numOfObservations, null,
'Number of folds must be less than number of samples!');
}

final remainder = numberOfSamples % _numberOfFolds;
final size = (numberOfSamples / _numberOfFolds).truncate();
final sizes = List<int>.filled(_numberOfFolds, 1)
.map((int el) => el * size)
.toList(growable: false);

if (remainder > 0) {
final range =
sizes.take(remainder).map((int el) => ++el).toList(growable: false);
sizes.setRange(0, remainder, range);
}

int startIdx = 0;
int endIdx = 0;

for (int i = 0; i < sizes.length; i++) {
endIdx = startIdx + sizes[i];
yield _range(startIdx, endIdx);
final remainder = numOfObservations % _numberOfFolds;
final foldSize = numOfObservations ~/ _numberOfFolds;
for (int i = 0, startIdx = 0, endIdx = 0; i < _numberOfFolds; i++) {
endIdx = startIdx + foldSize + (i >= _numberOfFolds - remainder ? 1 : 0);
yield ZRange.closedOpen(startIdx, endIdx).values();
startIdx = endIdx;
}
}

Iterable<int> _range(int start, int end) sync* {
for (int i = start; i < end; i++) {
yield i;
}
}
}
20 changes: 10 additions & 10 deletions test/data_splitter/data_splitter_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ void main() {
expect(
kfoldSplitter.split(12),
equals([
[0, 1, 2],
[3, 4, 5],
[6, 7],
[8, 9],
[10, 11]
[0, 1],
[2, 3],
[4, 5],
[6, 7, 8],
[9, 10, 11],
]));

kfoldSplitter = KFoldSplitter(4);
Expand Down Expand Up @@ -64,11 +64,11 @@ void main() {
expect(
kfoldSplitter.split(37),
equals([
[0, 1, 2, 3, 4, 5, 6, 7],
[8, 9, 10, 11, 12, 13, 14, 15],
[16, 17, 18, 19, 20, 21, 22],
[23, 24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35, 36]
[0, 1, 2, 3, 4, 5, 6],
[7, 8, 9, 10, 11, 12, 13],
[14, 15, 16, 17, 18, 19, 20],
[21, 22, 23, 24, 25, 26, 27, 28],
[29, 30, 31, 32, 33, 34, 35, 36],
]));

kfoldSplitter = KFoldSplitter(3);
Expand Down

0 comments on commit 397290f

Please sign in to comment.