Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions clang/include/clang/CIR/LoweringHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#ifndef LLVM_CLANG_CIR_LOWERINGHELPERS_H
#define LLVM_CLANG_CIR_LOWERINGHELPERS_H

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
Expand All @@ -36,7 +37,17 @@ convertToDenseElementsAttr(cir::ConstArrayAttr attr,
const llvm::SmallVectorImpl<int64_t> &dims,
mlir::Type type);

template <typename AttrTy, typename StorageTy>
mlir::DenseElementsAttr
convertToDenseElementsAttr(cir::ConstVectorAttr attr,
const llvm::SmallVectorImpl<int64_t> &dims,
mlir::Type type);

std::optional<mlir::Attribute>
lowerConstArrayAttr(cir::ConstArrayAttr constArr,
const mlir::TypeConverter *converter);

std::optional<mlir::Attribute>
lowerConstVectorAttr(cir::ConstVectorAttr constArr,
const mlir::TypeConverter *converter);
#endif
29 changes: 27 additions & 2 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1860,8 +1860,10 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
rewriter.eraseOp(op);
return mlir::success();
}
} else if (const auto recordAttr =
mlir::dyn_cast<cir::ConstRecordAttr>(op.getValue())) {
}

else if (const auto recordAttr =
mlir::dyn_cast<cir::ConstRecordAttr>(op.getValue())) {
// TODO(cir): this diverges from traditional lowering. Normally the
// initializer would be a global constant that is memcopied. Here we just
// define a local constant with llvm.undef that will be stored into the
Expand Down Expand Up @@ -2421,6 +2423,9 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializer(
} else if (mlir::isa<cir::ConstArrayAttr>(init)) {
return lowerInitializerForConstArray(rewriter, op, init,
useInitializerRegion);
} else if (mlir::isa<cir::ConstVectorAttr>(init)) {
return lowerInitializerForConstVector(rewriter, op, init,
useInitializerRegion);
} else if (auto dataMemberAttr = mlir::dyn_cast<cir::DataMemberAttr>(init)) {
assert(lowerMod && "lower module is not available");
mlir::DataLayout layout(op->getParentOfType<mlir::ModuleOp>());
Expand All @@ -2437,6 +2442,26 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializer(
}
llvm_unreachable("unreachable");
}

mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializerForConstVector(
mlir::ConversionPatternRewriter &rewriter, cir::GlobalOp op,
mlir::Attribute &init, bool &useInitializerRegion) const {
auto constVec = mlir::cast<cir::ConstVectorAttr>(init);
if (const auto attr = mlir::dyn_cast<mlir::ArrayAttr>(constVec.getElts())) {
if (auto val = lowerConstVectorAttr(constVec, getTypeConverter());
val.has_value()) {
init = val.value();
useInitializerRegion = false;
} else
useInitializerRegion = true;
return mlir::success();
}

op.emitError() << "unsupported lowering for #cir.const_vector with value "
<< constVec.getElts();
return mlir::failure();
}

mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializerForConstArray(
mlir::ConversionPatternRewriter &rewriter, cir::GlobalOp op,
mlir::Attribute &init, bool &useInitializerRegion) const {
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,11 @@ class CIRToLLVMGlobalOpLowering
cir::GlobalOp op, mlir::Attribute &init,
bool &useInitializerRegion) const;

mlir::LogicalResult
lowerInitializerForConstVector(mlir::ConversionPatternRewriter &rewriter,
cir::GlobalOp op, mlir::Attribute &init,
bool &useInitializerRegion) const;

mlir::LogicalResult
lowerInitializerDirect(mlir::ConversionPatternRewriter &rewriter,
cir::GlobalOp op, mlir::Type llvmType,
Expand Down
79 changes: 79 additions & 0 deletions clang/lib/CIR/Lowering/LoweringHelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,39 @@ void convertToDenseElementsAttrImpl(
}
}

template <typename AttrTy, typename StorageTy>
void convertToDenseElementsAttrImpl(
cir::ConstVectorAttr attr, llvm::SmallVectorImpl<StorageTy> &values,
const llvm::SmallVectorImpl<int64_t> &currentDims, int64_t dimIndex,
int64_t currentIndex) {
dimIndex++;
std::size_t elementsSizeInCurrentDim = 1;
for (std::size_t i = dimIndex; i < currentDims.size(); i++)
elementsSizeInCurrentDim *= currentDims[i];

auto arrayAttr = mlir::cast<mlir::ArrayAttr>(attr.getElts());
for (auto eltAttr : arrayAttr) {
if (auto valueAttr = mlir::dyn_cast<AttrTy>(eltAttr)) {
values[currentIndex++] = valueAttr.getValue();
continue;
}

if (auto subArrayAttr = mlir::dyn_cast<cir::ConstArrayAttr>(eltAttr)) {
convertToDenseElementsAttrImpl<AttrTy>(subArrayAttr, values, currentDims,
dimIndex, currentIndex);
currentIndex += elementsSizeInCurrentDim;
continue;
}

if (mlir::isa<cir::ZeroAttr, cir::UndefAttr>(eltAttr)) {
currentIndex += elementsSizeInCurrentDim;
continue;
}

llvm_unreachable("unknown element in ConstArrayAttr");
}
}

template <typename AttrTy, typename StorageTy>
mlir::DenseElementsAttr convertToDenseElementsAttr(
cir::ConstArrayAttr attr, const llvm::SmallVectorImpl<int64_t> &dims,
Expand All @@ -109,6 +142,22 @@ mlir::DenseElementsAttr convertToDenseElementsAttr(
llvm::ArrayRef(values));
}

template <typename AttrTy, typename StorageTy>
mlir::DenseElementsAttr convertToDenseElementsAttr(
cir::ConstVectorAttr attr, const llvm::SmallVectorImpl<int64_t> &dims,
mlir::Type elementType, mlir::Type convertedElementType) {
unsigned vector_size = 1;
for (auto dim : dims)
vector_size *= dim;
auto values = llvm::SmallVector<StorageTy, 8>(
vector_size, getZeroInitFromType<StorageTy>(elementType));
convertToDenseElementsAttrImpl<AttrTy>(attr, values, dims, /*currentDim=*/0,
/*initialIndex=*/0);
return mlir::DenseElementsAttr::get(
mlir::RankedTensorType::get(dims, convertedElementType),
llvm::ArrayRef(values));
}

std::optional<mlir::Attribute>
lowerConstArrayAttr(cir::ConstArrayAttr constArr,
const mlir::TypeConverter *converter) {
Expand Down Expand Up @@ -141,3 +190,33 @@ lowerConstArrayAttr(cir::ConstArrayAttr constArr,

return std::nullopt;
}

std::optional<mlir::Attribute>
lowerConstVectorAttr(cir::ConstVectorAttr constArr,
const mlir::TypeConverter *converter) {

// Ensure ConstArrayAttr has a type.
auto typedConstArr = mlir::dyn_cast<mlir::TypedAttr>(constArr);
assert(typedConstArr && "cir::ConstArrayAttr is not a mlir::TypedAttr");

// Ensure ConstArrayAttr type is a ArrayType.
auto cirArrayType = mlir::dyn_cast<cir::VectorType>(typedConstArr.getType());
assert(cirArrayType && "cir::ConstArrayAttr is not a cir::ArrayType");

// Is a ConstArrayAttr with an cir::ArrayType: fetch element type.
mlir::Type type = cirArrayType;
auto dims = llvm::SmallVector<int64_t, 2>{};
while (auto arrayType = mlir::dyn_cast<cir::ArrayType>(type)) {
dims.push_back(arrayType.getSize());
type = arrayType.getEltType();
}

if (mlir::isa<cir::IntType>(type))
return convertToDenseElementsAttr<cir::IntAttr, mlir::APInt>(
constArr, dims, type, converter->convertType(type));
if (mlir::isa<cir::CIRFPTypeInterface>(type))
return convertToDenseElementsAttr<cir::FPAttr, mlir::APFloat>(
constArr, dims, type, converter->convertType(type));

return std::nullopt;
}
9 changes: 8 additions & 1 deletion clang/test/CIR/CodeGen/vectype-ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,19 @@ vi2 vec_c;

// LLVM: @[[VEC_C:.*]] = global <2 x i32> zeroinitializer

vd2 d;
vd2 vec_d;

// CIR: cir.global external @[[VEC_D:.*]] = #cir.zero : !cir.vector<!cir.double x 2>

// LLVM: @[[VEC_D:.*]] = global <2 x double> zeroinitializer

vi4 vec_e = { 1, 2, 3, 4 };

// CIR: cir.global external @[[VEC_E:.*]] = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> :
// CIR-SAME: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<!s32i x 4>

// LLVM: @[[VEC_E:.*]] = global <4 x i32> <i32 1, i32 2, i32 3, i32 4>

// CIR: cir.func {{@.*vector_int_test.*}}
// LLVM: define dso_local void {{@.*vector_int_test.*}}
void vector_int_test(int x) {
Expand Down
5 changes: 5 additions & 0 deletions clang/test/CIR/CodeGen/vectype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ vd2 b;
vll2 c;
// CHECK: cir.global external @[[VEC_C:.*]] = #cir.zero : !cir.vector<!s64i x 2>

vi4 d = { 1, 2, 3, 4 };

// CHECK: cir.global external @[[VEC_D:.*]] = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> :
// CHECK-SAME: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<!s32i x 4>

void vector_int_test(int x, unsigned short usx) {

// Vector constant.
Expand Down
Loading