New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactors onnx import adding better support for image functions #9466
Conversation
…idation for null checks in onnx runner
…dify resize test to accept 1e-1 eps for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@treo summary of PR here
" * SPDX-License-Identifier: Apache-2.0\n" + | ||
" ******************************************************************************/\n"; | ||
"/*\n" + | ||
" * ******************************************************************************\n" + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@treo codegen copyright update
@@ -84,6 +88,7 @@ | |||
typeMapping.put(DataType.DATA_TYPE, org.nd4j.linalg.api.buffer.DataType.class); | |||
typeMapping.put(DataType.LOSS_REDUCE, org.nd4j.autodiff.loss.LossReduce.class); | |||
typeMapping.put(DataType.CONDITION, Condition.class); | |||
typeMapping.put(DataType.STRING, String.class); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added missing data type here
@@ -241,7 +241,7 @@ fun SDImage() = Namespace("Image"){ | |||
Arg(BOOL, "preserveAspectRatio") { description = "Whether to preserve the aspect ratio." + | |||
" If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio" + | |||
" of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False."; defaultValue=false; } | |||
Arg(BOOL, "antialis") { description = "Whether to use an anti-aliasing filter when downsampling an image"; defaultValue=false; } | |||
Arg(BOOL, "antialias") { description = "Whether to use an anti-aliasing filter when downsampling an image"; defaultValue=false; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@treo typo fix
@@ -259,4 +259,36 @@ fun SDImage() = Namespace("Image"){ | |||
""".trimIndent() | |||
} | |||
} | |||
|
|||
Op("resizeBiLinear") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@treo new op for bi linear image resize.
} | ||
|
||
Op("resizeBiCubic") { | ||
javaPackage = "org.nd4j.linalg.api.ops.impl.image" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@treo same for bicubic
dynamicVariables.forEach { name, array -> | ||
dynamicVariablesConverted[name] = convertToOnnxTensor(array,name) | ||
} | ||
val dynamicVariablesConverted = convertToOnnxTensors(dynamicVariables) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactored to be a util function
@@ -82,13 +82,6 @@ class OnnxIRGraph(graphDef: Onnx.GraphProto,opMappingRegistry: OpMappingRegistry | |||
nodeNames.add(node.nodeName()) | |||
} | |||
|
|||
if(indexToNode.isNotEmpty()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@treo this was removed because op definitions already happen automatically and was causing clashes.
} | ||
|
||
val modelProto = ModelProto { | ||
OpSetImport(OperatorSetIdProto { | ||
version = 12 | ||
version = 13 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@treo this op call update was necessary for onnx runtime to work with my current opset.
graphDefBuilder.addOutput(ValueInfoProto { | ||
name = it | ||
}) | ||
if(!graphDef.outputList.contains(it)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@treo this is just a simple check for detecting dups. We shiv in extra "nodes" to streamline the algorithm to work similar to TF import.
@@ -256,5 +257,9 @@ class TensorflowIRGraph(graphDef: GraphDef, opDef: OpList | |||
return false | |||
} | |||
|
|||
override fun convertToNDArray(tensorTypeInput: TensorProto): INDArray { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@treo I added a new utility function to graph for access during model import for dealing with ndarrays.
What changes were proposed in this pull request?
Adds support for onnx resize using pre hook model import
to express the ops needed in terms of pre existing samediff ops.
(Please fill in changes proposed in this fix)
How was this patch tested?
(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)
Quick checklist
The following checklist helps ensure your PR is complete: