Skip to content

Latest commit

 

History

History
231 lines (181 loc) · 8.15 KB

Android端的Tensorflow部署二.md

File metadata and controls

231 lines (181 loc) · 8.15 KB

本文将展示一个稍微复杂点的例子,仍然不涉及模型训练,只是导入一个别人已经训好的模型,运行看效果。这个模型是用于物体分类的,打开相机拍照,识别图片中物体并给出识别结果及相应概率。

模型下载路径 https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip

解压后有三个文件:

imagenet_comp_graph_label_strings.txt
LICENSE
tensorflow_inception_graph.pb

创建一个App工程,将imagenet_comp_graph_label_strings.txt和tensorflow_inception_graph.pb拷到assets目录,导入tensorflow需要的jar和so库文件,

接下来为了方便使用相机,导入一个相机开源库,

compile 'com.flurgle:camerakit:0.9.12'

然后定义一个ImageClassifier类,用于图片分类,如下,

public class ImageClassifier {

    private TensorFlowInferenceInterface mTensorflow;
    private static final float THRESHOLD = 0.1f;
    private static final int MAX_RESULTS = 3;

    private String mInputName;
    private String mOutputName;
    private String[] mOutputNames;
    private int[] mIntValues;
    private float[] mFloatValues;

    /**
     * 输入图片是InputSize * InputSize * 3的
     */
    private int mInputSize;

    /**
     * 关于IMAGE_MEAN和IMAGE_STD的解释可参考
     * https://github.com/googlecodelabs/tensorflow-for-poets-2/issues/2
     * https://stats.stackexchange.com/questions/211436/why-do-we-normalize-images-by-subtracting-the-datasets-image-mean-and-not-the-c
     * 大意是这和网络相关,用于将输入标准化到某个区间内
     */
    private int mImageMean;
    private float mImageStd;

    private float[] mOutputs;

    private List<String> mLabels = new ArrayList<>();

    public ImageClassifier(AssetManager assetManager, String modelFilename, String labelFilename,
            int inputSize, int imageMean, float imageStd,
            String inputName, String outputName) {
        mInputName = inputName;
        mOutputName = outputName;

        readLabelFile(assetManager, labelFilename);

        mTensorflow = new TensorFlowInferenceInterface(assetManager, modelFilename);

        mInputSize = inputSize;
        mImageMean = imageMean;
        mImageStd = imageStd;

        mOutputNames = new String[]{outputName};
        mIntValues = new int[inputSize * inputSize];
        mFloatValues = new float[inputSize * inputSize * 3];

        // numClasses为输出结果的个数,每个结果对应一个概率
        int numClasses = (int) mTensorflow.graph().operation(outputName).output(0).shape().size(1);
        mOutputs = new float[numClasses];
    }

    /**
     * 读取label文件
     */
    private void readLabelFile(AssetManager assetManager, String labelFilename) {
        String actualFilename = labelFilename.split("file:///android_asset/")[1];
        BufferedReader br = null;
        try {
            br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
            String line;
            while ((line = br.readLine()) != null) {
                mLabels.add(line);
            }
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (br != null) {
                try {
                    br.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
    }

    public List<ClassifyResult> recognizeImage(final Bitmap bitmap) {
        bitmap.getPixels(mIntValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
        for (int i = 0; i < mIntValues.length; ++i) {
            final int val = mIntValues[i];
            // 对图片进行标准化处理
            mFloatValues[i * 3 + 0] = (((val >> 16) & 0xFF) - mImageMean) / mImageStd;
            mFloatValues[i * 3 + 1] = (((val >> 8) & 0xFF) - mImageMean) / mImageStd;
            mFloatValues[i * 3 + 2] = ((val & 0xFF) - mImageMean) / mImageStd;
        }
        mTensorflow.feed(mInputName, mFloatValues, new long[] {1, mInputSize, mInputSize, 3});
        mTensorflow.run(mOutputNames, true);
        mTensorflow.fetch(mOutputName, mOutputs);

        mQueue.clear();
        for (int i = 0; i < mOutputs.length; i++) {
            if (mOutputs[i] > THRESHOLD) {
                mQueue.add(new ClassifyResult(mOutputs[i], mLabels.get(i)));
            }
            if (mQueue.size() > MAX_RESULTS) {
                mQueue.poll();
            }
        }

        List<ClassifyResult> results = new ArrayList<>();
        while (!mQueue.isEmpty()) {
            results.add(0, mQueue.poll());
        }

        return results;
    }

    private final PriorityQueue<ClassifyResult> mQueue = new PriorityQueue<>(3, new Comparator<ClassifyResult>() {

        @Override
        public int compare(ClassifyResult o1, ClassifyResult o2) {
            return Float.compare(o1.confidence, o2.confidence);
        }
    });

    public void close() {
        mTensorflow.close();
    }
}

这里要读取label文件,便于之后根据output的index取得对应的label,output的是对应的概率。这里用一个优先队列保存概率最大的三个物体返回。

最后再来创建Activity,CameraView用于相机预览,当点击按钮时抓拍并将图片数据丢给ImageClassifier获得识别结果,然后显示出来。

public class MainActivity extends Activity {

    private static final int INPUT_SIZE = 224;
    private static final int IMAGE_MEAN = 117;
    private static final float IMAGE_STD = 1;
    private static final String INPUT_NAME = "input";
    private static final String OUTPUT_NAME = "output";

    private static final String MODEL_FILE = "file:///android_asset/tensorflow_inception_graph.pb";
    private static final String LABEL_FILE = "file:///android_asset/imagenet_comp_graph_label_strings.txt";

    private CameraView mCamera;

    private ImageClassifier mClassifier;

    private TextView mTvResult;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        mClassifier = new ImageClassifier(getAssets(), MODEL_FILE, LABEL_FILE, INPUT_SIZE, IMAGE_MEAN, IMAGE_STD, INPUT_NAME, OUTPUT_NAME);

        mTvResult = findViewById(R.id.result);
        mCamera = findViewById(R.id.camera);
        mCamera.setCameraListener(new CameraListener() {
            @Override
            public void onPictureTaken(byte[] jpeg) {
                Bitmap bitmap = BitmapFactory.decodeByteArray(jpeg, 0, jpeg.length);
                bitmap = Bitmap.createScaledBitmap(bitmap, INPUT_SIZE, INPUT_SIZE, false);
                List<ClassifyResult> results = mClassifier.recognizeImage(bitmap);
                mTvResult.setText(results.toString());
            }
        });

        findViewById(R.id.detect).setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View v) {
                mCamera.captureImage();
            }
        });
    }

    @Override
    protected void onResume() {
        super.onResume();
        mCamera.start();
    }

    @Override
    protected void onPause() {
        mCamera.stop();
        super.onPause();
    }

    @Override
    protected void onDestroy() {
        mClassifier.close();
        super.onDestroy();
    }
}

这个项目可以和tensorflow自带的Android例子对照着看,基本是相似的。参考tensorflow/examples/android中的ClassifierActivity.java和TensorFlowImageClassifier.java。

不过奇怪的是识别的结果中概率都很低,普遍是20%左右,原因不明。 下文将研究如何在已训好的模型基础上增量训练我们自己的数据。

另外可以参照poets, https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/index.html#1 https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2/index.html#0 https://petewarden.com/2016/09/27/tensorflow-for-mobile-poets/

Android TensorFlow Machine Learning Example

MindorksOpenSource/AndroidTensorFlowMachineLearningExample