24 changes: 19 additions & 5 deletions mlir/examples/toy/Ch5/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
/// or `false` on success. This allows for easily chaining together a set of
/// parser rules. These rules are used to populate an `mlir::OperationState`
/// similarly to the `build` methods described above.
static mlir::ParseResult parseConstantOp(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
mlir::DenseElementsAttr value;
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.parseAttribute(value, "value", result.attributes))
Expand All @@ -176,10 +176,10 @@ static mlir::ParseResult parseConstantOp(mlir::OpAsmParser &parser,

/// The 'OpAsmPrinter' class is a stream that allows for formatting
/// strings, attributes, operands, types, etc.
static void print(mlir::OpAsmPrinter &printer, ConstantOp op) {
void ConstantOp::print(mlir::OpAsmPrinter &printer) {
printer << " ";
printer.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"});
printer << op.value();
printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
printer << value();
}

/// Verifier for the constant operation. This corresponds to the
Expand Down Expand Up @@ -221,6 +221,13 @@ void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
state.addOperands({lhs, rhs});
}

mlir::ParseResult AddOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
return parseBinaryOp(parser, result);
}

void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }

/// Infer the output shape of the AddOp, this is required by the shape inference
/// interface.
void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
Expand Down Expand Up @@ -278,6 +285,13 @@ void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
state.addOperands({lhs, rhs});
}

mlir::ParseResult MulOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
return parseBinaryOp(parser, result);
}

void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }

/// Infer the output shape of the MulOp, this is required by the shape inference
/// interface.
void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
Expand Down
15 changes: 6 additions & 9 deletions mlir/examples/toy/Ch6/include/toy/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,8 @@ def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
// The constant operation returns a single value of TensorType.
let results = (outs F64Tensor);

// Specify a parser and printer method.
let parser = [{ return ::parseConstantOp(parser, result); }];
let printer = [{ return ::print(p, *this); }];
// Indicate that the operation has a custom parser and printer method.
let hasCustomAssemblyFormat = 1;

// Add custom build methods for the constant operation. These method populates
// the `state` that MLIR uses to create operations, i.e. these are used when
Expand Down Expand Up @@ -93,9 +92,8 @@ def AddOp : Toy_Op<"add",
let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
let results = (outs F64Tensor);

// Specify a parser and printer method.
let parser = [{ return ::parseBinaryOp(parser, result); }];
let printer = [{ return ::printBinaryOp(p, *this); }];
// Indicate that the operation has a custom parser and printer method.
let hasCustomAssemblyFormat = 1;

// Allow building an AddOp with from the two input operands.
let builders = [
Expand Down Expand Up @@ -171,9 +169,8 @@ def MulOp : Toy_Op<"mul",
let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
let results = (outs F64Tensor);

// Specify a parser and printer method.
let parser = [{ return ::parseBinaryOp(parser, result); }];
let printer = [{ return ::printBinaryOp(p, *this); }];
// Indicate that the operation has a custom parser and printer method.
let hasCustomAssemblyFormat = 1;

// Allow building a MulOp with from the two input operands.
let builders = [
Expand Down
24 changes: 19 additions & 5 deletions mlir/examples/toy/Ch6/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
/// or `false` on success. This allows for easily chaining together a set of
/// parser rules. These rules are used to populate an `mlir::OperationState`
/// similarly to the `build` methods described above.
static mlir::ParseResult parseConstantOp(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
mlir::DenseElementsAttr value;
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.parseAttribute(value, "value", result.attributes))
Expand All @@ -176,10 +176,10 @@ static mlir::ParseResult parseConstantOp(mlir::OpAsmParser &parser,

/// The 'OpAsmPrinter' class is a stream that allows for formatting
/// strings, attributes, operands, types, etc.
static void print(mlir::OpAsmPrinter &printer, ConstantOp op) {
void ConstantOp::print(mlir::OpAsmPrinter &printer) {
printer << " ";
printer.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"});
printer << op.value();
printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
printer << value();
}

/// Verifier for the constant operation. This corresponds to the
Expand Down Expand Up @@ -221,6 +221,13 @@ void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
state.addOperands({lhs, rhs});
}

mlir::ParseResult AddOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
return parseBinaryOp(parser, result);
}

void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }

/// Infer the output shape of the AddOp, this is required by the shape inference
/// interface.
void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
Expand Down Expand Up @@ -278,6 +285,13 @@ void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
state.addOperands({lhs, rhs});
}

mlir::ParseResult MulOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
return parseBinaryOp(parser, result);
}

void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }

/// Infer the output shape of the MulOp, this is required by the shape inference
/// interface.
void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
Expand Down
15 changes: 6 additions & 9 deletions mlir/examples/toy/Ch7/include/toy/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,8 @@ def ConstantOp : Toy_Op<"constant",
// The constant operation returns a single value of TensorType.
let results = (outs F64Tensor);

// Specify a parser and printer method.
let parser = [{ return ::parseConstantOp(parser, result); }];
let printer = [{ return ::print(p, *this); }];
// Indicate that the operation has a custom parser and printer method.
let hasCustomAssemblyFormat = 1;

// Add custom build methods for the constant operation. These method populates
// the `state` that MLIR uses to create operations, i.e. these are used when
Expand Down Expand Up @@ -112,9 +111,8 @@ def AddOp : Toy_Op<"add",
let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
let results = (outs F64Tensor);

// Specify a parser and printer method.
let parser = [{ return ::parseBinaryOp(parser, result); }];
let printer = [{ return ::printBinaryOp(p, *this); }];
// Indicate that the operation has a custom parser and printer method.
let hasCustomAssemblyFormat = 1;

// Allow building an AddOp with from the two input operands.
let builders = [
Expand Down Expand Up @@ -191,9 +189,8 @@ def MulOp : Toy_Op<"mul",
let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
let results = (outs F64Tensor);

// Specify a parser and printer method.
let parser = [{ return ::parseBinaryOp(parser, result); }];
let printer = [{ return ::printBinaryOp(p, *this); }];
// Indicate that the operation has a custom parser and printer method.
let hasCustomAssemblyFormat = 1;

// Allow building a MulOp with from the two input operands.
let builders = [
Expand Down
24 changes: 19 additions & 5 deletions mlir/examples/toy/Ch7/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
/// or `false` on success. This allows for easily chaining together a set of
/// parser rules. These rules are used to populate an `mlir::OperationState`
/// similarly to the `build` methods described above.
static mlir::ParseResult parseConstantOp(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
mlir::DenseElementsAttr value;
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.parseAttribute(value, "value", result.attributes))
Expand All @@ -163,10 +163,10 @@ static mlir::ParseResult parseConstantOp(mlir::OpAsmParser &parser,

/// The 'OpAsmPrinter' class is a stream that allows for formatting
/// strings, attributes, operands, types, etc.
static void print(mlir::OpAsmPrinter &printer, ConstantOp op) {
void ConstantOp::print(mlir::OpAsmPrinter &printer) {
printer << " ";
printer.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"});
printer << op.value();
printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
printer << value();
}

/// Verify that the given attribute value is valid for the given type.
Expand Down Expand Up @@ -248,6 +248,13 @@ void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
state.addOperands({lhs, rhs});
}

mlir::ParseResult AddOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
return parseBinaryOp(parser, result);
}

void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }

/// Infer the output shape of the AddOp, this is required by the shape inference
/// interface.
void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
Expand Down Expand Up @@ -305,6 +312,13 @@ void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
state.addOperands({lhs, rhs});
}

mlir::ParseResult MulOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
return parseBinaryOp(parser, result);
}

void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }

/// Infer the output shape of the MulOp, this is required by the shape inference
/// interface.
void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
Expand Down
22 changes: 11 additions & 11 deletions mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,7 @@ def Affine_Dialect : Dialect {

// Base class for Affine dialect ops.
class Affine_Op<string mnemonic, list<Trait> traits = []> :
Op<Affine_Dialect, mnemonic, traits> {
// For every affine op, there needs to be a:
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
// * ParseResult parse${C++ class of Op}(OpAsmParser &parser,
// OperationState &result)
// functions.
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
Op<Affine_Dialect, mnemonic, traits>;

// Require regions to have affine.yield.
def ImplicitAffineTerminator
Expand Down Expand Up @@ -109,6 +101,7 @@ def AffineApplyOp : Affine_Op<"apply", [NoSideEffect]> {
}];

let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
Expand Down Expand Up @@ -348,6 +341,7 @@ def AffineForOp : Affine_Op<"for",
}];

let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
Expand Down Expand Up @@ -472,6 +466,7 @@ def AffineIfOp : Affine_Op<"if",
}];

let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
Expand Down Expand Up @@ -538,6 +533,7 @@ def AffineLoadOp : AffineLoadOpBase<"load"> {
let extraClassDeclaration = extraClassDeclarationBase;

let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
Expand Down Expand Up @@ -567,8 +563,7 @@ class AffineMinMaxOpBase<string mnemonic, list<Trait> traits = []> :
operands().end()};
}
}];
let printer = [{ return ::printAffineMinMaxOp(p, *this); }];
let parser = [{ return ::parseAffineMinMaxOp<$cppClass>(parser, result); }];
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasCanonicalizer = 1;
let hasVerifier = 1;
Expand Down Expand Up @@ -754,6 +749,7 @@ def AffineParallelOp : Affine_Op<"parallel",
}
}];

let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
Expand Down Expand Up @@ -834,6 +830,7 @@ def AffinePrefetchOp : Affine_Op<"prefetch",
}];

let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
Expand Down Expand Up @@ -899,6 +896,7 @@ def AffineStoreOp : AffineStoreOpBase<"store"> {
let extraClassDeclaration = extraClassDeclarationBase;

let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
Expand Down Expand Up @@ -990,6 +988,7 @@ def AffineVectorLoadOp : AffineLoadOpBase<"vector_load"> {
}];

let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -1055,6 +1054,7 @@ def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> {
}];

let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down
9 changes: 1 addition & 8 deletions mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,6 @@ class Arith_CastOp<string mnemonic, TypeConstraint From, TypeConstraint To,
DeclareOpInterfaceMethods<CastOpInterface>]>,
Arguments<(ins From:$in)>,
Results<(outs To:$out)> {
let builders = [
OpBuilder<(ins "Value":$source, "Type":$destType), [{
impl::buildCastOp($_builder, $_state, source, destType);
}]>
];

let assemblyFormat = "$in attr-dict `:` type($in) `to` type($out)";
}

Expand Down Expand Up @@ -1208,8 +1202,7 @@ def SelectOp : Arith_Op<"select", [
let hasVerifier = 1;

// FIXME: Switch this to use the declarative assembly format.
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
let hasCustomAssemblyFormat = 1;
}

#endif // ARITHMETIC_OPS
4 changes: 1 addition & 3 deletions mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ def Async_ExecuteOp :
Variadic<Async_ValueType>:$results);
let regions = (region SizedRegion<1>:$body);

let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];

let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
let hasVerifier = 1;
let builders = [
Expand Down
3 changes: 1 addition & 2 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,7 @@ def EmitC_IncludeOp
Arg<StrAttr, "source file to include">:$include,
UnitAttr:$is_standard_include
);
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
let hasCustomAssemblyFormat = 1;
}

#endif // MLIR_DIALECT_EMITC_IR_EMITC
10 changes: 3 additions & 7 deletions mlir/include/mlir/Dialect/GPU/GPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,7 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
/// Verifies the body of the function.
LogicalResult verifyBody();
}];

let printer = [{ printGPUFuncOp(p, *this); }];
let parser = [{ return parseGPUFuncOp(parser, result); }];
let hasCustomAssemblyFormat = 1;
}

def GPU_LaunchFuncOp : GPU_Op<"launch_func",
Expand Down Expand Up @@ -556,9 +554,8 @@ def GPU_LaunchOp : GPU_Op<"launch">,
static constexpr unsigned kNumConfigRegionAttributes = 12;
}];

let parser = [{ return parseLaunchOp(parser, result); }];
let printer = [{ printLaunchOp(p, *this); }];
let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -783,9 +780,8 @@ def GPU_GPUModuleOp : GPU_Op<"module", [
```
}];
let builders = [OpBuilder<(ins "StringRef":$name)>];
let parser = [{ return ::parseGPUModuleOp(parser, result); }];
let printer = [{ return ::print(p, *this); }];
let regions = (region SizedRegion<1>:$body);
let hasCustomAssemblyFormat = 1;

// We need to ensure the block inside the region is properly terminated;
// the auto-generated builders do not guarantee that.
Expand Down
60 changes: 20 additions & 40 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,7 @@ def LLVM_ICmpOp : LLVM_Op<"icmp", [NoSideEffect]> {
build($_builder, $_state, IntegerType::get(lhs.getType().getContext(), 1),
predicate, lhs, rhs);
}]>];
let parser = [{ return parseCmpOp<ICmpPredicate>(parser, result); }];
let printer = [{ printICmpOp(p, *this); }];
let hasCustomAssemblyFormat = 1;
}

// Predicate for float comparisons
Expand Down Expand Up @@ -246,8 +245,7 @@ def LLVM_FCmpOp : LLVM_Op<"fcmp", [
let llvmBuilder = [{
$res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
}];
let parser = [{ return parseCmpOp<FCmpPredicate>(parser, result); }];
let printer = [{ printFCmpOp(p, *this); }];
let hasCustomAssemblyFormat = 1;
}

// Floating point binary operations.
Expand Down Expand Up @@ -312,8 +310,7 @@ def LLVM_AllocaOp : LLVM_Op<"alloca">, MemoryOpWithAlignmentBase {
build($_builder, $_state, resultType, arraySize,
$_builder.getI64IntegerAttr(alignment));
}]>];
let parser = [{ return parseAllocaOp(parser, result); }];
let printer = [{ printAllocaOp(p, *this); }];
let hasCustomAssemblyFormat = 1;
}

def LLVM_GEPOp : LLVM_Op<"getelementptr", [NoSideEffect]> {
Expand Down Expand Up @@ -382,8 +379,7 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {
OpBuilder<(ins "Type":$t, "Value":$addr,
CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile,
CArg<"bool", "false">:$isNonTemporal)>];
let parser = [{ return parseLoadOp(parser, result); }];
let printer = [{ printLoadOp(p, *this); }];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand All @@ -406,8 +402,7 @@ def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpWithAlignmentAndAttributes {
CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile,
CArg<"bool", "false">:$isNonTemporal)>
];
let parser = [{ return parseStoreOp(parser, result); }];
let printer = [{ printStoreOp(p, *this); }];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand All @@ -419,8 +414,7 @@ class LLVM_CastOp<string mnemonic, string builderFunc, Type type,
let arguments = (ins type:$arg);
let results = (outs resultType:$res);
let builders = [LLVM_OneResultOpBuilder];
let parser = [{ return mlir::impl::parseCastOp(parser, result); }];
let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }];
let assemblyFormat = "$arg attr-dict `:` type($arg) `to` type($res)";
}
def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "CreateBitCast",
LLVM_AnyNonAggregate, LLVM_AnyNonAggregate> {
Expand Down Expand Up @@ -492,17 +486,15 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
build($_builder, $_state, tys, /*callee=*/FlatSymbolRefAttr(), ops, normalOps,
unwindOps, normal, unwind);
}]>];
let parser = [{ return parseInvokeOp(parser, result); }];
let printer = [{ printInvokeOp(p, *this); }];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
let arguments = (ins UnitAttr:$cleanup, Variadic<LLVM_Type>);
let results = (outs LLVM_Type:$res);
let builders = [LLVM_OneResultOpBuilder];
let parser = [{ return parseLandingpadOp(parser, result); }];
let printer = [{ printLandingpadOp(p, *this); }];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -560,8 +552,7 @@ def LLVM_CallOp : LLVM_Op<"call",
build($_builder, $_state, results,
StringAttr::get($_builder.getContext(), callee), operands);
}]>];
let parser = [{ return parseCallOp(parser, result); }];
let printer = [{ printCallOp(p, *this); }];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect]> {
Expand All @@ -573,8 +564,7 @@ def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect]> {
let builders = [
OpBuilder<(ins "Value":$vector, "Value":$position,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
let parser = [{ return parseExtractElementOp(parser, result); }];
let printer = [{ printExtractElementOp(p, *this); }];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [NoSideEffect]> {
Expand All @@ -584,8 +574,7 @@ def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [NoSideEffect]> {
$res = builder.CreateExtractValue($container, extractPosition($position));
}];
let builders = [LLVM_OneResultOpBuilder];
let parser = [{ return parseExtractValueOp(parser, result); }];
let printer = [{ printExtractValueOp(p, *this); }];
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
Expand All @@ -597,8 +586,7 @@ def LLVM_InsertElementOp : LLVM_Op<"insertelement", [NoSideEffect]> {
$res = builder.CreateInsertElement($vector, $value, $position);
}];
let builders = [LLVM_OneResultOpBuilder];
let parser = [{ return parseInsertElementOp(parser, result); }];
let printer = [{ printInsertElementOp(p, *this); }];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
def LLVM_InsertValueOp : LLVM_Op<"insertvalue", [NoSideEffect]> {
Expand All @@ -614,8 +602,7 @@ def LLVM_InsertValueOp : LLVM_Op<"insertvalue", [NoSideEffect]> {
[{
build($_builder, $_state, container.getType(), container, value, position);
}]>];
let parser = [{ return parseInsertValueOp(parser, result); }];
let printer = [{ printInsertValueOp(p, *this); }];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector", [NoSideEffect]> {
Expand All @@ -629,8 +616,7 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector", [NoSideEffect]> {
let builders = [
OpBuilder<(ins "Value":$v1, "Value":$v2, "ArrayAttr":$mask,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
let parser = [{ return parseShuffleVectorOp(parser, result); }];
let printer = [{ printShuffleVectorOp(p, *this); }];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -709,8 +695,7 @@ def LLVM_ReturnOp : LLVM_TerminatorOp<"return", [NoSideEffect]> {
builder.CreateRetVoid();
}];

let parser = [{ return parseReturnOp(parser, result); }];
let printer = [{ printReturnOp(p, *this); }];
let assemblyFormat = "attr-dict ($args^ `:` type($args))?";
let hasVerifier = 1;
}
def LLVM_ResumeOp : LLVM_TerminatorOp<"resume", []> {
Expand Down Expand Up @@ -1152,8 +1137,7 @@ def LLVM_GlobalOp : LLVM_Op<"mlir.global",
}
}];

let printer = "printGlobalOp(p, *this);";
let parser = "return parseGlobalOp(parser, result);";
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -1289,8 +1273,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
LogicalResult verifyType();
}];

let printer = [{ printLLVMFuncOp(p, *this); }];
let parser = [{ return parseLLVMFuncOp(parser, result); }];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -1825,8 +1808,7 @@ def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw"> {
llvm::MaybeAlign(),
getLLVMAtomicOrdering($ordering));
}];
let parser = [{ return parseAtomicRMWOp(parser, result); }];
let printer = [{ printAtomicRMWOp(p, *this); }];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -1855,8 +1837,7 @@ def LLVM_AtomicCmpXchgOp : LLVM_Op<"cmpxchg"> {
getLLVMAtomicOrdering($success_ordering),
getLLVMAtomicOrdering($failure_ordering));
}];
let parser = [{ return parseAtomicCmpXchgOp(parser, result); }];
let printer = [{ printAtomicCmpXchgOp(p, *this); }];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand All @@ -1878,8 +1859,7 @@ def LLVM_FenceOp : LLVM_Op<"fence"> {
builder.CreateFence(getLLVMAtomicOrdering($ordering),
llvmContext.getOrInsertSyncScopeID($syncscope));
}];
let parser = [{ return parseFenceOp(parser, result); }];
let printer = [{ printFenceOp(p, *this); }];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down
3 changes: 1 addition & 2 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ def NVVM_VoteBallotOp :
$res = createIntrinsicCall(builder,
llvm::Intrinsic::nvvm_vote_ballot_sync, {$mask, $pred});
}];
let parser = [{ return parseNVVMVoteBallotOp(parser, result); }];
let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
let hasCustomAssemblyFormat = 1;
}


Expand Down
14 changes: 2 additions & 12 deletions mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,7 @@ def ROCDL_MubufLoadOp :
llvm::Intrinsic::amdgcn_buffer_load, {$rsrc, $vindex, $offset, $glc,
$slc}, {$_resultType});
}];
let parser = [{ return parseROCDLMubufLoadOp(parser, result); }];
let printer = [{
Operation *op = this->getOperation();
p << " " << op->getOperands()
<< " : " << op->getResultTypes();
}];
let hasCustomAssemblyFormat = 1;
}

def ROCDL_MubufStoreOp :
Expand All @@ -181,12 +176,7 @@ def ROCDL_MubufStoreOp :
llvm::Intrinsic::amdgcn_buffer_store, {$vdata, $rsrc, $vindex,
$offset, $glc, $slc}, {vdataType});
}];
let parser = [{ return parseROCDLMubufStoreOp(parser, result); }];
let printer = [{
Operation *op = this->getOperation();
p << " " << op->getOperands()
<< " : " << vdata().getType();
}];
let hasCustomAssemblyFormat = 1;
}

#endif // ROCDLIR_OPS
13 changes: 4 additions & 9 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,7 @@ include "mlir/Interfaces/ViewLikeInterface.td"

// Base class for Linalg dialect ops that do not correspond to library calls.
class Linalg_Op<string mnemonic, list<Trait> traits = []> :
Op<Linalg_Dialect, mnemonic, traits> {
// For every linalg op, there needs to be a:
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
// * ParseResult parse${C++ class of Op}(OpAsmParser &parser,
// OperationState &result)
// functions.
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
Op<Linalg_Dialect, mnemonic, traits>;

def Linalg_InitTensorOp : Linalg_Op<"init_tensor",
[NoSideEffect,
Expand Down Expand Up @@ -123,6 +115,7 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor",
];

let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand All @@ -141,6 +134,7 @@ def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>,
```
}];
let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -423,6 +417,7 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
}];

