Skip to content

Commit

Permalink
[mlir] Remove iterator_types() method from LinalgStructuredInterface.
Browse files Browse the repository at this point in the history
`getIteratorTypesArray` should be used instead. It's a better substitute for all the current usages of the interface.

The current `ArrayAttr iterator_types()` has a few problems:
* It creates an assumption operation has iterators types as an attribute, but it's not always the case. Sometime iterator types can be inferred from other attribute, or they're just static.
* ArrayAttr is an obscure contained and required extracting values in the client code.
* Makes it hard to migrate iterator types from strings to enums ([RFC](https://discourse.llvm.org/t/rfc-enumattr-for-iterator-types-in-linalg/64535/9)).

Concrete ops, like `linalg.generic` will still have iterator types as an attribute if needed.

As a side effect, this change helps a bit with migration to prefixed accessors.

Differential Revision: https://reviews.llvm.org/D135765
  • Loading branch information
olegshyshkov committed Oct 13, 2022
1 parent c5d950f commit c38d9cf
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 42 deletions.
22 changes: 6 additions & 16 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
Expand Up @@ -497,28 +497,21 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
return $_op.getBody();
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the iterator types attribute within the current operation.
}],
/*retTy=*/"ArrayAttr",
/*methodName=*/"iterator_types",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return $_op.getIteratorTypes();
}]
>,
InterfaceMethod<
/*desc=*/[{
Return iterator types in the current operation.

Default implementation assumes that the operation has an attribute
`iterator_types`, but it's not always the case. Sometimes iterator types
can be infered from other parameters and in such cases default
getIteratorTypesArray should be overriden.
}],
/*retTy=*/"SmallVector<StringRef>",
/*methodName=*/"getIteratorTypesArray",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto range = $_op.iterator_types().template getAsValueRange<StringAttr>();
auto range = $_op.getIteratorTypes().template getAsValueRange<StringAttr>();
return {range.begin(), range.end()};
}]
>,
Expand Down Expand Up @@ -773,9 +766,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
LogicalResult reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes);

// TODO: Remove once prefixing is flipped.
ArrayAttr getIteratorTypes() { return iterator_types(); }

