Skip to content

Commit

Permalink
Added multiple input for embedding_lookup op. (#6379)
Browse files Browse the repository at this point in the history
  • Loading branch information
shugeo authored and raver119 committed Sep 10, 2018
1 parent b7a5ca9 commit 1255ca8
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 23 deletions.
Expand Up @@ -36,21 +36,42 @@ CUSTOM_OP_IMPL(embedding_lookup, 2, 1, false, 0, 1) {

NDArray<T>* input = INPUT_VARIABLE(0); // lookup param
NDArray<T>* indeces = INPUT_VARIABLE(1); // indeces, as is
NDArray<T>* output = OUTPUT_VARIABLE(0); //
int indexRank = indeces->rankOf();

REQUIRE_TRUE(indexRank > 0, 0, "embeded_lookup: input array of indexes can't be single scalar, the requirement is: rank > 0 !");

int inputRank = input->rankOf();
int lastIndDim = indeces->lengthOf();
int partition_mode = INT_ARG(0); // partition_mode == 0 - i.e. 'mod' , 1 - 'div'

nd4j::ops::gather<T> op;

std::unique_ptr<ResultSet<T>> result(op.execute({input, indeces}, {}, {0}));
REQUIRE_TRUE(result->status() == ND4J_STATUS_OK, 0, "embedding_lookup: cannot retrieve results from gather op.");
REQUIRE_TRUE(result->at(0)->isSameShape(output), 0, "embedding_lookup: wrong shape of return from gather op.");
output->assign(result->at(0));
NDArray<T>* output = OUTPUT_VARIABLE(0); //

if (block.width() > 2) { // multiple input
indeces = INPUT_VARIABLE(block.width() - 1);
std::vector<int> dims(input->rankOf());
int i = output->rankOf() - input->rankOf();
for (auto& v: dims){
v = i++;
}

std::unique_ptr<ResultSet<T>> outputView(output->allTensorsAlongDimension(dims));
REQUIRE_TRUE(block.width() > output->sizeAt(0), 0, "embedding_lookup: input list should be greater then %i, but %i given.",
output->sizeAt(0), block.width()
);
for (Nd4jLong e = 0; e < indeces->lengthOf(); ++e) {
Nd4jLong thisIndex = static_cast<Nd4jLong>((*indeces)(e));
input = INPUT_VARIABLE(thisIndex); // lookup param

outputView->at(e)->assign(input);
}
}
else {
int indexRank = indeces->rankOf();
REQUIRE_TRUE(indexRank > 0, 0, "embeded_lookup: input array of indexes can't be single scalar, the requirement is: rank > 0 !");

int inputRank = input->rankOf();
int lastIndDim = indeces->lengthOf();
int partition_mode = INT_ARG(0); // partition_mode == 0 - i.e. 'mod' , 1 - 'div'

nd4j::ops::gather<T> op;

std::unique_ptr<ResultSet<T>> result(op.execute({input, indeces}, {}, {0}));
REQUIRE_TRUE(result->status() == ND4J_STATUS_OK, 0, "embedding_lookup: cannot retrieve results from gather op.");
REQUIRE_TRUE(result->at(0)->isSameShape(output), 0, "embedding_lookup: wrong shape of return from gather op.");
output->assign(result->at(0));
}
return ND4J_STATUS_OK;
}

Expand All @@ -59,23 +80,40 @@ DECLARE_SHAPE_FN(embedding_lookup) {
auto inShapeInfo = inputShape->at(0);
auto indecesShapeInfo = inputShape->at(1);
int inRank = shape::rank(inShapeInfo);
if (inputShape->size() == 2u) {
int outRank = inRank;

int outRank = inRank;
Nd4jLong *outShapeInfo = nullptr;

Nd4jLong* outShapeInfo = nullptr;

ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong);
std::vector<Nd4jLong> shapeInfo(outRank);

shapeInfo[0] = indecesShapeInfo[1]; // vector - how many elements
for (int e = 1; e < outRank; e++)
shapeInfo[e] = shape::sizeAt(inShapeInfo, e);
if (shape::order(inShapeInfo) == 'c')
shape::shapeBuffer(outRank, shapeInfo.data(), outShapeInfo);
else
shape::shapeBufferFortran(outRank, shapeInfo.data(), outShapeInfo);

return SHAPELIST(outShapeInfo);
}

Nd4jLong *outShapeInfo = nullptr;
int outRank = inRank + 1;
ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong);
std::vector<Nd4jLong> shapeInfo(outRank);

