Skip to content

Commit

Permalink
CHANGELOG record added, version updated
Browse files Browse the repository at this point in the history
  • Loading branch information
gyrdym committed Apr 23, 2019
1 parent eb52edb commit bd13b7f
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 7 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

## 11.0.1
- Cross validator refactored
- Data splitters refactored
- Unit tests for cross validator added

## 11.0.0
- Added immutable state to all the predictor subclasses

Expand Down
2 changes: 1 addition & 1 deletion benchmark/cross_validator.dart
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// 10 sec
// 8.5 sec
import 'dart:async';

import 'package:benchmark_harness/benchmark_harness.dart';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,20 @@ class CrossValidatorImpl implements CrossValidator {
var folds = 0;

for (final testIndices in allIndicesGroups) {
final testIndicesAsSet = Set<int>.from(testIndices);
final trainFeatures =
List<Vector>(observations.rowsNum - testIndices.length);
List<Vector>(observations.rowsNum - testIndicesAsSet.length);
final trainLabels =
List<Vector>(observations.rowsNum - testIndices.length);
List<Vector>(observations.rowsNum - testIndicesAsSet.length);

final testFeatures = List<Vector>(testIndices.length);
final testLabels = List<Vector>(testIndices.length);
final testFeatures = List<Vector>(testIndicesAsSet.length);
final testLabels = List<Vector>(testIndicesAsSet.length);

int trainPointsCounter = 0;
int testPointsCounter = 0;

for (int index = 0; index < observations.rowsNum; index++) {
if (testIndices.contains(index)) {
if (testIndicesAsSet.contains(index)) {
testFeatures[testPointsCounter] = observations.getRow(index);
testLabels[testPointsCounter] = labels.getRow(index);
testPointsCounter++;
Expand Down
2 changes: 1 addition & 1 deletion pubspec.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: ml_algo
description: Machine learning algorithms written in native dart
version: 11.0.0
version: 11.0.1
author: Ilia Gyrdymov <ilgyrd@gmail.com>
homepage: https://github.com/gyrdym/ml_algo

Expand Down
18 changes: 18 additions & 0 deletions test/cross_validator/cross_validator_impl_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,23 @@ void main() {
[530, 130, 930],
])), argThat(equals([[700], [900]])), metric)).called(1);
});

test('should throw an exception if observations number and outcomes number '
'mismatch', () {
final allObservations = Matrix.from([
[330, 930, 130],
[630, 830, 230],
]);
final allOutcomes = Matrix.from([
[100],
]);
final metric = MetricType.mape;
final splitter = SplitterMock();
final predictor = PredictorMock();
final validator = CrossValidatorImpl(Float32x4, splitter);

expect(() => validator.evaluate((observations, outcomes) => predictor,
allObservations, allOutcomes, metric), throwsException);
});
});
}

0 comments on commit bd13b7f

Please sign in to comment.