Skip to content

Commit

Permalink
Merge pull request #100 from gyrdym/cross-validator-unit-tests
Browse files Browse the repository at this point in the history
Cross validator and data splitters refactored + unit tests for them added
  • Loading branch information
gyrdym committed Apr 23, 2019
2 parents f85b376 + bd13b7f commit f5b8546
Show file tree
Hide file tree
Showing 13 changed files with 322 additions and 217 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
@@ -1,5 +1,10 @@
# Changelog

## 11.0.1
- Cross validator refactored
- Data splitters refactored
- Unit tests for cross validator added

## 11.0.0
- Added immutable state to all the predictor subclasses

Expand Down
44 changes: 44 additions & 0 deletions benchmark/cross_validator.dart
@@ -0,0 +1,44 @@
// 8.5 sec
import 'dart:async';

import 'package:benchmark_harness/benchmark_harness.dart';
import 'package:ml_algo/ml_algo.dart';
import 'package:ml_linalg/matrix.dart';
import 'package:ml_linalg/vector.dart';

const observationsNum = 1000;
const featuresNum = 20;

class CrossValidatorBenchmark extends BenchmarkBase {
CrossValidatorBenchmark() : super('Cross validator benchmark');

Matrix features;
Matrix labels;
CrossValidator crossValidator;

static void main() {
CrossValidatorBenchmark().report();
}

@override
void run() {
crossValidator.evaluate((trainFeatures, trainLabels) =>
ParameterlessRegressor.knn(trainFeatures, trainLabels, k: 7),
features, labels, MetricType.mape);
}

@override
void setup() {
features = Matrix.fromRows(List.generate(observationsNum,
(i) => Vector.randomFilled(featuresNum)));
labels = Matrix.fromColumns([Vector.randomFilled(observationsNum)]);

crossValidator = CrossValidator.kFold(numberOfFolds: 5);
}

void tearDown() {}
}

Future main() async {
CrossValidatorBenchmark.main();
}
9 changes: 6 additions & 3 deletions lib/src/model_selection/cross_validator/cross_validator.dart
@@ -1,5 +1,7 @@
import 'package:ml_algo/src/metric/metric_type.dart';
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_impl.dart';
import 'package:ml_algo/src/model_selection/data_splitter/k_fold.dart';
import 'package:ml_algo/src/model_selection/data_splitter/leave_p_out.dart';
import 'package:ml_algo/src/predictor/predictor.dart';
import 'package:ml_linalg/matrix.dart';

Expand All @@ -9,14 +11,15 @@ abstract class CrossValidator {
///
/// It splits a dataset into [numberOfFolds] test sets and subsequently
/// evaluates the predictor on each produced test set
factory CrossValidator.kFold({Type dtype, int numberOfFolds}) =
CrossValidatorImpl.kFold;
factory CrossValidator.kFold({Type dtype, int numberOfFolds}) =>
CrossValidatorImpl(dtype, KFoldSplitter(numberOfFolds));

/// Creates LPO validator to evaluate quality of a predictor.
///
/// It splits a dataset into all possible test sets of size [p] and
/// subsequently evaluates quality of the predictor on each produced test set
factory CrossValidator.lpo({Type dtype, int p}) = CrossValidatorImpl.lpo;
factory CrossValidator.lpo({Type dtype, int p}) =>
CrossValidatorImpl(dtype, LeavePOutSplitter(p));

/// Returns a score of quality of passed predictor depending on given [metric]
double evaluate(Predictor predictorFactory(Matrix features, Matrix outcomes),
Expand Down
33 changes: 13 additions & 20 deletions lib/src/model_selection/cross_validator/cross_validator_impl.dart
@@ -1,21 +1,13 @@
import 'package:ml_algo/src/utils/default_parameter_values.dart';
import 'package:ml_algo/src/metric/metric_type.dart';
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator.dart';
import 'package:ml_algo/src/model_selection/data_splitter/k_fold.dart';
import 'package:ml_algo/src/model_selection/data_splitter/leave_p_out.dart';
import 'package:ml_algo/src/model_selection/data_splitter/splitter.dart';
import 'package:ml_algo/src/predictor/predictor.dart';
import 'package:ml_algo/src/utils/default_parameter_values.dart';
import 'package:ml_linalg/matrix.dart';
import 'package:ml_linalg/vector.dart';

class CrossValidatorImpl implements CrossValidator {
factory CrossValidatorImpl.kFold({Type dtype, int numberOfFolds = 5}) =>
CrossValidatorImpl._(dtype, KFoldSplitter(numberOfFolds));

factory CrossValidatorImpl.lpo({Type dtype, int p}) =>
CrossValidatorImpl._(dtype, LeavePOutSplitter(p));

CrossValidatorImpl._(Type dtype, this._splitter)
CrossValidatorImpl(Type dtype, this._splitter)
: dtype = dtype ?? DefaultParameterValues.dtype;

final Type dtype;
Expand All @@ -30,24 +22,24 @@ class CrossValidatorImpl implements CrossValidator {
}

final allIndicesGroups = _splitter.split(observations.rowsNum);
// TODO get rid of length accessing
final scores = List<double>(allIndicesGroups.length);
int scoreCounter = 0;
var score = 0.0;
var folds = 0;

for (final testIndices in allIndicesGroups) {
final testIndicesAsSet = Set<int>.from(testIndices);
final trainFeatures =
List<Vector>(observations.rowsNum - testIndices.length);
List<Vector>(observations.rowsNum - testIndicesAsSet.length);
final trainLabels =
List<Vector>(observations.rowsNum - testIndices.length);
List<Vector>(observations.rowsNum - testIndicesAsSet.length);

final testFeatures = List<Vector>(testIndices.length);
final testLabels = List<Vector>(testIndices.length);
final testFeatures = List<Vector>(testIndicesAsSet.length);
final testLabels = List<Vector>(testIndicesAsSet.length);

int trainPointsCounter = 0;
int testPointsCounter = 0;

for (int index = 0; index < observations.rowsNum; index++) {
if (testIndices.contains(index)) {
if (testIndicesAsSet.contains(index)) {
testFeatures[testPointsCounter] = observations.getRow(index);
testLabels[testPointsCounter] = labels.getRow(index);
testPointsCounter++;
Expand All @@ -63,13 +55,14 @@ class CrossValidatorImpl implements CrossValidator {
Matrix.fromRows(trainLabels, dtype: dtype),
)..fit();

scores[scoreCounter++] = predictor.test(
score += predictor.test(
Matrix.fromRows(testFeatures, dtype: dtype),
Matrix.fromRows(testLabels, dtype: dtype),
metric
);
folds++;
}

return scores.fold<double>(0, (sum, value) => sum + value) / scores.length;
return score / folds;
}
}
46 changes: 15 additions & 31 deletions lib/src/model_selection/data_splitter/k_fold.dart
@@ -1,47 +1,31 @@
import 'package:ml_algo/src/model_selection/data_splitter/splitter.dart';
import 'package:xrange/zrange.dart';

class KFoldSplitter implements Splitter {
final int _numberOfFolds;

KFoldSplitter(this._numberOfFolds) {
if (_numberOfFolds == 0 || _numberOfFolds == 1) {
throw RangeError(
'Number of folds must be greater than 1 and less than number of samples');
'Number of folds must be greater than 1 and less than number of '
'samples');
}
}

final int _numberOfFolds;

@override
Iterable<Iterable<int>> split(int numberOfSamples) sync* {
if (_numberOfFolds > numberOfSamples) {
throw RangeError.range(_numberOfFolds, 0, numberOfSamples, null,
Iterable<Iterable<int>> split(int numOfObservations) sync* {
if (_numberOfFolds > numOfObservations) {
throw RangeError.range(_numberOfFolds, 0, numOfObservations, null,
'Number of folds must be less than number of samples!');
}

final remainder = numberOfSamples % _numberOfFolds;
final size = (numberOfSamples / _numberOfFolds).truncate();
final sizes = List<int>.filled(_numberOfFolds, 1)
.map((int el) => el * size)
.toList(growable: false);

if (remainder > 0) {
final range =
sizes.take(remainder).map((int el) => ++el).toList(growable: false);
sizes.setRange(0, remainder, range);
}

int startIdx = 0;
int endIdx = 0;

for (int i = 0; i < sizes.length; i++) {
endIdx = startIdx + sizes[i];
yield _range(startIdx, endIdx);
final remainder = numOfObservations % _numberOfFolds;
final foldSize = numOfObservations ~/ _numberOfFolds;
for (int i = 0, startIdx = 0, endIdx = 0; i < _numberOfFolds; i++) {
// if we reached last fold of size [foldSize], all the next folds up
// to the last fold will have size that is equal to [foldSize] + 1
endIdx = startIdx + foldSize + (i >= _numberOfFolds - remainder ? 1 : 0);
yield ZRange.closedOpen(startIdx, endIdx).values();
startIdx = endIdx;
}
}

Iterable<int> _range(int start, int end) sync* {
for (int i = start; i < end; i++) {
yield i;
}
}
}
19 changes: 7 additions & 12 deletions lib/src/model_selection/data_splitter/leave_p_out.dart
@@ -1,21 +1,18 @@
import 'package:ml_algo/src/model_selection/data_splitter/splitter.dart';

class LeavePOutSplitter implements Splitter {
int _p = 2;

LeavePOutSplitter(int p) {
if (p == 0) {
throw UnsupportedError('Value `$p` for parameter `p` is unsupported');
LeavePOutSplitter([this._p = 2]) {
if (_p == 0) {
throw UnsupportedError('Value `$_p` for parameter `p` is unsupported');
}
_p = p;
}

final int _p;

@override
Iterable<Iterable<int>> split(int numberOfSamples) sync* {
for (int u = 0; u < 1 << numberOfSamples; u++) {
if (_count(u) == _p) {
yield _generateCombination(u);
}
if (_count(u) == _p) yield _generateCombination(u);
}
}

Expand All @@ -28,9 +25,7 @@ class LeavePOutSplitter implements Splitter {

Iterable<int> _generateCombination(int u) sync* {
for (int n = 0; u > 0; ++n, u >>= 1) {
if ((u & 1) > 0) {
yield n;
}
if ((u & 1) > 0) yield n;
}
}
}
2 changes: 1 addition & 1 deletion pubspec.yaml
@@ -1,6 +1,6 @@
name: ml_algo
description: Machine learning algorithms written in native dart
version: 11.0.0
version: 11.0.1
author: Ilia Gyrdymov <ilgyrd@gmail.com>
homepage: https://github.com/gyrdym/ml_algo

Expand Down
80 changes: 80 additions & 0 deletions test/cross_validator/cross_validator_impl_test.dart
@@ -0,0 +1,80 @@
import 'dart:typed_data';

import 'package:ml_algo/src/metric/metric_type.dart';
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_impl.dart';
import 'package:ml_algo/src/model_selection/data_splitter/splitter.dart';
import 'package:ml_linalg/matrix.dart';
import 'package:mockito/mockito.dart';
import 'package:test/test.dart';

import '../test_utils/mocks.dart';

Splitter createSplitter(Iterable<Iterable<int>> indices) {
final splitter = SplitterMock();
when(splitter.split(any)).thenReturn(indices);
return splitter;
}

void main() {
group('CrossValidatorImpl', () {
test('should perform validation of a model on given test indices of'
'observations', () {
final allObservations = Matrix.from([
[330, 930, 130],
[630, 830, 230],
[730, 730, 330],
[830, 630, 430],
[930, 530, 530],
[130, 430, 630],
[230, 330, 730],
[430, 230, 830],
[530, 130, 930],
]);
final allOutcomes = Matrix.from([
[100],[200],[300],[400],[500],[600],[700],[800],[900],
]);
final metric = MetricType.mape;
final splitter = createSplitter([[0,2,4],[6, 8]]);
final predictor = PredictorMock();
final validator = CrossValidatorImpl(Float32x4, splitter);

var score = 20.0;
when(predictor.test(any, any, any))
.thenAnswer((Invocation inv) => score = score + 10);

final actual = validator.evaluate((observations, outcomes) => predictor,
allObservations, allOutcomes, metric);

expect(actual, 35);

verify(predictor.test(argThat(equals([
[330, 930, 130],
[730, 730, 330],
[930, 530, 530],
])), argThat(equals([[100], [300], [500]])), metric)).called(1);

verify(predictor.test(argThat(equals([
[230, 330, 730],
[530, 130, 930],
])), argThat(equals([[700], [900]])), metric)).called(1);
});

test('should throw an exception if observations number and outcomes number '
'mismatch', () {
final allObservations = Matrix.from([
[330, 930, 130],
[630, 830, 230],
]);
final allOutcomes = Matrix.from([
[100],
]);
final metric = MetricType.mape;
final splitter = SplitterMock();
final predictor = PredictorMock();
final validator = CrossValidatorImpl(Float32x4, splitter);

expect(() => validator.evaluate((observations, outcomes) => predictor,
allObservations, allOutcomes, metric), throwsException);
});
});
}

0 comments on commit f5b8546

Please sign in to comment.