Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

How to run trained mxnet model in Java #4060

Closed
anjishnu opened this issue Dec 2, 2016 · 4 comments
Closed

How to run trained mxnet model in Java #4060

anjishnu opened this issue Dec 2, 2016 · 4 comments

Comments

@anjishnu
Copy link
Contributor

anjishnu commented Dec 2, 2016

I want to develop and experiment with models in Python but execute them in Java (e.g. in an Android App).

How do I accomplish something analogous to this -
https://medium.com/google-cloud/how-to-invoke-a-trained-tensorflow-model-from-java-programs-27ed5f4f502d

@thirdwing
Copy link
Contributor

@zihaolucky
Copy link
Member

If not in Android, you can build/make a MXNet-Scala package and put the jar into your project, then write a predictor.scala to accept java array input and serve.

class MXNetPredictor (prefix: String, epoch: Int, batchSize: Int) {

  val model = FeedForward.load(prefix, epoch)

  /**
    * predict
    *
    * @param flat A flat feature input vector.
    * @param shape Shape of input data.
    * @return
    */
  def predict(flat: Array[Float], shape: Array[Int]): Array[NDArray] = {
    val ndArray = NDArray.array(flat, Shape(shape))

    val data: IndexedSeq[NDArray] = IndexedSeq(ndArray)
    val label: IndexedSeq[NDArray] = IndexedSeq()

    val valData: NDArrayIter = new NDArrayIter(data, label, batchSize)
    val prediction = model.predict(valData)
    return prediction

  }

  /**
    * Top-1 prediction results for a batch data.
    *
    * @param flat A flat feature input vector.
    * @param shape Shape of input data.
    * @return Return top-1 prediction results for a batch data.
    */
  def predictTop1(flat: Array[Float], shape: Array[Int]): NDArray = {
    val prediction = predict(flat, shape)(0)
    NDArray.argmaxChannel(prediction)
  }

}

In Java side, it's simple

private MXNetPredictor          predictor;

@yajiedesign
Copy link
Contributor

This issue is closed due to lack of activity in the last 90 days. Feel free to reopen if this is still an active issue. Thanks!

@sampathchanda
Copy link

@zihaolucky Do you have a git repo or some working example of the same? I am new to both java and scala but trying to deploy an MXNet model using Java. It would be great if you can point me to a working example of the same.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants