-
-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
70 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
// 10 sec | ||
import 'dart:async'; | ||
|
||
import 'package:benchmark_harness/benchmark_harness.dart'; | ||
import 'package:ml_algo/ml_algo.dart'; | ||
import 'package:ml_linalg/matrix.dart'; | ||
import 'package:ml_linalg/vector.dart'; | ||
|
||
const observationsNum = 1000; | ||
const featuresNum = 20; | ||
|
||
class CrossValidatorBenchmark extends BenchmarkBase { | ||
CrossValidatorBenchmark() : super('Cross validator benchmark'); | ||
|
||
Matrix features; | ||
Matrix labels; | ||
CrossValidator crossValidator; | ||
|
||
static void main() { | ||
CrossValidatorBenchmark().report(); | ||
} | ||
|
||
@override | ||
void run() { | ||
crossValidator.evaluate((trainFeatures, trainLabels) => | ||
ParameterlessRegressor.knn(trainFeatures, trainLabels, k: 7), | ||
features, labels, MetricType.mape); | ||
} | ||
|
||
@override | ||
void setup() { | ||
features = Matrix.fromRows(List.generate(observationsNum, | ||
(i) => Vector.randomFilled(featuresNum))); | ||
labels = Matrix.fromColumns([Vector.randomFilled(observationsNum)]); | ||
|
||
crossValidator = CrossValidator.kFold(numberOfFolds: 5); | ||
} | ||
|
||
void tearDown() {} | ||
} | ||
|
||
Future main() async { | ||
CrossValidatorBenchmark.main(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,47 +1,28 @@ | ||
import 'package:ml_algo/src/model_selection/data_splitter/splitter.dart'; | ||
import 'package:xrange/zrange.dart'; | ||
|
||
class KFoldSplitter implements Splitter { | ||
final int _numberOfFolds; | ||
|
||
KFoldSplitter(this._numberOfFolds) { | ||
if (_numberOfFolds == 0 || _numberOfFolds == 1) { | ||
throw RangeError( | ||
'Number of folds must be greater than 1 and less than number of samples'); | ||
} | ||
} | ||
|
||
final int _numberOfFolds; | ||
|
||
@override | ||
Iterable<Iterable<int>> split(int numberOfSamples) sync* { | ||
if (_numberOfFolds > numberOfSamples) { | ||
throw RangeError.range(_numberOfFolds, 0, numberOfSamples, null, | ||
Iterable<Iterable<int>> split(int numOfObservations) sync* { | ||
if (_numberOfFolds > numOfObservations) { | ||
throw RangeError.range(_numberOfFolds, 0, numOfObservations, null, | ||
'Number of folds must be less than number of samples!'); | ||
} | ||
|
||
final remainder = numberOfSamples % _numberOfFolds; | ||
final size = (numberOfSamples / _numberOfFolds).truncate(); | ||
final sizes = List<int>.filled(_numberOfFolds, 1) | ||
.map((int el) => el * size) | ||
.toList(growable: false); | ||
|
||
if (remainder > 0) { | ||
final range = | ||
sizes.take(remainder).map((int el) => ++el).toList(growable: false); | ||
sizes.setRange(0, remainder, range); | ||
} | ||
|
||
int startIdx = 0; | ||
int endIdx = 0; | ||
|
||
for (int i = 0; i < sizes.length; i++) { | ||
endIdx = startIdx + sizes[i]; | ||
yield _range(startIdx, endIdx); | ||
final remainder = numOfObservations % _numberOfFolds; | ||
final foldSize = numOfObservations ~/ _numberOfFolds; | ||
for (int i = 0, startIdx = 0, endIdx = 0; i < _numberOfFolds; i++) { | ||
endIdx = startIdx + foldSize + (i >= _numberOfFolds - remainder ? 1 : 0); | ||
yield ZRange.closedOpen(startIdx, endIdx).values(); | ||
startIdx = endIdx; | ||
} | ||
} | ||
|
||
Iterable<int> _range(int start, int end) sync* { | ||
for (int i = start; i < end; i++) { | ||
yield i; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters