Skip to content

Commit 5759d94

Browse files
committed
Revert "Apply shortened printing/parsing form to linalg.reduce."
This reverts commit 281c2d4. This broke the windows mlir buildbot: https://lab.llvm.org/buildbot/#/builders/13/builds/30167
1 parent 828b476 commit 5759d94

File tree

4 files changed

+14
-61
lines changed

4 files changed

+14
-61
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -255,16 +255,6 @@ def MapOp : LinalgStructuredBase_Op<"map", [
255255
linalg.yield %0: f32
256256
}
257257
```
258-
259-
Shortened print form is available. Applies to simple maps with one
260-
non-yield operation inside the body.
261-
262-
The example above will be printed as:
263-
```
264-
%add = linalg.map { arith.addf }
265-
ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
266-
outs(%init: tensor<64xf32>)
267-
```
268258
}];
269259

270260
let arguments = (ins
@@ -339,22 +329,10 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
339329
outs(%init:tensor<16x64xf32>)
340330
dimensions = [1]
341331
(%in: f32, %out: f32) {
342-
%0 = arith.addf %out, %in: f32
332+
%0 = arith.addf %in, %out: f32
343333
linalg.yield %0: f32
344334
}
345335
```
346-
347-
Shortened print form is available. Applies to simple (not variadic) reduces
348-
with one non-yield operation inside the body. Applies only if the operation
349-
takes `%out` as the first argument.
350-
351-
The example above will be printed as:
352-
```
353-
%reduce = linalg.reduce { arith.addf }
354-
ins(%input:tensor<16x32x64xf32>)
355-
outs(%init:tensor<16x64xf32>)
356-
dimensions = [1]
357-
```
358336
}];
359337

360338
let arguments = (ins

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 9 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,8 +1016,7 @@ void MapOp::build(
10161016
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
10171017
const OperationName &payloadOpName,
10181018
const NamedAttrList &payloadOpAttrs,
1019-
ArrayRef<Value> operands,
1020-
bool initFirst = false) {
1019+
ArrayRef<Value> operands) {
10211020
OpBuilder b(parser.getContext());
10221021
Region *body = result.addRegion();
10231022
Block &block = body->emplaceBlock();
@@ -1027,24 +1026,14 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
10271026
block.addArgument(operand.getType().cast<ShapedType>().getElementType(),
10281027
b.getUnknownLoc());
10291028
}
1030-
SmallVector<Value> payloadOpOperands;
1031-
// If initFirst flag is enabled, we consider init as the first position of
1032-
// payload operands.
1033-
if (initFirst) {
1034-
payloadOpOperands.push_back(block.getArguments().back());
1035-
for (const auto& arg : block.getArguments().drop_back())
1036-
payloadOpOperands.push_back(arg);
1037-
} else {
1038-
payloadOpOperands = {block.getArguments().begin(),
1039-
block.getArguments().end()};
1040-
}
10411029

10421030
Operation *payloadOp = b.create(
10431031
result.location, b.getStringAttr(payloadOpName.getStringRef()),
1044-
payloadOpOperands,
1032+
block.getArguments(),
10451033
TypeRange{
10461034
result.operands.back().getType().cast<ShapedType>().getElementType()},
10471035
payloadOpAttrs);
1036+
10481037
b.create<YieldOp>(result.location, payloadOp->getResults());
10491038
}
10501039

@@ -1083,9 +1072,7 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
10831072

10841073
// Retrieve the operation from the body, if it is the only one (except
10851074
// yield) and if it gets the same amount of arguments as the body does.
1086-
// If initFirst flag is enabled, we check that init takes the first position in
1087-
// operands of payload.
1088-
static Operation *findPayloadOp(Block *body, bool initFirst = false) {
1075+
static Operation *findPayloadOp(Block *body) {
10891076
if (body->getOperations().size() != 2)
10901077
return nullptr;
10911078
Operation &payload = body->getOperations().front();
@@ -1094,22 +1081,10 @@ static Operation *findPayloadOp(Block *body, bool initFirst = false) {
10941081
if (payload.getNumOperands() == 0 ||
10951082
payload.getNumOperands() != body->getNumArguments())
10961083
return nullptr;
1097-
if (initFirst) {
1098-
// check init
1099-
if (payload.getOperands().back() != body->getArgument(0))
1084+
for (const auto &[bbArg, operand] :
1085+
llvm::zip(payload.getOperands(), body->getArguments())) {
1086+
if (bbArg != operand)
11001087
return nullptr;
1101-
// check rest
1102-
for (int i = 1; i < body->getNumArguments(); ++i) {
1103-
if (payload.getOperand(i - 1) != body->getArgument(i)) {
1104-
return nullptr;
1105-
}
1106-
}
1107-
} else {
1108-
for (const auto &[bbArg, operand] :
1109-
llvm::zip(payload.getOperands(), body->getArguments())) {
1110-
if (bbArg != operand)
1111-
return nullptr;
1112-
}
11131088
}
11141089
return &payload;
11151090
}
@@ -1308,7 +1283,7 @@ ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
13081283

13091284
if (payloadOpName.has_value()) {
13101285
addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1311-
makeArrayRef(result.operands), /*initFirst=*/true);
1286+
makeArrayRef(result.operands));
13121287
} else {
13131288
SmallVector<OpAsmParser::Argument> regionArgs;
13141289
if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
@@ -1331,7 +1306,7 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
13311306

13321307
void ReduceOp::print(OpAsmPrinter &p) {
13331308
Block *mapper = getBody();
1334-
Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
1309+
Operation *payloadOp = findPayloadOp(mapper);
13351310
if (payloadOp) {
13361311
printShortForm(p, payloadOp);
13371312
}

mlir/test/Dialect/Linalg/one-shot-bufferize.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
363363
outs(%init:tensor<16x64xf32>)
364364
dimensions = [1]
365365
(%in: f32, %out: f32) {
366-
%0 = arith.addf %out, %in: f32
366+
%0 = arith.addf %in, %out: f32
367367
linalg.yield %0: f32
368368
}
369369
func.return %reduce : tensor<16x64xf32>

mlir/test/Dialect/Linalg/roundtrip.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
414414
outs(%init:tensor<16x64xf32>)
415415
dimensions = [1]
416416
(%in: f32, %out: f32) {
417-
%0 = arith.addf %out, %in: f32
417+
%0 = arith.addf %in, %out: f32
418418
linalg.yield %0: f32
419419
}
420420
func.return %reduce : tensor<16x64xf32>
@@ -433,7 +433,7 @@ func.func @reduce_memref(%input: memref<16x32x64xf32>,
433433
outs(%init:memref<16x64xf32>)
434434
dimensions = [1]
435435
(%in: f32, %out: f32) {
436-
%0 = arith.addf %out, %in: f32
436+
%0 = arith.addf %in, %out: f32
437437
linalg.yield %0: f32
438438
}
439439
func.return
@@ -587,7 +587,7 @@ func.func @reduce_arith_with_attr(%input: tensor<16x32x64xf32>,
587587
outs(%init:tensor<16x64xf32>)
588588
dimensions = [1]
589589
(%in: f32, %out: f32) {
590-
%0 = arith.addf %out, %in fastmath<fast> : f32
590+
%0 = arith.addf %in, %out fastmath<fast> : f32
591591
linalg.yield %0: f32
592592
}
593593
func.return %reduce : tensor<16x64xf32>

0 commit comments

Comments
 (0)