Skip to content

Commit

Permalink
e2e tests task added
Browse files Browse the repository at this point in the history
  • Loading branch information
gyrdym committed Sep 8, 2020
1 parent d38a809 commit 30bf630
Show file tree
Hide file tree
Showing 11 changed files with 16 additions and 26 deletions.
2 changes: 1 addition & 1 deletion lib/src/classifier/_mixins/linear_classifier_mixin.dart
Expand Up @@ -12,7 +12,7 @@ mixin LinearClassifierMixin implements LinearClassifier {

return DataFrame.fromMatrix(
probabilities,
header: classNames,
header: targetNames,
);
}

Expand Down
2 changes: 0 additions & 2 deletions lib/src/classifier/linear_classifier.dart
Expand Up @@ -4,8 +4,6 @@ import 'package:ml_linalg/matrix.dart';

/// An interface for all types of linear classifiers
abstract class LinearClassifier implements Classifier {
Iterable<String> get classNames;

/// A function that is used for converting learned coefficients into
/// probabilities
LinkFunction get linkFunction;
Expand Down
Expand Up @@ -30,7 +30,7 @@ class LogisticRegressorImpl
LogisticRegressor {

LogisticRegressorImpl(
this.classNames,
this.targetNames,
this.linkFunction,
this.fitIntercept,
this.interceptScale,
Expand Down Expand Up @@ -72,9 +72,8 @@ class LogisticRegressorImpl
final Matrix coefficientsByClasses;

@override
@deprecated
@JsonKey(name: logisticRegressorClassNamesJsonKey)
final Iterable<String> classNames;
final Iterable<String> targetNames;

@override
@JsonKey(name: logisticRegressorFitInterceptJsonKey)
Expand Down Expand Up @@ -118,9 +117,6 @@ class LogisticRegressorImpl
)
final List<num> costPerIteration;

@override
Iterable<String> get targetNames => classNames;

@override
DataFrame predict(DataFrame testFeatures) {
final predictedLabels = getProbabilitiesMatrix(testFeatures)
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Expand Up @@ -27,7 +27,7 @@ class SoftmaxRegressorImpl

SoftmaxRegressorImpl(
this.coefficientsByClasses,
this.classNames,
this.targetNames,
this.linkFunction,
this.fitIntercept,
this.interceptScale,
Expand All @@ -54,9 +54,8 @@ class SoftmaxRegressorImpl
Map<String, dynamic> toJson() => _$SoftmaxRegressorImplToJson(this);

@override
@deprecated
@JsonKey(name: softmaxRegressorClassNamesJsonKey)
final Iterable<String> classNames;
final Iterable<String> targetNames;

@override
@JsonKey(name: softmaxRegressorFitInterceptJsonKey)
Expand Down Expand Up @@ -105,9 +104,6 @@ class SoftmaxRegressorImpl
)
final List<num> costPerIteration;

@override
Iterable<String> get targetNames => classNames;

@override
DataFrame predict(DataFrame testFeatures) {
final allProbabilities = getProbabilitiesMatrix(testFeatures);
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Expand Up @@ -316,7 +316,7 @@ void main() {
expect(restoredClassifier.dtype, classifier.dtype);
expect(restoredClassifier.linkFunction,
isA<Float32InverseLogitLinkFunction>());
expect(restoredClassifier.classNames, [targetName]);
expect(restoredClassifier.targetNames, [targetName]);
});

test('should restore a classifier instance from json file, '
Expand All @@ -335,7 +335,7 @@ void main() {
expect(restoredClassifier.dtype, classifier.dtype);
expect(restoredClassifier.linkFunction,
isA<Float64InverseLogitLinkFunction>());
expect(restoredClassifier.classNames, [targetName]);
expect(restoredClassifier.targetNames, [targetName]);
});
});
}
Expand Up @@ -68,7 +68,7 @@ void main() {
group('default constructor', () {
test('should create the instance with `classNames` list of just one '
'element', () {
expect(regressor.classNames, equals([className]));
expect(regressor.targetNames, equals([className]));
});

test('should throw an exception if probability threshold is less '
Expand Down
Expand Up @@ -274,7 +274,7 @@ void main() {
expect(classifier.fitIntercept, fitIntercept);
expect(classifier.dtype, dtype);
expect(classifier.coefficientsByClasses, learnedCoefficients);
expect(classifier.classNames, [targetName]);
expect(classifier.targetNames, [targetName]);
});
});
}
Expand Up @@ -230,7 +230,7 @@ void main() {

expect(restoredClassifier.interceptScale, classifier.interceptScale);
expect(restoredClassifier.fitIntercept, classifier.fitIntercept);
expect(restoredClassifier.classNames, classifier.classNames);
expect(restoredClassifier.targetNames, classifier.targetNames);
expect(restoredClassifier.coefficientsByClasses,
classifier.coefficientsByClasses);
expect(restoredClassifier.linkFunction.runtimeType,
Expand All @@ -249,7 +249,7 @@ void main() {

expect(restoredClassifier.interceptScale, classifier.interceptScale);
expect(restoredClassifier.fitIntercept, classifier.fitIntercept);
expect(restoredClassifier.classNames, classifier.classNames);
expect(restoredClassifier.targetNames, classifier.targetNames);
expect(restoredClassifier.coefficientsByClasses,
classifier.coefficientsByClasses);
expect(restoredClassifier.linkFunction.runtimeType,
Expand Down
Expand Up @@ -47,7 +47,7 @@ void main() {
test('should persist data passed to the `create` method', () {
expect(regressor.costPerIteration, costPerIteration);
expect(regressor.dtype, dtype);
expect(regressor.classNames, classNames);
expect(regressor.targetNames, classNames);
expect(regressor.interceptScale, interceptScale);
expect(regressor.linkFunction, linkFunction);
expect(regressor.fitIntercept, fitIntercept);
Expand Down

0 comments on commit 30bf630

Please sign in to comment.