Skip to content

[DAGCombiner][AMDGPU] Track signedness in ByteProviders #65995

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

jrbyrnes
Copy link
Contributor

Provide ability to track signedness in ByteProviders & do so in SIISelLowering caculateByteProvider.

When combining into an arithmetic vectorized reduction instruction, the signedness is relevant. In such cases, the signedness may be encoded into the tree itself rather than the top-level instruction -- thus, there is a need to track whether we are handling the byte as a signed operand during traversal.

@jrbyrnes jrbyrnes requested a review from a team as a code owner September 11, 2023 18:44
@jrbyrnes jrbyrnes requested a review from arsenm September 11, 2023 18:44
@llvmbot
Copy link
Member

llvmbot commented Sep 11, 2023

@llvm/pr-subscribers-backend-amdgpu

Changes

Provide ability to track signedness in ByteProviders & do so in SIISelLowering caculateByteProvider.

When combining into an arithmetic vectorized reduction instruction, the signedness is relevant. In such cases, the signedness may be encoded into the tree itself rather than the top-level instruction -- thus, there is a need to track whether we are handling the byte as a signed operand during traversal.

Patch is 33.60 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/65995.diff

4 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/ByteProvider.h (+11-2)
  • (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+54-32)
  • (modified) llvm/test/CodeGen/AMDGPU/idot4s.ll (+207-14)
  • (modified) llvm/test/CodeGen/AMDGPU/idot4u.ll (+59-17)
diff --git a/llvm/include/llvm/CodeGen/ByteProvider.h b/llvm/include/llvm/CodeGen/ByteProvider.h
index 3187b4e68c56f3a..7fc2d9a876e6710 100644
--- a/llvm/include/llvm/CodeGen/ByteProvider.h
+++ b/llvm/include/llvm/CodeGen/ByteProvider.h
@@ -32,6 +32,11 @@ template  class ByteProvider {
   ByteProvider(std::optional Src, int64_t DestOffset, int64_t SrcOffset)
       : Src(Src), DestOffset(DestOffset), SrcOffset(SrcOffset) {}
 
+  ByteProvider(std::optional Src, int64_t DestOffset, int64_t SrcOffset,
+               bool IsSigned)
+      : Src(Src), DestOffset(DestOffset), SrcOffset(SrcOffset),
+        IsSigned(IsSigned) {}
+
   // TODO -- use constraint in c++20
   // Does this type correspond with an operation in selection DAG
   template  class is_op {
@@ -61,13 +66,17 @@ template  class ByteProvider {
   // DestOffset
   int64_t SrcOffset = 0;
 
+  // Tracks whether or not the byte is treated as a signed operand -- useful
+  // for arithmetic combines.
+  bool IsSigned = 0;
+
   ByteProvider() = default;
 
   static ByteProvider getSrc(std::optional Val, int64_t ByteOffset,
-                             int64_t VectorOffset) {
+                             int64_t VectorOffset, bool IsSigned = 0) {
     static_assert(is_op().value,
                   "ByteProviders must contain an operation in selection DAG.");
-    return ByteProvider(Val, ByteOffset, VectorOffset);
+    return ByteProvider(Val, ByteOffset, VectorOffset, IsSigned);
   }
 
   static ByteProvider getConstantZero() {
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 85c9ed489e926ce..a7116518aaadd0d 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -10513,7 +10513,7 @@ SDValue SITargetLowering::performAndCombine(SDNode *N,
 // performed.
 static const std::optional>
 calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
-                 unsigned Depth = 0) {
+                 bool IsSigned = 0, unsigned Depth = 0) {
   // We may need to recursively traverse a series of SRLs
   if (Depth >= 6)
     return std::nullopt;
@@ -10524,12 +10524,15 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
 
   switch (Op->getOpcode()) {
   case ISD::TRUNCATE: {
-    return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1);
+    return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned,
+                            Depth + 1);
   }
 
   case ISD::SIGN_EXTEND:
   case ISD::ZERO_EXTEND:
   case ISD::SIGN_EXTEND_INREG: {
+    IsSigned |= Op->getOpcode() == ISD::SIGN_EXTEND ||
+                Op->getOpcode() == ISD::SIGN_EXTEND_INREG;
     SDValue NarrowOp = Op->getOperand(0);
     auto NarrowVT = NarrowOp.getValueType();
     if (Op->getOpcode() == ISD::SIGN_EXTEND_INREG) {
@@ -10542,7 +10545,8 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
 
     if (SrcIndex >= NarrowByteWidth)
       return std::nullopt;
-    return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1);
+    return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned,
+                            Depth + 1);
   }
 
   case ISD::SRA:
@@ -10558,11 +10562,15 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
 
     SrcIndex += BitShift / 8;
 
-    return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1);
+    return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned,
+                            Depth + 1);
   }
 
   default: {
-    return ByteProvider::getSrc(Op, DestByte, SrcIndex);
+    if (auto L = dyn_cast(Op))
+      IsSigned |= L->getExtensionType() == ISD::SEXTLOAD;
+
+    return ByteProvider::getSrc(Op, DestByte, SrcIndex, IsSigned);
   }
   }
   llvm_unreachable("fully handled switch");
