diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td index 5fd25e3b576f2..fbf750d643031 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td @@ -267,6 +267,15 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [ This op itself takes no operands and generates no results. Its region can take zero or more arguments and return zero or one values. + From `SPV_KHR_physical_storage_buffer`: + If a parameter of function is + - a pointer (or contains a pointer) in the PhysicalStorageBuffer storage + class, the function parameter must be decorated with exactly one of + `Aliased` or `Restrict`. + - a pointer (or contains a pointer) and the type it points to is a pointer + in the PhysicalStorageBuffer storage class, the function parameter must + be decorated with exactly one of `AliasedPointer` or `RestrictPointer`. + ``` @@ -280,6 +289,20 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [ ```mlir spirv.func @foo() -> () "None" { ... } spirv.func @bar() -> () "Inline|Pure" { ... } + + spirv.func @aliased_pointer(%arg0: !spirv.ptr, + { spirv.decoration = #spirv.decoration }) -> () "None" { ... } + + spirv.func @restrict_pointer(%arg0: !spirv.ptr, + { spirv.decoration = #spirv.decoration }) -> () "None" { ... } + + spirv.func @aliased_pointee(%arg0: !spirv.ptr, Generic> { spirv.decoration = + #spirv.decoration }) -> () "None" { ... } + + spirv.func @restrict_pointee(%arg0: !spirv.ptr, Generic> { spirv.decoration = + #spirv.decoration }) -> () "None" { ... } ``` }]; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp index 8a68decc5878c..d7944d600b0a2 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -992,30 +992,39 @@ static LogicalResult verifyRegionAttribute(Location loc, Type valueType, StringRef symbol = attribute.getName().strref(); Attribute attr = attribute.getValue(); - if (symbol != spirv::getInterfaceVarABIAttrName()) - return emitError(loc, "found unsupported '") - << symbol << "' attribute on region argument"; - - auto varABIAttr = llvm::dyn_cast(attr); - if (!varABIAttr) - return emitError(loc, "'") - << symbol << "' must be a spirv::InterfaceVarABIAttr"; - - if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat()) - return emitError(loc, "'") << symbol - << "' attribute cannot specify storage class " - "when attaching to a non-scalar value"; + if (symbol == spirv::getInterfaceVarABIAttrName()) { + auto varABIAttr = llvm::dyn_cast(attr); + if (!varABIAttr) + return emitError(loc, "'") + << symbol << "' must be a spirv::InterfaceVarABIAttr"; + + if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat()) + return emitError(loc, "'") << symbol + << "' attribute cannot specify storage class " + "when attaching to a non-scalar value"; + return success(); + } + if (symbol == spirv::DecorationAttr::name) { + if (!isa(attr)) + return emitError(loc, "'") + << symbol << "' must be a spirv::DecorationAttr"; + return success(); + } - return success(); + return emitError(loc, "found unsupported '") + << symbol << "' attribute on region argument"; } LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op, unsigned regionIndex, unsigned argIndex, NamedAttribute attribute) { - return verifyRegionAttribute( - op->getLoc(), op->getRegion(regionIndex).getArgument(argIndex).getType(), - attribute); + auto funcOp = dyn_cast(op); + if (!funcOp) + return success(); + Type argType = funcOp.getArgumentTypes()[argIndex]; + + return verifyRegionAttribute(op->getLoc(), argType, attribute); } LogicalResult SPIRVDialect::verifyRegionResultAttribute( diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 5343a12132a91..3b159030cab75 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -972,8 +972,73 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) { } LogicalResult spirv::FuncOp::verifyType() { - if (getFunctionType().getNumResults() > 1) + FunctionType fnType = getFunctionType(); + if (fnType.getNumResults() > 1) return emitOpError("cannot have more than one result"); + + auto hasDecorationAttr = [&](spirv::Decoration decoration, + unsigned argIndex) { + auto func = llvm::cast(getOperation()); + for (auto argAttr : cast(func).getArgAttrs(argIndex)) { + if (argAttr.getName() != spirv::DecorationAttr::name) + continue; + if (auto decAttr = dyn_cast(argAttr.getValue())) + return decAttr.getValue() == decoration; + } + return false; + }; + + for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) { + Type param = fnType.getInputs()[i]; + auto inputPtrType = dyn_cast(param); + if (!inputPtrType) + continue; + + auto pointeePtrType = + dyn_cast(inputPtrType.getPointeeType()); + if (pointeePtrType) { + // SPIR-V spec, from SPV_KHR_physical_storage_buffer: + // > If an OpFunctionParameter is a pointer (or contains a pointer) + // > and the type it points to is a pointer in the PhysicalStorageBuffer + // > storage class, the function parameter must be decorated with exactly + // > one of AliasedPointer or RestrictPointer. + if (pointeePtrType.getStorageClass() != + spirv::StorageClass::PhysicalStorageBuffer) + continue; + + bool hasAliasedPtr = + hasDecorationAttr(spirv::Decoration::AliasedPointer, i); + bool hasRestrictPtr = + hasDecorationAttr(spirv::Decoration::RestrictPointer, i); + if (!hasAliasedPtr && !hasRestrictPtr) + return emitOpError() + << "with a pointer points to a physical buffer pointer must " + "be decorated either 'AliasedPointer' or 'RestrictPointer'"; + continue; + } + // SPIR-V spec, from SPV_KHR_physical_storage_buffer: + // > If an OpFunctionParameter is a pointer (or contains a pointer) in + // > the PhysicalStorageBuffer storage class, the function parameter must + // > be decorated with exactly one of Aliased or Restrict. + if (auto pointeeArrayType = + dyn_cast(inputPtrType.getPointeeType())) { + pointeePtrType = + dyn_cast(pointeeArrayType.getElementType()); + } else { + pointeePtrType = inputPtrType; + } + + if (!pointeePtrType || pointeePtrType.getStorageClass() != + spirv::StorageClass::PhysicalStorageBuffer) + continue; + + bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i); + bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i); + if (!hasAliased && !hasRestrict) + return emitOpError() << "with physical buffer pointer must be decorated " + "either 'Aliased' or 'Restrict'"; + } + return success(); } diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 00645d2c45519..0c521adb11332 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -239,8 +239,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef words) { if (decorationName.empty()) { return emitError(unknownLoc, "invalid Decoration code : ") << words[1]; } - auto attrName = llvm::convertToSnakeFromCamelCase(decorationName); - auto symbol = opBuilder.getStringAttr(attrName); + auto symbol = getSymbolDecoration(decorationName); switch (static_cast(words[1])) { case spirv::Decoration::FPFastMathMode: if (words.size() != 3) { @@ -298,6 +297,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef words) { break; } case spirv::Decoration::Aliased: + case spirv::Decoration::AliasedPointer: case spirv::Decoration::Block: case spirv::Decoration::BufferBlock: case spirv::Decoration::Flat: @@ -308,6 +308,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef words) { case spirv::Decoration::NoUnsignedWrap: case spirv::Decoration::RelaxedPrecision: case spirv::Decoration::Restrict: + case spirv::Decoration::RestrictPointer: if (words.size() != 2) { return emitError(unknownLoc, "OpDecoration with ") << decorationName << "needs a single target "; @@ -369,6 +370,46 @@ LogicalResult spirv::Deserializer::processMemberName(ArrayRef words) { return success(); } +LogicalResult spirv::Deserializer::setFunctionArgAttrs( + uint32_t argID, SmallVectorImpl &argAttrs, size_t argIndex) { + if (!decorations.contains(argID)) { + argAttrs[argIndex] = DictionaryAttr::get(context, {}); + return success(); + } + + spirv::DecorationAttr foundDecorationAttr; + for (NamedAttribute decAttr : decorations[argID]) { + for (auto decoration : + {spirv::Decoration::Aliased, spirv::Decoration::Restrict, + spirv::Decoration::AliasedPointer, + spirv::Decoration::RestrictPointer}) { + + if (decAttr.getName() != + getSymbolDecoration(stringifyDecoration(decoration))) + continue; + + if (foundDecorationAttr) + return emitError(unknownLoc, + "more than one Aliased/Restrict decorations for " + "function argument with result ") + << argID; + + foundDecorationAttr = spirv::DecorationAttr::get(context, decoration); + break; + } + } + + if (!foundDecorationAttr) + return emitError(unknownLoc, "unimplemented decoration support for " + "function argument with result ") + << argID; + + NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name), + foundDecorationAttr); + argAttrs[argIndex] = DictionaryAttr::get(context, attr); + return success(); +} + LogicalResult spirv::Deserializer::processFunction(ArrayRef operands) { if (curFunction) { @@ -430,6 +471,9 @@ spirv::Deserializer::processFunction(ArrayRef operands) { logger.indent(); }); + SmallVector argAttrs; + argAttrs.resize(functionType.getNumInputs()); + // Parse the op argument instructions if (functionType.getNumInputs()) { for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) { @@ -463,11 +507,21 @@ spirv::Deserializer::processFunction(ArrayRef operands) { return emitError(unknownLoc, "duplicate definition of result ") << operands[1]; } + if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) { + return failure(); + } + auto argValue = funcOp.getArgument(i); valueMap[operands[1]] = argValue; } } + if (llvm::any_of(argAttrs, [](Attribute attr) { + auto argAttr = cast(attr); + return !argAttr.empty(); + })) + funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs)); + // entryBlock is needed to access the arguments, Once that is done, we can // erase the block for functions with 'Import' LinkageAttributes, since these // are essentially function declarations, so they have no body. diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h index 69be47851ef3c..fc9a8f5f9364b 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -233,6 +233,19 @@ class Deserializer { return globalVariableMap.lookup(id); } + /// Sets the function argument's attributes. |argID| is the function + /// argument's result , and |argIndex| is its index in the function's + /// argument list. + LogicalResult setFunctionArgAttrs(uint32_t argID, + SmallVectorImpl &argAttrs, + size_t argIndex); + + /// Gets the symbol name from the name of decoration. + StringAttr getSymbolDecoration(StringRef decorationName) { + auto attrName = llvm::convertToSnakeFromCamelCase(decorationName); + return opBuilder.getStringAttr(attrName); + } + //===--------------------------------------------------------------------===// // Type //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp index 7bfcca5b4dcdc..41d2c0310d000 100644 --- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp @@ -177,6 +177,34 @@ LogicalResult Serializer::processUndefOp(spirv::UndefOp op) { return success(); } +LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) { + for (auto [idx, arg] : llvm::enumerate(op.getArguments())) { + uint32_t argTypeID = 0; + if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { + return failure(); + } + auto argValueID = getNextID(); + + // Process decoration attributes of arguments. + auto funcOp = cast(*op); + for (auto argAttr : funcOp.getArgAttrs(idx)) { + if (argAttr.getName() != DecorationAttr::name) + continue; + + if (auto decAttr = dyn_cast(argAttr.getValue())) { + if (failed(processDecorationAttr(op->getLoc(), argValueID, + decAttr.getValue(), decAttr))) + return failure(); + } + } + + valueIDMap[arg] = argValueID; + encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter, + {argTypeID, argValueID}); + } + return success(); +} + LogicalResult Serializer::processFuncOp(spirv::FuncOp op) { LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n"); assert(functionHeader.empty() && functionBody.empty()); @@ -229,32 +257,15 @@ LogicalResult Serializer::processFuncOp(spirv::FuncOp op) { // is going to return false for this function from now on) // Hence, we'll remove the body once we are done with the serialization. op.addEntryBlock(); - for (auto arg : op.getArguments()) { - uint32_t argTypeID = 0; - if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { - return failure(); - } - auto argValueID = getNextID(); - valueIDMap[arg] = argValueID; - encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter, - {argTypeID, argValueID}); - } + if (failed(processFuncParameter(op))) + return failure(); // Don't need to process the added block, there is nothing to process, // the fake body was added just to get the arguments, remove the body, // since it's use is done. op.eraseBody(); } else { - // Declare the parameters. - for (auto arg : op.getArguments()) { - uint32_t argTypeID = 0; - if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { - return failure(); - } - auto argValueID = getNextID(); - valueIDMap[arg] = argValueID; - encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter, - {argTypeID, argValueID}); - } + if (failed(processFuncParameter(op))) + return failure(); // Some instructions (e.g., OpVariable) in a function must be in the first // block in the function. These instructions will be put in diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 9e9a16456cc10..1029fb933175f 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -215,23 +215,15 @@ static std::string getDecorationName(StringRef attrName) { return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true); } -LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, - NamedAttribute attr) { - auto attrName = attr.getName().strref(); - auto decorationName = getDecorationName(attrName); - auto decoration = spirv::symbolizeDecoration(decorationName); - if (!decoration) { - return emitError( - loc, "non-argument attributes expected to have snake-case-ified " - "decoration name, unhandled attribute with name : ") - << attrName; - } +LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID, + Decoration decoration, + Attribute attr) { SmallVector args; - switch (*decoration) { + switch (decoration) { case spirv::Decoration::LinkageAttributes: { // Get the value of the Linkage Attributes // e.g., LinkageAttributes=["linkageName", linkageType]. - auto linkageAttr = llvm::dyn_cast(attr.getValue()); + auto linkageAttr = llvm::dyn_cast(attr); auto linkageName = linkageAttr.getLinkageName(); auto linkageType = linkageAttr.getLinkageType().getValue(); // Encode the Linkage Name (string literal to uint32_t). @@ -241,32 +233,36 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, break; } case spirv::Decoration::FPFastMathMode: - if (auto intAttr = dyn_cast(attr.getValue())) { + if (auto intAttr = dyn_cast(attr)) { args.push_back(static_cast(intAttr.getValue())); break; } return emitError(loc, "expected FPFastMathModeAttr attribute for ") - << attrName; + << stringifyDecoration(decoration); case spirv::Decoration::Binding: case spirv::Decoration::DescriptorSet: case spirv::Decoration::Location: - if (auto intAttr = dyn_cast(attr.getValue())) { + if (auto intAttr = dyn_cast(attr)) { args.push_back(intAttr.getValue().getZExtValue()); break; } - return emitError(loc, "expected integer attribute for ") << attrName; + return emitError(loc, "expected integer attribute for ") + << stringifyDecoration(decoration); case spirv::Decoration::BuiltIn: - if (auto strAttr = dyn_cast(attr.getValue())) { + if (auto strAttr = dyn_cast(attr)) { auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue()); if (enumVal) { args.push_back(static_cast(*enumVal)); break; } return emitError(loc, "invalid ") - << attrName << " attribute " << strAttr.getValue(); + << stringifyDecoration(decoration) << " decoration attribute " + << strAttr.getValue(); } - return emitError(loc, "expected string attribute for ") << attrName; + return emitError(loc, "expected string attribute for ") + << stringifyDecoration(decoration); case spirv::Decoration::Aliased: + case spirv::Decoration::AliasedPointer: case spirv::Decoration::Flat: case spirv::Decoration::NonReadable: case spirv::Decoration::NonWritable: @@ -275,14 +271,34 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, case spirv::Decoration::NoUnsignedWrap: case spirv::Decoration::RelaxedPrecision: case spirv::Decoration::Restrict: - // For unit attributes, the args list has no values so we do nothing - if (auto unitAttr = dyn_cast(attr.getValue())) + case spirv::Decoration::RestrictPointer: + // For unit attributes and decoration attributes, the args list + // has no values so we do nothing. + if (isa(attr)) break; - return emitError(loc, "expected unit attribute for ") << attrName; + return emitError(loc, + "expected unit attribute or decoration attribute for ") + << stringifyDecoration(decoration); default: - return emitError(loc, "unhandled decoration ") << decorationName; + return emitError(loc, "unhandled decoration ") + << stringifyDecoration(decoration); + } + return emitDecoration(resultID, decoration, args); +} + +LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, + NamedAttribute attr) { + StringRef attrName = attr.getName().strref(); + std::string decorationName = getDecorationName(attrName); + std::optional decoration = + spirv::symbolizeDecoration(decorationName); + if (!decoration) { + return emitError( + loc, "non-argument attributes expected to have snake-case-ified " + "decoration name, unhandled attribute with name : ") + << attrName; } - return emitDecoration(resultID, *decoration, args); + return processDecorationAttr(loc, resultID, *decoration, attr.getValue()); } LogicalResult Serializer::processName(uint32_t resultID, StringRef name) { diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h index 4b2ebf610bd72..9edb0f4af008d 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h @@ -127,6 +127,7 @@ class Serializer { /// Processes a SPIR-V function op. LogicalResult processFuncOp(spirv::FuncOp op); + LogicalResult processFuncParameter(spirv::FuncOp op); LogicalResult processVariableOp(spirv::VariableOp op); @@ -134,6 +135,8 @@ class Serializer { LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp); /// Process attributes that translate to decorations on the result + LogicalResult processDecorationAttr(Location loc, uint32_t resultID, + Decoration decoration, Attribute attr); LogicalResult processDecoration(Location loc, uint32_t resultID, NamedAttribute attr); diff --git a/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir b/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir index b3991cbdbe8af..b9c56a3fcffd0 100644 --- a/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir @@ -81,7 +81,7 @@ spirv.func @pointerIncomingRayPayloadKHR(!spirv.ptr) spirv.func @pointerShaderRecordBufferKHR(!spirv.ptr) "None" // CHECK-ALL: llvm.func @pointerPhysicalStorageBuffer(!llvm.ptr) -spirv.func @pointerPhysicalStorageBuffer(!spirv.ptr) "None" +spirv.func @pointerPhysicalStorageBuffer(!spirv.ptr { spirv.decoration = #spirv.decoration }) "None" // CHECK-ALL: llvm.func @pointerCodeSectionINTEL(!llvm.ptr) spirv.func @pointerCodeSectionINTEL(!spirv.ptr) "None" diff --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir index 4f4a72da7c050..e289dbf28ad28 100644 --- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir @@ -414,7 +414,7 @@ spirv.module Physical64 OpenCL requires #spirv.vce { - spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr) "None" { + spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr { spirv.decoration = #spirv.decoration }) "None" { // CHECK: {{%.*}} = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr to i32 %0 = spirv.ConvertPtrToU %arg0 : !spirv.ptr to i32 spirv.Return diff --git a/mlir/test/Dialect/SPIRV/IR/function-decorations.mlir b/mlir/test/Dialect/SPIRV/IR/function-decorations.mlir index 2e39421df13cc..d915f8820c4f4 100644 --- a/mlir/test/Dialect/SPIRV/IR/function-decorations.mlir +++ b/mlir/test/Dialect/SPIRV/IR/function-decorations.mlir @@ -17,3 +17,59 @@ spirv.module Logical GLSL450 requires #spirv.vce { } spirv.func @inside.func() -> () "Pure" attributes {} {spirv.Return} } + +// ----- + +// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr {spirv.decoration = #spirv.decoration}) "None" +spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr { spirv.decoration = #spirv.decoration }) "None" { + spirv.Return +} + +// ----- + +// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr {spirv.decoration = #spirv.decoration}) "None" +spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr { spirv.decoration = #spirv.decoration }) "None" { + spirv.Return +} + +// ----- + +// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr, Generic> {spirv.decoration = #spirv.decoration}) "None" +spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr, Generic> { spirv.decoration = #spirv.decoration }) "None" { + spirv.Return +} + +// ----- + +// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr, Generic> {spirv.decoration = #spirv.decoration}) "None" +spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr, Generic> { spirv.decoration = #spirv.decoration }) "None" { + spirv.Return +} + +// ----- + +// expected-error @+1 {{'spirv.func' op with physical buffer pointer must be decorated either 'Aliased' or 'Restrict'}} +spirv.func @no_arg_decoration_pointer(%arg0: !spirv.ptr) "None" { + spirv.Return +} + +// ----- + +// expected-error @+1 {{'spirv.func' op with a pointer points to a physical buffer pointer must be decorated either 'AliasedPointer' or 'RestrictPointer'}} +spirv.func @no_arg_decoration_pointer(%arg0: !spirv.ptr, Function>) "None" { + spirv.Return +} + +// ----- + +// expected-error @+1 {{'spirv.func' op with physical buffer pointer must be decorated either 'Aliased' or 'Restrict'}} +spirv.func @no_decoration_name_attr(%arg0 : !spirv.ptr { random_attr = #spirv.decoration }) "None" { + spirv.Return +} + +// ----- + +// expected-error @+1 {{'spirv.func' op arguments may only have dialect attributes}} +spirv.func @no_decoration_name_attr(%arg0 : !spirv.ptr { spirv.decoration = #spirv.decoration, random_attr = #spirv.decoration }) "None" { + spirv.Return +} diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir index 4eaa21d2f94ef..931034f3d5f6e 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir @@ -66,7 +66,7 @@ spirv.module PhysicalStorageBuffer64 GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> } { - spirv.func @physical_ptr(%val : !spirv.ptr) "None" { + spirv.func @physical_ptr(%val : !spirv.ptr { spirv.decoration = #spirv.decoration }) "None" { spirv.Return } } diff --git a/mlir/test/Target/SPIRV/cast-ops.mlir b/mlir/test/Target/SPIRV/cast-ops.mlir index 7fe0969497c3e..ede0bf30511ef 100644 --- a/mlir/test/Target/SPIRV/cast-ops.mlir +++ b/mlir/test/Target/SPIRV/cast-ops.mlir @@ -115,7 +115,7 @@ spirv.module Physical64 OpenCL requires #spirv.vce { - spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr) "None" { + spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr { spirv.decoration = #spirv.decoration} ) "None" { // CHECK: {{%.*}} = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr to i32 %0 = spirv.ConvertPtrToU %arg0 : !spirv.ptr to i32 spirv.Return diff --git a/mlir/test/Target/SPIRV/function-decorations.mlir b/mlir/test/Target/SPIRV/function-decorations.mlir index b0f6705df9ca4..117d4ca628f76 100644 --- a/mlir/test/Target/SPIRV/function-decorations.mlir +++ b/mlir/test/Target/SPIRV/function-decorations.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s +// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip -split-input-file -verify-diagnostics %s | FileCheck %s spirv.module Logical GLSL450 requires #spirv.vce { spirv.func @linkage_attr_test_kernel() "DontInline" attributes {} { @@ -17,3 +17,72 @@ spirv.module Logical GLSL450 requires #spirv.vce { } spirv.func @inside.func() -> () "Pure" attributes {} {spirv.Return} } + +// ----- + +spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce { + // CHECK-LABEL: spirv.func @func_arg_decoration_aliased(%{{.*}}: !spirv.ptr {spirv.decoration = #spirv.decoration}) + spirv.func @func_arg_decoration_aliased( + %arg0 : !spirv.ptr { spirv.decoration = #spirv.decoration } + ) "None" { + spirv.Return + } +} + +// ----- + +spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce { + // CHECK-LABEL: spirv.func @func_arg_decoration_restrict(%{{.*}}: !spirv.ptr {spirv.decoration = #spirv.decoration}) + spirv.func @func_arg_decoration_restrict( + %arg0 : !spirv.ptr { spirv.decoration = #spirv.decoration } + ) "None" { + spirv.Return + } +} + +// ----- + +spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce { + // CHECK-LABEL: spirv.func @func_arg_decoration_aliased_pointer(%{{.*}}: !spirv.ptr, Generic> {spirv.decoration = #spirv.decoration}) + spirv.func @func_arg_decoration_aliased_pointer( + %arg0 : !spirv.ptr, Generic> { spirv.decoration = #spirv.decoration } + ) "None" { + spirv.Return + } +} + +// ----- + +spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce { + // CHECK-LABEL: spirv.func @func_arg_decoration_restrict_pointer(%{{.*}}: !spirv.ptr, Generic> {spirv.decoration = #spirv.decoration}) + spirv.func @func_arg_decoration_restrict_pointer( + %arg0 : !spirv.ptr, Generic> { spirv.decoration = #spirv.decoration } + ) "None" { + spirv.Return + } +} + +// ----- + +spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce { + // CHECK-LABEL: spirv.func @fn1(%{{.*}}: i32, %{{.*}}: !spirv.ptr {spirv.decoration = #spirv.decoration}) + spirv.func @fn1( + %arg0: i32, + %arg1: !spirv.ptr { spirv.decoration = #spirv.decoration } + ) "None" { + spirv.Return + } + + // CHECK-LABEL: spirv.func @fn2(%{{.*}}: !spirv.ptr {spirv.decoration = #spirv.decoration}, %{{.*}}: !spirv.ptr {spirv.decoration = #spirv.decoration}) + spirv.func @fn2( + %arg0: !spirv.ptr { spirv.decoration = #spirv.decoration }, + %arg1: !spirv.ptr { spirv.decoration = #spirv.decoration} + ) "None" { + spirv.Return + } +}