let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
Expand Down
4 changes: 1 addition & 3 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,8 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [AttrSizedOperandSegments]> {
}
}];

let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parseGenericOp(parser, result); }];

let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
Expand Down
22 changes: 8 additions & 14 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@ def MemRefTypeAttr
}

class MemRef_Op<string mnemonic, list<Trait> traits = []>
: Op<MemRef_Dialect, mnemonic, traits> {
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
: Op<MemRef_Dialect, mnemonic, traits>;

// Base class for ops with static/dynamic offset, sizes and strides
// attributes/arguments.
Expand Down Expand Up @@ -275,6 +272,7 @@ def MemRef_AllocaScopeOp : MemRef_Op<"alloca_scope",

let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$bodyRegion);
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -374,11 +372,6 @@ def MemRef_CastOp : MemRef_Op<"cast", [
let arguments = (ins AnyRankedOrUnrankedMemRef:$source);
let results = (outs AnyRankedOrUnrankedMemRef:$dest);
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
let builders = [
OpBuilder<(ins "Value":$source, "Type":$destType), [{
impl::buildCastOp($_builder, $_state, source, destType);
}]>
];

let extraClassDeclaration = [{
/// Fold the given CastOp into consumer op.
Expand All @@ -388,7 +381,6 @@ def MemRef_CastOp : MemRef_Op<"cast", [
}];

let hasFolder = 1;
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -663,6 +655,7 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
return getOperand(getNumOperands() - 1);
}
}];
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
Expand Down Expand Up @@ -777,6 +770,7 @@ def GenericAtomicRMWOp : MemRef_Op<"generic_atomic_rmw", [
return memref().getType().cast<MemRefType>();
}
}];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -1002,6 +996,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
static StringRef getIsDataCacheAttrName() { return "isDataCache"; }
}];

let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
Expand Down Expand Up @@ -1054,8 +1049,6 @@ def MemRef_ReinterpretCastOp
attr-dict `:` type($source) `to` type($result)
}];

let parser = ?;
let printer = ?;
let hasVerifier = 1;

let builders = [
Expand Down Expand Up @@ -1249,9 +1242,8 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :

let hasFolder = 1;
let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parseReshapeLikeOp(parser, result); }];
}

def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape"> {
Expand Down Expand Up @@ -1706,6 +1698,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [NoSideEffect]>,
ShapedType getShapedType() { return in().getType().cast<ShapedType>(); }
}];

let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
Expand Down Expand Up @@ -1776,6 +1769,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [
}];

let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down
8 changes: 3 additions & 5 deletions mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@ include "mlir/Dialect/OpenACC/AccCommon.td"

// Base class for OpenACC dialect ops.
class OpenACC_Op<string mnemonic, list<Trait> traits = []> :
Op<OpenACC_Dialect, mnemonic, traits> {

let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
Op<OpenACC_Dialect, mnemonic, traits>;

// Reduction operation enumeration.
def OpenACC_ReductionOpAdd : I32EnumAttrCase<"redop_add", 0>;
Expand Down Expand Up @@ -152,6 +148,7 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
/// The i-th data operand passed.
Value getDataOperand(unsigned i);
}];
let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -404,6 +401,7 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
static StringRef getPrivateKeyword() { return "private"; }
static StringRef getReductionKeyword() { return "reduction"; }
}];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down
24 changes: 8 additions & 16 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,7 @@ def ParallelOp : OpenMP_Op<"parallel", [AttrSizedOperandSegments,
let builders = [
OpBuilder<(ins CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
];
let parser = [{ return parseParallelOp(parser, result); }];
let printer = [{ return printParallelOp(p, *this); }];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -215,8 +214,7 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments]> {

let regions = (region SizedRegion<1>:$region);

let parser = [{ return parseSectionsOp(parser, result); }];
let printer = [{ return printSectionsOp(p, *this); }];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -334,8 +332,7 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
/// Returns the number of reduction variables.
unsigned getNumReductionVars() { return reduction_vars().size(); }
}];
let parser = [{ return parseWsLoopOp(parser, result); }];
let printer = [{ return printWsLoopOp(p, *this); }];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -419,8 +416,7 @@ def TargetOp : OpenMP_Op<"target",[AttrSizedOperandSegments]> {

let regions = (region AnyRegion:$region);

let parser = [{ return parseTargetOp(parser, result); }];
let printer = [{ return printTargetOp(p, *this); }];
let hasCustomAssemblyFormat = 1;
}


Expand Down Expand Up @@ -608,8 +604,7 @@ def AtomicReadOp : OpenMP_Op<"atomic.read"> {
OpenMP_PointerLikeType:$v,
DefaultValuedAttr<I64Attr, "0">:$hint,
OptionalAttr<MemoryOrderKindAttr>:$memory_order);
let parser = [{ return parseAtomicReadOp(parser, result); }];
let printer = [{ return printAtomicReadOp(p, *this); }];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -637,8 +632,7 @@ def AtomicWriteOp : OpenMP_Op<"atomic.write"> {
AnyType:$value,
DefaultValuedAttr<I64Attr, "0">:$hint,
OptionalAttr<MemoryOrderKindAttr>:$memory_order);
let parser = [{ return parseAtomicWriteOp(parser, result); }];
let printer = [{ return printAtomicWriteOp(p, *this); }];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -702,8 +696,7 @@ def AtomicUpdateOp : OpenMP_Op<"atomic.update"> {
AtomicBinOpKindAttr:$binop,
DefaultValuedAttr<I64Attr, "0">:$hint,
OptionalAttr<MemoryOrderKindAttr>:$memory_order);
let parser = [{ return parseAtomicUpdateOp(parser, result); }];
let printer = [{ return printAtomicUpdateOp(p, *this); }];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -746,8 +739,7 @@ def AtomicCaptureOp : OpenMP_Op<"atomic.capture",
let arguments = (ins DefaultValuedAttr<I64Attr, "0">:$hint,
OptionalAttr<MemoryOrderKind>:$memory_order);
let regions = (region SizedRegion<1>:$region);
let parser = [{ return parseAtomicCaptureOp(parser, result); }];
let printer = [{ return printAtomicCaptureOp(p, *this); }];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down
5 changes: 1 addition & 4 deletions mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@ include "mlir/IR/SymbolInterfaces.td"
//===----------------------------------------------------------------------===//

class PDL_Op<string mnemonic, list<Trait> traits = []>
: Op<PDL_Dialect, mnemonic, traits> {
let printer = [{ ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
: Op<PDL_Dialect, mnemonic, traits>;

//===----------------------------------------------------------------------===//
// pdl::ApplyNativeConstraintOp
Expand Down
3 changes: 1 addition & 2 deletions mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,7 @@ def PDLInterp_ForEachOp
/// Returns the loop variable.
BlockArgument getLoopVariable() { return region().getArgument(0); }
}];
let parser = [{ return ::parseForEachOp(parser, result); }];
let printer = [{ return ::print(p, *this); }];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down
20 changes: 11 additions & 9 deletions mlir/include/mlir/Dialect/SCF/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,11 @@ def SCF_Dialect : Dialect {

// Base class for SCF dialect ops.
class SCF_Op<string mnemonic, list<Trait> traits = []> :
Op<SCF_Dialect, mnemonic, traits> {
// For every standard op, there needs to be a:
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
// * ParseResult parse${C++ class of Op}(OpAsmParser &parser,
// OperationState &result)
// functions.
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
Op<SCF_Dialect, mnemonic, traits>;

//===----------------------------------------------------------------------===//
// ConditionOp
//===----------------------------------------------------------------------===//

def ConditionOp : SCF_Op<"condition", [
HasParent<"WhileOp">,
Expand Down Expand Up @@ -107,6 +103,7 @@ def ExecuteRegionOp : SCF_Op<"execute_region"> {
let regions = (region AnyRegion:$region);

let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;

let hasFolder = 0;
let hasVerifier = 1;
Expand Down Expand Up @@ -308,6 +305,7 @@ def ForOp : SCF_Op<"for",
}];

let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -401,6 +399,7 @@ def IfOp : SCF_Op<"if",
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -483,6 +482,7 @@ def ParallelOp : SCF_Op<"parallel",
}];

let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -531,6 +531,7 @@ def ReduceOp : SCF_Op<"reduce", [HasParent<"ParallelOp">]> {
];

let arguments = (ins AnyType:$operand);
let hasCustomAssemblyFormat = 1;
let regions = (region SizedRegion<1>:$reductionOperator);
let hasVerifier = 1;
}
Expand Down Expand Up @@ -684,6 +685,7 @@ def WhileOp : SCF_Op<"while",
}];

let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class SPV_ArithmeticBinaryOp<string mnemonic, Type type,
let results = (outs
SPV_ScalarOrVectorOrCoopMatrixOf<type>:$result
);
let assemblyFormat = "operands attr-dict `:` type($result)";
}

class SPV_ArithmeticUnaryOp<string mnemonic, Type type,
Expand Down
15 changes: 0 additions & 15 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@

