diff --git a/test/cross_validator/cross_validator_impl_test.dart b/test/cross_validator/cross_validator_impl_test.dart new file mode 100644 index 00000000..e30d07b8 --- /dev/null +++ b/test/cross_validator/cross_validator_impl_test.dart @@ -0,0 +1,62 @@ +import 'dart:typed_data'; + +import 'package:ml_algo/src/metric/metric_type.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_impl.dart'; +import 'package:ml_algo/src/model_selection/data_splitter/splitter.dart'; +import 'package:ml_linalg/matrix.dart'; +import 'package:mockito/mockito.dart'; +import 'package:test/test.dart'; + +import '../test_utils/mocks.dart'; + +Splitter createSplitter(Iterable> indices) { + final splitter = SplitterMock(); + when(splitter.split(any)).thenReturn(indices); + return splitter; +} + +void main() { + group('CrossValidatorImpl', () { + test('should perform validation of a model on given test indices of' + 'observations', () { + final allObservations = Matrix.from([ + [330, 930, 130], + [630, 830, 230], + [730, 730, 330], + [830, 630, 430], + [930, 530, 530], + [130, 430, 630], + [230, 330, 730], + [430, 230, 830], + [530, 130, 930], + ]); + final allOutcomes = Matrix.from([ + [100],[200],[300],[400],[500],[600],[700],[800],[900], + ]); + final metric = MetricType.mape; + final splitter = createSplitter([[0,2,4],[6, 8]]); + final predictor = PredictorMock(); + final validator = CrossValidatorImpl(Float32x4, splitter); + + var score = 20.0; + when(predictor.test(any, any, any)) + .thenAnswer((Invocation inv) => score = score + 10); + + final actual = validator.evaluate((observations, outcomes) => predictor, + allObservations, allOutcomes, metric); + + expect(actual, 35); + + verify(predictor.test(argThat(equals([ + [330, 930, 130], + [730, 730, 330], + [930, 530, 530], + ])), argThat(equals([[100], [300], [500]])), metric)).called(1); + + verify(predictor.test(argThat(equals([ + [230, 330, 730], + [530, 130, 930], + ])), argThat(equals([[700], [900]])), metric)).called(1); + }); + }); +} diff --git a/test/test_all.dart b/test/test_all.dart index fbe1de78..d2c81143 100644 --- a/test/test_all.dart +++ b/test/test_all.dart @@ -6,6 +6,8 @@ import 'classifier/logistic_regressor_integration_test.dart' import 'classifier/logistic_regressor_test.dart' as logistic_regressor_test; import 'classifier/softmax_regressor_test.dart' as softmax_regressor_test; import 'cost_function/cost_function_test.dart' as cost_function_test; +import 'cross_validator/cross_validator_impl_test.dart' + as cross_validator_impl_test; import 'data_preprocessing/intercept_preprocessor_test.dart' as intercept_preprocessor_test; import 'data_splitter/k_fold_splitter_test.dart' as k_fold_splitter_test; @@ -34,6 +36,7 @@ void main() { logistic_regressor_test.main(); softmax_regressor_test.main(); cost_function_test.main(); + cross_validator_impl_test.main(); intercept_preprocessor_test.main(); k_fold_splitter_test.main(); lpo_splitter_test.main(); diff --git a/test/test_utils/mocks.dart b/test/test_utils/mocks.dart index 3e7cd6fe..19fbcab5 100644 --- a/test/test_utils/mocks.dart +++ b/test/test_utils/mocks.dart @@ -5,6 +5,7 @@ import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_ import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor_factory.dart'; import 'package:ml_algo/src/math/randomizer/randomizer.dart'; import 'package:ml_algo/src/math/randomizer/randomizer_factory.dart'; +import 'package:ml_algo/src/model_selection/data_splitter/splitter.dart'; import 'package:ml_algo/src/optimizer/convergence_detector/convergence_detector.dart'; import 'package:ml_algo/src/optimizer/convergence_detector/convergence_detector_factory.dart'; import 'package:ml_algo/src/optimizer/gradient/learning_rate_generator/learning_rate_generator.dart'; @@ -16,6 +17,7 @@ import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_ import 'package:ml_algo/src/optimizer/optimizer.dart'; import 'package:ml_algo/src/optimizer/optimizer_factory.dart'; import 'package:ml_algo/src/optimizer/optimizer_type.dart'; +import 'package:ml_algo/src/predictor/predictor.dart'; import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper.dart'; import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_factory.dart'; import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; @@ -60,6 +62,10 @@ class ConvergenceDetectorFactoryMock extends Mock class ConvergenceDetectorMock extends Mock implements ConvergenceDetector {} +class SplitterMock extends Mock implements Splitter {} + +class PredictorMock extends Mock implements Predictor {} + LearningRateGeneratorFactoryMock createLearningRateGeneratorFactoryMock({ Map generators, }) {