Skip to content

Commit

Permalink
LinearRegressor, default constructor: collectLearningData parameter a…
Browse files Browse the repository at this point in the history
…dded; dtype parameter passed wherever it's needed
  • Loading branch information
gyrdym committed Jun 9, 2020
1 parent edf7827 commit e1d8d17
Show file tree
Hide file tree
Showing 55 changed files with 1,158 additions and 840 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
@@ -1,5 +1,9 @@
# Changelog

# 13.9.0
- `LinearRegressor`:
- `Default constructor`: `collectLearningData` parameter added

## 13.8.1
- `ml_dataframe` dependency updated
- `xrange` dependency constrain removed
Expand Down
Expand Up @@ -3,6 +3,8 @@ import 'package:ml_algo/src/cost_function/cost_function_type.dart';
import 'package:ml_algo/src/di/dependencies.dart';
import 'package:ml_algo/src/helpers/add_intercept_if.dart';
import 'package:ml_algo/src/helpers/features_target_split.dart';
import 'package:ml_algo/src/helpers/normalize_class_labels.dart';
import 'package:ml_algo/src/helpers/validate_class_labels.dart';
import 'package:ml_algo/src/linear_optimizer/gradient_optimizer/learning_rate_generator/learning_rate_type.dart';
import 'package:ml_algo/src/linear_optimizer/initial_coefficients_generator/initial_coefficients_type.dart';
import 'package:ml_algo/src/linear_optimizer/linear_optimizer.dart';
Expand Down Expand Up @@ -33,30 +35,31 @@ LinearOptimizer createLogLikelihoodOptimizer(
LearningRateType learningRateType,
InitialCoefficientsType initialWeightsType,
Matrix initialWeights,
num positiveLabel,
num negativeLabel,
DType dtype,
}) {
final splits = featuresTargetSplit(fittingData,
targetNames: targetNames,
).toList();
validateClassLabels(positiveLabel, negativeLabel);

final splits = featuresTargetSplit(fittingData, targetNames: targetNames)
.toList();
final points = splits[0].toMatrix(dtype);
final labels = splits[1].toMatrix(dtype);

final optimizerFactory = dependencies
.getDependency<LinearOptimizerFactory>();

final costFunctionFactory = dependencies
.getDependency<CostFunctionFactory>();

final optimizerFactory = dependencies.getDependency<LinearOptimizerFactory>();
final costFunctionFactory = dependencies.getDependency<CostFunctionFactory>();
final costFunction = costFunctionFactory.createByType(
CostFunctionType.logLikelihood,
linkFunction: linkFunction,
positiveLabel: positiveLabel,
negativeLabel: negativeLabel,
);
final normalizedLabels = normalizeClassLabels(labels,
positiveLabel, negativeLabel);

return optimizerFactory.createByType(
optimizerType,
addInterceptIf(fitIntercept, points, interceptScale, dtype),
labels,
normalizedLabels,
costFunction: costFunction,
iterationLimit: iterationsLimit,
initialLearningRate: initialLearningRate,
Expand Down
1 change: 1 addition & 0 deletions lib/src/classifier/_mixins/linear_classifier_mixin.dart
Expand Up @@ -23,6 +23,7 @@ mixin LinearClassifierMixin implements LinearClassifier {
fitIntercept,
testFeatures.toMatrix(dtype),
interceptScale,
dtype,
);

validateCoefficientsMatrix(coefficientsByClasses,
Expand Down
@@ -1,9 +1,9 @@
import 'package:ml_algo/src/classifier/_helpers/log_likelihood_optimizer_factory.dart';
import 'package:ml_algo/src/classifier/_helpers/create_log_likelihood_optimizer.dart';
import 'package:ml_algo/src/classifier/logistic_regressor/logistic_regressor.dart';
import 'package:ml_algo/src/classifier/logistic_regressor/logistic_regressor_factory.dart';
import 'package:ml_algo/src/classifier/logistic_regressor/logistic_regressor_impl.dart';
import 'package:ml_algo/src/di/dependencies.dart';
import 'package:ml_algo/src/helpers/validate_class_labels.dart';
import 'package:ml_algo/src/helpers/validate_initial_coefficients.dart';
import 'package:ml_algo/src/helpers/validate_probability_threshold.dart';
import 'package:ml_algo/src/helpers/validate_train_data.dart';
import 'package:ml_algo/src/linear_optimizer/gradient_optimizer/learning_rate_generator/learning_rate_type.dart';
import 'package:ml_algo/src/linear_optimizer/initial_coefficients_generator/initial_coefficients_type.dart';
Expand Down Expand Up @@ -36,10 +36,11 @@ LogisticRegressor createLogisticRegressor(
Vector initialCoefficients,
num positiveLabel,
num negativeLabel,
bool collectLearningData,
DType dtype,
) {
validateProbabilityThreshold(probabilityThreshold);
validateTrainData(trainData, [targetName]);
validateClassLabels(positiveLabel, negativeLabel);

if (initialCoefficients.isNotEmpty) {
validateInitialCoefficients(initialCoefficients, fitIntercept,
Expand All @@ -48,7 +49,6 @@ LogisticRegressor createLogisticRegressor(

final linkFunction = dependencies.getDependency<LinkFunction>(
dependencyName: dTypeToInverseLogitLinkFunctionToken[dtype]);

final optimizer = createLogLikelihoodOptimizer(
trainData,
[targetName],
Expand All @@ -66,28 +66,31 @@ LogisticRegressor createLogisticRegressor(
fitIntercept: fitIntercept,
interceptScale: interceptScale,
isFittingDataNormalized: isFittingDataNormalized,
positiveLabel: positiveLabel,
negativeLabel: negativeLabel,
dtype: dtype,
);

final coefficientsByClasses = optimizer.findExtrema(
initialCoefficients: initialCoefficients.isNotEmpty
? Matrix.fromColumns([initialCoefficients], dtype: dtype)
: null,
isMinimizingObjective: false,
collectLearningData: collectLearningData,
);
final costPerIteration = optimizer.costPerIteration.isNotEmpty
? optimizer.costPerIteration
: null;

final regressorFactory = dependencies
.getDependency<LogisticRegressorFactory>();

return regressorFactory.create(
targetName,
return LogisticRegressorImpl(
[targetName],
linkFunction,
probabilityThreshold,
fitIntercept,
interceptScale,
coefficientsByClasses,
probabilityThreshold,
negativeLabel,
positiveLabel,
costPerIteration,
dtype,
);
}
Expand Up @@ -5,5 +5,6 @@ import 'package:ml_algo/src/classifier/logistic_regressor/logistic_regressor_imp

LogisticRegressor createLogisticRegressorFromJson(String json) {
final decoded = jsonDecode(json) as Map<String, dynamic>;

return LogisticRegressorImpl.fromJson(decoded);
}
15 changes: 13 additions & 2 deletions lib/src/classifier/logistic_regressor/logistic_regressor.dart
Expand Up @@ -13,10 +13,10 @@ import 'package:ml_linalg/vector.dart';

/// Logistic regression-based classification.
///
/// Logistic regression is an algorithm that solves a binary classification
/// Logistic regression is an algorithm that solves the binary classification
/// problem. The algorithm uses maximization of the passed data likelihood.
/// In other words, the regressor iteratively tries to select coefficients
/// that makes combination of passed features and their coefficients most
/// that makes combination of passed features and these coefficients most
/// likely.
abstract class LogisticRegressor implements
LinearClassifier, Assessable, Serializable {
Expand Down Expand Up @@ -112,6 +112,11 @@ abstract class LogisticRegressor implements
/// [negativeLabel] Defines the value, that will be used for `negative` class.
/// By default, `0`.
///
/// [collectLearningData] Whether or not to collect learning data, for
/// instance cost function value per each iteration. Affects performance much.
/// If [collectLearningData] is true, one may access [costPerIteration]
/// getter in order to evaluate learning process more thoroughly.
///
/// [dtype] A data type for all the numeric values, used by the algorithm. Can
/// affect performance or accuracy of the computations. Default value is
/// [DType.float32]
Expand All @@ -136,6 +141,7 @@ abstract class LogisticRegressor implements
Vector initialCoefficients,
num positiveLabel = 1,
num negativeLabel = 0,
bool collectLearningData = false,
DType dtype = DType.float32,
}) => createLogisticRegressor(
trainData,
Expand All @@ -157,6 +163,7 @@ abstract class LogisticRegressor implements
initialCoefficients ?? Vector.empty(dtype: dtype),
positiveLabel,
negativeLabel,
collectLearningData,
dtype,
);

Expand Down Expand Up @@ -200,4 +207,8 @@ abstract class LogisticRegressor implements
/// ````
factory LogisticRegressor.fromJson(String json) =>
createLogisticRegressorFromJson(json);

/// Returns a list of cost values per each learning iteration. Returns null
/// if the parameter `collectLearningData` of the default constructor is false
List<num> get costPerIteration;
}

This file was deleted.

This file was deleted.

36 changes: 21 additions & 15 deletions lib/src/classifier/logistic_regressor/logistic_regressor_impl.dart
Expand Up @@ -3,7 +3,9 @@ import 'package:ml_algo/src/classifier/_mixins/linear_classifier_mixin.dart';
import 'package:ml_algo/src/classifier/logistic_regressor/logistic_regressor.dart';
import 'package:ml_algo/src/classifier/logistic_regressor/logistic_regressor_json_keys.dart';
import 'package:ml_algo/src/common/serializable/serializable_mixin.dart';
import 'package:ml_algo/src/helpers/validate_class_labels.dart';
import 'package:ml_algo/src/helpers/validate_coefficients_matrix.dart';
import 'package:ml_algo/src/helpers/validate_probability_threshold.dart';
import 'package:ml_algo/src/link_function/helpers/from_link_function_json.dart';
import 'package:ml_algo/src/link_function/helpers/link_function_to_json.dart';
import 'package:ml_algo/src/link_function/link_function.dart';
Expand All @@ -15,7 +17,6 @@ import 'package:ml_linalg/from_dtype_json.dart';
import 'package:ml_linalg/from_matrix_json.dart';
import 'package:ml_linalg/matrix.dart';
import 'package:ml_linalg/matrix_to_json.dart';
import 'package:ml_linalg/vector.dart';

part 'logistic_regressor_impl.g.dart';

Expand All @@ -37,8 +38,11 @@ class LogisticRegressorImpl
this.probabilityThreshold,
this.negativeLabel,
this.positiveLabel,
this.costPerIteration,
this.dtype,
) {
validateProbabilityThreshold(probabilityThreshold);
validateClassLabels(positiveLabel, negativeLabel);
validateCoefficientsMatrix(coefficientsByClasses);

// Logistic regression specific check, it cannot be placed in
Expand Down Expand Up @@ -105,23 +109,25 @@ class LogisticRegressorImpl
final LinkFunction linkFunction;

@override
DataFrame predict(DataFrame testFeatures) {
final probabilities = getProbabilitiesMatrix(testFeatures).getColumn(0);

final classesList = probabilities
// TODO: use SIMD
.map((value) => value >= probabilityThreshold
? positiveLabel
: negativeLabel,
)
.toList(growable: false);
@JsonKey(
name: logisticRegressorCostPerIterationJsonKey,
includeIfNull: false,
)
final List<num> costPerIteration;

final classesMatrix = Matrix.fromColumns([
Vector.fromList(classesList),
]);
@override
DataFrame predict(DataFrame testFeatures) {
final predictedLabels = getProbabilitiesMatrix(testFeatures)
.mapColumns(
(column) => column.mapToVector(
(probability) => probability >= probabilityThreshold
? positiveLabel.toDouble()
: negativeLabel.toDouble()
),
);

return DataFrame.fromMatrix(
classesMatrix,
predictedLabels,
header: classNames,
);
}
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Expand Up @@ -7,3 +7,4 @@ const logisticRegressorProbabilityThresholdJsonKey = 'PT';
const logisticRegressorPositiveLabelJsonKey = 'PL';
const logisticRegressorNegativeLabelJsonKey = 'NL';
const logisticRegressorLinkFunctionJsonKey = 'LF';
const logisticRegressorCostPerIterationJsonKey = 'CPI';

0 comments on commit e1d8d17

Please sign in to comment.