Skip to content

Commit

Permalink
[VPlan] Simplify redundant trunc (zext A) pairs to A.
Browse files Browse the repository at this point in the history
Add simplification for redundant trunc(zext A) pairs. Generally apply a
transform from D149903.

Depends on D159200.

Reviewed By: Ayal

Differential Revision: https://reviews.llvm.org/D159202
  • Loading branch information
fhahn committed Oct 22, 2023
1 parent 4aeb7a0 commit 0c8e5be
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 44 deletions.
29 changes: 27 additions & 2 deletions llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -812,23 +812,48 @@ static bool isConstantOne(VPValue *V) {
static unsigned getOpcodeForRecipe(VPRecipeBase &R) {
if (auto *WidenR = dyn_cast<VPWidenRecipe>(&R))
return WidenR->getUnderlyingInstr()->getOpcode();
if (auto *WidenC = dyn_cast<VPWidenCastRecipe>(&R))
return WidenC->getOpcode();
if (auto *RepR = dyn_cast<VPReplicateRecipe>(&R))
return RepR->getUnderlyingInstr()->getOpcode();
if (auto *VPI = dyn_cast<VPInstruction>(&R))
return VPI->getOpcode();
return 0;
}

/// Return the scalar size in bits for \p VPV if possible.
static Type *getTypeForVPValue(VPValue *VPV) {
// TODO: Replace with VPlan type inference once ready.
if (auto *VPC = dyn_cast<VPWidenCastRecipe>(VPV))
return VPC->getResultType();
auto *UV = VPV->getUnderlyingValue();
return UV->getType();
}

/// Try to simplify recipe \p R.
static void simplifyRecipe(VPRecipeBase &R) {
unsigned Opcode = getOpcodeForRecipe(R);
if (Opcode == Instruction::Mul) {
switch (getOpcodeForRecipe(R)) {
case Instruction::Mul: {
VPValue *A = R.getOperand(0);
VPValue *B = R.getOperand(1);
if (isConstantOne(A))
return R.getVPSingleValue()->replaceAllUsesWith(B);
if (isConstantOne(B))
return R.getVPSingleValue()->replaceAllUsesWith(A);
break;
}
case Instruction::Trunc: {
VPRecipeBase *Zext = R.getOperand(0)->getDefiningRecipe();
if (!Zext || getOpcodeForRecipe(*Zext) != Instruction::ZExt)
break;
VPValue *A = Zext->getOperand(0);
VPValue *Trunc = R.getVPSingleValue();
if (getTypeForVPValue(Trunc) == getTypeForVPValue(A))
Trunc->replaceAllUsesWith(A);
break;
}
default:
break;
}
}

Expand Down
62 changes: 23 additions & 39 deletions llvm/test/Transforms/LoopVectorize/if-pred-stores.ll
Original file line number Diff line number Diff line change
Expand Up @@ -432,17 +432,13 @@ define void @minimal_bit_widths(i1 %c) {
; UNROLL-NEXT: [[TMP5:%.*]] = load i8, ptr [[TMP3]], align 1
; UNROLL-NEXT: br i1 [[C:%.*]], label [[PRED_STORE_IF:%.*]], label [[PRED_STORE_CONTINUE3]]
; UNROLL: pred.store.if:
; UNROLL-NEXT: [[TMP6:%.*]] = zext i8 [[TMP4]] to i32
; UNROLL-NEXT: [[TMP7:%.*]] = trunc i32 [[TMP6]] to i8
; UNROLL-NEXT: store i8 [[TMP7]], ptr [[TMP2]], align 1
; UNROLL-NEXT: [[TMP8:%.*]] = zext i8 [[TMP5]] to i32
; UNROLL-NEXT: [[TMP9:%.*]] = trunc i32 [[TMP8]] to i8
; UNROLL-NEXT: store i8 [[TMP9]], ptr [[TMP3]], align 1
; UNROLL-NEXT: store i8 [[TMP4]], ptr [[TMP2]], align 1
; UNROLL-NEXT: store i8 [[TMP5]], ptr [[TMP3]], align 1
; UNROLL-NEXT: br label [[PRED_STORE_CONTINUE3]]
; UNROLL: pred.store.continue3:
; UNROLL-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2
; UNROLL-NEXT: [[TMP10:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1000
; UNROLL-NEXT: br i1 [[TMP10]], label [[FOR_END:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
; UNROLL-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1000
; UNROLL-NEXT: br i1 [[TMP6]], label [[FOR_END:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
; UNROLL: for.end:
; UNROLL-NEXT: ret void
;
Expand All @@ -461,21 +457,17 @@ define void @minimal_bit_widths(i1 %c) {
; UNROLL-NOSIMPLIFY-NEXT: [[TMP5:%.*]] = load i8, ptr [[TMP3]], align 1
; UNROLL-NOSIMPLIFY-NEXT: br i1 [[C:%.*]], label [[PRED_STORE_IF:%.*]], label [[PRED_STORE_CONTINUE:%.*]]
; UNROLL-NOSIMPLIFY: pred.store.if:
; UNROLL-NOSIMPLIFY-NEXT: [[TMP6:%.*]] = zext i8 [[TMP4]] to i32
; UNROLL-NOSIMPLIFY-NEXT: [[TMP7:%.*]] = trunc i32 [[TMP6]] to i8
; UNROLL-NOSIMPLIFY-NEXT: store i8 [[TMP7]], ptr [[TMP2]], align 1
; UNROLL-NOSIMPLIFY-NEXT: store i8 [[TMP4]], ptr [[TMP2]], align 1
; UNROLL-NOSIMPLIFY-NEXT: br label [[PRED_STORE_CONTINUE]]
; UNROLL-NOSIMPLIFY: pred.store.continue:
; UNROLL-NOSIMPLIFY-NEXT: br i1 [[C]], label [[PRED_STORE_IF2:%.*]], label [[PRED_STORE_CONTINUE3]]
; UNROLL-NOSIMPLIFY: pred.store.if2:
; UNROLL-NOSIMPLIFY-NEXT: [[TMP8:%.*]] = zext i8 [[TMP5]] to i32
; UNROLL-NOSIMPLIFY-NEXT: [[TMP9:%.*]] = trunc i32 [[TMP8]] to i8
; UNROLL-NOSIMPLIFY-NEXT: store i8 [[TMP9]], ptr [[TMP3]], align 1
; UNROLL-NOSIMPLIFY-NEXT: store i8 [[TMP5]], ptr [[TMP3]], align 1
; UNROLL-NOSIMPLIFY-NEXT: br label [[PRED_STORE_CONTINUE3]]
; UNROLL-NOSIMPLIFY: pred.store.continue3:
; UNROLL-NOSIMPLIFY-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2
; UNROLL-NOSIMPLIFY-NEXT: [[TMP10:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1000
; UNROLL-NOSIMPLIFY-NEXT: br i1 [[TMP10]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
; UNROLL-NOSIMPLIFY-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1000
; UNROLL-NOSIMPLIFY-NEXT: br i1 [[TMP6]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
; UNROLL-NOSIMPLIFY: middle.block:
; UNROLL-NOSIMPLIFY-NEXT: br i1 true, label [[FOR_END:%.*]], label [[SCALAR_PH]]
; UNROLL-NOSIMPLIFY: scalar.ph:
Expand Down Expand Up @@ -515,27 +507,23 @@ define void @minimal_bit_widths(i1 %c) {
; VEC-NEXT: [[TMP3:%.*]] = extractelement <2 x i1> [[BROADCAST_SPLAT]], i32 0
; VEC-NEXT: br i1 [[TMP3]], label [[PRED_STORE_IF:%.*]], label [[PRED_STORE_CONTINUE:%.*]]
; VEC: pred.store.if:
; VEC-NEXT: [[TMP4:%.*]] = extractelement <2 x i8> [[WIDE_LOAD]], i32 0
; VEC-NEXT: [[TMP5:%.*]] = zext i8 [[TMP4]] to i32
; VEC-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr undef, i64 [[TMP0]]
; VEC-NEXT: [[TMP7:%.*]] = trunc i32 [[TMP5]] to i8
; VEC-NEXT: store i8 [[TMP7]], ptr [[TMP6]], align 1
; VEC-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr undef, i64 [[TMP0]]
; VEC-NEXT: [[TMP5:%.*]] = extractelement <2 x i8> [[WIDE_LOAD]], i32 0
; VEC-NEXT: store i8 [[TMP5]], ptr [[TMP4]], align 1
; VEC-NEXT: br label [[PRED_STORE_CONTINUE]]
; VEC: pred.store.continue:
; VEC-NEXT: [[TMP8:%.*]] = extractelement <2 x i1> [[BROADCAST_SPLAT]], i32 1
; VEC-NEXT: br i1 [[TMP8]], label [[PRED_STORE_IF2:%.*]], label [[PRED_STORE_CONTINUE3]]
; VEC-NEXT: [[TMP6:%.*]] = extractelement <2 x i1> [[BROADCAST_SPLAT]], i32 1
; VEC-NEXT: br i1 [[TMP6]], label [[PRED_STORE_IF2:%.*]], label [[PRED_STORE_CONTINUE3]]
; VEC: pred.store.if2:
; VEC-NEXT: [[TMP9:%.*]] = add i64 [[INDEX]], 1
; VEC-NEXT: [[TMP10:%.*]] = extractelement <2 x i8> [[WIDE_LOAD]], i32 1
; VEC-NEXT: [[TMP11:%.*]] = zext i8 [[TMP10]] to i32
; VEC-NEXT: [[TMP12:%.*]] = getelementptr i8, ptr undef, i64 [[TMP9]]
; VEC-NEXT: [[TMP13:%.*]] = trunc i32 [[TMP11]] to i8
; VEC-NEXT: store i8 [[TMP13]], ptr [[TMP12]], align 1
; VEC-NEXT: [[TMP7:%.*]] = add i64 [[INDEX]], 1
; VEC-NEXT: [[TMP8:%.*]] = getelementptr i8, ptr undef, i64 [[TMP7]]
; VEC-NEXT: [[TMP9:%.*]] = extractelement <2 x i8> [[WIDE_LOAD]], i32 1
; VEC-NEXT: store i8 [[TMP9]], ptr [[TMP8]], align 1
; VEC-NEXT: br label [[PRED_STORE_CONTINUE3]]
; VEC: pred.store.continue3:
; VEC-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2
; VEC-NEXT: [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1000
; VEC-NEXT: br i1 [[TMP14]], label [[FOR_END:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
; VEC-NEXT: [[TMP10:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1000
; VEC-NEXT: br i1 [[TMP10]], label [[FOR_END:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
; VEC: for.end:
; VEC-NEXT: ret void
;
Expand Down Expand Up @@ -606,21 +594,17 @@ define void @minimal_bit_widths_with_aliasing_store(i1 %c, ptr %ptr) {
; UNROLL-NOSIMPLIFY-NEXT: store i8 0, ptr [[TMP3]], align 1
; UNROLL-NOSIMPLIFY-NEXT: br i1 [[C:%.*]], label [[PRED_STORE_IF:%.*]], label [[PRED_STORE_CONTINUE:%.*]]
; UNROLL-NOSIMPLIFY: pred.store.if:
; UNROLL-NOSIMPLIFY-NEXT: [[TMP6:%.*]] = zext i8 [[TMP4]] to i32
; UNROLL-NOSIMPLIFY-NEXT: [[TMP7:%.*]] = trunc i32 [[TMP6]] to i8
; UNROLL-NOSIMPLIFY-NEXT: store i8 [[TMP7]], ptr [[TMP2]], align 1
; UNROLL-NOSIMPLIFY-NEXT: store i8 [[TMP4]], ptr [[TMP2]], align 1
; UNROLL-NOSIMPLIFY-NEXT: br label [[PRED_STORE_CONTINUE]]
; UNROLL-NOSIMPLIFY: pred.store.continue:
; UNROLL-NOSIMPLIFY-NEXT: br i1 [[C]], label [[PRED_STORE_IF2:%.*]], label [[PRED_STORE_CONTINUE3]]
; UNROLL-NOSIMPLIFY: pred.store.if2:
; UNROLL-NOSIMPLIFY-NEXT: [[TMP8:%.*]] = zext i8 [[TMP5]] to i32
; UNROLL-NOSIMPLIFY-NEXT: [[TMP9:%.*]] = trunc i32 [[TMP8]] to i8
; UNROLL-NOSIMPLIFY-NEXT: store i8 [[TMP9]], ptr [[TMP3]], align 1
; UNROLL-NOSIMPLIFY-NEXT: store i8 [[TMP5]], ptr [[TMP3]], align 1
; UNROLL-NOSIMPLIFY-NEXT: br label [[PRED_STORE_CONTINUE3]]
; UNROLL-NOSIMPLIFY: pred.store.continue3:
; UNROLL-NOSIMPLIFY-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2
; UNROLL-NOSIMPLIFY-NEXT: [[TMP10:%.*]] = icmp eq i64 [[INDEX_NEXT]], 0
; UNROLL-NOSIMPLIFY-NEXT: br i1 [[TMP10]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
; UNROLL-NOSIMPLIFY-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], 0
; UNROLL-NOSIMPLIFY-NEXT: br i1 [[TMP6]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
; UNROLL-NOSIMPLIFY: middle.block:
; UNROLL-NOSIMPLIFY-NEXT: br i1 true, label [[FOR_END:%.*]], label [[SCALAR_PH]]
; UNROLL-NOSIMPLIFY: scalar.ph:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ define void @sink_with_sideeffects(i1 %c, ptr %ptr) {
; CHECK-NEXT: Successor(s): pred.store.if, pred.store.continue

; CHECK: pred.store.if:
; CHECK-NEXT: CLONE ir<%tmp4> = zext ir<%tmp3>
; CHECK-NEXT: CLONE ir<%tmp5> = trunc ir<%tmp4>
; CHECK-NEXT: CLONE store ir<%tmp5>, ir<%tmp2>
; CHECK-NEXT: CLONE store ir<%tmp3>, ir<%tmp2>
; CHECK-NEXT: Successor(s): pred.store.continue

; CHECK: pred.store.continue:
Expand Down

0 comments on commit 0c8e5be

Please sign in to comment.