From 4ceae744833df00df8929d8660d438da71982352 Mon Sep 17 00:00:00 2001 From: Jasmine Tang Date: Tue, 2 Dec 2025 22:17:28 -0800 Subject: [PATCH 1/6] Implement extractf, tests are from clang/test/CodeGen/X86/avx512f-builtins.c --- clang/include/clang/CIR/Dialect/IR/CIROps.td | 6 +- .../CIR/Dialect/IR/CIRTypeConstraints.td | 14 +- clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp | 91 ++++++++- .../CodeGenBuiltins/X86/avx512f-builtins.c | 178 ++++++++++++++++++ 4 files changed, 283 insertions(+), 6 deletions(-) diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index ae199f35cb10e..1540fd022860b 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -1870,8 +1870,8 @@ def CIR_SelectOp : CIR_Op<"select", [ let summary = "Yield one of two values based on a boolean value"; let description = [{ The `cir.select` operation takes three operands. The first operand - `condition` is a boolean value of type `!cir.bool`. The second and the third - operand can be of any CIR types, but their types must be the same. If the + `condition` is either a boolean value of type `!cir.bool` or a boolean vector of type `!cir.bool`. + The second and the third operand can be of any CIR types, but their types must be the same. If the first operand is `true`, the operation yields its second operand. Otherwise, the operation yields its third operand. @@ -1885,7 +1885,7 @@ def CIR_SelectOp : CIR_Op<"select", [ ``` }]; - let arguments = (ins CIR_BoolType:$condition, CIR_AnyType:$true_value, + let arguments = (ins CIR_ScalarOrVectorOf:$condition, CIR_AnyType:$true_value, CIR_AnyType:$false_value); let results = (outs CIR_AnyType:$result); diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td index ddca98eac93ab..dd514d755ce24 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td @@ -250,8 +250,8 @@ def CIR_PtrToArray : CIR_PtrToType; def CIR_AnyVectorType : CIR_TypeBase<"::cir::VectorType", "vector type">; -def CIR_VectorElementType : AnyTypeOf<[CIR_AnyIntOrFloatType, CIR_AnyPtrType], - "any cir integer, floating point or pointer type" +def CIR_VectorElementType : AnyTypeOf<[CIR_AnyBoolType, CIR_AnyIntOrFloatType, CIR_AnyPtrType], + "any cir boolean, integer, floating point or pointer type" > { let cppFunctionName = "isValidVectorTypeElementType"; } @@ -266,6 +266,16 @@ class CIR_VectorTypeOf types, string summary = ""> "vector of " # CIR_TypeSummaries.value, summary)>; +class CIR_VectorOf : CIR_ConfinedType< + CIR_AnyVectorType, + [CIR_ElementTypePred], + "CIR vector of " # T.summary>; + +// Type constraint accepting a either a type T or a vector of type T +// Mimicking LLVMIR's LLVM_ScalarOrVectorOf +class CIR_ScalarOrVectorOf : + AnyTypeOf<[T, CIR_VectorOf]>; + // Vector of integral type def IntegerVector : Type< And<[ diff --git a/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp b/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp index 1b2e3f41479db..16d23e1ae0bfc 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp @@ -152,6 +152,71 @@ computeFullLaneShuffleMask(CIRGenFunction &cgf, const mlir::Value vec, outIndices.resize(numElts); } +static mlir::Value getBoolMaskVecValue(CIRGenBuilderTy &builder, + mlir::Location loc, mlir::Value mask, + unsigned numElems) { + + cir::BoolType boolTy = builder.getBoolTy(); + auto maskTy = cir::VectorType::get( + boolTy, cast(mask.getType()).getWidth()); + mlir::Value maskVec = builder.createBitcast(mask, maskTy); + + if (numElems < 8) { + SmallVector indices; + mlir::Type i32Ty = builder.getSInt32Ty(); + for (auto i : llvm::seq(0, numElems)) + indices.push_back(cir::IntAttr::get(i32Ty, i)); + + maskVec = builder.createVecShuffle(loc, maskVec, maskVec, indices); + } + return maskVec; +} + +// Helper function mirroring OG's bool Constant::isAllOnesValue() +static bool isAllOnesValue(mlir::Value value) { + auto constOp = mlir::dyn_cast_or_null(value.getDefiningOp()); + if (!constOp) + return false; + + // Check for -1 integers + if (auto intAttr = constOp.getValueAttr()) { + return intAttr.getValue().isAllOnes(); + } + + // Check for FP which are bitcasted from -1 integers + if (auto fpAttr = constOp.getValueAttr()) { + return fpAttr.getValue().bitcastToAPInt().isAllOnes(); + } + + // Check for constant vectors with splat values + if (cir::VectorType v = dyn_cast(constOp.getType())) { + if (auto vecAttr = constOp.getValueAttr()) { + if (vecAttr.isSplat()) { + auto splatAttr = vecAttr.getSplatValue(); + if (auto splatInt = mlir::dyn_cast(splatAttr)) { + return splatInt.getValue().isAllOnes(); + } + } + } + } + + return false; +} + +static mlir::Value emitX86Select(CIRGenBuilderTy &builder, mlir::Location loc, + mlir::Value mask, mlir::Value op0, + mlir::Value op1) { + + // If the mask is all ones just return first argument. + if (isAllOnesValue(mask)) + return op0; + + mask = getBoolMaskVecValue(builder, loc, mask, + cast(op0.getType()).getSize()); + + return builder.createSelect(loc, mask, op0, op1); +} + static mlir::Value emitX86MaskAddLogic(CIRGenBuilderTy &builder, mlir::Location loc, const std::string &intrinsicName, @@ -887,7 +952,31 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID, case X86::BI__builtin_ia32_extractf64x2_256_mask: case X86::BI__builtin_ia32_extracti64x2_256_mask: case X86::BI__builtin_ia32_extractf64x2_512_mask: - case X86::BI__builtin_ia32_extracti64x2_512_mask: + case X86::BI__builtin_ia32_extracti64x2_512_mask: { + mlir::Location loc = getLoc(expr->getExprLoc()); + cir::VectorType dstTy = cast(convertType(expr->getType())); + unsigned numElts = dstTy.getSize(); + unsigned srcNumElts = cast(ops[0].getType()).getSize(); + unsigned subVectors = srcNumElts / numElts; + unsigned index = + ops[1].getDefiningOp().getIntValue().getZExtValue(); + + index &= subVectors - 1; // Remove any extra bits. + index *= numElts; + + int64_t indices[16]; + for (unsigned i = 0; i != numElts; ++i) + indices[i] = i + index; + + mlir::Value zero = builder.getNullValue(ops[0].getType(), loc); + mlir::Value res = + builder.createVecShuffle(loc, ops[0], zero, ArrayRef(indices, numElts)); + if (ops.size() == 4) { + res = emitX86Select(builder, loc, ops[3], res, ops[2]); + } + + return res; + } case X86::BI__builtin_ia32_vinsertf128_pd256: case X86::BI__builtin_ia32_vinsertf128_ps256: case X86::BI__builtin_ia32_vinsertf128_si256: diff --git a/clang/test/CIR/CodeGenBuiltins/X86/avx512f-builtins.c b/clang/test/CIR/CodeGenBuiltins/X86/avx512f-builtins.c index cdcdad42b2845..31ba14c88fae9 100644 --- a/clang/test/CIR/CodeGenBuiltins/X86/avx512f-builtins.c +++ b/clang/test/CIR/CodeGenBuiltins/X86/avx512f-builtins.c @@ -695,3 +695,181 @@ void test_mm512_mask_i64scatter_epi32(void *__addr, __mmask8 __mask, __m512i __i // OGCG: @llvm.x86.avx512.mask.scatter.qpi.512 return _mm512_mask_i64scatter_epi32(__addr, __mask, __index, __v1, 2); } + +__m256d test_mm512_extractf64x4_pd(__m512d a) +{ + // CIR-LABEL: test_mm512_extractf64x4_pd + // CIR: [[ZERO:%.*]] = cir.const #cir.zero : !cir.vector<8 x !cir.double> + // CIR: cir.vec.shuffle(%{{.*}}, [[ZERO]] : !cir.vector<8 x !cir.double>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !cir.double> + + // LLVM-LABEL: test_mm512_extractf64x4_pd + // LLVM: shufflevector <8 x double> %{{.*}}, <8 x double> zeroinitializer, <4 x i32> + + // OGCG-LABEL: test_mm512_extractf64x4_pd + // OGCG: shufflevector <8 x double> %{{.*}}, <8 x double> poison, <4 x i32> + return _mm512_extractf64x4_pd(a, 1); +} + +__m256d test_mm512_mask_extractf64x4_pd(__m256d __W,__mmask8 __U,__m512d __A){ + // CIR-LABEL: test_mm512_mask_extractf64x4_pd + // CIR: [[ZERO:%.*]] = cir.const #cir.zero : !cir.vector<8 x !cir.double> + // CIR: cir.vec.shuffle(%{{.*}}, [[ZERO]] : !cir.vector<8 x !cir.double>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !cir.double> + // CIR: cir.select if %{{.*}} then %{{.*}} else %{{.*}} : (!cir.vector<4 x !cir.bool>, !cir.vector<4 x !cir.double>, !cir.vector<4 x !cir.double>) -> !cir.vector<4 x !cir.double> + + // LLVM-LABEL: test_mm512_mask_extractf64x4_pd + // LLVM: shufflevector <8 x double> %{{.*}}, <8 x double> zeroinitializer, <4 x i32> + // LLVM: select <4 x i1> %{{.*}}, <4 x double> %{{.*}}, <4 x double> %{{.*}} + + // OGCG-LABEL: test_mm512_mask_extractf64x4_pd + // OGCG: shufflevector <8 x double> %{{.*}}, <8 x double> poison, <4 x i32> + // OGCG: select <4 x i1> %{{.*}}, <4 x double> %{{.*}}, <4 x double> %{{.*}} + return _mm512_mask_extractf64x4_pd( __W, __U, __A, 1); +} + +__m256d test_mm512_maskz_extractf64x4_pd(__mmask8 __U,__m512d __A){ + // CIR-LABEL: test_mm512_maskz_extractf64x4_pd + // CIR: [[ZERO:%.*]] = cir.const #cir.zero : !cir.vector<8 x !cir.double> + // CIR: cir.vec.shuffle(%{{.*}}, [[ZERO]] : !cir.vector<8 x !cir.double>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !cir.double> + // CIR: cir.select if %{{.*}} then %{{.*}} else %{{.*}} : (!cir.vector<4 x !cir.bool>, !cir.vector<4 x !cir.double>, !cir.vector<4 x !cir.double>) -> !cir.vector<4 x !cir.double> + + // LLVM-LABEL: test_mm512_maskz_extractf64x4_pd + // LLVM: shufflevector <8 x double> %{{.*}}, <8 x double> zeroinitializer, <4 x i32> + // LLVM: select <4 x i1> %{{.*}}, <4 x double> %{{.*}}, <4 x double> %{{.*}} + + // OGCG-LABEL: test_mm512_maskz_extractf64x4_pd + // OGCG: shufflevector <8 x double> %{{.*}}, <8 x double> poison, <4 x i32> + // OGCG: select <4 x i1> %{{.*}}, <4 x double> %{{.*}}, <4 x double> %{{.*}} + return _mm512_maskz_extractf64x4_pd( __U, __A, 1); +} + +__m128 test_mm512_extractf32x4_ps(__m512 a) +{ + // CIR-LABEL: test_mm512_extractf32x4_ps + // CIR: [[ZERO:%.*]] = cir.const #cir.zero : !cir.vector<16 x !cir.float> + // CIR: cir.vec.shuffle(%{{.*}}, [[ZERO]] : !cir.vector<16 x !cir.float>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !cir.float> + + // LLVM-LABEL: test_mm512_extractf32x4_ps + // LLVM: shufflevector <16 x float> %{{.*}}, <16 x float> zeroinitializer, <4 x i32> + + // OGCG-LABEL: test_mm512_extractf32x4_ps + // OGCG: shufflevector <16 x float> %{{.*}}, <16 x float> poison, <4 x i32> + return _mm512_extractf32x4_ps(a, 1); +} + +__m128 test_mm512_mask_extractf32x4_ps(__m128 __W, __mmask8 __U,__m512 __A){ + // CIR-LABEL: test_mm512_mask_extractf32x4_ps + // CIR: [[ZERO:%.*]] = cir.const #cir.zero : !cir.vector<16 x !cir.float> + // CIR: cir.vec.shuffle(%{{.*}}, [[ZERO]] : !cir.vector<16 x !cir.float>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !cir.float> + + // LLVM-LABEL: test_mm512_mask_extractf32x4_ps + // LLVM: shufflevector <16 x float> %{{.*}}, <16 x float> zeroinitializer, <4 x i32> + // LLVM: select <4 x i1> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}} + + // OGCG-LABEL: test_mm512_mask_extractf32x4_ps + // OGCG: shufflevector <16 x float> %{{.*}}, <16 x float> poison, <4 x i32> + // OGCG: select <4 x i1> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}} + return _mm512_mask_extractf32x4_ps( __W, __U, __A, 1); +} + +__m128 test_mm512_maskz_extractf32x4_ps( __mmask8 __U,__m512 __A){ + // CIR-LABEL: test_mm512_maskz_extractf32x4_ps + // CIR: [[ZERO:%.*]] = cir.const #cir.zero : !cir.vector<16 x !cir.float> + // CIR: cir.vec.shuffle(%{{.*}}, [[ZERO]] : !cir.vector<16 x !cir.float>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !cir.float> + + // LLVM-LABEL: test_mm512_maskz_extractf32x4_ps + // LLVM: shufflevector <16 x float> %{{.*}}, <16 x float> zeroinitializer, <4 x i32> + // LLVM: select <4 x i1> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}} + + // OGCG-LABEL: test_mm512_maskz_extractf32x4_ps + // OGCG: shufflevector <16 x float> %{{.*}}, <16 x float> poison, <4 x i32> + // OGCG: select <4 x i1> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}} + return _mm512_maskz_extractf32x4_ps(__U, __A, 1); +} + +__m128i test_mm512_extracti32x4_epi32(__m512i __A) { + // CIR-LABEL: test_mm512_extracti32x4_epi32 + // CIR: [[ZERO:%.*]] = cir.const #cir.zero : !cir.vector<16 x !s32i> + // CIR: cir.vec.shuffle(%{{.*}}, [[ZERO]] : !cir.vector<16 x !s32i>) [#cir.int<12> : !s32i, #cir.int<13> : !s32i, #cir.int<14> : !s32i, #cir.int<15> : !s32i] : !cir.vector<4 x !s32i> + + // LLVM-LABEL: test_mm512_extracti32x4_epi32 + // LLVM: shufflevector <16 x i32> %{{.*}}, <16 x i32> zeroinitializer, <4 x i32> + + // OGCG-LABEL: test_mm512_extracti32x4_epi32 + // OGCG: shufflevector <16 x i32> %{{.*}}, <16 x i32> poison, <4 x i32> + return _mm512_extracti32x4_epi32(__A, 3); +} + +__m128i test_mm512_mask_extracti32x4_epi32(__m128i __W, __mmask8 __U, __m512i __A) { + // CIR-LABEL: test_mm512_mask_extracti32x4_epi32 + // CIR: [[ZERO:%.*]] = cir.const #cir.zero : !cir.vector<16 x !s32i> + // CIR: cir.vec.shuffle(%{{.*}}, [[ZERO]] : !cir.vector<16 x !s32i>) [#cir.int<12> : !s32i, #cir.int<13> : !s32i, #cir.int<14> : !s32i, #cir.int<15> : !s32i] : !cir.vector<4 x !s32i> + // CIR: cir.select if %{{.*}} then %{{.*}} else %{{.*}} : (!cir.vector<4 x !cir.bool>, !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i> + + // LLVM-LABEL: test_mm512_mask_extracti32x4_epi32 + // LLVM: shufflevector <16 x i32> %{{.*}}, <16 x i32> zeroinitializer, <4 x i32> + // LLVM: select <4 x i1> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}} + + // OGCG-LABEL: test_mm512_mask_extracti32x4_epi32 + // OGCG: shufflevector <16 x i32> %{{.*}}, <16 x i32> poison, <4 x i32> + // OGCG: select <4 x i1> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}} + return _mm512_mask_extracti32x4_epi32(__W, __U, __A, 3); +} + +__m128i test_mm512_maskz_extracti32x4_epi32(__mmask8 __U, __m512i __A) { + // CIR-LABEL: test_mm512_maskz_extracti32x4_epi32 + // CIR: [[ZERO:%.*]] = cir.const #cir.zero : !cir.vector<16 x !s32i> + // CIR: cir.vec.shuffle(%{{.*}}, [[ZERO]] : !cir.vector<16 x !s32i>) [#cir.int<12> : !s32i, #cir.int<13> : !s32i, #cir.int<14> : !s32i, #cir.int<15> : !s32i] : !cir.vector<4 x !s32i> + // CIR: cir.select if %{{.*}} then %{{.*}} else %{{.*}} : (!cir.vector<4 x !cir.bool>, !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i> + + // LLVM-LABEL: test_mm512_maskz_extracti32x4_epi32 + // LLVM: shufflevector <16 x i32> %{{.*}}, <16 x i32> zeroinitializer, <4 x i32> + // LLVM: select <4 x i1> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}} + + // OGCG-LABEL: test_mm512_maskz_extracti32x4_epi32 + // OGCG: shufflevector <16 x i32> %{{.*}}, <16 x i32> poison, <4 x i32> + // OGCG: select <4 x i1> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}} + return _mm512_maskz_extracti32x4_epi32(__U, __A, 3); +} + +__m256i test_mm512_extracti64x4_epi64(__m512i __A) { + // CIR-LABEL: test_mm512_extracti64x4_epi64 + // CIR: [[ZERO:%.*]] = cir.const #cir.zero : !cir.vector<8 x !s64i> + // CIR: cir.vec.shuffle(%{{.*}}, [[ZERO]] : !cir.vector<8 x !s64i>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !s64i> + + // LLVM-LABEL: test_mm512_extracti64x4_epi64 + // LLVM: shufflevector <8 x i64> %{{.*}}, <8 x i64> zeroinitializer, <4 x i32> + + // OGCG-LABEL: test_mm512_extracti64x4_epi64 + // OGCG: shufflevector <8 x i64> %{{.*}}, <8 x i64> poison, <4 x i32> + return _mm512_extracti64x4_epi64(__A, 1); +} + +__m256i test_mm512_mask_extracti64x4_epi64(__m256i __W, __mmask8 __U, __m512i __A) { + // CIR-LABEL: test_mm512_mask_extracti64x4_epi64 + // CIR: [[ZERO:%.*]] = cir.const #cir.zero : !cir.vector<8 x !s64i> + // CIR: cir.vec.shuffle(%{{.*}}, [[ZERO]] : !cir.vector<8 x !s64i>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !s64i> + + // LLVM-LABEL: test_mm512_mask_extracti64x4_epi64 + // LLVM: shufflevector <8 x i64> %{{.*}}, <8 x i64> zeroinitializer, <4 x i32> + // LLVM: select <4 x i1> %{{.*}}, <4 x i64> %{{.*}}, <4 x i64> %{{.*}} + + // OGCG-LABEL: test_mm512_mask_extracti64x4_epi64 + // OGCG: shufflevector <8 x i64> %{{.*}}, <8 x i64> poison, <4 x i32> + // OGCG: select <4 x i1> %{{.*}}, <4 x i64> %{{.*}}, <4 x i64> %{{.*}} + return _mm512_mask_extracti64x4_epi64(__W, __U, __A, 1); +} + +__m256i test_mm512_maskz_extracti64x4_epi64(__mmask8 __U, __m512i __A) { + // CIR-LABEL: test_mm512_maskz_extracti64x4_epi64 + // CIR: [[ZERO:%.*]] = cir.const #cir.zero : !cir.vector<8 x !s64i> + // CIR: cir.vec.shuffle(%{{.*}}, [[ZERO]] : !cir.vector<8 x !s64i>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !s64i> + + // LLVM-LABEL: test_mm512_maskz_extracti64x4_epi64 + // LLVM: shufflevector <8 x i64> %{{.*}}, <8 x i64> zeroinitializer, <4 x i32> + // LLVM: select <4 x i1> %{{.*}}, <4 x i64> %{{.*}}, <4 x i64> %{{.*}} + + // OGCG-LABEL: test_mm512_maskz_extracti64x4_epi64 + // OGCG: shufflevector <8 x i64> %{{.*}}, <8 x i64> poison, <4 x i32> + // OGCG: select <4 x i1> %{{.*}}, <4 x i64> %{{.*}}, <4 x i64> %{{.*}} + return _mm512_maskz_extracti64x4_epi64(__U, __A, 1); +} From 7de533b101132f647d4239d1e911bd8a4aef650c Mon Sep 17 00:00:00 2001 From: Jasmine Tang Date: Wed, 3 Dec 2025 10:56:56 -0800 Subject: [PATCH 2/6] Resolve PR reviews --- clang/include/clang/CIR/Dialect/IR/CIROps.td | 34 ++++++++++++++- .../CIR/Dialect/IR/CIRTypeConstraints.td | 9 +--- clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp | 42 +++---------------- 3 files changed, 38 insertions(+), 47 deletions(-) diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 1540fd022860b..731243cd1ba92 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -425,6 +425,32 @@ def CIR_ConstantOp : CIR_Op<"const", [ return boolAttr.getValue(); llvm_unreachable("Expected a BoolAttr in ConstantOp"); } + static bool isAllOnesValue(mlir::Value value) { + auto constOp = mlir::dyn_cast_or_null(value.getDefiningOp()); + if (!constOp) + return false; + + // Check for -1 integers + if (auto intAttr = constOp.getValueAttr()) + return intAttr.getValue().isAllOnes(); + + // Check for FP which are bitcasted from -1 integers + if (auto fpAttr = constOp.getValueAttr()) + return fpAttr.getValue().bitcastToAPInt().isAllOnes(); + + + // Check for constant vectors with splat values + if (cir::VectorType v = mlir::dyn_cast(constOp.getType())) + if (auto vecAttr = constOp.getValueAttr()) + if (vecAttr.isSplat()) { + auto splatAttr = vecAttr.getSplatValue(); + if (auto splatInt = mlir::dyn_cast(splatAttr)) { + return splatInt.getValue().isAllOnes(); + } + } + + return false; + } }]; let hasFolder = 1; @@ -1885,8 +1911,12 @@ def CIR_SelectOp : CIR_Op<"select", [ ``` }]; - let arguments = (ins CIR_ScalarOrVectorOf:$condition, CIR_AnyType:$true_value, - CIR_AnyType:$false_value); + let arguments = (ins + CIR_ScalarOrVectorOf:$condition, + CIR_AnyType:$true_value, + CIR_AnyType:$false_value + ); + let results = (outs CIR_AnyType:$result); let assemblyFormat = [{ diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td index dd514d755ce24..fa315c60587fb 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td @@ -266,15 +266,8 @@ class CIR_VectorTypeOf types, string summary = ""> "vector of " # CIR_TypeSummaries.value, summary)>; -class CIR_VectorOf : CIR_ConfinedType< - CIR_AnyVectorType, - [CIR_ElementTypePred], - "CIR vector of " # T.summary>; - // Type constraint accepting a either a type T or a vector of type T -// Mimicking LLVMIR's LLVM_ScalarOrVectorOf -class CIR_ScalarOrVectorOf : - AnyTypeOf<[T, CIR_VectorOf]>; +class CIR_ScalarOrVectorOf : AnyTypeOf<[T, CIR_VectorTypeOf<[T]>]>; // Vector of integral type def IntegerVector : Type< diff --git a/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp b/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp index 16d23e1ae0bfc..3a4d9caa76e0c 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp @@ -163,6 +163,7 @@ static mlir::Value getBoolMaskVecValue(CIRGenBuilderTy &builder, if (numElems < 8) { SmallVector indices; + indices.reserve(numElems); mlir::Type i32Ty = builder.getSInt32Ty(); for (auto i : llvm::seq(0, numElems)) indices.push_back(cir::IntAttr::get(i32Ty, i)); @@ -172,43 +173,11 @@ static mlir::Value getBoolMaskVecValue(CIRGenBuilderTy &builder, return maskVec; } -// Helper function mirroring OG's bool Constant::isAllOnesValue() -static bool isAllOnesValue(mlir::Value value) { - auto constOp = mlir::dyn_cast_or_null(value.getDefiningOp()); - if (!constOp) - return false; - - // Check for -1 integers - if (auto intAttr = constOp.getValueAttr()) { - return intAttr.getValue().isAllOnes(); - } - - // Check for FP which are bitcasted from -1 integers - if (auto fpAttr = constOp.getValueAttr()) { - return fpAttr.getValue().bitcastToAPInt().isAllOnes(); - } - - // Check for constant vectors with splat values - if (cir::VectorType v = dyn_cast(constOp.getType())) { - if (auto vecAttr = constOp.getValueAttr()) { - if (vecAttr.isSplat()) { - auto splatAttr = vecAttr.getSplatValue(); - if (auto splatInt = mlir::dyn_cast(splatAttr)) { - return splatInt.getValue().isAllOnes(); - } - } - } - } - - return false; -} - static mlir::Value emitX86Select(CIRGenBuilderTy &builder, mlir::Location loc, mlir::Value mask, mlir::Value op0, mlir::Value op1) { - // If the mask is all ones just return first argument. - if (isAllOnesValue(mask)) + if (cir::ConstantOp::isAllOnesValue(mask)) return op0; mask = getBoolMaskVecValue(builder, loc, mask, @@ -958,6 +927,7 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID, unsigned numElts = dstTy.getSize(); unsigned srcNumElts = cast(ops[0].getType()).getSize(); unsigned subVectors = srcNumElts / numElts; + assert(llvm::isPowerOf2_32(subVectors) && "Expected power of 2 subvectors"); unsigned index = ops[1].getDefiningOp().getIntValue().getZExtValue(); @@ -965,15 +935,13 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID, index *= numElts; int64_t indices[16]; - for (unsigned i = 0; i != numElts; ++i) - indices[i] = i + index; + std::iota(indices, indices + numElts, index); mlir::Value zero = builder.getNullValue(ops[0].getType(), loc); mlir::Value res = builder.createVecShuffle(loc, ops[0], zero, ArrayRef(indices, numElts)); - if (ops.size() == 4) { + if (ops.size() == 4) res = emitX86Select(builder, loc, ops[3], res, ops[2]); - } return res; } From f1d16e88a9c9021129831cae9a4124cb7003f7ba Mon Sep 17 00:00:00 2001 From: Jasmine Tang Date: Tue, 9 Dec 2025 03:31:17 -0800 Subject: [PATCH 3/6] Address PR comments except verifiers --- clang/include/clang/CIR/Dialect/IR/CIROps.td | 29 ++++---- .../CIR/Dialect/IR/CIRTypeConstraints.td | 8 +-- clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp | 13 ++-- .../CodeGenBuiltins/X86/avx512f-builtins.c | 72 +++++++++---------- 4 files changed, 63 insertions(+), 59 deletions(-) diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 731243cd1ba92..bb11aaa542984 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -425,23 +425,18 @@ def CIR_ConstantOp : CIR_Op<"const", [ return boolAttr.getValue(); llvm_unreachable("Expected a BoolAttr in ConstantOp"); } - static bool isAllOnesValue(mlir::Value value) { - auto constOp = mlir::dyn_cast_or_null(value.getDefiningOp()); - if (!constOp) - return false; - + bool isAllOnesValue() { // Check for -1 integers - if (auto intAttr = constOp.getValueAttr()) + if (auto intAttr = getValueAttr()) return intAttr.getValue().isAllOnes(); // Check for FP which are bitcasted from -1 integers - if (auto fpAttr = constOp.getValueAttr()) + if (auto fpAttr = getValueAttr()) return fpAttr.getValue().bitcastToAPInt().isAllOnes(); - // Check for constant vectors with splat values - if (cir::VectorType v = mlir::dyn_cast(constOp.getType())) - if (auto vecAttr = constOp.getValueAttr()) + if (cir::VectorType v = mlir::dyn_cast(getType())) + if (auto vecAttr = getValueAttr()) if (vecAttr.isSplat()) { auto splatAttr = vecAttr.getSplatValue(); if (auto splatInt = mlir::dyn_cast(splatAttr)) { @@ -1896,10 +1891,16 @@ def CIR_SelectOp : CIR_Op<"select", [ let summary = "Yield one of two values based on a boolean value"; let description = [{ The `cir.select` operation takes three operands. The first operand - `condition` is either a boolean value of type `!cir.bool` or a boolean vector of type `!cir.bool`. - The second and the third operand can be of any CIR types, but their types must be the same. If the - first operand is `true`, the operation yields its second operand. Otherwise, - the operation yields its third operand. + `condition` is either a boolean value of type `!cir.bool` or a boolean + vector of type `!cir.bool`. The second and the third operand can be of + any CIR types, but their types must be the same. If the first operand + is `true`, the operation yields its second operand. Otherwise, the + operation yields its third operand. + + In the case where the first operand is a boolean vector, then the second + and third operand needs to also be of some vectors of the same type to + each other and that the number of elements of all three operands needs to + be the same as well. Example: diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td index fa315c60587fb..7ff4c4b70d0d2 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td @@ -250,10 +250,10 @@ def CIR_PtrToArray : CIR_PtrToType; def CIR_AnyVectorType : CIR_TypeBase<"::cir::VectorType", "vector type">; -def CIR_VectorElementType : AnyTypeOf<[CIR_AnyBoolType, CIR_AnyIntOrFloatType, CIR_AnyPtrType], - "any cir boolean, integer, floating point or pointer type" -> { - let cppFunctionName = "isValidVectorTypeElementType"; +def CIR_VectorElementType + : AnyTypeOf<[CIR_AnyBoolType, CIR_AnyIntOrFloatType, CIR_AnyPtrType], + "any cir boolean, integer, floating point or pointer type"> { + let cppFunctionName = "isValidVectorTypeElementType"; } class CIR_ElementTypePred : SubstLeaves<"$_self", diff --git a/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp b/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp index 3a4d9caa76e0c..8f02d6d989efd 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp @@ -20,6 +20,7 @@ #include "clang/Basic/TargetBuiltins.h" #include "clang/CIR/Dialect/IR/CIRTypes.h" #include "clang/CIR/MissingFeatures.h" +#include "llvm/Support/ErrorHandling.h" using namespace clang; using namespace clang::CIRGen; @@ -162,7 +163,7 @@ static mlir::Value getBoolMaskVecValue(CIRGenBuilderTy &builder, mlir::Value maskVec = builder.createBitcast(mask, maskTy); if (numElems < 8) { - SmallVector indices; + SmallVector indices; indices.reserve(numElems); mlir::Type i32Ty = builder.getSInt32Ty(); for (auto i : llvm::seq(0, numElems)) @@ -176,8 +177,9 @@ static mlir::Value getBoolMaskVecValue(CIRGenBuilderTy &builder, static mlir::Value emitX86Select(CIRGenBuilderTy &builder, mlir::Location loc, mlir::Value mask, mlir::Value op0, mlir::Value op1) { + auto constOp = mlir::dyn_cast_or_null(mask.getDefiningOp()); // If the mask is all ones just return first argument. - if (cir::ConstantOp::isAllOnesValue(mask)) + if (constOp && constOp.isAllOnesValue()) return op0; mask = getBoolMaskVecValue(builder, loc, mask, @@ -937,9 +939,10 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID, int64_t indices[16]; std::iota(indices, indices + numElts, index); - mlir::Value zero = builder.getNullValue(ops[0].getType(), loc); - mlir::Value res = - builder.createVecShuffle(loc, ops[0], zero, ArrayRef(indices, numElts)); + mlir::Value poison = + builder.getConstant(loc, cir::PoisonAttr::get(ops[0].getType())); + mlir::Value res = builder.createVecShuffle(loc, ops[0], poison, + ArrayRef(indices, numElts)); if (ops.size() == 4) res = emitX86Select(builder, loc, ops[3], res, ops[2]); diff --git a/clang/test/CIR/CodeGenBuiltins/X86/avx512f-builtins.c b/clang/test/CIR/CodeGenBuiltins/X86/avx512f-builtins.c index 31ba14c88fae9..33756537317c6 100644 --- a/clang/test/CIR/CodeGenBuiltins/X86/avx512f-builtins.c +++ b/clang/test/CIR/CodeGenBuiltins/X86/avx512f-builtins.c @@ -699,11 +699,11 @@ void test_mm512_mask_i64scatter_epi32(void *__addr, __mmask8 __mask, __m512i __i __m256d test_mm512_extractf64x4_pd(__m512d a) { // CIR-LABEL: test_mm512_extractf64x4_pd - // CIR: [[ZERO:%.*]] = cir.const #cir.zero : !cir.vector<8 x !cir.double> - // CIR: cir.vec.shuffle(%{{.*}}, [[ZERO]] : !cir.vector<8 x !cir.double>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !cir.double> + // CIR: [[POISON:%.*]] = cir.const #cir.poison : !cir.vector<8 x !cir.double> + // CIR: cir.vec.shuffle(%{{.*}}, [[POISON]] : !cir.vector<8 x !cir.double>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !cir.double> // LLVM-LABEL: test_mm512_extractf64x4_pd - // LLVM: shufflevector <8 x double> %{{.*}}, <8 x double> zeroinitializer, <4 x i32> + // LLVM: shufflevector <8 x double> %{{.*}}, <8 x double> poison, <4 x i32> // OGCG-LABEL: test_mm512_extractf64x4_pd // OGCG: shufflevector <8 x double> %{{.*}}, <8 x double> poison, <4 x i32> @@ -712,12 +712,12 @@ __m256d test_mm512_extractf64x4_pd(__m512d a) __m256d test_mm512_mask_extractf64x4_pd(__m256d __W,__mmask8 __U,__m512d __A){ // CIR-LABEL: test_mm512_mask_extractf64x4_pd - // CIR: [[ZERO:%.*]] = cir.const #cir.zero : !cir.vector<8 x !cir.double> - // CIR: cir.vec.shuffle(%{{.*}}, [[ZERO]] : !cir.vector<8 x !cir.double>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !cir.double> + // CIR: [[POISON:%.*]] = cir.const #cir.poison : !cir.vector<8 x !cir.double> + // CIR: cir.vec.shuffle(%{{.*}}, [[POISON]] : !cir.vector<8 x !cir.double>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !cir.double> // CIR: cir.select if %{{.*}} then %{{.*}} else %{{.*}} : (!cir.vector<4 x !cir.bool>, !cir.vector<4 x !cir.double>, !cir.vector<4 x !cir.double>) -> !cir.vector<4 x !cir.double> // LLVM-LABEL: test_mm512_mask_extractf64x4_pd - // LLVM: shufflevector <8 x double> %{{.*}}, <8 x double> zeroinitializer, <4 x i32> + // LLVM: shufflevector <8 x double> %{{.*}}, <8 x double> poison, <4 x i32> // LLVM: select <4 x i1> %{{.*}}, <4 x double> %{{.*}}, <4 x double> %{{.*}} // OGCG-LABEL: test_mm512_mask_extractf64x4_pd @@ -728,12 +728,12 @@ __m256d test_mm512_mask_extractf64x4_pd(__m256d __W,__mmask8 __U,__m512d __A){ __m256d test_mm512_maskz_extractf64x4_pd(__mmask8 __U,__m512d __A){ // CIR-LABEL: test_mm512_maskz_extractf64x4_pd - // CIR: [[ZERO:%.*]] = cir.const #cir.zero : !cir.vector<8 x !cir.double> - // CIR: cir.vec.shuffle(%{{.*}}, [[ZERO]] : !cir.vector<8 x !cir.double>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !cir.double> + // CIR: [[POISON:%.*]] = cir.const #cir.poison : !cir.vector<8 x !cir.double> + // CIR: cir.vec.shuffle(%{{.*}}, [[POISON]] : !cir.vector<8 x !cir.double>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !cir.double> // CIR: cir.select if %{{.*}} then %{{.*}} else %{{.*}} : (!cir.vector<4 x !cir.bool>, !cir.vector<4 x !cir.double>, !cir.vector<4 x !cir.double>) -> !cir.vector<4 x !cir.double> // LLVM-LABEL: test_mm512_maskz_extractf64x4_pd - // LLVM: shufflevector <8 x double> %{{.*}}, <8 x double> zeroinitializer, <4 x i32> + // LLVM: shufflevector <8 x double> %{{.*}}, <8 x double> poison, <4 x i32> // LLVM: select <4 x i1> %{{.*}}, <4 x double> %{{.*}}, <4 x double> %{{.*}} // OGCG-LABEL: test_mm512_maskz_extractf64x4_pd @@ -745,11 +745,11 @@ __m256d test_mm512_maskz_extractf64x4_pd(__mmask8 __U,__m512d __A){ __m128 test_mm512_extractf32x4_ps(__m512 a) { // CIR-LABEL: test_mm512_extractf32x4_ps - // CIR: [[ZERO:%.*]] = cir.const #cir.zero : !cir.vector<16 x !cir.float> - // CIR: cir.vec.shuffle(%{{.*}}, [[ZERO]] : !cir.vector<16 x !cir.float>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !cir.float> + // CIR: [[POISON:%.*]] = cir.const #cir.poison : !cir.vector<16 x !cir.float> + // CIR: cir.vec.shuffle(%{{.*}}, [[POISON]] : !cir.vector<16 x !cir.float>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !cir.float> // LLVM-LABEL: test_mm512_extractf32x4_ps - // LLVM: shufflevector <16 x float> %{{.*}}, <16 x float> zeroinitializer, <4 x i32> + // LLVM: shufflevector <16 x float> %{{.*}}, <16 x float> poison, <4 x i32> // OGCG-LABEL: test_mm512_extractf32x4_ps // OGCG: shufflevector <16 x float> %{{.*}}, <16 x float> poison, <4 x i32> @@ -758,11 +758,11 @@ __m128 test_mm512_extractf32x4_ps(__m512 a) __m128 test_mm512_mask_extractf32x4_ps(__m128 __W, __mmask8 __U,__m512 __A){ // CIR-LABEL: test_mm512_mask_extractf32x4_ps - // CIR: [[ZERO:%.*]] = cir.const #cir.zero : !cir.vector<16 x !cir.float> - // CIR: cir.vec.shuffle(%{{.*}}, [[ZERO]] : !cir.vector<16 x !cir.float>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !cir.float> + // CIR: [[POISON:%.*]] = cir.const #cir.poison : !cir.vector<16 x !cir.float> + // CIR: cir.vec.shuffle(%{{.*}}, [[POISON]] : !cir.vector<16 x !cir.float>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !cir.float> // LLVM-LABEL: test_mm512_mask_extractf32x4_ps - // LLVM: shufflevector <16 x float> %{{.*}}, <16 x float> zeroinitializer, <4 x i32> + // LLVM: shufflevector <16 x float> %{{.*}}, <16 x float> poison, <4 x i32> // LLVM: select <4 x i1> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}} // OGCG-LABEL: test_mm512_mask_extractf32x4_ps @@ -773,11 +773,11 @@ __m128 test_mm512_mask_extractf32x4_ps(__m128 __W, __mmask8 __U,__m512 __A){ __m128 test_mm512_maskz_extractf32x4_ps( __mmask8 __U,__m512 __A){ // CIR-LABEL: test_mm512_maskz_extractf32x4_ps - // CIR: [[ZERO:%.*]] = cir.const #cir.zero : !cir.vector<16 x !cir.float> - // CIR: cir.vec.shuffle(%{{.*}}, [[ZERO]] : !cir.vector<16 x !cir.float>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !cir.float> + // CIR: [[POISON:%.*]] = cir.const #cir.poison : !cir.vector<16 x !cir.float> + // CIR: cir.vec.shuffle(%{{.*}}, [[POISON]] : !cir.vector<16 x !cir.float>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !cir.float> // LLVM-LABEL: test_mm512_maskz_extractf32x4_ps - // LLVM: shufflevector <16 x float> %{{.*}}, <16 x float> zeroinitializer, <4 x i32> + // LLVM: shufflevector <16 x float> %{{.*}}, <16 x float> poison, <4 x i32> // LLVM: select <4 x i1> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}} // OGCG-LABEL: test_mm512_maskz_extractf32x4_ps @@ -788,11 +788,11 @@ __m128 test_mm512_maskz_extractf32x4_ps( __mmask8 __U,__m512 __A){ __m128i test_mm512_extracti32x4_epi32(__m512i __A) { // CIR-LABEL: test_mm512_extracti32x4_epi32 - // CIR: [[ZERO:%.*]] = cir.const #cir.zero : !cir.vector<16 x !s32i> - // CIR: cir.vec.shuffle(%{{.*}}, [[ZERO]] : !cir.vector<16 x !s32i>) [#cir.int<12> : !s32i, #cir.int<13> : !s32i, #cir.int<14> : !s32i, #cir.int<15> : !s32i] : !cir.vector<4 x !s32i> + // CIR: [[POISON:%.*]] = cir.const #cir.poison : !cir.vector<16 x !s32i> + // CIR: cir.vec.shuffle(%{{.*}}, [[POISON]] : !cir.vector<16 x !s32i>) [#cir.int<12> : !s32i, #cir.int<13> : !s32i, #cir.int<14> : !s32i, #cir.int<15> : !s32i] : !cir.vector<4 x !s32i> // LLVM-LABEL: test_mm512_extracti32x4_epi32 - // LLVM: shufflevector <16 x i32> %{{.*}}, <16 x i32> zeroinitializer, <4 x i32> + // LLVM: shufflevector <16 x i32> %{{.*}}, <16 x i32> poison, <4 x i32> // OGCG-LABEL: test_mm512_extracti32x4_epi32 // OGCG: shufflevector <16 x i32> %{{.*}}, <16 x i32> poison, <4 x i32> @@ -801,12 +801,12 @@ __m128i test_mm512_extracti32x4_epi32(__m512i __A) { __m128i test_mm512_mask_extracti32x4_epi32(__m128i __W, __mmask8 __U, __m512i __A) { // CIR-LABEL: test_mm512_mask_extracti32x4_epi32 - // CIR: [[ZERO:%.*]] = cir.const #cir.zero : !cir.vector<16 x !s32i> - // CIR: cir.vec.shuffle(%{{.*}}, [[ZERO]] : !cir.vector<16 x !s32i>) [#cir.int<12> : !s32i, #cir.int<13> : !s32i, #cir.int<14> : !s32i, #cir.int<15> : !s32i] : !cir.vector<4 x !s32i> + // CIR: [[POISON:%.*]] = cir.const #cir.poison : !cir.vector<16 x !s32i> + // CIR: cir.vec.shuffle(%{{.*}}, [[POISON]] : !cir.vector<16 x !s32i>) [#cir.int<12> : !s32i, #cir.int<13> : !s32i, #cir.int<14> : !s32i, #cir.int<15> : !s32i] : !cir.vector<4 x !s32i> // CIR: cir.select if %{{.*}} then %{{.*}} else %{{.*}} : (!cir.vector<4 x !cir.bool>, !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i> // LLVM-LABEL: test_mm512_mask_extracti32x4_epi32 - // LLVM: shufflevector <16 x i32> %{{.*}}, <16 x i32> zeroinitializer, <4 x i32> + // LLVM: shufflevector <16 x i32> %{{.*}}, <16 x i32> poison, <4 x i32> // LLVM: select <4 x i1> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}} // OGCG-LABEL: test_mm512_mask_extracti32x4_epi32 @@ -817,12 +817,12 @@ __m128i test_mm512_mask_extracti32x4_epi32(__m128i __W, __mmask8 __U, __m512i __ __m128i test_mm512_maskz_extracti32x4_epi32(__mmask8 __U, __m512i __A) { // CIR-LABEL: test_mm512_maskz_extracti32x4_epi32 - // CIR: [[ZERO:%.*]] = cir.const #cir.zero : !cir.vector<16 x !s32i> - // CIR: cir.vec.shuffle(%{{.*}}, [[ZERO]] : !cir.vector<16 x !s32i>) [#cir.int<12> : !s32i, #cir.int<13> : !s32i, #cir.int<14> : !s32i, #cir.int<15> : !s32i] : !cir.vector<4 x !s32i> + // CIR: [[POISON:%.*]] = cir.const #cir.poison : !cir.vector<16 x !s32i> + // CIR: cir.vec.shuffle(%{{.*}}, [[POISON]] : !cir.vector<16 x !s32i>) [#cir.int<12> : !s32i, #cir.int<13> : !s32i, #cir.int<14> : !s32i, #cir.int<15> : !s32i] : !cir.vector<4 x !s32i> // CIR: cir.select if %{{.*}} then %{{.*}} else %{{.*}} : (!cir.vector<4 x !cir.bool>, !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i> // LLVM-LABEL: test_mm512_maskz_extracti32x4_epi32 - // LLVM: shufflevector <16 x i32> %{{.*}}, <16 x i32> zeroinitializer, <4 x i32> + // LLVM: shufflevector <16 x i32> %{{.*}}, <16 x i32> poison, <4 x i32> // LLVM: select <4 x i1> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}} // OGCG-LABEL: test_mm512_maskz_extracti32x4_epi32 @@ -833,11 +833,11 @@ __m128i test_mm512_maskz_extracti32x4_epi32(__mmask8 __U, __m512i __A) { __m256i test_mm512_extracti64x4_epi64(__m512i __A) { // CIR-LABEL: test_mm512_extracti64x4_epi64 - // CIR: [[ZERO:%.*]] = cir.const #cir.zero : !cir.vector<8 x !s64i> - // CIR: cir.vec.shuffle(%{{.*}}, [[ZERO]] : !cir.vector<8 x !s64i>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !s64i> + // CIR: [[POISON:%.*]] = cir.const #cir.poison : !cir.vector<8 x !s64i> + // CIR: cir.vec.shuffle(%{{.*}}, [[POISON]] : !cir.vector<8 x !s64i>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !s64i> // LLVM-LABEL: test_mm512_extracti64x4_epi64 - // LLVM: shufflevector <8 x i64> %{{.*}}, <8 x i64> zeroinitializer, <4 x i32> + // LLVM: shufflevector <8 x i64> %{{.*}}, <8 x i64> poison, <4 x i32> // OGCG-LABEL: test_mm512_extracti64x4_epi64 // OGCG: shufflevector <8 x i64> %{{.*}}, <8 x i64> poison, <4 x i32> @@ -846,11 +846,11 @@ __m256i test_mm512_extracti64x4_epi64(__m512i __A) { __m256i test_mm512_mask_extracti64x4_epi64(__m256i __W, __mmask8 __U, __m512i __A) { // CIR-LABEL: test_mm512_mask_extracti64x4_epi64 - // CIR: [[ZERO:%.*]] = cir.const #cir.zero : !cir.vector<8 x !s64i> - // CIR: cir.vec.shuffle(%{{.*}}, [[ZERO]] : !cir.vector<8 x !s64i>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !s64i> + // CIR: [[POISON:%.*]] = cir.const #cir.poison : !cir.vector<8 x !s64i> + // CIR: cir.vec.shuffle(%{{.*}}, [[POISON]] : !cir.vector<8 x !s64i>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !s64i> // LLVM-LABEL: test_mm512_mask_extracti64x4_epi64 - // LLVM: shufflevector <8 x i64> %{{.*}}, <8 x i64> zeroinitializer, <4 x i32> + // LLVM: shufflevector <8 x i64> %{{.*}}, <8 x i64> poison, <4 x i32> // LLVM: select <4 x i1> %{{.*}}, <4 x i64> %{{.*}}, <4 x i64> %{{.*}} // OGCG-LABEL: test_mm512_mask_extracti64x4_epi64 @@ -861,11 +861,11 @@ __m256i test_mm512_mask_extracti64x4_epi64(__m256i __W, __mmask8 __U, __m512i __ __m256i test_mm512_maskz_extracti64x4_epi64(__mmask8 __U, __m512i __A) { // CIR-LABEL: test_mm512_maskz_extracti64x4_epi64 - // CIR: [[ZERO:%.*]] = cir.const #cir.zero : !cir.vector<8 x !s64i> - // CIR: cir.vec.shuffle(%{{.*}}, [[ZERO]] : !cir.vector<8 x !s64i>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !s64i> + // CIR: [[POISON:%.*]] = cir.const #cir.poison : !cir.vector<8 x !s64i> + // CIR: cir.vec.shuffle(%{{.*}}, [[POISON]] : !cir.vector<8 x !s64i>) [#cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !s64i> // LLVM-LABEL: test_mm512_maskz_extracti64x4_epi64 - // LLVM: shufflevector <8 x i64> %{{.*}}, <8 x i64> zeroinitializer, <4 x i32> + // LLVM: shufflevector <8 x i64> %{{.*}}, <8 x i64> poison, <4 x i32> // LLVM: select <4 x i1> %{{.*}}, <4 x i64> %{{.*}}, <4 x i64> %{{.*}} // OGCG-LABEL: test_mm512_maskz_extracti64x4_epi64 From b50c2d77038d8370e401059fc2bd92036f5d81d9 Mon Sep 17 00:00:00 2001 From: Jasmine Tang Date: Tue, 9 Dec 2025 04:55:51 -0800 Subject: [PATCH 4/6] Add verifier for select op --- clang/include/clang/CIR/Dialect/IR/CIROps.td | 1 + clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 22 ++++++++++++++++++++ clang/test/CIR/IR/select.cir | 16 ++++++++++++++ 3 files changed, 39 insertions(+) create mode 100644 clang/test/CIR/IR/select.cir diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index bb11aaa542984..39a6fb21447e6 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -1930,6 +1930,7 @@ def CIR_SelectOp : CIR_Op<"select", [ }]; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 396d97ddd794e..752c31a7ecdc3 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -2271,6 +2271,28 @@ OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) { return {}; } +LogicalResult cir::SelectOp::verify() { + // INFO: No need to check if trueTy == falseTy here, it's verified by + // the AllTypesMatch trait already. + // We can go straight into getting the vector type. + + auto condVecTy = + mlir::dyn_cast(this->getCondition().getType()); + auto trueVecTy = + mlir::dyn_cast(this->getTrueValue().getType()); + auto falseVecTy = + mlir::dyn_cast(this->getFalseValue().getType()); + + if (condVecTy && (!trueVecTy || !falseVecTy)) { + // INFO: No need to check for size of vector here, it's verified by + // the AllTypesMatch trait already + return emitOpError() + << "second and third operand must both be of the same " + "vector type when" + " the conditional operand is of vector boolean type"; + } + return mlir::success(); +} //===----------------------------------------------------------------------===// // ShiftOp diff --git a/clang/test/CIR/IR/select.cir b/clang/test/CIR/IR/select.cir new file mode 100644 index 0000000000000..23b4044cdb976 --- /dev/null +++ b/clang/test/CIR/IR/select.cir @@ -0,0 +1,16 @@ +// RUN: cir-opt %s -verify-diagnostics +!s32i = !cir.int +module { + cir.func @select_int_wrong_size(%arg0 : !cir.vector<8 x !cir.bool>, %arg1 : !cir.vector<7 x !s32i>, %arg2 : !cir.vector<8 x !s32i>) -> !cir.vector<8 x !s32i> { + // expected-error @below {{failed to verify that all of {true_value, false_value, result} have same type}} + %0 = cir.select if %arg0 then %arg1 else %arg2 : (!cir.vector<8 x !cir.bool>, !cir.vector<7 x !s32i>, !cir.vector<8 x !s32i>) -> !cir.vector<8 x !s32i> + cir.return %0 : !cir.vector<8 x !s32i> + } + + + cir.func @select_int_wrong_type(%arg0 : !cir.vector<8 x !cir.bool>, %arg1 : !s32i, %arg2 : !s32i) -> !s32i { + // expected-error @below {{second and third operand must both be of the same vector type when the conditional operand is of vector boolean type}} + %0 = cir.select if %arg0 then %arg1 else %arg2 : (!cir.vector<8 x !cir.bool>, !s32i, !s32i) -> !s32i + cir.return %0 : !s32i + } +} From ced00b3afbb120a4f461fcd677deacc35930a6dd Mon Sep 17 00:00:00 2001 From: Jasmine Tang Date: Tue, 9 Dec 2025 05:13:33 -0800 Subject: [PATCH 5/6] Add a valid test case where we validly select the whole vector from a single boolean cond --- clang/test/CIR/IR/select.cir | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/clang/test/CIR/IR/select.cir b/clang/test/CIR/IR/select.cir index 23b4044cdb976..9d0133d33ce57 100644 --- a/clang/test/CIR/IR/select.cir +++ b/clang/test/CIR/IR/select.cir @@ -13,4 +13,9 @@ module { %0 = cir.select if %arg0 then %arg1 else %arg2 : (!cir.vector<8 x !cir.bool>, !s32i, !s32i) -> !s32i cir.return %0 : !s32i } + + cir.func @select_int_valid(%arg0 : !cir.bool, %arg1 : !cir.vector<8 x !s32i>, %arg2 : !cir.vector<8 x !s32i>) -> !cir.vector<8 x !s32i> { + %0 = cir.select if %arg0 then %arg1 else %arg2 : (!cir.bool, !cir.vector<8 x !s32i>, !cir.vector<8 x !s32i>) -> !cir.vector<8 x !s32i> + cir.return %0 : !cir.vector<8 x !s32i> + } } From e0aeba83e4f5614ca741891562e658c75054ff61 Mon Sep 17 00:00:00 2001 From: Jasmine Tang Date: Wed, 10 Dec 2025 02:45:31 -0800 Subject: [PATCH 6/6] Redo select diagnostics, redo seperating lines --- clang/include/clang/CIR/Dialect/IR/CIROps.td | 1 + clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 33 +++++++++----------- clang/test/CIR/IR/select.cir | 2 +- 3 files changed, 17 insertions(+), 19 deletions(-) diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 9bd6e5dce74c7..655e57690d169 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -425,6 +425,7 @@ def CIR_ConstantOp : CIR_Op<"const", [ return boolAttr.getValue(); llvm_unreachable("Expected a BoolAttr in ConstantOp"); } + bool isAllOnesValue() { // Check for -1 integers if (auto intAttr = getValueAttr()) diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 4733c4e904304..6a0c1ee64dc25 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -2330,27 +2330,24 @@ OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) { return {}; } + LogicalResult cir::SelectOp::verify() { - // INFO: No need to check if trueTy == falseTy here, it's verified by - // the AllTypesMatch trait already. - // We can go straight into getting the vector type. - - auto condVecTy = - mlir::dyn_cast(this->getCondition().getType()); - auto trueVecTy = - mlir::dyn_cast(this->getTrueValue().getType()); - auto falseVecTy = - mlir::dyn_cast(this->getFalseValue().getType()); - - if (condVecTy && (!trueVecTy || !falseVecTy)) { - // INFO: No need to check for size of vector here, it's verified by - // the AllTypesMatch trait already + // AllTypesMatch already guarantees trueVal and falseVal have matching types. + auto condTy = dyn_cast(getCondition().getType()); + + // If condition is not a vector, no further checks are needed. + if (!condTy) + return success(); + + // When condition is a vector, both other operands must also be vectors. + if (!isa(getTrueValue().getType()) || + !isa(getFalseValue().getType())) { return emitOpError() - << "second and third operand must both be of the same " - "vector type when" - " the conditional operand is of vector boolean type"; + << "expected both true and false operands to be vector types " + "when the condition is a vector boolean type"; } - return mlir::success(); + + return success(); } //===----------------------------------------------------------------------===// diff --git a/clang/test/CIR/IR/select.cir b/clang/test/CIR/IR/select.cir index 9d0133d33ce57..ba97e562d04fb 100644 --- a/clang/test/CIR/IR/select.cir +++ b/clang/test/CIR/IR/select.cir @@ -9,7 +9,7 @@ module { cir.func @select_int_wrong_type(%arg0 : !cir.vector<8 x !cir.bool>, %arg1 : !s32i, %arg2 : !s32i) -> !s32i { - // expected-error @below {{second and third operand must both be of the same vector type when the conditional operand is of vector boolean type}} + // expected-error @below {{expected both true and false operands to be vector types when the condition is a vector boolean type}} %0 = cir.select if %arg0 then %arg1 else %arg2 : (!cir.vector<8 x !cir.bool>, !s32i, !s32i) -> !s32i cir.return %0 : !s32i }