diff --git a/clang/lib/CIR/CodeGen/CIRGenBuilder.h b/clang/lib/CIR/CodeGen/CIRGenBuilder.h index f8d3d93e49075..aeb1a122429e2 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuilder.h +++ b/clang/lib/CIR/CodeGen/CIRGenBuilder.h @@ -472,13 +472,13 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy { //===--------------------------------------------------------------------===// // UnaryOp creation helpers //===--------------------------------------------------------------------===// - mlir::Value createNeg(mlir::Value value) { + mlir::Value createNeg(mlir::Value value, bool nsw = false) { if (auto intTy = mlir::dyn_cast(value.getType())) { // Source is a unsigned integer: first cast it to signed. if (intTy.isUnsigned()) value = createIntCast(value, getSIntNTy(intTy.getWidth())); - return createMinus(value.getLoc(), value); + return createMinus(value.getLoc(), value, nsw); } llvm_unreachable("negation for the given type is NYI"); diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp index 231039ec5da29..cfb18ec535d3d 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp @@ -668,10 +668,14 @@ class ScalarExprEmitter : public StmtVisitor { } } else if (const PointerType *ptr = type->getAs()) { QualType type = ptr->getPointeeType(); - if (cgf.getContext().getAsVariableArrayType(type)) { - // VLA types don't have constant size. - cgf.cgm.errorNYI(e->getSourceRange(), "Pointer arithmetic on VLA"); - return {}; + if (const VariableArrayType *vla = + cgf.getContext().getAsVariableArrayType(type)) { + mlir::Location loc = cgf.getLoc(e->getSourceRange()); + mlir::Value numElts = cgf.getVLASize(vla).numElts; + if (!e->isIncrementOp()) + numElts = cgf.getBuilder().createNeg(numElts, /*nsw=*/true); + assert(!cir::MissingFeatures::sanitizers()); + value = cgf.getBuilder().createPtrStride(loc, value, numElts); } else { // For everything else, we can just do a simple increment. mlir::Location loc = cgf.getLoc(e->getSourceRange()); @@ -1898,9 +1902,24 @@ static mlir::Value emitPointerArithmetic(CIRGenFunction &cgf, } QualType elementType = pointerType->getPointeeType(); - if (cgf.getContext().getAsVariableArrayType(elementType)) { - cgf.cgm.errorNYI("variable array type"); - return nullptr; + if (const VariableArrayType *vla = + cgf.getContext().getAsVariableArrayType(elementType)) { + mlir::Value numElements = cgf.getVLASize(vla).numElts; + mlir::Location loc = cgf.getLoc(op.e->getExprLoc()); + index = cgf.getBuilder().createCast(cir::CastKind::integral, index, + numElements.getType()); + // GEP indexes are signed, and scaling an index isn't permitted to + // signed-overflow, so we use the same semantics for our explicit + // multiply. We suppress this if overflow is not undefined behavior. + cir::OverflowBehavior overflowBehavior = + cgf.getLangOpts().PointerOverflowDefined + ? cir::OverflowBehavior::None + : cir::OverflowBehavior::NoSignedWrap; + index = + cgf.getBuilder().createMul(loc, index, numElements, overflowBehavior); + assert(!cir::MissingFeatures::sanitizers()); + return cir::PtrStrideOp::create(cgf.getBuilder(), loc, pointer.getType(), + pointer, index); } assert(!cir::MissingFeatures::sanitizers()); diff --git a/clang/test/CIR/CodeGen/vla-pointer-arith.c b/clang/test/CIR/CodeGen/vla-pointer-arith.c new file mode 100644 index 0000000000000..89c7083924002 --- /dev/null +++ b/clang/test/CIR/CodeGen/vla-pointer-arith.c @@ -0,0 +1,75 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir +// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t-cir.ll +// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-llvm %s -o %t.ll +// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG + +// Test pointer arithmetic on VLA types. + +void test_vla_ptr_add(int n, int i) { + int arr[n]; + int (*p)[n] = &arr; + p = p + i; +} + +// CIR-LABEL: @test_vla_ptr_add +// CIR: cir.alloca {{.*}} ["arr"] +// CIR: %[[N2:.*]] = cir.load{{.*}} !cir.ptr, !s32i +// CIR: %[[VLA_SIZE:.*]] = cir.cast integral %[[N2]] : !s32i -> !u64i +// CIR: %[[P:.*]] = cir.load{{.*}} !cir.ptr>, !cir.ptr +// CIR: %[[I:.*]] = cir.load{{.*}} !cir.ptr, !s32i +// CIR: %[[I_EXT:.*]] = cir.cast integral %[[I]] : !s32i -> !u64i +// CIR: %[[SCALED:.*]] = cir.mul nsw %[[I_EXT]], %[[VLA_SIZE]] : !u64i +// CIR: cir.ptr_stride %[[P]], %[[SCALED]] : (!cir.ptr, !u64i) -> !cir.ptr + +// LLVM-LABEL: @test_vla_ptr_add +// LLVM: %[[SCALED:.*]] = mul nsw i64 %{{.*}}, %{{.*}} +// LLVM: getelementptr i32, ptr %{{.*}}, i64 %[[SCALED]] + +// OGCG-LABEL: @test_vla_ptr_add +// OGCG: %[[IDX:.*]] = mul nsw i64 %{{.*}}, %{{.*}} +// OGCG: getelementptr inbounds i32, ptr %{{.*}}, i64 %[[IDX]] + +void test_vla_ptr_inc(int n) { + int arr[n]; + int (*p)[n] = &arr; + p++; +} + +// CIR-LABEL: @test_vla_ptr_inc +// CIR: cir.alloca {{.*}} ["arr"] +// CIR: %[[N2:.*]] = cir.load{{.*}} !cir.ptr, !s32i +// CIR: %[[VLA_SIZE:.*]] = cir.cast integral %[[N2]] : !s32i -> !u64i +// CIR: %[[P:.*]] = cir.load{{.*}} !cir.ptr>, !cir.ptr +// CIR: cir.ptr_stride %[[P]], %[[VLA_SIZE]] : (!cir.ptr, !u64i) -> !cir.ptr + +// LLVM-LABEL: @test_vla_ptr_inc +// LLVM: getelementptr i32, ptr %{{.*}}, i64 %{{.*}} + +// OGCG-LABEL: @test_vla_ptr_inc +// OGCG: getelementptr inbounds nuw i32, ptr %{{.*}}, i64 %{{.*}} + +void test_vla_ptr_dec(int n) { + int arr[n]; + int (*p)[n] = &arr; + p--; +} + +// CIR-LABEL: @test_vla_ptr_dec +// CIR: cir.alloca {{.*}} ["arr"] +// CIR: %[[N2:.*]] = cir.load{{.*}} !cir.ptr, !s32i +// CIR: %[[VLA_SIZE:.*]] = cir.cast integral %[[N2]] : !s32i -> !u64i +// CIR: %[[P:.*]] = cir.load{{.*}} !cir.ptr>, !cir.ptr +// CIR: %[[SIGNED:.*]] = cir.cast integral %[[VLA_SIZE]] : !u64i -> !s64i +// CIR: %[[NEG:.*]] = cir.minus nsw %[[SIGNED]] : !s64i +// CIR: cir.ptr_stride %[[P]], %[[NEG]] : (!cir.ptr, !s64i) -> !cir.ptr + +// LLVM-LABEL: @test_vla_ptr_dec +// LLVM: %[[NEG:.*]] = sub nsw i64 0, %{{.*}} +// LLVM: getelementptr i32, ptr %{{.*}}, i64 %[[NEG]] + +// OGCG-LABEL: @test_vla_ptr_dec +// OGCG: %[[NEG:.*]] = sub nsw i64 0, %{{.*}} +// OGCG: getelementptr inbounds i32, ptr %{{.*}}, i64 %[[NEG]] +