Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GlobalIsel] Combine selects with constants #76089

Merged
merged 3 commits into from
Jan 2, 2024
Merged

Conversation

tschuett
Copy link
Member

A first small step at combining selects.

A first small step at combining selects.
@llvmbot
Copy link
Collaborator

llvmbot commented Dec 20, 2023

@llvm/pr-subscribers-backend-amdgpu

@llvm/pr-subscribers-backend-aarch64

Author: Thorsten Schütt (tschuett)

Changes

A first small step at combining selects.


Patch is 691.01 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76089.diff

29 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h (+15-3)
  • (modified) llvm/include/llvm/Target/GlobalISel/Combine.td (+7-8)
  • (modified) llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp (+296-56)
  • (modified) llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir (+246)
  • (modified) llvm/test/CodeGen/AArch64/andcompare.ll (+9-5)
  • (modified) llvm/test/CodeGen/AArch64/arm64-ccmp.ll (+73-51)
  • (modified) llvm/test/CodeGen/AArch64/call-rv-marker.ll (+409-38)
  • (modified) llvm/test/CodeGen/AArch64/neon-bitwise-instructions.ll (+8-12)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/combine-fold-binop-into-select.mir (+17-25)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/divergence-divergent-i1-phis-no-lane-mask-merging.ll (+4-2)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/divergence-structurizer.ll (+4-2)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/fshl.ll (+1088-1104)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/fshr.ll (+1105-1311)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wqm.demote.ll (+32-16)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/lshr.ll (+41-36)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/saddsat.ll (+173-176)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/shl.ll (+59-52)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/usubsat.ll (+587-307)
  • (modified) llvm/test/CodeGen/AMDGPU/ctlz_zero_undef.ll (+18-14)
  • (modified) llvm/test/CodeGen/AMDGPU/fdiv_flags.f32.ll (+8-6)
  • (modified) llvm/test/CodeGen/AMDGPU/fptrunc.ll (+24-22)
  • (modified) llvm/test/CodeGen/AMDGPU/fsqrt.f32.ll (+112-87)
  • (modified) llvm/test/CodeGen/AMDGPU/fsqrt.f64.ll (+238-260)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.inverse.ballot.i64.ll (+4-4)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.frexp.ll (+95-62)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.log.ll (+22-18)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.log10.ll (+22-18)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.log2.ll (+8-6)
  • (modified) llvm/test/CodeGen/AMDGPU/rsq.f64.ll (+450-500)
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index e7debc652a0a8b..dcc1a4580b14a2 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -769,9 +769,6 @@ class CombinerHelper {
   bool matchCombineFSubFpExtFNegFMulToFMadOrFMA(MachineInstr &MI,
                                                 BuildFnTy &MatchInfo);
 
-  /// Fold boolean selects to logical operations.
-  bool matchSelectToLogical(MachineInstr &MI, BuildFnTy &MatchInfo);
-
   bool matchCombineFMinMaxNaN(MachineInstr &MI, unsigned &Info);
 
   /// Transform G_ADD(x, G_SUB(y, x)) to y.
@@ -814,6 +811,9 @@ class CombinerHelper {
   // Given a binop \p MI, commute operands 1 and 2.
   void applyCommuteBinOpOperands(MachineInstr &MI);
 
+  /// Combine selects.
+  bool matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo);
+
 private:
   /// Checks for legality of an indexed variant of \p LdSt.
   bool isIndexedLoadStoreLegal(GLoadStore &LdSt) const;
@@ -904,6 +904,18 @@ class CombinerHelper {
   /// select (fcmp uge x, 1.0) 1.0, x -> fminnm x, 1.0
   bool matchFPSelectToMinMax(Register Dst, Register Cond, Register TrueVal,
                              Register FalseVal, BuildFnTy &MatchInfo);
+
+  /// Try to fold selects to logical operations.
+  bool tryFoldBoolSelectToLogic(GSelect *Select, BuildFnTy &MatchInfo);
+
+  bool tryFoldSelectOfConstants(GSelect *Select, BuildFnTy &MatchInfo);
+
+  bool isOneOrOneSplat(Register Src, bool AllowUndefs);
+  bool isZeroOrZeroSplat(Register Src, bool AllowUndefs);
+  bool isConstantSplatVector(Register Src, int64_t SplatValue,
+                             bool AllowUndefs);
+
+  std::optional<APInt> getConstantOrConstantSplatVector(Register Src);
 };
 } // namespace llvm
 
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index 77db371adaf776..6bda80681432a0 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -437,13 +437,6 @@ def select_constant_cmp: GICombineRule<
   (apply [{ Helper.replaceSingleDefInstWithOperand(*${root}, ${matchinfo}); }])
 >;
 
-def select_to_logical : GICombineRule<
-  (defs root:$root, build_fn_matchinfo:$matchinfo),
-  (match (wip_match_opcode G_SELECT):$root,
-    [{ return Helper.matchSelectToLogical(*${root}, ${matchinfo}); }]),
-  (apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])
->;
-
 // Fold (C op x) -> (x op C)
 // TODO: handle more isCommutable opcodes
 // TODO: handle compares (currently not marked as isCommutable)
@@ -1242,6 +1235,12 @@ def select_to_minmax: GICombineRule<
          [{ return Helper.matchSimplifySelectToMinMax(*${root}, ${info}); }]),
   (apply [{ Helper.applyBuildFn(*${root}, ${info}); }])>;
 
