From 7d90e64e5b5ec7ba07020471057a6a4932f6dc82 Mon Sep 17 00:00:00 2001 From: Gabor Spaits Date: Sun, 28 Sep 2025 23:58:13 +0200 Subject: [PATCH 1/3] Extend vector.reduce.add and splat transform to scalable vectors --- .../Transforms/InstCombine/InstCombineCalls.cpp | 14 ++++++++++---- .../Transforms/InstCombine/vector-reductions.ll | 7 ++++--- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index cf6d0ecab4f69..02b46b0161ad8 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -3785,13 +3785,19 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // vector.reduce.add.vNiM(splat(%x)) -> mul(%x, N) if (Value *Splat = getSplatValue(Arg)) { - ElementCount VecToReduceCount = - cast(Arg->getType())->getElementCount(); + VectorType *VecToReduceTy = cast(Arg->getType()); + ElementCount VecToReduceCount = VecToReduceTy->getElementCount(); + Value *RHS; if (VecToReduceCount.isFixed()) { unsigned VectorSize = VecToReduceCount.getFixedValue(); - return BinaryOperator::CreateMul( - Splat, ConstantInt::get(Splat->getType(), VectorSize)); + RHS = ConstantInt::get(Splat->getType(), VectorSize); } + + RHS = Builder.CreateElementCount(Type::getInt64Ty(II->getContext()), + VecToReduceCount); + if (Splat->getType() != RHS->getType()) + RHS = Builder.CreateZExtOrTrunc(RHS, Splat->getType()); + return BinaryOperator::CreateMul(Splat, RHS); } } [[fallthrough]]; diff --git a/llvm/test/Transforms/InstCombine/vector-reductions.ll b/llvm/test/Transforms/InstCombine/vector-reductions.ll index f1e0dd9bd06d7..34f0570c2698d 100644 --- a/llvm/test/Transforms/InstCombine/vector-reductions.ll +++ b/llvm/test/Transforms/InstCombine/vector-reductions.ll @@ -469,9 +469,10 @@ define i2 @constant_multiplied_7xi2(i2 %0) { define i32 @negative_scalable_vector(i32 %0) { ; CHECK-LABEL: @negative_scalable_vector( -; CHECK-NEXT: [[TMP2:%.*]] = insertelement poison, i32 [[TMP0:%.*]], i64 0 -; CHECK-NEXT: [[TMP3:%.*]] = shufflevector [[TMP2]], poison, zeroinitializer -; CHECK-NEXT: [[TMP4:%.*]] = tail call i32 @llvm.vector.reduce.add.nxv4i32( [[TMP3]]) +; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[DOTTR:%.*]] = trunc i64 [[TMP2]] to i32 +; CHECK-NEXT: [[TMP3:%.*]] = shl i32 [[DOTTR]], 2 +; CHECK-NEXT: [[TMP4:%.*]] = mul i32 [[TMP0:%.*]], [[TMP3]] ; CHECK-NEXT: ret i32 [[TMP4]] ; %2 = insertelement poison, i32 %0, i64 0 From f5189ce8500389e24894879b19c1f97208f0a36f Mon Sep 17 00:00:00 2001 From: Gabor Spaits Date: Mon, 29 Sep 2025 10:30:35 +0200 Subject: [PATCH 2/3] Throw out redundant baranch and redundant check --- .../lib/Transforms/InstCombine/InstCombineCalls.cpp | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 02b46b0161ad8..3eb472f53936e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -3787,16 +3787,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (Value *Splat = getSplatValue(Arg)) { VectorType *VecToReduceTy = cast(Arg->getType()); ElementCount VecToReduceCount = VecToReduceTy->getElementCount(); - Value *RHS; - if (VecToReduceCount.isFixed()) { - unsigned VectorSize = VecToReduceCount.getFixedValue(); - RHS = ConstantInt::get(Splat->getType(), VectorSize); - } - - RHS = Builder.CreateElementCount(Type::getInt64Ty(II->getContext()), - VecToReduceCount); - if (Splat->getType() != RHS->getType()) - RHS = Builder.CreateZExtOrTrunc(RHS, Splat->getType()); + Value *RHS = Builder.CreateElementCount( + Type::getInt64Ty(II->getContext()), VecToReduceCount); + RHS = Builder.CreateZExtOrTrunc(RHS, Splat->getType()); return BinaryOperator::CreateMul(Splat, RHS); } } From 556e00455de2ba25411169bec7f0252bfc7d433b Mon Sep 17 00:00:00 2001 From: Gabor Spaits Date: Mon, 29 Sep 2025 10:46:02 +0200 Subject: [PATCH 3/3] Add more tests --- .../InstCombine/vector-reductions.ll | 78 ++++++++++++++++++- 1 file changed, 76 insertions(+), 2 deletions(-) diff --git a/llvm/test/Transforms/InstCombine/vector-reductions.ll b/llvm/test/Transforms/InstCombine/vector-reductions.ll index 34f0570c2698d..56b3e5726d460 100644 --- a/llvm/test/Transforms/InstCombine/vector-reductions.ll +++ b/llvm/test/Transforms/InstCombine/vector-reductions.ll @@ -467,8 +467,8 @@ define i2 @constant_multiplied_7xi2(i2 %0) { ret i2 %4 } -define i32 @negative_scalable_vector(i32 %0) { -; CHECK-LABEL: @negative_scalable_vector( +define i32 @reduce_add_splat_to_mul_vscale_4xi32(i32 %0) { +; CHECK-LABEL: @reduce_add_splat_to_mul_vscale_4xi32( ; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64() ; CHECK-NEXT: [[DOTTR:%.*]] = trunc i64 [[TMP2]] to i32 ; CHECK-NEXT: [[TMP3:%.*]] = shl i32 [[DOTTR]], 2 @@ -480,3 +480,77 @@ define i32 @negative_scalable_vector(i32 %0) { %4 = tail call i32 @llvm.vector.reduce.add.nxv4i32( %3) ret i32 %4 } + +define i64 @reduce_add_splat_to_mul_vscale_4xi64(i64 %0) { +; CHECK-LABEL: @reduce_add_splat_to_mul_vscale_4xi64( +; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[TMP3:%.*]] = shl nuw i64 [[TMP2]], 2 +; CHECK-NEXT: [[TMP4:%.*]] = mul i64 [[TMP0:%.*]], [[TMP3]] +; CHECK-NEXT: ret i64 [[TMP4]] +; + %2 = insertelement poison, i64 %0, i64 0 + %3 = shufflevector %2, poison, zeroinitializer + %4 = tail call i64 @llvm.vector.reduce.add.nxv4i64( %3) + ret i64 %4 +} + +define i2 @reduce_add_splat_to_mul_vscale_4xi2(i2 %0) { +; CHECK-LABEL: @reduce_add_splat_to_mul_vscale_4xi2( +; CHECK-NEXT: ret i2 0 +; + %2 = insertelement poison, i2 %0, i64 0 + %3 = shufflevector %2, poison, zeroinitializer + %4 = tail call i2 @llvm.vector.reduce.add.nxv4i2( %3) + ret i2 %4 +} + +define i1 @reduce_add_splat_to_mul_vscale_8xi1(i1 %0) { +; CHECK-LABEL: @reduce_add_splat_to_mul_vscale_8xi1( +; CHECK-NEXT: ret i1 false +; + %2 = insertelement poison, i1 %0, i64 0 + %3 = shufflevector %2, poison, zeroinitializer + %4 = tail call i1 @llvm.vector.reduce.add.nxv8i1( %3) + ret i1 %4 +} + +define i2 @reduce_add_splat_to_mul_vscale_5xi2(i2 %0) { +; CHECK-LABEL: @reduce_add_splat_to_mul_vscale_5xi2( +; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[TMP3:%.*]] = trunc i64 [[TMP2]] to i2 +; CHECK-NEXT: [[TMP4:%.*]] = mul i2 [[TMP0:%.*]], [[TMP3]] +; CHECK-NEXT: ret i2 [[TMP4]] +; + %2 = insertelement poison, i2 %0, i64 0 + %3 = shufflevector %2, poison, zeroinitializer + %4 = tail call i2 @llvm.vector.reduce.add.nxv5i2( %3) + ret i2 %4 +} + +define i2 @reduce_add_splat_to_mul_vscale_6xi2(i2 %0) { +; CHECK-LABEL: @reduce_add_splat_to_mul_vscale_6xi2( +; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[DOTTR:%.*]] = trunc i64 [[TMP2]] to i2 +; CHECK-NEXT: [[TMP3:%.*]] = shl i2 [[DOTTR]], 1 +; CHECK-NEXT: [[TMP4:%.*]] = mul i2 [[TMP0:%.*]], [[TMP3]] +; CHECK-NEXT: ret i2 [[TMP4]] +; + %2 = insertelement poison, i2 %0, i64 0 + %3 = shufflevector %2, poison, zeroinitializer + %4 = tail call i2 @llvm.vector.reduce.add.nxv6i2( %3) + ret i2 %4 +} + +define i2 @reduce_add_splat_to_mul_vscale_7xi2(i2 %0) { +; CHECK-LABEL: @reduce_add_splat_to_mul_vscale_7xi2( +; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[TMP3:%.*]] = trunc i64 [[TMP2]] to i2 +; CHECK-NEXT: [[TMP4:%.*]] = mul i2 [[TMP0:%.*]], [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = sub i2 0, [[TMP4]] +; CHECK-NEXT: ret i2 [[TMP5]] +; + %2 = insertelement poison, i2 %0, i64 0 + %3 = shufflevector %2, poison, zeroinitializer + %4 = tail call i2 @llvm.vector.reduce.add.nxv7i2( %3) + ret i2 %4 +}