Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #98 from gyrdym/weighted-knn
Weighted knn regression
- Loading branch information
Showing
16 changed files
with
371 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
// 5.7 sec | ||
import 'dart:async'; | ||
|
||
import 'package:benchmark_harness/benchmark_harness.dart'; | ||
import 'package:ml_algo/ml_algo.dart'; | ||
import 'package:ml_linalg/matrix.dart'; | ||
import 'package:ml_linalg/vector.dart'; | ||
|
||
const observationsNum = 500; | ||
const featuresNum = 20; | ||
|
||
class KnnRegressorBenchmark extends BenchmarkBase { | ||
KnnRegressorBenchmark() : super('KNN regression benchmark'); | ||
|
||
Matrix features; | ||
Matrix testFeatures; | ||
Matrix labels; | ||
Matrix testLabels; | ||
NoNParametricRegressor regressor; | ||
|
||
|
||
static void main() { | ||
KnnRegressorBenchmark().report(); | ||
} | ||
|
||
@override | ||
void run() { | ||
regressor.predict(testFeatures); | ||
} | ||
|
||
@override | ||
void setup() { | ||
regressor = NoNParametricRegressor.nearestNeighbor(k: 7); | ||
|
||
features = Matrix.fromRows(List.generate(observationsNum * 2, | ||
(i) => Vector.randomFilled(featuresNum))); | ||
labels = Matrix.fromColumns([Vector.randomFilled(observationsNum * 2)]); | ||
|
||
testFeatures = Matrix.fromRows(List.generate(observationsNum, | ||
(i) => Vector.randomFilled(featuresNum))); | ||
testLabels = Matrix.fromColumns([Vector.randomFilled(observationsNum)]); | ||
|
||
regressor.fit(features, labels); | ||
} | ||
|
||
void tearDown() {} | ||
} | ||
|
||
Future main() async { | ||
KnnRegressorBenchmark.main(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import 'dart:math' as math; | ||
|
||
typedef KernelFn = double Function(double u); | ||
|
||
double uniformKernel(double u) => 1; | ||
|
||
double epanechnikovKernel(double u) => 0.75 * (1 - u * u); | ||
|
||
double cosineKernel(double u) => math.pi / 4 * math.cos(math.pi / 2 * u); | ||
|
||
double gaussianKernel(double u) => 1 / math.sqrt(2 * math.pi) * | ||
math.exp(-0.5 * u * u); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import 'package:ml_algo/src/algorithms/knn/kernel.dart'; | ||
import 'package:ml_algo/src/algorithms/knn/kernel_type.dart'; | ||
|
||
abstract class KernelFunctionFactory { | ||
KernelFn createByType(Kernel type); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import 'package:ml_algo/src/algorithms/knn/kernel.dart'; | ||
import 'package:ml_algo/src/algorithms/knn/kernel_function_factory.dart'; | ||
import 'package:ml_algo/src/algorithms/knn/kernel_type.dart'; | ||
|
||
class KernelFunctionFactoryImpl implements KernelFunctionFactory { | ||
const KernelFunctionFactoryImpl(); | ||
|
||
@override | ||
KernelFn createByType(Kernel type) { | ||
switch (type) { | ||
case Kernel.uniform: | ||
return uniformKernel; | ||
case Kernel.epanechnikov: | ||
return epanechnikovKernel; | ||
case Kernel.cosine: | ||
return cosineKernel; | ||
case Kernel.gaussian: | ||
return gaussianKernel; | ||
default: | ||
throw UnsupportedError('Unsupported kernel type - $type'); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
enum Kernel { uniform, epanechnikov, cosine, gaussian, } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,24 @@ | ||
import 'package:ml_algo/src/algorithms/knn/kernel_type.dart'; | ||
import 'package:ml_algo/src/regressor/knn_regressor.dart'; | ||
import 'package:ml_algo/src/regressor/regressor.dart'; | ||
import 'package:ml_linalg/distance.dart'; | ||
|
||
/// A factory for all the non parametric family of Machine Learning algorithms | ||
abstract class NoNParametricRegressor implements Regressor { | ||
/// Creates an instance of KNN regressor | ||
/// | ||
/// KNN here means "K nearest neighbor" | ||
/// [k] a number of neighbors | ||
/// | ||
/// [k] a number of nearest neighbours | ||
/// | ||
/// [kernel] a type of kernel function, that will be used to find an outcome | ||
/// for a new observation | ||
/// | ||
/// [distance] a distance type, that will be used to measure a distance | ||
/// between two observation vectors | ||
factory NoNParametricRegressor.nearestNeighbor({ | ||
int k, | ||
Distance distanceType, | ||
Kernel kernel, | ||
Distance distance, | ||
}) = KNNRegressor; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import 'package:ml_algo/src/algorithms/knn/kernel.dart'; | ||
import 'package:ml_algo/src/algorithms/knn/kernel_function_factory_impl.dart'; | ||
import 'package:ml_algo/src/algorithms/knn/kernel_type.dart'; | ||
import 'package:test/test.dart'; | ||
|
||
void main() { | ||
group('KernelFunctionFactoryImpl', () { | ||
final factory = const KernelFunctionFactoryImpl(); | ||
|
||
test('should create proper instance for kernels', () { | ||
expect([ | ||
factory.createByType(Kernel.uniform) is KernelFn, | ||
factory.createByType(Kernel.epanechnikov) is KernelFn, | ||
factory.createByType(Kernel.cosine) is KernelFn, | ||
factory.createByType(Kernel.gaussian) is KernelFn, | ||
], equals(List<bool>.filled(4, true))); | ||
}); | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import 'package:ml_algo/src/algorithms/knn/kernel.dart'; | ||
import 'package:test/test.dart'; | ||
|
||
void main() { | ||
group('Kernel', () { | ||
test('uniform should always return 1', () { | ||
expect(uniformKernel(0), 1); | ||
expect(uniformKernel(100000), 1); | ||
}); | ||
|
||
test('epanechnikov should return proper value', () { | ||
expect(epanechnikovKernel(0), 0.75); | ||
expect(epanechnikovKernel(1), 0); | ||
expect(epanechnikovKernel(10), -74.25); | ||
}); | ||
|
||
test('cosine should return proper value', () { | ||
expect(cosineKernel(0), closeTo(0.7853, 1e-4)); | ||
expect(cosineKernel(1), closeTo(0.0000, 1e-4)); | ||
expect(cosineKernel(20), closeTo(0.7853, 1e-4)); | ||
}); | ||
|
||
test('gaussian should return proper value', () { | ||
expect(gaussianKernel(0), closeTo(0.3989, 1e-4)); | ||
expect(gaussianKernel(1), closeTo(0.2419, 1e-4)); | ||
expect(gaussianKernel(3), closeTo(0.0044, 1e-4)); | ||
expect(gaussianKernel(10), closeTo(0.0000, 1e-4)); | ||
}); | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.