Skip to content
Permalink
Browse files

Added quantized model for GTSRB subset (blog post demo)

  • Loading branch information...
frogermcs committed Jun 25, 2019
1 parent 8eee377 commit 7dde56cc2022d57c9015d874dee90e43b0c5f442
Binary file not shown.
@@ -0,0 +1,10 @@
30_speed
60_speed
80_speed
80_lifted
120_speed
no_overtaking_general
right_of_way_general
stop
no_way_general
no_way_trucks
@@ -48,6 +48,27 @@ public void process(@NonNull Frame frame) {
bitmap.recycle();

ByteBuffer byteBufferToClassify = bitmapToModelsMatchingByteBuffer(toClassify);
if (modelConfig.isQuantized()) {
runInferenceOnQuantizedModel(byteBufferToClassify);
} else {
runInferenceOnFloatModel(byteBufferToClassify);
}
}

private void runInferenceOnQuantizedModel(ByteBuffer byteBufferToClassify) {
byte[][] result = new byte[1][labels.size()];
interpreter.run(byteBufferToClassify, result);
float[][] resultFloats = new float[1][labels.size()];
byte[] bytes = result[0];
for (int i = 0; i < bytes.length; i++) {
float resultF = (bytes[i] & 0xff) / 255.f;
resultFloats[0][i] = resultF;

}
classificationListener.onClassifiedFrame(getSortedResult(resultFloats));
}

private void runInferenceOnFloatModel(ByteBuffer byteBufferToClassify) {
float[][] result = new float[1][labels.size()];
interpreter.run(byteBufferToClassify, result);
classificationListener.onClassifiedFrame(getSortedResult(result));
@@ -62,8 +83,14 @@ private ByteBuffer bitmapToModelsMatchingByteBuffer(Bitmap bitmap) {
for (int i = 0; i < modelConfig.getInputWidth(); ++i) {
for (int j = 0; j < modelConfig.getInputHeight(); ++j) {
int pixelVal = intValues[pixel++];
for (float channelVal : pixelToChannelValues(pixelVal)) {
byteBuffer.putFloat(channelVal);
if (modelConfig.isQuantized()) {
for (byte channelVal : pixelToChannelValuesQuant(pixelVal)) {
byteBuffer.put(channelVal);
}
} else {
for (float channelVal : pixelToChannelValues(pixelVal)) {
byteBuffer.putFloat(channelVal);
}
}
}
}
@@ -89,6 +116,14 @@ private ByteBuffer bitmapToModelsMatchingByteBuffer(Bitmap bitmap) {
}
}

private byte[] pixelToChannelValuesQuant(int pixel) {
byte[] rgbVals = new byte[3];
rgbVals[0] = (byte) ((pixel >> 16) & 0xFF);
rgbVals[1] = (byte) ((pixel >> 8) & 0xFF);
rgbVals[2] = (byte) ((pixel) & 0xFF);
return rgbVals;
}

private List<ClassificationResult> getSortedResult(float[][] resultsArray) {
PriorityQueue<ClassificationResult> sortedResults = new PriorityQueue<>(
MAX_CLASSIFICATION_RESULTS,
@@ -6,6 +6,7 @@

import androidx.appcompat.app.AppCompatActivity;

import com.frogermcs.imageclassificationtester.configs.GtsrbQuantConfig;
import com.frogermcs.imageclassificationtester.configs.MobileNetV2Float;
import com.frogermcs.imageclassificationtester.configs.ModelConfig;
import com.otaliastudios.cameraview.CameraView;
@@ -34,8 +35,9 @@ protected void onCreate(Bundle savedInstanceState) {

private void initClassification() {
try {
ModelConfig modelConfig = new MobileNetV2Float();
// ModelConfig modelConfig = new MobileNetV2Float();
// ModelConfig modelConfig = new MnistConfig();
ModelConfig modelConfig = new GtsrbQuantConfig();
classificationFrameProcessor = new ClassificationFrameProcessor(this, this, modelConfig);
cameraView.addFrameProcessor(classificationFrameProcessor);
} catch (IOException e) {
@@ -0,0 +1,49 @@
package com.frogermcs.imageclassificationtester.configs;

public class GtsrbQuantConfig extends ModelConfig {

@Override
public String getModelFilename() {
return "gtsrb_demo.tflite";
}

@Override
public String getLabelsFilename() {
return "gtsrb_demo_labels.txt";
}

@Override
public int getInputWidth() {
return 224;
}

@Override
public int getInputHeight() {
return 224;
}

@Override
public int getInputSize() {
return getInputWidth() * getInputHeight() * getChannelsCount() * QUANT_BYTES_COUNT;
}

@Override
public int getChannelsCount() {
return 3;
}

@Override
public float getStd() {
return 128.f;
}

@Override
public float getMean() {
return 128.f;
}

@Override
public boolean isQuantized() {
return true;
}
}
@@ -40,4 +40,9 @@ public float getMean() {
public float getStd() {
return 255.f;
}

@Override
public boolean isQuantized() {
return false;
}
}
@@ -41,4 +41,9 @@ public float getStd() {
public float getMean() {
return 128.f;
}

@Override
public boolean isQuantized() {
return false;
}
}
@@ -3,6 +3,7 @@
public abstract class ModelConfig {

static final int FLOAT_BYTES_COUNT = 4;
static final int QUANT_BYTES_COUNT = 1;

public abstract String getModelFilename();

@@ -20,4 +21,6 @@

public abstract float getStd();

public abstract boolean isQuantized();

}

0 comments on commit 7dde56c

Please sign in to comment.
You can’t perform that action at this time.