-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactors onnx import adding better support for image functions #9466
Changes from all commits
b8fb66b
865cb81
4a5bbc2
d5ee31d
34b544f
e168433
497a909
cd4b919
93345cd
b16f43e
1b6e8b2
7a9c510
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,21 +57,25 @@ public class Nd4jNamespaceGenerator { | |
private static Map<Config, TypeName> configMapping = new HashMap<>(); | ||
public static Count exactlyOne = new Exactly(1); | ||
private static String copyright = | ||
"/*******************************************************************************\n" + | ||
" * Copyright (c) 2019-2020 Konduit K.K.\n" + | ||
" *\n" + | ||
" * This program and the accompanying materials are made available under the\n" + | ||
" * terms of the Apache License, Version 2.0 which is available at\n" + | ||
" * https://www.apache.org/licenses/LICENSE-2.0.\n" + | ||
" *\n" + | ||
" * Unless required by applicable law or agreed to in writing, software\n" + | ||
" * distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT\n" + | ||
" * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the\n" + | ||
" * License for the specific language governing permissions and limitations\n" + | ||
" * under the License.\n" + | ||
" *\n" + | ||
" * SPDX-License-Identifier: Apache-2.0\n" + | ||
" ******************************************************************************/\n"; | ||
"/*\n" + | ||
" * ******************************************************************************\n" + | ||
" * *\n" + | ||
" * *\n" + | ||
" * * This program and the accompanying materials are made available under the\n" + | ||
" * * terms of the Apache License, Version 2.0 which is available at\n" + | ||
" * * https://www.apache.org/licenses/LICENSE-2.0.\n" + | ||
" * *\n" + | ||
" * * See the NOTICE file distributed with this work for additional\n" + | ||
" * * information regarding copyright ownership.\n" + | ||
" * * Unless required by applicable law or agreed to in writing, software\n" + | ||
" * * distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT\n" + | ||
" * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the\n" + | ||
" * * License for the specific language governing permissions and limitations\n" + | ||
" * * under the License.\n" + | ||
" * *\n" + | ||
" * * SPDX-License-Identifier: Apache-2.0\n" + | ||
" * *****************************************************************************\n" + | ||
" */\n"; | ||
private static String codeGenWarning = | ||
"\n//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================\n\n"; | ||
|
||
|
@@ -84,6 +88,7 @@ public class Nd4jNamespaceGenerator { | |
typeMapping.put(DataType.DATA_TYPE, org.nd4j.linalg.api.buffer.DataType.class); | ||
typeMapping.put(DataType.LOSS_REDUCE, org.nd4j.autodiff.loss.LossReduce.class); | ||
typeMapping.put(DataType.CONDITION, Condition.class); | ||
typeMapping.put(DataType.STRING, String.class); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added missing data type here |
||
|
||
validationMapping.put(DataType.BOOL, "validateBool"); | ||
validationMapping.put(DataType.FLOATING_POINT, "validateFloatingPoint"); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -241,7 +241,7 @@ fun SDImage() = Namespace("Image"){ | |
Arg(BOOL, "preserveAspectRatio") { description = "Whether to preserve the aspect ratio." + | ||
" If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio" + | ||
" of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False."; defaultValue=false; } | ||
Arg(BOOL, "antialis") { description = "Whether to use an anti-aliasing filter when downsampling an image"; defaultValue=false; } | ||
Arg(BOOL, "antialias") { description = "Whether to use an anti-aliasing filter when downsampling an image"; defaultValue=false; } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @treo typo fix |
||
Arg(ENUM, "ImageResizeMethod") { possibleValues = listOf( "ResizeBilinear", "ResizeBicubic", "ResizeNearest", "ResizeGaussian", | ||
"ResizeLanczos5", "ResizeMitchelcubic", "ResizeArea"); description = "ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling.\n" + | ||
"ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing.\n" + | ||
|
@@ -259,4 +259,36 @@ fun SDImage() = Namespace("Image"){ | |
""".trimIndent() | ||
} | ||
} | ||
|
||
Op("resizeBiLinear") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @treo new op for bi linear image resize. |
||
javaPackage = "org.nd4j.linalg.api.ops.impl.image" | ||
javaOpClass = "ResizeBilinear" | ||
Input(NUMERIC,"input") { description = "4D image"} | ||
Arg(INT ,"height") { description = "target height for resizing to "} | ||
Arg(INT ,"width") { description = "target width for resizing to"} | ||
Arg(BOOL ,"alignCorners") { description = "whether to align corners during resizing. Images are aligned to preserve corners."} | ||
Arg(BOOL,"halfPixelCenters") { description = "When resizing, assumes pixels are centered at 0.5."} | ||
Output(NUMERIC, "output"){ description = "Output image" } | ||
Doc(Language.ANY, DocScope.ALL){ | ||
""" | ||
Resize images to size using the specified method. | ||
""".trimIndent() | ||
} | ||
} | ||
|
||
Op("resizeBiCubic") { | ||
javaPackage = "org.nd4j.linalg.api.ops.impl.image" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @treo same for bicubic |
||
javaOpClass = "ResizeBicubic" | ||
Input(NUMERIC,"input") { description = "4D image"} | ||
Input(INT ,"size") { description = "the target size to resize to "} | ||
Arg(BOOL ,"alignCorners") { description = "whether to align corners during resizing. Images are aligned to preserve corners."} | ||
Arg(BOOL,"alignPixelCenters") { description = "When resizing, assumes pixels are centered at 0.5."} | ||
Output(NUMERIC, "output"){ description = "Output image" } | ||
Doc(Language.ANY, DocScope.ALL){ | ||
""" | ||
Resize images to size using the specified method. | ||
""".trimIndent() | ||
} | ||
} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ | |
package org.nd4j.codegen.ops | ||
|
||
import org.nd4j.codegen.api.AtLeast | ||
import org.nd4j.codegen.api.DataType | ||
import org.nd4j.codegen.api.Language | ||
import org.nd4j.codegen.api.doc.DocScope | ||
import org.nd4j.codegen.dsl.* | ||
|
@@ -480,6 +481,38 @@ fun SDBaseOps() = Namespace("BaseOps"){ | |
} | ||
} | ||
|
||
|
||
Op("sparseToDense") { | ||
javaPackage = "org.nd4j.linalg.api.ops.compat" | ||
javaOpClass = "CompatSparseToDense" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @treo sparse to dense is also a new op. Used in model import and works fine, just never fully mapped or exposed. |
||
Input(NUMERIC, "indices") { description = "The indices of the sparse matrix" } | ||
Input(NUMERIC, "shape") { description = "The output shape" } | ||
Input(NUMERIC, "values") { description = "The values for the array" } | ||
Output(NUMERIC, "output"){ description = "Populated dense INDArray with given values and indices" } | ||
Doc(Language.ANY, DocScope.ALL){ | ||
""" | ||
Create a dense matrix equivalent of a sparse matrix based on the given input. | ||
""".trimIndent() | ||
} | ||
} | ||
|
||
Op("sparseToDense") { | ||
javaPackage = "org.nd4j.linalg.api.ops.compat" | ||
javaOpClass = "CompatSparseToDense" | ||
Input(NUMERIC, "indices") { description = "The indices of the sparse matrix" } | ||
Input(NUMERIC, "shape") { description = "The output shape" } | ||
Input(NUMERIC, "values") { description = "The values for the array" } | ||
Input(NUMERIC,"defaultValue") { description = "Default value" } | ||
Output(NUMERIC, "output"){ description = "Populated dense INDArray with given values and indices" } | ||
|
||
Doc(Language.ANY, DocScope.ALL){ | ||
""" | ||
Create a dense matrix equivalent of a sparse matrix based on the given input. | ||
""".trimIndent() | ||
} | ||
} | ||
|
||
|
||
Op("lt") { | ||
javaPackage = "org.nd4j.linalg.api.ops.impl.scalar.comparison" | ||
javaOpClass = "ScalarLessThan" | ||
|
@@ -588,6 +621,63 @@ fun SDBaseOps() = Namespace("BaseOps"){ | |
useMixin(keepDimsDoc) | ||
} | ||
|
||
Op("whereNumpy") { | ||
javaPackage = "org.nd4j.linalg.api.ops.impl.controlflow" | ||
javaOpClass = "WhereNumpy" | ||
Input(NUMERIC, "x") { description = "The first array" } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @treo we've always had 2 where implementations: I'm exposing both here and letting user pick. |
||
Input(NUMERIC, "y") { description = "The second array" } | ||
Input(NUMERIC, "condition") { description = "Condition array determining which elements at which indices should be picked from. If true, picks from x, other wise y" } | ||
Output(NUMERIC, "output"){ description = "Number of elements that the condition is satisfied for" } | ||
Doc(Language.ANY, DocScope.ALL){ | ||
""" | ||
As implemented in numpy, Return elements chosen from x or y depending on condition. | ||
""".trimIndent() | ||
} | ||
useMixin(keepDimsDoc) | ||
} | ||
|
||
Op("where") { | ||
javaPackage = "org.nd4j.linalg.api.ops.impl.controlflow" | ||
javaOpClass = "Where" | ||
Input(NUMERIC, "x") { description = "The first array" } | ||
Input(NUMERIC, "y") { description = "The second array" } | ||
Input(BOOL, "condition") { description = "Condition array determining which elements at which indices should be picked from. If true, picks from x, other wise y" } | ||
Output(NUMERIC, "output"){ description = "Number of elements that the condition is satisfied for" } | ||
Doc(Language.ANY, DocScope.ALL){ | ||
""" | ||
Similar to numpy where, takes elements from x or y depending on whether the condition at a given element is true or false | ||
""".trimIndent() | ||
} | ||
useMixin(keepDimsDoc) | ||
} | ||
|
||
Op("where") { | ||
javaPackage = "org.nd4j.linalg.api.ops.impl.controlflow" | ||
javaOpClass = "Where" | ||
Input(NUMERIC, "x") { description = "The first array" } | ||
Input(BOOL, "condition") { description = "Condition array determining which elements at which indices should be picked from. If true, picks from x, other wise y" } | ||
Output(NUMERIC, "output"){ description = "Number of elements that the condition is satisfied for" } | ||
Doc(Language.ANY, DocScope.ALL){ | ||
""" | ||
Similar to numpy where, takes elements from x or y depending on whether the condition at a given element is true or false | ||
""".trimIndent() | ||
} | ||
useMixin(keepDimsDoc) | ||
} | ||
|
||
Op("where") { | ||
javaPackage = "org.nd4j.linalg.api.ops.impl.controlflow" | ||
javaOpClass = "Where" | ||
Input(BOOL, "condition") { description = "Condition array determining which elements at which indices should be picked from. If true, picks from x, other wise y" } | ||
Output(NUMERIC, "output"){ description = "Number of elements that the condition is satisfied for" } | ||
Doc(Language.ANY, DocScope.ALL){ | ||
""" | ||
Returns elements that are true from the given condition array | ||
""".trimIndent() | ||
} | ||
useMixin(keepDimsDoc) | ||
} | ||
|
||
Op("max") { | ||
javaPackage = "org.nd4j.linalg.api.ops.impl.reduce.same" | ||
legacy = true | ||
|
@@ -819,6 +909,25 @@ fun SDBaseOps() = Namespace("BaseOps"){ | |
} | ||
} | ||
|
||
|
||
Op("create") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @treo this is the only op we have for anything like "ones" or "zeros" |
||
javaPackage = "org.nd4j.linalg.api.ops.impl.shape" | ||
javaOpClass = "Create" | ||
Input(NUMERIC, "shape") { description = "Input INDArray " } | ||
Arg(DATA_TYPE,"dataType") {description = "Data type of array"} | ||
//note when generating strings you need to add quotes as well or it's generated incorrectly | ||
Arg(STRING,"order") {description = "Order of array "; defaultValue="\"c\""} | ||
Arg(BOOL,"initialize") {description = "Whether to initialize the array or not "; defaultValue=false} | ||
Output(NUMERIC, "output"){ description = "A new INDArray with the same (dynamic) shape as the input" } | ||
Doc(Language.ANY, DocScope.ALL){ | ||
""" | ||
Return a newly created variable, with the specified shape and data type. | ||
""".trimIndent() | ||
} | ||
} | ||
|
||
|
||
|
||
Op("onesLike") { | ||
javaPackage = "org.nd4j.linalg.api.ops.impl.shape" | ||
Input(NUMERIC, "input") { description = "Input INDArray " } | ||
|
@@ -1388,7 +1497,7 @@ fun SDBaseOps() = Namespace("BaseOps"){ | |
Output(NUMERIC, "output"){ description = "reduced array of rank (input rank - num dimensions)" } | ||
Doc(Language.ANY, DocScope.ALL){ | ||
""" | ||
Stardard deviation array reduction operation, optionally along specified dimensions | ||
Standard deviation array reduction operation, optionally along specified dimensions | ||
""".trimIndent() | ||
} | ||
useMixin(keepDimsDoc) | ||
|
@@ -1420,6 +1529,36 @@ fun SDBaseOps() = Namespace("BaseOps"){ | |
} | ||
} | ||
|
||
|
||
Op("stridedSlice") { | ||
javaPackage = "org.nd4j.linalg.api.ops.impl.shape" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @treo new op: just exposing what's already there. A minor pain was the default values for the masks but I got that worked out. |
||
Input(NUMERIC, "in") { description = "Variable to get subset of" } | ||
Input(NUMERIC, "begin") { description = "The beginning indices for the slice" } | ||
Input(NUMERIC, "end") { description = "The ending indicesof the slice" } | ||
Input(NUMERIC, "strides") { description = "The strides for each dimension" } | ||
Arg(INT, "beginMask") { description = "Bit mask: If the ith bit is set to 1, then the value in the begin long[] is ignored, and a value of 0 is used instead for the beginning index for that dimension"; defaultValue=0 } | ||
Arg(INT, "endMask") { description = "Bit mask: If the ith bit is set to 1, then the value in the end long[] is ignored, and a value of size(i)-1 is used instead for the end index for that dimension"; defaultValue=0 } | ||
Arg(INT, "ellipsisMask") { description = "Bit mask: only one non-zero value is allowed here. If a non-zero value is set, then other dimensions are inserted as required at the specified position"; defaultValue=0 } | ||
Arg(INT, "newAxisMask") { description = "Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and a size 1 dimension is inserted at this point"; defaultValue=0 } | ||
Arg(INT, "shrinkAxisMask") { description = "Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and a size 1 dimension is removed at this point. Note that begin/end/stride values must result in a size 1 output for these dimensions"; defaultValue=0 } | ||
Output(NUMERIC, "output"){ description = "A subset of the input array" } | ||
Doc(Language.ANY, DocScope.ALL){ | ||
""" | ||
Get a subset of the specified input, by specifying the first element, last element, and the strides. | ||
For example, if input is: | ||
[a, b, c] | ||
[d, e, f] | ||
[g, h, i] | ||
then stridedSlice(input, begin=[0,1], end=[2,2], strides=[2,1], all masks = 0) will return: | ||
[b, c] | ||
[h, i] | ||
""".trimIndent() | ||
} | ||
} | ||
|
||
|
||
|
||
|
||
Op("sum") { | ||
javaPackage = "org.nd4j.linalg.api.ops.impl.reduce.same" | ||
legacy = true | ||
|
@@ -1446,7 +1585,7 @@ fun SDBaseOps() = Namespace("BaseOps"){ | |
Doc(Language.ANY, DocScope.ALL){ | ||
""" | ||
Switch operation | ||
Predictate - if false, values are output to left (first) branch/output; if true, to right (second) branch/output | ||
Predicate - if false, values are output to left (first) branch/output; if true, to right (second) branch/output | ||
""".trimIndent() | ||
} | ||
} | ||
|
@@ -1605,6 +1744,9 @@ fun SDBaseOps() = Namespace("BaseOps"){ | |
useMixin(keepDimsDoc) | ||
} | ||
|
||
|
||
|
||
|
||
Op("zerosLike") { | ||
javaPackage = "org.nd4j.linalg.api.ops.impl.shape" | ||
Input(NUMERIC, "input") { description = "Input " } | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2593,6 +2593,9 @@ public SDVariable one(String name, int... shape) { | |
return one(name, Nd4j.defaultFloatingPointType(), shape); | ||
} | ||
|
||
|
||
|
||
|
||
/** | ||
* See {@link #one(String, DataType, long...)}. | ||
* Creates a constant - i.e., CONSTANT type SDVariable. | ||
|
@@ -2987,9 +2990,6 @@ public SDVariable var(INDArray arr) { | |
public SDVariable var(String name, @NonNull INDArray arr) { | ||
if (variables.containsKey(name) && variables.get(name).getVariable().getArr() != null) | ||
throw new IllegalArgumentException("Another variable with the name " + name + " already exists."); | ||
Preconditions.checkState(arr.dataType().isFPType(), "Cannot create variable with non-floating point type:" + | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @treo this was removed because it disallowed booleans among other things for variable creation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In what context do non-fp variables make even sense? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my import use case, booleans and strings. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. eg: We've been supporting strings and booleans for a while and can run string manipulation among other ops that are not numeric. Var shouldn't always have to be just trainable parameters it could be ETL. |
||
" provided array has datatype %s. Variables must be floating point type to be trainable by backpropagation.\n" + | ||
"For non floating point types, these should be created as placeholders or constants instead.", arr.dataType()); | ||
Preconditions.checkArgument(!arr.isEmpty(), "Empty arrays cannot be used when creating variables. Array shape: %ndShape", arr); | ||
|
||
if (name == null || name.length() < 1) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@treo codegen copyright update