Skip to content

Commit

Permalink
[mlir][sparse] group tensor id and levels into pairs in loop emitter
Browse files Browse the repository at this point in the history
This addressed some unresolved comments in https://reviews.llvm.org/D142930

Reviewed By: aartbik, wrengr

Differential Revision: https://reviews.llvm.org/D148565
  • Loading branch information
PeimingLiu committed May 4, 2023
1 parent 9c4717a commit 36c95ee
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 143 deletions.
19 changes: 19 additions & 0 deletions mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
Expand Up @@ -85,6 +85,25 @@ class CodegenEnv {
return latticeMerger.getDimLevelType(b);
}

//
// LoopEmitter delegates.
//

constexpr TensorLevel makeTensorLevel(TensorId t, Level l) const {
// Make sure LoopEmitter, GenericOp, and Merger agree on the number of
// tensors. Merger has one more synthetic tensor for loop invariants.
assert(loopEmitter.getNumTensors() == linalgOp->getNumOperands() &&
loopEmitter.getNumTensors() == latticeMerger.getNumTensors() - 1);
return loopEmitter.makeTensorLevel(t, l);
}
std::pair<TensorId, Level> unpackTensorLevel(TensorLevel tl) const {
return loopEmitter.unpackTensorLevel(tl);
}
template <class ContainerTy>
auto unpackTensorLevelRange(ContainerTy &&c) const {
return loopEmitter.unpackTensorLevelRange(std::forward<ContainerTy>(c));
}

//
// Code generation environment verify functions.
//
Expand Down
101 changes: 43 additions & 58 deletions mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
Expand Up @@ -456,13 +456,12 @@ void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc,
}

void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
ArrayRef<TensorId> tids,
ArrayRef<Level> lvls) {
ArrayRef<TensorLevel> tidLvls) {
// TODO: sort
assert(loopSeqStack.size() == loopStack.size());
// Prepares for all the tensors used in the current loop sequence.
std::vector<std::tuple<TensorId, Level, bool>> slicedTids;
for (auto [tid, lvl] : llvm::zip(tids, lvls)) {
for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
if (!dependentLvlMap[tid][lvl].empty()) {
bool fullyRed = genSliceBegin(builder, loc, tid, lvl);
slicedTids.emplace_back(tid, lvl, fullyRed);
Expand Down Expand Up @@ -660,17 +659,19 @@ Operation *LoopEmitter::emitWhileLoopOverSliceAtSparseLvl(
return loop;
}

Operation *LoopEmitter::enterLoopOverTensorAtLvl(
OpBuilder &builder, Location loc, ArrayRef<TensorId> tids,
ArrayRef<Level> lvls, MutableArrayRef<Value> reduc, bool isParallel) {
Operation *LoopEmitter::enterLoopOverTensorAtLvl(OpBuilder &builder,
Location loc,
ArrayRef<TensorLevel> tidLvls,
MutableArrayRef<Value> reduc,
bool isParallel) {
// TODO: support multiple return on parallel for?
assert(!isParallel || reduc.size() <= 1);
bool isSparseCond = false, isSparseSliceCond = false;
size_t tid = tids.front(), lvl = lvls.front();
auto [tid, lvl] = unpackTensorLevel(tidLvls.front());

// Finds out the tensor level that we should use to generate loops. Amongs all
// the tensor levels, there is at most one sparse tensor level.
for (auto [t, l] : llvm::zip(tids, lvls)) {
for (auto [t, l] : unpackTensorLevelRange(tidLvls)) {
assert(lvlTypes[t].size() > l); // Must be a valid tid, dim pair
assert(!coords[t][l] || // We cannot re-enter the same level
!dependentLvlMap[t][l].empty()); // unless it is a slice-driver loop
Expand Down Expand Up @@ -712,12 +713,9 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl(
Operation *l = nullptr;

// At most one tensor used as condition in for loop;
SmallVector<TensorId, 1> condTid;
SmallVector<Level, 1> condLvl;
// There Might be multiple dense slice driven tensor.
SmallVector<TensorId> sliceTids;
SmallVector<Level> sliceLvls;
SmallVector<bool> sliceReduc;
SmallVector<TensorLevel, 1> condTidLvl;
// There might be multiple dense slice driven tensor.
SmallVector<SliceLoopInfo> sliceDrivenInfo;

// Generates loops differently depending on whether we need a slice-driven
// loop or a simple level traversal loop.
Expand All @@ -734,9 +732,7 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl(
lvl, reduc);
}
levelReducedDep[tid][lvl]++;
sliceTids.push_back(tid);
sliceLvls.push_back(lvl);
sliceReduc.push_back(fullyReduced);
sliceDrivenInfo.emplace_back(tid, lvl, fullyReduced);
} else {
Value lo = isSparseCond ? posits[tid][lvl] // current offset
: loopSeqStack.back().first; // universal index
Expand All @@ -747,21 +743,19 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl(
// Adjust for loop hi for dense slice-driven loop.
if (fullyReduced) {
hi = sliceSz;
condTid.push_back(tid);
condLvl.push_back(lvl);
condTidLvl.push_back(makeTensorLevel(tid, lvl));
} else {
hi = SUBI(lvlSizes[tid][lvl], sliceSz);
hi = ADDI(hi, C_IDX(1));
}
} else {
condTid.push_back(tid);
condLvl.push_back(lvl);
condTidLvl.push_back(makeTensorLevel(tid, lvl));
}
l = emitForLoopOverTensorAtLvl(builder, loc, tid, lvl, lo, hi, reduc,
isParallel);
}
Value iv = coords[tid][lvl];
for (auto [t, l] : llvm::zip(tids, lvls)) {
for (auto [t, l] : unpackTensorLevelRange(tidLvls)) {
// We only need to handle slice-driven loops on dense level here.
// If it is a slice-driven loop on sparse level, it needs a while loop to
// insert break statements, and it must have been handled correctly in L692.
Expand All @@ -774,9 +768,7 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl(
} else {
// Puts sliced dense loop into LoopInfo so that LoopEmitter knows how to
// exit it.
sliceTids.push_back(t);
sliceLvls.push_back(l);
sliceReduc.push_back(fullyReduc);
sliceDrivenInfo.emplace_back(t, l, fullyReduc);
// Update the slice information as we enter the new loop.
assert(*info.slicedOnLvl == l);
info.minCrd = info.offset = iv;
Expand All @@ -787,10 +779,10 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl(
}
// NOTE: we can also prepare for next dim here in advance
// Pushes the loop into stack.
loopStack.emplace_back(condTid, condLvl, sliceTids, sliceLvls, sliceReduc, l,
loopStack.emplace_back(condTidLvl, sliceDrivenInfo, l,
builder.getInsertionBlock(), iv, loopTag);
// Emit extra locals.
emitExtraLocalsForTensorsAtDenseLvls(builder, loc, tids, lvls);
emitExtraLocalsForTensorsAtDenseLvls(builder, loc, tidLvls);
return l;
}

Expand Down Expand Up @@ -854,33 +846,33 @@ Operation *LoopEmitter::enterFilterLoopOverTensorAtLvl(

// NOTE: we can also prepare for next lvl here in advance
// Push the loop into stack
loopStack.emplace_back(ArrayRef<TensorId>(tid), ArrayRef<Level>(lvl),
ArrayRef<TensorId>(), ArrayRef<Level>(),
ArrayRef<bool>(), forOp, builder.getInsertionBlock(),
coords[tid][lvl], nullptr);
loopStack.emplace_back(ArrayRef<TensorLevel>(makeTensorLevel(tid, lvl)),
ArrayRef<SliceLoopInfo>(), forOp,
builder.getInsertionBlock(), coords[tid][lvl],
nullptr);
return forOp;
}

void LoopEmitter::genDenseAffineAddress(OpBuilder &builder, Location loc,
TensorId tid, Level lvl,
TensorLevel tidLvl,
AffineExpr lvlExpr) {
auto [tid, lvl] = unpackTensorLevel(tidLvl);
assert(isDenseDLT(lvlTypes[tid][lvl]));
// For dense levels, the level-coordinate also serves as the position.
Value lvlCrd = genAffine(builder, loc, lvlExpr);
posits[tid][lvl] = genAddress(builder, loc, tid, lvl, lvlCrd);
}

Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
OpBuilder &builder, Location loc, ArrayRef<TensorId> tids,
ArrayRef<Level> lvls, bool needsUniv, MutableArrayRef<Value> reduc) {
OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
bool needsUniv, MutableArrayRef<Value> reduc) {
// NOTE: the slice driven tensor-related reduction variable must
// appear before normal tensors.
assert(tids.size() == lvls.size());
SmallVector<Type> types;
SmallVector<Value> operands;
// Construct the while-loop with a parameter for each coordinate.
const Type indexType = builder.getIndexType();
for (auto [tid, lvl] : llvm::zip(tids, lvls)) {
for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
// TODO: support coiteration with slice driven tensors.
const auto lvlTp = lvlTypes[tid][lvl];
assert(dependentLvlMap[tid][lvl].empty() && "TODO: not yet implemented");
Expand Down Expand Up @@ -922,7 +914,7 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
builder.setInsertionPointToStart(&whileOp.getBefore().front());
Value cond;
unsigned o = 0;
for (auto [t, lvl] : llvm::zip(tids, lvls)) {
for (auto [t, lvl] : unpackTensorLevelRange(tidLvls)) {
const TensorId tid = t; // Why `t` can not be captured by lambda?
const auto lvlTp = lvlTypes[tid][lvl];
if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) ||
Expand Down Expand Up @@ -956,7 +948,7 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(

SmallVector<std::pair<Value, unsigned>> slicesPreds;
unsigned i = 0;
for (auto [tid, lvl] : llvm::zip(tids, lvls)) {
for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
// Prepares for next level.
const auto lvlTp = lvlTypes[tid][lvl];
if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) ||
Expand Down Expand Up @@ -1007,7 +999,7 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
Value min;
// Finds the minimum coordinate
if (!needsUniv) {
for (auto [tid, lvl] : llvm::zip(tids, lvls)) {
for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
const auto lvlTp = lvlTypes[tid][lvl];
if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) ||
isCompressedWithHiDLT(lvlTp)) {
Expand All @@ -1027,12 +1019,11 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
}

// Sets up the loop stack.
loopStack.emplace_back(tids, lvls, ArrayRef<TensorId>(), ArrayRef<Level>(),
ArrayRef<bool>(), whileOp, builder.getInsertionBlock(),
min, loopTag);
loopStack.emplace_back(tidLvls, ArrayRef<SliceLoopInfo>(), whileOp,
builder.getInsertionBlock(), min, loopTag);
assert(loopStack.size() == loopSeqStack.size());

for (auto [tid, dstLvl] : llvm::zip(tids, lvls)) {
for (auto [tid, dstLvl] : unpackTensorLevelRange(tidLvls)) {
const auto reassoc = getCollapseReassociation(tid, dstLvl);
assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
// TODO: Refactors this into smaller functions.
Expand Down Expand Up @@ -1079,7 +1070,7 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
}

// Emits extra locals
emitExtraLocalsForTensorsAtDenseLvls(builder, loc, tids, lvls);
emitExtraLocalsForTensorsAtDenseLvls(builder, loc, tidLvls);

// Updates reduction variables
assert(after->getNumArguments() == o + reduc.size() + (needsUniv ? 1 : 0));
Expand Down Expand Up @@ -1140,15 +1131,12 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
llvm_unreachable("Unrecognized level-type!");
}

void LoopEmitter::emitExtraLocalsForTensorsAtDenseLvls(OpBuilder &builder,
Location loc,
ArrayRef<TensorId> tids,
ArrayRef<Level> lvls) {
void LoopEmitter::emitExtraLocalsForTensorsAtDenseLvls(
OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls) {
// Initialize dense positions. Note that we generate dense coordinates of the
// output tensor unconditionally, since they may not appear in the lattice,
// but may be needed for linearized codegen.
assert(tids.size() == lvls.size());
for (auto [tid, lvl] : llvm::zip(tids, lvls)) {
for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
if (isDenseDLT(lvlTypes[tid][lvl])) {
// Slice-driven dense level should have be handled already.
if (!dependentLvlMap[tid][lvl].empty())
Expand All @@ -1175,8 +1163,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
MutableArrayRef<Value> reduc) {
const LoopInfo &loopInfo = loopStack.back();
rewriter.setInsertionPointToEnd(loopInfo.userCodeBlock);
for (auto [tid, lvl, reduced] : llvm::zip(
loopInfo.slicedTids, loopInfo.slicedLvls, loopInfo.sliceReduced)) {
for (auto [tid, lvl, reduced] : loopInfo.sliceDrivenInfo) {
SliceInfo &info = sliceStack[tid].back();
assert(isDenseDLT(lvlTypes[tid][lvl]));
assert(*info.slicedOnLvl == lvl && !reduced);
Expand Down Expand Up @@ -1253,7 +1240,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
// Finished iterating a tensor, clean up
// We only do the clean up on for loop as while loops do not necessarily
// finish the iteration on a sparse tensor
for (auto [tid, lvl] : llvm::zip(loopInfo.tids, loopInfo.lvls)) {
for (auto [tid, lvl] : unpackTensorLevelRange(loopInfo.tidLvls)) {
// Reset to null.
coords[tid][lvl] = Value();
posits[tid][lvl] = Value();
Expand All @@ -1278,8 +1265,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
unsigned o = 0;
SmallVector<Value> operands;
unsigned delta = 0;
for (auto [tid, lvl, resolved] : llvm::zip(
loopInfo.slicedTids, loopInfo.slicedLvls, loopInfo.sliceReduced)) {
for (auto [tid, lvl, resolved] : loopInfo.sliceDrivenInfo) {
// TODO: handle dense.
assert(isCompressedDLT(lvlTypes[tid][lvl]));
levelReducedDep[tid][lvl]--;
Expand All @@ -1291,7 +1277,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
// fully reduced while op for iterating one slices.
// FIXME: since we didn't implement coiteration, this must be iteration
// just on fully resolved slice.
assert(loopInfo.slicedTids.size() == 1 && loopInfo.tids.empty());
assert(loopInfo.sliceDrivenInfo.size() == 1 && loopInfo.tidLvls.empty());
// The if guard to filter out out-range coordinates.
assert(llvm::isa<scf::IfOp>(builder.getInsertionBlock()->getParentOp()));
posits[tid][lvl] = whileOp->getResult(o++);
Expand All @@ -1308,7 +1294,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
};

Value one = C_IDX(1);
for (auto [tid, dstLvl] : llvm::zip(loopInfo.tids, loopInfo.lvls)) {
for (auto [tid, dstLvl] : unpackTensorLevelRange(loopInfo.tidLvls)) {
const auto lvlTp = lvlTypes[tid][dstLvl];
if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) ||
isCompressedWithHiDLT(lvlTp)) {
Expand Down Expand Up @@ -1376,7 +1362,6 @@ void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc,
// Clean up the values, it would help use to discover potential bug at a
// earlier stage (instead of silently using a wrong value).
const LoopInfo &loopInfo = loopStack.back();
assert(loopInfo.tids.size() == loopInfo.lvls.size());
SmallVector<Value> red;
if (llvm::isa<scf::WhileOp>(loopInfo.loop)) {
exitWhileLoop(rewriter, loc, reduc);
Expand Down

0 comments on commit 36c95ee

Please sign in to comment.