-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[CIR] Introduce syntax for scalable vectors #172683
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CIR] Introduce syntax for scalable vectors #172683
Conversation
As of llvm#172346, CIR supports scalable vectors. This patch updates the assembly format to introduce syntax for representing scalable dimensions. The proposed syntax follows the format used by the builtin vector type: ```mlir # Builtin scalable vector cir.vector<[16] x !cir.int<u, 8>> # CIR scalable vector !cir.vector<[16] x !cir.int<u, 8>> ``` This contrasts with LLVM IR, where scalable dimensions are modeled using the `vscale` keyword: ```llvm ; LLVM scalable vector <vscale x 16 x i8> ``` To support this change, `cir::VectorType` gains a custom parser and printer, which are small modifications of the auto-generated ones.
|
@llvm/pr-subscribers-clangir Author: Andrzej Warzyński (banach-space) ChangesAs of #172346, CIR supports scalable vectors. This patch updates the The proposed syntax follows the format used by the builtin vector type: # Builtin scalable vector
cir.vector<[16] x !cir.int<u, 8>>
# CIR scalable vector
!cir.vector<[16] x !cir.int<u, 8>>This contrasts with LLVM IR, where scalable dimensions are modeled using ; LLVM scalable vector
<vscale x 16 x i8>To support this change, Full diff: https://github.com/llvm/llvm-project/pull/172683.diff 3 Files Affected:
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
index ce64bef3270ed..fe79e3a086d4e 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
@@ -430,6 +430,7 @@ def CIR_VectorType : CIR_Type<"Vector", "vector", [
```mlir
vector-type ::= !cir.vector<size x element-type>
+ size ::= (decimal-literal | `[` decimal-literal `]`)
element-type ::= float-type | integer-type | pointer-type
```
@@ -442,6 +443,13 @@ def CIR_VectorType : CIR_Type<"Vector", "vector", [
!cir.vector<4 x !cir.int<u, 8>>
!cir.vector<2 x !cir.float>
```
+
+ Scalable vectors are indicated by enclosing size in square brackets.
+
+ Example:
+ ```mlir
+ !cir.vector<[4] x !cir.int<u, 8>>
+ ```
}];
let parameters = (ins
@@ -450,10 +458,6 @@ def CIR_VectorType : CIR_Type<"Vector", "vector", [
OptionalParameter<"bool">:$is_scalable
);
- let assemblyFormat = [{
- `<` $size `x` $element_type `>`
- }];
-
let builders = [
TypeBuilderWithInferredContext<(ins
"mlir::Type":$element_type, "uint64_t":$size, CArg<"bool",
@@ -471,6 +475,7 @@ def CIR_VectorType : CIR_Type<"Vector", "vector", [
let genVerifyDecl = 1;
let skipDefaultBuilders = 1;
+ let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp
index c7531022fdfb8..182714ddad9e9 100644
--- a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp
@@ -828,6 +828,76 @@ mlir::LogicalResult cir::VectorType::verify(
return success();
}
+mlir::Type cir::VectorType::parse(::mlir::AsmParser &odsParser) {
+
+ llvm::SMLoc odsLoc = odsParser.getCurrentLocation();
+ mlir::Builder odsBuilder(odsParser.getContext());
+ mlir::FailureOr<::mlir::Type> elementType;
+ mlir::FailureOr<uint64_t> size;
+ bool isScalabe = false;
+
+ // Parse literal '<'
+ if (odsParser.parseLess())
+ return {};
+
+ // Parse literal '[', if present, and set the scalability flag accordingly
+ if (odsParser.parseOptionalLSquare().succeeded()) {
+ isScalabe = true;
+ }
+
+ // Parse variable 'size'
+ size = mlir::FieldParser<uint64_t>::parse(odsParser);
+ if (mlir::failed(size)) {
+ odsParser.emitError(odsParser.getCurrentLocation(),
+ "failed to parse CIR_VectorType parameter 'size' which "
+ "is to be a `uint64_t`");
+ return {};
+ }
+
+ // Parse literal ']', which is expected when dealing with scalable
+ // dim sizes
+ if (isScalabe && odsParser.parseRSquare().failed()) {
+ odsParser.emitError(odsParser.getCurrentLocation(),
+ "missing closing `]` for scalable dim size");
+ return {};
+ }
+
+ // Parse literal 'x'
+ if (odsParser.parseKeyword("x"))
+ return {};
+
+ // Parse variable 'elementType'
+ elementType = mlir::FieldParser<::mlir::Type>::parse(odsParser);
+ if (mlir::failed(elementType)) {
+ odsParser.emitError(odsParser.getCurrentLocation(),
+ "failed to parse CIR_VectorType parameter "
+ "'elementType' which is to be a `mlir::Type`");
+ return {};
+ }
+
+ // Parse literal '>'
+ if (odsParser.parseGreater())
+ return {};
+ return odsParser.getChecked<VectorType>(odsLoc, odsParser.getContext(),
+ mlir::Type((*elementType)),
+ uint64_t((*size)), isScalabe);
+}
+
+void cir::VectorType::print(mlir::AsmPrinter &odsPrinter) const {
+ mlir::Builder odsBuilder(getContext());
+ odsPrinter << "<";
+ if (this->getIsScalable())
+ odsPrinter << "[";
+
+ odsPrinter.printStrippedAttrOrType(getSize());
+ if (this->getIsScalable())
+ odsPrinter << "]";
+ odsPrinter << ' ' << "x";
+ odsPrinter << ' ';
+ odsPrinter.printStrippedAttrOrType(getElementType());
+ odsPrinter << ">";
+}
+
//===----------------------------------------------------------------------===//
// TargetAddressSpace definitions
//===----------------------------------------------------------------------===//
diff --git a/clang/test/CIR/IR/vector.cir b/clang/test/CIR/IR/vector.cir
index d274c35099ee5..ac5d0453b1b7e 100644
--- a/clang/test/CIR/IR/vector.cir
+++ b/clang/test/CIR/IR/vector.cir
@@ -23,6 +23,7 @@ cir.func @vec_int_test() {
%0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
%1 = cir.alloca !cir.vector<3 x !s32i>, !cir.ptr<!cir.vector<3 x !s32i>>, ["b"]
%2 = cir.alloca !cir.vector<2 x !s32i>, !cir.ptr<!cir.vector<2 x !s32i>>, ["c"]
+ %3 = cir.alloca !cir.vector<[1] x !s32i>, !cir.ptr<!cir.vector<[1] x !s32i>>, ["d"]
cir.return
}
@@ -30,6 +31,7 @@ cir.func @vec_int_test() {
// CHECK: %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
// CHECK: %1 = cir.alloca !cir.vector<3 x !s32i>, !cir.ptr<!cir.vector<3 x !s32i>>, ["b"]
// CHECK: %2 = cir.alloca !cir.vector<2 x !s32i>, !cir.ptr<!cir.vector<2 x !s32i>>, ["c"]
+// CHECK: %3 = cir.alloca !cir.vector<[1] x !s32i>, !cir.ptr<!cir.vector<[1] x !s32i>>, ["d"]
// CHECK: cir.return
// CHECK: }
|
|
@llvm/pr-subscribers-clang Author: Andrzej Warzyński (banach-space) ChangesAs of #172346, CIR supports scalable vectors. This patch updates the The proposed syntax follows the format used by the builtin vector type: # Builtin scalable vector
cir.vector<[16] x !cir.int<u, 8>>
# CIR scalable vector
!cir.vector<[16] x !cir.int<u, 8>>This contrasts with LLVM IR, where scalable dimensions are modeled using ; LLVM scalable vector
<vscale x 16 x i8>To support this change, Full diff: https://github.com/llvm/llvm-project/pull/172683.diff 3 Files Affected:
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
index ce64bef3270ed..fe79e3a086d4e 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
@@ -430,6 +430,7 @@ def CIR_VectorType : CIR_Type<"Vector", "vector", [
```mlir
vector-type ::= !cir.vector<size x element-type>
+ size ::= (decimal-literal | `[` decimal-literal `]`)
element-type ::= float-type | integer-type | pointer-type
```
@@ -442,6 +443,13 @@ def CIR_VectorType : CIR_Type<"Vector", "vector", [
!cir.vector<4 x !cir.int<u, 8>>
!cir.vector<2 x !cir.float>
```
+
+ Scalable vectors are indicated by enclosing size in square brackets.
+
+ Example:
+ ```mlir
+ !cir.vector<[4] x !cir.int<u, 8>>
+ ```
}];
let parameters = (ins
@@ -450,10 +458,6 @@ def CIR_VectorType : CIR_Type<"Vector", "vector", [
OptionalParameter<"bool">:$is_scalable
);
- let assemblyFormat = [{
- `<` $size `x` $element_type `>`
- }];
-
let builders = [
TypeBuilderWithInferredContext<(ins
"mlir::Type":$element_type, "uint64_t":$size, CArg<"bool",
@@ -471,6 +475,7 @@ def CIR_VectorType : CIR_Type<"Vector", "vector", [
let genVerifyDecl = 1;
let skipDefaultBuilders = 1;
+ let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp
index c7531022fdfb8..182714ddad9e9 100644
--- a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp
@@ -828,6 +828,76 @@ mlir::LogicalResult cir::VectorType::verify(
return success();
}
+mlir::Type cir::VectorType::parse(::mlir::AsmParser &odsParser) {
+
+ llvm::SMLoc odsLoc = odsParser.getCurrentLocation();
+ mlir::Builder odsBuilder(odsParser.getContext());
+ mlir::FailureOr<::mlir::Type> elementType;
+ mlir::FailureOr<uint64_t> size;
+ bool isScalabe = false;
+
+ // Parse literal '<'
+ if (odsParser.parseLess())
+ return {};
+
+ // Parse literal '[', if present, and set the scalability flag accordingly
+ if (odsParser.parseOptionalLSquare().succeeded()) {
+ isScalabe = true;
+ }
+
+ // Parse variable 'size'
+ size = mlir::FieldParser<uint64_t>::parse(odsParser);
+ if (mlir::failed(size)) {
+ odsParser.emitError(odsParser.getCurrentLocation(),
+ "failed to parse CIR_VectorType parameter 'size' which "
+ "is to be a `uint64_t`");
+ return {};
+ }
+
+ // Parse literal ']', which is expected when dealing with scalable
+ // dim sizes
+ if (isScalabe && odsParser.parseRSquare().failed()) {
+ odsParser.emitError(odsParser.getCurrentLocation(),
+ "missing closing `]` for scalable dim size");
+ return {};
+ }
+
+ // Parse literal 'x'
+ if (odsParser.parseKeyword("x"))
+ return {};
+
+ // Parse variable 'elementType'
+ elementType = mlir::FieldParser<::mlir::Type>::parse(odsParser);
+ if (mlir::failed(elementType)) {
+ odsParser.emitError(odsParser.getCurrentLocation(),
+ "failed to parse CIR_VectorType parameter "
+ "'elementType' which is to be a `mlir::Type`");
+ return {};
+ }
+
+ // Parse literal '>'
+ if (odsParser.parseGreater())
+ return {};
+ return odsParser.getChecked<VectorType>(odsLoc, odsParser.getContext(),
+ mlir::Type((*elementType)),
+ uint64_t((*size)), isScalabe);
+}
+
+void cir::VectorType::print(mlir::AsmPrinter &odsPrinter) const {
+ mlir::Builder odsBuilder(getContext());
+ odsPrinter << "<";
+ if (this->getIsScalable())
+ odsPrinter << "[";
+
+ odsPrinter.printStrippedAttrOrType(getSize());
+ if (this->getIsScalable())
+ odsPrinter << "]";
+ odsPrinter << ' ' << "x";
+ odsPrinter << ' ';
+ odsPrinter.printStrippedAttrOrType(getElementType());
+ odsPrinter << ">";
+}
+
//===----------------------------------------------------------------------===//
// TargetAddressSpace definitions
//===----------------------------------------------------------------------===//
diff --git a/clang/test/CIR/IR/vector.cir b/clang/test/CIR/IR/vector.cir
index d274c35099ee5..ac5d0453b1b7e 100644
--- a/clang/test/CIR/IR/vector.cir
+++ b/clang/test/CIR/IR/vector.cir
@@ -23,6 +23,7 @@ cir.func @vec_int_test() {
%0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
%1 = cir.alloca !cir.vector<3 x !s32i>, !cir.ptr<!cir.vector<3 x !s32i>>, ["b"]
%2 = cir.alloca !cir.vector<2 x !s32i>, !cir.ptr<!cir.vector<2 x !s32i>>, ["c"]
+ %3 = cir.alloca !cir.vector<[1] x !s32i>, !cir.ptr<!cir.vector<[1] x !s32i>>, ["d"]
cir.return
}
@@ -30,6 +31,7 @@ cir.func @vec_int_test() {
// CHECK: %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
// CHECK: %1 = cir.alloca !cir.vector<3 x !s32i>, !cir.ptr<!cir.vector<3 x !s32i>>, ["b"]
// CHECK: %2 = cir.alloca !cir.vector<2 x !s32i>, !cir.ptr<!cir.vector<2 x !s32i>>, ["c"]
+// CHECK: %3 = cir.alloca !cir.vector<[1] x !s32i>, !cir.ptr<!cir.vector<[1] x !s32i>>, ["d"]
// CHECK: cir.return
// CHECK: }
|
andykaylor
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great. I have only the smallest possible nit.
| return {}; | ||
|
|
||
| // Parse literal '[', if present, and set the scalability flag accordingly | ||
| if (odsParser.parseOptionalLSquare().succeeded()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't have braces.
| // Parse variable 'elementType' | ||
| elementType = mlir::FieldParser<::mlir::Type>::parse(odsParser); | ||
| if (mlir::failed(elementType)) { | ||
| odsParser.emitError(odsParser.getCurrentLocation(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add an invalid vector cir test to make sure errors are emitted correctly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be different in CIR, but in MLIR we don't really test the parser and tests in invalid.mlir (and similar) are reserved for verification errors.
I am happy to add a test if you think that that would be helpful, but we’d probably want to add a dedicated file for parser errors - perhaps one already exists? I didn’t find any.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usually, for the handwritten parser or verifier, we have a invalid-<>.cir file to test it, for example, clang/test/CIR/IR/invalid-vector.cir, I think we can add a small test for scalable vector type syntax. What do you think? @andykaylor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, Only this error need to be covered "missing closing ] for scalable dim size", the other already covered in this test :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a test in this commit.
IMO, it's not great :/ (suggestions for improvement are welcome) .Testing !cir.vector<[1 x !s32i> (instead of !cir.vector<[1 x] !s32i>) would be better, but the former is captured by Parser::parseDialectSymbolBody with:
/llvm-project/clang/test/CIR/IR/invalid-vector.cir:17:30: error: unbalanced '[' character in pretty dialect name
%3 = cir.alloca !cir.vector<[1 x !s32i>, !cir.ptr<!cir.vector<[1] x !s32i>>
^That error is hit before getting into cir::VectorType::parse.
Btw, I don't want to come across as nit-picking or pushing back, but I see quite a few CIR parser errors that are not tested, e.g.
| parser.emitError(loc, "invalid self-reference within record"); |
Perhaps it would be better to skip testing in this case as well? WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding it. I am fine with both options, but I think it will be better to keep this test case 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, I don't want to come across as nit-picking or pushing back, but I see quite a few CIR parser errors that are not tested, e.g.
I'm sure we've been inconsistent about this. I generally only look for tests for verifier errors and have been satisfied with round-trip tests for printing/parsing. My view is that the tests for the verifier check that we are correctly catching incorrectly constructed IR, which can occur anywhere during IR generation or processing, whereas the printing and parsing are paired so they test each other for correctness.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets me keep the test for now, we can always remove it later.
AmrDeveloper
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with nits, Thanks
As of #172346, CIR supports scalable vectors. This patch updates the
assembly format to introduce syntax for representing scalable
dimensions.
The proposed syntax follows the format used by the builtin vector type:
This contrasts with LLVM IR, where scalable dimensions are modeled using
the
vscalekeyword:To support this change,
cir::VectorTypegains a custom parser andprinter, which are small modifications of the auto-generated ones.