class SPV_AtomicUpdateOp<string mnemonic, list<Trait> traits = []> :
SPV_Op<mnemonic, traits> {
let parser = [{ return ::parseAtomicUpdateOp(parser, result, false); }];
let printer = [{ return ::printAtomicUpdateOp(getOperation(), p); }];

let arguments = (ins
SPV_AnyPtr:$pointer,
SPV_ScopeAttr:$memory_scope,
Expand All @@ -32,9 +29,6 @@ class SPV_AtomicUpdateOp<string mnemonic, list<Trait> traits = []> :

class SPV_AtomicUpdateWithValueOp<string mnemonic, list<Trait> traits = []> :
SPV_Op<mnemonic, traits> {
let parser = [{ return ::parseAtomicUpdateOp(parser, result, true); }];
let printer = [{ return ::printAtomicUpdateOp(getOperation(), p); }];

let arguments = (ins
SPV_AnyPtr:$pointer,
SPV_ScopeAttr:$memory_scope,
Expand Down Expand Up @@ -163,9 +157,6 @@ def SPV_AtomicCompareExchangeOp : SPV_Op<"AtomicCompareExchange", []> {
let results = (outs
SPV_Integer:$result
);

let parser = [{ return ::parseAtomicCompareExchangeImpl(parser, result); }];
let printer = [{ return ::printAtomicCompareExchangeImpl(*this, p); }];
}

// -----
Expand Down Expand Up @@ -215,9 +206,6 @@ def SPV_AtomicCompareExchangeWeakOp : SPV_Op<"AtomicCompareExchangeWeak", []> {
let results = (outs
SPV_Integer:$result
);

let parser = [{ return ::parseAtomicCompareExchangeImpl(parser, result); }];
let printer = [{ return ::printAtomicCompareExchangeImpl(*this, p); }];
}

// -----
Expand Down Expand Up @@ -331,9 +319,6 @@ def SPV_AtomicFAddEXTOp : SPV_Op<"AtomicFAddEXT", []> {
let results = (outs
SPV_Float:$result
);

let parser = [{ return ::parseAtomicUpdateOp(parser, result, true); }];
let printer = [{ return ::printAtomicUpdateOp(getOperation(), p); }];
}

// -----
Expand Down
14 changes: 5 additions & 9 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -4269,12 +4269,11 @@ class SPV_Op<string mnemonic, list<Trait> traits = []> :
// For each SPIR-V op, the following static functions need to be defined
// in SPVOps.cpp:
//
// * static ParseResult parse<op-c++-class-name>(OpAsmParser &parser,
// OperationState &result)
// * static void print(OpAsmPrinter &p, <op-c++-class-name> op)
// * ParseResult <op-c++-class-name>::parse(OpAsmParser &parser,
// OperationState &result)
// * void <op-c++-class-name>::print(OpAsmPrinter &p)
// * LogicalResult <op-c++-class-name>::verify()
let parser = [{ return ::parse$cppClass(parser, result); }];
let printer = [{ return ::print(*this, p); }];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;

// Specifies whether this op has a direct corresponding SPIR-V binary
Expand Down Expand Up @@ -4320,8 +4319,7 @@ class SPV_UnaryOp<string mnemonic, Type resultType, Type operandType,
SPV_ScalarOrVectorOf<resultType>:$result
);

let parser = [{ return ::parseUnaryOp(parser, result); }];
let printer = [{ return ::printUnaryOp(getOperation(), p); }];
let assemblyFormat = "$operand `:` type($operand) attr-dict";
// No additional verification needed in addition to the ODS-generated ones.
let hasVerifier = 0;
}
Expand All @@ -4338,8 +4336,6 @@ class SPV_BinaryOp<string mnemonic, Type resultType, Type operandsType,
SPV_ScalarOrVectorOf<resultType>:$result
);

let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];
let printer = [{ return impl::printOneResultOp(getOperation(), p); }];
// No additional verification needed in addition to the ODS-generated ones.
let hasVerifier = 0;
}
Expand Down
12 changes: 8 additions & 4 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ class SPV_BitBinaryOp<string mnemonic, list<Trait> traits = []> :
// All the operands type used in bit instructions are SPV_Integer.
SPV_BinaryOp<mnemonic, SPV_Integer, SPV_Integer,
!listconcat(traits,
[NoSideEffect, SameOperandsAndResultType])>;
[NoSideEffect, SameOperandsAndResultType])> {
let assemblyFormat = "operands attr-dict `:` type($result)";
}

class SPV_BitFieldExtractOp<string mnemonic, list<Trait> traits = []> :
SPV_Op<mnemonic, !listconcat(traits,
Expand Down Expand Up @@ -51,9 +53,11 @@ class SPV_BitUnaryOp<string mnemonic, list<Trait> traits = []> :
class SPV_ShiftOp<string mnemonic, list<Trait> traits = []> :
SPV_BinaryOp<mnemonic, SPV_Integer, SPV_Integer,
!listconcat(traits,
[NoSideEffect, SameOperandsAndResultShape])> {
let parser = [{ return ::parseShiftOp(parser, result); }];
let printer = [{ ::printShiftOp(this->getOperation(), p); }];
[NoSideEffect, SameOperandsAndResultShape,
AllTypesMatch<["operand1", "result"]>])> {
let assemblyFormat = [{
operands attr-dict `:` type($operand1) `,` type($operand2)
}];
let hasVerifier = 1;
}

Expand Down
12 changes: 6 additions & 6 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ class SPV_CastOp<string mnemonic, Type resultType, Type operandType,
let results = (outs
SPV_ScalarOrVectorOrCoopMatrixOf<resultType>:$result
);

let parser = [{ return mlir::impl::parseCastOp(parser, result); }];
let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }];
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` type($result)
}];
}

// -----
Expand Down Expand Up @@ -85,9 +85,9 @@ def SPV_BitcastOp : SPV_Op<"Bitcast", [NoSideEffect]> {
SPV_ScalarOrVectorOrPtr:$result
);

let parser = [{ return mlir::impl::parseCastOp(parser, result); }];
let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }];

let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` type($result)
}];
let hasCanonicalizer = 1;
}

Expand Down
23 changes: 9 additions & 14 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,17 @@ class SPV_GLSLUnaryOp<string mnemonic, int opcode, Type resultType,
SPV_ScalarOrVectorOf<resultType>:$result
);

let parser = [{ return parseUnaryOp(parser, result); }];

let printer = [{ return printUnaryOp(getOperation(), p); }];

let hasVerifier = 0;
}

// Base class for GLSL Unary arithmetic ops where return type matches
// the operand type.
class SPV_GLSLUnaryArithmeticOp<string mnemonic, int opcode, Type type,
list<Trait> traits = []> :
SPV_GLSLUnaryOp<mnemonic, opcode, type, type, traits>;
SPV_GLSLUnaryOp<mnemonic, opcode, type, type,
traits # [SameOperandsAndResultType]> {
let assemblyFormat = "$operand `:` type($operand) attr-dict";
}

// Base class for GLSL binary ops.
class SPV_GLSLBinaryOp<string mnemonic, int opcode, Type resultType,
Expand All @@ -72,18 +71,17 @@ class SPV_GLSLBinaryOp<string mnemonic, int opcode, Type resultType,
SPV_ScalarOrVectorOf<resultType>:$result
);

let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];

let printer = [{ return impl::printOneResultOp(getOperation(), p); }];

let hasVerifier = 0;
}

// Base class for GLSL Binary arithmetic ops where operand types and
// return type matches.
class SPV_GLSLBinaryArithmeticOp<string mnemonic, int opcode, Type type,
list<Trait> traits = []> :
SPV_GLSLBinaryOp<mnemonic, opcode, type, type, traits>;
SPV_GLSLBinaryOp<mnemonic, opcode, type, type,
traits # [SameOperandsAndResultType]> {
let assemblyFormat = "operands attr-dict `:` type($result)";
}

// Base class for GLSL ternary ops.
class SPV_GLSLTernaryArithmeticOp<string mnemonic, int opcode, Type type,
Expand All @@ -100,10 +98,7 @@ class SPV_GLSLTernaryArithmeticOp<string mnemonic, int opcode, Type type,
SPV_ScalarOrVectorOf<type>:$result
);

let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];

let printer = [{ return impl::printOneResultOp(getOperation(), p); }];

let hasCustomAssemblyFormat = 1;
let hasVerifier = 0;
}

Expand Down
26 changes: 17 additions & 9 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@ class SPV_LogicalBinaryOp<string mnemonic, Type operandsType,
list<Trait> traits = []> :
// Result type is SPV_Bool.
SPV_BinaryOp<mnemonic, SPV_Bool, operandsType,
!listconcat(traits,
[NoSideEffect, SameTypeOperands,
SameOperandsAndResultShape])> {
let parser = [{ return ::parseLogicalBinaryOp(parser, result); }];
let printer = [{ return ::printLogicalOp(getOperation(), p); }];
!listconcat(traits, [
NoSideEffect, SameTypeOperands,
SameOperandsAndResultShape,
TypesMatchWith<"type of result to correspond to the `i1` "
"equivalent of the operand",
"operand1", "result",
"getUnaryOpResultType($_self)"
>])> {
let assemblyFormat = "$operand1 `,` $operand2 `:` type($operand1) attr-dict";

let builders = [
OpBuilder<(ins "Value":$lhs, "Value":$rhs),
Expand All @@ -37,10 +41,14 @@ class SPV_LogicalUnaryOp<string mnemonic, Type operandType,
list<Trait> traits = []> :
// Result type is SPV_Bool.
SPV_UnaryOp<mnemonic, SPV_Bool, operandType,
!listconcat(traits, [NoSideEffect, SameTypeOperands,
SameOperandsAndResultShape])> {
let parser = [{ return ::parseLogicalUnaryOp(parser, result); }];
let printer = [{ return ::printLogicalOp(getOperation(), p); }];
!listconcat(traits, [
NoSideEffect, SameTypeOperands, SameOperandsAndResultShape,
TypesMatchWith<"type of result to correspond to the `i1` "
"equivalent of the operand",
"operand", "result",
"getUnaryOpResultType($_self)"
>])> {
let assemblyFormat = "$operand `:` type($operand) attr-dict";

let builders = [
OpBuilder<(ins "Value":$value),
Expand Down
3 changes: 0 additions & 3 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ class SPV_GroupNonUniformArithmeticOp<string mnemonic, Type type,
let results = (outs
SPV_ScalarOrVectorOf<type>:$result
);

let parser = [{ return parseGroupNonUniformArithmeticOp(parser, result); }];
let printer = [{ printGroupNonUniformArithmeticOp(getOperation(), p); }];
}

// -----
Expand Down
16 changes: 7 additions & 9 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ class SPV_OCLUnaryOp<string mnemonic, int opcode, Type resultType,
SPV_ScalarOrVectorOf<resultType>:$result
);

let parser = [{ return parseUnaryOp(parser, result); }];

let printer = [{ return printUnaryOp(getOperation(), p); }];
let assemblyFormat = "$operand `:` type($operand) attr-dict";

let hasVerifier = 0;
}
Expand All @@ -55,7 +53,8 @@ class SPV_OCLUnaryOp<string mnemonic, int opcode, Type resultType,
// the operand type.
class SPV_OCLUnaryArithmeticOp<string mnemonic, int opcode, Type type,
list<Trait> traits = []> :
SPV_OCLUnaryOp<mnemonic, opcode, type, type, traits>;
SPV_OCLUnaryOp<mnemonic, opcode, type, type,
traits # [SameOperandsAndResultType]>;

// Base class for OpenCL binary ops.
class SPV_OCLBinaryOp<string mnemonic, int opcode, Type resultType,
Expand All @@ -71,18 +70,17 @@ class SPV_OCLBinaryOp<string mnemonic, int opcode, Type resultType,
SPV_ScalarOrVectorOf<resultType>:$result
);

let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];

let printer = [{ return impl::printOneResultOp(getOperation(), p); }];

let hasVerifier = 0;
}

// Base class for OpenCL Binary arithmetic ops where operand types and
// return type matches.
class SPV_OCLBinaryArithmeticOp<string mnemonic, int opcode, Type type,
list<Trait> traits = []> :
SPV_OCLBinaryOp<mnemonic, opcode, type, type, traits>;
SPV_OCLBinaryOp<mnemonic, opcode, type, type,
traits # [SameOperandsAndResultType]> {
let assemblyFormat = "operands attr-dict `:` type($result)";
}

// -----

Expand Down
24 changes: 11 additions & 13 deletions mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define SHAPE_OPS

include "mlir/Dialect/Shape/IR/ShapeBase.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
Expand Down Expand Up @@ -121,9 +122,7 @@ def Shape_ConstShapeOp : Shape_Op<"const_shape",
let arguments = (ins IndexElementsAttr:$shape);
let results = (outs Shape_ShapeOrExtentTensorType:$result);

