diff --git a/CHANGELOG.md b/CHANGELOG.md index 8bab9467..449bc0b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ # Changelog +## 13.3.2 +- `KnnClassifier`: classifier instantiating refactored + ## 13.3.1 - `readme`: KnnRegressor usage example fixed diff --git a/lib/src/classifier/knn_classifier/_helpers/create_knn_classifier.dart b/lib/src/classifier/knn_classifier/_helpers/create_knn_classifier.dart new file mode 100644 index 00000000..79e4f5aa --- /dev/null +++ b/lib/src/classifier/knn_classifier/_helpers/create_knn_classifier.dart @@ -0,0 +1,67 @@ +import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier.dart'; +import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_factory.dart'; +import 'package:ml_algo/src/di/dependencies.dart'; +import 'package:ml_algo/src/helpers/features_target_split.dart'; +import 'package:ml_algo/src/helpers/validate_train_data.dart'; +import 'package:ml_algo/src/knn_kernel/kernel_factory.dart'; +import 'package:ml_algo/src/knn_kernel/kernel_type.dart'; +import 'package:ml_algo/src/knn_solver/knn_solver_factory.dart'; +import 'package:ml_dataframe/ml_dataframe.dart'; +import 'package:ml_linalg/distance.dart'; +import 'package:ml_linalg/dtype.dart'; + +KnnClassifier createKnnClassifier( + DataFrame trainData, + String targetName, + int k, + KernelType kernelType, + Distance distance, + DType dtype, +) { + validateTrainData(trainData, [targetName]); + + final splits = featuresTargetSplit(trainData, + targetNames: [targetName], + ).toList(); + + final featuresSplit = splits[0]; + final targetSplit = splits[1]; + + final trainFeatures = featuresSplit.toMatrix(dtype); + final trainLabels = targetSplit.toMatrix(dtype); + + final classLabels = targetSplit[targetName].isDiscrete + ? targetSplit[targetName] + .discreteValues + .map((dynamic value) => value as num) + .toList(growable: false) + : targetSplit + .toMatrix(dtype) + .getColumn(0) + .unique() + .toList(growable: false); + + final kernelFactory = dependencies.getDependency(); + final kernel = kernelFactory.createByType(kernelType); + + final solverFactory = dependencies.getDependency(); + + final solver = solverFactory.create( + trainFeatures, + trainLabels, + k, + distance, + true, + ); + + final knnClassifierFactory = dependencies + .getDependency(); + + return knnClassifierFactory.create( + targetName, + classLabels, + kernel, + solver, + dtype, + ); +} diff --git a/lib/src/classifier/knn_classifier/knn_classifier.dart b/lib/src/classifier/knn_classifier/knn_classifier.dart index 12c1823c..c3994dc1 100644 --- a/lib/src/classifier/knn_classifier/knn_classifier.dart +++ b/lib/src/classifier/knn_classifier/knn_classifier.dart @@ -1,6 +1,5 @@ import 'package:ml_algo/src/classifier/classifier.dart'; -import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_factory.dart'; -import 'package:ml_algo/src/di/dependencies.dart'; +import 'package:ml_algo/src/classifier/knn_classifier/_helpers/create_knn_classifier.dart'; import 'package:ml_algo/src/knn_kernel/kernel_type.dart'; import 'package:ml_algo/src/model_selection/assessable.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; @@ -22,14 +21,14 @@ import 'package:ml_linalg/dtype.dart'; abstract class KnnClassifier implements Assessable, Classifier { /// Parameters: /// - /// [fittingData] Labelled observations, among which will be searched [k] + /// [trainData] Labelled observations, among which will be searched [k] /// nearest to the given unlabelled observations neighbours. Must contain /// [targetName] column. /// /// [targetName] A string, that serves as a name of the column, that contains /// labels (or outcomes). /// - /// [k] a number of nearest neighbours to be found among [fittingData] + /// [k] a number of nearest neighbours to be found among [trainData] /// /// [kernel] a type of a kernel function, that will be used to predict an /// outcome for a new observation @@ -41,7 +40,7 @@ abstract class KnnClassifier implements Assessable, Classifier { /// affect performance or accuracy of the computations. Default value is /// [DType.float32] factory KnnClassifier( - DataFrame fittingData, + DataFrame trainData, String targetName, int k, { @@ -49,7 +48,5 @@ abstract class KnnClassifier implements Assessable, Classifier { Distance distance = Distance.euclidean, DType dtype = DType.float32, } - ) => dependencies - .getDependency() - .create(fittingData, targetName, k, kernel, distance, dtype); + ) => createKnnClassifier(trainData, targetName, k, kernel, distance, dtype); } diff --git a/lib/src/classifier/knn_classifier/knn_classifier_factory.dart b/lib/src/classifier/knn_classifier/knn_classifier_factory.dart index 4fc0f7c0..e4464ad0 100644 --- a/lib/src/classifier/knn_classifier/knn_classifier_factory.dart +++ b/lib/src/classifier/knn_classifier/knn_classifier_factory.dart @@ -1,16 +1,14 @@ import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier.dart'; -import 'package:ml_algo/src/knn_kernel/kernel_type.dart'; -import 'package:ml_dataframe/ml_dataframe.dart'; -import 'package:ml_linalg/distance.dart'; +import 'package:ml_algo/src/knn_kernel/kernel.dart'; +import 'package:ml_algo/src/knn_solver/knn_solver.dart'; import 'package:ml_linalg/dtype.dart'; abstract class KnnClassifierFactory { KnnClassifier create( - DataFrame fittingData, String targetName, - int k, - KernelType kernelType, - Distance distance, + List classLabels, + Kernel kernel, + KnnSolver solver, DType dtype, ); } diff --git a/lib/src/classifier/knn_classifier/knn_classifier_factory_impl.dart b/lib/src/classifier/knn_classifier/knn_classifier_factory_impl.dart index 24566cb1..57f79288 100644 --- a/lib/src/classifier/knn_classifier/knn_classifier_factory_impl.dart +++ b/lib/src/classifier/knn_classifier/knn_classifier_factory_impl.dart @@ -1,69 +1,19 @@ import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier.dart'; import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_factory.dart'; import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_impl.dart'; -import 'package:ml_algo/src/helpers/features_target_split.dart'; -import 'package:ml_algo/src/helpers/validate_train_data.dart'; -import 'package:ml_algo/src/knn_kernel/kernel_factory.dart'; -import 'package:ml_algo/src/knn_kernel/kernel_type.dart'; -import 'package:ml_algo/src/knn_solver/knn_solver_factory.dart'; -import 'package:ml_dataframe/ml_dataframe.dart'; -import 'package:ml_linalg/distance.dart'; +import 'package:ml_algo/src/knn_kernel/kernel.dart'; +import 'package:ml_algo/src/knn_solver/knn_solver.dart'; import 'package:ml_linalg/dtype.dart'; class KnnClassifierFactoryImpl implements KnnClassifierFactory { - - KnnClassifierFactoryImpl(this._kernelFactory, this._knnSolverFactory); - - final KernelFactory _kernelFactory; - final KnnSolverFactory _knnSolverFactory; + const KnnClassifierFactoryImpl(); @override KnnClassifier create( - DataFrame fittingData, String targetName, - int k, - KernelType kernelType, - Distance distance, + List classLabels, + Kernel kernel, + KnnSolver solver, DType dtype, - ) { - validateTrainData(fittingData, [targetName]); - - final splits = featuresTargetSplit(fittingData, - targetNames: [targetName], - ).toList(); - - final featuresSplit = splits[0]; - final targetSplit = splits[1]; - - final trainFeatures = featuresSplit.toMatrix(dtype); - final trainLabels = targetSplit.toMatrix(dtype); - final classLabels = targetSplit[targetName].isDiscrete - ? targetSplit[targetName] - .discreteValues - .map((dynamic value) => value as num) - .toList(growable: false) - : targetSplit - .toMatrix(dtype) - .getColumn(0) - .unique() - .toList(growable: false); - - final kernel = _kernelFactory.createByType(kernelType); - - final solver = _knnSolverFactory.create( - trainFeatures, - trainLabels, - k, - distance, - true, - ); - - return KnnClassifierImpl( - targetName, - classLabels, - kernel, - solver, - dtype, - ); - } + ) => KnnClassifierImpl(targetName, classLabels, kernel, solver, dtype); } diff --git a/lib/src/di/dependencies.dart b/lib/src/di/dependencies.dart index 46659a41..0ec0b775 100644 --- a/lib/src/di/dependencies.dart +++ b/lib/src/di/dependencies.dart @@ -58,10 +58,7 @@ Injector get dependencies => (_) => const KnnSolverFactoryImpl()) ..registerSingleton( - (injector) => KnnClassifierFactoryImpl( - injector.getDependency(), - injector.getDependency(), - )) + (_) => const KnnClassifierFactoryImpl()) ..registerSingleton( (injector) => KnnRegressorFactoryImpl( diff --git a/pubspec.yaml b/pubspec.yaml index 38e6b7b7..166c2e83 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -1,6 +1,6 @@ name: ml_algo description: Machine learning algorithms written in native dart -version: 13.3.1 +version: 13.3.2 author: Ilia Gyrdymov homepage: https://github.com/gyrdym/ml_algo diff --git a/test/classifier/knn_classifier/knn_classifier_factory_impl_test.dart b/test/classifier/knn_classifier/knn_classifier_factory_impl_test.dart index 4f47caa4..d3e190c7 100644 --- a/test/classifier/knn_classifier/knn_classifier_factory_impl_test.dart +++ b/test/classifier/knn_classifier/knn_classifier_factory_impl_test.dart @@ -1,119 +1,22 @@ import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_factory_impl.dart'; import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_impl.dart'; -import 'package:ml_algo/src/knn_kernel/kernel_type.dart'; -import 'package:ml_dataframe/ml_dataframe.dart'; -import 'package:ml_linalg/distance.dart'; import 'package:ml_linalg/dtype.dart'; -import 'package:mockito/mockito.dart'; import 'package:test/test.dart'; import '../../mocks.dart'; void main() { group('KnnClassifierFactoryImpl', () { - final solverMock = KnnSolverMock(); - final kernelMock = KernelMock(); - final kernelFnFactory = createKernelFactoryMock(kernelMock); - final solverFnFactory = createKnnSolverFactoryMock(solverMock); + test('should create a KnnClassifierImpl instance', () { + final factory = const KnnClassifierFactoryImpl(); + final kernelMock = KernelMock(); + final solverMock = KnnSolverMock(); + final dtype = DType.float32; - final factory = KnnClassifierFactoryImpl(kernelFnFactory, solverFnFactory); + final actual = factory.create('target', [1, 2, 3], kernelMock, solverMock, + dtype); - final data = DataFrame.fromSeries( - [ - Series('first' , [1, 1, 1, 1]), - Series('second', [2, 2, 2, 2]), - Series('third' , [2, 2, 2, 2]), - Series('fourth', [4, 4, 4, 4]), - Series('fifth' , [1, 3, 2, 1], isDiscrete: true), - ] - ); - - final targetName = 'fifth'; - - test('should return a knn classifier', () { - final classifier = factory.create( - data, - targetName, - 2, - KernelType.uniform, - Distance.hamming, - DType.float32, - ); - - verify(kernelFnFactory.createByType(KernelType.uniform)).called(1); - verify(solverFnFactory.create( - argThat(equals([ - [1, 2, 2, 4], - [1, 2, 2, 4], - [1, 2, 2, 4], - [1, 2, 2, 4], - ])), - argThat(equals([ - [1], - [3], - [2], - [1], - ])), - 2, - Distance.hamming, - true, - )).called(1); - - expect(classifier, isA()); - }); - - test('should extract class label list from target column even if the ' - 'latter is not discrete', () { - final data = DataFrame.fromSeries( - [ - Series('first' , [1, 1, 1, 1]), - Series('target' , [1, 3, 2, 1]), - ] - ); - - final classifier = factory.create( - data, - 'target', - 2, - KernelType.uniform, - Distance.hamming, - DType.float32, - ); - - verify(kernelFnFactory.createByType(KernelType.uniform)).called(1); - verify(solverFnFactory.create( - argThat(equals([ - [1], - [1], - [1], - [1], - ])), - argThat(equals([ - [1], - [3], - [2], - [1], - ])), - 2, - Distance.hamming, - true, - )).called(1); - - expect(classifier, isA()); - }); - - test('should throw an exception if target column does not exist in the ' - 'train data', () { - final actual = () => factory.create( - data, - 'unknown_column', - 2, - KernelType.uniform, - Distance.hamming, - DType.float32, - ); - - expect(actual, throwsException); + expect(actual, isA()); }); }); } diff --git a/test/classifier/knn_classifier/knn_classifier_test.dart b/test/classifier/knn_classifier/knn_classifier_test.dart index 951d7933..1acc80ec 100644 --- a/test/classifier/knn_classifier/knn_classifier_test.dart +++ b/test/classifier/knn_classifier/knn_classifier_test.dart @@ -3,6 +3,10 @@ import 'package:ml_algo/ml_algo.dart'; import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier.dart'; import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_factory.dart'; import 'package:ml_algo/src/di/injector.dart'; +import 'package:ml_algo/src/knn_kernel/kernel.dart'; +import 'package:ml_algo/src/knn_kernel/kernel_factory.dart'; +import 'package:ml_algo/src/knn_solver/knn_solver.dart'; +import 'package:ml_algo/src/knn_solver/knn_solver_factory.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_linalg/distance.dart'; import 'package:ml_linalg/dtype.dart'; @@ -13,27 +17,99 @@ import '../../mocks.dart'; void main() { group('KnnClassifier', () { - final data = DataFrame( - >[ - [1, 2, 2, 4, 5], - [1, 2, 2, 4, 5], - [1, 2, 2, 4, 5], - [1, 2, 2, 4, 5], - ], - headerExists: false, - header: ['first', 'second', 'third', 'fourth', 'fifth'], + final data = DataFrame.fromSeries( + [ + Series('first' , [1, 1, 1, 1]), + Series('second', [2, 2, 2, 2]), + Series('third' , [2, 2, 2, 2]), + Series('fourth', [4, 4, 4, 4]), + Series('fifth' , [1, 3, 2, 1], isDiscrete: true), + ], ); final targetName = 'fifth'; - final knnClassifier = KnnClassifierMock(); - final knnClassifierFactory = createKnnClassifierFactoryMock(knnClassifier); + Kernel kernelMock; + KernelFactory kernelFactoryMock; - setUp(() => injector = Injector() - ..registerSingleton((_) => knnClassifierFactory), - ); + KnnSolver solverMock; + KnnSolverFactory solverFactoryMock; + + KnnClassifier knnClassifierMock; + KnnClassifierFactory knnClassifierFactoryMock; + + setUp(() { + kernelMock = KernelMock(); + kernelFactoryMock = createKernelFactoryMock(kernelMock); + + solverMock = KnnSolverMock(); + solverFactoryMock = createKnnSolverFactoryMock(solverMock); + + knnClassifierMock = KnnClassifierMock(); + knnClassifierFactoryMock = createKnnClassifierFactoryMock( + knnClassifierMock); + + injector = Injector() + ..registerSingleton((_) => kernelFactoryMock) + ..registerSingleton((_) => solverFactoryMock) + ..registerSingleton((_) => knnClassifierFactoryMock); + }); + + tearDown(() { + reset(kernelMock); + reset(kernelFactoryMock); + + reset(solverMock); + reset(solverFactoryMock); + + reset(knnClassifierMock); + reset(knnClassifierFactoryMock); + + injector = null; + }); - tearDown(() => injector = null); + test('should call kernel factory with proper kernel type', () { + KnnClassifier( + data, + targetName, + 2, + kernel: KernelType.uniform, + distance: Distance.cosine, + dtype: DType.float32, + ); + + verify(kernelFactoryMock.createByType(KernelType.uniform)).called(1); + }); + + test('should call solver factory with proper train features, train labels, ' + 'k parameter, distance type and standardization flag', () { + KnnClassifier( + data, + targetName, + 2, + kernel: KernelType.uniform, + distance: Distance.hamming, + dtype: DType.float32, + ); + + verify(solverFactoryMock.create( + argThat(equals([ + [1, 2, 2, 4], + [1, 2, 2, 4], + [1, 2, 2, 4], + [1, 2, 2, 4], + ])), + argThat(equals([ + [1], + [3], + [2], + [1], + ])), + 2, + Distance.hamming, + true, + )).called(1); + }); test('should call KnnClassifierFactory in order to create a classifier', () { final classifier = KnnClassifier( @@ -42,19 +118,56 @@ void main() { 2, kernel: KernelType.uniform, distance: Distance.cosine, - dtype: DType.float64, + dtype: DType.float32, ); - verify(knnClassifierFactory.create( - data, + verify(knnClassifierFactoryMock.create( targetName, - 2, - KernelType.uniform, - Distance.cosine, - DType.float64 + [1, 3, 2], + kernelMock, + solverMock, + DType.float32, )).called(1); - expect(classifier, same(knnClassifier)); + expect(classifier, same(knnClassifierMock)); + }); + + test('should extract class label list from target column even if the ' + 'latter is not marked as discrete', () { + final data = DataFrame.fromSeries( + [ + Series('first' , [1, 1, 1, 1, 1, 1, 1, 1]), + Series('target' , [1, 3, 2, 1, 3, 3, 2, 1]), + ], + ); + + KnnClassifier( + data, + 'target', + 2, + kernel: KernelType.uniform, + distance: Distance.hamming, + dtype: DType.float32, + ); + + final expectedLabels = [1, 3, 2]; + + verify(knnClassifierFactoryMock.create(any, expectedLabels, any, any, any)) + .called(1); + }); + + test('should throw an exception if target column does not exist in the ' + 'train data', () { + final actual = () => KnnClassifier( + data, + 'unknown_column', + 2, + kernel: KernelType.uniform, + distance: Distance.hamming, + dtype: DType.float32, + ); + + expect(actual, throwsException); }); }); } diff --git a/test/mocks.dart b/test/mocks.dart index b66095ae..4c68ab05 100644 --- a/test/mocks.dart +++ b/test/mocks.dart @@ -210,7 +210,7 @@ KnnSolverFactory createKnnSolverFactoryMock(KnnSolver solver) { KnnClassifierFactory createKnnClassifierFactoryMock(KnnClassifier classifier) { final factory = KnnClassifierFactoryMock(); - when(factory.create(any, any, any, any, any, any)).thenReturn(classifier); + when(factory.create(any, any, any, any, any)).thenReturn(classifier); return factory; }