Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: Adding document and notebooks for ONNXModel #1164

Merged
merged 10 commits into from
Aug 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ object DatabricksUtilities extends HasHttpClient {
Map("pypi" -> Map("package" -> "nltk")),
Map("pypi" -> Map("package" -> "bs4")),
Map("pypi" -> Map("package" -> "plotly")),
Map("pypi" -> Map("package" -> "Pillow"))
Map("pypi" -> Map("package" -> "Pillow")),
Map("pypi" -> Map("package" -> "onnxmltools")),
Map("pypi" -> Map("package" -> "lightgbm"))
memoryz marked this conversation as resolved.
Show resolved Hide resolved
).toJson.compactPrint

// Execution Params
Expand Down
111 changes: 111 additions & 0 deletions deep-learning/src/main/python/mmlspark/onnx/ONNXModel.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,29 @@
# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.

from abc import ABCMeta
import sys
from typing import Mapping, List
from py4j import java_gateway

if sys.version >= "3":
basestring = str

from mmlspark.onnx._ONNXModel import _ONNXModel
from pyspark.ml.common import inherit_doc
from py4j.java_gateway import JavaObject


class NodeInfo(object):
def __init__(self, name: str, value_info: JavaObject):
self.name = name
self.value_info = ValueInfo.from_java(value_info)

def __str__(self) -> str:
return "NodeInfo(name=" + self.name + ",info=" + str(self.value_info) + ")"

def __repr__(self) -> str:
return self.__str__()


@inherit_doc
Expand All @@ -27,3 +43,98 @@ def setMiniBatchSize(self, n):
self._java_obj = self._java_obj.setMiniBatchSize(n)
return self

def __parse_node_info(self, node_info: JavaObject) -> "NodeInfo":
name = node_info.getName()
value_info = node_info.getInfo()
return NodeInfo(name, value_info)

def getModelInputs(self):
self._transfer_params_to_java()
mi = self._java_obj.modelInputJava()
return {name: self.__parse_node_info(info) for name, info in mi.items()}

def getModelOutputs(self) -> Mapping[str, NodeInfo]:
self._transfer_params_to_java()
mo = self._java_obj.modelOutputJava()
return {name: self.__parse_node_info(info) for name, info in mo.items()}


class ValueInfo(metaclass=ABCMeta):
@classmethod
def from_java(cls, java_value_info: JavaObject) -> "ValueInfo":
className = java_value_info.getClass().getName()
if className == "ai.onnxruntime.TensorInfo":
return TensorInfo.from_java(java_value_info)
elif className == "ai.onnxruntime.MapInfo":
return MapInfo.from_java(java_value_info)
else:
return SequenceInfo.from_java(java_value_info)


class TensorInfo(ValueInfo):
def __init__(self, shape: List[int], type: str):
self.shape = shape
self.type = type

def __repr__(self):
return str(self)

def __str__(self):
return "TensorInfo(shape={}, type={})".format("[" + ",".join(map(str, self.shape)) + "]", self.type)

@classmethod
def from_java(cls, java_tensor_info: JavaObject) -> "TensorInfo":
shape = list(java_tensor_info.getShape())
type = java_gateway.get_field(java_tensor_info, "type").toString()
return cls(shape, type)


class MapInfo(ValueInfo):
def __init__(self, key_type: str, value_type: str, size: int = -1):
self.key_type = key_type
self.value_type = value_type
self.size = size

def __repr__(self) -> str:
return str(self)

def __str__(self) -> str:
initial = "MapInfo(size=UNKNOWN" if self.size == -1 else "MapInfo(size=" + str(self.size)
return initial + ",keyType=" + self.key_type + ",valueType=" + self.value_type + ")"

@classmethod
def from_java(cls, java_map_info: JavaObject) -> "MapInfo":
if java_map_info == None:
return None
else:
key_type = java_gateway.get_field(java_map_info, "keyType").toString()
value_type = java_gateway.get_field(java_map_info, "valueType").toString()
size = java_gateway.get_field(java_map_info, "size")
return cls(key_type, value_type, size)


class SequenceInfo(ValueInfo):
def __init__(self, length: int, sequence_of_maps: bool, map_info: MapInfo, sequence_type: str):
self.length = length
self.sequence_of_maps = sequence_of_maps
self.map_info = map_info
self.sequence_type = sequence_type

def __repr__(self) -> str:
return str(self)

def __str__(self) -> str:
initial = "SequenceInfo(length=" + ("UNKNOWN" if self.length == -1 else str(self.length))
if self.sequence_of_maps:
initial += ",type=" + str(self.map_info) + ")"
else:
initial += ",type=" + str(self.sequence_type) + ")"
return initial

