-
-
Notifications
You must be signed in to change notification settings - Fork 30
/
softmax_regressor_impl.dart
70 lines (54 loc) · 1.71 KB
/
softmax_regressor_impl.dart
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import 'package:ml_algo/src/classifier/_mixins/linear_classifier_mixin.dart';
import 'package:ml_algo/src/classifier/softmax_regressor/softmax_regressor.dart';
import 'package:ml_algo/src/link_function/link_function.dart';
import 'package:ml_algo/src/predictor/assessable_predictor_mixin.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/dtype.dart';
import 'package:ml_linalg/linalg.dart';
import 'package:ml_linalg/matrix.dart';
class SoftmaxRegressorImpl with LinearClassifierMixin,
AssessablePredictorMixin implements SoftmaxRegressor {
SoftmaxRegressorImpl(
this.coefficientsByClasses,
this.classNames,
this.linkFunction,
this.fitIntercept,
this.interceptScale,
this._positiveLabel,
this._negativeLabel,
this.dtype,
);
@override
final List<String> classNames;
@override
final bool fitIntercept;
@override
final num interceptScale;
@override
final Matrix coefficientsByClasses;
@override
final DType dtype;
@override
final LinkFunction linkFunction;
final num _positiveLabel;
final num _negativeLabel;
@override
DataFrame predict(DataFrame testFeatures) {
final allProbabilities = getProbabilitiesMatrix(testFeatures);
final classes = allProbabilities.mapRows((probabilities) {
final positiveLabelIdx = probabilities
.toList()
.indexOf(probabilities.max());
final predictedRow = List.filled(
coefficientsByClasses.columnsNum,
_negativeLabel,
);
predictedRow[positiveLabelIdx] = _positiveLabel;
return Vector.fromList(predictedRow, dtype: dtype);
});
return DataFrame.fromMatrix(
classes,
header: classNames,
);
}
}