Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def ConstArrayAttr : CIR_Attr<"ConstArray", "const_array", [TypedAttrInterface]>
}]>
];

// Printing and parsing available in CIRDialect.cpp
// Printing and parsing available in CIRAttrs.cpp
let hasCustomAssemblyFormat = 1;

// Enable verifier.
Expand All @@ -215,6 +215,38 @@ def ConstArrayAttr : CIR_Attr<"ConstArray", "const_array", [TypedAttrInterface]>
}];
}

//===----------------------------------------------------------------------===//
// ConstVectorAttr
//===----------------------------------------------------------------------===//

def ConstVectorAttr : CIR_Attr<"ConstVector", "const_vector",
[TypedAttrInterface]> {
let summary = "A constant vector from ArrayAttr";
let description = [{
A CIR vector attribute is an array of literals of the specified attribute
types.
}];

let parameters = (ins AttributeSelfTypeParameter<"">:$type,
"mlir::ArrayAttr":$elts);

// Define a custom builder for the type; that removes the need to pass in an
// MLIRContext instance, as it can be inferred from the `type`.
let builders = [
AttrBuilderWithInferredContext<(ins "cir::VectorType":$type,
"mlir::ArrayAttr":$elts), [{
return $_get(type.getContext(), type, elts);
}]>
];

let assemblyFormat = [{
`<` $elts `>`
}];

// Enable verifier.
let genVerifyDecl = 1;
}

//===----------------------------------------------------------------------===//
// ConstPtrAttr
//===----------------------------------------------------------------------===//
Expand Down
23 changes: 21 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,27 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &value,
elements, typedFiller);
}
case APValue::Vector: {
cgm.errorNYI("ConstExprEmitter::tryEmitPrivate vector");
return {};
const QualType elementType =
destType->castAs<VectorType>()->getElementType();
const unsigned numElements = value.getVectorLength();

SmallVector<mlir::Attribute, 16> elements;
elements.reserve(numElements);

for (unsigned i = 0; i < numElements; ++i) {
const mlir::Attribute element =
tryEmitPrivateForMemory(value.getVectorElt(i), elementType);
if (!element)
return {};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens to any elements that are already in the elements vector? Do they get cleaned up and deleted properly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understood, yes the destructor will be called when we early return and will clean the vector

elements.push_back(element);
}

const auto desiredVecTy =
mlir::cast<cir::VectorType>(cgm.convertType(destType));

return cir::ConstVectorAttr::get(
desiredVecTy,
mlir::ArrayAttr::get(cgm.getBuilder().getContext(), elements));
}
case APValue::MemberPointer: {
cgm.errorNYI("ConstExprEmitter::tryEmitPrivate member pointer");
Expand Down
41 changes: 41 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,47 @@ void ConstArrayAttr::print(AsmPrinter &printer) const {
printer << ">";
}

//===----------------------------------------------------------------------===//
// CIR ConstVectorAttr
//===----------------------------------------------------------------------===//

LogicalResult cir::ConstVectorAttr::verify(
function_ref<::mlir::InFlightDiagnostic()> emitError, Type type,
ArrayAttr elts) {

if (!mlir::isa<cir::VectorType>(type)) {
return emitError() << "type of cir::ConstVectorAttr is not a "
"cir::VectorType: "
<< type;
}

const auto vecType = mlir::cast<cir::VectorType>(type);

if (vecType.getSize() != elts.size()) {
return emitError()
<< "number of constant elements should match vector size";
}

// Check if the types of the elements match
LogicalResult elementTypeCheck = success();
elts.walkImmediateSubElements(
[&](Attribute element) {
if (elementTypeCheck.failed()) {
// An earlier element didn't match
return;
}
auto typedElement = mlir::dyn_cast<TypedAttr>(element);
if (!typedElement ||
typedElement.getType() != vecType.getElementType()) {
elementTypeCheck = failure();
emitError() << "constant type should match vector element type";
}
},
[&](Type) {});

return elementTypeCheck;
}

//===----------------------------------------------------------------------===//
// CIR Dialect
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
return success();
}

if (mlir::isa<cir::ConstArrayAttr>(attrType))
if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr>(attrType))
return success();

assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?");
Expand Down
40 changes: 35 additions & 5 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,15 +188,17 @@ class CIRAttrToValue {

mlir::Value visit(mlir::Attribute attr) {
return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
.Case<cir::IntAttr, cir::FPAttr, cir::ConstArrayAttr, cir::ConstPtrAttr,
cir::ZeroAttr>([&](auto attrT) { return visitCirAttr(attrT); })
.Case<cir::IntAttr, cir::FPAttr, cir::ConstArrayAttr,
cir::ConstVectorAttr, cir::ConstPtrAttr, cir::ZeroAttr>(
[&](auto attrT) { return visitCirAttr(attrT); })
.Default([&](auto attrT) { return mlir::Value(); });
}

mlir::Value visitCirAttr(cir::IntAttr intAttr);
mlir::Value visitCirAttr(cir::FPAttr fltAttr);
mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr);
mlir::Value visitCirAttr(cir::ConstArrayAttr attr);
mlir::Value visitCirAttr(cir::ConstVectorAttr attr);
mlir::Value visitCirAttr(cir::ZeroAttr attr);

private:
Expand Down Expand Up @@ -275,6 +277,33 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstArrayAttr attr) {
return result;
}

/// ConstVectorAttr visitor.
mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstVectorAttr attr) {
const mlir::Type llvmTy = converter->convertType(attr.getType());
const mlir::Location loc = parentOp->getLoc();

SmallVector<mlir::Attribute> mlirValues;
for (const mlir::Attribute elementAttr : attr.getElts()) {
mlir::Attribute mlirAttr;
if (auto intAttr = mlir::dyn_cast<cir::IntAttr>(elementAttr)) {
mlirAttr = rewriter.getIntegerAttr(
converter->convertType(intAttr.getType()), intAttr.getValue());
} else if (auto floatAttr = mlir::dyn_cast<cir::FPAttr>(elementAttr)) {
mlirAttr = rewriter.getFloatAttr(
converter->convertType(floatAttr.getType()), floatAttr.getValue());
} else {
llvm_unreachable(
"vector constant with an element that is neither an int nor a float");
}
mlirValues.push_back(mlirAttr);
}

return rewriter.create<mlir::LLVM::ConstantOp>(
loc, llvmTy,
mlir::DenseElementsAttr::get(mlir::cast<mlir::ShapedType>(llvmTy),
mlirValues));
}

/// ZeroAttr visitor.
mlir::Value CIRAttrToValue::visitCirAttr(cir::ZeroAttr attr) {
mlir::Location loc = parentOp->getLoc();
Expand Down Expand Up @@ -888,7 +917,8 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
cir::GlobalOp op, mlir::Attribute init,
mlir::ConversionPatternRewriter &rewriter) const {
// TODO: Generalize this handling when more types are needed here.
assert((isa<cir::ConstArrayAttr, cir::ConstPtrAttr, cir::ZeroAttr>(init)));
assert((isa<cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr,
cir::ZeroAttr>(init)));

// TODO(cir): once LLVM's dialect has proper equivalent attributes this
// should be updated. For now, we use a custom op to initialize globals
Expand Down Expand Up @@ -941,8 +971,8 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
op.emitError() << "unsupported initializer '" << init.value() << "'";
return mlir::failure();
}
} else if (mlir::isa<cir::ConstArrayAttr, cir::ConstPtrAttr, cir::ZeroAttr>(
init.value())) {
} else if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
cir::ConstPtrAttr, cir::ZeroAttr>(init.value())) {
// TODO(cir): once LLVM's dialect has proper equivalent attributes this
// should be updated. For now, we use a custom op to initialize globals
// to the appropriate value.
Expand Down
11 changes: 10 additions & 1 deletion clang/test/CIR/CodeGen/vector-ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,23 @@ vi2 vec_c;

// OGCG: @[[VEC_C:.*]] = global <2 x i32> zeroinitializer

vd2 d;
vd2 vec_d;

// CIR: cir.global external @[[VEC_D:.*]] = #cir.zero : !cir.vector<2 x !cir.double>

// LLVM: @[[VEC_D:.*]] = dso_local global <2 x double> zeroinitialize

// OGCG: @[[VEC_D:.*]] = global <2 x double> zeroinitializer

vi4 vec_e = { 1, 2, 3, 4 };

// CIR: cir.global external @[[VEC_E:.*]] = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> :
// CIR-SAME: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i>

// LLVM: @[[VEC_E:.*]] = dso_local global <4 x i32> <i32 1, i32 2, i32 3, i32 4>

// OGCG: @[[VEC_E:.*]] = global <4 x i32> <i32 1, i32 2, i32 3, i32 4>

void foo() {
vi4 a;
vi3 b;
Expand Down
9 changes: 9 additions & 0 deletions clang/test/CIR/CodeGen/vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ vll2 c;

// OGCG: @[[VEC_C:.*]] = global <2 x i64> zeroinitializer

vi4 d = { 1, 2, 3, 4 };

// CIR: cir.global external @[[VEC_D:.*]] = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> :
// CIR-SAME: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i>

// LLVM: @[[VEC_D:.*]] = dso_local global <4 x i32> <i32 1, i32 2, i32 3, i32 4>

// OGCG: @[[VEC_D:.*]] = global <4 x i32> <i32 1, i32 2, i32 3, i32 4>

void vec_int_test() {
vi4 a;
vd2 b;
Expand Down
6 changes: 6 additions & 0 deletions clang/test/CIR/IR/vector.cir
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ cir.global external @vec_b = #cir.zero : !cir.vector<3 x !s32i>
cir.global external @vec_c = #cir.zero : !cir.vector<2 x !s32i>
// CHECK: cir.global external @vec_c = #cir.zero : !cir.vector<2 x !s32i>

cir.global external @vec_d = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2>
: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i>

// CIR: cir.global external @vec_d = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> :
// CIR-SAME: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i>

cir.func @vec_int_test() {
%0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
%1 = cir.alloca !cir.vector<3 x !s32i>, !cir.ptr<!cir.vector<3 x !s32i>>, ["b"]
Expand Down