Skip to content

Commit

Permalink
[AMDGPU] Match udot4 pattern.
Browse files Browse the repository at this point in the history
Summary: D.u32 = S0.u8[0] * S1.u8[0] +
                 S0.u8[1] * S1.u8[1] +
                 S0.u8[2] * S1.u8[2] +
                 S0.u8[3] * S1.u8[3] + S2.u32

Author: FarhanaAleen

Reviewed By: arsenm

Subscribers: llvm-commits, AMDGPU

Differential Revision: https://reviews.llvm.org/D50921

llvm-svn: 340936
  • Loading branch information
Farhana Aleen committed Aug 29, 2018
1 parent bc2f06c commit 9250c92
Show file tree
Hide file tree
Showing 2 changed files with 663 additions and 0 deletions.
39 changes: 39 additions & 0 deletions llvm/lib/Target/AMDGPU/VOP3PInstructions.td
Expand Up @@ -165,6 +165,39 @@ def V_FMA_MIXHI_F16 : VOP3_VOP3PInst<"v_fma_mixhi_f16", VOP3_Profile<VOP_F16_F16
defm : MadFmaMixPats<fma, V_FMA_MIX_F32, V_FMA_MIXLO_F16, V_FMA_MIXHI_F16>;
}

class Srl<int N> : PatFrag<(ops node:$src),
(srl node:$src, (i32 N))>;

foreach Bits = [8, 16, 24] in {
def srl#Bits : Srl<Bits>;
}

def and_255 : PatFrag<
(ops node:$src0), (and node:$src0, (i32 255))
>;

class Extract_U8<int FromBitIndex> : PatFrag<(
ops node:$src),
!if (!eq (FromBitIndex, 24), // last element
(!cast<Srl>("srl"#FromBitIndex) node:$src),
!if (!eq (FromBitIndex, 0), // first element
(and_255 node:$src),
(and_255 (!cast<Srl>("srl"#FromBitIndex) node:$src))))>;

// Defines patterns that extract each Index'ed 8bit from a 32bit scalar value;
foreach Index = [1, 2, 3, 4] in {
def UElt#Index : Extract_U8<!shl(!add(Index, -1), 3)>;
}

// Defines multiplication patterns where the multiplication is happening on each
// Index'ed 8bit of a 32bit scalar value.
foreach Index = [1, 2, 3, 4] in {
def MulU_Elt#Index : PatFrag<
(ops node:$src0, node:$src1),
(AMDGPUmul_u24_oneuse (!cast<Extract_U8>("UElt"#Index) node:$src0),
(!cast<Extract_U8>("UElt"#Index) node:$src1))>;
}

class UDot2Pat<Instruction Inst> : GCNPat <
(add (add_oneuse (AMDGPUmul_u24_oneuse (srl i32:$src0, (i32 16)),
(srl i32:$src1, (i32 16))), i32:$src2),
Expand Down Expand Up @@ -212,6 +245,12 @@ defm : DotPats<int_amdgcn_udot8, V_DOT8_U32_U4>;
def : UDot2Pat<V_DOT2_U32_U16>;
def : SDot2Pat<V_DOT2_I32_I16>;

def : GCNPat <
!cast<dag>(!foldl((i32 i32:$src2), [1, 2, 3, 4], lhs, y,
(add_oneuse lhs, (!cast<PatFrag>("MulU_Elt"#y) i32:$src0, i32:$src1)))),
(V_DOT4_U32_U8 (i32 8), $src0, (i32 8), $src1, (i32 8), $src2, (i1 0))
>;

} // End SubtargetPredicate = HasDLInsts

multiclass VOP3P_Real_vi<bits<10> op> {
Expand Down

0 comments on commit 9250c92

Please sign in to comment.