Skip to content

Commit

Permalink
Merge deffbe1 into 5501afd
Browse files Browse the repository at this point in the history
  • Loading branch information
gyrdym committed Jun 21, 2020
2 parents 5501afd + deffbe1 commit 74aaf5b
Show file tree
Hide file tree
Showing 22 changed files with 153 additions and 149 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
@@ -1,5 +1,8 @@
# Changelog

## 14.0.1
- data splitters renamed and reorganized

## 14.0.0
- Breaking change:
- `CrossValidator`: `evalute` method's api changed, it returns a Future resolving with scores Vector now instead
Expand Down
8 changes: 4 additions & 4 deletions lib/src/di/dependencies.dart
Expand Up @@ -30,8 +30,8 @@ import 'package:ml_algo/src/link_function/softmax/float32_softmax_link_function.
import 'package:ml_algo/src/link_function/softmax/float64_softmax_link_function.dart';
import 'package:ml_algo/src/math/randomizer/randomizer_factory.dart';
import 'package:ml_algo/src/math/randomizer/randomizer_factory_impl.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter_factory.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter_factory_impl.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory_impl.dart';
import 'package:ml_algo/src/regressor/knn_regressor/knn_regressor_factory.dart';
import 'package:ml_algo/src/regressor/knn_regressor/knn_regressor_factory_impl.dart';
import 'package:ml_algo/src/tree_trainer/leaf_detector/leaf_detector_factory.dart';
Expand Down Expand Up @@ -87,8 +87,8 @@ Injector get dependencies =>
(_) => const Float64SoftmaxLinkFunction(),
dependencyName: float64SoftmaxLinkFunctionToken)

..registerSingleton<DataSplitterFactory>(
(_) => const DataSplitterFactoryImpl())
..registerSingleton<SplitIndicesProviderFactory>(
(_) => const SplitIndicesProviderFactoryImpl())

..registerSingleton<SoftmaxRegressorFactory>(
(_) => const SoftmaxRegressorFactoryImpl())
Expand Down
15 changes: 8 additions & 7 deletions lib/src/model_selection/cross_validator/cross_validator.dart
Expand Up @@ -2,8 +2,8 @@ import 'package:ml_algo/src/di/dependencies.dart';
import 'package:ml_algo/src/metric/metric_type.dart';
import 'package:ml_algo/src/model_selection/assessable.dart';
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_impl.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter_factory.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter_type.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_type.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/dtype.dart';
import 'package:ml_linalg/linalg.dart';
Expand Down Expand Up @@ -38,9 +38,9 @@ abstract class CrossValidator {
DType dtype = DType.float32,
}) {
final dataSplitterFactory = dependencies
.getDependency<DataSplitterFactory>();
.getDependency<SplitIndicesProviderFactory>();
final dataSplitter = dataSplitterFactory
.createByType(DataSplitterType.kFold, numberOfFolds: numberOfFolds);
.createByType(SplitIndicesProviderType.kFold, numberOfFolds: numberOfFolds);

return CrossValidatorImpl(
samples,
Expand All @@ -50,7 +50,7 @@ abstract class CrossValidator {
);
}

/// Creates LPO validator to evaluate quality of a predictor.
/// Creates LPO validator to evaluate performance of a predictor.
///
/// It splits a dataset into all possible test sets of size [p] and
/// subsequently evaluates quality of the predictor on each produced test set.
Expand All @@ -71,9 +71,10 @@ abstract class CrossValidator {
int p, {
DType dtype = DType.float32,
}) {
final dataSplitterFactory = dependencies.getDependency<DataSplitterFactory>();
final dataSplitterFactory = dependencies
.getDependency<SplitIndicesProviderFactory>();
final dataSplitter = dataSplitterFactory
.createByType(DataSplitterType.lpo, p: p);
.createByType(SplitIndicesProviderType.lpo, p: p);

return CrossValidatorImpl(
samples,
Expand Down
Expand Up @@ -2,7 +2,7 @@ import 'package:ml_algo/src/common/exception/invalid_test_data_columns_number_ex
import 'package:ml_algo/src/common/exception/invalid_train_data_columns_number_exception.dart';
import 'package:ml_algo/src/metric/metric_type.dart';
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/dtype.dart';
import 'package:ml_linalg/matrix.dart';
Expand All @@ -20,7 +20,7 @@ class CrossValidatorImpl implements CrossValidator {
final DataFrame samples;
final DType dtype;
final Iterable<String> targetNames;
final DataSplitter _splitter;
final SplitIndicesProvider _splitter;

@override
Future<Vector> evaluate(
Expand All @@ -35,7 +35,7 @@ class CrossValidatorImpl implements CrossValidator {
final discreteColumns = enumerate(samples.series)
.where((indexedSeries) => indexedSeries.value.isDiscrete)
.map((indexedSeries) => indexedSeries.index);
final allIndicesGroups = _splitter.split(samplesAsMatrix.rowsNum);
final allIndicesGroups = _splitter.getIndices(samplesAsMatrix.rowsNum);
final scores = allIndicesGroups
.map((testRowsIndices) {
final split = _makeSplit(testRowsIndices, discreteColumns);
Expand Down
3 changes: 0 additions & 3 deletions lib/src/model_selection/data_splitter/data_splitter.dart

This file was deleted.

This file was deleted.

This file was deleted.

3 changes: 0 additions & 3 deletions lib/src/model_selection/data_splitter/data_splitter_type.dart

This file was deleted.

@@ -1,22 +1,22 @@
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider.dart';
import 'package:xrange/integers.dart';

class KFoldDataSplitter implements DataSplitter {
KFoldDataSplitter(this._numberOfFolds) {
class KFoldIndicesProvider implements SplitIndicesProvider {
KFoldIndicesProvider(this._numberOfFolds) {
if (_numberOfFolds == 0 || _numberOfFolds == 1) {
throw RangeError(
'Number of folds must be greater than 1 and less than number of '
'Number of folds must be greater than 1 and less than the number of '
'samples');
}
}

final int _numberOfFolds;

@override
Iterable<Iterable<int>> split(int numOfObservations) sync* {
Iterable<Iterable<int>> getIndices(int numOfObservations) sync* {
if (_numberOfFolds > numOfObservations) {
throw RangeError.range(_numberOfFolds, 0, numOfObservations, null,
'Number of folds must be less than number of samples!');
'Number of folds must be less than the number of samples');
}
final remainder = numOfObservations % _numberOfFolds;
final foldSize = numOfObservations ~/ _numberOfFolds;
Expand Down
@@ -1,7 +1,7 @@
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider.dart';

class LeavePOutDataSplitter implements DataSplitter {
LeavePOutDataSplitter([this._p = 2]) {
class LpoIndicesProvider implements SplitIndicesProvider {
LpoIndicesProvider([this._p = 2]) {
if (_p == 0) {
throw UnsupportedError('Value `$_p` for parameter `p` is unsupported');
}
Expand All @@ -10,7 +10,7 @@ class LeavePOutDataSplitter implements DataSplitter {
final int _p;

@override
Iterable<Iterable<int>> split(int numberOfSamples) sync* {
Iterable<Iterable<int>> getIndices(int numberOfSamples) sync* {
for (var u = 0; u < 1 << numberOfSamples; u++) {
if (_count(u) == _p) yield _generateCombination(u);
}
Expand Down
@@ -0,0 +1,3 @@
abstract class SplitIndicesProvider {
Iterable<Iterable<int>> getIndices(int numberOfSamples);
}
@@ -0,0 +1,9 @@
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_type.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider.dart';

abstract class SplitIndicesProviderFactory {
SplitIndicesProvider createByType(SplitIndicesProviderType splitterType, {
int numberOfFolds,
int p,
});
}
@@ -0,0 +1,34 @@
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_type.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/k_fold_data_splitter.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/lpo_indices_provider.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider.dart';

class SplitIndicesProviderFactoryImpl implements SplitIndicesProviderFactory {
const SplitIndicesProviderFactoryImpl();

@override
SplitIndicesProvider createByType(SplitIndicesProviderType splitterType, {
int numberOfFolds,
int p,
}) {
switch (splitterType) {
case SplitIndicesProviderType.kFold:
if (numberOfFolds == null) {
throw Exception('Number of folds is not defined for K-fold splitter');
}
return KFoldIndicesProvider(numberOfFolds);

case SplitIndicesProviderType.lpo:
if (p == null) {
throw Exception('`p` parameter is not defined for leave-p-out '
'splitter');
}
return LpoIndicesProvider(p);

default:
throw UnimplementedError('Splitter of type $splitterType is not '
'implemented yet');
}
}
}
@@ -0,0 +1,3 @@
enum SplitIndicesProviderType {
lpo, kFold,
}
6 changes: 3 additions & 3 deletions pubspec.yaml
@@ -1,6 +1,6 @@
name: ml_algo
description: Machine learning algorithms written in native dart
version: 14.0.0
description: Machine learning algorithms, Machine learning models performance evaluation functionality
version: 14.0.1
homepage: https://github.com/gyrdym/ml_algo

environment:
Expand All @@ -11,7 +11,7 @@ dependencies:
json_annotation: ^3.0.1
json_serializable: ^3.3.0
ml_dataframe: ^0.1.1
ml_linalg: ^12.17.0
ml_linalg: ^12.17.1
quiver: ^2.0.2
xrange: ^0.0.8

Expand Down
10 changes: 5 additions & 5 deletions test/mocks.dart
Expand Up @@ -23,8 +23,8 @@ import 'package:ml_algo/src/link_function/link_function.dart';
import 'package:ml_algo/src/math/randomizer/randomizer.dart';
import 'package:ml_algo/src/math/randomizer/randomizer_factory.dart';
import 'package:ml_algo/src/model_selection/assessable.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter_factory.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory.dart';
import 'package:ml_algo/src/regressor/knn_regressor/knn_regressor.dart';
import 'package:ml_algo/src/regressor/knn_regressor/knn_regressor_factory.dart';
import 'package:ml_algo/src/tree_trainer/decision_tree_trainer.dart';
Expand Down Expand Up @@ -78,9 +78,9 @@ class ConvergenceDetectorFactoryMock extends Mock

class ConvergenceDetectorMock extends Mock implements ConvergenceDetector {}

class DataSplitterMock extends Mock implements DataSplitter {}
class DataSplitterMock extends Mock implements SplitIndicesProvider {}

class DataSplitterFactoryMock extends Mock implements DataSplitterFactory {}
class DataSplitterFactoryMock extends Mock implements SplitIndicesProviderFactory {}

class AssessableMock extends Mock implements Assessable {}

Expand Down Expand Up @@ -216,7 +216,7 @@ LinearOptimizerFactory createLinearOptimizerFactoryMock(
return factory;
}

DataSplitterFactory createDataSplitterFactoryMock(DataSplitter dataSplitter) {
SplitIndicesProviderFactory createDataSplitterFactoryMock(SplitIndicesProvider dataSplitter) {
final factory = DataSplitterFactoryMock();
when(factory.createByType(any,
numberOfFolds: anyNamed('numberOfFolds'),
Expand Down
Expand Up @@ -2,17 +2,17 @@ import 'package:ml_algo/src/common/exception/invalid_test_data_columns_number_ex
import 'package:ml_algo/src/common/exception/invalid_train_data_columns_number_exception.dart';
import 'package:ml_algo/src/metric/metric_type.dart';
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_impl.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/dtype.dart';
import 'package:mockito/mockito.dart';
import 'package:test/test.dart';

import '../../mocks.dart';

DataSplitter createSplitter(Iterable<Iterable<int>> indices) {
SplitIndicesProvider createSplitter(Iterable<Iterable<int>> indices) {
final splitter = DataSplitterMock();
when(splitter.split(any)).thenReturn(indices);
when(splitter.getIndices(any)).thenReturn(indices);
return splitter;
}

Expand Down
18 changes: 9 additions & 9 deletions test/model_selection/cross_validator/cross_validator_test.dart
@@ -1,9 +1,9 @@
import 'package:injector/injector.dart';
import 'package:ml_algo/ml_algo.dart';
import 'package:ml_algo/src/di/injector.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter_factory.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter_type.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_type.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:mockito/mockito.dart';
import 'package:test/test.dart';
Expand All @@ -18,15 +18,15 @@ void main() {
header: ['1', '2', '3', '4'],
);

DataSplitter dataSplitter;
DataSplitterFactory dataSplitterFactory;
SplitIndicesProvider dataSplitter;
SplitIndicesProviderFactory dataSplitterFactory;

setUp(() {
dataSplitter = DataSplitterMock();
dataSplitterFactory = createDataSplitterFactoryMock(dataSplitter);

injector = Injector()
..registerDependency<DataSplitterFactory>((_) => dataSplitterFactory);
..registerDependency<SplitIndicesProviderFactory>((_) => dataSplitterFactory);
});

tearDown(() => injector = null);
Expand All @@ -36,7 +36,7 @@ void main() {
CrossValidator.kFold(data, ['4'], numberOfFolds: 10);

verify(dataSplitterFactory
.createByType(DataSplitterType.kFold, numberOfFolds: 10),
.createByType(SplitIndicesProviderType.kFold, numberOfFolds: 10),
).called(1);
});

Expand All @@ -45,7 +45,7 @@ void main() {
CrossValidator.kFold(data, ['4']);

verify(dataSplitterFactory
.createByType(DataSplitterType.kFold, numberOfFolds: 5),
.createByType(SplitIndicesProviderType.kFold, numberOfFolds: 5),
).called(1);
});

Expand All @@ -54,7 +54,7 @@ void main() {
CrossValidator.lpo(data, ['4'], 30);

verify(dataSplitterFactory
.createByType(DataSplitterType.lpo, p: 30),
.createByType(SplitIndicesProviderType.lpo, p: 30),
).called(1);
});
});
Expand Down

0 comments on commit 74aaf5b

Please sign in to comment.