Skip to content

Commit

Permalink
Rename and reorganize data splitters
Browse files Browse the repository at this point in the history
  • Loading branch information
gyrdym committed Jun 21, 2020
1 parent 5501afd commit b5992b5
Show file tree
Hide file tree
Showing 17 changed files with 36 additions and 36 deletions.
4 changes: 2 additions & 2 deletions lib/src/di/dependencies.dart
Expand Up @@ -30,8 +30,8 @@ import 'package:ml_algo/src/link_function/softmax/float32_softmax_link_function.
import 'package:ml_algo/src/link_function/softmax/float64_softmax_link_function.dart';
import 'package:ml_algo/src/math/randomizer/randomizer_factory.dart';
import 'package:ml_algo/src/math/randomizer/randomizer_factory_impl.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter_factory.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter_factory_impl.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter_factory.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter_factory_impl.dart';
import 'package:ml_algo/src/regressor/knn_regressor/knn_regressor_factory.dart';
import 'package:ml_algo/src/regressor/knn_regressor/knn_regressor_factory_impl.dart';
import 'package:ml_algo/src/tree_trainer/leaf_detector/leaf_detector_factory.dart';
Expand Down
4 changes: 2 additions & 2 deletions lib/src/model_selection/cross_validator/cross_validator.dart
Expand Up @@ -2,8 +2,8 @@ import 'package:ml_algo/src/di/dependencies.dart';
import 'package:ml_algo/src/metric/metric_type.dart';
import 'package:ml_algo/src/model_selection/assessable.dart';
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_impl.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter_factory.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter_type.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter_factory.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter_type.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/dtype.dart';
import 'package:ml_linalg/linalg.dart';
Expand Down
Expand Up @@ -2,7 +2,7 @@ import 'package:ml_algo/src/common/exception/invalid_test_data_columns_number_ex
import 'package:ml_algo/src/common/exception/invalid_train_data_columns_number_exception.dart';
import 'package:ml_algo/src/metric/metric_type.dart';
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/dtype.dart';
import 'package:ml_linalg/matrix.dart';
Expand Down

This file was deleted.

@@ -0,0 +1,9 @@
import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter_type.dart';

abstract class DataSplitterFactory {
DataSplitter createByType(DataSplitterType splitterType, {
int numberOfFolds,
int p,
});
}
@@ -1,8 +1,8 @@
import 'package:ml_algo/src/model_selection/data_splitter/k_fold_data_splitter.dart';
import 'package:ml_algo/src/model_selection/data_splitter/leave_p_out_data_splitter.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter_factory.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter_type.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter_factory.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter_type.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/k_fold_data_splitter.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/leave_p_out_data_splitter.dart';

class DataSplitterFactoryImpl implements DataSplitterFactory {
const DataSplitterFactoryImpl();
Expand Down
@@ -1,11 +1,11 @@
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter.dart';
import 'package:xrange/integers.dart';

class KFoldDataSplitter implements DataSplitter {
KFoldDataSplitter(this._numberOfFolds) {
if (_numberOfFolds == 0 || _numberOfFolds == 1) {
throw RangeError(
'Number of folds must be greater than 1 and less than number of '
'Number of folds must be greater than 1 and less than the number of '
'samples');
}
}
Expand All @@ -16,7 +16,7 @@ class KFoldDataSplitter implements DataSplitter {
Iterable<Iterable<int>> split(int numOfObservations) sync* {
if (_numberOfFolds > numOfObservations) {
throw RangeError.range(_numberOfFolds, 0, numOfObservations, null,
'Number of folds must be less than number of samples!');
'Number of folds must be less than the number of samples');
}
final remainder = numOfObservations % _numberOfFolds;
final foldSize = numOfObservations ~/ _numberOfFolds;
Expand Down
@@ -1,4 +1,4 @@
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter.dart';

class LeavePOutDataSplitter implements DataSplitter {
LeavePOutDataSplitter([this._p = 2]) {
Expand Down
2 changes: 1 addition & 1 deletion pubspec.yaml
Expand Up @@ -11,7 +11,7 @@ dependencies:
json_annotation: ^3.0.1
json_serializable: ^3.3.0
ml_dataframe: ^0.1.1
ml_linalg: ^12.17.0
ml_linalg: ^12.17.1
quiver: ^2.0.2
xrange: ^0.0.8

Expand Down
4 changes: 2 additions & 2 deletions test/mocks.dart
Expand Up @@ -23,8 +23,8 @@ import 'package:ml_algo/src/link_function/link_function.dart';
import 'package:ml_algo/src/math/randomizer/randomizer.dart';
import 'package:ml_algo/src/math/randomizer/randomizer_factory.dart';
import 'package:ml_algo/src/model_selection/assessable.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter_factory.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter_factory.dart';
import 'package:ml_algo/src/regressor/knn_regressor/knn_regressor.dart';
import 'package:ml_algo/src/regressor/knn_regressor/knn_regressor_factory.dart';
import 'package:ml_algo/src/tree_trainer/decision_tree_trainer.dart';
Expand Down
Expand Up @@ -2,7 +2,7 @@ import 'package:ml_algo/src/common/exception/invalid_test_data_columns_number_ex
import 'package:ml_algo/src/common/exception/invalid_train_data_columns_number_exception.dart';
import 'package:ml_algo/src/metric/metric_type.dart';
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_impl.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/dtype.dart';
import 'package:mockito/mockito.dart';
Expand Down
@@ -1,9 +1,9 @@
import 'package:injector/injector.dart';
import 'package:ml_algo/ml_algo.dart';
import 'package:ml_algo/src/di/injector.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter_factory.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter_type.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter_factory.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter_type.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:mockito/mockito.dart';
import 'package:test/test.dart';
Expand Down
@@ -1,7 +1,7 @@
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter_factory_impl.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter_type.dart';
import 'package:ml_algo/src/model_selection/data_splitter/k_fold_data_splitter.dart';
import 'package:ml_algo/src/model_selection/data_splitter/leave_p_out_data_splitter.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter_factory_impl.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/data_splitter_type.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/k_fold_data_splitter.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/leave_p_out_data_splitter.dart';
import 'package:test/test.dart';

void main() {
Expand Down
@@ -1,4 +1,4 @@
import 'package:ml_algo/src/model_selection/data_splitter/k_fold_data_splitter.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/k_fold_data_splitter.dart';
import 'package:test/test.dart';

void main() {
Expand Down
@@ -1,4 +1,4 @@
import 'package:ml_algo/src/model_selection/data_splitter/leave_p_out_data_splitter.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/leave_p_out_data_splitter.dart';
import 'package:test/test.dart';

void main() {
Expand Down

0 comments on commit b5992b5

Please sign in to comment.