Skip to content

Conversation

amd-eochoalo
Copy link
Contributor

@amd-eochoalo amd-eochoalo commented Sep 17, 2025

spirv.FunctionCall's verifier was being too aggressive. It included verification of non-local properties by looking at the callee's definition.

This caused problems in cases where callee had verification errors and could lead to null pointer dereferencing.

According to MLIR's developers guide

TLDR: only verify local aspects of an operation,
in particular don’t follow def-use chains
(don’t look at the producer of any operand or the user
of any results).

The fix includes adding the SymbolUserOpInterface to FunctionCall and moving most of the verification logic to verifySymbolUses.

Fixes #159295

`spirv.FunctionCall`'s verifier was being too aggressive.
It included verification of non-local properties by
looking at the callee's definition.

This caused problems in cases where callee had verification
errors and could lead to null pointer dereferencing.

According to MLIR's developers guide

  TLDR: only verify local aspects of an operation,
  in particular don’t follow def-use chains
  (don’t look at the producer of any operand or the user
  of any results).

https://mlir.llvm.org/getting_started/DeveloperGuide/#ir-verifier
@amd-eochoalo amd-eochoalo marked this pull request as ready for review September 17, 2025 17:02
@llvmbot
Copy link
Member

llvmbot commented Sep 17, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Erick Ochoa Lopez (amd-eochoalo)

Changes

spirv.FunctionCall's verifier was being too aggressive. It included verification of non-local properties by looking at the callee's definition.

This caused problems in cases where callee had verification errors and could lead to null pointer dereferencing.

According to MLIR's developers guide

> TLDR: only verify local aspects of an operation,
> in particular don’t follow def-use chains
> (don’t look at the producer of any operand or the user
> of any results).

Fixes #159295


Full diff: https://github.com/llvm/llvm-project/pull/159399.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp (-30)
  • (modified) mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir (+29-42)
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 890406df74e72..95d63ddee6824 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -160,42 +160,12 @@ LogicalResult FunctionCallOp::verify() {
            << fnName.getValue() << "' not found in nearest symbol table";
   }
 
-  auto functionType = funcOp.getFunctionType();
-
   if (getNumResults() > 1) {
     return emitOpError(
                "expected callee function to have 0 or 1 result, but provided ")
            << getNumResults();
   }
 
-  if (functionType.getNumInputs() != getNumOperands()) {
-    return emitOpError("has incorrect number of operands for callee: expected ")
-           << functionType.getNumInputs() << ", but provided "
-           << getNumOperands();
-  }
-
-  for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
-    if (getOperand(i).getType() != functionType.getInput(i)) {
-      return emitOpError("operand type mismatch: expected operand type ")
-             << functionType.getInput(i) << ", but provided "
-             << getOperand(i).getType() << " for operand number " << i;
-    }
-  }
-
-  if (functionType.getNumResults() != getNumResults()) {
-    return emitOpError(
-               "has incorrect number of results has for callee: expected ")
-           << functionType.getNumResults() << ", but provided "
-           << getNumResults();
-  }
-
-  if (getNumResults() &&
-      (getResult(0).getType() != functionType.getResult(0))) {
-    return emitOpError("result type mismatch: expected ")
-           << functionType.getResult(0) << ", but provided "
-           << getResult(0).getType();
-  }
-
   return success();
 }
 
diff --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
index 8ec0bf5bbaacf..bfa33559fff61 100644
--- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
@@ -210,48 +210,6 @@ spirv.module Logical GLSL450 {
 
 // -----
 
-spirv.module Logical GLSL450 {
-  spirv.func @f_result_type_mismatch(%arg0 : i32, %arg1 : i32) -> () "None" {
-    // expected-error @+1 {{has incorrect number of results has for callee: expected 0, but provided 1}}
-    %1 = spirv.FunctionCall @f_result_type_mismatch(%arg0, %arg0) : (i32, i32) -> (i32)
-    spirv.Return
-  }
-}
-
-// -----
-
-spirv.module Logical GLSL450 {
-  spirv.func @f_type_mismatch(%arg0 : i32, %arg1 : i32) -> () "None" {
-    // expected-error @+1 {{has incorrect number of operands for callee: expected 2, but provided 1}}
-    spirv.FunctionCall @f_type_mismatch(%arg0) : (i32) -> ()
-    spirv.Return
-  }
-}
-
-// -----
-
-spirv.module Logical GLSL450 {
-  spirv.func @f_type_mismatch(%arg0 : i32, %arg1 : i32) -> () "None" {
-    %0 = spirv.Constant 2.0 : f32
-    // expected-error @+1 {{operand type mismatch: expected operand type 'i32', but provided 'f32' for operand number 1}}
-    spirv.FunctionCall @f_type_mismatch(%arg0, %0) : (i32, f32) -> ()
-    spirv.Return
-  }
-}
-
-// -----
-
-spirv.module Logical GLSL450 {
-  spirv.func @f_type_mismatch(%arg0 : i32, %arg1 : i32) -> i32 "None" {
-    %cst = spirv.Constant 0: i32
-    // expected-error @+1 {{result type mismatch: expected 'i32', but provided 'f32'}}
-    %0 = spirv.FunctionCall @f_type_mismatch(%arg0, %arg0) : (i32, i32) -> f32
-    spirv.ReturnValue %cst: i32
-  }
-}
-
-// -----
-
 spirv.module Logical GLSL450 {
   spirv.func @f_foo(%arg0 : i32, %arg1 : i32) -> i32 "None" {
     // expected-error @+1 {{op callee function 'f_undefined' not found in nearest symbol table}}
@@ -262,6 +220,35 @@ spirv.module Logical GLSL450 {
 
 // -----
 
+"builtin.module"() ({
+  "spirv.module"() <{
+    addressing_model = #spirv.addressing_model<Logical>,
+    memory_model = #spirv.memory_model<GLSL450>
+  }> ({
+    "spirv.func"() <{
+      function_control = #spirv.function_control<None>,
+      function_type = (f32) -> f32,
+      sym_name = "bar"
+    }> ({
+    ^bb0(%arg0: f32):
+      %0 = "spirv.FunctionCall"(%arg0) <{callee = @foo}> : (f32) -> f32
+      "spirv.ReturnValue"(%0) : (f32) -> ()
+    }) : () -> ()
+    // expected-error @+1 {{requires attribute 'function_type'}}
+    "spirv.func"() <{
+      function_control = #spirv.function_control<None>,
+      message = "2nd parent",
+      sym_name = "foo"
+      // This is invalid MLIR because function_type is missing from spirv.func
+    }> ({
+    ^bb0(%arg0: f32):
+      "spirv.ReturnValue"(%arg0) : (f32) -> ()
+    }) : () -> ()
+  }) : () -> ()
+}) : () -> ()
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.loop
 //===----------------------------------------------------------------------===//

Comment on lines 223 to 251
"builtin.module"() ({
"spirv.module"() <{
addressing_model = #spirv.addressing_model<Logical>,
memory_model = #spirv.memory_model<GLSL450>
}> ({
"spirv.func"() <{
function_control = #spirv.function_control<None>,
function_type = (f32) -> f32,
sym_name = "bar"
}> ({
^bb0(%arg0: f32):
%0 = "spirv.FunctionCall"(%arg0) <{callee = @foo}> : (f32) -> f32
"spirv.ReturnValue"(%0) : (f32) -> ()
}) : () -> ()
// expected-error @+1 {{requires attribute 'function_type'}}
"spirv.func"() <{
function_control = #spirv.function_control<None>,
message = "2nd parent",
sym_name = "foo"
// This is invalid MLIR because function_type is missing from spirv.func
}> ({
^bb0(%arg0: f32):
"spirv.ReturnValue"(%arg0) : (f32) -> ()
}) : () -> ()
}) : () -> ()
}) : () -> ()

// -----

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"builtin.module"() ({
"spirv.module"() <{
addressing_model = #spirv.addressing_model<Logical>,
memory_model = #spirv.memory_model<GLSL450>
}> ({
"spirv.func"() <{
function_control = #spirv.function_control<None>,
function_type = (f32) -> f32,
sym_name = "bar"
}> ({
^bb0(%arg0: f32):
%0 = "spirv.FunctionCall"(%arg0) <{callee = @foo}> : (f32) -> f32
"spirv.ReturnValue"(%0) : (f32) -> ()
}) : () -> ()
// expected-error @+1 {{requires attribute 'function_type'}}
"spirv.func"() <{
function_control = #spirv.function_control<None>,
message = "2nd parent",
sym_name = "foo"
// This is invalid MLIR because function_type is missing from spirv.func
}> ({
^bb0(%arg0: f32):
"spirv.ReturnValue"(%arg0) : (f32) -> ()
}) : () -> ()
}) : () -> ()
}) : () -> ()
// -----

I think removing this test is also reasonable. Happy to make this change if reviewers agree.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kuhar, can I interpret your thumbs up as "remove this test"? (Got confused by your correction to the comment inside this test.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will merge then but happy to delete this test in another PR.

@amd-eochoalo amd-eochoalo changed the title [mlir][spirv] Remove nonlocal call verification. [mlir][spirv] Use verifySymbolUses for spirv.FunctionCall. Sep 18, 2025
@amd-eochoalo amd-eochoalo merged commit 54c5521 into llvm:main Sep 18, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[MLIR][SPIR-V] Null pointer dereference in type attribute handling while function call processing
3 participants