# Image Classification with DJL Spark Support

In this example, we will use Jupyter Notebook to run image Classification with DJL Spark extension on Scala. To execute this Scala kernel successfully, you need to install [Almond](https://almond.sh/), a Scala kernel for Jupyter Notebook. Almond provide extensive functionalities for Scala and Spark applications.

[Almond installation instruction](https://almond.sh/docs/quick-start-install) (Note: only Scala 2.12 are tested)

After that, you can start with DJL's Scala notebook.


## Import dependencies

Firstly, let's import the depdendencies we need. We choose to use DJL PyTorch as our backend engine. You can also switch to MXNet by uncommenting the two lines for MXNet and comment PyTorch.

In [None]:
import $ivy.`org.apache.spark::spark-sql:3.2.2`
import $ivy.`ai.djl:api:0.21.0`
import $ivy.`ai.djl.spark:spark:0.21.0`
import $ivy.`ai.djl.pytorch:pytorch-model-zoo:0.21.0`
import $ivy.`ai.djl.pytorch:pytorch-native-cpu-precxx11:1.13.1`
// import $ivy.`ai.djl.mxnet:mxnet-engine:0.21.0`
// import $ivy.`ai.djl.mxnet:mxnet-native-mkl:1.9.1`

Then we can import the packages we need to use. In the last two lines, we disabled the Spark logging in order to avoid polluting your cell outputs.

In [None]:
import org.apache.spark.sql.NotebookSparkSession
import ai.djl.modality.Classifications
import ai.djl.spark.SparkTransformer
import ai.djl.spark.translator.SparkImageClassificationTranslator
import org.apache.spark.sql.SparkSession

import org.apache.log4j.{Level, Logger}
Logger.getLogger("org").setLevel(Level.OFF) // avoid too much message popping out
Logger.getLogger("ai").setLevel(Level.OFF) // avoid too much message popping out

## Start Spark application

We can create a `NotebookSparkSession` through the Almond Spark plugin. It will internally apply all necessary jars to each of the worker node.

In [None]:
// Create Spark session
val spark = {
  NotebookSparkSession.builder()
    .master("local[*]")
    .getOrCreate()
}

Let's try to load the images from the local folder using Spark library:

In [None]:
val df = spark.read.format("image").option("dropInvalid", true).load("../image-classification/images")
df.select("image.origin", "image.width", "image.height").show(truncate=false)

Then We can run inference on these images. All we need to do is to create a `SparkTransformer` and run inference with DJL.

In [None]:
System.setProperty("PYTORCH_PRECXX11", "true")
val transformer = new SparkTransformer[Classifications]()
  .setInputCol("image.*")
  .setOutputCol("value")
  .setModelUrl("https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/pytorch_resnet18.zip?model_name=traced_resnet18")
  .setOutputClass(classOf[Classifications])
  .setTranslator(new SparkImageClassificationTranslator())
val outputDf = transformer.transform(df)
println(outputDf.collect().mkString("\n"))