Skip to content

Commit

Permalink
cross validation unit tests added positive case
Browse files Browse the repository at this point in the history
  • Loading branch information
gyrdym committed Apr 23, 2019
1 parent 513d75f commit eb52edb
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 0 deletions.
62 changes: 62 additions & 0 deletions test/cross_validator/cross_validator_impl_test.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import 'dart:typed_data';

import 'package:ml_algo/src/metric/metric_type.dart';
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_impl.dart';
import 'package:ml_algo/src/model_selection/data_splitter/splitter.dart';
import 'package:ml_linalg/matrix.dart';
import 'package:mockito/mockito.dart';
import 'package:test/test.dart';

import '../test_utils/mocks.dart';

Splitter createSplitter(Iterable<Iterable<int>> indices) {
final splitter = SplitterMock();
when(splitter.split(any)).thenReturn(indices);
return splitter;
}

void main() {
group('CrossValidatorImpl', () {
test('should perform validation of a model on given test indices of'
'observations', () {
final allObservations = Matrix.from([
[330, 930, 130],
[630, 830, 230],
[730, 730, 330],
[830, 630, 430],
[930, 530, 530],
[130, 430, 630],
[230, 330, 730],
[430, 230, 830],
[530, 130, 930],
]);
final allOutcomes = Matrix.from([
[100],[200],[300],[400],[500],[600],[700],[800],[900],
]);
final metric = MetricType.mape;
final splitter = createSplitter([[0,2,4],[6, 8]]);
final predictor = PredictorMock();
final validator = CrossValidatorImpl(Float32x4, splitter);

var score = 20.0;
when(predictor.test(any, any, any))
.thenAnswer((Invocation inv) => score = score + 10);

final actual = validator.evaluate((observations, outcomes) => predictor,
allObservations, allOutcomes, metric);

expect(actual, 35);

verify(predictor.test(argThat(equals([
[330, 930, 130],
[730, 730, 330],
[930, 530, 530],
])), argThat(equals([[100], [300], [500]])), metric)).called(1);

verify(predictor.test(argThat(equals([
[230, 330, 730],
[530, 130, 930],
])), argThat(equals([[700], [900]])), metric)).called(1);
});
});
}
3 changes: 3 additions & 0 deletions test/test_all.dart
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import 'classifier/logistic_regressor_integration_test.dart'
import 'classifier/logistic_regressor_test.dart' as logistic_regressor_test;
import 'classifier/softmax_regressor_test.dart' as softmax_regressor_test;
import 'cost_function/cost_function_test.dart' as cost_function_test;
import 'cross_validator/cross_validator_impl_test.dart'
as cross_validator_impl_test;
import 'data_preprocessing/intercept_preprocessor_test.dart'
as intercept_preprocessor_test;
import 'data_splitter/k_fold_splitter_test.dart' as k_fold_splitter_test;
Expand Down Expand Up @@ -34,6 +36,7 @@ void main() {
logistic_regressor_test.main();
softmax_regressor_test.main();
cost_function_test.main();
cross_validator_impl_test.main();
intercept_preprocessor_test.main();
k_fold_splitter_test.main();
lpo_splitter_test.main();
Expand Down
6 changes: 6 additions & 0 deletions test/test_utils/mocks.dart
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_
import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor_factory.dart';
import 'package:ml_algo/src/math/randomizer/randomizer.dart';
import 'package:ml_algo/src/math/randomizer/randomizer_factory.dart';
import 'package:ml_algo/src/model_selection/data_splitter/splitter.dart';
import 'package:ml_algo/src/optimizer/convergence_detector/convergence_detector.dart';
import 'package:ml_algo/src/optimizer/convergence_detector/convergence_detector_factory.dart';
import 'package:ml_algo/src/optimizer/gradient/learning_rate_generator/learning_rate_generator.dart';
Expand All @@ -16,6 +17,7 @@ import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_
import 'package:ml_algo/src/optimizer/optimizer.dart';
import 'package:ml_algo/src/optimizer/optimizer_factory.dart';
import 'package:ml_algo/src/optimizer/optimizer_type.dart';
import 'package:ml_algo/src/predictor/predictor.dart';
import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper.dart';
import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_factory.dart';
import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart';
Expand Down Expand Up @@ -60,6 +62,10 @@ class ConvergenceDetectorFactoryMock extends Mock

class ConvergenceDetectorMock extends Mock implements ConvergenceDetector {}

class SplitterMock extends Mock implements Splitter {}

class PredictorMock extends Mock implements Predictor {}

LearningRateGeneratorFactoryMock createLearningRateGeneratorFactoryMock({
Map<LearningRateType, LearningRateGenerator> generators,
}) {
Expand Down

0 comments on commit eb52edb

Please sign in to comment.