Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Fix use-after-free bugs in {RankedTensorType|VectorType}::Builder #68969

Merged
merged 6 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
112 changes: 68 additions & 44 deletions mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,60 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {

namespace mlir {

//===----------------------------------------------------------------------===//
// CopyOnWriteArrayRef<T>
//===----------------------------------------------------------------------===//

// A wrapper around an ArrayRef<T> that copies to a SmallVector<T> on
// modification. This is for use in the mlir::<Type>::Builders.
template <typename T>
class CopyOnWriteArrayRef {
MacDue marked this conversation as resolved.
Show resolved Hide resolved
public:
CopyOnWriteArrayRef(ArrayRef<T> array) : nonOwning(array){};

CopyOnWriteArrayRef &operator=(ArrayRef<T> array) {
nonOwning = array;
owningStorage = {};
return *this;
}

void insert(size_t index, T value) {
SmallVector<T> &vector = ensureCopy();
vector.insert(vector.begin() + index, value);
}

void erase(size_t index) {
SmallVector<T> &vector = ensureCopy();
vector.erase(vector.begin() + index);
}

void set(size_t index, T value) { ensureCopy()[index] = value; }

size_t size() const { return ArrayRef<T>(*this).size(); }

bool empty() const { return ArrayRef<T>(*this).empty(); }

operator ArrayRef<T>() const {
return nonOwning.empty() ? ArrayRef<T>(owningStorage) : nonOwning;
}

private:
SmallVector<T> &ensureCopy() {
// Empty non-owning storage signals the array has been copied to the owning
// storage (or both are empty). Note: `nonOwning` should never reference
// `owningStorage`. This can lead to dangling references if the
// CopyOnWriteArrayRef<T> is copied.
if (!nonOwning.empty()) {
owningStorage = SmallVector<T>(nonOwning);
nonOwning = {};
}
return owningStorage;
}

ArrayRef<T> nonOwning;
SmallVector<T> owningStorage;
};

//===----------------------------------------------------------------------===//
// MemRefType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -274,20 +328,14 @@ class RankedTensorType::Builder {
/// Erase a dim from shape @pos.
Builder &dropDim(unsigned pos) {
assert(pos < shape.size() && "overflow");
if (storage.empty())
storage.append(shape.begin(), shape.end());
storage.erase(storage.begin() + pos);
shape = {storage.data(), storage.size()};
shape.erase(pos);
return *this;
}

/// Insert a val into shape @pos.
Builder &insertDim(int64_t val, unsigned pos) {
assert(pos <= shape.size() && "overflow");
if (storage.empty())
storage.append(shape.begin(), shape.end());
storage.insert(storage.begin() + pos, val);
shape = {storage.data(), storage.size()};
shape.insert(pos, val);
return *this;
}

Expand All @@ -296,9 +344,7 @@ class RankedTensorType::Builder {
}

private:
ArrayRef<int64_t> shape;
// Owning shape data for copy-on-write operations.
SmallVector<int64_t> storage;
CopyOnWriteArrayRef<int64_t> shape;
Type elementType;
Attribute encoding;
};
Expand All @@ -313,27 +359,18 @@ class VectorType::Builder {
public:
/// Build from another VectorType.
explicit Builder(VectorType other)
: shape(other.getShape()), elementType(other.getElementType()),
: elementType(other.getElementType()), shape(other.getShape()),
scalableDims(other.getScalableDims()) {}

/// Build from scratch.
Builder(ArrayRef<int64_t> shape, Type elementType,
unsigned numScalableDims = 0, ArrayRef<bool> scalableDims = {})
: shape(shape), elementType(elementType) {
if (scalableDims.empty())
scalableDims = SmallVector<bool>(shape.size(), false);
else
this->scalableDims = scalableDims;
}
ArrayRef<bool> scalableDims = {})
: elementType(elementType), shape(shape), scalableDims(scalableDims) {}

Builder &setShape(ArrayRef<int64_t> newShape,
ArrayRef<bool> newIsScalableDim = {}) {
if (newIsScalableDim.empty())
scalableDims = SmallVector<bool>(shape.size(), false);
else
scalableDims = newIsScalableDim;

shape = newShape;
scalableDims = newIsScalableDim;
return *this;
}

Expand All @@ -345,25 +382,16 @@ class VectorType::Builder {
/// Erase a dim from shape @pos.
Builder &dropDim(unsigned pos) {
assert(pos < shape.size() && "overflow");
if (storage.empty())
storage.append(shape.begin(), shape.end());
if (storageScalableDims.empty())
storageScalableDims.append(scalableDims.begin(), scalableDims.end());
storage.erase(storage.begin() + pos);
storageScalableDims.erase(storageScalableDims.begin() + pos);
shape = {storage.data(), storage.size()};
scalableDims =
ArrayRef<bool>(storageScalableDims.data(), storageScalableDims.size());
shape.erase(pos);
if (!scalableDims.empty())
scalableDims.erase(pos);
return *this;
}

/// Set a dim in shape @pos to val.
Builder &setDim(unsigned pos, int64_t val) {
if (storage.empty())
storage.append(shape.begin(), shape.end());
assert(pos < storage.size() && "overflow");
storage[pos] = val;
shape = {storage.data(), storage.size()};
assert(pos < shape.size() && "overflow");
shape.set(pos, val);
return *this;
}

Expand All @@ -372,13 +400,9 @@ class VectorType::Builder {
}

private:
ArrayRef<int64_t> shape;
// Owning shape data for copy-on-write operations.
SmallVector<int64_t> storage;
Type elementType;
ArrayRef<bool> scalableDims;
// Owning scalableDims data for copy-on-write operations.
SmallVector<bool> storageScalableDims;
CopyOnWriteArrayRef<int64_t> shape;
CopyOnWriteArrayRef<bool> scalableDims;
};

/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
Expand Down
95 changes: 95 additions & 0 deletions mlir/unittests/IR/ShapedTypeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,99 @@ TEST(ShapedTypeTest, CloneVector) {
VectorType::get(vectorNewShape, vectorNewType));
}

TEST(ShapedTypeTest, VectorTypeBuilder) {
MLIRContext context;
Type f32 = FloatType::getF32(&context);

SmallVector<int64_t> shape{2, 4, 8, 9, 1};
SmallVector<bool> scalableDims{true, false, true, false, false};
VectorType vectorType = VectorType::get(shape, f32, scalableDims);

{
// Drop some dims.
VectorType dropFrontTwoDims =
VectorType::Builder(vectorType).dropDim(0).dropDim(0);
ASSERT_EQ(vectorType.getElementType(), dropFrontTwoDims.getElementType());
ASSERT_EQ(vectorType.getShape().drop_front(2), dropFrontTwoDims.getShape());
ASSERT_EQ(vectorType.getScalableDims().drop_front(2),
dropFrontTwoDims.getScalableDims());
}

{
// Set some dims.
VectorType setTwoDims =
VectorType::Builder(vectorType).setDim(0, 10).setDim(3, 12);
ASSERT_EQ(setTwoDims.getShape(), ArrayRef<int64_t>({10, 4, 8, 12, 1}));
ASSERT_EQ(vectorType.getElementType(), setTwoDims.getElementType());
ASSERT_EQ(vectorType.getScalableDims(), setTwoDims.getScalableDims());
}

{
// Test for bug from:
// https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a
// Constructs a temporary builder, modifies it, copies it to `builder`.
// This used to lead to a use-after-free. Running under sanitizers will
// catch any issues.
VectorType::Builder builder = VectorType::Builder(vectorType).setDim(0, 16);
VectorType newVectorType = VectorType(builder);
ASSERT_EQ(newVectorType.getDimSize(0), 16);
}

{
// Make builder from scratch (without scalable dims) -- this use to lead to
// a use-after-free see: https://github.com/llvm/llvm-project/pull/68969.
// Running under sanitizers will catch any issues.
SmallVector<int64_t> shape{1, 2, 3, 4};
VectorType::Builder builder(shape, f32);
ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(shape));
}

{
// Set vector shape (without scalable dims) -- this use to lead to
// a use-after-free see: https://github.com/llvm/llvm-project/pull/68969.
// Running under sanitizers will catch any issues.
VectorType::Builder builder(vectorType);
SmallVector<int64_t> newShape{2, 2};
builder.setShape(newShape);
ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(newShape));
}
}

TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
MLIRContext context;
Type f32 = FloatType::getF32(&context);

SmallVector<int64_t> shape{2, 4, 8, 16, 32};
RankedTensorType tensorType = RankedTensorType::get(shape, f32);

{
// Drop some dims.
RankedTensorType dropFrontTwoDims =
RankedTensorType::Builder(tensorType).dropDim(0).dropDim(0);
ASSERT_EQ(tensorType.getElementType(), dropFrontTwoDims.getElementType());
ASSERT_EQ(tensorType.getShape().drop_front(2), dropFrontTwoDims.getShape());
}

{
// Insert some dims.
RankedTensorType insertTwoDims =
RankedTensorType::Builder(tensorType).insertDim(7, 2).insertDim(9, 3);
ASSERT_EQ(tensorType.getElementType(), insertTwoDims.getElementType());
ASSERT_EQ(insertTwoDims.getShape(),
ArrayRef<int64_t>({2, 4, 7, 9, 8, 16, 32}));
}

{
// Test for bug from:
// https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a
// Constructs a temporary builder, modifies it, copies it to `builder`.
// This used to lead to a use-after-free. Running under sanitizers will
// catch any issues.
RankedTensorType::Builder builder =
RankedTensorType::Builder(tensorType).dropDim(0);
RankedTensorType newTensorType = RankedTensorType(builder);
ASSERT_EQ(tensorType.getShape().drop_front(), newTensorType.getShape());
}
}

} // namespace