diff --git a/lib/src/classifier/_mixins/linear_classifier_mixin.dart b/lib/src/classifier/_mixins/linear_classifier_mixin.dart index 4da421bf..21404e7a 100644 --- a/lib/src/classifier/_mixins/linear_classifier_mixin.dart +++ b/lib/src/classifier/_mixins/linear_classifier_mixin.dart @@ -1,22 +1,25 @@ import 'package:ml_algo/src/classifier/linear_classifier.dart'; import 'package:ml_algo/src/helpers/add_intercept_if.dart'; -import 'package:ml_algo/src/helpers/get_probabilities.dart'; +import 'package:ml_algo/src/helpers/validate_coefficients_matrix.dart'; +import 'package:ml_algo/src/helpers/validate_test_features.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; mixin LinearClassifierMixin implements LinearClassifier { @override - DataFrame predictProbabilities(DataFrame features) { + DataFrame predictProbabilities(DataFrame testFeatures) { + validateTestFeatures(testFeatures, dtype); + final processedFeatures = addInterceptIf( fitIntercept, - features.toMatrix(), + testFeatures.toMatrix(dtype), interceptScale, ); - final probabilities = getProbabilities( - processedFeatures, - coefficientsByClasses, - linkFunction, - ); + validateCoefficientsMatrix(coefficientsByClasses, + processedFeatures.columnsNum); + + final probabilities = linkFunction + .link(processedFeatures * coefficientsByClasses); return DataFrame.fromMatrix( probabilities, diff --git a/lib/src/classifier/decision_tree_classifier/decision_tree_classifier_impl.dart b/lib/src/classifier/decision_tree_classifier/decision_tree_classifier_impl.dart index fd27d98c..4c6dc659 100644 --- a/lib/src/classifier/decision_tree_classifier/decision_tree_classifier_impl.dart +++ b/lib/src/classifier/decision_tree_classifier/decision_tree_classifier_impl.dart @@ -9,10 +9,11 @@ import 'package:ml_linalg/vector.dart'; class DecisionTreeClassifierImpl with AssessablePredictorMixin implements DecisionTreeClassifier { - DecisionTreeClassifierImpl(this._solver, String className, this._dtype) + DecisionTreeClassifierImpl(this._solver, String className, this.dtype) : classNames = [className]; - final DType _dtype; + @override + final DType dtype; final DecisionTreeSolver _solver; @@ -22,7 +23,7 @@ class DecisionTreeClassifierImpl with AssessablePredictorMixin @override DataFrame predict(DataFrame features) { final predictedLabels = features - .toMatrix(_dtype) + .toMatrix(dtype) .rows .map(_solver.getLabelForSample); @@ -33,10 +34,10 @@ class DecisionTreeClassifierImpl with AssessablePredictorMixin final outcomeList = predictedLabels .map((label) => label.value) .toList(growable: false); - final outcomeVector = Vector.fromList(outcomeList, dtype: _dtype); + final outcomeVector = Vector.fromList(outcomeList, dtype: dtype); return DataFrame.fromMatrix( - Matrix.fromColumns([outcomeVector], dtype: _dtype), + Matrix.fromColumns([outcomeVector], dtype: dtype), header: classNames, ); } @@ -46,14 +47,14 @@ class DecisionTreeClassifierImpl with AssessablePredictorMixin final probabilities = Matrix.fromColumns([ Vector.fromList( features - .toMatrix(_dtype) + .toMatrix(dtype) .rows .map(_solver.getLabelForSample) .map((label) => label.probability) .toList(growable: false), - dtype: _dtype, + dtype: dtype, ), - ], dtype: _dtype); + ], dtype: dtype); return DataFrame.fromMatrix( probabilities, diff --git a/lib/src/classifier/knn_classifier/knn_classifier_impl.dart b/lib/src/classifier/knn_classifier/knn_classifier_impl.dart index 6b9afcfc..16983baa 100644 --- a/lib/src/classifier/knn_classifier/knn_classifier_impl.dart +++ b/lib/src/classifier/knn_classifier/knn_classifier_impl.dart @@ -16,7 +16,7 @@ class KnnClassifierImpl with AssessablePredictorMixin implements KnnClassifier { this._classLabels, this._kernel, this._solver, - this._dtype, + this.dtype, ) : classNames = [targetName] { validateClassLabelList(_classLabels); } @@ -24,15 +24,17 @@ class KnnClassifierImpl with AssessablePredictorMixin implements KnnClassifier { @override final List classNames; + @override + final DType dtype; + final List _classLabels; final Kernel _kernel; final KnnSolver _solver; - final DType _dtype; - final String _columPrefix = 'Class label'; + final String _columnPrefix = 'Class label'; @override DataFrame predict(DataFrame features) { - validateTestFeatures(features, _dtype); + validateTestFeatures(features, dtype); final labelsToProbabilities = _getLabelToProbabilityMapping(features); final labels = labelsToProbabilities.keys.toList(); @@ -41,10 +43,10 @@ class KnnClassifierImpl with AssessablePredictorMixin implements KnnClassifier { .map((row) => labels[row.toList().indexOf(row.max())]) .toList(); - final outcomesAsVector = Vector.fromList(predictedOutcomes, dtype: _dtype); + final outcomesAsVector = Vector.fromList(predictedOutcomes, dtype: dtype); return DataFrame.fromMatrix( - Matrix.fromColumns([outcomesAsVector], dtype: _dtype), + Matrix.fromColumns([outcomesAsVector], dtype: dtype), header: classNames, ); } @@ -56,7 +58,7 @@ class KnnClassifierImpl with AssessablePredictorMixin implements KnnClassifier { final header = labelsToProbabilities .keys - .map((label) => '${_columPrefix} ${label.toString()}'); + .map((label) => '${_columnPrefix} ${label.toString()}'); return DataFrame.fromMatrix(probabilityMatrix, header: header); } @@ -84,7 +86,7 @@ class KnnClassifierImpl with AssessablePredictorMixin implements KnnClassifier { /// where each row is a classes probability distribution for the appropriate /// feature record from test feature matrix Map> _getLabelToProbabilityMapping(DataFrame features) { - final kNeighbourGroups = _solver.findKNeighbours(features.toMatrix(_dtype)); + final kNeighbourGroups = _solver.findKNeighbours(features.toMatrix(dtype)); final classLabelsAsSet = Set.from(_classLabels); return kNeighbourGroups.fold>>( @@ -133,11 +135,11 @@ class KnnClassifierImpl with AssessablePredictorMixin implements KnnClassifier { Matrix _getProbabilityMatrix(Map> allLabelsToProbabilities) { final probabilityVectors = allLabelsToProbabilities .values - .map((probabilities) => Vector.fromList(probabilities, dtype: _dtype)) + .map((probabilities) => Vector.fromList(probabilities, dtype: dtype)) .toList(growable: false); return Matrix - .fromColumns(probabilityVectors, dtype: _dtype); + .fromColumns(probabilityVectors, dtype: dtype); } Map _updateLabelToWeightMapping( diff --git a/lib/src/classifier/logistic_regressor/logistic_regressor_impl.dart b/lib/src/classifier/logistic_regressor/logistic_regressor_impl.dart index 5ed3e577..4176f870 100644 --- a/lib/src/classifier/logistic_regressor/logistic_regressor_impl.dart +++ b/lib/src/classifier/logistic_regressor/logistic_regressor_impl.dart @@ -22,7 +22,7 @@ class LogisticRegressorImpl with LinearClassifierMixin, this._probabilityThreshold, this._negativeLabel, this._positiveLabel, - this._dtype, + this.dtype, ) : classNames = [className] { validateCoefficientsMatrix(coefficientsByClasses); @@ -53,18 +53,20 @@ class LogisticRegressorImpl with LinearClassifierMixin, @override final LinkFunction linkFunction; - final DType _dtype; + @override + final DType dtype; + final num _probabilityThreshold; final num _positiveLabel; final num _negativeLabel; @override DataFrame predict(DataFrame testFeatures) { - validateTestFeatures(testFeatures, _dtype); + validateTestFeatures(testFeatures, dtype); final processedFeatures = addInterceptIf( fitIntercept, - testFeatures.toMatrix(_dtype), + testFeatures.toMatrix(dtype), interceptScale, ); diff --git a/lib/src/classifier/softmax_regressor/softmax_regressor_impl.dart b/lib/src/classifier/softmax_regressor/softmax_regressor_impl.dart index 0ffe8059..0e572e10 100644 --- a/lib/src/classifier/softmax_regressor/softmax_regressor_impl.dart +++ b/lib/src/classifier/softmax_regressor/softmax_regressor_impl.dart @@ -1,7 +1,6 @@ import 'package:ml_algo/src/classifier/_mixins/linear_classifier_mixin.dart'; import 'package:ml_algo/src/classifier/softmax_regressor/softmax_regressor.dart'; import 'package:ml_algo/src/helpers/add_intercept_if.dart'; -import 'package:ml_algo/src/helpers/get_probabilities.dart'; import 'package:ml_algo/src/helpers/validate_coefficients_matrix.dart'; import 'package:ml_algo/src/helpers/validate_test_features.dart'; import 'package:ml_algo/src/link_function/link_function.dart'; @@ -22,7 +21,7 @@ class SoftmaxRegressorImpl with LinearClassifierMixin, this.interceptScale, this._positiveLabel, this._negativeLabel, - this._dtype, + this.dtype, ); @override @@ -37,7 +36,8 @@ class SoftmaxRegressorImpl with LinearClassifierMixin, @override final Matrix coefficientsByClasses; - final DType _dtype; + @override + final DType dtype; @override final LinkFunction linkFunction; @@ -48,11 +48,11 @@ class SoftmaxRegressorImpl with LinearClassifierMixin, @override DataFrame predict(DataFrame testFeatures) { - validateTestFeatures(testFeatures, _dtype); + validateTestFeatures(testFeatures, dtype); final processedFeatures = addInterceptIf( fitIntercept, - testFeatures.toMatrix(_dtype), + testFeatures.toMatrix(dtype), interceptScale, ); @@ -74,7 +74,7 @@ class SoftmaxRegressorImpl with LinearClassifierMixin, predictedRow[positiveLabelIdx] = _positiveLabel; - return Vector.fromList(predictedRow, dtype: _dtype); + return Vector.fromList(predictedRow, dtype: dtype); }); return DataFrame.fromMatrix( diff --git a/lib/src/predictor/predictor.dart b/lib/src/predictor/predictor.dart index f6f73622..f1883b36 100644 --- a/lib/src/predictor/predictor.dart +++ b/lib/src/predictor/predictor.dart @@ -1,7 +1,10 @@ import 'package:ml_dataframe/ml_dataframe.dart'; +import 'package:ml_linalg/dtype.dart'; /// A common interface for all types of classifiers and regressors abstract class Predictor { /// Returns prediction, based on the model learned parameters DataFrame predict(DataFrame testFeatures); + + DType get dtype; } diff --git a/lib/src/regressor/knn_regressor/knn_regressor_impl.dart b/lib/src/regressor/knn_regressor/knn_regressor_impl.dart index 7fdbef03..9768bc8f 100644 --- a/lib/src/regressor/knn_regressor/knn_regressor_impl.dart +++ b/lib/src/regressor/knn_regressor/knn_regressor_impl.dart @@ -13,25 +13,27 @@ class KnnRegressorImpl with AssessablePredictorMixin implements KnnRegressor { this._targetName, this._solver, this._kernel, - this._dtype, + this.dtype, ); + @override + final DType dtype; + final String _targetName; final KnnSolver _solver; final Kernel _kernel; - final DType _dtype; - Vector get _zeroVector => _cachedZeroVector ??= Vector.zero(1, dtype: _dtype); + Vector get _zeroVector => _cachedZeroVector ??= Vector.zero(1, dtype: dtype); Vector _cachedZeroVector; @override DataFrame predict(DataFrame testFeatures) { - validateTestFeatures(testFeatures, _dtype); + validateTestFeatures(testFeatures, dtype); final prediction = Matrix.fromRows( - _predictOutcomes(testFeatures.toMatrix(_dtype)) + _predictOutcomes(testFeatures.toMatrix(dtype)) .toList(growable: false), - dtype: _dtype, + dtype: dtype, ); return DataFrame.fromMatrix( diff --git a/test/classifier/softmax_regressor/softmax_regressor_impl_test.dart b/test/classifier/softmax_regressor/softmax_regressor_impl_test.dart index f55adfb3..e945ef66 100644 --- a/test/classifier/softmax_regressor/softmax_regressor_impl_test.dart +++ b/test/classifier/softmax_regressor/softmax_regressor_impl_test.dart @@ -143,11 +143,73 @@ void main() { expect(actual.toMatrix(dtype), equals(expectedOutcomeMatrix)); }); - test('should return a dataframe with proper header', () { + test('should predict the first class if outcome is equiprobable', () { + reset(linkFunctionMock); + + when(linkFunctionMock.link(any)).thenReturn(Matrix.fromList([ + [0.33, 0.33, 0.33], + ])); + + final actual = regressor.predict(testFeatures); + final expectedOutcome = Matrix.fromList([ + [positiveLabel, negativeLabel, negativeLabel], + ]); + + expect(actual.toMatrix(dtype), equals(expectedOutcome)); + }); + + test('should return a dataframe with a proper header', () { final actual = regressor.predict(testFeatures); expect(actual.header, equals(targetNames)); }); }); + + group('`predictProbabilities` method', () { + test('should throw an exception if no features provided', () { + final testFeatures = DataFrame.fromMatrix(Matrix.empty()); + + expect(() => regressor.predictProbabilities(testFeatures), + throwsException); + }); + + test('should throw an exception if too few features provided', () { + final testFeatures = DataFrame.fromMatrix(Matrix.fromList([ + [1, 2], + ])); + + expect(() => regressor.predictProbabilities(testFeatures), + throwsException); + }); + + test('should throw an exception if too many features provided', () { + final testFeatures = DataFrame.fromMatrix(Matrix.fromList([ + [1, 2, 4, 4, 5, 6], + ])); + + expect(() => regressor.predictProbabilities(testFeatures), + throwsException); + }); + + test('should consider intercept term', () { + regressor.predictProbabilities(testFeatures); + + verify(linkFunctionMock.link( + testFeaturesMatrixWithIntercept * coefficientsByClasses) + ).called(1); + }); + + test('should return probabilities as dataframe', () { + final probabilities = regressor.predictProbabilities(testFeatures); + + expect(probabilities.rows, equals(mockedProbabilities)); + }); + + test('should return a dataframe with a proper header', () { + final probabilities = regressor.predictProbabilities(testFeatures); + + expect(probabilities.header, equals(targetNames)); + }); + }); }); }