From eb219ba9725ae9de9745943111b3930cca66f050 Mon Sep 17 00:00:00 2001 From: Ilya Gyrdymov Date: Tue, 12 Feb 2019 00:35:40 +0200 Subject: [PATCH 1/3] link function accepts vector --- .../labels_probability_calculator.dart | 5 - ...labels_probability_calculator_factory.dart | 7 - ...s_probability_calculator_factory_impl.dart | 14 -- .../labels_probability_calculator_impl.dart | 31 ---- lib/src/classifier/linear_classifier.dart | 6 + lib/src/classifier/logistic_regressor.dart | 16 +- lib/src/classifier/softmax_regressor.dart | 151 ++++++++++++++++++ lib/src/cost_function/log_likelihood.dart | 6 +- lib/src/link_function/link_function.dart | 7 +- .../link_function/link_function_factory.dart | 2 +- .../link_function_factory_impl.dart | 4 +- .../link_function/logit_link_function.dart | 19 ++- .../classifier/logistic_regressor_common.dart | 28 ++-- test/classifier/logistic_regressor_test.dart | 6 +- test/cost_function/cost_function_test.dart | 5 +- .../link_function_test.dart | 17 +- test/test_all.dart | 48 ++++-- test/test_utils/mocks.dart | 24 +-- 18 files changed, 256 insertions(+), 140 deletions(-) delete mode 100644 lib/src/classifier/labels_probability_calculator/labels_probability_calculator.dart delete mode 100644 lib/src/classifier/labels_probability_calculator/labels_probability_calculator_factory.dart delete mode 100644 lib/src/classifier/labels_probability_calculator/labels_probability_calculator_factory_impl.dart delete mode 100644 lib/src/classifier/labels_probability_calculator/labels_probability_calculator_impl.dart create mode 100644 lib/src/classifier/softmax_regressor.dart diff --git a/lib/src/classifier/labels_probability_calculator/labels_probability_calculator.dart b/lib/src/classifier/labels_probability_calculator/labels_probability_calculator.dart deleted file mode 100644 index 99361307..00000000 --- a/lib/src/classifier/labels_probability_calculator/labels_probability_calculator.dart +++ /dev/null @@ -1,5 +0,0 @@ -import 'package:ml_linalg/vector.dart'; - -abstract class LabelsProbabilityCalculator { - MLVector getProbabilities(MLVector scores); -} diff --git a/lib/src/classifier/labels_probability_calculator/labels_probability_calculator_factory.dart b/lib/src/classifier/labels_probability_calculator/labels_probability_calculator_factory.dart deleted file mode 100644 index c3f0512f..00000000 --- a/lib/src/classifier/labels_probability_calculator/labels_probability_calculator_factory.dart +++ /dev/null @@ -1,7 +0,0 @@ -import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; - -abstract class LabelsProbabilityCalculatorFactory { - LabelsProbabilityCalculator create( - LinkFunctionType linkFunctionType, Type dtype); -} diff --git a/lib/src/classifier/labels_probability_calculator/labels_probability_calculator_factory_impl.dart b/lib/src/classifier/labels_probability_calculator/labels_probability_calculator_factory_impl.dart deleted file mode 100644 index 075b9aab..00000000 --- a/lib/src/classifier/labels_probability_calculator/labels_probability_calculator_factory_impl.dart +++ /dev/null @@ -1,14 +0,0 @@ -import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator.dart'; -import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator_factory.dart'; -import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator_impl.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; - -class LabelsProbabilityCalculatorFactoryImpl - implements LabelsProbabilityCalculatorFactory { - const LabelsProbabilityCalculatorFactoryImpl(); - - @override - LabelsProbabilityCalculator create( - LinkFunctionType linkFunctionType, Type dtype) => - LabelsProbabilityCalculatorImpl(linkFunctionType, dtype); -} diff --git a/lib/src/classifier/labels_probability_calculator/labels_probability_calculator_impl.dart b/lib/src/classifier/labels_probability_calculator/labels_probability_calculator_impl.dart deleted file mode 100644 index 6dca02cc..00000000 --- a/lib/src/classifier/labels_probability_calculator/labels_probability_calculator_impl.dart +++ /dev/null @@ -1,31 +0,0 @@ -import 'dart:typed_data'; - -import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator.dart'; -import 'package:ml_algo/src/link_function/link_function.dart'; -import 'package:ml_algo/src/link_function/link_function_factory.dart'; -import 'package:ml_algo/src/link_function/link_function_factory_impl.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; -import 'package:ml_linalg/vector.dart'; - -class LabelsProbabilityCalculatorImpl implements LabelsProbabilityCalculator { - final Type dtype; - final LinkFunction linkFunction; - - LabelsProbabilityCalculatorImpl( - LinkFunctionType linkFunctionType, - this.dtype, { - LinkFunctionFactory linkFnFactory = const LinkFunctionFactoryImpl(), - }) : linkFunction = linkFnFactory.fromType(linkFunctionType); - - @override - MLVector getProbabilities(MLVector scores) { - switch (dtype) { - case Float32x4: - return scores.fastMap( - (Float32x4 el, int startOffset, int endOffset) => - linkFunction.float32x4Link(el)); - default: - throw UnsupportedError('Unsupported data type - $dtype'); - } - } -} diff --git a/lib/src/classifier/linear_classifier.dart b/lib/src/classifier/linear_classifier.dart index 02edd532..59fe1488 100644 --- a/lib/src/classifier/linear_classifier.dart +++ b/lib/src/classifier/linear_classifier.dart @@ -2,6 +2,7 @@ import 'package:ml_algo/gradient_type.dart'; import 'package:ml_algo/learning_rate_type.dart'; import 'package:ml_algo/src/classifier/classifier.dart'; import 'package:ml_algo/src/classifier/logistic_regressor.dart'; +import 'package:ml_algo/src/classifier/softmax_regressor.dart'; import 'package:ml_algo/src/optimizer/optimizer_type.dart'; /// A factory for all the linear classifiers @@ -70,6 +71,11 @@ abstract class LinearClassifier implements Classifier { Type dtype, }) = LogisticRegressor; + /** + * Creates a softmax regression classifier + */ + factory LinearClassifier.softMaxRegressor() = SoftMaxRegressor; + factory LinearClassifier.SVM() => throw UnimplementedError(); factory LinearClassifier.naiveBayes() => throw UnimplementedError(); } diff --git a/lib/src/classifier/logistic_regressor.dart b/lib/src/classifier/logistic_regressor.dart index a2620f5a..3c2ab13e 100644 --- a/lib/src/classifier/logistic_regressor.dart +++ b/lib/src/classifier/logistic_regressor.dart @@ -1,8 +1,5 @@ import 'package:ml_algo/gradient_type.dart'; import 'package:ml_algo/learning_rate_type.dart'; -import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator.dart'; -import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator_factory.dart'; -import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator_factory_impl.dart'; import 'package:ml_algo/src/classifier/labels_processor/labels_processor.dart'; import 'package:ml_algo/src/classifier/labels_processor/labels_processor_factory.dart'; import 'package:ml_algo/src/classifier/labels_processor/labels_processor_factory_impl.dart'; @@ -12,6 +9,9 @@ import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_ import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor_factory.dart'; import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor_factory_impl.dart'; import 'package:ml_algo/src/default_parameter_values.dart'; +import 'package:ml_algo/src/link_function/link_function.dart'; +import 'package:ml_algo/src/link_function/link_function_factory.dart'; +import 'package:ml_algo/src/link_function/link_function_factory_impl.dart'; import 'package:ml_algo/src/link_function/link_function_type.dart'; import 'package:ml_algo/src/metric/factory.dart'; import 'package:ml_algo/src/metric/metric_type.dart'; @@ -30,7 +30,7 @@ class LogisticRegressor implements LinearClassifier { final Optimizer optimizer; final InterceptPreprocessor interceptPreprocessor; final LabelsProcessor labelsProcessor; - final LabelsProbabilityCalculator probabilityCalculator; + final LinkFunction linkFunction; LogisticRegressor({ // public arguments @@ -54,15 +54,13 @@ class LogisticRegressor implements LinearClassifier { const LabelsProcessorFactoryImpl(), InterceptPreprocessorFactory interceptPreprocessorFactory = const InterceptPreprocessorFactoryImpl(), - LabelsProbabilityCalculatorFactory probabilityCalculatorFactory = - const LabelsProbabilityCalculatorFactoryImpl(), + LinkFunctionFactory linkFunctionFactory = const LinkFunctionFactoryImpl(), OptimizerFactory optimizerFactory = const OptimizerFactoryImpl(), BatchSizeCalculator batchSizeCalculator = const BatchSizeCalculatorImpl(), }) : labelsProcessor = labelsProcessorFactory.create(dtype), interceptPreprocessor = interceptPreprocessorFactory.create(dtype, scale: fitIntercept ? interceptScale : 0.0), - probabilityCalculator = - probabilityCalculatorFactory.create(linkFunctionType, dtype), + linkFunction = linkFunctionFactory.fromType(linkFunctionType, dtype), optimizer = optimizerFactory.fromType( optimizer, dtype: dtype, @@ -136,7 +134,7 @@ class LogisticRegressor implements LinearClassifier { int i = 0; _weightsByClasses.forEach((double label, MLVector weights) { final scores = (processedFeatures * weights).toVector(); - distributions[i++] = probabilityCalculator.getProbabilities(scores); + distributions[i++] = linkFunction.linkScoresToProbs(scores); }); return MLMatrix.columns(distributions, dtype: dtype); } diff --git a/lib/src/classifier/softmax_regressor.dart b/lib/src/classifier/softmax_regressor.dart new file mode 100644 index 00000000..c6742568 --- /dev/null +++ b/lib/src/classifier/softmax_regressor.dart @@ -0,0 +1,151 @@ +import 'package:ml_algo/gradient_type.dart'; +import 'package:ml_algo/learning_rate_type.dart'; +import 'package:ml_algo/src/classifier/labels_processor/labels_processor.dart'; +import 'package:ml_algo/src/classifier/labels_processor/labels_processor_factory.dart'; +import 'package:ml_algo/src/classifier/labels_processor/labels_processor_factory_impl.dart'; +import 'package:ml_algo/src/classifier/linear_classifier.dart'; +import 'package:ml_algo/src/cost_function/cost_function_type.dart'; +import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor.dart'; +import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor_factory.dart'; +import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor_factory_impl.dart'; +import 'package:ml_algo/src/default_parameter_values.dart'; +import 'package:ml_algo/src/link_function/link_function.dart'; +import 'package:ml_algo/src/link_function/link_function_factory.dart'; +import 'package:ml_algo/src/link_function/link_function_factory_impl.dart'; +import 'package:ml_algo/src/link_function/link_function_type.dart'; +import 'package:ml_algo/src/metric/factory.dart'; +import 'package:ml_algo/src/metric/metric_type.dart'; +import 'package:ml_algo/src/optimizer/gradient/batch_size_calculator/batch_size_calculator.dart'; +import 'package:ml_algo/src/optimizer/gradient/batch_size_calculator/batch_size_calculator_impl.dart'; +import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_type.dart'; +import 'package:ml_algo/src/optimizer/optimizer.dart'; +import 'package:ml_algo/src/optimizer/optimizer_factory.dart'; +import 'package:ml_algo/src/optimizer/optimizer_factory_impl.dart'; +import 'package:ml_algo/src/optimizer/optimizer_type.dart'; +import 'package:ml_linalg/matrix.dart'; +import 'package:ml_linalg/vector.dart'; + +class SoftMaxRegressor implements LinearClassifier { + final Type dtype; + final Optimizer optimizer; + final InterceptPreprocessor interceptPreprocessor; + final LabelsProcessor labelsProcessor; + final LinkFunction linkFunction; + + SoftMaxRegressor({ + // public arguments + int iterationsLimit = DefaultParameterValues.iterationsLimit, + double initialLearningRate = DefaultParameterValues.initialLearningRate, + double minWeightsUpdate = DefaultParameterValues.minWeightsUpdate, + double lambda, + int randomSeed, + int batchSize = 1, + bool fitIntercept = false, + double interceptScale = 1.0, + OptimizerType optimizer = OptimizerType.gradientDescent, + GradientType gradientType = GradientType.stochastic, + LearningRateType learningRateType = LearningRateType.constant, + InitialWeightsType initialWeightsType = InitialWeightsType.zeroes, + LinkFunctionType linkFunctionType = LinkFunctionType.logit, + this.dtype = DefaultParameterValues.dtype, + + // private arguments + LabelsProcessorFactory labelsProcessorFactory = + const LabelsProcessorFactoryImpl(), + InterceptPreprocessorFactory interceptPreprocessorFactory = + const InterceptPreprocessorFactoryImpl(), + LinkFunctionFactory linkFunctionFactory = const LinkFunctionFactoryImpl(), + OptimizerFactory optimizerFactory = const OptimizerFactoryImpl(), + BatchSizeCalculator batchSizeCalculator = const BatchSizeCalculatorImpl(), + }) : labelsProcessor = labelsProcessorFactory.create(dtype), + interceptPreprocessor = interceptPreprocessorFactory.create(dtype, + scale: fitIntercept ? interceptScale : 0.0), + linkFunction = linkFunctionFactory.fromType(linkFunctionType, dtype), + optimizer = optimizerFactory.fromType( + optimizer, + dtype: dtype, + costFunctionType: CostFunctionType.logLikelihood, + linkFunctionType: linkFunctionType, + learningRateType: learningRateType, + initialWeightsType: initialWeightsType, + initialLearningRate: initialLearningRate, + minCoefficientsUpdate: minWeightsUpdate, + iterationLimit: iterationsLimit, + lambda: lambda, + batchSize: gradientType != null + ? batchSizeCalculator.calculate(gradientType, batchSize) + : null, + randomSeed: randomSeed, + ); + + @override + MLVector get weights => null; + + @override + Map get weightsByClasses => _weightsByClasses; + Map _weightsByClasses; + + @override + List get classLabels => _classLabels; + List _classLabels; + + @override + void fit(MLMatrix features, MLVector labels, + {MLVector initialWeights, bool isDataNormalized = false}) { + _classLabels = labels.unique().toList(); + final labelsAsList = _classLabels.toList(); + final processedFeatures = interceptPreprocessor.addIntercept(features); + _weightsByClasses = Map.fromIterable( + labelsAsList, + key: (dynamic label) => label as double, + value: (dynamic label) => _fitBinaryClassifier(processedFeatures, labels, + label as double, initialWeights, isDataNormalized), + ); + } + + @override + double test(MLMatrix features, MLVector origLabels, MetricType metricType) { + final evaluator = MetricFactory.createByType(metricType); + final prediction = predictClasses(features); + return evaluator.getError(prediction, origLabels); + } + + @override + MLMatrix predictProbabilities(MLMatrix features) { + final processedFeatures = interceptPreprocessor.addIntercept(features); + return _predictProbabilities(processedFeatures); + } + + @override + MLVector predictClasses(MLMatrix features) { + final processedFeatures = interceptPreprocessor.addIntercept(features); + final distributions = _predictProbabilities(processedFeatures); + final classes = List(processedFeatures.rowsNum); + for (int i = 0; i < distributions.rowsNum; i++) { + final probabilities = distributions.getRow(i); + classes[i] = probabilities.toList().indexOf(probabilities.max()) * 1.0; + } + return MLVector.from(classes, dtype: dtype); + } + + MLMatrix _predictProbabilities(MLMatrix processedFeatures) { + final numOfObservations = _weightsByClasses.length; + final distributions = List(numOfObservations); + int i = 0; + _weightsByClasses.forEach((double label, MLVector weights) { + final scores = (processedFeatures * weights).toVector(); + distributions[i++] = linkFunction.linkScoresToProbs(scores); + }); + return MLMatrix.columns(distributions, dtype: dtype); + } + + MLVector _fitBinaryClassifier(MLMatrix features, MLVector labels, + double targetLabel, MLVector initialWeights, bool arePointsNormalized) { + final binaryLabels = + labelsProcessor.makeLabelsOneVsAll(labels, targetLabel); + return optimizer.findExtrema(features, binaryLabels, + initialWeights: initialWeights, + arePointsNormalized: arePointsNormalized, + isMinimizingObjective: false); + } +} diff --git a/lib/src/cost_function/log_likelihood.dart b/lib/src/cost_function/log_likelihood.dart index 8262ab25..79830a08 100644 --- a/lib/src/cost_function/log_likelihood.dart +++ b/lib/src/cost_function/log_likelihood.dart @@ -16,7 +16,7 @@ class LogLikelihoodCost implements CostFunction { LinkFunctionType linkFunctionType, { this.dtype = DefaultParameterValues.dtype, LinkFunctionFactory linkFunctionFactory = const LinkFunctionFactoryImpl(), - }) : linkFunction = linkFunctionFactory.fromType(linkFunctionType); + }) : linkFunction = linkFunctionFactory.fromType(linkFunctionType, dtype); @override double getCost(double score, double yOrig) { @@ -25,10 +25,10 @@ class LogLikelihoodCost implements CostFunction { @override MLVector getGradient(MLMatrix x, MLVector w, MLVector y) { + final scores = (x * w).toVector(); switch (dtype) { case Float32x4: - return (x.transpose() * - (y - (x * w).fastMap(linkFunction.float32x4Link))) + return (x.transpose() * (y - linkFunction.linkScoresToProbs(scores))) .toVector(); default: throw throw UnsupportedError('Unsupported data type - $dtype'); diff --git a/lib/src/link_function/link_function.dart b/lib/src/link_function/link_function.dart index ab43ec84..1bd0ab7d 100644 --- a/lib/src/link_function/link_function.dart +++ b/lib/src/link_function/link_function.dart @@ -1,5 +1,8 @@ -import 'dart:typed_data'; +import 'package:ml_linalg/matrix.dart'; +import 'package:ml_linalg/vector.dart'; abstract class LinkFunction { - Float32x4 float32x4Link(Float32x4 scores); + /// Accepts a vector of scores, returns a vector of probabilities + /// Score is a multiplication of a feature value and the corresponding weight (coefficient) + MLVector linkScoresToProbs(MLVector scores, [MLMatrix scoresByClasses]); } diff --git a/lib/src/link_function/link_function_factory.dart b/lib/src/link_function/link_function_factory.dart index 9d50ef75..522ecbf0 100644 --- a/lib/src/link_function/link_function_factory.dart +++ b/lib/src/link_function/link_function_factory.dart @@ -2,5 +2,5 @@ import 'package:ml_algo/src/link_function/link_function.dart'; import 'package:ml_algo/src/link_function/link_function_type.dart'; abstract class LinkFunctionFactory { - LinkFunction fromType(LinkFunctionType type); + LinkFunction fromType(LinkFunctionType type, Type dtype); } diff --git a/lib/src/link_function/link_function_factory_impl.dart b/lib/src/link_function/link_function_factory_impl.dart index a42ebec2..8dd2c696 100644 --- a/lib/src/link_function/link_function_factory_impl.dart +++ b/lib/src/link_function/link_function_factory_impl.dart @@ -7,10 +7,10 @@ class LinkFunctionFactoryImpl implements LinkFunctionFactory { const LinkFunctionFactoryImpl(); @override - LinkFunction fromType(LinkFunctionType type) { + LinkFunction fromType(LinkFunctionType type, Type dtype) { switch (type) { case LinkFunctionType.logit: - return LogitLinkFunction(); + return LogitLinkFunction(dtype); default: throw UnsupportedError('Unsupported link function type - $type'); } diff --git a/lib/src/link_function/logit_link_function.dart b/lib/src/link_function/logit_link_function.dart index 9d9d141b..db10c541 100644 --- a/lib/src/link_function/logit_link_function.dart +++ b/lib/src/link_function/logit_link_function.dart @@ -2,13 +2,30 @@ import 'dart:math' as math; import 'dart:typed_data'; import 'package:ml_algo/src/link_function/link_function.dart'; +import 'package:ml_linalg/matrix.dart'; +import 'package:ml_linalg/vector.dart'; class LogitLinkFunction implements LinkFunction { + final Type dtype; + final float32x4Zeroes = Float32x4.zero(); final float32x4Ones = Float32x4.splat(1.0); + LogitLinkFunction(this.dtype); + @override - Float32x4 float32x4Link(Float32x4 scores) => + MLVector linkScoresToProbs(MLVector scores, [MLMatrix scoresByClasses]) { + switch (dtype) { + case Float32x4: + return scores.fastMap( + (Float32x4 el, int startOffset, int endOffset) => + scoreToProbFloat32x4(el)); + default: + throw UnsupportedError('Unsupported data type - $dtype'); + } + } + + Float32x4 scoreToProbFloat32x4(Float32x4 scores) => float32x4Ones / (float32x4Ones + Float32x4( diff --git a/test/classifier/logistic_regressor_common.dart b/test/classifier/logistic_regressor_common.dart index a150f431..af3a673f 100644 --- a/test/classifier/logistic_regressor_common.dart +++ b/test/classifier/logistic_regressor_common.dart @@ -2,13 +2,13 @@ import 'dart:typed_data'; import 'package:ml_algo/gradient_type.dart'; import 'package:ml_algo/learning_rate_type.dart'; -import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator.dart'; -import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator_factory.dart'; import 'package:ml_algo/src/classifier/labels_processor/labels_processor.dart'; import 'package:ml_algo/src/classifier/labels_processor/labels_processor_factory.dart'; import 'package:ml_algo/src/classifier/logistic_regressor.dart'; import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor.dart'; import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor_factory.dart'; +import 'package:ml_algo/src/link_function/link_function.dart'; +import 'package:ml_algo/src/link_function/link_function_factory.dart'; import 'package:ml_algo/src/link_function/link_function_type.dart'; import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_type.dart'; import 'package:ml_algo/src/optimizer/optimizer.dart'; @@ -21,10 +21,10 @@ LabelsProcessor labelsProcessorMock; LabelsProcessorFactory labelsProcessorFactoryMock; InterceptPreprocessor interceptPreprocessorMock; InterceptPreprocessorFactory interceptPreprocessorFactoryMock; -LabelsProbabilityCalculator probabilityCalculatorMock; -LabelsProbabilityCalculatorFactory probabilityCalculatorFactoryMock; Optimizer optimizerMock; OptimizerFactory optimizerFactoryMock; +LinkFunctionFactory linkFunctionFactoryMock; +LinkFunction linkFunctionMock; void setUpLabelsProcessorFactory() { labelsProcessorMock = LabelsProcessorMock(); @@ -38,22 +38,20 @@ void setUpInterceptPreprocessorFactory() { preprocessor: interceptPreprocessorMock); } -void setUpProbabilityCalculatorFactory() { - probabilityCalculatorMock = LabelsProbabilityCalculatorMock(); - probabilityCalculatorFactoryMock = - createLabelsProbabilityCalculatorFactoryMock( - linkType: LinkFunctionType.logit, - dtype: Float32x4, - calculator: probabilityCalculatorMock, - ); -} - void setUpOptimizerFactory() { optimizerMock = OptimizerMock(); optimizerFactoryMock = createOptimizerFactoryMock( optimizers: {OptimizerType.gradientDescent: optimizerMock}); } +void setUpLinkFunctionFactory() { + linkFunctionMock = LinkFunctionMock(); + linkFunctionFactoryMock = + createLinkFunctionFactoryMock(Float32x4, linkFunctions: { + LinkFunctionType.logit: linkFunctionMock, + }); +} + LogisticRegressor createRegressor({ int iterationLimit = 100, double learningRate = 0.01, @@ -72,7 +70,7 @@ LogisticRegressor createRegressor({ labelsProcessorFactory: labelsProcessorFactoryMock, interceptPreprocessorFactory: interceptPreprocessorFactoryMock, linkFunctionType: LinkFunctionType.logit, - probabilityCalculatorFactory: probabilityCalculatorFactoryMock, + linkFunctionFactory: linkFunctionFactoryMock, optimizer: OptimizerType.gradientDescent, optimizerFactory: optimizerFactoryMock, gradientType: GradientType.stochastic, diff --git a/test/classifier/logistic_regressor_test.dart b/test/classifier/logistic_regressor_test.dart index 1fd1b692..6977c270 100644 --- a/test/classifier/logistic_regressor_test.dart +++ b/test/classifier/logistic_regressor_test.dart @@ -20,7 +20,7 @@ void main() { test('should initialize properly', () { setUpLabelsProcessorFactory(); setUpInterceptPreprocessorFactory(); - setUpProbabilityCalculatorFactory(); + setUpLinkFunctionFactory(); setUpOptimizerFactory(); createRegressor(); @@ -28,7 +28,7 @@ void main() { verify(labelsProcessorFactoryMock.create(Float32x4)).called(1); verify(interceptPreprocessorFactoryMock.create(Float32x4, scale: 0.0)) .called(1); - verify(probabilityCalculatorFactoryMock.create( + verify(linkFunctionFactoryMock.fromType( LinkFunctionType.logit, Float32x4)) .called(1); verify(optimizerFactoryMock.fromType( @@ -50,7 +50,7 @@ void main() { test('should make appropriate method calls when `fit` is called', () { setUpLabelsProcessorFactory(); setUpInterceptPreprocessorFactory(); - setUpProbabilityCalculatorFactory(); + setUpLinkFunctionFactory(); setUpOptimizerFactory(); final features = MLMatrix.from([ diff --git a/test/cost_function/cost_function_test.dart b/test/cost_function/cost_function_test.dart index 6c51d065..afc084b1 100644 --- a/test/cost_function/cost_function_test.dart +++ b/test/cost_function/cost_function_test.dart @@ -40,13 +40,14 @@ void main() { group('LogLikelihoodCost', () { final mockedLinkFn = LinkFunctionMock(); final linkFunctionFactoryMock = - createLinkFunctionFactoryMock(linkFunctions: { + createLinkFunctionFactoryMock(Float32x4, linkFunctions: { LinkFunctionType.logit: mockedLinkFn, }); final logLikelihoodCost = LogLikelihoodCost(LinkFunctionType.logit, linkFunctionFactory: linkFunctionFactoryMock); - when(mockedLinkFn.float32x4Link(any)).thenReturn(Float32x4.splat(1.0)); + when(mockedLinkFn.linkScoresToProbs(any)) + .thenReturn(MLVector.from([1.0, 1.0, 1.0])); test('should return a proper gradient vector', () { // The formula in matrix notation: diff --git a/test/score_to_prob_link_function/link_function_test.dart b/test/score_to_prob_link_function/link_function_test.dart index b9bb08c0..786f9251 100644 --- a/test/score_to_prob_link_function/link_function_test.dart +++ b/test/score_to_prob_link_function/link_function_test.dart @@ -1,19 +1,20 @@ import 'dart:typed_data'; import 'package:ml_algo/src/link_function/logit_link_function.dart'; +import 'package:ml_linalg/vector.dart'; import 'package:test/test.dart'; +import '../test_utils/helpers/floating_point_iterable_matchers.dart'; + void main() { - group('Vectorized logit link function', () { + group('Float32x4 logit link function', () { test('should properly translate score to probability', () { - final scores = Float32x4(1.0, 2.0, 3.0, 4.0); - final logitLink = LogitLinkFunction(); - final probabilities = logitLink.float32x4Link(scores); + final scores = MLVector.from([1.0, 2.0, 3.0, 4.0]); + final logitLink = LogitLinkFunction(Float32x4); + final probabilities = logitLink.linkScoresToProbs(scores); - expect(probabilities.x, inInclusiveRange(0.731, 0.732)); - expect(probabilities.y, inInclusiveRange(0.88, 0.881)); - expect(probabilities.z, inInclusiveRange(0.952, 0.953)); - expect(probabilities.w, inInclusiveRange(0.982, 0.983)); + expect(probabilities, + vectorAlmostEqualTo([0.73105, 0.88079, 0.9525, 0.98201], 1e-4)); }); }); } diff --git a/test/test_all.dart b/test/test_all.dart index d62f5c99..535c1260 100644 --- a/test/test_all.dart +++ b/test/test_all.dart @@ -1,23 +1,39 @@ -import 'classifier/logistic_regressor_integration_test.dart' as logistic_regressor_integration_test; +import 'classifier/logistic_regressor_integration_test.dart' + as logistic_regressor_integration_test; import 'classifier/logistic_regressor_test.dart' as logistic_regressor_test; import 'cost_function/cost_function_test.dart' as cost_function_test; -import 'data_preprocessing/categorical_encoder/category_values_extractor_impl_test.dart' as cat_value_extractor_test; -import 'data_preprocessing/categorical_encoder/one_hot_encoder_test.dart' as one_hot_encoder_test; -import 'data_preprocessing/categorical_encoder/ordinal_encoder_test.dart' as ordinal_encoder_test; -import 'data_preprocessing/ml_data/csv_ml_data_integration_test.dart' as csv_ml_data_integration_test; -import 'data_preprocessing/ml_data/csv_ml_data_with_categories_integration_test.dart' as csv_ml_data_with_cat_test; -import 'data_preprocessing/ml_data/ml_data_encoders_processor_impl_test.dart' as ml_data_enc_preprocessor_test; -import 'data_preprocessing/ml_data/ml_data_features_extractor_impl_test.dart' as ml_data_feature_extractor_test; -import 'data_preprocessing/ml_data/ml_data_labels_extractor_impl_test.dart' as ml_data_labels_extractor_test; -import 'data_preprocessing/ml_data/ml_data_params_validator_impl_test.dart' as ml_data_params_validator_test; -import 'data_preprocessing/ml_data/ml_data_read_mask_creator_impl_test.dart' as ml_data_read_mask_creator_test; -import 'data_preprocessing/intercept_preprocessor_test.dart' as intercept_preprocessor_test; +import 'data_preprocessing/categorical_encoder/category_values_extractor_impl_test.dart' + as cat_value_extractor_test; +import 'data_preprocessing/categorical_encoder/one_hot_encoder_test.dart' + as one_hot_encoder_test; +import 'data_preprocessing/categorical_encoder/ordinal_encoder_test.dart' + as ordinal_encoder_test; +import 'data_preprocessing/ml_data/csv_ml_data_integration_test.dart' + as csv_ml_data_integration_test; +import 'data_preprocessing/ml_data/csv_ml_data_with_categories_integration_test.dart' + as csv_ml_data_with_cat_test; +import 'data_preprocessing/ml_data/ml_data_encoders_processor_impl_test.dart' + as ml_data_enc_preprocessor_test; +import 'data_preprocessing/ml_data/ml_data_features_extractor_impl_test.dart' + as ml_data_feature_extractor_test; +import 'data_preprocessing/ml_data/ml_data_labels_extractor_impl_test.dart' + as ml_data_labels_extractor_test; +import 'data_preprocessing/ml_data/ml_data_params_validator_impl_test.dart' + as ml_data_params_validator_test; +import 'data_preprocessing/ml_data/ml_data_read_mask_creator_impl_test.dart' + as ml_data_read_mask_creator_test; +import 'data_preprocessing/intercept_preprocessor_test.dart' + as intercept_preprocessor_test; import 'data_splitter/data_splitter_test.dart' as data_splitter_test; import 'math/randomizer_test.dart' as randomizer_test; -import 'optimizer/coordinate/coordinate_optimizer_integration_test.dart' as coord_optimizer_integration_test; -import 'optimizer/gradient/gradient_optimizer_integration_test.dart' as gradient_optimizer_integration_test; -import 'optimizer/gradient/gradient_optimizer_test.dart' as gradient_optimizer_test; -import 'score_to_prob_link_function/link_function_test.dart' as link_function_test; +import 'optimizer/coordinate/coordinate_optimizer_integration_test.dart' + as coord_optimizer_integration_test; +import 'optimizer/gradient/gradient_optimizer_integration_test.dart' + as gradient_optimizer_integration_test; +import 'optimizer/gradient/gradient_optimizer_test.dart' + as gradient_optimizer_test; +import 'score_to_prob_link_function/link_function_test.dart' + as link_function_test; void main() { logistic_regressor_integration_test.main(); diff --git a/test/test_utils/mocks.dart b/test/test_utils/mocks.dart index 99c7c218..fcf83660 100644 --- a/test/test_utils/mocks.dart +++ b/test/test_utils/mocks.dart @@ -1,7 +1,5 @@ import 'package:logging/logging.dart'; import 'package:ml_algo/learning_rate_type.dart'; -import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator.dart'; -import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator_factory.dart'; import 'package:ml_algo/src/classifier/labels_processor/labels_processor.dart'; import 'package:ml_algo/src/classifier/labels_processor/labels_processor_factory.dart'; import 'package:ml_algo/src/cost_function/cost_function.dart'; @@ -80,12 +78,6 @@ class InterceptPreprocessorFactoryMock extends Mock class InterceptPreprocessorMock extends Mock implements InterceptPreprocessor {} -class LabelsProbabilityCalculatorFactoryMock extends Mock - implements LabelsProbabilityCalculatorFactory {} - -class LabelsProbabilityCalculatorMock extends Mock - implements LabelsProbabilityCalculator {} - class OptimizerFactoryMock extends Mock implements OptimizerFactory {} class OptimizerMock extends Mock implements Optimizer {} @@ -137,12 +129,13 @@ InitialWeightsGeneratorFactoryMock createInitialWeightsGeneratorFactoryMock({ return factory; } -LinkFunctionFactoryMock createLinkFunctionFactoryMock({ +LinkFunctionFactoryMock createLinkFunctionFactoryMock( + Type dtype, { Map linkFunctions, }) { final factory = LinkFunctionFactoryMock(); linkFunctions.forEach((LinkFunctionType type, LinkFunction fn) { - when(factory.fromType(type)).thenReturn(fn); + when(factory.fromType(type, dtype)).thenReturn(fn); }); return factory; } @@ -165,17 +158,6 @@ LabelsProcessorFactoryMock createLabelsProcessorFactoryMock({ return factory; } -LabelsProbabilityCalculatorFactoryMock - createLabelsProbabilityCalculatorFactoryMock({ - LinkFunctionType linkType, - Type dtype, - LabelsProbabilityCalculator calculator, -}) { - final factory = LabelsProbabilityCalculatorFactoryMock(); - when(factory.create(linkType, dtype)).thenReturn(calculator); - return factory; -} - OptimizerFactoryMock createOptimizerFactoryMock({ Map optimizers, }) { From 1554c2fab86d1013f9bf4179649ec72226a543b9 Mon Sep 17 00:00:00 2001 From: Ilya Gyrdymov Date: Tue, 12 Feb 2019 01:06:41 +0200 Subject: [PATCH 2/3] link function -> score to prob mapper --- lib/src/classifier/logistic_regressor.dart | 22 +++++++++-------- lib/src/classifier/softmax_regressor.dart | 22 +++++++++-------- .../cost_function/cost_function_factory.dart | 7 +++--- .../cost_function_factory_impl.dart | 10 ++++---- lib/src/cost_function/log_likelihood.dart | 21 +++++++++------- .../link_function/link_function_factory.dart | 6 ----- .../link_function_factory_impl.dart | 18 -------------- lib/src/link_function/link_function_type.dart | 3 --- lib/src/optimizer/gradient/gradient.dart | 6 ++--- lib/src/optimizer/optimizer_factory.dart | 6 ++--- lib/src/optimizer/optimizer_factory_impl.dart | 10 ++++---- .../logit_mapper.dart} | 6 ++--- .../score_to_prob_mapper.dart} | 2 +- .../score_to_prob_mapper_factory.dart | 6 +++++ .../score_to_prob_mapper_factory_impl.dart | 18 ++++++++++++++ .../score_to_prob_mapper_type.dart | 3 +++ .../classifier/logistic_regressor_common.dart | 24 +++++++++---------- test/classifier/logistic_regressor_test.dart | 12 +++++----- test/cost_function/cost_function_test.dart | 14 +++++------ test/optimizer/gradient/gradient_common.dart | 4 ++-- .../score_to_prob_mapper_test.dart} | 8 +++---- test/test_all.dart | 10 ++++---- test/test_utils/mocks.dart | 21 ++++++++-------- 23 files changed, 134 insertions(+), 125 deletions(-) delete mode 100644 lib/src/link_function/link_function_factory.dart delete mode 100644 lib/src/link_function/link_function_factory_impl.dart delete mode 100644 lib/src/link_function/link_function_type.dart rename lib/src/{link_function/logit_link_function.dart => score_to_prob_mapper/logit_mapper.dart} (86%) rename lib/src/{link_function/link_function.dart => score_to_prob_mapper/score_to_prob_mapper.dart} (90%) create mode 100644 lib/src/score_to_prob_mapper/score_to_prob_mapper_factory.dart create mode 100644 lib/src/score_to_prob_mapper/score_to_prob_mapper_factory_impl.dart create mode 100644 lib/src/score_to_prob_mapper/score_to_prob_mapper_type.dart rename test/{score_to_prob_link_function/link_function_test.dart => score_to_prob_mapper/score_to_prob_mapper_test.dart} (65%) diff --git a/lib/src/classifier/logistic_regressor.dart b/lib/src/classifier/logistic_regressor.dart index 3c2ab13e..756c8816 100644 --- a/lib/src/classifier/logistic_regressor.dart +++ b/lib/src/classifier/logistic_regressor.dart @@ -9,10 +9,6 @@ import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_ import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor_factory.dart'; import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor_factory_impl.dart'; import 'package:ml_algo/src/default_parameter_values.dart'; -import 'package:ml_algo/src/link_function/link_function.dart'; -import 'package:ml_algo/src/link_function/link_function_factory.dart'; -import 'package:ml_algo/src/link_function/link_function_factory_impl.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; import 'package:ml_algo/src/metric/factory.dart'; import 'package:ml_algo/src/metric/metric_type.dart'; import 'package:ml_algo/src/optimizer/gradient/batch_size_calculator/batch_size_calculator.dart'; @@ -22,6 +18,10 @@ import 'package:ml_algo/src/optimizer/optimizer.dart'; import 'package:ml_algo/src/optimizer/optimizer_factory.dart'; import 'package:ml_algo/src/optimizer/optimizer_factory_impl.dart'; import 'package:ml_algo/src/optimizer/optimizer_type.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_factory.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_factory_impl.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; import 'package:ml_linalg/matrix.dart'; import 'package:ml_linalg/vector.dart'; @@ -30,7 +30,7 @@ class LogisticRegressor implements LinearClassifier { final Optimizer optimizer; final InterceptPreprocessor interceptPreprocessor; final LabelsProcessor labelsProcessor; - final LinkFunction linkFunction; + final ScoreToProbMapper scoreToProbMapper; LogisticRegressor({ // public arguments @@ -46,7 +46,7 @@ class LogisticRegressor implements LinearClassifier { GradientType gradientType = GradientType.stochastic, LearningRateType learningRateType = LearningRateType.constant, InitialWeightsType initialWeightsType = InitialWeightsType.zeroes, - LinkFunctionType linkFunctionType = LinkFunctionType.logit, + ScoreToProbMapperType scoreToProbMapperType = ScoreToProbMapperType.logit, this.dtype = DefaultParameterValues.dtype, // private arguments @@ -54,18 +54,20 @@ class LogisticRegressor implements LinearClassifier { const LabelsProcessorFactoryImpl(), InterceptPreprocessorFactory interceptPreprocessorFactory = const InterceptPreprocessorFactoryImpl(), - LinkFunctionFactory linkFunctionFactory = const LinkFunctionFactoryImpl(), + ScoreToProbMapperFactory scoreToProbMapperFactory = + const ScoreToProbMapperFactoryImpl(), OptimizerFactory optimizerFactory = const OptimizerFactoryImpl(), BatchSizeCalculator batchSizeCalculator = const BatchSizeCalculatorImpl(), }) : labelsProcessor = labelsProcessorFactory.create(dtype), interceptPreprocessor = interceptPreprocessorFactory.create(dtype, scale: fitIntercept ? interceptScale : 0.0), - linkFunction = linkFunctionFactory.fromType(linkFunctionType, dtype), + scoreToProbMapper = + scoreToProbMapperFactory.fromType(scoreToProbMapperType, dtype), optimizer = optimizerFactory.fromType( optimizer, dtype: dtype, costFunctionType: CostFunctionType.logLikelihood, - linkFunctionType: linkFunctionType, + scoreToProbMapperType: scoreToProbMapperType, learningRateType: learningRateType, initialWeightsType: initialWeightsType, initialLearningRate: initialLearningRate, @@ -134,7 +136,7 @@ class LogisticRegressor implements LinearClassifier { int i = 0; _weightsByClasses.forEach((double label, MLVector weights) { final scores = (processedFeatures * weights).toVector(); - distributions[i++] = linkFunction.linkScoresToProbs(scores); + distributions[i++] = scoreToProbMapper.linkScoresToProbs(scores); }); return MLMatrix.columns(distributions, dtype: dtype); } diff --git a/lib/src/classifier/softmax_regressor.dart b/lib/src/classifier/softmax_regressor.dart index c6742568..bf9058cc 100644 --- a/lib/src/classifier/softmax_regressor.dart +++ b/lib/src/classifier/softmax_regressor.dart @@ -9,10 +9,6 @@ import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_ import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor_factory.dart'; import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor_factory_impl.dart'; import 'package:ml_algo/src/default_parameter_values.dart'; -import 'package:ml_algo/src/link_function/link_function.dart'; -import 'package:ml_algo/src/link_function/link_function_factory.dart'; -import 'package:ml_algo/src/link_function/link_function_factory_impl.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; import 'package:ml_algo/src/metric/factory.dart'; import 'package:ml_algo/src/metric/metric_type.dart'; import 'package:ml_algo/src/optimizer/gradient/batch_size_calculator/batch_size_calculator.dart'; @@ -22,6 +18,10 @@ import 'package:ml_algo/src/optimizer/optimizer.dart'; import 'package:ml_algo/src/optimizer/optimizer_factory.dart'; import 'package:ml_algo/src/optimizer/optimizer_factory_impl.dart'; import 'package:ml_algo/src/optimizer/optimizer_type.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_factory.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_factory_impl.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; import 'package:ml_linalg/matrix.dart'; import 'package:ml_linalg/vector.dart'; @@ -30,7 +30,7 @@ class SoftMaxRegressor implements LinearClassifier { final Optimizer optimizer; final InterceptPreprocessor interceptPreprocessor; final LabelsProcessor labelsProcessor; - final LinkFunction linkFunction; + final ScoreToProbMapper scoreToProbMapper; SoftMaxRegressor({ // public arguments @@ -46,7 +46,7 @@ class SoftMaxRegressor implements LinearClassifier { GradientType gradientType = GradientType.stochastic, LearningRateType learningRateType = LearningRateType.constant, InitialWeightsType initialWeightsType = InitialWeightsType.zeroes, - LinkFunctionType linkFunctionType = LinkFunctionType.logit, + ScoreToProbMapperType scoreToProbMapperType = ScoreToProbMapperType.logit, this.dtype = DefaultParameterValues.dtype, // private arguments @@ -54,18 +54,20 @@ class SoftMaxRegressor implements LinearClassifier { const LabelsProcessorFactoryImpl(), InterceptPreprocessorFactory interceptPreprocessorFactory = const InterceptPreprocessorFactoryImpl(), - LinkFunctionFactory linkFunctionFactory = const LinkFunctionFactoryImpl(), + ScoreToProbMapperFactory scoreToProbMapperFactory = + const ScoreToProbMapperFactoryImpl(), OptimizerFactory optimizerFactory = const OptimizerFactoryImpl(), BatchSizeCalculator batchSizeCalculator = const BatchSizeCalculatorImpl(), }) : labelsProcessor = labelsProcessorFactory.create(dtype), interceptPreprocessor = interceptPreprocessorFactory.create(dtype, scale: fitIntercept ? interceptScale : 0.0), - linkFunction = linkFunctionFactory.fromType(linkFunctionType, dtype), + scoreToProbMapper = + scoreToProbMapperFactory.fromType(scoreToProbMapperType, dtype), optimizer = optimizerFactory.fromType( optimizer, dtype: dtype, costFunctionType: CostFunctionType.logLikelihood, - linkFunctionType: linkFunctionType, + scoreToProbMapperType: scoreToProbMapperType, learningRateType: learningRateType, initialWeightsType: initialWeightsType, initialLearningRate: initialLearningRate, @@ -134,7 +136,7 @@ class SoftMaxRegressor implements LinearClassifier { int i = 0; _weightsByClasses.forEach((double label, MLVector weights) { final scores = (processedFeatures * weights).toVector(); - distributions[i++] = linkFunction.linkScoresToProbs(scores); + distributions[i++] = scoreToProbMapper.linkScoresToProbs(scores); }); return MLMatrix.columns(distributions, dtype: dtype); } diff --git a/lib/src/cost_function/cost_function_factory.dart b/lib/src/cost_function/cost_function_factory.dart index c4f8e118..34a583ec 100644 --- a/lib/src/cost_function/cost_function_factory.dart +++ b/lib/src/cost_function/cost_function_factory.dart @@ -1,10 +1,11 @@ import 'package:ml_algo/src/cost_function/cost_function.dart'; import 'package:ml_algo/src/cost_function/cost_function_type.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; abstract class CostFunctionFactory { CostFunction fromType(CostFunctionType type, - {Type dtype, LinkFunctionType linkFunctionType}); + {Type dtype, ScoreToProbMapperType scoreToProbMapperType}); CostFunction squared(); - CostFunction logLikelihood(LinkFunctionType linkFunctionType, {Type dtype}); + CostFunction logLikelihood(ScoreToProbMapperType scoreToProbMapperType, + {Type dtype}); } diff --git a/lib/src/cost_function/cost_function_factory_impl.dart b/lib/src/cost_function/cost_function_factory_impl.dart index dd55e9ed..7a66f2fd 100644 --- a/lib/src/cost_function/cost_function_factory_impl.dart +++ b/lib/src/cost_function/cost_function_factory_impl.dart @@ -4,7 +4,7 @@ import 'package:ml_algo/src/cost_function/cost_function_type.dart'; import 'package:ml_algo/src/cost_function/log_likelihood.dart'; import 'package:ml_algo/src/cost_function/squared.dart'; import 'package:ml_algo/src/default_parameter_values.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; class CostFunctionFactoryImpl implements CostFunctionFactory { const CostFunctionFactoryImpl(); @@ -13,17 +13,17 @@ class CostFunctionFactoryImpl implements CostFunctionFactory { CostFunction squared() => SquaredCost(); @override - CostFunction logLikelihood(LinkFunctionType linkFunctionType, + CostFunction logLikelihood(ScoreToProbMapperType scoreToProbMapperType, {Type dtype = DefaultParameterValues.dtype}) => - LogLikelihoodCost(linkFunctionType, dtype: dtype); + LogLikelihoodCost(scoreToProbMapperType, dtype: dtype); @override CostFunction fromType(CostFunctionType type, {Type dtype = DefaultParameterValues.dtype, - LinkFunctionType linkFunctionType}) { + ScoreToProbMapperType scoreToProbMapperType}) { switch (type) { case CostFunctionType.logLikelihood: - return logLikelihood(linkFunctionType, dtype: dtype); + return logLikelihood(scoreToProbMapperType, dtype: dtype); case CostFunctionType.squared: return squared(); default: diff --git a/lib/src/cost_function/log_likelihood.dart b/lib/src/cost_function/log_likelihood.dart index 79830a08..9a84fbc5 100644 --- a/lib/src/cost_function/log_likelihood.dart +++ b/lib/src/cost_function/log_likelihood.dart @@ -2,21 +2,23 @@ import 'dart:typed_data'; import 'package:ml_algo/src/cost_function/cost_function.dart'; import 'package:ml_algo/src/default_parameter_values.dart'; -import 'package:ml_algo/src/link_function/link_function.dart'; -import 'package:ml_algo/src/link_function/link_function_factory.dart'; -import 'package:ml_algo/src/link_function/link_function_factory_impl.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_factory.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_factory_impl.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; import 'package:ml_linalg/linalg.dart'; class LogLikelihoodCost implements CostFunction { - final LinkFunction linkFunction; + final ScoreToProbMapper scoreToProbMapper; final Type dtype; LogLikelihoodCost( - LinkFunctionType linkFunctionType, { + ScoreToProbMapperType scoreToProbMapperType, { this.dtype = DefaultParameterValues.dtype, - LinkFunctionFactory linkFunctionFactory = const LinkFunctionFactoryImpl(), - }) : linkFunction = linkFunctionFactory.fromType(linkFunctionType, dtype); + ScoreToProbMapperFactory scoreToProbMapperFactory = + const ScoreToProbMapperFactoryImpl(), + }) : scoreToProbMapper = + scoreToProbMapperFactory.fromType(scoreToProbMapperType, dtype); @override double getCost(double score, double yOrig) { @@ -28,7 +30,8 @@ class LogLikelihoodCost implements CostFunction { final scores = (x * w).toVector(); switch (dtype) { case Float32x4: - return (x.transpose() * (y - linkFunction.linkScoresToProbs(scores))) + return (x.transpose() * + (y - scoreToProbMapper.linkScoresToProbs(scores))) .toVector(); default: throw throw UnsupportedError('Unsupported data type - $dtype'); diff --git a/lib/src/link_function/link_function_factory.dart b/lib/src/link_function/link_function_factory.dart deleted file mode 100644 index 522ecbf0..00000000 --- a/lib/src/link_function/link_function_factory.dart +++ /dev/null @@ -1,6 +0,0 @@ -import 'package:ml_algo/src/link_function/link_function.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; - -abstract class LinkFunctionFactory { - LinkFunction fromType(LinkFunctionType type, Type dtype); -} diff --git a/lib/src/link_function/link_function_factory_impl.dart b/lib/src/link_function/link_function_factory_impl.dart deleted file mode 100644 index 8dd2c696..00000000 --- a/lib/src/link_function/link_function_factory_impl.dart +++ /dev/null @@ -1,18 +0,0 @@ -import 'package:ml_algo/src/link_function/link_function.dart'; -import 'package:ml_algo/src/link_function/link_function_factory.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; -import 'package:ml_algo/src/link_function/logit_link_function.dart'; - -class LinkFunctionFactoryImpl implements LinkFunctionFactory { - const LinkFunctionFactoryImpl(); - - @override - LinkFunction fromType(LinkFunctionType type, Type dtype) { - switch (type) { - case LinkFunctionType.logit: - return LogitLinkFunction(dtype); - default: - throw UnsupportedError('Unsupported link function type - $type'); - } - } -} diff --git a/lib/src/link_function/link_function_type.dart b/lib/src/link_function/link_function_type.dart deleted file mode 100644 index dc6d56a5..00000000 --- a/lib/src/link_function/link_function_type.dart +++ /dev/null @@ -1,3 +0,0 @@ -enum LinkFunctionType { - logit, -} diff --git a/lib/src/optimizer/gradient/gradient.dart b/lib/src/optimizer/gradient/gradient.dart index 24f5cbd9..662a486a 100644 --- a/lib/src/optimizer/gradient/gradient.dart +++ b/lib/src/optimizer/gradient/gradient.dart @@ -3,7 +3,6 @@ import 'package:ml_algo/src/cost_function/cost_function_factory.dart'; import 'package:ml_algo/src/cost_function/cost_function_factory_impl.dart'; import 'package:ml_algo/src/cost_function/cost_function_type.dart'; import 'package:ml_algo/src/default_parameter_values.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; import 'package:ml_algo/src/math/randomizer/randomizer.dart'; import 'package:ml_algo/src/math/randomizer/randomizer_factory.dart'; import 'package:ml_algo/src/math/randomizer/randomizer_factory_impl.dart'; @@ -16,6 +15,7 @@ import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_ import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_generator_factory_impl.dart'; import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_type.dart'; import 'package:ml_algo/src/optimizer/optimizer.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; import 'package:ml_linalg/matrix.dart'; import 'package:ml_linalg/range.dart'; import 'package:ml_linalg/vector.dart'; @@ -46,7 +46,7 @@ class GradientOptimizer implements Optimizer { CostFunctionType costFnType, LearningRateType learningRateType, InitialWeightsType initialWeightsType, - LinkFunctionType linkFunctionType, + ScoreToProbMapperType scoreToProbMapperType, double initialLearningRate = DefaultParameterValues.initialLearningRate, double minWeightsUpdate = DefaultParameterValues.minWeightsUpdate, int iterationLimit = DefaultParameterValues.iterationsLimit, @@ -62,7 +62,7 @@ class GradientOptimizer implements Optimizer { _learningRateGenerator = learningRateGeneratorFactory.fromType(learningRateType), _costFunction = costFunctionFactory.fromType(costFnType, - dtype: dtype, linkFunctionType: linkFunctionType), + dtype: dtype, scoreToProbMapperType: scoreToProbMapperType), _randomizer = randomizerFactory.create(randomSeed) { _learningRateGenerator.init(initialLearningRate ?? 1.0); } diff --git a/lib/src/optimizer/optimizer_factory.dart b/lib/src/optimizer/optimizer_factory.dart index b5c64a17..f66956c7 100644 --- a/lib/src/optimizer/optimizer_factory.dart +++ b/lib/src/optimizer/optimizer_factory.dart @@ -1,13 +1,13 @@ import 'package:ml_algo/learning_rate_type.dart'; import 'package:ml_algo/src/cost_function/cost_function_factory.dart'; import 'package:ml_algo/src/cost_function/cost_function_type.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; import 'package:ml_algo/src/math/randomizer/randomizer_factory.dart'; import 'package:ml_algo/src/optimizer/gradient/learning_rate_generator/learning_rate_generator_factory.dart'; import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_generator_factory.dart'; import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_type.dart'; import 'package:ml_algo/src/optimizer/optimizer.dart'; import 'package:ml_algo/src/optimizer/optimizer_type.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; abstract class OptimizerFactory { Optimizer fromType( @@ -20,7 +20,7 @@ abstract class OptimizerFactory { CostFunctionType costFunctionType, LearningRateType learningRateType, InitialWeightsType initialWeightsType, - LinkFunctionType linkFunctionType, + ScoreToProbMapperType scoreToProbMapperType, double initialLearningRate, double minCoefficientsUpdate, int iterationLimit, @@ -38,7 +38,7 @@ abstract class OptimizerFactory { CostFunctionType costFnType, LearningRateType learningRateType, InitialWeightsType initialWeightsType, - LinkFunctionType linkFunctionType, + ScoreToProbMapperType scoreToProbMapperType, double initialLearningRate, double minCoefficientsUpdate, int iterationLimit, diff --git a/lib/src/optimizer/optimizer_factory_impl.dart b/lib/src/optimizer/optimizer_factory_impl.dart index 6d06284b..e1f0d7e3 100644 --- a/lib/src/optimizer/optimizer_factory_impl.dart +++ b/lib/src/optimizer/optimizer_factory_impl.dart @@ -3,7 +3,6 @@ import 'dart:typed_data'; import 'package:ml_algo/src/cost_function/cost_function_factory.dart'; import 'package:ml_algo/src/cost_function/cost_function_factory_impl.dart'; import 'package:ml_algo/src/cost_function/cost_function_type.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; import 'package:ml_algo/src/math/randomizer/randomizer_factory.dart'; import 'package:ml_algo/src/math/randomizer/randomizer_factory_impl.dart'; import 'package:ml_algo/src/optimizer/coordinate/coordinate.dart'; @@ -17,6 +16,7 @@ import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_ import 'package:ml_algo/src/optimizer/optimizer.dart'; import 'package:ml_algo/src/optimizer/optimizer_factory.dart'; import 'package:ml_algo/src/optimizer/optimizer_type.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; class OptimizerFactoryImpl implements OptimizerFactory { const OptimizerFactoryImpl(); @@ -34,7 +34,7 @@ class OptimizerFactoryImpl implements OptimizerFactory { CostFunctionType costFunctionType, LearningRateType learningRateType, InitialWeightsType initialWeightsType, - LinkFunctionType linkFunctionType, + ScoreToProbMapperType scoreToProbMapperType, double initialLearningRate, double minCoefficientsUpdate, int iterationLimit, @@ -65,7 +65,7 @@ class OptimizerFactoryImpl implements OptimizerFactory { costFnType: costFunctionType, learningRateType: learningRateType, initialWeightsType: initialWeightsType, - linkFunctionType: linkFunctionType, + scoreToProbMapperType: scoreToProbMapperType, initialLearningRate: initialLearningRate, minCoefficientsUpdate: minCoefficientsUpdate, iterationLimit: iterationLimit, @@ -114,7 +114,7 @@ class OptimizerFactoryImpl implements OptimizerFactory { CostFunctionType costFnType, LearningRateType learningRateType, InitialWeightsType initialWeightsType, - LinkFunctionType linkFunctionType, + ScoreToProbMapperType scoreToProbMapperType, double initialLearningRate, double minCoefficientsUpdate, int iterationLimit, @@ -130,7 +130,7 @@ class OptimizerFactoryImpl implements OptimizerFactory { costFnType: costFnType, learningRateType: learningRateType, initialWeightsType: initialWeightsType, - linkFunctionType: linkFunctionType, + scoreToProbMapperType: scoreToProbMapperType, initialLearningRate: initialLearningRate, minWeightsUpdate: minCoefficientsUpdate, iterationLimit: iterationLimit, diff --git a/lib/src/link_function/logit_link_function.dart b/lib/src/score_to_prob_mapper/logit_mapper.dart similarity index 86% rename from lib/src/link_function/logit_link_function.dart rename to lib/src/score_to_prob_mapper/logit_mapper.dart index db10c541..020edc88 100644 --- a/lib/src/link_function/logit_link_function.dart +++ b/lib/src/score_to_prob_mapper/logit_mapper.dart @@ -1,17 +1,17 @@ import 'dart:math' as math; import 'dart:typed_data'; -import 'package:ml_algo/src/link_function/link_function.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper.dart'; import 'package:ml_linalg/matrix.dart'; import 'package:ml_linalg/vector.dart'; -class LogitLinkFunction implements LinkFunction { +class LogitMapper implements ScoreToProbMapper { final Type dtype; final float32x4Zeroes = Float32x4.zero(); final float32x4Ones = Float32x4.splat(1.0); - LogitLinkFunction(this.dtype); + LogitMapper(this.dtype); @override MLVector linkScoresToProbs(MLVector scores, [MLMatrix scoresByClasses]) { diff --git a/lib/src/link_function/link_function.dart b/lib/src/score_to_prob_mapper/score_to_prob_mapper.dart similarity index 90% rename from lib/src/link_function/link_function.dart rename to lib/src/score_to_prob_mapper/score_to_prob_mapper.dart index 1bd0ab7d..e3d47ca1 100644 --- a/lib/src/link_function/link_function.dart +++ b/lib/src/score_to_prob_mapper/score_to_prob_mapper.dart @@ -1,7 +1,7 @@ import 'package:ml_linalg/matrix.dart'; import 'package:ml_linalg/vector.dart'; -abstract class LinkFunction { +abstract class ScoreToProbMapper { /// Accepts a vector of scores, returns a vector of probabilities /// Score is a multiplication of a feature value and the corresponding weight (coefficient) MLVector linkScoresToProbs(MLVector scores, [MLMatrix scoresByClasses]); diff --git a/lib/src/score_to_prob_mapper/score_to_prob_mapper_factory.dart b/lib/src/score_to_prob_mapper/score_to_prob_mapper_factory.dart new file mode 100644 index 00000000..7202e7a0 --- /dev/null +++ b/lib/src/score_to_prob_mapper/score_to_prob_mapper_factory.dart @@ -0,0 +1,6 @@ +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper.dart'; + +abstract class ScoreToProbMapperFactory { + ScoreToProbMapper fromType(ScoreToProbMapperType type, Type dtype); +} diff --git a/lib/src/score_to_prob_mapper/score_to_prob_mapper_factory_impl.dart b/lib/src/score_to_prob_mapper/score_to_prob_mapper_factory_impl.dart new file mode 100644 index 00000000..462570e0 --- /dev/null +++ b/lib/src/score_to_prob_mapper/score_to_prob_mapper_factory_impl.dart @@ -0,0 +1,18 @@ +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_factory.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/logit_mapper.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper.dart'; + +class ScoreToProbMapperFactoryImpl implements ScoreToProbMapperFactory { + const ScoreToProbMapperFactoryImpl(); + + @override + ScoreToProbMapper fromType(ScoreToProbMapperType type, Type dtype) { + switch (type) { + case ScoreToProbMapperType.logit: + return LogitMapper(dtype); + default: + throw UnsupportedError('Unsupported link function type - $type'); + } + } +} diff --git a/lib/src/score_to_prob_mapper/score_to_prob_mapper_type.dart b/lib/src/score_to_prob_mapper/score_to_prob_mapper_type.dart new file mode 100644 index 00000000..cc3ab38f --- /dev/null +++ b/lib/src/score_to_prob_mapper/score_to_prob_mapper_type.dart @@ -0,0 +1,3 @@ +enum ScoreToProbMapperType { + logit, +} diff --git a/test/classifier/logistic_regressor_common.dart b/test/classifier/logistic_regressor_common.dart index af3a673f..cd63bc15 100644 --- a/test/classifier/logistic_regressor_common.dart +++ b/test/classifier/logistic_regressor_common.dart @@ -7,13 +7,13 @@ import 'package:ml_algo/src/classifier/labels_processor/labels_processor_factory import 'package:ml_algo/src/classifier/logistic_regressor.dart'; import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor.dart'; import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor_factory.dart'; -import 'package:ml_algo/src/link_function/link_function.dart'; -import 'package:ml_algo/src/link_function/link_function_factory.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_type.dart'; import 'package:ml_algo/src/optimizer/optimizer.dart'; import 'package:ml_algo/src/optimizer/optimizer_factory.dart'; import 'package:ml_algo/src/optimizer/optimizer_type.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_factory.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; import '../test_utils/mocks.dart'; @@ -23,8 +23,8 @@ InterceptPreprocessor interceptPreprocessorMock; InterceptPreprocessorFactory interceptPreprocessorFactoryMock; Optimizer optimizerMock; OptimizerFactory optimizerFactoryMock; -LinkFunctionFactory linkFunctionFactoryMock; -LinkFunction linkFunctionMock; +ScoreToProbMapperFactory scoreToProbFactoryMock; +ScoreToProbMapper scoreToProbMapperMock; void setUpLabelsProcessorFactory() { labelsProcessorMock = LabelsProcessorMock(); @@ -44,11 +44,11 @@ void setUpOptimizerFactory() { optimizers: {OptimizerType.gradientDescent: optimizerMock}); } -void setUpLinkFunctionFactory() { - linkFunctionMock = LinkFunctionMock(); - linkFunctionFactoryMock = - createLinkFunctionFactoryMock(Float32x4, linkFunctions: { - LinkFunctionType.logit: linkFunctionMock, +void setUpScoreToProbMapperFactory() { + scoreToProbMapperMock = ScoreToProbMapperMock(); + scoreToProbFactoryMock = + createScoreToProbMapperFactoryMock(Float32x4, mappers: { + ScoreToProbMapperType.logit: scoreToProbMapperMock, }); } @@ -69,8 +69,8 @@ LogisticRegressor createRegressor({ lambda: lambda, labelsProcessorFactory: labelsProcessorFactoryMock, interceptPreprocessorFactory: interceptPreprocessorFactoryMock, - linkFunctionType: LinkFunctionType.logit, - linkFunctionFactory: linkFunctionFactoryMock, + scoreToProbMapperType: ScoreToProbMapperType.logit, + scoreToProbMapperFactory: scoreToProbFactoryMock, optimizer: OptimizerType.gradientDescent, optimizerFactory: optimizerFactoryMock, gradientType: GradientType.stochastic, diff --git a/test/classifier/logistic_regressor_test.dart b/test/classifier/logistic_regressor_test.dart index 6977c270..054524f3 100644 --- a/test/classifier/logistic_regressor_test.dart +++ b/test/classifier/logistic_regressor_test.dart @@ -2,9 +2,9 @@ import 'dart:typed_data'; import 'package:ml_algo/learning_rate_type.dart'; import 'package:ml_algo/src/cost_function/cost_function_type.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_type.dart'; import 'package:ml_algo/src/optimizer/optimizer_type.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; import 'package:ml_linalg/matrix.dart'; import 'package:ml_linalg/vector.dart'; import 'package:mockito/mockito.dart'; @@ -20,7 +20,7 @@ void main() { test('should initialize properly', () { setUpLabelsProcessorFactory(); setUpInterceptPreprocessorFactory(); - setUpLinkFunctionFactory(); + setUpScoreToProbMapperFactory(); setUpOptimizerFactory(); createRegressor(); @@ -28,8 +28,8 @@ void main() { verify(labelsProcessorFactoryMock.create(Float32x4)).called(1); verify(interceptPreprocessorFactoryMock.create(Float32x4, scale: 0.0)) .called(1); - verify(linkFunctionFactoryMock.fromType( - LinkFunctionType.logit, Float32x4)) + verify(scoreToProbFactoryMock.fromType( + ScoreToProbMapperType.logit, Float32x4)) .called(1); verify(optimizerFactoryMock.fromType( OptimizerType.gradientDescent, @@ -37,7 +37,7 @@ void main() { costFunctionType: CostFunctionType.logLikelihood, learningRateType: LearningRateType.constant, initialWeightsType: InitialWeightsType.zeroes, - linkFunctionType: LinkFunctionType.logit, + scoreToProbMapperType: ScoreToProbMapperType.logit, initialLearningRate: 0.01, minCoefficientsUpdate: 0.001, iterationLimit: 100, @@ -50,7 +50,7 @@ void main() { test('should make appropriate method calls when `fit` is called', () { setUpLabelsProcessorFactory(); setUpInterceptPreprocessorFactory(); - setUpLinkFunctionFactory(); + setUpScoreToProbMapperFactory(); setUpOptimizerFactory(); final features = MLMatrix.from([ diff --git a/test/cost_function/cost_function_test.dart b/test/cost_function/cost_function_test.dart index afc084b1..29b1fef3 100644 --- a/test/cost_function/cost_function_test.dart +++ b/test/cost_function/cost_function_test.dart @@ -2,7 +2,7 @@ import 'dart:typed_data'; import 'package:ml_algo/src/cost_function/log_likelihood.dart'; import 'package:ml_algo/src/cost_function/squared.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; import 'package:ml_linalg/linalg.dart'; import 'package:mockito/mockito.dart'; import 'package:test/test.dart'; @@ -38,13 +38,13 @@ void main() { }); group('LogLikelihoodCost', () { - final mockedLinkFn = LinkFunctionMock(); - final linkFunctionFactoryMock = - createLinkFunctionFactoryMock(Float32x4, linkFunctions: { - LinkFunctionType.logit: mockedLinkFn, + final mockedLinkFn = ScoreToProbMapperMock(); + final scoreToProbMapperFactoryMock = + createScoreToProbMapperFactoryMock(Float32x4, mappers: { + ScoreToProbMapperType.logit: mockedLinkFn, }); - final logLikelihoodCost = LogLikelihoodCost(LinkFunctionType.logit, - linkFunctionFactory: linkFunctionFactoryMock); + final logLikelihoodCost = LogLikelihoodCost(ScoreToProbMapperType.logit, + scoreToProbMapperFactory: scoreToProbMapperFactoryMock); when(mockedLinkFn.linkScoresToProbs(any)) .thenReturn(MLVector.from([1.0, 1.0, 1.0])); diff --git a/test/optimizer/gradient/gradient_common.dart b/test/optimizer/gradient/gradient_common.dart index 9d13d138..ccb83746 100644 --- a/test/optimizer/gradient/gradient_common.dart +++ b/test/optimizer/gradient/gradient_common.dart @@ -43,7 +43,7 @@ GradientOptimizer createOptimizer( costFunctionMock = CostFunctionMock(); costFunctionFactoryMock = CostFunctionFactoryMock(); when(costFunctionFactoryMock.fromType(CostFunctionType.squared, - dtype: Float32x4, linkFunctionType: null)) + dtype: Float32x4, scoreToProbMapperType: null)) .thenReturn(costFunctionMock); final randomizerFactoryMock = RandomizerFactoryMock(); @@ -69,7 +69,7 @@ GradientOptimizer createOptimizer( batchSize: batchSize); verify(costFunctionFactoryMock.fromType(CostFunctionType.squared, - dtype: Float32x4, linkFunctionType: null)); + dtype: Float32x4, scoreToProbMapperType: null)); return opt; } diff --git a/test/score_to_prob_link_function/link_function_test.dart b/test/score_to_prob_mapper/score_to_prob_mapper_test.dart similarity index 65% rename from test/score_to_prob_link_function/link_function_test.dart rename to test/score_to_prob_mapper/score_to_prob_mapper_test.dart index 786f9251..74e53b4f 100644 --- a/test/score_to_prob_link_function/link_function_test.dart +++ b/test/score_to_prob_mapper/score_to_prob_mapper_test.dart @@ -1,16 +1,16 @@ import 'dart:typed_data'; -import 'package:ml_algo/src/link_function/logit_link_function.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/logit_mapper.dart'; import 'package:ml_linalg/vector.dart'; import 'package:test/test.dart'; import '../test_utils/helpers/floating_point_iterable_matchers.dart'; void main() { - group('Float32x4 logit link function', () { - test('should properly translate score to probability', () { + group('LogitMapper', () { + test('should properly translate scores to probabilities for Float32x4', () { final scores = MLVector.from([1.0, 2.0, 3.0, 4.0]); - final logitLink = LogitLinkFunction(Float32x4); + final logitLink = LogitMapper(Float32x4); final probabilities = logitLink.linkScoresToProbs(scores); expect(probabilities, diff --git a/test/test_all.dart b/test/test_all.dart index 535c1260..d58d37c8 100644 --- a/test/test_all.dart +++ b/test/test_all.dart @@ -8,6 +8,8 @@ import 'data_preprocessing/categorical_encoder/one_hot_encoder_test.dart' as one_hot_encoder_test; import 'data_preprocessing/categorical_encoder/ordinal_encoder_test.dart' as ordinal_encoder_test; +import 'data_preprocessing/intercept_preprocessor_test.dart' + as intercept_preprocessor_test; import 'data_preprocessing/ml_data/csv_ml_data_integration_test.dart' as csv_ml_data_integration_test; import 'data_preprocessing/ml_data/csv_ml_data_with_categories_integration_test.dart' @@ -22,8 +24,6 @@ import 'data_preprocessing/ml_data/ml_data_params_validator_impl_test.dart' as ml_data_params_validator_test; import 'data_preprocessing/ml_data/ml_data_read_mask_creator_impl_test.dart' as ml_data_read_mask_creator_test; -import 'data_preprocessing/intercept_preprocessor_test.dart' - as intercept_preprocessor_test; import 'data_splitter/data_splitter_test.dart' as data_splitter_test; import 'math/randomizer_test.dart' as randomizer_test; import 'optimizer/coordinate/coordinate_optimizer_integration_test.dart' @@ -32,8 +32,8 @@ import 'optimizer/gradient/gradient_optimizer_integration_test.dart' as gradient_optimizer_integration_test; import 'optimizer/gradient/gradient_optimizer_test.dart' as gradient_optimizer_test; -import 'score_to_prob_link_function/link_function_test.dart' - as link_function_test; +import 'score_to_prob_mapper/score_to_prob_mapper_test.dart' + as score_to_prob_mapper_test; void main() { logistic_regressor_integration_test.main(); @@ -55,5 +55,5 @@ void main() { coord_optimizer_integration_test.main(); gradient_optimizer_integration_test.main(); gradient_optimizer_test.main(); - link_function_test.main(); + score_to_prob_mapper_test.main(); } diff --git a/test/test_utils/mocks.dart b/test/test_utils/mocks.dart index fcf83660..0c0c0b97 100644 --- a/test/test_utils/mocks.dart +++ b/test/test_utils/mocks.dart @@ -12,9 +12,6 @@ import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_ import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor_factory.dart'; import 'package:ml_algo/src/data_preprocessing/ml_data/validator/ml_data_params_validator.dart'; import 'package:ml_algo/src/data_preprocessing/ml_data/value_converter/value_converter.dart'; -import 'package:ml_algo/src/link_function/link_function.dart'; -import 'package:ml_algo/src/link_function/link_function_factory.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; import 'package:ml_algo/src/math/randomizer/randomizer.dart'; import 'package:ml_algo/src/math/randomizer/randomizer_factory.dart'; import 'package:ml_algo/src/optimizer/gradient/learning_rate_generator/learning_rate_generator.dart'; @@ -25,6 +22,9 @@ import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_ import 'package:ml_algo/src/optimizer/optimizer.dart'; import 'package:ml_algo/src/optimizer/optimizer_factory.dart'; import 'package:ml_algo/src/optimizer/optimizer_type.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_factory.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; import 'package:mockito/mockito.dart'; class EncoderMock extends Mock implements CategoricalDataEncoder {} @@ -64,9 +64,10 @@ class InitialWeightsGeneratorFactoryMock extends Mock class InitialWeightsGeneratorMock extends Mock implements InitialWeightsGenerator {} -class LinkFunctionMock extends Mock implements LinkFunction {} +class ScoreToProbMapperMock extends Mock implements ScoreToProbMapper {} -class LinkFunctionFactoryMock extends Mock implements LinkFunctionFactory {} +class ScoreToProbMapperFactoryMock extends Mock + implements ScoreToProbMapperFactory {} class LabelsProcessorFactoryMock extends Mock implements LabelsProcessorFactory {} @@ -129,12 +130,12 @@ InitialWeightsGeneratorFactoryMock createInitialWeightsGeneratorFactoryMock({ return factory; } -LinkFunctionFactoryMock createLinkFunctionFactoryMock( +ScoreToProbMapperFactoryMock createScoreToProbMapperFactoryMock( Type dtype, { - Map linkFunctions, + Map mappers, }) { - final factory = LinkFunctionFactoryMock(); - linkFunctions.forEach((LinkFunctionType type, LinkFunction fn) { + final factory = ScoreToProbMapperFactoryMock(); + mappers.forEach((ScoreToProbMapperType type, ScoreToProbMapper fn) { when(factory.fromType(type, dtype)).thenReturn(fn); }); return factory; @@ -175,7 +176,7 @@ OptimizerFactoryMock createOptimizerFactoryMock({ costFunctionType: anyNamed('costFunctionType'), learningRateType: anyNamed('learningRateType'), initialWeightsType: anyNamed('initialWeightsType'), - linkFunctionType: anyNamed('linkFunctionType'), + scoreToProbMapperType: anyNamed('scoreToProbMapperType'), initialLearningRate: anyNamed('initialLearningRate'), minCoefficientsUpdate: anyNamed('minCoefficientsUpdate'), iterationLimit: anyNamed('iterationLimit'), From 3a44ee7ddeb6ba473a00b16fafe7817831c006da Mon Sep 17 00:00:00 2001 From: Ilya Gyrdymov Date: Tue, 12 Feb 2019 01:12:17 +0200 Subject: [PATCH 3/3] changelog and pubspec --- CHANGELOG.md | 4 + lib/src/classifier/linear_classifier.dart | 7 +- lib/src/classifier/softmax_regressor.dart | 153 ---------------------- pubspec.yaml | 2 +- 4 files changed, 6 insertions(+), 160 deletions(-) delete mode 100644 lib/src/classifier/softmax_regressor.dart diff --git a/CHANGELOG.md b/CHANGELOG.md index fece0103..d2b769bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +## 6.1.0 +- `LinkFunction` renamed to `ScoreToProbMapper` +- `ScoreToProbMapper` accepts vector and returns vector instead of a scalar + ## 6.0.6 - Pedantic package integration added - Some linter issues fixed diff --git a/lib/src/classifier/linear_classifier.dart b/lib/src/classifier/linear_classifier.dart index 59fe1488..f0813f58 100644 --- a/lib/src/classifier/linear_classifier.dart +++ b/lib/src/classifier/linear_classifier.dart @@ -2,7 +2,6 @@ import 'package:ml_algo/gradient_type.dart'; import 'package:ml_algo/learning_rate_type.dart'; import 'package:ml_algo/src/classifier/classifier.dart'; import 'package:ml_algo/src/classifier/logistic_regressor.dart'; -import 'package:ml_algo/src/classifier/softmax_regressor.dart'; import 'package:ml_algo/src/optimizer/optimizer_type.dart'; /// A factory for all the linear classifiers @@ -71,11 +70,7 @@ abstract class LinearClassifier implements Classifier { Type dtype, }) = LogisticRegressor; - /** - * Creates a softmax regression classifier - */ - factory LinearClassifier.softMaxRegressor() = SoftMaxRegressor; - + factory LinearClassifier.softMaxRegressor() => throw UnimplementedError(); factory LinearClassifier.SVM() => throw UnimplementedError(); factory LinearClassifier.naiveBayes() => throw UnimplementedError(); } diff --git a/lib/src/classifier/softmax_regressor.dart b/lib/src/classifier/softmax_regressor.dart deleted file mode 100644 index bf9058cc..00000000 --- a/lib/src/classifier/softmax_regressor.dart +++ /dev/null @@ -1,153 +0,0 @@ -import 'package:ml_algo/gradient_type.dart'; -import 'package:ml_algo/learning_rate_type.dart'; -import 'package:ml_algo/src/classifier/labels_processor/labels_processor.dart'; -import 'package:ml_algo/src/classifier/labels_processor/labels_processor_factory.dart'; -import 'package:ml_algo/src/classifier/labels_processor/labels_processor_factory_impl.dart'; -import 'package:ml_algo/src/classifier/linear_classifier.dart'; -import 'package:ml_algo/src/cost_function/cost_function_type.dart'; -import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor.dart'; -import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor_factory.dart'; -import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor_factory_impl.dart'; -import 'package:ml_algo/src/default_parameter_values.dart'; -import 'package:ml_algo/src/metric/factory.dart'; -import 'package:ml_algo/src/metric/metric_type.dart'; -import 'package:ml_algo/src/optimizer/gradient/batch_size_calculator/batch_size_calculator.dart'; -import 'package:ml_algo/src/optimizer/gradient/batch_size_calculator/batch_size_calculator_impl.dart'; -import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_type.dart'; -import 'package:ml_algo/src/optimizer/optimizer.dart'; -import 'package:ml_algo/src/optimizer/optimizer_factory.dart'; -import 'package:ml_algo/src/optimizer/optimizer_factory_impl.dart'; -import 'package:ml_algo/src/optimizer/optimizer_type.dart'; -import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper.dart'; -import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_factory.dart'; -import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_factory_impl.dart'; -import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; -import 'package:ml_linalg/matrix.dart'; -import 'package:ml_linalg/vector.dart'; - -class SoftMaxRegressor implements LinearClassifier { - final Type dtype; - final Optimizer optimizer; - final InterceptPreprocessor interceptPreprocessor; - final LabelsProcessor labelsProcessor; - final ScoreToProbMapper scoreToProbMapper; - - SoftMaxRegressor({ - // public arguments - int iterationsLimit = DefaultParameterValues.iterationsLimit, - double initialLearningRate = DefaultParameterValues.initialLearningRate, - double minWeightsUpdate = DefaultParameterValues.minWeightsUpdate, - double lambda, - int randomSeed, - int batchSize = 1, - bool fitIntercept = false, - double interceptScale = 1.0, - OptimizerType optimizer = OptimizerType.gradientDescent, - GradientType gradientType = GradientType.stochastic, - LearningRateType learningRateType = LearningRateType.constant, - InitialWeightsType initialWeightsType = InitialWeightsType.zeroes, - ScoreToProbMapperType scoreToProbMapperType = ScoreToProbMapperType.logit, - this.dtype = DefaultParameterValues.dtype, - - // private arguments - LabelsProcessorFactory labelsProcessorFactory = - const LabelsProcessorFactoryImpl(), - InterceptPreprocessorFactory interceptPreprocessorFactory = - const InterceptPreprocessorFactoryImpl(), - ScoreToProbMapperFactory scoreToProbMapperFactory = - const ScoreToProbMapperFactoryImpl(), - OptimizerFactory optimizerFactory = const OptimizerFactoryImpl(), - BatchSizeCalculator batchSizeCalculator = const BatchSizeCalculatorImpl(), - }) : labelsProcessor = labelsProcessorFactory.create(dtype), - interceptPreprocessor = interceptPreprocessorFactory.create(dtype, - scale: fitIntercept ? interceptScale : 0.0), - scoreToProbMapper = - scoreToProbMapperFactory.fromType(scoreToProbMapperType, dtype), - optimizer = optimizerFactory.fromType( - optimizer, - dtype: dtype, - costFunctionType: CostFunctionType.logLikelihood, - scoreToProbMapperType: scoreToProbMapperType, - learningRateType: learningRateType, - initialWeightsType: initialWeightsType, - initialLearningRate: initialLearningRate, - minCoefficientsUpdate: minWeightsUpdate, - iterationLimit: iterationsLimit, - lambda: lambda, - batchSize: gradientType != null - ? batchSizeCalculator.calculate(gradientType, batchSize) - : null, - randomSeed: randomSeed, - ); - - @override - MLVector get weights => null; - - @override - Map get weightsByClasses => _weightsByClasses; - Map _weightsByClasses; - - @override - List get classLabels => _classLabels; - List _classLabels; - - @override - void fit(MLMatrix features, MLVector labels, - {MLVector initialWeights, bool isDataNormalized = false}) { - _classLabels = labels.unique().toList(); - final labelsAsList = _classLabels.toList(); - final processedFeatures = interceptPreprocessor.addIntercept(features); - _weightsByClasses = Map.fromIterable( - labelsAsList, - key: (dynamic label) => label as double, - value: (dynamic label) => _fitBinaryClassifier(processedFeatures, labels, - label as double, initialWeights, isDataNormalized), - ); - } - - @override - double test(MLMatrix features, MLVector origLabels, MetricType metricType) { - final evaluator = MetricFactory.createByType(metricType); - final prediction = predictClasses(features); - return evaluator.getError(prediction, origLabels); - } - - @override - MLMatrix predictProbabilities(MLMatrix features) { - final processedFeatures = interceptPreprocessor.addIntercept(features); - return _predictProbabilities(processedFeatures); - } - - @override - MLVector predictClasses(MLMatrix features) { - final processedFeatures = interceptPreprocessor.addIntercept(features); - final distributions = _predictProbabilities(processedFeatures); - final classes = List(processedFeatures.rowsNum); - for (int i = 0; i < distributions.rowsNum; i++) { - final probabilities = distributions.getRow(i); - classes[i] = probabilities.toList().indexOf(probabilities.max()) * 1.0; - } - return MLVector.from(classes, dtype: dtype); - } - - MLMatrix _predictProbabilities(MLMatrix processedFeatures) { - final numOfObservations = _weightsByClasses.length; - final distributions = List(numOfObservations); - int i = 0; - _weightsByClasses.forEach((double label, MLVector weights) { - final scores = (processedFeatures * weights).toVector(); - distributions[i++] = scoreToProbMapper.linkScoresToProbs(scores); - }); - return MLMatrix.columns(distributions, dtype: dtype); - } - - MLVector _fitBinaryClassifier(MLMatrix features, MLVector labels, - double targetLabel, MLVector initialWeights, bool arePointsNormalized) { - final binaryLabels = - labelsProcessor.makeLabelsOneVsAll(labels, targetLabel); - return optimizer.findExtrema(features, binaryLabels, - initialWeights: initialWeights, - arePointsNormalized: arePointsNormalized, - isMinimizingObjective: false); - } -} diff --git a/pubspec.yaml b/pubspec.yaml index 5009bda7..f80dd599 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -1,6 +1,6 @@ name: ml_algo description: Machine learning algorithms written with native dart (without bindings to any popular ML libraries, just pure Dart implementation) -version: 6.0.6 +version: 6.1.0 author: Ilia Gyrdymov homepage: https://github.com/gyrdym/ml_algo