Skip to content

Commit

Permalink
implement TFLite。
Browse files Browse the repository at this point in the history
  • Loading branch information
keiji committed Jul 26, 2018
1 parent ea89a94 commit 2f0bc16
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 11 deletions.
5 changes: 4 additions & 1 deletion app/build.gradle
Expand Up @@ -22,6 +22,9 @@ android {
proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
}
}
aaptOptions {
noCompress "tflite"
}
}

dependencies {
Expand All @@ -37,5 +40,5 @@ dependencies {
implementation 'io.reactivex.rxjava2:rxkotlin:2.2.0'
implementation 'io.reactivex.rxjava2:rxandroid:2.0.2'

implementation 'org.tensorflow:tensorflow-android:1.9.0'
implementation 'org.tensorflow:tensorflow-lite:1.9.0'
}
50 changes: 40 additions & 10 deletions app/src/main/java/io/keiji/foodgallery/ImageRecognizer.kt
Expand Up @@ -20,15 +20,22 @@ import android.content.res.AssetManager
import android.graphics.Bitmap
import android.os.Debug
import android.util.Log
import org.tensorflow.contrib.android.TensorFlowInferenceInterface
import org.tensorflow.lite.Interpreter
import java.io.FileInputStream
import java.io.IOException
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.nio.MappedByteBuffer
import java.nio.channels.FileChannel


class ImageRecognizer(assetManager: AssetManager) {

companion object {
val TAG = ImageRecognizer::class.java.simpleName

// https://github.com/keiji/food_gallery_with_tensorflow
private val MODEL_FILE_PATH = "food_model_4ch.pb"
private val MODEL_FILE_PATH = "food_model_4ch.tflite"

private val IMAGE_WIDTH = 128
private val IMAGE_HEIGHT = 128
Expand All @@ -41,22 +48,45 @@ class ImageRecognizer(assetManager: AssetManager) {
}
}

val tfInference: TensorFlowInferenceInterface = TensorFlowInferenceInterface(
assetManager.open(MODEL_FILE_PATH))
@Throws(IOException::class)
private fun loadModelFile(assets: AssetManager, modelFileName: String): MappedByteBuffer {
val fileDescriptor = assets.openFd(modelFileName)
val inputStream = FileInputStream(fileDescriptor.fileDescriptor)

val fileChannel = inputStream.channel
val startOffset = fileDescriptor.startOffset
val declaredLength = fileDescriptor.declaredLength

return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
}

val tfInference: Interpreter

init {
tfInference = Interpreter(loadModelFile(assetManager, MODEL_FILE_PATH)).apply {
// setUseNNAPI(true)
}
}

val inputBuffer = ByteBuffer
.allocateDirect(IMAGE_BYTES_LENGTH * 4)
.order(ByteOrder.nativeOrder())

val resultArray = FloatArray(1)
val resultArray = Array(1, { FloatArray(1) })

fun recognize(imageByteArray: ByteArray): Float {
val start = Debug.threadCpuTimeNanos()

tfInference.feed("input", imageByteArray, IMAGE_BYTES_LENGTH.toLong())
tfInference.run(arrayOf("result"))
tfInference.fetch("result", resultArray)
imageByteArray.forEach { inputBuffer.putFloat(it.toInt().and(0xFF).toFloat()) }
inputBuffer.rewind()

val start = Debug.threadCpuTimeNanos()
tfInference.run(inputBuffer, resultArray)
inputBuffer.rewind()

val elapsed = Debug.threadCpuTimeNanos() - start
Log.d(TAG, "Elapsed: %,3d ns".format(elapsed))

return resultArray[0]
return resultArray[0][0]
}

fun stop() {
Expand Down

0 comments on commit 2f0bc16

Please sign in to comment.