Skip to content

Commit

Permalink
LinearClassifierMixin extended
Browse files Browse the repository at this point in the history
  • Loading branch information
gyrdym committed Oct 31, 2019
1 parent 724b3e4 commit 7798c38
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 40 deletions.
17 changes: 11 additions & 6 deletions lib/src/classifier/_mixins/linear_classifier_mixin.dart
Expand Up @@ -3,10 +3,20 @@ import 'package:ml_algo/src/helpers/add_intercept_if.dart';
import 'package:ml_algo/src/helpers/validate_coefficients_matrix.dart';
import 'package:ml_algo/src/helpers/validate_test_features.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/matrix.dart';

mixin LinearClassifierMixin implements LinearClassifier {
@override
DataFrame predictProbabilities(DataFrame testFeatures) {
final probabilities = getProbabilitiesMatrix(testFeatures);

return DataFrame.fromMatrix(
probabilities,
header: classNames,
);
}

Matrix getProbabilitiesMatrix(DataFrame testFeatures) {
validateTestFeatures(testFeatures, dtype);

final processedFeatures = addInterceptIf(
Expand All @@ -18,12 +28,7 @@ mixin LinearClassifierMixin implements LinearClassifier {
validateCoefficientsMatrix(coefficientsByClasses,
processedFeatures.columnsNum);

final probabilities = linkFunction
return linkFunction
.link(processedFeatures * coefficientsByClasses);

return DataFrame.fromMatrix(
probabilities,
header: classNames,
);
}
}
@@ -1,8 +1,6 @@
import 'package:ml_algo/src/classifier/_mixins/linear_classifier_mixin.dart';
import 'package:ml_algo/src/classifier/logistic_regressor/logistic_regressor.dart';
import 'package:ml_algo/src/helpers/add_intercept_if.dart';
import 'package:ml_algo/src/helpers/validate_coefficients_matrix.dart';
import 'package:ml_algo/src/helpers/validate_test_features.dart';
import 'package:ml_algo/src/link_function/link_function.dart';
import 'package:ml_algo/src/predictor/assessable_predictor_mixin.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
Expand Down Expand Up @@ -62,20 +60,7 @@ class LogisticRegressorImpl with LinearClassifierMixin,

@override
DataFrame predict(DataFrame testFeatures) {
validateTestFeatures(testFeatures, dtype);

final processedFeatures = addInterceptIf(
fitIntercept,
testFeatures.toMatrix(dtype),
interceptScale,
);

validateCoefficientsMatrix(coefficientsByClasses,
processedFeatures.columnsNum);

final probabilities = linkFunction
.link(processedFeatures * coefficientsByClasses)
.getColumn(0);
final probabilities = getProbabilitiesMatrix(testFeatures).getColumn(0);

final classesList = probabilities
// TODO: use SIMD
Expand Down
17 changes: 1 addition & 16 deletions lib/src/classifier/softmax_regressor/softmax_regressor_impl.dart
@@ -1,8 +1,5 @@
import 'package:ml_algo/src/classifier/_mixins/linear_classifier_mixin.dart';
import 'package:ml_algo/src/classifier/softmax_regressor/softmax_regressor.dart';
import 'package:ml_algo/src/helpers/add_intercept_if.dart';
import 'package:ml_algo/src/helpers/validate_coefficients_matrix.dart';
import 'package:ml_algo/src/helpers/validate_test_features.dart';
import 'package:ml_algo/src/link_function/link_function.dart';
import 'package:ml_algo/src/predictor/assessable_predictor_mixin.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
Expand Down Expand Up @@ -48,19 +45,7 @@ class SoftmaxRegressorImpl with LinearClassifierMixin,

@override
DataFrame predict(DataFrame testFeatures) {
validateTestFeatures(testFeatures, dtype);

final processedFeatures = addInterceptIf(
fitIntercept,
testFeatures.toMatrix(dtype),
interceptScale,
);

validateCoefficientsMatrix(coefficientsByClasses,
processedFeatures.columnsNum);

final allProbabilities = linkFunction
.link(processedFeatures * coefficientsByClasses);
final allProbabilities = getProbabilitiesMatrix(testFeatures);

final classes = allProbabilities.mapRows((probabilities) {
final positiveLabelIdx = probabilities
Expand Down
4 changes: 2 additions & 2 deletions lib/src/predictor/assessable_predictor_mixin.dart
Expand Up @@ -15,8 +15,8 @@ mixin AssessablePredictorMixin implements Assessable, Predictor {

final metric = MetricFactory.createByType(metricType);
final prediction = predict(splits[0]);
final origLabels = splits[1].toMatrix();
final origLabels = splits[1].toMatrix(dtype);

return metric.getScore(prediction.toMatrix(), origLabels);
return metric.getScore(prediction.toMatrix(dtype), origLabels);
}
}

0 comments on commit 7798c38

Please sign in to comment.