Skip to content

Commit

Permalink
Add scalar exponentiation by a tensor (flashlight#774)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: flashlight#774

Raises a double to the power of a tensor element-wise per exponentiation.

Reviewed By: benoitsteiner

Differential Revision: D32902722

fbshipit-source-id: 10e901744fc0f863224c5971ed8fd2ade2656190
  • Loading branch information
jacobkahn authored and facebook-github-bot committed Jan 11, 2022
1 parent 5ac0dea commit b9e1d7f
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 12 deletions.
9 changes: 5 additions & 4 deletions flashlight/fl/tensor/Index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@

namespace fl {

range::range(idx idx) : range(0, idx) {}
range::range(const idx& i) : range(0, i) {}

range::range(idx start, idx end) : range(start, end, /* stride */ 1) {}
range::range(const idx& start, const idx& end)
: range(start, end, /* stride */ 1) {}

range::range(idx start, idx end, Dim stride)
range::range(const idx& start, const idx& end, const Dim stride)
: // fl::end decays to int
start_(std::visit([](Dim idx) -> Dim { return idx; }, start)),
start_(std::visit([](const Dim idx) -> Dim { return idx; }, start)),
// fl::end --> -1, else idx as Dim
end_(
std::holds_alternative<fl::end_t>(end)
Expand Down
6 changes: 3 additions & 3 deletions flashlight/fl/tensor/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,18 @@ class range {
/**
* Construct a range with the indices [0, idx) (i.e. [0, idx - 1])
*/
explicit range(idx idx);
explicit range(const idx& idx);

/**
* Construct a range with the indices [start, end) (i.e. [start, end - 1])
*/
range(idx start, idx end);
range(const idx& start, const idx& end);

/**
* Construct a range with the indices [start, end) (i.e. [start, end - 1])
* with the given stride.
*/
range(idx start, idx end, Dim stride);
range(const idx& start, const idx& end, const Dim stride);

Dim start() const;
Dim end() const;
Expand Down
4 changes: 4 additions & 0 deletions flashlight/fl/tensor/TensorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,10 @@ Tensor power(const Tensor& lhs, const double& rhs) {
return lhs.backend().power(lhs, full(lhs.shape(), rhs));
}

Tensor power(const double& lhs, const Tensor& rhs) {
return rhs.backend().power(full(rhs.shape(), lhs), rhs);
}

/******************************* BLAS ********************************/
Tensor matmul(
const Tensor& lhs,
Expand Down
5 changes: 3 additions & 2 deletions flashlight/fl/tensor/TensorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,9 +411,9 @@ class Tensor {
case dtype::f64:
return scalar<double>();
case dtype::s32:
return scalar<long>();
return scalar<int>();
case dtype::u32:
return scalar<unsigned long>();
return scalar<unsigned int>();
case dtype::b8:
return scalar<char>();
case dtype::u8:
Expand Down Expand Up @@ -1090,6 +1090,7 @@ Tensor maximum(const double& lhs, const Tensor& rhs);
*/
Tensor power(const Tensor& lhs, const Tensor& rhs);
Tensor power(const Tensor& lhs, const double& rhs);
Tensor power(const double& lhs, const Tensor& rhs);

/******************************* BLAS ********************************/

Expand Down
3 changes: 0 additions & 3 deletions flashlight/fl/tensor/backend/af/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,6 @@ af::array condenseIndices(
af::dim4 newDims(1, 1, 1, 1);
unsigned newDimIdx = 0;
for (unsigned i = 0; i < AF_MAX_DIMS; ++i) {
if (dims[i] == 1 && indexTypes && indexTypes.value().size() > i) {
}

// If we're doing an index op (indexTypes is non-empty), then only collapse
// the dimension if it contains an index literal
if (dims[i] == 1 && indexTypes && indexTypes.value().size() > i &&
Expand Down
4 changes: 4 additions & 0 deletions flashlight/fl/test/tensor/TensorBaseTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,10 @@ TEST(TensorBaseTest, power) {
TEST(TensorBaseTest, powerDouble) {
auto a = fl::full({3, 3}, 2.);
ASSERT_TRUE(allClose(fl::power(a, 3), a * a * a));

auto b = fl::full({3, 3}, 2.);
ASSERT_TRUE(
allClose(fl::power(3, a), fl::full(b.shape(), 3 * 3, fl::dtype::f32)));
}

TEST(TensorBaseTest, floor) {
Expand Down

0 comments on commit b9e1d7f

Please sign in to comment.