-
Notifications
You must be signed in to change notification settings - Fork 15.5k
Implement select/selectsh builtins in CIR #172299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-clang @llvm/pr-subscribers-clangir Author: Priyanshu Kumar (Priyanshu3820) ChangesFull diff: https://github.com/llvm/llvm-project/pull/172299.diff 2 Files Affected:
diff --git a/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp b/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp
index fb17e31bf36d6..b3182fc83776e 100644
--- a/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp
@@ -370,6 +370,22 @@ static mlir::Value emitX86vpcom(CIRGenBuilderTy &builder, mlir::Location loc,
return builder.createVecCompare(loc, pred, op0, op1);
}
+static mlir::Value emitX86Select(CIRGenBuilderTy &builder, mlir::Location loc,
+ mlir::Value mask, mlir::Value Op0,
+ mlir::Value Op1) {
+ return builder.create<cir::VecTernaryOp>(loc, Op0.getType(), mask, Op0, Op1);
+}
+
+static mlir::Value emitX86ScalarSelect(CIRGenBuilderTy &builder,
+ mlir::Location loc, mlir::Value mask,
+ mlir::Value Op0, mlir::Value Op1) {
+
+ mlir::Value zero = builder.getZero(mask.getType(), loc);
+ mlir::Value cond = builder.createCompare(loc, CmpOpKind::ne, mask, zero);
+
+ return builder.createSelect(loc, Op0.getType(), cond, Op0, Op1);
+}
+
mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID,
const CallExpr *expr) {
if (builtinID == Builtin::BI__builtin_cpu_is) {
@@ -1186,10 +1202,14 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID,
case X86::BI__builtin_ia32_selectpd_128:
case X86::BI__builtin_ia32_selectpd_256:
case X86::BI__builtin_ia32_selectpd_512:
+ return emitX86Select(builder, getLoc(expr->getExprLoc()), ops[0], ops[1],
+ ops[2]);
case X86::BI__builtin_ia32_selectsh_128:
case X86::BI__builtin_ia32_selectsbf_128:
case X86::BI__builtin_ia32_selectss_128:
case X86::BI__builtin_ia32_selectsd_128:
+ return emitX86ScalarSelect(builder, getLoc(expr->getExprLoc()), ops[0],
+ ops[1], ops[2]);
case X86::BI__builtin_ia32_cmpb128_mask:
case X86::BI__builtin_ia32_cmpb256_mask:
case X86::BI__builtin_ia32_cmpb512_mask:
diff --git a/clang/test/CIR/CodeGenBuiltins/X86/avx512-select-builtins.c b/clang/test/CIR/CodeGenBuiltins/X86/avx512-select-builtins.c
new file mode 100644
index 0000000000000..68edcbee8fcb6
--- /dev/null
+++ b/clang/test/CIR/CodeGenBuiltins/X86/avx512-select-builtins.c
@@ -0,0 +1,292 @@
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -target-feature +avx512bw -target-feature +avx512vl -fclangir -emit-cir %s -o - | FileCheck %s --check-prefix=CIR
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -target-feature +avx512bw -target-feature +avx512vl -emit-llvm %s -o - | FileCheck %s --check-prefix=LLVM
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -target-feature +avx512bw -target-feature +avx512vl -emit-llvm %s -o - | FileCheck %s --check-prefix=OGCG
+
+// REQUIRES: avx512bw
+// REQUIRES: avx512vl
+
+#include <immintrin.h>
+
+// CIR-LABEL: test_selectb_128
+// CIR: cir.vec.ternary
+// LLVM-LABEL: test_selectb_128
+// LLVM: select <16 x i8>
+// OGCG-LABEL: test_selectb_128
+// OGCG: select <16 x i8>
+__m128i test_selectb_128(__mmask16 k, __m128i a, __m128i b) {
+ return _mm_selectb_128(k, a, b);
+}
+
+// CIR-LABEL: test_selectb_256
+// CIR: cir.vec.ternary
+// LLVM-LABEL: test_selectb_256
+// LLVM: select <32 x i8>
+// OGCG-LABEL: test_selectb_256
+// OGCG: select <32 x i8>
+__m256i test_selectb_256(__mmask32 k, __m256i a, __m256i b) {
+ return _mm256_selectb_epi8(k, a, b);
+}
+
+// CIR-LABEL: test_selectb_512
+// CIR: cir.vec.ternary
+// LLVM-LABEL: test_selectb_512
+// LLVM: select <64 x i8>
+// OGCG-LABEL: test_selectb_512
+// OGCG: select <64 x i8>
+__m512i test_selectb_512(__mmask64 k, __m512i a, __m512i b) {
+ return _mm512_selectb_epi8(k, a, b);
+}
+
+// CIR-LABEL: test_selectw_128
+// CIR: cir.vec.ternary
+// LLVM-LABEL: test_selectw_128
+// LLVM: select <8 x i16>
+// OGCG-LABEL: test_selectw_128
+// OGCG: select <8 x i16>
+__m128i test_selectw_128(__mmask8 k, __m128i a, __m128i b) {
+ return _mm_selectw_128(k, a, b);
+}
+
+// CIR-LABEL: test_selectw_256
+// CIR: cir.vec.ternary
+// LLVM-LABEL: test_selectw_256
+// LLVM: select <16 x i16>
+// OGCG-LABEL: test_selectw_256
+// OGCG: select <16 x i16>
+__m256i test_selectw_256(__mmask16 k, __m256i a, __m256i b) {
+ return _mm256_selectw_epi16(k, a, b);
+}
+
+// CIR-LABEL: test_selectw_512
+// CIR: cir.vec.ternary
+// LLVM-LABEL: test_selectw_512
+// LLVM: select <32 x i16>
+// OGCG-LABEL: test_selectw_512
+// OGCG: select <32 x i16>
+__m512i test_selectw_512(__mmask32 k, __m512i a, __m512i b) {
+ return _mm512_selectw_epi16(k, a, b);
+}
+
+// CIR-LABEL: test_selectd_128
+// CIR: cir.vec.ternary
+// LLVM-LABEL: test_selectd_128
+// LLVM: select <4 x i32>
+// OGCG-LABEL: test_selectd_128
+// OGCG: select <4 x i32>
+__m128i test_selectd_128(__mmask4 k, __m128i a, __m128i b) {
+ return _mm_selectd_128(k, a, b);
+}
+
+// CIR-LABEL: test_selectd_256
+// CIR: cir.vec.ternary
+// LLVM-LABEL: test_selectd_256
+// LLVM: select <8 x i32>
+// OGCG-LABEL: test_selectd_256
+// OGCG: select <8 x i32>
+__m256i test_selectd_256(__mmask8 k, __m256i a, __m256i b) {
+ return _mm256_selectd_epi32(k, a, b);
+}
+
+// CIR-LABEL: test_selectd_512
+// CIR: cir.vec.ternary
+// LLVM-LABEL: test_selectd_512
+// LLVM: select <16 x i32>
+// OGCG-LABEL: test_selectd_512
+// OGCG: select <16 x i32>
+__m512i test_selectd_512(__mmask16 k, __m512i a, __m512i b) {
+ return _mm512_selectd_epi32(k, a, b);
+}
+
+// CIR-LABEL: test_selectq_128
+// CIR: cir.vec.ternary
+// LLVM-LABEL: test_selectq_128
+// LLVM: select <2 x i64>
+// OGCG-LABEL: test_selectq_128
+// OGCG: select <2 x i64>
+__m128i test_selectq_128(__mmask2 k, __m128i a, __m128i b) {
+ return _mm_selectq_128(k, a, b);
+}
+
+// CIR-LABEL: test_selectq_256
+// CIR: cir.vec.ternary
+// LLVM-LABEL: test_selectq_256
+// LLVM: select <4 x i64>
+// OGCG-LABEL: test_selectq_256
+// OGCG: select <4 x i64>
+__m256i test_selectq_256(__mmask4 k, __m256i a, __m256i b) {
+ return _mm256_selectq_epi64(k, a, b);
+}
+
+// CIR-LABEL: test_selectq_512
+// CIR: cir.vec.ternary
+// LLVM-LABEL: test_selectq_512
+// LLVM: select <8 x i64>
+// OGCG-LABEL: test_selectq_512
+// OGCG: select <8 x i64>
+__m512i test_selectq_512(__mmask8 k, __m512i a, __m512i b) {
+ return _mm512_selectq_epi64(k, a, b);
+}
+
+// CIR-LABEL: test_selectph_128
+// CIR: cir.vec.ternary
+// LLVM-LABEL: test_selectph_128
+// LLVM: select
+// OGCG-LABEL: test_selectph_128
+// OGCG: select
+__m128i test_selectph_128(__mmask8 k, __m128i a, __m128i b) {
+ return _mm_selectph_128(k, a, b);
+}
+
+// CIR-LABEL: test_selectph_256
+// CIR: cir.vec.ternary
+// LLVM-LABEL: test_selectph_256
+// LLVM: select
+// OGCG-LABEL: test_selectph_256
+// OGCG: select
+__m256i test_selectph_256(__mmask16 k, __m256i a, __m256i b) {
+ return _mm256_selectph_epi16(k, a, b);
+}
+
+// CIR-LABEL: test_selectph_512
+// CIR: cir.vec.ternary
+// LLVM-LABEL: test_selectph_512
+// LLVM: select
+// OGCG-LABEL: test_selectph_512
+// OGCG: select
+__m512i test_selectph_512(__mmask32 k, __m512i a, __m512i b) {
+ return _mm512_selectph_epi16(k, a, b);
+}
+
+// CIR-LABEL: test_selectpbf_128
+// CIR: cir.vec.ternary
+// LLVM-LABEL: test_selectpbf_128
+// LLVM: select
+// OGCG-LABEL: test_selectpbf_128
+// OGCG: select
+__m128i test_selectpbf_128(__mmask8 k, __m128i a, __m128i b) {
+ return _mm_selectpbf_128(k, a, b);
+}
+
+// CIR-LABEL: test_selectpbf_256
+// CIR: cir.vec.ternary
+// LLVM-LABEL: test_selectpbf_256
+// LLVM: select
+// OGCG-LABEL: test_selectpbf_256
+// OGCG: select
+__m256i test_selectpbf_256(__mmask16 k, __m256i a, __m256i b) {
+ return _mm256_selectpbf_epi16(k, a, b);
+}
+
+// CIR-LABEL: test_selectpbf_512
+// CIR: cir.vec.ternary
+// LLVM-LABEL: test_selectpbf_512
+// LLVM: select
+// OGCG-LABEL: test_selectpbf_512
+// OGCG: select
+__m512i test_selectpbf_512(__mmask32 k, __m512i a, __m512i b) {
+ return _mm512_selectpbf_epi16(k, a, b);
+}
+
+// CIR-LABEL: test_selectps_128
+// CIR: cir.vec.ternary
+// LLVM-LABEL: test_selectps_128
+// LLVM: select
+// OGCG-LABEL: test_selectps_128
+// OGCG: select
+__m128 test_selectps_128(__mmask8 k, __m128 a, __m128 b) {
+ return _mm_selectps_128(k, a, b);
+}
+
+// CIR-LABEL: test_selectps_256
+// CIR: cir.vec.ternary
+// LLVM-LABEL: test_selectps_256
+// LLVM: select
+// OGCG-LABEL: test_selectps_256
+// OGCG: select
+__m256 test_selectps_256(__mmask8 k, __m256 a, __m256 b) {
+ return _mm256_selectps(k, a, b);
+}
+
+// CIR-LABEL: test_selectps_512
+// CIR: cir.vec.ternary
+// LLVM-LABEL: test_selectps_512
+// LLVM: select
+// OGCG-LABEL: test_selectps_512
+// OGCG: select
+__m512 test_selectps_512(__mmask16 k, __m512 a, __m512 b) {
+ return _mm512_selectps(k, a, b);
+}
+
+// CIR-LABEL: test_selectpd_128
+// CIR: cir.vec.ternary
+// LLVM-LABEL: test_selectpd_128
+// LLVM: select
+// OGCG-LABEL: test_selectpd_128
+// OGCG: select
+__m128d test_selectpd_128(__mmask8 k, __m128d a, __m128d b) {
+ return _mm_selectpd_128(k, a, b);
+}
+
+// CIR-LABEL: test_selectpd_256
+// CIR: cir.vec.ternary
+// LLVM-LABEL: test_selectpd_256
+// LLVM: select
+// OGCG-LABEL: test_selectpd_256
+// OGCG: select
+__m256d test_selectpd_256(__mmask8 k, __m256d a, __m256d b) {
+ return _mm256_selectpd(k, a, b);
+}
+
+// CIR-LABEL: test_selectpd_512
+// CIR: cir.vec.ternary
+// LLVM-LABEL: test_selectpd_512
+// LLVM: select
+// OGCG-LABEL: test_selectpd_512
+// OGCG: select
+__m512d test_selectpd_512(__mmask8 k, __m512d a, __m512d b) {
+ return _mm512_selectpd(k, a, b);
+}
+
+// CIR-LABEL: test_selectsh_128
+// CIR: cir.cmp {{.*}} ne
+// CIR: cir.select
+// LLVM-LABEL: test_selectsh_128
+// LLVM: select
+// OGCG-LABEL: test_selectsh_128
+// OGCG: select
+__m128i test_selectsh_128(unsigned short k, __m128i a, __m128i b) {
+ return _mm_selectsh_128(k, a, b);
+}
+
+// CIR-LABEL: test_selectsbf_128
+// CIR: cir.cmp {{.*}} ne
+// CIR: cir.select
+// LLVM-LABEL: test_selectsbf_128
+// LLVM: select
+// OGCG-LABEL: test_selectsbf_128
+// OGCG: select
+__m128i test_selectsbf_128(unsigned short k, __m128i a, __m128i b) {
+ return _mm_selectsbf_128(k, a, b);
+}
+
+// CIR-LABEL: test_selectss_128
+// CIR: cir.cmp {{.*}} ne
+// CIR: cir.select
+// LLVM-LABEL: test_selectss_128
+// LLVM: select
+// OGCG-LABEL: test_selectss_128
+// OGCG: select
+__m128 test_selectss_128(unsigned short k, __m128 a, __m128 b) {
+ return _mm_selectss_128(k, a, b);
+}
+
+// CIR-LABEL: test_selectsd_128
+// CIR: cir.cmp {{.*}} ne
+// CIR: cir.select
+// LLVM-LABEL: test_selectsd_128
+// LLVM: select
+// OGCG-LABEL: test_selectsd_128
+// OGCG: select
+__m128d test_selectsd_128(unsigned short k, __m128d a, __m128d b) {
+ return _mm_selectsd_128(k, a, b);
+}
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
🐧 Linux x64 Test Results
All executed tests passed, but another part of the build failed. Click on a failure below to see the details. tools/clang/lib/CIR/CodeGen/CMakeFiles/obj.clangCIR.dir/CIRGenBuiltinX86.cpp.oIf these failures are unrelated to your changes (for example tests are broken or flaky at HEAD), please open an issue at https://github.com/llvm/llvm-project/issues and add the |
| case X86::BI__builtin_ia32_selectpd_128: | ||
| case X86::BI__builtin_ia32_selectpd_256: | ||
| case X86::BI__builtin_ia32_selectpd_512: | ||
| return emitX86VectorSelect(builder, getLoc(expr->getExprLoc()), ops[0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| return emitX86VectorSelect(builder, getLoc(expr->getExprLoc()), ops[0], | |
| return emitX86Select(builder, getLoc(expr->getExprLoc()), ops[0], |
This function has already been upstreamed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes but it emits selectop while we need to emit vecTernaryOp to match the semantics of vector select.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you pass vector operands to cir.select it performs a vector select. I just tested it and the existing function works as required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok then I'll update it
| auto zeroAttr = builder.getZeroAttr(mask.getType()); | ||
| mlir::Value zero = | ||
| cir::ConstantOp::create(builder, loc, mask.getType(), zeroAttr); | ||
| mlir::Value cond = builder.createCompare(loc, cir::CmpOpKind::ne, mask, zero); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The classic codegen implementation of this function is casting the mask to a vector of i1 and extracting element zero. You need to do something equivalent to that. Effectively, it's this (assuming mask is an integer):
cond = (bool)(mask & 0x1);
So, for example, if the value of mask is 2, cond should be zero.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
| case X86::BI__builtin_ia32_selectsbf_128: | ||
| case X86::BI__builtin_ia32_selectss_128: | ||
| case X86::BI__builtin_ia32_selectsd_128: | ||
| return emitX86ScalarSelect(builder, getLoc(expr->getExprLoc()), ops[0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not correct. Compare the classic codegen implementation:
Value *A = Builder.CreateExtractElement(Ops[1], (uint64_t)0);
Value *B = Builder.CreateExtractElement(Ops[2], (uint64_t)0);
A = EmitX86ScalarSelect(*this, Ops[0], A, B);
return Builder.CreateInsertElement(Ops[1], A, (uint64_t)0);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
| // LLVM-LABEL: test_selectb_128 | ||
| // LLVM: select <16 x i8> | ||
| // OGCG-LABEL: test_selectb_128 | ||
| // OGCG: select <16 x i8> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add more checks so that we are verifying all the operations that are generated and the correct parameter usage.
Related to: #167765