Skip to content

Conversation

@AmrDeveloper
Copy link
Member

Upstream the basic support for the ExtVectorType element expr

@llvmbot llvmbot added clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project labels Nov 11, 2025
@llvmbot
Copy link
Member

llvmbot commented Nov 11, 2025

@llvm/pr-subscribers-clangir

@llvm/pr-subscribers-clang

Author: Amr Hesham (AmrDeveloper)

Changes

Upstream the basic support for the ExtVectorType element expr


Full diff: https://github.com/llvm/llvm-project/pull/167570.diff

6 Files Affected:

  • (modified) clang/lib/CIR/CodeGen/CIRGenExpr.cpp (+85)
  • (modified) clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp (+2)
  • (modified) clang/lib/CIR/CodeGen/CIRGenFunction.cpp (+2)
  • (modified) clang/lib/CIR/CodeGen/CIRGenFunction.h (+6)
  • (modified) clang/lib/CIR/CodeGen/CIRGenValue.h (+32-1)
  • (added) clang/test/CIR/CodeGen/vector-ext-element.cpp (+46)
diff --git a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
index 422fa1cf5ad2e..f898511ce2fed 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
@@ -627,10 +627,51 @@ RValue CIRGenFunction::emitLoadOfLValue(LValue lv, SourceLocation loc) {
                                                  lv.getVectorIdx()));
   }
 
+  if (lv.isExtVectorElt()) {
+    return emitLoadOfExtVectorElementLValue(lv);
+  }
+
   cgm.errorNYI(loc, "emitLoadOfLValue");
   return RValue::get(nullptr);
 }
 
+int64_t CIRGenFunction::getAccessedFieldNo(unsigned int idx,
+                                           const mlir::ArrayAttr elts) {
+  auto elt = mlir::dyn_cast<mlir::IntegerAttr>(elts[idx]);
+  assert(elt && "The indices should be integer attributes");
+  return elt.getInt();
+}
+
+// If this is a reference to a subset of the elements of a vector, create an
+// appropriate shufflevector.
+RValue CIRGenFunction::emitLoadOfExtVectorElementLValue(LValue lv) {
+  mlir::Location loc = lv.getExtVectorPointer().getLoc();
+  mlir::Value vec = builder.createLoad(loc, lv.getExtVectorAddress());
+
+  // HLSL allows treating scalars as one-element vectors. Converting the scalar
+  // IR value to a vector here allows the rest of codegen to behave as normal.
+  if (getLangOpts().HLSL && !mlir::isa<cir::VectorType>(vec.getType())) {
+    cgm.errorNYI(loc, "emitLoadOfExtVectorElementLValue: HLSL");
+    return {};
+  }
+
+  const mlir::ArrayAttr elts = lv.getExtVectorElts();
+
+  // If the result of the expression is a non-vector type, we must be extracting
+  // a single element. Just codegen as an extractelement.
+  const auto *exprVecTy = lv.getType()->getAs<clang::VectorType>();
+  if (!exprVecTy) {
+    int64_t indexValue = getAccessedFieldNo(0, elts);
+    cir::ConstantOp index =
+        builder.getConstInt(loc, builder.getSInt64Ty(), indexValue);
+    return RValue::get(cir::VecExtractOp::create(builder, loc, vec, index));
+  }
+
+  cgm.errorNYI(
+      loc, "emitLoadOfExtVectorElementLValue: Result of expr is vector type");
+  return {};
+}
+
 static cir::FuncOp emitFunctionDeclPointer(CIRGenModule &cgm, GlobalDecl gd) {
   assert(!cir::MissingFeatures::weakRefReference());
   return cgm.getAddrOfFunction(gd);
@@ -1116,6 +1157,50 @@ CIRGenFunction::emitArraySubscriptExpr(const clang::ArraySubscriptExpr *e) {
   return lv;
 }
 
+LValue CIRGenFunction::emitExtVectorElementExpr(const ExtVectorElementExpr *e) {
+  // Emit the base vector as an l-value.
+  LValue base;
+
+  // ExtVectorElementExpr's base can either be a vector or pointer to vector.
+  if (e->isArrow()) {
+    cgm.errorNYI(e->getSourceRange(),
+                 "emitExtVectorElementExpr: pointer to vector");
+    return {};
+  } else if (e->getBase()->isGLValue()) {
+    // Otherwise, if the base is an lvalue ( as in the case of foo.x.x),
+    // emit the base as an lvalue.
+    assert(e->getBase()->getType()->isVectorType());
+    base = emitLValue(e->getBase());
+  } else {
+    // Otherwise, the base is a normal rvalue (as in (V+V).x), emit it as such.
+    cgm.errorNYI(e->getSourceRange(),
+                 "emitExtVectorElementExpr: base is a normal rvalue");
+    return {};
+  }
+
+  QualType type =
+      e->getType().withCVRQualifiers(base.getQuals().getCVRQualifiers());
+
+  // Encode the element access list into a vector of unsigned indices.
+  SmallVector<uint32_t, 4> indices;
+  e->getEncodedElementAccess(indices);
+
+  if (base.isSimple()) {
+    SmallVector<int64_t> attrElts;
+    for (uint32_t i : indices) {
+      attrElts.push_back(static_cast<int64_t>(i));
+    }
+
+    mlir::ArrayAttr elts = builder.getI64ArrayAttr(attrElts);
+    return LValue::makeExtVectorElt(base.getAddress(), elts, type,
+                                    base.getBaseInfo());
+  }
+
+  cgm.errorNYI(e->getSourceRange(),
+               "emitExtVectorElementExpr: isSimple is false");
+  return {};
+}
+
 LValue CIRGenFunction::emitStringLiteralLValue(const StringLiteral *e,
                                                llvm::StringRef name) {
   cir::GlobalOp globalOp = cgm.getGlobalForStringLiteral(e, name);
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
index c1a36134d8942..eeb766ff9f9fc 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
@@ -277,6 +277,8 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
                                 e->getSourceRange().getBegin());
   }
 
