Skip to content

Commit

Permalink
Constants for allowed metrics added
Browse files Browse the repository at this point in the history
  • Loading branch information
gyrdym committed Sep 11, 2020
1 parent 8b041f2 commit 0eded86
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 15 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
@@ -1,5 +1,9 @@
# Changelog

## 15.2.1
- `classificationMetrics` constant list added
- `regressionMetrics` constant list added

## 15.2.0
- Recall metric added

Expand Down
16 changes: 16 additions & 0 deletions e2e/decision_tree_classifier_test.dart
Expand Up @@ -69,5 +69,21 @@ void main() async {

expect(scores.mean(), greaterThan(0.5));
});

test('should return adequate score on iris dataset using recall '
'metric, dtype=DType.float32', () async {
final scores = await evaluateClassifier(
MetricType.recall, DType.float32);

expect(scores.mean(), greaterThan(0.5));
});

test('should return adequate score on iris dataset using recall '
'metric, dtype=DType.float64', () async {
final scores = await evaluateClassifier(
MetricType.recall, DType.float64);

expect(scores.mean(), greaterThan(0.5));
});
});
}
16 changes: 16 additions & 0 deletions e2e/knn_classifier_test.dart
Expand Up @@ -69,5 +69,21 @@ void main() async {

expect(scores.mean(), greaterThan(0.5));
});

test('should return adequate score on iris dataset using recall '
'metric, dtype=DType.float32', () async {
final scores = await evaluateKnnClassifier(MetricType.recall,
DType.float32);

expect(scores.mean(), greaterThan(0.5));
});

test('should return adequate score on iris dataset using recall '
'metric, dtype=DType.float64', () async {
final scores = await evaluateKnnClassifier(MetricType.recall,
DType.float64);

expect(scores.mean(), greaterThan(0.5));
});
});
}
16 changes: 16 additions & 0 deletions e2e/logistic_regressor_test.dart
Expand Up @@ -65,5 +65,21 @@ Future main() async {

expect(scores.mean(), greaterThan(0.5));
});

test('should return adequate score on pima indians diabetes dataset using '
'recall metric, dtype=DType.float32', () async {
final scores = await evaluateLogisticRegressor(MetricType.recall,
DType.float32);

expect(scores.mean(), greaterThan(0.5));
});

test('should return adequate score on pima indians diabetes dataset using '
'recall metric, dtype=DType.float64', () async {
final scores = await evaluateLogisticRegressor(MetricType.recall,
DType.float32);

expect(scores.mean(), greaterThan(0.5));
});
});
}
16 changes: 16 additions & 0 deletions e2e/softmax_regressor_test.dart
Expand Up @@ -72,5 +72,21 @@ Future main() async {

expect(scores.mean(), greaterThan(0.5));
});

test('should return adequate score on iris dataset using recall '
'metric, dtype=DType.float32', () async {
final scores = await evaluateSoftmaxRegressor(MetricType.recall,
DType.float32);

expect(scores.mean(), greaterThan(0.5));
});

test('should return adequate score on iris dataset using recall '
'metric, dtype=DType.float64', () async {
final scores = await evaluateSoftmaxRegressor(MetricType.recall,
DType.float64);

expect(scores.mean(), greaterThan(0.5));
});
});
}
12 changes: 12 additions & 0 deletions lib/src/metric/metric.constants.dart
@@ -0,0 +1,12 @@
import 'package:ml_algo/src/metric/metric_type.dart';

const classificationMetrics = [
MetricType.recall,
MetricType.precision,
MetricType.accuracy,
];

const regressionMetrics = [
MetricType.mape,
MetricType.rmse,
];
10 changes: 3 additions & 7 deletions lib/src/model_selection/model_assessor/classifier_assessor.dart
Expand Up @@ -3,6 +3,7 @@ import 'package:ml_algo/src/common/exception/invalid_metric_type_exception.dart'
import 'package:ml_algo/src/di/common/init_common_module.dart';
import 'package:ml_algo/src/helpers/features_target_split_interface.dart';
import 'package:ml_algo/src/helpers/normalize_class_labels_interface.dart';
import 'package:ml_algo/src/metric/metric.constants.dart';
import 'package:ml_algo/src/metric/metric_factory.dart';
import 'package:ml_algo/src/metric/metric_type.dart';
import 'package:ml_algo/src/model_selection/model_assessor/model_assessor.dart';
Expand All @@ -16,11 +17,6 @@ class ClassifierAssessor implements ModelAssessor<Classifier> {
this._normalizeClassLabels,
);

static const List<MetricType> _allowedMetricTypes = [
MetricType.precision,
MetricType.accuracy,
];

final MetricFactory _metricFactory;
final EncoderFactory _encoderFactory;
final FeaturesTargetSplit _featuresTargetSplit;
Expand All @@ -32,9 +28,9 @@ class ClassifierAssessor implements ModelAssessor<Classifier> {
MetricType metricType,
DataFrame samples,
) {
if (!_allowedMetricTypes.contains(metricType)) {
if (!classificationMetrics.contains(metricType)) {
throw InvalidMetricTypeException(
metricType, _allowedMetricTypes);
metricType, classificationMetrics);
}

final splits = _featuresTargetSplit(
Expand Down
10 changes: 3 additions & 7 deletions lib/src/model_selection/model_assessor/regressor_assessor.dart
@@ -1,6 +1,7 @@
import 'package:ml_algo/ml_algo.dart';
import 'package:ml_algo/src/common/exception/invalid_metric_type_exception.dart';
import 'package:ml_algo/src/helpers/features_target_split_interface.dart';
import 'package:ml_algo/src/metric/metric.constants.dart';
import 'package:ml_algo/src/metric/metric_factory.dart';
import 'package:ml_algo/src/metric/metric_type.dart';
import 'package:ml_algo/src/model_selection/model_assessor/model_assessor.dart';
Expand All @@ -13,11 +14,6 @@ class RegressorAssessor implements ModelAssessor<Predictor> {
this._featuresTargetSplit,
);

static const List<MetricType> _allowedMetricTypes = [
MetricType.rmse,
MetricType.mape,
];

final MetricFactory _metricFactory;
final FeaturesTargetSplit _featuresTargetSplit;

Expand All @@ -27,9 +23,9 @@ class RegressorAssessor implements ModelAssessor<Predictor> {
MetricType metricType,
DataFrame samples,
) {
if (!_allowedMetricTypes.contains(metricType)) {
if (!regressionMetrics.contains(metricType)) {
throw InvalidMetricTypeException(
metricType, _allowedMetricTypes);
metricType, regressionMetrics);
}

final splits = _featuresTargetSplit(
Expand Down
2 changes: 1 addition & 1 deletion pubspec.yaml
@@ -1,6 +1,6 @@
name: ml_algo
description: Machine learning algorithms, Machine learning models performance evaluation functionality
version: 15.2.0
version: 15.2.1
homepage: https://github.com/gyrdym/ml_algo

environment:
Expand Down

0 comments on commit 0eded86

Please sign in to comment.