Skip to content

Commit

Permalink
Add shape indexing protection (flashlight#810)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: flashlight#810

See title. We control the overload, so check bounds in `Shape::operator[]`. Ended up revealing several subtle stack overwriting bugs that caused some nondeterministic behavior which are now resolved upstream.

Differential Revision: D33143590

fbshipit-source-id: afccbdbf97ce4fe1687003c6e612c2a6d89142b1
  • Loading branch information
jacobkahn authored and facebook-github-bot committed Jan 11, 2022
1 parent aa87044 commit f3d484d
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 10 deletions.
21 changes: 15 additions & 6 deletions flashlight/fl/tensor/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <algorithm>
#include <limits>
#include <numeric>
#include <sstream>
#include <stdexcept>

namespace fl {
Expand All @@ -19,31 +20,39 @@ Shape::Shape(std::initializer_list<Dim> d) : Shape(std::vector<Dim>(d)) {}

const Dim kEmptyShapeNumberOfElements = 1;

void Shape::checkDimsOrThrow(const size_t dim) const {
if (dim > ndim() - 1) {
std::stringstream ss;
ss << "Shape index " << std::to_string(dim)
<< " out of bounds for shape with " << std::to_string(dims_.size())
<< " dimensions.";
throw std::invalid_argument(ss.str());
}
}

Dim Shape::elements() const {
if (dims_.size() == 0) {
return kEmptyShapeNumberOfElements;
}
return std::accumulate(dims_.begin(), dims_.end(), 1, std::multiplies<Dim>());
}

size_t Shape::ndim() const {
int Shape::ndim() const {
return dims_.size();
}

Dim Shape::dim(const size_t dim) const {
if (dim >= dims_.size()) {
throw std::invalid_argument(
"fl::Shape::dim - passed dimension is larger than "
"the number of dimensions in the shape");
}
checkDimsOrThrow(dim);
return dims_[dim];
}

Dim& Shape::operator[](const size_t dim) {
checkDimsOrThrow(dim);
return dims_[dim];
}

const Dim& Shape::operator[](const size_t dim) const {
checkDimsOrThrow(dim);
return dims_[dim];
}

Expand Down
8 changes: 7 additions & 1 deletion flashlight/fl/tensor/Shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ class Shape {
// {} is a scalar shape.
std::vector<Dim> dims_;

/**
* Check if a dimension is valid (i.e. in bounds) given the current size of
* the shape. If not valid, throws an exception.
*/
void checkDimsOrThrow(const size_t dim) const;

public:
Shape() = default;
~Shape() = default;
Expand Down Expand Up @@ -74,7 +80,7 @@ class Shape {
/**
* @return Number of dimensions in the shape.
*/
size_t ndim() const;
int ndim() const;

/**
* Get the size of a given dimension in the number of arguments. Throws if the
Expand Down
2 changes: 1 addition & 1 deletion flashlight/fl/tensor/TensorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ Dim Tensor::dim(const size_t dim) const {
return shape().dim(dim);
}

size_t Tensor::ndim() const {
int Tensor::ndim() const {
return shape().ndim();
}

Expand Down
2 changes: 1 addition & 1 deletion flashlight/fl/tensor/TensorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ class Tensor {
*
* @return the number of dimensions
*/
size_t ndim() const;
int ndim() const;

/**
* Returns true if the tensor has zero elements, else false.
Expand Down
2 changes: 1 addition & 1 deletion flashlight/fl/tensor/backend/af/ArrayFireBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ Tensor ArrayFireBackend::pad(
endPadding,
detail::flToAfPadType(type)),
/* numDims = */ // TODO: check
std::max(input.ndim(), padWidths.size()));
std::max(input.ndim(), static_cast<int>(padWidths.size())));
}

/************************** Unary Operators ***************************/
Expand Down
9 changes: 9 additions & 0 deletions flashlight/fl/test/tensor/ShapeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,12 @@ TEST(ShapeTest, Equality) {
ASSERT_EQ(Shape({5, 2, 3, 1}), Shape({5, 2, 3, 1}));
ASSERT_NE(Shape({5, 2, 1, 1}), Shape({5, 2, 1, 4}));
}

TEST(ShapeTest, Indexing) {
auto a = Shape({3, 4, 5, 2});
ASSERT_EQ(a[0], 3);
ASSERT_EQ(a[1], 4);
ASSERT_EQ(a[2], 5);
ASSERT_EQ(a[3], 2);
ASSERT_THROW(a[4], std::invalid_argument);
}

0 comments on commit f3d484d

Please sign in to comment.