Skip to content

Commit

Permalink
Revert "[X86][AMX] Try to hoist AMX shapes' def"
Browse files Browse the repository at this point in the history
This reverts commit 9011856.

Reason: Broke the MSan buildbots.
https://lab.llvm.org/buildbot/#/builders/5/builds/6967/steps/9/logs/stdio

More details can be found in the original phabricator review:
https://reviews.llvm.org/D101067
  • Loading branch information
hctim committed Apr 23, 2021
1 parent e10d7d4 commit caea37b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 68 deletions.
70 changes: 17 additions & 53 deletions llvm/lib/Target/X86/X86PreTileConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,6 @@ struct MIRef {
++I, ++Pos)
MI = &*I;
}
MIRef(MachineInstr *MI)
: MI(MI), MBB(MI->getParent()),
Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
MIRef(MachineInstr *MI, MachineBasicBlock *MBB)
: MI(MI), MBB(MBB),
Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
Expand All @@ -69,7 +66,6 @@ struct MIRef {
bool operator==(const MIRef &RHS) const {
return MI == RHS.MI && MBB == RHS.MBB;
}
bool operator!=(const MIRef &RHS) const { return !(*this == RHS); }
bool operator<(const MIRef &RHS) const {
return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos);
}
Expand All @@ -81,7 +77,7 @@ struct MIRef {
struct BBInfo {
MIRef FirstAMX;
MIRef LastCall;
bool HasAMXRegLiveIn = false;
MIRef LastShape;
bool TileCfgForbidden = false;
bool NeedTileCfgLiveIn = false;
};
Expand All @@ -90,8 +86,8 @@ class X86PreTileConfig : public MachineFunctionPass {
MachineRegisterInfo *MRI;
const MachineLoopInfo *MLI;
SmallSet<MachineInstr *, 8> DefVisited;
SmallSet<MachineBasicBlock *, 8> ShapeBBs;
DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo;
DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs;

/// Check if the callee will clobber AMX registers.
bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) {
Expand Down Expand Up @@ -128,33 +124,6 @@ class X86PreTileConfig : public MachineFunctionPass {
/// Collect the shape def information for later use.
void collectShapeInfo(MachineInstr &MI);

/// Try to hoist shapes definded below AMX instructions.
bool hoistShapesInBB(MachineBasicBlock *MBB) {
auto FirstShapeBelowAMX =
llvm::lower_bound(ShapeBBs[MBB], BBVisitedInfo[MBB].FirstAMX);
auto InsertPoint = BBVisitedInfo[MBB].FirstAMX.MI->getIterator();
for (auto I = FirstShapeBelowAMX, E = ShapeBBs[MBB].end(); I != E; ++I) {
// Do not hoist instructions that access memory.
if (I->MI->mayLoadOrStore())
return false;
for (auto &MO : I->MI->operands()) {
if (MO.isDef())
continue;
// Do not hoist instructions if the sources' def under AMX instruction.
// TODO: We can handle isMoveImmediate MI here.
if (MO.isReg() &&
MIRef(MRI->getVRegDef(MO.getReg())) > BBVisitedInfo[MBB].FirstAMX)
return false;
// TODO: Maybe need more checks here.
}
MBB->insert(InsertPoint, I->MI->removeFromParent());
}
// We only need to mark the last shape in the BB now.
ShapeBBs[MBB].clear();
ShapeBBs[MBB].push_back(MIRef(&*--InsertPoint, MBB));
return true;
}

public:
X86PreTileConfig() : MachineFunctionPass(ID) {}

Expand Down Expand Up @@ -196,9 +165,9 @@ INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",
void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) {
auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {
MIRef MIR(MI, MBB);
auto I = llvm::lower_bound(ShapeBBs[MBB], MIR);
if (*I != MIR)
ShapeBBs[MBB].insert(I, MIR);
if (BBVisitedInfo[MBB].LastShape < MIR)
BBVisitedInfo[MBB].LastShape = MIR;
ShapeBBs.insert(MBB);
};

SmallVector<Register, 8> WorkList(
Expand Down Expand Up @@ -260,10 +229,6 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
else
CfgLiveInBBs.push_back(&MBB);
}
if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn)
for (auto *Succ : MBB.successors())
if (!isLoopBackEdge(Succ, &MBB))
BBVisitedInfo[Succ].HasAMXRegLiveIn = true;
}

// Update NeedTileCfgLiveIn for predecessors.
Expand All @@ -287,17 +252,8 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
return false;

// Avoid to insert ldtilecfg before any shape defs.
SmallVector<MachineBasicBlock *, 8> WorkList;
for (auto &I : ShapeBBs) {
// TODO: We can hoist shapes across BBs here.
if (BBVisitedInfo[I.first].HasAMXRegLiveIn)
REPORT_CONFIG_FAIL
if (BBVisitedInfo[I.first].FirstAMX &&
BBVisitedInfo[I.first].FirstAMX < ShapeBBs[I.first].back() &&
!hoistShapesInBB(I.first))
REPORT_CONFIG_FAIL
WorkList.push_back(I.first);
}
SmallVector<MachineBasicBlock *, 8> WorkList(
make_range(ShapeBBs.begin(), ShapeBBs.end()));
while (!WorkList.empty()) {
MachineBasicBlock *MBB = WorkList.pop_back_val();
for (auto *Pred : MBB->predecessors()) {
Expand Down Expand Up @@ -326,6 +282,9 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
} else {
// Avoid the BB to be multi visited.
VisitedOrInserted.insert(I);
// We cannot sink it across any AMX instruction.
if (BBVisitedInfo[I.MBB].FirstAMX)
REPORT_CONFIG_FAIL;
// Sink the inserting point along the chain with NeedTileCfgLiveIn =
// true when MBB isn't all shapes reachable.
for (auto *Succ : I.MBB->successors())
Expand All @@ -337,9 +296,14 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {

// A given point might be forked due to shape conditions are not met.
for (MIRef I : InsertPoints) {
// Even MBB is all shapes reachable, we still need to check if there's
// AMX that intersects with shapes in the same MBB.
if (BBVisitedInfo[I.MBB].FirstAMX &&
BBVisitedInfo[I.MBB].FirstAMX < BBVisitedInfo[I.MBB].LastShape)
REPORT_CONFIG_FAIL;
// Make sure we insert ldtilecfg after the last shape def in MBB.
if (ShapeBBs.count(I.MBB) && I < ShapeBBs[I.MBB].back())
I = ShapeBBs[I.MBB].back();
if (I < BBVisitedInfo[I.MBB].LastShape)
I = BBVisitedInfo[I.MBB].LastShape;
// There're chances the MBB is sunk more than once. Record it to avoid
// multi insert.
if (VisitedOrInserted.insert(I).second) {
Expand Down
15 changes: 0 additions & 15 deletions llvm/test/CodeGen/X86/AMX/amx-sched.ll
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

define <256 x i32> @test_shape_sched(i16 %m, i16 %n, i16 %k, <256 x i32> %c, <256 x i32> %a, <256 x i32> %b) nounwind {
; Just to make sure shape def is not scheduled across ldtilecfg.
; CHECK-LABEL: test_shape_sched:
; CHECK: ldtilecfg
; CHECK-NOT: movw
%c1 = bitcast <256 x i32> %c to x86_amx
Expand All @@ -13,19 +12,5 @@ define <256 x i32> @test_shape_sched(i16 %m, i16 %n, i16 %k, <256 x i32> %c, <25
ret <256 x i32> %res
}

define <256 x i32> @test_shape_sched2(i16 %m, i16 %n, i16 %k, i8* %c, i8* %a, i8* %b) nounwind {
; Just to make sure shape def is not scheduled across ldtilecfg.
; CHECK-LABEL: test_shape_sched2:
; CHECK: ldtilecfg
; CHECK-NOT: movw
%aa = lshr i16 %k, 2
%c1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %n, i8* %c, i64 64)
%a1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %k, i8* %a, i64 64)
%b1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %aa, i16 %n, i8* %b, i64 64)
%t = call x86_amx @llvm.x86.tdpbssd.internal(i16 %m, i16 %n, i16 %k, x86_amx %c1, x86_amx %a1, x86_amx %b1)
%res = bitcast x86_amx %t to <256 x i32>
ret <256 x i32> %res
}

declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)

0 comments on commit caea37b

Please sign in to comment.