-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][spirv] Use verifySymbolUses
for spirv.FunctionCall
.
#159399
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][spirv] Use verifySymbolUses
for spirv.FunctionCall
.
#159399
Conversation
`spirv.FunctionCall`'s verifier was being too aggressive. It included verification of non-local properties by looking at the callee's definition. This caused problems in cases where callee had verification errors and could lead to null pointer dereferencing. According to MLIR's developers guide TLDR: only verify local aspects of an operation, in particular don’t follow def-use chains (don’t look at the producer of any operand or the user of any results). https://mlir.llvm.org/getting_started/DeveloperGuide/#ir-verifier
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-spirv Author: Erick Ochoa Lopez (amd-eochoalo) Changes
This caused problems in cases where callee had verification errors and could lead to null pointer dereferencing. According to MLIR's developers guide > TLDR: only verify local aspects of an operation, Fixes #159295 Full diff: https://github.com/llvm/llvm-project/pull/159399.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 890406df74e72..95d63ddee6824 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -160,42 +160,12 @@ LogicalResult FunctionCallOp::verify() {
<< fnName.getValue() << "' not found in nearest symbol table";
}
- auto functionType = funcOp.getFunctionType();
-
if (getNumResults() > 1) {
return emitOpError(
"expected callee function to have 0 or 1 result, but provided ")
<< getNumResults();
}
- if (functionType.getNumInputs() != getNumOperands()) {
- return emitOpError("has incorrect number of operands for callee: expected ")
- << functionType.getNumInputs() << ", but provided "
- << getNumOperands();
- }
-
- for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
- if (getOperand(i).getType() != functionType.getInput(i)) {
- return emitOpError("operand type mismatch: expected operand type ")
- << functionType.getInput(i) << ", but provided "
- << getOperand(i).getType() << " for operand number " << i;
- }
- }
-
- if (functionType.getNumResults() != getNumResults()) {
- return emitOpError(
- "has incorrect number of results has for callee: expected ")
- << functionType.getNumResults() << ", but provided "
- << getNumResults();
- }
-
- if (getNumResults() &&
- (getResult(0).getType() != functionType.getResult(0))) {
- return emitOpError("result type mismatch: expected ")
- << functionType.getResult(0) << ", but provided "
- << getResult(0).getType();
- }
-
return success();
}
diff --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
index 8ec0bf5bbaacf..bfa33559fff61 100644
--- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
@@ -210,48 +210,6 @@ spirv.module Logical GLSL450 {
// -----
-spirv.module Logical GLSL450 {
- spirv.func @f_result_type_mismatch(%arg0 : i32, %arg1 : i32) -> () "None" {
- // expected-error @+1 {{has incorrect number of results has for callee: expected 0, but provided 1}}
- %1 = spirv.FunctionCall @f_result_type_mismatch(%arg0, %arg0) : (i32, i32) -> (i32)
- spirv.Return
- }
-}
-
-// -----
-
-spirv.module Logical GLSL450 {
- spirv.func @f_type_mismatch(%arg0 : i32, %arg1 : i32) -> () "None" {
- // expected-error @+1 {{has incorrect number of operands for callee: expected 2, but provided 1}}
- spirv.FunctionCall @f_type_mismatch(%arg0) : (i32) -> ()
- spirv.Return
- }
-}
-
-// -----
-
-spirv.module Logical GLSL450 {
- spirv.func @f_type_mismatch(%arg0 : i32, %arg1 : i32) -> () "None" {
- %0 = spirv.Constant 2.0 : f32
- // expected-error @+1 {{operand type mismatch: expected operand type 'i32', but provided 'f32' for operand number 1}}
- spirv.FunctionCall @f_type_mismatch(%arg0, %0) : (i32, f32) -> ()
- spirv.Return
- }
-}
-
-// -----
-
-spirv.module Logical GLSL450 {
- spirv.func @f_type_mismatch(%arg0 : i32, %arg1 : i32) -> i32 "None" {
- %cst = spirv.Constant 0: i32
- // expected-error @+1 {{result type mismatch: expected 'i32', but provided 'f32'}}
- %0 = spirv.FunctionCall @f_type_mismatch(%arg0, %arg0) : (i32, i32) -> f32
- spirv.ReturnValue %cst: i32
- }
-}
-
-// -----
-
spirv.module Logical GLSL450 {
spirv.func @f_foo(%arg0 : i32, %arg1 : i32) -> i32 "None" {
// expected-error @+1 {{op callee function 'f_undefined' not found in nearest symbol table}}
@@ -262,6 +220,35 @@ spirv.module Logical GLSL450 {
// -----
+"builtin.module"() ({
+ "spirv.module"() <{
+ addressing_model = #spirv.addressing_model<Logical>,
+ memory_model = #spirv.memory_model<GLSL450>
+ }> ({
+ "spirv.func"() <{
+ function_control = #spirv.function_control<None>,
+ function_type = (f32) -> f32,
+ sym_name = "bar"
+ }> ({
+ ^bb0(%arg0: f32):
+ %0 = "spirv.FunctionCall"(%arg0) <{callee = @foo}> : (f32) -> f32
+ "spirv.ReturnValue"(%0) : (f32) -> ()
+ }) : () -> ()
+ // expected-error @+1 {{requires attribute 'function_type'}}
+ "spirv.func"() <{
+ function_control = #spirv.function_control<None>,
+ message = "2nd parent",
+ sym_name = "foo"
+ // This is invalid MLIR because function_type is missing from spirv.func
+ }> ({
+ ^bb0(%arg0: f32):
+ "spirv.ReturnValue"(%arg0) : (f32) -> ()
+ }) : () -> ()
+ }) : () -> ()
+}) : () -> ()
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.mlir.loop
//===----------------------------------------------------------------------===//
|
"builtin.module"() ({ | ||
"spirv.module"() <{ | ||
addressing_model = #spirv.addressing_model<Logical>, | ||
memory_model = #spirv.memory_model<GLSL450> | ||
}> ({ | ||
"spirv.func"() <{ | ||
function_control = #spirv.function_control<None>, | ||
function_type = (f32) -> f32, | ||
sym_name = "bar" | ||
}> ({ | ||
^bb0(%arg0: f32): | ||
%0 = "spirv.FunctionCall"(%arg0) <{callee = @foo}> : (f32) -> f32 | ||
"spirv.ReturnValue"(%0) : (f32) -> () | ||
}) : () -> () | ||
// expected-error @+1 {{requires attribute 'function_type'}} | ||
"spirv.func"() <{ | ||
function_control = #spirv.function_control<None>, | ||
message = "2nd parent", | ||
sym_name = "foo" | ||
// This is invalid MLIR because function_type is missing from spirv.func | ||
}> ({ | ||
^bb0(%arg0: f32): | ||
"spirv.ReturnValue"(%arg0) : (f32) -> () | ||
}) : () -> () | ||
}) : () -> () | ||
}) : () -> () | ||
|
||
// ----- | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"builtin.module"() ({ | |
"spirv.module"() <{ | |
addressing_model = #spirv.addressing_model<Logical>, | |
memory_model = #spirv.memory_model<GLSL450> | |
}> ({ | |
"spirv.func"() <{ | |
function_control = #spirv.function_control<None>, | |
function_type = (f32) -> f32, | |
sym_name = "bar" | |
}> ({ | |
^bb0(%arg0: f32): | |
%0 = "spirv.FunctionCall"(%arg0) <{callee = @foo}> : (f32) -> f32 | |
"spirv.ReturnValue"(%0) : (f32) -> () | |
}) : () -> () | |
// expected-error @+1 {{requires attribute 'function_type'}} | |
"spirv.func"() <{ | |
function_control = #spirv.function_control<None>, | |
message = "2nd parent", | |
sym_name = "foo" | |
// This is invalid MLIR because function_type is missing from spirv.func | |
}> ({ | |
^bb0(%arg0: f32): | |
"spirv.ReturnValue"(%arg0) : (f32) -> () | |
}) : () -> () | |
}) : () -> () | |
}) : () -> () | |
// ----- |
I think removing this test is also reasonable. Happy to make this change if reviewers agree.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kuhar, can I interpret your thumbs up as "remove this test"? (Got confused by your correction to the comment inside this test.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will merge then but happy to delete this test in another PR.
verifySymbolUses
for spirv.FunctionCall
.
Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
spirv.FunctionCall
's verifier was being too aggressive. It included verification of non-local properties by looking at the callee's definition.This caused problems in cases where callee had verification errors and could lead to null pointer dereferencing.
According to MLIR's developers guide
The fix includes adding the
SymbolUserOpInterface
toFunctionCall
and moving most of the verification logic toverifySymbolUses
.Fixes #159295