Skip to content

Commit

Permalink
data splitter unit tests split up into bunch of micro tests, CrossVal…
Browse files Browse the repository at this point in the history
…idator initializing refactored
  • Loading branch information
gyrdym committed Apr 23, 2019
1 parent 397290f commit 513d75f
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 176 deletions.
9 changes: 6 additions & 3 deletions lib/src/model_selection/cross_validator/cross_validator.dart
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import 'package:ml_algo/src/metric/metric_type.dart';
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_impl.dart';
import 'package:ml_algo/src/model_selection/data_splitter/k_fold.dart';
import 'package:ml_algo/src/model_selection/data_splitter/leave_p_out.dart';
import 'package:ml_algo/src/predictor/predictor.dart';
import 'package:ml_linalg/matrix.dart';

Expand All @@ -9,14 +11,15 @@ abstract class CrossValidator {
///
/// It splits a dataset into [numberOfFolds] test sets and subsequently
/// evaluates the predictor on each produced test set
factory CrossValidator.kFold({Type dtype, int numberOfFolds}) =
CrossValidatorImpl.kFold;
factory CrossValidator.kFold({Type dtype, int numberOfFolds}) =>
CrossValidatorImpl(dtype, KFoldSplitter(numberOfFolds));

/// Creates LPO validator to evaluate quality of a predictor.
///
/// It splits a dataset into all possible test sets of size [p] and
/// subsequently evaluates quality of the predictor on each produced test set
factory CrossValidator.lpo({Type dtype, int p}) = CrossValidatorImpl.lpo;
factory CrossValidator.lpo({Type dtype, int p}) =>
CrossValidatorImpl(dtype, LeavePOutSplitter(p));

/// Returns a score of quality of passed predictor depending on given [metric]
double evaluate(Predictor predictorFactory(Matrix features, Matrix outcomes),
Expand Down
12 changes: 2 additions & 10 deletions lib/src/model_selection/cross_validator/cross_validator_impl.dart
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
import 'package:ml_algo/src/utils/default_parameter_values.dart';
import 'package:ml_algo/src/metric/metric_type.dart';
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator.dart';
import 'package:ml_algo/src/model_selection/data_splitter/k_fold.dart';
import 'package:ml_algo/src/model_selection/data_splitter/leave_p_out.dart';
import 'package:ml_algo/src/model_selection/data_splitter/splitter.dart';
import 'package:ml_algo/src/predictor/predictor.dart';
import 'package:ml_algo/src/utils/default_parameter_values.dart';
import 'package:ml_linalg/matrix.dart';
import 'package:ml_linalg/vector.dart';

class CrossValidatorImpl implements CrossValidator {
factory CrossValidatorImpl.kFold({Type dtype, int numberOfFolds = 5}) =>
CrossValidatorImpl._(dtype, KFoldSplitter(numberOfFolds));

factory CrossValidatorImpl.lpo({Type dtype, int p}) =>
CrossValidatorImpl._(dtype, LeavePOutSplitter(p));

CrossValidatorImpl._(Type dtype, this._splitter)
CrossValidatorImpl(Type dtype, this._splitter)
: dtype = dtype ?? DefaultParameterValues.dtype;

final Type dtype;
Expand Down
5 changes: 4 additions & 1 deletion lib/src/model_selection/data_splitter/k_fold.dart
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ class KFoldSplitter implements Splitter {
KFoldSplitter(this._numberOfFolds) {
if (_numberOfFolds == 0 || _numberOfFolds == 1) {
throw RangeError(
'Number of folds must be greater than 1 and less than number of samples');
'Number of folds must be greater than 1 and less than number of '
'samples');
}
}

Expand All @@ -20,6 +21,8 @@ class KFoldSplitter implements Splitter {
final remainder = numOfObservations % _numberOfFolds;
final foldSize = numOfObservations ~/ _numberOfFolds;
for (int i = 0, startIdx = 0, endIdx = 0; i < _numberOfFolds; i++) {
// if we reached last fold of size [foldSize], all the next folds up
// to the last fold will have size that is equal to [foldSize] + 1
endIdx = startIdx + foldSize + (i >= _numberOfFolds - remainder ? 1 : 0);
yield ZRange.closedOpen(startIdx, endIdx).values();
startIdx = endIdx;
Expand Down
19 changes: 7 additions & 12 deletions lib/src/model_selection/data_splitter/leave_p_out.dart
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
import 'package:ml_algo/src/model_selection/data_splitter/splitter.dart';

class LeavePOutSplitter implements Splitter {
int _p = 2;

LeavePOutSplitter(int p) {
if (p == 0) {
throw UnsupportedError('Value `$p` for parameter `p` is unsupported');
LeavePOutSplitter([this._p = 2]) {
if (_p == 0) {
throw UnsupportedError('Value `$_p` for parameter `p` is unsupported');
}
_p = p;
}

final int _p;

@override
Iterable<Iterable<int>> split(int numberOfSamples) sync* {
for (int u = 0; u < 1 << numberOfSamples; u++) {
if (_count(u) == _p) {
yield _generateCombination(u);
}
if (_count(u) == _p) yield _generateCombination(u);
}
}

Expand All @@ -28,9 +25,7 @@ class LeavePOutSplitter implements Splitter {

Iterable<int> _generateCombination(int u) sync* {
for (int n = 0; u > 0; ++n, u >>= 1) {
if ((u & 1) > 0) {
yield n;
}
if ((u & 1) > 0) yield n;
}
}
}
148 changes: 0 additions & 148 deletions test/data_splitter/data_splitter_test.dart

This file was deleted.

69 changes: 69 additions & 0 deletions test/data_splitter/k_fold_splitter_test.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import 'package:ml_algo/src/model_selection/data_splitter/k_fold.dart';
import 'package:test/test.dart';

void main() {
group('KFold splitter', () {
void testKFoldSplitter(int numOfFold, int numOfObservations,
Iterable<Iterable<int>> expected) {
test('should return proper groups of indices if number of folds is '
'$numOfFold and number of observations is $numOfObservations', () {
final splitter = KFoldSplitter(numOfFold);
expect(splitter.split(numOfObservations), equals(expected));
});
}

test('should throw an exception if passed number of folds is equal to '
'0', () {
expect(() => KFoldSplitter(0), throwsRangeError);
});

test('should throw an exception if passed number of folds is equal to '
'1', () {
expect(() => KFoldSplitter(1), throwsRangeError);
});

testKFoldSplitter(5, 12, [
[0, 1],
[2, 3],
[4, 5],
[6, 7, 8],
[9, 10, 11],
]);

testKFoldSplitter(4, 12, [
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
[9, 10, 11],
]);

testKFoldSplitter(3, 12, [
[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11],
]);

testKFoldSplitter(12, 12, [
[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11],
]);

testKFoldSplitter(5, 37, [
[0, 1, 2, 3, 4, 5, 6],
[7, 8, 9, 10, 11, 12, 13],
[14, 15, 16, 17, 18, 19, 20],
[21, 22, 23, 24, 25, 26, 27, 28],
[29, 30, 31, 32, 33, 34, 35, 36],
]);

test('should throws a range error if number of observations is 0', () {
final splitter = KFoldSplitter(3);
expect(() => splitter.split(0), throwsRangeError);
});

test('should throws a range error if number of observations is less than'
'number of folds', () {
final splitter = KFoldSplitter(9);
expect(() => splitter.split(8), throwsRangeError);
});
});
}
69 changes: 69 additions & 0 deletions test/data_splitter/lpo_splitter_test.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import 'package:ml_algo/src/model_selection/data_splitter/leave_p_out.dart';
import 'package:test/test.dart';

void main() {
group('Leave p out splitter', () {
void testLpoSplitter(int p, int numOfObservations,
Iterable<Iterable<int>> expected) {
test('should return proper groups of indices if p is $p and number of '
'observations is $numOfObservations', () {
final splitter = LeavePOutSplitter(p);
expect(splitter.split(numOfObservations).toSet(), equals(expected));
});
}

testLpoSplitter(2, 4, [
[0, 1],
[0, 2],
[0, 3],
[1, 2],
[1, 3],
[2, 3],
].toSet());

testLpoSplitter(2, 5, [
[0, 1],
[0, 2],
[0, 3],
[0, 4],
[1, 2],
[1, 3],
[1, 4],
[2, 3],
[2, 4],
[3, 4],
].toSet());

testLpoSplitter(1, 5, [
[0],
[1],
[2],
[3],
[4],
].toSet());

testLpoSplitter(3, 4, [
[0, 1, 2],
[0, 1, 3],
[0, 2, 3],
[1, 2, 3],
].toSet());

testLpoSplitter(3, 5, [
[0, 1, 2],
[0, 1, 3],
[0, 1, 4],
[0, 2, 3],
[0, 2, 4],
[0, 3, 4],
[1, 2, 3],
[1, 2, 4],
[1, 3, 4],
[2, 3, 4],
].toSet());

test('should throw an error, if p is equal to 0', () {
expect(() => LeavePOutSplitter(0), throwsUnsupportedError);
});
});
}
Loading

0 comments on commit 513d75f

Please sign in to comment.