/
basics.py
137 lines (119 loc) · 13.6 KB
/
basics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# -*- coding: utf-8 -*-
from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import load_iris
from sklearn_porter import Porter
iris_data = load_iris()
X = iris_data.data
y = iris_data.target
clf = KNeighborsClassifier(algorithm='brute',
n_neighbors=3,
weights='uniform')
clf.fit(X, y)
porter = Porter(clf)
output = porter.export()
print(output)
"""
import java.util.*;
class KNeighborsClassifier {
private int nNeighbors;
private int nTemplates;
private int nClasses;
private double power;
private double[][] X;
private int[] y;
public KNeighborsClassifier(int nNeighbors, int nTemplates, int nClasses, double power, double[][] X, int[] y) {
this.nNeighbors = nNeighbors;
this.nTemplates = nTemplates;
this.nClasses = nClasses;
this.power = power;
this.X = X;
this.y = y;
}
private static class Neighbor {
Integer clazz;
Double dist;
public Neighbor(int clazz, double dist) {
this.clazz = clazz;
this.dist = dist;
}
}
private static double compute(double[] temp, double[] cand, double q) {
double dist = 0.;
double diff;
for (int i = 0, l = temp.length; i < l; i++) {
diff = Math.abs(temp[i] - cand[i]);
if (q==1) {
dist += diff;
} else if (q==2) {
dist += diff*diff;
} else if (q==Double.POSITIVE_INFINITY) {
if (diff > dist) {
dist = diff;
}
} else {
dist += Math.pow(diff, q);
}
}
if (q==1 || q==Double.POSITIVE_INFINITY) {
return dist;
} else if (q==2) {
return Math.sqrt(dist);
} else {
return Math.pow(dist, 1. / q);
}
}
public int predict(double[] features) {
int classIdx = -1;
if (this.nNeighbors == 1) {
double minDist = Double.POSITIVE_INFINITY;
double curDist;
for (int i = 0; i < this.nTemplates; i++) {
curDist = KNeighborsClassifier.compute(this.X[i], features, this.power);
if (curDist <= minDist) {
minDist = curDist;
classIdx = y[i];
}
}
} else {
int[] classes = new int[this.nClasses];
ArrayList<Neighbor> dists = new ArrayList<Neighbor>();
for (int i = 0; i < this.nTemplates; i++) {
dists.add(new Neighbor(y[i], KNeighborsClassifier.compute(this.X[i], features, this.power)));
}
Collections.sort(dists, new Comparator<Neighbor>() {
@Override
public int compare(Neighbor n1, Neighbor n2) {
return n1.dist.compareTo(n2.dist);
}
});
for (Neighbor neighbor : dists.subList(0, this.nNeighbors)) {
classes[neighbor.clazz]++;
}
int classVal = -1;
for (int i = 0; i < this.nClasses; i++) {
if (classes[i] > classVal) {
classVal = classes[i];
classIdx = i;
}
}
}
return classIdx;
}
public static void main(String[] args) {
if (args.length == 4) {
// Features:
double[] features = new double[args.length];
for (int i = 0, l = args.length; i < l; i++) {
features[i] = Double.parseDouble(args[i]);
}
// Parameters:
double[][] X = {{5.0999999999999996, 3.5, 1.3999999999999999, 0.20000000000000001}, {4.9000000000000004, 3.0, 1.3999999999999999, 0.20000000000000001}, {4.7000000000000002, 3.2000000000000002, 1.3, 0.20000000000000001}, {4.5999999999999996, 3.1000000000000001, 1.5, 0.20000000000000001}, {5.0, 3.6000000000000001, 1.3999999999999999, 0.20000000000000001}, {5.4000000000000004, 3.8999999999999999, 1.7, 0.40000000000000002}, {4.5999999999999996, 3.3999999999999999, 1.3999999999999999, 0.29999999999999999}, {5.0, 3.3999999999999999, 1.5, 0.20000000000000001}, {4.4000000000000004, 2.8999999999999999, 1.3999999999999999, 0.20000000000000001}, {4.9000000000000004, 3.1000000000000001, 1.5, 0.10000000000000001}, {5.4000000000000004, 3.7000000000000002, 1.5, 0.20000000000000001}, {4.7999999999999998, 3.3999999999999999, 1.6000000000000001, 0.20000000000000001}, {4.7999999999999998, 3.0, 1.3999999999999999, 0.10000000000000001}, {4.2999999999999998, 3.0, 1.1000000000000001, 0.10000000000000001}, {5.7999999999999998, 4.0, 1.2, 0.20000000000000001}, {5.7000000000000002, 4.4000000000000004, 1.5, 0.40000000000000002}, {5.4000000000000004, 3.8999999999999999, 1.3, 0.40000000000000002}, {5.0999999999999996, 3.5, 1.3999999999999999, 0.29999999999999999}, {5.7000000000000002, 3.7999999999999998, 1.7, 0.29999999999999999}, {5.0999999999999996, 3.7999999999999998, 1.5, 0.29999999999999999}, {5.4000000000000004, 3.3999999999999999, 1.7, 0.20000000000000001}, {5.0999999999999996, 3.7000000000000002, 1.5, 0.40000000000000002}, {4.5999999999999996, 3.6000000000000001, 1.0, 0.20000000000000001}, {5.0999999999999996, 3.2999999999999998, 1.7, 0.5}, {4.7999999999999998, 3.3999999999999999, 1.8999999999999999, 0.20000000000000001}, {5.0, 3.0, 1.6000000000000001, 0.20000000000000001}, {5.0, 3.3999999999999999, 1.6000000000000001, 0.40000000000000002}, {5.2000000000000002, 3.5, 1.5, 0.20000000000000001}, {5.2000000000000002, 3.3999999999999999, 1.3999999999999999, 0.20000000000000001}, {4.7000000000000002, 3.2000000000000002, 1.6000000000000001, 0.20000000000000001}, {4.7999999999999998, 3.1000000000000001, 1.6000000000000001, 0.20000000000000001}, {5.4000000000000004, 3.3999999999999999, 1.5, 0.40000000000000002}, {5.2000000000000002, 4.0999999999999996, 1.5, 0.10000000000000001}, {5.5, 4.2000000000000002, 1.3999999999999999, 0.20000000000000001}, {4.9000000000000004, 3.1000000000000001, 1.5, 0.10000000000000001}, {5.0, 3.2000000000000002, 1.2, 0.20000000000000001}, {5.5, 3.5, 1.3, 0.20000000000000001}, {4.9000000000000004, 3.1000000000000001, 1.5, 0.10000000000000001}, {4.4000000000000004, 3.0, 1.3, 0.20000000000000001}, {5.0999999999999996, 3.3999999999999999, 1.5, 0.20000000000000001}, {5.0, 3.5, 1.3, 0.29999999999999999}, {4.5, 2.2999999999999998, 1.3, 0.29999999999999999}, {4.4000000000000004, 3.2000000000000002, 1.3, 0.20000000000000001}, {5.0, 3.5, 1.6000000000000001, 0.59999999999999998}, {5.0999999999999996, 3.7999999999999998, 1.8999999999999999, 0.40000000000000002}, {4.7999999999999998, 3.0, 1.3999999999999999, 0.29999999999999999}, {5.0999999999999996, 3.7999999999999998, 1.6000000000000001, 0.20000000000000001}, {4.5999999999999996, 3.2000000000000002, 1.3999999999999999, 0.20000000000000001}, {5.2999999999999998, 3.7000000000000002, 1.5, 0.20000000000000001}, {5.0, 3.2999999999999998, 1.3999999999999999, 0.20000000000000001}, {7.0, 3.2000000000000002, 4.7000000000000002, 1.3999999999999999}, {6.4000000000000004, 3.2000000000000002, 4.5, 1.5}, {6.9000000000000004, 3.1000000000000001, 4.9000000000000004, 1.5}, {5.5, 2.2999999999999998, 4.0, 1.3}, {6.5, 2.7999999999999998, 4.5999999999999996, 1.5}, {5.7000000000000002, 2.7999999999999998, 4.5, 1.3}, {6.2999999999999998, 3.2999999999999998, 4.7000000000000002, 1.6000000000000001}, {4.9000000000000004, 2.3999999999999999, 3.2999999999999998, 1.0}, {6.5999999999999996, 2.8999999999999999, 4.5999999999999996, 1.3}, {5.2000000000000002, 2.7000000000000002, 3.8999999999999999, 1.3999999999999999}, {5.0, 2.0, 3.5, 1.0}, {5.9000000000000004, 3.0, 4.2000000000000002, 1.5}, {6.0, 2.2000000000000002, 4.0, 1.0}, {6.0999999999999996, 2.8999999999999999, 4.7000000000000002, 1.3999999999999999}, {5.5999999999999996, 2.8999999999999999, 3.6000000000000001, 1.3}, {6.7000000000000002, 3.1000000000000001, 4.4000000000000004, 1.3999999999999999}, {5.5999999999999996, 3.0, 4.5, 1.5}, {5.7999999999999998, 2.7000000000000002, 4.0999999999999996, 1.0}, {6.2000000000000002, 2.2000000000000002, 4.5, 1.5}, {5.5999999999999996, 2.5, 3.8999999999999999, 1.1000000000000001}, {5.9000000000000004, 3.2000000000000002, 4.7999999999999998, 1.8}, {6.0999999999999996, 2.7999999999999998, 4.0, 1.3}, {6.2999999999999998, 2.5, 4.9000000000000004, 1.5}, {6.0999999999999996, 2.7999999999999998, 4.7000000000000002, 1.2}, {6.4000000000000004, 2.8999999999999999, 4.2999999999999998, 1.3}, {6.5999999999999996, 3.0, 4.4000000000000004, 1.3999999999999999}, {6.7999999999999998, 2.7999999999999998, 4.7999999999999998, 1.3999999999999999}, {6.7000000000000002, 3.0, 5.0, 1.7}, {6.0, 2.8999999999999999, 4.5, 1.5}, {5.7000000000000002, 2.6000000000000001, 3.5, 1.0}, {5.5, 2.3999999999999999, 3.7999999999999998, 1.1000000000000001}, {5.5, 2.3999999999999999, 3.7000000000000002, 1.0}, {5.7999999999999998, 2.7000000000000002, 3.8999999999999999, 1.2}, {6.0, 2.7000000000000002, 5.0999999999999996, 1.6000000000000001}, {5.4000000000000004, 3.0, 4.5, 1.5}, {6.0, 3.3999999999999999, 4.5, 1.6000000000000001}, {6.7000000000000002, 3.1000000000000001, 4.7000000000000002, 1.5}, {6.2999999999999998, 2.2999999999999998, 4.4000000000000004, 1.3}, {5.5999999999999996, 3.0, 4.0999999999999996, 1.3}, {5.5, 2.5, 4.0, 1.3}, {5.5, 2.6000000000000001, 4.4000000000000004, 1.2}, {6.0999999999999996, 3.0, 4.5999999999999996, 1.3999999999999999}, {5.7999999999999998, 2.6000000000000001, 4.0, 1.2}, {5.0, 2.2999999999999998, 3.2999999999999998, 1.0}, {5.5999999999999996, 2.7000000000000002, 4.2000000000000002, 1.3}, {5.7000000000000002, 3.0, 4.2000000000000002, 1.2}, {5.7000000000000002, 2.8999999999999999, 4.2000000000000002, 1.3}, {6.2000000000000002, 2.8999999999999999, 4.2999999999999998, 1.3}, {5.0999999999999996, 2.5, 3.0, 1.1000000000000001}, {5.7000000000000002, 2.7999999999999998, 4.0999999999999996, 1.3}, {6.2999999999999998, 3.2999999999999998, 6.0, 2.5}, {5.7999999999999998, 2.7000000000000002, 5.0999999999999996, 1.8999999999999999}, {7.0999999999999996, 3.0, 5.9000000000000004, 2.1000000000000001}, {6.2999999999999998, 2.8999999999999999, 5.5999999999999996, 1.8}, {6.5, 3.0, 5.7999999999999998, 2.2000000000000002}, {7.5999999999999996, 3.0, 6.5999999999999996, 2.1000000000000001}, {4.9000000000000004, 2.5, 4.5, 1.7}, {7.2999999999999998, 2.8999999999999999, 6.2999999999999998, 1.8}, {6.7000000000000002, 2.5, 5.7999999999999998, 1.8}, {7.2000000000000002, 3.6000000000000001, 6.0999999999999996, 2.5}, {6.5, 3.2000000000000002, 5.0999999999999996, 2.0}, {6.4000000000000004, 2.7000000000000002, 5.2999999999999998, 1.8999999999999999}, {6.7999999999999998, 3.0, 5.5, 2.1000000000000001}, {5.7000000000000002, 2.5, 5.0, 2.0}, {5.7999999999999998, 2.7999999999999998, 5.0999999999999996, 2.3999999999999999}, {6.4000000000000004, 3.2000000000000002, 5.2999999999999998, 2.2999999999999998}, {6.5, 3.0, 5.5, 1.8}, {7.7000000000000002, 3.7999999999999998, 6.7000000000000002, 2.2000000000000002}, {7.7000000000000002, 2.6000000000000001, 6.9000000000000004, 2.2999999999999998}, {6.0, 2.2000000000000002, 5.0, 1.5}, {6.9000000000000004, 3.2000000000000002, 5.7000000000000002, 2.2999999999999998}, {5.5999999999999996, 2.7999999999999998, 4.9000000000000004, 2.0}, {7.7000000000000002, 2.7999999999999998, 6.7000000000000002, 2.0}, {6.2999999999999998, 2.7000000000000002, 4.9000000000000004, 1.8}, {6.7000000000000002, 3.2999999999999998, 5.7000000000000002, 2.1000000000000001}, {7.2000000000000002, 3.2000000000000002, 6.0, 1.8}, {6.2000000000000002, 2.7999999999999998, 4.7999999999999998, 1.8}, {6.0999999999999996, 3.0, 4.9000000000000004, 1.8}, {6.4000000000000004, 2.7999999999999998, 5.5999999999999996, 2.1000000000000001}, {7.2000000000000002, 3.0, 5.7999999999999998, 1.6000000000000001}, {7.4000000000000004, 2.7999999999999998, 6.0999999999999996, 1.8999999999999999}, {7.9000000000000004, 3.7999999999999998, 6.4000000000000004, 2.0}, {6.4000000000000004, 2.7999999999999998, 5.5999999999999996, 2.2000000000000002}, {6.2999999999999998, 2.7999999999999998, 5.0999999999999996, 1.5}, {6.0999999999999996, 2.6000000000000001, 5.5999999999999996, 1.3999999999999999}, {7.7000000000000002, 3.0, 6.0999999999999996, 2.2999999999999998}, {6.2999999999999998, 3.3999999999999999, 5.5999999999999996, 2.3999999999999999}, {6.4000000000000004, 3.1000000000000001, 5.5, 1.8}, {6.0, 3.0, 4.7999999999999998, 1.8}, {6.9000000000000004, 3.1000000000000001, 5.4000000000000004, 2.1000000000000001}, {6.7000000000000002, 3.1000000000000001, 5.5999999999999996, 2.3999999999999999}, {6.9000000000000004, 3.1000000000000001, 5.0999999999999996, 2.2999999999999998}, {5.7999999999999998, 2.7000000000000002, 5.0999999999999996, 1.8999999999999999}, {6.7999999999999998, 3.2000000000000002, 5.9000000000000004, 2.2999999999999998}, {6.7000000000000002, 3.2999999999999998, 5.7000000000000002, 2.5}, {6.7000000000000002, 3.0, 5.2000000000000002, 2.2999999999999998}, {6.2999999999999998, 2.5, 5.0, 1.8999999999999999}, {6.5, 3.0, 5.2000000000000002, 2.0}, {6.2000000000000002, 3.3999999999999999, 5.4000000000000004, 2.2999999999999998}, {5.9000000000000004, 3.0, 5.0999999999999996, 1.8}};
int[] y = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2};
// Prediction:
KNeighborsClassifier clf = new KNeighborsClassifier(3, 150, 3, 2, X, y);
int estimation = clf.predict(features);
System.out.println(estimation);
}
}
}
"""