+def match_selects : GICombineRule<
+  (defs root:$root, build_fn_matchinfo:$matchinfo),
+  (match (wip_match_opcode G_SELECT):$root,
+        [{ return Helper.matchSelect(*${root}, ${matchinfo}); }]),
+  (apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
+
 // FIXME: These should use the custom predicate feature once it lands.
 def undef_combines : GICombineGroup<[undef_to_fp_zero, undef_to_int_zero,
                                      undef_to_negative_one,
@@ -1282,7 +1281,7 @@ def width_reduction_combines : GICombineGroup<[reduce_shl_of_extend,
 def phi_combines : GICombineGroup<[extend_through_phis]>;
 
 def select_combines : GICombineGroup<[select_undef_cmp, select_constant_cmp,
-                                      select_to_logical]>;
+                                      match_selects]>;
 
 def trivial_combines : GICombineGroup<[copy_prop, mul_to_shl, add_p2i_to_ptradd,
                                        mul_by_neg_one, idempotent_prop]>;
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index 91a64d59e154df..072a73ded170bd 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -5940,62 +5940,6 @@ bool CombinerHelper::matchCombineFSubFpExtFNegFMulToFMadOrFMA(
   return false;
 }
 
-bool CombinerHelper::matchSelectToLogical(MachineInstr &MI,
-                                          BuildFnTy &MatchInfo) {
-  GSelect &Sel = cast<GSelect>(MI);
-  Register DstReg = Sel.getReg(0);
-  Register Cond = Sel.getCondReg();
-  Register TrueReg = Sel.getTrueReg();
-  Register FalseReg = Sel.getFalseReg();
-
-  auto *TrueDef = getDefIgnoringCopies(TrueReg, MRI);
-  auto *FalseDef = getDefIgnoringCopies(FalseReg, MRI);
-
-  const LLT CondTy = MRI.getType(Cond);
-  const LLT OpTy = MRI.getType(TrueReg);
-  if (CondTy != OpTy || OpTy.getScalarSizeInBits() != 1)
-    return false;
-
-  // We have a boolean select.
-
-  // select Cond, Cond, F --> or Cond, F
-  // select Cond, 1, F    --> or Cond, F
-  auto MaybeCstTrue = isConstantOrConstantSplatVector(*TrueDef, MRI);
-  if (Cond == TrueReg || (MaybeCstTrue && MaybeCstTrue->isOne())) {
-    MatchInfo = [=](MachineIRBuilder &MIB) {
-      MIB.buildOr(DstReg, Cond, FalseReg);
-    };
-    return true;
-  }
-
-  // select Cond, T, Cond --> and Cond, T
-  // select Cond, T, 0    --> and Cond, T
-  auto MaybeCstFalse = isConstantOrConstantSplatVector(*FalseDef, MRI);
-  if (Cond == FalseReg || (MaybeCstFalse && MaybeCstFalse->isZero())) {
-    MatchInfo = [=](MachineIRBuilder &MIB) {
-      MIB.buildAnd(DstReg, Cond, TrueReg);
-    };
-    return true;
-  }
-
- // select Cond, T, 1 --> or (not Cond), T
-  if (MaybeCstFalse && MaybeCstFalse->isOne()) {
-    MatchInfo = [=](MachineIRBuilder &MIB) {
-      MIB.buildOr(DstReg, MIB.buildNot(OpTy, Cond), TrueReg);
-    };
-    return true;
-  }
-
-  // select Cond, 0, F --> and (not Cond), F
-  if (MaybeCstTrue && MaybeCstTrue->isZero()) {
-    MatchInfo = [=](MachineIRBuilder &MIB) {
-      MIB.buildAnd(DstReg, MIB.buildNot(OpTy, Cond), FalseReg);
-    };
-    return true;
-  }
-  return false;
-}
-
 bool CombinerHelper::matchCombineFMinMaxNaN(MachineInstr &MI,
                                             unsigned &IdxToPropagate) {
   bool PropagateNaN;
@@ -6318,3 +6262,299 @@ void CombinerHelper::applyCommuteBinOpOperands(MachineInstr &MI) {
   MI.getOperand(2).setReg(LHSReg);
   Observer.changedInstr(MI);
 }
+
+bool CombinerHelper::isOneOrOneSplat(Register Src, bool AllowUndefs) {
+  LLT SrcTy = MRI.getType(Src);
+  if (SrcTy.isFixedVector())
+    return isConstantSplatVector(Src, 1, AllowUndefs);
+  if (SrcTy.isScalar()) {
+    if (AllowUndefs && getOpcodeDef<GImplicitDef>(Src, MRI) != nullptr)
+      return true;
+    auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI);
+    return IConstant && IConstant->Value == 1;
+  }
+  return false; // scalable vector
+}
+
+bool CombinerHelper::isZeroOrZeroSplat(Register Src, bool AllowUndefs) {
+  LLT SrcTy = MRI.getType(Src);
+  if (SrcTy.isFixedVector())
+    return isConstantSplatVector(Src, 0, AllowUndefs);
+  if (SrcTy.isScalar()) {
+    if (AllowUndefs && getOpcodeDef<GImplicitDef>(Src, MRI) != nullptr)
+      return true;
+    auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI);
+    return IConstant && IConstant->Value == 0;
+  }
+  return false; // scalable vector
+}
+
+// Ignores COPYs during conformance checks.
+// FIXME scalable vectors.
+bool CombinerHelper::isConstantSplatVector(Register Src, int64_t SplatValue,
+                                           bool AllowUndefs) {
+  GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Src, MRI);
+  if (!BuildVector)
+    return false;
+  unsigned NumSources = BuildVector->getNumSources();
+
+  for (unsigned I = 0; I < NumSources; ++I) {
+    GImplicitDef *ImplicitDef =
+        getOpcodeDef<GImplicitDef>(BuildVector->getSourceReg(I), MRI);
+    if (ImplicitDef && AllowUndefs)
+      continue;
+    if (ImplicitDef && !AllowUndefs)
+      return false;
+    std::optional<ValueAndVReg> IConstant =
+        getIConstantVRegValWithLookThrough(BuildVector->getSourceReg(I), MRI);
+    if (IConstant && IConstant->Value == SplatValue)
+      continue;
+    return false;
+  }
+  return true;
+}
+
+// Ignores COPYs during lookups.
+// FIXME scalable vectors
+std::optional<APInt>
+CombinerHelper::getConstantOrConstantSplatVector(Register Src) {
+  auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI);
+  if (IConstant)
+    return IConstant->Value;
+
+  GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Src, MRI);
+  if (!BuildVector)
+    return std::nullopt;
+  unsigned NumSources = BuildVector->getNumSources();
+
+  std::optional<APInt> Value = std::nullopt;
+  for (unsigned I = 0; I < NumSources; ++I) {
+    std::optional<ValueAndVReg> IConstant =
+        getIConstantVRegValWithLookThrough(BuildVector->getSourceReg(I), MRI);
+    if (!IConstant)
+      return std::nullopt;
+    if (!Value)
+      Value = IConstant->Value;
+    else if (*Value != IConstant->Value)
+      return std::nullopt;
+  }
+  return Value;
+}
+
+// TODO: use knownbits to determine zeros
+bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
+                                              BuildFnTy &MatchInfo) {
+  uint32_t Flags = Select->getFlags();
+  Register Dest = Select->getReg(0);
+  Register Cond = Select->getCondReg();
+  Register True = Select->getTrueReg();
+  Register False = Select->getFalseReg();
+  LLT CondTy = MRI.getType(Select->getCondReg());
+  LLT TrueTy = MRI.getType(Select->getTrueReg());
+
+  // Either both are scalars or both are vectors.
+  std::optional<APInt> TrueOpt = getConstantOrConstantSplatVector(True);
+  std::optional<APInt> FalseOpt = getConstantOrConstantSplatVector(False);
+
+  if (!TrueOpt || !FalseOpt)
+    return false;
+
+  // These are only the splat values.
+  APInt TrueValue = *TrueOpt;
+  APInt FalseValue = *FalseOpt;
+
+  // Boolean or fixed vector of booleans.
+  if (CondTy.isScalableVector() ||
+      (CondTy.isFixedVector() &&
+       CondTy.getElementType().getScalarSizeInBits() != 1) ||
+      CondTy.getScalarSizeInBits() != 1)
+    return false;
+
+  // select Cond, 1, 0 --> zext (Cond)
+  if (TrueValue.isOne() && FalseValue.isZero()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      B.buildZExtOrTrunc(Dest, Cond);
+    };
+    return true;
+  }
+
+  // select Cond, -1, 0 --> sext (Cond)
+  if (TrueValue.isAllOnes() && FalseValue.isZero()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      B.buildSExtOrTrunc(Dest, Cond);
+    };
+    return true;
+  }
+
+  // select Cond, 0, 1 --> zext (!Cond)
+  if (TrueValue.isZero() && FalseValue.isOne()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(CondTy);
+      B.buildNot(Inner, Cond);
+      B.buildZExtOrTrunc(Dest, Inner);
+    };
+    return true;
+  }
+
+  // select Cond, 0, -1 --> sext (!Cond)
+  if (TrueValue.isZero() && FalseValue.isAllOnes()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(CondTy);
+      B.buildNot(Inner, Cond);
+      B.buildSExtOrTrunc(Dest, Inner);
+    };
+    return true;
+  }
+
+  // select Cond, C1, C1-1 --> add (zext Cond), C1-1
+  if (TrueValue - 1 == FalseValue) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildZExtOrTrunc(Inner, Cond);
+      B.buildAdd(Dest, Inner, False);
+    };
+    return true;
+  }
+
+  // select Cond, C1, C1+1 --> add (sext Cond), C1+1
+  if (TrueValue + 1 == FalseValue) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildSExtOrTrunc(Inner, Cond);
+      B.buildAdd(Dest, Inner, False);
+    };
+    return true;
+  }
+
+  // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
+  if (TrueValue.isPowerOf2() && FalseValue.isZero()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildZExtOrTrunc(Inner, Cond);
+      // The shift amount must be scalar.
+      LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy;
+      auto ShAmtC = B.buildConstant(ShiftTy, TrueValue.exactLogBase2());
+      B.buildShl(Dest, Inner, ShAmtC, Flags);
+    };
+    return true;
+  }
+  // select Cond, -1, C --> or (sext Cond), C
+  if (TrueValue.isAllOnes()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildSExtOrTrunc(Inner, Cond);
+      B.buildOr(Dest, Inner, False, Flags);
+    };
+    return true;
+  }
+
+  // select Cond, C, -1 --> or (sext (not Cond)), C
+  if (FalseValue.isAllOnes()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Not = MRI.createGenericVirtualRegister(CondTy);
+      B.buildNot(Not, Cond);
+      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildSExtOrTrunc(Inner, Not);
+      B.buildOr(Dest, Inner, True, Flags);
+    };
+    return true;
+  }
+
+  return false;
+}
+
+// TODO: use knownbits to determine zeros
+bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
+                                              BuildFnTy &MatchInfo) {
+  uint32_t Flags = Select->getFlags();
+  Register DstReg = Select->getReg(0);
+  Register Cond = Select->getCondReg();
+  Register True = Select->getTrueReg();
+  Register False = Select->getFalseReg();
+  LLT CondTy = MRI.getType(Select->getCondReg());
+  LLT TrueTy = MRI.getType(Select->getTrueReg());
+
+  // Boolean or fixed vector of booleans.
+  if (CondTy.isScalableVector() ||
+      (CondTy.isFixedVector() &&
+       CondTy.getElementType().getScalarSizeInBits() != 1) ||
+      CondTy.getScalarSizeInBits() != 1)
+    return false;
+
+  // select Cond, Cond, F --> or Cond, F
+  // select Cond, 1, F    --> or Cond, F
+  if ((Cond == True) || isOneOrOneSplat(True, /* AllowUndefs */ true)) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Ext = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildZExtOrTrunc(Ext, Cond);
+      B.buildOr(DstReg, Ext, False, Flags);
+    };
+    return true;
+  }
+
+  // select Cond, T, Cond --> and Cond, T
+  // select Cond, T, 0    --> and Cond, T
+  if ((Cond == False) || isZeroOrZeroSplat(False, /* AllowUndefs */ true)) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Ext = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildZExtOrTrunc(Ext, Cond);
+      B.buildAnd(DstReg, Ext, True);
+    };
+    return true;
+  }
+
+  // select Cond, T, 1 --> or (not Cond), T
+  if (isOneOrOneSplat(False, /* AllowUndefs */ true)) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      // First the not.
+      Register Inner = MRI.createGenericVirtualRegister(CondTy);
+      B.buildNot(Inner, Cond);
+      // Then an ext to match the destination register.
+      Register Ext = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildZExtOrTrunc(Ext, Inner);
+      B.buildOr(DstReg, Ext, True, Flags);
+    };
+    return true;
+  }
+
+  // select Cond, 0, F --> and (not Cond), F
+  if (isZeroOrZeroSplat(True, /* AllowUndefs */ true)) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      // First the not.
+      Register Inner = MRI.createGenericVirtualRegister(CondTy);
+      B.buildNot(Inner, Cond);
+      // Then an ext to match the destination register.
+      Register Ext = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildZExtOrTrunc(Ext, Inner);
+      B.buildAnd(DstReg, Ext, False);
+    };
+    return true;
+  }
+
+  return false;
+}
+
+bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) {
+  GSelect *Select = cast<GSelect>(&MI);
+
+  if (tryFoldSelectOfConstants(Select, MatchInfo))
+    return true;
+
+  if (tryFoldBoolSelectToLogic(Select, MatchInfo))
+    return true;
+
+  return false;
+}
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir b/llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir
index 81d38a5b080470..be2de620fa456c 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir
@@ -298,3 +298,249 @@ body:             |
     %ext:_(s32) = G_ANYEXT %sel
     $w0 = COPY %ext(s32)
 ...
+---
+# select cond, 1, 0 --> zext(Cond)
+name:            select_cond_1_0_to_zext_cond
+body:             |
+  bb.1:
+    liveins: $x0, $x1, $x2
+    ; CHECK-LABEL: name: select_cond_1_0_to_zext_cond
+    ; CHECK: liveins: $x0, $x1, $x2
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: %c:_(s1) = G_TRUNC [[COPY]](s64)
+    ; CHECK-NEXT: %ext:_(s32) = G_ANYEXT %c(s1)
+    ; CHECK-NEXT: $w0 = COPY %ext(s32)
+    %0:_(s64) = COPY $x0
+    %1:_(s64) = COPY $x1
+    %2:_(s64) = COPY $x2
+    %c:_(s1) = G_TRUNC %0
+    %t:_(s1) = G_TRUNC %1
+    %f:_(s1) = G_TRUNC %2
+    %zero:_(s1) = G_CONSTANT i1 0
+    %one:_(s1) = G_CONSTANT i1 1
+    %sel:_(s1) = G_SELECT %c, %one, %zero
+    %ext:_(s32) = G_ANYEXT %sel
+    $w0 = COPY %ext(s32)
+...
+---
+# select cond, 0, 1 --> zext(!Cond)
+name:            select_cond_0_1_to_sext_not_cond
+body:             |
+  bb.1:
+    liveins: $x0, $x1, $x2
+    ; CHECK-LABEL: name: select_cond_0_1_to_sext_not_cond
+    ; CHECK: liveins: $x0, $x1, $x2
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: %c:_(s1) = G_TRUNC [[COPY]](s64)
+    ; CHECK-NEXT: %one:_(s1) = G_CONSTANT i1 true
+    ; CHECK-NEXT: [[XOR:%[0-9]+]]:_(s1) = G_XOR %c, %one
+    ; CHECK-NEXT: %ext:_(s32) = G_ANYEXT [[XOR]](s1)
+    ; CHECK-NEXT: $w0 = COPY %ext(s32)
+    %0:_(s64) = COPY $x0
+    %1:_(s64) = COPY $x1
+    %2:_(s64) = COPY $x2
+    %c:_(s1) = G_TRUNC %0
+    %t:_(s1) = G_TRUNC %1
+    %f:_(s1) = G_TRUNC %2
+    %zero:_(s1) = G_CONSTANT i1 0
+    %one:_(s1) = G_CONSTANT i1 1
+    %sel:_(s1) = G_SELECT %c, %zero, %one
+    %ext:_(s32) = G_ANYEXT %sel
+    $w0 = COPY %ext(s32)
+...
+---
+# select cond, 2, 1 --> and (zext Cond), false
+name:            select_cond_2_1_to_and_zext_cond_false
+body:             |
+  bb.1:
+    liveins: $x0, $x1, $x2
+    ; CHECK-LABEL: name: select_cond_2_1_to_and_zext_cond_false
+    ; CHECK: liveins: $x0, $x1, $x2
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: %c:_(s1) = G_TRUNC [[COPY]](s64)
+    ; CHECK-NEXT: %one:_(s8) = G_CONSTANT i8 101
+    ; CHECK-NEXT: [[ZEXT:%[0-9]+]]:_(s8) = G_ZEXT %c(s1)
+    ; CHECK-NEXT: %sel:_(s8) = G_ADD [[ZEXT]], %one
+    ; CHECK-NEXT: %ext:_(s32) = G_ANYEXT %sel(s8)
+    ; CHECK-NEXT: $w0 = COPY %ext(s32)
+    %0:_(s64) = COPY $x0
+    %1:_(s64) = COPY $x1
+    %2:_(s64) = COPY $x2
+    %c:_(s1) = G_TRUNC %0
+    %t:_(s1) = G_TRUNC %1
+    %f:_(s1) = G_TRUNC %2
+    %two:_(s8) = G_CONSTANT i8 102
+    %one:_(s8) = G_CONSTANT i8 101
+    %sel:_(s8) = G_SELECT %c, %two, %one
+    %ext:_(s32) = G_ANYEXT %sel
+    $w0 = COPY %ext(s32)
+...
+---
+# select cond, 1, 2 --> and (ext Cond), false
+name:            select_cond_1_2_to_and_sext_cond_false
+body:             |
+  bb.1:
+    liveins: $x0, $x1, $x2
+    ; CHECK-LABEL: name: select_cond_1_2_to_and_sext_cond_false
+    ; CHECK: liveins: $x0, $x1, $x2
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: %c:_(s1) = G_TRUNC [[COPY]](s64)
+    ; CHECK-NEXT: %one:_(s8) = G_CONSTANT i8 102
+    ; CHECK-NEXT: [[SEXT:%[0-9]+]]:_(s8) = G_SEXT %c(s1)
+    ; CHECK-NEXT: %sel:_(s8) = G_ADD [[SEXT]], %one
+    ; CHECK-NEXT: %ext:_(s32) = G_ANYEXT %sel(s8)
+    ; CHECK-NEXT: $w0 = COPY %ext(s32)
+    %0:_(s64) = COPY $x0
+    %1:_(s64) = COPY $x1
+    %2:_(s64) = COPY $x2
+    %c:_(s1) = G_TRUNC %0
+    %t:_(s1) = G_TRUNC %1
+    %f:_(s1) = G_TRUNC %2
+    %two:_(s8) = G_CONSTANT i8 101
+    %one:_(s8) = G_CONSTANT i8 102
+    %sel:_(s8) = G_SELECT %c, %two, %one
+    %ext:_(s32) = G_ANYEXT %sel
+    $w0 = COPY %ext(s32...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Dec 20, 2023

@llvm/pr-subscribers-llvm-globalisel

Author: Thorsten Schütt (tschuett)

Changes

A first small step at combining selects.


Patch is 691.01 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76089.diff

29 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h (+15-3)
  • (modified) llvm/include/llvm/Target/GlobalISel/Combine.td (+7-8)
  • (modified) llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp (+296-56)
  • (modified) llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir (+246)
  • (modified) llvm/test/CodeGen/AArch64/andcompare.ll (+9-5)
  • (modified) llvm/test/CodeGen/AArch64/arm64-ccmp.ll (+73-51)
  • (modified) llvm/test/CodeGen/AArch64/call-rv-marker.ll (+409-38)
  • (modified) llvm/test/CodeGen/AArch64/neon-bitwise-instructions.ll (+8-12)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/combine-fold-binop-into-select.mir (+17-25)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/divergence-divergent-i1-phis-no-lane-mask-merging.ll (+4-2)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/divergence-structurizer.ll (+4-2)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/fshl.ll (+1088-1104)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/fshr.ll (+1105-1311)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wqm.demote.ll (+32-16)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/lshr.ll (+41-36)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/saddsat.ll (+173-176)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/shl.ll (+59-52)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/usubsat.ll (+587-307)
  • (modified) llvm/test/CodeGen/AMDGPU/ctlz_zero_undef.ll (+18-14)
  • (modified) llvm/test/CodeGen/AMDGPU/fdiv_flags.f32.ll (+8-6)
  • (modified) llvm/test/CodeGen/AMDGPU/fptrunc.ll (+24-22)
  • (modified) llvm/test/CodeGen/AMDGPU/fsqrt.f32.ll (+112-87)
  • (modified) llvm/test/CodeGen/AMDGPU/fsqrt.f64.ll (+238-260)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.inverse.ballot.i64.ll (+4-4)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.frexp.ll (+95-62)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.log.ll (+22-18)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.log10.ll (+22-18)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.log2.ll (+8-6)
  • (modified) llvm/test/CodeGen/AMDGPU/rsq.f64.ll (+450-500)
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index e7debc652a0a8b..dcc1a4580b14a2 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -769,9 +769,6 @@ class CombinerHelper {
   bool matchCombineFSubFpExtFNegFMulToFMadOrFMA(MachineInstr &MI,
                                                 BuildFnTy &MatchInfo);
 
-  /// Fold boolean selects to logical operations.
-  bool matchSelectToLogical(MachineInstr &MI, BuildFnTy &MatchInfo);
-
   bool matchCombineFMinMaxNaN(MachineInstr &MI, unsigned &Info);
 
   /// Transform G_ADD(x, G_SUB(y, x)) to y.
@@ -814,6 +811,9 @@ class CombinerHelper {
   // Given a binop \p MI, commute operands 1 and 2.
   void applyCommuteBinOpOperands(MachineInstr &MI);
 
+  /// Combine selects.
+  bool matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo);
+
 private:
   /// Checks for legality of an indexed variant of \p LdSt.
   bool isIndexedLoadStoreLegal(GLoadStore &LdSt) const;
@@ -904,6 +904,18 @@ class CombinerHelper {
   /// select (fcmp uge x, 1.0) 1.0, x -> fminnm x, 1.0
   bool matchFPSelectToMinMax(Register Dst, Register Cond, Register TrueVal,
                              Register FalseVal, BuildFnTy &MatchInfo);
+
+  /// Try to fold selects to logical operations.
+  bool tryFoldBoolSelectToLogic(GSelect *Select, BuildFnTy &MatchInfo);
+
+  bool tryFoldSelectOfConstants(GSelect *Select, BuildFnTy &MatchInfo);
+
+  bool isOneOrOneSplat(Register Src, bool AllowUndefs);
+  bool isZeroOrZeroSplat(Register Src, bool AllowUndefs);
+  bool isConstantSplatVector(Register Src, int64_t SplatValue,
+                             bool AllowUndefs);
+
+  std::optional<APInt> getConstantOrConstantSplatVector(Register Src);
 };
 } // namespace llvm
 
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index 77db371adaf776..6bda80681432a0 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -437,13 +437,6 @@ def select_constant_cmp: GICombineRule<
   (apply [{ Helper.replaceSingleDefInstWithOperand(*${root}, ${matchinfo}); }])
 >;
 
-def select_to_logical : GICombineRule<
-  (defs root:$root, build_fn_matchinfo:$matchinfo),
-  (match (wip_match_opcode G_SELECT):$root,
-    [{ return Helper.matchSelectToLogical(*${root}, ${matchinfo}); }]),
-  (apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])
->;
-
 // Fold (C op x) -> (x op C)
 // TODO: handle more isCommutable opcodes
 // TODO: handle compares (currently not marked as isCommutable)
