Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VectorCombine] Fold reduce(trunc(x)) -> trunc(reduce(x)) iff cost effective #81852

Merged
merged 2 commits into from
Feb 19, 2024

Conversation

RKSimon
Copy link
Collaborator

@RKSimon RKSimon commented Feb 15, 2024

Vector truncations can be pretty expensive, especially on X86, whilst scalar truncations are often free.

If the cost of performing the add/mul/and/or/xor reduction is cheap enough on the pre-truncated type, then avoid the vector truncation entirely.

Fixes #81469

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 15, 2024

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-backend-x86

Author: Simon Pilgrim (RKSimon)

Changes

Vector truncations can be pretty expensive, especially on X86, whilst scalar truncations are often free.

If the cost of performing the add/mul/and/or/xor reduction is cheap enough on the pre-truncated type, then avoid the vector truncation entirely.

Fixes #81469


Full diff: https://github.com/llvm/llvm-project/pull/81852.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+63)
  • (added) llvm/test/Transforms/VectorCombine/X86/reduction-of-truncations.ll (+91)
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index f18711ba30b708..20fb9de7c75aa9 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -111,6 +111,7 @@ class VectorCombine {
   bool scalarizeLoadExtract(Instruction &I);
   bool foldShuffleOfBinops(Instruction &I);
   bool foldShuffleFromReductions(Instruction &I);
+  bool foldTruncFromReductions(Instruction &I);
   bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
 
   void replaceValue(Value &Old, Value &New) {
@@ -1526,6 +1527,67 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
   return foldSelectShuffle(*Shuffle, true);
 }
 
+/// Determine if its more efficient to fold:
+///   reduce(trunc(x)) -> trunc(reduce(x)).
+bool VectorCombine::foldTruncFromReductions(Instruction &I) {
+  auto *II = dyn_cast<IntrinsicInst>(&I);
+  if (!II)
+    return false;
+
+  unsigned ReductionOpc = 0;
+  switch (II->getIntrinsicID()) {
+  case Intrinsic::vector_reduce_add:
+    ReductionOpc = Instruction::Add;
+    break;
+  case Intrinsic::vector_reduce_mul:
+    ReductionOpc = Instruction::Mul;
+    break;
+  case Intrinsic::vector_reduce_and:
+    ReductionOpc = Instruction::And;
+    break;
+  case Intrinsic::vector_reduce_or:
+    ReductionOpc = Instruction::Or;
+    break;
+  case Intrinsic::vector_reduce_xor:
+    ReductionOpc = Instruction::Xor;
+    break;
+  default:
+    return false;
+  }
+  Value *ReductionSrc = I.getOperand(0);
+
+  Value *TruncSrc;
+  if (!match(ReductionSrc, m_Trunc(m_OneUse(m_Value(TruncSrc)))))
+    return false;
+
+  auto *Trunc = cast<CastInst>(ReductionSrc);
+  auto *TruncTy = cast<VectorType>(TruncSrc->getType());
+  auto *ReductionTy = cast<VectorType>(ReductionSrc->getType());
+  Type *ResultTy = I.getType();
+
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  InstructionCost OldCost =
+      TTI.getCastInstrCost(Instruction::Trunc, ReductionTy, TruncTy,
+                           TTI::CastContextHint::None, CostKind, Trunc) +
+      TTI.getArithmeticReductionCost(ReductionOpc, ReductionTy, std::nullopt,
+                                     CostKind);
+  InstructionCost NewCost =
+      TTI.getArithmeticReductionCost(ReductionOpc, TruncTy, std::nullopt,
+                                     CostKind) +
+      TTI.getCastInstrCost(Instruction::Trunc, ResultTy,
+                           ReductionTy->getScalarType(),
+                           TTI::CastContextHint::None, CostKind);
+
+  if (OldCost < NewCost || !NewCost.isValid())
+    return false;
+
+  Value *NewReduction = Builder.CreateIntrinsic(
+      TruncTy->getScalarType(), II->getIntrinsicID(), {TruncSrc});
+  Value *NewTruncation = Builder.CreateTrunc(NewReduction, ResultTy);
+  replaceValue(I, *NewTruncation);
+  return true;
+}
+
 /// This method looks for groups of shuffles acting on binops, of the form:
 ///  %x = shuffle ...
 ///  %y = shuffle ...
@@ -1917,6 +1979,7 @@ bool VectorCombine::run() {
       switch (Opcode) {
       case Instruction::Call:
         MadeChange |= foldShuffleFromReductions(I);
+        MadeChange |= foldTruncFromReductions(I);
         break;
       case Instruction::ICmp:
       case Instruction::FCmp:
diff --git a/llvm/test/Transforms/VectorCombine/X86/reduction-of-truncations.ll b/llvm/test/Transforms/VectorCombine/X86/reduction-of-truncations.ll
new file mode 100644
index 00000000000000..54d8250cdec816
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/X86/reduction-of-truncations.ll
@@ -0,0 +1,91 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -S --passes=vector-combine -mtriple=x86_64-- -mcpu=x86-64    | FileCheck %s --check-prefixes=CHECK,X64
+; RUN: opt < %s -S --passes=vector-combine -mtriple=x86_64-- -mcpu=x86-64-v2 | FileCheck %s --check-prefixes=CHECK,X64
+; RUN: opt < %s -S --passes=vector-combine -mtriple=x86_64-- -mcpu=x86-64-v3 | FileCheck %s --check-prefixes=CHECK,X64
+; RUN: opt < %s -S --passes=vector-combine -mtriple=x86_64-- -mcpu=x86-64-v4 | FileCheck %s --check-prefixes=CHECK,AVX512
+
+;
+; Fold reduce(trunc(X)) -> trunc(reduce(X)) if more cost efficient
+;
+
+; Cheap AVX512 v8i64 -> v8i32 truncation
+define i32 @reduce_add_trunc_v8i64_i32(<8 x i64> %a0)  {
+; X64-LABEL: @reduce_add_trunc_v8i64_i32(
+; X64-NEXT:    [[TMP1:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[A0:%.*]])
+; X64-NEXT:    [[RED:%.*]] = trunc i64 [[TMP1]] to i32
+; X64-NEXT:    ret i32 [[RED]]
+;
+; AVX512-LABEL: @reduce_add_trunc_v8i64_i32(
+; AVX512-NEXT:    [[TR:%.*]] = trunc <8 x i64> [[A0:%.*]] to <8 x i32>
+; AVX512-NEXT:    [[RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TR]])
+; AVX512-NEXT:    ret i32 [[RED]]
+;
+  %tr = trunc <8 x i64> %a0 to <8 x i32>
+  %red = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %tr)
+  ret i32 %red
+}
+declare i32 @llvm.vector.reduce.add.v8i32(<8 x i32>)
+
+; No legal vXi8 multiplication so vXi16 is always cheaper
+define i8 @reduce_mul_trunc_v16i16_i8(<16 x i16> %a0)  {
+; CHECK-LABEL: @reduce_mul_trunc_v16i16_i8(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i16 @llvm.vector.reduce.mul.v16i16(<16 x i16> [[A0:%.*]])
+; CHECK-NEXT:    [[RED:%.*]] = trunc i16 [[TMP1]] to i8
+; CHECK-NEXT:    ret i8 [[RED]]
+;
+  %tr = trunc <16 x i16> %a0 to <16 x i8>
+  %red = tail call i8 @llvm.vector.reduce.mul.v16i8(<16 x i8> %tr)
+  ret i8 %red
+}
+declare i8 @llvm.vector.reduce.mul.v16i8(<16 x i8>)
+
+define i8 @reduce_or_trunc_v8i32_i8(<8 x i32> %a0)  {
+; CHECK-LABEL: @reduce_or_trunc_v8i32_i8(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[A0:%.*]])
+; CHECK-NEXT:    [[RED:%.*]] = trunc i32 [[TMP1]] to i8
+; CHECK-NEXT:    ret i8 [[RED]]
+;
+  %tr = trunc <8 x i32> %a0 to <8 x i8>
+  %red = tail call i8 @llvm.vector.reduce.or.v8i32(<8 x i8> %tr)
+  ret i8 %red
+}
+declare i32 @llvm.vector.reduce.or.v8i8(<8 x i8>)
+
+define i8 @reduce_xor_trunc_v16i64_i8(<16 x i64> %a0)  {
+; CHECK-LABEL: @reduce_xor_trunc_v16i64_i8(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i64 @llvm.vector.reduce.xor.v16i64(<16 x i64> [[A0:%.*]])
+; CHECK-NEXT:    [[RED:%.*]] = trunc i64 [[TMP1]] to i8
+; CHECK-NEXT:    ret i8 [[RED]]
+;
+  %tr = trunc <16 x i64> %a0 to <16 x i8>
+  %red = tail call i8 @llvm.vector.reduce.xor.v16i8(<16 x i8> %tr)
+  ret i8 %red
+}
+declare i8 @llvm.vector.reduce.xor.v16i8(<16 x i8>)
+
+; Negative Test: vXi16 multiply is much cheaper than vXi64
+define i16 @reduce_mul_trunc_v8i64_i16(<8 x i64> %a0)  {
+; CHECK-LABEL: @reduce_mul_trunc_v8i64_i16(
+; CHECK-NEXT:    [[TR:%.*]] = trunc <8 x i64> [[A0:%.*]] to <8 x i16>
+; CHECK-NEXT:    [[RED:%.*]] = tail call i16 @llvm.vector.reduce.mul.v8i16(<8 x i16> [[TR]])
+; CHECK-NEXT:    ret i16 [[RED]]
+;
+  %tr = trunc <8 x i64> %a0 to <8 x i16>
+  %red = tail call i16 @llvm.vector.reduce.mul.v8i16(<8 x i16> %tr)
+  ret i16 %red
+}
+declare i16 @llvm.vector.reduce.mul.v8i16(<8 x i16>)
+
+; Negative Test: min/max reductions can't use pre-truncated types.
+define i8 @reduce_smin_trunc_v16i16_i8(<16 x i16> %a0)  {
+; CHECK-LABEL: @reduce_smin_trunc_v16i16_i8(
+; CHECK-NEXT:    [[TR:%.*]] = trunc <16 x i16> [[A0:%.*]] to <16 x i8>
+; CHECK-NEXT:    [[RED:%.*]] = tail call i8 @llvm.vector.reduce.smin.v16i8(<16 x i8> [[TR]])
+; CHECK-NEXT:    ret i8 [[RED]]
+;
+  %tr = trunc <16 x i16> %a0 to <16 x i8>
+  %red = tail call i8 @llvm.vector.reduce.smin.v16i8(<16 x i8> %tr)
+  ret i8 %red
+}
+declare i8 @llvm.vector.reduce.smin.v16i8(<16 x i8>)
+

RKSimon added a commit that referenced this pull request Feb 16, 2024
…and Arithmetic Instruction Opcode and MinMax IntrinsicID / RecurKind

Noticed on #81852
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG

…fective

Vector truncations can be pretty expensive, especially on X86, whilst scalar truncations are often free.

If the cost of performing the add/mul/and/or/xor reduction is cheap enough on the pre-truncated type, then avoid the vector truncation entirely.

Fixes llvm#81469
@RKSimon RKSimon merged commit 769c22f into llvm:main Feb 19, 2024
4 checks passed
@RKSimon RKSimon deleted the reduce-trunc branch February 20, 2024 12:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[X86] Prefer trunc(reduce(x)) over reduce(trunc(x))
4 participants