Skip to content

Commit

Permalink
Revert "[AMDGPU]: Allow combining into v_dot4" (#66158)
Browse files Browse the repository at this point in the history
This reverts commit 7fda1b7.
  • Loading branch information
jrbyrnes committed Sep 12, 2023
1 parent dfd0cd1 commit db47264
Show file tree
Hide file tree
Showing 4 changed files with 392 additions and 5,077 deletions.
317 changes: 2 additions & 315 deletions llvm/lib/Target/AMDGPU/SIISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12501,193 +12501,6 @@ SDValue SITargetLowering::tryFoldToMad64_32(SDNode *N,
return Accum;
}

// Collect the ultimate src of each of the mul24 node's operands, and confirm
// each operand is 8 bytes.
static std::optional<ByteProvider<SDValue>>
handleMulOperand(const SDValue &MulOperand) {
auto Byte0 = calculateByteProvider(MulOperand, 0, 0);
if (!Byte0 || Byte0->isConstantZero()) {
return std::nullopt;
}
auto Byte1 = calculateByteProvider(MulOperand, 1, 0);
if (Byte1 && !Byte1->isConstantZero()) {
return std::nullopt;
}
return Byte0;
}

static unsigned addPermMasks(unsigned First, unsigned Second) {
unsigned FirstCs = First & 0x0c0c0c0c;
unsigned SecondCs = Second & 0x0c0c0c0c;
unsigned FirstNoCs = First & ~0x0c0c0c0c;
unsigned SecondNoCs = Second & ~0x0c0c0c0c;

assert(FirstCs & 0xFF | SecondCs & 0xFF);
assert(FirstCs & 0xFF00 | SecondCs & 0xFF00);
assert(FirstCs & 0xFF0000 | SecondCs & 0xFF0000);
assert(FirstCs & 0xFF000000 | SecondCs & 0xFF000000);

return (FirstNoCs | SecondNoCs) | (FirstCs & SecondCs);
}

static void placeSources(ByteProvider<SDValue> &Src0,
ByteProvider<SDValue> &Src1,
SmallVectorImpl<std::pair<SDValue, unsigned>> &Src0s,
SmallVectorImpl<std::pair<SDValue, unsigned>> &Src1s,
int Step) {

assert(Src0.Src.has_value() && Src1.Src.has_value());
// Src0s and Src1s are empty, just place arbitrarily
if (Step == 0) {
Src0s.push_back({*Src0.Src, (Src0.SrcOffset << 24) + 0x0c0c0c});
Src1s.push_back({*Src1.Src, (Src1.SrcOffset << 24) + 0x0c0c0c});
return;
}

for (int BPI = 0; BPI < 2; BPI++) {
std::pair<ByteProvider<SDValue>, ByteProvider<SDValue>> BPP = {Src0, Src1};
if (BPI == 1) {
BPP = {Src1, Src0};
}
unsigned ZeroMask = 0x0c0c0c0c;
unsigned FMask = 0xFF << (8 * (3 - Step));

unsigned FirstMask =
BPP.first.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask);
unsigned SecondMask =
BPP.second.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask);
// Attempt to find Src vector which contains our SDValue, if so, add our
// perm mask to the existing one. If we are unable to find a match for the
// first SDValue, attempt to find match for the second.
int FirstGroup = -1;
for (int I = 0; I < 2; I++) {
SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs =
I == 0 ? Src0s : Src1s;
auto MatchesFirst = [&BPP](std::pair<SDValue, unsigned> IterElt) {
return IterElt.first == *BPP.first.Src;
};

auto Match = std::find_if(Srcs.begin(), Srcs.end(), MatchesFirst);
if (Match != Srcs.end()) {
Match->second = addPermMasks(FirstMask, Match->second);
FirstGroup = I;
break;
}
}
if (FirstGroup != -1) {
SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs =
FirstGroup == 1 ? Src0s : Src1s;
auto MatchesSecond = [&BPP](std::pair<SDValue, unsigned> IterElt) {
return IterElt.first == *BPP.second.Src;
};
auto Match = std::find_if(Srcs.begin(), Srcs.end(), MatchesSecond);
if (Match != Srcs.end()) {
Match->second = addPermMasks(SecondMask, Match->second);
} else
Srcs.push_back({*BPP.second.Src, SecondMask});
return;
}
}

// If we have made it here, then we could not find a match in Src0s or Src1s
// for either Src0 or Src1, so just place them arbitrarily.

unsigned ZeroMask = 0x0c0c0c0c;
unsigned FMask = 0xFF << (8 * (3 - Step));

Src0s.push_back(
{*Src0.Src, (Src0.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask))});
Src1s.push_back(
{*Src1.Src, (Src1.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask))});

return;
}

static SDValue
resolveSources(SelectionDAG &DAG, SDLoc SL,
SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs,
bool IsSigned, bool IsAny) {

// If we just have one source, just permute it accordingly.
if (Srcs.size() == 1) {
auto Elt = Srcs.begin();
auto EltVal = DAG.getBitcastedAnyExtOrTrunc(Elt->first, SL, MVT::i32);

// v_perm will produce the original value
if (Elt->second == 0x3020100)
return EltVal;

return DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltVal, EltVal,
DAG.getConstant(Elt->second, SL, MVT::i32));
}

auto FirstElt = Srcs.begin();
auto SecondElt = std::next(FirstElt);

SmallVector<SDValue, 2> Perms;

// If we have multiple sources in the chain, combine them via perms (using
// calculated perm mask) and Ors.
while (true) {
auto FirstMask = FirstElt->second;
auto SecondMask = SecondElt->second;

unsigned FirstCs = FirstMask & 0x0c0c0c0c;
unsigned FirstPlusFour = FirstMask | 0x04040404;
// 0x0c + 0x04 = 0x10, so anding with 0x0F will produced 0x00 for any
// original 0x0C
FirstMask = (FirstPlusFour & 0x0F0F0F0F) | FirstCs;

auto PermMask = addPermMasks(FirstMask, SecondMask);
auto FirstVal =
DAG.getBitcastedAnyExtOrTrunc(FirstElt->first, SL, MVT::i32);
auto SecondVal =
DAG.getBitcastedAnyExtOrTrunc(SecondElt->first, SL, MVT::i32);

Perms.push_back(DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, FirstVal,
SecondVal,
DAG.getConstant(PermMask, SL, MVT::i32)));

FirstElt = std::next(SecondElt);
if (FirstElt == Srcs.end())
break;

SecondElt = std::next(FirstElt);
// If we only have a FirstElt, then just combine that into the cumulative
// source node
if (SecondElt == Srcs.end()) {
auto EltVal =
DAG.getBitcastedAnyExtOrTrunc(FirstElt->first, SL, MVT::i32);

Perms.push_back(
DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltVal, EltVal,
DAG.getConstant(FirstElt->second, SL, MVT::i32)));
break;
}
}

assert(Perms.size() == 1 || Perms.size() == 2);
return Perms.size() == 2
? DAG.getNode(ISD::OR, SL, MVT::i32, Perms[0], Perms[1])
: Perms[0];
}

static void fixMasks(SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs,
unsigned ChainLength) {
for (auto &[EntryVal, EntryMask] : Srcs) {
EntryMask = EntryMask >> ((4 - ChainLength) * 8);
auto ZeroMask = ChainLength == 2 ? 0x0c0c0000 : 0x0c000000;
EntryMask += ZeroMask;
}
}

static bool isMul(const SDValue Op) {
auto Opcode = Op.getOpcode();

return (Opcode == ISD::MUL || Opcode == AMDGPUISD::MUL_U24 ||
Opcode == AMDGPUISD::MUL_I24);
}

SDValue SITargetLowering::performAddCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
SelectionDAG &DAG = DCI.DAG;
Expand All @@ -12701,140 +12514,14 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
if (SDValue Folded = tryFoldToMad64_32(N, DCI))
return Folded;
}

return SDValue();
}

if (SDValue V = reassociateScalarOps(N, DAG)) {
return V;
}

if ((isMul(LHS) || isMul(RHS)) && Subtarget->hasDot7Insts() &&
(Subtarget->hasDot1Insts() || Subtarget->hasDot8Insts())) {
SDValue TempNode(N, 0);
auto MulIdx = isMul(LHS) ? 0 : 1;

auto MulOpcode = TempNode.getOperand(MulIdx).getOpcode();
bool IsSigned =
MulOpcode == AMDGPUISD::MUL_I24 ||
(MulOpcode == ISD::MUL &&
TempNode->getOperand(MulIdx)->getFlags().hasNoSignedWrap() &&
!TempNode->getOperand(MulIdx)->getFlags().hasNoUnsignedWrap());
SmallVector<std::pair<SDValue, unsigned>, 4> Src0s;
SmallVector<std::pair<SDValue, unsigned>, 4> Src1s;
SmallVector<SDValue, 4> Src2s;

// Match the v_dot4 tree, while collecting src nodes.
int ChainLength = 0;
for (int I = 0; I < 4; I++) {
auto MulIdx = isMul(LHS) ? 0 : isMul(RHS) ? 1 : -1;
if (MulIdx == -1)
break;
auto IterIsSigned =
MulOpcode == AMDGPUISD::MUL_I24 ||
(MulOpcode == ISD::MUL &&
TempNode->getOperand(MulIdx)->getFlags().hasNoSignedWrap() &&
!TempNode->getOperand(MulIdx)->getFlags().hasNoUnsignedWrap());
if (IterIsSigned != IsSigned) {
break;
}
auto Src0 = handleMulOperand(TempNode->getOperand(MulIdx)->getOperand(0));
if (!Src0)
break;
auto Src1 = handleMulOperand(TempNode->getOperand(MulIdx)->getOperand(1));
if (!Src1)
break;
placeSources(*Src0, *Src1, Src0s, Src1s, I);
auto AddIdx = 1 - MulIdx;
// Allow the special case where add (add (mul24, 0), mul24) became ->
// add (mul24, mul24)
if (I == 2 && isMul(TempNode->getOperand(AddIdx))) {
Src2s.push_back(TempNode->getOperand(AddIdx));
auto Src0 =
handleMulOperand(TempNode->getOperand(AddIdx)->getOperand(0));
if (!Src0)
break;
auto Src1 =
handleMulOperand(TempNode->getOperand(AddIdx)->getOperand(1));
if (!Src1)
break;
placeSources(*Src0, *Src1, Src0s, Src1s, I + 1);
Src2s.push_back(DAG.getConstant(0, SL, MVT::i32));
ChainLength = I + 2;
break;
}

TempNode = TempNode->getOperand(AddIdx);
Src2s.push_back(TempNode);
ChainLength = I + 1;
if (TempNode->getNumOperands() < 2)
break;
LHS = TempNode->getOperand(0);
RHS = TempNode->getOperand(1);
}

if (ChainLength < 2)
return SDValue();

// Masks were constructed with assumption that we would find a chain of
// length 4. If not, then we need to 0 out the MSB bits (via perm mask of
// 0x0c) so they do not affect dot calculation.
if (ChainLength < 4) {
fixMasks(Src0s, ChainLength);
fixMasks(Src1s, ChainLength);
}

SDValue Src0, Src1;

// If we are just using a single source for both, and have permuted the
// bytes consistently, we can just use the sources without permuting
// (commutation)
bool UseOriginalSrc = false;
if (ChainLength == 4 && Src0s.size() == 1 && Src1s.size() == 1 &&
Src0s.begin()->second == Src1s.begin()->second &&
Src0s.begin()->first.getValueSizeInBits() == 32 &&
Src1s.begin()->first.getValueSizeInBits() == 32) {
SmallVector<unsigned, 4> SrcBytes;
auto Src0Mask = Src0s.begin()->second;
SrcBytes.push_back(Src0Mask & 0xFF000000);
bool UniqueEntries = true;
for (auto I = 1; I < 4; I++) {
auto NextByte = Src0Mask & (0xFF << ((3 - I) * 8));

if (is_contained(SrcBytes, NextByte)) {
UniqueEntries = false;
break;
}
SrcBytes.push_back(NextByte);
}

if (UniqueEntries) {
UseOriginalSrc = true;
// Must be 32 bits to enter above conditional
assert(Src0s.begin()->first.getValueSizeInBits() == 32);
assert(Src1s.begin()->first.getValueSizeInBits() == 32);
Src0 = DAG.getBitcast(MVT::getIntegerVT(32), Src0s.begin()->first);
Src1 = DAG.getBitcast(MVT::getIntegerVT(32), Src1s.begin()->first);
}
}

if (!UseOriginalSrc) {
Src0 = resolveSources(DAG, SL, Src0s, false, true);
Src1 = resolveSources(DAG, SL, Src1s, false, true);
}

SDValue Src2 =
DAG.getExtOrTrunc(IsSigned, Src2s[ChainLength - 1], SL, MVT::i32);

SDValue IID = DAG.getTargetConstant(IsSigned ? Intrinsic::amdgcn_sdot4
: Intrinsic::amdgcn_udot4,
SL, MVT::i64);

assert(!VT.isVector());
auto Dot = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SL, MVT::i32, IID, Src0,
Src1, Src2, DAG.getTargetConstant(0, SL, MVT::i1));

return DAG.getExtOrTrunc(IsSigned, Dot, SL, VT);
}

if (VT != MVT::i32 || !DCI.isAfterLegalizeDAG())
return SDValue();

Expand Down
23 changes: 12 additions & 11 deletions llvm/test/CodeGen/AMDGPU/idot2.ll
Original file line number Diff line number Diff line change
Expand Up @@ -2823,18 +2823,18 @@ define amdgpu_kernel void @notsdot2_sext8(ptr addrspace(1) %src1,
; GFX9-DL-NEXT: s_load_dwordx4 s[4:7], s[0:1], 0x24
; GFX9-DL-NEXT: s_load_dwordx2 s[2:3], s[0:1], 0x34
; GFX9-DL-NEXT: v_lshlrev_b32_e32 v0, 1, v0
; GFX9-DL-NEXT: s_mov_b32 s1, 0xc0c0001
; GFX9-DL-NEXT: s_waitcnt lgkmcnt(0)
; GFX9-DL-NEXT: global_load_ushort v1, v0, s[4:5]
; GFX9-DL-NEXT: global_load_ushort v2, v0, s[6:7]
; GFX9-DL-NEXT: s_load_dword s0, s[2:3], 0x0
; GFX9-DL-NEXT: v_mov_b32_e32 v0, 0
; GFX9-DL-NEXT: s_waitcnt vmcnt(1)
; GFX9-DL-NEXT: v_perm_b32 v1, v1, v1, s1
; GFX9-DL-NEXT: s_waitcnt vmcnt(0)
; GFX9-DL-NEXT: v_perm_b32 v2, v2, v2, s1
; GFX9-DL-NEXT: v_mul_i32_i24_sdwa v3, sext(v2), sext(v1) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0
; GFX9-DL-NEXT: v_lshrrev_b16_e32 v1, 8, v1
; GFX9-DL-NEXT: v_lshrrev_b16_e32 v2, 8, v2
; GFX9-DL-NEXT: v_mul_i32_i24_sdwa v1, sext(v2), sext(v1) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0
; GFX9-DL-NEXT: s_waitcnt lgkmcnt(0)
; GFX9-DL-NEXT: v_dot4_i32_i8 v1, v2, v1, s0
; GFX9-DL-NEXT: v_add3_u32 v1, v1, s0, v3
; GFX9-DL-NEXT: global_store_dword v0, v1, s[2:3]
; GFX9-DL-NEXT: s_endpgm
;
Expand All @@ -2843,20 +2843,21 @@ define amdgpu_kernel void @notsdot2_sext8(ptr addrspace(1) %src1,
; GFX10-DL-NEXT: s_load_dwordx4 s[4:7], s[0:1], 0x24
; GFX10-DL-NEXT: v_lshlrev_b32_e32 v0, 1, v0
; GFX10-DL-NEXT: s_load_dwordx2 s[0:1], s[0:1], 0x34
; GFX10-DL-NEXT: v_mov_b32_e32 v3, 0
; GFX10-DL-NEXT: s_waitcnt lgkmcnt(0)
; GFX10-DL-NEXT: s_clause 0x1
; GFX10-DL-NEXT: global_load_ushort v1, v0, s[4:5]
; GFX10-DL-NEXT: global_load_ushort v2, v0, s[6:7]
; GFX10-DL-NEXT: s_load_dword s2, s[0:1], 0x0
; GFX10-DL-NEXT: s_waitcnt vmcnt(1)
; GFX10-DL-NEXT: v_perm_b32 v0, v1, v1, 0xc0c0001
; GFX10-DL-NEXT: v_lshrrev_b16 v0, 8, v1
; GFX10-DL-NEXT: s_waitcnt vmcnt(0)
; GFX10-DL-NEXT: v_perm_b32 v1, v2, v2, 0xc0c0001
; GFX10-DL-NEXT: v_lshrrev_b16 v3, 8, v2
; GFX10-DL-NEXT: v_mul_i32_i24_sdwa v1, sext(v2), sext(v1) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0
; GFX10-DL-NEXT: v_mov_b32_e32 v2, 0
; GFX10-DL-NEXT: v_mul_i32_i24_sdwa v0, sext(v3), sext(v0) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0
; GFX10-DL-NEXT: s_waitcnt lgkmcnt(0)
; GFX10-DL-NEXT: v_mov_b32_e32 v2, s2
; GFX10-DL-NEXT: v_dot4c_i32_i8_e32 v2, v1, v0
; GFX10-DL-NEXT: global_store_dword v3, v2, s[0:1]
; GFX10-DL-NEXT: v_add3_u32 v0, v0, s2, v1
; GFX10-DL-NEXT: global_store_dword v2, v0, s[0:1]
; GFX10-DL-NEXT: s_endpgm
ptr addrspace(1) %src2,
ptr addrspace(1) nocapture %dst) {
Expand Down

0 comments on commit db47264

Please sign in to comment.