@classmethod
def from_java(cls, java_sequence_info: JavaObject) -> "SequenceInfo":
length = java_gateway.get_field(java_sequence_info, "length")
sequence_of_maps = java_gateway.get_field(java_sequence_info, "sequenceOfMaps")
map_info = MapInfo.from_java(java_gateway.get_field(java_sequence_info, "mapInfo"))
sequence_type = java_gateway.get_field(java_sequence_info, "sequenceType").toString()
return cls(length, sequence_of_maps, map_info, sequence_type)
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ import org.apache.spark.{SparkContext, TaskContext}
import spray.json.DefaultJsonProtocol._

import java.nio._
import java.util
import scala.annotation.tailrec
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters.mapAsScalaMapConverter
import scala.reflect.ClassTag

Expand Down Expand Up @@ -262,16 +264,32 @@ object ONNXModel extends ComplexParamsReadable[ONNXModel] with Logging {
}
}

private def flattenNested[T: ClassTag](nestedSeq: Seq[_]): Seq[T] = {
nestedSeq.flatMap {
case x: T => Array(x)
private def writeNestedSeqToBuffer[T: ClassTag](nestedSeq: Seq[_], bufferWrite: T => Unit): Unit = {
nestedSeq.foreach {
case x: T => bufferWrite(x)
case s: Seq[_] =>
flattenNested(s)
case a: Array[_] =>
flattenNested(a)
writeNestedSeqToBuffer(s, bufferWrite)
}
}

private def writeNestedSeqToStringBuffer(nestedSeq: Seq[_], size: Int): ArrayBuffer[String] = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you share a little details on why this stuff is necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method writeNestedSeqToStringBuffer, together with writeNestedSeqToBuffer, are used to replace the flattenNested method.

flattenNested flattens a nested Seq to a 1-d array, but is super slow due to repeated memory allocation.
writeNestedSeqToStringBuffer and writeNestedSeqToBuffer both allocates the memory buffer one time, and write the nested Seq to the buffer. I did some local testing, and they appear to be 5-10 times faster than flattenNested.

The reason I need two methods is because there is not a nio.Buffer class for string type, so I had to replace it with a ArrayBuffer[String].

var i = 0
val buffer = ArrayBuffer.fill[String](size)("")

def innerWrite(nestedSeq: Seq[_]): Unit = {
nestedSeq.foreach {
case x: String =>
buffer.update(i, x)
i = i + 1
case s: Seq[_] =>
innerWrite(s)
}
}

innerWrite(nestedSeq)
buffer
}

private[onnx] def selectGpuDevice(deviceType: Option[String]): Option[Int] = {
deviceType match {
case None | Some("CUDA") =>
Expand Down Expand Up @@ -336,36 +354,47 @@ object ONNXModel extends ComplexParamsReadable[ONNXModel] with Logging {
})
}

private def createTensor(env: OrtEnvironment, tensorInfo: TensorInfo, batchedValues: Seq[_]) = {
val classTag = ClassTag(tensorInfo.`type`.clazz)
val flattened: Array[_] = flattenNested(batchedValues)(classTag).toArray

private def createTensor(env: OrtEnvironment, tensorInfo: TensorInfo, batchedValues: Seq[_]): OnnxTensor = {
val shape: Array[Long] = tensorInfo.getShape
// the first dimension of the shape can be -1 when multiple inputs are allowed. Setting it to the real
// input size. Otherwise we cannot create the tensor from the 1D array buffer.
shape(0) = batchedValues.length
val size = shape.product.toInt

tensorInfo.`type` match {
case OnnxJavaType.FLOAT =>
val buffer = FloatBuffer.wrap(flattened.map(_.asInstanceOf[Float]))
val buffer = FloatBuffer.allocate(size)
writeNestedSeqToBuffer[Float](batchedValues, buffer.put(_))
buffer.rewind()
OnnxTensor.createTensor(env, buffer, shape)
case OnnxJavaType.DOUBLE =>
val buffer = DoubleBuffer.wrap(flattened.map(_.asInstanceOf[Double]))
val buffer = DoubleBuffer.allocate(size)
writeNestedSeqToBuffer[Double](batchedValues, buffer.put(_))
buffer.rewind()
OnnxTensor.createTensor(env, buffer, shape)
case OnnxJavaType.INT8 =>
val buffer = ByteBuffer.wrap(flattened.map(_.asInstanceOf[Byte]))
val buffer = ByteBuffer.allocate(size)
writeNestedSeqToBuffer[Byte](batchedValues, buffer.put(_))
buffer.rewind()
OnnxTensor.createTensor(env, buffer, shape)
case OnnxJavaType.INT16 =>
val buffer = ShortBuffer.wrap(flattened.map(_.asInstanceOf[Short]))
val buffer = ShortBuffer.allocate(size)
writeNestedSeqToBuffer[Short](batchedValues, buffer.put(_))
buffer.rewind()
OnnxTensor.createTensor(env, buffer, shape)
case OnnxJavaType.INT32 =>
val buffer = IntBuffer.wrap(flattened.map(_.asInstanceOf[Int]))
val buffer = IntBuffer.allocate(size)
writeNestedSeqToBuffer[Int](batchedValues, buffer.put(_))
buffer.rewind()
OnnxTensor.createTensor(env, buffer, shape)
case OnnxJavaType.INT64 =>
val buffer = LongBuffer.wrap(flattened.map(_.asInstanceOf[Long]))
val buffer = LongBuffer.allocate(size)
writeNestedSeqToBuffer[Long](batchedValues, buffer.put(_))
buffer.rewind()
OnnxTensor.createTensor(env, buffer, shape)
case OnnxJavaType.STRING =>
OnnxTensor.createTensor(env, flattened.map(_.asInstanceOf[String]), shape)
val flattened = writeNestedSeqToStringBuffer(batchedValues, size).toArray
OnnxTensor.createTensor(env, flattened, shape)
case other =>
throw new NotImplementedError(s"Tensor input type $other not supported. " +
s"Only FLOAT, DOUBLE, INT8, INT16, INT32, INT64, STRING types are supported.")
Expand Down Expand Up @@ -414,6 +443,10 @@ class ONNXModel(override val uid: String)
}.flatten.get
}

def modelInputJava: util.Map[String, NodeInfo] = {
collection.mutable.Map(modelInput.toSeq: _*).asJava
}

def modelOutput: Map[String, NodeInfo] = {
using(OrtEnvironment.getEnvironment) {
env =>
Expand All @@ -423,6 +456,10 @@ class ONNXModel(override val uid: String)
}.flatten.get
}

def modelOutputJava: util.Map[String, NodeInfo] = {
collection.mutable.Map(modelOutput.toSeq: _*).asJava
}

private var broadcastedModelPayload: Option[Broadcast[Array[Byte]]] = None

def setModelPayload(value: Array[Byte]): this.type = {
Expand Down
47 changes: 47 additions & 0 deletions docs/onnx.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
---
title: ONNX model inferencing on Spark
description: Learn how to use the ONNX model transformer to run inference for an ONNX model on Spark.
---

# ONNX model inferencing on Spark

## ONNX

[ONNX](https://onnx.ai/) is an open format to represent both deep learning and traditional machine learning models. With ONNX, AI developers can more easily move models between state-of-the-art tools and choose the combination that is best for them.

MMLSpark now includes a Spark transformer to bring an trained ONNX model to Apache Spark, so you can run inference on your data with Spark's large-scale data processing power.

## Usage

1. Create a `com.microsoft.ml.spark.onnx.ONNXModel` object and use `setModelLocation` or `setModelPayload` to load the ONNX model.

For example:

```scala
val onnx = new ONNXModel().setModelLocation("/path/to/model.onnx")
```

2. Use ONNX visualization tool (e.g. [Netron](https://netron.app/)) to inspect the ONNX model's input and output nodes.

![Screenshot that illustrates an ONNX model's input and output nodes](https://mmlspark.blob.core.windows.net/graphics/ONNXModelInputsOutputs.png)

3. Set the parameters properly to the `ONNXModel` object.

The `com.microsoft.ml.spark.onnx.ONNXModel` class provides a set of parameters to control the behavior of the inference.

| Parameter | Description | Default Value |
|:------------------|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------------------|
| feedDict | Map the ONNX model's expected input node names to the input DataFrame's column names. Make sure the input DataFrame's column schema matches with the corresponding input's shape of the ONNX model. For example, an image classification model may have an input node of shape `[1, 3, 224, 224]` with type Float. It is assumed that the first dimension (1) is the batch size. Then the input DataFrame's corresponding column's type should be `ArrayType(ArrayType(ArrayType(FloatType)))`. | None |
| fetchDict | Map the output DataFrame's column names to the ONNX model's output node names. | None |
| miniBatcher | Specify the MiniBatcher to use. | `FixedMiniBatchTransformer` with batch size 10 |
| softMaxDict | A map between output DataFrame columns, where the value column will be computed from taking the softmax of the key column. If the 'rawPrediction' column contains logits outputs, then one can set softMaxDict to `Map("rawPrediction" -> "probability")` to obtain the probability outputs. | None |
| argMaxDict | A map between output DataFrame columns, where the value column will be computed from taking the argmax of the key column. This can be used to convert probability or logits output to the predicted label. | None |
| deviceType | Specify a device type the model inference runs on. Supported types are: CPU or CUDA. If not specified, auto detection will be used. | None |
| optimizationLevel | Specify the [optimization level](https://onnxruntime.ai/docs/resources/graph-optimizations.html#graph-optimization-levels) for the ONNX graph optimizations. Supported values are: `NO_OPT`, `BASIC_OPT`, `EXTENDED_OPT`, `ALL_OPT`. | `ALL_OPT` |

4. Call `transform` method to run inference on the input DataFrame.

## Example

- [Interpretability - Image Explainers](../notebooks/Interpretability%20-%20Image%20Explainers.ipynb)
- [ONNX - Inference on Spark](../notebooks/ONNX%20-%20Inference%20on%20Spark.ipynb)
Loading