Skip to content

Commit

Permalink
[mlir][sparse] extend loop emitter and optimize lattices with the awa…
Browse files Browse the repository at this point in the history
…reness of slice based iteration

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D142929
  • Loading branch information
PeimingLiu committed Mar 20, 2023
1 parent 8d024a7 commit 1328bb6
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 124 deletions.
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
Expand Up @@ -399,11 +399,17 @@ class Merger {
/// to sparse level-type.
bool hasAnySparse(const BitVector &bits) const;

/// Returns true if bits contains a dependent index reduction condition on
/// sparse levels.
bool hasSparseIdxReduction(const BitVector &bits) const;

/// Gets the level-type of the `t`th tensor on `i`th loop.
DimLevelType getDimLevelType(TensorId t, LoopId i) const {
assert(t < numTensors && i < numLoops);
return lvlTypes[t][i];
}

/// Gets the level-type of the TensorLoopId.
DimLevelType getDimLevelType(TensorLoopId b) const {
return getDimLevelType(tensor(b), loop(b));
}
Expand Down
58 changes: 51 additions & 7 deletions mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
Expand Up @@ -28,6 +28,23 @@ static bool isMaterializing(Value val) {
val.getDefiningOp<bufferization::AllocTensorOp>();
}

/// Makes target array's elements sorted according to the `order` array.
static void sortArrayBasedOnOrder(std::vector<LoopId> &target,
ArrayRef<LoopId> order) {
std::sort(target.begin(), target.end(), [&order](LoopId l, LoopId r) {
assert(l != r);
int idxL = -1, idxR = -1;
for (int i = 0, e = order.size(); i < e; i++) {
if (order[i] == l)
idxL = i;
if (order[i] == r)
idxR = i;
}
assert(idxL >= 0 && idxR >= 0);
return idxL < idxR;
});
}

//===----------------------------------------------------------------------===//
// Code generation environment constructor and general methods
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -57,15 +74,42 @@ void CodegenEnv::startEmit() {
insChain = sparseOut->get();
latticeMerger.setHasSparseOut(true);
}

// Sort the related loop array such that they are in the same order as they
// appears on the topoOrder.
// TODO: since we only handle affine addition for slice based codegen, and
// addition is assoicative, the order how we evaluate the expression does
// not matter. However, to support multiplication, the order of the loop
// index should match the evaluation order to the affine expression AST.

// Initialize loop emitter.
SmallVector<Value> tensors;
for (OpOperand &t : linalgOp->getOpOperands())
SmallVector<Value> tensors; // input tensors passed to loop emitter
for (OpOperand &t : linalgOp->getOpOperands()) {
tensors.push_back(t.get());
loopEmitter.initialize(tensors,
StringAttr::get(linalgOp.getContext(),
linalg::GenericOp::getOperationName()),
/*hasOutput=*/true,
/*isSparseOut=*/sparseOut != nullptr, topSort);
Level rank = linalgOp.getMatchingIndexingMap(&t).getNumResults();
for (Level lvl = 0; lvl < rank; lvl++) {
sortArrayBasedOnOrder(
latticeMerger.getDependentLoops(t.getOperandNumber(), lvl), topSort);
}
}

loopEmitter.initialize(
tensors,
StringAttr::get(linalgOp.getContext(),
linalg::GenericOp::getOperationName()),
/*hasOutput=*/true,
/*isSparseOut=*/sparseOut != nullptr, topSort,
// TODO: compute the map and pass it to loop emitter directly instead of
// passing in a callback.
[this](TensorId t, Level lvl) -> std::vector<std::pair<TensorId, Level>> {
// Translates from a list of loop index to a list of [tid, dim] pair.
std::vector<LoopId> rLoops = this->merger().getDependentLoops(t, lvl);
std::vector<std::pair<TensorId, Level>> ret;
ret.reserve(rLoops.size());
for (LoopId l : rLoops)
ret.emplace_back(this->merger().getLoopDefiningLvl(l));
return ret;
});
}

std::optional<Operation *> CodegenEnv::genLoopBoundary(
Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
Expand Up @@ -99,7 +99,6 @@ class CodegenEnv {
topSort.reserve(capacity);
}

ArrayRef<LoopId> getTopSort() const { return topSort; };
ArrayRef<LoopId> getTopSortSlice(LoopOrd n, LoopOrd m) const;
ArrayRef<LoopId> getLoopStackUpTo(LoopOrd n) const;
ArrayRef<LoopId> getCurrentLoopStack() const;
Expand Down
23 changes: 17 additions & 6 deletions mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
Expand Up @@ -208,12 +208,14 @@ Value LoopEmitter::genSparseCrd(OpBuilder &builder, Location loc, TensorId tid,
}

LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput,
bool isSparseOut, ArrayRef<LoopId> topSort) {
initialize(tensors, loopTag, hasOutput, isSparseOut, topSort);
bool isSparseOut, ArrayRef<LoopId> topSort,
DependentLvlGetter dimGetter) {
initialize(tensors, loopTag, hasOutput, isSparseOut, topSort, dimGetter);
}

void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
bool isSparseOut, ArrayRef<LoopId> topSort) {
bool isSparseOut, ArrayRef<LoopId> topSort,
DependentLvlGetter dimGetter) {
// First initialize the top-level type of the fields.
this->loopTag = loopTag;
this->hasOutput = hasOutput;
Expand Down Expand Up @@ -242,6 +244,9 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
this->loopStack.reserve(numLoops);
this->loopSeqStack.reserve(numLoops);

this->dependentLvlMap.assign(
numTensors, std::vector<std::vector<std::pair<TensorId, Level>>>());

// Initialize nested types of `TensorId`-indexed fields.
for (TensorId tid = 0; tid < numTensors; tid++) {
const Value t = tensors[tid];
Expand Down Expand Up @@ -283,6 +288,12 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
coordinatesBuffers[tid].assign(lvlRank, Value());
sliceOffsets[tid].assign(lvlRank, Value());
sliceStrides[tid].assign(lvlRank, Value());

dependentLvlMap[tid].assign(lvlRank,
std::vector<std::pair<TensorId, Level>>());
if (dimGetter)
for (Level l = 0; l < lvlRank; l++)
dependentLvlMap[tid][l] = dimGetter(tid, l);
}

// Construct the inverse of the `topSort` from the sparsifier.
Expand Down Expand Up @@ -997,8 +1008,8 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
}
}

void LoopEmitter::exitCoIterationLoop(OpBuilder &builder, Location loc,
MutableArrayRef<Value> reduc) {
void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
MutableArrayRef<Value> reduc) {
const LoopInfo &loopInfo = loopStack.back();
auto whileOp = llvm::cast<scf::WhileOp>(loopInfo.loop);
builder.setInsertionPointToEnd(loopInfo.userCodeBlock);
Expand Down Expand Up @@ -1082,7 +1093,7 @@ void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc,
assert(loopInfo.tids.size() == loopInfo.lvls.size());
SmallVector<Value> red;
if (llvm::isa<scf::WhileOp>(loopInfo.loop)) {
exitCoIterationLoop(rewriter, loc, reduc);
exitWhileLoop(rewriter, loc, reduc);
} else {
exitForLoop(rewriter, loc, reduc);
}
Expand Down
27 changes: 23 additions & 4 deletions mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
Expand Up @@ -76,6 +76,14 @@ class LoopEmitter {
/// initializing the loop emitter (e.g., to fill a dense output with zeros).
using OutputUpdater = function_ref<Value(OpBuilder &builder, Location loc,
Value memref, Value tensor)>;
// Map from [tid, dim] to a list of dependent [tid, dim] for affine expression
// index on sparse tensors.
// E.g., for affine index (d0 + d1), it depends on two [tid, dim] that defines
// d0 and d1 (for affine expression reduction).
// If the list is empty, it means that there is no affine expression on the
// input [tid, dim].
using DependentLvlGetter =
function_ref<std::vector<std::pair<TensorId, Level>>(TensorId, Level)>;

LoopEmitter() = default;

Expand All @@ -89,11 +97,13 @@ class LoopEmitter {
/// to `LoopId`.
void initialize(ValueRange tensors, StringAttr loopTag = nullptr,
bool hasOutput = false, bool isSparseOut = false,
ArrayRef<LoopId> topSort = {});
ArrayRef<LoopId> topSort = {},
DependentLvlGetter getter = nullptr);

explicit LoopEmitter(ValueRange tensors, StringAttr loopTag = nullptr,
bool hasOutput = false, bool isSparseOut = false,
ArrayRef<LoopId> topSort = {});
ArrayRef<LoopId> topSort = {},
DependentLvlGetter getter = nullptr);

/// Starts a loop emitting session by generating all the buffers needed
/// for iterating over the tensors.
Expand Down Expand Up @@ -295,8 +305,8 @@ class LoopEmitter {
MutableArrayRef<Value> reduc);

/// Exits a while loop, returns the reduction results.
void exitCoIterationLoop(OpBuilder &builder, Location loc,
MutableArrayRef<Value> reduc);
void exitWhileLoop(OpBuilder &builder, Location loc,
MutableArrayRef<Value> reduc);

//
// View-based-reshape methods.
Expand Down Expand Up @@ -380,6 +390,15 @@ class LoopEmitter {
std::vector<std::vector<Value>> sliceOffsets;
std::vector<std::vector<Value>> sliceStrides;

// Map from [tid, level] to a list of dependent [tid, level].
// See comments for `DependentDimGetter`.
std::vector<std::vector<std::vector<std::pair<TensorId, Level>>>>
dependentLvlMap;

//
// View based reshape related-fields and methods
//

/// Collapse Reassociations related to a specific tensor
// TODO: support expand.
std::vector<ArrayAttr> collapseReassoc;
Expand Down

0 comments on commit 1328bb6

Please sign in to comment.