87 changes: 85 additions & 2 deletions mlir/include/mlir/IR/OperationSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,14 @@ class OperandStorage final
/// 'values'.
void setOperands(Operation *owner, ValueRange values);

/// Erase an operand held by the storage.
void eraseOperand(unsigned index);
/// Replace the operands beginning at 'start' and ending at 'start' + 'length'
/// with the ones provided in 'operands'. 'operands' may be smaller or larger
/// than the range pointed to by 'start'+'length'.
void setOperands(Operation *owner, unsigned start, unsigned length,
ValueRange operands);

/// Erase the operands held by the storage within the given range.
void eraseOperands(unsigned start, unsigned length);

/// Get the operation operands held by the storage.
MutableArrayRef<OpOperand> getOperands() {
Expand Down Expand Up @@ -653,6 +659,69 @@ class OperandRange final : public llvm::detail::indexed_accessor_range_base<
friend RangeBaseT;
};

//===----------------------------------------------------------------------===//
// MutableOperandRange

/// This class provides a mutable adaptor for a range of operands. It allows for
/// setting, inserting, and erasing operands from the given range.
class MutableOperandRange {
public:
/// A pair of a named attribute corresponding to an operand segment attribute,
/// and the index within that attribute. The attribute should correspond to an
/// i32 DenseElementsAttr.
using OperandSegment = std::pair<unsigned, NamedAttribute>;

/// Construct a new mutable range from the given operand, operand start index,
/// and range length. `operandSegments` is an optional set of operand segments
/// to be updated when mutating the operand list.
MutableOperandRange(Operation *owner, unsigned start, unsigned length,
ArrayRef<OperandSegment> operandSegments = llvm::None);
MutableOperandRange(Operation *owner);

/// Slice this range into a sub range, with the additional operand segment.
MutableOperandRange slice(unsigned subStart, unsigned subLen,
Optional<OperandSegment> segment = llvm::None);

/// Append the given values to the range.
void append(ValueRange values);

/// Assign this range to the given values.
void assign(ValueRange values);

/// Assign the range to the given value.
void assign(Value value);

/// Erase the operands within the given sub-range.
void erase(unsigned subStart, unsigned subLen = 1);

/// Clear this range and erase all of the operands.
void clear();

/// Returns the current size of the range.
unsigned size() const { return length; }

/// Allow implicit conversion to an OperandRange.
operator OperandRange() const;

/// Returns the owning operation.
Operation *getOwner() const { return owner; }

private:
/// Update the length of this range to the one provided.
void updateLength(unsigned newLength);

/// The owning operation of this range.
Operation *owner;

/// The start index of the operand range within the owner operand list, and
/// the length starting from `start`.
unsigned start, length;

/// Optional set of operand segments that should be updated when mutating the
/// length of this range.
SmallVector<std::pair<unsigned, NamedAttribute>, 1> operandSegments;
};

//===----------------------------------------------------------------------===//
// ResultRange

Expand Down Expand Up @@ -752,6 +821,20 @@ class ValueRange final
/// Allow access to `offset_base` and `dereference_iterator`.
friend RangeBaseT;
};

//===----------------------------------------------------------------------===//
// Operation Equivalency
//===----------------------------------------------------------------------===//

/// This class provides utilities for computing if two operations are
/// equivalent.
struct OperationEquivalence {
/// Compute a hash for the given operation.
static llvm::hash_code computeHash(Operation *op);

/// Compare two operations and return if they are equivalent.
static bool isEquivalentTo(Operation *lhs, Operation *rhs);
};
} // end namespace mlir

namespace llvm {
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/IR/UseDefLists.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ template <typename DerivedT, typename IRValueTy> class IROperand {
other.back = nullptr;
nextUse = nullptr;
back = nullptr;
insertIntoCurrent();
if (value)
insertIntoCurrent();
return *this;
}

Expand Down
5 changes: 0 additions & 5 deletions mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,6 @@ class BranchOpInterface;
//===----------------------------------------------------------------------===//

namespace detail {
/// Erase an operand from a branch operation that is used as a successor
/// operand. `operandIndex` is the operand within `operands` to be erased.
void eraseBranchSuccessorOperand(OperandRange operands, unsigned operandIndex,
Operation *op);

/// Return the `BlockArgument` corresponding to operand `operandIndex` in some
/// successor if `operandIndex` is within the range of `operands`, or None if
/// `operandIndex` isn't a successor operand index.
Expand Down
26 changes: 11 additions & 15 deletions mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -27,29 +27,25 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
}];
let methods = [
InterfaceMethod<[{
Returns a set of values that correspond to the arguments to the
Returns a mutable range of operands that correspond to the arguments of
successor at the given index. Returns None if the operands to the
successor are non-materialized values, i.e. they are internal to the
operation.
}],
"Optional<OperandRange>", "getSuccessorOperands", (ins "unsigned":$index)
"Optional<MutableOperandRange>", "getMutableSuccessorOperands",
(ins "unsigned":$index)
>,
InterfaceMethod<[{
Return true if this operation can erase an operand to a successor block.
}],
"bool", "canEraseSuccessorOperand"
>,
InterfaceMethod<[{
Erase the operand at `operandIndex` from the `index`-th successor. This
should only be called if `canEraseSuccessorOperand` returns true.
Returns a range of operands that correspond to the arguments of
successor at the given index. Returns None if the operands to the
successor are non-materialized values, i.e. they are internal to the
operation.
}],
"void", "eraseSuccessorOperand",
(ins "unsigned":$index, "unsigned":$operandIndex), [{}],
/*defaultImplementation=*/[{
"Optional<OperandRange>", "getSuccessorOperands",
(ins "unsigned":$index), [{}], [{
ConcreteOp *op = static_cast<ConcreteOp *>(this);
Optional<OperandRange> operands = op->getSuccessorOperands(index);
assert(operands && "unable to query operands for successor");
detail::eraseBranchSuccessorOperand(*operands, operandIndex, *op);
auto operands = op->getMutableSuccessorOperands(index);
return operands ? Optional<OperandRange>(*operands) : llvm::None;
}]
>,
InterfaceMethod<[{
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/TableGen/OpTrait.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"
#include <vector>

namespace llvm {
class Init;
Expand Down Expand Up @@ -105,6 +106,10 @@ class InterfaceOpTrait : public OpTrait {

// Whether the declaration of methods for this trait should be emitted.
bool shouldDeclareMethods() const;

// Returns the methods that should always be declared if this interface is
// emitting declarations.
std::vector<StringRef> getAlwaysDeclaredMethods() const;
};

} // end namespace tblgen
Expand Down
21 changes: 9 additions & 12 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,24 +160,22 @@ static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) {
// LLVM::BrOp
//===----------------------------------------------------------------------===//

Optional<OperandRange> BrOp::getSuccessorOperands(unsigned index) {
Optional<MutableOperandRange>
BrOp::getMutableSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
return getOperands();
return destOperandsMutable();
}

bool BrOp::canEraseSuccessorOperand() { return true; }

//===----------------------------------------------------------------------===//
// LLVM::CondBrOp
//===----------------------------------------------------------------------===//

Optional<OperandRange> CondBrOp::getSuccessorOperands(unsigned index) {
Optional<MutableOperandRange>
CondBrOp::getMutableSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return index == 0 ? trueDestOperands() : falseDestOperands();
return index == 0 ? trueDestOperandsMutable() : falseDestOperandsMutable();
}

bool CondBrOp::canEraseSuccessorOperand() { return true; }

//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::LoadOp.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -257,13 +255,12 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
/// LLVM::InvokeOp
///===---------------------------------------------------------------------===//

Optional<OperandRange> InvokeOp::getSuccessorOperands(unsigned index) {
Optional<MutableOperandRange>
InvokeOp::getMutableSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return index == 0 ? normalDestOperands() : unwindDestOperands();
return index == 0 ? normalDestOperandsMutable() : unwindDestOperandsMutable();
}

bool InvokeOp::canEraseSuccessorOperand() { return true; }

static LogicalResult verify(InvokeOp op) {
if (op.getNumResults() > 1)
return op.emitOpError("must have 0 or 1 result");
Expand Down
17 changes: 7 additions & 10 deletions mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -987,26 +987,23 @@ static LogicalResult verify(spirv::BitcastOp bitcastOp) {
// spv.BranchOp
//===----------------------------------------------------------------------===//

Optional<OperandRange> spirv::BranchOp::getSuccessorOperands(unsigned index) {
Optional<MutableOperandRange>
spirv::BranchOp::getMutableSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
return getOperands();
return targetOperandsMutable();
}

bool spirv::BranchOp::canEraseSuccessorOperand() { return true; }

//===----------------------------------------------------------------------===//
// spv.BranchConditionalOp
//===----------------------------------------------------------------------===//

Optional<OperandRange>
spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
Optional<MutableOperandRange>
spirv::BranchConditionalOp::getMutableSuccessorOperands(unsigned index) {
assert(index < 2 && "invalid successor index");
return index == kTrueIndex ? getTrueBlockArguments()
: getFalseBlockArguments();
return index == kTrueIndex ? trueTargetOperandsMutable()
: falseTargetOperandsMutable();
}

bool spirv::BranchConditionalOp::canEraseSuccessorOperand() { return true; }

static ParseResult parseBranchConditionalOp(OpAsmParser &parser,
OperationState &state) {
auto &builder = parser.getBuilder();
Expand Down
15 changes: 7 additions & 8 deletions mlir/lib/Dialect/StandardOps/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -677,13 +677,12 @@ void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
context);
}

Optional<OperandRange> BranchOp::getSuccessorOperands(unsigned index) {
Optional<MutableOperandRange>
BranchOp::getMutableSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
return getOperands();
return destOperandsMutable();
}

bool BranchOp::canEraseSuccessorOperand() { return true; }

Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { return dest(); }

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1021,13 +1020,13 @@ void CondBranchOp::getCanonicalizationPatterns(
SimplifyCondBranchIdenticalSuccessors>(context);
}

Optional<OperandRange> CondBranchOp::getSuccessorOperands(unsigned index) {
Optional<MutableOperandRange>
CondBranchOp::getMutableSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return index == trueIndex ? getTrueOperands() : getFalseOperands();
return index == trueIndex ? trueDestOperandsMutable()
: falseDestOperandsMutable();
}

bool CondBranchOp::canEraseSuccessorOperand() { return true; }

Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
if (BoolAttr condAttr = operands.front().dyn_cast_or_null<BoolAttr>())
return condAttr.getValue() ? trueDest() : falseDest();
Expand Down
28 changes: 24 additions & 4 deletions mlir/lib/IR/Attributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,15 +196,26 @@ ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {

/// Return the specified attribute if present, null otherwise.
Attribute DictionaryAttr::get(StringRef name) const {
Optional<NamedAttribute> attr = getNamed(name);
return attr ? attr->second : nullptr;
}
Attribute DictionaryAttr::get(Identifier name) const {
Optional<NamedAttribute> attr = getNamed(name);
return attr ? attr->second : nullptr;
}

/// Return the specified named attribute if present, None otherwise.
Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
ArrayRef<NamedAttribute> values = getValue();
auto it = llvm::lower_bound(values, name, compareNamedAttributeWithName);
return it != values.end() && it->first == name ? it->second : Attribute();
return it != values.end() && it->first == name ? *it
: Optional<NamedAttribute>();
}
Attribute DictionaryAttr::get(Identifier name) const {
Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
for (auto elt : getValue())
if (elt.first == name)
return elt.second;
return nullptr;
return elt;
return llvm::None;
}

DictionaryAttr::iterator DictionaryAttr::begin() const {
Expand Down Expand Up @@ -1191,6 +1202,15 @@ Attribute MutableDictionaryAttr::get(Identifier name) const {
return attrs ? attrs.get(name) : nullptr;
}

/// Return the specified named attribute if present, None otherwise.
Optional<NamedAttribute> MutableDictionaryAttr::getNamed(StringRef name) const {
return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
}
Optional<NamedAttribute>
MutableDictionaryAttr::getNamed(Identifier name) const {
return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
}

/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
void MutableDictionaryAttr::set(Identifier name, Attribute value) {
Expand Down
19 changes: 19 additions & 0 deletions mlir/lib/IR/Operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,25 @@ void Operation::setOperands(ValueRange operands) {
assert(operands.empty() && "setting operands without an operand storage");
}

/// Replace the operands beginning at 'start' and ending at 'start' + 'length'
/// with the ones provided in 'operands'. 'operands' may be smaller or larger
/// than the range pointed to by 'start'+'length'.
void Operation::setOperands(unsigned start, unsigned length,
ValueRange operands) {
assert((start + length) <= getNumOperands() &&
"invalid operand range specified");
if (LLVM_LIKELY(hasOperandStorage))
return getOperandStorage().setOperands(this, start, length, operands);
assert(operands.empty() && "setting operands without an operand storage");
}

/// Insert the given operands into the operand list at the given 'index'.
void Operation::insertOperands(unsigned index, ValueRange operands) {
if (LLVM_LIKELY(hasOperandStorage))
return setOperands(index, /*length=*/0, operands);
assert(operands.empty() && "inserting operands without an operand storage");
}

//===----------------------------------------------------------------------===//
// Diagnostics
//===----------------------------------------------------------------------===//
Expand Down
229 changes: 215 additions & 14 deletions mlir/lib/IR/OperationSupport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"
using namespace mlir;

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -89,6 +91,55 @@ void detail::OperandStorage::setOperands(Operation *owner, ValueRange values) {
storageOperands[i].set(values[i]);
}

/// Replace the operands beginning at 'start' and ending at 'start' + 'length'
/// with the ones provided in 'operands'. 'operands' may be smaller or larger
/// than the range pointed to by 'start'+'length'.
void detail::OperandStorage::setOperands(Operation *owner, unsigned start,
unsigned length, ValueRange operands) {
// If the new size is the same, we can update inplace.
unsigned newSize = operands.size();
if (newSize == length) {
MutableArrayRef<OpOperand> storageOperands = getOperands();
for (unsigned i = 0, e = length; i != e; ++i)
storageOperands[start + i].set(operands[i]);
return;
}
// If the new size is greater, remove the extra operands and set the rest
// inplace.
if (newSize < length) {
eraseOperands(start + operands.size(), length - newSize);
setOperands(owner, start, newSize, operands);
return;
}
// Otherwise, the new size is greater so we need to grow the storage.
auto storageOperands = resize(owner, size() + (newSize - length));

// Shift operands to the right to make space for the new operands.
unsigned rotateSize = storageOperands.size() - (start + length);
auto rbegin = storageOperands.rbegin();
std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize);

// Update the operands inplace.
for (unsigned i = 0, e = operands.size(); i != e; ++i)
storageOperands[start + i].set(operands[i]);
}

/// Erase an operand held by the storage.
void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
TrailingOperandStorage &storage = getStorage();
MutableArrayRef<OpOperand> operands = storage.getOperands();
assert((start + length) <= operands.size());
storage.numOperands -= length;

// Shift all operands down if the operand to remove is not at the end.
if (start != storage.numOperands) {
auto indexIt = std::next(operands.begin(), start);
std::rotate(indexIt, std::next(indexIt, length), operands.end());
}
for (unsigned i = 0; i != length; ++i)
operands[storage.numOperands + i].~OpOperand();
}

/// Resize the storage to the given size. Returns the array containing the new
/// operands.
MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
Expand Down Expand Up @@ -149,20 +200,6 @@ MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
return newOperands;
}

/// Erase an operand held by the storage.
void detail::OperandStorage::eraseOperand(unsigned index) {
assert(index < size());
TrailingOperandStorage &storage = getStorage();
MutableArrayRef<OpOperand> operands = storage.getOperands();
--storage.numOperands;

// Shift all operands down by 1 if the operand to remove is not at the end.
auto indexIt = std::next(operands.begin(), index);
if (index != storage.numOperands)
std::rotate(indexIt, std::next(indexIt), operands.end());
operands[storage.numOperands].~OpOperand();
}

//===----------------------------------------------------------------------===//
// ResultStorage
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -235,6 +272,95 @@ unsigned OperandRange::getBeginOperandIndex() const {
return base->getOperandNumber();
}

//===----------------------------------------------------------------------===//
// MutableOperandRange

/// Construct a new mutable range from the given operand, operand start index,
/// and range length.
MutableOperandRange::MutableOperandRange(
Operation *owner, unsigned start, unsigned length,
ArrayRef<OperandSegment> operandSegments)
: owner(owner), start(start), length(length),
operandSegments(operandSegments.begin(), operandSegments.end()) {
assert((start + length) <= owner->getNumOperands() && "invalid range");
}
MutableOperandRange::MutableOperandRange(Operation *owner)
: MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}

/// Slice this range into a sub range, with the additional operand segment.
MutableOperandRange
MutableOperandRange::slice(unsigned subStart, unsigned subLen,
Optional<OperandSegment> segment) {
assert((subStart + subLen) <= length && "invalid sub-range");
MutableOperandRange subSlice(owner, start + subStart, subLen,
operandSegments);
if (segment)
subSlice.operandSegments.push_back(*segment);
return subSlice;
}

/// Append the given values to the range.
void MutableOperandRange::append(ValueRange values) {
if (values.empty())
return;
owner->insertOperands(start + length, values);
updateLength(length + values.size());
}

/// Assign this range to the given values.
void MutableOperandRange::assign(ValueRange values) {
owner->setOperands(start, length, values);
if (length != values.size())
updateLength(/*newLength=*/values.size());
}

/// Assign the range to the given value.
void MutableOperandRange::assign(Value value) {
if (length == 1) {
owner->setOperand(start, value);
} else {
owner->setOperands(start, length, value);
updateLength(/*newLength=*/1);
}
}

/// Erase the operands within the given sub-range.
void MutableOperandRange::erase(unsigned subStart, unsigned subLen) {
assert((subStart + subLen) <= length && "invalid sub-range");
if (length == 0)
return;
owner->eraseOperands(start + subStart, subLen);
updateLength(length - subLen);
}

/// Clear this range and erase all of the operands.
void MutableOperandRange::clear() {
if (length != 0) {
owner->eraseOperands(start, length);
updateLength(/*newLength=*/0);
}
}

/// Allow implicit conversion to an OperandRange.
MutableOperandRange::operator OperandRange() const {
return owner->getOperands().slice(start, length);
}

/// Update the length of this range to the one provided.
void MutableOperandRange::updateLength(unsigned newLength) {
int32_t diff = int32_t(newLength) - int32_t(length);
length = newLength;

// Update any of the provided segment attributes.
for (OperandSegment &segment : operandSegments) {
auto attr = segment.second.second.cast<DenseIntElementsAttr>();
SmallVector<int32_t, 8> segments(attr.getValues<int32_t>());
segments[segment.first] += diff;
segment.second.second = DenseIntElementsAttr::get(attr.getType(), segments);
owner->setAttr(segment.second.first, segment.second.second);
}
}

//===----------------------------------------------------------------------===//
// ResultRange

Expand Down Expand Up @@ -281,3 +407,78 @@ Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
Operation *operation = reinterpret_cast<Operation *>(owner.ptr.get<void *>());
return operation->getResult(owner.startIndex + index);
}

//===----------------------------------------------------------------------===//
// Operation Equivalency
//===----------------------------------------------------------------------===//

llvm::hash_code OperationEquivalence::computeHash(Operation *op) {
// Hash operations based upon their:
// - Operation Name
// - Attributes
llvm::hash_code hash = llvm::hash_combine(
op->getName(), op->getMutableAttrDict().getDictionary());

// - Result Types
ArrayRef<Type> resultTypes = op->getResultTypes();
switch (resultTypes.size()) {
case 0:
// We don't need to add anything to the hash.
break;
case 1:
// Add in the result type.
hash = llvm::hash_combine(hash, resultTypes.front());
break;
default:
// Use the type buffer as the hash, as we can guarantee it is the same for
// any given range of result types. This takes advantage of the fact the
// result types >1 are stored in a TupleType and uniqued.
hash = llvm::hash_combine(hash, resultTypes.data());
break;
}

// - Operands
// TODO: Allow commutative operations to have different ordering.
return llvm::hash_combine(
hash, llvm::hash_combine_range(op->operand_begin(), op->operand_end()));
}

bool OperationEquivalence::isEquivalentTo(Operation *lhs, Operation *rhs) {
if (lhs == rhs)
return true;

// Compare the operation name.
if (lhs->getName() != rhs->getName())
return false;
// Check operand counts.
if (lhs->getNumOperands() != rhs->getNumOperands())
return false;
// Compare attributes.
if (lhs->getMutableAttrDict() != rhs->getMutableAttrDict())
return false;
// Compare result types.
ArrayRef<Type> lhsResultTypes = lhs->getResultTypes();
ArrayRef<Type> rhsResultTypes = rhs->getResultTypes();
if (lhsResultTypes.size() != rhsResultTypes.size())
return false;
switch (lhsResultTypes.size()) {
case 0:
break;
case 1:
// Compare the single result type.
if (lhsResultTypes.front() != rhsResultTypes.front())
return false;
break;
default:
// Use the type buffer for the comparison, as we can guarantee it is the
// same for any given range of result types. This takes advantage of the
// fact the result types >1 are stored in a TupleType and uniqued.
if (lhsResultTypes.data() != rhsResultTypes.data())
return false;
break;
}
// Compare operands.
// TODO: Allow commutative operations to have different ordering.
return std::equal(lhs->operand_begin(), lhs->operand_end(),
rhs->operand_begin());
}
33 changes: 0 additions & 33 deletions mlir/lib/Interfaces/ControlFlowInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,6 @@ using namespace mlir;
// BranchOpInterface
//===----------------------------------------------------------------------===//

/// Erase an operand from a branch operation that is used as a successor
/// operand. 'operandIndex' is the operand within 'operands' to be erased.
void mlir::detail::eraseBranchSuccessorOperand(OperandRange operands,
unsigned operandIndex,
Operation *op) {
assert(operandIndex < operands.size() &&
"invalid index for successor operands");

// Erase the operand from the operation.
size_t fullOperandIndex = operands.getBeginOperandIndex() + operandIndex;
op->eraseOperand(fullOperandIndex);

// If this operation has an OperandSegmentSizeAttr, keep it up to date.
auto operandSegmentAttr =
op->getAttrOfType<DenseElementsAttr>("operand_segment_sizes");
if (!operandSegmentAttr)
return;

// Find the segment containing the full operand index and decrement it.
// TODO: This seems like a general utility that could be added somewhere.
SmallVector<int32_t, 4> values(operandSegmentAttr.getValues<int32_t>());
unsigned currentSize = 0;
for (unsigned i = 0, e = values.size(); i != e; ++i) {
currentSize += values[i];
if (fullOperandIndex < currentSize) {
--values[i];
break;
}
}
op->setAttr("operand_segment_sizes",
DenseIntElementsAttr::get(operandSegmentAttr.getType(), values));
}

/// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
/// successor if 'operandIndex' is within the range of 'operands', or None if
/// `operandIndex` isn't a successor operand index.
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/TableGen/OpTrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,7 @@ llvm::StringRef InterfaceOpTrait::getTrait() const {
bool InterfaceOpTrait::shouldDeclareMethods() const {
return def->isSubClassOf("DeclareOpInterfaceMethods");
}

std::vector<StringRef> InterfaceOpTrait::getAlwaysDeclaredMethods() const {
return def->getValueAsListOfStrings("alwaysOverriddenMethods");
}
32 changes: 3 additions & 29 deletions mlir/lib/Transforms/CSE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,9 @@
using namespace mlir;

namespace {
// TODO(riverriddle) Handle commutative operations.
struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
static unsigned getHashValue(const Operation *opC) {
auto *op = const_cast<Operation *>(opC);
// Hash the operations based upon their:
// - Operation Name
// - Attributes
// - Result Types
// - Operands
return llvm::hash_combine(
op->getName(), op->getMutableAttrDict().getDictionary(),
op->getResultTypes(),
llvm::hash_combine_range(op->operand_begin(), op->operand_end()));
return OperationEquivalence::computeHash(const_cast<Operation *>(opC));
}
static bool isEqual(const Operation *lhsC, const Operation *rhsC) {
auto *lhs = const_cast<Operation *>(lhsC);
Expand All @@ -48,24 +38,8 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
rhs == getTombstoneKey() || rhs == getEmptyKey())
return false;

// Compare the operation name.
if (lhs->getName() != rhs->getName())
return false;
// Check operand and result type counts.
if (lhs->getNumOperands() != rhs->getNumOperands() ||
lhs->getNumResults() != rhs->getNumResults())
return false;
// Compare attributes.
if (lhs->getMutableAttrDict() != rhs->getMutableAttrDict())
return false;
// Compare operands.
if (!std::equal(lhs->operand_begin(), lhs->operand_end(),
rhs->operand_begin()))
return false;
// Compare result types.
return std::equal(lhs->result_type_begin(), lhs->result_type_end(),
rhs->result_type_begin());
return OperationEquivalence::isEquivalentTo(const_cast<Operation *>(lhsC),
const_cast<Operation *>(rhsC));
}
};
} // end anonymous namespace
Expand Down
9 changes: 5 additions & 4 deletions mlir/lib/Transforms/Utils/RegionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {

// Check to see if we can reason about the successor operands and mutate them.
BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op);
if (!branchInterface || !branchInterface.canEraseSuccessorOperand()) {
if (!branchInterface) {
for (Block *successor : op->getSuccessors())
for (BlockArgument arg : successor->getArguments())
liveMap.setProvedLive(arg);
Expand All @@ -219,7 +219,7 @@ static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
// If we can't reason about the operands to a successor, conservatively mark
// all arguments as live.
for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
if (!branchInterface.getSuccessorOperands(i))
if (!branchInterface.getMutableSuccessorOperands(i))
for (BlockArgument arg : op->getSuccessor(i)->getArguments())
liveMap.setProvedLive(arg);
}
Expand Down Expand Up @@ -278,7 +278,8 @@ static void eraseTerminatorSuccessorOperands(Operation *terminator,
// since it will promote later operands of the terminator being erased
// first, reducing the quadratic-ness.
unsigned succ = succE - succI - 1;
Optional<OperandRange> succOperands = branchOp.getSuccessorOperands(succ);
Optional<MutableOperandRange> succOperands =
branchOp.getMutableSuccessorOperands(succ);
if (!succOperands)
continue;
Block *successor = terminator->getSuccessor(succ);
Expand All @@ -288,7 +289,7 @@ static void eraseTerminatorSuccessorOperands(Operation *terminator,
// shifting later args when earlier args are erased.
unsigned arg = argE - argI - 1;
if (!liveMap.wasProvenLive(successor->getArgument(arg)))
branchOp.eraseSuccessorOperand(succ, arg);
succOperands->erase(arg);
}
}
}
Expand Down
7 changes: 3 additions & 4 deletions mlir/test/lib/Dialect/Test/TestDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,12 @@ TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
// TestBranchOp
//===----------------------------------------------------------------------===//

Optional<OperandRange> TestBranchOp::getSuccessorOperands(unsigned index) {
Optional<MutableOperandRange>
TestBranchOp::getMutableSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
return getOperands();
return targetOperandsMutable();
}

bool TestBranchOp::canEraseSuccessorOperand() { return true; }

//===----------------------------------------------------------------------===//
// Test IsolatedRegionOp - parse passthrough region arguments.
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions mlir/test/mlir-tblgen/op-decl.td
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
// CHECK: Operation::operand_range getODSOperands(unsigned index);
// CHECK: Value a();
// CHECK: Operation::operand_range b();
// CHECK: ::mlir::MutableOperandRange aMutable();
// CHECK: ::mlir::MutableOperandRange bMutable();
// CHECK: Operation::result_range getODSResults(unsigned index);
// CHECK: Value r();
// CHECK: Region &someRegion();
Expand Down Expand Up @@ -119,6 +121,7 @@ def NS_EOp : NS_Op<"op_with_optionals", []> {

// CHECK-LABEL: NS::EOp declarations
// CHECK: Value a();
// CHECK: ::mlir::MutableOperandRange aMutable();
// CHECK: Value b();
// CHECK: static void build(OpBuilder &odsBuilder, OperationState &odsState, /*optional*/Type b, /*optional*/Value a)

Expand Down
20 changes: 20 additions & 0 deletions mlir/test/mlir-tblgen/op-interface.td
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: mlir-tblgen -gen-op-interface-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL --dump-input-on-failure
// RUN: mlir-tblgen -gen-op-decls -I %S/../../include %s | FileCheck %s --check-prefix=OP_DECL --dump-input-on-failure

include "mlir/IR/OpBase.td"

Expand All @@ -12,6 +13,14 @@ def TestOpInterface : OpInterface<"TestOpInterface"> {
/*methodName=*/"foo",
/*args=*/(ins "int":$input)
>,
InterfaceMethod<
/*desc=*/[{some function comment}],
/*retTy=*/"int",
/*methodName=*/"default_foo",
/*args=*/(ins "int":$input),
/*body=*/[{}],
/*defaultBody=*/[{ return 0; }]
>,
];
}

Expand All @@ -27,8 +36,19 @@ def OpInterfaceOp : Op<TestDialect, "op_interface_op", [TestOpInterface]>;
def DeclareMethodsOp : Op<TestDialect, "declare_methods_op",
[DeclareOpInterfaceMethods<TestOpInterface>]>;

def DeclareMethodsWithDefaultOp : Op<TestDialect, "declare_methods_op",
[DeclareOpInterfaceMethods<TestOpInterface, ["default_foo"]>]>;

// DECL-LABEL: TestOpInterfaceInterfaceTraits
// DECL: class TestOpInterface : public OpInterface<TestOpInterface, detail::TestOpInterfaceInterfaceTraits>
// DECL: int foo(int input);

// DECL-NOT: TestOpInterface

// OP_DECL-LABEL: class DeclareMethodsOp : public
// OP_DECL: int foo(int input);
// OP_DECL-NOT: int default_foo(int input);

// OP_DECL-LABEL: class DeclareMethodsWithDefaultOp : public
// OP_DECL: int foo(int input);
// OP_DECL: int default_foo(int input);
166 changes: 105 additions & 61 deletions mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,23 @@ static const char *const builderOpState = "odsState";
// {1}: The total number of non-variadic operands/results.
// {2}: The total number of variadic operands/results.
// {3}: The total number of actual values.
// {4}: The begin iterator of the actual values.
// {5}: "operand" or "result".
// {4}: "operand" or "result".
const char *sameVariadicSizeValueRangeCalcCode = R"(
bool isVariadic[] = {{{0}};
int prevVariadicCount = 0;
for (unsigned i = 0; i < index; ++i)
if (isVariadic[i]) ++prevVariadicCount;
// Calculate how many dynamic values a static variadic {5} corresponds to.
// This assumes all static variadic {5}s have the same dynamic value count.
// Calculate how many dynamic values a static variadic {4} corresponds to.
// This assumes all static variadic {4}s have the same dynamic value count.
int variadicSize = ({3} - {1}) / {2};
// `index` passed in as the parameter is the static index which counts each
// {5} (variadic or not) as size 1. So here for each previous static variadic
// {5}, we need to offset by (variadicSize - 1) to get where the dynamic
// value pack for this static {5} starts.
int offset = index + (variadicSize - 1) * prevVariadicCount;
// {4} (variadic or not) as size 1. So here for each previous static variadic
// {4}, we need to offset by (variadicSize - 1) to get where the dynamic
// value pack for this static {4} starts.
int start = index + (variadicSize - 1) * prevVariadicCount;
int size = isVariadic[index] ? variadicSize : 1;
return {{std::next({4}, offset), std::next({4}, offset + size)};
return {{start, size};
)";

// The logic to calculate the actual value range for a declared operand/result
Expand All @@ -72,14 +70,23 @@ const char *sameVariadicSizeValueRangeCalcCode = R"(
// (variadic or not).
//
// {0}: The name of the attribute specifying the segment sizes.
// {1}: The begin iterator of the actual values.
const char *attrSizedSegmentValueRangeCalcCode = R"(
auto sizeAttr = getAttrOfType<DenseIntElementsAttr>("{0}");
unsigned start = 0;
for (unsigned i = 0; i < index; ++i)
start += (*(sizeAttr.begin() + i)).getZExtValue();
unsigned end = start + (*(sizeAttr.begin() + index)).getZExtValue();
return {{std::next({1}, start), std::next({1}, end)};
unsigned size = (*(sizeAttr.begin() + index)).getZExtValue();
return {{start, size};
)";

// The logic to build a range of either operand or result values.
//
// {0}: The begin iterator of the actual values.
// {1}: The call to generate the start and length of the value range.
const char *valueRangeReturnCode = R"(
auto valueRange = {1};
return {{std::next({0}, valueRange.first),
std::next({0}, valueRange.first + valueRange.second)};
)";

static const char *const opCommentHeader = R"(
Expand Down Expand Up @@ -177,6 +184,9 @@ class OpEmitter {
// Generates getters for named operands.
void genNamedOperandGetters();

// Generates setters for named operands.
void genNamedOperandSetters();

// Generates getters for named results.
void genNamedResultGetters();

Expand Down Expand Up @@ -310,6 +320,7 @@ OpEmitter::OpEmitter(const Operator &op)
genOpAsmInterface();
genOpNameGetter();
genNamedOperandGetters();
genNamedOperandSetters();
genNamedResultGetters();
genNamedRegionGetters();
genNamedSuccessorGetters();
Expand Down Expand Up @@ -478,6 +489,37 @@ void OpEmitter::genAttrSetters() {
}
}

// Generates the code to compute the start and end index of an operand or result
// range.
template <typename RangeT>
static void
generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
int numVariadic, int numNonVariadic,
StringRef rangeSizeCall, bool hasAttrSegmentSize,
StringRef segmentSizeAttr, RangeT &&odsValues) {
auto &method = opClass.newMethod("std::pair<unsigned, unsigned>", methodName,
"unsigned index");

if (numVariadic == 0) {
method.body() << " return {index, 1};\n";
} else if (hasAttrSegmentSize) {
method.body() << formatv(attrSizedSegmentValueRangeCalcCode,
segmentSizeAttr);
} else {
// Because the op can have arbitrarily interleaved variadic and non-variadic
// operands, we need to embed a list in the "sink" getter method for
// calculation at run-time.
llvm::SmallVector<StringRef, 4> isVariadic;
isVariadic.reserve(llvm::size(odsValues));
for (auto &it : odsValues)
isVariadic.push_back(it.isVariableLength() ? "true" : "false");
std::string isVariadicList = llvm::join(isVariadic, ", ");
method.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
numNonVariadic, numVariadic, rangeSizeCall,
"operand");
}
}

// Generates the named operand getter methods for the given Operator `op` and
// puts them in `opClass`. Uses `rangeType` as the return type of getters that
// return a range of operands (individual operands are `Value ` and each
Expand Down Expand Up @@ -519,32 +561,16 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
"'SameVariadicOperandSize' traits");
}

// First emit a "sink" getter method upon which we layer all nicer named
// First emit a few "sink" getter methods upon which we layer all nicer named
// getter methods.
auto &m = opClass.newMethod(rangeType, "getODSOperands", "unsigned index");
generateValueRangeStartAndEnd(
opClass, "getODSOperandIndexAndLength", numVariadicOperands,
numNormalOperands, rangeSizeCall, attrSizedOperands,
"operand_segment_sizes", const_cast<Operator &>(op).getOperands());

if (numVariadicOperands == 0) {
// We still need to match the return type, which is a range.
m.body() << " return {std::next(" << rangeBeginCall
<< ", index), std::next(" << rangeBeginCall << ", index + 1)};";
} else if (attrSizedOperands) {
m.body() << formatv(attrSizedSegmentValueRangeCalcCode,
"operand_segment_sizes", rangeBeginCall);
} else {
// Because the op can have arbitrarily interleaved variadic and non-variadic
// operands, we need to embed a list in the "sink" getter method for
// calculation at run-time.
llvm::SmallVector<StringRef, 4> isVariadic;
isVariadic.reserve(numOperands);
for (int i = 0; i < numOperands; ++i)
isVariadic.push_back(op.getOperand(i).isVariableLength() ? "true"
: "false");
std::string isVariadicList = llvm::join(isVariadic, ", ");

m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
numNormalOperands, numVariadicOperands, rangeSizeCall,
rangeBeginCall, "operand");
}
auto &m = opClass.newMethod(rangeType, "getODSOperands", "unsigned index");
m.body() << formatv(valueRangeReturnCode, rangeBeginCall,
"getODSOperandIndexAndLength(index)");

