Skip to content

Commit

Permalink
collectLearningData argument added to SoftmaxRegressor default constr…
Browse files Browse the repository at this point in the history
…uctor
  • Loading branch information
gyrdym committed Jun 24, 2020
1 parent 470dc08 commit 4636074
Show file tree
Hide file tree
Showing 12 changed files with 74 additions and 15 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
@@ -1,5 +1,9 @@
# Changelog

## 14.2.0
- `SoftmaxRegressor`:
- `Default constructor`: `collectLearningData` parameter added

## 14.1.1
- `README`: Advanced usage example on Logistic regression added

Expand Down
Expand Up @@ -32,6 +32,7 @@ SoftmaxRegressor createSoftmaxRegressor(
Matrix initialCoefficients,
num positiveLabel,
num negativeLabel,
bool collectLearningData,
DType dtype,
) {
if (targetNames.isNotEmpty && targetNames.length < 2) {
Expand Down Expand Up @@ -69,7 +70,9 @@ SoftmaxRegressor createSoftmaxRegressor(
final coefficientsByClasses = optimizer.findExtrema(
initialCoefficients: initialCoefficients,
isMinimizingObjective: false,
collectLearningData: collectLearningData,
);
final costPerIteration = optimizer.costPerIteration;

final regressorFactory = dependencies
.getDependency<SoftmaxRegressorFactory>();
Expand All @@ -82,6 +85,7 @@ SoftmaxRegressor createSoftmaxRegressor(
interceptScale,
positiveLabel,
negativeLabel,
costPerIteration,
dtype,
);
}
11 changes: 11 additions & 0 deletions lib/src/classifier/softmax_regressor/softmax_regressor.dart
Expand Up @@ -105,6 +105,11 @@ abstract class SoftmaxRegressor implements
/// [negativeLabel] Defines the value that will be used for `negative` class.
/// By default, `0`.
///
/// [collectLearningData] Whether or not to collect learning data, for
/// instance cost function value per each iteration. Affects performance much.
/// If [collectLearningData] is true, one may access [costPerIteration]
/// getter in order to evaluate learning process more thoroughly.
///
/// [dtype] A data type for all the numeric values, used by the algorithm. Can
/// affect performance or accuracy of the computations. Default value is
/// [DType.float32]
Expand All @@ -129,6 +134,7 @@ abstract class SoftmaxRegressor implements
Matrix initialCoefficients,
num positiveLabel = 1,
num negativeLabel = 0,
bool collectLearningData = false,
DType dtype = DType.float32,
}
) => createSoftmaxRegressor(
Expand All @@ -150,6 +156,7 @@ abstract class SoftmaxRegressor implements
initialCoefficients,
positiveLabel,
negativeLabel,
collectLearningData,
dtype,
);

Expand Down Expand Up @@ -192,4 +199,8 @@ abstract class SoftmaxRegressor implements
/// ````
factory SoftmaxRegressor.fromJson(String json) =>
createSoftmaxRegressorFromJson(json);

/// Returns a list of cost values per each learning iteration. Returns null
/// if the parameter `collectLearningData` of the default constructor is false
List<num> get costPerIteration;
}
Expand Up @@ -12,6 +12,7 @@ abstract class SoftmaxRegressorFactory {
num interceptScale,
num positiveLabel,
num negativeLabel,
List<num> costPerIteration,
DType dtype,
);
}
Expand Up @@ -17,6 +17,7 @@ class SoftmaxRegressorFactoryImpl implements SoftmaxRegressorFactory {
num interceptScale,
num positiveLabel,
num negativeLabel,
List<num> costPerIteration,
DType dtype,
) => SoftmaxRegressorImpl(
coefficientsByClasses,
Expand All @@ -26,6 +27,7 @@ class SoftmaxRegressorFactoryImpl implements SoftmaxRegressorFactory {
interceptScale,
positiveLabel,
negativeLabel,
costPerIteration,
dtype,
);
}
Expand Up @@ -33,6 +33,7 @@ class SoftmaxRegressorImpl
this.interceptScale,
this.positiveLabel,
this.negativeLabel,
this.costPerIteration,
this.dtype,
) {
validateClassLabels(positiveLabel, negativeLabel);
Expand Down Expand Up @@ -94,6 +95,13 @@ class SoftmaxRegressorImpl
@JsonKey(name: softmaxRegressorNegativeLabelJsonKey)
final num negativeLabel;

@override
@JsonKey(
name: softmaxRegressorCostPerIterationJsonKey,
includeIfNull: false,
)
final List<num> costPerIteration;

@override
DataFrame predict(DataFrame testFeatures) {
final allProbabilities = getProbabilitiesMatrix(testFeatures);
Expand Down
48 changes: 35 additions & 13 deletions lib/src/classifier/softmax_regressor/softmax_regressor_impl.g.dart

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Expand Up @@ -6,3 +6,4 @@ const softmaxRegressorDTypeJsonKey = 'DT';
const softmaxRegressorLinkFunctionJsonKey = 'LF';
const softmaxRegressorPositiveLabelJsonKey = 'PL';
const softmaxRegressorNegativeLabelJsonKey = 'NL';
const softmaxRegressorCostPerIterationJsonKey = 'CPI';
2 changes: 1 addition & 1 deletion pubspec.yaml
@@ -1,6 +1,6 @@
name: ml_algo
description: Machine learning algorithms, Machine learning models performance evaluation functionality
version: 14.1.1
version: 14.2.0
homepage: https://github.com/gyrdym/ml_algo

environment:
Expand Down
Expand Up @@ -22,6 +22,7 @@ void main() {
final interceptScale = 1;
final positiveLabel = 1;
final negativeLabel = -1;
final costPerIteration = [1, 2, 3];
final dtype = DType.float32;

final regressor = factory.create(
Expand All @@ -32,6 +33,7 @@ void main() {
interceptScale,
positiveLabel,
negativeLabel,
costPerIteration,
dtype,
);

Expand Down
Expand Up @@ -16,6 +16,7 @@ void main() {
final interceptScale = 10;
final positiveLabel = 1.0;
final negativeLabel = -1.0;
final costPerIteration = [10, -10, 20, 2.3];
final dtype = DType.float32;

final coefficientsByClasses = Matrix.fromList([
Expand Down Expand Up @@ -91,6 +92,7 @@ void main() {
interceptScale,
positiveLabel,
negativeLabel,
costPerIteration,
dtype,
);
});
Expand All @@ -109,6 +111,7 @@ void main() {
interceptScale,
positiveLabel,
negativeLabel,
costPerIteration,
dtype,
);

Expand All @@ -129,6 +132,7 @@ void main() {
interceptScale,
positiveLabel,
negativeLabel,
costPerIteration,
dtype,
);

Expand Down
2 changes: 1 addition & 1 deletion test/mocks.dart
Expand Up @@ -252,7 +252,7 @@ KnnRegressorFactory createKnnRegressorFactoryMock(KnnRegressor regressor) {
SoftmaxRegressorFactory createSoftmaxRegressorFactoryMock(
SoftmaxRegressor softmaxRegressor) {
final factory = SoftmaxRegressorFactoryMock();
when(factory.create(any, any, any, any, any, any, any, any))
when(factory.create(any, any, any, any, any, any, any, any, any))
.thenReturn(softmaxRegressor);
return factory;
}
Expand Down

0 comments on commit 4636074

Please sign in to comment.