Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMDGPU] Rework dot4 signedness checks #68757

Merged
merged 1 commit into from
Nov 30, 2023
Merged

Conversation

jrbyrnes
Copy link
Contributor

Rely on AMDGPUISD::*_MUL for most cases -- as signedness semantics have already been calculated and encoded into these ops.

In cases where we have an ISD::MUL, try harder to reason about the semantics given the signedness info in the ByteProvider.

Solves some edge cases.

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 11, 2023

@llvm/pr-subscribers-backend-amdgpu

Author: Jeffrey Byrnes (jrbyrnes)

Changes

Rely on AMDGPUISD::*_MUL for most cases -- as signedness semantics have already been calculated and encoded into these ops.

In cases where we have an ISD::MUL, try harder to reason about the semantics given the signedness info in the ByteProvider.

Solves some edge cases.


Patch is 23.02 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68757.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+50-25)
  • (modified) llvm/test/CodeGen/AMDGPU/idot4u.ll (+428)
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 9bd0f5390b19e31..f9cb443f3223f72 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -12952,30 +12952,51 @@ static bool isMul(const SDValue Op) {
 
 static std::optional<bool> checkSignedness(const SDValue &N,
                                            ByteProvider<SDValue> &Src0,
-                                           ByteProvider<SDValue> &Src1) {
+                                           ByteProvider<SDValue> &Src1,
+                                           const SDValue &S1Op,
+                                           const SDValue &S0Op) {
   auto MulOpcode = N.getOpcode();
-  std::optional<bool> IterIsSigned;
-  // Both sides of the tree must have the same signedness semantics.
-  if ((Src0.IsSigned != Src1.IsSigned) ||
-      (Src0.IsSigned.value_or(false) != Src1.IsSigned.value_or(false)))
-    return IterIsSigned;
-  // If we have a MUL_U24 op with signed semantics, then fail.
-  if (Src0.IsSigned.value_or(false) && MulOpcode == AMDGPUISD::MUL_U24)
-    return IterIsSigned;
-  // If we have a MUL_I24 op with unsigned semantics, then fail.
-  if (!Src0.IsSigned.value_or(true) && MulOpcode == AMDGPUISD::MUL_I24)
-    return IterIsSigned;
-
-  bool TopLevelSignedness =
-      MulOpcode == AMDGPUISD::MUL_I24 ||
-      (MulOpcode == ISD::MUL && N.getNode()->getFlags().hasNoSignedWrap() &&
-       !N.getNode()->getFlags().hasNoUnsignedWrap());
-
-  // In cases where we are accumulating into an i8 (for v_dot4), the
-  // ByteProvider will not have signedness info since the MSBs are dont-cares.
-  // In this case, we simply use the TopLevelSignedness of the instruction.
-  IterIsSigned = Src0.IsSigned.value_or(TopLevelSignedness);
-  return IterIsSigned;
+
+  // We have previously determined the signedness semantics
+  if (MulOpcode == AMDGPUISD::MUL_U24 || MulOpcode == AMDGPUISD::MUL_I24)
+    return MulOpcode == AMDGPUISD::MUL_I24;
+
+  // We don't know the signedness semantics, try harder to determine.
+
+  // Case1: both Srcs have signedness info, use it.
+  if (Src0.IsSigned && Src1.IsSigned)
+    return *Src0.IsSigned == *Src1.IsSigned
+               ? std::optional<bool>(*Src0.IsSigned)
+               : std::nullopt;
+
+  // Case2: we are missing signedness info from one.
+  // We can determine unsigned semantics if one has unsigned info and the other
+  // has leading zeros.
+  if (Src0.IsSigned != Src1.IsSigned) {
+    auto SrcIsSigned = Src0.IsSigned ? *Src0.IsSigned : *Src1.IsSigned;
+    // If we only have signed info, then fail
+    if (SrcIsSigned)
+      return std::nullopt;
+    bool Mismatch = false;
+    auto OtherOp = Src0.IsSigned ? S1Op : S0Op;
+    for (int I = 1; I < OtherOp.getValueSizeInBits() / 8; I++) {
+      auto BP = calculateByteProvider(OtherOp, I, 0);
+      if (!BP->isConstantZero()) {
+        Mismatch = true;
+        break;
+      }
+    }
+
+    // If we are missing unsigned semantics, but all the MSBs are zero, then it
+    // is okay.
+    return Mismatch ? std::nullopt : std::optional<bool>(false);
+  }
+
+  // Case3: neither src has signedness info.
+  // In this case, we either any_extended and will mask out the MSBs, or we
+  // truly don't care about the MSBs. Either way, using unsigned semantics is
+  // correct.
+  return false;
 }
 
 SDValue SITargetLowering::performAddCombine(SDNode *N,
@@ -13019,7 +13040,9 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
         break;
 
       auto IterIsSigned =
-          checkSignedness(TempNode->getOperand(MulIdx), *Src0, *Src1);
+          checkSignedness(TempNode->getOperand(MulIdx), *Src0, *Src1,
+                          TempNode->getOperand(MulIdx)->getOperand(0),
+                          TempNode->getOperand(MulIdx)->getOperand(1));
       if (!IterIsSigned)
         break;
       if (!IsSigned)
@@ -13041,7 +13064,9 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
         if (!Src1)
           break;
         auto IterIsSigned =
-            checkSignedness(TempNode->getOperand(AddIdx), *Src0, *Src1);
+            checkSignedness(TempNode->getOperand(AddIdx), *Src0, *Src1,
+                            TempNode->getOperand(AddIdx)->getOperand(0),
+                            TempNode->getOperand(AddIdx)->getOperand(1));
         if (!IterIsSigned)
           break;
         assert(IsSigned);
diff --git a/llvm/test/CodeGen/AMDGPU/idot4u.ll b/llvm/test/CodeGen/AMDGPU/idot4u.ll
index a82c5215f3b2c65..ba9ccb4c636f948 100644
--- a/llvm/test/CodeGen/AMDGPU/idot4u.ll
+++ b/llvm/test/CodeGen/AMDGPU/idot4u.ll
@@ -1735,6 +1735,268 @@ entry:
   ret void
 }
 
+
+define amdgpu_kernel void @notdot4_mixedtypes2(ptr addrspace(1) %src1,
+; GFX7-LABEL: notdot4_mixedtypes2:
+; GFX7:       ; %bb.0: ; %entry
+; GFX7-NEXT:    s_load_dwordx4 s[4:7], s[0:1], 0x9
+; GFX7-NEXT:    s_load_dwordx2 s[0:1], s[0:1], 0xd
+; GFX7-NEXT:    s_mov_b32 s3, 0xf000
+; GFX7-NEXT:    s_mov_b32 s10, 0
+; GFX7-NEXT:    s_mov_b32 s11, s3
+; GFX7-NEXT:    s_waitcnt lgkmcnt(0)
+; GFX7-NEXT:    s_mov_b64 s[8:9], s[4:5]
+; GFX7-NEXT:    v_lshlrev_b32_e32 v0, 2, v0
+; GFX7-NEXT:    v_mov_b32_e32 v1, 0
+; GFX7-NEXT:    buffer_load_dword v2, v[0:1], s[8:11], 0 addr64
+; GFX7-NEXT:    s_mov_b64 s[8:9], s[6:7]
+; GFX7-NEXT:    buffer_load_dword v0, v[0:1], s[8:11], 0 addr64
+; GFX7-NEXT:    s_mov_b32 s2, -1
+; GFX7-NEXT:    buffer_load_ushort v1, off, s[0:3], 0
+; GFX7-NEXT:    s_waitcnt vmcnt(2)
+; GFX7-NEXT:    v_bfe_i32 v3, v2, 0, 8
+; GFX7-NEXT:    v_bfe_u32 v4, v2, 8, 8
+; GFX7-NEXT:    s_waitcnt vmcnt(1)
+; GFX7-NEXT:    v_bfe_i32 v7, v0, 8, 8
+; GFX7-NEXT:    v_and_b32_e32 v7, 0xffff, v7
+; GFX7-NEXT:    v_bfe_i32 v5, v2, 16, 8
+; GFX7-NEXT:    v_and_b32_e32 v3, 0xffff, v3
+; GFX7-NEXT:    v_and_b32_e32 v6, 0xff, v0
+; GFX7-NEXT:    s_waitcnt vmcnt(0)
+; GFX7-NEXT:    v_mad_u32_u24 v1, v4, v7, v1
+; GFX7-NEXT:    v_and_b32_e32 v5, 0xffff, v5
+; GFX7-NEXT:    v_bfe_u32 v8, v0, 16, 8
+; GFX7-NEXT:    v_ashrrev_i32_e32 v0, 24, v0
+; GFX7-NEXT:    v_mad_u32_u24 v1, v3, v6, v1
+; GFX7-NEXT:    v_lshrrev_b32_e32 v2, 24, v2
+; GFX7-NEXT:    v_and_b32_e32 v0, 0xffff, v0
+; GFX7-NEXT:    v_mad_u32_u24 v1, v5, v8, v1
+; GFX7-NEXT:    v_mad_u32_u24 v0, v2, v0, v1
+; GFX7-NEXT:    buffer_store_short v0, off, s[0:3], 0
+; GFX7-NEXT:    s_endpgm
+;
+; GFX8-LABEL: notdot4_mixedtypes2:
+; GFX8:       ; %bb.0: ; %entry
+; GFX8-NEXT:    s_load_dwordx4 s[4:7], s[0:1], 0x24
+; GFX8-NEXT:    s_load_dwordx2 s[0:1], s[0:1], 0x34
+; GFX8-NEXT:    v_lshlrev_b32_e32 v2, 2, v0
+; GFX8-NEXT:    v_mov_b32_e32 v5, 0xff
+; GFX8-NEXT:    s_waitcnt lgkmcnt(0)
+; GFX8-NEXT:    v_mov_b32_e32 v1, s5
+; GFX8-NEXT:    v_add_u32_e32 v0, vcc, s4, v2
+; GFX8-NEXT:    v_addc_u32_e32 v1, vcc, 0, v1, vcc
+; GFX8-NEXT:    flat_load_dword v3, v[0:1]
+; GFX8-NEXT:    v_mov_b32_e32 v1, s7
+; GFX8-NEXT:    v_add_u32_e32 v0, vcc, s6, v2
+; GFX8-NEXT:    v_addc_u32_e32 v1, vcc, 0, v1, vcc
+; GFX8-NEXT:    flat_load_dword v2, v[0:1]
+; GFX8-NEXT:    v_mov_b32_e32 v0, s0
+; GFX8-NEXT:    v_mov_b32_e32 v1, s1
+; GFX8-NEXT:    flat_load_ushort v4, v[0:1]
+; GFX8-NEXT:    s_waitcnt vmcnt(2)
+; GFX8-NEXT:    v_lshrrev_b32_e32 v9, 8, v3
+; GFX8-NEXT:    v_and_b32_e32 v9, 0xff, v9
+; GFX8-NEXT:    v_lshrrev_b32_e32 v6, 16, v3
+; GFX8-NEXT:    v_bfe_i32 v7, v3, 0, 8
+; GFX8-NEXT:    v_bfe_i32 v6, v6, 0, 8
+; GFX8-NEXT:    v_lshrrev_b32_e32 v3, 24, v3
+; GFX8-NEXT:    s_waitcnt vmcnt(1)
+; GFX8-NEXT:    v_lshrrev_b32_e32 v10, 8, v2
+; GFX8-NEXT:    v_bfe_i32 v10, v10, 0, 8
+; GFX8-NEXT:    v_and_b32_e32 v8, 0xff, v2
+; GFX8-NEXT:    s_waitcnt vmcnt(0)
+; GFX8-NEXT:    v_mad_u16 v4, v9, v10, v4
+; GFX8-NEXT:    v_and_b32_sdwa v5, v2, v5 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD
+; GFX8-NEXT:    v_lshrrev_b32_e32 v2, 24, v2
+; GFX8-NEXT:    v_mad_u16 v4, v7, v8, v4
+; GFX8-NEXT:    v_bfe_i32 v2, v2, 0, 8
+; GFX8-NEXT:    v_mad_u16 v4, v6, v5, v4
+; GFX8-NEXT:    v_mad_u16 v2, v3, v2, v4
+; GFX8-NEXT:    flat_store_short v[0:1], v2
+; GFX8-NEXT:    s_endpgm
+;
+; GFX9-NODL-LABEL: notdot4_mixedtypes2:
+; GFX9-NODL:       ; %bb.0: ; %entry
+; GFX9-NODL-NEXT:    s_load_dwordx4 s[4:7], s[0:1], 0x24
+; GFX9-NODL-NEXT:    s_load_dwordx2 s[2:3], s[0:1], 0x34
+; GFX9-NODL-NEXT:    v_lshlrev_b32_e32 v0, 2, v0
+; GFX9-NODL-NEXT:    s_movk_i32 s0, 0xff
+; GFX9-NODL-NEXT:    s_waitcnt lgkmcnt(0)
+; GFX9-NODL-NEXT:    global_load_dword v1, v0, s[4:5]
+; GFX9-NODL-NEXT:    global_load_dword v2, v0, s[6:7]
+; GFX9-NODL-NEXT:    v_mov_b32_e32 v0, 0
+; GFX9-NODL-NEXT:    global_load_ushort v3, v0, s[2:3]
+; GFX9-NODL-NEXT:    s_waitcnt vmcnt(2)
+; GFX9-NODL-NEXT:    v_lshrrev_b32_e32 v7, 8, v1
+; GFX9-NODL-NEXT:    s_waitcnt vmcnt(1)
+; GFX9-NODL-NEXT:    v_lshrrev_b32_e32 v8, 8, v2
+; GFX9-NODL-NEXT:    v_and_b32_e32 v7, 0xff, v7
+; GFX9-NODL-NEXT:    v_bfe_i32 v8, v8, 0, 8
+; GFX9-NODL-NEXT:    v_lshrrev_b32_e32 v4, 16, v1
+; GFX9-NODL-NEXT:    v_bfe_i32 v5, v1, 0, 8
+; GFX9-NODL-NEXT:    v_and_b32_e32 v6, 0xff, v2
+; GFX9-NODL-NEXT:    s_waitcnt vmcnt(0)
+; GFX9-NODL-NEXT:    v_mad_legacy_u16 v3, v7, v8, v3
+; GFX9-NODL-NEXT:    v_and_b32_sdwa v9, v2, s0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD
+; GFX9-NODL-NEXT:    v_lshrrev_b32_e32 v2, 24, v2
+; GFX9-NODL-NEXT:    v_bfe_i32 v4, v4, 0, 8
+; GFX9-NODL-NEXT:    v_mad_legacy_u16 v3, v5, v6, v3
+; GFX9-NODL-NEXT:    v_lshrrev_b32_e32 v1, 24, v1
+; GFX9-NODL-NEXT:    v_bfe_i32 v2, v2, 0, 8
+; GFX9-NODL-NEXT:    v_mad_legacy_u16 v3, v4, v9, v3
+; GFX9-NODL-NEXT:    v_mad_legacy_u16 v1, v1, v2, v3
+; GFX9-NODL-NEXT:    global_store_short v0, v1, s[2:3]
+; GFX9-NODL-NEXT:    s_endpgm
+;
+; GFX9-DL-LABEL: notdot4_mixedtypes2:
+; GFX9-DL:       ; %bb.0: ; %entry
+; GFX9-DL-NEXT:    s_load_dwordx4 s[4:7], s[0:1], 0x24
+; GFX9-DL-NEXT:    s_load_dwordx2 s[2:3], s[0:1], 0x34
+; GFX9-DL-NEXT:    v_lshlrev_b32_e32 v0, 2, v0
+; GFX9-DL-NEXT:    s_movk_i32 s0, 0xff
+; GFX9-DL-NEXT:    s_waitcnt lgkmcnt(0)
+; GFX9-DL-NEXT:    global_load_dword v1, v0, s[4:5]
+; GFX9-DL-NEXT:    global_load_dword v2, v0, s[6:7]
+; GFX9-DL-NEXT:    v_mov_b32_e32 v0, 0
+; GFX9-DL-NEXT:    global_load_ushort v3, v0, s[2:3]
+; GFX9-DL-NEXT:    s_waitcnt vmcnt(2)
+; GFX9-DL-NEXT:    v_lshrrev_b32_e32 v7, 8, v1
+; GFX9-DL-NEXT:    s_waitcnt vmcnt(1)
+; GFX9-DL-NEXT:    v_lshrrev_b32_e32 v8, 8, v2
+; GFX9-DL-NEXT:    v_and_b32_e32 v7, 0xff, v7
+; GFX9-DL-NEXT:    v_bfe_i32 v8, v8, 0, 8
+; GFX9-DL-NEXT:    v_lshrrev_b32_e32 v4, 16, v1
+; GFX9-DL-NEXT:    v_bfe_i32 v5, v1, 0, 8
+; GFX9-DL-NEXT:    v_and_b32_e32 v6, 0xff, v2
+; GFX9-DL-NEXT:    s_waitcnt vmcnt(0)
+; GFX9-DL-NEXT:    v_mad_legacy_u16 v3, v7, v8, v3
+; GFX9-DL-NEXT:    v_and_b32_sdwa v9, v2, s0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD
+; GFX9-DL-NEXT:    v_lshrrev_b32_e32 v2, 24, v2
+; GFX9-DL-NEXT:    v_bfe_i32 v4, v4, 0, 8
+; GFX9-DL-NEXT:    v_mad_legacy_u16 v3, v5, v6, v3
+; GFX9-DL-NEXT:    v_lshrrev_b32_e32 v1, 24, v1
+; GFX9-DL-NEXT:    v_bfe_i32 v2, v2, 0, 8
+; GFX9-DL-NEXT:    v_mad_legacy_u16 v3, v4, v9, v3
+; GFX9-DL-NEXT:    v_mad_legacy_u16 v1, v1, v2, v3
+; GFX9-DL-NEXT:    global_store_short v0, v1, s[2:3]
+; GFX9-DL-NEXT:    s_endpgm
+;
+; GFX10-DL-LABEL: notdot4_mixedtypes2:
+; GFX10-DL:       ; %bb.0: ; %entry
+; GFX10-DL-NEXT:    s_clause 0x1
+; GFX10-DL-NEXT:    s_load_dwordx4 s[4:7], s[0:1], 0x24
+; GFX10-DL-NEXT:    s_load_dwordx2 s[2:3], s[0:1], 0x34
+; GFX10-DL-NEXT:    v_lshlrev_b32_e32 v0, 2, v0
+; GFX10-DL-NEXT:    v_mov_b32_e32 v8, 0xff
+; GFX10-DL-NEXT:    s_waitcnt lgkmcnt(0)
+; GFX10-DL-NEXT:    s_clause 0x1
+; GFX10-DL-NEXT:    global_load_dword v1, v0, s[4:5]
+; GFX10-DL-NEXT:    global_load_dword v2, v0, s[6:7]
+; GFX10-DL-NEXT:    v_mov_b32_e32 v0, 0
+; GFX10-DL-NEXT:    global_load_ushort v3, v0, s[2:3]
+; GFX10-DL-NEXT:    s_waitcnt vmcnt(2)
+; GFX10-DL-NEXT:    v_lshrrev_b32_e32 v4, 8, v1
+; GFX10-DL-NEXT:    s_waitcnt vmcnt(1)
+; GFX10-DL-NEXT:    v_lshrrev_b32_e32 v5, 8, v2
+; GFX10-DL-NEXT:    v_lshrrev_b32_e32 v6, 16, v1
+; GFX10-DL-NEXT:    v_bfe_i32 v7, v1, 0, 8
+; GFX10-DL-NEXT:    v_and_b32_e32 v9, 0xff, v2
+; GFX10-DL-NEXT:    v_and_b32_e32 v4, 0xff, v4
+; GFX10-DL-NEXT:    v_bfe_i32 v5, v5, 0, 8
+; GFX10-DL-NEXT:    v_lshrrev_b32_e32 v1, 24, v1
+; GFX10-DL-NEXT:    s_waitcnt vmcnt(0)
+; GFX10-DL-NEXT:    v_mad_u16 v3, v4, v5, v3
+; GFX10-DL-NEXT:    v_bfe_i32 v4, v6, 0, 8
+; GFX10-DL-NEXT:    v_and_b32_sdwa v5, v2, v8 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD
+; GFX10-DL-NEXT:    v_lshrrev_b32_e32 v2, 24, v2
+; GFX10-DL-NEXT:    v_mad_u16 v3, v7, v9, v3
+; GFX10-DL-NEXT:    v_bfe_i32 v2, v2, 0, 8
+; GFX10-DL-NEXT:    v_mad_u16 v3, v4, v5, v3
+; GFX10-DL-NEXT:    v_mad_u16 v1, v1, v2, v3
+; GFX10-DL-NEXT:    global_store_short v0, v1, s[2:3]
+; GFX10-DL-NEXT:    s_endpgm
+;
+; GFX11-DL-LABEL: notdot4_mixedtypes2:
+; GFX11-DL:       ; %bb.0: ; %entry
+; GFX11-DL-NEXT:    s_clause 0x1
+; GFX11-DL-NEXT:    s_load_b128 s[4:7], s[0:1], 0x24
+; GFX11-DL-NEXT:    s_load_b64 s[0:1], s[0:1], 0x34
+; GFX11-DL-NEXT:    v_lshlrev_b32_e32 v0, 2, v0
+; GFX11-DL-NEXT:    s_waitcnt lgkmcnt(0)
+; GFX11-DL-NEXT:    s_clause 0x1
+; GFX11-DL-NEXT:    global_load_b32 v1, v0, s[4:5]
+; GFX11-DL-NEXT:    global_load_b32 v0, v0, s[6:7]
+; GFX11-DL-NEXT:    v_mov_b32_e32 v2, 0
+; GFX11-DL-NEXT:    s_waitcnt vmcnt(1)
+; GFX11-DL-NEXT:    v_lshrrev_b32_e32 v4, 8, v1
+; GFX11-DL-NEXT:    s_waitcnt vmcnt(0)
+; GFX11-DL-NEXT:    v_and_b32_e32 v9, 0xff, v0
+; GFX11-DL-NEXT:    global_load_u16 v3, v2, s[0:1]
+; GFX11-DL-NEXT:    v_lshrrev_b32_e32 v5, 8, v0
+; GFX11-DL-NEXT:    v_lshrrev_b32_e32 v6, 16, v1
+; GFX11-DL-NEXT:    v_and_b32_e32 v4, 0xff, v4
+; GFX11-DL-NEXT:    v_lshrrev_b32_e32 v7, 16, v0
+; GFX11-DL-NEXT:    v_bfe_i32 v8, v1, 0, 8
+; GFX11-DL-NEXT:    v_bfe_i32 v5, v5, 0, 8
+; GFX11-DL-NEXT:    v_lshrrev_b32_e32 v0, 24, v0
+; GFX11-DL-NEXT:    v_lshrrev_b32_e32 v1, 24, v1
+; GFX11-DL-NEXT:    s_delay_alu instid0(VALU_DEP_2) | instskip(SKIP_4) | instid1(VALU_DEP_3)
+; GFX11-DL-NEXT:    v_bfe_i32 v0, v0, 0, 8
+; GFX11-DL-NEXT:    s_waitcnt vmcnt(0)
+; GFX11-DL-NEXT:    v_mad_u16 v3, v4, v5, v3
+; GFX11-DL-NEXT:    v_bfe_i32 v4, v6, 0, 8
+; GFX11-DL-NEXT:    v_and_b32_e32 v5, 0xff, v7
+; GFX11-DL-NEXT:    v_mad_u16 v3, v8, v9, v3
+; GFX11-DL-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX11-DL-NEXT:    v_mad_u16 v3, v4, v5, v3
+; GFX11-DL-NEXT:    v_mad_u16 v0, v1, v0, v3
+; GFX11-DL-NEXT:    global_store_b16 v2, v0, s[0:1]
+; GFX11-DL-NEXT:    s_nop 0
+; GFX11-DL-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
+; GFX11-DL-NEXT:    s_endpgm
+                                              ptr addrspace(1) %src2,
+                                              ptr addrspace(1) nocapture %dst) {
+entry:
+  %idx = call i32 @llvm.amdgcn.workitem.id.x()
+  %gep1 = getelementptr <4 x i8>, ptr addrspace(1) %src1, i32 %idx
+  %vec1 = load <4 x i8>, ptr addrspace(1) %gep1
+  %gep2 = getelementptr <4 x i8>, ptr addrspace(1) %src2, i32 %idx
+  %vec2 = load <4 x i8>, ptr addrspace(1) %gep2
+
+  %v1e0 = extractelement <4 x i8> %vec1, i64 0
+  %cv1e0 = sext i8 %v1e0 to i16
+  %v2e0 = extractelement <4 x i8> %vec2, i64 0
+  %cv2e0 = zext i8 %v2e0 to i16
+  %mul1 = mul nuw nsw i16 %cv1e0, %cv2e0
+
+  %v1e1 = extractelement <4 x i8> %vec1, i64 1
+  %cv1e1 = zext i8 %v1e1 to i16
+  %v2e1 = extractelement <4 x i8> %vec2, i64 1
+  %cv2e1 = sext i8 %v2e1 to i16
+  %mul2 = mul nuw nsw i16 %cv1e1, %cv2e1
+
+  %v1e2 = extractelement <4 x i8> %vec1, i64 2
+  %cv1e2 = sext i8 %v1e2 to i16
+  %v2e2 = extractelement <4 x i8> %vec2, i64 2
+  %cv2e2 = zext i8 %v2e2 to i16
+  %mul3 = mul nuw nsw i16 %cv1e2, %cv2e2
+
+  %v1e3 = extractelement <4 x i8> %vec1, i64 3
+  %cv1e3 = zext i8 %v1e3 to i16
+  %v2e3 = extractelement <4 x i8> %vec2, i64 3
+  %cv2e3 = sext i8 %v2e3 to i16
+  %mul4 = mul nuw nsw i16 %cv1e3, %cv2e3
+
+  %acc = load i16, ptr addrspace(1) %dst, align 2
+  %add1 = add i16 %mul2, %acc
+  %add2 = add i16 %add1, %mul1
+  %add3 = add i16 %add2, %mul3
+  %add4 = add i16 %add3, %mul4
+
+  store i16 %add4, ptr addrspace(1) %dst, align 2
+  ret void
+}
+
 ; TODO: cleanup s_lshr_b32
 define amdgpu_kernel void @udot4_acc32_vecMul(ptr addrspace(1) %src1,
 ; GFX7-LABEL: udot4_acc32_vecMul:
@@ -4622,4 +4884,170 @@ entry:
   ret void
 }
 
+define amdgpu_kernel void @idot4_acc32_anyext(ptr addrspace(1) %src1,
+; GFX7-LABEL: idot4_acc32_anyext:
+; GFX7:       ; %bb.0: ; %entry
+; GFX7-NEXT:    s_load_dwordx4 s[4:7], s[0:1], 0x9
+; GFX7-NEXT:    s_load_dwordx2 s[0:1], s[0:1], 0xd
+; GFX7-NEXT:    s_mov_b32 s3, 0xf000
+; GFX7-NEXT:    s_mov_b32 s10, 0
+; GFX7-NEXT:    s_mov_b32 s11, s3
+; GFX7-NEXT:    s_waitcnt lgkmcnt(0)
+; GFX7-NEXT:    s_mov_b64 s[8:9], s[4:5]
+; GFX7-NEXT:    v_lshlrev_b32_e32 v0, 2, v0
+; GFX7-NEXT:    v_mov_b32_e32 v1, 0
+; GFX7-NEXT:    buffer_load_dword v2, v[0:1], s[8:11], 0 addr64
+; GFX7-NEXT:    s_mov_b64 s[8:9], s[6:7]
+; GFX7-NEXT:    buffer_load_dword v0, v[0:1], s[8:11], 0 addr64
+; GFX7-NEXT:    s_load_dword s4, s[0:1], 0x0
+; GFX7-NEXT:    s_mov_b32 s2, -1
+; GFX7-NEXT:    s_waitcnt vmcnt(1)
+; GFX7-NEXT:    v_and_b32_e32 v1, 0xff, v2
+; GFX7-NEXT:    v_bfe_u32 v2, v2, 8, 8
+; GFX7-NEXT:    s_waitcnt vmcnt(0)
+; GFX7-NEXT:    v_bfe_u32 v0, v0, 8, 8
+; GFX7-NEXT:    s_waitcnt lgkmcnt(0)
+; GFX7-NEXT:    v_mad_u32_u24 v1, v1, v1, s4
+; GFX7-NEXT:    v_mad_u32_u24 v0, v2, v0, v1
+; GFX7-NEXT:    buffer_store_dword v0, off, s[0:3], 0
+; GFX7-NEXT:    s_endpgm
+;
+; GFX8-LABEL: idot4_acc32_anyext:
+; GFX8:       ; %bb.0: ; %entry
+; GFX8-NEXT:    s_load_dwordx4 s[4:7], s[0:1], 0x24
+; GFX8-NEXT:    s_load_dwordx2 s[0:1], s[0:1], 0x34
+; GFX8-NEXT:    v_lshlrev_b32_e32 v2, 2, v0
+; GFX8-NEXT:    s_waitcnt lgkmcnt(0)
+; GFX8-NEXT:    v_mov_b32_e32 v1, s5
+; GFX8-NEXT:    v_add_u32_e32 v0, vcc, s4, v2
+; GFX8-NEXT:    v_addc_u32_e32 v1, vcc, 0, v1, vcc
+; GFX8-NEXT:    flat_load_dword v3, v[0:1]
+; GFX8-NEXT:    v_mov_b32_e32 v1, s7
+; GFX8-NEXT:    v_add_u32_e32 v0, vcc, s6, v2
+; GFX8-NEXT:    v_addc_u32_e32 v1, vcc, 0, v1, vcc
+; GFX8-NEXT:    flat_load_dword v0, v[0:1]
+; GFX8-NEXT:    s_load_dword s2, s[0:1], 0x0
+; GFX8-NEXT:    s_waitcnt vmcnt(1)
+; GFX8-NEXT:    v_and_b32_e32 v1, 0xff, v3
+; GFX8-NEXT:    v_bfe_u32 v2, v3, 8, 8
+; GFX8-NEXT:    s_waitcnt lgkmcnt(0)
+; GFX8-NEXT:    v_mad_u32_u24 v1, v1, v1, s2
+; GFX8-NEXT:    s_waitcnt vmcnt(0)
+; GFX8-NEXT:    v_bfe_u32 v0, v0, 8, 8
+; GFX8-NEXT:    v_mad_u32_u24 v2, v2, v0, v1
+; GFX8-NEXT:    v_mov_b32_e32 v0, s0
+; GFX8-NEXT:    v_mov_b32_e32 v1, s1
+; GFX8-NEXT:    flat_store_dword v[0:1], v2
+; GFX8-NEXT:    s_endpgm
+;
+; GFX9-NODL-LABEL: idot4_acc32_anyext:
+; GFX9-NODL:       ; %bb.0: ; %entry
+; GFX9-NODL-NEXT:    s_load_dwordx4 s[4:7], s[0:1], 0x24
+; GFX9-NODL-NEXT:    s_load_dwordx2 s[2:3], s[0:1], 0x34
+; GFX9-NODL-NEXT:    v_lshlrev_b32_e32 v0, 2, v0
+; GFX9-NODL-NEXT:    s_waitcnt lgkmcnt(0)
+; GFX9-NODL-NEXT:    global_load_dword v1, v0, s[4:5]
+; GFX9-NODL-NEXT:    global_load_dword v2, v0, s[6:7]
+; GFX9-NODL-NEXT:    s_load_dword s0, s[2:3], 0x0
+; GFX9-NODL-NEXT:    v_mov_b32_e32 v0, 0
+; GFX9-NODL-NEXT:    s_waitcnt vmcnt(1)
+; GFX9-NODL-NEXT:    v_mul_u32_u24_sdwa v3, v1, v1 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0
+; GFX9-NODL-NEXT:    s_waitcnt vmcnt(0)
+; GFX9-NODL-NEXT:    v_mul_u32_u24_sdwa v1, v1, v2 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1
+; GFX9-NODL-NEXT:    s_waitcnt lgkmcnt(0)
+; GFX9-NODL-NEXT:    v_add3_u32 v1, v3, s0, v1
+; GFX9-NODL-NEXT:    global_store_dword v0, v1, s[2:3]
+; GFX9-NODL-NEXT:    s_endpgm
+;
+; GFX9-DL-LABEL: idot4_acc32_anyext:
+; GFX9-DL:       ; %bb.0: ; %entry
+; GFX9-DL-NEXT:    s_load_dwordx4 s[4:7], s[0:1], 0x24
+; GFX9-DL-NEXT:    s_load_dwordx2 s[2:3], s[0:1], 0x34
+; GFX9-DL-NEXT:    v_lshlrev_b32_e32 v0, 2, v0
+; GFX9-DL-NEXT:    s_mov_b32 s1, 0xc0c0500
+; GFX9-DL-NEXT:    s_waitcnt lgkmcnt(0)
+; GFX9-DL-NEXT:    global_load_dword v1, v0, s[4:5]
+; GFX9-DL-NEXT:    global_load_dword v2, v0, s[6:7]
+; GFX9-DL-NEXT:    s_load_dword s0, s[2:3], 0x0
+; GFX9-DL-NEXT:    s_mov_b32 s4, 0xc0c0100
+; GFX9-DL-NEXT:    v_mov_b32_e32 v0, 0
+; GFX9-DL-NEXT:    s_waitcnt vmcnt(0)
+; GFX9-DL-NEXT:    v_perm_b32 v2, v2, v1, s1
+; GFX9-DL-NEXT:    v_perm_b32 v1, v1, v1, s4
+; GFX9-DL-NEXT:    s_waitcnt lgkmcnt(0)
+; GFX9-DL-NEXT:    v_...
[truncated]

llvm/lib/Target/AMDGPU/SIISelLowering.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/AMDGPU/SIISelLowering.cpp Outdated Show resolved Hide resolved
@jrbyrnes
Copy link
Contributor Author

jrbyrnes commented Oct 27, 2023

Passes PSDB

Copy link

github-actions bot commented Nov 7, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@jrbyrnes
Copy link
Contributor Author

jrbyrnes commented Nov 7, 2023

A more complete rework of the signedness semantics matching to remove ByteProvider.IsSigned.

Use computeKnownBits as suggested by @arsenm .

Special case when we don't know the sign bit for one or both ops: Given that we have valid ByteProviders for each op in the candidate mul, the upper bits must be extension bits. Thus, if we don't know the sign bit, one of two things must have occurred: 1. sign extend from unknown bit, 2. any extend. In either case, we can use the sign extend version of dot4 for this op. Thus, if the other op is also unknown, or known to have its sign bit set (this implies it must have been sign extended), we can use the signed version of dot.

A draft while I test the new implementation.

@jrbyrnes jrbyrnes marked this pull request as ready for review November 8, 2023 18:16
@jrbyrnes
Copy link
Contributor Author

jrbyrnes commented Nov 8, 2023

It passes PSDB + has good CK dot4 lowering / compile time.

@jrbyrnes
Copy link
Contributor Author

Ping --

llvm/lib/Target/AMDGPU/SIISelLowering.cpp Outdated Show resolved Hide resolved
ByteProvider<SDValue> &Src0,
ByteProvider<SDValue> &Src1) {
static std::optional<bool>
checkSignedness(const SDValue &N, ByteProvider<SDValue> &Src0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you rename this function to show it's for multiplies? As far as I can tell this is just performing sign bit logic on a multiply. As in, you're effectively reproducing what computeKnownBits is doing for a multiply. Is it possible to just check the result of computeKnownBits on the top multiply?

Copy link
Contributor Author

@jrbyrnes jrbyrnes Nov 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

computeKnownBits on the top multiply does not provide sufficient information to determine dot lowering behavior for combine for all permutations.

For example, existing tests produce the following data:

S0SignBit, 	S1SignBit, 	Multiply KnownSignBit, 	Should Use Dot, 	Signed Dot
0 		?		? 			0			X
?		0 		? 			0			X
?		? 		? 			1			1

The S0SignBit and S1SignBit are sufficient to determine the behavior of combine in all cases, and we see that different expected combine behaviors map to the same sign bit of the KnownBits from top multiply. The logic of checkSignedness deviates from computeKnownBits: in checkSignedness, we know certain properties of the operands (all but the least-significant 8 bits are extension bits) and we are able to do more special case reasoning.

@jrbyrnes
Copy link
Contributor Author

Ping

@jrbyrnes
Copy link
Contributor Author

jrbyrnes commented Nov 30, 2023

No change / rebase push -- retrigger code formatting checks.

Change-Id: I97ace66cd6d60eb7f5892d762af8912374ac5a8d
@jrbyrnes jrbyrnes merged commit 1b02f59 into llvm:main Nov 30, 2023
2 of 3 checks passed
@jrbyrnes
Copy link
Contributor Author

Thanks for the review @arsenm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants