-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ND4J: Add convenience method for where_np op #6184
Comments
The second snippets points to a From the WhereNumpy test code it is not clear how to return the element indexes (which is the np.where behavior) and not the values? Do i need to preinitialize the |
For example, is it possible to use private INDArray getIndexWhere(INDArray input, Predicate<Double> predicate) {
// expects only vector or scalar !
List<Integer> indexes = new ArrayList<>();
for (int i = 0; i < input.size(0); i++) {
if (predicate.test(input.getDouble(i)))
indexes.add(i);
}
return CollectionUtils.isEmpty(indexes) ?
Nd4j.empty(DataBuffer.Type.FLOAT) : return Nd4j.create(indexes);
} |
OK, so looking at this again: One usage of where_np op takes 3 input variables, condition, x, y For the single input variable case (i.e., only "condition" array specified): it seems we return a 2d array of coordinates instead of a tuple of 1d coordinates deeplearning4j/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp Lines 86 to 113 in 451dd76
So, to answer your question: essentially, yes, but (same as numpy) input is an array post applicaiton of the predicate, not predicate as an argument. That said, it seems the where_np shape function is wrong: deeplearning4j/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp Lines 121 to 145 in 451dd76
A workaround for now is simply calculating the number of matches yourself - for example:
|
* #6228 fit with backprop(false) error * Early stopping epoch termination conditions - handle maximize case * #6159 ND4J permute - wrong arg length case * #6184 add Nd4j.where - where_np convenience method * Early stopping fixes for pretrain layers * Update to EarlyStoppingParallelTrainer given termination condition API change * ConvolutionUtils permute fix * Small indexing fix for CG * Fix Nd4j.where * Exclude a couple of TF import tests (libnd4j exec) for now * BaseNDArray fixes * #6282 Add Nd4j.stack(...) convenience method
This thread has been automatically locked since there has not been any recent activity after it was closed. Please open a new issue for related bugs. |
We have it, but users have to interact with the op directly, allocate output shapes
https://github.com/deeplearning4j/deeplearning4j/blob/451dd76b50355358dc176f2b704e98c43423c5b8/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp
https://github.com/deeplearning4j/deeplearning4j/blob/451dd76b50355358dc176f2b704e98c43423c5b8/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/WhereNumpy.java
So, something like Nd4j.where(...)
For anyone reading this - wokraround is using this, plus Nd4j.getExecutioner().calculateOutputShapes(op):
https://github.com/deeplearning4j/dl4j-examples/blob/049b072a8a082012e5f813269884e8ed524a484d/nd4j-examples/src/main/java/org/nd4j/examples/CustomOpsExamples.java
The text was updated successfully, but these errors were encountered: