Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLVM][CodeGen] Teach SelectionDAG how to expand FREM to a vector math call. #83859

Merged
merged 3 commits into from
Mar 8, 2024

Conversation

paulwalker-arm
Copy link
Collaborator

This removes, at least when a vector library is available, a failure case for scalable vectors. Doing so means we can confidently cost vector FREM instructions without making an assumption that later passes will transform the IR before it gets to the code generator.

NOTE: Currently only FREM has been implemented but the same mechanism can be used for the other libm related ISD nodes.

…h call.

This removes, at least when a vector library is available, a failure
case for scalable vectors. Doing so means we can confidently cost
vector FREM instructions without making an assumption that later
passes will transform the IR before it gets to the code generator.

NOTE: Currently only FREM has been implemented but the same mechanism
can be used for the other libm related ISD nodes.
@llvmbot llvmbot added backend:AArch64 llvm:SelectionDAG SelectionDAGISel as well labels Mar 4, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 4, 2024

@llvm/pr-subscribers-backend-aarch64

@llvm/pr-subscribers-llvm-selectiondag

Author: Paul Walker (paulwalker-arm)

Changes

This removes, at least when a vector library is available, a failure case for scalable vectors. Doing so means we can confidently cost vector FREM instructions without making an assumption that later passes will transform the IR before it gets to the code generator.

NOTE: Currently only FREM has been implemented but the same mechanism can be used for the other libm related ISD nodes.


Full diff: https://github.com/llvm/llvm-project/pull/83859.diff

3 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (+127)
  • (modified) llvm/lib/CodeGen/TargetPassConfig.cpp (+5-1)
  • (added) llvm/test/CodeGen/AArch64/fp-veclib-expansion.ll (+116)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 6074498d9144ff..ebd6f62a63ac4d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -28,6 +28,8 @@
 
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/VectorUtils.h"
 #include "llvm/CodeGen/ISDOpcodes.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/SelectionDAGNodes.h"
@@ -147,6 +149,14 @@ class VectorLegalizer {
   void ExpandStrictFPOp(SDNode *Node, SmallVectorImpl<SDValue> &Results);
   void ExpandREM(SDNode *Node, SmallVectorImpl<SDValue> &Results);
 
+  bool tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall LC,
+                            SmallVectorImpl<SDValue> &Results);
+  bool tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall Call_F32,
+                            RTLIB::Libcall Call_F64, RTLIB::Libcall Call_F80,
+                            RTLIB::Libcall Call_F128,
+                            RTLIB::Libcall Call_PPCF128,
+                            SmallVectorImpl<SDValue> &Results);
+
   void UnrollStrictFPOp(SDNode *Node, SmallVectorImpl<SDValue> &Results);
 
   /// Implements vector promotion.
@@ -1139,6 +1149,13 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
   case ISD::VP_MERGE:
     Results.push_back(ExpandVP_MERGE(Node));
     return;
+  case ISD::FREM:
+    if (tryExpandVecMathCall(Node, RTLIB::REM_F32, RTLIB::REM_F64,
+                             RTLIB::REM_F80, RTLIB::REM_F128,
+                             RTLIB::REM_PPCF128, Results))
+      return;
+
+    break;
   }
 
   SDValue Unrolled = DAG.UnrollVectorOp(Node);
@@ -1842,6 +1859,116 @@ void VectorLegalizer::ExpandREM(SDNode *Node,
   Results.push_back(Result);
 }
 
+// Try to expand libm nodes into a call to a vector math. Callers provide the
+// LibFunc equivalent of the passed in Node, which is used to lookup mappings
+// within TargetLibraryInfo. Only simply mappings are considered whereby only
+// matching vector operands are allowed and masked functions are passed an all
+// true vector (i.e. Node cannot be a predicated operation).
+bool VectorLegalizer::tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall LC,
+                                           SmallVectorImpl<SDValue> &Results) {
+  // Chain must be propagated but currently strict fp operations are down
+  // converted to their none strict counterpart.
+  assert(!Node->isStrictFPOpcode() && "Unexpected strict fp operation!");
+
+  const char *LCName = TLI.getLibcallName(LC);
+  if (!LCName)
+    return false;
+  LLVM_DEBUG(dbgs() << "Looking for vector variant of " << LCName << "\n");
+
+  EVT VT = Node->getValueType(0);
+  ElementCount VL = VT.getVectorElementCount();
+
+  // Lookup a vector function equivalent to the specified libcall. Prefer
+  // unmasked variants but we will generate a mask if need be.
+  const TargetLibraryInfo &TLibInfo = DAG.getLibInfo();
+  const VecDesc *VD = TLibInfo.getVectorMappingInfo(LCName, VL, false);
+  if (!VD)
+    VD = TLibInfo.getVectorMappingInfo(LCName, VL, /*Masked*/ true);
+  if (!VD)
+    return false;
+
+  LLVMContext *Ctx = DAG.getContext();
+  Type *Ty = VT.getTypeForEVT(*Ctx);
+  Type *ScalarTy = Ty->getScalarType();
+
+  // Construct a scalar function type based on Node's operands.
+  SmallVector<Type *, 8> ArgTys;
+  for (unsigned i = 0; i < Node->getNumOperands(); ++i) {
+    assert(Node->getOperand(i).getValueType() == VT &&
+           "Expected matching vector types!");
+    ArgTys.push_back(ScalarTy);
+  }
+  FunctionType *ScalarFTy = FunctionType::get(ScalarTy, ArgTys, false);
+
+  // Generate call information for the vector function.
+  const std::string MangledName = VD->getVectorFunctionABIVariantString();
+  auto OptVFInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy);
+  if (!OptVFInfo)
+    return false;
+
+  LLVM_DEBUG(dbgs() << "Found vector variant " << VD->getVectorFnName()
+                    << "\n");
+
+  // Sanity check just in case OptVFInfo has unexpected paramaters.
+  if (OptVFInfo->Shape.Parameters.size() !=
+      Node->getNumOperands() + VD->isMasked())
+    return false;
+
+  // Collect vector call operands.
+
+  SDLoc DL(Node);
+  TargetLowering::ArgListTy Args;
+  TargetLowering::ArgListEntry Entry;
+  Entry.IsSExt = false;
+  Entry.IsZExt = false;
+
+  unsigned OpNum = 0;
+  for (auto &VFParam : OptVFInfo->Shape.Parameters) {
+    if (VFParam.ParamKind == VFParamKind::GlobalPredicate) {
+      EVT MaskVT = TLI.getSetCCResultType(DAG.getDataLayout(), *Ctx, VT);
+      Entry.Node = DAG.getBoolConstant(true, DL, MaskVT, VT);
+      Entry.Ty = MaskVT.getTypeForEVT(*Ctx);
+      Args.push_back(Entry);
+      continue;
+    }
+
+    // Only vector operands are supported.
+    if (VFParam.ParamKind != VFParamKind::Vector)
+      return false;
+
+    Entry.Node = Node->getOperand(OpNum++);
+    Entry.Ty = Ty;
+    Args.push_back(Entry);
+  }
+
+  // Emit a call to the vector function.
+  SDValue Callee = DAG.getExternalSymbol(VD->getVectorFnName().data(),
+                                         TLI.getPointerTy(DAG.getDataLayout()));
+  TargetLowering::CallLoweringInfo CLI(DAG);
+  CLI.setDebugLoc(DL)
+      .setChain(DAG.getEntryNode())
+      .setLibCallee(CallingConv::C, Ty, Callee, std::move(Args));
+
+  std::pair<SDValue, SDValue> CallResult = TLI.LowerCallTo(CLI);
+  Results.push_back(CallResult.first);
+  return true;
+}
+
+/// Try to expand the node to a vector libcall based on the result type.
+bool VectorLegalizer::tryExpandVecMathCall(
+    SDNode *Node, RTLIB::Libcall Call_F32, RTLIB::Libcall Call_F64,
+    RTLIB::Libcall Call_F80, RTLIB::Libcall Call_F128,
+    RTLIB::Libcall Call_PPCF128, SmallVectorImpl<SDValue> &Results) {
+  RTLIB::Libcall LC = RTLIB::getFPLibCall(
+      Node->getValueType(0).getVectorElementType(), Call_F32, Call_F64,
+      Call_F80, Call_F128, Call_PPCF128);
+
+  if (LC == RTLIB::UNKNOWN_LIBCALL)
+    return false;
+
+  return tryExpandVecMathCall(Node, LC, Results);
+}
+
 void VectorLegalizer::UnrollStrictFPOp(SDNode *Node,
                                        SmallVectorImpl<SDValue> &Results) {
   EVT VT = Node->getValueType(0);
diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp
index cf068ece8d4cab..8832b51333d910 100644
--- a/llvm/lib/CodeGen/TargetPassConfig.cpp
+++ b/llvm/lib/CodeGen/TargetPassConfig.cpp
@@ -205,6 +205,10 @@ static cl::opt<bool> MISchedPostRA(
 static cl::opt<bool> EarlyLiveIntervals("early-live-intervals", cl::Hidden,
     cl::desc("Run live interval analysis earlier in the pipeline"));
 
+static cl::opt<bool> DisableReplaceWithVecLib(
+    "disable-replace-with-vec-lib", cl::Hidden,
+    cl::desc("Disable replace with vector math call pass"));
+
 /// Option names for limiting the codegen pipeline.
 /// Those are used in error reporting and we didn't want
 /// to duplicate their names all over the place.
@@ -856,7 +860,7 @@ void TargetPassConfig::addIRPasses() {
   if (getOptLevel() != CodeGenOptLevel::None && !DisableConstantHoisting)
     addPass(createConstantHoistingPass());
 
-  if (getOptLevel() != CodeGenOptLevel::None)
+  if (getOptLevel() != CodeGenOptLevel::None && !DisableReplaceWithVecLib)
     addPass(createReplaceWithVeclibLegacyPass());
 
   if (getOptLevel() != CodeGenOptLevel::None && !DisablePartialLibcallInlining)
diff --git a/llvm/test/CodeGen/AArch64/fp-veclib-expansion.ll b/llvm/test/CodeGen/AArch64/fp-veclib-expansion.ll
new file mode 100644
index 00000000000000..67c056c780cc80
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/fp-veclib-expansion.ll
@@ -0,0 +1,116 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc --disable-replace-with-vec-lib --vector-library=ArmPL < %s -o - | FileCheck --check-prefix=ARMPL %s
+; RUN: llc --disable-replace-with-vec-lib --vector-library=sleefgnuabi < %s -o - | FileCheck --check-prefix=SLEEF %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+define <2 x double> @frem_v2f64(<2 x double> %unused, <2 x double> %a, <2 x double> %b) #0 {
+; ARMPL-LABEL: frem_v2f64:
+; ARMPL:       // %bb.0:
+; ARMPL-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; ARMPL-NEXT:    .cfi_def_cfa_offset 16
+; ARMPL-NEXT:    .cfi_offset w30, -16
+; ARMPL-NEXT:    mov v0.16b, v1.16b
+; ARMPL-NEXT:    mov v1.16b, v2.16b
+; ARMPL-NEXT:    bl armpl_vfmodq_f64
+; ARMPL-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; ARMPL-NEXT:    ret
+;
+; SLEEF-LABEL: frem_v2f64:
+; SLEEF:       // %bb.0:
+; SLEEF-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; SLEEF-NEXT:    .cfi_def_cfa_offset 16
+; SLEEF-NEXT:    .cfi_offset w30, -16
+; SLEEF-NEXT:    mov v0.16b, v1.16b
+; SLEEF-NEXT:    mov v1.16b, v2.16b
+; SLEEF-NEXT:    bl _ZGVnN2vv_fmod
+; SLEEF-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; SLEEF-NEXT:    ret
+  %res = frem <2 x double> %a, %b
+  ret <2 x double> %res
+}
+
+define <4 x float> @frem_strict_v4f32(<4 x float> %unused, <4 x float> %a, <4 x float> %b) #1 {
+; ARMPL-LABEL: frem_strict_v4f32:
+; ARMPL:       // %bb.0:
+; ARMPL-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; ARMPL-NEXT:    .cfi_def_cfa_offset 16
+; ARMPL-NEXT:    .cfi_offset w30, -16
+; ARMPL-NEXT:    mov v0.16b, v1.16b
+; ARMPL-NEXT:    mov v1.16b, v2.16b
+; ARMPL-NEXT:    bl armpl_vfmodq_f32
+; ARMPL-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; ARMPL-NEXT:    ret
+;
+; SLEEF-LABEL: frem_strict_v4f32:
+; SLEEF:       // %bb.0:
+; SLEEF-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; SLEEF-NEXT:    .cfi_def_cfa_offset 16
+; SLEEF-NEXT:    .cfi_offset w30, -16
+; SLEEF-NEXT:    mov v0.16b, v1.16b
+; SLEEF-NEXT:    mov v1.16b, v2.16b
+; SLEEF-NEXT:    bl _ZGVnN4vv_fmodf
+; SLEEF-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; SLEEF-NEXT:    ret
+  %res = frem <4 x float> %a, %b
+  ret <4 x float> %res
+}
+
+define <vscale x 4 x float> @frem_nxv4f32(<vscale x 4 x float> %unused, <vscale x 4 x float> %a, <vscale x 4 x float> %b) #0 {
+; ARMPL-LABEL: frem_nxv4f32:
+; ARMPL:       // %bb.0:
+; ARMPL-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; ARMPL-NEXT:    .cfi_def_cfa_offset 16
+; ARMPL-NEXT:    .cfi_offset w30, -16
+; ARMPL-NEXT:    ptrue p0.s
+; ARMPL-NEXT:    mov z0.d, z1.d
+; ARMPL-NEXT:    mov z1.d, z2.d
+; ARMPL-NEXT:    bl armpl_svfmod_f32_x
+; ARMPL-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; ARMPL-NEXT:    ret
+;
+; SLEEF-LABEL: frem_nxv4f32:
+; SLEEF:       // %bb.0:
+; SLEEF-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; SLEEF-NEXT:    .cfi_def_cfa_offset 16
+; SLEEF-NEXT:    .cfi_offset w30, -16
+; SLEEF-NEXT:    ptrue p0.s
+; SLEEF-NEXT:    mov z0.d, z1.d
+; SLEEF-NEXT:    mov z1.d, z2.d
+; SLEEF-NEXT:    bl _ZGVsMxvv_fmodf
+; SLEEF-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; SLEEF-NEXT:    ret
+  %res = frem <vscale x 4 x float> %a, %b
+  ret <vscale x 4 x float> %res
+}
+
+define <vscale x 2 x double> @frem_strict_nxv2f64(<vscale x 2 x double> %unused, <vscale x 2 x double> %a, <vscale x 2 x double> %b) #1 {
+; ARMPL-LABEL: frem_strict_nxv2f64:
+; ARMPL:       // %bb.0:
+; ARMPL-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; ARMPL-NEXT:    .cfi_def_cfa_offset 16
+; ARMPL-NEXT:    .cfi_offset w30, -16
+; ARMPL-NEXT:    ptrue p0.d
+; ARMPL-NEXT:    mov z0.d, z1.d
+; ARMPL-NEXT:    mov z1.d, z2.d
+; ARMPL-NEXT:    bl armpl_svfmod_f64_x
+; ARMPL-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; ARMPL-NEXT:    ret
+;
+; SLEEF-LABEL: frem_strict_nxv2f64:
+; SLEEF:       // %bb.0:
+; SLEEF-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; SLEEF-NEXT:    .cfi_def_cfa_offset 16
+; SLEEF-NEXT:    .cfi_offset w30, -16
+; SLEEF-NEXT:    ptrue p0.d
+; SLEEF-NEXT:    mov z0.d, z1.d
+; SLEEF-NEXT:    mov z1.d, z2.d
+; SLEEF-NEXT:    bl _ZGVsMxvv_fmod
+; SLEEF-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; SLEEF-NEXT:    ret
+  %res = frem <vscale x 2 x double> %a, %b
+  ret <vscale x 2 x double> %res
+}
+
+attributes #0 = { "target-features"="+sve" }
+attributes #1 = { "target-features"="+sve" strictfp }

const TargetLibraryInfo &TLibInfo = DAG.getLibInfo();
const VecDesc *VD = TLibInfo.getVectorMappingInfo(LCName, VL, false);
if (!VD)
VD = TLibInfo.getVectorMappingInfo(LCName, VL, /*Masked*/ true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
VD = TLibInfo.getVectorMappingInfo(LCName, VL, /*Masked*/ true);
VD = TLibInfo.getVectorMappingInfo(LCName, VL, /*Masked=*/ true);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link

github-actions bot commented Mar 7, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Member

@paschalis-mpeis paschalis-mpeis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes.


/// Try to expand the node to a vector libcall based on the result type.
bool VectorLegalizer::tryExpandVecMathCall(
SDNode *Node, RTLIB::Libcall Call_F32, RTLIB::Libcall Call_F64,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason the Libcall variants are parameters to this function, instead of being hardcoded? (Looking at other functions like DAGTypeLegalizer::SoftenFloatRes_FREM I see the list being hardcoded).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm expecting this function to be used by other ISD nodes (e.g. ISD::FSIN, ISD::FSIN etc) so I followed the idiom used by ExpandFPLibCall.

Copy link
Collaborator

@huntergr-arm huntergr-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@paulwalker-arm paulwalker-arm merged commit bd6eb54 into llvm:main Mar 8, 2024
3 of 4 checks passed
@paulwalker-arm paulwalker-arm deleted the codegen-scalable-frem branch March 19, 2024 18:17
paschalis-mpeis added a commit that referenced this pull request Jun 4, 2024
Refactor the pass to only support IntrinsicInst calls.

ReplaceWithVecLib used to support instructions, as AArch64 was using this
pass to replace a vectorized frem instruction to the fmod vector library
call (through TLI).

As this replacement is now done by the codegen (#83859), there is no need
for this pass to support instructions.

Additionally, removed 'frem' tests from:
- AArch64/replace-with-veclib-armpl.ll
- AArch64/replace-with-veclib-sleef-scalable.ll
- AArch64/replace-with-veclib-sleef.l

Such testing is done at codegen level:
- #83859
paschalis-mpeis added a commit that referenced this pull request Jun 7, 2024
Refactor the pass to only support IntrinsicInst calls.

ReplaceWithVecLib used to support instructions, as AArch64 was using this
pass to replace a vectorized frem instruction to the fmod vector library
call (through TLI).

As this replacement is now done by the codegen (#83859), there is no need
for this pass to support instructions.

Additionally, removed 'frem' tests from:
- AArch64/replace-with-veclib-armpl.ll
- AArch64/replace-with-veclib-sleef-scalable.ll
- AArch64/replace-with-veclib-sleef.ll

Such testing is done at codegen level:
- #83859
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants