Skip to content

Commit

Permalink
Add scalar exponentiation by a tensor (#774)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #774

Raises a double to the power of a tensor element-wise per exponentiation.

Reviewed By: benoitsteiner

Differential Revision: D32902722

fbshipit-source-id: 12518c140529122176f7568d3955dc6f211c41e9
  • Loading branch information
jacobkahn authored and facebook-github-bot committed Dec 9, 2021
1 parent 437f71c commit fce3034
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 7 deletions.
9 changes: 5 additions & 4 deletions flashlight/fl/tensor/Index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@

namespace fl {

range::range(idx idx) : range(0, idx) {}
range::range(const idx& i) : range(0, i) {}

range::range(idx start, idx end) : range(start, end, /* stride */ 1) {}
range::range(const idx& start, const idx& end)
: range(start, end, /* stride */ 1) {}

range::range(idx start, idx end, Dim stride)
range::range(const idx& start, const idx& end, const Dim stride)
: // fl::end decays to int
start_(std::visit([](Dim idx) -> Dim { return idx; }, start)),
start_(std::visit([](const Dim idx) -> Dim { return idx; }, start)),
// fl::end --> -1, else idx as Dim
end_(
std::holds_alternative<fl::end_t>(end)
Expand Down
6 changes: 3 additions & 3 deletions flashlight/fl/tensor/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,18 @@ class range {
/**
* Construct a range with the indices [0, idx) (i.e. [0, idx - 1])
*/
explicit range(idx idx);
explicit range(const idx& idx);

/**
* Construct a range with the indices [start, end) (i.e. [start, end - 1])
*/
range(idx start, idx end);
range(const idx& start, const idx& end);

/**
* Construct a range with the indices [start, end) (i.e. [start, end - 1])
* with the given stride.
*/
range(idx start, idx end, Dim stride);
range(const idx& start, const idx& end, const Dim stride);

Dim start() const;
Dim end() const;
Expand Down
4 changes: 4 additions & 0 deletions flashlight/fl/tensor/TensorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,10 @@ Tensor power(const Tensor& lhs, const double& rhs) {
return lhs.backend().power(lhs, full(lhs.shape(), rhs));
}

Tensor power(const double& lhs, const Tensor& rhs) {
return rhs.backend().power(full(rhs.shape(), lhs), rhs);
}

/******************************* BLAS ********************************/
Tensor matmul(
const Tensor& lhs,
Expand Down
1 change: 1 addition & 0 deletions flashlight/fl/tensor/TensorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,7 @@ Tensor maximum(const double& lhs, const Tensor& rhs);
*/
Tensor power(const Tensor& lhs, const Tensor& rhs);
Tensor power(const Tensor& lhs, const double& rhs);
Tensor power(const double& lhs, const Tensor& rhs);

/******************************* BLAS ********************************/

Expand Down
4 changes: 4 additions & 0 deletions flashlight/fl/test/tensor/TensorBaseTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,10 @@ TEST(TensorBaseTest, power) {
TEST(TensorBaseTest, powerDouble) {
auto a = fl::full({3, 3}, 2.);
ASSERT_TRUE(allClose(fl::power(a, 3), a * a * a));

auto b = fl::full({3, 3}, 2.);
ASSERT_TRUE(
allClose(fl::power(3, a), fl::full(b.shape(), 3 * 3, fl::dtype::f32)));
}

TEST(TensorBaseTest, floor) {
Expand Down

0 comments on commit fce3034

Please sign in to comment.