Skip to content

Commit

Permalink
[spark] Update README (#2596)
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed May 12, 2023
1 parent 06c6e4c commit aa61c10
Showing 1 changed file with 15 additions and 19 deletions.
34 changes: 15 additions & 19 deletions extensions/spark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,32 +45,28 @@ Using the DJL Spark Extension is simple and straightforward. Here is an example
### Scala

```scala
import ai.djl.spark.SparkTransformer
import ai.djl.spark.translator.SparkImageClassificationTranslator
import ai.djl.spark.task.vision.ImageClassifier

val transformer = new SparkTransformer[Classifications]()
.setInputCols(Array("input_col1", "input_col2"))
.setOutputCols(Array("value"))
val classifier = new ImageClassifier()
.setInputCols(Array("origin", "height", "width", "nChannels", "mode", "data"))
.setOutputCol("prediction")
.setEngine("PyTorch")
.setModelUrl("model_url")
.setOutputClass(classOf[Classifications])
.setTranslator(new SparkImageClassificationTranslator())
val outputDf = transformer.transform(df)
.setModelUrl("djl://ai.djl.pytorch/resnet")
.setTopK(2)
var outputDf = classifier.classify(df)
```

### Python

```python
from djl_spark.transformer import SparkTransformer
from djl_spark.translator import SparkImageClassificationTranslator

transformer = SparkTransformer(input_cols=["input_col1", "input_col2"],
output_cols=["value"],
engine="PyTorch",
model_url="model_url",
output_class="ai.djl.modality.Classifications",
translator=SparkImageClassificationTranslator())
outputDf = transformer.transform(df)
from djl_spark.task.vision import ImageClassifier

classifier = ImageClassifier(input_cols=["origin", "height", "width", "nChannels", "mode", "data"],
output_col="prediction",
engine="PyTorch",
model_url="djl://ai.djl.pytorch/resnet",
top_k=2)
outputDf = classifier.classify(df)
```

See [examples](https://github.com/deepjavalibrary/djl-demo/tree/master/apache-spark/spark3.0) for more details.

0 comments on commit aa61c10

Please sign in to comment.