diff --git a/lib/src/model_selection/cross_validator/cross_validator.dart b/lib/src/model_selection/cross_validator/cross_validator.dart index d69ce5fd..37d64e77 100644 --- a/lib/src/model_selection/cross_validator/cross_validator.dart +++ b/lib/src/model_selection/cross_validator/cross_validator.dart @@ -1,5 +1,7 @@ import 'package:ml_algo/src/metric/metric_type.dart'; import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_impl.dart'; +import 'package:ml_algo/src/model_selection/data_splitter/k_fold.dart'; +import 'package:ml_algo/src/model_selection/data_splitter/leave_p_out.dart'; import 'package:ml_algo/src/predictor/predictor.dart'; import 'package:ml_linalg/matrix.dart'; @@ -9,14 +11,15 @@ abstract class CrossValidator { /// /// It splits a dataset into [numberOfFolds] test sets and subsequently /// evaluates the predictor on each produced test set - factory CrossValidator.kFold({Type dtype, int numberOfFolds}) = - CrossValidatorImpl.kFold; + factory CrossValidator.kFold({Type dtype, int numberOfFolds}) => + CrossValidatorImpl(dtype, KFoldSplitter(numberOfFolds)); /// Creates LPO validator to evaluate quality of a predictor. /// /// It splits a dataset into all possible test sets of size [p] and /// subsequently evaluates quality of the predictor on each produced test set - factory CrossValidator.lpo({Type dtype, int p}) = CrossValidatorImpl.lpo; + factory CrossValidator.lpo({Type dtype, int p}) => + CrossValidatorImpl(dtype, LeavePOutSplitter(p)); /// Returns a score of quality of passed predictor depending on given [metric] double evaluate(Predictor predictorFactory(Matrix features, Matrix outcomes), diff --git a/lib/src/model_selection/cross_validator/cross_validator_impl.dart b/lib/src/model_selection/cross_validator/cross_validator_impl.dart index 1fc6cbc4..a09b12ed 100644 --- a/lib/src/model_selection/cross_validator/cross_validator_impl.dart +++ b/lib/src/model_selection/cross_validator/cross_validator_impl.dart @@ -1,21 +1,13 @@ -import 'package:ml_algo/src/utils/default_parameter_values.dart'; import 'package:ml_algo/src/metric/metric_type.dart'; import 'package:ml_algo/src/model_selection/cross_validator/cross_validator.dart'; -import 'package:ml_algo/src/model_selection/data_splitter/k_fold.dart'; -import 'package:ml_algo/src/model_selection/data_splitter/leave_p_out.dart'; import 'package:ml_algo/src/model_selection/data_splitter/splitter.dart'; import 'package:ml_algo/src/predictor/predictor.dart'; +import 'package:ml_algo/src/utils/default_parameter_values.dart'; import 'package:ml_linalg/matrix.dart'; import 'package:ml_linalg/vector.dart'; class CrossValidatorImpl implements CrossValidator { - factory CrossValidatorImpl.kFold({Type dtype, int numberOfFolds = 5}) => - CrossValidatorImpl._(dtype, KFoldSplitter(numberOfFolds)); - - factory CrossValidatorImpl.lpo({Type dtype, int p}) => - CrossValidatorImpl._(dtype, LeavePOutSplitter(p)); - - CrossValidatorImpl._(Type dtype, this._splitter) + CrossValidatorImpl(Type dtype, this._splitter) : dtype = dtype ?? DefaultParameterValues.dtype; final Type dtype; diff --git a/lib/src/model_selection/data_splitter/k_fold.dart b/lib/src/model_selection/data_splitter/k_fold.dart index d350df7a..2e03208f 100644 --- a/lib/src/model_selection/data_splitter/k_fold.dart +++ b/lib/src/model_selection/data_splitter/k_fold.dart @@ -5,7 +5,8 @@ class KFoldSplitter implements Splitter { KFoldSplitter(this._numberOfFolds) { if (_numberOfFolds == 0 || _numberOfFolds == 1) { throw RangeError( - 'Number of folds must be greater than 1 and less than number of samples'); + 'Number of folds must be greater than 1 and less than number of ' + 'samples'); } } @@ -20,6 +21,8 @@ class KFoldSplitter implements Splitter { final remainder = numOfObservations % _numberOfFolds; final foldSize = numOfObservations ~/ _numberOfFolds; for (int i = 0, startIdx = 0, endIdx = 0; i < _numberOfFolds; i++) { + // if we reached last fold of size [foldSize], all the next folds up + // to the last fold will have size that is equal to [foldSize] + 1 endIdx = startIdx + foldSize + (i >= _numberOfFolds - remainder ? 1 : 0); yield ZRange.closedOpen(startIdx, endIdx).values(); startIdx = endIdx; diff --git a/lib/src/model_selection/data_splitter/leave_p_out.dart b/lib/src/model_selection/data_splitter/leave_p_out.dart index f5d4c18f..4df31ebb 100644 --- a/lib/src/model_selection/data_splitter/leave_p_out.dart +++ b/lib/src/model_selection/data_splitter/leave_p_out.dart @@ -1,21 +1,18 @@ import 'package:ml_algo/src/model_selection/data_splitter/splitter.dart'; class LeavePOutSplitter implements Splitter { - int _p = 2; - - LeavePOutSplitter(int p) { - if (p == 0) { - throw UnsupportedError('Value `$p` for parameter `p` is unsupported'); + LeavePOutSplitter([this._p = 2]) { + if (_p == 0) { + throw UnsupportedError('Value `$_p` for parameter `p` is unsupported'); } - _p = p; } + final int _p; + @override Iterable> split(int numberOfSamples) sync* { for (int u = 0; u < 1 << numberOfSamples; u++) { - if (_count(u) == _p) { - yield _generateCombination(u); - } + if (_count(u) == _p) yield _generateCombination(u); } } @@ -28,9 +25,7 @@ class LeavePOutSplitter implements Splitter { Iterable _generateCombination(int u) sync* { for (int n = 0; u > 0; ++n, u >>= 1) { - if ((u & 1) > 0) { - yield n; - } + if ((u & 1) > 0) yield n; } } } diff --git a/test/data_splitter/data_splitter_test.dart b/test/data_splitter/data_splitter_test.dart deleted file mode 100644 index 094d746c..00000000 --- a/test/data_splitter/data_splitter_test.dart +++ /dev/null @@ -1,148 +0,0 @@ -import 'package:ml_algo/src/model_selection/data_splitter/k_fold.dart'; -import 'package:ml_algo/src/model_selection/data_splitter/leave_p_out.dart'; -import 'package:test/test.dart'; - -void main() { - group('Splitters test:\n', () { - KFoldSplitter kfoldSplitter; - LeavePOutSplitter leavePOutSplitter; - - test('K fold splitter test', () { - expect(() => KFoldSplitter(0), throwsRangeError); - expect(() => KFoldSplitter(1), throwsRangeError); - - kfoldSplitter = KFoldSplitter(5); - - expect( - kfoldSplitter.split(12), - equals([ - [0, 1], - [2, 3], - [4, 5], - [6, 7, 8], - [9, 10, 11], - ])); - - kfoldSplitter = KFoldSplitter(4); - expect( - kfoldSplitter.split(12), - equals([ - [0, 1, 2], - [3, 4, 5], - [6, 7, 8], - [9, 10, 11] - ])); - - kfoldSplitter = KFoldSplitter(3); - expect( - kfoldSplitter.split(12), - equals([ - [0, 1, 2, 3], - [4, 5, 6, 7], - [8, 9, 10, 11] - ])); - - kfoldSplitter = KFoldSplitter(12); - expect( - kfoldSplitter.split(12), - equals([ - [0], - [1], - [2], - [3], - [4], - [5], - [6], - [7], - [8], - [9], - [10], - [11] - ])); - - kfoldSplitter = KFoldSplitter(5); - expect( - kfoldSplitter.split(37), - equals([ - [0, 1, 2, 3, 4, 5, 6], - [7, 8, 9, 10, 11, 12, 13], - [14, 15, 16, 17, 18, 19, 20], - [21, 22, 23, 24, 25, 26, 27, 28], - [29, 30, 31, 32, 33, 34, 35, 36], - ])); - - kfoldSplitter = KFoldSplitter(3); - expect(() => kfoldSplitter.split(0), throwsRangeError); - - kfoldSplitter = KFoldSplitter(9); - expect(() => kfoldSplitter.split(8), throwsRangeError); - }); - - test('Leave P out splitter test... ', () { - leavePOutSplitter = LeavePOutSplitter(2); - - expect( - leavePOutSplitter.split(4).toSet(), - equals([ - [0, 1], - [0, 2], - [0, 3], - [1, 2], - [1, 3], - [2, 3] - ].toSet())); - expect( - leavePOutSplitter.split(5).toSet(), - equals([ - [0, 1], - [0, 2], - [0, 3], - [0, 4], - [1, 2], - [1, 3], - [1, 4], - [2, 3], - [2, 4], - [3, 4] - ].toSet())); - - leavePOutSplitter = LeavePOutSplitter(1); - expect( - leavePOutSplitter.split(5).toSet(), - equals([ - [0], - [1], - [2], - [3], - [4] - ])); - - leavePOutSplitter = LeavePOutSplitter(3); - expect( - leavePOutSplitter.split(4).toSet(), - equals([ - [0, 1, 2], - [0, 1, 3], - [0, 2, 3], - [1, 2, 3] - ].toSet())); - expect( - leavePOutSplitter.split(5).toSet(), - equals([ - [0, 1, 2], - [0, 1, 3], - [0, 1, 4], - [0, 2, 3], - [0, 2, 4], - [0, 3, 4], - [1, 2, 3], - [1, 2, 4], - [1, 3, 4], - [2, 3, 4] - ].toSet())); - - expect(() => leavePOutSplitter = LeavePOutSplitter(0), - throwsUnsupportedError); - }); - }); -} diff --git a/test/data_splitter/k_fold_splitter_test.dart b/test/data_splitter/k_fold_splitter_test.dart new file mode 100644 index 00000000..94fae3a4 --- /dev/null +++ b/test/data_splitter/k_fold_splitter_test.dart @@ -0,0 +1,69 @@ +import 'package:ml_algo/src/model_selection/data_splitter/k_fold.dart'; +import 'package:test/test.dart'; + +void main() { + group('KFold splitter', () { + void testKFoldSplitter(int numOfFold, int numOfObservations, + Iterable> expected) { + test('should return proper groups of indices if number of folds is ' + '$numOfFold and number of observations is $numOfObservations', () { + final splitter = KFoldSplitter(numOfFold); + expect(splitter.split(numOfObservations), equals(expected)); + }); + } + + test('should throw an exception if passed number of folds is equal to ' + '0', () { + expect(() => KFoldSplitter(0), throwsRangeError); + }); + + test('should throw an exception if passed number of folds is equal to ' + '1', () { + expect(() => KFoldSplitter(1), throwsRangeError); + }); + + testKFoldSplitter(5, 12, [ + [0, 1], + [2, 3], + [4, 5], + [6, 7, 8], + [9, 10, 11], + ]); + + testKFoldSplitter(4, 12, [ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + [9, 10, 11], + ]); + + testKFoldSplitter(3, 12, [ + [0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11], + ]); + + testKFoldSplitter(12, 12, [ + [0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], + ]); + + testKFoldSplitter(5, 37, [ + [0, 1, 2, 3, 4, 5, 6], + [7, 8, 9, 10, 11, 12, 13], + [14, 15, 16, 17, 18, 19, 20], + [21, 22, 23, 24, 25, 26, 27, 28], + [29, 30, 31, 32, 33, 34, 35, 36], + ]); + + test('should throws a range error if number of observations is 0', () { + final splitter = KFoldSplitter(3); + expect(() => splitter.split(0), throwsRangeError); + }); + + test('should throws a range error if number of observations is less than' + 'number of folds', () { + final splitter = KFoldSplitter(9); + expect(() => splitter.split(8), throwsRangeError); + }); + }); +} diff --git a/test/data_splitter/lpo_splitter_test.dart b/test/data_splitter/lpo_splitter_test.dart new file mode 100644 index 00000000..04f5dab8 --- /dev/null +++ b/test/data_splitter/lpo_splitter_test.dart @@ -0,0 +1,69 @@ +import 'package:ml_algo/src/model_selection/data_splitter/leave_p_out.dart'; +import 'package:test/test.dart'; + +void main() { + group('Leave p out splitter', () { + void testLpoSplitter(int p, int numOfObservations, + Iterable> expected) { + test('should return proper groups of indices if p is $p and number of ' + 'observations is $numOfObservations', () { + final splitter = LeavePOutSplitter(p); + expect(splitter.split(numOfObservations).toSet(), equals(expected)); + }); + } + + testLpoSplitter(2, 4, [ + [0, 1], + [0, 2], + [0, 3], + [1, 2], + [1, 3], + [2, 3], + ].toSet()); + + testLpoSplitter(2, 5, [ + [0, 1], + [0, 2], + [0, 3], + [0, 4], + [1, 2], + [1, 3], + [1, 4], + [2, 3], + [2, 4], + [3, 4], + ].toSet()); + + testLpoSplitter(1, 5, [ + [0], + [1], + [2], + [3], + [4], + ].toSet()); + + testLpoSplitter(3, 4, [ + [0, 1, 2], + [0, 1, 3], + [0, 2, 3], + [1, 2, 3], + ].toSet()); + + testLpoSplitter(3, 5, [ + [0, 1, 2], + [0, 1, 3], + [0, 1, 4], + [0, 2, 3], + [0, 2, 4], + [0, 3, 4], + [1, 2, 3], + [1, 2, 4], + [1, 3, 4], + [2, 3, 4], + ].toSet()); + + test('should throw an error, if p is equal to 0', () { + expect(() => LeavePOutSplitter(0), throwsUnsupportedError); + }); + }); +} diff --git a/test/test_all.dart b/test/test_all.dart index d83c9c14..fbe1de78 100644 --- a/test/test_all.dart +++ b/test/test_all.dart @@ -8,7 +8,8 @@ import 'classifier/softmax_regressor_test.dart' as softmax_regressor_test; import 'cost_function/cost_function_test.dart' as cost_function_test; import 'data_preprocessing/intercept_preprocessor_test.dart' as intercept_preprocessor_test; -import 'data_splitter/data_splitter_test.dart' as data_splitter_test; +import 'data_splitter/k_fold_splitter_test.dart' as k_fold_splitter_test; +import 'data_splitter/lpo_splitter_test.dart' as lpo_splitter_test; import 'math/randomizer_test.dart' as randomizer_test; import 'optimizer/convergence_detector/convergence_detector_impl_test.dart' as convergence_detector_test; @@ -34,7 +35,8 @@ void main() { softmax_regressor_test.main(); cost_function_test.main(); intercept_preprocessor_test.main(); - data_splitter_test.main(); + k_fold_splitter_test.main(); + lpo_splitter_test.main(); randomizer_test.main(); convergence_detector_test.main(); coord_optimizer_integration_test.main();