Skip to content

Commit

Permalink
LinearRegressor: collect learning data parameter added to default con…
Browse files Browse the repository at this point in the history
…structor
  • Loading branch information
gyrdym committed Jun 20, 2020
1 parent e1d8d17 commit ba56110
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 27 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
@@ -1,9 +1,13 @@
# Changelog

# 13.9.0
## 13.10.0
- `LinearRegressor`:
- `Default constructor`: `collectLearningData` parameter added

## 13.9.0
- `LogisticRegressor`:
- `Default constructor`: `collectLearningData` parameter added

## 13.8.1
- `ml_dataframe` dependency updated
- `xrange` dependency constrain removed
Expand Down
13 changes: 13 additions & 0 deletions lib/src/regressor/linear_regressor/linear_regressor.dart
Expand Up @@ -94,6 +94,11 @@ abstract class LinearRegressor implements Assessable, Serializable, Predictor {
/// optimization algorithm. [initialCoefficients] should have length that is
/// equal to the number of features in the [fittingData].
///
/// [collectLearningData] Whether or not to collect learning data, for
/// instance cost function value per each iteration. Affects performance much.
/// If [collectLearningData] is true, one may access [costPerIteration]
/// getter in order to evaluate learning process more thoroughly.
///
/// [dtype] A data type for all the numeric values, used by the algorithm. Can
/// affect performance or accuracy of the computations. Default value is
/// [DType.float32].
Expand All @@ -113,6 +118,7 @@ abstract class LinearRegressor implements Assessable, Serializable, Predictor {
int batchSize = 1,
Matrix initialCoefficients,
bool isFittingDataNormalized = false,
bool collectLearningData = false,
DType dtype = DType.float32,
}) {
final optimizer = createSquaredCostOptimizer(
Expand All @@ -137,13 +143,16 @@ abstract class LinearRegressor implements Assessable, Serializable, Predictor {
final coefficients = optimizer.findExtrema(
initialCoefficients: initialCoefficients,
isMinimizingObjective: true,
collectLearningData: collectLearningData,
).getColumn(0);
final costPerIteration = optimizer.costPerIteration;

return LinearRegressorImpl(
coefficients,
targetName,
fitIntercept: fitIntercept,
interceptScale: interceptScale,
costPerIteration: costPerIteration,
dtype: dtype,
);
}
Expand Down Expand Up @@ -205,4 +214,8 @@ abstract class LinearRegressor implements Assessable, Serializable, Predictor {
/// A value defining a size of the intercept if [fitIntercept] is
/// `true`
num get interceptScale;

/// Returns a list of cost values per each learning iteration. Returns null
/// if the parameter `collectLearningData` of the default constructor is false
List<num> get costPerIteration;
}
8 changes: 8 additions & 0 deletions lib/src/regressor/linear_regressor/linear_regressor_impl.dart
Expand Up @@ -25,6 +25,7 @@ class LinearRegressorImpl
LinearRegressorImpl(this.coefficients, this.targetName, {
bool fitIntercept = false,
double interceptScale = 1.0,
this.costPerIteration,
this.dtype = DType.float32,
}) :
fitIntercept = fitIntercept,
Expand Down Expand Up @@ -56,6 +57,13 @@ class LinearRegressorImpl
)
final Vector coefficients;

@override
@JsonKey(
name: linearRegressorCostPerIterationJsonKey,
includeIfNull: false,
)
final List<num> costPerIteration;

@override
@JsonKey(
name: linearRegressorDTypeJsonKey,
Expand Down
32 changes: 22 additions & 10 deletions lib/src/regressor/linear_regressor/linear_regressor_impl.g.dart

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Expand Up @@ -3,3 +3,4 @@ const linearRegressorFitInterceptJsonKey = 'FI';
const linearRegressorInterceptScaleJsonKey = 'IS';
const linearRegressorCoefficientsJsonKey = 'CS';
const linearRegressorDTypeJsonKey = 'DT';
const linearRegressorCostPerIterationJsonKey = 'CPI';
2 changes: 1 addition & 1 deletion pubspec.yaml
@@ -1,6 +1,6 @@
name: ml_algo
description: Machine learning algorithms written in native dart
version: 13.9.0
version: 13.10.0
homepage: https://github.com/gyrdym/ml_algo

environment:
Expand Down
5 changes: 5 additions & 0 deletions test/regressor/linear_regressor_integration_test.dart
Expand Up @@ -32,6 +32,7 @@ void main() {
fitIntercept: fitIntercept,
interceptScale: interceptScale,
iterationsLimit: iterationsLimit,
collectLearningData: true,
dtype: dtype,
);
});
Expand All @@ -53,6 +54,7 @@ void main() {
linearRegressorInterceptScaleJsonKey: interceptScale,
linearRegressorCoefficientsJsonKey: regressor.coefficients.toJson(),
linearRegressorDTypeJsonKey: dTypeToJson(dtype),
linearRegressorCostPerIterationJsonKey: regressor.costPerIteration,
});
});

Expand All @@ -75,6 +77,8 @@ void main() {
test('should restore from json file', () async {
await regressor.saveAsJson(filePath);

print(regressor.costPerIteration);

final file = File(filePath);
final json = await file.readAsString();
final restoredModel = LinearRegressor.fromJson(json);
Expand All @@ -84,6 +88,7 @@ void main() {
expect(restoredModel.fitIntercept, regressor.fitIntercept);
expect(restoredModel.coefficients, regressor.coefficients);
expect(restoredModel.targetName, regressor.targetName);
expect(restoredModel.costPerIteration, regressor.costPerIteration);
});
});
}
52 changes: 37 additions & 15 deletions test/regressor/linear_regressor_test.dart
Expand Up @@ -20,18 +20,10 @@ import '../mocks.dart';

void main() {
group('LinearRegressor', () {
final initialCoefficients = Matrix.fromList([
[1],
[2],
[3],
[4],
[5],
]);

final initialCoefficients = Matrix.column([1, 2, 3, 4, 5]);
final learnedCoefficients = Matrix.fromColumns([
Vector.fromList([55, 66, 77, 88, 99]),
]);

final observations = DataFrame(
[
<num>[10, 20, 30, 40, 200],
Expand All @@ -40,17 +32,16 @@ void main() {
header: ['feature_1', 'feature_2', 'feature_3', 'feature_4', 'target'],
headerExists: false,
);
final costPerIteration = [1, 2, 3, 100];

CostFunction costFunctionMock;
CostFunctionFactory costFunctionFactoryMock;

LinearOptimizer linearOptimizerMock;
LinearOptimizerFactory linearOptimizerFactoryMock;

setUp(() {
costFunctionMock = CostFunctionMock();
costFunctionFactoryMock = createCostFunctionFactoryMock(costFunctionMock);

linearOptimizerMock = LinearOptimizerMock();
linearOptimizerFactoryMock = createLinearOptimizerFactoryMock(
linearOptimizerMock);
Expand All @@ -64,7 +55,9 @@ void main() {
when(linearOptimizerMock.findExtrema(
initialCoefficients: anyNamed('initialCoefficients'),
isMinimizingObjective: anyNamed('isMinimizingObjective'),
collectLearningData: anyNamed('collectLearningData'),
)).thenReturn(learnedCoefficients);
when(linearOptimizerMock.costPerIteration).thenReturn(costPerIteration);
});

tearDownAll(() => injector = null);
Expand Down Expand Up @@ -150,27 +143,56 @@ void main() {
fitIntercept: true,
interceptScale: 2.0,
);

final features = Matrix.fromList([
[55, 44, 33, 22],
[10, 88, 77, 11],
[12, 22, 39, 13],
]);

final featuresWithIntercept = Matrix.fromColumns([
Vector.filled(3, 2),
...features.columns,
]);

final prediction = predictor.predict(
DataFrame.fromMatrix(features),
);

expect(prediction.header, equals(['target']));

expect(prediction.toMatrix(), equals(
featuresWithIntercept * learnedCoefficients,
));
});

test('should collect cost values per iteration if collectLearningData is '
'true', () {
final regressor = LinearRegressor(
observations,
'target',
initialCoefficients: initialCoefficients,
collectLearningData: true,
);

expect(regressor.costPerIteration, same(costPerIteration));
verify(linearOptimizerMock.findExtrema(
initialCoefficients: anyNamed('initialCoefficients'),
isMinimizingObjective: anyNamed('isMinimizingObjective'),
collectLearningData: true,
)).called(1);
});

test('should not collect cost values per iteration if collectLearningData is '
'false', () {
LinearRegressor(
observations,
'target',
initialCoefficients: initialCoefficients,
collectLearningData: false,
);

verify(linearOptimizerMock.findExtrema(
initialCoefficients: anyNamed('initialCoefficients'),
isMinimizingObjective: anyNamed('isMinimizingObjective'),
collectLearningData: false,
)).called(1);
});
});
}

0 comments on commit ba56110

Please sign in to comment.