Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SDISel] Teach the type legalizer about ADDRSPACECAST #90969

Merged
merged 3 commits into from
May 7, 2024

Conversation

qcolombet
Copy link
Collaborator

Vectorized ADDRSPACECASTs were not supported by the type legalizer.

This patch adds the support for:

  • splitting the vector result: <2 x ptr> => 2 x <1 x ptr>
  • scalarization: <1 x ptr> => ptr
  • widening: <3 x ptr> => <4 x ptr>

This is all exercised by the added NVPTX tests.

@jholewinski to double check the NVPTX test results.

Vectorized ADDRSPACECASTs were not supported by the type legalizer.
This patch adds the support for:
- splitting the vector result: <2 x ptr> => 2 x <1 x ptr>
- scalarization: <1 x ptr> => ptr
- widening: <3 x ptr> => <4 x ptr>

This is all exercised by the added NVPTX tests.
@llvmbot llvmbot added the llvm:SelectionDAG SelectionDAGISel as well label May 3, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented May 3, 2024

@llvm/pr-subscribers-llvm-selectiondag

Author: Quentin Colombet (qcolombet)

Changes

Vectorized ADDRSPACECASTs were not supported by the type legalizer.

This patch adds the support for:

  • splitting the vector result: <2 x ptr> => 2 x <1 x ptr>
  • scalarization: <1 x ptr> => ptr
  • widening: <3 x ptr> => <4 x ptr>

This is all exercised by the added NVPTX tests.

@jholewinski to double check the NVPTX test results.


Full diff: https://github.com/llvm/llvm-project/pull/90969.diff

3 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+3)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+66)
  • (modified) llvm/test/CodeGen/NVPTX/addrspacecast.ll (+92)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 0252e3d6febca9..dcd547d231c70b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -785,6 +785,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue ScalarizeVecRes_InregOp(SDNode *N);
   SDValue ScalarizeVecRes_VecInregOp(SDNode *N);
 
+  SDValue ScalarizeVecRes_ADDRSPACECAST(SDNode *N);
   SDValue ScalarizeVecRes_BITCAST(SDNode *N);
   SDValue ScalarizeVecRes_BUILD_VECTOR(SDNode *N);
   SDValue ScalarizeVecRes_EXTRACT_SUBVECTOR(SDNode *N);
