/
cross_validator.dart
144 lines (137 loc) · 4.73 KB
/
cross_validator.dart
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import 'package:ml_algo/src/di/dependencies.dart';
import 'package:ml_algo/src/metric/metric_type.dart';
import 'package:ml_algo/src/model_selection/assessable.dart';
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_impl.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter_factory.dart';
import 'package:ml_algo/src/model_selection/data_splitter/data_splitter_type.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/dtype.dart';
import 'package:ml_linalg/linalg.dart';
typedef PredictorFactory = Assessable Function(DataFrame observations,
Iterable<String> targetNames);
typedef DataPreprocessFn = List<DataFrame> Function(DataFrame trainData,
DataFrame testData);
/// A factory and an interface for all the cross validator types
abstract class CrossValidator {
/// Creates k-fold validator to evaluate quality of a predictor.
///
/// It splits a dataset into [numberOfFolds] test sets and subsequently
/// evaluates given predictor on each produced test set
///
/// Parameters:
///
/// [samples] A dataset to be split into parts to iteratively evaluate given
/// predictor's performance
///
/// [targetColumnNames] Names of columns from [samples] that contain outcomes
///
/// [numberOfFolds] Number of splits of the [samples]
///
/// [dtype] A type for all the numerical data
factory CrossValidator.kFold(
DataFrame samples,
Iterable<String> targetColumnNames, {
int numberOfFolds = 5,
DType dtype = DType.float32,
}) {
final dataSplitterFactory = dependencies
.getDependency<DataSplitterFactory>();
final dataSplitter = dataSplitterFactory
.createByType(DataSplitterType.kFold, numberOfFolds: numberOfFolds);
return CrossValidatorImpl(
samples,
targetColumnNames,
dataSplitter,
dtype,
);
}
/// Creates LPO validator to evaluate quality of a predictor.
///
/// It splits a dataset into all possible test sets of size [p] and
/// subsequently evaluates quality of the predictor on each produced test set.
///
/// Parameters:
///
/// [samples] A dataset to be split into parts to iteratively
/// evaluate given predictor's performance
///
/// [targetColumnNames] Names of columns from [samples] that contain outcomes.
///
/// [p] Size of a split of [samples].
///
/// [dtype] A type for all the numerical data.
factory CrossValidator.lpo(
DataFrame samples,
Iterable<String> targetColumnNames,
int p, {
DType dtype = DType.float32,
}) {
final dataSplitterFactory = dependencies.getDependency<DataSplitterFactory>();
final dataSplitter = dataSplitterFactory
.createByType(DataSplitterType.lpo, p: p);
return CrossValidatorImpl(
samples,
targetColumnNames,
dataSplitter,
dtype,
);
}
/// Returns a future resolving with a vector of scores of quality of passed
/// predictor depending on given [metricType]
///
/// Parameters:
///
/// [predictorFactory] A factory function that returns an evaluating predictor
///
/// [metricType] Metric using to assess a predictor creating by
/// [predictorFactory]
///
/// [onDataSplit] A callback that is called when a new train-test split is
/// ready to be passed into evaluating predictor. One may place some
/// additional data-dependent logic here, e.g., data preprocessing. The
/// callback accepts train and test data from a new split and returns
/// transformed split as list, where the first element is train data and
/// the second one - test data, both of [DataFrame] type. This new transformed
/// split will be passed into the predictor.
///
/// Example:
///
/// ````dart
/// final data = DataFrame(
/// <Iterable<num>>[
/// [ 1, 1, 1, 1],
/// [ 2, 3, 4, 5],
/// [18, 71, 15, 61],
/// [19, 0, 21, 331],
/// [11, 10, 9, 40],
/// ],
/// header: header,
/// headerExists: false,
/// );
/// final predictorFactory = (trainData, _) =>
/// KnnRegressor(trainData, 'col_3', k: 4);
/// final onDataSplit = (trainData, testData) {
/// final standardizer = Standardizer(trainData);
/// return [
/// standardizer.process(trainData),
/// standardizer.process(testData),
/// ];
/// }
/// final validator = CrossValidator.kFold(data, ['col_3']);
/// final scores = await validator.evaluate(
/// predictorFactory,
/// MetricType.mape,
/// onDataSplit: onDataSplit,
/// );
/// final averageScore = scores.mean();
///
/// print(averageScore);
/// ````
Future<Vector> evaluate(
PredictorFactory predictorFactory,
MetricType metricType,
{
DataPreprocessFn onDataSplit,
}
);
}