Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactors onnx import adding better support for image functions #9466

Merged
merged 12 commits into from Oct 6, 2021

Conversation

agibsonccc
Copy link
Contributor

What changes were proposed in this pull request?

Adds support for onnx resize using pre hook model import
to express the ops needed in terms of pre existing samediff ops.
(Please fill in changes proposed in this fix)

How was this patch tested?

(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)

Quick checklist

The following checklist helps ensure your PR is complete:

  • [ X] Eclipse Contributor Agreement signed, and signed commits - see IP Requirements page for details
  • [X ] Reviewed the Contributing Guidelines and followed the steps within.
  • [X ] Created tests for any significant new code additions.
  • [ X] Relevant tests for your changes are passing.

dsk114u
dsk114u previously approved these changes Oct 1, 2021
@agibsonccc agibsonccc requested a review from treo October 6, 2021 10:25
Copy link
Contributor Author

@agibsonccc agibsonccc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@treo summary of PR here

" * SPDX-License-Identifier: Apache-2.0\n" +
" ******************************************************************************/\n";
"/*\n" +
" * ******************************************************************************\n" +
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@treo codegen copyright update

@@ -84,6 +88,7 @@
typeMapping.put(DataType.DATA_TYPE, org.nd4j.linalg.api.buffer.DataType.class);
typeMapping.put(DataType.LOSS_REDUCE, org.nd4j.autodiff.loss.LossReduce.class);
typeMapping.put(DataType.CONDITION, Condition.class);
typeMapping.put(DataType.STRING, String.class);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added missing data type here

@@ -241,7 +241,7 @@ fun SDImage() = Namespace("Image"){
Arg(BOOL, "preserveAspectRatio") { description = "Whether to preserve the aspect ratio." +
" If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio" +
" of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False."; defaultValue=false; }
Arg(BOOL, "antialis") { description = "Whether to use an anti-aliasing filter when downsampling an image"; defaultValue=false; }
Arg(BOOL, "antialias") { description = "Whether to use an anti-aliasing filter when downsampling an image"; defaultValue=false; }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@treo typo fix

@@ -259,4 +259,36 @@ fun SDImage() = Namespace("Image"){
""".trimIndent()
}
}

Op("resizeBiLinear") {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@treo new op for bi linear image resize.

}

Op("resizeBiCubic") {
javaPackage = "org.nd4j.linalg.api.ops.impl.image"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@treo same for bicubic

dynamicVariables.forEach { name, array ->
dynamicVariablesConverted[name] = convertToOnnxTensor(array,name)
}
val dynamicVariablesConverted = convertToOnnxTensors(dynamicVariables)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored to be a util function

@@ -82,13 +82,6 @@ class OnnxIRGraph(graphDef: Onnx.GraphProto,opMappingRegistry: OpMappingRegistry
nodeNames.add(node.nodeName())
}

if(indexToNode.isNotEmpty()) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@treo this was removed because op definitions already happen automatically and was causing clashes.

}

val modelProto = ModelProto {
OpSetImport(OperatorSetIdProto {
version = 12
version = 13
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@treo this op call update was necessary for onnx runtime to work with my current opset.

graphDefBuilder.addOutput(ValueInfoProto {
name = it
})
if(!graphDef.outputList.contains(it))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@treo this is just a simple check for detecting dups. We shiv in extra "nodes" to streamline the algorithm to work similar to TF import.

@@ -256,5 +257,9 @@ class TensorflowIRGraph(graphDef: GraphDef, opDef: OpList
return false
}

override fun convertToNDArray(tensorTypeInput: TensorProto): INDArray {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@treo I added a new utility function to graph for access during model import for dealing with ndarrays.

@agibsonccc agibsonccc merged commit fa20598 into master Oct 6, 2021
@agibsonccc agibsonccc deleted the ag_onnx_resize branch October 6, 2021 12:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants