Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #117 from gyrdym/create-factories-to-all-predictors
KNN classifier instantiation refactored
- Loading branch information
Showing
10 changed files
with
234 additions
and
206 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
67 changes: 67 additions & 0 deletions
67
lib/src/classifier/knn_classifier/_helpers/create_knn_classifier.dart
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier.dart'; | ||
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_factory.dart'; | ||
import 'package:ml_algo/src/di/dependencies.dart'; | ||
import 'package:ml_algo/src/helpers/features_target_split.dart'; | ||
import 'package:ml_algo/src/helpers/validate_train_data.dart'; | ||
import 'package:ml_algo/src/knn_kernel/kernel_factory.dart'; | ||
import 'package:ml_algo/src/knn_kernel/kernel_type.dart'; | ||
import 'package:ml_algo/src/knn_solver/knn_solver_factory.dart'; | ||
import 'package:ml_dataframe/ml_dataframe.dart'; | ||
import 'package:ml_linalg/distance.dart'; | ||
import 'package:ml_linalg/dtype.dart'; | ||
|
||
KnnClassifier createKnnClassifier( | ||
DataFrame trainData, | ||
String targetName, | ||
int k, | ||
KernelType kernelType, | ||
Distance distance, | ||
DType dtype, | ||
) { | ||
validateTrainData(trainData, [targetName]); | ||
|
||
final splits = featuresTargetSplit(trainData, | ||
targetNames: [targetName], | ||
).toList(); | ||
|
||
final featuresSplit = splits[0]; | ||
final targetSplit = splits[1]; | ||
|
||
final trainFeatures = featuresSplit.toMatrix(dtype); | ||
final trainLabels = targetSplit.toMatrix(dtype); | ||
|
||
final classLabels = targetSplit[targetName].isDiscrete | ||
? targetSplit[targetName] | ||
.discreteValues | ||
.map((dynamic value) => value as num) | ||
.toList(growable: false) | ||
: targetSplit | ||
.toMatrix(dtype) | ||
.getColumn(0) | ||
.unique() | ||
.toList(growable: false); | ||
|
||
final kernelFactory = dependencies.getDependency<KernelFactory>(); | ||
final kernel = kernelFactory.createByType(kernelType); | ||
|
||
final solverFactory = dependencies.getDependency<KnnSolverFactory>(); | ||
|
||
final solver = solverFactory.create( | ||
trainFeatures, | ||
trainLabels, | ||
k, | ||
distance, | ||
true, | ||
); | ||
|
||
final knnClassifierFactory = dependencies | ||
.getDependency<KnnClassifierFactory>(); | ||
|
||
return knnClassifierFactory.create( | ||
targetName, | ||
classLabels, | ||
kernel, | ||
solver, | ||
dtype, | ||
); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
12 changes: 5 additions & 7 deletions
12
lib/src/classifier/knn_classifier/knn_classifier_factory.dart
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,14 @@ | ||
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier.dart'; | ||
import 'package:ml_algo/src/knn_kernel/kernel_type.dart'; | ||
import 'package:ml_dataframe/ml_dataframe.dart'; | ||
import 'package:ml_linalg/distance.dart'; | ||
import 'package:ml_algo/src/knn_kernel/kernel.dart'; | ||
import 'package:ml_algo/src/knn_solver/knn_solver.dart'; | ||
import 'package:ml_linalg/dtype.dart'; | ||
|
||
abstract class KnnClassifierFactory { | ||
KnnClassifier create( | ||
DataFrame fittingData, | ||
String targetName, | ||
int k, | ||
KernelType kernelType, | ||
Distance distance, | ||
List<num> classLabels, | ||
Kernel kernel, | ||
KnnSolver solver, | ||
DType dtype, | ||
); | ||
} |
64 changes: 7 additions & 57 deletions
64
lib/src/classifier/knn_classifier/knn_classifier_factory_impl.dart
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,69 +1,19 @@ | ||
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier.dart'; | ||
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_factory.dart'; | ||
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_impl.dart'; | ||
import 'package:ml_algo/src/helpers/features_target_split.dart'; | ||
import 'package:ml_algo/src/helpers/validate_train_data.dart'; | ||
import 'package:ml_algo/src/knn_kernel/kernel_factory.dart'; | ||
import 'package:ml_algo/src/knn_kernel/kernel_type.dart'; | ||
import 'package:ml_algo/src/knn_solver/knn_solver_factory.dart'; | ||
import 'package:ml_dataframe/ml_dataframe.dart'; | ||
import 'package:ml_linalg/distance.dart'; | ||
import 'package:ml_algo/src/knn_kernel/kernel.dart'; | ||
import 'package:ml_algo/src/knn_solver/knn_solver.dart'; | ||
import 'package:ml_linalg/dtype.dart'; | ||
|
||
class KnnClassifierFactoryImpl implements KnnClassifierFactory { | ||
|
||
KnnClassifierFactoryImpl(this._kernelFactory, this._knnSolverFactory); | ||
|
||
final KernelFactory _kernelFactory; | ||
final KnnSolverFactory _knnSolverFactory; | ||
const KnnClassifierFactoryImpl(); | ||
|
||
@override | ||
KnnClassifier create( | ||
DataFrame fittingData, | ||
String targetName, | ||
int k, | ||
KernelType kernelType, | ||
Distance distance, | ||
List<num> classLabels, | ||
Kernel kernel, | ||
KnnSolver solver, | ||
DType dtype, | ||
) { | ||
validateTrainData(fittingData, [targetName]); | ||
|
||
final splits = featuresTargetSplit(fittingData, | ||
targetNames: [targetName], | ||
).toList(); | ||
|
||
final featuresSplit = splits[0]; | ||
final targetSplit = splits[1]; | ||
|
||
final trainFeatures = featuresSplit.toMatrix(dtype); | ||
final trainLabels = targetSplit.toMatrix(dtype); | ||
final classLabels = targetSplit[targetName].isDiscrete | ||
? targetSplit[targetName] | ||
.discreteValues | ||
.map((dynamic value) => value as num) | ||
.toList(growable: false) | ||
: targetSplit | ||
.toMatrix(dtype) | ||
.getColumn(0) | ||
.unique() | ||
.toList(growable: false); | ||
|
||
final kernel = _kernelFactory.createByType(kernelType); | ||
|
||
final solver = _knnSolverFactory.create( | ||
trainFeatures, | ||
trainLabels, | ||
k, | ||
distance, | ||
true, | ||
); | ||
|
||
return KnnClassifierImpl( | ||
targetName, | ||
classLabels, | ||
kernel, | ||
solver, | ||
dtype, | ||
); | ||
} | ||
) => KnnClassifierImpl(targetName, classLabels, kernel, solver, dtype); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
113 changes: 8 additions & 105 deletions
113
test/classifier/knn_classifier/knn_classifier_factory_impl_test.dart
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,119 +1,22 @@ | ||
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_factory_impl.dart'; | ||
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_impl.dart'; | ||
import 'package:ml_algo/src/knn_kernel/kernel_type.dart'; | ||
import 'package:ml_dataframe/ml_dataframe.dart'; | ||
import 'package:ml_linalg/distance.dart'; | ||
import 'package:ml_linalg/dtype.dart'; | ||
import 'package:mockito/mockito.dart'; | ||
import 'package:test/test.dart'; | ||
|
||
import '../../mocks.dart'; | ||
|
||
void main() { | ||
group('KnnClassifierFactoryImpl', () { | ||
final solverMock = KnnSolverMock(); | ||
final kernelMock = KernelMock(); | ||
final kernelFnFactory = createKernelFactoryMock(kernelMock); | ||
final solverFnFactory = createKnnSolverFactoryMock(solverMock); | ||
test('should create a KnnClassifierImpl instance', () { | ||
final factory = const KnnClassifierFactoryImpl(); | ||
final kernelMock = KernelMock(); | ||
final solverMock = KnnSolverMock(); | ||
final dtype = DType.float32; | ||
|
||
final factory = KnnClassifierFactoryImpl(kernelFnFactory, solverFnFactory); | ||
final actual = factory.create('target', [1, 2, 3], kernelMock, solverMock, | ||
dtype); | ||
|
||
final data = DataFrame.fromSeries( | ||
[ | ||
Series('first' , <num>[1, 1, 1, 1]), | ||
Series('second', <num>[2, 2, 2, 2]), | ||
Series('third' , <num>[2, 2, 2, 2]), | ||
Series('fourth', <num>[4, 4, 4, 4]), | ||
Series('fifth' , <num>[1, 3, 2, 1], isDiscrete: true), | ||
] | ||
); | ||
|
||
final targetName = 'fifth'; | ||
|
||
test('should return a knn classifier', () { | ||
final classifier = factory.create( | ||
data, | ||
targetName, | ||
2, | ||
KernelType.uniform, | ||
Distance.hamming, | ||
DType.float32, | ||
); | ||
|
||
verify(kernelFnFactory.createByType(KernelType.uniform)).called(1); | ||
verify(solverFnFactory.create( | ||
argThat(equals([ | ||
[1, 2, 2, 4], | ||
[1, 2, 2, 4], | ||
[1, 2, 2, 4], | ||
[1, 2, 2, 4], | ||
])), | ||
argThat(equals([ | ||
[1], | ||
[3], | ||
[2], | ||
[1], | ||
])), | ||
2, | ||
Distance.hamming, | ||
true, | ||
)).called(1); | ||
|
||
expect(classifier, isA<KnnClassifierImpl>()); | ||
}); | ||
|
||
test('should extract class label list from target column even if the ' | ||
'latter is not discrete', () { | ||
final data = DataFrame.fromSeries( | ||
[ | ||
Series('first' , <num>[1, 1, 1, 1]), | ||
Series('target' , <num>[1, 3, 2, 1]), | ||
] | ||
); | ||
|
||
final classifier = factory.create( | ||
data, | ||
'target', | ||
2, | ||
KernelType.uniform, | ||
Distance.hamming, | ||
DType.float32, | ||
); | ||
|
||
verify(kernelFnFactory.createByType(KernelType.uniform)).called(1); | ||
verify(solverFnFactory.create( | ||
argThat(equals([ | ||
[1], | ||
[1], | ||
[1], | ||
[1], | ||
])), | ||
argThat(equals([ | ||
[1], | ||
[3], | ||
[2], | ||
[1], | ||
])), | ||
2, | ||
Distance.hamming, | ||
true, | ||
)).called(1); | ||
|
||
expect(classifier, isA<KnnClassifierImpl>()); | ||
}); | ||
|
||
test('should throw an exception if target column does not exist in the ' | ||
'train data', () { | ||
final actual = () => factory.create( | ||
data, | ||
'unknown_column', | ||
2, | ||
KernelType.uniform, | ||
Distance.hamming, | ||
DType.float32, | ||
); | ||
|
||
expect(actual, throwsException); | ||
expect(actual, isA<KnnClassifierImpl>()); | ||
}); | ||
}); | ||
} |
Oops, something went wrong.