-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMDGPU] Rework dot4 signedness checks #68757
Conversation
@llvm/pr-subscribers-backend-amdgpu Author: Jeffrey Byrnes (jrbyrnes) ChangesRely on AMDGPUISD::*_MUL for most cases -- as signedness semantics have already been calculated and encoded into these ops. In cases where we have an ISD::MUL, try harder to reason about the semantics given the signedness info in the ByteProvider. Solves some edge cases. Patch is 23.02 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68757.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 9bd0f5390b19e31..f9cb443f3223f72 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -12952,30 +12952,51 @@ static bool isMul(const SDValue Op) {
static std::optional<bool> checkSignedness(const SDValue &N,
ByteProvider<SDValue> &Src0,
- ByteProvider<SDValue> &Src1) {
+ ByteProvider<SDValue> &Src1,
+ const SDValue &S1Op,
+ const SDValue &S0Op) {
auto MulOpcode = N.getOpcode();
- std::optional<bool> IterIsSigned;
- // Both sides of the tree must have the same signedness semantics.
- if ((Src0.IsSigned != Src1.IsSigned) ||
- (Src0.IsSigned.value_or(false) != Src1.IsSigned.value_or(false)))
- return IterIsSigned;
- // If we have a MUL_U24 op with signed semantics, then fail.
- if (Src0.IsSigned.value_or(false) && MulOpcode == AMDGPUISD::MUL_U24)
- return IterIsSigned;
- // If we have a MUL_I24 op with unsigned semantics, then fail.
- if (!Src0.IsSigned.value_or(true) && MulOpcode == AMDGPUISD::MUL_I24)
- return IterIsSigned;
-
- bool TopLevelSignedness =
- MulOpcode == AMDGPUISD::MUL_I24 ||
- (MulOpcode == ISD::MUL && N.getNode()->getFlags().hasNoSignedWrap() &&
- !N.getNode()->getFlags().hasNoUnsignedWrap());
-
- // In cases where we are accumulating into an i8 (for v_dot4), the
- // ByteProvider will not have signedness info since the MSBs are dont-cares.
- // In this case, we simply use the TopLevelSignedness of the instruction.
- IterIsSigned = Src0.IsSigned.value_or(TopLevelSignedness);
- return IterIsSigned;
+
+ // We have previously determined the signedness semantics
+ if (MulOpcode == AMDGPUISD::MUL_U24 || MulOpcode == AMDGPUISD::MUL_I24)
+ return MulOpcode == AMDGPUISD::MUL_I24;
+
+ // We don't know the signedness semantics, try harder to determine.
+
+ // Case1: both Srcs have signedness info, use it.
+ if (Src0.IsSigned && Src1.IsSigned)
+ return *Src0.IsSigned == *Src1.IsSigned
+ ? std::optional<bool>(*Src0.IsSigned)
+ : std::nullopt;
+
+ // Case2: we are missing signedness info from one.
+ // We can determine unsigned semantics if one has unsigned info and the other
+ // has leading zeros.
+ if (Src0.IsSigned != Src1.IsSigned) {
+ auto SrcIsSigned = Src0.IsSigned ? *Src0.IsSigned : *Src1.IsSigned;
+ // If we only have signed info, then fail
+ if (SrcIsSigned)
+ return std::nullopt;
+ bool Mismatch = false;
+ auto OtherOp = Src0.IsSigned ? S1Op : S0Op;
+ for (int I = 1; I < OtherOp.getValueSizeInBits() / 8; I++) {
+ auto BP = calculateByteProvider(OtherOp, I, 0);
+ if (!BP->isConstantZero()) {
+ Mismatch = true;
+ break;
+ }
+ }
+
+ // If we are missing unsigned semantics, but all the MSBs are zero, then it
+ // is okay.
+ return Mismatch ? std::nullopt : std::optional<bool>(false);
+ }
+
+ // Case3: neither src has signedness info.
+ // In this case, we either any_extended and will mask out the MSBs, or we
+ // truly don't care about the MSBs. Either way, using unsigned semantics is
+ // correct.
+ return false;
}
SDValue SITargetLowering::performAddCombine(SDNode *N,
@@ -13019,7 +13040,9 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
break;
auto IterIsSigned =
- checkSignedness(TempNode->getOperand(MulIdx), *Src0, *Src1);
+ checkSignedness(TempNode->getOperand(MulIdx), *Src0, *Src1,
+ TempNode->getOperand(MulIdx)->getOperand(0),
+ TempNode->getOperand(MulIdx)->getOperand(1));
if (!IterIsSigned)
break;
if (!IsSigned)
@@ -13041,7 +13064,9 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
if (!Src1)
break;
auto IterIsSigned =
- checkSignedness(TempNode->getOperand(AddIdx), *Src0, *Src1);
+ checkSignedness(TempNode->getOperand(AddIdx), *Src0, *Src1,
+ TempNode->getOperand(AddIdx)->getOperand(0),
+ TempNode->getOperand(AddIdx)->getOperand(1));
if (!IterIsSigned)
break;
assert(IsSigned);
diff --git a/llvm/test/CodeGen/AMDGPU/idot4u.ll b/llvm/test/CodeGen/AMDGPU/idot4u.ll
index a82c5215f3b2c65..ba9ccb4c636f948 100644
--- a/llvm/test/CodeGen/AMDGPU/idot4u.ll
+++ b/llvm/test/CodeGen/AMDGPU/idot4u.ll
@@ -1735,6 +1735,268 @@ entry:
ret void
}
+
+define amdgpu_kernel void @notdot4_mixedtypes2(ptr addrspace(1) %src1,
+; GFX7-LABEL: notdot4_mixedtypes2:
+; GFX7: ; %bb.0: ; %entry
+; GFX7-NEXT: s_load_dwordx4 s[4:7], s[0:1], 0x9
+; GFX7-NEXT: s_load_dwordx2 s[0:1], s[0:1], 0xd
+; GFX7-NEXT: s_mov_b32 s3, 0xf000
+; GFX7-NEXT: s_mov_b32 s10, 0
+; GFX7-NEXT: s_mov_b32 s11, s3
+; GFX7-NEXT: s_waitcnt lgkmcnt(0)
+; GFX7-NEXT: s_mov_b64 s[8:9], s[4:5]
+; GFX7-NEXT: v_lshlrev_b32_e32 v0, 2, v0
+; GFX7-NEXT: v_mov_b32_e32 v1, 0
+; GFX7-NEXT: buffer_load_dword v2, v[0:1], s[8:11], 0 addr64
+; GFX7-NEXT: s_mov_b64 s[8:9], s[6:7]
+; GFX7-NEXT: buffer_load_dword v0, v[0:1], s[8:11], 0 addr64
+; GFX7-NEXT: s_mov_b32 s2, -1
+; GFX7-NEXT: buffer_load_ushort v1, off, s[0:3], 0
+; GFX7-NEXT: s_waitcnt vmcnt(2)
+; GFX7-NEXT: v_bfe_i32 v3, v2, 0, 8
+; GFX7-NEXT: v_bfe_u32 v4, v2, 8, 8
+; GFX7-NEXT: s_waitcnt vmcnt(1)
+; GFX7-NEXT: v_bfe_i32 v7, v0, 8, 8
+; GFX7-NEXT: v_and_b32_e32 v7, 0xffff, v7
+; GFX7-NEXT: v_bfe_i32 v5, v2, 16, 8
+; GFX7-NEXT: v_and_b32_e32 v3, 0xffff, v3
+; GFX7-NEXT: v_and_b32_e32 v6, 0xff, v0
+; GFX7-NEXT: s_waitcnt vmcnt(0)
+; GFX7-NEXT: v_mad_u32_u24 v1, v4, v7, v1
+; GFX7-NEXT: v_and_b32_e32 v5, 0xffff, v5
+; GFX7-NEXT: v_bfe_u32 v8, v0, 16, 8
+; GFX7-NEXT: v_ashrrev_i32_e32 v0, 24, v0
+; GFX7-NEXT: v_mad_u32_u24 v1, v3, v6, v1
+; GFX7-NEXT: v_lshrrev_b32_e32 v2, 24, v2
+; GFX7-NEXT: v_and_b32_e32 v0, 0xffff, v0
+; GFX7-NEXT: v_mad_u32_u24 v1, v5, v8, v1
+; GFX7-NEXT: v_mad_u32_u24 v0, v2, v0, v1
+; GFX7-NEXT: buffer_store_short v0, off, s[0:3], 0
+; GFX7-NEXT: s_endpgm
+;
+; GFX8-LABEL: notdot4_mixedtypes2:
+; GFX8: ; %bb.0: ; %entry
+; GFX8-NEXT: s_load_dwordx4 s[4:7], s[0:1], 0x24
+; GFX8-NEXT: s_load_dwordx2 s[0:1], s[0:1], 0x34
+; GFX8-NEXT: v_lshlrev_b32_e32 v2, 2, v0
+; GFX8-NEXT: v_mov_b32_e32 v5, 0xff
+; GFX8-NEXT: s_waitcnt lgkmcnt(0)
+; GFX8-NEXT: v_mov_b32_e32 v1, s5
+; GFX8-NEXT: v_add_u32_e32 v0, vcc, s4, v2
+; GFX8-NEXT: v_addc_u32_e32 v1, vcc, 0, v1, vcc
+; GFX8-NEXT: flat_load_dword v3, v[0:1]
+; GFX8-NEXT: v_mov_b32_e32 v1, s7
+; GFX8-NEXT: v_add_u32_e32 v0, vcc, s6, v2
+; GFX8-NEXT: v_addc_u32_e32 v1, vcc, 0, v1, vcc
+; GFX8-NEXT: flat_load_dword v2, v[0:1]
+; GFX8-NEXT: v_mov_b32_e32 v0, s0
+; GFX8-NEXT: v_mov_b32_e32 v1, s1
+; GFX8-NEXT: flat_load_ushort v4, v[0:1]
+; GFX8-NEXT: s_waitcnt vmcnt(2)
+; GFX8-NEXT: v_lshrrev_b32_e32 v9, 8, v3
+; GFX8-NEXT: v_and_b32_e32 v9, 0xff, v9
+; GFX8-NEXT: v_lshrrev_b32_e32 v6, 16, v3
+; GFX8-NEXT: v_bfe_i32 v7, v3, 0, 8
+; GFX8-NEXT: v_bfe_i32 v6, v6, 0, 8
+; GFX8-NEXT: v_lshrrev_b32_e32 v3, 24, v3
+; GFX8-NEXT: s_waitcnt vmcnt(1)
+; GFX8-NEXT: v_lshrrev_b32_e32 v10, 8, v2
+; GFX8-NEXT: v_bfe_i32 v10, v10, 0, 8
+; GFX8-NEXT: v_and_b32_e32 v8, 0xff, v2
+; GFX8-NEXT: s_waitcnt vmcnt(0)
+; GFX8-NEXT: v_mad_u16 v4, v9, v10, v4
+; GFX8-NEXT: v_and_b32_sdwa v5, v2, v5 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD
+; GFX8-NEXT: v_lshrrev_b32_e32 v2, 24, v2
+; GFX8-NEXT: v_mad_u16 v4, v7, v8, v4
+; GFX8-NEXT: v_bfe_i32 v2, v2, 0, 8
+; GFX8-NEXT: v_mad_u16 v4, v6, v5, v4
+; GFX8-NEXT: v_mad_u16 v2, v3, v2, v4
+; GFX8-NEXT: flat_store_short v[0:1], v2
+; GFX8-NEXT: s_endpgm
+;
+; GFX9-NODL-LABEL: notdot4_mixedtypes2:
+; GFX9-NODL: ; %bb.0: ; %entry
+; GFX9-NODL-NEXT: s_load_dwordx4 s[4:7], s[0:1], 0x24
+; GFX9-NODL-NEXT: s_load_dwordx2 s[2:3], s[0:1], 0x34
+; GFX9-NODL-NEXT: v_lshlrev_b32_e32 v0, 2, v0
+; GFX9-NODL-NEXT: s_movk_i32 s0, 0xff
+; GFX9-NODL-NEXT: s_waitcnt lgkmcnt(0)
+; GFX9-NODL-NEXT: global_load_dword v1, v0, s[4:5]
+; GFX9-NODL-NEXT: global_load_dword v2, v0, s[6:7]
+; GFX9-NODL-NEXT: v_mov_b32_e32 v0, 0
+; GFX9-NODL-NEXT: global_load_ushort v3, v0, s[2:3]
+; GFX9-NODL-NEXT: s_waitcnt vmcnt(2)
+; GFX9-NODL-NEXT: v_lshrrev_b32_e32 v7, 8, v1
+; GFX9-NODL-NEXT: s_waitcnt vmcnt(1)
+; GFX9-NODL-NEXT: v_lshrrev_b32_e32 v8, 8, v2
+; GFX9-NODL-NEXT: v_and_b32_e32 v7, 0xff, v7
+; GFX9-NODL-NEXT: v_bfe_i32 v8, v8, 0, 8
+; GFX9-NODL-NEXT: v_lshrrev_b32_e32 v4, 16, v1
+; GFX9-NODL-NEXT: v_bfe_i32 v5, v1, 0, 8
+; GFX9-NODL-NEXT: v_and_b32_e32 v6, 0xff, v2
+; GFX9-NODL-NEXT: s_waitcnt vmcnt(0)
+; GFX9-NODL-NEXT: v_mad_legacy_u16 v3, v7, v8, v3
+; GFX9-NODL-NEXT: v_and_b32_sdwa v9, v2, s0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD
+; GFX9-NODL-NEXT: v_lshrrev_b32_e32 v2, 24, v2
+; GFX9-NODL-NEXT: v_bfe_i32 v4, v4, 0, 8
+; GFX9-NODL-NEXT: v_mad_legacy_u16 v3, v5, v6, v3
+; GFX9-NODL-NEXT: v_lshrrev_b32_e32 v1, 24, v1
+; GFX9-NODL-NEXT: v_bfe_i32 v2, v2, 0, 8
+; GFX9-NODL-NEXT: v_mad_legacy_u16 v3, v4, v9, v3
+; GFX9-NODL-NEXT: v_mad_legacy_u16 v1, v1, v2, v3
+; GFX9-NODL-NEXT: global_store_short v0, v1, s[2:3]
+; GFX9-NODL-NEXT: s_endpgm
+;
+; GFX9-DL-LABEL: notdot4_mixedtypes2:
+; GFX9-DL: ; %bb.0: ; %entry
+; GFX9-DL-NEXT: s_load_dwordx4 s[4:7], s[0:1], 0x24
+; GFX9-DL-NEXT: s_load_dwordx2 s[2:3], s[0:1], 0x34
+; GFX9-DL-NEXT: v_lshlrev_b32_e32 v0, 2, v0
+; GFX9-DL-NEXT: s_movk_i32 s0, 0xff
+; GFX9-DL-NEXT: s_waitcnt lgkmcnt(0)
+; GFX9-DL-NEXT: global_load_dword v1, v0, s[4:5]
+; GFX9-DL-NEXT: global_load_dword v2, v0, s[6:7]
+; GFX9-DL-NEXT: v_mov_b32_e32 v0, 0
+; GFX9-DL-NEXT: global_load_ushort v3, v0, s[2:3]
+; GFX9-DL-NEXT: s_waitcnt vmcnt(2)
+; GFX9-DL-NEXT: v_lshrrev_b32_e32 v7, 8, v1
+; GFX9-DL-NEXT: s_waitcnt vmcnt(1)
+; GFX9-DL-NEXT: v_lshrrev_b32_e32 v8, 8, v2
+; GFX9-DL-NEXT: v_and_b32_e32 v7, 0xff, v7
+; GFX9-DL-NEXT: v_bfe_i32 v8, v8, 0, 8
+; GFX9-DL-NEXT: v_lshrrev_b32_e32 v4, 16, v1
+; GFX9-DL-NEXT: v_bfe_i32 v5, v1, 0, 8
+; GFX9-DL-NEXT: v_and_b32_e32 v6, 0xff, v2
+; GFX9-DL-NEXT: s_waitcnt vmcnt(0)
+; GFX9-DL-NEXT: v_mad_legacy_u16 v3, v7, v8, v3
+; GFX9-DL-NEXT: v_and_b32_sdwa v9, v2, s0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD
+; GFX9-DL-NEXT: v_lshrrev_b32_e32 v2, 24, v2
+; GFX9-DL-NEXT: v_bfe_i32 v4, v4, 0, 8
+; GFX9-DL-NEXT: v_mad_legacy_u16 v3, v5, v6, v3
+; GFX9-DL-NEXT: v_lshrrev_b32_e32 v1, 24, v1
+; GFX9-DL-NEXT: v_bfe_i32 v2, v2, 0, 8
+; GFX9-DL-NEXT: v_mad_legacy_u16 v3, v4, v9, v3
+; GFX9-DL-NEXT: v_mad_legacy_u16 v1, v1, v2, v3
+; GFX9-DL-NEXT: global_store_short v0, v1, s[2:3]
+; GFX9-DL-NEXT: s_endpgm
+;
+; GFX10-DL-LABEL: notdot4_mixedtypes2:
+; GFX10-DL: ; %bb.0: ; %entry
+; GFX10-DL-NEXT: s_clause 0x1
+; GFX10-DL-NEXT: s_load_dwordx4 s[4:7], s[0:1], 0x24
+; GFX10-DL-NEXT: s_load_dwordx2 s[2:3], s[0:1], 0x34
+; GFX10-DL-NEXT: v_lshlrev_b32_e32 v0, 2, v0
+; GFX10-DL-NEXT: v_mov_b32_e32 v8, 0xff
+; GFX10-DL-NEXT: s_waitcnt lgkmcnt(0)
+; GFX10-DL-NEXT: s_clause 0x1
+; GFX10-DL-NEXT: global_load_dword v1, v0, s[4:5]
+; GFX10-DL-NEXT: global_load_dword v2, v0, s[6:7]
+; GFX10-DL-NEXT: v_mov_b32_e32 v0, 0
+; GFX10-DL-NEXT: global_load_ushort v3, v0, s[2:3]
+; GFX10-DL-NEXT: s_waitcnt vmcnt(2)
+; GFX10-DL-NEXT: v_lshrrev_b32_e32 v4, 8, v1
+; GFX10-DL-NEXT: s_waitcnt vmcnt(1)
+; GFX10-DL-NEXT: v_lshrrev_b32_e32 v5, 8, v2
+; GFX10-DL-NEXT: v_lshrrev_b32_e32 v6, 16, v1
+; GFX10-DL-NEXT: v_bfe_i32 v7, v1, 0, 8
+; GFX10-DL-NEXT: v_and_b32_e32 v9, 0xff, v2
+; GFX10-DL-NEXT: v_and_b32_e32 v4, 0xff, v4
+; GFX10-DL-NEXT: v_bfe_i32 v5, v5, 0, 8
+; GFX10-DL-NEXT: v_lshrrev_b32_e32 v1, 24, v1
+; GFX10-DL-NEXT: s_waitcnt vmcnt(0)
+; GFX10-DL-NEXT: v_mad_u16 v3, v4, v5, v3
+; GFX10-DL-NEXT: v_bfe_i32 v4, v6, 0, 8
+; GFX10-DL-NEXT: v_and_b32_sdwa v5, v2, v8 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD
+; GFX10-DL-NEXT: v_lshrrev_b32_e32 v2, 24, v2
+; GFX10-DL-NEXT: v_mad_u16 v3, v7, v9, v3
+; GFX10-DL-NEXT: v_bfe_i32 v2, v2, 0, 8
+; GFX10-DL-NEXT: v_mad_u16 v3, v4, v5, v3
+; GFX10-DL-NEXT: v_mad_u16 v1, v1, v2, v3
+; GFX10-DL-NEXT: global_store_short v0, v1, s[2:3]
+; GFX10-DL-NEXT: s_endpgm
+;
+; GFX11-DL-LABEL: notdot4_mixedtypes2:
+; GFX11-DL: ; %bb.0: ; %entry
+; GFX11-DL-NEXT: s_clause 0x1
+; GFX11-DL-NEXT: s_load_b128 s[4:7], s[0:1], 0x24
+; GFX11-DL-NEXT: s_load_b64 s[0:1], s[0:1], 0x34
+; GFX11-DL-NEXT: v_lshlrev_b32_e32 v0, 2, v0
+; GFX11-DL-NEXT: s_waitcnt lgkmcnt(0)
+; GFX11-DL-NEXT: s_clause 0x1
+; GFX11-DL-NEXT: global_load_b32 v1, v0, s[4:5]
+; GFX11-DL-NEXT: global_load_b32 v0, v0, s[6:7]
+; GFX11-DL-NEXT: v_mov_b32_e32 v2, 0
+; GFX11-DL-NEXT: s_waitcnt vmcnt(1)
+; GFX11-DL-NEXT: v_lshrrev_b32_e32 v4, 8, v1
+; GFX11-DL-NEXT: s_waitcnt vmcnt(0)
+; GFX11-DL-NEXT: v_and_b32_e32 v9, 0xff, v0
+; GFX11-DL-NEXT: global_load_u16 v3, v2, s[0:1]
+; GFX11-DL-NEXT: v_lshrrev_b32_e32 v5, 8, v0
+; GFX11-DL-NEXT: v_lshrrev_b32_e32 v6, 16, v1
+; GFX11-DL-NEXT: v_and_b32_e32 v4, 0xff, v4
+; GFX11-DL-NEXT: v_lshrrev_b32_e32 v7, 16, v0
+; GFX11-DL-NEXT: v_bfe_i32 v8, v1, 0, 8
+; GFX11-DL-NEXT: v_bfe_i32 v5, v5, 0, 8
+; GFX11-DL-NEXT: v_lshrrev_b32_e32 v0, 24, v0
+; GFX11-DL-NEXT: v_lshrrev_b32_e32 v1, 24, v1
+; GFX11-DL-NEXT: s_delay_alu instid0(VALU_DEP_2) | instskip(SKIP_4) | instid1(VALU_DEP_3)
+; GFX11-DL-NEXT: v_bfe_i32 v0, v0, 0, 8
+; GFX11-DL-NEXT: s_waitcnt vmcnt(0)
+; GFX11-DL-NEXT: v_mad_u16 v3, v4, v5, v3
+; GFX11-DL-NEXT: v_bfe_i32 v4, v6, 0, 8
+; GFX11-DL-NEXT: v_and_b32_e32 v5, 0xff, v7
+; GFX11-DL-NEXT: v_mad_u16 v3, v8, v9, v3
+; GFX11-DL-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX11-DL-NEXT: v_mad_u16 v3, v4, v5, v3
+; GFX11-DL-NEXT: v_mad_u16 v0, v1, v0, v3
+; GFX11-DL-NEXT: global_store_b16 v2, v0, s[0:1]
+; GFX11-DL-NEXT: s_nop 0
+; GFX11-DL-NEXT: s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
+; GFX11-DL-NEXT: s_endpgm
+ ptr addrspace(1) %src2,
+ ptr addrspace(1) nocapture %dst) {
+entry:
+ %idx = call i32 @llvm.amdgcn.workitem.id.x()
+ %gep1 = getelementptr <4 x i8>, ptr addrspace(1) %src1, i32 %idx
+ %vec1 = load <4 x i8>, ptr addrspace(1) %gep1
+ %gep2 = getelementptr <4 x i8>, ptr addrspace(1) %src2, i32 %idx
+ %vec2 = load <4 x i8>, ptr addrspace(1) %gep2
+
+ %v1e0 = extractelement <4 x i8> %vec1, i64 0
+ %cv1e0 = sext i8 %v1e0 to i16
+ %v2e0 = extractelement <4 x i8> %vec2, i64 0
+ %cv2e0 = zext i8 %v2e0 to i16
+ %mul1 = mul nuw nsw i16 %cv1e0, %cv2e0
+
+ %v1e1 = extractelement <4 x i8> %vec1, i64 1
+ %cv1e1 = zext i8 %v1e1 to i16
+ %v2e1 = extractelement <4 x i8> %vec2, i64 1
+ %cv2e1 = sext i8 %v2e1 to i16
+ %mul2 = mul nuw nsw i16 %cv1e1, %cv2e1
+
+ %v1e2 = extractelement <4 x i8> %vec1, i64 2
+ %cv1e2 = sext i8 %v1e2 to i16
+ %v2e2 = extractelement <4 x i8> %vec2, i64 2
+ %cv2e2 = zext i8 %v2e2 to i16
+ %mul3 = mul nuw nsw i16 %cv1e2, %cv2e2
+
+ %v1e3 = extractelement <4 x i8> %vec1, i64 3
+ %cv1e3 = zext i8 %v1e3 to i16
+ %v2e3 = extractelement <4 x i8> %vec2, i64 3
+ %cv2e3 = sext i8 %v2e3 to i16
+ %mul4 = mul nuw nsw i16 %cv1e3, %cv2e3
+
+ %acc = load i16, ptr addrspace(1) %dst, align 2
+ %add1 = add i16 %mul2, %acc
+ %add2 = add i16 %add1, %mul1
+ %add3 = add i16 %add2, %mul3
+ %add4 = add i16 %add3, %mul4
+
+ store i16 %add4, ptr addrspace(1) %dst, align 2
+ ret void
+}
+
; TODO: cleanup s_lshr_b32
define amdgpu_kernel void @udot4_acc32_vecMul(ptr addrspace(1) %src1,
; GFX7-LABEL: udot4_acc32_vecMul:
@@ -4622,4 +4884,170 @@ entry:
ret void
}
+define amdgpu_kernel void @idot4_acc32_anyext(ptr addrspace(1) %src1,
+; GFX7-LABEL: idot4_acc32_anyext:
+; GFX7: ; %bb.0: ; %entry
+; GFX7-NEXT: s_load_dwordx4 s[4:7], s[0:1], 0x9
+; GFX7-NEXT: s_load_dwordx2 s[0:1], s[0:1], 0xd
+; GFX7-NEXT: s_mov_b32 s3, 0xf000
+; GFX7-NEXT: s_mov_b32 s10, 0
+; GFX7-NEXT: s_mov_b32 s11, s3
+; GFX7-NEXT: s_waitcnt lgkmcnt(0)
+; GFX7-NEXT: s_mov_b64 s[8:9], s[4:5]
+; GFX7-NEXT: v_lshlrev_b32_e32 v0, 2, v0
+; GFX7-NEXT: v_mov_b32_e32 v1, 0
+; GFX7-NEXT: buffer_load_dword v2, v[0:1], s[8:11], 0 addr64
+; GFX7-NEXT: s_mov_b64 s[8:9], s[6:7]
+; GFX7-NEXT: buffer_load_dword v0, v[0:1], s[8:11], 0 addr64
+; GFX7-NEXT: s_load_dword s4, s[0:1], 0x0
+; GFX7-NEXT: s_mov_b32 s2, -1
+; GFX7-NEXT: s_waitcnt vmcnt(1)
+; GFX7-NEXT: v_and_b32_e32 v1, 0xff, v2
+; GFX7-NEXT: v_bfe_u32 v2, v2, 8, 8
+; GFX7-NEXT: s_waitcnt vmcnt(0)
+; GFX7-NEXT: v_bfe_u32 v0, v0, 8, 8
+; GFX7-NEXT: s_waitcnt lgkmcnt(0)
+; GFX7-NEXT: v_mad_u32_u24 v1, v1, v1, s4
+; GFX7-NEXT: v_mad_u32_u24 v0, v2, v0, v1
+; GFX7-NEXT: buffer_store_dword v0, off, s[0:3], 0
+; GFX7-NEXT: s_endpgm
+;
+; GFX8-LABEL: idot4_acc32_anyext:
+; GFX8: ; %bb.0: ; %entry
+; GFX8-NEXT: s_load_dwordx4 s[4:7], s[0:1], 0x24
+; GFX8-NEXT: s_load_dwordx2 s[0:1], s[0:1], 0x34
+; GFX8-NEXT: v_lshlrev_b32_e32 v2, 2, v0
+; GFX8-NEXT: s_waitcnt lgkmcnt(0)
+; GFX8-NEXT: v_mov_b32_e32 v1, s5
+; GFX8-NEXT: v_add_u32_e32 v0, vcc, s4, v2
+; GFX8-NEXT: v_addc_u32_e32 v1, vcc, 0, v1, vcc
+; GFX8-NEXT: flat_load_dword v3, v[0:1]
+; GFX8-NEXT: v_mov_b32_e32 v1, s7
+; GFX8-NEXT: v_add_u32_e32 v0, vcc, s6, v2
+; GFX8-NEXT: v_addc_u32_e32 v1, vcc, 0, v1, vcc
+; GFX8-NEXT: flat_load_dword v0, v[0:1]
+; GFX8-NEXT: s_load_dword s2, s[0:1], 0x0
+; GFX8-NEXT: s_waitcnt vmcnt(1)
+; GFX8-NEXT: v_and_b32_e32 v1, 0xff, v3
+; GFX8-NEXT: v_bfe_u32 v2, v3, 8, 8
+; GFX8-NEXT: s_waitcnt lgkmcnt(0)
+; GFX8-NEXT: v_mad_u32_u24 v1, v1, v1, s2
+; GFX8-NEXT: s_waitcnt vmcnt(0)
+; GFX8-NEXT: v_bfe_u32 v0, v0, 8, 8
+; GFX8-NEXT: v_mad_u32_u24 v2, v2, v0, v1
+; GFX8-NEXT: v_mov_b32_e32 v0, s0
+; GFX8-NEXT: v_mov_b32_e32 v1, s1
+; GFX8-NEXT: flat_store_dword v[0:1], v2
+; GFX8-NEXT: s_endpgm
+;
+; GFX9-NODL-LABEL: idot4_acc32_anyext:
+; GFX9-NODL: ; %bb.0: ; %entry
+; GFX9-NODL-NEXT: s_load_dwordx4 s[4:7], s[0:1], 0x24
+; GFX9-NODL-NEXT: s_load_dwordx2 s[2:3], s[0:1], 0x34
+; GFX9-NODL-NEXT: v_lshlrev_b32_e32 v0, 2, v0
+; GFX9-NODL-NEXT: s_waitcnt lgkmcnt(0)
+; GFX9-NODL-NEXT: global_load_dword v1, v0, s[4:5]
+; GFX9-NODL-NEXT: global_load_dword v2, v0, s[6:7]
+; GFX9-NODL-NEXT: s_load_dword s0, s[2:3], 0x0
+; GFX9-NODL-NEXT: v_mov_b32_e32 v0, 0
+; GFX9-NODL-NEXT: s_waitcnt vmcnt(1)
+; GFX9-NODL-NEXT: v_mul_u32_u24_sdwa v3, v1, v1 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0
+; GFX9-NODL-NEXT: s_waitcnt vmcnt(0)
+; GFX9-NODL-NEXT: v_mul_u32_u24_sdwa v1, v1, v2 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1
+; GFX9-NODL-NEXT: s_waitcnt lgkmcnt(0)
+; GFX9-NODL-NEXT: v_add3_u32 v1, v3, s0, v1
+; GFX9-NODL-NEXT: global_store_dword v0, v1, s[2:3]
+; GFX9-NODL-NEXT: s_endpgm
+;
+; GFX9-DL-LABEL: idot4_acc32_anyext:
+; GFX9-DL: ; %bb.0: ; %entry
+; GFX9-DL-NEXT: s_load_dwordx4 s[4:7], s[0:1], 0x24
+; GFX9-DL-NEXT: s_load_dwordx2 s[2:3], s[0:1], 0x34
+; GFX9-DL-NEXT: v_lshlrev_b32_e32 v0, 2, v0
+; GFX9-DL-NEXT: s_mov_b32 s1, 0xc0c0500
+; GFX9-DL-NEXT: s_waitcnt lgkmcnt(0)
+; GFX9-DL-NEXT: global_load_dword v1, v0, s[4:5]
+; GFX9-DL-NEXT: global_load_dword v2, v0, s[6:7]
+; GFX9-DL-NEXT: s_load_dword s0, s[2:3], 0x0
+; GFX9-DL-NEXT: s_mov_b32 s4, 0xc0c0100
+; GFX9-DL-NEXT: v_mov_b32_e32 v0, 0
+; GFX9-DL-NEXT: s_waitcnt vmcnt(0)
+; GFX9-DL-NEXT: v_perm_b32 v2, v2, v1, s1
+; GFX9-DL-NEXT: v_perm_b32 v1, v1, v1, s4
+; GFX9-DL-NEXT: s_waitcnt lgkmcnt(0)
+; GFX9-DL-NEXT: v_...
[truncated]
|
c3a12c8
to
dc80000
Compare
dc80000
to
6f37de9
Compare
Passes PSDB |
6f37de9
to
2dc29fd
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
2dc29fd
to
8955321
Compare
A more complete rework of the signedness semantics matching to remove ByteProvider.IsSigned. Use computeKnownBits as suggested by @arsenm . Special case when we don't know the sign bit for one or both ops: Given that we have valid ByteProviders for each op in the candidate mul, the upper bits must be extension bits. Thus, if we don't know the sign bit, one of two things must have occurred: 1. sign extend from unknown bit, 2. any extend. In either case, we can use the sign extend version of dot4 for this op. Thus, if the other op is also unknown, or known to have its sign bit set (this implies it must have been sign extended), we can use the signed version of dot. A draft while I test the new implementation. |
8955321
to
15fb49e
Compare
It passes PSDB + has good CK dot4 lowering / compile time. |
Ping -- |
ByteProvider<SDValue> &Src0, | ||
ByteProvider<SDValue> &Src1) { | ||
static std::optional<bool> | ||
checkSignedness(const SDValue &N, ByteProvider<SDValue> &Src0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you rename this function to show it's for multiplies? As far as I can tell this is just performing sign bit logic on a multiply. As in, you're effectively reproducing what computeKnownBits is doing for a multiply. Is it possible to just check the result of computeKnownBits on the top multiply?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
computeKnownBits on the top multiply does not provide sufficient information to determine dot lowering behavior for combine for all permutations.
For example, existing tests produce the following data:
S0SignBit, S1SignBit, Multiply KnownSignBit, Should Use Dot, Signed Dot
0 ? ? 0 X
? 0 ? 0 X
? ? ? 1 1
The S0SignBit and S1SignBit are sufficient to determine the behavior of combine in all cases, and we see that different expected combine behaviors map to the same sign bit of the KnownBits from top multiply. The logic of checkSignedness deviates from computeKnownBits: in checkSignedness, we know certain properties of the operands (all but the least-significant 8 bits are extension bits) and we are able to do more special case reasoning.
15fb49e
to
25c7d33
Compare
Ping |
25c7d33
to
873d37c
Compare
No change / rebase push -- retrigger code formatting checks. |
Change-Id: I97ace66cd6d60eb7f5892d762af8912374ac5a8d
873d37c
to
1f26cf8
Compare
Thanks for the review @arsenm |
Rely on AMDGPUISD::*_MUL for most cases -- as signedness semantics have already been calculated and encoded into these ops.
In cases where we have an ISD::MUL, try harder to reason about the semantics given the signedness info in the ByteProvider.
Solves some edge cases.