56 changes: 28 additions & 28 deletions mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,10 @@ def PDL_AttributeOp : PDL_Op<"attribute"> {
```
}];

let arguments = (ins Optional<PDL_Type>:$type,
let arguments = (ins Optional<PDL_Type>:$valueType,
OptionalAttr<AnyAttr>:$value);
let results = (outs PDL_Attribute:$attr);
let assemblyFormat = "(`:` $type^)? (`=` $value^)? attr-dict-with-keyword";
let assemblyFormat = "(`:` $valueType^)? (`=` $value^)? attr-dict-with-keyword";

let builders = [
OpBuilder<(ins CArg<"Value", "Value()">:$type), [{
Expand Down Expand Up @@ -156,8 +156,8 @@ def PDL_EraseOp : PDL_Op<"erase", [HasParent<"pdl::RewriteOp">]> {
pdl.erase %root
```
}];
let arguments = (ins PDL_Operation:$operation);
let assemblyFormat = "$operation attr-dict";
let arguments = (ins PDL_Operation:$opValue);
let assemblyFormat = "$opValue attr-dict";
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -187,9 +187,9 @@ def PDL_OperandOp
```
}];

let arguments = (ins Optional<PDL_Type>:$type);
let results = (outs PDL_Value:$val);
let assemblyFormat = "(`:` $type^)? attr-dict";
let arguments = (ins Optional<PDL_Type>:$valueType);
let results = (outs PDL_Value:$value);
let assemblyFormat = "(`:` $valueType^)? attr-dict";

let builders = [
OpBuilder<(ins), [{
Expand Down Expand Up @@ -226,9 +226,9 @@ def PDL_OperandsOp
```
}];

let arguments = (ins Optional<PDL_RangeOf<PDL_Type>>:$type);
let results = (outs PDL_RangeOf<PDL_Value>:$val);
let assemblyFormat = "(`:` $type^)? attr-dict";
let arguments = (ins Optional<PDL_RangeOf<PDL_Type>>:$valueType);
let results = (outs PDL_RangeOf<PDL_Value>:$value);
let assemblyFormat = "(`:` $valueType^)? attr-dict";

