Skip to content

Commit

Permalink
Shugeo segment bp (#6402)
Browse files Browse the repository at this point in the history
* Added declarations for all segment_*_bp ops (sorted and unsorted).

* segment_sum_bp op implementation. Initial version.

* Added implementation for segment_min/max/mean/prod_bp ops.

* Added implementation for unsorted_segment_*_bp ops.

* Moved functionality of backprop ops to helpers.

* Implemented and tested 1D case for segment_*_bp ops.

* Implemented ND version for segment_*_bp ops.

* Added tests for all segment_*_bp ops.
  • Loading branch information
shugeo authored and raver119 committed Sep 10, 2018
1 parent 1255ca8 commit 46f84ca
Show file tree
Hide file tree
Showing 15 changed files with 1,067 additions and 16 deletions.
21 changes: 21 additions & 0 deletions libnd4j/include/ops/declarable/generic/parity_ops/segment_max.cpp
Expand Up @@ -62,6 +62,27 @@ namespace nd4j {

return SHAPELIST(outputShape);
}
CUSTOM_OP_IMPL(segment_max_bp, 3, 2, false, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto indices = INPUT_VARIABLE(1);
auto gradOut = INPUT_VARIABLE(2);
auto output = OUTPUT_VARIABLE(0);
auto outIndices = OUTPUT_VARIABLE(1);
outIndices->assign(indices);
return helpers::segmentMaxFunctorBP(input, indices, gradOut, output);
}
DECLARE_SHAPE_FN(segment_max_bp){
Nd4jLong* in = inputShape->at(0);
Nd4jLong* inIdx = inputShape->at(1);

Nd4jLong* outShape;
Nd4jLong* outIndex;
COPY_SHAPE(in, outShape);
COPY_SHAPE(inIdx, outIndex);
return SHAPELIST(outShape, outIndex);

}

}

}
Expand Up @@ -62,6 +62,25 @@ namespace nd4j {

return SHAPELIST(outputShape);
}
}

CUSTOM_OP_IMPL(segment_mean_bp, 3, 2, false, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto indices = INPUT_VARIABLE(1);
auto gradOut = INPUT_VARIABLE(2);
auto output = OUTPUT_VARIABLE(0);
auto outIndices = OUTPUT_VARIABLE(1);
outIndices->assign(indices);
return helpers::segmentMeanFunctorBP(input, indices, gradOut, output);
}
DECLARE_SHAPE_FN(segment_mean_bp){
Nd4jLong* in = inputShape->at(0);
Nd4jLong* inIdx = inputShape->at(1);

Nd4jLong* outShape;
Nd4jLong* outIndex;
COPY_SHAPE(in, outShape);
COPY_SHAPE(inIdx, outIndex);
return SHAPELIST(outShape, outIndex);
}
}
}
Expand Up @@ -62,6 +62,24 @@ namespace nd4j {

return SHAPELIST(outputShape);
}
}
CUSTOM_OP_IMPL(segment_min_bp, 3, 2, false, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto indices = INPUT_VARIABLE(1);
auto gradOut = INPUT_VARIABLE(2);
auto output = OUTPUT_VARIABLE(0);
auto outIndices = OUTPUT_VARIABLE(1);
outIndices->assign(indices);
return helpers::segmentMinFunctorBP(input, indices, gradOut, output);
}
DECLARE_SHAPE_FN(segment_min_bp){
Nd4jLong* in = inputShape->at(0);
Nd4jLong* inIdx = inputShape->at(1);

Nd4jLong* outShape;
Nd4jLong* outIndex;
COPY_SHAPE(in, outShape);
COPY_SHAPE(inIdx, outIndex);
return SHAPELIST(outShape, outIndex);
}
}
}
Expand Up @@ -61,6 +61,26 @@ namespace nd4j {

return SHAPELIST(outputShape);
}
}

CUSTOM_OP_IMPL(segment_prod_bp, 3, 2, false, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto indices = INPUT_VARIABLE(1);
auto gradOut = INPUT_VARIABLE(2);
auto output = OUTPUT_VARIABLE(0);
auto outIndices = OUTPUT_VARIABLE(1);
outIndices->assign(indices);
return helpers::segmentProdFunctorBP(input, indices, gradOut, output);
}

DECLARE_SHAPE_FN(segment_prod_bp){
Nd4jLong* in = inputShape->at(0);
Nd4jLong* inIdx = inputShape->at(1);

Nd4jLong* outShape;
Nd4jLong* outIndex;
COPY_SHAPE(in, outShape);
COPY_SHAPE(inIdx, outIndex);
return SHAPELIST(outShape, outIndex);
}
}
}
17 changes: 17 additions & 0 deletions libnd4j/include/ops/declarable/generic/parity_ops/segment_sum.cpp
Expand Up @@ -62,6 +62,23 @@ namespace nd4j {

return SHAPELIST(outputShape);
}

CUSTOM_OP_IMPL(segment_sum_bp, 3, 2, false, 0, 0) {

return helpers::segmentSumFunctorBP(INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), OUTPUT_VARIABLE(0));
}
DECLARE_SHAPE_FN(segment_sum_bp){
Nd4jLong* in = inputShape->at(0);
Nd4jLong* inIdx = inputShape->at(1);

Nd4jLong* outShape;
Nd4jLong* outIndex;
COPY_SHAPE(in, outShape);
COPY_SHAPE(inIdx, outIndex);
return SHAPELIST(outShape, outIndex);

}
}


}
Expand Up @@ -59,6 +59,19 @@ namespace nd4j {

return SHAPELIST(outputShape);
}
}

CUSTOM_OP_IMPL(unsorted_segment_max_bp, 3, 2, false, 0, 1) {
return helpers::unsortedSegmentMaxFunctorBP(INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_VARIABLE(0));
}
DECLARE_SHAPE_FN(unsorted_segment_max_bp){
Nd4jLong* in = inputShape->at(0);
Nd4jLong* inIdx = inputShape->at(1);

Nd4jLong* outShape;
Nd4jLong* outIndex;
COPY_SHAPE(in, outShape);
COPY_SHAPE(inIdx, outIndex);
return SHAPELIST(outShape, outIndex);
}
}
}
Expand Up @@ -15,15 +15,15 @@
******************************************************************************/

//
// Created by george@skymind.io on 2/21/2018.
// Created by george@skymind.io on 9/6/2018.
//

#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/segment.h>

namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(unsorted_segment_mean, 2, 1, false, 0, 0) {
CUSTOM_OP_IMPL(unsorted_segment_mean, 2, 1, false, 0, 1) {
NDArray<T>* input = INPUT_VARIABLE(0);
NDArray<T>* idxSegments = INPUT_VARIABLE(1);
NDArray<T>* segmentedOutput = OUTPUT_VARIABLE(0);
Expand Down Expand Up @@ -60,6 +60,20 @@ namespace nd4j {

return SHAPELIST(outputShape);
}
}

CUSTOM_OP_IMPL(unsorted_segment_mean_bp, 3, 2, false, 0, 1) {
return helpers::unsortedSegmentMeanFunctorBP(INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_VARIABLE(0));
}
DECLARE_SHAPE_FN(unsorted_segment_mean_bp){
Nd4jLong* in = inputShape->at(0);
Nd4jLong* inIdx = inputShape->at(1);

Nd4jLong* outShape;
Nd4jLong* outIndex;
COPY_SHAPE(in, outShape);
COPY_SHAPE(inIdx, outIndex);
return SHAPELIST(outShape, outIndex);

}
}
}
Expand Up @@ -15,15 +15,15 @@
******************************************************************************/

//
// Created by george@skymind.io on 2/21/2018.
// Created by george@skymind.io on 9/6/2018.
//

#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/segment.h>

namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(unsorted_segment_min, 2, 1, false, 0, 0) {
CUSTOM_OP_IMPL(unsorted_segment_min, 2, 1, false, 0, 1) {
NDArray<T>* input = INPUT_VARIABLE(0);
NDArray<T>* idxSegments = INPUT_VARIABLE(1);
NDArray<T>* segmentedOutput = OUTPUT_VARIABLE(0);
Expand Down Expand Up @@ -59,6 +59,23 @@ namespace nd4j {

return SHAPELIST(outputShape);
}

CUSTOM_OP_IMPL(unsorted_segment_min_bp, 3, 2, false, 0, 1) {
return helpers::unsortedSegmentMinFunctorBP(INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_VARIABLE(0));
}

DECLARE_SHAPE_FN(unsorted_segment_min_bp){
Nd4jLong* in = inputShape->at(0);
Nd4jLong* inIdx = inputShape->at(1);

Nd4jLong* outShape;
Nd4jLong* outIndex;
COPY_SHAPE(in, outShape);
COPY_SHAPE(inIdx, outIndex);
return SHAPELIST(outShape, outIndex);

}

}

}
Expand Up @@ -15,15 +15,15 @@
******************************************************************************/

//
// Created by george@skymind.io on 2/21/2018.
// Created by george@skymind.io on 9/6/2018.
//

#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/segment.h>

namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(unsorted_segment_prod, 2, 1, false, 0, 0) {
CUSTOM_OP_IMPL(unsorted_segment_prod, 2, 1, false, 0, 1) {
NDArray<T>* input = INPUT_VARIABLE(0);
NDArray<T>* idxSegments = INPUT_VARIABLE(1);
NDArray<T>* segmentedOutput = OUTPUT_VARIABLE(0);
Expand Down Expand Up @@ -59,6 +59,23 @@ namespace nd4j {

return SHAPELIST(outputShape);
}

CUSTOM_OP_IMPL(unsorted_segment_prod_bp, 3, 2, false, 0, 1) {
return helpers::unsortedSegmentProdFunctorBP(INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_VARIABLE(0));
}

DECLARE_SHAPE_FN(unsorted_segment_prod_bp){
Nd4jLong* in = inputShape->at(0);
Nd4jLong* inIdx = inputShape->at(1);

Nd4jLong* outShape;
Nd4jLong* outIndex;
COPY_SHAPE(in, outShape);
COPY_SHAPE(inIdx, outIndex);
return SHAPELIST(outShape, outIndex);

}

}

}
Expand Up @@ -15,7 +15,7 @@
******************************************************************************/

//
// Created by george@skymind.io on 2/21/2018.
// Created by george@skymind.io on 9/6/2018.
//

#include <ops/declarable/CustomOperations.h>
Expand Down Expand Up @@ -59,5 +59,22 @@ namespace nd4j {

return SHAPELIST(outputShape);
}

CUSTOM_OP_IMPL(unsorted_segment_sqrt_n_bp, 3, 2, false, 0, 1) {
return helpers::unsortedSegmentSqrtNFunctorBP(INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_VARIABLE(0));
}

DECLARE_SHAPE_FN(unsorted_segment_sqrt_n_bp){
Nd4jLong* in = inputShape->at(0);
Nd4jLong* inIdx = inputShape->at(1);

Nd4jLong* outShape;
Nd4jLong* outIndex;
COPY_SHAPE(in, outShape);
COPY_SHAPE(inIdx, outIndex);
return SHAPELIST(outShape, outIndex);

}

}
}
Expand Up @@ -23,7 +23,7 @@

namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(unsorted_segment_sum, 2, 1, false, 0, 0) {
CUSTOM_OP_IMPL(unsorted_segment_sum, 2, 1, false, 0, 1) {
NDArray<T>* input = INPUT_VARIABLE(0);
NDArray<T>* idxSegments = INPUT_VARIABLE(1);
NDArray<T>* segmentedOutput = OUTPUT_VARIABLE(0);
Expand Down Expand Up @@ -59,6 +59,22 @@ namespace nd4j {

return SHAPELIST(outputShape);
}
CUSTOM_OP_IMPL(unsorted_segment_sum_bp, 3, 2, false, 0, 1) {
return helpers::unsortedSegmentSumFunctorBP(INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_VARIABLE(0));
}

DECLARE_SHAPE_FN(unsorted_segment_sum_bp){
Nd4jLong* in = inputShape->at(0);
Nd4jLong* inIdx = inputShape->at(1);

Nd4jLong* outShape;
Nd4jLong* outIndex;
COPY_SHAPE(in, outShape);
COPY_SHAPE(inIdx, outIndex);
return SHAPELIST(outShape, outIndex);

}

}

}

0 comments on commit 46f84ca

Please sign in to comment.