@@ -10576,7 +10584,7 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
 // performed. \p StartingIndex is the originally requested byte of the Or
 static const std::optional>
 calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
-                      unsigned StartingIndex = 0) {
+                      unsigned StartingIndex = 0, bool IsSigned = 0) {
   // Finding Src tree of RHS of or typically requires at least 1 additional
   // depth
   if (Depth > 6)
@@ -10591,11 +10599,11 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
   switch (Op.getOpcode()) {
   case ISD::OR: {
     auto RHS = calculateByteProvider(Op.getOperand(1), Index, Depth + 1,
-                                     StartingIndex);
+                                     StartingIndex, IsSigned);
     if (!RHS)
       return std::nullopt;
     auto LHS = calculateByteProvider(Op.getOperand(0), Index, Depth + 1,
-                                     StartingIndex);
+                                     StartingIndex, IsSigned);
     if (!LHS)
       return std::nullopt;
     // A well formed Or will have two ByteProviders for each byte, one of which
@@ -10626,7 +10634,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
       return ByteProvider::getConstantZero();
     }
 
-    return calculateSrcByte(Op->getOperand(0), StartingIndex, Index);
+    return calculateSrcByte(Op->getOperand(0), StartingIndex, Index, IsSigned);
   }
 
   case ISD::SRA:
@@ -10651,7 +10659,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
     // the SRL is Index + ByteShift
     return BytesProvided - ByteShift > Index
                ? calculateSrcByte(Op->getOperand(0), StartingIndex,
-                                  Index + ByteShift)
+                                  Index + ByteShift, IsSigned)
                : ByteProvider::getConstantZero();
   }
 
@@ -10672,7 +10680,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
     return Index < ByteShift
                ? ByteProvider::getConstantZero()
                : calculateByteProvider(Op.getOperand(0), Index - ByteShift,
-                                       Depth + 1, StartingIndex);
+                                       Depth + 1, StartingIndex, IsSigned);
   }
   case ISD::ANY_EXTEND:
   case ISD::SIGN_EXTEND:
@@ -10691,13 +10699,17 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
     if (NarrowBitWidth % 8 != 0)
       return std::nullopt;
     uint64_t NarrowByteWidth = NarrowBitWidth / 8;
+    IsSigned |= Op->getOpcode() == ISD::SIGN_EXTEND ||
+                Op->getOpcode() == ISD::SIGN_EXTEND_INREG ||
+                Op->getOpcode() == ISD::AssertSext;
 
     if (Index >= NarrowByteWidth)
       return Op.getOpcode() == ISD::ZERO_EXTEND
                  ? std::optional>(
                        ByteProvider::getConstantZero())
                  : std::nullopt;
-    return calculateByteProvider(NarrowOp, Index, Depth + 1, StartingIndex);
+    return calculateByteProvider(NarrowOp, Index, Depth + 1, StartingIndex,
+                                 IsSigned);
   }
 
   case ISD::TRUNCATE: {
@@ -10705,7 +10717,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
 
     if (NarrowByteWidth >= Index) {
       return calculateByteProvider(Op.getOperand(0), Index, Depth + 1,
-                                   StartingIndex);
+                                   StartingIndex, IsSigned);
     }
 
     return std::nullopt;
@@ -10713,13 +10725,14 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
 
   case ISD::CopyFromReg: {
     if (BitWidth / 8 > Index)
-      return calculateSrcByte(Op, StartingIndex, Index);
+      return calculateSrcByte(Op, StartingIndex, Index, IsSigned);
 
     return std::nullopt;
   }
 
   case ISD::LOAD: {
     auto L = cast(Op.getNode());
+    IsSigned |= L->getExtensionType() == ISD::SEXTLOAD;
     unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits();
     if (NarrowBitWidth % 8 != 0)
       return std::nullopt;
@@ -10736,7 +10749,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
     }
 
     if (NarrowByteWidth > Index) {
-      return calculateSrcByte(Op, StartingIndex, Index);
+      return calculateSrcByte(Op, StartingIndex, Index, IsSigned);
     }
 
     return std::nullopt;
@@ -10744,7 +10757,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
 
   case ISD::BSWAP:
     return calculateByteProvider(Op->getOperand(0), BitWidth / 8 - Index - 1,
-                                 Depth + 1, StartingIndex);
+                                 Depth + 1, StartingIndex, IsSigned);
 
   case ISD::EXTRACT_VECTOR_ELT: {
     auto IdxOp = dyn_cast(Op->getOperand(1));
@@ -10759,7 +10772,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
     }
 
     return calculateSrcByte(ScalarSize == 32 ? Op : Op.getOperand(0),
-                            StartingIndex, Index);
+                            StartingIndex, Index, IsSigned);
   }
 
   case AMDGPUISD::PERM: {
@@ -10775,9 +10788,10 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
     auto NextOp = Op.getOperand(IdxMask > 0x03 ? 0 : 1);
     auto NextIndex = IdxMask > 0x03 ? IdxMask % 4 : IdxMask;
 
-    return IdxMask != 0x0c ? calculateSrcByte(NextOp, StartingIndex, NextIndex)
-                           : ByteProvider(
-                                 ByteProvider::getConstantZero());
+    return IdxMask != 0x0c
+               ? calculateSrcByte(NextOp, StartingIndex, NextIndex, IsSigned)
+               : ByteProvider(
+                     ByteProvider::getConstantZero());
   }
 
   default: {
@@ -12587,11 +12601,7 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
     auto MulIdx = isMul(LHS) ? 0 : 1;
 
     auto MulOpcode = TempNode.getOperand(MulIdx).getOpcode();
-    bool IsSigned =
-        MulOpcode == AMDGPUISD::MUL_I24 ||
-        (MulOpcode == ISD::MUL &&
-         TempNode->getOperand(MulIdx)->getFlags().hasNoSignedWrap() &&
-         !TempNode->getOperand(MulIdx)->getFlags().hasNoUnsignedWrap());
+    std::optional IsSigned;
     SmallVector, 4> Src0s;
     SmallVector, 4> Src1s;
     SmallVector Src2s;
@@ -12607,15 +12617,17 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
           (MulOpcode == ISD::MUL &&
            TempNode->getOperand(MulIdx)->getFlags().hasNoSignedWrap() &&
            !TempNode->getOperand(MulIdx)->getFlags().hasNoUnsignedWrap());
-      if (IterIsSigned != IsSigned) {
-        break;
-      }
       auto Src0 = handleMulOperand(TempNode->getOperand(MulIdx)->getOperand(0));
       if (!Src0)
         break;
       auto Src1 = handleMulOperand(TempNode->getOperand(MulIdx)->getOperand(1));
       if (!Src1)
         break;
+      IterIsSigned |= Src0->IsSigned || Src1->IsSigned;
+      if (!IsSigned)
+        IsSigned = IterIsSigned;
+      if (IterIsSigned != *IsSigned)
+        break;
       placeSources(*Src0, *Src1, Src0s, Src1s, I);
       auto AddIdx = 1 - MulIdx;
       // Allow the special case where add (add (mul24, 0), mul24) became ->
@@ -12630,6 +12642,15 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
             handleMulOperand(TempNode->getOperand(AddIdx)->getOperand(1));
         if (!Src1)
           break;
+        auto IterIsSigned =
+            MulOpcode == AMDGPUISD::MUL_I24 ||
+            (MulOpcode == ISD::MUL &&
+             TempNode->getOperand(MulIdx)->getFlags().hasNoSignedWrap() &&
+             !TempNode->getOperand(MulIdx)->getFlags().hasNoUnsignedWrap());
+        IterIsSigned |= Src0->IsSigned || Src1->IsSigned;
+        assert(IsSigned);
+        if (IterIsSigned != *IsSigned)
+          break;
         placeSources(*Src0, *Src1, Src0s, Src1s, I + 1);
         Src2s.push_back(DAG.getConstant(0, SL, MVT::i32));
         ChainLength = I + 2;
@@ -12695,18 +12716,19 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
       Src1 = resolveSources(DAG, SL, Src1s, false, true);
     }
 
+    assert(IsSigned);
     SDValue Src2 =
-        DAG.getExtOrTrunc(IsSigned, Src2s[ChainLength - 1], SL, MVT::i32);
+        DAG.getExtOrTrunc(*IsSigned, Src2s[ChainLength - 1], SL, MVT::i32);
 
-    SDValue IID = DAG.getTargetConstant(IsSigned ? Intrinsic::amdgcn_sdot4
-                                                 : Intrinsic::amdgcn_udot4,
+    SDValue IID = DAG.getTargetConstant(*IsSigned ? Intrinsic::amdgcn_sdot4
+                                                  : Intrinsic::amdgcn_udot4,
                                         SL, MVT::i64);
 
     assert(!VT.isVector());
     auto Dot = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SL, MVT::i32, IID, Src0,
                            Src1, Src2, DAG.getTargetConstant(0, SL, MVT::i1));
 
-    return DAG.getExtOrTrunc(IsSigned, Dot, SL, VT);
+    return DAG.getExtOrTrunc(*IsSigned, Dot, SL, VT);
   }
 
   if (VT != MVT::i32 || !DCI.isAfterLegalizeDAG())
diff --git a/llvm/test/CodeGen/AMDGPU/idot4s.ll b/llvm/test/CodeGen/AMDGPU/idot4s.ll
index 7edd24f12982ebd..e521039ce9ac838 100644
--- a/llvm/test/CodeGen/AMDGPU/idot4s.ll
+++ b/llvm/test/CodeGen/AMDGPU/idot4s.ll
@@ -143,7 +143,7 @@ define amdgpu_kernel void @idot4_acc32(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    global_load_b32 v0, v0, s[6:7]
 ; GFX11-DL-NEXT:    s_load_b32 s2, s[0:1], 0x0
 ; GFX11-DL-NEXT:    s_waitcnt vmcnt(0) lgkmcnt(0)
-; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v1, v0, s2
+; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v1, v0, s2 neg_lo:[1,1,0]
 ; GFX11-DL-NEXT:    global_store_b32 v2, v0, s[0:1]
 ; GFX11-DL-NEXT:    s_nop 0
 ; GFX11-DL-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
@@ -352,7 +352,7 @@ define amdgpu_kernel void @idot4_acc16(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    global_load_b32 v0, v0, s[6:7]
 ; GFX11-DL-NEXT:    global_load_i16 v3, v1, s[0:1]
 ; GFX11-DL-NEXT:    s_waitcnt vmcnt(0)
-; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v2, v0, v3
+; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v2, v0, v3 neg_lo:[1,1,0]
 ; GFX11-DL-NEXT:    global_store_b16 v1, v0, s[0:1]
 ; GFX11-DL-NEXT:    s_nop 0
 ; GFX11-DL-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
@@ -732,7 +732,7 @@ define amdgpu_kernel void @idot4_multiuse_mul1(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_2)
 ; GFX11-DL-NEXT:    v_mad_i32_i24 v2, v2, v3, s2
 ; GFX11-DL-NEXT:    v_mov_b32_e32 v3, 0
-; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v1, v0, v2
+; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v1, v0, v2 neg_lo:[1,1,0]
 ; GFX11-DL-NEXT:    global_store_b32 v3, v0, s[0:1]
 ; GFX11-DL-NEXT:    s_nop 0
 ; GFX11-DL-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
@@ -922,7 +922,7 @@ define amdgpu_kernel void @idot4_acc32_vecMul(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    global_load_b32 v0, v0, s[6:7]
 ; GFX11-DL-NEXT:    s_load_b32 s2, s[0:1], 0x0
 ; GFX11-DL-NEXT:    s_waitcnt vmcnt(0) lgkmcnt(0)
-; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v1, v0, s2
+; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v1, v0, s2 neg_lo:[1,1,0]
 ; GFX11-DL-NEXT:    global_store_b32 v2, v0, s[0:1]
 ; GFX11-DL-NEXT:    s_nop 0
 ; GFX11-DL-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
@@ -1356,7 +1356,7 @@ define amdgpu_kernel void @idot4_acc32_2ele(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    v_perm_b32 v0, v0, v0, 0xc0c0100
 ; GFX11-DL-NEXT:    s_waitcnt lgkmcnt(0)
 ; GFX11-DL-NEXT:    s_delay_alu instid0(VALU_DEP_1)
-; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v0, v1, s2
+; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v0, v1, s2 neg_lo:[1,1,0]
 ; GFX11-DL-NEXT:    global_store_b32 v2, v0, s[0:1]
 ; GFX11-DL-NEXT:    s_nop 0
 ; GFX11-DL-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
@@ -1534,7 +1534,7 @@ define amdgpu_kernel void @idot4_acc32_3ele(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    v_perm_b32 v0, v0, v0, 0xc020100
 ; GFX11-DL-NEXT:    s_waitcnt lgkmcnt(0)
 ; GFX11-DL-NEXT:    s_delay_alu instid0(VALU_DEP_1)
-; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v0, v1, s2
+; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v0, v1, s2 neg_lo:[1,1,0]
 ; GFX11-DL-NEXT:    global_store_b32 v2, v0, s[0:1]
 ; GFX11-DL-NEXT:    s_nop 0
 ; GFX11-DL-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
@@ -1719,7 +1719,7 @@ define amdgpu_kernel void @idot4_acc32_3ele_permuted(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    v_perm_b32 v0, v0, v0, 0xc020003
 ; GFX11-DL-NEXT:    s_waitcnt lgkmcnt(0)
 ; GFX11-DL-NEXT:    s_delay_alu instid0(VALU_DEP_1)
-; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v0, v1, s2
+; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v0, v1, s2 neg_lo:[1,1,0]
 ; GFX11-DL-NEXT:    global_store_b32 v2, v0, s[0:1]
 ; GFX11-DL-NEXT:    s_nop 0
 ; GFX11-DL-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
@@ -1885,7 +1885,7 @@ define amdgpu_kernel void @idot4_acc32_opt(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    global_load_b32 v1, v0, s[4:5]
 ; GFX11-DL-NEXT:    global_load_b32 v0, v0, s[6:7]
 ; GFX11-DL-NEXT:    s_waitcnt vmcnt(0)
-; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v1, v0, 0
+; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v1, v0, 0 neg_lo:[1,1,0]
 ; GFX11-DL-NEXT:    global_store_b32 v2, v0, s[0:1]
 ; GFX11-DL-NEXT:    s_nop 0
 ; GFX11-DL-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
@@ -2092,7 +2092,7 @@ define amdgpu_kernel void @idot4_acc32_3src(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    v_or_b32_e32 v1, v1, v2
 ; GFX11-DL-NEXT:    v_mov_b32_e32 v2, 0
 ; GFX11-DL-NEXT:    s_waitcnt lgkmcnt(0)
-; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v0, v1, s0
+; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v0, v1, s0 neg_lo:[1,1,0]
 ; GFX11-DL-NEXT:    global_store_b32 v2, v0, s[6:7]
 ; GFX11-DL-NEXT:    s_nop 0
 ; GFX11-DL-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
@@ -2299,7 +2299,7 @@ define amdgpu_kernel void @idot4_acc32_3src_3ele(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    v_or_b32_e32 v1, v1, v2
 ; GFX11-DL-NEXT:    v_mov_b32_e32 v2, 0
 ; GFX11-DL-NEXT:    s_waitcnt lgkmcnt(0)
-; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v0, v1, s0
+; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v0, v1, s0 neg_lo:[1,1,0]
 ; GFX11-DL-NEXT:    global_store_b32 v2, v0, s[6:7]
 ; GFX11-DL-NEXT:    s_nop 0
 ; GFX11-DL-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
@@ -2504,7 +2504,7 @@ define amdgpu_kernel void @idot4_bad_source(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    s_waitcnt lgkmcnt(0)
 ; GFX11-DL-NEXT:    v_mad_i32_i24 v2, v2, s2, s3
 ; GFX11-DL-NEXT:    s_delay_alu instid0(VALU_DEP_1)
-; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v1, v0, v2
+; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v1, v0, v2 neg_lo:[1,1,0]
 ; GFX11-DL-NEXT:    global_store_b32 v3, v0, s[0:1]
 ; GFX11-DL-NEXT:    s_nop 0
 ; GFX11-DL-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
@@ -2695,7 +2695,7 @@ define amdgpu_kernel void @idot4_commutative(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    v_perm_b32 v0, v0, v0, 0xc020100
 ; GFX11-DL-NEXT:    s_waitcnt lgkmcnt(0)
 ; GFX11-DL-NEXT:    s_delay_alu instid0(VALU_DEP_1)
-; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v0, v1, s2
+; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v0, v1, s2 neg_lo:[1,1,0]
 ; GFX11-DL-NEXT:    global_store_b32 v2, v0, s[0:1]
 ; GFX11-DL-NEXT:    s_nop 0
 ; GFX11-DL-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
@@ -2897,7 +2897,7 @@ define amdgpu_kernel void @idot4_acc32_3src_3ele_src0(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    v_or_b32_e32 v1, v1, v2
 ; GFX11-DL-NEXT:    v_mov_b32_e32 v2, 0
 ; GFX11-DL-NEXT:    s_waitcnt lgkmcnt(0)
-; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v0, v1, s0
+; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v0, v1, s0 neg_lo:[1,1,0]
 ; GFX11-DL-NEXT:    global_store_b32 v2, v0, s[6:7]
 ; GFX11-DL-NEXT:    s_nop 0
 ; GFX11-DL-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
@@ -3133,7 +3133,7 @@ define amdgpu_kernel void @idot4_4src(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    v_mov_b32_e32 v1, 0
 ; GFX11-DL-NEXT:    s_waitcnt lgkmcnt(0)
 ; GFX11-DL-NEXT:    s_delay_alu instid0(VALU_DEP_2)
-; GFX11-DL-NEXT:    ...

Comment on lines 69 to 70
// Tracks whether or not the byte is treated as a signed operand -- useful
// for arithmetic combines.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand what this means. The signededness only means anything in the context of the instruction producing or reading it. Is this talking about the unrelated bytes that aren't being tracked?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a flag telling the client how to create an instruction which reads the contained SDValue.

For the clients that use ByteProvider for vectorized arithmetic reduction instructions, it is relevant. These instructions typically involve a widening of type, and the IsSigned flag determines whether to use Sext or Zext. This information may only be present in the tree itself, and not in the ultimate source or top-level instruction.

@jrbyrnes jrbyrnes force-pushed the SignedCBP branch 2 times, most recently from 9da765f to 3b1345f Compare September 11, 2023 22:20
@jrbyrnes
Copy link
Contributor Author

Sorry for the immediate ping, however it, as it stands, the dot4 combine is causing failures which are blocking integration with upstream. I will either need to upstream this quickly or revert the commit and patch it offline.

@dstutt
Copy link
Collaborator

dstutt commented Sep 12, 2023

This partially fixes the issue we were seeing, but still fails if one of the src's is signed and the other is unsigned.

}

default: {
return ByteProvider<SDValue>::getSrc(Op, DestByte, SrcIndex);
if (auto L = dyn_cast<LoadSDNode>(Op))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems concerning, this could miss other load times (e.g. does this catch atomic sextload?). I'd expect this to be with all load handling?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced we should be combining atomics to begin with.

I've blacklisted atomics / memtrinsic until we have a better idea of how to handle the signedness of such. Primarily isMemIntrinsic will need to be handled as these should be supported ByteProviders.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But you aren't combining atomics, you are just interpreting the result value

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you are right --

Even so, I propose to blacklist for now.

Change-Id: I3c005123652412b32418b93734d2257b8195c0a4
@jrbyrnes
Copy link
Contributor Author

Thanks for swift responses.

This partially fixes the issue we were seeing, but still fails if one of the src's is signed and the other is unsigned.

  %19 = extractelement <4 x i8> %11, i64 1
  %20 = sext i8 %19 to i16
  %21 = extractelement <4 x i8> %13, i64 1
  %22 = zext i8 %21 to i16
  %23 = mul nsw i16 %22, %20

I should say this code looks a bit strange, since we are mixing types (signed / unsigned) in the mul operation. In general, I would think the unsigned type "wins" as this is the spirit of the c++ standard, but, since that was causing failures, I have just disabled the combine in such cases.

@dstutt
Copy link
Collaborator

dstutt commented Sep 12, 2023

I should say this code looks a bit strange, since we are mixing types (signed / unsigned) in the mul operation. In general, I would think the unsigned type "wins" as this is the spirit of the c++ standard, but, since that was causing failures, I have just disabled the combine in such cases.

It's not C++ - it's a Vulkan operation https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpSUDot that's used in the failing test (although I appreciate what you mean by the spirit of the c++ standard).

@dstutt
Copy link
Collaborator

dstutt commented Sep 12, 2023

Confirmed that the change to disable the combine for mixed signed has fixed the issue I was seeing.

@jayfoad
Copy link
Contributor

jayfoad commented Sep 12, 2023

  %19 = extractelement <4 x i8> %11, i64 1
  %20 = sext i8 %19 to i16
  %21 = extractelement <4 x i8> %13, i64 1
  %22 = zext i8 %21 to i16
  %23 = mul nsw i16 %22, %20

I should say this code looks a bit strange, since we are mixing types (signed / unsigned) in the mul operation. In general, I would think the unsigned type "wins" as this is the spirit of the c++ standard, but, since that was causing failures, I have just disabled the combine in such cases.

This is LLVM IR. It has well defined semantics that have nothing to do with "the spirit of the c++ standard". Whatever your DAG combine does, it needs to preserve the semantics.

@arsenm
Copy link
Contributor

arsenm commented Sep 12, 2023

Thanks for swift responses. In general, I would think the unsigned type "wins"

Signed or unsigned types do not exist. There are only operations which you may associate with signed or unsigned behavior

@jrbyrnes
Copy link
Contributor Author

Thanks for swift responses. In general, I would think the unsigned type "wins"

Signed or unsigned types do not exist. There are only operations which you may associate with signed or unsigned behavior

This is the intention of the IsSigned flag -- to indicate whether or not we can associate the newly created instruction with signed behavior.

Subsequent iterations will try harder to reason about atomics / intrinsics, but we cannot combine into v_dot4 when we have mixed (signed/unsigned) behavior. In such cases, to preserve semantics of IR, the combine fails.

@jrbyrnes
Copy link
Contributor Author

The problematic commit has been reverted, patch is now part of the reland review https://reviews.llvm.org/D155995

@jrbyrnes jrbyrnes closed this Sep 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants