Skip to content

Commit

Permalink
Merge 5e06b60 into cef79bc
Browse files Browse the repository at this point in the history
  • Loading branch information
gyrdym committed Jun 21, 2020
2 parents cef79bc + 5e06b60 commit 3886eaf
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 91 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
@@ -1,5 +1,9 @@
# Changelog

## 14.0.0
- Breaking change:
- `CrossValidator`: `evalute` method's api changed, it returns a Future resolving with scores Vector now

## 13.10.0
- `LinearRegressor`:
- `Default constructor`: `collectLearningData` parameter added
Expand Down
@@ -0,0 +1,10 @@
class InvalidTestDataColumnsNumberException implements Exception {
InvalidTestDataColumnsNumberException(int expected, int received) :
message = 'Unexpected columns number in test data, '
'expected $expected, received ${received}';

final String message;

@override
String toString() => message;
}
@@ -0,0 +1,10 @@
class InvalidTrainDataColumnsNumberException implements Exception {
InvalidTrainDataColumnsNumberException(int expected, int received) :
message = 'Unexpected columns number in training data, '
'expected $expected, received ${received}';

final String message;

@override
String toString() => message;
}
36 changes: 20 additions & 16 deletions lib/src/model_selection/cross_validator/cross_validator.dart
Expand Up @@ -23,8 +23,8 @@ abstract class CrossValidator {
///
/// Parameters:
///
/// [samples] The whole training dataset to be split into parts to iteratively
/// evaluate given predictor on the each particular part
/// [samples] A dataset to be split into parts to iteratively evaluate given
/// predictor's performance
///
/// [targetColumnNames] Names of columns from [samples] that contain outcomes
///
Expand Down Expand Up @@ -57,8 +57,8 @@ abstract class CrossValidator {
///
/// Parameters:
///
/// [samples] The whole training dataset to be split into parts to iteratively
/// evaluate given model on the each particular part.
/// [samples] A dataset to be split into parts to iteratively
/// evaluate given predictor's performance
///
/// [targetColumnNames] Names of columns from [samples] that contain outcomes.
///
Expand All @@ -83,21 +83,21 @@ abstract class CrossValidator {
);
}

/// Returns a score of quality of passed predictor depending on given
/// [metricType]
/// Returns a future resolving with a vector of scores of quality of passed
/// predictor depending on given [metricType]
///
/// Parameters:
///
/// [predictorFactory] A factory function that returns a testing predictor
/// [predictorFactory] A factory function that returns an evaluating predictor
///
/// [metricType] Metric to assess a predictor, that is being created by
/// [metricType] Metric using to assess a predictor creating by
/// [predictorFactory]
///
/// [onDataSplit] A callback that is called when a new train-test split is
/// ready to be passed into evaluating predictor. One may place some
/// additional data-dependent logic here, e.g., data preprocessing. The
/// callback accepts train and test data from a new split and returns
/// transformed split as list, where the first element is training data and
/// transformed split as list, where the first element is train data and
/// the second one - test data, both of [DataFrame] type. This new transformed
/// split will be passed into the predictor.
///
Expand All @@ -115,26 +115,30 @@ abstract class CrossValidator {
/// header: header,
/// headerExists: false,
/// );
///
/// final predictorFactory = (trainData, _) =>
/// KnnRegressor(trainData, 'col_3', k: 4);
///
/// final onDataSplit = (trainData, testData) {
/// final standardizer = Standardizer(trainData);
/// return [
/// standardizer.process(trainData),
/// standardizer.process(testData),
/// ];
/// }
///
/// final validator = CrossValidator.kFold(data, ['col_3']);
/// final score = validator.evaluate(
/// final scores = await validator.evaluate(
/// predictorFactory,
/// MetricType.mape,
/// onDataSplit: onDataSplit,
/// );
/// final averageScore = scores.mean();
///
/// print(averageScore);
/// ````
double evaluate(PredictorFactory predictorFactory, MetricType metricType, {
DataPreprocessFn onDataSplit,
});
Future<Vector> evaluate(
PredictorFactory predictorFactory,
MetricType metricType,
{
DataPreprocessFn onDataSplit,
}
);
}
80 changes: 38 additions & 42 deletions lib/src/model_selection/cross_validator/cross_validator_impl.dart
@@ -1,3 +1,5 @@
import 'package:ml_algo/src/common/exception/invalid_test_data_columns_number_exception.dart';
import 'package:ml_algo/src/common/exception/invalid_train_data_columns_number_exception.dart';
import 'package:ml_algo/src/metric/metric_type.dart';
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter.dart';
Expand All @@ -21,54 +23,48 @@ class CrossValidatorImpl implements CrossValidator {
final DataSplitter _splitter;

@override
double evaluate(PredictorFactory predictorFactory, MetricType metricType, {
DataPreprocessFn onDataSplit,
}) {
Future<Vector> evaluate(
PredictorFactory predictorFactory,
MetricType metricType,
{
DataPreprocessFn onDataSplit,
}
) {
final samplesAsMatrix = samples.toMatrix(dtype);
final sourceColumnsNum = samplesAsMatrix.columnsNum;

final discreteColumns = enumerate(samples.series)
.where((indexedSeries) => indexedSeries.value.isDiscrete)
.map((indexedSeries) => indexedSeries.index);

final allIndicesGroups = _splitter.split(samplesAsMatrix.rowsNum);
var score = 0.0;
var folds = 0;

for (final testRowsIndices in allIndicesGroups) {
final split = _makeSplit(testRowsIndices, discreteColumns);
final trainDataFrame = split[0];
final testDataFrame = split[1];

final splits = onDataSplit != null
? onDataSplit(trainDataFrame, testDataFrame)
: [trainDataFrame, testDataFrame];

final transformedTrainData = splits[0];
final transformedTestData = splits[1];

final transformedTrainDataColumnsNum = transformedTrainData.header.length;
final transformedTestDataColumnsNum = transformedTestData.header.length;

if (transformedTrainDataColumnsNum != sourceColumnsNum) {
throw Exception('Unexpected columns number in training data: '
'expected $sourceColumnsNum, received '
'${transformedTrainDataColumnsNum}');
}

if (transformedTestDataColumnsNum != sourceColumnsNum) {
throw Exception('Unexpected columns number in testing data: '
'expected $sourceColumnsNum, received '
'${transformedTestDataColumnsNum}');
}

score += predictorFactory(transformedTrainData, targetNames)
.assess(transformedTestData, targetNames, metricType);

folds++;
}

return score / folds;
final scores = allIndicesGroups
.map((testRowsIndices) {
final split = _makeSplit(testRowsIndices, discreteColumns);
final trainDataFrame = split[0];
final testDataFrame = split[1];
final splits = onDataSplit != null
? onDataSplit(trainDataFrame, testDataFrame)
: [trainDataFrame, testDataFrame];
final transformedTrainData = splits[0];
final transformedTestData = splits[1];
final transformedTrainDataColumnsNum = transformedTrainData.header.length;
final transformedTestDataColumnsNum = transformedTestData.header.length;

if (transformedTrainDataColumnsNum != sourceColumnsNum) {
throw InvalidTrainDataColumnsNumberException(sourceColumnsNum,
transformedTrainDataColumnsNum);
}

if (transformedTestDataColumnsNum != sourceColumnsNum) {
throw InvalidTestDataColumnsNumberException(sourceColumnsNum,
transformedTestDataColumnsNum);
}

return predictorFactory(transformedTrainData, targetNames)
.assess(transformedTestData, targetNames, metricType);
})
.toList();

return Future.value(Vector.fromList(scores, dtype: dtype));
}

List<DataFrame> _makeSplit(Iterable<int> testRowsIndices,
Expand Down
2 changes: 1 addition & 1 deletion pubspec.yaml
@@ -1,6 +1,6 @@
name: ml_algo
description: Machine learning algorithms written in native dart
version: 13.10.0
version: 14.0.0
homepage: https://github.com/gyrdym/ml_algo

environment:
Expand Down

0 comments on commit 3886eaf

Please sign in to comment.