Skip to content

Commit

Permalink
Shugeo segment refactor (#6391)
Browse files Browse the repository at this point in the history
* usorted_segment_* ops implementation. Initial version.

* with proper unsorted_segment_sqrt_n helper signature.

* Implemented index check up for segment_* ops.

* 1D case for unsorted_segment_max was implemented and tested.

* Implemented and tested unsorted_segment_max op.

* Implemented and tested unsorted_segment_min op.

* Implemented and tested unsorted_segment_sum op.

* Implemented and tested unsorted_segment_prod op.

* Implemented unsorted_segment_mean and tests.

* Implemented and tested unsorted_segment_sqrt_n op.
  • Loading branch information
shugeo authored and raver119 committed Sep 7, 2018
1 parent d796e6e commit 5829de9
Show file tree
Hide file tree
Showing 15 changed files with 1,412 additions and 21 deletions.
Expand Up @@ -30,10 +30,10 @@ namespace nd4j {
REQUIRE_TRUE(idxSegments->isVector(), 0, "segment_max: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "segment_max: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));

T expected = (T) 0.f, wrong = (T) 0.f;
Nd4jLong expected, wrong;

REQUIRE_TRUE(helpers::segmentIndicesValidate(idxSegments, expected, wrong), 0, "segment_max: segment indices should be arranged, but %2.1f > %2.1f",
expected, wrong);
REQUIRE_TRUE(helpers::segmentIndicesValidate(idxSegments, expected, wrong), 0, "segment_max: segment indices should be arranged, but %i > %i",
wrong, expected);

helpers::segmentMaxFunctor(input, idxSegments, segmentedOutput);

Expand Down
Expand Up @@ -30,10 +30,10 @@ namespace nd4j {
REQUIRE_TRUE(idxSegments->isVector(), 0, "segment_mean: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "segment_mean: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));

T expected = (T) 0.f, wrong = (T) 0.f;
Nd4jLong expected, wrong;

REQUIRE_TRUE(helpers::segmentIndicesValidate(idxSegments, expected, wrong), 0, "segment_mean: segment indices should be arranged, but %2.1f > %2.1f",
expected, wrong);
REQUIRE_TRUE(helpers::segmentIndicesValidate(idxSegments, expected, wrong), 0, "segment_mean: segment indices should be arranged, but %i > %i",
wrong, expected);

helpers::segmentMeanFunctor(input, idxSegments, segmentedOutput);

Expand Down
Expand Up @@ -30,10 +30,10 @@ namespace nd4j {
REQUIRE_TRUE(idxSegments->isVector(), 0, "segment_min: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "segment_min: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));

T expected = (T) 0.f, wrong = (T) 0.f;
Nd4jLong expected, wrong;

REQUIRE_TRUE(helpers::segmentIndicesValidate(idxSegments, expected, wrong), 0, "segment_min: segment indices should be arranged, but %2.1f > %2.1f",
expected, wrong);
REQUIRE_TRUE(helpers::segmentIndicesValidate(idxSegments, expected, wrong), 0, "segment_min: segment indices should be arranged, but %i > %i",
wrong, expected);

helpers::segmentMinFunctor(input, idxSegments, segmentedOutput);

Expand Down
Expand Up @@ -30,10 +30,10 @@ namespace nd4j {
REQUIRE_TRUE(idxSegments->isVector(), 0, "segment_prod: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "segment_prod: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));

T expected = (T) 0.f, wrong = (T) 0.f;
Nd4jLong expected, wrong;

REQUIRE_TRUE(helpers::segmentIndicesValidate(idxSegments, expected, wrong), 0, "segment_prod: segment indices should be arranged, but %2.1f > %2.1f",
expected, wrong);
REQUIRE_TRUE(helpers::segmentIndicesValidate(idxSegments, expected, wrong), 0, "segment_prod: segment indices should be arranged, but %i > %i",
wrong, expected);

helpers::segmentProdFunctor(input, idxSegments, segmentedOutput);

Expand Down
Expand Up @@ -30,10 +30,10 @@ namespace nd4j {
REQUIRE_TRUE(idxSegments->isVector(), 0, "segment_sum: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "segment_sum: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));

T expected = (T) 0.f, wrong = (T) 0.f;
Nd4jLong expected, wrong;

REQUIRE_TRUE(helpers::segmentIndicesValidate(idxSegments, expected, wrong), 0, "segment_sum: segment indices should be arranged, but %2.1f > %2.1f",
expected, wrong);
REQUIRE_TRUE(helpers::segmentIndicesValidate(idxSegments, expected, wrong), 0, "segment_sum: segment indices should be arranged, but %i > %i",
wrong, expected);

helpers::segmentSumFunctor(input, idxSegments, segmentedOutput);

Expand Down
@@ -0,0 +1,64 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

//
// Created by george@skymind.io on 9/6/2018.
//

#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/segment.h>

namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(unsorted_segment_max, 2, 1, false, 0, 1) {
NDArray<T>* input = INPUT_VARIABLE(0);
NDArray<T>* idxSegments = INPUT_VARIABLE(1);
NDArray<T>* segmentedOutput = OUTPUT_VARIABLE(0);
Nd4jLong numOfClasses = INT_ARG(0);
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_max: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_max: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));

Nd4jLong wrong;

REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(idxSegments, numOfClasses, wrong), 0, "unsorted_segment_max: segment indices should be arranged, but %i > %i",
numOfClasses, wrong);

helpers::unsortedSegmentMaxFunctor(input, idxSegments, numOfClasses, segmentedOutput);

return ND4J_STATUS_OK;
}

DECLARE_SHAPE_FN(unsorted_segment_max) {

auto in = inputShape->at(0);
int outRank = shape::rank(in);
Nd4jLong numOfClasses = INT_ARG(0);
Nd4jLong* outputShape;

ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong);

outputShape[0] = outRank;
outputShape[1] = numOfClasses;
for(int i = 1; i < outRank; ++i)
outputShape[i + 1] = shape::sizeAt(in, i);

shape::updateStrides(outputShape, shape::order(in));

return SHAPELIST(outputShape);
}
}

}
@@ -0,0 +1,65 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

//
// Created by george@skymind.io on 2/21/2018.
//

#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/segment.h>

namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(unsorted_segment_mean, 2, 1, false, 0, 0) {
NDArray<T>* input = INPUT_VARIABLE(0);
NDArray<T>* idxSegments = INPUT_VARIABLE(1);
NDArray<T>* segmentedOutput = OUTPUT_VARIABLE(0);
Nd4jLong numOfClasses = INT_ARG(0);

REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_mean: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_mean: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));

Nd4jLong wrong;

REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(idxSegments, numOfClasses, wrong), 0, "unsorted_segment_mean: segment indices should be arranged, but %i > %i",
wrong, numOfClasses);

helpers::unsortedSegmentMeanFunctor(input, idxSegments, numOfClasses, segmentedOutput);

return ND4J_STATUS_OK;
}

DECLARE_SHAPE_FN(unsorted_segment_mean) {

auto in = inputShape->at(0);
int outRank = shape::rank(in);
Nd4jLong* outputShape = nullptr;
Nd4jLong numOfClasses = INT_ARG(0);

ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong);

outputShape[0] = outRank;
outputShape[1] = numOfClasses;
for(int i = 1; i < outRank; ++i)
outputShape[i + 1] = shape::sizeAt(in, i);

shape::updateStrides(outputShape, shape::order(in));

return SHAPELIST(outputShape);
}
}

}
@@ -0,0 +1,64 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

//
// Created by george@skymind.io on 2/21/2018.
//

#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/segment.h>

namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(unsorted_segment_min, 2, 1, false, 0, 0) {
NDArray<T>* input = INPUT_VARIABLE(0);
NDArray<T>* idxSegments = INPUT_VARIABLE(1);
NDArray<T>* segmentedOutput = OUTPUT_VARIABLE(0);
Nd4jLong numOfClasses = INT_ARG(0);
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_min: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_min: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));

Nd4jLong wrong;

REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(idxSegments, numOfClasses, wrong), 0, "unsorted_segment_min: segment indices should be arranged, but %i > %i",
wrong, numOfClasses);

helpers::unsortedSegmentMinFunctor(input, idxSegments, numOfClasses, segmentedOutput);

return ND4J_STATUS_OK;
}

DECLARE_SHAPE_FN(unsorted_segment_min) {

Nd4jLong* in = inputShape->at(0);
int outRank = shape::rank(in);
Nd4jLong* outputShape = nullptr;
Nd4jLong numOfClasses = INT_ARG(0);

ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong);

outputShape[0] = outRank;
outputShape[1] = numOfClasses;
for(int i = 1; i < outRank; ++i)
outputShape[i + 1] = shape::sizeAt(in, i);

shape::updateStrides(outputShape, shape::order(in));

return SHAPELIST(outputShape);
}
}

}
@@ -0,0 +1,64 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

//
// Created by george@skymind.io on 2/21/2018.
//

#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/segment.h>

namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(unsorted_segment_prod, 2, 1, false, 0, 0) {
NDArray<T>* input = INPUT_VARIABLE(0);
NDArray<T>* idxSegments = INPUT_VARIABLE(1);
NDArray<T>* segmentedOutput = OUTPUT_VARIABLE(0);
Nd4jLong numOfClasses = INT_ARG(0);
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_prod: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_prod: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));

Nd4jLong wrong;

REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(idxSegments, numOfClasses, wrong), 0, "unsorted_segment_prod: segment indices should be arranged, but %i > %i",
wrong, numOfClasses);

helpers::unsortedSegmentProdFunctor(input, idxSegments, numOfClasses, segmentedOutput);

return ND4J_STATUS_OK;
}

DECLARE_SHAPE_FN(unsorted_segment_prod) {

auto in = inputShape->at(0);
int outRank = shape::rank(in);
Nd4jLong* outputShape = nullptr;
Nd4jLong numOfClasses = INT_ARG(0);

ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong);

outputShape[0] = outRank;
outputShape[1] = numOfClasses;
for(int i = 1; i < outRank; ++i)
outputShape[i + 1] = shape::sizeAt(in, i);

shape::updateStrides(outputShape, shape::order(in));

return SHAPELIST(outputShape);
}
}

}
@@ -0,0 +1,63 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

//
// Created by george@skymind.io on 2/21/2018.
//

#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/segment.h>

namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(unsorted_segment_sqrt_n, 2, 1, false, 0, 1) {
NDArray<T>* input = INPUT_VARIABLE(0);
NDArray<T>* idxSegments = INPUT_VARIABLE(1);
NDArray<T>* segmentedOutput = OUTPUT_VARIABLE(0);
Nd4jLong numOfClasses = INT_ARG(0);
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_sqrt_n: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sqrt_n: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));

Nd4jLong wrong;

REQUIRE_TRUE(helpers::segmentIndicesValidate(idxSegments, numOfClasses, wrong), 0, "unsorted_segment_sqrt_n: segment indices should be arranged, but %i > %i",
wrong, numOfClasses);

helpers::unsortedSegmentSqrtNFunctor(input, idxSegments, numOfClasses, segmentedOutput);

return ND4J_STATUS_OK;
}

DECLARE_SHAPE_FN(unsorted_segment_sqrt_n) {

auto in = inputShape->at(0);
int outRank = shape::rank(in);
Nd4jLong* outputShape = nullptr;
Nd4jLong numOfClasses = INT_ARG(0);

ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong);

outputShape[0] = outRank;
outputShape[1] = numOfClasses;
for(int i = 1; i < outRank; ++i)
outputShape[i + 1] = shape::sizeAt(in, i);

shape::updateStrides(outputShape, shape::order(in));

return SHAPELIST(outputShape);
}
}
}

0 comments on commit 5829de9

Please sign in to comment.