# Image Classification with Apache Spark

In this example, we will use Jupyter Notebook to run image Classification with Apache Spark on Scala. To execute this Scala kernel successfully, you need to install [Almond](https://almond.sh/), a Scala kernel for Jupyter Notebook. Almond provide extensive functionalities for Scala and Spark applications.

[Almond installation instruction](https://almond.sh/docs/quick-start-install) (Note: only Scala 2.12 are tested)

After that, you can start with DJL's Scala notebook.


## Import dependencies

Firstly, let's import the depdendencies we need. We choose to use DJL PyTorch as our backend engine. You can also switch to MXNet by uncommenting the two lines for MXNet and comment PyTorch.

In [None]:
import $ivy.`org.apache.spark::spark-sql:3.0.1`
import $ivy.`org.apache.spark::spark-mllib:3.0.1`
import $ivy.`ai.djl:api:0.10.0`
import $ivy.`ai.djl.pytorch:pytorch-model-zoo:0.10.0`
import $ivy.`ai.djl.pytorch:pytorch-native-auto:1.7.1`
// import $ivy.`ai.djl.mxnet:mxnet-model-zoo:0.10.0`
// import $ivy.`ai.djl.mxnet:mxnet-native-auto:1.7.0-backport`

Then we can import the packages we need to use. In the last two lines, we disabled the Spark logging in order to avoid polluting your cell outputs.

In [None]:
import java.util
import ai.djl.Model
import ai.djl.modality.Classifications
import ai.djl.modality.cv.transform.{ Resize, ToTensor}
import ai.djl.ndarray.types.{DataType, Shape}
import ai.djl.ndarray.{NDList, NDManager}
import ai.djl.repository.zoo.{Criteria, ModelZoo, ZooModel}
import ai.djl.training.util.ProgressBar
import ai.djl.translate.{Batchifier, Pipeline, Translator, TranslatorContext}
import ai.djl.util.Utils
import org.apache.spark.ml.image.ImageSchema
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.{Encoders, Row, NotebookSparkSession}
import org.apache.log4j.{Level, Logger}
Logger.getLogger("org").setLevel(Level.OFF) // avoid too much message popping out
Logger.getLogger("ai").setLevel(Level.OFF) // avoid too much message popping out

## Create Translator

A Translator in DJL is used to define the preprocessing and postprocessing logic. The following code is to 

- preprocess: convert a Spark DataFrame Row to DJL NDArray.
- postprocess: convert inference result to classifications

In [None]:
  // Translator: a class used to do preprocessing and post processing
  class MyTranslator extends Translator[Row, Classifications] {

    private var classes: java.util.List[String] = new util.ArrayList[String]()
    private val pipeline: Pipeline = new Pipeline()
      .add(new Resize(224, 224))
      .add(new ToTensor())

    override def prepare(manager: NDManager, model: Model): Unit = {
        classes = Utils.readLines(model.getArtifact("synset.txt").openStream())
      }

    override def processInput(ctx: TranslatorContext, row: Row): NDList = {

      val height = ImageSchema.getHeight(row)
      val width = ImageSchema.getWidth(row)
      val channel = ImageSchema.getNChannels(row)
      var image = ctx.getNDManager.create(ImageSchema.getData(row), new Shape(height, width, channel)).toType(DataType.UINT8, true)
      // BGR to RGB
      image = image.flip(2)
      pipeline.transform(new NDList(image))
    }

    // Deal with the output.，NDList contains output result, usually one or more NDArray(s).
    override def processOutput(ctx: TranslatorContext, list: NDList): Classifications = {
      var probabilitiesNd = list.singletonOrThrow
      probabilitiesNd = probabilitiesNd.softmax(0)
      new Classifications(classes, probabilitiesNd)
    }

    override def getBatchifier: Batchifier = Batchifier.STACK
  }

## Load the model

Now, we just need to fetch the model from a URL. The url can be a hdfs (hdfs://), file (file://) or http (https://) format. We use Criteria as a container to store the model and translator information. Then, all we need to do is to load the model from it.

Note: DJL Criteria and Model are not serializable, so we add `lazy` declaration.

If you are using MXNet as the backend engine, plase uncomment the mxnet model url.

In [None]:
val modelUrl = "https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/pytorch_resnet18.zip?model_name=traced_resnet18"
// val modelUrl = "https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/mxnet_resnet18.zip?model_name=resnet18_v1"
lazy val criteria = Criteria.builder
  .setTypes(classOf[Row], classOf[Classifications])
  .optModelUrls(modelUrl)
  .optTranslator(new MyTranslator())
  .optProgress(new ProgressBar)
  .build()
lazy val model = ModelZoo.loadModel(criteria)

## Start Spark application

We can create a `NotebookSparkSession` through the Almond Spark plugin. It will internally apply all necessary jars to each of the worker node.

In [None]:
// Create Spark session
val spark = {
  NotebookSparkSession.builder()
    .master("local[*]")
    .getOrCreate()
}

Let's try to load the images from the local folder using Spark ML library:

In [None]:
val df = spark.read.format("image").option("dropInvalid", true).load("../image-classification/images")
df.select("image.origin", "image.width", "image.height").show(truncate=false)

Then We can run inference on these images. All we need to do is to create a `Predictor` and run inference with DJL.

In [None]:
val result = df.select(col("image.*")).mapPartitions(partition => {
  val predictor = model.newPredictor()
  partition.map(row => {
    // image data stored as HWC format
    predictor.predict(row).toString
  })
})(Encoders.STRING)
println(result.collect().mkString("\n"))