// Then we emit nicer named getter methods by redirecting to the "sink" getter
// method.
Expand Down Expand Up @@ -579,6 +605,26 @@ void OpEmitter::genNamedOperandGetters() {
/*getOperandCallPattern=*/"getOperation()->getOperand({0})");
}

void OpEmitter::genNamedOperandSetters() {
auto *attrSizedOperands = op.getTrait("OpTrait::AttrSizedOperandSegments");
for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
const auto &operand = op.getOperand(i);
if (operand.name.empty())
continue;
auto &m = opClass.newMethod("::mlir::MutableOperandRange",
(operand.name + "Mutable").str());
auto &body = m.body();
body << " auto range = getODSOperandIndexAndLength(" << i << ");\n"
<< " return ::mlir::MutableOperandRange(getOperation(), "
"range.first, range.second";
if (attrSizedOperands)
body << ", ::mlir::MutableOperandRange::OperandSegment(" << i
<< "u, *getOperation()->getMutableAttrDict().getNamed("
"\"operand_segment_sizes\"))";
body << ");\n";
}
}

void OpEmitter::genNamedResultGetters() {
const int numResults = op.getNumResults();
const int numVariadicResults = op.getNumVariableLengthResults();
Expand Down Expand Up @@ -607,29 +653,14 @@ void OpEmitter::genNamedResultGetters() {
"'SameVariadicResultSize' traits");
}

generateValueRangeStartAndEnd(
opClass, "getODSResultIndexAndLength", numVariadicResults,
numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
"result_segment_sizes", op.getResults());
auto &m = opClass.newMethod("Operation::result_range", "getODSResults",
"unsigned index");

if (numVariadicResults == 0) {
m.body() << " return {std::next(getOperation()->result_begin(), index), "
"std::next(getOperation()->result_begin(), index + 1)};";
} else if (attrSizedResults) {
m.body() << formatv(attrSizedSegmentValueRangeCalcCode,
"result_segment_sizes",
"getOperation()->result_begin()");
} else {
llvm::SmallVector<StringRef, 4> isVariadic;
isVariadic.reserve(numResults);
for (int i = 0; i < numResults; ++i)
isVariadic.push_back(op.getResult(i).isVariableLength() ? "true"
: "false");
std::string isVariadicList = llvm::join(isVariadic, ", ");

m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
numNormalResults, numVariadicResults,
"getOperation()->getNumResults()",
"getOperation()->result_begin()", "result");
}
m.body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
"getODSResultIndexAndLength(index)");

for (int i = 0; i != numResults; ++i) {
const auto &result = op.getResult(i);
Expand Down Expand Up @@ -1251,10 +1282,23 @@ void OpEmitter::genOpInterfaceMethods() {
if (!opTrait || !opTrait->shouldDeclareMethods())
continue;
auto interface = opTrait->getOpInterface();
for (auto method : interface.getMethods()) {
// Don't declare if the method has a body or a default implementation.
if (method.getBody() || method.getDefaultImplementation())

// Get the set of methods that should always be declared.
auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods();
llvm::StringSet<> alwaysDeclaredMethods;
alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(),
alwaysDeclaredMethodsVec.end());

for (const OpInterfaceMethod &method : interface.getMethods()) {
// Don't declare if the method has a body.
if (method.getBody())
continue;
// Don't declare if the method has a default implementation and the op
// didn't request that it always be declared.
if (method.getDefaultImplementation() &&
!alwaysDeclaredMethods.count(method.getName()))
continue;

std::string args;
llvm::raw_string_ostream os(args);
interleaveComma(method.getArguments(), os,
Expand Down
8 changes: 6 additions & 2 deletions mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
StringRef interfaceName,
StringRef interfaceTraitsName) {
os << " template <typename ConcreteOp>\n "
<< llvm::formatv("struct Trait : public OpInterface<{0},"
<< llvm::formatv("struct {0}Trait : public OpInterface<{0},"
" detail::{1}>::Trait<ConcreteOp> {{\n",
interfaceName, interfaceTraitsName);

Expand All @@ -171,13 +171,17 @@ static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
tblgen::FmtContext traitCtx;
traitCtx.withOp("op");
if (auto verify = interface.getVerify()) {
os << " static LogicalResult verifyTrait(Operation* op) {\n"
os << " static LogicalResult verifyTrait(Operation* op) {\n"
<< std::string(tblgen::tgfmt(*verify, &traitCtx)) << "\n }\n";
}
if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration())
os << extraTraitDecls << "\n";

os << " };\n";

// Emit a utility using directive for the trait class.
os << " template <typename ConcreteOp>\n "
<< llvm::formatv("using Trait = {0}Trait<ConcreteOp>;\n", interfaceName);
}

static void emitInterfaceDecl(OpInterface &interface, raw_ostream &os) {
Expand Down
77 changes: 75 additions & 2 deletions mlir/unittests/IR/OperationSupportTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ TEST(OperandStorageTest, NonResizable) {
Value operand = useOp->getResult(0);

// Create a non-resizable operation with one operand.
Operation *user = createOp(&context, operand, builder.getIntegerType(16));
Operation *user = createOp(&context, operand);

// The same number of operands is okay.
user->setOperands(operand);
Expand All @@ -57,7 +57,7 @@ TEST(OperandStorageTest, Resizable) {
Value operand = useOp->getResult(0);

// Create a resizable operation with one operand.
Operation *user = createOp(&context, operand, builder.getIntegerType(16));
Operation *user = createOp(&context, operand);

// The same number of operands is okay.
user->setOperands(operand);
Expand All @@ -76,4 +76,77 @@ TEST(OperandStorageTest, Resizable) {
useOp->destroy();
}

TEST(OperandStorageTest, RangeReplace) {
MLIRContext context;
Builder builder(&context);

Operation *useOp =
createOp(&context, /*operands=*/llvm::None, builder.getIntegerType(16));
Value operand = useOp->getResult(0);

// Create a resizable operation with one operand.
Operation *user = createOp(&context, operand);

// Check setting with the same number of operands.
user->setOperands(/*start=*/0, /*length=*/1, operand);
EXPECT_EQ(user->getNumOperands(), 1u);

// Check setting with more operands.
user->setOperands(/*start=*/0, /*length=*/1, {operand, operand, operand});
EXPECT_EQ(user->getNumOperands(), 3u);

// Check setting with less operands.
user->setOperands(/*start=*/1, /*length=*/2, {operand});
EXPECT_EQ(user->getNumOperands(), 2u);

// Check inserting without replacing operands.
user->setOperands(/*start=*/2, /*length=*/0, {operand});
EXPECT_EQ(user->getNumOperands(), 3u);

// Check erasing operands.
user->setOperands(/*start=*/0, /*length=*/3, {});
EXPECT_EQ(user->getNumOperands(), 0u);

// Destroy the operations.
user->destroy();
useOp->destroy();
}

TEST(OperandStorageTest, MutableRange) {
MLIRContext context;
Builder builder(&context);

Operation *useOp =
createOp(&context, /*operands=*/llvm::None, builder.getIntegerType(16));
Value operand = useOp->getResult(0);

// Create a resizable operation with one operand.
Operation *user = createOp(&context, operand);

// Check setting with the same number of operands.
MutableOperandRange mutableOperands(user);
mutableOperands.assign(operand);
EXPECT_EQ(mutableOperands.size(), 1u);
EXPECT_EQ(user->getNumOperands(), 1u);

// Check setting with more operands.
mutableOperands.assign({operand, operand, operand});
EXPECT_EQ(mutableOperands.size(), 3u);
EXPECT_EQ(user->getNumOperands(), 3u);

// Check with inserting a new operand.
mutableOperands.append({operand, operand});
EXPECT_EQ(mutableOperands.size(), 5u);
EXPECT_EQ(user->getNumOperands(), 5u);

// Check erasing operands.
mutableOperands.clear();
EXPECT_EQ(mutableOperands.size(), 0u);
EXPECT_EQ(user->getNumOperands(), 0u);

// Destroy the operations.
user->destroy();
useOp->destroy();
}

} // end namespace