Skip to content

Commit

Permalink
Merge pull request #117 from gyrdym/create-factories-to-all-predictors
Browse files Browse the repository at this point in the history
KNN classifier instantiation refactored
  • Loading branch information
gyrdym committed Oct 25, 2019
2 parents 8bfe681 + 7dc6c4e commit 3624f04
Show file tree
Hide file tree
Showing 10 changed files with 234 additions and 206 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
@@ -1,5 +1,8 @@
# Changelog

## 13.3.2
- `KnnClassifier`: classifier instantiating refactored

## 13.3.1
- `readme`: KnnRegressor usage example fixed

Expand Down
@@ -0,0 +1,67 @@
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier.dart';
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_factory.dart';
import 'package:ml_algo/src/di/dependencies.dart';
import 'package:ml_algo/src/helpers/features_target_split.dart';
import 'package:ml_algo/src/helpers/validate_train_data.dart';
import 'package:ml_algo/src/knn_kernel/kernel_factory.dart';
import 'package:ml_algo/src/knn_kernel/kernel_type.dart';
import 'package:ml_algo/src/knn_solver/knn_solver_factory.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/distance.dart';
import 'package:ml_linalg/dtype.dart';

KnnClassifier createKnnClassifier(
DataFrame trainData,
String targetName,
int k,
KernelType kernelType,
Distance distance,
DType dtype,
) {
validateTrainData(trainData, [targetName]);

final splits = featuresTargetSplit(trainData,
targetNames: [targetName],
).toList();

final featuresSplit = splits[0];
final targetSplit = splits[1];

final trainFeatures = featuresSplit.toMatrix(dtype);
final trainLabels = targetSplit.toMatrix(dtype);

final classLabels = targetSplit[targetName].isDiscrete
? targetSplit[targetName]
.discreteValues
.map((dynamic value) => value as num)
.toList(growable: false)
: targetSplit
.toMatrix(dtype)
.getColumn(0)
.unique()
.toList(growable: false);

final kernelFactory = dependencies.getDependency<KernelFactory>();
final kernel = kernelFactory.createByType(kernelType);

final solverFactory = dependencies.getDependency<KnnSolverFactory>();

final solver = solverFactory.create(
trainFeatures,
trainLabels,
k,
distance,
true,
);

final knnClassifierFactory = dependencies
.getDependency<KnnClassifierFactory>();

return knnClassifierFactory.create(
targetName,
classLabels,
kernel,
solver,
dtype,
);
}
13 changes: 5 additions & 8 deletions lib/src/classifier/knn_classifier/knn_classifier.dart
@@ -1,6 +1,5 @@
import 'package:ml_algo/src/classifier/classifier.dart';
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_factory.dart';
import 'package:ml_algo/src/di/dependencies.dart';
import 'package:ml_algo/src/classifier/knn_classifier/_helpers/create_knn_classifier.dart';
import 'package:ml_algo/src/knn_kernel/kernel_type.dart';
import 'package:ml_algo/src/model_selection/assessable.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
Expand All @@ -22,14 +21,14 @@ import 'package:ml_linalg/dtype.dart';
abstract class KnnClassifier implements Assessable, Classifier {
/// Parameters:
///
/// [fittingData] Labelled observations, among which will be searched [k]
/// [trainData] Labelled observations, among which will be searched [k]
/// nearest to the given unlabelled observations neighbours. Must contain
/// [targetName] column.
///
/// [targetName] A string, that serves as a name of the column, that contains
/// labels (or outcomes).
///
/// [k] a number of nearest neighbours to be found among [fittingData]
/// [k] a number of nearest neighbours to be found among [trainData]
///
/// [kernel] a type of a kernel function, that will be used to predict an
/// outcome for a new observation
Expand All @@ -41,15 +40,13 @@ abstract class KnnClassifier implements Assessable, Classifier {
/// affect performance or accuracy of the computations. Default value is
/// [DType.float32]
factory KnnClassifier(
DataFrame fittingData,
DataFrame trainData,
String targetName,
int k,
{
KernelType kernel = KernelType.gaussian,
Distance distance = Distance.euclidean,
DType dtype = DType.float32,
}
) => dependencies
.getDependency<KnnClassifierFactory>()
.create(fittingData, targetName, k, kernel, distance, dtype);
) => createKnnClassifier(trainData, targetName, k, kernel, distance, dtype);
}
12 changes: 5 additions & 7 deletions lib/src/classifier/knn_classifier/knn_classifier_factory.dart
@@ -1,16 +1,14 @@
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier.dart';
import 'package:ml_algo/src/knn_kernel/kernel_type.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/distance.dart';
import 'package:ml_algo/src/knn_kernel/kernel.dart';
import 'package:ml_algo/src/knn_solver/knn_solver.dart';
import 'package:ml_linalg/dtype.dart';

abstract class KnnClassifierFactory {
KnnClassifier create(
DataFrame fittingData,
String targetName,
int k,
KernelType kernelType,
Distance distance,
List<num> classLabels,
Kernel kernel,
KnnSolver solver,
DType dtype,
);
}
64 changes: 7 additions & 57 deletions lib/src/classifier/knn_classifier/knn_classifier_factory_impl.dart
@@ -1,69 +1,19 @@
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier.dart';
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_factory.dart';
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_impl.dart';
import 'package:ml_algo/src/helpers/features_target_split.dart';
import 'package:ml_algo/src/helpers/validate_train_data.dart';
import 'package:ml_algo/src/knn_kernel/kernel_factory.dart';
import 'package:ml_algo/src/knn_kernel/kernel_type.dart';
import 'package:ml_algo/src/knn_solver/knn_solver_factory.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/distance.dart';
import 'package:ml_algo/src/knn_kernel/kernel.dart';
import 'package:ml_algo/src/knn_solver/knn_solver.dart';
import 'package:ml_linalg/dtype.dart';

