-
-
Notifications
You must be signed in to change notification settings - Fork 30
/
knn_regressor.dart
116 lines (111 loc) · 4.34 KB
/
knn_regressor.dart
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import 'package:ml_algo/src/common/constants/default_parameters/common.dart';
import 'package:ml_algo/src/common/serializable/serializable.dart';
import 'package:ml_algo/src/knn_kernel/kernel_type.dart';
import 'package:ml_algo/src/model_selection/assessable.dart';
import 'package:ml_algo/src/predictor/predictor.dart';
import 'package:ml_algo/src/predictor/retrainable.dart';
import 'package:ml_algo/src/regressor/knn_regressor/_init_module.dart';
import 'package:ml_algo/src/regressor/knn_regressor/knn_regressor_factory.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/distance.dart';
import 'package:ml_linalg/dtype.dart';
/// A class that performs regression basing on `k nearest neighbours` algorithm
///
/// K nearest neighbours algorithm is an algorithm that is targeted to search
/// most similar labelled observations (number of these observations equals `k`)
/// for the given unlabelled one.
///
/// In order to make a prediction, or rather to set a label for a given new
/// observation, labels of found `k` observations are being summed up and
/// divided by `k`.
///
/// To get a more precise result, one may use weighted average of found labels -
/// the farther a found observation from the target one, the lower the weight of
/// the observation is. To obtain these weights one may use a kernel function.
abstract class KnnRegressor
implements Assessable, Serializable, Retrainable<KnnRegressor>, Predictor {
/// Parameters:
///
/// [fittingData] Labelled observations, among which will be searched [k]
/// nearest to the given unlabelled observations neighbours. Must contain
/// [targetName] column.
///
/// [targetName] A string, that serves as a name of the column, that contains
/// labels (or outcomes).
///
/// [k] a number of nearest neighbours to be found among [fittingData]
///
/// [kernel] a type of a kernel function, that will be used to predict an
/// outcome for a new observation
///
/// [distance] a distance type, that will be used to measure a distance
/// between two observation vectors
///
/// [dtype] A data type for all the numeric values, used by the algorithm. Can
/// affect performance or accuracy of the computations. Default value is
/// [dTypeDefaultValue]
factory KnnRegressor(
DataFrame fittingData,
String targetName,
int k, {
KernelType kernel = KernelType.gaussian,
Distance distance = Distance.euclidean,
DType dtype = dTypeDefaultValue,
}) =>
initKnnRegressorModule().get<KnnRegressorFactory>().create(
fittingData,
targetName,
k,
kernel,
distance,
dtype,
);
/// Restores previously fitted regressor instance from the given [json]
///
/// ````dart
/// import 'dart:io';
/// import 'package:ml_dataframe/ml_dataframe.dart';
///
/// final data = <Iterable>[
/// ['feature 1', 'feature 2', 'feature 3', 'outcome']
/// [ 5.0, 7.0, 6.0, 1.0],
/// [ 1.0, 2.0, 3.0, 0.0],
/// [ 10.0, 12.0, 31.0, 0.0],
/// [ 9.0, 8.0, 5.0, 0.0],
/// [ 4.0, 0.0, 1.0, 1.0],
/// ];
/// final targetName = 'outcome';
/// final samples = DataFrame(data, headerExists: true);
/// final regressor = KnnRegressor(
/// samples,
/// targetName,
/// 3,
/// );
///
/// final pathToFile = './regressor.json';
///
/// await regressor.saveAsJson(pathToFile);
///
/// final file = File(pathToFile);
/// final json = await file.readAsString();
/// final restoredRegressor = KnnRegressor.fromJson(json);
///
/// // here you can use previously fitted restored regressor to make
/// // some prediction, e.g. via `restoredRegressor.predict(...)`;
/// ````
factory KnnRegressor.fromJson(String json) =>
initKnnRegressorModule().get<KnnRegressorFactory>().fromJson(json);
/// A number of nearest neighbours
///
/// The value is read-only, it's a hyperparameter of the model
int get k;
/// A kernel type
///
/// The value is read-only, it's a hyperparameter of the model
KernelType get kernelType;
/// A distance type that is used to measure a distance between two
/// observations
///
/// The value is read-only, it's a hyperparameter of the model
Distance get distanceType;
}