SmallVector<StringRef> getIteratorTypeNames() {
return getIteratorTypesArray();
}
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Expand Up @@ -264,7 +264,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [

let extraClassDeclaration = structuredOpsBaseDecls # [{
// Implement functions necessary for LinalgStructuredInterface.
ArrayAttr getIteratorTypes();
SmallVector<StringRef> getIteratorTypesArray();
ArrayAttr getIndexingMaps();
std::string getLibraryCallName() {
return "op_has_no_registered_library_name";
Expand Down Expand Up @@ -334,7 +334,7 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [

let extraClassDeclaration = structuredOpsBaseDecls # [{
// Declare functions necessary for LinalgStructuredInterface.
ArrayAttr getIteratorTypes();
SmallVector<StringRef> getIteratorTypesArray();
ArrayAttr getIndexingMaps();
std::string getLibraryCallName() {
return "op_has_no_registered_library_name";
Expand Down
10 changes: 4 additions & 6 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Expand Up @@ -1393,11 +1393,9 @@ LogicalResult MapOp::verify() {
return success();
}

ArrayAttr MapOp::getIteratorTypes() {
SmallVector<StringRef> MapOp::getIteratorTypesArray() {
int64_t rank = getInit().getType().getRank();
return Builder(getContext())
.getStrArrayAttr(
SmallVector<StringRef>(rank, getParallelIteratorTypeName()));
return SmallVector<StringRef>(rank, getParallelIteratorTypeName());
}

ArrayAttr MapOp::getIndexingMaps() {
Expand Down Expand Up @@ -1435,13 +1433,13 @@ void ReduceOp::getAsmResultNames(
setNameFn(getResults().front(), "reduced");
}

ArrayAttr ReduceOp::getIteratorTypes() {
SmallVector<StringRef> ReduceOp::getIteratorTypesArray() {
int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
SmallVector<StringRef> iteratorTypes(inputRank,
getParallelIteratorTypeName());
for (int64_t reductionDim : getDimensions())
iteratorTypes[reductionDim] = getReductionIteratorTypeName();
return Builder(getContext()).getStrArrayAttr(iteratorTypes);
return iteratorTypes;
}

ArrayAttr ReduceOp::getIndexingMaps() {
Expand Down
8 changes: 3 additions & 5 deletions mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
Expand Up @@ -92,11 +92,9 @@ struct LinalgOpTilingInterface
/// Return the loop iterator type.
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
return llvm::to_vector(
llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) {
return utils::symbolizeIteratorType(
strAttr.cast<StringAttr>().getValue())
.value();
return llvm::to_vector(llvm::map_range(
concreteOp.getIteratorTypesArray(), [](StringRef iteratorType) {
return utils::symbolizeIteratorType(iteratorType).value();
}));
}

Expand Down
Expand Up @@ -250,7 +250,7 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
// Fuse producer and consumer into a new generic op.
auto fusedOp = rewriter.create<GenericOp>(
loc, op.getResult(0).getType(), inputOps, outputOps,
rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.iterator_types(),
rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(),
/*doc=*/nullptr, /*library_call=*/nullptr);
Block &prodBlock = prod.getRegion().front();
Block &consBlock = op.getRegion().front();
Expand Down
Expand Up @@ -1857,7 +1857,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
if (op.getNumOutputs() != 1)
return failure();
unsigned numTensors = op.getNumInputsAndOutputs();
unsigned numLoops = op.iterator_types().getValue().size();
unsigned numLoops = op.getNumLoops();
Merger merger(numTensors, numLoops);
if (!findSparseAnnotations(merger, op))
return failure();
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/lib/Dialect/Test/TestOps.td
Expand Up @@ -2816,7 +2816,7 @@ def TestLinalgConvOp :
return &regionBuilder;
}

mlir::ArrayAttr iterator_types() {
mlir::ArrayAttr getIteratorTypes() {
return getOperation()->getAttrOfType<mlir::ArrayAttr>("iterator_types");
}

Expand Down Expand Up @@ -2875,7 +2875,7 @@ def TestLinalgFillOp :
return &regionBuilder;
}

mlir::ArrayAttr iterator_types() {
mlir::ArrayAttr getIteratorTypes() {
return getOperation()->getAttrOfType<mlir::ArrayAttr>("iterator_types");
}

Expand Down
Expand Up @@ -235,7 +235,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: value

# IMPL: Test3Op::iterator_types() {
# IMPL: Test3Op::getIteratorTypesArray() {
# IMPL-NEXT: int64_t rank = getRank(getOutputOperand(0));

# IMPL: Test3Op::getIndexingMaps() {
Expand Down
16 changes: 8 additions & 8 deletions mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
Expand Up @@ -553,7 +553,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
let extraClassDeclaration = structuredOpsBaseDecls # [{{
// Auto-generated.
ArrayAttr iterator_types();
SmallVector<StringRef> getIteratorTypesArray();
ArrayAttr getIndexingMaps();
static void regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ArrayRef<NamedAttribute> attrs);
Expand Down Expand Up @@ -587,24 +587,24 @@ static const char structuredOpBuilderFormat[] = R"FMT(
}]>
)FMT";

// The iterator_types() method for structured ops. Parameters:
// The getIteratorTypesArray() method for structured ops. Parameters:
// {0}: Class name
// {1}: Comma interleaved iterator type names.
static const char structuredOpIteratorTypesFormat[] =
R"FMT(
ArrayAttr {0}::iterator_types() {{
return Builder(getContext()).getStrArrayAttr(SmallVector<StringRef>{{ {1} });
SmallVector<StringRef> {0}::getIteratorTypesArray() {{
return SmallVector<StringRef>{{ {1} };
}
)FMT";

// The iterator_types() method for rank polymorphic structured ops. Parameters:
// The getIteratorTypesArray() method for rank polymorphic structured ops.
// Parameters:
// {0}: Class name
static const char rankPolyStructuredOpIteratorTypesFormat[] =
R"FMT(
ArrayAttr {0}::iterator_types() {{
SmallVector<StringRef> {0}::getIteratorTypesArray() {{
int64_t rank = getRank(getOutputOperand(0));
return Builder(getContext()).getStrArrayAttr(
SmallVector<StringRef>(rank, getParallelIteratorTypeName()));
return SmallVector<StringRef>(rank, getParallelIteratorTypeName());
}
)FMT";

Expand Down

0 comments on commit c38d9cf

Please sign in to comment.