class KnnClassifierFactoryImpl implements KnnClassifierFactory {

KnnClassifierFactoryImpl(this._kernelFactory, this._knnSolverFactory);

final KernelFactory _kernelFactory;
final KnnSolverFactory _knnSolverFactory;
const KnnClassifierFactoryImpl();

@override
KnnClassifier create(
DataFrame fittingData,
String targetName,
int k,
KernelType kernelType,
Distance distance,
List<num> classLabels,
Kernel kernel,
KnnSolver solver,
DType dtype,
) {
validateTrainData(fittingData, [targetName]);

final splits = featuresTargetSplit(fittingData,
targetNames: [targetName],
).toList();

final featuresSplit = splits[0];
final targetSplit = splits[1];

final trainFeatures = featuresSplit.toMatrix(dtype);
final trainLabels = targetSplit.toMatrix(dtype);
final classLabels = targetSplit[targetName].isDiscrete
? targetSplit[targetName]
.discreteValues
.map((dynamic value) => value as num)
.toList(growable: false)
: targetSplit
.toMatrix(dtype)
.getColumn(0)
.unique()
.toList(growable: false);

final kernel = _kernelFactory.createByType(kernelType);

final solver = _knnSolverFactory.create(
trainFeatures,
trainLabels,
k,
distance,
true,
);

return KnnClassifierImpl(
targetName,
classLabels,
kernel,
solver,
dtype,
);
}
) => KnnClassifierImpl(targetName, classLabels, kernel, solver, dtype);
}
5 changes: 1 addition & 4 deletions lib/src/di/dependencies.dart
Expand Up @@ -58,10 +58,7 @@ Injector get dependencies =>
(_) => const KnnSolverFactoryImpl())

..registerSingleton<KnnClassifierFactory>(
(injector) => KnnClassifierFactoryImpl(
injector.getDependency<KernelFactory>(),
injector.getDependency<KnnSolverFactory>(),
))
(_) => const KnnClassifierFactoryImpl())

..registerSingleton<KnnRegressorFactory>(
(injector) => KnnRegressorFactoryImpl(
Expand Down
2 changes: 1 addition & 1 deletion pubspec.yaml
@@ -1,6 +1,6 @@
name: ml_algo
description: Machine learning algorithms written in native dart
version: 13.3.1
version: 13.3.2
author: Ilia Gyrdymov <ilgyrd@gmail.com>
homepage: https://github.com/gyrdym/ml_algo

Expand Down
113 changes: 8 additions & 105 deletions test/classifier/knn_classifier/knn_classifier_factory_impl_test.dart
@@ -1,119 +1,22 @@
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_factory_impl.dart';
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_impl.dart';
import 'package:ml_algo/src/knn_kernel/kernel_type.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/distance.dart';
import 'package:ml_linalg/dtype.dart';
import 'package:mockito/mockito.dart';
import 'package:test/test.dart';

import '../../mocks.dart';

void main() {
group('KnnClassifierFactoryImpl', () {
final solverMock = KnnSolverMock();
final kernelMock = KernelMock();
final kernelFnFactory = createKernelFactoryMock(kernelMock);
final solverFnFactory = createKnnSolverFactoryMock(solverMock);
test('should create a KnnClassifierImpl instance', () {
final factory = const KnnClassifierFactoryImpl();
final kernelMock = KernelMock();
final solverMock = KnnSolverMock();
final dtype = DType.float32;

final factory = KnnClassifierFactoryImpl(kernelFnFactory, solverFnFactory);
final actual = factory.create('target', [1, 2, 3], kernelMock, solverMock,
dtype);

final data = DataFrame.fromSeries(
[
Series('first' , <num>[1, 1, 1, 1]),
Series('second', <num>[2, 2, 2, 2]),
Series('third' , <num>[2, 2, 2, 2]),
Series('fourth', <num>[4, 4, 4, 4]),
Series('fifth' , <num>[1, 3, 2, 1], isDiscrete: true),
]
);

final targetName = 'fifth';

test('should return a knn classifier', () {
final classifier = factory.create(
data,
targetName,
2,
KernelType.uniform,
Distance.hamming,
DType.float32,
);

verify(kernelFnFactory.createByType(KernelType.uniform)).called(1);
verify(solverFnFactory.create(
argThat(equals([
[1, 2, 2, 4],
[1, 2, 2, 4],
[1, 2, 2, 4],
[1, 2, 2, 4],
])),
argThat(equals([
[1],
[3],
[2],
[1],
])),
2,
Distance.hamming,
true,
)).called(1);

expect(classifier, isA<KnnClassifierImpl>());
});

test('should extract class label list from target column even if the '
'latter is not discrete', () {
final data = DataFrame.fromSeries(
[
Series('first' , <num>[1, 1, 1, 1]),
Series('target' , <num>[1, 3, 2, 1]),
]
);

final classifier = factory.create(
data,
'target',
2,
KernelType.uniform,
Distance.hamming,
DType.float32,
);

verify(kernelFnFactory.createByType(KernelType.uniform)).called(1);
verify(solverFnFactory.create(
argThat(equals([
[1],
[1],
[1],
[1],
])),
argThat(equals([
[1],
[3],
[2],
[1],
])),
2,
Distance.hamming,
true,
)).called(1);

expect(classifier, isA<KnnClassifierImpl>());
});

test('should throw an exception if target column does not exist in the '
'train data', () {
final actual = () => factory.create(
data,
'unknown_column',
2,
KernelType.uniform,
Distance.hamming,
DType.float32,
);

expect(actual, throwsException);
expect(actual, isA<KnnClassifierImpl>());
});
});
}

0 comments on commit 3624f04

Please sign in to comment.