/
create_knn_classifier.dart
67 lines (57 loc) · 1.9 KB
/
create_knn_classifier.dart
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier.dart';
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_factory.dart';
import 'package:ml_algo/src/di/dependencies.dart';
import 'package:ml_algo/src/helpers/features_target_split.dart';
import 'package:ml_algo/src/helpers/validate_train_data.dart';
import 'package:ml_algo/src/knn_kernel/kernel_factory.dart';
import 'package:ml_algo/src/knn_kernel/kernel_type.dart';
import 'package:ml_algo/src/knn_solver/knn_solver_factory.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/distance.dart';
import 'package:ml_linalg/dtype.dart';
KnnClassifier createKnnClassifier(
DataFrame trainData,
String targetName,
int k,
KernelType kernelType,
Distance distance,
DType dtype,
) {
validateTrainData(trainData, [targetName]);
final splits = featuresTargetSplit(trainData,
targetNames: [targetName],
).toList();
final featuresSplit = splits[0];
final targetSplit = splits[1];
final trainFeatures = featuresSplit.toMatrix(dtype);
final trainLabels = targetSplit.toMatrix(dtype);
final classLabels = targetSplit[targetName].isDiscrete
? targetSplit[targetName]
.discreteValues
.map((dynamic value) => value as num)
.toList(growable: false)
: targetSplit
.toMatrix(dtype)
.getColumn(0)
.unique()
.toList(growable: false);
final kernelFactory = dependencies.get<KernelFactory>();
final kernel = kernelFactory.createByType(kernelType);
final solverFactory = dependencies.get<KnnSolverFactory>();
final solver = solverFactory.create(
trainFeatures,
trainLabels,
k,
distance,
true,
);
final knnClassifierFactory = dependencies
.get<KnnClassifierFactory>();
return knnClassifierFactory.create(
targetName,
classLabels,
kernel,
solver,
dtype,
);
}