@@ -852,6 +853,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   void SplitVecRes_BinOp(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_TernaryOp(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_UnaryOp(SDNode *N, SDValue &Lo, SDValue &Hi);
+  void SplitVecRes_ADDRSPACECAST(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_FFREXP(SDNode *N, unsigned ResNo, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_ExtendOp(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_InregOp(SDNode *N, SDValue &Lo, SDValue &Hi);
@@ -955,6 +957,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   // Widen Vector Result Promotion.
   void WidenVectorResult(SDNode *N, unsigned ResNo);
   SDValue WidenVecRes_MERGE_VALUES(SDNode* N, unsigned ResNo);
+  SDValue WidenVecRes_ADDRSPACECAST(SDNode *N);
   SDValue WidenVecRes_AssertZext(SDNode* N);
   SDValue WidenVecRes_BITCAST(SDNode* N);
   SDValue WidenVecRes_BUILD_VECTOR(SDNode* N);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index cab4dc5f3c1565..14501e5d01d568 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -23,6 +23,7 @@
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/Analysis/MemoryLocation.h"
 #include "llvm/Analysis/VectorUtils.h"
+#include "llvm/CodeGen/ISDOpcodes.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/TypeSize.h"
@@ -116,6 +117,9 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::FCANONICALIZE:
     R = ScalarizeVecRes_UnaryOp(N);
     break;
+  case ISD::ADDRSPACECAST:
+    R = ScalarizeVecRes_ADDRSPACECAST(N);
+    break;
   case ISD::FFREXP:
     R = ScalarizeVecRes_FFREXP(N, ResNo);
     break;
@@ -475,6 +479,31 @@ SDValue DAGTypeLegalizer::ScalarizeVecRes_VecInregOp(SDNode *N) {
   llvm_unreachable("Illegal extend_vector_inreg opcode");
 }
 
+SDValue DAGTypeLegalizer::ScalarizeVecRes_ADDRSPACECAST(SDNode *N) {
+  EVT DestVT = N->getValueType(0).getVectorElementType();
+  SDValue Op = N->getOperand(0);
+  EVT OpVT = Op.getValueType();
+  SDLoc DL(N);
+  // The result needs scalarizing, but it's not a given that the source does.
+  // This is a workaround for targets where it's impossible to scalarize the
+  // result of a conversion, because the source type is legal.
+  // For instance, this happens on AArch64: v1i1 is illegal but v1i{8,16,32}
+  // are widened to v8i8, v4i16, and v2i32, which is legal, because v1i64 is
+  // legal and was not scalarized.
+  // See the similar logic in ScalarizeVecRes_SETCC
+  if (getTypeAction(OpVT) == TargetLowering::TypeScalarizeVector) {
+    Op = GetScalarizedVector(Op);
+  } else {
+    EVT VT = OpVT.getVectorElementType();
+    Op = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op,
+                     DAG.getVectorIdxConstant(0, DL));
+  }
+  auto *AddrSpaceCastN = cast<AddrSpaceCastSDNode>(N);
+  unsigned SrcAS = AddrSpaceCastN->getSrcAddressSpace();
+  unsigned DestAS = AddrSpaceCastN->getDestAddressSpace();
+  return DAG.getAddrSpaceCast(SDLoc(N), DestVT, Op, SrcAS, DestAS);
+}
+
 SDValue DAGTypeLegalizer::ScalarizeVecRes_SCALAR_TO_VECTOR(SDNode *N) {
   // If the operand is wider than the vector element type then it is implicitly
   // truncated.  Make that explicit here.
@@ -1122,6 +1151,9 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::FCANONICALIZE:
     SplitVecRes_UnaryOp(N, Lo, Hi);
     break;
+  case ISD::ADDRSPACECAST:
+    SplitVecRes_ADDRSPACECAST(N, Lo, Hi);
+    break;
   case ISD::FFREXP:
     SplitVecRes_FFREXP(N, ResNo, Lo, Hi);
     break;
@@ -2353,6 +2385,27 @@ void DAGTypeLegalizer::SplitVecRes_UnaryOp(SDNode *N, SDValue &Lo,
   Hi = DAG.getNode(Opcode, dl, HiVT, {Hi, MaskHi, EVLHi}, Flags);
 }
 
+void DAGTypeLegalizer::SplitVecRes_ADDRSPACECAST(SDNode *N, SDValue &Lo,
+                                                 SDValue &Hi) {
+  EVT LoVT, HiVT;
+  SDLoc dl(N);
+  std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(N->getValueType(0));
+
+  // If the input also splits, handle it directly for a compile time speedup.
+  // Otherwise split it by hand.
+  EVT InVT = N->getOperand(0).getValueType();
+  if (getTypeAction(InVT) == TargetLowering::TypeSplitVector)
+    GetSplitVector(N->getOperand(0), Lo, Hi);
+  else
+    std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 0);
+
+  auto *AddrSpaceCastN = cast<AddrSpaceCastSDNode>(N);
+  unsigned SrcAS = AddrSpaceCastN->getSrcAddressSpace();
+  unsigned DestAS = AddrSpaceCastN->getDestAddressSpace();
+  Lo = DAG.getAddrSpaceCast(dl, LoVT, Lo, SrcAS, DestAS);
+  Hi = DAG.getAddrSpaceCast(dl, HiVT, Hi, SrcAS, DestAS);
+}
+
 void DAGTypeLegalizer::SplitVecRes_FFREXP(SDNode *N, unsigned ResNo,
                                           SDValue &Lo, SDValue &Hi) {
   SDLoc dl(N);
@@ -4121,6 +4174,9 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
     report_fatal_error("Do not know how to widen the result of this operator!");
 
   case ISD::MERGE_VALUES:      Res = WidenVecRes_MERGE_VALUES(N, ResNo); break;
+  case ISD::ADDRSPACECAST:
+    Res = WidenVecRes_ADDRSPACECAST(N);
+    break;
   case ISD::AssertZext:        Res = WidenVecRes_AssertZext(N); break;
   case ISD::BITCAST:           Res = WidenVecRes_BITCAST(N); break;
   case ISD::BUILD_VECTOR:      Res = WidenVecRes_BUILD_VECTOR(N); break;
@@ -5086,6 +5142,16 @@ SDValue DAGTypeLegalizer::WidenVecRes_MERGE_VALUES(SDNode *N, unsigned ResNo) {
   return GetWidenedVector(WidenVec);
 }
 
+SDValue DAGTypeLegalizer::WidenVecRes_ADDRSPACECAST(SDNode *N) {
+  EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
+  SDValue InOp = GetWidenedVector(N->getOperand(0));
+  auto *AddrSpaceCastN = cast<AddrSpaceCastSDNode>(N);
+
+  return DAG.getAddrSpaceCast(SDLoc(N), WidenVT, InOp,
+                              AddrSpaceCastN->getSrcAddressSpace(),
+                              AddrSpaceCastN->getDestAddressSpace());
+}
+
 SDValue DAGTypeLegalizer::WidenVecRes_BITCAST(SDNode *N) {
   SDValue InOp = N->getOperand(0);
   EVT InVT = InOp.getValueType();
diff --git a/llvm/test/CodeGen/NVPTX/addrspacecast.ll b/llvm/test/CodeGen/NVPTX/addrspacecast.ll
index b680490ac5b124..85752bb95eb31f 100644
--- a/llvm/test/CodeGen/NVPTX/addrspacecast.ll
+++ b/llvm/test/CodeGen/NVPTX/addrspacecast.ll
@@ -98,3 +98,95 @@ define i32 @conv8(ptr %ptr) {
   %val = load i32, ptr addrspace(5) %specptr
   ret i32 %val
 }
+
+; Check that we support addrspacecast when splitting the vector
+; result (<2 x ptr> => 2 x <1 x ptr>).
+; This also checks that scalarization works for addrspacecast
+; (when going from <1 x ptr> to ptr.)
+; ALL-LABEL: split1To0
+define void @split1To0(ptr nocapture noundef readonly %xs) {
+; CLS32: cvta.global.u32
+; CLS32: cvta.global.u32
+; CLS64: cvta.global.u64
+; CLS64: cvta.global.u64
+; ALL: st.u32
+; ALL: st.u32
+  %vec_addr = load <2 x ptr addrspace(1)>, ptr %xs, align 16
+  %addrspacecast = addrspacecast <2 x ptr addrspace(1)> %vec_addr to <2 x ptr>
+  %extractelement0 = extractelement <2 x ptr> %addrspacecast, i64 0
+  store float 0.5, ptr %extractelement0, align 4
+  %extractelement1 = extractelement <2 x ptr> %addrspacecast, i64 1
+  store float 1.0, ptr %extractelement1, align 4
+  ret void
+}
+
+; Same as split1To0 but from 0 to 1, to make sure the addrspacecast preserve
+; the source and destination addrspaces properly.
+; ALL-LABEL: split0To1
+define void @split0To1(ptr nocapture noundef readonly %xs) {
+; CLS32: cvta.to.global.u32
+; CLS32: cvta.to.global.u32
+; CLS64: cvta.to.global.u64
+; CLS64: cvta.to.global.u64
+; ALL: st.global.u32
+; ALL: st.global.u32
+  %vec_addr = load <2 x ptr>, ptr %xs, align 16
+  %addrspacecast = addrspacecast <2 x ptr> %vec_addr to <2 x ptr addrspace(1)>
+  %extractelement0 = extractelement <2 x ptr addrspace(1)> %addrspacecast, i64 0
+  store float 0.5, ptr addrspace(1) %extractelement0, align 4
+  %extractelement1 = extractelement <2 x ptr addrspace(1)> %addrspacecast, i64 1
+  store float 1.0, ptr addrspace(1) %extractelement1, align 4
+  ret void
+}
+
+; Check that we support addrspacecast when a widening is required
+; (3 x ptr => 4 x ptr).
+; ALL-LABEL: widen1To0
+define void @widen1To0(ptr nocapture noundef readonly %xs) {
+; CLS32: cvta.global.u32
+; CLS32: cvta.global.u32
+; CLS32: cvta.global.u32
+
+; CLS64: cvta.global.u64
+; CLS64: cvta.global.u64
+; CLS64: cvta.global.u64
+
+; ALL: st.u32
+; ALL: st.u32
+; ALL: st.u32
+  %vec_addr = load <3 x ptr addrspace(1)>, ptr %xs, align 16
+  %addrspacecast = addrspacecast <3 x ptr addrspace(1)> %vec_addr to <3 x ptr>
+  %extractelement0 = extractelement <3 x ptr> %addrspacecast, i64 0
+  store float 0.5, ptr %extractelement0, align 4
+  %extractelement1 = extractelement <3 x ptr> %addrspacecast, i64 1
+  store float 1.0, ptr %extractelement1, align 4
+  %extractelement2 = extractelement <3 x ptr> %addrspacecast, i64 2
+  store float 1.5, ptr %extractelement2, align 4
+  ret void
+}
+
+; Same as widen1To0 but from 0 to 1, to make sure the addrspacecast preserve
+; the source and destination addrspaces properly.
+; ALL-LABEL: widen0To1
+define void @widen0To1(ptr nocapture noundef readonly %xs) {
+; CLS32: cvta.to.global.u32
+; CLS32: cvta.to.global.u32
+; CLS32: cvta.to.global.u32
+
+; CLS64: cvta.to.global.u64
+; CLS64: cvta.to.global.u64
+; CLS64: cvta.to.global.u64
+
+; ALL: st.global.u32
+; ALL: st.global.u32
+; ALL: st.global.u32
+  %vec_addr = load <3 x ptr>, ptr %xs, align 16
+  %addrspacecast = addrspacecast <3 x ptr> %vec_addr to <3 x ptr addrspace(1)>
+  %extractelement0 = extractelement <3 x ptr addrspace(1)> %addrspacecast, i64 0
+  store float 0.5, ptr addrspace(1) %extractelement0, align 4
+  %extractelement1 = extractelement <3 x ptr addrspace(1)> %addrspacecast, i64 1
+  store float 1.0, ptr addrspace(1) %extractelement1, align 4
+  %extractelement2 = extractelement <3 x ptr addrspace(1)> %addrspacecast, i64 2
+  store float 1.5, ptr addrspace(1) %extractelement2, align 4
+  ret void
+}

qcolombet and others added 2 commits May 5, 2024 09:58
Co-authored-by: Matt Arsenault <arsenm2@gmail.com>
Co-authored-by: Matt Arsenault <arsenm2@gmail.com>
@qcolombet
Copy link
Collaborator Author

Thanks for the suggestions and review @arsenm !

@qcolombet qcolombet merged commit 6ce0474 into llvm:main May 7, 2024
3 of 4 checks passed
@qcolombet qcolombet deleted the sdisel_addrspcast branch May 7, 2024 09:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants