diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td index ef6682ab3630c..acb6467132be9 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td @@ -15,6 +15,7 @@ #define MLIR_DIALECT_SPIRV_IR_CONTROLFLOW_OPS include "mlir/Dialect/SPIRV/IR/SPIRVBase.td" +include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -187,7 +188,8 @@ def SPIRV_BranchConditionalOp : SPIRV_Op<"BranchConditional", [ // ----- def SPIRV_FunctionCallOp : SPIRV_Op<"FunctionCall", [ - InFunctionScope, DeclareOpInterfaceMethods]> { + InFunctionScope, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "Call a function."; let description = [{ diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp index 890406df74e72..f0b46e61965f4 100644 --- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp @@ -151,10 +151,20 @@ LogicalResult BranchConditionalOp::verify() { //===----------------------------------------------------------------------===// LogicalResult FunctionCallOp::verify() { + if (getNumResults() > 1) { + return emitOpError( + "expected callee function to have 0 or 1 result, but provided ") + << getNumResults(); + } + return success(); +} + +LogicalResult +FunctionCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { auto fnName = getCalleeAttr(); - auto funcOp = dyn_cast_or_null( - SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName)); + auto funcOp = + symbolTable.lookupNearestSymbolFrom(*this, fnName); if (!funcOp) { return emitOpError("callee function '") << fnName.getValue() << "' not found in nearest symbol table"; @@ -162,12 +172,6 @@ LogicalResult FunctionCallOp::verify() { auto functionType = funcOp.getFunctionType(); - if (getNumResults() > 1) { - return emitOpError( - "expected callee function to have 0 or 1 result, but provided ") - << getNumResults(); - } - if (functionType.getNumInputs() != getNumOperands()) { return emitOpError("has incorrect number of operands for callee: expected ") << functionType.getNumInputs() << ", but provided " diff --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir index 8ec0bf5bbaacf..8e29ff6679068 100644 --- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir @@ -262,6 +262,35 @@ spirv.module Logical GLSL450 { // ----- +"builtin.module"() ({ + "spirv.module"() <{ + addressing_model = #spirv.addressing_model, + memory_model = #spirv.memory_model + }> ({ + "spirv.func"() <{ + function_control = #spirv.function_control, + function_type = (f32) -> f32, + sym_name = "bar" + }> ({ + ^bb0(%arg0: f32): + %0 = "spirv.FunctionCall"(%arg0) <{callee = @foo}> : (f32) -> f32 + "spirv.ReturnValue"(%0) : (f32) -> () + }) : () -> () + // expected-error @+1 {{requires attribute 'function_type'}} + "spirv.func"() <{ + function_control = #spirv.function_control, + message = "2nd parent", + sym_name = "foo" + // This is invalid MLIR because function_type is missing from spirv.func. + }> ({ + ^bb0(%arg0: f32): + "spirv.ReturnValue"(%arg0) : (f32) -> () + }) : () -> () + }) : () -> () +}) : () -> () + +// ----- + //===----------------------------------------------------------------------===// // spirv.mlir.loop //===----------------------------------------------------------------------===//