Skip to content

Commit

Permalink
Merge pull request #75 from gyrdym/softmax-regression
Browse files Browse the repository at this point in the history
Link function refactored
  • Loading branch information
gyrdym committed Feb 11, 2019
2 parents 7d5a9ae + 3a44ee7 commit 9a26e0f
Show file tree
Hide file tree
Showing 32 changed files with 221 additions and 250 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## 6.1.0
- `LinkFunction` renamed to `ScoreToProbMapper`
- `ScoreToProbMapper` accepts vector and returns vector instead of a scalar

## 6.0.6
- Pedantic package integration added
- Some linter issues fixed
Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

1 change: 1 addition & 0 deletions lib/src/classifier/linear_classifier.dart
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ abstract class LinearClassifier implements Classifier {
Type dtype,
}) = LogisticRegressor;

factory LinearClassifier.softMaxRegressor() => throw UnimplementedError();
factory LinearClassifier.SVM() => throw UnimplementedError();
factory LinearClassifier.naiveBayes() => throw UnimplementedError();
}
24 changes: 12 additions & 12 deletions lib/src/classifier/logistic_regressor.dart
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import 'package:ml_algo/gradient_type.dart';
import 'package:ml_algo/learning_rate_type.dart';
import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator.dart';
import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator_factory.dart';
import 'package:ml_algo/src/classifier/labels_probability_calculator/labels_probability_calculator_factory_impl.dart';
import 'package:ml_algo/src/classifier/labels_processor/labels_processor.dart';
import 'package:ml_algo/src/classifier/labels_processor/labels_processor_factory.dart';
import 'package:ml_algo/src/classifier/labels_processor/labels_processor_factory_impl.dart';
Expand All @@ -12,7 +9,6 @@ import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_
import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor_factory.dart';
import 'package:ml_algo/src/data_preprocessing/intercept_preprocessor/intercept_preprocessor_factory_impl.dart';
import 'package:ml_algo/src/default_parameter_values.dart';
import 'package:ml_algo/src/link_function/link_function_type.dart';
import 'package:ml_algo/src/metric/factory.dart';
import 'package:ml_algo/src/metric/metric_type.dart';
import 'package:ml_algo/src/optimizer/gradient/batch_size_calculator/batch_size_calculator.dart';
Expand All @@ -22,6 +18,10 @@ import 'package:ml_algo/src/optimizer/optimizer.dart';
import 'package:ml_algo/src/optimizer/optimizer_factory.dart';
import 'package:ml_algo/src/optimizer/optimizer_factory_impl.dart';
import 'package:ml_algo/src/optimizer/optimizer_type.dart';
import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper.dart';
import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_factory.dart';
import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_factory_impl.dart';
import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart';
import 'package:ml_linalg/matrix.dart';
import 'package:ml_linalg/vector.dart';