+  mlir::Value VisitExtVectorElementExpr(Expr *e) { return emitLoadOfLValue(e); }
+
   mlir::Value VisitMemberExpr(MemberExpr *e);
 
   mlir::Value VisitCompoundLiteralExpr(CompoundLiteralExpr *e) {
diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.cpp b/clang/lib/CIR/CodeGen/CIRGenFunction.cpp
index 5d5209b9ffb60..5d3040f3b10eb 100644
--- a/clang/lib/CIR/CodeGen/CIRGenFunction.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenFunction.cpp
@@ -883,6 +883,8 @@ LValue CIRGenFunction::emitLValue(const Expr *e) {
     return emitConditionalOperatorLValue(cast<BinaryConditionalOperator>(e));
   case Expr::ArraySubscriptExprClass:
     return emitArraySubscriptExpr(cast<ArraySubscriptExpr>(e));
+  case Expr::ExtVectorElementExprClass:
+    return emitExtVectorElementExpr(cast<ExtVectorElementExpr>(e));
   case Expr::UnaryOperatorClass:
     return emitUnaryOpLValue(cast<UnaryOperator>(e));
   case Expr::StringLiteralClass:
diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h
index f879e580989f7..58c98c077bd0e 100644
--- a/clang/lib/CIR/CodeGen/CIRGenFunction.h
+++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h
@@ -1264,6 +1264,8 @@ class CIRGenFunction : public CIRGenTypeCache {
                               QualType &baseType, Address &addr);
   LValue emitArraySubscriptExpr(const clang::ArraySubscriptExpr *e);
 
+  LValue emitExtVectorElementExpr(const ExtVectorElementExpr *e);
+
   Address emitArrayToPointerDecay(const Expr *e,
                                   LValueBaseInfo *baseInfo = nullptr);
 
@@ -1329,6 +1331,8 @@ class CIRGenFunction : public CIRGenTypeCache {
                                               mlir::Value emittedE,
                                               bool isDynamic);
 
+  int64_t getAccessedFieldNo(unsigned idx, mlir::ArrayAttr elts);
+
   RValue emitCall(const CIRGenFunctionInfo &funcInfo,
                   const CIRGenCallee &callee, ReturnValueSlot returnValue,
                   const CallArgList &args, cir::CIRCallOpInterface *callOp,
@@ -1624,6 +1628,8 @@ class CIRGenFunction : public CIRGenTypeCache {
   /// Load a complex number from the specified l-value.
   mlir::Value emitLoadOfComplex(LValue src, SourceLocation loc);
 
+  RValue emitLoadOfExtVectorElementLValue(LValue lv);
+
   /// Given an expression that represents a value lvalue, this method emits
   /// the address of the lvalue, then loads the result as an rvalue,
   /// returning the rvalue.
diff --git a/clang/lib/CIR/CodeGen/CIRGenValue.h b/clang/lib/CIR/CodeGen/CIRGenValue.h
index ab245a771d72c..20a3d0ef61341 100644
--- a/clang/lib/CIR/CodeGen/CIRGenValue.h
+++ b/clang/lib/CIR/CodeGen/CIRGenValue.h
@@ -166,7 +166,8 @@ class LValue {
   // this is the alignment of the whole vector)
   unsigned alignment;
   mlir::Value v;
-  mlir::Value vectorIdx; // Index for vector subscript
+  mlir::Value vectorIdx;      // Index for vector subscript
+  mlir::Attribute vectorElts; // ExtVector element subset: V.xyx
   mlir::Type elementType;
   LValueBaseInfo baseInfo;
   const CIRGenBitFieldInfo *bitFieldInfo{nullptr};
@@ -190,6 +191,7 @@ class LValue {
   bool isSimple() const { return lvType == Simple; }
   bool isVectorElt() const { return lvType == VectorElt; }
   bool isBitField() const { return lvType == BitField; }
+  bool isExtVectorElt() const { return lvType == ExtVectorElt; }
   bool isGlobalReg() const { return lvType == GlobalReg; }
   bool isVolatile() const { return quals.hasVolatile(); }
 
@@ -254,6 +256,22 @@ class LValue {
     return vectorIdx;
   }
 
+  // extended vector elements.
+  Address getExtVectorAddress() const {
+    assert(isExtVectorElt());
+    return Address(getExtVectorPointer(), elementType, getAlignment());
+  }
+
+  mlir::Value getExtVectorPointer() const {
+    assert(isExtVectorElt());
+    return v;
+  }
+
+  mlir::ArrayAttr getExtVectorElts() const {
+    assert(isExtVectorElt());
+    return mlir::cast<mlir::ArrayAttr>(vectorElts);
+  }
+
   static LValue makeVectorElt(Address vecAddress, mlir::Value index,
                               clang::QualType t, LValueBaseInfo baseInfo) {
     LValue r;
@@ -265,6 +283,19 @@ class LValue {
     return r;
   }
 
+  static LValue makeExtVectorElt(Address vecAddress, mlir::ArrayAttr elts,
+                                 clang::QualType type,
+                                 LValueBaseInfo baseInfo) {
+    LValue r;
+    r.lvType = ExtVectorElt;
+    r.v = vecAddress.getPointer();
+    r.elementType = vecAddress.getElementType();
+    r.vectorElts = elts;
+    r.initialize(type, type.getQualifiers(), vecAddress.getAlignment(),
+                 baseInfo);
+    return r;
+  }
+
   // bitfield lvalue
   Address getBitFieldAddress() const {
     return Address(getBitFieldPointer(), elementType, getAlignment());
diff --git a/clang/test/CIR/CodeGen/vector-ext-element.cpp b/clang/test/CIR/CodeGen/vector-ext-element.cpp
new file mode 100644
index 0000000000000..de9d53936d2eb
--- /dev/null
+++ b/clang/test/CIR/CodeGen/vector-ext-element.cpp
@@ -0,0 +1,46 @@
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-cir %s -o %t.cir
+// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-llvm %s -o %t-cir.ll
+// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -emit-llvm %s -o %t.ll
+// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG
+
+typedef int vi4 __attribute__((ext_vector_type(4)));
+
+void element_expr_from_gl() {
+  vi4 a;
+  int x = a.x;
+  int y = a.y;
+}
+
+// CIR: %[[A_ADDR:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
+// CIR: %[[X_ADDR:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["x", init]
+// CIR: %[[Y_ADDR:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init]
+// CIR: %[[TMP_A:.*]] = cir.load {{.*}} %[[A_ADDR]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[CONST_0:.*]] = cir.const #cir.int<0> : !s64i
+// CIR: %[[ELEM_0:.*]] = cir.vec.extract %[[TMP_A]][%[[CONST_0]] : !s64i] : !cir.vector<4 x !s32i>
+// CIR: cir.store {{.*}} %[[ELEM_0]], %[[X_ADDR]] : !s32i, !cir.ptr<!s32i>
+// CIR: %[[TMP_A:.*]] = cir.load {{.*}} %[[A_ADDR]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s64i
+// CIR: %[[ELEM_1:.*]] = cir.vec.extract %[[TMP_A]][%[[CONST_1]] : !s64i] : !cir.vector<4 x !s32i>
+// CIR: cir.store {{.*}} %[[ELEM_1]], %[[Y_ADDR]] : !s32i, !cir.ptr<!s32i>
+
+// LLVM: %[[A_ADDR:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[X_ADDR:.*]] = alloca i32, i64 1, align 4
+// LLVM: %[[Y_ADDR:.*]] = alloca i32, i64 1, align 4
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[A_ADDR]], align 16
+// LLVM: %[[ELEM_0:.*]] = extractelement <4 x i32> %4, i64 0
+// LLVM: store i32 %[[ELEM_0]], ptr %[[X_ADDR]], align 4
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[A_ADDR]], align 16
+// LLVM: %[[ELEM_1:.*]] = extractelement <4 x i32> %6, i64 1
+// LLVM: store i32 %[[ELEM_1]], ptr %[[Y_ADDR]], align 4
+
+// OGCG: %[[A_ADDR:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[X_ADDR:.*]] = alloca i32, align 4
+// OGCG: %[[Y_ADDR:.*]] = alloca i32, align 4
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[A_ADDR]], align 16
+// OGCG: %[[ELEM_0:.*]] = extractelement <4 x i32> %[[TMP_A]], i64 0
+// OGCG: store i32 %[[ELEM_0]], ptr %[[X_ADDR]], align 4
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[A_ADDR]], align 16
+// OGCG: %[[ELEM_1:.*]] = extractelement <4 x i32> %[[TMP_A]], i64 1
+// OGCG: store i32 %[[ELEM_1]], ptr %[[Y_ADDR]], align 4

Copy link
Contributor

@andykaylor andykaylor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, just a few nits

lv.getVectorIdx()));
}

if (lv.isExtVectorElt()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Braces aren't needed here.

Comment on lines 640 to 641
auto elt = mlir::dyn_cast<mlir::IntegerAttr>(elts[idx]);
assert(elt && "The indices should be integer attributes");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto elt = mlir::dyn_cast<mlir::IntegerAttr>(elts[idx]);
assert(elt && "The indices should be integer attributes");
auto elt = mlir::cast<mlir::IntegerAttr>(elts[idx]);


if (base.isSimple()) {
SmallVector<int64_t> attrElts;
for (uint32_t i : indices) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Braces aren't needed here.

e->getEncodedElementAccess(indices);

if (base.isSimple()) {
SmallVector<int64_t> attrElts;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work in place of the loop below?

Suggested change
SmallVector<int64_t> attrElts;
SmallVector<int64_t> attrElts(indicis.begin(), indices.end());

@AmrDeveloper AmrDeveloper enabled auto-merge (squash) November 13, 2025 16:54
@AmrDeveloper AmrDeveloper merged commit 9216e17 into llvm:main Nov 13, 2025
8 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants