Skip to content

Commit

Permalink
[mlir] Update how scalable indices are printed
Browse files Browse the repository at this point in the history
This patch makes sure that scalable indices (that would normally
represent scalable tile or vector sizes) are printed correctly, i.e.
with additional square brackets:
```
%1, %loop = transform.structured.tile %0 [2, 8, [4]]
```

This change complements https://reviews.llvm.org/D150944 and is a part
of a larger effort to enable scalable vectorisation in Linalg. See this
RFC for more context:
  * https://discourse.llvm.org/t/rfc-scalable-vectorisation-in-linalg/

Differential Revision: https://reviews.llvm.org/D151978
  • Loading branch information
banach-space committed Jun 2, 2023
1 parent 5c2072e commit 726835c
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 8 deletions.
7 changes: 6 additions & 1 deletion mlir/include/mlir/Interfaces/ViewLikeInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,15 @@ namespace mlir {
/// indicating their types. This allows idiomatic printing of mixed value and
/// integer attributes in a list. E.g.
/// `[%arg0 : index, 7, 42, %arg42 : i32]`.
///
/// If `isTrailingIdxScalable` is true, then wrap the trailing index with
/// square brackets, e.g. `[42]`, to denote scalability. This would normally be
/// used for scalable tile or vector sizes.
void printDynamicIndexList(
OpAsmPrinter &printer, Operation *op, OperandRange values,
ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square,
bool isTrailingIdxScalable = false);

/// Parser hook for custom directive in assemblyFormat.
///
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2555,7 +2555,9 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,

void TileOp::print(OpAsmPrinter &p) {
p << ' ' << getTarget();
printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes());
printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(),
/*valueTypes=*/{}, OpAsmParser::Delimiter::Square,
getLastTileSizeScalable());
printOptionalInterchange(p, getInterchange());
p << " : ";
p.printFunctionalType(getOperands().getTypes(), getResults().getTypes());
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1262,7 +1262,7 @@ ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
if (succeeded(parser.parseOptionalKeyword("in"))) {
// Parse upper bounds.
if (parseDynamicIndexList(
parser, dynamicUbs, staticUbs, /*scalable=*/nullptr,
parser, dynamicUbs, staticUbs, /*isTrailingIdxScalable=*/nullptr,
/*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(dynamicUbs, indexType, result.operands))
return failure();
Expand All @@ -1274,7 +1274,7 @@ ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse lower bounds.
if (parser.parseEqual() ||
parseDynamicIndexList(
parser, dynamicLbs, staticLbs, /*scalable=*/nullptr,
parser, dynamicLbs, staticLbs, /*isTrailingIdxScalable=*/nullptr,
/*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||

parser.resolveOperands(dynamicLbs, indexType, result.operands))
Expand All @@ -1283,7 +1283,7 @@ ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse upper bounds.
if (parser.parseKeyword("to") ||
parseDynamicIndexList(
parser, dynamicUbs, staticUbs, /*scalable=*/nullptr,
parser, dynamicUbs, staticUbs, /*isTrailingIdxScalable=*/nullptr,
/*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(dynamicUbs, indexType, result.operands))
return failure();
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Transform/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,6 @@ ParseResult mlir::transform::parsePackedOrDynamicIndexList(
return success();
}

return parseDynamicIndexList(parser, values, integers, /*scalable=*/nullptr,
&valueTypes);
return parseDynamicIndexList(parser, values, integers,
/*isTrailingIdxScalable=*/nullptr, &valueTypes);
}
20 changes: 19 additions & 1 deletion mlir/lib/Interfaces/ViewLikeInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,23 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
OperandRange values,
ArrayRef<int64_t> integers,
TypeRange valueTypes,
AsmParser::Delimiter delimiter) {
AsmParser::Delimiter delimiter,
bool isTrailingIdxScalable) {
char leftDelimiter = getLeftDelimiter(delimiter);
char rightDelimiter = getRightDelimiter(delimiter);
printer << leftDelimiter;
if (integers.empty()) {
printer << rightDelimiter;
return;
}

int64_t trailingScalableInteger;
if (isTrailingIdxScalable) {
// ATM only the trailing idx can be scalable
trailingScalableInteger = integers.back();
integers = integers.drop_back();
}

unsigned idx = 0;
llvm::interleaveComma(integers, printer, [&](int64_t integer) {
if (ShapedType::isDynamic(integer)) {
Expand All @@ -122,6 +131,15 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
printer << integer;
}
});

// Print the trailing scalable index
if (isTrailingIdxScalable) {
printer << ", ";
printer << "[";
printer << trailingScalableInteger;
printer << "]";
}

printer << rightDelimiter;
}

Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Dialect/Transform/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,11 @@ transform.sequence failures(propagate) {
transform.print %arg0 {name = "test"} : !transform.any_op
transform.print {name = "test"}
}

// CHECK: transform.sequence
// CHECK: transform.structured.tile %0[4, 4, [4]]
transform.sequence failures(propagate) {
^bb0(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.tile %0 [4, 4, [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
}

0 comments on commit 726835c

Please sign in to comment.