Expand All @@ -30,7 +30,7 @@ class LogisticRegressor implements LinearClassifier {
final Optimizer optimizer;
final InterceptPreprocessor interceptPreprocessor;
final LabelsProcessor labelsProcessor;
final LabelsProbabilityCalculator probabilityCalculator;
final ScoreToProbMapper scoreToProbMapper;

LogisticRegressor({
// public arguments
Expand All @@ -46,28 +46,28 @@ class LogisticRegressor implements LinearClassifier {
GradientType gradientType = GradientType.stochastic,
LearningRateType learningRateType = LearningRateType.constant,
InitialWeightsType initialWeightsType = InitialWeightsType.zeroes,
LinkFunctionType linkFunctionType = LinkFunctionType.logit,
ScoreToProbMapperType scoreToProbMapperType = ScoreToProbMapperType.logit,
this.dtype = DefaultParameterValues.dtype,

// private arguments
LabelsProcessorFactory labelsProcessorFactory =
const LabelsProcessorFactoryImpl(),
InterceptPreprocessorFactory interceptPreprocessorFactory =
const InterceptPreprocessorFactoryImpl(),
LabelsProbabilityCalculatorFactory probabilityCalculatorFactory =
const LabelsProbabilityCalculatorFactoryImpl(),
ScoreToProbMapperFactory scoreToProbMapperFactory =
const ScoreToProbMapperFactoryImpl(),
OptimizerFactory optimizerFactory = const OptimizerFactoryImpl(),
BatchSizeCalculator batchSizeCalculator = const BatchSizeCalculatorImpl(),
}) : labelsProcessor = labelsProcessorFactory.create(dtype),
interceptPreprocessor = interceptPreprocessorFactory.create(dtype,
scale: fitIntercept ? interceptScale : 0.0),
probabilityCalculator =
probabilityCalculatorFactory.create(linkFunctionType, dtype),
scoreToProbMapper =
scoreToProbMapperFactory.fromType(scoreToProbMapperType, dtype),
optimizer = optimizerFactory.fromType(
optimizer,
dtype: dtype,
costFunctionType: CostFunctionType.logLikelihood,
linkFunctionType: linkFunctionType,
scoreToProbMapperType: scoreToProbMapperType,
learningRateType: learningRateType,
initialWeightsType: initialWeightsType,
initialLearningRate: initialLearningRate,
Expand Down Expand Up @@ -136,7 +136,7 @@ class LogisticRegressor implements LinearClassifier {
int i = 0;
_weightsByClasses.forEach((double label, MLVector weights) {
final scores = (processedFeatures * weights).toVector();
distributions[i++] = probabilityCalculator.getProbabilities(scores);
distributions[i++] = scoreToProbMapper.linkScoresToProbs(scores);
});
return MLMatrix.columns(distributions, dtype: dtype);
}
Expand Down
7 changes: 4 additions & 3 deletions lib/src/cost_function/cost_function_factory.dart
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import 'package:ml_algo/src/cost_function/cost_function.dart';
import 'package:ml_algo/src/cost_function/cost_function_type.dart';
import 'package:ml_algo/src/link_function/link_function_type.dart';
import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart';

abstract class CostFunctionFactory {
CostFunction fromType(CostFunctionType type,
{Type dtype, LinkFunctionType linkFunctionType});
{Type dtype, ScoreToProbMapperType scoreToProbMapperType});
CostFunction squared();
CostFunction logLikelihood(LinkFunctionType linkFunctionType, {Type dtype});
CostFunction logLikelihood(ScoreToProbMapperType scoreToProbMapperType,
{Type dtype});
}
10 changes: 5 additions & 5 deletions lib/src/cost_function/cost_function_factory_impl.dart
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import 'package:ml_algo/src/cost_function/cost_function_type.dart';
import 'package:ml_algo/src/cost_function/log_likelihood.dart';
import 'package:ml_algo/src/cost_function/squared.dart';
import 'package:ml_algo/src/default_parameter_values.dart';
import 'package:ml_algo/src/link_function/link_function_type.dart';
import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart';

class CostFunctionFactoryImpl implements CostFunctionFactory {
const CostFunctionFactoryImpl();
Expand All @@ -13,17 +13,17 @@ class CostFunctionFactoryImpl implements CostFunctionFactory {
CostFunction squared() => SquaredCost();

@override
CostFunction logLikelihood(LinkFunctionType linkFunctionType,
CostFunction logLikelihood(ScoreToProbMapperType scoreToProbMapperType,
{Type dtype = DefaultParameterValues.dtype}) =>
LogLikelihoodCost(linkFunctionType, dtype: dtype);
LogLikelihoodCost(scoreToProbMapperType, dtype: dtype);

@override
CostFunction fromType(CostFunctionType type,
{Type dtype = DefaultParameterValues.dtype,
LinkFunctionType linkFunctionType}) {
ScoreToProbMapperType scoreToProbMapperType}) {
switch (type) {
case CostFunctionType.logLikelihood:
return logLikelihood(linkFunctionType, dtype: dtype);
return logLikelihood(scoreToProbMapperType, dtype: dtype);
case CostFunctionType.squared:
return squared();
default:
Expand Down
21 changes: 12 additions & 9 deletions lib/src/cost_function/log_likelihood.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,23 @@ import 'dart:typed_data';

import 'package:ml_algo/src/cost_function/cost_function.dart';
import 'package:ml_algo/src/default_parameter_values.dart';
import 'package:ml_algo/src/link_function/link_function.dart';
import 'package:ml_algo/src/link_function/link_function_factory.dart';
import 'package:ml_algo/src/link_function/link_function_factory_impl.dart';
import 'package:ml_algo/src/link_function/link_function_type.dart';
import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper.dart';
import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_factory.dart';
import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_factory_impl.dart';
import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart';
import 'package:ml_linalg/linalg.dart';

class LogLikelihoodCost implements CostFunction {
final LinkFunction linkFunction;
final ScoreToProbMapper scoreToProbMapper;
final Type dtype;

LogLikelihoodCost(
LinkFunctionType linkFunctionType, {
ScoreToProbMapperType scoreToProbMapperType, {
this.dtype = DefaultParameterValues.dtype,
LinkFunctionFactory linkFunctionFactory = const LinkFunctionFactoryImpl(),
}) : linkFunction = linkFunctionFactory.fromType(linkFunctionType);
ScoreToProbMapperFactory scoreToProbMapperFactory =
const ScoreToProbMapperFactoryImpl(),
}) : scoreToProbMapper =
scoreToProbMapperFactory.fromType(scoreToProbMapperType, dtype);

@override
double getCost(double score, double yOrig) {
Expand All @@ -25,10 +27,11 @@ class LogLikelihoodCost implements CostFunction {

@override
MLVector getGradient(MLMatrix x, MLVector w, MLVector y) {
final scores = (x * w).toVector();
switch (dtype) {
case Float32x4:
return (x.transpose() *
(y - (x * w).fastMap<Float32x4>(linkFunction.float32x4Link)))
(y - scoreToProbMapper.linkScoresToProbs(scores)))
.toVector();
default:
throw throw UnsupportedError('Unsupported data type - $dtype');
Expand Down
5 changes: 0 additions & 5 deletions lib/src/link_function/link_function.dart

This file was deleted.

6 changes: 0 additions & 6 deletions lib/src/link_function/link_function_factory.dart

This file was deleted.

18 changes: 0 additions & 18 deletions lib/src/link_function/link_function_factory_impl.dart

This file was deleted.

3 changes: 0 additions & 3 deletions lib/src/link_function/link_function_type.dart

This file was deleted.

21 changes: 0 additions & 21 deletions lib/src/link_function/logit_link_function.dart

This file was deleted.

6 changes: 3 additions & 3 deletions lib/src/optimizer/gradient/gradient.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import 'package:ml_algo/src/cost_function/cost_function_factory.dart';
import 'package:ml_algo/src/cost_function/cost_function_factory_impl.dart';
import 'package:ml_algo/src/cost_function/cost_function_type.dart';
import 'package:ml_algo/src/default_parameter_values.dart';
import 'package:ml_algo/src/link_function/link_function_type.dart';
import 'package:ml_algo/src/math/randomizer/randomizer.dart';
import 'package:ml_algo/src/math/randomizer/randomizer_factory.dart';
import 'package:ml_algo/src/math/randomizer/randomizer_factory_impl.dart';
Expand All @@ -16,6 +15,7 @@ import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_
import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_generator_factory_impl.dart';
import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_type.dart';
import 'package:ml_algo/src/optimizer/optimizer.dart';
import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart';
import 'package:ml_linalg/matrix.dart';
import 'package:ml_linalg/range.dart';
import 'package:ml_linalg/vector.dart';
Expand Down Expand Up @@ -46,7 +46,7 @@ class GradientOptimizer implements Optimizer {
CostFunctionType costFnType,
LearningRateType learningRateType,
InitialWeightsType initialWeightsType,
LinkFunctionType linkFunctionType,
ScoreToProbMapperType scoreToProbMapperType,
double initialLearningRate = DefaultParameterValues.initialLearningRate,
double minWeightsUpdate = DefaultParameterValues.minWeightsUpdate,
int iterationLimit = DefaultParameterValues.iterationsLimit,
Expand All @@ -62,7 +62,7 @@ class GradientOptimizer implements Optimizer {
_learningRateGenerator =
learningRateGeneratorFactory.fromType(learningRateType),
_costFunction = costFunctionFactory.fromType(costFnType,
dtype: dtype, linkFunctionType: linkFunctionType),
dtype: dtype, scoreToProbMapperType: scoreToProbMapperType),
_randomizer = randomizerFactory.create(randomSeed) {
_learningRateGenerator.init(initialLearningRate ?? 1.0);
}
Expand Down
6 changes: 3 additions & 3 deletions lib/src/optimizer/optimizer_factory.dart
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import 'package:ml_algo/learning_rate_type.dart';
import 'package:ml_algo/src/cost_function/cost_function_factory.dart';
import 'package:ml_algo/src/cost_function/cost_function_type.dart';
import 'package:ml_algo/src/link_function/link_function_type.dart';
import 'package:ml_algo/src/math/randomizer/randomizer_factory.dart';
import 'package:ml_algo/src/optimizer/gradient/learning_rate_generator/learning_rate_generator_factory.dart';
import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_generator_factory.dart';
import 'package:ml_algo/src/optimizer/initial_weights_generator/initial_weights_type.dart';
import 'package:ml_algo/src/optimizer/optimizer.dart';
import 'package:ml_algo/src/optimizer/optimizer_type.dart';
import 'package:ml_algo/src/score_to_prob_mapper/score_to_prob_mapper_type.dart';

abstract class OptimizerFactory {
Optimizer fromType(
Expand All @@ -20,7 +20,7 @@ abstract class OptimizerFactory {
CostFunctionType costFunctionType,
LearningRateType learningRateType,
InitialWeightsType initialWeightsType,
LinkFunctionType linkFunctionType,
ScoreToProbMapperType scoreToProbMapperType,
double initialLearningRate,
double minCoefficientsUpdate,
int iterationLimit,
Expand All @@ -38,7 +38,7 @@ abstract class OptimizerFactory {
CostFunctionType costFnType,
LearningRateType learningRateType,
InitialWeightsType initialWeightsType,
LinkFunctionType linkFunctionType,
ScoreToProbMapperType scoreToProbMapperType,
double initialLearningRate,
double minCoefficientsUpdate,
int iterationLimit,
Expand Down
Loading

0 comments on commit 9a26e0f

Please sign in to comment.