// TODO: Move this to main so that all shape ops implement these.
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasCanonicalizer = 1;

Expand Down Expand Up @@ -331,7 +330,9 @@ def Shape_RankOp : Shape_Op<"rank",
}];
}

def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [
DeclareOpInterfaceMethods<CastOpInterface>, NoSideEffect
]> {
let summary = "Creates a dimension tensor from a shape";
let description = [{
Converts a shape to a 1D integral tensor of extents. The number of elements
Expand Down Expand Up @@ -594,9 +595,8 @@ def Shape_ReduceOp : Shape_Op<"reduce",

let builders = [OpBuilder<(ins "Value":$shape, "ValueRange":$initVals)>];

let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}

def Shape_ShapeOfOp : Shape_Op<"shape_of",
Expand Down Expand Up @@ -624,7 +624,9 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of",
}];
}

def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> {
def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [
DeclareOpInterfaceMethods<CastOpInterface>, NoSideEffect
]> {
let summary = "Casts between index types of the shape and standard dialect";
let description = [{
Converts a `shape.size` to a standard index. This operation and its
Expand Down Expand Up @@ -878,9 +880,6 @@ def Shape_AssumingOp : Shape_Op<"assuming", [
let regions = (region SizedRegion<1>:$doRegion);
let results = (outs Variadic<AnyType>:$results);

let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];

let extraClassDeclaration = [{
// Inline the region into the region containing the AssumingOp and delete
// the AssumingOp.
Expand All @@ -895,6 +894,7 @@ def Shape_AssumingOp : Shape_Op<"assuming", [
];

let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -1081,9 +1081,7 @@ def Shape_FunctionLibraryOp : Shape_Op<"function_library",

let builders = [OpBuilder<(ins "StringRef":$name)>];
let skipDefaultBuilders = 1;

let printer = [{ ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
let hasCustomAssemblyFormat = 1;
}

#endif // SHAPE_OPS
5 changes: 1 addition & 4 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
//===----------------------------------------------------------------------===//

class SparseTensor_Op<string mnemonic, list<Trait> traits = []>
: Op<SparseTensor_Dialect, mnemonic, traits> {
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
: Op<SparseTensor_Dialect, mnemonic, traits>;

//===----------------------------------------------------------------------===//
// Sparse Tensor Operations.
Expand Down
10 changes: 1 addition & 9 deletions mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,7 @@ def StandardOps_Dialect : Dialect {

// Base class for Standard dialect ops.
class Std_Op<string mnemonic, list<Trait> traits = []> :
Op<StandardOps_Dialect, mnemonic, traits> {
// For every standard op, there needs to be a:
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
// * ParseResult parse${C++ class of Op}(OpAsmParser &parser,
// OperationState &result)
// functions.
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
Op<StandardOps_Dialect, mnemonic, traits>;

//===----------------------------------------------------------------------===//
// CallOp
Expand Down
8 changes: 2 additions & 6 deletions mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@ include "mlir/Interfaces/TilingInterface.td"
include "mlir/Interfaces/ViewLikeInterface.td"

class Tensor_Op<string mnemonic, list<Trait> traits = []>
: Op<Tensor_Dialect, mnemonic, traits> {
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
: Op<Tensor_Dialect, mnemonic, traits>;

// Base class for ops with static/dynamic offset, sizes and strides
// attributes/arguments.
Expand Down Expand Up @@ -737,9 +734,8 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :

let hasFolder = 1;
let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parseReshapeLikeOp(parser, result); }];
}

def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
Expand Down
17 changes: 8 additions & 9 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,7 @@ def Vector_Dialect : Dialect {

// Base class for Vector dialect ops.
class Vector_Op<string mnemonic, list<Trait> traits = []> :
Op<Vector_Dialect, mnemonic, traits> {
// For every vector op, there needs to be a:
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
// * ParseResult parse${C++ class of Op}(OpAsmParser &parser,
// OperationState &result)
// functions.
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
Op<Vector_Dialect, mnemonic, traits>;

// The "kind" of combining function for contractions and reductions.
def COMBINING_KIND_ADD : BitEnumAttrCaseBit<"ADD", 0, "add">;
Expand Down Expand Up @@ -253,6 +245,7 @@ def Vector_ContractionOp :
}];

let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -289,6 +282,7 @@ def Vector_ReductionOp :
return vector().getType().cast<VectorType>();
}
}];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -478,6 +472,7 @@ def Vector_ShuffleOp :
return vector().getType().cast<VectorType>();
}
}];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -559,6 +554,7 @@ def Vector_ExtractOp :
}
}];
let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
Expand Down Expand Up @@ -968,6 +964,7 @@ def Vector_OuterProductOp :
return CombiningKind::ADD;
}
}];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -1350,6 +1347,7 @@ def Vector_TransferReadOp :
CArg<"Optional<ArrayRef<bool>>", "::llvm::None">:$inBounds)>,
];
let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
Expand Down Expand Up @@ -1489,6 +1487,7 @@ def Vector_TransferWriteOp :
];
let hasFolder = 1;
let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down
3 changes: 1 addition & 2 deletions mlir/include/mlir/IR/BuiltinOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,7 @@ def FuncOp : Builtin_Op<"func", [

bool isDeclaration() { return isExternal(); }
}];
let parser = [{ return ::parseFuncOp(parser, result); }];
let printer = [{ return ::print(*this, p); }];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down
16 changes: 13 additions & 3 deletions mlir/include/mlir/IR/OpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -2442,14 +2442,24 @@ class Op<Dialect dialect, string mnemonic, list<Trait> props = []> {
// provided.
bit skipDefaultBuilders = 0;

// Custom parser.
// Custom parser and printer.
// NOTE: These fields are deprecated in favor of `assemblyFormat` or
// `hasCustomAssemblyFormat`, and are slated for deletion.
code parser = ?;

// Custom printer.
code printer = ?;

// Custom assembly format.
/// This field corresponds to a declarative description of the assembly format
/// for this operation. If populated, the `hasCustomAssemblyFormat` field is
/// ignored.
string assemblyFormat = ?;
/// This field indicates that the operation has a custom assembly format
/// implemented in C++. When set to `1` a `parse` and `print` method are generated
/// on the operation class. The operation should implement these methods to
/// support the custom format of the operation. The methods have the form:
/// * ParseResult parse(OpAsmParser &parser, OperationState &result)
/// * void print(OpAsmPrinter &p)
bit hasCustomAssemblyFormat = 0;

// A bit indicating if the operation has additional invariants that need to
// verified (aside from those verified by other ODS constructs). If set to `1`,
Expand Down
33 changes: 1 addition & 32 deletions mlir/include/mlir/IR/OpDefinition.h
Original file line number Diff line number Diff line change
Expand Up @@ -1897,26 +1897,9 @@ class OpInterface
};

//===----------------------------------------------------------------------===//
// Common Operation Folders/Parsers/Printers
// CastOpInterface utilities
//===----------------------------------------------------------------------===//

// These functions are out-of-line implementations of the methods in UnaryOp and
// BinaryOp, which avoids them being template instantiated/duplicated.
namespace impl {
ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser,
OperationState &result);

void buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs,
Value rhs);
ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
OperationState &result);

// Prints the given binary `op` in custom assembly form if both the two operands
// and the result have the same time. Otherwise, prints the generic assembly
// form.
void printOneResultOp(Operation *op, OpAsmPrinter &p);
} // namespace impl

// These functions are out-of-line implementations of the methods in
// CastOpInterface, which avoids them being template instantiated/duplicated.
namespace impl {
Expand All @@ -1927,20 +1910,6 @@ LogicalResult foldCastInterfaceOp(Operation *op,
/// Attempt to verify the given cast operation.
LogicalResult verifyCastInterfaceOp(
Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible);

// TODO: Remove the parse/print/build here (new ODS functionality obsoletes the
// need for them, but some older ODS code in `std` still depends on them).
void buildCastOp(OpBuilder &builder, OperationState &result, Value source,
Type destType);
ParseResult parseCastOp(OpAsmParser &parser, OperationState &result);
void printCastOp(Operation *op, OpAsmPrinter &p);
// TODO: These methods are deprecated in favor of CastOpInterface. Remove them
// when all uses have been updated. Also, consider adding functionality to
// CastOpInterface to be able to perform the ChainedTensorCast canonicalization
// generically.
Value foldCastOp(Operation *op);
LogicalResult verifyCastOp(Operation *op,
function_ref<bool(Type, Type)> areCastCompatible);
} // namespace impl
} // namespace mlir

Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1003,11 +1003,11 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
switch (conversion) {
case PrintConversion::ZeroExt64:
value = rewriter.create<arith::ExtUIOp>(
loc, value, IntegerType::get(rewriter.getContext(), 64));
loc, IntegerType::get(rewriter.getContext(), 64), value);
break;
case PrintConversion::SignExt64:
value = rewriter.create<arith::ExtSIOp>(
loc, value, IntegerType::get(rewriter.getContext(), 64));
loc, IntegerType::get(rewriter.getContext(), 64), value);
break;
case PrintConversion::None:
break;
Expand Down
218 changes: 111 additions & 107 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Large diffs are not rendered by default.

20 changes: 10 additions & 10 deletions mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1506,16 +1506,7 @@ OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}

static void print(OpAsmPrinter &p, arith::SelectOp op) {
p << " " << op.getOperands();
p.printOptionalAttrDict(op->getAttrs());
p << " : ";
if (ShapedType condType = op.getCondition().getType().dyn_cast<ShapedType>())
p << condType << ", ";
p << op.getType();
}

static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
Type conditionType, resultType;
SmallVector<OpAsmParser::OperandType, 3> operands;
if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
Expand All @@ -1538,6 +1529,15 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
parser.getNameLoc(), result.operands);
}

void arith::SelectOp::print(OpAsmPrinter &p) {
p << " " << getOperands();
p.printOptionalAttrDict((*this)->getAttrs());
p << " : ";
if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>())
p << condType << ", ";
p << getType();
}

LogicalResult arith::SelectOp::verify() {
Type conditionType = getCondition().getType();
if (conditionType.isSignlessInteger(1))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ struct IndexCastOpInterface
getMemRefType(castOp.getType().cast<TensorType>(), state.getOptions(),
layout, sourceType.getMemorySpace());

replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, source,
resultType);
replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType,
source);
return success();
}
};
Expand Down
20 changes: 10 additions & 10 deletions mlir/lib/Dialect/Async/IR/Async.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,31 +125,31 @@ void ExecuteOp::build(OpBuilder &builder, OperationState &result,
}
}

