diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md index 42c431d13f8ec..0c0b08509ff43 100644 --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -626,7 +626,8 @@ let verifier = [{ ``` Code placed in `verifier` will be called after the auto-generated verification -code. +code. The order of trait verification excluding those of `verifier` should not +be relied upon. ### Declarative Assembly Format diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir index b0cc4dd7a6eb3..36b2ee9b5a8ae 100644 --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -254,7 +254,7 @@ func @reduce_op_and_body(%arg0 : f32) { // ----- func @reduce_invalid_op(%arg0 : f32) { - // expected-error@+1 {{gpu.all_reduce' op attribute 'op' failed to satisfy constraint}} + // expected-error@+1 {{attribute 'op' failed to satisfy constraint}} %res = "gpu.all_reduce"(%arg0) ({}) {op = "foo"} : (f32) -> (f32) return } @@ -321,14 +321,14 @@ func @reduce_incorrect_yield(%arg0 : f32) { // ----- func @shuffle_mismatching_type(%arg0 : f32, %arg1 : i32, %arg2 : i32) { - // expected-error@+1 {{'gpu.shuffle' op requires the same type for value operand and result}} + // expected-error@+1 {{requires the same type for value operand and result}} %shfl, %pred = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "xor" } : (f32, i32, i32) -> (i32, i1) } // ----- func @shuffle_unsupported_type(%arg0 : index, %arg1 : i32, %arg2 : i32) { - // expected-error@+1 {{'gpu.shuffle' op requires value operand type to be f32 or i32}} + // expected-error@+1 {{requires value operand type to be f32 or i32}} %shfl, %pred = gpu.shuffle %arg0, %arg1, %arg2 xor : index } diff --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir index 0b97a8ebb1e55..b5b5639a5bd9e 100644 --- a/mlir/test/Dialect/LLVMIR/global.mlir +++ b/mlir/test/Dialect/LLVMIR/global.mlir @@ -65,12 +65,12 @@ func @references() { // ----- -// expected-error @+1 {{op requires string attribute 'sym_name'}} +// expected-error @+1 {{requires string attribute 'sym_name'}} "llvm.mlir.global"() ({}) {type = !llvm.i64, constant, value = 42 : i64} : () -> () // ----- -// expected-error @+1 {{op requires attribute 'type'}} +// expected-error @+1 {{requires attribute 'type'}} "llvm.mlir.global"() ({}) {sym_name = "foo", constant, value = 42 : i64} : () -> () // ----- diff --git a/mlir/test/Dialect/SPIRV/composite-ops.mlir b/mlir/test/Dialect/SPIRV/composite-ops.mlir index ca3f603115767..04153162e0dc9 100644 --- a/mlir/test/Dialect/SPIRV/composite-ops.mlir +++ b/mlir/test/Dialect/SPIRV/composite-ops.mlir @@ -124,7 +124,7 @@ func @composite_extract_invalid_index_type_1() -> () { // ----- func @composite_extract_invalid_index_type_2(%arg0 : !spv.array<4x!spv.array<4xf32>>) -> () { - // expected-error @+1 {{op attribute 'indices' failed to satisfy constraint: 32-bit integer array attribute}} + // expected-error @+1 {{attribute 'indices' failed to satisfy constraint: 32-bit integer array attribute}} %0 = spv.CompositeExtract %arg0[1] : !spv.array<4x!spv.array<4xf32>> return } diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 1f6da8190baeb..52d0586e98f2c 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1069,7 +1069,7 @@ func @reduce_elt_type_mismatch(%arg0: vector<16xf32>) -> i32 { // ----- func @reduce_unsupported_attr(%arg0: vector<16xf32>) -> i32 { - // expected-error@+1 {{'vector.reduction' op attribute 'kind' failed to satisfy constraint: string attribute}} + // expected-error@+1 {{attribute 'kind' failed to satisfy constraint: string attribute}} %0 = vector.reduction 1234, %arg0 : vector<16xf32> into i32 } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index c8e40c520139d..1ccf322ee8b57 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -58,7 +58,7 @@ func @constant_wrong_type() { func @affine_apply_no_map() { ^bb0: %i = constant 0 : index - %x = "affine.apply" (%i) { } : (index) -> (index) // expected-error {{'affine.apply' op requires attribute 'map'}} + %x = "affine.apply" (%i) { } : (index) -> (index) // expected-error {{requires attribute 'map'}} return } @@ -1205,7 +1205,7 @@ func @assume_alignment(%0: memref<4x4xf16>) { // 0 alignment value. func @assume_alignment(%0: memref<4x4xf16>) { - // expected-error@+1 {{'std.assume_alignment' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}} + // expected-error@+1 {{attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}} std.assume_alignment %0, 0 : memref<4x4xf16> return } diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td index 522dc2459fcaf..b4c850269a1d2 100644 --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -30,6 +30,20 @@ def AOp : NS_Op<"a_op", []> { // DEF-LABEL: AOp definitions +// Test verify method +// --- + +// DEF: LogicalResult AOpOperandAdaptor::verify +// DEF: auto tblgen_aAttr = odsAttrs.get("aAttr"); +// DEF-NEXT: if (!tblgen_aAttr) return emitError(loc, "'test.a_op' op ""requires attribute 'aAttr'"); +// DEF: if (!((some-condition))) return emitError(loc, "'test.a_op' op ""attribute 'aAttr' failed to satisfy constraint: some attribute kind"); +// DEF: auto tblgen_bAttr = odsAttrs.get("bAttr"); +// DEF-NEXT: if (tblgen_bAttr) { +// DEF-NEXT: if (!((some-condition))) return emitError(loc, "'test.a_op' op ""attribute 'bAttr' failed to satisfy constraint: some attribute kind"); +// DEF: auto tblgen_cAttr = odsAttrs.get("cAttr"); +// DEF-NEXT: if (tblgen_cAttr) { +// DEF-NEXT: if (!((some-condition))) return emitError(loc, "'test.a_op' op ""attribute 'cAttr' failed to satisfy constraint: some attribute kind"); + // Test getter methods // --- @@ -80,20 +94,6 @@ def AOp : NS_Op<"a_op", []> { // DEF: ArrayRef attributes // DEF: odsState.addAttributes(attributes); -// Test verify method -// --- - -// DEF: AOp::verify() -// DEF: auto tblgen_aAttr = this->getAttr("aAttr"); -// DEF-NEXT: if (!tblgen_aAttr) return emitOpError("requires attribute 'aAttr'"); -// DEF: if (!((some-condition))) return emitOpError("attribute 'aAttr' failed to satisfy constraint: some attribute kind"); -// DEF: auto tblgen_bAttr = this->getAttr("bAttr"); -// DEF-NEXT: if (tblgen_bAttr) { -// DEF-NEXT: if (!((some-condition))) return emitOpError("attribute 'bAttr' failed to satisfy constraint: some attribute kind"); -// DEF: auto tblgen_cAttr = this->getAttr("cAttr"); -// DEF-NEXT: if (tblgen_cAttr) { -// DEF-NEXT: if (!((some-condition))) return emitOpError("attribute 'cAttr' failed to satisfy constraint: some attribute kind"); - def SomeTypeAttr : TypeAttrBase<"SomeType", "some type attribute">; def BOp : NS_Op<"b_op", []> { @@ -114,27 +114,11 @@ def BOp : NS_Op<"b_op", []> { ); } -// Test common attribute kind getters' return types -// --- - -// DEF: Attribute BOp::any_attr() -// DEF: bool BOp::bool_attr() -// DEF: APInt BOp::i32_attr() -// DEF: APInt BOp::i64_attr() -// DEF: APFloat BOp::f32_attr() -// DEF: APFloat BOp::f64_attr() -// DEF: StringRef BOp::str_attr() -// DEF: ElementsAttr BOp::elements_attr() -// DEF: StringRef BOp::function_attr() -// DEF: SomeType BOp::type_attr() -// DEF: ArrayAttr BOp::array_attr() -// DEF: ArrayAttr BOp::some_attr_array() -// DEF: Type BOp::type_attr() // Test common attribute kinds' constraints // --- -// DEF-LABEL: BOp::verify +// DEF-LABEL: BOpOperandAdaptor::verify // DEF: if (!((true))) // DEF: if (!((tblgen_bool_attr.isa()))) // DEF: if (!(((tblgen_i32_attr.isa())) && ((tblgen_i32_attr.cast().getType().isSignlessInteger(32))))) @@ -149,6 +133,23 @@ def BOp : NS_Op<"b_op", []> { // DEF: if (!(((tblgen_some_attr_array.isa())) && (llvm::all_of(tblgen_some_attr_array.cast(), [](Attribute attr) { return (some-condition); })))) // DEF: if (!(((tblgen_type_attr.isa())) && ((tblgen_type_attr.cast().getValue().isa())))) +// Test common attribute kind getters' return types +// --- + +// DEF: Attribute BOp::any_attr() +// DEF: bool BOp::bool_attr() +// DEF: APInt BOp::i32_attr() +// DEF: APInt BOp::i64_attr() +// DEF: APFloat BOp::f32_attr() +// DEF: APFloat BOp::f64_attr() +// DEF: StringRef BOp::str_attr() +// DEF: ElementsAttr BOp::elements_attr() +// DEF: StringRef BOp::function_attr() +// DEF: SomeType BOp::type_attr() +// DEF: ArrayAttr BOp::array_attr() +// DEF: ArrayAttr BOp::some_attr_array() +// DEF: Type BOp::type_attr() + // Test building constant values for array attribute kinds // --- diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td index aa7b50710cde5..a617208d157a0 100644 --- a/mlir/test/mlir-tblgen/predicate.td +++ b/mlir/test/mlir-tblgen/predicate.td @@ -1,4 +1,4 @@ -// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s +// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s --dump-input-on-failure include "mlir/IR/OpBase.td" @@ -32,41 +32,41 @@ def OpF : NS_Op<"op_for_int_min_val", []> { let arguments = (ins Confined]>:$attr); } -// CHECK-LABEL: OpF::verify() +// CHECK-LABEL: OpFOperandAdaptor::verify // CHECK: (tblgen_attr.cast().getInt() >= 10) -// CHECK-SAME: return emitOpError("attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 10"); +// CHECK-SAME: "attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 10" def OpFX : NS_Op<"op_for_int_max_val", []> { let arguments = (ins Confined]>:$attr); } -// CHECK-LABEL: OpFX::verify() +// CHECK-LABEL: OpFXOperandAdaptor::verify // CHECK: (tblgen_attr.cast().getInt() <= 10) -// CHECK-SAME: return emitOpError("attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose maximum value is 10"); +// CHECK-SAME: "attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose maximum value is 10" def OpG : NS_Op<"op_for_arr_min_count", []> { let arguments = (ins Confined]>:$attr); } -// CHECK-LABEL: OpG::verify() +// CHECK-LABEL: OpGOperandAdaptor::verify // CHECK: (tblgen_attr.cast().size() >= 8) -// CHECK-SAME: return emitOpError("attribute 'attr' failed to satisfy constraint: array attribute with at least 8 elements"); +// CHECK-SAME: "attribute 'attr' failed to satisfy constraint: array attribute with at least 8 elements" def OpH : NS_Op<"op_for_arr_value_at_index", []> { let arguments = (ins Confined]>:$attr); } -// CHECK-LABEL: OpH::verify() +// CHECK-LABEL: OpHOperandAdaptor::verify // CHECK: (((tblgen_attr.cast().size() > 0)) && ((tblgen_attr.cast()[0].cast().getInt() == 8))))) -// CHECK-SAME: return emitOpError("attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be 8"); +// CHECK-SAME: "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be 8" def OpI: NS_Op<"op_for_arr_min_value_at_index", []> { let arguments = (ins Confined]>:$attr); } -// CHECK-LABEL: OpI::verify() +// CHECK-LABEL: OpIOperandAdaptor::verify // CHECK: (((tblgen_attr.cast().size() > 0)) && ((tblgen_attr.cast()[0].cast().getInt() >= 8))))) -// CHECK-SAME: return emitOpError("attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at least 8"); +// CHECK-SAME: "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at least 8" def OpJ: NS_Op<"op_for_TCopVTEtAreSameAt", [ PredOpTrait<"operands indexed at 0, 2, 3 should all have " @@ -80,11 +80,11 @@ def OpJ: NS_Op<"op_for_TCopVTEtAreSameAt", [ ); } -// CHECK-LABEL: OpJ::verify() +// CHECK-LABEL: OpJOperandAdaptor::verify // CHECK: llvm::is_splat(llvm::map_range( // CHECK-SAME: llvm::ArrayRef({0, 2, 3}), // CHECK-SAME: [this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); })) -// CHECK: return emitOpError("failed to verify that operands indexed at 0, 2, 3 should all have the same type"); +// CHECK: "failed to verify that operands indexed at 0, 2, 3 should all have the same type" def OpK : NS_Op<"op_for_AnyTensorOf", []> { let arguments = (ins TensorOf<[F32, I32]>:$x); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 7b0cd9d7a4826..21dccd4f3d5ad 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -321,6 +321,116 @@ class OpEmitter { }; } // end anonymous namespace +// Populate the format context `ctx` with substitutions of attributes, operands +// and results. +// - attrGet corresponds to the name of the function to call to get value of +// attribute (the generated function call returns an Attribute); +// - operandGet corresponds to the name of the function with which to retrieve +// an operand (the generaed function call returns an OperandRange); +// - reultGet corresponds to the name of the function to get an result (the +// generated function call returns a ValueRange); +static void populateSubstitutions(const Operator &op, const char *attrGet, + const char *operandGet, const char *resultGet, + FmtContext &ctx) { + // Populate substitutions for attributes and named operands. + for (const auto &namedAttr : op.getAttributes()) + ctx.addSubst(namedAttr.name, + formatv("{0}(\"{1}\")", attrGet, namedAttr.name)); + for (int i = 0, e = op.getNumOperands(); i < e; ++i) { + auto &value = op.getOperand(i); + if (value.name.empty()) + continue; + + if (value.isVariadic()) + ctx.addSubst(value.name, formatv("{0}({1})", operandGet, i)); + else + ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", operandGet, i)); + } + + // Populate substitutions for results. + for (int i = 0, e = op.getNumResults(); i < e; ++i) { + auto &value = op.getResult(i); + if (value.name.empty()) + continue; + + if (value.isVariadic()) + ctx.addSubst(value.name, formatv("{0}({1})", resultGet, i)); + else + ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", resultGet, i)); + } +} + +// Generate attribute verification. If emitVerificationRequiringOp is set then +// only verification for attributes whose value depend on op being known are +// emitted, else only verification that doesn't depend on the op being known are +// generated. +// - emitErrorPrefix is the prefix for the error emitting call which consists +// of the entire function call up to start of error message fragment; +// - emitVerificationRequiringOp specifies whether verification should be +// emitted for verification that require the op to exist; +static void genAttributeVerifier(const Operator &op, const char *attrGet, + const Twine &emitErrorPrefix, + bool emitVerificationRequiringOp, + FmtContext &ctx, OpMethodBody &body) { + for (const auto &namedAttr : op.getAttributes()) { + const auto &attr = namedAttr.attr; + if (attr.isDerivedAttr()) + continue; + + auto attrName = namedAttr.name; + bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional(); + auto attrPred = attr.getPredicate(); + auto condition = attrPred.isNull() ? "" : attrPred.getCondition(); + // There is a condition to emit only if the use of $_op and whether to + // emit verifications for op matches. + bool hasConditionToEmit = (!(condition.find("$_op") != StringRef::npos) ^ + emitVerificationRequiringOp); + + // Prefix with `tblgen_` to avoid hiding the attribute accessor. + auto varName = tblgenNamePrefix + attrName; + + // If the attribute is + // 1. Required (not allowed missing) and not in op verification, or + // 2. Has a condition that will get verified + // then the variable will be used. + // + // Therefore, for optional attributes whose verification requires that an + // op already exists for verification/emitVerificationRequiringOp is set + // has nothing that can be verified here. + if ((allowMissingAttr || emitVerificationRequiringOp) && + !hasConditionToEmit) + continue; + + body << formatv(" {\n auto {0} = {1}(\"{2}\");\n", varName, attrGet, + attrName); + + if (!emitVerificationRequiringOp && !allowMissingAttr) { + body << " if (!" << varName << ") return " << emitErrorPrefix + << "\"requires attribute '" << attrName << "'\");\n"; + } + + if (!hasConditionToEmit) { + body << " }\n"; + continue; + } + + if (allowMissingAttr) { + // If the attribute has a default value, then only verify the predicate if + // set. This does effectively assume that the default value is valid. + // TODO: verify the debug value is valid (perhaps in debug mode only). + body << " if (" << varName << ") {\n"; + } + + body << tgfmt(" if (!($0)) return $1\"attribute '$2' " + "failed to satisfy constraint: $3\");\n", + /*ctx=*/nullptr, tgfmt(condition, &ctx.withSelf(varName)), + emitErrorPrefix, attrName, attr.getDescription()); + if (allowMissingAttr) + body << " }\n"; + body << " }\n"; + } +} + OpEmitter::OpEmitter(const Operator &op) : def(op.getDef()), op(op), opClass(op.getCppClassName(), op.getExtraClassDeclaration()) { @@ -1512,110 +1622,27 @@ void OpEmitter::genPrinter() { } void OpEmitter::genVerifier() { - auto valueInit = def.getValueInit("verifier"); - CodeInit *codeInit = dyn_cast(valueInit); - bool hasCustomVerify = codeInit && !codeInit->getValue().empty(); - auto &method = opClass.newMethod("LogicalResult", "verify", /*params=*/""); auto &body = method.body(); + body << " if (failed(" << op.getAdaptorName() + << "(*this).verify(this->getLoc()))) " + << "return failure();\n"; - const char *checkAttrSizedValueSegmentsCode = R"( - { - auto sizeAttr = getAttrOfType("{0}"); - auto numElements = sizeAttr.getType().cast().getNumElements(); - if (numElements != {1}) {{ - return emitOpError("'{0}' attribute for specifying {2} segments " - "must have {1} elements"); - } - } - )"; - - // Verify a few traits first so that we can use - // getODSOperands()/getODSResults() in the rest of the verifier. - for (auto &trait : op.getTraits()) { - if (auto *t = dyn_cast(&trait)) { - if (t->getTrait() == "OpTrait::AttrSizedOperandSegments") { - body << formatv(checkAttrSizedValueSegmentsCode, - "operand_segment_sizes", op.getNumOperands(), - "operand"); - } else if (t->getTrait() == "OpTrait::AttrSizedResultSegments") { - body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes", - op.getNumResults(), "result"); - } - } - } - - // Populate substitutions for attributes and named operands and results. - for (const auto &namedAttr : op.getAttributes()) - verifyCtx.addSubst(namedAttr.name, - formatv("this->getAttr(\"{0}\")", namedAttr.name)); - for (int i = 0, e = op.getNumOperands(); i < e; ++i) { - auto &value = op.getOperand(i); - if (value.name.empty()) - continue; - - if (value.isVariadic()) - verifyCtx.addSubst(value.name, formatv("this->getODSOperands({0})", i)); - else - verifyCtx.addSubst(value.name, - formatv("(*this->getODSOperands({0}).begin())", i)); - } - for (int i = 0, e = op.getNumResults(); i < e; ++i) { - auto &value = op.getResult(i); - if (value.name.empty()) - continue; - - if (value.isVariadic()) - verifyCtx.addSubst(value.name, formatv("this->getODSResults({0})", i)); - else - verifyCtx.addSubst(value.name, - formatv("(*this->getODSResults({0}).begin())", i)); - } - - // Verify the attributes have the correct type. - for (const auto &namedAttr : op.getAttributes()) { - const auto &attr = namedAttr.attr; - if (attr.isDerivedAttr()) - continue; - - auto attrName = namedAttr.name; - // Prefix with `tblgen_` to avoid hiding the attribute accessor. - auto varName = tblgenNamePrefix + attrName; - body << formatv(" auto {0} = this->getAttr(\"{1}\");\n", varName, - attrName); - - bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional(); - if (allowMissingAttr) { - // If the attribute has a default value, then only verify the predicate if - // set. This does effectively assume that the default value is valid. - // TODO: verify the debug value is valid (perhaps in debug mode only). - body << " if (" << varName << ") {\n"; - } else { - body << " if (!" << varName - << ") return emitOpError(\"requires attribute '" << attrName - << "'\");\n {\n"; - } - - auto attrPred = attr.getPredicate(); - if (!attrPred.isNull()) { - body << tgfmt( - " if (!($0)) return emitOpError(\"attribute '$1' " - "failed to satisfy constraint: $2\");\n", - /*ctx=*/nullptr, - tgfmt(attrPred.getCondition(), &verifyCtx.withSelf(varName)), - attrName, attr.getDescription()); - } - - body << " }\n"; - } + auto *valueInit = def.getValueInit("verifier"); + CodeInit *codeInit = dyn_cast(valueInit); + bool hasCustomVerify = codeInit && !codeInit->getValue().empty(); + populateSubstitutions(op, "this->getAttr", "this->getODSOperands", + "this->getODSResults", verifyCtx); + genAttributeVerifier(op, "this->getAttr", "emitOpError(", + /*emitVerificationRequiringOp=*/true, verifyCtx, body); genOperandResultVerifier(body, op.getOperands(), "operand"); genOperandResultVerifier(body, op.getResults(), "result"); for (auto &trait : op.getTraits()) { if (auto *t = dyn_cast(&trait)) { - body << tgfmt(" if (!($0)) {\n " - "return emitOpError(\"failed to verify that $1\");\n }\n", + body << tgfmt(" if (!($0))\n " + "return emitOpError(\"failed to verify that $1\");\n", &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx), t->getDescription()); } @@ -1890,12 +1917,17 @@ class OpOperandAdaptorEmitter { private: explicit OpOperandAdaptorEmitter(const Operator &op); + // Add verification function. This generates a verify method for the adaptor + // which verifies all the op-independent attribute constraints. + void addVerification(); + + const Operator &op; Class adaptor; }; } // end namespace OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op) - : adaptor(op.getAdaptorName()) { + : op(op), adaptor(op.getAdaptorName()) { adaptor.newField("ValueRange", "odsOperands"); adaptor.newField("DictionaryAttr", "odsAttrs"); const auto *attrSizedOperands = @@ -1957,6 +1989,50 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op) if (!attr.isDerivedAttr()) emitAttr(name, attr); } + + // Add verification function. + addVerification(); +} + +void OpOperandAdaptorEmitter::addVerification() { + auto &method = adaptor.newMethod("LogicalResult", "verify", + /*params=*/"Location loc"); + auto &body = method.body(); + + const char *checkAttrSizedValueSegmentsCode = R"( + { + auto sizeAttr = odsAttrs.get("{0}").cast(); + auto numElements = sizeAttr.getType().cast().getNumElements(); + if (numElements != {1}) + return emitError(loc, "'{0}' attribute for specifying {2} segments " + "must have {1} elements"); + } + )"; + + // Verify a few traits first so that we can use + // getODSOperands()/getODSResults() in the rest of the verifier. + for (auto &trait : op.getTraits()) { + if (auto *t = dyn_cast(&trait)) { + if (t->getTrait() == "OpTrait::AttrSizedOperandSegments") { + body << formatv(checkAttrSizedValueSegmentsCode, + "operand_segment_sizes", op.getNumOperands(), + "operand"); + } else if (t->getTrait() == "OpTrait::AttrSizedResultSegments") { + body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes", + op.getNumResults(), "result"); + } + } + } + + FmtContext verifyCtx; + populateSubstitutions(op, "odsAttrs.get", "getODSOperands", + "", verifyCtx); + genAttributeVerifier(op, "odsAttrs.get", + Twine("emitError(loc, \"'") + op.getOperationName() + + "' op \"", + /*emitVerificationRequiringOp*/ false, verifyCtx, body); + + body << " return success();"; } void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {