@@ -1016,8 +1016,7 @@ void MapOp::build(
1016
1016
static void addBodyWithPayloadOp (OpAsmParser &parser, OperationState &result,
1017
1017
const OperationName &payloadOpName,
1018
1018
const NamedAttrList &payloadOpAttrs,
1019
- ArrayRef<Value> operands,
1020
- bool initFirst = false ) {
1019
+ ArrayRef<Value> operands) {
1021
1020
OpBuilder b (parser.getContext ());
1022
1021
Region *body = result.addRegion ();
1023
1022
Block &block = body->emplaceBlock ();
@@ -1027,24 +1026,14 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
1027
1026
block.addArgument (operand.getType ().cast <ShapedType>().getElementType (),
1028
1027
b.getUnknownLoc ());
1029
1028
}
1030
- SmallVector<Value> payloadOpOperands;
1031
- // If initFirst flag is enabled, we consider init as the first position of
1032
- // payload operands.
1033
- if (initFirst) {
1034
- payloadOpOperands.push_back (block.getArguments ().back ());
1035
- for (const auto & arg : block.getArguments ().drop_back ())
1036
- payloadOpOperands.push_back (arg);
1037
- } else {
1038
- payloadOpOperands = {block.getArguments ().begin (),
1039
- block.getArguments ().end ()};
1040
- }
1041
1029
1042
1030
Operation *payloadOp = b.create (
1043
1031
result.location , b.getStringAttr (payloadOpName.getStringRef ()),
1044
- payloadOpOperands ,
1032
+ block. getArguments () ,
1045
1033
TypeRange{
1046
1034
result.operands .back ().getType ().cast <ShapedType>().getElementType ()},
1047
1035
payloadOpAttrs);
1036
+
1048
1037
b.create <YieldOp>(result.location , payloadOp->getResults ());
1049
1038
}
1050
1039
@@ -1083,9 +1072,7 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1083
1072
1084
1073
// Retrieve the operation from the body, if it is the only one (except
1085
1074
// yield) and if it gets the same amount of arguments as the body does.
1086
- // If initFirst flag is enabled, we check that init takes the first position in
1087
- // operands of payload.
1088
- static Operation *findPayloadOp (Block *body, bool initFirst = false ) {
1075
+ static Operation *findPayloadOp (Block *body) {
1089
1076
if (body->getOperations ().size () != 2 )
1090
1077
return nullptr ;
1091
1078
Operation &payload = body->getOperations ().front ();
@@ -1094,22 +1081,10 @@ static Operation *findPayloadOp(Block *body, bool initFirst = false) {
1094
1081
if (payload.getNumOperands () == 0 ||
1095
1082
payload.getNumOperands () != body->getNumArguments ())
1096
1083
return nullptr ;
1097
- if (initFirst) {
1098
- // check init
1099
- if (payload. getOperands (). back () != body-> getArgument ( 0 ) )
1084
+ for ( const auto &[bbArg, operand] :
1085
+ llvm::zip (payload. getOperands (), body-> getArguments ())) {
1086
+ if (bbArg != operand )
1100
1087
return nullptr ;
1101
- // check rest
1102
- for (int i = 1 ; i < body->getNumArguments (); ++i) {
1103
- if (payload.getOperand (i - 1 ) != body->getArgument (i)) {
1104
- return nullptr ;
1105
- }
1106
- }
1107
- } else {
1108
- for (const auto &[bbArg, operand] :
1109
- llvm::zip (payload.getOperands (), body->getArguments ())) {
1110
- if (bbArg != operand)
1111
- return nullptr ;
1112
- }
1113
1088
}
1114
1089
return &payload;
1115
1090
}
@@ -1308,7 +1283,7 @@ ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1308
1283
1309
1284
if (payloadOpName.has_value ()) {
1310
1285
addBodyWithPayloadOp (parser, result, payloadOpName.value (), payloadOpAttrs,
1311
- makeArrayRef (result.operands ), /* initFirst= */ true );
1286
+ makeArrayRef (result.operands ));
1312
1287
} else {
1313
1288
SmallVector<OpAsmParser::Argument> regionArgs;
1314
1289
if (parser.parseArgumentList (regionArgs, OpAsmParser::Delimiter::Paren,
@@ -1331,7 +1306,7 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1331
1306
1332
1307
void ReduceOp::print (OpAsmPrinter &p) {
1333
1308
Block *mapper = getBody ();
1334
- Operation *payloadOp = findPayloadOp (mapper, /* initFirst= */ true );
1309
+ Operation *payloadOp = findPayloadOp (mapper);
1335
1310
if (payloadOp) {
1336
1311
printShortForm (p, payloadOp);
1337
1312
}
0 commit comments