static void print(OpAsmPrinter &p, ExecuteOp op) {
void ExecuteOp::print(OpAsmPrinter &p) {
// [%tokens,...]
if (!op.dependencies().empty())
p << " [" << op.dependencies() << "]";
if (!dependencies().empty())
p << " [" << dependencies() << "]";

// (%value as %unwrapped: !async.value<!arg.type>, ...)
if (!op.operands().empty()) {
if (!operands().empty()) {
p << " (";
Block *entry = op.body().empty() ? nullptr : &op.body().front();
llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable {
Block *entry = body().empty() ? nullptr : &body().front();
llvm::interleaveComma(operands(), p, [&, n = 0](Value operand) mutable {
Value argument = entry ? entry->getArgument(n++) : Value();
p << operand << " as " << argument << ": " << operand.getType();
});
p << ")";
}

// -> (!async.value<!return.type>, ...)
p.printOptionalArrowTypeList(llvm::drop_begin(op.getResultTypes()));
p.printOptionalAttrDictWithKeyword(op->getAttrs(),
p.printOptionalArrowTypeList(llvm::drop_begin(getResultTypes()));
p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
{kOperandSegmentSizesAttr});
p << ' ';
p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
p.printRegion(body(), /*printEntryBlockArgs=*/false);
}

static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {
MLIRContext *ctx = result.getContext();

// Sizes of parsed variadic operands, will be updated below after parsing.
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -835,15 +835,15 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
scalingFactor);
}
Value numWorkersIndex =
b.create<arith::IndexCastOp>(numWorkerThreadsVal, b.getI32Type());
b.create<arith::IndexCastOp>(b.getI32Type(), numWorkerThreadsVal);
Value numWorkersFloat =
b.create<arith::SIToFPOp>(numWorkersIndex, b.getF32Type());
b.create<arith::SIToFPOp>(b.getF32Type(), numWorkersIndex);
Value scaledNumWorkers =
b.create<arith::MulFOp>(scalingFactor, numWorkersFloat);
Value scaledNumInt =
b.create<arith::FPToSIOp>(scaledNumWorkers, b.getI32Type());
b.create<arith::FPToSIOp>(b.getI32Type(), scaledNumWorkers);
Value scaledWorkers =
b.create<arith::IndexCastOp>(scaledNumInt, b.getIndexType());
b.create<arith::IndexCastOp>(b.getIndexType(), scaledNumInt);

Value maxComputeBlocks = b.create<arith::MaxSIOp>(
b.create<arith::ConstantIndexOp>(1), scaledWorkers);
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,18 @@ OpFoldResult emitc::ConstantOp::fold(ArrayRef<Attribute> operands) {
// IncludeOp
//===----------------------------------------------------------------------===//

static void print(OpAsmPrinter &p, IncludeOp &op) {
bool standardInclude = op.is_standard_include();
void IncludeOp::print(OpAsmPrinter &p) {
bool standardInclude = is_standard_include();

p << " ";
if (standardInclude)
p << "<";
p << "\"" << op.include() << "\"";
p << "\"" << include() << "\"";
if (standardInclude)
p << ">";
}

static ParseResult parseIncludeOp(OpAsmParser &parser, OperationState &result) {
ParseResult IncludeOp::parse(OpAsmParser &parser, OperationState &result) {
bool standardInclude = !parser.parseOptionalLess();

StringAttr include;
Expand Down
82 changes: 41 additions & 41 deletions mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,21 +445,21 @@ static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size,
p << size.z << " = " << operands.z << ')';
}

static void printLaunchOp(OpAsmPrinter &p, LaunchOp op) {
void LaunchOp::print(OpAsmPrinter &p) {
// Print the launch configuration.
p << ' ' << op.getBlocksKeyword();
printSizeAssignment(p, op.getGridSize(), op.getGridSizeOperandValues(),
op.getBlockIds());
p << ' ' << op.getThreadsKeyword();
printSizeAssignment(p, op.getBlockSize(), op.getBlockSizeOperandValues(),
op.getThreadIds());
if (op.dynamicSharedMemorySize())
p << ' ' << op.getDynamicSharedMemorySizeKeyword() << ' '
<< op.dynamicSharedMemorySize();
p << ' ' << getBlocksKeyword();
printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
getBlockIds());
p << ' ' << getThreadsKeyword();
printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(),
getThreadIds());
if (dynamicSharedMemorySize())
p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
<< dynamicSharedMemorySize();

p << ' ';
p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict(op->getAttrs());
p.printRegion(body(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict((*this)->getAttrs());
}

// Parse the size assignment blocks for blocks and threads. These have the form
Expand Down Expand Up @@ -492,12 +492,14 @@ parseSizeAssignment(OpAsmParser &parser,
return parser.parseRParen();
}

// Parses a Launch operation.
// operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
// region attr-dict?
// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
static ParseResult parseLaunchOp(OpAsmParser &parser, OperationState &result) {
/// Parses a Launch operation.
/// operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in`
/// ssa-reassignment
/// `threads` `(` ssa-id-list `)` `in`
/// ssa-reassignment
/// region attr-dict?
/// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
// Sizes of the grid and block.
SmallVector<OpAsmParser::OperandType, LaunchOp::kNumConfigOperands> sizes(
LaunchOp::kNumConfigOperands);
Expand Down Expand Up @@ -778,7 +780,7 @@ parseAttributions(OpAsmParser &parser, StringRef keyword,
/// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
/// (`->` function-result-list)? memory-attribution `kernel`?
/// function-attributes? region
static ParseResult parseGPUFuncOp(OpAsmParser &parser, OperationState &result) {
ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType> entryArgs;
SmallVector<NamedAttrList> argAttrs;
SmallVector<NamedAttrList> resultAttrs;
Expand Down Expand Up @@ -853,27 +855,26 @@ static void printAttributions(OpAsmPrinter &p, StringRef keyword,
p << ')';
}

/// Prints a GPU Func op.
static void printGPUFuncOp(OpAsmPrinter &p, GPUFuncOp op) {
void GPUFuncOp::print(OpAsmPrinter &p) {
p << ' ';
p.printSymbolName(op.getName());
p.printSymbolName(getName());

FunctionType type = op.getType();
function_interface_impl::printFunctionSignature(
p, op.getOperation(), type.getInputs(),
/*isVariadic=*/false, type.getResults());
FunctionType type = getType();
function_interface_impl::printFunctionSignature(p, *this, type.getInputs(),
/*isVariadic=*/false,
type.getResults());

printAttributions(p, op.getWorkgroupKeyword(), op.getWorkgroupAttributions());
printAttributions(p, op.getPrivateKeyword(), op.getPrivateAttributions());
if (op.isKernel())
p << ' ' << op.getKernelKeyword();
printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
if (isKernel())
p << ' ' << getKernelKeyword();

function_interface_impl::printFunctionAttributes(
p, op.getOperation(), type.getNumInputs(), type.getNumResults(),
{op.getNumWorkgroupAttributionsAttrName(),
p, *this, type.getNumInputs(), type.getNumResults(),
{getNumWorkgroupAttributionsAttrName(),
GPUDialect::getKernelFuncAttrName()});
p << ' ';
p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false);
p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
}

LogicalResult GPUFuncOp::verifyType() {
Expand Down Expand Up @@ -970,10 +971,9 @@ void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
}

static ParseResult parseGPUModuleOp(OpAsmParser &parser,
OperationState &result) {
ParseResult GPUModuleOp::parse(OpAsmParser &parser, OperationState &result) {
StringAttr nameAttr;
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(),
result.attributes))
return failure();

Expand All @@ -991,13 +991,13 @@ static ParseResult parseGPUModuleOp(OpAsmParser &parser,
return success();
}

static void print(OpAsmPrinter &p, GPUModuleOp op) {
void GPUModuleOp::print(OpAsmPrinter &p) {
p << ' ';
p.printSymbolName(op.getName());
p.printOptionalAttrDictWithKeyword(op->getAttrs(),
{SymbolTable::getSymbolAttrName()});
p.printSymbolName(getName());
p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
{mlir::SymbolTable::getSymbolAttrName()});
p << ' ';
p.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
}

Expand Down
299 changes: 137 additions & 162 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
}

// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
OperationState &result) {
ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
MLIRContext *context = parser.getContext();
auto int32Ty = IntegerType::get(context, 32);
auto int1Ty = IntegerType::get(context, 1);
Expand All @@ -62,6 +61,8 @@ static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
parser.getNameLoc(), result.operands));
}

void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }

LogicalResult CpAsyncOp::verify() {
if (size() != 4 && size() != 8 && size() != 16)
return emitError("expected byte size to be either 4, 8 or 16.");
Expand Down
14 changes: 10 additions & 4 deletions mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ using namespace ROCDL;
// <operation> ::=
// `llvm.amdgcn.buffer.load.* %rsrc, %vindex, %offset, %glc, %slc :
// result_type`
static ParseResult parseROCDLMubufLoadOp(OpAsmParser &parser,
OperationState &result) {
ParseResult MubufLoadOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 8> ops;
Type type;
if (parser.parseOperandList(ops, 5) || parser.parseColonType(type) ||
Expand All @@ -56,11 +55,14 @@ static ParseResult parseROCDLMubufLoadOp(OpAsmParser &parser,
parser.getNameLoc(), result.operands);
}

void MubufLoadOp::print(OpAsmPrinter &p) {
p << " " << getOperands() << " : " << (*this)->getResultTypes();
}

// <operation> ::=
// `llvm.amdgcn.buffer.store.* %vdata, %rsrc, %vindex, %offset, %glc, %slc :
// result_type`
static ParseResult parseROCDLMubufStoreOp(OpAsmParser &parser,
OperationState &result) {
ParseResult MubufStoreOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 8> ops;
Type type;
if (parser.parseOperandList(ops, 6) || parser.parseColonType(type))
Expand All @@ -78,6 +80,10 @@ static ParseResult parseROCDLMubufStoreOp(OpAsmParser &parser,
return success();
}

void MubufStoreOp::print(OpAsmPrinter &p) {
p << " " << getOperands() << " : " << vdata().getType();
}

//===----------------------------------------------------------------------===//
// ROCDLDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
Expand Down
72 changes: 36 additions & 36 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -517,50 +517,51 @@ void GenericOp::build(
/*libraryCall=*/"", bodyBuild, attributes);
}

static void print(OpAsmPrinter &p, GenericOp op) {
void GenericOp::print(OpAsmPrinter &p) {
p << " ";

// Print extra attributes.
auto genericAttrNames = op.linalgTraitAttrNames();
auto genericAttrNames = linalgTraitAttrNames();

llvm::StringSet<> genericAttrNamesSet;
genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
SmallVector<NamedAttribute, 8> genericAttrs;
for (auto attr : op->getAttrs())
for (auto attr : (*this)->getAttrs())
if (genericAttrNamesSet.count(attr.getName().strref()) > 0)
genericAttrs.push_back(attr);
if (!genericAttrs.empty()) {
auto genericDictAttr = DictionaryAttr::get(op.getContext(), genericAttrs);
auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
p << genericDictAttr;
}

// Printing is shared with named ops, except for the region and attributes
printCommonStructuredOpParts(p, op);
printCommonStructuredOpParts(p, *this);

genericAttrNames.push_back("operand_segment_sizes");
genericAttrNamesSet.insert(genericAttrNames.back());

bool hasExtraAttrs = false;
for (NamedAttribute n : op->getAttrs()) {
for (NamedAttribute n : (*this)->getAttrs()) {
if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
break;
}
if (hasExtraAttrs) {
p << " attrs = ";
p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/genericAttrNames);
p.printOptionalAttrDict((*this)->getAttrs(),
/*elidedAttrs=*/genericAttrNames);
}

// Print region.
if (!op.region().empty()) {
if (!region().empty()) {
p << ' ';
p.printRegion(op.region());
p.printRegion(region());
}

// Print results.
printNamedStructuredOpResults(p, op.result_tensors().getTypes());
printNamedStructuredOpResults(p, result_tensors().getTypes());
}

static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
DictionaryAttr dictAttr;
// Parse the core linalg traits that must check into a dictAttr.
// The name is unimportant as we will overwrite result.attributes.
Expand Down Expand Up @@ -988,15 +989,15 @@ LogicalResult InitTensorOp::reifyResultShapes(
// YieldOp
//===----------------------------------------------------------------------===//

static void print(OpAsmPrinter &p, linalg::YieldOp op) {
if (op.getNumOperands() > 0)
p << ' ' << op.getOperands();
p.printOptionalAttrDict(op->getAttrs());
if (op.getNumOperands() > 0)
p << " : " << op.getOperandTypes();
void linalg::YieldOp::print(OpAsmPrinter &p) {
if (getNumOperands() > 0)
p << ' ' << getOperands();
p.printOptionalAttrDict((*this)->getAttrs());
if (getNumOperands() > 0)
p << " : " << getOperandTypes();
}

static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 2> opInfo;
SmallVector<Type, 2> types;
SMLoc loc = parser.getCurrentLocation();
Expand Down Expand Up @@ -1137,48 +1138,47 @@ void TiledLoopOp::build(OpBuilder &builder, OperationState &result,
}
}

static void print(OpAsmPrinter &p, TiledLoopOp op) {
p << " (" << op.getInductionVars() << ") = (" << op.lowerBound() << ") to ("
<< op.upperBound() << ") step (" << op.step() << ")";
void TiledLoopOp::print(OpAsmPrinter &p) {
p << " (" << getInductionVars() << ") = (" << lowerBound() << ") to ("
<< upperBound() << ") step (" << step() << ")";

if (!op.inputs().empty()) {
if (!inputs().empty()) {
p << " ins (";
llvm::interleaveComma(llvm::zip(op.getRegionInputArgs(), op.inputs()), p,
llvm::interleaveComma(llvm::zip(getRegionInputArgs(), inputs()), p,
[&](auto it) {
p << std::get<0>(it) << " = " << std::get<1>(it)
<< ": " << std::get<1>(it).getType();
});
p << ")";
}
if (!op.outputs().empty()) {
if (!outputs().empty()) {
p << " outs (";
llvm::interleaveComma(llvm::zip(op.getRegionOutputArgs(), op.outputs()), p,
llvm::interleaveComma(llvm::zip(getRegionOutputArgs(), outputs()), p,
[&](auto it) {
p << std::get<0>(it) << " = " << std::get<1>(it)
<< ": " << std::get<1>(it).getType();
});
p << ")";
}

if (llvm::any_of(op.iterator_types(), [](Attribute attr) {
if (llvm::any_of(iterator_types(), [](Attribute attr) {
return attr.cast<StringAttr>().getValue() !=
getParallelIteratorTypeName();
}))
p << " iterators" << op.iterator_types();
p << " iterators" << iterator_types();

if (op.distribution_types().hasValue())
p << " distribution" << op.distribution_types().getValue();
if (distribution_types().hasValue())
p << " distribution" << distribution_types().getValue();

p << ' ';
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict(
op->getAttrs(), /*elidedAttrs=*/{TiledLoopOp::getOperandSegmentSizeAttr(),
getIteratorTypesAttrName(),
getDistributionTypesAttrName()});
p.printRegion(region(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
TiledLoopOp::getOperandSegmentSizeAttr(),
getIteratorTypesAttrName(),
getDistributionTypesAttrName()});
}

static ParseResult parseTiledLoopOp(OpAsmParser &parser,
OperationState &result) {
ParseResult TiledLoopOp::parse(OpAsmParser &parser, OperationState &result) {
auto &builder = parser.getBuilder();
// Parse an opening `(` followed by induction variables followed by `)`
SmallVector<OpAsmParser::OperandType, 4> ivs;
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,7 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
auto i32Vec = broadcast(builder.getI32Type(), shape);

// exp2(k)
Value k = builder.create<arith::FPToSIOp>(kF32, i32Vec);
Value k = builder.create<arith::FPToSIOp>(i32Vec, kF32);
Value exp2KValue = exp2I32(builder, k);

// exp(x) = exp(y) * exp2(k)
Expand Down Expand Up @@ -1042,7 +1042,7 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(

auto i32Vec = broadcast(builder.getI32Type(), shape);
auto fPToSingedInteger = [&](Value a) -> Value {
return builder.create<arith::FPToSIOp>(a, i32Vec);
return builder.create<arith::FPToSIOp>(i32Vec, a);
};

auto modulo4 = [&](Value a) -> Value {
Expand Down
123 changes: 60 additions & 63 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,6 @@ Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
return NoneType::get(type.getContext());
}

LogicalResult memref::CastOp::verify() {
return impl::verifyCastOp(*this, areCastCompatible);
}

//===----------------------------------------------------------------------===//
// AllocOp / AllocaOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -165,7 +161,7 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
alloc.alignmentAttr());
// Insert a cast so we have the same type as the old alloc.
auto resultCast =
rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType());
rewriter.create<CastOp>(alloc.getLoc(), alloc.getType(), newAlloc);

rewriter.replaceOp(alloc, {resultCast});
return success();
Expand Down Expand Up @@ -210,23 +206,22 @@ void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
// AllocaScopeOp
//===----------------------------------------------------------------------===//

static void print(OpAsmPrinter &p, AllocaScopeOp &op) {
void AllocaScopeOp::print(OpAsmPrinter &p) {
bool printBlockTerminators = false;

p << ' ';
if (!op.results().empty()) {
p << " -> (" << op.getResultTypes() << ")";
if (!results().empty()) {
p << " -> (" << getResultTypes() << ")";
printBlockTerminators = true;
}
p << ' ';
p.printRegion(op.bodyRegion(),
p.printRegion(bodyRegion(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/printBlockTerminators);
p.printOptionalAttrDict(op->getAttrs());
p.printOptionalAttrDict((*this)->getAttrs());
}

static ParseResult parseAllocaScopeOp(OpAsmParser &parser,
OperationState &result) {
ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
// Create a region for the body.
result.regions.reserve(1);
Region *bodyRegion = result.addRegion();
Expand Down Expand Up @@ -782,17 +777,16 @@ void DmaStartOp::build(OpBuilder &builder, OperationState &result,
result.addOperands({stride, elementsPerStride});
}

static void print(OpAsmPrinter &p, DmaStartOp op) {
p << " " << op.getSrcMemRef() << '[' << op.getSrcIndices() << "], "
<< op.getDstMemRef() << '[' << op.getDstIndices() << "], "
<< op.getNumElements() << ", " << op.getTagMemRef() << '['
<< op.getTagIndices() << ']';
if (op.isStrided())
p << ", " << op.getStride() << ", " << op.getNumElementsPerStride();
void DmaStartOp::print(OpAsmPrinter &p) {
p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
<< getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
<< ", " << getTagMemRef() << '[' << getTagIndices() << ']';
if (isStrided())
p << ", " << getStride() << ", " << getNumElementsPerStride();

p.printOptionalAttrDict(op->getAttrs());
p << " : " << op.getSrcMemRef().getType() << ", "
<< op.getDstMemRef().getType() << ", " << op.getTagMemRef().getType();
p.printOptionalAttrDict((*this)->getAttrs());
p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
<< ", " << getTagMemRef().getType();
}

// Parse DmaStartOp.
Expand All @@ -803,8 +797,7 @@ static void print(OpAsmPrinter &p, DmaStartOp op) {
// memref<1024 x f32, 2>,
// memref<1 x i32>
//
static ParseResult parseDmaStartOp(OpAsmParser &parser,
OperationState &result) {
ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType srcMemRefInfo;
SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
OpAsmParser::OperandType dstMemRefInfo;
Expand Down Expand Up @@ -997,8 +990,8 @@ LogicalResult GenericAtomicRMWOp::verify() {
return hasSideEffects ? failure() : success();
}

static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser,
OperationState &result) {
ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType memref;
Type memrefType;
SmallVector<OpAsmParser::OperandType, 4> ivs;
Expand All @@ -1019,11 +1012,11 @@ static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser,
return success();
}

static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) {
p << ' ' << op.memref() << "[" << op.indices()
<< "] : " << op.memref().getType() << ' ';
p.printRegion(op.getRegion());
p.printOptionalAttrDict(op->getAttrs());
void GenericAtomicRMWOp::print(OpAsmPrinter &p) {
p << ' ' << memref() << "[" << indices() << "] : " << memref().getType()
<< ' ';
p.printRegion(getRegion());
p.printOptionalAttrDict((*this)->getAttrs());
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1167,20 +1160,19 @@ OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
// PrefetchOp
//===----------------------------------------------------------------------===//

static void print(OpAsmPrinter &p, PrefetchOp op) {
p << " " << op.memref() << '[';
p.printOperands(op.indices());
p << ']' << ", " << (op.isWrite() ? "write" : "read");
p << ", locality<" << op.localityHint();
p << ">, " << (op.isDataCache() ? "data" : "instr");
void PrefetchOp::print(OpAsmPrinter &p) {
p << " " << memref() << '[';
p.printOperands(indices());
p << ']' << ", " << (isWrite() ? "write" : "read");
p << ", locality<" << localityHint();
p << ">, " << (isDataCache() ? "data" : "instr");
p.printOptionalAttrDict(
op->getAttrs(),
(*this)->getAttrs(),
/*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
p << " : " << op.getMemRefType();
p << " : " << getMemRefType();
}

static ParseResult parsePrefetchOp(OpAsmParser &parser,
OperationState &result) {
ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
IntegerAttr localityHint;
Expand Down Expand Up @@ -1378,12 +1370,19 @@ SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
getReassociationIndices());
}

static void print(OpAsmPrinter &p, ExpandShapeOp op) {
::mlir::printReshapeOp<ExpandShapeOp>(p, op);
ParseResult ExpandShapeOp::parse(OpAsmParser &parser, OperationState &result) {
return parseReshapeLikeOp(parser, result);
}
void ExpandShapeOp::print(OpAsmPrinter &p) {
::mlir::printReshapeOp<ExpandShapeOp>(p, *this);
}

static void print(OpAsmPrinter &p, CollapseShapeOp op) {
::mlir::printReshapeOp<CollapseShapeOp>(p, op);
ParseResult CollapseShapeOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseReshapeLikeOp(parser, result);
}
void CollapseShapeOp::print(OpAsmPrinter &p) {
::mlir::printReshapeOp<CollapseShapeOp>(p, *this);
}

/// Detect whether memref dims [dim, dim + extent) can be reshaped without
Expand Down Expand Up @@ -2156,8 +2155,8 @@ class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
rewriter.replaceOp(subViewOp, subViewOp.source());
return success();
}
rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.source(),
subViewOp.getType());
rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
subViewOp.source());
return success();
}
};
Expand All @@ -2177,7 +2176,7 @@ struct SubViewReturnTypeCanonicalizer {
/// A canonicalizer wrapper to replace SubViewOps.
struct SubViewCanonicalizer {
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType());
rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
}
};

Expand Down Expand Up @@ -2245,15 +2244,13 @@ void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
}

// transpose $in $permutation attr-dict : type($in) `to` type(results)
static void print(OpAsmPrinter &p, TransposeOp op) {
p << " " << op.in() << " " << op.permutation();
p.printOptionalAttrDict(op->getAttrs(),
{TransposeOp::getPermutationAttrName()});
p << " : " << op.in().getType() << " to " << op.getType();
void TransposeOp::print(OpAsmPrinter &p) {
p << " " << in() << " " << permutation();
p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
p << " : " << in().getType() << " to " << getType();
}

static ParseResult parseTransposeOp(OpAsmParser &parser,
OperationState &result) {
ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType in;
AffineMap permutation;
MemRefType srcType, dstType;
Expand Down Expand Up @@ -2296,7 +2293,7 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
// ViewOp
//===----------------------------------------------------------------------===//

static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
ParseResult ViewOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType srcInfo;
SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
Expand All @@ -2321,12 +2318,12 @@ static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
parser.addTypeToList(dstType, result.types));
}

static void print(OpAsmPrinter &p, ViewOp op) {
p << ' ' << op.getOperand(0) << '[';
p.printOperand(op.byte_shift());
p << "][" << op.sizes() << ']';
p.printOptionalAttrDict(op->getAttrs());
p << " : " << op.getOperand(0).getType() << " to " << op.getType();
void ViewOp::print(OpAsmPrinter &p) {
p << ' ' << getOperand(0) << '[';
p.printOperand(byte_shift());
p << "][" << sizes() << ']';
p.printOptionalAttrDict((*this)->getAttrs());
p << " : " << getOperand(0).getType() << " to " << getType();
}

LogicalResult ViewOp::verify() {
Expand Down Expand Up @@ -2422,7 +2419,7 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
viewOp.getOperand(0),
viewOp.byte_shift(), newOperands);
// Insert a cast so we have the same type as the old memref type.
rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType());
rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
return success();
}
};
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
size = rewriter.create<memref::LoadOp>(loc, op.shape(), index);
if (!size.getType().isa<IndexType>())
size = rewriter.create<arith::IndexCastOp>(loc, size,
rewriter.getIndexType());
size = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), size);
sizes[i] = size;
} else {
sizes[i] = rewriter.getIndexAttr(op.getType().getDimSize(i));
Expand Down
89 changes: 41 additions & 48 deletions mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,7 @@ struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
/// `private` `(` value-list `)`?
/// `firstprivate` `(` value-list `)`?
/// region attr-dict?
static ParseResult parseParallelOp(OpAsmParser &parser,
OperationState &result) {
ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
Builder &builder = parser.getBuilder();
SmallVector<OpAsmParser::OperandType, 8> privateOperands,
firstprivateOperands, copyOperands, copyinOperands,
Expand Down Expand Up @@ -390,99 +389,94 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
return success();
}

static void print(OpAsmPrinter &printer, ParallelOp &op) {
void ParallelOp::print(OpAsmPrinter &printer) {
// async()?
if (Value async = op.async())
if (Value async = this->async())
printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ": "
<< async.getType() << ")";

// wait()?
printOperandList(op.waitOperands(), ParallelOp::getWaitKeyword(), printer);
printOperandList(waitOperands(), ParallelOp::getWaitKeyword(), printer);

// num_gangs()?
if (Value numGangs = op.numGangs())
if (Value numGangs = this->numGangs())
printer << " " << ParallelOp::getNumGangsKeyword() << "(" << numGangs
<< ": " << numGangs.getType() << ")";

// num_workers()?
if (Value numWorkers = op.numWorkers())
if (Value numWorkers = this->numWorkers())
printer << " " << ParallelOp::getNumWorkersKeyword() << "(" << numWorkers
<< ": " << numWorkers.getType() << ")";

// vector_length()?
if (Value vectorLength = op.vectorLength())
if (Value vectorLength = this->vectorLength())
printer << " " << ParallelOp::getVectorLengthKeyword() << "("
<< vectorLength << ": " << vectorLength.getType() << ")";

// if()?
if (Value ifCond = op.ifCond())
if (Value ifCond = this->ifCond())
printer << " " << ParallelOp::getIfKeyword() << "(" << ifCond << ")";

// self()?
if (Value selfCond = op.selfCond())
if (Value selfCond = this->selfCond())
printer << " " << ParallelOp::getSelfKeyword() << "(" << selfCond << ")";

// reduction()?
printOperandList(op.reductionOperands(), ParallelOp::getReductionKeyword(),
printOperandList(reductionOperands(), ParallelOp::getReductionKeyword(),
printer);

// copy()?
printOperandList(op.copyOperands(), ParallelOp::getCopyKeyword(), printer);
printOperandList(copyOperands(), ParallelOp::getCopyKeyword(), printer);

// copyin()?
printOperandList(op.copyinOperands(), ParallelOp::getCopyinKeyword(),
printer);
printOperandList(copyinOperands(), ParallelOp::getCopyinKeyword(), printer);

// copyin_readonly()?
printOperandList(op.copyinReadonlyOperands(),
printOperandList(copyinReadonlyOperands(),
ParallelOp::getCopyinReadonlyKeyword(), printer);

// copyout()?
printOperandList(op.copyoutOperands(), ParallelOp::getCopyoutKeyword(),
printer);
printOperandList(copyoutOperands(), ParallelOp::getCopyoutKeyword(), printer);

// copyout_zero()?
printOperandList(op.copyoutZeroOperands(),
ParallelOp::getCopyoutZeroKeyword(), printer);
printOperandList(copyoutZeroOperands(), ParallelOp::getCopyoutZeroKeyword(),
printer);

// create()?
printOperandList(op.createOperands(), ParallelOp::getCreateKeyword(),
printer);
printOperandList(createOperands(), ParallelOp::getCreateKeyword(), printer);

// create_zero()?
printOperandList(op.createZeroOperands(), ParallelOp::getCreateZeroKeyword(),
printOperandList(createZeroOperands(), ParallelOp::getCreateZeroKeyword(),
printer);

// no_create()?
printOperandList(op.noCreateOperands(), ParallelOp::getNoCreateKeyword(),
printOperandList(noCreateOperands(), ParallelOp::getNoCreateKeyword(),
printer);

// present()?
printOperandList(op.presentOperands(), ParallelOp::getPresentKeyword(),
printer);
printOperandList(presentOperands(), ParallelOp::getPresentKeyword(), printer);

// deviceptr()?
printOperandList(op.devicePtrOperands(), ParallelOp::getDevicePtrKeyword(),
printOperandList(devicePtrOperands(), ParallelOp::getDevicePtrKeyword(),
printer);

// attach()?
printOperandList(op.attachOperands(), ParallelOp::getAttachKeyword(),
printer);
printOperandList(attachOperands(), ParallelOp::getAttachKeyword(), printer);

// private()?
printOperandList(op.gangPrivateOperands(), ParallelOp::getPrivateKeyword(),
printOperandList(gangPrivateOperands(), ParallelOp::getPrivateKeyword(),
printer);

// firstprivate()?
printOperandList(op.gangFirstPrivateOperands(),
printOperandList(gangFirstPrivateOperands(),
ParallelOp::getFirstPrivateKeyword(), printer);

printer << ' ';
printer.printRegion(op.region(),
printer.printRegion(region(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
printer.printOptionalAttrDictWithKeyword(
op->getAttrs(), ParallelOp::getOperandSegmentSizeAttr());
(*this)->getAttrs(), ParallelOp::getOperandSegmentSizeAttr());
}

unsigned ParallelOp::getNumDataOperands() {
Expand Down Expand Up @@ -518,7 +512,7 @@ Value ParallelOp::getDataOperand(unsigned i) {
/// (`private` `(` value-list `)`)?
/// (`reduction` `(` value-list `)`)?
/// region attr-dict?
static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
Builder &builder = parser.getBuilder();
unsigned executionMapping = OpenACCExecMapping::NONE;
SmallVector<Type, 8> operandTypes;
Expand Down Expand Up @@ -606,12 +600,12 @@ static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
return success();
}

static void print(OpAsmPrinter &printer, LoopOp &op) {
unsigned execMapping = op.exec_mapping();
void LoopOp::print(OpAsmPrinter &printer) {
unsigned execMapping = exec_mapping();
if (execMapping & OpenACCExecMapping::GANG) {
printer << " " << LoopOp::getGangKeyword();
Value gangNum = op.gangNum();
Value gangStatic = op.gangStatic();
Value gangNum = this->gangNum();
Value gangStatic = this->gangStatic();

// Print optional gang operands
if (gangNum || gangStatic) {
Expand All @@ -633,39 +627,38 @@ static void print(OpAsmPrinter &printer, LoopOp &op) {
printer << " " << LoopOp::getWorkerKeyword();

// Print optional worker operand if present
if (Value workerNum = op.workerNum())
if (Value workerNum = this->workerNum())
printer << "(" << workerNum << ": " << workerNum.getType() << ")";
}

if (execMapping & OpenACCExecMapping::VECTOR) {
printer << " " << LoopOp::getVectorKeyword();

// Print optional vector operand if present
if (Value vectorLength = op.vectorLength())
if (Value vectorLength = this->vectorLength())
printer << "(" << vectorLength << ": " << vectorLength.getType() << ")";
}

// tile()?
printOperandList(op.tileOperands(), LoopOp::getTileKeyword(), printer);
printOperandList(tileOperands(), LoopOp::getTileKeyword(), printer);

// private()?
printOperandList(op.privateOperands(), LoopOp::getPrivateKeyword(), printer);
printOperandList(privateOperands(), LoopOp::getPrivateKeyword(), printer);

// reduction()?
printOperandList(op.reductionOperands(), LoopOp::getReductionKeyword(),
printer);
printOperandList(reductionOperands(), LoopOp::getReductionKeyword(), printer);

if (op.getNumResults() > 0)
printer << " -> (" << op.getResultTypes() << ")";
if (getNumResults() > 0)
printer << " -> (" << getResultTypes() << ")";

printer << ' ';
printer.printRegion(op.region(),
printer.printRegion(region(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);

printer.printOptionalAttrDictWithKeyword(
op->getAttrs(), {LoopOp::getExecutionMappingAttrName(),
LoopOp::getOperandSegmentSizeAttr()});
(*this)->getAttrs(), {LoopOp::getExecutionMappingAttrName(),
LoopOp::getOperandSegmentSizeAttr()});
}

LogicalResult acc::LoopOp::verify() {
Expand Down
Loading