Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ND4J: Add convenience method for where_np op #6184

Closed
AlexDBlack opened this issue Aug 17, 2018 · 4 comments
Closed

ND4J: Add convenience method for where_np op #6184

AlexDBlack opened this issue Aug 17, 2018 · 4 comments
Assignees
Labels
ND4J ND4J Issues

Comments

@AlexDBlack AlexDBlack added Java ND4J ND4J Issues labels Aug 17, 2018
@tzolov
Copy link

tzolov commented Aug 17, 2018

The second snippets points to a Nd4j.getExecutioner().exec(op) not Nd4j.getExecutioner().calculateOutputShapes(op) example?

From the WhereNumpy test code it is not clear how to return the element indexes (which is the np.where behavior) and not the values? Do i need to preinitialize the put array with the indexes?
Are there some more elaborate WhereNumpy examples?

@tzolov
Copy link

tzolov commented Aug 18, 2018

For example, is it possible to use WhereNumpy for implementing the following np.where(...) like logic:

private INDArray getIndexWhere(INDArray input, Predicate<Double> predicate) {
   // expects only vector or scalar !
   List<Integer> indexes = new ArrayList<>();
   for (int i = 0; i < input.size(0); i++) {
      if (predicate.test(input.getDouble(i))) 
         indexes.add(i);
   }

   return CollectionUtils.isEmpty(indexes) ?
      Nd4j.empty(DataBuffer.Type.FLOAT) : return Nd4j.create(indexes);
}

@AlexDBlack
Copy link
Contributor Author

OK, so looking at this again:

One usage of where_np op takes 3 input variables, condition, x, y
and 1 output variable
Condition is a 0 or 1 valued array of the same shape as x/y

For the single input variable case (i.e., only "condition" array specified): it seems we return a 2d array of coordinates instead of a tuple of 1d coordinates

// in this case we return 2D matrix, which basically contains coordinates fo true
REQUIRE_TRUE(block.width() == 1, 0, "Where op takes either 1 or 3 operands, But got %d operands instead", block.width());
int width = condition->rankOf();
std::vector<int> dims = ShapeUtils<T>::convertAxisToTadTarget(width, {0});
NDArrayList<T> list(0, true);
int cnt = 0;
Nd4jLong idx[MAX_RANK];
for (int e = 0; e < condition->lengthOf(); e++) {
shape::ind2subC(condition->rankOf(), condition->shapeOf(), e, idx);
auto offset = shape::getOffset(0, condition->shapeOf(), condition->stridesOf(), idx, condition->rankOf());
T v = condition->buffer()[offset];
if (v != (T) 0.0f) {
auto array = new NDArray<T>('c', {1, condition->rankOf()});
for (int f = 0; f < condition->rankOf(); f++)
array->putIndexedScalar(f, (T) idx[f]);
list.write(cnt++, array);
}
}
auto result = list.stack();
OVERWRITE_RESULT(result);

So, to answer your question: essentially, yes, but (same as numpy) input is an array post applicaiton of the predicate, not predicate as an argument.

That said, it seems the where_np shape function is wrong:

DECLARE_SHAPE_FN(where_np) {
if (block.width() == 3) {
auto inShape = inputShape->at(1);
Nd4jLong *newshape;
COPY_SHAPE(inShape, newshape);
return SHAPELIST(newshape);
} else {
// FIXME: we can't estimate result here in this case
auto inShape = inputShape->at(0);
Nd4jLong *newshape;
ALLOCATE(newshape, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong);
newshape[0] = 2;
newshape[1] = 10;
newshape[2] = 10;
newshape[3] = 1;
newshape[4] = 1;
newshape[5] = 0;
newshape[6] = 1;
newshape[7] = 99;
return SHAPELIST(newshape);
}

A workaround for now is simply calculating the number of matches yourself - for example:

int count = Nd4j.getExecutioner().exec(new MatchCondition(cg.params(), Conditions.equals(0))).z().getInt(0);

@AlexDBlack AlexDBlack self-assigned this Aug 22, 2018
AlexDBlack added a commit that referenced this issue Aug 27, 2018
* #6228 fit with backprop(false) error

* Early stopping epoch termination conditions - handle maximize case

* #6159 ND4J permute - wrong arg length case

* #6184 add Nd4j.where - where_np convenience method

* Early stopping fixes for pretrain layers

* Update to EarlyStoppingParallelTrainer given termination condition API change

* ConvolutionUtils permute fix

* Small indexing fix for CG

* Fix Nd4j.where

* Exclude a couple of TF import tests (libnd4j exec) for now

* BaseNDArray fixes

* #6282 Add Nd4j.stack(...) convenience method
@lock
Copy link

lock bot commented Sep 26, 2018

This thread has been automatically locked since there has not been any recent activity after it was closed. Please open a new issue for related bugs.

@lock lock bot locked and limited conversation to collaborators Sep 26, 2018
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
ND4J ND4J Issues
Projects
None yet
Development

No branches or pull requests

2 participants