Skip to content

Commit

Permalink
Revert "[X86] Pass to transform tdpbsud&tdpbusd&tdpbuud intrinsics to…
Browse files Browse the repository at this point in the history
… scalar operation"

This reverts commit 275df61.
  • Loading branch information
yubingex007-a11y committed Mar 30, 2021
1 parent 4ca8607 commit 0c63b86
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 331 deletions.
70 changes: 6 additions & 64 deletions llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp
Expand Up @@ -73,9 +73,6 @@ class X86LowerAMXIntrinsics {
Value *Ptr, Value *Stride, Value *Tile);
template <Intrinsic::ID IntrID>
typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
IntrID == Intrinsic::x86_tdpbsud_internal ||
IntrID == Intrinsic::x86_tdpbusd_internal ||
IntrID == Intrinsic::x86_tdpbuud_internal ||
IntrID == Intrinsic::x86_tdpbf16ps_internal,
Value *>::type
createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B,
Expand All @@ -85,9 +82,6 @@ class X86LowerAMXIntrinsics {
bool lowerTileLoadStore(Instruction *TileLoadStore);
template <Intrinsic::ID IntrID>
typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
IntrID == Intrinsic::x86_tdpbsud_internal ||
IntrID == Intrinsic::x86_tdpbusd_internal ||
IntrID == Intrinsic::x86_tdpbuud_internal ||
IntrID == Intrinsic::x86_tdpbf16ps_internal,
bool>::type
lowerTileDP(Instruction *TileDP);
Expand Down Expand Up @@ -229,33 +223,14 @@ Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops(

template <Intrinsic::ID IntrID>
typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
IntrID == Intrinsic::x86_tdpbsud_internal ||
IntrID == Intrinsic::x86_tdpbusd_internal ||
IntrID == Intrinsic::x86_tdpbuud_internal ||
IntrID == Intrinsic::x86_tdpbf16ps_internal,
Value *>::type
X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End,
IRBuilderBase &B, Value *Row,
Value *Col, Value *K, Value *Acc,
Value *LHS, Value *RHS) {
std::string IntrinName;
switch (IntrID) {
case Intrinsic::x86_tdpbssd_internal:
IntrinName = "tiledpbssd";
break;
case Intrinsic::x86_tdpbsud_internal:
IntrinName = "tiledpbsud";
break;
case Intrinsic::x86_tdpbusd_internal:
IntrinName = "tiledpbusd";
break;
case Intrinsic::x86_tdpbuud_internal:
IntrinName = "tiledpbuud";
break;
case Intrinsic::x86_tdpbf16ps_internal:
IntrinName = "tiledpbf16ps";
break;
}
std::string IntrinName =
IntrID == Intrinsic::x86_tdpbssd_internal ? "tiledpbssd" : "tdpbf16ps";
Loop *RowLoop = nullptr;
Loop *ColLoop = nullptr;
Loop *InnerLoop = nullptr;
Expand Down Expand Up @@ -354,7 +329,7 @@ X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End,
B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol);
Value *NewVecC = nullptr;

if (IntrID != Intrinsic::x86_tdpbf16ps_internal) {
if (IntrID == Intrinsic::x86_tdpbssd_internal) {
// tiledpbssd.scalarize.inner.body:
// calculate idxa, idxb
// %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
Expand All @@ -376,30 +351,12 @@ X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End,
Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty);
Value *EltB = B.CreateExtractElement(VecB, IdxB);
Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty);
Value *SEXTSubVecB = nullptr;
Value *SEXTSubVecA = nullptr;
switch (IntrID) {
case Intrinsic::x86_tdpbssd_internal:
SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
break;
case Intrinsic::x86_tdpbsud_internal:
SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
break;
case Intrinsic::x86_tdpbusd_internal:
SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
break;
case Intrinsic::x86_tdpbuud_internal:
SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
break;
}
Value *SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
Value *SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB));
Value *ResElt = B.CreateAdd(EltC, SubVecR);
NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
} else {
} else if (IntrID == Intrinsic::x86_tdpbf16ps_internal) {
// tiledpbf16ps.scalarize.inner.body:
// calculate idxa, idxb, idxc
// %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
Expand Down Expand Up @@ -461,9 +418,6 @@ X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End,

template <Intrinsic::ID IntrID>
typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
IntrID == Intrinsic::x86_tdpbsud_internal ||
IntrID == Intrinsic::x86_tdpbusd_internal ||
IntrID == Intrinsic::x86_tdpbuud_internal ||
IntrID == Intrinsic::x86_tdpbf16ps_internal,
bool>::type
X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) {
Expand Down Expand Up @@ -573,9 +527,6 @@ bool X86LowerAMXIntrinsics::visit() {
if (auto *Inst = dyn_cast<IntrinsicInst>(&*II++)) {
switch (Inst->getIntrinsicID()) {
case Intrinsic::x86_tdpbssd_internal:
case Intrinsic::x86_tdpbsud_internal:
case Intrinsic::x86_tdpbusd_internal:
case Intrinsic::x86_tdpbuud_internal:
case Intrinsic::x86_tileloadd64_internal:
case Intrinsic::x86_tilestored64_internal:
case Intrinsic::x86_tilezero_internal:
Expand All @@ -594,15 +545,6 @@ bool X86LowerAMXIntrinsics::visit() {
case Intrinsic::x86_tdpbssd_internal:
C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) || C;
break;
case Intrinsic::x86_tdpbsud_internal:
C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(Inst) || C;
break;
case Intrinsic::x86_tdpbusd_internal:
C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(Inst) || C;
break;
case Intrinsic::x86_tdpbuud_internal:
C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(Inst) || C;
break;
case Intrinsic::x86_tdpbf16ps_internal:
C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) || C;
break;
Expand Down

0 comments on commit 0c63b86

Please sign in to comment.