Skip to content

Commit

Permalink
[WIP] few more fixes (#6999)
Browse files Browse the repository at this point in the history
* broadcastable ops output for non-experimental build

* linspace mapped into LegacyRandomOp

* multi output exports

* - INDArray.toFlatArray fix for empty data type
- one more graph + test
  • Loading branch information
raver119 authored and sshepel committed Jan 15, 2019
1 parent a1bfb0e commit 2159ea1
Show file tree
Hide file tree
Showing 15 changed files with 239 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

//
// @author raver119@gmail.com
//

#include <exceptions/unresolved_output_exception.h>
#include <StringUtils.h>
#include <utility>

namespace nd4j {
namespace graph {
unresolved_output_exception::unresolved_output_exception(std::string message) : std::runtime_error(message) {
//
}

unresolved_output_exception unresolved_output_exception::build(std::string message, std::pair<int, int> &varIndex) {
auto nodeId = StringUtils::valueToString<int>(varIndex.first);
auto outputIdx = StringUtils::valueToString<int>(varIndex.second);
message += "; Variable: [" + nodeId + ":" + outputIdx + "]";
return unresolved_output_exception(message);
}

unresolved_output_exception unresolved_output_exception::build(std::string message, int nodeId, int outputIndex) {
std::pair<int, int> p(nodeId, outputIndex);
return build(message, p);
}

unresolved_output_exception unresolved_output_exception::build(std::string message, std::string &varName, int outputIndex) {
auto outputIdx = StringUtils::valueToString<int>(outputIndex);
message += "; Variable: [" + varName + ":" + outputIdx + "]";
return unresolved_output_exception(message);
}
}
}
43 changes: 43 additions & 0 deletions libnd4j/include/graph/exceptions/unresolved_output_exception.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

//
// @author raver119@gmail.com
//

#ifndef DEV_TESTS_UNRESOLVED_OUTPUT_H
#define DEV_TESTS_UNRESOLVED_OUTPUT_H

#include <utility>
#include <string>
#include <stdexcept>

namespace nd4j {
namespace graph {
class unresolved_output_exception : public std::runtime_error {
public:
unresolved_output_exception(std::string message);
~unresolved_output_exception() = default;

static unresolved_output_exception build(std::string message, int nodeId, int outputIndex);
static unresolved_output_exception build(std::string message, std::pair<int, int> &varIndex);
static unresolved_output_exception build(std::string message, std::string &varName, int outputIndex);
};
}
}


#endif //DEV_TESTS_UNRESOLVED_INPUT_H
2 changes: 1 addition & 1 deletion libnd4j/include/graph/execution/impl/LogicConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ namespace nd4j {
// now fetch and transfer variables to Conditional node
// but only if return wasn't called at the end of scope
if (!isReturn) {
for (int e = 0; e < 65536; e++) {
for (int e = 0; e < DataTypeUtils::max<int>(); e++) {
std::pair<int, int> pair(lastNode, e);
std::pair<int, int> pairNew(node->id(), e);
if (__variableSpace->hasVariable(pair)) {
Expand Down
16 changes: 14 additions & 2 deletions libnd4j/include/graph/impl/Graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <graph/VariableProxy.h>
#include <graph/exceptions/graph_exception.h>
#include <graph/exceptions/unresolved_input_exception.h>
#include <graph/exceptions/unresolved_output_exception.h>

namespace nd4j {
namespace graph {
Expand Down Expand Up @@ -229,8 +230,19 @@ namespace nd4j {

nd4j_debug("Graph output size: %i\n", _output.size());
for (int e = 0; e < (int) _output.size(); e++) {
nd4j_debug("Output node: %i\n", _output.at(e));
res->push_back(_variableSpace->getVariable(_output.at(e)));
auto nodeId = _output.at(e);
nd4j_debug("Output node: %i\n", nodeId);

for (int e = 0; e < DataTypeUtils::max<int>(); e++) {
if (_variableSpace->hasVariable(nodeId, e)) {
res->push_back(_variableSpace->getVariable(nodeId, e));
} else {
if (e == 0) {
throw unresolved_output_exception::build("Can't find output variable", nodeId, e);
} else
break;
}
}
}

return res;
Expand Down
10 changes: 7 additions & 3 deletions libnd4j/include/ops/declarable/impl/BroadcastableOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,14 @@ namespace nd4j {
auto outputs = _descriptor->getOutputTypesForOutput(0);
nd4j::DataType dtype = block.dataType(0);
if (block.dataType(0) != nd4j::DataType::BOOL && !(outputs.size() == 1 && outputs[0] == nd4j::DataType::BOOL)) {
if (shape::length(y) > shape::length(x)) {
dtype = DataTypeUtils::pickPairwiseResultType(y, x);
if (Environment::getInstance()->isExperimentalBuild()) {
if (shape::length(y) > shape::length(x)) {
dtype = DataTypeUtils::pickPairwiseResultType(y, x);
} else {
dtype = DataTypeUtils::pickPairwiseResultType(x, y);
}
} else {
dtype = DataTypeUtils::pickPairwiseResultType(x, y);
dtype = ArrayOptions::dataType(x);
}
} else
dtype = nd4j::DataType::BOOL;
Expand Down
2 changes: 1 addition & 1 deletion libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ namespace nd4j {
auto res = new ResultSet();
res->setStatus(result);

for (int e = 0; e < 65536; e++) {
for (int e = 0; e < DataTypeUtils::max<int>(); e++) {
std::pair<int,int> pair(1, e);
if (varSpace.hasVariable(pair)) {
auto var = varSpace.getVariable(pair);
Expand Down
4 changes: 2 additions & 2 deletions libnd4j/include/ops/declarable/impl/DeclarableOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ namespace nd4j {

// checking optionally available outputs
auto varSpace = block.getVariableSpace();
for (int index = 0; index < 65536; index++) {
for (int index = 0; index < DataTypeUtils::max<int>(); index++) {
if (varSpace->hasVariable(block.nodeId(), index)) {
auto var = block.variable(block.nodeId(), index);

Expand Down Expand Up @@ -725,7 +725,7 @@ namespace nd4j {
return arrayList;


for (int e = 0; e < 65536; e++) {
for (int e = 0; e < DataTypeUtils::max<int>(); e++) {
std::pair<int,int> pair(1, e);
if (variableSpace.hasVariable(pair)) {
auto var = variableSpace.getVariable(pair);
Expand Down
11 changes: 10 additions & 1 deletion libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,15 @@ namespace nd4j {
RandomLauncher::applyAlphaDropOut(block.randomGenerator(), z, prob, a, b, pa);
}
break;
case nd4j::random::Linspace: {
auto z = OUTPUT_VARIABLE(0);
auto start = INPUT_VARIABLE(0);
auto finish = INPUT_VARIABLE(1);
auto numOfElements = INPUT_VARIABLE(2);

z->linspace(start->e<double>(0), (finish->e<double>(0) - start->e<double>(0)) / (numOfElements->e<Nd4jLong>(0) - 1.));
}
break;
default: {
nd4j_printf("Unknown random op requested: [%i]\n", opNum);
return ND4J_STATUS_KERNEL_FAILURE;
Expand Down Expand Up @@ -412,7 +421,7 @@ namespace nd4j {
return arrayList;


for (int e = 0; e < 65536; e++) {
for (int e = 0; e < DataTypeUtils::max<int>(); e++) {
std::pair<int,int> pair(1, e);
if (variableSpace.hasVariable(pair)) {
auto var = variableSpace.getVariable(pair);
Expand Down
110 changes: 110 additions & 0 deletions libnd4j/tests_cpu/layers_tests/OneOffTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,5 +225,115 @@ TEST_F(OneOffTests, test_tensor_array_4) {

ASSERT_EQ(e, *z);

delete graph;
}

TEST_F(OneOffTests, test_assert_4) {
auto e = NDArrayFactory::create<Nd4jLong>('c', {2, 2}, {1, 1, 1, 1});

auto graph = GraphExecutioner::importFromFlatBuffers("./resources/assert_type_rank2_int64.fb");
ASSERT_TRUE(graph != nullptr);

graph->printOut();


Nd4jStatus status = GraphExecutioner::execute(graph);
ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1));

auto z = graph->getVariableSpace()->getVariable(1)->getNDArray();
ASSERT_TRUE(z != nullptr);

ASSERT_EQ(e, *z);

delete graph;
}

TEST_F(OneOffTests, test_cond_true_1) {
auto e = NDArrayFactory::create<float>('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f});

auto graph = GraphExecutioner::importFromFlatBuffers("./resources/cond_true.fb");
ASSERT_TRUE(graph != nullptr);

graph->printOut();


Nd4jStatus status = GraphExecutioner::execute(graph);
ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(6));

auto z = graph->getVariableSpace()->getVariable(6)->getNDArray();
ASSERT_TRUE(z != nullptr);

z->printIndexedBuffer("z buffer");

ASSERT_EQ(e, *z);

delete graph;
}
/*
TEST_F(OneOffTests, test_cond_false_1) {
auto e = NDArrayFactory::create<float>('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f});
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/cond_false.fb");
ASSERT_TRUE(graph != nullptr);
graph->printOut();
Nd4jStatus status = GraphExecutioner::execute(graph);
ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(6));
auto z = graph->getVariableSpace()->getVariable(6)->getNDArray();
ASSERT_TRUE(z != nullptr);
z->printIndexedBuffer("z buffer");
ASSERT_EQ(e, *z);
delete graph;
}
*/

TEST_F(OneOffTests, test_identity_n_2) {
auto e = NDArrayFactory::create<float>('c', {2, 3}, {0.77878559f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f});

auto graph = GraphExecutioner::importFromFlatBuffers("./resources/identity_n_2.fb");
ASSERT_TRUE(graph != nullptr);

graph->printOut();


Nd4jStatus status = GraphExecutioner::execute(graph);
ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1));
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1, 1));

auto z = graph->getVariableSpace()->getVariable(1)->getNDArray();
ASSERT_TRUE(z != nullptr);

ASSERT_EQ(e, *z);

delete graph;
}

TEST_F(OneOffTests, test_non2d_1) {
auto e = NDArrayFactory::create<float>('c', {1, 1}, {5.42746449f});

auto graph = GraphExecutioner::importFromFlatBuffers("./resources/non2d_1.fb");
ASSERT_TRUE(graph != nullptr);

graph->printOut();

Nd4jStatus status = GraphExecutioner::execute(graph);
ASSERT_EQ(Status::OK(), status);

auto z = graph->getVariableSpace()->getVariable(3)->getNDArray();
ASSERT_TRUE(z != nullptr);

ASSERT_EQ(e, *z);


delete graph;
}
Binary file not shown.
Binary file added libnd4j/tests_cpu/resources/cond_true.fb
Binary file not shown.
Binary file added libnd4j/tests_cpu/resources/identity_n_2.fb
Binary file not shown.
Binary file added libnd4j/tests_cpu/resources/non2d_1.fb
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -6368,7 +6368,7 @@ public int toFlatArray(FlatBufferBuilder builder) {
}
int shape = FlatArray.createShapeVector(builder, this.shapeInfoDataBuffer().asLong());
int buffer = this.isEmpty() ? 0 : this.dataType() == DataType.UTF8 ? stringBuffer(builder, this.data()) : FlatArray.createBufferVector(builder, this.data().asBytes());
val type = this.isEmpty() ? FlatBuffersMapper.getDataTypeAsByte(Nd4j.dataType()) : FlatBuffersMapper.getDataTypeAsByte(this.data().dataType());
val type = this.isEmpty() ? FlatBuffersMapper.getDataTypeAsByte(this.dataType()) : FlatBuffersMapper.getDataTypeAsByte(this.data().dataType());
int array = FlatArray.createFlatArray(builder, shape, buffer, type, ByteOrder.BE);

return array;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ public static SameDiff getGraphAfterExec(String baseDir, String modelFilename, S
val executioner = new NativeGraphExecutioner();
val results = executioner.executeGraph(graph, configuration);

//graph.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/tensor_array_unstack_sz1_int64_nodynamic_noname_shape2-3.fb"));
//graph.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/non2d_1.fb"));
} else if (executeWith.equals(ExecuteWith.JUST_PRINT)) {
for (String input : inputs.keySet()) {
graph.associateArrayWithVariable(inputs.get(input), graph.variableMap().get(input));
Expand Down

0 comments on commit 2159ea1

Please sign in to comment.