diff --git a/CHANGELOG.md b/CHANGELOG.md index fece0103..d2b769bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +## 6.1.0 +- `LinkFunction` renamed to `ScoreToProbMapper` +- `ScoreToProbMapper` accepts vector and returns vector instead of a scalar + ## 6.0.6 - Pedantic package integration added - Some linter issues fixed diff --git a/lib/src/classifier/labels_probability_calculator/labels_probability_calculator.dart b/lib/src/classifier/labels_probability_calculator/labels_probability_calculator.dart deleted file mode 100644 index 99361307..00000000 --- a/lib/src/classifier/labels_probability_calculator/labels_probability_calculator.dart +++ /dev/null @@ -1,5 +0,0 @@ -import 'package:ml_linalg/vector.dart'; - -abstract class LabelsProbabilityCalculator { - MLVector getProbabilities(MLVector scores); -} diff --git a/lib/src/classifier/labels_probability_calculator/labels_probability_calculator_factory.dart b/lib/src/classifier/labels_probability_calculator/labels_probability_calculator_factory.dart deleted file mode 100644 index c3f0512f..00000000 --- a/lib/src/classifier/labels_probability_calculator/labels_probability_calculator_factory.dart +++ /dev/null @@ -1,7 +0,0 @@ -import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; - -abstract class LabelsProbabilityCalculatorFactory { - LabelsProbabilityCalculator create( - LinkFunctionType linkFunctionType, Type dtype); -} diff --git a/lib/src/classifier/labels_probability_calculator/labels_probability_calculator_factory_impl.dart b/lib/src/classifier/labels_probability_calculator/labels_probability_calculator_factory_impl.dart deleted file mode 100644 index 075b9aab..00000000 --- a/lib/src/classifier/labels_probability_calculator/labels_probability_calculator_factory_impl.dart +++ /dev/null @@ -1,14 +0,0 @@ -import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator.dart'; -import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator_factory.dart'; -import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator_impl.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; - -class LabelsProbabilityCalculatorFactoryImpl - implements LabelsProbabilityCalculatorFactory { - const LabelsProbabilityCalculatorFactoryImpl(); - - @override - LabelsProbabilityCalculator create( - LinkFunctionType linkFunctionType, Type dtype) => - LabelsProbabilityCalculatorImpl(linkFunctionType, dtype); -} diff --git a/lib/src/classifier/labels_probability_calculator/labels_probability_calculator_impl.dart b/lib/src/classifier/labels_probability_calculator/labels_probability_calculator_impl.dart deleted file mode 100644 index 6dca02cc..00000000 --- a/lib/src/classifier/labels_probability_calculator/labels_probability_calculator_impl.dart +++ /dev/null @@ -1,31 +0,0 @@ -import 'dart:typed_data'; - -import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator.dart'; -import 'package:ml_algo/src/link_function/link_function.dart'; -import 'package:ml_algo/src/link_function/link_function_factory.dart'; -import 'package:ml_algo/src/link_function/link_function_factory_impl.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; -import 'package:ml_linalg/vector.dart'; - -class LabelsProbabilityCalculatorImpl implements LabelsProbabilityCalculator { - final Type dtype; - final LinkFunction linkFunction; - - LabelsProbabilityCalculatorImpl( - LinkFunctionType linkFunctionType, - this.dtype, { - LinkFunctionFactory linkFnFactory = const LinkFunctionFactoryImpl(), - }) : linkFunction = linkFnFactory.fromType(linkFunctionType); - - @override - MLVector getProbabilities(MLVector scores) { - switch (dtype) { - case Float32x4: - return scores.fastMap( - (Float32x4 el, int startOffset, int endOffset) => - linkFunction.float32x4Link(el)); - default: - throw UnsupportedError('Unsupported data type - $dtype'); - } - } -} diff --git a/lib/src/classifier/linear_classifier.dart b/lib/src/classifier/linear_classifier.dart index 02edd532..f0813f58 100644 --- a/lib/src/classifier/linear_classifier.dart +++ b/lib/src/classifier/linear_classifier.dart @@ -70,6 +70,7 @@ abstract class LinearClassifier implements Classifier { Type dtype, }) = LogisticRegressor; + factory LinearClassifier.softMaxRegressor() => throw UnimplementedError(); factory LinearClassifier.SVM() => throw UnimplementedError(); factory LinearClassifier.naiveBayes() => throw UnimplementedError(); } diff --git a/lib/src/classifier/logistic_regressor.dart b/lib/src/classifier/logistic_regressor.dart index a2620f5a..756c8816 100644 --- a/lib/src/classifier/logistic_regressor.dart +++ b/lib/src/classifier/logistic_regressor.dart @@ -1,8 +1,5 @@ import 'package:ml_algo/gradient_type.dart'; import 'package:ml_algo/learning_rate_type.dart'; -import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator.dart'; -import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator_factory.dart'; -import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator_factory_impl.dart'; import 'package:ml_algo/src/classifier/labels_processor/labels_processor.dart'; import 'package:ml_algo/src/classifier/labels_processor/labels_processor_factory.dart'; import 'package:ml_algo/src/classifier/labels_processor/labels_processor_factory_impl.dart'; @@ -12,7 +9,6 @@ import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_ import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor_factory.dart'; import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor_factory_impl.dart'; import 'package:ml_algo/src/default_parameter_values.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; import 'package:ml_algo/src/metric/factory.dart'; import 'package:ml_algo/src/metric/metric_type.dart'; import 'package:ml_algo/src/optimizer/gradient/batch_size_calculator/batch_size_calculator.dart'; @@ -22,6 +18,10 @@ import 'package:ml_algo/src/optimizer/optimizer.dart'; import 'package:ml_algo/src/optimizer/optimizer_factory.dart'; import 'package:ml_algo/src/optimizer/optimizer_factory_impl.dart'; import 'package:ml_algo/src/optimizer/optimizer_type.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_factory.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_factory_impl.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; import 'package:ml_linalg/matrix.dart'; import 'package:ml_linalg/vector.dart'; @@ -30,7 +30,7 @@ class LogisticRegressor implements LinearClassifier { final Optimizer optimizer; final InterceptPreprocessor interceptPreprocessor; final LabelsProcessor labelsProcessor; - final LabelsProbabilityCalculator probabilityCalculator; + final ScoreToProbMapper scoreToProbMapper; LogisticRegressor({ // public arguments @@ -46,7 +46,7 @@ class LogisticRegressor implements LinearClassifier { GradientType gradientType = GradientType.stochastic, LearningRateType learningRateType = LearningRateType.constant, InitialWeightsType initialWeightsType = InitialWeightsType.zeroes, - LinkFunctionType linkFunctionType = LinkFunctionType.logit, + ScoreToProbMapperType scoreToProbMapperType = ScoreToProbMapperType.logit, this.dtype = DefaultParameterValues.dtype, // private arguments @@ -54,20 +54,20 @@ class LogisticRegressor implements LinearClassifier { const LabelsProcessorFactoryImpl(), InterceptPreprocessorFactory interceptPreprocessorFactory = const InterceptPreprocessorFactoryImpl(), - LabelsProbabilityCalculatorFactory probabilityCalculatorFactory = - const LabelsProbabilityCalculatorFactoryImpl(), + ScoreToProbMapperFactory scoreToProbMapperFactory = + const ScoreToProbMapperFactoryImpl(), OptimizerFactory optimizerFactory = const OptimizerFactoryImpl(), BatchSizeCalculator batchSizeCalculator = const BatchSizeCalculatorImpl(), }) : labelsProcessor = labelsProcessorFactory.create(dtype), interceptPreprocessor = interceptPreprocessorFactory.create(dtype, scale: fitIntercept ? interceptScale : 0.0), - probabilityCalculator = - probabilityCalculatorFactory.create(linkFunctionType, dtype), + scoreToProbMapper = + scoreToProbMapperFactory.fromType(scoreToProbMapperType, dtype), optimizer = optimizerFactory.fromType( optimizer, dtype: dtype, costFunctionType: CostFunctionType.logLikelihood, - linkFunctionType: linkFunctionType, + scoreToProbMapperType: scoreToProbMapperType, learningRateType: learningRateType, initialWeightsType: initialWeightsType, initialLearningRate: initialLearningRate, @@ -136,7 +136,7 @@ class LogisticRegressor implements LinearClassifier { int i = 0; _weightsByClasses.forEach((double label, MLVector weights) { final scores = (processedFeatures * weights).toVector(); - distributions[i++] = probabilityCalculator.getProbabilities(scores); + distributions[i++] = scoreToProbMapper.linkScoresToProbs(scores); }); return MLMatrix.columns(distributions, dtype: dtype); } diff --git a/lib/src/cost_function/cost_function_factory.dart b/lib/src/cost_function/cost_function_factory.dart index c4f8e118..34a583ec 100644 --- a/lib/src/cost_function/cost_function_factory.dart +++ b/lib/src/cost_function/cost_function_factory.dart @@ -1,10 +1,11 @@ import 'package:ml_algo/src/cost_function/cost_function.dart'; import 'package:ml_algo/src/cost_function/cost_function_type.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; abstract class CostFunctionFactory { CostFunction fromType(CostFunctionType type, - {Type dtype, LinkFunctionType linkFunctionType}); + {Type dtype, ScoreToProbMapperType scoreToProbMapperType}); CostFunction squared(); - CostFunction logLikelihood(LinkFunctionType linkFunctionType, {Type dtype}); + CostFunction logLikelihood(ScoreToProbMapperType scoreToProbMapperType, + {Type dtype}); } diff --git a/lib/src/cost_function/cost_function_factory_impl.dart b/lib/src/cost_function/cost_function_factory_impl.dart index dd55e9ed..7a66f2fd 100644 --- a/lib/src/cost_function/cost_function_factory_impl.dart +++ b/lib/src/cost_function/cost_function_factory_impl.dart @@ -4,7 +4,7 @@ import 'package:ml_algo/src/cost_function/cost_function_type.dart'; import 'package:ml_algo/src/cost_function/log_likelihood.dart'; import 'package:ml_algo/src/cost_function/squared.dart'; import 'package:ml_algo/src/default_parameter_values.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; class CostFunctionFactoryImpl implements CostFunctionFactory { const CostFunctionFactoryImpl(); @@ -13,17 +13,17 @@ class CostFunctionFactoryImpl implements CostFunctionFactory { CostFunction squared() => SquaredCost(); @override - CostFunction logLikelihood(LinkFunctionType linkFunctionType, + CostFunction logLikelihood(ScoreToProbMapperType scoreToProbMapperType, {Type dtype = DefaultParameterValues.dtype}) => - LogLikelihoodCost(linkFunctionType, dtype: dtype); + LogLikelihoodCost(scoreToProbMapperType, dtype: dtype); @override CostFunction fromType(CostFunctionType type, {Type dtype = DefaultParameterValues.dtype, - LinkFunctionType linkFunctionType}) { + ScoreToProbMapperType scoreToProbMapperType}) { switch (type) { case CostFunctionType.logLikelihood: - return logLikelihood(linkFunctionType, dtype: dtype); + return logLikelihood(scoreToProbMapperType, dtype: dtype); case CostFunctionType.squared: return squared(); default: diff --git a/lib/src/cost_function/log_likelihood.dart b/lib/src/cost_function/log_likelihood.dart index 8262ab25..9a84fbc5 100644 --- a/lib/src/cost_function/log_likelihood.dart +++ b/lib/src/cost_function/log_likelihood.dart @@ -2,21 +2,23 @@ import 'dart:typed_data'; import 'package:ml_algo/src/cost_function/cost_function.dart'; import 'package:ml_algo/src/default_parameter_values.dart'; -import 'package:ml_algo/src/link_function/link_function.dart'; -import 'package:ml_algo/src/link_function/link_function_factory.dart'; -import 'package:ml_algo/src/link_function/link_function_factory_impl.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_factory.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_factory_impl.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; import 'package:ml_linalg/linalg.dart'; class LogLikelihoodCost implements CostFunction { - final LinkFunction linkFunction; + final ScoreToProbMapper scoreToProbMapper; final Type dtype; LogLikelihoodCost( - LinkFunctionType linkFunctionType, { + ScoreToProbMapperType scoreToProbMapperType, { this.dtype = DefaultParameterValues.dtype, - LinkFunctionFactory linkFunctionFactory = const LinkFunctionFactoryImpl(), - }) : linkFunction = linkFunctionFactory.fromType(linkFunctionType); + ScoreToProbMapperFactory scoreToProbMapperFactory = + const ScoreToProbMapperFactoryImpl(), + }) : scoreToProbMapper = + scoreToProbMapperFactory.fromType(scoreToProbMapperType, dtype); @override double getCost(double score, double yOrig) { @@ -25,10 +27,11 @@ class LogLikelihoodCost implements CostFunction { @override MLVector getGradient(MLMatrix x, MLVector w, MLVector y) { + final scores = (x * w).toVector(); switch (dtype) { case Float32x4: return (x.transpose() * - (y - (x * w).fastMap(linkFunction.float32x4Link))) + (y - scoreToProbMapper.linkScoresToProbs(scores))) .toVector(); default: throw throw UnsupportedError('Unsupported data type - $dtype'); diff --git a/lib/src/link_function/link_function.dart b/lib/src/link_function/link_function.dart deleted file mode 100644 index ab43ec84..00000000 --- a/lib/src/link_function/link_function.dart +++ /dev/null @@ -1,5 +0,0 @@ -import 'dart:typed_data'; - -abstract class LinkFunction { - Float32x4 float32x4Link(Float32x4 scores); -} diff --git a/lib/src/link_function/link_function_factory.dart b/lib/src/link_function/link_function_factory.dart deleted file mode 100644 index 9d50ef75..00000000 --- a/lib/src/link_function/link_function_factory.dart +++ /dev/null @@ -1,6 +0,0 @@ -import 'package:ml_algo/src/link_function/link_function.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; - -abstract class LinkFunctionFactory { - LinkFunction fromType(LinkFunctionType type); -} diff --git a/lib/src/link_function/link_function_factory_impl.dart b/lib/src/link_function/link_function_factory_impl.dart deleted file mode 100644 index a42ebec2..00000000 --- a/lib/src/link_function/link_function_factory_impl.dart +++ /dev/null @@ -1,18 +0,0 @@ -import 'package:ml_algo/src/link_function/link_function.dart'; -import 'package:ml_algo/src/link_function/link_function_factory.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; -import 'package:ml_algo/src/link_function/logit_link_function.dart'; - -class LinkFunctionFactoryImpl implements LinkFunctionFactory { - const LinkFunctionFactoryImpl(); - - @override - LinkFunction fromType(LinkFunctionType type) { - switch (type) { - case LinkFunctionType.logit: - return LogitLinkFunction(); - default: - throw UnsupportedError('Unsupported link function type - $type'); - } - } -} diff --git a/lib/src/link_function/link_function_type.dart b/lib/src/link_function/link_function_type.dart deleted file mode 100644 index dc6d56a5..00000000 --- a/lib/src/link_function/link_function_type.dart +++ /dev/null @@ -1,3 +0,0 @@ -enum LinkFunctionType { - logit, -} diff --git a/lib/src/link_function/logit_link_function.dart b/lib/src/link_function/logit_link_function.dart deleted file mode 100644 index 9d9d141b..00000000 --- a/lib/src/link_function/logit_link_function.dart +++ /dev/null @@ -1,21 +0,0 @@ -import 'dart:math' as math; -import 'dart:typed_data'; - -import 'package:ml_algo/src/link_function/link_function.dart'; - -class LogitLinkFunction implements LinkFunction { - final float32x4Zeroes = Float32x4.zero(); - final float32x4Ones = Float32x4.splat(1.0); - - @override - Float32x4 float32x4Link(Float32x4 scores) => - float32x4Ones / - (float32x4Ones + - Float32x4( - //@TODO: find a more efficient way to raise exponent to the float power in SIMD way - math.exp(-scores.x), - math.exp(-scores.y), - math.exp(-scores.z), - math.exp(-scores.w), - )); -} diff --git a/lib/src/optimizer/gradient/gradient.dart b/lib/src/optimizer/gradient/gradient.dart index 24f5cbd9..662a486a 100644 --- a/lib/src/optimizer/gradient/gradient.dart +++ b/lib/src/optimizer/gradient/gradient.dart @@ -3,7 +3,6 @@ import 'package:ml_algo/src/cost_function/cost_function_factory.dart'; import 'package:ml_algo/src/cost_function/cost_function_factory_impl.dart'; import 'package:ml_algo/src/cost_function/cost_function_type.dart'; import 'package:ml_algo/src/default_parameter_values.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; import 'package:ml_algo/src/math/randomizer/randomizer.dart'; import 'package:ml_algo/src/math/randomizer/randomizer_factory.dart'; import 'package:ml_algo/src/math/randomizer/randomizer_factory_impl.dart'; @@ -16,6 +15,7 @@ import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_ import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_generator_factory_impl.dart'; import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_type.dart'; import 'package:ml_algo/src/optimizer/optimizer.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; import 'package:ml_linalg/matrix.dart'; import 'package:ml_linalg/range.dart'; import 'package:ml_linalg/vector.dart'; @@ -46,7 +46,7 @@ class GradientOptimizer implements Optimizer { CostFunctionType costFnType, LearningRateType learningRateType, InitialWeightsType initialWeightsType, - LinkFunctionType linkFunctionType, + ScoreToProbMapperType scoreToProbMapperType, double initialLearningRate = DefaultParameterValues.initialLearningRate, double minWeightsUpdate = DefaultParameterValues.minWeightsUpdate, int iterationLimit = DefaultParameterValues.iterationsLimit, @@ -62,7 +62,7 @@ class GradientOptimizer implements Optimizer { _learningRateGenerator = learningRateGeneratorFactory.fromType(learningRateType), _costFunction = costFunctionFactory.fromType(costFnType, - dtype: dtype, linkFunctionType: linkFunctionType), + dtype: dtype, scoreToProbMapperType: scoreToProbMapperType), _randomizer = randomizerFactory.create(randomSeed) { _learningRateGenerator.init(initialLearningRate ?? 1.0); } diff --git a/lib/src/optimizer/optimizer_factory.dart b/lib/src/optimizer/optimizer_factory.dart index b5c64a17..f66956c7 100644 --- a/lib/src/optimizer/optimizer_factory.dart +++ b/lib/src/optimizer/optimizer_factory.dart @@ -1,13 +1,13 @@ import 'package:ml_algo/learning_rate_type.dart'; import 'package:ml_algo/src/cost_function/cost_function_factory.dart'; import 'package:ml_algo/src/cost_function/cost_function_type.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; import 'package:ml_algo/src/math/randomizer/randomizer_factory.dart'; import 'package:ml_algo/src/optimizer/gradient/learning_rate_generator/learning_rate_generator_factory.dart'; import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_generator_factory.dart'; import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_type.dart'; import 'package:ml_algo/src/optimizer/optimizer.dart'; import 'package:ml_algo/src/optimizer/optimizer_type.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; abstract class OptimizerFactory { Optimizer fromType( @@ -20,7 +20,7 @@ abstract class OptimizerFactory { CostFunctionType costFunctionType, LearningRateType learningRateType, InitialWeightsType initialWeightsType, - LinkFunctionType linkFunctionType, + ScoreToProbMapperType scoreToProbMapperType, double initialLearningRate, double minCoefficientsUpdate, int iterationLimit, @@ -38,7 +38,7 @@ abstract class OptimizerFactory { CostFunctionType costFnType, LearningRateType learningRateType, InitialWeightsType initialWeightsType, - LinkFunctionType linkFunctionType, + ScoreToProbMapperType scoreToProbMapperType, double initialLearningRate, double minCoefficientsUpdate, int iterationLimit, diff --git a/lib/src/optimizer/optimizer_factory_impl.dart b/lib/src/optimizer/optimizer_factory_impl.dart index 6d06284b..e1f0d7e3 100644 --- a/lib/src/optimizer/optimizer_factory_impl.dart +++ b/lib/src/optimizer/optimizer_factory_impl.dart @@ -3,7 +3,6 @@ import 'dart:typed_data'; import 'package:ml_algo/src/cost_function/cost_function_factory.dart'; import 'package:ml_algo/src/cost_function/cost_function_factory_impl.dart'; import 'package:ml_algo/src/cost_function/cost_function_type.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; import 'package:ml_algo/src/math/randomizer/randomizer_factory.dart'; import 'package:ml_algo/src/math/randomizer/randomizer_factory_impl.dart'; import 'package:ml_algo/src/optimizer/coordinate/coordinate.dart'; @@ -17,6 +16,7 @@ import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_ import 'package:ml_algo/src/optimizer/optimizer.dart'; import 'package:ml_algo/src/optimizer/optimizer_factory.dart'; import 'package:ml_algo/src/optimizer/optimizer_type.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; class OptimizerFactoryImpl implements OptimizerFactory { const OptimizerFactoryImpl(); @@ -34,7 +34,7 @@ class OptimizerFactoryImpl implements OptimizerFactory { CostFunctionType costFunctionType, LearningRateType learningRateType, InitialWeightsType initialWeightsType, - LinkFunctionType linkFunctionType, + ScoreToProbMapperType scoreToProbMapperType, double initialLearningRate, double minCoefficientsUpdate, int iterationLimit, @@ -65,7 +65,7 @@ class OptimizerFactoryImpl implements OptimizerFactory { costFnType: costFunctionType, learningRateType: learningRateType, initialWeightsType: initialWeightsType, - linkFunctionType: linkFunctionType, + scoreToProbMapperType: scoreToProbMapperType, initialLearningRate: initialLearningRate, minCoefficientsUpdate: minCoefficientsUpdate, iterationLimit: iterationLimit, @@ -114,7 +114,7 @@ class OptimizerFactoryImpl implements OptimizerFactory { CostFunctionType costFnType, LearningRateType learningRateType, InitialWeightsType initialWeightsType, - LinkFunctionType linkFunctionType, + ScoreToProbMapperType scoreToProbMapperType, double initialLearningRate, double minCoefficientsUpdate, int iterationLimit, @@ -130,7 +130,7 @@ class OptimizerFactoryImpl implements OptimizerFactory { costFnType: costFnType, learningRateType: learningRateType, initialWeightsType: initialWeightsType, - linkFunctionType: linkFunctionType, + scoreToProbMapperType: scoreToProbMapperType, initialLearningRate: initialLearningRate, minWeightsUpdate: minCoefficientsUpdate, iterationLimit: iterationLimit, diff --git a/lib/src/score_to_prob_mapper/logit_mapper.dart b/lib/src/score_to_prob_mapper/logit_mapper.dart new file mode 100644 index 00000000..020edc88 --- /dev/null +++ b/lib/src/score_to_prob_mapper/logit_mapper.dart @@ -0,0 +1,38 @@ +import 'dart:math' as math; +import 'dart:typed_data'; + +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper.dart'; +import 'package:ml_linalg/matrix.dart'; +import 'package:ml_linalg/vector.dart'; + +class LogitMapper implements ScoreToProbMapper { + final Type dtype; + + final float32x4Zeroes = Float32x4.zero(); + final float32x4Ones = Float32x4.splat(1.0); + + LogitMapper(this.dtype); + + @override + MLVector linkScoresToProbs(MLVector scores, [MLMatrix scoresByClasses]) { + switch (dtype) { + case Float32x4: + return scores.fastMap( + (Float32x4 el, int startOffset, int endOffset) => + scoreToProbFloat32x4(el)); + default: + throw UnsupportedError('Unsupported data type - $dtype'); + } + } + + Float32x4 scoreToProbFloat32x4(Float32x4 scores) => + float32x4Ones / + (float32x4Ones + + Float32x4( + //@TODO: find a more efficient way to raise exponent to the float power in SIMD way + math.exp(-scores.x), + math.exp(-scores.y), + math.exp(-scores.z), + math.exp(-scores.w), + )); +} diff --git a/lib/src/score_to_prob_mapper/score_to_prob_mapper.dart b/lib/src/score_to_prob_mapper/score_to_prob_mapper.dart new file mode 100644 index 00000000..e3d47ca1 --- /dev/null +++ b/lib/src/score_to_prob_mapper/score_to_prob_mapper.dart @@ -0,0 +1,8 @@ +import 'package:ml_linalg/matrix.dart'; +import 'package:ml_linalg/vector.dart'; + +abstract class ScoreToProbMapper { + /// Accepts a vector of scores, returns a vector of probabilities + /// Score is a multiplication of a feature value and the corresponding weight (coefficient) + MLVector linkScoresToProbs(MLVector scores, [MLMatrix scoresByClasses]); +} diff --git a/lib/src/score_to_prob_mapper/score_to_prob_mapper_factory.dart b/lib/src/score_to_prob_mapper/score_to_prob_mapper_factory.dart new file mode 100644 index 00000000..7202e7a0 --- /dev/null +++ b/lib/src/score_to_prob_mapper/score_to_prob_mapper_factory.dart @@ -0,0 +1,6 @@ +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper.dart'; + +abstract class ScoreToProbMapperFactory { + ScoreToProbMapper fromType(ScoreToProbMapperType type, Type dtype); +} diff --git a/lib/src/score_to_prob_mapper/score_to_prob_mapper_factory_impl.dart b/lib/src/score_to_prob_mapper/score_to_prob_mapper_factory_impl.dart new file mode 100644 index 00000000..462570e0 --- /dev/null +++ b/lib/src/score_to_prob_mapper/score_to_prob_mapper_factory_impl.dart @@ -0,0 +1,18 @@ +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_factory.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/logit_mapper.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper.dart'; + +class ScoreToProbMapperFactoryImpl implements ScoreToProbMapperFactory { + const ScoreToProbMapperFactoryImpl(); + + @override + ScoreToProbMapper fromType(ScoreToProbMapperType type, Type dtype) { + switch (type) { + case ScoreToProbMapperType.logit: + return LogitMapper(dtype); + default: + throw UnsupportedError('Unsupported link function type - $type'); + } + } +} diff --git a/lib/src/score_to_prob_mapper/score_to_prob_mapper_type.dart b/lib/src/score_to_prob_mapper/score_to_prob_mapper_type.dart new file mode 100644 index 00000000..cc3ab38f --- /dev/null +++ b/lib/src/score_to_prob_mapper/score_to_prob_mapper_type.dart @@ -0,0 +1,3 @@ +enum ScoreToProbMapperType { + logit, +} diff --git a/pubspec.yaml b/pubspec.yaml index 5009bda7..f80dd599 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -1,6 +1,6 @@ name: ml_algo description: Machine learning algorithms written with native dart (without bindings to any popular ML libraries, just pure Dart implementation) -version: 6.0.6 +version: 6.1.0 author: Ilia Gyrdymov homepage: https://github.com/gyrdym/ml_algo diff --git a/test/classifier/logistic_regressor_common.dart b/test/classifier/logistic_regressor_common.dart index a150f431..cd63bc15 100644 --- a/test/classifier/logistic_regressor_common.dart +++ b/test/classifier/logistic_regressor_common.dart @@ -2,18 +2,18 @@ import 'dart:typed_data'; import 'package:ml_algo/gradient_type.dart'; import 'package:ml_algo/learning_rate_type.dart'; -import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator.dart'; -import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator_factory.dart'; import 'package:ml_algo/src/classifier/labels_processor/labels_processor.dart'; import 'package:ml_algo/src/classifier/labels_processor/labels_processor_factory.dart'; import 'package:ml_algo/src/classifier/logistic_regressor.dart'; import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor.dart'; import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor_factory.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_type.dart'; import 'package:ml_algo/src/optimizer/optimizer.dart'; import 'package:ml_algo/src/optimizer/optimizer_factory.dart'; import 'package:ml_algo/src/optimizer/optimizer_type.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_factory.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; import '../test_utils/mocks.dart'; @@ -21,10 +21,10 @@ LabelsProcessor labelsProcessorMock; LabelsProcessorFactory labelsProcessorFactoryMock; InterceptPreprocessor interceptPreprocessorMock; InterceptPreprocessorFactory interceptPreprocessorFactoryMock; -LabelsProbabilityCalculator probabilityCalculatorMock; -LabelsProbabilityCalculatorFactory probabilityCalculatorFactoryMock; Optimizer optimizerMock; OptimizerFactory optimizerFactoryMock; +ScoreToProbMapperFactory scoreToProbFactoryMock; +ScoreToProbMapper scoreToProbMapperMock; void setUpLabelsProcessorFactory() { labelsProcessorMock = LabelsProcessorMock(); @@ -38,22 +38,20 @@ void setUpInterceptPreprocessorFactory() { preprocessor: interceptPreprocessorMock); } -void setUpProbabilityCalculatorFactory() { - probabilityCalculatorMock = LabelsProbabilityCalculatorMock(); - probabilityCalculatorFactoryMock = - createLabelsProbabilityCalculatorFactoryMock( - linkType: LinkFunctionType.logit, - dtype: Float32x4, - calculator: probabilityCalculatorMock, - ); -} - void setUpOptimizerFactory() { optimizerMock = OptimizerMock(); optimizerFactoryMock = createOptimizerFactoryMock( optimizers: {OptimizerType.gradientDescent: optimizerMock}); } +void setUpScoreToProbMapperFactory() { + scoreToProbMapperMock = ScoreToProbMapperMock(); + scoreToProbFactoryMock = + createScoreToProbMapperFactoryMock(Float32x4, mappers: { + ScoreToProbMapperType.logit: scoreToProbMapperMock, + }); +} + LogisticRegressor createRegressor({ int iterationLimit = 100, double learningRate = 0.01, @@ -71,8 +69,8 @@ LogisticRegressor createRegressor({ lambda: lambda, labelsProcessorFactory: labelsProcessorFactoryMock, interceptPreprocessorFactory: interceptPreprocessorFactoryMock, - linkFunctionType: LinkFunctionType.logit, - probabilityCalculatorFactory: probabilityCalculatorFactoryMock, + scoreToProbMapperType: ScoreToProbMapperType.logit, + scoreToProbMapperFactory: scoreToProbFactoryMock, optimizer: OptimizerType.gradientDescent, optimizerFactory: optimizerFactoryMock, gradientType: GradientType.stochastic, diff --git a/test/classifier/logistic_regressor_test.dart b/test/classifier/logistic_regressor_test.dart index 1fd1b692..054524f3 100644 --- a/test/classifier/logistic_regressor_test.dart +++ b/test/classifier/logistic_regressor_test.dart @@ -2,9 +2,9 @@ import 'dart:typed_data'; import 'package:ml_algo/learning_rate_type.dart'; import 'package:ml_algo/src/cost_function/cost_function_type.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_type.dart'; import 'package:ml_algo/src/optimizer/optimizer_type.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; import 'package:ml_linalg/matrix.dart'; import 'package:ml_linalg/vector.dart'; import 'package:mockito/mockito.dart'; @@ -20,7 +20,7 @@ void main() { test('should initialize properly', () { setUpLabelsProcessorFactory(); setUpInterceptPreprocessorFactory(); - setUpProbabilityCalculatorFactory(); + setUpScoreToProbMapperFactory(); setUpOptimizerFactory(); createRegressor(); @@ -28,8 +28,8 @@ void main() { verify(labelsProcessorFactoryMock.create(Float32x4)).called(1); verify(interceptPreprocessorFactoryMock.create(Float32x4, scale: 0.0)) .called(1); - verify(probabilityCalculatorFactoryMock.create( - LinkFunctionType.logit, Float32x4)) + verify(scoreToProbFactoryMock.fromType( + ScoreToProbMapperType.logit, Float32x4)) .called(1); verify(optimizerFactoryMock.fromType( OptimizerType.gradientDescent, @@ -37,7 +37,7 @@ void main() { costFunctionType: CostFunctionType.logLikelihood, learningRateType: LearningRateType.constant, initialWeightsType: InitialWeightsType.zeroes, - linkFunctionType: LinkFunctionType.logit, + scoreToProbMapperType: ScoreToProbMapperType.logit, initialLearningRate: 0.01, minCoefficientsUpdate: 0.001, iterationLimit: 100, @@ -50,7 +50,7 @@ void main() { test('should make appropriate method calls when `fit` is called', () { setUpLabelsProcessorFactory(); setUpInterceptPreprocessorFactory(); - setUpProbabilityCalculatorFactory(); + setUpScoreToProbMapperFactory(); setUpOptimizerFactory(); final features = MLMatrix.from([ diff --git a/test/cost_function/cost_function_test.dart b/test/cost_function/cost_function_test.dart index 6c51d065..29b1fef3 100644 --- a/test/cost_function/cost_function_test.dart +++ b/test/cost_function/cost_function_test.dart @@ -2,7 +2,7 @@ import 'dart:typed_data'; import 'package:ml_algo/src/cost_function/log_likelihood.dart'; import 'package:ml_algo/src/cost_function/squared.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; import 'package:ml_linalg/linalg.dart'; import 'package:mockito/mockito.dart'; import 'package:test/test.dart'; @@ -38,15 +38,16 @@ void main() { }); group('LogLikelihoodCost', () { - final mockedLinkFn = LinkFunctionMock(); - final linkFunctionFactoryMock = - createLinkFunctionFactoryMock(linkFunctions: { - LinkFunctionType.logit: mockedLinkFn, + final mockedLinkFn = ScoreToProbMapperMock(); + final scoreToProbMapperFactoryMock = + createScoreToProbMapperFactoryMock(Float32x4, mappers: { + ScoreToProbMapperType.logit: mockedLinkFn, }); - final logLikelihoodCost = LogLikelihoodCost(LinkFunctionType.logit, - linkFunctionFactory: linkFunctionFactoryMock); + final logLikelihoodCost = LogLikelihoodCost(ScoreToProbMapperType.logit, + scoreToProbMapperFactory: scoreToProbMapperFactoryMock); - when(mockedLinkFn.float32x4Link(any)).thenReturn(Float32x4.splat(1.0)); + when(mockedLinkFn.linkScoresToProbs(any)) + .thenReturn(MLVector.from([1.0, 1.0, 1.0])); test('should return a proper gradient vector', () { // The formula in matrix notation: diff --git a/test/optimizer/gradient/gradient_common.dart b/test/optimizer/gradient/gradient_common.dart index 9d13d138..ccb83746 100644 --- a/test/optimizer/gradient/gradient_common.dart +++ b/test/optimizer/gradient/gradient_common.dart @@ -43,7 +43,7 @@ GradientOptimizer createOptimizer( costFunctionMock = CostFunctionMock(); costFunctionFactoryMock = CostFunctionFactoryMock(); when(costFunctionFactoryMock.fromType(CostFunctionType.squared, - dtype: Float32x4, linkFunctionType: null)) + dtype: Float32x4, scoreToProbMapperType: null)) .thenReturn(costFunctionMock); final randomizerFactoryMock = RandomizerFactoryMock(); @@ -69,7 +69,7 @@ GradientOptimizer createOptimizer( batchSize: batchSize); verify(costFunctionFactoryMock.fromType(CostFunctionType.squared, - dtype: Float32x4, linkFunctionType: null)); + dtype: Float32x4, scoreToProbMapperType: null)); return opt; } diff --git a/test/score_to_prob_link_function/link_function_test.dart b/test/score_to_prob_link_function/link_function_test.dart deleted file mode 100644 index b9bb08c0..00000000 --- a/test/score_to_prob_link_function/link_function_test.dart +++ /dev/null @@ -1,19 +0,0 @@ -import 'dart:typed_data'; - -import 'package:ml_algo/src/link_function/logit_link_function.dart'; -import 'package:test/test.dart'; - -void main() { - group('Vectorized logit link function', () { - test('should properly translate score to probability', () { - final scores = Float32x4(1.0, 2.0, 3.0, 4.0); - final logitLink = LogitLinkFunction(); - final probabilities = logitLink.float32x4Link(scores); - - expect(probabilities.x, inInclusiveRange(0.731, 0.732)); - expect(probabilities.y, inInclusiveRange(0.88, 0.881)); - expect(probabilities.z, inInclusiveRange(0.952, 0.953)); - expect(probabilities.w, inInclusiveRange(0.982, 0.983)); - }); - }); -} diff --git a/test/score_to_prob_mapper/score_to_prob_mapper_test.dart b/test/score_to_prob_mapper/score_to_prob_mapper_test.dart new file mode 100644 index 00000000..74e53b4f --- /dev/null +++ b/test/score_to_prob_mapper/score_to_prob_mapper_test.dart @@ -0,0 +1,20 @@ +import 'dart:typed_data'; + +import 'package:ml_algo/src/score_to_prob_mapper/logit_mapper.dart'; +import 'package:ml_linalg/vector.dart'; +import 'package:test/test.dart'; + +import '../test_utils/helpers/floating_point_iterable_matchers.dart'; + +void main() { + group('LogitMapper', () { + test('should properly translate scores to probabilities for Float32x4', () { + final scores = MLVector.from([1.0, 2.0, 3.0, 4.0]); + final logitLink = LogitMapper(Float32x4); + final probabilities = logitLink.linkScoresToProbs(scores); + + expect(probabilities, + vectorAlmostEqualTo([0.73105, 0.88079, 0.9525, 0.98201], 1e-4)); + }); + }); +} diff --git a/test/test_all.dart b/test/test_all.dart index d62f5c99..d58d37c8 100644 --- a/test/test_all.dart +++ b/test/test_all.dart @@ -1,23 +1,39 @@ -import 'classifier/logistic_regressor_integration_test.dart' as logistic_regressor_integration_test; +import 'classifier/logistic_regressor_integration_test.dart' + as logistic_regressor_integration_test; import 'classifier/logistic_regressor_test.dart' as logistic_regressor_test; import 'cost_function/cost_function_test.dart' as cost_function_test; -import 'data_preprocessing/categorical_encoder/category_values_extractor_impl_test.dart' as cat_value_extractor_test; -import 'data_preprocessing/categorical_encoder/one_hot_encoder_test.dart' as one_hot_encoder_test; -import 'data_preprocessing/categorical_encoder/ordinal_encoder_test.dart' as ordinal_encoder_test; -import 'data_preprocessing/ml_data/csv_ml_data_integration_test.dart' as csv_ml_data_integration_test; -import 'data_preprocessing/ml_data/csv_ml_data_with_categories_integration_test.dart' as csv_ml_data_with_cat_test; -import 'data_preprocessing/ml_data/ml_data_encoders_processor_impl_test.dart' as ml_data_enc_preprocessor_test; -import 'data_preprocessing/ml_data/ml_data_features_extractor_impl_test.dart' as ml_data_feature_extractor_test; -import 'data_preprocessing/ml_data/ml_data_labels_extractor_impl_test.dart' as ml_data_labels_extractor_test; -import 'data_preprocessing/ml_data/ml_data_params_validator_impl_test.dart' as ml_data_params_validator_test; -import 'data_preprocessing/ml_data/ml_data_read_mask_creator_impl_test.dart' as ml_data_read_mask_creator_test; -import 'data_preprocessing/intercept_preprocessor_test.dart' as intercept_preprocessor_test; +import 'data_preprocessing/categorical_encoder/category_values_extractor_impl_test.dart' + as cat_value_extractor_test; +import 'data_preprocessing/categorical_encoder/one_hot_encoder_test.dart' + as one_hot_encoder_test; +import 'data_preprocessing/categorical_encoder/ordinal_encoder_test.dart' + as ordinal_encoder_test; +import 'data_preprocessing/intercept_preprocessor_test.dart' + as intercept_preprocessor_test; +import 'data_preprocessing/ml_data/csv_ml_data_integration_test.dart' + as csv_ml_data_integration_test; +import 'data_preprocessing/ml_data/csv_ml_data_with_categories_integration_test.dart' + as csv_ml_data_with_cat_test; +import 'data_preprocessing/ml_data/ml_data_encoders_processor_impl_test.dart' + as ml_data_enc_preprocessor_test; +import 'data_preprocessing/ml_data/ml_data_features_extractor_impl_test.dart' + as ml_data_feature_extractor_test; +import 'data_preprocessing/ml_data/ml_data_labels_extractor_impl_test.dart' + as ml_data_labels_extractor_test; +import 'data_preprocessing/ml_data/ml_data_params_validator_impl_test.dart' + as ml_data_params_validator_test; +import 'data_preprocessing/ml_data/ml_data_read_mask_creator_impl_test.dart' + as ml_data_read_mask_creator_test; import 'data_splitter/data_splitter_test.dart' as data_splitter_test; import 'math/randomizer_test.dart' as randomizer_test; -import 'optimizer/coordinate/coordinate_optimizer_integration_test.dart' as coord_optimizer_integration_test; -import 'optimizer/gradient/gradient_optimizer_integration_test.dart' as gradient_optimizer_integration_test; -import 'optimizer/gradient/gradient_optimizer_test.dart' as gradient_optimizer_test; -import 'score_to_prob_link_function/link_function_test.dart' as link_function_test; +import 'optimizer/coordinate/coordinate_optimizer_integration_test.dart' + as coord_optimizer_integration_test; +import 'optimizer/gradient/gradient_optimizer_integration_test.dart' + as gradient_optimizer_integration_test; +import 'optimizer/gradient/gradient_optimizer_test.dart' + as gradient_optimizer_test; +import 'score_to_prob_mapper/score_to_prob_mapper_test.dart' + as score_to_prob_mapper_test; void main() { logistic_regressor_integration_test.main(); @@ -39,5 +55,5 @@ void main() { coord_optimizer_integration_test.main(); gradient_optimizer_integration_test.main(); gradient_optimizer_test.main(); - link_function_test.main(); + score_to_prob_mapper_test.main(); } diff --git a/test/test_utils/mocks.dart b/test/test_utils/mocks.dart index 99c7c218..0c0c0b97 100644 --- a/test/test_utils/mocks.dart +++ b/test/test_utils/mocks.dart @@ -1,7 +1,5 @@ import 'package:logging/logging.dart'; import 'package:ml_algo/learning_rate_type.dart'; -import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator.dart'; -import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator_factory.dart'; import 'package:ml_algo/src/classifier/labels_processor/labels_processor.dart'; import 'package:ml_algo/src/classifier/labels_processor/labels_processor_factory.dart'; import 'package:ml_algo/src/cost_function/cost_function.dart'; @@ -14,9 +12,6 @@ import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_ import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor_factory.dart'; import 'package:ml_algo/src/data_preprocessing/ml_data/validator/ml_data_params_validator.dart'; import 'package:ml_algo/src/data_preprocessing/ml_data/value_converter/value_converter.dart'; -import 'package:ml_algo/src/link_function/link_function.dart'; -import 'package:ml_algo/src/link_function/link_function_factory.dart'; -import 'package:ml_algo/src/link_function/link_function_type.dart'; import 'package:ml_algo/src/math/randomizer/randomizer.dart'; import 'package:ml_algo/src/math/randomizer/randomizer_factory.dart'; import 'package:ml_algo/src/optimizer/gradient/learning_rate_generator/learning_rate_generator.dart'; @@ -27,6 +22,9 @@ import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_ import 'package:ml_algo/src/optimizer/optimizer.dart'; import 'package:ml_algo/src/optimizer/optimizer_factory.dart'; import 'package:ml_algo/src/optimizer/optimizer_type.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_factory.dart'; +import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart'; import 'package:mockito/mockito.dart'; class EncoderMock extends Mock implements CategoricalDataEncoder {} @@ -66,9 +64,10 @@ class InitialWeightsGeneratorFactoryMock extends Mock class InitialWeightsGeneratorMock extends Mock implements InitialWeightsGenerator {} -class LinkFunctionMock extends Mock implements LinkFunction {} +class ScoreToProbMapperMock extends Mock implements ScoreToProbMapper {} -class LinkFunctionFactoryMock extends Mock implements LinkFunctionFactory {} +class ScoreToProbMapperFactoryMock extends Mock + implements ScoreToProbMapperFactory {} class LabelsProcessorFactoryMock extends Mock implements LabelsProcessorFactory {} @@ -80,12 +79,6 @@ class InterceptPreprocessorFactoryMock extends Mock class InterceptPreprocessorMock extends Mock implements InterceptPreprocessor {} -class LabelsProbabilityCalculatorFactoryMock extends Mock - implements LabelsProbabilityCalculatorFactory {} - -class LabelsProbabilityCalculatorMock extends Mock - implements LabelsProbabilityCalculator {} - class OptimizerFactoryMock extends Mock implements OptimizerFactory {} class OptimizerMock extends Mock implements Optimizer {} @@ -137,12 +130,13 @@ InitialWeightsGeneratorFactoryMock createInitialWeightsGeneratorFactoryMock({ return factory; } -LinkFunctionFactoryMock createLinkFunctionFactoryMock({ - Map linkFunctions, +ScoreToProbMapperFactoryMock createScoreToProbMapperFactoryMock( + Type dtype, { + Map mappers, }) { - final factory = LinkFunctionFactoryMock(); - linkFunctions.forEach((LinkFunctionType type, LinkFunction fn) { - when(factory.fromType(type)).thenReturn(fn); + final factory = ScoreToProbMapperFactoryMock(); + mappers.forEach((ScoreToProbMapperType type, ScoreToProbMapper fn) { + when(factory.fromType(type, dtype)).thenReturn(fn); }); return factory; } @@ -165,17 +159,6 @@ LabelsProcessorFactoryMock createLabelsProcessorFactoryMock({ return factory; } -LabelsProbabilityCalculatorFactoryMock - createLabelsProbabilityCalculatorFactoryMock({ - LinkFunctionType linkType, - Type dtype, - LabelsProbabilityCalculator calculator, -}) { - final factory = LabelsProbabilityCalculatorFactoryMock(); - when(factory.create(linkType, dtype)).thenReturn(calculator); - return factory; -} - OptimizerFactoryMock createOptimizerFactoryMock({ Map optimizers, }) { @@ -193,7 +176,7 @@ OptimizerFactoryMock createOptimizerFactoryMock({ costFunctionType: anyNamed('costFunctionType'), learningRateType: anyNamed('learningRateType'), initialWeightsType: anyNamed('initialWeightsType'), - linkFunctionType: anyNamed('linkFunctionType'), + scoreToProbMapperType: anyNamed('scoreToProbMapperType'), initialLearningRate: anyNamed('initialLearningRate'), minCoefficientsUpdate: anyNamed('minCoefficientsUpdate'), iterationLimit: anyNamed('iterationLimit'),