/
cross_validator.dart
41 lines (32 loc) · 1.01 KB
/
cross_validator.dart
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
// 8.5 sec
import 'package:benchmark_harness/benchmark_harness.dart';
import 'package:ml_algo/ml_algo.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/matrix.dart';
import 'package:ml_linalg/vector.dart';
const observationsNum = 1000;
const columnsNum = 21;
class CrossValidatorBenchmark extends BenchmarkBase {
CrossValidatorBenchmark() : super('Cross validator benchmark');
CrossValidator crossValidator;
static void main() {
CrossValidatorBenchmark().report();
}
@override
void run() {
crossValidator.evaluate((trainSamples) =>
KnnRegressor(trainSamples, 'col_20', 7),
MetricType.mape);
}
@override
void setup() {
final samples = Matrix.fromRows(List.generate(observationsNum,
(i) => Vector.randomFilled(columnsNum)));
final dataFrame = DataFrame.fromMatrix(samples);
crossValidator = CrossValidator.kFold(dataFrame, numberOfFolds: 5);
}
void tearDown() {}
}
Future main() async {
CrossValidatorBenchmark.main();
}