@@ -1242,6 +1235,12 @@ def select_to_minmax: GICombineRule<
          [{ return Helper.matchSimplifySelectToMinMax(*${root}, ${info}); }]),
   (apply [{ Helper.applyBuildFn(*${root}, ${info}); }])>;
 
+def match_selects : GICombineRule<
+  (defs root:$root, build_fn_matchinfo:$matchinfo),
+  (match (wip_match_opcode G_SELECT):$root,
+        [{ return Helper.matchSelect(*${root}, ${matchinfo}); }]),
+  (apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
+
 // FIXME: These should use the custom predicate feature once it lands.
 def undef_combines : GICombineGroup<[undef_to_fp_zero, undef_to_int_zero,
                                      undef_to_negative_one,
@@ -1282,7 +1281,7 @@ def width_reduction_combines : GICombineGroup<[reduce_shl_of_extend,
 def phi_combines : GICombineGroup<[extend_through_phis]>;
 
 def select_combines : GICombineGroup<[select_undef_cmp, select_constant_cmp,
-                                      select_to_logical]>;
+                                      match_selects]>;
 
 def trivial_combines : GICombineGroup<[copy_prop, mul_to_shl, add_p2i_to_ptradd,
                                        mul_by_neg_one, idempotent_prop]>;
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index 91a64d59e154df..072a73ded170bd 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -5940,62 +5940,6 @@ bool CombinerHelper::matchCombineFSubFpExtFNegFMulToFMadOrFMA(
   return false;
 }
 
-bool CombinerHelper::matchSelectToLogical(MachineInstr &MI,
-                                          BuildFnTy &MatchInfo) {
-  GSelect &Sel = cast<GSelect>(MI);
-  Register DstReg = Sel.getReg(0);
-  Register Cond = Sel.getCondReg();
-  Register TrueReg = Sel.getTrueReg();
-  Register FalseReg = Sel.getFalseReg();
-
-  auto *TrueDef = getDefIgnoringCopies(TrueReg, MRI);
-  auto *FalseDef = getDefIgnoringCopies(FalseReg, MRI);
-
-  const LLT CondTy = MRI.getType(Cond);
-  const LLT OpTy = MRI.getType(TrueReg);
-  if (CondTy != OpTy || OpTy.getScalarSizeInBits() != 1)
-    return false;
-
-  // We have a boolean select.
-
-  // select Cond, Cond, F --> or Cond, F
-  // select Cond, 1, F    --> or Cond, F
-  auto MaybeCstTrue = isConstantOrConstantSplatVector(*TrueDef, MRI);
-  if (Cond == TrueReg || (MaybeCstTrue && MaybeCstTrue->isOne())) {
-    MatchInfo = [=](MachineIRBuilder &MIB) {
-      MIB.buildOr(DstReg, Cond, FalseReg);
-    };
-    return true;
-  }
-
-  // select Cond, T, Cond --> and Cond, T
-  // select Cond, T, 0    --> and Cond, T
-  auto MaybeCstFalse = isConstantOrConstantSplatVector(*FalseDef, MRI);
-  if (Cond == FalseReg || (MaybeCstFalse && MaybeCstFalse->isZero())) {
-    MatchInfo = [=](MachineIRBuilder &MIB) {
-      MIB.buildAnd(DstReg, Cond, TrueReg);
-    };
-    return true;
-  }
-
- // select Cond, T, 1 --> or (not Cond), T
-  if (MaybeCstFalse && MaybeCstFalse->isOne()) {
-    MatchInfo = [=](MachineIRBuilder &MIB) {
-      MIB.buildOr(DstReg, MIB.buildNot(OpTy, Cond), TrueReg);
-    };
-    return true;
-  }
-
-  // select Cond, 0, F --> and (not Cond), F
-  if (MaybeCstTrue && MaybeCstTrue->isZero()) {
-    MatchInfo = [=](MachineIRBuilder &MIB) {
-      MIB.buildAnd(DstReg, MIB.buildNot(OpTy, Cond), FalseReg);
-    };
-    return true;
-  }
-  return false;
-}
-
 bool CombinerHelper::matchCombineFMinMaxNaN(MachineInstr &MI,
                                             unsigned &IdxToPropagate) {
   bool PropagateNaN;
@@ -6318,3 +6262,299 @@ void CombinerHelper::applyCommuteBinOpOperands(MachineInstr &MI) {
   MI.getOperand(2).setReg(LHSReg);
   Observer.changedInstr(MI);
 }
+
+bool CombinerHelper::isOneOrOneSplat(Register Src, bool AllowUndefs) {
+  LLT SrcTy = MRI.getType(Src);
+  if (SrcTy.isFixedVector())
+    return isConstantSplatVector(Src, 1, AllowUndefs);
+  if (SrcTy.isScalar()) {
+    if (AllowUndefs && getOpcodeDef<GImplicitDef>(Src, MRI) != nullptr)
+      return true;
+    auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI);
+    return IConstant && IConstant->Value == 1;
+  }
+  return false; // scalable vector
+}
+
+bool CombinerHelper::isZeroOrZeroSplat(Register Src, bool AllowUndefs) {
+  LLT SrcTy = MRI.getType(Src);
+  if (SrcTy.isFixedVector())
+    return isConstantSplatVector(Src, 0, AllowUndefs);
+  if (SrcTy.isScalar()) {
+    if (AllowUndefs && getOpcodeDef<GImplicitDef>(Src, MRI) != nullptr)
+      return true;
+    auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI);
+    return IConstant && IConstant->Value == 0;
+  }
+  return false; // scalable vector
+}
+
+// Ignores COPYs during conformance checks.
+// FIXME scalable vectors.
+bool CombinerHelper::isConstantSplatVector(Register Src, int64_t SplatValue,
+                                           bool AllowUndefs) {
+  GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Src, MRI);
+  if (!BuildVector)
+    return false;
+  unsigned NumSources = BuildVector->getNumSources();
+
+  for (unsigned I = 0; I < NumSources; ++I) {
+    GImplicitDef *ImplicitDef =
+        getOpcodeDef<GImplicitDef>(BuildVector->getSourceReg(I), MRI);
+    if (ImplicitDef && AllowUndefs)
+      continue;
+    if (ImplicitDef && !AllowUndefs)
+      return false;
+    std::optional<ValueAndVReg> IConstant =
+        getIConstantVRegValWithLookThrough(BuildVector->getSourceReg(I), MRI);
+    if (IConstant && IConstant->Value == SplatValue)
+      continue;
+    return false;
+  }
+  return true;
+}
+
+// Ignores COPYs during lookups.
+// FIXME scalable vectors
+std::optional<APInt>
+CombinerHelper::getConstantOrConstantSplatVector(Register Src) {
+  auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI);
+  if (IConstant)
+    return IConstant->Value;
+
+  GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Src, MRI);
+  if (!BuildVector)
+    return std::nullopt;
+  unsigned NumSources = BuildVector->getNumSources();
+
+  std::optional<APInt> Value = std::nullopt;
+  for (unsigned I = 0; I < NumSources; ++I) {
+    std::optional<ValueAndVReg> IConstant =
+        getIConstantVRegValWithLookThrough(BuildVector->getSourceReg(I), MRI);
+    if (!IConstant)
+      return std::nullopt;
+    if (!Value)
+      Value = IConstant->Value;
+    else if (*Value != IConstant->Value)
+      return std::nullopt;
+  }
+  return Value;
+}
+
+// TODO: use knownbits to determine zeros
+bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
+                                              BuildFnTy &MatchInfo) {
+  uint32_t Flags = Select->getFlags();
+  Register Dest = Select->getReg(0);
+  Register Cond = Select->getCondReg();
+  Register True = Select->getTrueReg();
+  Register False = Select->getFalseReg();
+  LLT CondTy = MRI.getType(Select->getCondReg());
+  LLT TrueTy = MRI.getType(Select->getTrueReg());
+
+  // Either both are scalars or both are vectors.
+  std::optional<APInt> TrueOpt = getConstantOrConstantSplatVector(True);
+  std::optional<APInt> FalseOpt = getConstantOrConstantSplatVector(False);
+
+  if (!TrueOpt || !FalseOpt)
+    return false;
+
+  // These are only the splat values.
+  APInt TrueValue = *TrueOpt;
+  APInt FalseValue = *FalseOpt;
+
+  // Boolean or fixed vector of booleans.
+  if (CondTy.isScalableVector() ||
+      (CondTy.isFixedVector() &&
+       CondTy.getElementType().getScalarSizeInBits() != 1) ||
+      CondTy.getScalarSizeInBits() != 1)
+    return false;
+
+  // select Cond, 1, 0 --> zext (Cond)
+  if (TrueValue.isOne() && FalseValue.isZero()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      B.buildZExtOrTrunc(Dest, Cond);
+    };
+    return true;
+  }
+
+  // select Cond, -1, 0 --> sext (Cond)
+  if (TrueValue.isAllOnes() && FalseValue.isZero()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      B.buildSExtOrTrunc(Dest, Cond);
+    };
+    return true;
+  }
+
+  // select Cond, 0, 1 --> zext (!Cond)
+  if (TrueValue.isZero() && FalseValue.isOne()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(CondTy);
+      B.buildNot(Inner, Cond);
+      B.buildZExtOrTrunc(Dest, Inner);
+    };
+    return true;
+  }
+
+  // select Cond, 0, -1 --> sext (!Cond)
+  if (TrueValue.isZero() && FalseValue.isAllOnes()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(CondTy);
+      B.buildNot(Inner, Cond);
+      B.buildSExtOrTrunc(Dest, Inner);
+    };
+    return true;
+  }
+
+  // select Cond, C1, C1-1 --> add (zext Cond), C1-1
+  if (TrueValue - 1 == FalseValue) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildZExtOrTrunc(Inner, Cond);
+      B.buildAdd(Dest, Inner, False);
+    };
+    return true;
+  }
+
+  // select Cond, C1, C1+1 --> add (sext Cond), C1+1
+  if (TrueValue + 1 == FalseValue) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildSExtOrTrunc(Inner, Cond);
+      B.buildAdd(Dest, Inner, False);
+    };
+    return true;
+  }
+
+  // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
+  if (TrueValue.isPowerOf2() && FalseValue.isZero()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildZExtOrTrunc(Inner, Cond);
+      // The shift amount must be scalar.
+      LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy;
+      auto ShAmtC = B.buildConstant(ShiftTy, TrueValue.exactLogBase2());
+      B.buildShl(Dest, Inner, ShAmtC, Flags);
+    };
+    return true;
+  }
+  // select Cond, -1, C --> or (sext Cond), C
+  if (TrueValue.isAllOnes()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildSExtOrTrunc(Inner, Cond);
+      B.buildOr(Dest, Inner, False, Flags);
+    };
+    return true;
+  }
+
+  // select Cond, C, -1 --> or (sext (not Cond)), C
+  if (FalseValue.isAllOnes()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Not = MRI.createGenericVirtualRegister(CondTy);
+      B.buildNot(Not, Cond);
+      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildSExtOrTrunc(Inner, Not);
+      B.buildOr(Dest, Inner, True, Flags);
+    };
+    return true;
+  }
+
+  return false;
+}
+
+// TODO: use knownbits to determine zeros
+bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
+                                              BuildFnTy &MatchInfo) {
+  uint32_t Flags = Select->getFlags();
+  Register DstReg = Select->getReg(0);
+  Register Cond = Select->getCondReg();
+  Register True = Select->getTrueReg();
+  Register False = Select->getFalseReg();
+  LLT CondTy = MRI.getType(Select->getCondReg());
+  LLT TrueTy = MRI.getType(Select->getTrueReg());
+
+  // Boolean or fixed vector of booleans.
+  if (CondTy.isScalableVector() ||
+      (CondTy.isFixedVector() &&
+       CondTy.getElementType().getScalarSizeInBits() != 1) ||
+      CondTy.getScalarSizeInBits() != 1)
+    return false;
+
+  // select Cond, Cond, F --> or Cond, F
+  // select Cond, 1, F    --> or Cond, F
+  if ((Cond == True) || isOneOrOneSplat(True, /* AllowUndefs */ true)) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Ext = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildZExtOrTrunc(Ext, Cond);
+      B.buildOr(DstReg, Ext, False, Flags);
+    };
+    return true;
+  }
+
+  // select Cond, T, Cond --> and Cond, T
+  // select Cond, T, 0    --> and Cond, T
+  if ((Cond == False) || isZeroOrZeroSplat(False, /* AllowUndefs */ true)) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      Register Ext = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildZExtOrTrunc(Ext, Cond);
+      B.buildAnd(DstReg, Ext, True);
+    };
+    return true;
+  }
+
+  // select Cond, T, 1 --> or (not Cond), T
+  if (isOneOrOneSplat(False, /* AllowUndefs */ true)) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      // First the not.
+      Register Inner = MRI.createGenericVirtualRegister(CondTy);
+      B.buildNot(Inner, Cond);
+      // Then an ext to match the destination register.
+      Register Ext = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildZExtOrTrunc(Ext, Inner);
+      B.buildOr(DstReg, Ext, True, Flags);
+    };
+    return true;
+  }
+
+  // select Cond, 0, F --> and (not Cond), F
+  if (isZeroOrZeroSplat(True, /* AllowUndefs */ true)) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      // First the not.
+      Register Inner = MRI.createGenericVirtualRegister(CondTy);
+      B.buildNot(Inner, Cond);
+      // Then an ext to match the destination register.
+      Register Ext = MRI.createGenericVirtualRegister(TrueTy);
+      B.buildZExtOrTrunc(Ext, Inner);
+      B.buildAnd(DstReg, Ext, False);
+    };
+    return true;
+  }
+
+  return false;
+}
+
+bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) {
+  GSelect *Select = cast<GSelect>(&MI);
+
+  if (tryFoldSelectOfConstants(Select, MatchInfo))
+    return true;
+
+  if (tryFoldBoolSelectToLogic(Select, MatchInfo))
+    return true;
+
+  return false;
+}
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir b/llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir
index 81d38a5b080470..be2de620fa456c 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir
@@ -298,3 +298,249 @@ body:             |
     %ext:_(s32) = G_ANYEXT %sel
     $w0 = COPY %ext(s32)
 ...
