Skip to content

Commit

Permalink
splitData helper added
Browse files Browse the repository at this point in the history
  • Loading branch information
gyrdym committed Jun 22, 2020
1 parent 3728e58 commit a721e19
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 13 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
@@ -1,5 +1,8 @@
# Changelog

## 14.1.0
- `Model selection`: `splitData` helper added

## 14.0.1
- data splitters renamed and reorganized

Expand Down
1 change: 1 addition & 0 deletions lib/ml_algo.dart
Expand Up @@ -10,5 +10,6 @@ export 'package:ml_algo/src/metric/classification/type.dart';
export 'package:ml_algo/src/metric/metric_type.dart';
export 'package:ml_algo/src/metric/regression/type.dart';
export 'package:ml_algo/src/model_selection/cross_validator/cross_validator.dart';
export 'package:ml_algo/src/model_selection/split_data.dart';
export 'package:ml_algo/src/regressor/knn_regressor/knn_regressor.dart';
export 'package:ml_algo/src/regressor/linear_regressor/linear_regressor.dart';
41 changes: 29 additions & 12 deletions lib/src/model_selection/split_data.dart
Expand Up @@ -9,28 +9,45 @@ List<DataFrame> splitData(DataFrame data, Iterable<num> ratios) {
.rows
.toList();
var start = 0;
var ratioSum = 0.0;

return ratios
.map((ratio) {
if (ratio <= 0 || ratio >= 1) {
throw Exception('Ratio value must be within range 0..1 (both '
throw Exception('Ratio value must be within the range 0..1 (both '
'exclusive), $ratio given');
}

final end = start + (inputRows.length * ratio).ceil();
ratioSum += ratio;

if (ratioSum >= 1) {
throw Exception('Ratios sum is more than or equal to 1');
}

final rawSplitSize = inputRows.length * ratio;

if (rawSplitSize < 1) {
throw Exception('Ratio is too small comparing to the input data size: '
'ratio $ratio, min ratio value '
'${(1 / inputRows.length).toStringAsFixed(2)}');
}

final end = start + (rawSplitSize.ceil() == inputRows.length
? rawSplitSize.floor()
: rawSplitSize.ceil());
final rows = inputRows.sublist(start, end);

start = end;

return DataFrame(rows, headerExists: false, header: data.header);
return DataFrame(
rows,
headerExists: false,
header: data.header,
);
})
.toList()
.followedBy([
DataFrame(
inputRows.sublist(start),
headerExists: false,
header: data.header
),
])
.toList(growable: false);
.toList()..add(DataFrame(
inputRows.sublist(start),
headerExists: false,
header: data.header
));
}
2 changes: 1 addition & 1 deletion pubspec.yaml
@@ -1,6 +1,6 @@
name: ml_algo
description: Machine learning algorithms, Machine learning models performance evaluation functionality
version: 14.0.1
version: 14.1.0
homepage: https://github.com/gyrdym/ml_algo

environment:
Expand Down
131 changes: 131 additions & 0 deletions test/model_selection/split_data_test.dart
Expand Up @@ -57,5 +57,136 @@ void main() {
[500981, 29918, 5008.55],
]);
});

test('should split data, case 2', () {
final splits = splitData(data, [0.2, 0.2, 0.2, 0.2])
.toList();

expect(splits, hasLength(5));
expect(splits[0].header, header);
expect(splits[0].rows, [
[100.00, null, 200.33],
]);
expect(splits[1].header, header);
expect(splits[1].rows, [
[-2221, 1002, 70009],
]);
expect(splits[2].header, header);
expect(splits[2].rows, [
[ 9008, 10006, null],
]);
expect(splits[3].header, header);
expect(splits[3].rows, [
[ 7888, 10002, 300918],
]);
expect(splits[4].header, header);
expect(splits[4].rows, [
[500981, 29918, 5008.55],
]);
});

test('should throw exception if there is a too small ratio, case 1', () {
expect(() => splitData(data, [0.2, 0.3, 0.01]), throwsException);
});

test('should throw exception if there is a too small ratio, case 2', () {
expect(() => splitData(data, [0.2, 0.3, 0.1]), throwsException);
});

test('should split data into two parts, first part is less than the '
'second one', () {
final splits = splitData(data, [0.2])
.toList();

expect(splits, hasLength(2));
expect(splits[0].header, header);
expect(splits[0].rows, [
[100.00, null, 200.33],
]);
expect(splits[1].header, header);
expect(splits[1].rows, [
[ -2221, 1002, 70009],
[ 9008, 10006, null],
[ 7888, 10002, 300918],
[500981, 29918, 5008.55],
]);
});

test('should split data into two parts, first part is less than the '
'second one, case 2', () {
final splits = splitData(data, [0.25])
.toList();

expect(splits, hasLength(2));
expect(splits[0].header, header);
expect(splits[0].rows, [
[100.00, null, 200.33],
[-2221, 1002, 70009],
]);
expect(splits[1].header, header);
expect(splits[1].rows, [
[ 9008, 10006, null],
[ 7888, 10002, 300918],
[500981, 29918, 5008.55],
]);
});

test('should split data into two parts, first part is greater than the '
'second one', () {
final splits = splitData(data, [0.9])
.toList();

expect(splits, hasLength(2));
expect(splits[0].header, header);
expect(splits[0].rows, [
[100.00, null, 200.33],
[ -2221, 1002, 70009],
[ 9008, 10006, null],
[ 7888, 10002, 300918],
]);
expect(splits[1].header, header);
expect(splits[1].rows, [
[500981, 29918, 5008.55],
]);
});

test('should split data into two parts, first part is greater than the '
'second one, case 2', () {
final splits = splitData(data, [0.95])
.toList();

expect(splits, hasLength(2));
expect(splits[0].header, header);
expect(splits[0].rows, [
[100.00, null, 200.33],
[ -2221, 1002, 70009],
[ 9008, 10006, null],
[ 7888, 10002, 300918],
]);
expect(splits[1].header, header);
expect(splits[1].rows, [
[500981, 29918, 5008.55],
]);
});

test('should throw exception if ratios sum is equal to 1, '
'2 elements', () {
expect(() => splitData(data, [0.5, 0.5]), throwsException);
});

test('should throw exception if ratios sum is equal to 1, '
'3 elements', () {
expect(() => splitData(data, [0.3, 0.3, 0.4]), throwsException);
});

test('should throw exception if ratios sum is greater than 1, '
'2 elements', () {
expect(() => splitData(data, [0.5, 0.6]), throwsException);
});

test('should throw exception if ratios sum is greater than 1, '
'3 elements', () {
expect(() => splitData(data, [0.3, 0.5, 0.4]), throwsException);
});
});
}

0 comments on commit a721e19

Please sign in to comment.