-
-
Notifications
You must be signed in to change notification settings - Fork 30
/
decision_tree_classifier_impl.dart
64 lines (53 loc) · 1.74 KB
/
decision_tree_classifier_impl.dart
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier.dart';
import 'package:ml_algo/src/predictor/assessable_predictor_mixin.dart';
import 'package:ml_algo/src/decision_tree_solver/decision_tree_solver.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/dtype.dart';
import 'package:ml_linalg/matrix.dart';
import 'package:ml_linalg/vector.dart';
class DecisionTreeClassifierImpl with AssessablePredictorMixin
implements DecisionTreeClassifier {
DecisionTreeClassifierImpl(this._solver, String className, this.dtype)
: classNames = [className];
@override
final DType dtype;
final DecisionTreeSolver _solver;
@override
final List<String> classNames;
@override
DataFrame predict(DataFrame features) {
final predictedLabels = features
.toMatrix(dtype)
.rows
.map(_solver.getLabelForSample);
if (predictedLabels.isEmpty) {
return DataFrame([<num>[]]);
}
final outcomeList = predictedLabels
.map((label) => label.value)
.toList(growable: false);
final outcomeVector = Vector.fromList(outcomeList, dtype: dtype);
return DataFrame.fromMatrix(
Matrix.fromColumns([outcomeVector], dtype: dtype),
header: classNames,
);
}
@override
DataFrame predictProbabilities(DataFrame features) {
final probabilities = Matrix.fromColumns([
Vector.fromList(
features
.toMatrix(dtype)
.rows
.map(_solver.getLabelForSample)
.map((label) => label.probability)
.toList(growable: false),
dtype: dtype,
),
], dtype: dtype);
return DataFrame.fromMatrix(
probabilities,
header: classNames,
);
}
}