New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][spirv] Support alias/restrict function argument decorations #76353
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
fe1f249
to
50200c4
Compare
50200c4
to
2b64b3e
Compare
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Kohei Yamaguchi (sott0n) ChangesCloses #76106 Patch is 31.05 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76353.diff 14 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index 5fd25e3b576f2a..0afe508b4db013 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -267,6 +267,11 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [
This op itself takes no operands and generates no results. Its region
can take zero or more arguments and return zero or one values.
+ From `SPV_KHR_physical_storage_buffer`:
+ If a parameter of function is
+ - a pointer (or contains a pointer) in the PhysicalStorageBuffer storage class, the function parameter must be decorated with exactly one of `Aliased` or `Restrict`.
+ - a pointer (or contains a pointer) and the type it points to is a pointer in the PhysicalStorageBuffer storage class, the function parameter must be decorated with exactly one of `AliasedPointer` or `RestrictPointer`.
+
<!-- End of AutoGen section -->
```
@@ -280,6 +285,9 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [
```mlir
spirv.func @foo() -> () "None" { ... }
spirv.func @bar() -> () "Inline|Pure" { ... }
+
+ spirv.func @baz(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased|Restrict>}) -> () "None" { ... }
+ spirv.func @qux(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<AliasedPointer|RestrictPointer>}) "None)
```
}];
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 8a68decc5878c8..66ec520cfeca31 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -992,19 +992,25 @@ static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
StringRef symbol = attribute.getName().strref();
Attribute attr = attribute.getValue();
- if (symbol != spirv::getInterfaceVarABIAttrName())
+ if (symbol == spirv::getInterfaceVarABIAttrName()) {
+ auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
+ if (!varABIAttr)
+ return emitError(loc, "'")
+ << symbol << "' must be a spirv::InterfaceVarABIAttr";
+
+ if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
+ return emitError(loc, "'") << symbol
+ << "' attribute cannot specify storage class "
+ "when attaching to a non-scalar value";
+ } else if (symbol == spirv::DecorationAttr::name) {
+ auto decAttr = llvm::dyn_cast<spirv::DecorationAttr>(attr);
+ if (!decAttr)
+ return emitError(loc, "'")
+ << symbol << "' must be a spirv::DecorationAttr";
+ } else {
return emitError(loc, "found unsupported '")
<< symbol << "' attribute on region argument";
-
- auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
- if (!varABIAttr)
- return emitError(loc, "'")
- << symbol << "' must be a spirv::InterfaceVarABIAttr";
-
- if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
- return emitError(loc, "'") << symbol
- << "' attribute cannot specify storage class "
- "when attaching to a non-scalar value";
+ }
return success();
}
@@ -1013,9 +1019,12 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
unsigned regionIndex,
unsigned argIndex,
NamedAttribute attribute) {
- return verifyRegionAttribute(
- op->getLoc(), op->getRegion(regionIndex).getArgument(argIndex).getType(),
- attribute);
+ auto funcOp = dyn_cast<FunctionOpInterface>(op);
+ if (!funcOp)
+ return success();
+ Type argType = funcOp.getArgumentTypes()[argIndex];
+
+ return verifyRegionAttribute(op->getLoc(), argType, attribute);
}
LogicalResult SPIRVDialect::verifyRegionResultAttribute(
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 2a1d083308282a..d6064f446b4454 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -972,8 +972,75 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) {
}
LogicalResult spirv::FuncOp::verifyType() {
- if (getFunctionType().getNumResults() > 1)
+ FunctionType fnType = getFunctionType();
+ if (fnType.getNumResults() > 1)
return emitOpError("cannot have more than one result");
+
+ auto hasDecorationAttr = [op = getOperation()](spirv::Decoration decoration,
+ unsigned argIndex) {
+ if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
+ for (auto argAttr : funcOp.getArgAttrs(argIndex))
+ if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
+ return decAttr.getValue() == decoration;
+ }
+ return false;
+ };
+
+ auto funcOp = dyn_cast<spirv::FuncOp>(getOperation());
+ unsigned numArgs = funcOp.getNumArguments();
+ if (numArgs < 1)
+ return success();
+
+ for (unsigned i = 0; i < numArgs; ++i) {
+ auto param = fnType.getInputs()[i];
+ auto inputPtrType = dyn_cast<spirv::PointerType>(param);
+ if (!inputPtrType)
+ continue;
+
+ auto pointeePtrType =
+ dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
+ if (pointeePtrType) {
+ // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
+ // > If an OpFunctionParameter is a pointer (or contains a pointer)
+ // > and the type it points to is a pointer in the PhysicalStorageBuffer
+ // > storage class, the function parameter must be decorated with exactly
+ // > one of AliasedPointer or RestrictPointer.
+ if (pointeePtrType.getStorageClass() ==
+ spirv::StorageClass::PhysicalStorageBuffer) {
+ bool hasAliasedPtr =
+ hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
+ bool hasRestrictPtr =
+ hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
+
+ if (!hasAliasedPtr && !hasRestrictPtr)
+ return emitOpError()
+ << "with a pointer points to a physical buffer pointer must "
+ "be decorated either 'AliasedPointer' or 'RestrictPointer'";
+ }
+ } else {
+ // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
+ // > If an OpFunctionParameter is a pointer (or contains a pointer) in
+ // > the PhysicalStorageBuffer storage class, the function parameter must
+ // > be decorated with exactly one of Aliased or Restrict.
+ if (auto pointeeArrayType =
+ dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
+ pointeePtrType =
+ dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
+ } else {
+ pointeePtrType = inputPtrType;
+ }
+ if (pointeePtrType && pointeePtrType.getStorageClass() ==
+ spirv::StorageClass::PhysicalStorageBuffer) {
+ bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
+ bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
+ if (!hasAliased && !hasRestrict)
+ return emitOpError()
+ << "with physical buffer pointer must be decorated "
+ "either 'Aliased' or 'Restrict'";
+ }
+ }
+ }
+
return success();
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 89e2e7ad52fa7d..24748007bbb175 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -239,8 +239,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
if (decorationName.empty()) {
return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
}
- auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
- auto symbol = opBuilder.getStringAttr(attrName);
+ auto symbol = getSymbolDecoration(decorationName);
switch (static_cast<spirv::Decoration>(words[1])) {
case spirv::Decoration::FPFastMathMode:
if (words.size() != 3) {
@@ -298,6 +297,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
break;
}
case spirv::Decoration::Aliased:
+ case spirv::Decoration::AliasedPointer:
case spirv::Decoration::Block:
case spirv::Decoration::BufferBlock:
case spirv::Decoration::Flat:
@@ -308,6 +308,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
case spirv::Decoration::NoUnsignedWrap:
case spirv::Decoration::RelaxedPrecision:
case spirv::Decoration::Restrict:
+ case spirv::Decoration::RestrictPointer:
if (words.size() != 2) {
return emitError(unknownLoc, "OpDecoration with ")
<< decorationName << "needs a single target <id>";
@@ -369,6 +370,46 @@ LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
return success();
}
+void spirv::Deserializer::setArgAttrs(uint32_t argID) {
+ if (!decorations.count(argID)) {
+ argAttrs.push_back(DictionaryAttr::get(context, {}));
+ return;
+ }
+
+ // Replace a decoration as UnitAttr with DecorationAttr for the physical
+ // buffer pointer in the function parameter.
+ // e.g. "aliased" -> "spirv.decoration = #spirv.decoration<Aliased>").
+ for (auto decAttr : decorations[argID]) {
+ if (decAttr.getName() ==
+ getSymbolDecoration(stringifyDecoration(spirv::Decoration::Aliased))) {
+ decorations[argID].erase(decAttr.getName());
+ decorations[argID].set(
+ spirv::DecorationAttr::name,
+ spirv::DecorationAttr::get(context, spirv::Decoration::Aliased));
+ } else if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
+ spirv::Decoration::Restrict))) {
+ decorations[argID].erase(decAttr.getName());
+ decorations[argID].set(
+ spirv::DecorationAttr::name,
+ spirv::DecorationAttr::get(context, spirv::Decoration::Restrict));
+ } else if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
+ spirv::Decoration::AliasedPointer))) {
+ decorations[argID].erase(decAttr.getName());
+ decorations[argID].set(spirv::DecorationAttr::name,
+ spirv::DecorationAttr::get(
+ context, spirv::Decoration::AliasedPointer));
+ } else if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
+ spirv::Decoration::RestrictPointer))) {
+ decorations[argID].erase(decAttr.getName());
+ decorations[argID].set(spirv::DecorationAttr::name,
+ spirv::DecorationAttr::get(
+ context, spirv::Decoration::RestrictPointer));
+ }
+ }
+
+ argAttrs.push_back(decorations[argID].getDictionary(context));
+}
+
LogicalResult
spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
if (curFunction) {
@@ -463,11 +504,18 @@ spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
return emitError(unknownLoc, "duplicate definition of result <id> ")
<< operands[1];
}
+ setArgAttrs(operands[1]);
auto argValue = funcOp.getArgument(i);
valueMap[operands[1]] = argValue;
}
}
+ if (llvm::any_of(argAttrs, [](Attribute attr) {
+ auto argAttr = cast<DictionaryAttr>(attr);
+ return !argAttr.empty();
+ }))
+ funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
+
// entryBlock is needed to access the arguments, Once that is done, we can
// erase the block for functions with 'Import' LinkageAttributes, since these
// are essentially function declarations, so they have no body.
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 69be47851ef3c5..115addd49f949a 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -233,6 +233,15 @@ class Deserializer {
return globalVariableMap.lookup(id);
}
+ /// Sets the argument's attributes with the given argument <id>.
+ void setArgAttrs(uint32_t argID);
+
+ /// Gets the symbol name from the name of decoration.
+ StringAttr getSymbolDecoration(StringRef decorationName) {
+ auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
+ return opBuilder.getStringAttr(attrName);
+ }
+
//===--------------------------------------------------------------------===//
// Type
//===--------------------------------------------------------------------===//
@@ -605,6 +614,9 @@ class Deserializer {
/// A list of all structs which have unresolved member types.
SmallVector<DeferredStructTypeInfo, 0> deferredStructTypesInfos;
+ /// A list of argument attributes of function.
+ SmallVector<Attribute, 0> argAttrs;
+
#ifndef NDEBUG
/// A logger used to emit information during the deserialzation process.
llvm::ScopedPrinter logger;
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 44538c38a41b83..2efb0ee64c9253 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -177,6 +177,35 @@ LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
return success();
}
+LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
+ unsigned numArgs = op.getNumArguments();
+ if (numArgs != 0) {
+ for (unsigned i = 0; i < numArgs; ++i) {
+ auto arg = op.getArgument(i);
+ uint32_t argTypeID = 0;
+ if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
+ return failure();
+ }
+ auto argValueID = getNextID();
+
+ // Process decoration attributes of arguments.
+ auto funcOp = cast<FunctionOpInterface>(*op);
+ for (auto argAttr : funcOp.getArgAttrs(i)) {
+ if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
+ if (failed(processDecorationAttr(op->getLoc(), argValueID,
+ decAttr.getValue(), decAttr)))
+ return failure();
+ }
+ }
+
+ valueIDMap[arg] = argValueID;
+ encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
+ {argTypeID, argValueID});
+ }
+ }
+ return success();
+}
+
LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
assert(functionHeader.empty() && functionBody.empty());
@@ -229,32 +258,15 @@ LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
// is going to return false for this function from now on)
// Hence, we'll remove the body once we are done with the serialization.
op.addEntryBlock();
- for (auto arg : op.getArguments()) {
- uint32_t argTypeID = 0;
- if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
- return failure();
- }
- auto argValueID = getNextID();
- valueIDMap[arg] = argValueID;
- encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
- {argTypeID, argValueID});
- }
+ if (failed(processFuncParameter(op)))
+ return failure();
// Don't need to process the added block, there is nothing to process,
// the fake body was added just to get the arguments, remove the body,
// since it's use is done.
op.eraseBody();
} else {
- // Declare the parameters.
- for (auto arg : op.getArguments()) {
- uint32_t argTypeID = 0;
- if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
- return failure();
- }
- auto argValueID = getNextID();
- valueIDMap[arg] = argValueID;
- encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
- {argTypeID, argValueID});
- }
+ if (failed(processFuncParameter(op)))
+ return failure();
// Some instructions (e.g., OpVariable) in a function must be in the first
// block in the function. These instructions will be put in
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 9e9a16456cc102..6c3aaa35457470 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -215,23 +215,15 @@ static std::string getDecorationName(StringRef attrName) {
return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
}
-LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
- NamedAttribute attr) {
- auto attrName = attr.getName().strref();
- auto decorationName = getDecorationName(attrName);
- auto decoration = spirv::symbolizeDecoration(decorationName);
- if (!decoration) {
- return emitError(
- loc, "non-argument attributes expected to have snake-case-ified "
- "decoration name, unhandled attribute with name : ")
- << attrName;
- }
+LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
+ Decoration decoration,
+ Attribute attr) {
SmallVector<uint32_t, 1> args;
- switch (*decoration) {
+ switch (decoration) {
case spirv::Decoration::LinkageAttributes: {
// Get the value of the Linkage Attributes
// e.g., LinkageAttributes=["linkageName", linkageType].
- auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr.getValue());
+ auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr);
auto linkageName = linkageAttr.getLinkageName();
auto linkageType = linkageAttr.getLinkageType().getValue();
// Encode the Linkage Name (string literal to uint32_t).
@@ -241,32 +233,36 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
break;
}
case spirv::Decoration::FPFastMathMode:
- if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr.getValue())) {
+ if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
args.push_back(static_cast<uint32_t>(intAttr.getValue()));
break;
}
return emitError(loc, "expected FPFastMathModeAttr attribute for ")
- << attrName;
+ << stringifyDecoration(decoration);
case spirv::Decoration::Binding:
case spirv::Decoration::DescriptorSet:
case spirv::Decoration::Location:
- if (auto intAttr = dyn_cast<IntegerAttr>(attr.getValue())) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
args.push_back(intAttr.getValue().getZExtValue());
break;
}
- return emitError(loc, "expected integer attribute for ") << attrName;
+ return emitError(loc, "expected integer attribute for ")
+ << stringifyDecoration(decoration);
case spirv::Decoration::BuiltIn:
- if (auto strAttr = dyn_cast<StringAttr>(attr.getValue())) {
+ if (auto strAttr = dyn_cast<StringAttr>(attr)) {
auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
if (enumVal) {
args.push_back(static_cast<uint32_t>(*enumVal));
break;
}
return emitError(loc, "invalid ")
- << attrName << " attribute " << strAttr.getValue();
+ << stringifyDecoration(decoration) << " decoration attribute "
+ << strAttr.getValue();
}
- return emitError(loc, "expected string attribute for ") << attrName;
+ return emitError(loc, "expected string attribute for ")
+ << stringifyDecoration(decoration);
case spirv::Decoration::Aliased:
+ case spirv::Decoration::AliasedPointer:
case spirv::Decoration::Flat:
case spirv::Decoration::NonReadable:
case spirv::Decoration::NonWritable:
@@ -275,14 +271,33 @@ LogicalResult Serializer::processDecoration(Location loc, uin...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thanks for working on this @sott0n!
I left some suggestions, mostly on the coding style. I'm not super familiar with the deserializer, so it would be best if @antiagainst could take a look to.
@kuhar Thanks for your review! I addressed your comments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but please wait for a second approval from @antiagainst before merging
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thanks for the contribution!
@@ -414,7 +414,7 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses], [] | |||
// ----- | |||
|
|||
spirv.module PhysicalStorageBuffer64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, PhysicalStorageBufferAddresses], []> { | |||
spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer>) "None" { | |||
spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }) "None" { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is okay for now; but later we might want to have custom parser/printer to simpify it into something like %arg0 : !spirv.ptr<...> Aliased
or something (cannot recall exactly the restrictions on attribute parsing/printing). Also given we are here for spirv.func
, it would be nice to omit the function control when it's "None"
. Not for this patch though.
@kuhar as FYI
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
later we might want to have custom parser/printer to simpify it into something like %arg0 : !spirv.ptr<...> Aliased or something (cannot recall exactly the restrictions on attribute parsing/printing).
It looks good! After applying this patch, I would like to think about the custom parse/printer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG!
spirv::Decoration::AliasedPointer, | ||
spirv::Decoration::RestrictPointer}) { | ||
if (decAttr.getName() == | ||
getSymbolDecoration(stringifyDecoration(decoration))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd think it's better to just break once found one of the above four decorations. We can create the dictionary attribute out of the for loop. This way we don't mutate decorations
and it's clearer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean that during deserialization, only these four decorations should be added to argAttrs
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually only one of these four for now. That reminds, we should error out if there are more we didn't recognize to avoid silently ignore decorations. So, make sure the decoration is only one of the four.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I understand and updated this function.
we should error out if there are more we didn't recognize to avoid silently ignore decorations.
Should we add this invalid test case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine to do without a test for now, given that the possible error cases are unbound. Especially for the deserialization, right now it's meant to only cover the cases where we can serialize.
BTW we do have some deserialization test here: https://github.com/llvm/llvm-project/blob/main/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp but it's a bit painful to write those.
@sott0n I didn't see your fixes. Did you forget to push out or still working on it? |
@antiagainst Sorry.. I pushed it. And Some are still in the works. |
Thanks a lot for the contribution and bearing with me for the nitpicking. :) To avoid burden you further, I rebased and revised some comment/error message slightly. :) I'll land once bots are happy. |
@kuhar @antiagainst Thanks for your comment and support to land this PR! If you have any issues or TODO tasks in SPIR-V dialect, I would be happy if you could share it with me. :) |
Thanks for your further interest! Please feel free to take a look at issues with the |
…lvm#76353) Closes llvm#76106 --------- Co-authored-by: Lei Zhang <antiagainst@gmail.com>
Closes #76106