From 7798c380c05178e875ebd167dba26fa1095913fe Mon Sep 17 00:00:00 2001 From: Ilya Gyrdymov Date: Fri, 1 Nov 2019 01:08:54 +0200 Subject: [PATCH] LinearClassifierMixin extended --- .../_mixins/linear_classifier_mixin.dart | 17 +++++++++++------ .../logistic_regressor_impl.dart | 17 +---------------- .../softmax_regressor_impl.dart | 17 +---------------- .../predictor/assessable_predictor_mixin.dart | 4 ++-- 4 files changed, 15 insertions(+), 40 deletions(-) diff --git a/lib/src/classifier/_mixins/linear_classifier_mixin.dart b/lib/src/classifier/_mixins/linear_classifier_mixin.dart index 21404e7a..a5d7cad3 100644 --- a/lib/src/classifier/_mixins/linear_classifier_mixin.dart +++ b/lib/src/classifier/_mixins/linear_classifier_mixin.dart @@ -3,10 +3,20 @@ import 'package:ml_algo/src/helpers/add_intercept_if.dart'; import 'package:ml_algo/src/helpers/validate_coefficients_matrix.dart'; import 'package:ml_algo/src/helpers/validate_test_features.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; +import 'package:ml_linalg/matrix.dart'; mixin LinearClassifierMixin implements LinearClassifier { @override DataFrame predictProbabilities(DataFrame testFeatures) { + final probabilities = getProbabilitiesMatrix(testFeatures); + + return DataFrame.fromMatrix( + probabilities, + header: classNames, + ); + } + + Matrix getProbabilitiesMatrix(DataFrame testFeatures) { validateTestFeatures(testFeatures, dtype); final processedFeatures = addInterceptIf( @@ -18,12 +28,7 @@ mixin LinearClassifierMixin implements LinearClassifier { validateCoefficientsMatrix(coefficientsByClasses, processedFeatures.columnsNum); - final probabilities = linkFunction + return linkFunction .link(processedFeatures * coefficientsByClasses); - - return DataFrame.fromMatrix( - probabilities, - header: classNames, - ); } } diff --git a/lib/src/classifier/logistic_regressor/logistic_regressor_impl.dart b/lib/src/classifier/logistic_regressor/logistic_regressor_impl.dart index 4176f870..d5e02265 100644 --- a/lib/src/classifier/logistic_regressor/logistic_regressor_impl.dart +++ b/lib/src/classifier/logistic_regressor/logistic_regressor_impl.dart @@ -1,8 +1,6 @@ import 'package:ml_algo/src/classifier/_mixins/linear_classifier_mixin.dart'; import 'package:ml_algo/src/classifier/logistic_regressor/logistic_regressor.dart'; -import 'package:ml_algo/src/helpers/add_intercept_if.dart'; import 'package:ml_algo/src/helpers/validate_coefficients_matrix.dart'; -import 'package:ml_algo/src/helpers/validate_test_features.dart'; import 'package:ml_algo/src/link_function/link_function.dart'; import 'package:ml_algo/src/predictor/assessable_predictor_mixin.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; @@ -62,20 +60,7 @@ class LogisticRegressorImpl with LinearClassifierMixin, @override DataFrame predict(DataFrame testFeatures) { - validateTestFeatures(testFeatures, dtype); - - final processedFeatures = addInterceptIf( - fitIntercept, - testFeatures.toMatrix(dtype), - interceptScale, - ); - - validateCoefficientsMatrix(coefficientsByClasses, - processedFeatures.columnsNum); - - final probabilities = linkFunction - .link(processedFeatures * coefficientsByClasses) - .getColumn(0); + final probabilities = getProbabilitiesMatrix(testFeatures).getColumn(0); final classesList = probabilities // TODO: use SIMD diff --git a/lib/src/classifier/softmax_regressor/softmax_regressor_impl.dart b/lib/src/classifier/softmax_regressor/softmax_regressor_impl.dart index 0e572e10..746e3421 100644 --- a/lib/src/classifier/softmax_regressor/softmax_regressor_impl.dart +++ b/lib/src/classifier/softmax_regressor/softmax_regressor_impl.dart @@ -1,8 +1,5 @@ import 'package:ml_algo/src/classifier/_mixins/linear_classifier_mixin.dart'; import 'package:ml_algo/src/classifier/softmax_regressor/softmax_regressor.dart'; -import 'package:ml_algo/src/helpers/add_intercept_if.dart'; -import 'package:ml_algo/src/helpers/validate_coefficients_matrix.dart'; -import 'package:ml_algo/src/helpers/validate_test_features.dart'; import 'package:ml_algo/src/link_function/link_function.dart'; import 'package:ml_algo/src/predictor/assessable_predictor_mixin.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; @@ -48,19 +45,7 @@ class SoftmaxRegressorImpl with LinearClassifierMixin, @override DataFrame predict(DataFrame testFeatures) { - validateTestFeatures(testFeatures, dtype); - - final processedFeatures = addInterceptIf( - fitIntercept, - testFeatures.toMatrix(dtype), - interceptScale, - ); - - validateCoefficientsMatrix(coefficientsByClasses, - processedFeatures.columnsNum); - - final allProbabilities = linkFunction - .link(processedFeatures * coefficientsByClasses); + final allProbabilities = getProbabilitiesMatrix(testFeatures); final classes = allProbabilities.mapRows((probabilities) { final positiveLabelIdx = probabilities diff --git a/lib/src/predictor/assessable_predictor_mixin.dart b/lib/src/predictor/assessable_predictor_mixin.dart index 8eeaf63d..3726aeff 100644 --- a/lib/src/predictor/assessable_predictor_mixin.dart +++ b/lib/src/predictor/assessable_predictor_mixin.dart @@ -15,8 +15,8 @@ mixin AssessablePredictorMixin implements Assessable, Predictor { final metric = MetricFactory.createByType(metricType); final prediction = predict(splits[0]); - final origLabels = splits[1].toMatrix(); + final origLabels = splits[1].toMatrix(dtype); - return metric.getScore(prediction.toMatrix(), origLabels); + return metric.getScore(prediction.toMatrix(dtype), origLabels); } }