-
Notifications
You must be signed in to change notification settings - Fork 66
/
Classifier.java
98 lines (70 loc) · 2.87 KB
/
Classifier.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
package mariannelinhares.mnistandroid;
import android.content.res.AssetManager;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
/**
* Changed from https://github.com/MindorksOpenSource/AndroidTensorFlowMNISTExample/blob/master/app/src/main/java/com/mindorks/tensorflowexample/TensorFlowImageClassifier.java
* Created by marianne-linhares on 20/04/17.
*/
public class Classifier {
// Only returns if at least this confidence
private static final float THRESHOLD = 0.1f;
private TensorFlowInferenceInterface tfHelper;
private String inputName;
private String outputName;
private int inputSize;
private List<String> labels;
private float[] output;
private String[] outputNames;
static private List<String> readLabels(Classifier c, AssetManager am, String fileName) throws IOException {
BufferedReader br = null;
br = new BufferedReader(new InputStreamReader(am.open(fileName)));
String line;
List<String> labels = new ArrayList<>();
while ((line = br.readLine()) != null) {
labels.add(line);
}
br.close();
return labels;
}
static public Classifier create(AssetManager assetManager, String modelPath, String labelPath,
int inputSize, String inputName, String outputName)
throws IOException {
Classifier c = new Classifier();
c.inputName = inputName;
c.outputName = outputName;
// Read labels
String labelFile = labelPath.split("file:///android_asset/")[1];
c.labels = readLabels(c, assetManager, labelFile);
c.tfHelper = new TensorFlowInferenceInterface();
if (c.tfHelper.initializeTensorFlow(assetManager, modelPath) != 0) {
throw new RuntimeException("TF initialization failed");
}
int numClasses = 10;
c.inputSize = inputSize;
// Pre-allocate buffer.
c.outputNames = new String[]{ outputName };
c.outputName = outputName;
c.output = new float[numClasses];
return c;
}
public Classification recognize(final float[] pixels) {
tfHelper.fillNodeFloat(inputName, new int[]{inputSize * inputSize}, pixels);
tfHelper.runInference(outputNames);
tfHelper.readNodeFloat(outputName, output);
// Find the best classification
Classification ans = new Classification();
for (int i = 0; i < output.length; ++i) {
System.out.println(output[i]);
System.out.println(labels.get(i));
if (output[i] > THRESHOLD && output[i] > ans.getConf()) {
ans.update(output[i], labels.get(i));
}
}
return ans;
}
}