-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[CIR] Upstream basic support for ExtVector element expr #167570
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CIR] Upstream basic support for ExtVector element expr #167570
Conversation
|
@llvm/pr-subscribers-clangir @llvm/pr-subscribers-clang Author: Amr Hesham (AmrDeveloper) ChangesUpstream the basic support for the ExtVectorType element expr Full diff: https://github.com/llvm/llvm-project/pull/167570.diff 6 Files Affected:
diff --git a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
index 422fa1cf5ad2e..f898511ce2fed 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
@@ -627,10 +627,51 @@ RValue CIRGenFunction::emitLoadOfLValue(LValue lv, SourceLocation loc) {
lv.getVectorIdx()));
}
+ if (lv.isExtVectorElt()) {
+ return emitLoadOfExtVectorElementLValue(lv);
+ }
+
cgm.errorNYI(loc, "emitLoadOfLValue");
return RValue::get(nullptr);
}
+int64_t CIRGenFunction::getAccessedFieldNo(unsigned int idx,
+ const mlir::ArrayAttr elts) {
+ auto elt = mlir::dyn_cast<mlir::IntegerAttr>(elts[idx]);
+ assert(elt && "The indices should be integer attributes");
+ return elt.getInt();
+}
+
+// If this is a reference to a subset of the elements of a vector, create an
+// appropriate shufflevector.
+RValue CIRGenFunction::emitLoadOfExtVectorElementLValue(LValue lv) {
+ mlir::Location loc = lv.getExtVectorPointer().getLoc();
+ mlir::Value vec = builder.createLoad(loc, lv.getExtVectorAddress());
+
+ // HLSL allows treating scalars as one-element vectors. Converting the scalar
+ // IR value to a vector here allows the rest of codegen to behave as normal.
+ if (getLangOpts().HLSL && !mlir::isa<cir::VectorType>(vec.getType())) {
+ cgm.errorNYI(loc, "emitLoadOfExtVectorElementLValue: HLSL");
+ return {};
+ }
+
+ const mlir::ArrayAttr elts = lv.getExtVectorElts();
+
+ // If the result of the expression is a non-vector type, we must be extracting
+ // a single element. Just codegen as an extractelement.
+ const auto *exprVecTy = lv.getType()->getAs<clang::VectorType>();
+ if (!exprVecTy) {
+ int64_t indexValue = getAccessedFieldNo(0, elts);
+ cir::ConstantOp index =
+ builder.getConstInt(loc, builder.getSInt64Ty(), indexValue);
+ return RValue::get(cir::VecExtractOp::create(builder, loc, vec, index));
+ }
+
+ cgm.errorNYI(
+ loc, "emitLoadOfExtVectorElementLValue: Result of expr is vector type");
+ return {};
+}
+
static cir::FuncOp emitFunctionDeclPointer(CIRGenModule &cgm, GlobalDecl gd) {
assert(!cir::MissingFeatures::weakRefReference());
return cgm.getAddrOfFunction(gd);
@@ -1116,6 +1157,50 @@ CIRGenFunction::emitArraySubscriptExpr(const clang::ArraySubscriptExpr *e) {
return lv;
}
+LValue CIRGenFunction::emitExtVectorElementExpr(const ExtVectorElementExpr *e) {
+ // Emit the base vector as an l-value.
+ LValue base;
+
+ // ExtVectorElementExpr's base can either be a vector or pointer to vector.
+ if (e->isArrow()) {
+ cgm.errorNYI(e->getSourceRange(),
+ "emitExtVectorElementExpr: pointer to vector");
+ return {};
+ } else if (e->getBase()->isGLValue()) {
+ // Otherwise, if the base is an lvalue ( as in the case of foo.x.x),
+ // emit the base as an lvalue.
+ assert(e->getBase()->getType()->isVectorType());
+ base = emitLValue(e->getBase());
+ } else {
+ // Otherwise, the base is a normal rvalue (as in (V+V).x), emit it as such.
+ cgm.errorNYI(e->getSourceRange(),
+ "emitExtVectorElementExpr: base is a normal rvalue");
+ return {};
+ }
+
+ QualType type =
+ e->getType().withCVRQualifiers(base.getQuals().getCVRQualifiers());
+
+ // Encode the element access list into a vector of unsigned indices.
+ SmallVector<uint32_t, 4> indices;
+ e->getEncodedElementAccess(indices);
+
+ if (base.isSimple()) {
+ SmallVector<int64_t> attrElts;
+ for (uint32_t i : indices) {
+ attrElts.push_back(static_cast<int64_t>(i));
+ }
+
+ mlir::ArrayAttr elts = builder.getI64ArrayAttr(attrElts);
+ return LValue::makeExtVectorElt(base.getAddress(), elts, type,
+ base.getBaseInfo());
+ }
+
+ cgm.errorNYI(e->getSourceRange(),
+ "emitExtVectorElementExpr: isSimple is false");
+ return {};
+}
+
LValue CIRGenFunction::emitStringLiteralLValue(const StringLiteral *e,
llvm::StringRef name) {
cir::GlobalOp globalOp = cgm.getGlobalForStringLiteral(e, name);
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
index c1a36134d8942..eeb766ff9f9fc 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
@@ -277,6 +277,8 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
e->getSourceRange().getBegin());
}
+ mlir::Value VisitExtVectorElementExpr(Expr *e) { return emitLoadOfLValue(e); }
+
mlir::Value VisitMemberExpr(MemberExpr *e);
mlir::Value VisitCompoundLiteralExpr(CompoundLiteralExpr *e) {
diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.cpp b/clang/lib/CIR/CodeGen/CIRGenFunction.cpp
index 5d5209b9ffb60..5d3040f3b10eb 100644
--- a/clang/lib/CIR/CodeGen/CIRGenFunction.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenFunction.cpp
@@ -883,6 +883,8 @@ LValue CIRGenFunction::emitLValue(const Expr *e) {
return emitConditionalOperatorLValue(cast<BinaryConditionalOperator>(e));
case Expr::ArraySubscriptExprClass:
return emitArraySubscriptExpr(cast<ArraySubscriptExpr>(e));
+ case Expr::ExtVectorElementExprClass:
+ return emitExtVectorElementExpr(cast<ExtVectorElementExpr>(e));
case Expr::UnaryOperatorClass:
return emitUnaryOpLValue(cast<UnaryOperator>(e));
case Expr::StringLiteralClass:
diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h
index f879e580989f7..58c98c077bd0e 100644
--- a/clang/lib/CIR/CodeGen/CIRGenFunction.h
+++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h
@@ -1264,6 +1264,8 @@ class CIRGenFunction : public CIRGenTypeCache {
QualType &baseType, Address &addr);
LValue emitArraySubscriptExpr(const clang::ArraySubscriptExpr *e);
+ LValue emitExtVectorElementExpr(const ExtVectorElementExpr *e);
+
Address emitArrayToPointerDecay(const Expr *e,
LValueBaseInfo *baseInfo = nullptr);
@@ -1329,6 +1331,8 @@ class CIRGenFunction : public CIRGenTypeCache {
mlir::Value emittedE,
bool isDynamic);
+ int64_t getAccessedFieldNo(unsigned idx, mlir::ArrayAttr elts);
+
RValue emitCall(const CIRGenFunctionInfo &funcInfo,
const CIRGenCallee &callee, ReturnValueSlot returnValue,
const CallArgList &args, cir::CIRCallOpInterface *callOp,
@@ -1624,6 +1628,8 @@ class CIRGenFunction : public CIRGenTypeCache {
/// Load a complex number from the specified l-value.
mlir::Value emitLoadOfComplex(LValue src, SourceLocation loc);
+ RValue emitLoadOfExtVectorElementLValue(LValue lv);
+
/// Given an expression that represents a value lvalue, this method emits
/// the address of the lvalue, then loads the result as an rvalue,
/// returning the rvalue.
diff --git a/clang/lib/CIR/CodeGen/CIRGenValue.h b/clang/lib/CIR/CodeGen/CIRGenValue.h
index ab245a771d72c..20a3d0ef61341 100644
--- a/clang/lib/CIR/CodeGen/CIRGenValue.h
+++ b/clang/lib/CIR/CodeGen/CIRGenValue.h
@@ -166,7 +166,8 @@ class LValue {
// this is the alignment of the whole vector)
unsigned alignment;
mlir::Value v;
- mlir::Value vectorIdx; // Index for vector subscript
+ mlir::Value vectorIdx; // Index for vector subscript
+ mlir::Attribute vectorElts; // ExtVector element subset: V.xyx
mlir::Type elementType;
LValueBaseInfo baseInfo;
const CIRGenBitFieldInfo *bitFieldInfo{nullptr};
@@ -190,6 +191,7 @@ class LValue {
bool isSimple() const { return lvType == Simple; }
bool isVectorElt() const { return lvType == VectorElt; }
bool isBitField() const { return lvType == BitField; }
+ bool isExtVectorElt() const { return lvType == ExtVectorElt; }
bool isGlobalReg() const { return lvType == GlobalReg; }
bool isVolatile() const { return quals.hasVolatile(); }
@@ -254,6 +256,22 @@ class LValue {
return vectorIdx;
}
+ // extended vector elements.
+ Address getExtVectorAddress() const {
+ assert(isExtVectorElt());
+ return Address(getExtVectorPointer(), elementType, getAlignment());
+ }
+
+ mlir::Value getExtVectorPointer() const {
+ assert(isExtVectorElt());
+ return v;
+ }
+
+ mlir::ArrayAttr getExtVectorElts() const {
+ assert(isExtVectorElt());
+ return mlir::cast<mlir::ArrayAttr>(vectorElts);
+ }
+
static LValue makeVectorElt(Address vecAddress, mlir::Value index,
clang::QualType t, LValueBaseInfo baseInfo) {
LValue r;
@@ -265,6 +283,19 @@ class LValue {
return r;
}
+ static LValue makeExtVectorElt(Address vecAddress, mlir::ArrayAttr elts,
+ clang::QualType type,
+ LValueBaseInfo baseInfo) {
+ LValue r;
+ r.lvType = ExtVectorElt;
+ r.v = vecAddress.getPointer();
+ r.elementType = vecAddress.getElementType();
+ r.vectorElts = elts;
+ r.initialize(type, type.getQualifiers(), vecAddress.getAlignment(),
+ baseInfo);
+ return r;
+ }
+
// bitfield lvalue
Address getBitFieldAddress() const {
return Address(getBitFieldPointer(), elementType, getAlignment());
diff --git a/clang/test/CIR/CodeGen/vector-ext-element.cpp b/clang/test/CIR/CodeGen/vector-ext-element.cpp
new file mode 100644
index 0000000000000..de9d53936d2eb
--- /dev/null
+++ b/clang/test/CIR/CodeGen/vector-ext-element.cpp
@@ -0,0 +1,46 @@
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-cir %s -o %t.cir
+// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-llvm %s -o %t-cir.ll
+// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -emit-llvm %s -o %t.ll
+// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG
+
+typedef int vi4 __attribute__((ext_vector_type(4)));
+
+void element_expr_from_gl() {
+ vi4 a;
+ int x = a.x;
+ int y = a.y;
+}
+
+// CIR: %[[A_ADDR:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
+// CIR: %[[X_ADDR:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["x", init]
+// CIR: %[[Y_ADDR:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init]
+// CIR: %[[TMP_A:.*]] = cir.load {{.*}} %[[A_ADDR]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[CONST_0:.*]] = cir.const #cir.int<0> : !s64i
+// CIR: %[[ELEM_0:.*]] = cir.vec.extract %[[TMP_A]][%[[CONST_0]] : !s64i] : !cir.vector<4 x !s32i>
+// CIR: cir.store {{.*}} %[[ELEM_0]], %[[X_ADDR]] : !s32i, !cir.ptr<!s32i>
+// CIR: %[[TMP_A:.*]] = cir.load {{.*}} %[[A_ADDR]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s64i
+// CIR: %[[ELEM_1:.*]] = cir.vec.extract %[[TMP_A]][%[[CONST_1]] : !s64i] : !cir.vector<4 x !s32i>
+// CIR: cir.store {{.*}} %[[ELEM_1]], %[[Y_ADDR]] : !s32i, !cir.ptr<!s32i>
+
+// LLVM: %[[A_ADDR:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[X_ADDR:.*]] = alloca i32, i64 1, align 4
+// LLVM: %[[Y_ADDR:.*]] = alloca i32, i64 1, align 4
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[A_ADDR]], align 16
+// LLVM: %[[ELEM_0:.*]] = extractelement <4 x i32> %4, i64 0
+// LLVM: store i32 %[[ELEM_0]], ptr %[[X_ADDR]], align 4
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[A_ADDR]], align 16
+// LLVM: %[[ELEM_1:.*]] = extractelement <4 x i32> %6, i64 1
+// LLVM: store i32 %[[ELEM_1]], ptr %[[Y_ADDR]], align 4
+
+// OGCG: %[[A_ADDR:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[X_ADDR:.*]] = alloca i32, align 4
+// OGCG: %[[Y_ADDR:.*]] = alloca i32, align 4
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[A_ADDR]], align 16
+// OGCG: %[[ELEM_0:.*]] = extractelement <4 x i32> %[[TMP_A]], i64 0
+// OGCG: store i32 %[[ELEM_0]], ptr %[[X_ADDR]], align 4
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[A_ADDR]], align 16
+// OGCG: %[[ELEM_1:.*]] = extractelement <4 x i32> %[[TMP_A]], i64 1
+// OGCG: store i32 %[[ELEM_1]], ptr %[[Y_ADDR]], align 4
|
andykaylor
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, just a few nits
clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Outdated
| lv.getVectorIdx())); | ||
| } | ||
|
|
||
| if (lv.isExtVectorElt()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Braces aren't needed here.
clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Outdated
| auto elt = mlir::dyn_cast<mlir::IntegerAttr>(elts[idx]); | ||
| assert(elt && "The indices should be integer attributes"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| auto elt = mlir::dyn_cast<mlir::IntegerAttr>(elts[idx]); | |
| assert(elt && "The indices should be integer attributes"); | |
| auto elt = mlir::cast<mlir::IntegerAttr>(elts[idx]); |
clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Outdated
|
|
||
| if (base.isSimple()) { | ||
| SmallVector<int64_t> attrElts; | ||
| for (uint32_t i : indices) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Braces aren't needed here.
clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Outdated
| e->getEncodedElementAccess(indices); | ||
|
|
||
| if (base.isSimple()) { | ||
| SmallVector<int64_t> attrElts; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this work in place of the loop below?
| SmallVector<int64_t> attrElts; | |
| SmallVector<int64_t> attrElts(indicis.begin(), indices.end()); |
Upstream the basic support for the ExtVectorType element expr