-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][spirv] Add definitions and (de)serialization for FPRoundingMode #101546
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][spirv] Add definitions and (de)serialization for FPRoundingMode #101546
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-spirv Author: Andrea Faulds (andfau-amd) ChangesFull diff: https://github.com/llvm/llvm-project/pull/101546.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 6ec97e17c5dcc..b38978272c5bd 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -3249,6 +3249,19 @@ def SPIRV_FC_OptNoneINTEL : I32BitEnumAttrCaseBit<"OptNoneINTEL", 16> {
];
}
+def SPIRV_FPRM_RTE : I32EnumAttrCase<"RTE", 0>;
+def SPIRV_FPRM_RTZ : I32EnumAttrCase<"RTZ", 1>;
+def SPIRV_FPRM_RTP : I32EnumAttrCase<"RTP", 2>;
+def SPIRV_FPRM_RTN : I32EnumAttrCase<"RTN", 3>;
+
+// TODO: Enforce SPIR-V spec validation rule for Shader capability: only permit
+// FPRoundingMode on a value stored to certain storage classes?
+// (The OpenCL environment also has FPRoundingMode rules, but different.)
+def SPIRV_FPRoundingModeAttr :
+ SPIRV_I32EnumAttr<"FPRoundingMode", "valid SPIR-V FPRoundingMode", "fp_rounding_mode", [
+ SPIRV_FPRM_RTE, SPIRV_FPRM_RTZ, SPIRV_FPRM_RTP, SPIRV_FPRM_RTN
+ ]>;
+
def SPIRV_FunctionControlAttr :
SPIRV_BitEnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", "function_control", [
SPIRV_FC_None, SPIRV_FC_Inline, SPIRV_FC_DontInline, SPIRV_FC_Pure, SPIRV_FC_Const,
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index d7a308548cf4d..12980879b20ab 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -250,6 +250,15 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
static_cast<FPFastMathMode>(words[2])));
break;
+ case spirv::Decoration::FPRoundingMode:
+ if (words.size() != 3) {
+ return emitError(unknownLoc, "OpDecorate with ")
+ << decorationName << " needs a single integer literal";
+ }
+ decorations[words[0]].set(
+ symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
+ static_cast<FPRoundingMode>(words[2])));
+ break;
case spirv::Decoration::DescriptorSet:
case spirv::Decoration::Binding:
if (words.size() != 3) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 4c4fef177317e..714a3edfb5657 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -214,6 +214,9 @@ static std::string getDecorationName(StringRef attrName) {
// expected FPFastMathMode.
if (attrName == "fp_fast_math_mode")
return "FPFastMathMode";
+ // similar here
+ if (attrName == "fp_rounding_mode")
+ return "FPRoundingMode";
return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
}
@@ -242,6 +245,13 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
}
return emitError(loc, "expected FPFastMathModeAttr attribute for ")
<< stringifyDecoration(decoration);
+ case spirv::Decoration::FPRoundingMode:
+ if (auto intAttr = dyn_cast<FPRoundingModeAttr>(attr)) {
+ args.push_back(static_cast<uint32_t>(intAttr.getValue()));
+ break;
+ }
+ return emitError(loc, "expected FPRoundingModeAttr attribute for ")
+ << stringifyDecoration(decoration);
case spirv::Decoration::Binding:
case spirv::Decoration::DescriptorSet:
case spirv::Decoration::Location:
diff --git a/mlir/test/Target/SPIRV/decorations.mlir b/mlir/test/Target/SPIRV/decorations.mlir
index 195773735431e..0a29290b6a6fa 100644
--- a/mlir/test/Target/SPIRV/decorations.mlir
+++ b/mlir/test/Target/SPIRV/decorations.mlir
@@ -97,3 +97,13 @@ spirv.func @fmul_decorations(%arg: f32) -> f32 "None" {
spirv.ReturnValue %0 : f32
}
}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel, Float16], []> {
+spirv.func @fp_rounding_mode(%arg: f32) -> f16 "None" {
+ // CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f32 to f16
+ %0 = spirv.FConvert %arg {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f32 to f16
+ spirv.ReturnValue %0 : f16
+}
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add some negative tests for when fp_rounding_mode
is not a #spirv.fp_rounding_mode
attribute?
Hmm, I'm not sure if it's an important enough case to test. The code that rejects it is very straightforward, I couldn't find any negative tests for similar attributes, and it doesn't seem that likely to me that someone would generate invalid MLIR like that anyway. One observation I have there is that this is, effectively, validated by the serializer rather than a proper verifier method. I'm not sure if that's good, but it is what the other attributes do too, so... |
OK, that makes sense then. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
No description provided.