Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #100 from gyrdym/cross-validator-unit-tests
Cross validator and data splitters refactored + unit tests for them added
- Loading branch information
Showing
13 changed files
with
322 additions
and
217 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
// 8.5 sec | ||
import 'dart:async'; | ||
|
||
import 'package:benchmark_harness/benchmark_harness.dart'; | ||
import 'package:ml_algo/ml_algo.dart'; | ||
import 'package:ml_linalg/matrix.dart'; | ||
import 'package:ml_linalg/vector.dart'; | ||
|
||
const observationsNum = 1000; | ||
const featuresNum = 20; | ||
|
||
class CrossValidatorBenchmark extends BenchmarkBase { | ||
CrossValidatorBenchmark() : super('Cross validator benchmark'); | ||
|
||
Matrix features; | ||
Matrix labels; | ||
CrossValidator crossValidator; | ||
|
||
static void main() { | ||
CrossValidatorBenchmark().report(); | ||
} | ||
|
||
@override | ||
void run() { | ||
crossValidator.evaluate((trainFeatures, trainLabels) => | ||
ParameterlessRegressor.knn(trainFeatures, trainLabels, k: 7), | ||
features, labels, MetricType.mape); | ||
} | ||
|
||
@override | ||
void setup() { | ||
features = Matrix.fromRows(List.generate(observationsNum, | ||
(i) => Vector.randomFilled(featuresNum))); | ||
labels = Matrix.fromColumns([Vector.randomFilled(observationsNum)]); | ||
|
||
crossValidator = CrossValidator.kFold(numberOfFolds: 5); | ||
} | ||
|
||
void tearDown() {} | ||
} | ||
|
||
Future main() async { | ||
CrossValidatorBenchmark.main(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,47 +1,31 @@ | ||
import 'package:ml_algo/src/model_selection/data_splitter/splitter.dart'; | ||
import 'package:xrange/zrange.dart'; | ||
|
||
class KFoldSplitter implements Splitter { | ||
final int _numberOfFolds; | ||
|
||
KFoldSplitter(this._numberOfFolds) { | ||
if (_numberOfFolds == 0 || _numberOfFolds == 1) { | ||
throw RangeError( | ||
'Number of folds must be greater than 1 and less than number of samples'); | ||
'Number of folds must be greater than 1 and less than number of ' | ||
'samples'); | ||
} | ||
} | ||
|
||
final int _numberOfFolds; | ||
|
||
@override | ||
Iterable<Iterable<int>> split(int numberOfSamples) sync* { | ||
if (_numberOfFolds > numberOfSamples) { | ||
throw RangeError.range(_numberOfFolds, 0, numberOfSamples, null, | ||
Iterable<Iterable<int>> split(int numOfObservations) sync* { | ||
if (_numberOfFolds > numOfObservations) { | ||
throw RangeError.range(_numberOfFolds, 0, numOfObservations, null, | ||
'Number of folds must be less than number of samples!'); | ||
} | ||
|
||
final remainder = numberOfSamples % _numberOfFolds; | ||
final size = (numberOfSamples / _numberOfFolds).truncate(); | ||
final sizes = List<int>.filled(_numberOfFolds, 1) | ||
.map((int el) => el * size) | ||
.toList(growable: false); | ||
|
||
if (remainder > 0) { | ||
final range = | ||
sizes.take(remainder).map((int el) => ++el).toList(growable: false); | ||
sizes.setRange(0, remainder, range); | ||
} | ||
|
||
int startIdx = 0; | ||
int endIdx = 0; | ||
|
||
for (int i = 0; i < sizes.length; i++) { | ||
endIdx = startIdx + sizes[i]; | ||
yield _range(startIdx, endIdx); | ||
final remainder = numOfObservations % _numberOfFolds; | ||
final foldSize = numOfObservations ~/ _numberOfFolds; | ||
for (int i = 0, startIdx = 0, endIdx = 0; i < _numberOfFolds; i++) { | ||
// if we reached last fold of size [foldSize], all the next folds up | ||
// to the last fold will have size that is equal to [foldSize] + 1 | ||
endIdx = startIdx + foldSize + (i >= _numberOfFolds - remainder ? 1 : 0); | ||
yield ZRange.closedOpen(startIdx, endIdx).values(); | ||
startIdx = endIdx; | ||
} | ||
} | ||
|
||
Iterable<int> _range(int start, int end) sync* { | ||
for (int i = start; i < end; i++) { | ||
yield i; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import 'dart:typed_data'; | ||
|
||
import 'package:ml_algo/src/metric/metric_type.dart'; | ||
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_impl.dart'; | ||
import 'package:ml_algo/src/model_selection/data_splitter/splitter.dart'; | ||
import 'package:ml_linalg/matrix.dart'; | ||
import 'package:mockito/mockito.dart'; | ||
import 'package:test/test.dart'; | ||
|
||
import '../test_utils/mocks.dart'; | ||
|
||
Splitter createSplitter(Iterable<Iterable<int>> indices) { | ||
final splitter = SplitterMock(); | ||
when(splitter.split(any)).thenReturn(indices); | ||
return splitter; | ||
} | ||
|
||
void main() { | ||
group('CrossValidatorImpl', () { | ||
test('should perform validation of a model on given test indices of' | ||
'observations', () { | ||
final allObservations = Matrix.from([ | ||
[330, 930, 130], | ||
[630, 830, 230], | ||
[730, 730, 330], | ||
[830, 630, 430], | ||
[930, 530, 530], | ||
[130, 430, 630], | ||
[230, 330, 730], | ||
[430, 230, 830], | ||
[530, 130, 930], | ||
]); | ||
final allOutcomes = Matrix.from([ | ||
[100],[200],[300],[400],[500],[600],[700],[800],[900], | ||
]); | ||
final metric = MetricType.mape; | ||
final splitter = createSplitter([[0,2,4],[6, 8]]); | ||
final predictor = PredictorMock(); | ||
final validator = CrossValidatorImpl(Float32x4, splitter); | ||
|
||
var score = 20.0; | ||
when(predictor.test(any, any, any)) | ||
.thenAnswer((Invocation inv) => score = score + 10); | ||
|
||
final actual = validator.evaluate((observations, outcomes) => predictor, | ||
allObservations, allOutcomes, metric); | ||
|
||
expect(actual, 35); | ||
|
||
verify(predictor.test(argThat(equals([ | ||
[330, 930, 130], | ||
[730, 730, 330], | ||
[930, 530, 530], | ||
])), argThat(equals([[100], [300], [500]])), metric)).called(1); | ||
|
||
verify(predictor.test(argThat(equals([ | ||
[230, 330, 730], | ||
[530, 130, 930], | ||
])), argThat(equals([[700], [900]])), metric)).called(1); | ||
}); | ||
|
||
test('should throw an exception if observations number and outcomes number ' | ||
'mismatch', () { | ||
final allObservations = Matrix.from([ | ||
[330, 930, 130], | ||
[630, 830, 230], | ||
]); | ||
final allOutcomes = Matrix.from([ | ||
[100], | ||
]); | ||
final metric = MetricType.mape; | ||
final splitter = SplitterMock(); | ||
final predictor = PredictorMock(); | ||
final validator = CrossValidatorImpl(Float32x4, splitter); | ||
|
||
expect(() => validator.evaluate((observations, outcomes) => predictor, | ||
allObservations, allOutcomes, metric), throwsException); | ||
}); | ||
}); | ||
} |
Oops, something went wrong.