From bc5035890381f1fb52cebf4393c87172af1247b7 Mon Sep 17 00:00:00 2001 From: Zach Goldthorpe Date: Fri, 3 Oct 2025 15:44:19 -0500 Subject: [PATCH 1/4] Implemented InstCombine pattern. --- .../InstCombine/InstCombineCasts.cpp | 44 +++ .../InstCombine/fold-selective-shift.ll | 285 ++++++++++++++++++ 2 files changed, 329 insertions(+) create mode 100644 llvm/test/Transforms/InstCombine/fold-selective-shift.ll diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 9ca8194b44f8f..fb313e2a7eb22 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -756,6 +756,47 @@ static Instruction *shrinkInsertElt(CastInst &Trunc, return nullptr; } +/// Let N = 2 * M. +/// Given an N-bit integer representing a pack of two M-bit integers, +/// we can select one of the packed integers by right-shifting by either zero or +/// M, and then truncating the result to M bits. +/// +/// This function folds this shift-and-truncate into a select instruction, +/// enabling further simplification. +static Instruction *foldPackSelectingShift(TruncInst &Trunc, + InstCombinerImpl &IC) { + + const uint64_t BitWidth = Trunc.getDestTy()->getScalarSizeInBits(); + if (!isPowerOf2_64(BitWidth)) + return nullptr; + if (Trunc.getSrcTy()->getScalarSizeInBits() < 2 * BitWidth) + return nullptr; + + Value *Upper, *Lower, *ShrAmt; + if (!match(Trunc.getOperand(0), + m_OneUse(m_Shr( + m_OneUse(m_DisjointOr( + m_OneUse(m_Shl(m_Value(Upper), m_SpecificInt(BitWidth))), + m_Value(Lower))), + m_Value(ShrAmt))))) + return nullptr; + + KnownBits KnownLower = IC.computeKnownBits(Lower, nullptr); + if (!KnownLower.getMaxValue().isIntN(BitWidth)) + return nullptr; + + KnownBits KnownShr = IC.computeKnownBits(ShrAmt, nullptr); + if ((~KnownShr.Zero).getZExtValue() != BitWidth) + return nullptr; + + Value *ShrAmtZ = + IC.Builder.CreateICmpEQ(ShrAmt, Constant::getNullValue(Trunc.getSrcTy()), + ShrAmt->getName() + ".z"); + Value *Select = IC.Builder.CreateSelect(ShrAmtZ, Lower, Upper); + Select->takeName(Trunc.getOperand(0)); + return CastInst::CreateTruncOrBitCast(Select, Trunc.getDestTy()); +} + Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { if (Instruction *Result = commonCastTransforms(Trunc)) return Result; @@ -907,6 +948,9 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { if (Instruction *I = shrinkInsertElt(Trunc, Builder)) return I; + if (Instruction *I = foldPackSelectingShift(Trunc, *this)) + return I; + if (Src->hasOneUse() && (isa(SrcTy) || shouldChangeType(SrcTy, DestTy))) { // Transform "trunc (shl X, cst)" -> "shl (trunc X), cst" so long as the diff --git a/llvm/test/Transforms/InstCombine/fold-selective-shift.ll b/llvm/test/Transforms/InstCombine/fold-selective-shift.ll new file mode 100644 index 0000000000000..3baeff3871e16 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/fold-selective-shift.ll @@ -0,0 +1,285 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6 +; RUN: opt -passes=instcombine %s -S | FileCheck %s + +declare void @clobber.i32(i32) + +define i16 @selective_shift_16(i32 %mask, i16 %upper, i16 %lower) { +; CHECK-LABEL: define i16 @selective_shift_16( +; CHECK-SAME: i32 [[MASK:%.*]], i16 [[UPPER:%.*]], i16 [[LOWER:%.*]]) { +; CHECK-NEXT: [[MASK_BIT:%.*]] = and i32 [[MASK]], 16 +; CHECK-NEXT: [[MASK_BIT_Z:%.*]] = icmp eq i32 [[MASK_BIT]], 0 +; CHECK-NEXT: [[SEL_V:%.*]] = select i1 [[MASK_BIT_Z]], i16 [[LOWER]], i16 [[UPPER]] +; CHECK-NEXT: ret i16 [[SEL_V]] +; + %upper.zext = zext i16 %upper to i32 + %upper.shl = shl nuw i32 %upper.zext, 16 + %lower.zext = zext i16 %lower to i32 + %pack = or disjoint i32 %upper.shl, %lower.zext + %mask.bit = and i32 %mask, 16 + %sel = lshr i32 %pack, %mask.bit + %trunc = trunc i32 %sel to i16 + ret i16 %trunc +} + +define i16 @selective_shift_16.commute(i32 %mask, i16 %upper, i16 %lower) { +; CHECK-LABEL: define i16 @selective_shift_16.commute( +; CHECK-SAME: i32 [[MASK:%.*]], i16 [[UPPER:%.*]], i16 [[LOWER:%.*]]) { +; CHECK-NEXT: [[MASK_BIT:%.*]] = and i32 [[MASK]], 16 +; CHECK-NEXT: [[MASK_BIT_Z:%.*]] = icmp eq i32 [[MASK_BIT]], 0 +; CHECK-NEXT: [[SEL_V:%.*]] = select i1 [[MASK_BIT_Z]], i16 [[LOWER]], i16 [[UPPER]] +; CHECK-NEXT: ret i16 [[SEL_V]] +; + %upper.zext = zext i16 %upper to i32 + %upper.shl = shl nuw i32 %upper.zext, 16 + %lower.zext = zext i16 %lower to i32 + %pack = or disjoint i32 %lower.zext, %upper.shl + %mask.bit = and i32 %mask, 16 + %sel = lshr i32 %pack, %mask.bit + %trunc = trunc i32 %sel to i16 + ret i16 %trunc +} + +define i16 @selective_shift_16_range(i32 %mask, i32 %upper, i32 range(i32 0, 65536) %lower) { +; CHECK-LABEL: define i16 @selective_shift_16_range( +; CHECK-SAME: i32 [[MASK:%.*]], i32 [[UPPER:%.*]], i32 range(i32 0, 65536) [[LOWER:%.*]]) { +; CHECK-NEXT: [[MASK_BIT:%.*]] = and i32 [[MASK]], 16 +; CHECK-NEXT: [[MASK_BIT_Z:%.*]] = icmp eq i32 [[MASK_BIT]], 0 +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[MASK_BIT_Z]], i32 [[LOWER]], i32 [[UPPER]] +; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[SEL]] to i16 +; CHECK-NEXT: ret i16 [[TRUNC]] +; + %upper.shl = shl nuw i32 %upper, 16 + %pack = or disjoint i32 %upper.shl, %lower + %mask.bit = and i32 %mask, 16 + %sel = lshr i32 %pack, %mask.bit + %trunc = trunc i32 %sel to i16 + ret i16 %trunc +} + +define <2 x i16> @selective_shift_v16(<2 x i32> %mask, <2 x i16> %upper, <2 x i16> %lower) { +; CHECK-LABEL: define <2 x i16> @selective_shift_v16( +; CHECK-SAME: <2 x i32> [[MASK:%.*]], <2 x i16> [[UPPER:%.*]], <2 x i16> [[LOWER:%.*]]) { +; CHECK-NEXT: [[MASK_BIT:%.*]] = and <2 x i32> [[MASK]], splat (i32 16) +; CHECK-NEXT: [[MASK_BIT_Z:%.*]] = icmp eq <2 x i32> [[MASK_BIT]], zeroinitializer +; CHECK-NEXT: [[SEL_V:%.*]] = select <2 x i1> [[MASK_BIT_Z]], <2 x i16> [[LOWER]], <2 x i16> [[UPPER]] +; CHECK-NEXT: ret <2 x i16> [[SEL_V]] +; + %upper.zext = zext <2 x i16> %upper to <2 x i32> + %upper.shl = shl nuw <2 x i32> %upper.zext, splat(i32 16) + %lower.zext = zext <2 x i16> %lower to <2 x i32> + %pack = or disjoint <2 x i32> %upper.shl, %lower.zext + %mask.bit = and <2 x i32> %mask, splat(i32 16) + %sel = lshr <2 x i32> %pack, %mask.bit + %trunc = trunc <2 x i32> %sel to <2 x i16> + ret <2 x i16> %trunc +} + +define i16 @selective_shift_16.wide(i64 %mask, i16 %upper, i16 %lower) { +; CHECK-LABEL: define i16 @selective_shift_16.wide( +; CHECK-SAME: i64 [[MASK:%.*]], i16 [[UPPER:%.*]], i16 [[LOWER:%.*]]) { +; CHECK-NEXT: [[MASK_BIT:%.*]] = and i64 [[MASK]], 16 +; CHECK-NEXT: [[MASK_BIT_Z:%.*]] = icmp eq i64 [[MASK_BIT]], 0 +; CHECK-NEXT: [[SEL_V:%.*]] = select i1 [[MASK_BIT_Z]], i16 [[LOWER]], i16 [[UPPER]] +; CHECK-NEXT: ret i16 [[SEL_V]] +; + %upper.zext = zext i16 %upper to i64 + %upper.shl = shl nuw i64 %upper.zext, 16 + %lower.zext = zext i16 %lower to i64 + %pack = or disjoint i64 %upper.shl, %lower.zext + %mask.bit = and i64 %mask, 16 + %sel = lshr i64 %pack, %mask.bit + %trunc = trunc i64 %sel to i16 + ret i16 %trunc +} + +; narrow zext type blocks fold +define i16 @selective_shift_16.narrow(i24 %mask, i16 %upper, i16 %lower) { +; CHECK-LABEL: define i16 @selective_shift_16.narrow( +; CHECK-SAME: i24 [[MASK:%.*]], i16 [[UPPER:%.*]], i16 [[LOWER:%.*]]) { +; CHECK-NEXT: [[UPPER_ZEXT:%.*]] = zext i16 [[UPPER]] to i24 +; CHECK-NEXT: [[UPPER_SHL:%.*]] = shl i24 [[UPPER_ZEXT]], 16 +; CHECK-NEXT: [[LOWER_ZEXT:%.*]] = zext i16 [[LOWER]] to i24 +; CHECK-NEXT: [[PACK:%.*]] = or disjoint i24 [[UPPER_SHL]], [[LOWER_ZEXT]] +; CHECK-NEXT: [[MASK_BIT:%.*]] = and i24 [[MASK]], 16 +; CHECK-NEXT: [[SEL:%.*]] = lshr i24 [[PACK]], [[MASK_BIT]] +; CHECK-NEXT: [[TRUNC:%.*]] = trunc i24 [[SEL]] to i16 +; CHECK-NEXT: ret i16 [[TRUNC]] +; + %upper.zext = zext i16 %upper to i24 + %upper.shl = shl i24 %upper.zext, 16 + %lower.zext = zext i16 %lower to i24 + %pack = or disjoint i24 %upper.shl, %lower.zext + %mask.bit = and i24 %mask, 16 + %sel = lshr i24 %pack, %mask.bit + %trunc = trunc i24 %sel to i16 + ret i16 %trunc +} + +; %lower's upper bits block fold +define i16 @selective_shift_16_norange(i32 %mask, i32 %upper, i32 %lower) { +; CHECK-LABEL: define i16 @selective_shift_16_norange( +; CHECK-SAME: i32 [[MASK:%.*]], i32 [[UPPER:%.*]], i32 [[LOWER:%.*]]) { +; CHECK-NEXT: [[UPPER_SHL:%.*]] = shl nuw i32 [[UPPER]], 16 +; CHECK-NEXT: [[PACK:%.*]] = or i32 [[UPPER_SHL]], [[LOWER]] +; CHECK-NEXT: [[MASK_BIT:%.*]] = and i32 [[MASK]], 16 +; CHECK-NEXT: [[SEL:%.*]] = lshr i32 [[PACK]], [[MASK_BIT]] +; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[SEL]] to i16 +; CHECK-NEXT: ret i16 [[TRUNC]] +; + %upper.shl = shl nuw i32 %upper, 16 + %pack = or i32 %upper.shl, %lower + %mask.bit = and i32 %mask, 16 + %sel = lshr i32 %pack, %mask.bit + %trunc = trunc i32 %sel to i16 + ret i16 %trunc +} + +define i16 @selective_shift_16.mu.0(i32 %mask, i16 %upper, i16 %lower) { +; CHECK-LABEL: define i16 @selective_shift_16.mu.0( +; CHECK-SAME: i32 [[MASK:%.*]], i16 [[UPPER:%.*]], i16 [[LOWER:%.*]]) { +; CHECK-NEXT: [[UPPER_ZEXT:%.*]] = zext i16 [[UPPER]] to i32 +; CHECK-NEXT: call void @clobber.i32(i32 [[UPPER_ZEXT]]) +; CHECK-NEXT: [[LOWER_ZEXT:%.*]] = zext i16 [[LOWER]] to i32 +; CHECK-NEXT: call void @clobber.i32(i32 [[LOWER_ZEXT]]) +; CHECK-NEXT: [[MASK_BIT:%.*]] = and i32 [[MASK]], 16 +; CHECK-NEXT: [[MASK_BIT_Z:%.*]] = icmp eq i32 [[MASK_BIT]], 0 +; CHECK-NEXT: [[TRUNC:%.*]] = select i1 [[MASK_BIT_Z]], i16 [[LOWER]], i16 [[UPPER]] +; CHECK-NEXT: ret i16 [[TRUNC]] +; + %upper.zext = zext i16 %upper to i32 + call void @clobber.i32(i32 %upper.zext) + %upper.shl = shl nuw i32 %upper.zext, 16 + %lower.zext = zext i16 %lower to i32 + call void @clobber.i32(i32 %lower.zext) + %pack = or disjoint i32 %upper.shl, %lower.zext + %mask.bit = and i32 %mask, 16 + %sel = lshr i32 %pack, %mask.bit + %trunc = trunc i32 %sel to i16 + ret i16 %trunc +} + +; multi-use of %pack blocks fold +define i16 @selective_shift_16.mu.1(i32 %mask, i16 %upper, i16 %lower) { +; CHECK-LABEL: define i16 @selective_shift_16.mu.1( +; CHECK-SAME: i32 [[MASK:%.*]], i16 [[UPPER:%.*]], i16 [[LOWER:%.*]]) { +; CHECK-NEXT: [[UPPER_ZEXT:%.*]] = zext i16 [[UPPER]] to i32 +; CHECK-NEXT: [[UPPER_SHL:%.*]] = shl nuw i32 [[UPPER_ZEXT]], 16 +; CHECK-NEXT: [[LOWER_ZEXT:%.*]] = zext i16 [[LOWER]] to i32 +; CHECK-NEXT: [[PACK:%.*]] = or disjoint i32 [[UPPER_SHL]], [[LOWER_ZEXT]] +; CHECK-NEXT: call void @clobber.i32(i32 [[PACK]]) +; CHECK-NEXT: [[MASK_BIT:%.*]] = and i32 [[MASK]], 16 +; CHECK-NEXT: [[SEL:%.*]] = lshr i32 [[PACK]], [[MASK_BIT]] +; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[SEL]] to i16 +; CHECK-NEXT: ret i16 [[TRUNC]] +; + %upper.zext = zext i16 %upper to i32 + %upper.shl = shl nuw i32 %upper.zext, 16 + %lower.zext = zext i16 %lower to i32 + %pack = or disjoint i32 %upper.shl, %lower.zext + call void @clobber.i32(i32 %pack) + %mask.bit = and i32 %mask, 16 + %sel = lshr i32 %pack, %mask.bit + %trunc = trunc i32 %sel to i16 + ret i16 %trunc +} + +; multi-use of %sel blocks fold +define i16 @selective_shift_16.mu.2(i32 %mask, i16 %upper, i16 %lower) { +; CHECK-LABEL: define i16 @selective_shift_16.mu.2( +; CHECK-SAME: i32 [[MASK:%.*]], i16 [[UPPER:%.*]], i16 [[LOWER:%.*]]) { +; CHECK-NEXT: [[UPPER_ZEXT:%.*]] = zext i16 [[UPPER]] to i32 +; CHECK-NEXT: [[UPPER_SHL:%.*]] = shl nuw i32 [[UPPER_ZEXT]], 16 +; CHECK-NEXT: [[LOWER_ZEXT:%.*]] = zext i16 [[LOWER]] to i32 +; CHECK-NEXT: [[PACK:%.*]] = or disjoint i32 [[UPPER_SHL]], [[LOWER_ZEXT]] +; CHECK-NEXT: [[MASK_BIT:%.*]] = and i32 [[MASK]], 16 +; CHECK-NEXT: [[SEL:%.*]] = lshr i32 [[PACK]], [[MASK_BIT]] +; CHECK-NEXT: call void @clobber.i32(i32 [[SEL]]) +; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[SEL]] to i16 +; CHECK-NEXT: ret i16 [[TRUNC]] +; + %upper.zext = zext i16 %upper to i32 + %upper.shl = shl nuw i32 %upper.zext, 16 + %lower.zext = zext i16 %lower to i32 + %pack = or disjoint i32 %upper.shl, %lower.zext + %mask.bit = and i32 %mask, 16 + %sel = lshr i32 %pack, %mask.bit + call void @clobber.i32(i32 %sel) + %trunc = trunc i32 %sel to i16 + ret i16 %trunc +} + +; bitwidth must be a power of 2 to fold +define i24 @selective_shift_24(i48 %mask, i24 %upper, i24 %lower) { +; CHECK-LABEL: define i24 @selective_shift_24( +; CHECK-SAME: i48 [[MASK:%.*]], i24 [[UPPER:%.*]], i24 [[LOWER:%.*]]) { +; CHECK-NEXT: [[UPPER_ZEXT:%.*]] = zext i24 [[UPPER]] to i48 +; CHECK-NEXT: [[UPPER_SHL:%.*]] = shl nuw i48 [[UPPER_ZEXT]], 24 +; CHECK-NEXT: [[LOWER_ZEXT:%.*]] = zext i24 [[LOWER]] to i48 +; CHECK-NEXT: [[PACK:%.*]] = or disjoint i48 [[UPPER_SHL]], [[LOWER_ZEXT]] +; CHECK-NEXT: [[MASK_BIT:%.*]] = and i48 [[MASK]], 24 +; CHECK-NEXT: [[SEL:%.*]] = lshr i48 [[PACK]], [[MASK_BIT]] +; CHECK-NEXT: [[TRUNC:%.*]] = trunc i48 [[SEL]] to i24 +; CHECK-NEXT: ret i24 [[TRUNC]] +; + %upper.zext = zext i24 %upper to i48 + %upper.shl = shl nuw i48 %upper.zext, 24 + %lower.zext = zext i24 %lower to i48 + %pack = or disjoint i48 %upper.shl, %lower.zext + %mask.bit = and i48 %mask, 24 + %sel = lshr i48 %pack, %mask.bit + %trunc = trunc i48 %sel to i24 + ret i24 %trunc +} + +define i32 @selective_shift_32(i64 %mask, i32 %upper, i32 %lower) { +; CHECK-LABEL: define i32 @selective_shift_32( +; CHECK-SAME: i64 [[MASK:%.*]], i32 [[UPPER:%.*]], i32 [[LOWER:%.*]]) { +; CHECK-NEXT: [[MASK_BIT:%.*]] = and i64 [[MASK]], 32 +; CHECK-NEXT: [[MASK_BIT_Z:%.*]] = icmp eq i64 [[MASK_BIT]], 0 +; CHECK-NEXT: [[SEL_V:%.*]] = select i1 [[MASK_BIT_Z]], i32 [[LOWER]], i32 [[UPPER]] +; CHECK-NEXT: ret i32 [[SEL_V]] +; + %upper.zext = zext i32 %upper to i64 + %upper.shl = shl nuw i64 %upper.zext, 32 + %lower.zext = zext i32 %lower to i64 + %pack = or disjoint i64 %upper.shl, %lower.zext + %mask.bit = and i64 %mask, 32 + %sel = lshr i64 %pack, %mask.bit + %trunc = trunc i64 %sel to i32 + ret i32 %trunc +} + +define i32 @selective_shift_32.commute(i64 %mask, i32 %upper, i32 %lower) { +; CHECK-LABEL: define i32 @selective_shift_32.commute( +; CHECK-SAME: i64 [[MASK:%.*]], i32 [[UPPER:%.*]], i32 [[LOWER:%.*]]) { +; CHECK-NEXT: [[MASK_BIT:%.*]] = and i64 [[MASK]], 32 +; CHECK-NEXT: [[MASK_BIT_Z:%.*]] = icmp eq i64 [[MASK_BIT]], 0 +; CHECK-NEXT: [[SEL_V:%.*]] = select i1 [[MASK_BIT_Z]], i32 [[LOWER]], i32 [[UPPER]] +; CHECK-NEXT: ret i32 [[SEL_V]] +; + %upper.zext = zext i32 %upper to i64 + %upper.shl = shl nuw i64 %upper.zext, 32 + %lower.zext = zext i32 %lower to i64 + %pack = or disjoint i64 %lower.zext, %upper.shl + %mask.bit = and i64 %mask, 32 + %sel = lshr i64 %pack, %mask.bit + %trunc = trunc i64 %sel to i32 + ret i32 %trunc +} + +define i32 @selective_shift_32_range(i64 %mask, i64 %upper, i64 range(i64 0, 4294967296) %lower) { +; CHECK-LABEL: define i32 @selective_shift_32_range( +; CHECK-SAME: i64 [[MASK:%.*]], i64 [[UPPER:%.*]], i64 range(i64 0, 4294967296) [[LOWER:%.*]]) { +; CHECK-NEXT: [[MASK_BIT:%.*]] = and i64 [[MASK]], 32 +; CHECK-NEXT: [[MASK_BIT_Z:%.*]] = icmp eq i64 [[MASK_BIT]], 0 +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[MASK_BIT_Z]], i64 [[LOWER]], i64 [[UPPER]] +; CHECK-NEXT: [[TRUNC:%.*]] = trunc i64 [[SEL]] to i32 +; CHECK-NEXT: ret i32 [[TRUNC]] +; + %upper.shl = shl nuw i64 %upper, 32 + %pack = or disjoint i64 %upper.shl, %lower + %mask.bit = and i64 %mask, 32 + %sel = lshr i64 %pack, %mask.bit + %trunc = trunc i64 %sel to i32 + ret i32 %trunc +} From 9630c2272bd91dbcda720982f5f9017057026e03 Mon Sep 17 00:00:00 2001 From: Zach Goldthorpe Date: Mon, 6 Oct 2025 11:05:20 -0500 Subject: [PATCH 2/4] Moved pattern to `SimplifyDemandedUseBits` --- .../InstCombine/InstCombineCasts.cpp | 44 ------------- .../InstCombineSimplifyDemanded.cpp | 42 ++++++++++++ .../InstCombine/fold-selective-shift.ll | 64 +++++++------------ 3 files changed, 66 insertions(+), 84 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index fb313e2a7eb22..9ca8194b44f8f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -756,47 +756,6 @@ static Instruction *shrinkInsertElt(CastInst &Trunc, return nullptr; } -/// Let N = 2 * M. -/// Given an N-bit integer representing a pack of two M-bit integers, -/// we can select one of the packed integers by right-shifting by either zero or -/// M, and then truncating the result to M bits. -/// -/// This function folds this shift-and-truncate into a select instruction, -/// enabling further simplification. -static Instruction *foldPackSelectingShift(TruncInst &Trunc, - InstCombinerImpl &IC) { - - const uint64_t BitWidth = Trunc.getDestTy()->getScalarSizeInBits(); - if (!isPowerOf2_64(BitWidth)) - return nullptr; - if (Trunc.getSrcTy()->getScalarSizeInBits() < 2 * BitWidth) - return nullptr; - - Value *Upper, *Lower, *ShrAmt; - if (!match(Trunc.getOperand(0), - m_OneUse(m_Shr( - m_OneUse(m_DisjointOr( - m_OneUse(m_Shl(m_Value(Upper), m_SpecificInt(BitWidth))), - m_Value(Lower))), - m_Value(ShrAmt))))) - return nullptr; - - KnownBits KnownLower = IC.computeKnownBits(Lower, nullptr); - if (!KnownLower.getMaxValue().isIntN(BitWidth)) - return nullptr; - - KnownBits KnownShr = IC.computeKnownBits(ShrAmt, nullptr); - if ((~KnownShr.Zero).getZExtValue() != BitWidth) - return nullptr; - - Value *ShrAmtZ = - IC.Builder.CreateICmpEQ(ShrAmt, Constant::getNullValue(Trunc.getSrcTy()), - ShrAmt->getName() + ".z"); - Value *Select = IC.Builder.CreateSelect(ShrAmtZ, Lower, Upper); - Select->takeName(Trunc.getOperand(0)); - return CastInst::CreateTruncOrBitCast(Select, Trunc.getDestTy()); -} - Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { if (Instruction *Result = commonCastTransforms(Trunc)) return Result; @@ -948,9 +907,6 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { if (Instruction *I = shrinkInsertElt(Trunc, Builder)) return I; - if (Instruction *I = foldPackSelectingShift(Trunc, *this)) - return I; - if (Src->hasOneUse() && (isa(SrcTy) || shouldChangeType(SrcTy, DestTy))) { // Transform "trunc (shl X, cst)" -> "shl (trunc X), cst" so long as the diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index aa030294ff1e5..c018161fbfa44 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -800,6 +800,48 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Instruction *I, Known.Zero.setHighBits(ShiftAmt); // high bits known zero. } else { llvm::computeKnownBits(I, Known, Q, Depth); + + // Let N = 2 * M. + // Given an N-bit integer representing a pack of two M-bit integers, + // we can select one of the packed integers by right-shifting by either + // zero or M (which is the most straightforward to check if M is a power + // of 2), and then isolating the lower M bits. In this case, we can + // represent the shift as a select on whether the shr amount is nonzero. + uint64_t ShlAmt; + Value *Upper, *Lower; + if (!match(I->getOperand(0), + m_OneUse(m_DisjointOr( + m_OneUse(m_Shl(m_Value(Upper), m_ConstantInt(ShlAmt))), + m_Value(Lower))))) + break; + if (!isPowerOf2_64(ShlAmt)) + break; + + const uint64_t DemandedBitWidth = DemandedMask.getActiveBits(); + if (DemandedBitWidth > ShlAmt) + break; + + // Check that upper demanded bits are not lost from lshift. + if (Upper->getType()->getScalarSizeInBits() < ShlAmt + DemandedBitWidth) + break; + + KnownBits KnownLowerBits = computeKnownBits(Lower, I, Depth); + if (!KnownLowerBits.getMaxValue().isIntN(ShlAmt)) + break; + + Value *ShrAmt = I->getOperand(1); + KnownBits KnownShrBits = computeKnownBits(ShrAmt, I, Depth); + // Verify that ShrAmt is either exactly ShlAmt (which is a power of 2) or + // zero. + if ((~KnownShrBits.Zero).getZExtValue() != ShlAmt) + break; + + Value *ShrAmtZ = Builder.CreateICmpEQ( + ShrAmt, Constant::getNullValue(ShrAmt->getType()), + ShrAmt->getName() + ".z"); + Value *Select = Builder.CreateSelect(ShrAmtZ, Lower, Upper); + Select->takeName(I); + return Select; } break; } diff --git a/llvm/test/Transforms/InstCombine/fold-selective-shift.ll b/llvm/test/Transforms/InstCombine/fold-selective-shift.ll index 3baeff3871e16..18214c5edfeb9 100644 --- a/llvm/test/Transforms/InstCombine/fold-selective-shift.ll +++ b/llvm/test/Transforms/InstCombine/fold-selective-shift.ll @@ -39,8 +39,8 @@ define i16 @selective_shift_16.commute(i32 %mask, i16 %upper, i16 %lower) { ret i16 %trunc } -define i16 @selective_shift_16_range(i32 %mask, i32 %upper, i32 range(i32 0, 65536) %lower) { -; CHECK-LABEL: define i16 @selective_shift_16_range( +define i16 @selective_shift_16.range(i32 %mask, i32 %upper, i32 range(i32 0, 65536) %lower) { +; CHECK-LABEL: define i16 @selective_shift_16.range( ; CHECK-SAME: i32 [[MASK:%.*]], i32 [[UPPER:%.*]], i32 range(i32 0, 65536) [[LOWER:%.*]]) { ; CHECK-NEXT: [[MASK_BIT:%.*]] = and i32 [[MASK]], 16 ; CHECK-NEXT: [[MASK_BIT_Z:%.*]] = icmp eq i32 [[MASK_BIT]], 0 @@ -56,8 +56,27 @@ define i16 @selective_shift_16_range(i32 %mask, i32 %upper, i32 range(i32 0, 655 ret i16 %trunc } -define <2 x i16> @selective_shift_v16(<2 x i32> %mask, <2 x i16> %upper, <2 x i16> %lower) { -; CHECK-LABEL: define <2 x i16> @selective_shift_v16( +define i32 @selective_shift_16.masked(i32 %mask, i16 %upper, i16 %lower) { +; CHECK-LABEL: define i32 @selective_shift_16.masked( +; CHECK-SAME: i32 [[MASK:%.*]], i16 [[UPPER:%.*]], i16 [[LOWER:%.*]]) { +; CHECK-NEXT: [[MASK_BIT:%.*]] = and i32 [[MASK]], 16 +; CHECK-NEXT: [[MASK_BIT_Z:%.*]] = icmp eq i32 [[MASK_BIT]], 0 +; CHECK-NEXT: [[SEL_V:%.*]] = select i1 [[MASK_BIT_Z]], i16 [[LOWER]], i16 [[UPPER]] +; CHECK-NEXT: [[SEL:%.*]] = zext i16 [[SEL_V]] to i32 +; CHECK-NEXT: ret i32 [[SEL]] +; + %upper.zext = zext i16 %upper to i32 + %upper.shl = shl nuw i32 %upper.zext, 16 + %lower.zext = zext i16 %lower to i32 + %pack = or disjoint i32 %lower.zext, %upper.shl + %mask.bit = and i32 %mask, 16 + %sel = lshr i32 %pack, %mask.bit + %sel.masked = and i32 %sel, 65535 + ret i32 %sel.masked +} + +define <2 x i16> @selective_shift.v16(<2 x i32> %mask, <2 x i16> %upper, <2 x i16> %lower) { +; CHECK-LABEL: define <2 x i16> @selective_shift.v16( ; CHECK-SAME: <2 x i32> [[MASK:%.*]], <2 x i16> [[UPPER:%.*]], <2 x i16> [[LOWER:%.*]]) { ; CHECK-NEXT: [[MASK_BIT:%.*]] = and <2 x i32> [[MASK]], splat (i32 16) ; CHECK-NEXT: [[MASK_BIT_Z:%.*]] = icmp eq <2 x i32> [[MASK_BIT]], zeroinitializer @@ -183,7 +202,7 @@ define i16 @selective_shift_16.mu.1(i32 %mask, i16 %upper, i16 %lower) { ret i16 %trunc } -; multi-use of %sel blocks fold +; non-truncated use of %sel blocks fold define i16 @selective_shift_16.mu.2(i32 %mask, i16 %upper, i16 %lower) { ; CHECK-LABEL: define i16 @selective_shift_16.mu.2( ; CHECK-SAME: i32 [[MASK:%.*]], i16 [[UPPER:%.*]], i16 [[LOWER:%.*]]) { @@ -248,38 +267,3 @@ define i32 @selective_shift_32(i64 %mask, i32 %upper, i32 %lower) { %trunc = trunc i64 %sel to i32 ret i32 %trunc } - -define i32 @selective_shift_32.commute(i64 %mask, i32 %upper, i32 %lower) { -; CHECK-LABEL: define i32 @selective_shift_32.commute( -; CHECK-SAME: i64 [[MASK:%.*]], i32 [[UPPER:%.*]], i32 [[LOWER:%.*]]) { -; CHECK-NEXT: [[MASK_BIT:%.*]] = and i64 [[MASK]], 32 -; CHECK-NEXT: [[MASK_BIT_Z:%.*]] = icmp eq i64 [[MASK_BIT]], 0 -; CHECK-NEXT: [[SEL_V:%.*]] = select i1 [[MASK_BIT_Z]], i32 [[LOWER]], i32 [[UPPER]] -; CHECK-NEXT: ret i32 [[SEL_V]] -; - %upper.zext = zext i32 %upper to i64 - %upper.shl = shl nuw i64 %upper.zext, 32 - %lower.zext = zext i32 %lower to i64 - %pack = or disjoint i64 %lower.zext, %upper.shl - %mask.bit = and i64 %mask, 32 - %sel = lshr i64 %pack, %mask.bit - %trunc = trunc i64 %sel to i32 - ret i32 %trunc -} - -define i32 @selective_shift_32_range(i64 %mask, i64 %upper, i64 range(i64 0, 4294967296) %lower) { -; CHECK-LABEL: define i32 @selective_shift_32_range( -; CHECK-SAME: i64 [[MASK:%.*]], i64 [[UPPER:%.*]], i64 range(i64 0, 4294967296) [[LOWER:%.*]]) { -; CHECK-NEXT: [[MASK_BIT:%.*]] = and i64 [[MASK]], 32 -; CHECK-NEXT: [[MASK_BIT_Z:%.*]] = icmp eq i64 [[MASK_BIT]], 0 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[MASK_BIT_Z]], i64 [[LOWER]], i64 [[UPPER]] -; CHECK-NEXT: [[TRUNC:%.*]] = trunc i64 [[SEL]] to i32 -; CHECK-NEXT: ret i32 [[TRUNC]] -; - %upper.shl = shl nuw i64 %upper, 32 - %pack = or disjoint i64 %upper.shl, %lower - %mask.bit = and i64 %mask, 32 - %sel = lshr i64 %pack, %mask.bit - %trunc = trunc i64 %sel to i32 - ret i32 %trunc -} From 73598e0b2d92e4352c527eaf9b2bb28193624339 Mon Sep 17 00:00:00 2001 From: Zach Goldthorpe Date: Tue, 7 Oct 2025 09:05:07 -0500 Subject: [PATCH 3/4] Minor updates per reviewer feedback --- .../InstCombineSimplifyDemanded.cpp | 4 +- .../InstCombine/fold-selective-shift.ll | 54 +++++++++++++++++++ 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index c018161fbfa44..014e8a35bbe6d 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -810,7 +810,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Instruction *I, uint64_t ShlAmt; Value *Upper, *Lower; if (!match(I->getOperand(0), - m_OneUse(m_DisjointOr( + m_OneUse(m_c_DisjointOr( m_OneUse(m_Shl(m_Value(Upper), m_ConstantInt(ShlAmt))), m_Value(Lower))))) break; @@ -833,7 +833,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Instruction *I, KnownBits KnownShrBits = computeKnownBits(ShrAmt, I, Depth); // Verify that ShrAmt is either exactly ShlAmt (which is a power of 2) or // zero. - if ((~KnownShrBits.Zero).getZExtValue() != ShlAmt) + if (~KnownShrBits.Zero != ShlAmt) break; Value *ShrAmtZ = Builder.CreateICmpEQ( diff --git a/llvm/test/Transforms/InstCombine/fold-selective-shift.ll b/llvm/test/Transforms/InstCombine/fold-selective-shift.ll index 18214c5edfeb9..2b2296541f14a 100644 --- a/llvm/test/Transforms/InstCombine/fold-selective-shift.ll +++ b/llvm/test/Transforms/InstCombine/fold-selective-shift.ll @@ -56,6 +56,23 @@ define i16 @selective_shift_16.range(i32 %mask, i32 %upper, i32 range(i32 0, 655 ret i16 %trunc } +define i16 @selective_shift_16.range.commute(i32 %mask, i32 %upper, i32 range(i32 0, 65536) %lower) { +; CHECK-LABEL: define i16 @selective_shift_16.range.commute( +; CHECK-SAME: i32 [[MASK:%.*]], i32 [[UPPER:%.*]], i32 range(i32 0, 65536) [[LOWER:%.*]]) { +; CHECK-NEXT: [[MASK_BIT:%.*]] = and i32 [[MASK]], 16 +; CHECK-NEXT: [[MASK_BIT_Z:%.*]] = icmp eq i32 [[MASK_BIT]], 0 +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[MASK_BIT_Z]], i32 [[LOWER]], i32 [[UPPER]] +; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[SEL]] to i16 +; CHECK-NEXT: ret i16 [[TRUNC]] +; + %upper.shl = shl nuw i32 %upper, 16 + %pack = or disjoint i32 %lower, %upper.shl + %mask.bit = and i32 %mask, 16 + %sel = lshr i32 %pack, %mask.bit + %trunc = trunc i32 %sel to i16 + ret i16 %trunc +} + define i32 @selective_shift_16.masked(i32 %mask, i16 %upper, i16 %lower) { ; CHECK-LABEL: define i32 @selective_shift_16.masked( ; CHECK-SAME: i32 [[MASK:%.*]], i16 [[UPPER:%.*]], i16 [[LOWER:%.*]]) { @@ -75,6 +92,25 @@ define i32 @selective_shift_16.masked(i32 %mask, i16 %upper, i16 %lower) { ret i32 %sel.masked } +define i32 @selective_shift_16.masked.commute(i32 %mask, i16 %upper, i16 %lower) { +; CHECK-LABEL: define i32 @selective_shift_16.masked.commute( +; CHECK-SAME: i32 [[MASK:%.*]], i16 [[UPPER:%.*]], i16 [[LOWER:%.*]]) { +; CHECK-NEXT: [[MASK_BIT:%.*]] = and i32 [[MASK]], 16 +; CHECK-NEXT: [[MASK_BIT_Z:%.*]] = icmp eq i32 [[MASK_BIT]], 0 +; CHECK-NEXT: [[SEL_V:%.*]] = select i1 [[MASK_BIT_Z]], i16 [[LOWER]], i16 [[UPPER]] +; CHECK-NEXT: [[SEL:%.*]] = zext i16 [[SEL_V]] to i32 +; CHECK-NEXT: ret i32 [[SEL]] +; + %upper.zext = zext i16 %upper to i32 + %upper.shl = shl nuw i32 %upper.zext, 16 + %lower.zext = zext i16 %lower to i32 + %pack = or disjoint i32 %upper.shl, %lower.zext + %mask.bit = and i32 %mask, 16 + %sel = lshr i32 %pack, %mask.bit + %sel.masked = and i32 %sel, 65535 + ret i32 %sel.masked +} + define <2 x i16> @selective_shift.v16(<2 x i32> %mask, <2 x i16> %upper, <2 x i16> %lower) { ; CHECK-LABEL: define <2 x i16> @selective_shift.v16( ; CHECK-SAME: <2 x i32> [[MASK:%.*]], <2 x i16> [[UPPER:%.*]], <2 x i16> [[LOWER:%.*]]) { @@ -267,3 +303,21 @@ define i32 @selective_shift_32(i64 %mask, i32 %upper, i32 %lower) { %trunc = trunc i64 %sel to i32 ret i32 %trunc } + +define i32 @selective_shift_32.commute(i64 %mask, i32 %upper, i32 %lower) { +; CHECK-LABEL: define i32 @selective_shift_32.commute( +; CHECK-SAME: i64 [[MASK:%.*]], i32 [[UPPER:%.*]], i32 [[LOWER:%.*]]) { +; CHECK-NEXT: [[MASK_BIT:%.*]] = and i64 [[MASK]], 32 +; CHECK-NEXT: [[MASK_BIT_Z:%.*]] = icmp eq i64 [[MASK_BIT]], 0 +; CHECK-NEXT: [[SEL_V:%.*]] = select i1 [[MASK_BIT_Z]], i32 [[LOWER]], i32 [[UPPER]] +; CHECK-NEXT: ret i32 [[SEL_V]] +; + %upper.zext = zext i32 %upper to i64 + %upper.shl = shl nuw i64 %upper.zext, 32 + %lower.zext = zext i32 %lower to i64 + %pack = or disjoint i64 %lower.zext, %upper.shl + %mask.bit = and i64 %mask, 32 + %sel = lshr i64 %pack, %mask.bit + %trunc = trunc i64 %sel to i32 + ret i32 %trunc +} From 6d81fef00893981222c47000d1cb6692ac8b6f83 Mon Sep 17 00:00:00 2001 From: Zach Goldthorpe Date: Tue, 7 Oct 2025 09:17:58 -0500 Subject: [PATCH 4/4] Moved logic into separate function. --- .../InstCombineSimplifyDemanded.cpp | 102 ++++++++++-------- 1 file changed, 58 insertions(+), 44 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 014e8a35bbe6d..127a506e440b7 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -60,6 +60,58 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo, return true; } +/// Let N = 2 * M. +/// Given an N-bit integer representing a pack of two M-bit integers, +/// we can select one of the packed integers by right-shifting by either +/// zero or M (which is the most straightforward to check if M is a power +/// of 2), and then isolating the lower M bits. In this case, we can +/// represent the shift as a select on whether the shr amount is nonzero. +static Value *simplifyShiftSelectingPackedElement(Instruction *I, + const APInt &DemandedMask, + InstCombinerImpl &IC, + unsigned Depth) { + assert(I->getOpcode() == Instruction::LShr && + "Only lshr instruction supported"); + + uint64_t ShlAmt; + Value *Upper, *Lower; + if (!match(I->getOperand(0), + m_OneUse(m_c_DisjointOr( + m_OneUse(m_Shl(m_Value(Upper), m_ConstantInt(ShlAmt))), + m_Value(Lower))))) + return nullptr; + + if (!isPowerOf2_64(ShlAmt)) + return nullptr; + + const uint64_t DemandedBitWidth = DemandedMask.getActiveBits(); + if (DemandedBitWidth > ShlAmt) + return nullptr; + + // Check that upper demanded bits are not lost from lshift. + if (Upper->getType()->getScalarSizeInBits() < ShlAmt + DemandedBitWidth) + return nullptr; + + KnownBits KnownLowerBits = IC.computeKnownBits(Lower, I, Depth); + if (!KnownLowerBits.getMaxValue().isIntN(ShlAmt)) + return nullptr; + + Value *ShrAmt = I->getOperand(1); + KnownBits KnownShrBits = IC.computeKnownBits(ShrAmt, I, Depth); + + // Verify that ShrAmt is either exactly ShlAmt (which is a power of 2) or + // zero. + if (~KnownShrBits.Zero != ShlAmt) + return nullptr; + + Value *ShrAmtZ = + IC.Builder.CreateICmpEQ(ShrAmt, Constant::getNullValue(ShrAmt->getType()), + ShrAmt->getName() + ".z"); + Value *Select = IC.Builder.CreateSelect(ShrAmtZ, Lower, Upper); + Select->takeName(I); + return Select; +} + /// Returns the bitwidth of the given scalar or pointer type. For vector types, /// returns the element type's bitwidth. static unsigned getBitWidth(Type *Ty, const DataLayout &DL) { @@ -798,51 +850,13 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Instruction *I, Known >>= ShiftAmt; if (ShiftAmt) Known.Zero.setHighBits(ShiftAmt); // high bits known zero. - } else { - llvm::computeKnownBits(I, Known, Q, Depth); - - // Let N = 2 * M. - // Given an N-bit integer representing a pack of two M-bit integers, - // we can select one of the packed integers by right-shifting by either - // zero or M (which is the most straightforward to check if M is a power - // of 2), and then isolating the lower M bits. In this case, we can - // represent the shift as a select on whether the shr amount is nonzero. - uint64_t ShlAmt; - Value *Upper, *Lower; - if (!match(I->getOperand(0), - m_OneUse(m_c_DisjointOr( - m_OneUse(m_Shl(m_Value(Upper), m_ConstantInt(ShlAmt))), - m_Value(Lower))))) - break; - if (!isPowerOf2_64(ShlAmt)) - break; - - const uint64_t DemandedBitWidth = DemandedMask.getActiveBits(); - if (DemandedBitWidth > ShlAmt) - break; - - // Check that upper demanded bits are not lost from lshift. - if (Upper->getType()->getScalarSizeInBits() < ShlAmt + DemandedBitWidth) - break; - - KnownBits KnownLowerBits = computeKnownBits(Lower, I, Depth); - if (!KnownLowerBits.getMaxValue().isIntN(ShlAmt)) - break; - - Value *ShrAmt = I->getOperand(1); - KnownBits KnownShrBits = computeKnownBits(ShrAmt, I, Depth); - // Verify that ShrAmt is either exactly ShlAmt (which is a power of 2) or - // zero. - if (~KnownShrBits.Zero != ShlAmt) - break; - - Value *ShrAmtZ = Builder.CreateICmpEQ( - ShrAmt, Constant::getNullValue(ShrAmt->getType()), - ShrAmt->getName() + ".z"); - Value *Select = Builder.CreateSelect(ShrAmtZ, Lower, Upper); - Select->takeName(I); - return Select; + break; } + if (Value *V = + simplifyShiftSelectingPackedElement(I, DemandedMask, *this, Depth)) + return V; + + llvm::computeKnownBits(I, Known, Q, Depth); break; } case Instruction::AShr: {