-
Notifications
You must be signed in to change notification settings - Fork 5
/
Mlp.java
122 lines (111 loc) · 3.22 KB
/
Mlp.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
package org.genericsystem.cv;
import java.util.ArrayList;
import java.util.Arrays;
import org.opencv.core.Core;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.MatOfFloat;
import org.opencv.core.MatOfInt;
import org.opencv.core.TermCriteria;
import org.opencv.ml.ANN_MLP;
import org.opencv.ml.Ml;
public class Mlp {
static {
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
}
final int MAX_DATA = 1000;
ANN_MLP mlp;
int input;
int output;
ArrayList<float[]> train;
ArrayList<float[]> label;
MatOfFloat result;
public static void main(String[] args) {
Mlp mlp = new Mlp(2, 2);
mlp.addData(new float[] { 0, 0 }, new float[] { 1, 0 });
mlp.addData(new float[] { 1, 1 }, new float[] { 0, 1 });
mlp.addData(new float[] { 0, 1 }, new float[] { 1, 0 });
mlp.addData(new float[] { 1, 0 }, new float[] { 1, 0 });
// System.out.println(+mlp.getCount() + " " + mlp.label.size());
mlp.train();
mlp.predict(new float[] { 0, 0 });
System.out.println("0 xor 0, 0 or 0 = " + Arrays.toString(mlp.getResult()));
mlp.predict(new float[] { 1, 1 });
System.out.println("1 xor 1, 1 or 1 = " + Arrays.toString(mlp.getResult()));
mlp.predict(new float[] { 0, 1 });
System.out.println("0 xor 1, 0 or 1 = " + Arrays.toString(mlp.getResult()));
mlp.predict(new float[] { 1, 0 });
System.out.println("1 xor 0, 1 or 0 = " + Arrays.toString(mlp.getResult()));
}
public Mlp(int i, int o) {
input = i;
output = o;
mlp = ANN_MLP.create();
MatOfInt m1 = new MatOfInt(input, 8, output);
mlp.setLayerSizes(m1);
mlp.setActivationFunction(ANN_MLP.SIGMOID_SYM);
// mlp.setTrainMethod(ANN_MLP.BACKPROP);
mlp.setTermCriteria(new TermCriteria(TermCriteria.MAX_ITER + TermCriteria.EPS, 100000, 0.00001f));
result = new MatOfFloat();
train = new ArrayList<>();
label = new ArrayList<>();
}
void addData(float[] t, float[] l) {
if (t.length != input)
return;
if (train.size() >= MAX_DATA)
return;
train.add(t);
label.add(l);
}
int getCount() {
return train.size();
}
void train() {
float[][] tr = new float[train.size()][input];
for (int i = 0; i < train.size(); i++) {
for (int j = 0; j < train.get(i).length; j++) {
tr[i][j] = train.get(i)[j];
}
}
Mat response = new Mat(label.size(), label.get(0).length, CvType.CV_32FC1);
for (int i = 0; i < label.size(); i++)
for (int j = 0; j < label.get(0).length; j++)
response.put(i, j, label.get(i)[j]);
float[] trf = flatten(tr);
Mat trainData = new Mat(train.size(), input, CvType.CV_32FC1);
trainData.put(0, 0, trf);
mlp.train(trainData, Ml.ROW_SAMPLE, response);
trainData.release();
response.release();
train.clear();
label.clear();
}
float predict(float[] i) {
if (i.length != input)
return -1;
Mat test = new Mat(1, input, CvType.CV_32FC1);
test.put(0, 0, i);
float val = mlp.predict(test, result, 0);
return val;
}
float[] getResult() {
float[] r = result.toArray();
return r;
}
float[] flatten(float[][] a) {
if (a.length == 0)
return new float[] {};
int rCnt = a.length;
int cCnt = a[0].length;
float[] res = new float[rCnt * cCnt];
int idx = 0;
for (int r = 0; r < rCnt; r++) {
for (int c = 0; c < cCnt; c++) {
res[idx] = a[r][c];
idx++;
}
}
return res;
}
}