diff --git a/clang/include/clang/CIR/LoweringHelpers.h b/clang/include/clang/CIR/LoweringHelpers.h index 771a382591fa..32f9f3b3a98b 100644 --- a/clang/include/clang/CIR/LoweringHelpers.h +++ b/clang/include/clang/CIR/LoweringHelpers.h @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #ifndef LLVM_CLANG_CIR_LOWERINGHELPERS_H #define LLVM_CLANG_CIR_LOWERINGHELPERS_H + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" @@ -36,7 +37,17 @@ convertToDenseElementsAttr(cir::ConstArrayAttr attr, const llvm::SmallVectorImpl &dims, mlir::Type type); +template +mlir::DenseElementsAttr +convertToDenseElementsAttr(cir::ConstVectorAttr attr, + const llvm::SmallVectorImpl &dims, + mlir::Type type); + std::optional lowerConstArrayAttr(cir::ConstArrayAttr constArr, const mlir::TypeConverter *converter); + +std::optional +lowerConstVectorAttr(cir::ConstVectorAttr constArr, + const mlir::TypeConverter *converter); #endif diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 0a0dc2de0b44..051b83d2b850 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -1860,8 +1860,10 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite( rewriter.eraseOp(op); return mlir::success(); } - } else if (const auto recordAttr = - mlir::dyn_cast(op.getValue())) { + } + + else if (const auto recordAttr = + mlir::dyn_cast(op.getValue())) { // TODO(cir): this diverges from traditional lowering. Normally the // initializer would be a global constant that is memcopied. Here we just // define a local constant with llvm.undef that will be stored into the @@ -2421,6 +2423,9 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializer( } else if (mlir::isa(init)) { return lowerInitializerForConstArray(rewriter, op, init, useInitializerRegion); + } else if (mlir::isa(init)) { + return lowerInitializerForConstVector(rewriter, op, init, + useInitializerRegion); } else if (auto dataMemberAttr = mlir::dyn_cast(init)) { assert(lowerMod && "lower module is not available"); mlir::DataLayout layout(op->getParentOfType()); @@ -2437,6 +2442,26 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializer( } llvm_unreachable("unreachable"); } + +mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializerForConstVector( + mlir::ConversionPatternRewriter &rewriter, cir::GlobalOp op, + mlir::Attribute &init, bool &useInitializerRegion) const { + auto constVec = mlir::cast(init); + if (const auto attr = mlir::dyn_cast(constVec.getElts())) { + if (auto val = lowerConstVectorAttr(constVec, getTypeConverter()); + val.has_value()) { + init = val.value(); + useInitializerRegion = false; + } else + useInitializerRegion = true; + return mlir::success(); + } + + op.emitError() << "unsupported lowering for #cir.const_vector with value " + << constVec.getElts(); + return mlir::failure(); +} + mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializerForConstArray( mlir::ConversionPatternRewriter &rewriter, cir::GlobalOp op, mlir::Attribute &init, bool &useInitializerRegion) const { diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h index 64a296092c8a..9820dee369c2 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h @@ -623,6 +623,11 @@ class CIRToLLVMGlobalOpLowering cir::GlobalOp op, mlir::Attribute &init, bool &useInitializerRegion) const; + mlir::LogicalResult + lowerInitializerForConstVector(mlir::ConversionPatternRewriter &rewriter, + cir::GlobalOp op, mlir::Attribute &init, + bool &useInitializerRegion) const; + mlir::LogicalResult lowerInitializerDirect(mlir::ConversionPatternRewriter &rewriter, cir::GlobalOp op, mlir::Type llvmType, diff --git a/clang/lib/CIR/Lowering/LoweringHelpers.cpp b/clang/lib/CIR/Lowering/LoweringHelpers.cpp index 26c2945105cc..f4dfad43cf20 100644 --- a/clang/lib/CIR/Lowering/LoweringHelpers.cpp +++ b/clang/lib/CIR/Lowering/LoweringHelpers.cpp @@ -93,6 +93,39 @@ void convertToDenseElementsAttrImpl( } } +template +void convertToDenseElementsAttrImpl( + cir::ConstVectorAttr attr, llvm::SmallVectorImpl &values, + const llvm::SmallVectorImpl ¤tDims, int64_t dimIndex, + int64_t currentIndex) { + dimIndex++; + std::size_t elementsSizeInCurrentDim = 1; + for (std::size_t i = dimIndex; i < currentDims.size(); i++) + elementsSizeInCurrentDim *= currentDims[i]; + + auto arrayAttr = mlir::cast(attr.getElts()); + for (auto eltAttr : arrayAttr) { + if (auto valueAttr = mlir::dyn_cast(eltAttr)) { + values[currentIndex++] = valueAttr.getValue(); + continue; + } + + if (auto subArrayAttr = mlir::dyn_cast(eltAttr)) { + convertToDenseElementsAttrImpl(subArrayAttr, values, currentDims, + dimIndex, currentIndex); + currentIndex += elementsSizeInCurrentDim; + continue; + } + + if (mlir::isa(eltAttr)) { + currentIndex += elementsSizeInCurrentDim; + continue; + } + + llvm_unreachable("unknown element in ConstArrayAttr"); + } +} + template mlir::DenseElementsAttr convertToDenseElementsAttr( cir::ConstArrayAttr attr, const llvm::SmallVectorImpl &dims, @@ -109,6 +142,22 @@ mlir::DenseElementsAttr convertToDenseElementsAttr( llvm::ArrayRef(values)); } +template +mlir::DenseElementsAttr convertToDenseElementsAttr( + cir::ConstVectorAttr attr, const llvm::SmallVectorImpl &dims, + mlir::Type elementType, mlir::Type convertedElementType) { + unsigned vector_size = 1; + for (auto dim : dims) + vector_size *= dim; + auto values = llvm::SmallVector( + vector_size, getZeroInitFromType(elementType)); + convertToDenseElementsAttrImpl(attr, values, dims, /*currentDim=*/0, + /*initialIndex=*/0); + return mlir::DenseElementsAttr::get( + mlir::RankedTensorType::get(dims, convertedElementType), + llvm::ArrayRef(values)); +} + std::optional lowerConstArrayAttr(cir::ConstArrayAttr constArr, const mlir::TypeConverter *converter) { @@ -141,3 +190,33 @@ lowerConstArrayAttr(cir::ConstArrayAttr constArr, return std::nullopt; } + +std::optional +lowerConstVectorAttr(cir::ConstVectorAttr constArr, + const mlir::TypeConverter *converter) { + + // Ensure ConstArrayAttr has a type. + auto typedConstArr = mlir::dyn_cast(constArr); + assert(typedConstArr && "cir::ConstArrayAttr is not a mlir::TypedAttr"); + + // Ensure ConstArrayAttr type is a ArrayType. + auto cirArrayType = mlir::dyn_cast(typedConstArr.getType()); + assert(cirArrayType && "cir::ConstArrayAttr is not a cir::ArrayType"); + + // Is a ConstArrayAttr with an cir::ArrayType: fetch element type. + mlir::Type type = cirArrayType; + auto dims = llvm::SmallVector{}; + while (auto arrayType = mlir::dyn_cast(type)) { + dims.push_back(arrayType.getSize()); + type = arrayType.getEltType(); + } + + if (mlir::isa(type)) + return convertToDenseElementsAttr( + constArr, dims, type, converter->convertType(type)); + if (mlir::isa(type)) + return convertToDenseElementsAttr( + constArr, dims, type, converter->convertType(type)); + + return std::nullopt; +} diff --git a/clang/test/CIR/CodeGen/vectype-ext.cpp b/clang/test/CIR/CodeGen/vectype-ext.cpp index 969600c0ebf5..b928372413c5 100644 --- a/clang/test/CIR/CodeGen/vectype-ext.cpp +++ b/clang/test/CIR/CodeGen/vectype-ext.cpp @@ -25,12 +25,19 @@ vi2 vec_c; // LLVM: @[[VEC_C:.*]] = global <2 x i32> zeroinitializer -vd2 d; +vd2 vec_d; // CIR: cir.global external @[[VEC_D:.*]] = #cir.zero : !cir.vector // LLVM: @[[VEC_D:.*]] = global <2 x double> zeroinitializer +vi4 vec_e = { 1, 2, 3, 4 }; + +// CIR: cir.global external @[[VEC_E:.*]] = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : +// CIR-SAME: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector + +// LLVM: @[[VEC_E:.*]] = global <4 x i32> + // CIR: cir.func {{@.*vector_int_test.*}} // LLVM: define dso_local void {{@.*vector_int_test.*}} void vector_int_test(int x) { diff --git a/clang/test/CIR/CodeGen/vectype.cpp b/clang/test/CIR/CodeGen/vectype.cpp index 0389ad8b87ca..29f0eac2656c 100644 --- a/clang/test/CIR/CodeGen/vectype.cpp +++ b/clang/test/CIR/CodeGen/vectype.cpp @@ -14,6 +14,11 @@ vd2 b; vll2 c; // CHECK: cir.global external @[[VEC_C:.*]] = #cir.zero : !cir.vector +vi4 d = { 1, 2, 3, 4 }; + +// CHECK: cir.global external @[[VEC_D:.*]] = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : +// CHECK-SAME: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector + void vector_int_test(int x, unsigned short usx) { // Vector constant.