Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
74 changed files
with
2,001 additions
and
733 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
// MacBook Air 13.3 mid 2017: ~ 5 sec | ||
import 'package:benchmark_harness/benchmark_harness.dart'; | ||
import 'package:ml_algo/src/knn_solver/knn_solver.dart'; | ||
import 'package:ml_algo/src/knn_solver/knn_solver_impl.dart'; | ||
import 'package:ml_linalg/distance.dart'; | ||
import 'package:ml_linalg/matrix.dart'; | ||
import 'package:ml_linalg/vector.dart'; | ||
|
||
const k = 10; | ||
const trainObservationsNum = 2000; | ||
const observationsNum = 100; | ||
const featuresNum = 100; | ||
|
||
class KnnSolverBenchmark extends BenchmarkBase { | ||
KnnSolverBenchmark() : super('KnnSolver benchmark'); | ||
|
||
KnnSolver solver; | ||
Matrix features; | ||
|
||
static void main() { | ||
KnnSolverBenchmark().report(); | ||
} | ||
|
||
@override | ||
void run() { | ||
solver.findKNeighbours(features).toList(growable: false); | ||
} | ||
|
||
@override | ||
void setup() { | ||
final trainFeatures = Matrix.fromRows(List.generate(trainObservationsNum, | ||
(i) => Vector.randomFilled(featuresNum))); | ||
final trainLabels = Matrix.fromColumns([Vector.randomFilled(trainObservationsNum)]); | ||
|
||
solver = KnnSolverImpl(trainFeatures, trainLabels, k, Distance.euclidean, | ||
false); | ||
|
||
features = Matrix.fromRows(List.generate(observationsNum, | ||
(i) => Vector.randomFilled(featuresNum))); | ||
} | ||
} | ||
|
||
void main() { | ||
KnnSolverBenchmark.main(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,14 @@ | ||
export 'package:ml_algo/src/algorithms/knn/kernel_type.dart'; | ||
export 'package:ml_algo/src/classifier/decision_tree_classifier.dart'; | ||
export 'package:ml_algo/src/classifier/logistic_regressor.dart'; | ||
export 'package:ml_algo/src/classifier/softmax_regressor.dart'; | ||
export 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier.dart'; | ||
export 'package:ml_algo/src/classifier/knn_classifier/knn_classifier.dart'; | ||
export 'package:ml_algo/src/classifier/logistic_regressor/logistic_regressor.dart'; | ||
export 'package:ml_algo/src/classifier/softmax_regressor/softmax_regressor.dart'; | ||
export 'package:ml_algo/src/knn_kernel/kernel_type.dart'; | ||
export 'package:ml_algo/src/linear_optimizer/gradient_optimizer/learning_rate_generator/learning_rate_type.dart'; | ||
export 'package:ml_algo/src/linear_optimizer/linear_optimizer_type.dart'; | ||
export 'package:ml_algo/src/linear_optimizer/regularization_type.dart'; | ||
export 'package:ml_algo/src/metric/classification/type.dart'; | ||
export 'package:ml_algo/src/metric/metric_type.dart'; | ||
export 'package:ml_algo/src/metric/regression/type.dart'; | ||
export 'package:ml_algo/src/model_selection/cross_validator/cross_validator.dart'; | ||
export 'package:ml_algo/src/regressor/knn_regressor.dart'; | ||
export 'package:ml_algo/src/regressor/linear_regressor.dart'; | ||
export 'package:ml_algo/src/regressor/knn_regressor/knn_regressor.dart'; | ||
export 'package:ml_algo/src/regressor/linear_regressor/linear_regressor.dart'; |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
.../classifier/decision_tree_classifier.dart → ..._classifier/decision_tree_classifier.dart
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
...sifier/decision_tree_classifier_impl.dart → ...sifier/decision_tree_classifier_impl.dart
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import 'package:ml_algo/src/classifier/classifier.dart'; | ||
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_factory.dart'; | ||
import 'package:ml_algo/src/di/dependencies.dart'; | ||
import 'package:ml_algo/src/knn_kernel/kernel_type.dart'; | ||
import 'package:ml_algo/src/model_selection/assessable.dart'; | ||
import 'package:ml_dataframe/ml_dataframe.dart'; | ||
import 'package:ml_linalg/distance.dart'; | ||
import 'package:ml_linalg/dtype.dart'; | ||
|
||
/// A class that performs classification basing on `k nearest neighbours` (KNN) | ||
/// algorithm | ||
/// | ||
/// K nearest neighbours algorithm is an algorithm that is targeted to search | ||
/// most similar labelled observations (number of these observations equals `k`) | ||
/// for the given unlabelled one. | ||
/// | ||
/// It is possible to use majority class among the k found observations as a | ||
/// prediction for the given unlabelled observation, but it may lead to the | ||
/// imprecise result. Thus the weighted version of KNN algorithm is used in the | ||
/// classifier. To get weight of a particular observation one may use a kernel | ||
/// function. | ||
abstract class KnnClassifier implements Assessable, Classifier { | ||
/// Parameters: | ||
/// | ||
/// [fittingData] Labelled observations, among which will be searched [k] | ||
/// nearest to the given unlabelled observations neighbours. Must contain | ||
/// [targetName] column. | ||
/// | ||
/// [targetName] A string, that serves as a name of the column, that contains | ||
/// labels (or outcomes). | ||
/// | ||
/// [k] a number of nearest neighbours to be found among [fittingData] | ||
/// | ||
/// [kernel] a type of a kernel function, that will be used to predict an | ||
/// outcome for a new observation | ||
/// | ||
/// [distance] a distance type, that will be used to measure a distance | ||
/// between two observation vectors | ||
/// | ||
/// [dtype] A data type for all the numeric values, used by the algorithm. Can | ||
/// affect performance or accuracy of the computations. Default value is | ||
/// [DType.float32] | ||
factory KnnClassifier( | ||
DataFrame fittingData, | ||
String targetName, | ||
int k, | ||
{ | ||
KernelType kernel = KernelType.gaussian, | ||
Distance distance = Distance.euclidean, | ||
DType dtype = DType.float32, | ||
} | ||
) => dependencies | ||
.getDependency<KnnClassifierFactory>() | ||
.create(fittingData, targetName, k, kernel, distance, dtype); | ||
} |
Oops, something went wrong.