diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index c55864de9c170..0eaaa037ad578 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -816,15 +816,28 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) { break; } case Instruction::Trunc: { - VPRecipeBase *Zext = R.getOperand(0)->getDefiningRecipe(); - if (!Zext || getOpcodeForRecipe(*Zext) != Instruction::ZExt) + VPRecipeBase *Ext = R.getOperand(0)->getDefiningRecipe(); + if (!Ext) break; - VPValue *A = Zext->getOperand(0); + unsigned ExtOpcode = getOpcodeForRecipe(*Ext); + if (ExtOpcode != Instruction::ZExt && ExtOpcode != Instruction::SExt) + break; + VPValue *A = Ext->getOperand(0); VPValue *Trunc = R.getVPSingleValue(); - Type *TruncToTy = TypeInfo.inferScalarType(Trunc); - if (TruncToTy && TruncToTy == TypeInfo.inferScalarType(A)) + Type *TruncTy = TypeInfo.inferScalarType(Trunc); + Type *ATy = TypeInfo.inferScalarType(A); + if (TruncTy == ATy) { Trunc->replaceAllUsesWith(A); - + } else if (ATy->getScalarSizeInBits() < TruncTy->getScalarSizeInBits()) { + auto *VPC = + new VPWidenCastRecipe(Instruction::CastOps(ExtOpcode), A, TruncTy); + VPC->insertBefore(&R); + Trunc->replaceAllUsesWith(VPC); + } else if (ATy->getScalarSizeInBits() > TruncTy->getScalarSizeInBits()) { + auto *VPC = new VPWidenCastRecipe(Instruction::Trunc, A, TruncTy); + VPC->insertBefore(&R); + Trunc->replaceAllUsesWith(VPC); + } #ifndef NDEBUG // Verify that the cached type info is for both A and its users is still // accurate by comparing it to freshly computed types.