Skip to content

Commit

Permalink
Reapply "[X86][AMX] Try to hoist AMX shapes' def"
Browse files Browse the repository at this point in the history
We request no intersections between AMX instructions and their shapes'
def when we insert ldtilecfg. However, this is not always ture resulting
from not only users don't follow AMX API model, but also optimizations.

This patch adds a mechanism that tries to hoist AMX shapes' def as well.
It only hoists shapes inside a BB, we can improve it for cases across
BBs in future. Currently, it only hoists shapes of which all sources' def
above the first AMX instruction. We can improve for the case that only
source that moves an immediate value to a register below AMX instruction.

Reviewed By: xiangzhangllvm

Differential Revision: https://reviews.llvm.org/D101067
  • Loading branch information
phoebewang committed Apr 27, 2021
1 parent 9360430 commit 016092d
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 17 deletions.
69 changes: 52 additions & 17 deletions llvm/lib/Target/X86/X86PreTileConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ struct MIRef {
++I, ++Pos)
MI = &*I;
}
MIRef(MachineInstr *MI)
: MI(MI), MBB(MI->getParent()),
Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
MIRef(MachineInstr *MI, MachineBasicBlock *MBB)
: MI(MI), MBB(MBB),
Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
Expand All @@ -66,6 +69,7 @@ struct MIRef {
bool operator==(const MIRef &RHS) const {
return MI == RHS.MI && MBB == RHS.MBB;
}
bool operator!=(const MIRef &RHS) const { return !(*this == RHS); }
bool operator<(const MIRef &RHS) const {
return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos);
}
Expand All @@ -77,7 +81,7 @@ struct MIRef {
struct BBInfo {
MIRef FirstAMX;
MIRef LastCall;
MIRef LastShape;
bool HasAMXRegLiveIn = false;
bool TileCfgForbidden = false;
bool NeedTileCfgLiveIn = false;
};
Expand All @@ -86,8 +90,8 @@ class X86PreTileConfig : public MachineFunctionPass {
MachineRegisterInfo *MRI;
const MachineLoopInfo *MLI;
SmallSet<MachineInstr *, 8> DefVisited;
SmallSet<MachineBasicBlock *, 8> ShapeBBs;
DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo;
DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs;

/// Check if the callee will clobber AMX registers.
bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) {
Expand Down Expand Up @@ -124,6 +128,32 @@ class X86PreTileConfig : public MachineFunctionPass {
/// Collect the shape def information for later use.
void collectShapeInfo(MachineInstr &MI);

/// Try to hoist shapes definded below AMX instructions.
bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) {
MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX;
auto FirstShapeBelowAMX = llvm::lower_bound(Shapes, FirstAMX);
auto InsertPoint = FirstAMX.MI->getIterator();
for (auto I = FirstShapeBelowAMX, E = Shapes.end(); I != E; ++I) {
// Do not hoist instructions that access memory.
if (I->MI->mayLoadOrStore())
return false;
for (auto &MO : I->MI->operands()) {
if (MO.isDef())
continue;
// Do not hoist instructions if the sources' def under AMX instruction.
// TODO: We can handle isMoveImmediate MI here.
if (MO.isReg() && MIRef(MRI->getVRegDef(MO.getReg())) > FirstAMX)
return false;
// TODO: Maybe need more checks here.
}
MBB->insert(InsertPoint, I->MI->removeFromParent());
}
// We only need to mark the last shape in the BB now.
Shapes.clear();
Shapes.push_back(MIRef(&*--InsertPoint, MBB));
return true;
}

public:
X86PreTileConfig() : MachineFunctionPass(ID) {}

Expand Down Expand Up @@ -165,9 +195,9 @@ INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",
void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) {
auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {
MIRef MIR(MI, MBB);
if (BBVisitedInfo[MBB].LastShape < MIR)
BBVisitedInfo[MBB].LastShape = MIR;
ShapeBBs.insert(MBB);
auto I = llvm::lower_bound(ShapeBBs[MBB], MIR);
if (I == ShapeBBs[MBB].end() || *I != MIR)
ShapeBBs[MBB].insert(I, MIR);
};

SmallVector<Register, 8> WorkList(
Expand Down Expand Up @@ -229,6 +259,10 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
else
CfgLiveInBBs.push_back(&MBB);
}
if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn)
for (auto *Succ : MBB.successors())
if (!isLoopBackEdge(Succ, &MBB))
BBVisitedInfo[Succ].HasAMXRegLiveIn = true;
}

// Update NeedTileCfgLiveIn for predecessors.
Expand All @@ -252,8 +286,17 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
return false;

// Avoid to insert ldtilecfg before any shape defs.
SmallVector<MachineBasicBlock *, 8> WorkList(
make_range(ShapeBBs.begin(), ShapeBBs.end()));
SmallVector<MachineBasicBlock *, 8> WorkList;
for (auto &I : ShapeBBs) {
// TODO: We can hoist shapes across BBs here.
if (BBVisitedInfo[I.first].HasAMXRegLiveIn)
REPORT_CONFIG_FAIL
if (BBVisitedInfo[I.first].FirstAMX &&
BBVisitedInfo[I.first].FirstAMX < I.second.back() &&
!hoistShapesInBB(I.first, I.second))
REPORT_CONFIG_FAIL
WorkList.push_back(I.first);
}
while (!WorkList.empty()) {
MachineBasicBlock *MBB = WorkList.pop_back_val();
for (auto *Pred : MBB->predecessors()) {
Expand Down Expand Up @@ -282,9 +325,6 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
} else {
// Avoid the BB to be multi visited.
VisitedOrInserted.insert(I);
// We cannot sink it across any AMX instruction.
if (BBVisitedInfo[I.MBB].FirstAMX)
REPORT_CONFIG_FAIL;
// Sink the inserting point along the chain with NeedTileCfgLiveIn =
// true when MBB isn't all shapes reachable.
for (auto *Succ : I.MBB->successors())
Expand All @@ -296,14 +336,9 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {

// A given point might be forked due to shape conditions are not met.
for (MIRef I : InsertPoints) {
// Even MBB is all shapes reachable, we still need to check if there's
// AMX that intersects with shapes in the same MBB.
if (BBVisitedInfo[I.MBB].FirstAMX &&
BBVisitedInfo[I.MBB].FirstAMX < BBVisitedInfo[I.MBB].LastShape)
REPORT_CONFIG_FAIL;
// Make sure we insert ldtilecfg after the last shape def in MBB.
if (I < BBVisitedInfo[I.MBB].LastShape)
I = BBVisitedInfo[I.MBB].LastShape;
if (ShapeBBs.count(I.MBB) && I < ShapeBBs[I.MBB].back())
I = ShapeBBs[I.MBB].back();
// There're chances the MBB is sunk more than once. Record it to avoid
// multi insert.
if (VisitedOrInserted.insert(I).second) {
Expand Down
15 changes: 15 additions & 0 deletions llvm/test/CodeGen/X86/AMX/amx-sched.ll
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

define <256 x i32> @test_shape_sched(i16 %m, i16 %n, i16 %k, <256 x i32> %c, <256 x i32> %a, <256 x i32> %b) nounwind {
; Just to make sure shape def is not scheduled across ldtilecfg.
; CHECK-LABEL: test_shape_sched:
; CHECK: ldtilecfg
; CHECK-NOT: movw
%c1 = bitcast <256 x i32> %c to x86_amx
Expand All @@ -12,5 +13,19 @@ define <256 x i32> @test_shape_sched(i16 %m, i16 %n, i16 %k, <256 x i32> %c, <25
ret <256 x i32> %res
}

define <256 x i32> @test_shape_sched2(i16 %m, i16 %n, i16 %k, i8* %c, i8* %a, i8* %b) nounwind {
; Just to make sure shape def is not scheduled across ldtilecfg.
; CHECK-LABEL: test_shape_sched2:
; CHECK: ldtilecfg
; CHECK-NOT: movw
%aa = lshr i16 %k, 2
%c1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %n, i8* %c, i64 64)
%a1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %k, i8* %a, i64 64)
%b1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %aa, i16 %n, i8* %b, i64 64)
%t = call x86_amx @llvm.x86.tdpbssd.internal(i16 %m, i16 %n, i16 %k, x86_amx %c1, x86_amx %a1, x86_amx %b1)
%res = bitcast x86_amx %t to <256 x i32>
ret <256 x i32> %res
}

declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)

0 comments on commit 016092d

Please sign in to comment.