Skip to content

Commit

Permalink
SoftmacRegressorImpl.predictProbabilities method covered with unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gyrdym committed Oct 31, 2019
1 parent 9a5bd64 commit 724b3e4
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 43 deletions.
19 changes: 11 additions & 8 deletions lib/src/classifier/_mixins/linear_classifier_mixin.dart
@@ -1,22 +1,25 @@
import 'package:ml_algo/src/classifier/linear_classifier.dart';
import 'package:ml_algo/src/helpers/add_intercept_if.dart';
import 'package:ml_algo/src/helpers/get_probabilities.dart';
import 'package:ml_algo/src/helpers/validate_coefficients_matrix.dart';
import 'package:ml_algo/src/helpers/validate_test_features.dart';
import 'package:ml_dataframe/ml_dataframe.dart';

mixin LinearClassifierMixin implements LinearClassifier {
@override
DataFrame predictProbabilities(DataFrame features) {
DataFrame predictProbabilities(DataFrame testFeatures) {
validateTestFeatures(testFeatures, dtype);

final processedFeatures = addInterceptIf(
fitIntercept,
features.toMatrix(),
testFeatures.toMatrix(dtype),
interceptScale,
);

final probabilities = getProbabilities(
processedFeatures,
coefficientsByClasses,
linkFunction,
);
validateCoefficientsMatrix(coefficientsByClasses,
processedFeatures.columnsNum);

final probabilities = linkFunction
.link(processedFeatures * coefficientsByClasses);

return DataFrame.fromMatrix(
probabilities,
Expand Down
Expand Up @@ -9,10 +9,11 @@ import 'package:ml_linalg/vector.dart';
class DecisionTreeClassifierImpl with AssessablePredictorMixin
implements DecisionTreeClassifier {

DecisionTreeClassifierImpl(this._solver, String className, this._dtype)
DecisionTreeClassifierImpl(this._solver, String className, this.dtype)
: classNames = [className];

final DType _dtype;
@override
final DType dtype;

final DecisionTreeSolver _solver;

Expand All @@ -22,7 +23,7 @@ class DecisionTreeClassifierImpl with AssessablePredictorMixin
@override
DataFrame predict(DataFrame features) {
final predictedLabels = features
.toMatrix(_dtype)
.toMatrix(dtype)
.rows
.map(_solver.getLabelForSample);

Expand All @@ -33,10 +34,10 @@ class DecisionTreeClassifierImpl with AssessablePredictorMixin
final outcomeList = predictedLabels
.map((label) => label.value)
.toList(growable: false);
final outcomeVector = Vector.fromList(outcomeList, dtype: _dtype);
final outcomeVector = Vector.fromList(outcomeList, dtype: dtype);

return DataFrame.fromMatrix(
Matrix.fromColumns([outcomeVector], dtype: _dtype),
Matrix.fromColumns([outcomeVector], dtype: dtype),
header: classNames,
);
}
Expand All @@ -46,14 +47,14 @@ class DecisionTreeClassifierImpl with AssessablePredictorMixin
final probabilities = Matrix.fromColumns([
Vector.fromList(
features
.toMatrix(_dtype)
.toMatrix(dtype)
.rows
.map(_solver.getLabelForSample)
.map((label) => label.probability)
.toList(growable: false),
dtype: _dtype,
dtype: dtype,
),
], dtype: _dtype);
], dtype: dtype);

return DataFrame.fromMatrix(
probabilities,
Expand Down
22 changes: 12 additions & 10 deletions lib/src/classifier/knn_classifier/knn_classifier_impl.dart
Expand Up @@ -16,23 +16,25 @@ class KnnClassifierImpl with AssessablePredictorMixin implements KnnClassifier {
this._classLabels,
this._kernel,
this._solver,
this._dtype,
this.dtype,
) : classNames = [targetName] {
validateClassLabelList(_classLabels);
}

@override
final List<String> classNames;

@override
final DType dtype;

final List<num> _classLabels;
final Kernel _kernel;
final KnnSolver _solver;
final DType _dtype;
final String _columPrefix = 'Class label';
final String _columnPrefix = 'Class label';

@override
DataFrame predict(DataFrame features) {
validateTestFeatures(features, _dtype);
validateTestFeatures(features, dtype);

final labelsToProbabilities = _getLabelToProbabilityMapping(features);
final labels = labelsToProbabilities.keys.toList();
Expand All @@ -41,10 +43,10 @@ class KnnClassifierImpl with AssessablePredictorMixin implements KnnClassifier {
.map((row) => labels[row.toList().indexOf(row.max())])
.toList();

final outcomesAsVector = Vector.fromList(predictedOutcomes, dtype: _dtype);
final outcomesAsVector = Vector.fromList(predictedOutcomes, dtype: dtype);

return DataFrame.fromMatrix(
Matrix.fromColumns([outcomesAsVector], dtype: _dtype),
Matrix.fromColumns([outcomesAsVector], dtype: dtype),
header: classNames,
);
}
Expand All @@ -56,7 +58,7 @@ class KnnClassifierImpl with AssessablePredictorMixin implements KnnClassifier {

final header = labelsToProbabilities
.keys
.map((label) => '${_columPrefix} ${label.toString()}');
.map((label) => '${_columnPrefix} ${label.toString()}');

return DataFrame.fromMatrix(probabilityMatrix, header: header);
}
Expand Down Expand Up @@ -84,7 +86,7 @@ class KnnClassifierImpl with AssessablePredictorMixin implements KnnClassifier {
/// where each row is a classes probability distribution for the appropriate
/// feature record from test feature matrix
Map<num, List<num>> _getLabelToProbabilityMapping(DataFrame features) {
final kNeighbourGroups = _solver.findKNeighbours(features.toMatrix(_dtype));
final kNeighbourGroups = _solver.findKNeighbours(features.toMatrix(dtype));
final classLabelsAsSet = Set<num>.from(_classLabels);

return kNeighbourGroups.fold<Map<num, List<num>>>(
Expand Down Expand Up @@ -133,11 +135,11 @@ class KnnClassifierImpl with AssessablePredictorMixin implements KnnClassifier {
Matrix _getProbabilityMatrix(Map<num, List<num>> allLabelsToProbabilities) {
final probabilityVectors = allLabelsToProbabilities
.values
.map((probabilities) => Vector.fromList(probabilities, dtype: _dtype))
.map((probabilities) => Vector.fromList(probabilities, dtype: dtype))
.toList(growable: false);

return Matrix
.fromColumns(probabilityVectors, dtype: _dtype);
.fromColumns(probabilityVectors, dtype: dtype);
}

Map<num, num> _updateLabelToWeightMapping(
Expand Down
Expand Up @@ -22,7 +22,7 @@ class LogisticRegressorImpl with LinearClassifierMixin,
this._probabilityThreshold,
this._negativeLabel,
this._positiveLabel,
this._dtype,
this.dtype,
) : classNames = [className] {
validateCoefficientsMatrix(coefficientsByClasses);

Expand Down Expand Up @@ -53,18 +53,20 @@ class LogisticRegressorImpl with LinearClassifierMixin,
@override
final LinkFunction linkFunction;

final DType _dtype;
@override
final DType dtype;

final num _probabilityThreshold;
final num _positiveLabel;
final num _negativeLabel;

@override
DataFrame predict(DataFrame testFeatures) {
validateTestFeatures(testFeatures, _dtype);
validateTestFeatures(testFeatures, dtype);

final processedFeatures = addInterceptIf(
fitIntercept,
testFeatures.toMatrix(_dtype),
testFeatures.toMatrix(dtype),
interceptScale,
);

Expand Down
12 changes: 6 additions & 6 deletions lib/src/classifier/softmax_regressor/softmax_regressor_impl.dart
@@ -1,7 +1,6 @@
import 'package:ml_algo/src/classifier/_mixins/linear_classifier_mixin.dart';
import 'package:ml_algo/src/classifier/softmax_regressor/softmax_regressor.dart';
import 'package:ml_algo/src/helpers/add_intercept_if.dart';
import 'package:ml_algo/src/helpers/get_probabilities.dart';
import 'package:ml_algo/src/helpers/validate_coefficients_matrix.dart';
import 'package:ml_algo/src/helpers/validate_test_features.dart';
import 'package:ml_algo/src/link_function/link_function.dart';
Expand All @@ -22,7 +21,7 @@ class SoftmaxRegressorImpl with LinearClassifierMixin,
this.interceptScale,
this._positiveLabel,
this._negativeLabel,
this._dtype,
this.dtype,
);

@override
Expand All @@ -37,7 +36,8 @@ class SoftmaxRegressorImpl with LinearClassifierMixin,
@override
final Matrix coefficientsByClasses;

final DType _dtype;
@override
final DType dtype;

@override
final LinkFunction linkFunction;
Expand All @@ -48,11 +48,11 @@ class SoftmaxRegressorImpl with LinearClassifierMixin,

@override
DataFrame predict(DataFrame testFeatures) {
validateTestFeatures(testFeatures, _dtype);
validateTestFeatures(testFeatures, dtype);

final processedFeatures = addInterceptIf(
fitIntercept,
testFeatures.toMatrix(_dtype),
testFeatures.toMatrix(dtype),
interceptScale,
);

Expand All @@ -74,7 +74,7 @@ class SoftmaxRegressorImpl with LinearClassifierMixin,

predictedRow[positiveLabelIdx] = _positiveLabel;

return Vector.fromList(predictedRow, dtype: _dtype);
return Vector.fromList(predictedRow, dtype: dtype);
});

return DataFrame.fromMatrix(
Expand Down
3 changes: 3 additions & 0 deletions lib/src/predictor/predictor.dart
@@ -1,7 +1,10 @@
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/dtype.dart';

/// A common interface for all types of classifiers and regressors
abstract class Predictor {
/// Returns prediction, based on the model learned parameters
DataFrame predict(DataFrame testFeatures);

DType get dtype;
}
14 changes: 8 additions & 6 deletions lib/src/regressor/knn_regressor/knn_regressor_impl.dart
Expand Up @@ -13,25 +13,27 @@ class KnnRegressorImpl with AssessablePredictorMixin implements KnnRegressor {
this._targetName,
this._solver,
this._kernel,
this._dtype,
this.dtype,
);

@override
final DType dtype;

final String _targetName;
final KnnSolver _solver;
final Kernel _kernel;
final DType _dtype;

Vector get _zeroVector => _cachedZeroVector ??= Vector.zero(1, dtype: _dtype);
Vector get _zeroVector => _cachedZeroVector ??= Vector.zero(1, dtype: dtype);
Vector _cachedZeroVector;

@override
DataFrame predict(DataFrame testFeatures) {
validateTestFeatures(testFeatures, _dtype);
validateTestFeatures(testFeatures, dtype);

final prediction = Matrix.fromRows(
_predictOutcomes(testFeatures.toMatrix(_dtype))
_predictOutcomes(testFeatures.toMatrix(dtype))
.toList(growable: false),
dtype: _dtype,
dtype: dtype,
);

return DataFrame.fromMatrix(
Expand Down
Expand Up @@ -143,11 +143,73 @@ void main() {
expect(actual.toMatrix(dtype), equals(expectedOutcomeMatrix));
});

test('should return a dataframe with proper header', () {
test('should predict the first class if outcome is equiprobable', () {
reset(linkFunctionMock);

when(linkFunctionMock.link(any)).thenReturn(Matrix.fromList([
[0.33, 0.33, 0.33],
]));

final actual = regressor.predict(testFeatures);
final expectedOutcome = Matrix.fromList([
[positiveLabel, negativeLabel, negativeLabel],
]);

expect(actual.toMatrix(dtype), equals(expectedOutcome));
});

test('should return a dataframe with a proper header', () {
final actual = regressor.predict(testFeatures);

expect(actual.header, equals(targetNames));
});
});

group('`predictProbabilities` method', () {
test('should throw an exception if no features provided', () {
final testFeatures = DataFrame.fromMatrix(Matrix.empty());

expect(() => regressor.predictProbabilities(testFeatures),
throwsException);
});

test('should throw an exception if too few features provided', () {
final testFeatures = DataFrame.fromMatrix(Matrix.fromList([
[1, 2],
]));

expect(() => regressor.predictProbabilities(testFeatures),
throwsException);
});

test('should throw an exception if too many features provided', () {
final testFeatures = DataFrame.fromMatrix(Matrix.fromList([
[1, 2, 4, 4, 5, 6],
]));

expect(() => regressor.predictProbabilities(testFeatures),
throwsException);
});

test('should consider intercept term', () {
regressor.predictProbabilities(testFeatures);

verify(linkFunctionMock.link(
testFeaturesMatrixWithIntercept * coefficientsByClasses)
).called(1);
});

test('should return probabilities as dataframe', () {
final probabilities = regressor.predictProbabilities(testFeatures);

expect(probabilities.rows, equals(mockedProbabilities));
});

test('should return a dataframe with a proper header', () {
final probabilities = regressor.predictProbabilities(testFeatures);

expect(probabilities.header, equals(targetNames));
});
});
});
}

0 comments on commit 724b3e4

Please sign in to comment.