Skip to content

Commit

Permalink
[WIP] Broadcast tweaks (#7334)
Browse files Browse the repository at this point in the history
This commit adds couple of performance-related tweaks to broadcast operations
  • Loading branch information
raver119 committed Mar 22, 2019
1 parent 72ef734 commit 2a03387
Show file tree
Hide file tree
Showing 18 changed files with 932 additions and 66 deletions.
28 changes: 28 additions & 0 deletions libnd4j/blas/NativeOpExcutioner.h
Expand Up @@ -94,6 +94,20 @@ class ND4J_EXPORT NativeOpExcutioner {
Nd4jLong *tadOnlyShapeInfoZ,
Nd4jLong *tadOffsetsZ);

static void execInverseBroadcast(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadOnlyShapeInfo,
Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,
Nd4jLong *tadOffsetsZ);


static void execBroadcastBool(int opNum,
void *x,
Expand All @@ -109,6 +123,20 @@ class ND4J_EXPORT NativeOpExcutioner {
Nd4jLong *tadOnlyShapeInfoZ,
Nd4jLong *tadOffsetsZ);

static void execInverseBroadcastBool(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadOnlyShapeInfo,
Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,
Nd4jLong *tadOffsetsZ);


/**
*
Expand Down
112 changes: 75 additions & 37 deletions libnd4j/blas/cpu/NDArray.cpp
Expand Up @@ -2983,47 +2983,60 @@ template void NDArray::applyScalar(nd4j::scalar::Ops op, const bool scalar, NDAr
}

//////////////////////////////////////////////////////////////////////////
void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::initializer_list<int> dimensions, const NDArray* tadArray, NDArray* target, void* extraArgs) {
void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::initializer_list<int> dimensions, const NDArray* other, NDArray* target, void* extraArgs) {
std::vector<int> vec(dimensions);
applyBroadcast(op, vec, tadArray, target, extraArgs);
applyBroadcast(op, vec, other, target, extraArgs);
}

//////////////////////////////////////////////////////////////////////////
void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::vector<int>& dimensions, const NDArray* tadArray, NDArray* target, void* extraArgs) {
void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::vector<int>& dimensions, const NDArray* other, NDArray* target, void* extraArgs) {
if (isS())
throw std::runtime_error("NDArray::applyBroadcast: you can't use this method on String array!");
if(((op == broadcast::Divide || op == broadcast::FloorDiv || op == broadcast::FloorMod) && tadArray->isB()) || (op == broadcast::ReverseDivide && this->isB()))
if(((op == broadcast::Divide || op == broadcast::FloorDiv || op == broadcast::FloorMod) && other->isB()) || (op == broadcast::ReverseDivide && this->isB()))
throw std::runtime_error("NDArray::applyBroadcast: you can't divide by array!");

if (dimensions.size() == 0)
return;
auto result = target == nullptr ? this : target;

if(result->_dataType != DataTypeUtils::pickPairwiseResultType(_shapeInfo, tadArray->_shapeInfo))
NDArray *min(nullptr), *max(nullptr);
if(lengthOf() >= other->lengthOf()) {
max = this;
min = const_cast<NDArray*>(other);
}
else {
max = const_cast<NDArray*>(other);
min = this;
}

if(result->_dataType != DataTypeUtils::pickPairwiseResultType(_shapeInfo, other->_shapeInfo))
throw std::invalid_argument("NDArray::applyBroadcast method: wrong type of target array !");
if(!result->isSameShape(this))
throw std::invalid_argument("NDArray::applyBroadcast method: this and target arrays must have the same shape !");
if(!result->isSameShape(max))
throw std::invalid_argument("NDArray::applyBroadcast method: max and target arrays must have the same shape !");

std::vector<int> copy(dimensions);

if (dimensions.size() > 1)
std::sort(copy.begin(), copy.end());

Nd4jLong tadLength = shape::tadLength(this->_shapeInfo, copy.data(), (int) copy.size());
if (tadLength != tadArray->lengthOf())
Nd4jLong tadLength = shape::tadLength(max->_shapeInfo, copy.data(), (int) copy.size());
if (tadLength != min->lengthOf())
throw std::runtime_error("NDArray::applyBroadcast method: tad length mismatch !");

shape::TAD tad;
tad.init(this->_shapeInfo, copy.data(), copy.size());
tad.init(max->_shapeInfo, copy.data(), copy.size());
tad.createTadOnlyShapeInfo();
tad.createOffsets();

// TODO: eventually we want separate tads here
NativeOpExcutioner::execBroadcast(op, this->_buffer, this->_shapeInfo, tadArray->_buffer, tadArray->_shapeInfo, result->_buffer, result->_shapeInfo, copy.data(), (int)copy.size(), tad.tadOnlyShapeInfo, tad.tadOffsets, tad.tadOnlyShapeInfo, tad.tadOffsets);
if(max == this)
// TODO: eventually we want separate tads here
NativeOpExcutioner::execBroadcast(op, this->_buffer, this->_shapeInfo, other->_buffer, other->_shapeInfo, result->_buffer, result->_shapeInfo, copy.data(), (int)copy.size(), tad.tadOnlyShapeInfo, tad.tadOffsets, tad.tadOnlyShapeInfo, tad.tadOffsets);
else
NativeOpExcutioner::execInverseBroadcast(op, this->_buffer, this->_shapeInfo, other->_buffer, other->_shapeInfo, result->_buffer, result->_shapeInfo, copy.data(), (int)copy.size(), tad.tadOnlyShapeInfo, tad.tadOffsets, tad.tadOnlyShapeInfo, tad.tadOffsets);
}

//////////////////////////////////////////////////////////////////////////
void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector<int>& dimensions, const NDArray* tadArray, NDArray* target, void* extraArgs) {
void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector<int>& dimensions, const NDArray* other, NDArray* target, void* extraArgs) {
if (isS())
throw std::runtime_error("NDArray::applyBroadcast BoolOps: you can't use this method on String array!");

Expand All @@ -3032,29 +3045,42 @@ template void NDArray::applyScalar(nd4j::scalar::Ops op, const bool scalar, NDAr

auto result = target == nullptr ? this : target;

NDArray *min(nullptr), *max(nullptr);
if(lengthOf() >= other->lengthOf()) {
max = this;
min = const_cast<NDArray*>(other);
}
else {
max = const_cast<NDArray*>(other);
min = this;
}

if(result->_dataType != DataType::BOOL)
throw std::invalid_argument("NDArray::applyBroadcast bool method: type of target array must be BOOL!");
if(!result->isSameShape(this))
throw std::invalid_argument("NDArray::applyBroadcast bool method: this and other arrays must have the same shape !");
if(_dataType != tadArray->_dataType)
throw std::invalid_argument("NDArray::applyBroadcast bool method: this and tad arrays must have the same type !");
if(!result->isSameShape(max))
throw std::invalid_argument("NDArray::applyBroadcast bool method: max and target arrays must have the same shape !");
if(_dataType != other->_dataType)
throw std::invalid_argument("NDArray::applyBroadcast bool method: this and other arrays must have the same type !");

std::vector<int> copy(dimensions);

if (dimensions.size() > 1)
std::sort(copy.begin(), copy.end());

Nd4jLong tadLength = shape::tadLength(this->_shapeInfo, copy.data(), (int) copy.size());
if (tadLength != tadArray->lengthOf())
Nd4jLong tadLength = shape::tadLength(max->_shapeInfo, copy.data(), (int) copy.size());
if (tadLength != other->lengthOf())
throw std::runtime_error("Tad length mismatch");

shape::TAD tad;
tad.init(this->_shapeInfo, copy.data(), copy.size());
tad.init(max->_shapeInfo, copy.data(), copy.size());
tad.createTadOnlyShapeInfo();
tad.createOffsets();

// TODO: eventually we want separate tads here
NativeOpExcutioner::execBroadcastBool(op, this->_buffer, this->_shapeInfo, tadArray->_buffer, tadArray->_shapeInfo, result->_buffer, result->_shapeInfo, copy.data(), (int)copy.size(), tad.tadOnlyShapeInfo, tad.tadOffsets, tad.tadOnlyShapeInfo, tad.tadOffsets);
if(this == max)
NativeOpExcutioner::execBroadcastBool(op, this->_buffer, this->_shapeInfo, other->_buffer, other->_shapeInfo, result->_buffer, result->_shapeInfo, copy.data(), (int)copy.size(), tad.tadOnlyShapeInfo, tad.tadOffsets, tad.tadOnlyShapeInfo, tad.tadOffsets);
else
NativeOpExcutioner::execInverseBroadcastBool(op, this->_buffer, this->_shapeInfo, other->_buffer, other->_shapeInfo, result->_buffer, result->_shapeInfo, copy.data(), (int)copy.size(), tad.tadOnlyShapeInfo, tad.tadOffsets, tad.tadOnlyShapeInfo, tad.tadOffsets);
}

//////////////////////////////////////////////////////////////////////////
Expand All @@ -3068,8 +3094,7 @@ template void NDArray::applyScalar(nd4j::scalar::Ops op, const bool scalar, NDAr
//Edge case: broadcastOp(x,empty) -> empty; no-op
return;
}



if (isScalar()) {
NDArray temp(target->_shapeInfo, _dataType, false, _workspace);
temp.assign(this);
Expand All @@ -3081,14 +3106,14 @@ template void NDArray::applyScalar(nd4j::scalar::Ops op, const bool scalar, NDAr
return;
}

const NDArray* min(nullptr), *max(nullptr);
NDArray* min(nullptr), *max(nullptr);
if(this->rankOf() >= other->rankOf()) {
max = this;
min = other;
max = const_cast<NDArray*>(this);
min = const_cast<NDArray*>(other);
}
else {
max = other;
min = this;
max = const_cast<NDArray*>(other);
min = const_cast<NDArray*>(this);
}

if(checkTargetShape) {
Expand All @@ -3105,7 +3130,14 @@ template void NDArray::applyScalar(nd4j::scalar::Ops op, const bool scalar, NDAr
delete[] newShapeInfo;
}

std::vector<int> maxTadAxes = ShapeUtils::tadAxesForSimpleBroadcast(*max, *min);
if(!maxTadAxes.empty()) {
max->applyBroadcast(op.b, maxTadAxes, min, target, extraArgs);
return;
}

NDArray* pTarget = (max->_dataType == target->_dataType) ? target : new NDArray(target->ordering(), target->getShapeAsVector(), max->_dataType, target->_workspace);

// check whether max array has to be tiled
if(!max->isSameShape(target)) {
// evaluate repeating dimensions for tile operation
Expand All @@ -3125,11 +3157,10 @@ template void NDArray::applyScalar(nd4j::scalar::Ops op, const bool scalar, NDAr
product *= repeatMin[i-1];
}

auto pMin = const_cast<NDArray *>(min);
auto pMin = min;
if(product != 1 )
pMin = new NDArray(min->tile(repeatMin));


std::vector<int> sameDims = ShapeUtils::getDimsWithSameShape(*target, *pMin);

if(max == this) {
Expand Down Expand Up @@ -3181,14 +3212,14 @@ template void NDArray::applyScalar(nd4j::scalar::Ops op, const bool scalar, NDAr
return;
}

const NDArray* min(nullptr), *max(nullptr);
NDArray* min(nullptr), *max(nullptr);
if(this->rankOf() >= other->rankOf()) {
max = this;
min = other;
max = const_cast<NDArray*>(this);
min = const_cast<NDArray*>(other);
}
else {
max = other;
min = this;
max = const_cast<NDArray*>(other);
min = const_cast<NDArray*>(this);
}

if(checkTargetShape) {
Expand All @@ -3203,7 +3234,14 @@ template void NDArray::applyScalar(nd4j::scalar::Ops op, const bool scalar, NDAr
delete[] newShapeInfo;
}

NDArray* pTarget = (max->_dataType == target->_dataType) ? target : new NDArray(target->ordering(), target->getShapeAsVector(), max->_dataType, target->_workspace);
std::vector<int> maxTadAxes = ShapeUtils::tadAxesForSimpleBroadcast(*max, *min);
if(!maxTadAxes.empty()) {
const_cast<NDArray*>(this)->applyBroadcast(op.b, maxTadAxes, other, target, extraArgs);
return;
}

NDArray* pTarget = (max->_dataType == target->_dataType) ? target : new NDArray(target->ordering(), target->getShapeAsVector(), max->_dataType, target->_workspace);

// check whether max array has to be tiled
if(!max->isSameShape(target)) {
// evaluate repeating dimensions for tile operation
Expand All @@ -3224,7 +3262,7 @@ template void NDArray::applyScalar(nd4j::scalar::Ops op, const bool scalar, NDAr
product *= repeatMin[i-1];
}

auto pMin = const_cast<NDArray *>(min);
auto pMin = min;
if(product != 1 )
pMin = new NDArray(min->tile(repeatMin));

Expand Down

0 comments on commit 2a03387

Please sign in to comment.