-
-
Notifications
You must be signed in to change notification settings - Fork 29
/
cross_validator.dart
27 lines (24 loc) · 1.34 KB
/
cross_validator.dart
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import 'package:ml_algo/src/metric/metric_type.dart';
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_impl.dart';
import 'package:ml_algo/src/model_selection/data_splitter/k_fold.dart';
import 'package:ml_algo/src/model_selection/data_splitter/leave_p_out.dart';
import 'package:ml_algo/src/predictor/predictor.dart';
import 'package:ml_linalg/matrix.dart';
/// A factory and an interface for all the cross validator types
abstract class CrossValidator {
/// Creates k-fold validator to evaluate quality of a predictor.
///
/// It splits a dataset into [numberOfFolds] test sets and subsequently
/// evaluates the predictor on each produced test set
factory CrossValidator.kFold({Type dtype, int numberOfFolds}) =>
CrossValidatorImpl(dtype, KFoldSplitter(numberOfFolds));
/// Creates LPO validator to evaluate quality of a predictor.
///
/// It splits a dataset into all possible test sets of size [p] and
/// subsequently evaluates quality of the predictor on each produced test set
factory CrossValidator.lpo({Type dtype, int p}) =>
CrossValidatorImpl(dtype, LeavePOutSplitter(p));
/// Returns a score of quality of passed predictor depending on given [metric]
double evaluate(Predictor predictorFactory(Matrix features, Matrix outcomes),
Matrix observations, Matrix outcomes, MetricType metric);
}