-
-
Notifications
You must be signed in to change notification settings - Fork 29
/
cross_validator_impl.dart
102 lines (91 loc) · 3.71 KB
/
cross_validator_impl.dart
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import 'package:ml_algo/src/common/exception/invalid_test_data_columns_number_exception.dart';
import 'package:ml_algo/src/common/exception/invalid_train_data_columns_number_exception.dart';
import 'package:ml_algo/src/metric/metric_type.dart';
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/dtype.dart';
import 'package:ml_linalg/matrix.dart';
import 'package:ml_linalg/vector.dart';
import 'package:quiver/iterables.dart';
class CrossValidatorImpl implements CrossValidator {
CrossValidatorImpl(
this.samples,
this.targetNames,
this._splitter,
this.dtype,
);
final DataFrame samples;
final DType dtype;
final Iterable<String> targetNames;
final SplitIndicesProvider _splitter;
@override
Future<Vector> evaluate(
PredictorFactory predictorFactory,
MetricType metricType,
{
DataPreprocessFn onDataSplit,
}
) {
final samplesAsMatrix = samples.toMatrix(dtype);
final sourceColumnsNum = samplesAsMatrix.columnsNum;
final discreteColumns = enumerate(samples.series)
.where((indexedSeries) => indexedSeries.value.isDiscrete)
.map((indexedSeries) => indexedSeries.index);
final allIndicesGroups = _splitter.getIndices(samplesAsMatrix.rowsNum);
final scores = allIndicesGroups
.map((testRowsIndices) {
final split = _makeSplit(testRowsIndices, discreteColumns);
final trainDataFrame = split[0];
final testDataFrame = split[1];
final splits = onDataSplit != null
? onDataSplit(trainDataFrame, testDataFrame)
: [trainDataFrame, testDataFrame];
final transformedTrainData = splits[0];
final transformedTestData = splits[1];
final transformedTrainDataColumnsNum = transformedTrainData.header.length;
final transformedTestDataColumnsNum = transformedTestData.header.length;
if (transformedTrainDataColumnsNum != sourceColumnsNum) {
throw InvalidTrainDataColumnsNumberException(sourceColumnsNum,
transformedTrainDataColumnsNum);
}
if (transformedTestDataColumnsNum != sourceColumnsNum) {
throw InvalidTestDataColumnsNumberException(sourceColumnsNum,
transformedTestDataColumnsNum);
}
return predictorFactory(transformedTrainData, targetNames)
.assess(transformedTestData, targetNames, metricType);
})
.toList();
return Future.value(Vector.fromList(scores, dtype: dtype));
}
List<DataFrame> _makeSplit(Iterable<int> testRowsIndices,
Iterable<int> discreteColumns) {
final samplesAsMatrix = samples.toMatrix(dtype);
final testRowsIndicesAsSet = Set<int>.from(testRowsIndices);
final trainSamples =
List<Vector>(samplesAsMatrix.rowsNum - testRowsIndicesAsSet.length);
final testSamples = List<Vector>(testRowsIndicesAsSet.length);
var trainSamplesCounter = 0;
var testSamplesCounter = 0;
samplesAsMatrix.rowIndices.forEach((i) {
if (testRowsIndicesAsSet.contains(i)) {
testSamples[testSamplesCounter++] = samplesAsMatrix[i];
} else {
trainSamples[trainSamplesCounter++] = samplesAsMatrix[i];
}
});
return [
DataFrame.fromMatrix(
Matrix.fromRows(trainSamples, dtype: dtype),
header: samples.header,
discreteColumns: discreteColumns,
),
DataFrame.fromMatrix(
Matrix.fromRows(testSamples, dtype: dtype),
header: samples.header,
discreteColumns: discreteColumns,
),
];
}
}