shapeInfo[0] = indecesShapeInfo[1]; // vector - how many elements
NDArray<T>* indeces = INPUT_VARIABLE(block.width() - 1);
shapeInfo[0] = indeces->lengthOf(); // vector - how many elements
for (int e = 1; e < outRank; e++)
shapeInfo[e] = shape::sizeAt(inShapeInfo, e);
if (shape::order(inShapeInfo) == 'c')
shape::shapeBuffer(outRank, shapeInfo.data(), outShapeInfo);
shape::shapeBuffer(outRank, shapeInfo.data(), outShapeInfo);
else
shape::shapeBufferFortran(outRank, shapeInfo.data(), outShapeInfo);
shape::shapeBufferFortran(outRank, shapeInfo.data(), outShapeInfo);

return SHAPELIST(outShapeInfo);

return SHAPELIST(outShapeInfo);
}


Expand Down
41 changes: 41 additions & 0 deletions libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp
Expand Up @@ -1643,6 +1643,47 @@ TEST_F(DeclarableOpsTests5, EmbeddingLookup_2) {
delete result;
}

TEST_F(DeclarableOpsTests5, EmbeddingLookup_3) {


NDArray<float> y('c', {3,2}, {5.f, 4.f, 4.f, 5.f, 3.f, 3.f});
NDArray<float> exp('c', {6, 3, 3}, {
6, 20, 11, 21, 12, 22, 13, 23, 14,
5, 20, 11, 21, 12, 22, 13, 23, 14,
5, 20, 11, 21, 12, 22, 13, 23, 14,
6, 20, 11, 21, 12, 22, 13, 23, 14,
4, 20, 11, 21, 12, 22, 13, 23, 14,
4, 20, 11, 21, 12, 22, 13, 23, 14 });

// y.printShapeInfo("y shape");
// y.printIndexedBuffer("y buffer");
NDArray<float> p1('c', {3,3}, {1, 20, 11, 21, 12, 22, 13, 23, 14});
NDArray<float> p2('c', {3,3}, {2, 20, 11, 21, 12, 22, 13, 23, 14});
NDArray<float> p3('c', {3,3}, {3, 20, 11, 21, 12, 22, 13, 23, 14});
NDArray<float> p4('c', {3,3}, {4, 20, 11, 21, 12, 22, 13, 23, 14});
NDArray<float> p5('c', {3,3}, {5, 20, 11, 21, 12, 22, 13, 23, 14});
NDArray<float> p6('c', {3,3}, {6, 20, 11, 21, 12, 22, 13, 23, 14});
NDArray<float> p7('c', {3,3}, {7, 20, 11, 21, 12, 22, 13, 23, 14});
NDArray<float> p8('c', {3,3}, {8, 20, 11, 21, 12, 22, 13, 23, 14});

// res = tf.nn.embedding_lookup((p1, p2, p3, p4, p5, p6, p7), ids, 'mod')

nd4j::ops::embedding_lookup<float> op;
ResultSet<float> *result = op.execute({&p1, &p2, &p3, &p4, &p5, &p6, &p7, &p8, &y}, {}, {1});
NDArray<float>* output = result->at(0);
// x.printShapeInfo("Input");
// output->printIndexedBuffer("Output");
// exp.printShapeInfo("Expected");
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_TRUE(exp.isSameShape(output));
// output->printIndexedBuffer("Output");
// exp.printIndexedBuffer("Expect");

ASSERT_TRUE(exp.equalsTo(output));

delete result;
}

TEST_F(DeclarableOpsTests5, DynamicPartition_1) {

NDArray<float> x('c', {3, 4, 2}, {10, 20, 11, 21, 12, 22,
Expand Down

0 comments on commit 1255ca8

Please sign in to comment.