+---
+# select cond, 1, 0 --> zext(Cond)
+name:            select_cond_1_0_to_zext_cond
+body:             |
+  bb.1:
+    liveins: $x0, $x1, $x2
+    ; CHECK-LABEL: name: select_cond_1_0_to_zext_cond
+    ; CHECK: liveins: $x0, $x1, $x2
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: %c:_(s1) = G_TRUNC [[COPY]](s64)
+    ; CHECK-NEXT: %ext:_(s32) = G_ANYEXT %c(s1)
+    ; CHECK-NEXT: $w0 = COPY %ext(s32)
+    %0:_(s64) = COPY $x0
+    %1:_(s64) = COPY $x1
+    %2:_(s64) = COPY $x2
+    %c:_(s1) = G_TRUNC %0
+    %t:_(s1) = G_TRUNC %1
+    %f:_(s1) = G_TRUNC %2
+    %zero:_(s1) = G_CONSTANT i1 0
+    %one:_(s1) = G_CONSTANT i1 1
+    %sel:_(s1) = G_SELECT %c, %one, %zero
+    %ext:_(s32) = G_ANYEXT %sel
+    $w0 = COPY %ext(s32)
+...
+---
+# select cond, 0, 1 --> zext(!Cond)
+name:            select_cond_0_1_to_sext_not_cond
+body:             |
+  bb.1:
+    liveins: $x0, $x1, $x2
+    ; CHECK-LABEL: name: select_cond_0_1_to_sext_not_cond
+    ; CHECK: liveins: $x0, $x1, $x2
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: %c:_(s1) = G_TRUNC [[COPY]](s64)
+    ; CHECK-NEXT: %one:_(s1) = G_CONSTANT i1 true
+    ; CHECK-NEXT: [[XOR:%[0-9]+]]:_(s1) = G_XOR %c, %one
+    ; CHECK-NEXT: %ext:_(s32) = G_ANYEXT [[XOR]](s1)
+    ; CHECK-NEXT: $w0 = COPY %ext(s32)
+    %0:_(s64) = COPY $x0
+    %1:_(s64) = COPY $x1
+    %2:_(s64) = COPY $x2
+    %c:_(s1) = G_TRUNC %0
+    %t:_(s1) = G_TRUNC %1
+    %f:_(s1) = G_TRUNC %2
+    %zero:_(s1) = G_CONSTANT i1 0
+    %one:_(s1) = G_CONSTANT i1 1
+    %sel:_(s1) = G_SELECT %c, %zero, %one
+    %ext:_(s32) = G_ANYEXT %sel
+    $w0 = COPY %ext(s32)
+...
+---
+# select cond, 2, 1 --> and (zext Cond), false
+name:            select_cond_2_1_to_and_zext_cond_false
+body:             |
+  bb.1:
+    liveins: $x0, $x1, $x2
+    ; CHECK-LABEL: name: select_cond_2_1_to_and_zext_cond_false
+    ; CHECK: liveins: $x0, $x1, $x2
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: %c:_(s1) = G_TRUNC [[COPY]](s64)
+    ; CHECK-NEXT: %one:_(s8) = G_CONSTANT i8 101
+    ; CHECK-NEXT: [[ZEXT:%[0-9]+]]:_(s8) = G_ZEXT %c(s1)
+    ; CHECK-NEXT: %sel:_(s8) = G_ADD [[ZEXT]], %one
+    ; CHECK-NEXT: %ext:_(s32) = G_ANYEXT %sel(s8)
+    ; CHECK-NEXT: $w0 = COPY %ext(s32)
+    %0:_(s64) = COPY $x0
+    %1:_(s64) = COPY $x1
+    %2:_(s64) = COPY $x2
+    %c:_(s1) = G_TRUNC %0
+    %t:_(s1) = G_TRUNC %1
+    %f:_(s1) = G_TRUNC %2
+    %two:_(s8) = G_CONSTANT i8 102
+    %one:_(s8) = G_CONSTANT i8 101
+    %sel:_(s8) = G_SELECT %c, %two, %one
+    %ext:_(s32) = G_ANYEXT %sel
+    $w0 = COPY %ext(s32)
+...
+---
+# select cond, 1, 2 --> and (ext Cond), false
+name:            select_cond_1_2_to_and_sext_cond_false
+body:             |
+  bb.1:
+    liveins: $x0, $x1, $x2
+    ; CHECK-LABEL: name: select_cond_1_2_to_and_sext_cond_false
+    ; CHECK: liveins: $x0, $x1, $x2
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: %c:_(s1) = G_TRUNC [[COPY]](s64)
+    ; CHECK-NEXT: %one:_(s8) = G_CONSTANT i8 102
+    ; CHECK-NEXT: [[SEXT:%[0-9]+]]:_(s8) = G_SEXT %c(s1)
+    ; CHECK-NEXT: %sel:_(s8) = G_ADD [[SEXT]], %one
+    ; CHECK-NEXT: %ext:_(s32) = G_ANYEXT %sel(s8)
+    ; CHECK-NEXT: $w0 = COPY %ext(s32)
+    %0:_(s64) = COPY $x0
+    %1:_(s64) = COPY $x1
+    %2:_(s64) = COPY $x2
+    %c:_(s1) = G_TRUNC %0
+    %t:_(s1) = G_TRUNC %1
+    %f:_(s1) = G_TRUNC %2
+    %two:_(s8) = G_CONSTANT i8 101
+    %one:_(s8) = G_CONSTANT i8 102
+    %sel:_(s8) = G_SELECT %c, %two, %one
+    %ext:_(s32) = G_ANYEXT %sel
+    $w0 = COPY %ext(s32...
[truncated]

}

// TODO: use knownbits to determine zeros
bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a port of foldBoolSelectToLogic from the DAG.

}

// TODO: use knownbits to determine zeros
bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a partial port of foldSelectOfConstants from the DAG.

@@ -749,8 +751,7 @@ body: |
; CHECK-NEXT: %reg:_(s32) = COPY $vgpr0
; CHECK-NEXT: %zero:_(s32) = G_CONSTANT i32 0
; CHECK-NEXT: %cond:_(s1) = G_ICMP intpred(eq), %reg(s32), %zero
; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 1
; CHECK-NEXT: %srem:_(s32) = G_SELECT %cond(s1), [[C]], %zero
; CHECK-NEXT: %srem:_(s32) = G_ZEXT %cond(s1)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

improvement

@@ -802,8 +803,7 @@ body: |
; CHECK-NEXT: %reg:_(s32) = COPY $vgpr0
; CHECK-NEXT: %zero:_(s32) = G_CONSTANT i32 0
; CHECK-NEXT: %cond:_(s1) = G_ICMP intpred(eq), %reg(s32), %zero
; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 1
; CHECK-NEXT: %udiv:_(s32) = G_SELECT %cond(s1), [[C]], %zero
; CHECK-NEXT: %udiv:_(s32) = G_ZEXT %cond(s1)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

