Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][spirv] Support alias/restrict function argument decorations #76353

Merged
merged 11 commits into from Jan 6, 2024

Conversation

sott0n
Copy link
Contributor

@sott0n sott0n commented Dec 25, 2023

Closes #76106

Copy link

github-actions bot commented Dec 26, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@sott0n sott0n force-pushed the decoration-pointer-func-arg branch 2 times, most recently from fe1f249 to 50200c4 Compare December 28, 2023 05:27
@sott0n sott0n marked this pull request as ready for review December 28, 2023 09:24
@llvmbot
Copy link
Collaborator

llvmbot commented Dec 28, 2023

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Kohei Yamaguchi (sott0n)

Changes

Closes #76106


Patch is 31.05 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76353.diff

14 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td (+8)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp (+23-14)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+68-1)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+50-2)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.h (+12)
  • (modified) mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp (+33-21)
  • (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+40-25)
  • (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.h (+3)
  • (modified) mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir (+1-1)
  • (modified) mlir/test/Dialect/SPIRV/IR/cast-ops.mlir (+1-1)
  • (modified) mlir/test/Dialect/SPIRV/IR/structure-ops.mlir (+42)
  • (modified) mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir (+1-1)
  • (modified) mlir/test/Target/SPIRV/cast-ops.mlir (+1-1)
  • (modified) mlir/test/Target/SPIRV/function-decorations.mlir (+35-1)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index 5fd25e3b576f2a..0afe508b4db013 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -267,6 +267,11 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [
     This op itself takes no operands and generates no results. Its region
     can take zero or more arguments and return zero or one values.
 
+    From `SPV_KHR_physical_storage_buffer`:
+    If a parameter of function is
+    - a pointer (or contains a pointer) in the PhysicalStorageBuffer storage class, the function parameter must be decorated with exactly one of `Aliased` or `Restrict`.
+    - a pointer (or contains a pointer) and the type it points to is a pointer in the PhysicalStorageBuffer storage class, the function parameter must be decorated with exactly one of `AliasedPointer` or `RestrictPointer`.
+
     <!-- End of AutoGen section -->
 
     ```
@@ -280,6 +285,9 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [
     ```mlir
     spirv.func @foo() -> () "None" { ... }
     spirv.func @bar() -> () "Inline|Pure" { ... }
+
+    spirv.func @baz(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased|Restrict>}) -> () "None" { ... }
+    spirv.func @qux(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<AliasedPointer|RestrictPointer>}) "None)
     ```
   }];
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 8a68decc5878c8..66ec520cfeca31 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -992,19 +992,25 @@ static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
   StringRef symbol = attribute.getName().strref();
   Attribute attr = attribute.getValue();
 
-  if (symbol != spirv::getInterfaceVarABIAttrName())
+  if (symbol == spirv::getInterfaceVarABIAttrName()) {
+    auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
+    if (!varABIAttr)
+      return emitError(loc, "'")
+             << symbol << "' must be a spirv::InterfaceVarABIAttr";
+
+    if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
+      return emitError(loc, "'") << symbol
+                                 << "' attribute cannot specify storage class "
+                                    "when attaching to a non-scalar value";
+  } else if (symbol == spirv::DecorationAttr::name) {
+    auto decAttr = llvm::dyn_cast<spirv::DecorationAttr>(attr);
+    if (!decAttr)
+      return emitError(loc, "'")
+             << symbol << "' must be a spirv::DecorationAttr";
+  } else {
     return emitError(loc, "found unsupported '")
            << symbol << "' attribute on region argument";
-
-  auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
-  if (!varABIAttr)
-    return emitError(loc, "'")
-           << symbol << "' must be a spirv::InterfaceVarABIAttr";
-
-  if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
-    return emitError(loc, "'") << symbol
-                               << "' attribute cannot specify storage class "
-                                  "when attaching to a non-scalar value";
+  }
 
   return success();
 }
@@ -1013,9 +1019,12 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
                                                      unsigned regionIndex,
                                                      unsigned argIndex,
                                                      NamedAttribute attribute) {
-  return verifyRegionAttribute(
-      op->getLoc(), op->getRegion(regionIndex).getArgument(argIndex).getType(),
-      attribute);
+  auto funcOp = dyn_cast<FunctionOpInterface>(op);
+  if (!funcOp)
+    return success();
+  Type argType = funcOp.getArgumentTypes()[argIndex];
+
+  return verifyRegionAttribute(op->getLoc(), argType, attribute);
 }
 
 LogicalResult SPIRVDialect::verifyRegionResultAttribute(
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 2a1d083308282a..d6064f446b4454 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -972,8 +972,75 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) {
 }
 
 LogicalResult spirv::FuncOp::verifyType() {
-  if (getFunctionType().getNumResults() > 1)
+  FunctionType fnType = getFunctionType();
+  if (fnType.getNumResults() > 1)
     return emitOpError("cannot have more than one result");
+
+  auto hasDecorationAttr = [op = getOperation()](spirv::Decoration decoration,
+                                                 unsigned argIndex) {
+    if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
+      for (auto argAttr : funcOp.getArgAttrs(argIndex))
+        if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
+          return decAttr.getValue() == decoration;
+    }
+    return false;
+  };
+
+  auto funcOp = dyn_cast<spirv::FuncOp>(getOperation());
+  unsigned numArgs = funcOp.getNumArguments();
+  if (numArgs < 1)
+    return success();
+
+  for (unsigned i = 0; i < numArgs; ++i) {
+    auto param = fnType.getInputs()[i];
+    auto inputPtrType = dyn_cast<spirv::PointerType>(param);
+    if (!inputPtrType)
+      continue;
+
+    auto pointeePtrType =
+        dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
+    if (pointeePtrType) {
+      // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
+      // > If an OpFunctionParameter is a pointer (or contains a pointer)
+      // > and the type it points to is a pointer in the PhysicalStorageBuffer
+      // > storage class, the function parameter must be decorated with exactly
+      // > one of AliasedPointer or RestrictPointer.
+      if (pointeePtrType.getStorageClass() ==
+          spirv::StorageClass::PhysicalStorageBuffer) {
+        bool hasAliasedPtr =
+            hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
+        bool hasRestrictPtr =
+            hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
+
+        if (!hasAliasedPtr && !hasRestrictPtr)
+          return emitOpError()
+                 << "with a pointer points to a physical buffer pointer must "
+                    "be decorated either 'AliasedPointer' or 'RestrictPointer'";
+      }
+    } else {
+      // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
+      // > If an OpFunctionParameter is a pointer (or contains a pointer) in
+      // > the PhysicalStorageBuffer storage class, the function parameter must
+      // > be decorated with exactly one of Aliased or Restrict.
+      if (auto pointeeArrayType =
+              dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
+        pointeePtrType =
+            dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
+      } else {
+        pointeePtrType = inputPtrType;
+      }
+      if (pointeePtrType && pointeePtrType.getStorageClass() ==
+                                spirv::StorageClass::PhysicalStorageBuffer) {
+        bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
+        bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
+        if (!hasAliased && !hasRestrict)
+          return emitOpError()
+                 << "with physical buffer pointer must be decorated "
+                    "either 'Aliased' or 'Restrict'";
+      }
+    }
+  }
+
   return success();
 }
 
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 89e2e7ad52fa7d..24748007bbb175 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -239,8 +239,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
   if (decorationName.empty()) {
     return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
   }
-  auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
-  auto symbol = opBuilder.getStringAttr(attrName);
+  auto symbol = getSymbolDecoration(decorationName);
   switch (static_cast<spirv::Decoration>(words[1])) {
   case spirv::Decoration::FPFastMathMode:
     if (words.size() != 3) {
@@ -298,6 +297,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
     break;
   }
   case spirv::Decoration::Aliased:
+  case spirv::Decoration::AliasedPointer:
   case spirv::Decoration::Block:
   case spirv::Decoration::BufferBlock:
   case spirv::Decoration::Flat:
@@ -308,6 +308,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
   case spirv::Decoration::NoUnsignedWrap:
   case spirv::Decoration::RelaxedPrecision:
   case spirv::Decoration::Restrict:
+  case spirv::Decoration::RestrictPointer:
     if (words.size() != 2) {
       return emitError(unknownLoc, "OpDecoration with ")
              << decorationName << "needs a single target <id>";
@@ -369,6 +370,46 @@ LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
   return success();
 }
 
+void spirv::Deserializer::setArgAttrs(uint32_t argID) {
+  if (!decorations.count(argID)) {
+    argAttrs.push_back(DictionaryAttr::get(context, {}));
+    return;
+  }
+
+  // Replace a decoration as UnitAttr with DecorationAttr for the physical
+  // buffer pointer in the function parameter.
+  // e.g. "aliased" -> "spirv.decoration = #spirv.decoration<Aliased>").
+  for (auto decAttr : decorations[argID]) {
+    if (decAttr.getName() ==
+        getSymbolDecoration(stringifyDecoration(spirv::Decoration::Aliased))) {
+      decorations[argID].erase(decAttr.getName());
+      decorations[argID].set(
+          spirv::DecorationAttr::name,
+          spirv::DecorationAttr::get(context, spirv::Decoration::Aliased));
+    } else if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
+                                        spirv::Decoration::Restrict))) {
+      decorations[argID].erase(decAttr.getName());
+      decorations[argID].set(
+          spirv::DecorationAttr::name,
+          spirv::DecorationAttr::get(context, spirv::Decoration::Restrict));
+    } else if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
+                                        spirv::Decoration::AliasedPointer))) {
+      decorations[argID].erase(decAttr.getName());
+      decorations[argID].set(spirv::DecorationAttr::name,
+                             spirv::DecorationAttr::get(
+                                 context, spirv::Decoration::AliasedPointer));
+    } else if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
+                                        spirv::Decoration::RestrictPointer))) {
+      decorations[argID].erase(decAttr.getName());
+      decorations[argID].set(spirv::DecorationAttr::name,
+                             spirv::DecorationAttr::get(
+                                 context, spirv::Decoration::RestrictPointer));
+    }
+  }
+
+  argAttrs.push_back(decorations[argID].getDictionary(context));
+}
+
 LogicalResult
 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
   if (curFunction) {
@@ -463,11 +504,18 @@ spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
         return emitError(unknownLoc, "duplicate definition of result <id> ")
                << operands[1];
       }
+      setArgAttrs(operands[1]);
       auto argValue = funcOp.getArgument(i);
       valueMap[operands[1]] = argValue;
     }
   }
 
+  if (llvm::any_of(argAttrs, [](Attribute attr) {
+        auto argAttr = cast<DictionaryAttr>(attr);
+        return !argAttr.empty();
+      }))
+    funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
+
   // entryBlock is needed to access the arguments, Once that is done, we can
   // erase the block for functions with 'Import' LinkageAttributes, since these
   // are essentially function declarations, so they have no body.
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 69be47851ef3c5..115addd49f949a 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -233,6 +233,15 @@ class Deserializer {
     return globalVariableMap.lookup(id);
   }
 
+  /// Sets the argument's attributes with the given argument <id>.
+  void setArgAttrs(uint32_t argID);
+
+  /// Gets the symbol name from the name of decoration.
+  StringAttr getSymbolDecoration(StringRef decorationName) {
+    auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
+    return opBuilder.getStringAttr(attrName);
+  }
+
   //===--------------------------------------------------------------------===//
   // Type
   //===--------------------------------------------------------------------===//
@@ -605,6 +614,9 @@ class Deserializer {
   /// A list of all structs which have unresolved member types.
   SmallVector<DeferredStructTypeInfo, 0> deferredStructTypesInfos;
 
+  /// A list of argument attributes of function.
+  SmallVector<Attribute, 0> argAttrs;
+
 #ifndef NDEBUG
   /// A logger used to emit information during the deserialzation process.
   llvm::ScopedPrinter logger;
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 44538c38a41b83..2efb0ee64c9253 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -177,6 +177,35 @@ LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
   return success();
 }
 
+LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
+  unsigned numArgs = op.getNumArguments();
+  if (numArgs != 0) {
+    for (unsigned i = 0; i < numArgs; ++i) {
+      auto arg = op.getArgument(i);
+      uint32_t argTypeID = 0;
+      if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
+        return failure();
+      }
+      auto argValueID = getNextID();
+
+      // Process decoration attributes of arguments.
+      auto funcOp = cast<FunctionOpInterface>(*op);
+      for (auto argAttr : funcOp.getArgAttrs(i)) {
+        if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
+          if (failed(processDecorationAttr(op->getLoc(), argValueID,
+                                           decAttr.getValue(), decAttr)))
+            return failure();
+        }
+      }
+
+      valueIDMap[arg] = argValueID;
+      encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
+                            {argTypeID, argValueID});
+    }
+  }
+  return success();
+}
+
 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
   LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
   assert(functionHeader.empty() && functionBody.empty());
@@ -229,32 +258,15 @@ LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
     // is going to return false for this function from now on)
     // Hence, we'll remove the body once we are done with the serialization.
     op.addEntryBlock();
-    for (auto arg : op.getArguments()) {
-      uint32_t argTypeID = 0;
-      if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
-        return failure();
-      }
-      auto argValueID = getNextID();
-      valueIDMap[arg] = argValueID;
-      encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
-                            {argTypeID, argValueID});
-    }
+    if (failed(processFuncParameter(op)))
+      return failure();
     // Don't need to process the added block, there is nothing to process,
     // the fake body was added just to get the arguments, remove the body,
     // since it's use is done.
     op.eraseBody();
   } else {
-    // Declare the parameters.
-    for (auto arg : op.getArguments()) {
-      uint32_t argTypeID = 0;
-      if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
-        return failure();
-      }
-      auto argValueID = getNextID();
-      valueIDMap[arg] = argValueID;
-      encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
-                            {argTypeID, argValueID});
-    }
+    if (failed(processFuncParameter(op)))
+      return failure();
 
     // Some instructions (e.g., OpVariable) in a function must be in the first
     // block in the function. These instructions will be put in
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 9e9a16456cc102..6c3aaa35457470 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -215,23 +215,15 @@ static std::string getDecorationName(StringRef attrName) {
   return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
 }
 
-LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
-                                            NamedAttribute attr) {
-  auto attrName = attr.getName().strref();
-  auto decorationName = getDecorationName(attrName);
-  auto decoration = spirv::symbolizeDecoration(decorationName);
-  if (!decoration) {
-    return emitError(
-               loc, "non-argument attributes expected to have snake-case-ified "
-                    "decoration name, unhandled attribute with name : ")
-           << attrName;
-  }
+LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
+                                                Decoration decoration,
+                                                Attribute attr) {
   SmallVector<uint32_t, 1> args;
-  switch (*decoration) {
+  switch (decoration) {
   case spirv::Decoration::LinkageAttributes: {
     // Get the value of the Linkage Attributes
     // e.g., LinkageAttributes=["linkageName", linkageType].
-    auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr.getValue());
+    auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr);
     auto linkageName = linkageAttr.getLinkageName();
     auto linkageType = linkageAttr.getLinkageType().getValue();
     // Encode the Linkage Name (string literal to uint32_t).
@@ -241,32 +233,36 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
     break;
   }
   case spirv::Decoration::FPFastMathMode:
-    if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr.getValue())) {
+    if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
       args.push_back(static_cast<uint32_t>(intAttr.getValue()));
       break;
     }
     return emitError(loc, "expected FPFastMathModeAttr attribute for ")
-           << attrName;
+           << stringifyDecoration(decoration);
   case spirv::Decoration::Binding:
   case spirv::Decoration::DescriptorSet:
   case spirv::Decoration::Location:
-    if (auto intAttr = dyn_cast<IntegerAttr>(attr.getValue())) {
+    if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
       args.push_back(intAttr.getValue().getZExtValue());
       break;
     }
-    return emitError(loc, "expected integer attribute for ") << attrName;
+    return emitError(loc, "expected integer attribute for ")
+           << stringifyDecoration(decoration);
   case spirv::Decoration::BuiltIn:
-    if (auto strAttr = dyn_cast<StringAttr>(attr.getValue())) {
+    if (auto strAttr = dyn_cast<StringAttr>(attr)) {
       auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
       if (enumVal) {
         args.push_back(static_cast<uint32_t>(*enumVal));
         break;
       }
       return emitError(loc, "invalid ")
-             << attrName << " attribute " << strAttr.getValue();
+             << stringifyDecoration(decoration) << " decoration attribute "
+             << strAttr.getValue();
     }
-    return emitError(loc, "expected string attribute for ") << attrName;
+    return emitError(loc, "expected string attribute for ")
+           << stringifyDecoration(decoration);
   case spirv::Decoration::Aliased:
+  case spirv::Decoration::AliasedPointer:
   case spirv::Decoration::Flat:
   case spirv::Decoration::NonReadable:
   case spirv::Decoration::NonWritable:
@@ -275,14 +271,33 @@ LogicalResult Serializer::processDecoration(Location loc, uin...
[truncated]

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks for working on this @sott0n!
I left some suggestions, mostly on the coding style. I'm not super familiar with the deserializer, so it would be best if @antiagainst could take a look to.

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td Outdated Show resolved Hide resolved
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp Outdated Show resolved Hide resolved
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp Outdated Show resolved Hide resolved
mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp Outdated Show resolved Hide resolved
mlir/lib/Target/SPIRV/Serialization/Serializer.cpp Outdated Show resolved Hide resolved
mlir/lib/Target/SPIRV/Serialization/Serializer.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp Outdated Show resolved Hide resolved
@sott0n
Copy link
Contributor Author

sott0n commented Dec 29, 2023

@kuhar Thanks for your review! I addressed your comments.

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp Show resolved Hide resolved
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp Outdated Show resolved Hide resolved
mlir/test/Target/SPIRV/function-decorations.mlir Outdated Show resolved Hide resolved
mlir/test/Target/SPIRV/function-decorations.mlir Outdated Show resolved Hide resolved
mlir/test/Target/SPIRV/function-decorations.mlir Outdated Show resolved Hide resolved
mlir/test/Target/SPIRV/function-decorations.mlir Outdated Show resolved Hide resolved
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but please wait for a second approval from @antiagainst before merging

Copy link
Member

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks for the contribution!

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp Show resolved Hide resolved
@@ -414,7 +414,7 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses], []
// -----

spirv.module PhysicalStorageBuffer64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, PhysicalStorageBufferAddresses], []> {
spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer>) "None" {
spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }) "None" {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is okay for now; but later we might want to have custom parser/printer to simpify it into something like %arg0 : !spirv.ptr<...> Aliased or something (cannot recall exactly the restrictions on attribute parsing/printing). Also given we are here for spirv.func, it would be nice to omit the function control when it's "None". Not for this patch though.

@kuhar as FYI

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

later we might want to have custom parser/printer to simpify it into something like %arg0 : !spirv.ptr<...> Aliased or something (cannot recall exactly the restrictions on attribute parsing/printing).

It looks good! After applying this patch, I would like to think about the custom parse/printer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG!

mlir/test/Dialect/SPIRV/IR/structure-ops.mlir Outdated Show resolved Hide resolved
mlir/lib/Target/SPIRV/Deserialization/Deserializer.h Outdated Show resolved Hide resolved
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp Outdated Show resolved Hide resolved
spirv::Decoration::AliasedPointer,
spirv::Decoration::RestrictPointer}) {
if (decAttr.getName() ==
getSymbolDecoration(stringifyDecoration(decoration))) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd think it's better to just break once found one of the above four decorations. We can create the dictionary attribute out of the for loop. This way we don't mutate decorations and it's clearer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean that during deserialization, only these four decorations should be added to argAttrs ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually only one of these four for now. That reminds, we should error out if there are more we didn't recognize to avoid silently ignore decorations. So, make sure the decoration is only one of the four.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I understand and updated this function.

we should error out if there are more we didn't recognize to avoid silently ignore decorations.

Should we add this invalid test case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine to do without a test for now, given that the possible error cases are unbound. Especially for the deserialization, right now it's meant to only cover the cases where we can serialize.

BTW we do have some deserialization test here: https://github.com/llvm/llvm-project/blob/main/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp but it's a bit painful to write those.

mlir/lib/Target/SPIRV/Deserialization/Deserializer.h Outdated Show resolved Hide resolved
@antiagainst
Copy link
Member

@sott0n I didn't see your fixes. Did you forget to push out or still working on it?

@sott0n
Copy link
Contributor Author

sott0n commented Jan 5, 2024

@antiagainst Sorry.. I pushed it. And Some are still in the works.

@antiagainst
Copy link
Member

Thanks a lot for the contribution and bearing with me for the nitpicking. :) To avoid burden you further, I rebased and revised some comment/error message slightly. :) I'll land once bots are happy.

@antiagainst antiagainst changed the title [mlir][spirv] Support function argument decorations for ptr in the PhysicalStorageBuffer [mlir][spirv] Support alias/restrict function argument decorations Jan 6, 2024
@antiagainst antiagainst merged commit 747d8fb into llvm:main Jan 6, 2024
4 checks passed
@sott0n
Copy link
Contributor Author

sott0n commented Jan 7, 2024

@kuhar @antiagainst Thanks for your comment and support to land this PR! If you have any issues or TODO tasks in SPIR-V dialect, I would be happy if you could share it with me. :)

@sott0n sott0n deleted the decoration-pointer-func-arg branch January 7, 2024 02:17
@antiagainst
Copy link
Member

antiagainst commented Jan 7, 2024

@kuhar @antiagainst Thanks for your comment and support to land this PR! If you have any issues or TODO tasks in SPIR-V dialect, I would be happy if you could share it with me. :)

Thanks for your further interest! Please feel free to take a look at issues with the mlir:spirv label (https://github.com/llvm/llvm-project/issues?q=is%3Aopen+is%3Aissue+label%3Amlir%3Aspirv) to see if any one of them is interesting to you, especially those with the good first issue label at the same time!

justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
…lvm#76353)

Closes llvm#76106

---------

Co-authored-by: Lei Zhang <antiagainst@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir][spirv] Support function argument decorations
4 participants