Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[WIP] More libnd4j exec tests (#7009)
* new op: evaluate_reduction_shape * couple of tests for new op * bunch of commented out tests re-enabled * - axis handling for legacy reduce ops wrappers - one more graph + test * unsorted_segment_* tweaked for last argument optional source * histogram tweaks * more skips removed
- Loading branch information
Showing
20 changed files
with
214 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
85 changes: 85 additions & 0 deletions
85
libnd4j/include/ops/declarable/generic/shape/evaluate_reduction_shape.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
/******************************************************************************* | ||
* Copyright (c) 2015-2018 Skymind, Inc. | ||
* | ||
* This program and the accompanying materials are made available under the | ||
* terms of the Apache License, Version 2.0 which is available at | ||
* https://www.apache.org/licenses/LICENSE-2.0. | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||
* License for the specific language governing permissions and limitations | ||
* under the License. | ||
* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
******************************************************************************/ | ||
|
||
// | ||
// @author raver119@gmail.com | ||
// | ||
|
||
#include <op_boilerplate.h> | ||
#if NOT_EXCLUDED(OP_evaluate_reduction_shape) | ||
|
||
#include <ops/declarable/CustomOperations.h> | ||
|
||
namespace nd4j { | ||
namespace ops { | ||
CUSTOM_OP_IMPL(evaluate_reduction_shape, 2, 1, false, 0, 0) { | ||
auto inputShape = INPUT_VARIABLE(0); | ||
auto axis = INPUT_VARIABLE(1)->asVectorT<int>(); | ||
auto keepDims = block.numB() > 0 ? B_ARG(0) : false; | ||
auto oldFormat = block.numB() > 1 ? B_ARG(1) : false; | ||
auto output = OUTPUT_VARIABLE(0); | ||
|
||
auto shape = inputShape->asVectorT<Nd4jLong>(); | ||
|
||
auto tempShapeInfo = ShapeBuilders::createShapeInfo(nd4j::DataType::INT64, 'c', shape, block.workspace()); | ||
auto tempReductionShapeInfo = ShapeUtils::evalReduceShapeInfo('c', axis, tempShapeInfo, keepDims, oldFormat, block.workspace()); | ||
|
||
REQUIRE_TRUE(output->lengthOf() == shape::rank(tempReductionShapeInfo), 0, "evaluate_reduction_shape: output length should be %i, but got %i instead", shape::rank(tempReductionShapeInfo), output->lengthOf()); | ||
|
||
for (int e = 0; e < shape::rank(tempReductionShapeInfo); e++) | ||
output->p(e, tempReductionShapeInfo[e+1]); | ||
|
||
// we must release temporary shapeInfo | ||
RELEASE(tempReductionShapeInfo, block.workspace()); | ||
RELEASE(tempShapeInfo, block.workspace()); | ||
|
||
return Status::OK(); | ||
} | ||
|
||
DECLARE_TYPES(evaluate_reduction_shape) { | ||
getOpDescriptor() | ||
->setAllowedInputTypes(0, {ALL_INTS}) | ||
->setAllowedInputTypes(1, {ALL_INTS}) | ||
->setAllowedOutputTypes(0, nd4j::DataType::INT64); | ||
} | ||
|
||
DECLARE_SHAPE_FN(evaluate_reduction_shape) { | ||
auto input = INPUT_VARIABLE(0); | ||
auto axis = INPUT_VARIABLE(1)->asVectorT<int>(); | ||
|
||
auto keepDims = block.numB() > 0 ? B_ARG(0) : false; | ||
auto oldFormat = block.numB() > 1 ? B_ARG(1) : false; | ||
|
||
Nd4jLong length = input->lengthOf(); | ||
|
||
if (keepDims) { | ||
if (oldFormat) { | ||
// for oldFormat we can't go below rank 2 | ||
length = nd4j::math::nd4j_max<int>(2, length); | ||
} | ||
} else { | ||
length -= axis.size(); | ||
if (oldFormat) { | ||
length = nd4j::math::nd4j_max<int>(2, length); | ||
} | ||
} | ||
|
||
|
||
return SHAPELIST(ShapeBuilders::createVectorShapeInfo(nd4j::DataType::INT64, length, block.workspace())); | ||
} | ||
} | ||
} | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.