improvement

Comment on lines 6355 to 6358
// Only do this before legalization to avoid conflicting with target-specific
// transforms in the other direction.
if (CondTy != LLT::scalar(1))
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To implement your comment this check only works for AArch64, you should use isPreLegalize() instead.

However I think you actually want what your code does here, i.e. the s1 check. So probably rewrite the comment to remove the part about before-legalization. For non-s1 types you might need to care about getBooleanContents() etc so easier to do it for s1 only.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The text was inspired by the DAG combiner. I updated it.

llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp Outdated Show resolved Hide resolved
Comment on lines +6450 to +6451
Register Inner = MRI.createGenericVirtualRegister(TrueTy);
B.buildSExtOrTrunc(Inner, Cond);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can fold this into the MIRBuilder,

auto Inner = B.buildSextOrTrunc(TrueTy, Cond)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer the verbose version. It is more readable. I read code more often than I write code. It is completely obvious to any reader that I need a temporary register. Feeding the LLT into the builder is shorter, but it is confusing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's idiomatic for GlobalISel code to directly pass LLT to the builder when possible, you'll notice that we tend to rarely directly call createGenericVirtualRegister(). That said, if you feel very strongly about this I won't block this patch on that basis, but you should also be aware that this code will be in a minority w.r.t the rest of the codebase.

Copy link
Contributor

@aemerson aemerson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved, I'll leave the final call to you with the coding style.

@tschuett tschuett merged commit 4b91949 into llvm:main Jan 2, 2024
4 checks passed
@tschuett tschuett deleted the gisel-select1 branch January 2, 2024 16:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants