Skip to content

Commit

Permalink
- broadcastable edge case fix
Browse files Browse the repository at this point in the history
- one more test
  • Loading branch information
raver119 committed May 27, 2018
1 parent b89d15e commit a589d2c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
19 changes: 14 additions & 5 deletions libnd4j/blas/cpu/NDArray.cpp
Expand Up @@ -2272,7 +2272,16 @@ template <typename OpName>
void NDArray<T>::applyTrueBroadcast(const NDArray<T>* other, NDArray<T>* target, const bool checkTargetShape, T *extraArgs) const {

if(target == nullptr || other == nullptr)
throw "NDArray::applyTrueBroadcast method: target or other = nullptr !";
throw std::runtime_error("NDArray::applyTrueBroadcast method: target or other = nullptr !");

if (this->isScalar() && !other->isScalar()) {
if (target->isSameShape(other)) {
target->assign(this);
target->template applyPairwiseTransform<OpName>(const_cast<NDArray<T>*>(other), extraArgs);

return;
}
};

const NDArray<T>* min(nullptr), *max(nullptr);
if(this->rankOf() >= other->rankOf()) {
Expand All @@ -2287,9 +2296,9 @@ void NDArray<T>::applyTrueBroadcast(const NDArray<T>* other, NDArray<T>* target,
if(checkTargetShape) {
Nd4jLong* newShapeInfo = nullptr;
if(!ShapeUtils<T>::evalBroadcastShapeInfo(*max, *min, false, newShapeInfo, _workspace)) // the rank of target array must be equal to max->rankOf)()
throw "NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !" ;
throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !");
if(!shape::equalsSoft(target->getShapeInfo(), newShapeInfo))
throw "NDArray::applyTrueBroadcast method: the shape of target array is wrong !";
throw std::runtime_error("NDArray::applyTrueBroadcast method: the shape of target array is wrong !");

// if workspace is not null - do not call delete.
if (_workspace == nullptr)
Expand Down Expand Up @@ -2333,8 +2342,8 @@ NDArray<T>* NDArray<T>::applyTrueBroadcast(const NDArray<T>* other, T *extraArgs

Nd4jLong* newShapeInfo = nullptr;
if(!ShapeUtils<T>::evalBroadcastShapeInfo(*this, *other, true, newShapeInfo, _workspace)) // the rank of new array = max->rankOf)()
throw "NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !" ;
NDArray<T>* result = new NDArray<T>(newShapeInfo, false, this->_workspace);
throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !");
auto result = new NDArray<T>(newShapeInfo, false, this->_workspace);

// if workspace is not null - do not call delete.
if (_workspace == nullptr)
Expand Down
11 changes: 11 additions & 0 deletions libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp
Expand Up @@ -334,4 +334,15 @@ TEST_F(BroadcastableOpsTests, Test_Subtract_3) {

ASSERT_EQ(Status::OK(), result);
ASSERT_TRUE(e.equalsTo(z));
}

TEST_F(BroadcastableOpsTests, Test_Subtract_4) {
NDArray<float> x(1.0f);
NDArray<float> y('c', {2}, {0.0f, 1.0f});
NDArray<float> e('c', {2}, {1.0f, 0.0f});

auto z = x.template applyTrueBroadcast<simdOps::Subtract<float>>(y);

ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(z));
}

0 comments on commit a589d2c

Please sign in to comment.