let builders = [
OpBuilder<(ins), [{
Expand Down Expand Up @@ -341,16 +341,16 @@ def PDL_OperationOp : PDL_Op<"operation", [AttrSizedOperandSegments]> {
```
}];

let arguments = (ins OptionalAttr<StrAttr>:$name,
Variadic<PDL_InstOrRangeOf<PDL_Value>>:$operands,
Variadic<PDL_Attribute>:$attributes,
StrArrayAttr:$attributeNames,
Variadic<PDL_InstOrRangeOf<PDL_Type>>:$types);
let arguments = (ins OptionalAttr<StrAttr>:$opName,
Variadic<PDL_InstOrRangeOf<PDL_Value>>:$operandValues,
Variadic<PDL_Attribute>:$attributeValues,
StrArrayAttr:$attributeValueNames,
Variadic<PDL_InstOrRangeOf<PDL_Type>>:$typeValues);
let results = (outs PDL_Operation:$op);
let assemblyFormat = [{
($name^)? (`(` $operands^ `:` type($operands) `)`)?
custom<OperationOpAttributes>($attributes, $attributeNames)
(`->` `(` $types^ `:` type($types) `)`)? attr-dict
($opName^)? (`(` $operandValues^ `:` type($operandValues) `)`)?
custom<OperationOpAttributes>($attributeValues, $attributeValueNames)
(`->` `(` $typeValues^ `:` type($typeValues) `)`)? attr-dict
}];

let builders = [
Expand Down Expand Up @@ -413,9 +413,9 @@ def PDL_PatternOp : PDL_Op<"pattern", [

let arguments = (ins ConfinedAttr<I16Attr, [IntNonNegative]>:$benefit,
OptionalAttr<SymbolNameAttr>:$sym_name);
let regions = (region SizedRegion<1>:$body);
let regions = (region SizedRegion<1>:$bodyRegion);
let assemblyFormat = [{
($sym_name^)? `:` `benefit` `(` $benefit `)` attr-dict-with-keyword $body
($sym_name^)? `:` `benefit` `(` $benefit `)` attr-dict-with-keyword $bodyRegion
}];

let builders = [
Expand Down Expand Up @@ -467,11 +467,11 @@ def PDL_ReplaceOp : PDL_Op<"replace", [
pdl.replace %root with %otherOp
```
}];
let arguments = (ins PDL_Operation:$operation,
let arguments = (ins PDL_Operation:$opValue,
Optional<PDL_Operation>:$replOperation,
Variadic<PDL_InstOrRangeOf<PDL_Value>>:$replValues);
let assemblyFormat = [{
$operation `with` (`(` $replValues^ `:` type($replValues) `)`)?
$opValue `with` (`(` $replValues^ `:` type($replValues) `)`)?
($replOperation^)? attr-dict
}];
let hasVerifier = 1;
Expand Down Expand Up @@ -603,10 +603,10 @@ def PDL_RewriteOp : PDL_Op<"rewrite", [
let arguments = (ins Optional<PDL_Operation>:$root,
OptionalAttr<StrAttr>:$name,
Variadic<PDL_AnyType>:$externalArgs);
let regions = (region AnyRegion:$body);
let regions = (region AnyRegion:$bodyRegion);
let assemblyFormat = [{
($root^)? (`with` $name^ (`(` $externalArgs^ `:` type($externalArgs) `)`)?)?
($body^)?
($bodyRegion^)?
attr-dict-with-keyword
}];
let hasRegionVerifier = 1;
Expand Down Expand Up @@ -635,9 +635,9 @@ def PDL_TypeOp : PDL_Op<"type"> {
```
}];

let arguments = (ins OptionalAttr<TypeAttr>:$type);
let arguments = (ins OptionalAttr<TypeAttr>:$constantType);
let results = (outs PDL_Type:$result);
let assemblyFormat = "attr-dict (`:` $type^)?";
let assemblyFormat = "attr-dict (`:` $constantType^)?";
let hasVerifier = 1;
}

Expand All @@ -664,9 +664,9 @@ def PDL_TypesOp : PDL_Op<"types"> {
```
}];

let arguments = (ins OptionalAttr<TypeArrayAttr>:$types);
let arguments = (ins OptionalAttr<TypeArrayAttr>:$constantTypes);
let results = (outs PDL_RangeOf<PDL_Type>:$result);
let assemblyFormat = "attr-dict (`:` $types^)?";
let assemblyFormat = "attr-dict (`:` $constantTypes^)?";
let hasVerifier = 1;
}

Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,8 @@ struct WmmaElementwiseOpToNVVMLowering
extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
loc, adaptor.getOperands()[opIdx], i));
}
Value element =
createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.operation(),
extractedOperands);
Value element = createScalarOp(
rewriter, loc, subgroupMmaElementwiseOp.opType(), extractedOperands);
matrixStruct =
rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, element, i);
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ LogicalResult GPUModuleConversion::matchAndRewrite(

// Move the region from the module op into the SPIR-V module.
Region &spvModuleRegion = spvModule.getRegion();
rewriter.inlineRegionBefore(moduleOp.body(), spvModuleRegion,
rewriter.inlineRegionBefore(moduleOp.getBodyRegion(), spvModuleRegion,
spvModuleRegion.begin());
// The spv.module build method adds a block. Remove that.
rewriter.eraseBlock(&spvModuleRegion.back());
Expand Down
44 changes: 23 additions & 21 deletions mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,16 +575,17 @@ void PatternLowering::generate(SuccessNode *successNode, Block *&currentBlock) {

// Collect the set of operations generated by the rewriter.
SmallVector<StringRef, 4> generatedOps;
for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>())
generatedOps.push_back(*op.name());
for (auto op :
pattern.getRewriter().getBodyRegion().getOps<pdl::OperationOp>())
generatedOps.push_back(*op.getOpName());
ArrayAttr generatedOpsAttr;
if (!generatedOps.empty())
generatedOpsAttr = builder.getStrArrayAttr(generatedOps);

// Grab the root kind if present.
StringAttr rootKindAttr;
if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>())
if (Optional<StringRef> rootKind = rootOp.name())
if (Optional<StringRef> rootKind = rootOp.getOpName())
rootKindAttr = builder.getStringAttr(*rootKind);

builder.setInsertionPointToEnd(currentBlock);
Expand Down Expand Up @@ -620,12 +621,12 @@ SymbolRefAttr PatternLowering::generateRewriter(
attrOp.getLoc(), value);
}
} else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
if (TypeAttr type = typeOp.typeAttr()) {
if (TypeAttr type = typeOp.getConstantTypeAttr()) {
return newValue = builder.create<pdl_interp::CreateTypeOp>(
typeOp.getLoc(), type);
}
} else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
if (ArrayAttr type = typeOp.typesAttr()) {
if (ArrayAttr type = typeOp.getConstantTypesAttr()) {
return newValue = builder.create<pdl_interp::CreateTypesOp>(
typeOp.getLoc(), typeOp.getType(), type);
}
Expand Down Expand Up @@ -699,18 +700,18 @@ void PatternLowering::generateRewriter(
pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
mapRewriteValue(eraseOp.operation()));
mapRewriteValue(eraseOp.getOpValue()));
}

void PatternLowering::generateRewriter(
pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
SmallVector<Value, 4> operands;
for (Value operand : operationOp.operands())
for (Value operand : operationOp.getOperandValues())
operands.push_back(mapRewriteValue(operand));

SmallVector<Value, 4> attributes;
for (Value attr : operationOp.attributes())
for (Value attr : operationOp.getAttributeValues())
attributes.push_back(mapRewriteValue(attr));

bool hasInferredResultTypes = false;
Expand All @@ -721,14 +722,14 @@ void PatternLowering::generateRewriter(
// Create the new operation.
Location loc = operationOp.getLoc();
Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
loc, *operationOp.name(), types, hasInferredResultTypes, operands,
attributes, operationOp.attributeNames());
loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands,
attributes, operationOp.getAttributeValueNames());
rewriteValues[operationOp.op()] = createdOp;

// Generate accesses for any results that have their types constrained.
// Handle the case where there is a single range representing all of the
// result types.
OperandRange resultTys = operationOp.types();
OperandRange resultTys = operationOp.getTypeValues();
if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) {
Value &type = rewriteValues[resultTys[0]];
if (!type) {
Expand Down Expand Up @@ -772,8 +773,8 @@ void PatternLowering::generateRewriter(
// user facing.
if (Value replOp = replaceOp.replOperation()) {
// Don't use replace if we know the replaced operation has no results.
auto opOp = replaceOp.operation().getDefiningOp<pdl::OperationOp>();
if (!opOp || !opOp.types().empty()) {
auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>();
if (!opOp || !opOp.getTypeValues().empty()) {
replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
replOp.getLoc(), mapRewriteValue(replOp)));
}
Expand All @@ -784,13 +785,14 @@ void PatternLowering::generateRewriter(

// If there are no replacement values, just create an erase instead.
if (replOperands.empty()) {
builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(),
mapRewriteValue(replaceOp.operation()));
builder.create<pdl_interp::EraseOp>(
replaceOp.getLoc(), mapRewriteValue(replaceOp.getOpValue()));
return;
}

builder.create<pdl_interp::ReplaceOp>(
replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands);
builder.create<pdl_interp::ReplaceOp>(replaceOp.getLoc(),
mapRewriteValue(replaceOp.getOpValue()),
replOperands);
}

void PatternLowering::generateRewriter(
Expand All @@ -814,7 +816,7 @@ void PatternLowering::generateRewriter(
function_ref<Value(Value)> mapRewriteValue) {
// If the type isn't constant, the users (e.g. OperationOp) will resolve this
// type.
if (TypeAttr typeAttr = typeOp.typeAttr()) {
if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) {
rewriteValues[typeOp] =
builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
}
Expand All @@ -825,7 +827,7 @@ void PatternLowering::generateRewriter(
function_ref<Value(Value)> mapRewriteValue) {
// If the type isn't constant, the users (e.g. OperationOp) will resolve this
// type.
if (ArrayAttr typeAttr = typeOp.typesAttr()) {
if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) {
rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
typeOp.getLoc(), typeOp.getType(), typeAttr);
}
Expand All @@ -840,7 +842,7 @@ void PatternLowering::generateOperationResultTypeRewriter(
// Try to handle resolution for each of the result types individually. This is
// preferred over type inferrence because it will allow for us to use existing
// types directly, as opposed to trying to rebuild the type list.
OperandRange resultTypeValues = op.types();
OperandRange resultTypeValues = op.getTypeValues();
auto tryResolveResultTypes = [&] {
types.reserve(resultTypeValues.size());
for (const auto &it : llvm::enumerate(resultTypeValues)) {
Expand Down Expand Up @@ -886,7 +888,7 @@ void PatternLowering::generateOperationResultTypeRewriter(
// rewrites only have single block regions, so if the op isn't in the
// rewriter block (i.e. the current block of the operation) we already know
// it dominates (i.e. it's in the matcher).
Value replOpVal = replOpUser.operation();
Value replOpVal = replOpUser.getOpValue();
Operation *replacedOp = replOpVal.getDefiningOp();
if (replacedOp->getBlock() == rewriterBlock &&
!replacedOp->isBeforeInBlock(op))
Expand Down
45 changes: 24 additions & 21 deletions mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
predList.emplace_back(pos, builder.getIsNotNull());

// If the attribute has a type or value, add a constraint.
if (Value type = attr.type())
if (Value type = attr.getValueType())
getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
else if (Attribute value = attr.valueAttr())
predList.emplace_back(pos, builder.getAttributeConstraint(value));
Expand All @@ -76,7 +76,7 @@ static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
predList.emplace_back(pos, builder.getIsNotNull());

if (Value type = op.type())
if (Value type = op.getValueType())
getTreePredicates(predList, type, builder, inputs,
builder.getType(pos));
})
Expand Down Expand Up @@ -120,12 +120,12 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
predList.emplace_back(pos, builder.getIsNotNull());

// Check that this is the correct root operation.
if (Optional<StringRef> opName = op.name())
if (Optional<StringRef> opName = op.getOpName())
predList.emplace_back(pos, builder.getOperationName(*opName));

// Check that the operation has the proper number of operands. If there are
// any variable length operands, we check a minimum instead of an exact count.
OperandRange operands = op.operands();
OperandRange operands = op.getOperandValues();
unsigned minOperands = getNumNonRangeValues(operands);
if (minOperands != operands.size()) {
if (minOperands)
Expand All @@ -136,19 +136,19 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,

// Check that the operation has the proper number of results. If there are
// any variable length results, we check a minimum instead of an exact count.
OperandRange types = op.types();
OperandRange types = op.getTypeValues();
unsigned minResults = getNumNonRangeValues(types);
if (minResults == types.size())
predList.emplace_back(pos, builder.getResultCount(types.size()));
else if (minResults)
predList.emplace_back(pos, builder.getResultCountAtLeast(minResults));

// Recurse into any attributes, operands, or results.
for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
for (auto [attrName, attr] :
llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) {
getTreePredicates(
predList, std::get<1>(it), builder, inputs,
builder.getAttribute(opPos,
std::get<0>(it).cast<StringAttr>().getValue()));
predList, attr, builder, inputs,
builder.getAttribute(opPos, attrName.cast<StringAttr>().getValue()));
}

// Process the operands and results of the operation. For all values up to
Expand Down Expand Up @@ -208,10 +208,10 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
TypePosition *pos) {
// Check for a constraint on a constant type.
if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) {
if (Attribute type = typeOp.typeAttr())
if (Attribute type = typeOp.getConstantTypeAttr())
predList.emplace_back(pos, builder.getTypeConstraint(type));
} else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) {
if (Attribute typeAttr = typeOp.typesAttr())
if (Attribute typeAttr = typeOp.getConstantTypesAttr())
predList.emplace_back(pos, builder.getTypeConstraint(typeAttr));
}
}
Expand Down Expand Up @@ -327,7 +327,7 @@ static void getNonTreePredicates(pdl::PatternOp pattern,
std::vector<PositionalPredicate> &predList,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs) {
for (Operation &op : pattern.body().getOps()) {
for (Operation &op : pattern.getBodyRegion().getOps()) {
TypeSwitch<Operation *>(&op)
.Case([&](pdl::AttributeOp attrOp) {
getAttributePredicates(attrOp, predList, builder, inputs);
Expand All @@ -340,11 +340,13 @@ static void getNonTreePredicates(pdl::PatternOp pattern,
})
.Case([&](pdl::TypeOp typeOp) {
getTypePredicates(
typeOp, [&] { return typeOp.typeAttr(); }, builder, inputs);
typeOp, [&] { return typeOp.getConstantTypeAttr(); }, builder,
inputs);
})
.Case([&](pdl::TypesOp typeOp) {
getTypePredicates(
typeOp, [&] { return typeOp.typesAttr(); }, builder, inputs);
typeOp, [&] { return typeOp.getConstantTypesAttr(); }, builder,
inputs);
});
}
}
Expand All @@ -369,8 +371,8 @@ static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
// First, collect all the operations that are used as operands
// to other operations. These are not roots by default.
DenseSet<Value> used;
for (auto operationOp : pattern.body().getOps<pdl::OperationOp>()) {
for (Value operand : operationOp.operands())
for (auto operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) {
for (Value operand : operationOp.getOperandValues())
TypeSwitch<Operation *>(operand.getDefiningOp())
.Case<pdl::ResultOp, pdl::ResultsOp>(
[&used](auto resultOp) { used.insert(resultOp.parent()); });
Expand All @@ -383,7 +385,7 @@ static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {

// Finally, collect all the unused operations.
SmallVector<Value> roots;
for (Value operationOp : pattern.body().getOps<pdl::OperationOp>())
for (Value operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>())
if (!used.contains(operationOp))
roots.push_back(operationOp);

Expand Down Expand Up @@ -451,7 +453,7 @@ static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
// are expensive to join on.
TypeSwitch<Operation *>(entry.value.getDefiningOp())
.Case<pdl::OperationOp>([&](auto operationOp) {
OperandRange operands = operationOp.operands();
OperandRange operands = operationOp.getOperandValues();
// Special case when we pass all the operands in one range.
// For those, the index is empty.
if (operands.size() == 1 &&
Expand All @@ -462,7 +464,8 @@ static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
}

// Default case: visit all the operands.
for (const auto &p : llvm::enumerate(operationOp.operands()))
for (const auto &p :
llvm::enumerate(operationOp.getOperandValues()))
toVisit.emplace(p.value(), entry.value, p.index(),
entry.depth + 1);
})
Expand Down Expand Up @@ -507,7 +510,7 @@ static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
/// Returns true if the operand at the given index needs to be queried using an
/// operand group, i.e., if it is variadic itself or follows a variadic operand.
static bool useOperandGroup(pdl::OperationOp op, unsigned index) {
OperandRange operands = op.operands();
OperandRange operands = op.getOperandValues();
assert(index < operands.size() && "operand index out of range");
for (unsigned i = 0; i <= index; ++i)
if (operands[i].getType().isa<pdl::RangeType>())
Expand Down Expand Up @@ -537,7 +540,7 @@ static void visitUpward(std::vector<PositionalPredicate> &predList,
operandPos = builder.getAllOperands(opPos);
} else if (useOperandGroup(operationOp, *opIndex.index)) {
// We are querying an operand group.
Type type = operationOp.operands()[*opIndex.index].getType();
Type type = operationOp.getOperandValues()[*opIndex.index].getType();
bool variadic = type.isa<pdl::RangeType>();
operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic);
} else {
Expand Down
26 changes: 14 additions & 12 deletions mlir/lib/Dialect/Async/IR/Async.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ LogicalResult YieldOp::verify() {
// Get the underlying value types from async values returned from the
// parent `async.execute` operation.
auto executeOp = (*this)->getParentOfType<ExecuteOp>();
auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) {
return result.getType().cast<ValueType>().getValueType();
});
auto types =
llvm::map_range(executeOp.bodyResults(), [](const OpResult &result) {
return result.getType().cast<ValueType>().getValueType();
});

if (getOperandTypes() != types)
return emitOpError("operand types do not match the types returned from "
Expand All @@ -61,7 +62,7 @@ constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";

OperandRange ExecuteOp::getSuccessorEntryOperands(Optional<unsigned> index) {
assert(index && *index == 0 && "invalid region index");
return operands();
return bodyOperands();
}

bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
Expand All @@ -79,12 +80,13 @@ void ExecuteOp::getSuccessorRegions(Optional<unsigned> index,
// The `body` region branch back to the parent operation.
if (index) {
assert(*index == 0 && "invalid region index");
regions.push_back(RegionSuccessor(results()));
regions.push_back(RegionSuccessor(bodyResults()));
return;
}

// Otherwise the successor is the body region.
regions.push_back(RegionSuccessor(&body(), body().getArguments()));
regions.push_back(
RegionSuccessor(&bodyRegion(), bodyRegion().getArguments()));
}

void ExecuteOp::build(OpBuilder &builder, OperationState &result,
Expand Down Expand Up @@ -138,10 +140,10 @@ void ExecuteOp::print(OpAsmPrinter &p) {
p << " [" << dependencies() << "]";

// (%value as %unwrapped: !async.value<!arg.type>, ...)
if (!operands().empty()) {
if (!bodyOperands().empty()) {
p << " (";
Block *entry = body().empty() ? nullptr : &body().front();
llvm::interleaveComma(operands(), p, [&, n = 0](Value operand) mutable {
Block *entry = bodyRegion().empty() ? nullptr : &bodyRegion().front();
llvm::interleaveComma(bodyOperands(), p, [&, n = 0](Value operand) mutable {
Value argument = entry ? entry->getArgument(n++) : Value();
p << operand << " as " << argument << ": " << operand.getType();
});
Expand All @@ -153,7 +155,7 @@ void ExecuteOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
{kOperandSegmentSizesAttr});
p << ' ';
p.printRegion(body(), /*printEntryBlockArgs=*/false);
p.printRegion(bodyRegion(), /*printEntryBlockArgs=*/false);
}

ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {
Expand Down Expand Up @@ -226,12 +228,12 @@ ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {

LogicalResult ExecuteOp::verifyRegions() {
// Unwrap async.execute value operands types.
auto unwrappedTypes = llvm::map_range(operands(), [](Value operand) {
auto unwrappedTypes = llvm::map_range(bodyOperands(), [](Value operand) {
return operand.getType().cast<ValueType>().getValueType();
});

// Verify that unwrapped argument types matches the body region arguments.
if (body().getArgumentTypes() != unwrappedTypes)
if (bodyRegion().getArgumentTypes() != unwrappedTypes)
return emitOpError("async body region argument types do not match the "
"execute operation arguments types");

Expand Down
13 changes: 7 additions & 6 deletions mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,14 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {

// Make sure that all constants will be inside the outlined async function to
// reduce the number of function arguments.
cloneConstantsIntoTheRegion(execute.body());
cloneConstantsIntoTheRegion(execute.bodyRegion());

// Collect all outlined function inputs.
SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
execute.dependencies().end());
functionInputs.insert(execute.operands().begin(), execute.operands().end());
getUsedValuesDefinedAbove(execute.body(), functionInputs);
functionInputs.insert(execute.bodyOperands().begin(),
execute.bodyOperands().end());
getUsedValuesDefinedAbove(execute.bodyRegion(), functionInputs);

// Collect types for the outlined function inputs and outputs.
auto typesRange = llvm::map_range(
Expand All @@ -279,7 +280,7 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
// Prepare for coroutine conversion by creating the body of the function.
{
size_t numDependencies = execute.dependencies().size();
size_t numOperands = execute.operands().size();
size_t numOperands = execute.bodyOperands().size();

// Await on all dependencies before starting to execute the body region.
for (size_t i = 0; i < numDependencies; ++i)
Expand All @@ -296,11 +297,11 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
// arguments.
BlockAndValueMapping valueMapping;
valueMapping.map(functionInputs, func.getArguments());
valueMapping.map(execute.body().getArguments(), unwrappedOperands);
valueMapping.map(execute.bodyRegion().getArguments(), unwrappedOperands);

// Clone all operations from the execute operation body into the outlined
// function body.
for (Operation &op : execute.body().getOps())
for (Operation &op : execute.bodyRegion().getOps())
builder.clone(op, valueMapping);
}

Expand Down
15 changes: 6 additions & 9 deletions mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
// Ignore launch ops with missing attributes here. The errors will be
// reported by the verifiers of those ops.
if (!launchOp->getAttrOfType<SymbolRefAttr>(
LaunchFuncOp::getKernelAttrName()))
LaunchFuncOp::getKernelAttrName(launchOp->getName())))
return success();

// Check that `launch_func` refers to a well-formed GPU kernel module.
Expand Down Expand Up @@ -703,7 +703,7 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
auto kernelSymbol =
SymbolRefAttr::get(kernelModule.getNameAttr(),
{SymbolRefAttr::get(kernelFunc.getNameAttr())});
result.addAttribute(getKernelAttrName(), kernelSymbol);
result.addAttribute(getKernelAttrName(result.name), kernelSymbol);
SmallVector<int32_t, 9> segmentSizes(9, 1);
segmentSizes.front() = asyncDependencies.size();
segmentSizes[segmentSizes.size() - 2] = dynamicSharedMemorySize ? 1 : 0;
Expand All @@ -718,9 +718,11 @@ StringAttr LaunchFuncOp::getKernelModuleName() {

StringAttr LaunchFuncOp::getKernelName() { return kernel().getLeafReference(); }

unsigned LaunchFuncOp::getNumKernelOperands() { return operands().size(); }
unsigned LaunchFuncOp::getNumKernelOperands() {
return kernelOperands().size();
}

Value LaunchFuncOp::getKernelOperand(unsigned i) { return operands()[i]; }
Value LaunchFuncOp::getKernelOperand(unsigned i) { return kernelOperands()[i]; }

KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
auto operands = getOperands().drop_front(asyncDependencies().size());
Expand All @@ -743,11 +745,6 @@ LogicalResult LaunchFuncOp::verify() {
GPUDialect::getContainerModuleAttrName() +
"' attribute");

auto kernelAttr = (*this)->getAttrOfType<SymbolRefAttr>(getKernelAttrName());
if (!kernelAttr)
return emitOpError("symbol reference attribute '" + getKernelAttrName() +
"' must be specified");

return success();
}

Expand Down
14 changes: 7 additions & 7 deletions mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp,
OpBuilder builder(executeOp);
auto newOp = builder.create<async::ExecuteOp>(
executeOp.getLoc(), TypeRange{resultTypes}.drop_front() /*drop token*/,
executeOp.dependencies(), executeOp.operands());
executeOp.dependencies(), executeOp.bodyOperands());
BlockAndValueMapping mapper;
newOp.getRegion().getBlocks().clear();
executeOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
Expand Down Expand Up @@ -258,7 +258,7 @@ struct GpuAsyncRegionPass::DeferWaitCallback {
// Set `it` to the beginning of the region and add asyncTokens to the
// async.execute operands.
it = executeOp.getBody()->begin();
executeOp.operandsMutable().append(asyncTokens);
executeOp.bodyOperandsMutable().append(asyncTokens);
SmallVector<Type, 1> tokenTypes(
asyncTokens.size(), builder.getType<gpu::AsyncTokenType>());
SmallVector<Location, 1> tokenLocs(asyncTokens.size(),
Expand Down Expand Up @@ -301,7 +301,7 @@ struct GpuAsyncRegionPass::SingleTokenUseCallback {
void operator()(async::ExecuteOp executeOp) {
// Extract !gpu.async.token results which have multiple uses.
auto multiUseResults =
llvm::make_filter_range(executeOp.results(), [](OpResult result) {
llvm::make_filter_range(executeOp.bodyResults(), [](OpResult result) {
if (result.use_empty() || result.hasOneUse())
return false;
auto valueType = result.getType().dyn_cast<async::ValueType>();
Expand All @@ -319,16 +319,16 @@ struct GpuAsyncRegionPass::SingleTokenUseCallback {
});

for (auto index : indices) {
assert(!executeOp.results()[index].getUses().empty());
assert(!executeOp.bodyResults()[index].getUses().empty());
// Repeat async.yield token result, one for each use after the first one.
auto uses = llvm::drop_begin(executeOp.results()[index].getUses());
auto uses = llvm::drop_begin(executeOp.bodyResults()[index].getUses());
auto count = std::distance(uses.begin(), uses.end());
auto yieldOp = cast<async::YieldOp>(executeOp.getBody()->getTerminator());
SmallVector<Value, 4> operands(count, yieldOp.getOperand(index));
executeOp = addExecuteResults(executeOp, operands);
// Update 'uses' to refer to the new executeOp.
uses = llvm::drop_begin(executeOp.results()[index].getUses());
auto results = executeOp.results().take_back(count);
uses = llvm::drop_begin(executeOp.bodyResults()[index].getUses());
auto results = executeOp.bodyResults().take_back(count);
for (auto pair : llvm::zip(uses, results))
std::get<0>(pair).set(std::get<1>(pair));
}
Expand Down
10 changes: 5 additions & 5 deletions mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();

if (executionMapping != acc::OpenACCExecMapping::NONE)
result.addAttribute(LoopOp::getExecutionMappingAttrName(),
result.addAttribute(LoopOp::getExecutionMappingAttrStrName(),
builder.getI64IntegerAttr(executionMapping));

// Parse optional results in case there is a reduce.
Expand Down Expand Up @@ -662,16 +662,16 @@ void LoopOp::print(OpAsmPrinter &printer) {
/*printBlockTerminators=*/true);

printer.printOptionalAttrDictWithKeyword(
(*this)->getAttrs(), {LoopOp::getExecutionMappingAttrName(),
(*this)->getAttrs(), {LoopOp::getExecutionMappingAttrStrName(),
LoopOp::getOperandSegmentSizeAttr()});
}

LogicalResult acc::LoopOp::verify() {
// auto, independent and seq attribute are mutually exclusive.
if ((auto_() && (independent() || seq())) || (independent() && seq())) {
return emitError("only one of " + acc::LoopOp::getAutoAttrName() + ", " +
acc::LoopOp::getIndependentAttrName() + ", " +
acc::LoopOp::getSeqAttrName() +
return emitError("only one of " + acc::LoopOp::getAutoAttrStrName() + ", " +
acc::LoopOp::getIndependentAttrStrName() + ", " +
acc::LoopOp::getSeqAttrStrName() +
" can be present at the same time");
}

Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ LogicalResult ReductionOp::verify() {
"reduction clause interface";
while (op) {
for (const auto &var :
cast<ReductionClauseInterface>(op).getReductionVars())
cast<ReductionClauseInterface>(op).getAllReductionVars())
if (var == accumulator())
return success();
op = op->getParentWithTrait<ReductionClauseInterface::Trait>();
Expand All @@ -689,7 +689,7 @@ LogicalResult TaskGroupOp::verify() {
//===----------------------------------------------------------------------===//
// TaskLoopOp
//===----------------------------------------------------------------------===//
SmallVector<Value> TaskLoopOp::getReductionVars() {
SmallVector<Value> TaskLoopOp::getAllReductionVars() {
SmallVector<Value> allReductionNvars(in_reduction_vars().begin(),
in_reduction_vars().end());
allReductionNvars.insert(allReductionNvars.end(), reduction_vars().begin(),
Expand Down
34 changes: 18 additions & 16 deletions mlir/lib/Dialect/PDL/IR/PDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ static void visit(Operation *op, DenseSet<Operation *> &visited) {
// Traverse the operands / parent.
TypeSwitch<Operation *>(op)
.Case<OperationOp>([&visited](auto operation) {
for (Value operand : operation.operands())
for (Value operand : operation.getOperandValues())
visit(operand.getDefiningOp(), visited);
})
.Case<ResultOp, ResultsOp>([&visited](auto result) {
Expand Down Expand Up @@ -111,7 +111,7 @@ LogicalResult ApplyNativeRewriteOp::verify() {
//===----------------------------------------------------------------------===//

LogicalResult AttributeOp::verify() {
Value attrType = type();
Value attrType = getValueType();
Optional<Attribute> attrValue = value();

if (!attrValue) {
Expand Down Expand Up @@ -189,7 +189,7 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
if (!replOpUser || use.getOperandNumber() == 0)
return false;
// Make sure the replaced operation was defined before this one.
Operation *replacedOp = replOpUser.operation().getDefiningOp();
Operation *replacedOp = replOpUser.getOpValue().getDefiningOp();
return replacedOp->getBlock() != rewriterBlock ||
replacedOp->isBeforeInBlock(op);
};
Expand All @@ -203,7 +203,7 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
if (resultTypes.empty()) {
// If we don't know the concrete operation, don't attempt any verification.
// We can't make assumptions if we don't know the concrete operation.
Optional<StringRef> rawOpName = op.name();
Optional<StringRef> rawOpName = op.getOpName();
if (!rawOpName)
return success();
Optional<RegisteredOperationName> opName =
Expand Down Expand Up @@ -246,10 +246,12 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
isa<OperandOp, OperandsOp, OperationOp>(user);
};
if (TypeOp typeOp = dyn_cast<TypeOp>(resultTypeOp)) {
if (typeOp.type() || llvm::any_of(typeOp->getUsers(), constrainsInput))
if (typeOp.getConstantType() ||
llvm::any_of(typeOp->getUsers(), constrainsInput))
continue;
} else if (TypesOp typeOp = dyn_cast<TypesOp>(resultTypeOp)) {
if (typeOp.types() || llvm::any_of(typeOp->getUsers(), constrainsInput))
if (typeOp.getConstantTypes() ||
llvm::any_of(typeOp->getUsers(), constrainsInput))
continue;
}

Expand All @@ -264,11 +266,11 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,

LogicalResult OperationOp::verify() {
bool isWithinRewrite = isa<RewriteOp>((*this)->getParentOp());
if (isWithinRewrite && !name())
if (isWithinRewrite && !getOpName())
return emitOpError("must have an operation name when nested within "
"a `pdl.rewrite`");
ArrayAttr attributeNames = attributeNamesAttr();
auto attributeValues = attributes();
ArrayAttr attributeNames = getAttributeValueNamesAttr();
auto attributeValues = getAttributeValues();
if (attributeNames.size() != attributeValues.size()) {
return emitOpError()
<< "expected the same number of attribute values and attribute "
Expand All @@ -280,23 +282,23 @@ LogicalResult OperationOp::verify() {
// If the operation is within a rewrite body and doesn't have type inference,
// ensure that the result types can be resolved.
if (isWithinRewrite && !mightHaveTypeInference()) {
if (failed(verifyResultTypesAreInferrable(*this, types())))
if (failed(verifyResultTypesAreInferrable(*this, getTypeValues())))
return failure();
}

return verifyHasBindingUse(*this);
}

bool OperationOp::hasTypeInference() {
if (Optional<StringRef> rawOpName = name()) {
if (Optional<StringRef> rawOpName = getOpName()) {
OperationName opName(*rawOpName, getContext());
return opName.hasInterface<InferTypeOpInterface>();
}
return false;
}

bool OperationOp::mightHaveTypeInference() {
if (Optional<StringRef> rawOpName = name()) {
if (Optional<StringRef> rawOpName = getOpName()) {
OperationName opName(*rawOpName, getContext());
return opName.mightHaveInterface<InferTypeOpInterface>();
}
Expand Down Expand Up @@ -387,7 +389,7 @@ void PatternOp::build(OpBuilder &builder, OperationState &state,

/// Returns the rewrite operation of this pattern.
RewriteOp PatternOp::getRewriter() {
return cast<RewriteOp>(body().front().getTerminator());
return cast<RewriteOp>(getBodyRegion().front().getTerminator());
}

/// The default dialect is `pdl`.
Expand Down Expand Up @@ -441,7 +443,7 @@ LogicalResult ResultsOp::verify() {
//===----------------------------------------------------------------------===//

LogicalResult RewriteOp::verifyRegions() {
Region &rewriteRegion = body();
Region &rewriteRegion = getBodyRegion();

// Handle the case where the rewrite is external.
if (name()) {
Expand Down Expand Up @@ -477,7 +479,7 @@ StringRef RewriteOp::getDefaultDialect() {
//===----------------------------------------------------------------------===//

LogicalResult TypeOp::verify() {
if (!typeAttr())
if (!getConstantTypeAttr())
return verifyHasBindingUse(*this);
return success();
}
Expand All @@ -487,7 +489,7 @@ LogicalResult TypeOp::verify() {
//===----------------------------------------------------------------------===//

LogicalResult TypesOp::verify() {
if (!typesAttr())
if (!getConstantTypesAttr())
return verifyHasBindingUse(*this);
return success();
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr,
pdl::RewriteOp rewrite =
builder.create<pdl::RewriteOp>(loc, rootExpr, /*name=*/StringAttr(),
/*externalArgs=*/ValueRange());
builder.createBlock(&rewrite.body());
builder.createBlock(&rewrite.getBodyRegion());
}
}

Expand Down
38 changes: 20 additions & 18 deletions mlir/python/mlir/dialects/_pdl_ops_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,14 @@ class AttributeOp:
"""Specialization for PDL attribute op class."""

def __init__(self,
type: Optional[Union[OpView, Operation, Value]] = None,
valueType: Optional[Union[OpView, Operation, Value]] = None,
value: Optional[Attribute] = None,
*,
loc=None,
ip=None):
type = type if type is None else _get_value(type)
valueType = valueType if valueType is None else _get_value(valueType)
result = pdl.AttributeType.get()
super().__init__(result, type=type, value=value, loc=loc, ip=ip)
super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)


class EraseOp:
Expand All @@ -118,7 +118,7 @@ def __init__(self,
ip=None):
type = type if type is None else _get_value(type)
result = pdl.ValueType.get()
super().__init__(result, type=type, loc=loc, ip=ip)
super().__init__(result, valueType=type, loc=loc, ip=ip)


class OperandsOp:
Expand All @@ -131,7 +131,7 @@ def __init__(self,
ip=None):
types = types if types is None else _get_value(types)
result = pdl.RangeType.get(pdl.ValueType.get())
super().__init__(result, type=types, loc=loc, ip=ip)
super().__init__(result, valueType=types, loc=loc, ip=ip)


class OperationOp:
Expand All @@ -147,15 +147,15 @@ def __init__(self,
ip=None):
name = name if name is None else _get_str_attr(name)
args = _get_values(args)
attributeNames = []
attributeValues = []
attrNames = []
attrValues = []
for attrName, attrValue in attributes.items():
attributeNames.append(StringAttr.get(attrName))
attributeValues.append(_get_value(attrValue))
attributeNames = ArrayAttr.get(attributeNames)
attrNames.append(StringAttr.get(attrName))
attrValues.append(_get_value(attrValue))
attrNames = ArrayAttr.get(attrNames)
types = _get_values(types)
result = pdl.OperationType.get()
super().__init__(result, args, attributeValues, attributeNames, types, name=name, loc=loc, ip=ip)
super().__init__(result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip)


class PatternOp:
Expand Down Expand Up @@ -255,24 +255,26 @@ class TypeOp:
"""Specialization for PDL type op class."""

def __init__(self,
type: Optional[Union[TypeAttr, Type]] = None,
constantType: Optional[Union[TypeAttr, Type]] = None,
*,
loc=None,
ip=None):
type = type if type is None else _get_type_attr(type)
constantType = constantType if constantType is None else _get_type_attr(
constantType)
result = pdl.TypeType.get()
super().__init__(result, type=type, loc=loc, ip=ip)
super().__init__(result, constantType=constantType, loc=loc, ip=ip)


class TypesOp:
"""Specialization for PDL types op class."""

def __init__(self,
types: Sequence[Union[TypeAttr, Type]] = [],
constantTypes: Sequence[Union[TypeAttr, Type]] = [],
*,
loc=None,
ip=None):
types = _get_array_attr([_get_type_attr(ty) for ty in types])
types = None if not types else types
constantTypes = _get_array_attr(
[_get_type_attr(ty) for ty in constantTypes])
constantTypes = None if not constantTypes else constantTypes
result = pdl.RangeType.get(pdl.TypeType.get())
super().__init__(result, types=types, loc=loc, ip=ip)
super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
12 changes: 6 additions & 6 deletions mlir/test/Dialect/Async/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ func.func @empty_async_execute() -> !async.token {
// CHECK-LABEL: @return_async_value
func.func @return_async_value() -> !async.value<f32> {
// CHECK: async.execute -> !async.value<f32>
%token, %results = async.execute -> !async.value<f32> {
%token, %bodyResults = async.execute -> !async.value<f32> {
%cst = arith.constant 1.000000e+00 : f32
async.yield %cst : f32
}

// CHECK: return %results : !async.value<f32>
return %results : !async.value<f32>
// CHECK: return %bodyResults : !async.value<f32>
return %bodyResults : !async.value<f32>
}

// CHECK-LABEL: @return_captured_value
Expand All @@ -49,14 +49,14 @@ func.func @return_captured_value() -> !async.token {

// CHECK-LABEL: @return_async_values
func.func @return_async_values() -> (!async.value<f32>, !async.value<f32>) {
%token, %results:2 = async.execute -> (!async.value<f32>, !async.value<f32>) {
%token, %bodyResults:2 = async.execute -> (!async.value<f32>, !async.value<f32>) {
%cst1 = arith.constant 1.000000e+00 : f32
%cst2 = arith.constant 2.000000e+00 : f32
async.yield %cst1, %cst2 : f32, f32
}

// CHECK: return %results#0, %results#1 : !async.value<f32>, !async.value<f32>
return %results#0, %results#1 : !async.value<f32>, !async.value<f32>
// CHECK: return %bodyResults#0, %bodyResults#1 : !async.value<f32>, !async.value<f32>
return %bodyResults#0, %bodyResults#1 : !async.value<f32>, !async.value<f32>
}

// CHECK-LABEL: @async_token_dependencies
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/GPU/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ func.func @reduce_incorrect_yield(%arg0 : f32) {
// -----

func.func @shuffle_mismatching_type(%arg0 : f32, %arg1 : i32, %arg2 : i32) {
// expected-error@+1 {{op failed to verify that all of {value, result} have same type}}
// expected-error@+1 {{op failed to verify that all of {value, shuffleResult} have same type}}
%shfl, %pred = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = #gpu<shuffle_mode xor> } : (f32, i32, i32) -> (i32, i1)
return
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/PDL/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ pdl.pattern : benefit(1) {
pdl.pattern : benefit(1) {
// expected-error@below {{expected the same number of attribute values and attribute names, got 1 names and 0 values}}
%op = "pdl.operation"() {
attributeNames = ["attr"],
attributeValueNames = ["attr"],
operand_segment_sizes = array<i32: 0, 0, 0>
} : () -> (!pdl.operation)
rewrite %op with "rewriter"
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ static Value warpReduction(Location loc, OpBuilder &builder, Value input,
.create<gpu::ShuffleOp>(loc, laneVal, i,
/*width=*/size,
/*mode=*/gpu::ShuffleMode::XOR)
.result();
.getShuffleResult();
laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
}
return laneVal;
Expand Down