From ec7d8ac7f385ec380f6db44a3f7b99db0d0a533d Mon Sep 17 00:00:00 2001 From: AdityaK Date: Mon, 31 Mar 2025 18:37:06 +0000 Subject: [PATCH] Verify entry block in SPIRV dialect --- mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 18 ++++++++++++++++++ .../Dialect/SPIRV/IR/function-decorations.mlir | 17 +++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index da9855b02860d..16e91b0cb2cfc 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1021,6 +1021,24 @@ LogicalResult spirv::FuncOp::verifyType() { LogicalResult spirv::FuncOp::verifyBody() { FunctionType fnType = getFunctionType(); + if (!isExternal()) { + Block &entryBlock = front(); + + unsigned numArguments = this->getNumArguments(); + if (entryBlock.getNumArguments() != numArguments) + return emitOpError("entry block must have ") + << numArguments << " arguments to match function signature"; + + for (auto [index, fnArgType, blockArgType] : + llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) { + if (blockArgType != fnArgType) { + return emitOpError("type of entry block argument #") + << index << '(' << blockArgType + << ") must match the type of the corresponding argument in " + << "function signature(" << fnArgType << ')'; + } + } + } auto walkResult = walk([fnType](Operation *op) -> WalkResult { if (auto retOp = dyn_cast(op)) { diff --git a/mlir/test/Dialect/SPIRV/IR/function-decorations.mlir b/mlir/test/Dialect/SPIRV/IR/function-decorations.mlir index 07e187e6a7d68..f09767a416f6b 100644 --- a/mlir/test/Dialect/SPIRV/IR/function-decorations.mlir +++ b/mlir/test/Dialect/SPIRV/IR/function-decorations.mlir @@ -73,3 +73,20 @@ spirv.func @no_decoration_name_attr(%arg0 : !spirv.ptr { spirv.decoration = #spirv.decoration, random_attr = #spirv.decoration }) "None" { spirv.Return } + +// ----- + +// expected-error @+1 {{'spirv.func' op entry block must have 1 arguments to match function signature}} +spirv.func @f(f32) "None" { + %c0 = arith.constant 0 : index + spirv.Return +} + +// ----- + +// expected-error @+1 {{'spirv.func' op type of entry block argument #0('f64') must match the type of the corresponding argument in function signature('f32')}} +spirv.func @f(f32) "None" { + ^bb0(%arg0: f64): + %c0 = arith.constant 0 : index + spirv.Return +}