diff --git a/CHANGELOG.md b/CHANGELOG.md index 1436c70d..ea01e8ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ # Changelog +## 14.0.1 +- data splitters renamed and reorganized + ## 14.0.0 - Breaking change: - `CrossValidator`: `evalute` method's api changed, it returns a Future resolving with scores Vector now instead diff --git a/lib/src/di/dependencies.dart b/lib/src/di/dependencies.dart index 7afc9f2f..8f60a4fc 100644 --- a/lib/src/di/dependencies.dart +++ b/lib/src/di/dependencies.dart @@ -30,8 +30,8 @@ import 'package:ml_algo/src/link_function/softmax/float32_softmax_link_function. import 'package:ml_algo/src/link_function/softmax/float64_softmax_link_function.dart'; import 'package:ml_algo/src/math/randomizer/randomizer_factory.dart'; import 'package:ml_algo/src/math/randomizer/randomizer_factory_impl.dart'; -import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter_factory.dart'; -import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter_factory_impl.dart'; +import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory.dart'; +import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory_impl.dart'; import 'package:ml_algo/src/regressor/knn_regressor/knn_regressor_factory.dart'; import 'package:ml_algo/src/regressor/knn_regressor/knn_regressor_factory_impl.dart'; import 'package:ml_algo/src/tree_trainer/leaf_detector/leaf_detector_factory.dart'; @@ -87,8 +87,8 @@ Injector get dependencies => (_) => const Float64SoftmaxLinkFunction(), dependencyName: float64SoftmaxLinkFunctionToken) - ..registerSingleton( - (_) => const DataSplitterFactoryImpl()) + ..registerSingleton( + (_) => const SplitIndicesProviderFactoryImpl()) ..registerSingleton( (_) => const SoftmaxRegressorFactoryImpl()) diff --git a/lib/src/model_selection/cross_validator/cross_validator.dart b/lib/src/model_selection/cross_validator/cross_validator.dart index 13482a6b..b53e3481 100644 --- a/lib/src/model_selection/cross_validator/cross_validator.dart +++ b/lib/src/model_selection/cross_validator/cross_validator.dart @@ -2,8 +2,8 @@ import 'package:ml_algo/src/di/dependencies.dart'; import 'package:ml_algo/src/metric/metric_type.dart'; import 'package:ml_algo/src/model_selection/assessable.dart'; import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_impl.dart'; -import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter_factory.dart'; -import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter_type.dart'; +import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory.dart'; +import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_type.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_linalg/dtype.dart'; import 'package:ml_linalg/linalg.dart'; @@ -38,9 +38,9 @@ abstract class CrossValidator { DType dtype = DType.float32, }) { final dataSplitterFactory = dependencies - .getDependency(); + .getDependency(); final dataSplitter = dataSplitterFactory - .createByType(DataSplitterType.kFold, numberOfFolds: numberOfFolds); + .createByType(SplitIndicesProviderType.kFold, numberOfFolds: numberOfFolds); return CrossValidatorImpl( samples, @@ -50,7 +50,7 @@ abstract class CrossValidator { ); } - /// Creates LPO validator to evaluate quality of a predictor. + /// Creates LPO validator to evaluate performance of a predictor. /// /// It splits a dataset into all possible test sets of size [p] and /// subsequently evaluates quality of the predictor on each produced test set. @@ -71,9 +71,10 @@ abstract class CrossValidator { int p, { DType dtype = DType.float32, }) { - final dataSplitterFactory = dependencies.getDependency(); + final dataSplitterFactory = dependencies + .getDependency(); final dataSplitter = dataSplitterFactory - .createByType(DataSplitterType.lpo, p: p); + .createByType(SplitIndicesProviderType.lpo, p: p); return CrossValidatorImpl( samples, diff --git a/lib/src/model_selection/cross_validator/cross_validator_impl.dart b/lib/src/model_selection/cross_validator/cross_validator_impl.dart index de138607..f3b24432 100644 --- a/lib/src/model_selection/cross_validator/cross_validator_impl.dart +++ b/lib/src/model_selection/cross_validator/cross_validator_impl.dart @@ -2,7 +2,7 @@ import 'package:ml_algo/src/common/exception/invalid_test_data_columns_number_ex import 'package:ml_algo/src/common/exception/invalid_train_data_columns_number_exception.dart'; import 'package:ml_algo/src/metric/metric_type.dart'; import 'package:ml_algo/src/model_selection/cross_validator/cross_validator.dart'; -import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter.dart'; +import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_linalg/dtype.dart'; import 'package:ml_linalg/matrix.dart'; @@ -20,7 +20,7 @@ class CrossValidatorImpl implements CrossValidator { final DataFrame samples; final DType dtype; final Iterable targetNames; - final DataSplitter _splitter; + final SplitIndicesProvider _splitter; @override Future evaluate( @@ -35,7 +35,7 @@ class CrossValidatorImpl implements CrossValidator { final discreteColumns = enumerate(samples.series) .where((indexedSeries) => indexedSeries.value.isDiscrete) .map((indexedSeries) => indexedSeries.index); - final allIndicesGroups = _splitter.split(samplesAsMatrix.rowsNum); + final allIndicesGroups = _splitter.getIndices(samplesAsMatrix.rowsNum); final scores = allIndicesGroups .map((testRowsIndices) { final split = _makeSplit(testRowsIndices, discreteColumns); diff --git a/lib/src/model_selection/split_indices_provider/data_splitter.dart b/lib/src/model_selection/split_indices_provider/data_splitter.dart deleted file mode 100644 index 265ca4e9..00000000 --- a/lib/src/model_selection/split_indices_provider/data_splitter.dart +++ /dev/null @@ -1,3 +0,0 @@ -abstract class DataSplitter { - Iterable> split(int numberOfSamples); -} diff --git a/lib/src/model_selection/split_indices_provider/data_splitter_factory.dart b/lib/src/model_selection/split_indices_provider/data_splitter_factory.dart deleted file mode 100644 index 3942c5f1..00000000 --- a/lib/src/model_selection/split_indices_provider/data_splitter_factory.dart +++ /dev/null @@ -1,9 +0,0 @@ -import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter.dart'; -import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter_type.dart'; - -abstract class DataSplitterFactory { - DataSplitter createByType(DataSplitterType splitterType, { - int numberOfFolds, - int p, - }); -} diff --git a/lib/src/model_selection/split_indices_provider/data_splitter_type.dart b/lib/src/model_selection/split_indices_provider/data_splitter_type.dart deleted file mode 100644 index 0b29dbac..00000000 --- a/lib/src/model_selection/split_indices_provider/data_splitter_type.dart +++ /dev/null @@ -1,3 +0,0 @@ -enum DataSplitterType { - lpo, kFold, -} diff --git a/lib/src/model_selection/split_indices_provider/k_fold_data_splitter.dart b/lib/src/model_selection/split_indices_provider/k_fold_data_splitter.dart index 1756b0bf..370a5ba3 100644 --- a/lib/src/model_selection/split_indices_provider/k_fold_data_splitter.dart +++ b/lib/src/model_selection/split_indices_provider/k_fold_data_splitter.dart @@ -1,8 +1,8 @@ -import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter.dart'; +import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider.dart'; import 'package:xrange/integers.dart'; -class KFoldDataSplitter implements DataSplitter { - KFoldDataSplitter(this._numberOfFolds) { +class KFoldIndicesProvider implements SplitIndicesProvider { + KFoldIndicesProvider(this._numberOfFolds) { if (_numberOfFolds == 0 || _numberOfFolds == 1) { throw RangeError( 'Number of folds must be greater than 1 and less than the number of ' @@ -13,7 +13,7 @@ class KFoldDataSplitter implements DataSplitter { final int _numberOfFolds; @override - Iterable> split(int numOfObservations) sync* { + Iterable> getIndices(int numOfObservations) sync* { if (_numberOfFolds > numOfObservations) { throw RangeError.range(_numberOfFolds, 0, numOfObservations, null, 'Number of folds must be less than the number of samples'); diff --git a/lib/src/model_selection/split_indices_provider/leave_p_out_data_splitter.dart b/lib/src/model_selection/split_indices_provider/lpo_indices_provider.dart similarity index 76% rename from lib/src/model_selection/split_indices_provider/leave_p_out_data_splitter.dart rename to lib/src/model_selection/split_indices_provider/lpo_indices_provider.dart index 0d6a73c2..6f768d2e 100644 --- a/lib/src/model_selection/split_indices_provider/leave_p_out_data_splitter.dart +++ b/lib/src/model_selection/split_indices_provider/lpo_indices_provider.dart @@ -1,7 +1,7 @@ -import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter.dart'; +import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider.dart'; -class LeavePOutDataSplitter implements DataSplitter { - LeavePOutDataSplitter([this._p = 2]) { +class LpoIndicesProvider implements SplitIndicesProvider { + LpoIndicesProvider([this._p = 2]) { if (_p == 0) { throw UnsupportedError('Value `$_p` for parameter `p` is unsupported'); } @@ -10,7 +10,7 @@ class LeavePOutDataSplitter implements DataSplitter { final int _p; @override - Iterable> split(int numberOfSamples) sync* { + Iterable> getIndices(int numberOfSamples) sync* { for (var u = 0; u < 1 << numberOfSamples; u++) { if (_count(u) == _p) yield _generateCombination(u); } diff --git a/lib/src/model_selection/split_indices_provider/split_indices_provider.dart b/lib/src/model_selection/split_indices_provider/split_indices_provider.dart new file mode 100644 index 00000000..e4906366 --- /dev/null +++ b/lib/src/model_selection/split_indices_provider/split_indices_provider.dart @@ -0,0 +1,3 @@ +abstract class SplitIndicesProvider { + Iterable> getIndices(int numberOfSamples); +} diff --git a/lib/src/model_selection/split_indices_provider/split_indices_provider_factory.dart b/lib/src/model_selection/split_indices_provider/split_indices_provider_factory.dart new file mode 100644 index 00000000..0bf151da --- /dev/null +++ b/lib/src/model_selection/split_indices_provider/split_indices_provider_factory.dart @@ -0,0 +1,9 @@ +import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_type.dart'; +import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider.dart'; + +abstract class SplitIndicesProviderFactory { + SplitIndicesProvider createByType(SplitIndicesProviderType splitterType, { + int numberOfFolds, + int p, + }); +} diff --git a/lib/src/model_selection/split_indices_provider/data_splitter_factory_impl.dart b/lib/src/model_selection/split_indices_provider/split_indices_provider_factory_impl.dart similarity index 61% rename from lib/src/model_selection/split_indices_provider/data_splitter_factory_impl.dart rename to lib/src/model_selection/split_indices_provider/split_indices_provider_factory_impl.dart index baa4f2a0..15a9f55f 100644 --- a/lib/src/model_selection/split_indices_provider/data_splitter_factory_impl.dart +++ b/lib/src/model_selection/split_indices_provider/split_indices_provider_factory_impl.dart @@ -1,30 +1,30 @@ -import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter.dart'; -import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter_factory.dart'; -import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter_type.dart'; +import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory.dart'; +import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_type.dart'; import 'package:ml_algo/src/model_selection/split_indices_provider/k_fold_data_splitter.dart'; -import 'package:ml_algo/src/model_selection/split_indices_provider/leave_p_out_data_splitter.dart'; +import 'package:ml_algo/src/model_selection/split_indices_provider/lpo_indices_provider.dart'; +import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider.dart'; -class DataSplitterFactoryImpl implements DataSplitterFactory { - const DataSplitterFactoryImpl(); +class SplitIndicesProviderFactoryImpl implements SplitIndicesProviderFactory { + const SplitIndicesProviderFactoryImpl(); @override - DataSplitter createByType(DataSplitterType splitterType, { + SplitIndicesProvider createByType(SplitIndicesProviderType splitterType, { int numberOfFolds, int p, }) { switch (splitterType) { - case DataSplitterType.kFold: + case SplitIndicesProviderType.kFold: if (numberOfFolds == null) { throw Exception('Number of folds is not defined for K-fold splitter'); } - return KFoldDataSplitter(numberOfFolds); + return KFoldIndicesProvider(numberOfFolds); - case DataSplitterType.lpo: + case SplitIndicesProviderType.lpo: if (p == null) { throw Exception('`p` parameter is not defined for leave-p-out ' 'splitter'); } - return LeavePOutDataSplitter(p); + return LpoIndicesProvider(p); default: throw UnimplementedError('Splitter of type $splitterType is not ' diff --git a/lib/src/model_selection/split_indices_provider/split_indices_provider_type.dart b/lib/src/model_selection/split_indices_provider/split_indices_provider_type.dart new file mode 100644 index 00000000..aa5f2d74 --- /dev/null +++ b/lib/src/model_selection/split_indices_provider/split_indices_provider_type.dart @@ -0,0 +1,3 @@ +enum SplitIndicesProviderType { + lpo, kFold, +} diff --git a/pubspec.yaml b/pubspec.yaml index 89508e68..2d9717b4 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -1,6 +1,6 @@ name: ml_algo description: Machine learning algorithms written in native dart -version: 14.0.0 +version: 14.0.1 homepage: https://github.com/gyrdym/ml_algo environment: diff --git a/test/mocks.dart b/test/mocks.dart index 2bba0053..30e3c75e 100644 --- a/test/mocks.dart +++ b/test/mocks.dart @@ -23,8 +23,8 @@ import 'package:ml_algo/src/link_function/link_function.dart'; import 'package:ml_algo/src/math/randomizer/randomizer.dart'; import 'package:ml_algo/src/math/randomizer/randomizer_factory.dart'; import 'package:ml_algo/src/model_selection/assessable.dart'; -import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter.dart'; -import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter_factory.dart'; +import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider.dart'; +import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory.dart'; import 'package:ml_algo/src/regressor/knn_regressor/knn_regressor.dart'; import 'package:ml_algo/src/regressor/knn_regressor/knn_regressor_factory.dart'; import 'package:ml_algo/src/tree_trainer/decision_tree_trainer.dart'; @@ -78,9 +78,9 @@ class ConvergenceDetectorFactoryMock extends Mock class ConvergenceDetectorMock extends Mock implements ConvergenceDetector {} -class DataSplitterMock extends Mock implements DataSplitter {} +class DataSplitterMock extends Mock implements SplitIndicesProvider {} -class DataSplitterFactoryMock extends Mock implements DataSplitterFactory {} +class DataSplitterFactoryMock extends Mock implements SplitIndicesProviderFactory {} class AssessableMock extends Mock implements Assessable {} @@ -216,7 +216,7 @@ LinearOptimizerFactory createLinearOptimizerFactoryMock( return factory; } -DataSplitterFactory createDataSplitterFactoryMock(DataSplitter dataSplitter) { +SplitIndicesProviderFactory createDataSplitterFactoryMock(SplitIndicesProvider dataSplitter) { final factory = DataSplitterFactoryMock(); when(factory.createByType(any, numberOfFolds: anyNamed('numberOfFolds'), diff --git a/test/model_selection/cross_validator/cross_validator_impl_test.dart b/test/model_selection/cross_validator/cross_validator_impl_test.dart index ea0b9801..f894c9b8 100644 --- a/test/model_selection/cross_validator/cross_validator_impl_test.dart +++ b/test/model_selection/cross_validator/cross_validator_impl_test.dart @@ -2,7 +2,7 @@ import 'package:ml_algo/src/common/exception/invalid_test_data_columns_number_ex import 'package:ml_algo/src/common/exception/invalid_train_data_columns_number_exception.dart'; import 'package:ml_algo/src/metric/metric_type.dart'; import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_impl.dart'; -import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter.dart'; +import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_linalg/dtype.dart'; import 'package:mockito/mockito.dart'; @@ -10,9 +10,9 @@ import 'package:test/test.dart'; import '../../mocks.dart'; -DataSplitter createSplitter(Iterable> indices) { +SplitIndicesProvider createSplitter(Iterable> indices) { final splitter = DataSplitterMock(); - when(splitter.split(any)).thenReturn(indices); + when(splitter.getIndices(any)).thenReturn(indices); return splitter; } diff --git a/test/model_selection/cross_validator/cross_validator_test.dart b/test/model_selection/cross_validator/cross_validator_test.dart index 017c5101..9ef8977c 100644 --- a/test/model_selection/cross_validator/cross_validator_test.dart +++ b/test/model_selection/cross_validator/cross_validator_test.dart @@ -1,9 +1,9 @@ import 'package:injector/injector.dart'; import 'package:ml_algo/ml_algo.dart'; import 'package:ml_algo/src/di/injector.dart'; -import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter.dart'; -import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter_factory.dart'; -import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter_type.dart'; +import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider.dart'; +import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory.dart'; +import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_type.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:mockito/mockito.dart'; import 'package:test/test.dart'; @@ -18,15 +18,15 @@ void main() { header: ['1', '2', '3', '4'], ); - DataSplitter dataSplitter; - DataSplitterFactory dataSplitterFactory; + SplitIndicesProvider dataSplitter; + SplitIndicesProviderFactory dataSplitterFactory; setUp(() { dataSplitter = DataSplitterMock(); dataSplitterFactory = createDataSplitterFactoryMock(dataSplitter); injector = Injector() - ..registerDependency((_) => dataSplitterFactory); + ..registerDependency((_) => dataSplitterFactory); }); tearDown(() => injector = null); @@ -36,7 +36,7 @@ void main() { CrossValidator.kFold(data, ['4'], numberOfFolds: 10); verify(dataSplitterFactory - .createByType(DataSplitterType.kFold, numberOfFolds: 10), + .createByType(SplitIndicesProviderType.kFold, numberOfFolds: 10), ).called(1); }); @@ -45,7 +45,7 @@ void main() { CrossValidator.kFold(data, ['4']); verify(dataSplitterFactory - .createByType(DataSplitterType.kFold, numberOfFolds: 5), + .createByType(SplitIndicesProviderType.kFold, numberOfFolds: 5), ).called(1); }); @@ -54,7 +54,7 @@ void main() { CrossValidator.lpo(data, ['4'], 30); verify(dataSplitterFactory - .createByType(DataSplitterType.lpo, p: 30), + .createByType(SplitIndicesProviderType.lpo, p: 30), ).called(1); }); }); diff --git a/test/model_selection/data_splitter/data_splitter_factory_impl_test.dart b/test/model_selection/data_splitter/data_splitter_factory_impl_test.dart deleted file mode 100644 index 31fba6af..00000000 --- a/test/model_selection/data_splitter/data_splitter_factory_impl_test.dart +++ /dev/null @@ -1,41 +0,0 @@ -import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter_factory_impl.dart'; -import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter_type.dart'; -import 'package:ml_algo/src/model_selection/split_indices_provider/k_fold_data_splitter.dart'; -import 'package:ml_algo/src/model_selection/split_indices_provider/leave_p_out_data_splitter.dart'; -import 'package:test/test.dart'; - -void main() { - group('DataSplitterFactoryImpl', () { - const factory = DataSplitterFactoryImpl(); - - test('should create k fold data splitter', () { - expect( - factory.createByType(DataSplitterType.kFold, numberOfFolds: 3), - isA(), - ); - }); - - test('should throw an exception if number of folds is not provided for ' - 'k fold data splitter', () { - expect( - () => factory.createByType(DataSplitterType.kFold), - throwsException, - ); - }); - - test('should create leave p out data splitter', () { - expect( - factory.createByType(DataSplitterType.lpo, p: 3), - isA(), - ); - }); - - test('should throw an exception if `p` parameter is not provided for ' - 'leave p out data splitter', () { - expect( - () => factory.createByType(DataSplitterType.lpo), - throwsException, - ); - }); - }); -} diff --git a/test/model_selection/data_splitter/k_fold_data_splitter_test.dart b/test/model_selection/data_splitter/k_fold_split_indices_provider_test.dart similarity index 74% rename from test/model_selection/data_splitter/k_fold_data_splitter_test.dart rename to test/model_selection/data_splitter/k_fold_split_indices_provider_test.dart index fc3e2257..1dfe589e 100644 --- a/test/model_selection/data_splitter/k_fold_data_splitter_test.dart +++ b/test/model_selection/data_splitter/k_fold_split_indices_provider_test.dart @@ -2,24 +2,24 @@ import 'package:ml_algo/src/model_selection/split_indices_provider/k_fold_data_s import 'package:test/test.dart'; void main() { - group('KFoldDataSplitter', () { + group('KFoldIndicesProvider', () { void testKFoldSplitter(int numOfFold, int numOfObservations, Iterable> expected) { test('should return proper groups of indices if number of folds is ' '$numOfFold and number of observations is $numOfObservations', () { - final splitter = KFoldDataSplitter(numOfFold); - expect(splitter.split(numOfObservations), equals(expected)); + final splitter = KFoldIndicesProvider(numOfFold); + expect(splitter.getIndices(numOfObservations), equals(expected)); }); } test('should throw an exception if passed number of folds is equal to ' '0', () { - expect(() => KFoldDataSplitter(0), throwsRangeError); + expect(() => KFoldIndicesProvider(0), throwsRangeError); }); test('should throw an exception if passed number of folds is equal to ' '1', () { - expect(() => KFoldDataSplitter(1), throwsRangeError); + expect(() => KFoldIndicesProvider(1), throwsRangeError); }); testKFoldSplitter(5, 12, [ @@ -56,14 +56,14 @@ void main() { ]); test('should throws a range error if number of observations is 0', () { - final splitter = KFoldDataSplitter(3); - expect(() => splitter.split(0), throwsRangeError); + final splitter = KFoldIndicesProvider(3); + expect(() => splitter.getIndices(0), throwsRangeError); }); test('should throws a range error if number of observations is less than' 'number of folds', () { - final splitter = KFoldDataSplitter(9); - expect(() => splitter.split(8), throwsRangeError); + final splitter = KFoldIndicesProvider(9); + expect(() => splitter.getIndices(8), throwsRangeError); }); }); } diff --git a/test/model_selection/data_splitter/lpo_data_splitter_test.dart b/test/model_selection/data_splitter/lpo_split_indices_provider_test.dart similarity index 81% rename from test/model_selection/data_splitter/lpo_data_splitter_test.dart rename to test/model_selection/data_splitter/lpo_split_indices_provider_test.dart index 1e934875..08e2e9f8 100644 --- a/test/model_selection/data_splitter/lpo_data_splitter_test.dart +++ b/test/model_selection/data_splitter/lpo_split_indices_provider_test.dart @@ -1,14 +1,14 @@ -import 'package:ml_algo/src/model_selection/split_indices_provider/leave_p_out_data_splitter.dart'; +import 'package:ml_algo/src/model_selection/split_indices_provider/lpo_indices_provider.dart'; import 'package:test/test.dart'; void main() { - group('LeavePOutDataSplitter', () { + group('LpoIndicesProvider', () { void testLpoSplitter(int p, int numOfObservations, Iterable> expected) { test('should return proper groups of indices if p is $p and number of ' 'observations is $numOfObservations', () { - final splitter = LeavePOutDataSplitter(p); - expect(splitter.split(numOfObservations).toSet(), equals(expected)); + final splitter = LpoIndicesProvider(p); + expect(splitter.getIndices(numOfObservations).toSet(), equals(expected)); }); } @@ -63,7 +63,7 @@ void main() { ].toSet()); test('should throw an error, if p is equal to 0', () { - expect(() => LeavePOutDataSplitter(0), throwsUnsupportedError); + expect(() => LpoIndicesProvider(0), throwsUnsupportedError); }); }); } diff --git a/test/model_selection/data_splitter/split_indices_provider_factory_impl_test.dart b/test/model_selection/data_splitter/split_indices_provider_factory_impl_test.dart new file mode 100644 index 00000000..b4855222 --- /dev/null +++ b/test/model_selection/data_splitter/split_indices_provider_factory_impl_test.dart @@ -0,0 +1,41 @@ +import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory_impl.dart'; +import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_type.dart'; +import 'package:ml_algo/src/model_selection/split_indices_provider/k_fold_data_splitter.dart'; +import 'package:ml_algo/src/model_selection/split_indices_provider/lpo_indices_provider.dart'; +import 'package:test/test.dart'; + +void main() { + group('SplitIndicesProviderFactoryImpl', () { + const factory = SplitIndicesProviderFactoryImpl(); + + test('should create k fold indices provider', () { + expect( + factory.createByType(SplitIndicesProviderType.kFold, numberOfFolds: 3), + isA(), + ); + }); + + test('should throw an exception if number of folds is not provided for ' + 'k fold indices provider', () { + expect( + () => factory.createByType(SplitIndicesProviderType.kFold), + throwsException, + ); + }); + + test('should create leave p out indices provider', () { + expect( + factory.createByType(SplitIndicesProviderType.lpo, p: 3), + isA(), + ); + }); + + test('should throw an exception if `p` parameter is not provided for ' + 'leave p out indices provider', () { + expect( + () => factory.createByType(SplitIndicesProviderType.lpo), + throwsException, + ); + }); + }); +}