Skip to content

Commit

Permalink
Add Tensor broadcasting for binary ops (flashlight#775)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: flashlight#775

Internally use ArrayFire's `af::batchFunc` to add Tensor broadcasting behavior for all binary operations. This doesn't add broadcasting to some operations (e.g. in place ops) or weird indexing assignment ops, but we can add those later after some discussion on behavior — binary op behavior is pretty unambiguous.

Differential Revision: D32678121

fbshipit-source-id: 209964ae03b6ec4ca8958cfdec99dc0cbabf7122
  • Loading branch information
jacobkahn authored and facebook-github-bot committed Jan 11, 2022
1 parent d385b21 commit 5ac0dea
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 29 deletions.
74 changes: 54 additions & 20 deletions flashlight/fl/tensor/backend/af/ArrayFireBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
#include <af/random.h>

#include <algorithm>
#include <cstdlib>
#include <numeric>
#include <optional>
#include <sstream>
#include <stdexcept>

#include "flashlight/fl/tensor/TensorBase.h"
Expand All @@ -27,7 +30,7 @@
namespace fl {
namespace {

typedef af::array (*reduceFunc_t)(const af::array&, const int);
using reduceFunc_t = af::array (*)(const af::array&, const int);

template <typename T = reduceFunc_t>
af::array afReduceAxes(
Expand Down Expand Up @@ -72,6 +75,49 @@ bool isAllAxisReduction(const Tensor& input, const std::vector<int>& axes) {
return true;
}

bool canBroadcast(const Shape& lhs, const Shape& rhs) {
unsigned nDim = std::max(lhs.ndim(), rhs.ndim());

for (unsigned i = 0; i < nDim; ++i) {
if (i + 1 > lhs.ndim() || i + 1 > rhs.ndim()) {
// One Shape has more dimensions than the other - will broadcast to the
// smaller tensor
continue;
}
if (lhs[i] != rhs[i] && lhs[i] != 1 && rhs[i] != 1) {
return false;
}
}
return true;
}

// A binary operation on two ArrayFire arrays
using binaryOpFunc_t =
af::array (*)(const af::array& lhs, const af::array& rhs);

Tensor doBinaryOpOrBroadcast(
const Tensor& lhs,
const Tensor& rhs,
binaryOpFunc_t func) {
// Dims are the same or scalar <> 1-el tensor - no broadcasting
if (lhs.shape() == rhs.shape() || (lhs.size() <= 1 && rhs.size() <= 1)) {
return toTensor<ArrayFireTensor>(
func(toArray(lhs), toArray(rhs)), lhs.ndim());
}

if (canBroadcast(lhs.shape(), rhs.shape())) {
return toTensor<ArrayFireTensor>(
af::batchFunc(toArray(lhs), toArray(rhs), func),
std::max(lhs.ndim(), rhs.ndim()));
} else {
std::stringstream ss;
ss << "doBinaryOpOrBroadcast: cannot perform operation "
"or broadcasting with tensors of shapes "
<< lhs.shape() << " and " << rhs.shape() << " - dimension mismatch.";
throw std::invalid_argument(ss.str());
}
}

} // namespace

ArrayFireBackend::ArrayFireBackend() {
Expand Down Expand Up @@ -498,19 +544,10 @@ Tensor ArrayFireBackend::argsort(

// Operations on fl::Tensor call the respective operator overloads that are
// already defined on af::arrays
#define FL_AF_BINARY_OP_DEF(OP, FUNC) \
Tensor ArrayFireBackend::FUNC(const Tensor& lhs, const Tensor& rhs) { \
if (lhs.ndim() != rhs.ndim()) { \
std::stringstream ss; \
ss << "ArrayFireTensor arguments to operator " << std::string(#OP) \
<< " (" << std::string(#FUNC) << ") " \
<< "have a differing number of dimensions " << lhs.shape() << " and " \
<< rhs.shape(); \
throw std::invalid_argument(ss.str()); \
} \
return toTensor<ArrayFireTensor>( \
toArray(lhs) OP toArray(rhs), lhs.ndim()); \
} \
#define FL_AF_BINARY_OP_DEF(OP, FUNC) \
Tensor ArrayFireBackend::FUNC(const Tensor& lhs, const Tensor& rhs) { \
return doBinaryOpOrBroadcast(lhs, rhs, af::operator OP); \
} \
FL_AF_BINARY_OP_LITERALS_DEF(FUNC, OP);

// Definitions
Expand Down Expand Up @@ -539,18 +576,15 @@ FL_AF_BINARY_OP_DEF(>>, rShift);
#undef FL_AF_BINARY_OP_LITERALS_DEF

Tensor ArrayFireBackend::minimum(const Tensor& lhs, const Tensor& rhs) {
return toTensor<ArrayFireTensor>(
af::min(toArray(lhs), toArray(rhs)), lhs.ndim());
return doBinaryOpOrBroadcast(lhs, rhs, af::min);
}

Tensor ArrayFireBackend::maximum(const Tensor& lhs, const Tensor& rhs) {
return toTensor<ArrayFireTensor>(
af::max(toArray(lhs), toArray(rhs)), lhs.ndim());
return doBinaryOpOrBroadcast(lhs, rhs, af::max);
}

Tensor ArrayFireBackend::power(const Tensor& lhs, const Tensor& rhs) {
return toTensor<ArrayFireTensor>(
af::pow(toArray(lhs), toArray(rhs)), lhs.ndim());
return doBinaryOpOrBroadcast(lhs, rhs, af::pow);
}

/************************** BLAS ***************************/
Expand Down
16 changes: 11 additions & 5 deletions flashlight/fl/tensor/backend/af/ArrayFireTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "flashlight/fl/tensor/backend/af/ArrayFireTensor.h"

#include <cassert>
#include <memory>
#include <stdexcept>
#include <utility>
Expand Down Expand Up @@ -245,13 +246,16 @@ Tensor ArrayFireTensor::index(const std::vector<Index>& indices) {
"ArrayFire tensors support up to 4 dimensions.");
}

// If indexing with a single element and it's an Array, don't use spans
// TODO: vet and stress test this a lot more/add proper support for
// multi-tensor
bool tensorIndex = indices.size() == 1 &&
indices.front().type() == detail::IndexType::Tensor;
// If indexing by a single element and it's a tensor with the same number of
// indices as the array being indexed, do a flat index as this is probably a
// filter-based index (for example: a(a < 5)).
bool completeTensorIndex = indices.size() == 1 &&
indices.front().type() == detail::IndexType::Tensor &&
indices.front().get<Tensor>().size() == getHandle().elements();
std::vector<af::index> afIndices;
if (tensorIndex) {
if (completeTensorIndex) {
afIndices = {af::index(0)};
} else {
afIndices = {af::span, af::span, af::span, af::span}; // implicit spans
Expand All @@ -277,9 +281,11 @@ Tensor ArrayFireTensor::index(const std::vector<Index>& indices) {

getHandle(); // if this tensor was a view, run indexing and promote

assert(afIndices.size() == indexTypes.size());
// Compute numDums for the new Tensor
unsigned newNumDims = numDims();
if (tensorIndex) {

if (completeTensorIndex) {
// TODO/FIXME: compute this based on the number of els in the indexing
// tensor(s)
newNumDims = 1;
Expand Down
17 changes: 14 additions & 3 deletions flashlight/fl/test/tensor/IndexTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ TEST(IndexTest, flat) {
.scalar<float>(),
i + 1 - 10);
}

// Range flat assignment
auto rA = fl::rand({6});
a.flat(fl::range(1, 7)) = rA;
ASSERT_TRUE(allClose(rA, a.flatten()(fl::range(1, 7))));
}

TEST(IndexTest, TensorIndex) {
Expand All @@ -223,13 +228,19 @@ TEST(IndexTest, TensorIndex) {
auto i = fl::arange({10}, 0, fl::dtype::u32);
auto b = fl::rand({20, 20});
auto ref = b;
ASSERT_TRUE(allClose(b(i), b(fl::range(10), 0)));
ASSERT_EQ(b(i).shape(), b(fl::range(10)).shape());
ASSERT_TRUE(allClose(b(i), b(fl::range(10))));

b(i) += 3.;
ASSERT_TRUE(allClose(b(i), b(fl::range(10), 0)));
ASSERT_TRUE(allClose(b(i), b(fl::range(10))));
ASSERT_TRUE(allClose(b(i), (ref + 3)(i)));
b(i) += fl::full({(Dim)i.size()}, 10.);
b(i) += fl::full({(Dim)i.size(), b.dim(1)}, 10.);
ASSERT_EQ(b(i).shape(), (ref + 13)(i).shape());
ASSERT_TRUE(allClose(b(i), (ref + 13)(i)));

// Tensor index a > 1D tensor
auto c = fl::rand({10, 10, 10});
ASSERT_EQ(c(fl::arange({5})).shape(), Shape({5, 10, 10}));
}

TEST(IndexTest, ExpressionIndex) {
Expand Down
121 changes: 120 additions & 1 deletion flashlight/fl/test/tensor/TensorBaseTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
*/

#include <cmath>
#include <functional>
#include <sstream>
#include <stdexcept>
#include <vector>

#include <gtest/gtest.h>

Expand Down Expand Up @@ -349,6 +353,116 @@ TEST(TensorBaseTest, maximum) {
ASSERT_TRUE(allClose(fl::maximum(b, 1).astype(a.type()), b));
}

using binaryOpFunc_t = Tensor (*)(const Tensor& lhs, const Tensor& rhs);

TEST(TensorBaseTest, broadcasting) {
// Collection of {lhs, rhs, tileShapeLhs, tileShapeRhs} corresponding to
// broadcasting [lhs] to [rhs] by tiling by the the respective tileShapes
struct ShapeData {
Shape lhs; // broadcast from
Shape rhs; // broadcast to
Shape tileShapeLhs;
Shape tileShapeRhs;
};
std::vector<ShapeData> shapes = {
{{3, 1}, {3, 3}, {1, 3}, {1, 1}},
{{3}, {3, 3}, {1, 3}, {1, 1}},
{{3, 1, 4}, {3, 6, 4}, {1, 6, 1}, {1, 1, 1}},
{{3, 1, 4, 1}, {3, 2, 4, 5}, {1, 2, 1, 5}, {1, 1, 1, 1}},
{{1, 10}, {8, 10}, {8, 1}, {1, 1}},
{{2, 1, 5, 1}, {2, 3, 5, 3}, {1, 3, 1, 3}, {1, 1, 1, 1}},
{{3, 1, 2, 1}, {1, 4, 1, 5}, {1, 4, 1, 5}, {3, 1, 2, 1}},
{{3, 2, 1}, {3, 1, 4, 1}, {1, 1, 4}, {1, 2, 1, 1}}};

std::unordered_map<binaryOpFunc_t, std::string> functions = {
{fl::minimum, "minimum"},
{fl::maximum, "maximum"},
{fl::power, "power"},
{fl::add, "add"},
{fl::add, "add"},
{fl::sub, "sub"},
{fl::mul, "mul"},
{fl::div, "div"},
{fl::eq, "eq"},
{fl::neq, "neq"},
{fl::lessThan, "lessThan"},
{fl::lessThanEqual, "lessThanEqual"},
{fl::greaterThan, "greaterThan"},
{fl::greaterThanEqual, "greaterThanEqual"},
{fl::logicalOr, "logicalOr"},
{fl::logicalAnd, "logicalAnd"},
{fl::mod, "mod"},
{fl::bitwiseOr, "bitwiseOr"},
{fl::bitwiseXor, "bitwiseXor"},
{fl::lShift, "lShift"},
{fl::rShift, "rShift"}};

auto doBinaryOp = [](const Tensor& lhs,
const Tensor& rhs,
const Shape& tileShapeLhs,
const Shape& tileShapeRhs,
binaryOpFunc_t func) -> std::pair<Tensor, Tensor> {
assert(lhs.ndim() <= rhs.ndim());
return {
func(lhs, rhs), func(tile(lhs, tileShapeLhs), tile(rhs, tileShapeRhs))};
};

auto computeBroadcastShape = [](const Shape& lhsShape,
const Shape& rhsShape) -> Shape {
unsigned maxnDim = std::max(lhsShape.ndim(), rhsShape.ndim());
Shape outShape{std::vector<Dim>(maxnDim)};
for (unsigned i = 0; i < maxnDim; ++i) {
if (i > lhsShape.ndim() - 1) {
outShape[i] = rhsShape[i];
} else if (i > rhsShape.ndim() - 1) {
outShape[i] = lhsShape[i];
} else if (lhsShape[i] == 1) {
outShape[i] = rhsShape[i];
} else if (rhsShape[i] == 1) {
outShape[i] = lhsShape[i];
} else if (lhsShape[i] == rhsShape[i]) {
outShape[i] = lhsShape[i];
} else if (lhsShape[i] != rhsShape[i]) {
throw std::runtime_error(
"computeBroadcastShape - cannot broadcast shape");
}
}
return outShape;
};

for (auto funcp : functions) {
for (auto& shapeData : shapes) {
auto lhs = (fl::rand(shapeData.lhs) * 10).astype(fl::dtype::s32);
auto rhs = (fl::rand(shapeData.rhs) * 10).astype(fl::dtype::s32);

auto [actualOut, expectedOut] = doBinaryOp(
lhs,
rhs,
shapeData.tileShapeLhs,
shapeData.tileShapeRhs,
funcp.first);

Shape expectedShape = computeBroadcastShape(shapeData.lhs, shapeData.rhs);

std::stringstream ss;
ss << "lhs: " << shapeData.lhs << " rhs: " << shapeData.rhs
<< " function: " << funcp.second;
auto testData = ss.str();

ASSERT_EQ(actualOut.shape(), expectedShape) << testData;
ASSERT_TRUE(allClose(actualOut, expectedOut)) << testData;
}

// Scalar broadcasting
const double scalarVal = 4;
const Shape inShape = {2, 3, 4};
const auto lhs = fl::rand(inShape).astype(fl::dtype::s32);
const auto rhs = fl::fromScalar(scalarVal, fl::dtype::s32);
const auto rhsTiled = fl::full(inShape, scalarVal, fl::dtype::s32);
ASSERT_TRUE(allClose(funcp.first(lhs, rhs), funcp.first(lhs, rhsTiled)));
}
}

TEST(TensorBaseTest, argmin) {
Tensor in = Tensor::fromVector<float>({2, 3}, {4, 8, 6, 3, 5, 9});
auto a0 = fl::argmin(in, 0);
Expand Down Expand Up @@ -572,7 +686,12 @@ TEST(TensorBaseTest, topk) {
allClose(values, Tensor::fromVector<float>({3, 2}, {9, 8, 7, 9, 8, 7})));

fl::topk(
values, indices, a, /* k = */ 4, /* axis = */ 0, fl::SortMode::Ascending);
values,
indices,
a,
/* k = */ 4,
/* axis = */ 0,
fl::SortMode::Ascending);
ASSERT_TRUE(allClose(
values, Tensor::fromVector<float>({4, 2}, {0, 1, 2, 3, 0, 1, 2, 3})));
}
Expand Down

0 comments on commit 5ac0dea

Please sign in to comment.