Skip to content

Commit

Permalink
[mlir][sparse] support coiteration over sparse tensor slices
Browse files Browse the repository at this point in the history
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D140736
  • Loading branch information
PeimingLiu committed Feb 15, 2023
1 parent 22b7685 commit e2e83f4
Show file tree
Hide file tree
Showing 6 changed files with 304 additions and 48 deletions.
133 changes: 95 additions & 38 deletions mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
Expand Up @@ -87,6 +87,30 @@ static std::pair<Value, Value> fromSliceCoord(OpBuilder &builder, Location loc,
return std::make_pair(v, rem);
}

static std::pair<Value, Value>
genSliceLegitPredicate(OpBuilder &builder, Location loc, Value coord,
SparseTensorEncodingAttr enc, unsigned lvl) {
std::pair<Value, Value> trans = fromSliceCoord(builder, loc, coord, enc, lvl);
// First, coord >= offset (TODO: seems unsigned >= 0 won't be folded, skip
// the check if the offset is zero).
auto geOffset =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, coord,
getSliceOffset(builder, loc, enc, lvl));
// Second, coord_in_slice < length
auto ltLength =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, trans.first,
getSliceSize(builder, loc, enc, lvl));

// Third, rem == 0; confirmed that (a % 1) will be folded to 0
auto fitStride =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, trans.second,
constantIndex(builder, loc, 0));

auto pred = builder.create<arith::AndIOp>(loc, geOffset, ltLength);
pred = builder.create<arith::AndIOp>(loc, pred, fitStride);
return {trans.first, pred};
}

//===----------------------------------------------------------------------===//
// Sparse tensor loop emitter class implementations
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -353,31 +377,14 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
if (isSparseSlices[tid] && isSparseInput) {
// For sparse level slices, we need to filter out invalid coordinates that
// are not included in the slice.
std::pair<Value, Value> trans = fromSliceCoord(builder, loc, c, enc, dim);
SmallVector<Type> types;
for (Value red : reduc)
types.push_back(red.getType());

// First, coord >= offset (TODO: seems unsigned >= 0 won't be folded, skip
// the check if the offset is zero).
auto geOff =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, c,
getSliceOffset(builder, loc, enc, dim));
// Second, coords < length
auto ltLen = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ult, trans.first,
getSliceSize(builder, loc, enc, dim));

// Third, rem == 0; confirmed that (a % 1) will be folded to 0
auto fitStride = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, trans.second,
constantIndex(builder, loc, 0));

auto pred = builder.create<arith::AndIOp>(loc, geOff, ltLen);
pred = builder.create<arith::AndIOp>(loc, pred, fitStride);
auto [trans, pred] = genSliceLegitPredicate(builder, loc, c, enc, dim);
bool hasReduc = !types.empty();
scf::IfOp ifOp =
builder.create<scf::IfOp>(loc, types, pred, /*else*/ hasReduc);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, pred,
/*else*/ hasReduc);
if (hasReduc) {
// scf.for (a) -> v
// %s = scf.if (a) -> v
Expand All @@ -392,15 +399,15 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
}
// Set the insertion point to matched branch.
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
c = trans.first;
c = trans;
}

assert(c);
coord[tid][dim] = c;
// NOTE: we can also prepare for next dim here in advance
// Push the loop into stack
loopStack.emplace_back(ArrayRef<size_t>(tid), ArrayRef<size_t>(dim), loop,
coord[tid][dim], loopTag);
builder.getInsertionBlock(), coord[tid][dim], loopTag);
// Emit extra locals.
emitExtraLocalsForTensorsAtDenseDims(builder, loc, tids, dims);

Expand Down Expand Up @@ -470,7 +477,7 @@ Operation *LoopEmitter::enterFilterLoopOverTensorAtDim(
// NOTE: we can also prepare for next dim here in advance
// Push the loop into stack
loopStack.emplace_back(ArrayRef<size_t>(tid), ArrayRef<size_t>(dim), forOp,
coord[tid][dim], nullptr);
builder.getInsertionBlock(), coord[tid][dim], nullptr);
return forOp;
}

Expand Down Expand Up @@ -536,7 +543,9 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(

// Generates while body.
builder.setInsertionPointToStart(&whileOp.getAfter().front());
Value min;

SmallVector<std::pair<Value, unsigned>> slicesPreds;
unsigned i = 0;
for (auto [tid, dim] : llvm::zip(tids, dims)) {
// Prepares for next level.
if (isCompressedDLT(dimTypes[tid][dim]) ||
Expand All @@ -545,26 +554,73 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
Value s = pidxs[tid][dim];
Value load = genIndexLoad(builder, loc, ptr, s);
coord[tid][dim] = load;
if (!needsUniv) {
if (isSparseSlices[tid]) {
auto enc = getSparseTensorEncoding(tensors[tid].getType());
auto [trans, pred] =
genSliceLegitPredicate(builder, loc, load, enc, dim);
slicesPreds.emplace_back(pred, i);
// Updates to the relative coordinate to the slice.
coord[tid][dim] = trans;
}
i++;
}
}

if (!slicesPreds.empty()) {
// Skips invalid loop iteration when slice coordinate is inapplicable.
SmallVector<Value> yields(after->getArguments());
// Generates a list of if statments
// pidx = in_slice ? pidx : pidx + 1
// TODO: instead of always picking pidx + 1, we should set pidx = high to
// break to loop the coordinates is larger than the slice size.
for (auto [pred, idx] : slicesPreds) {
Value nextPidx = builder.create<arith::AddIOp>(
loc, yields[idx], constantIndex(builder, loc, 1));
yields[idx] =
builder.create<arith::SelectOp>(loc, pred, yields[idx], nextPidx);
}

Value pred = slicesPreds.front().first;
for (int i = 1, e = slicesPreds.size(); i < e; i++) {
pred = builder.create<arith::AndIOp>(loc, pred, slicesPreds[i].first);
}
auto ifOp = builder.create<scf::IfOp>(loc, types, pred, /*else*/ true);
ifOp->setAttr(getLoopEmitterLoopAttrName(),
StringAttr::get(builder.getContext(), "slice"));
builder.create<scf::YieldOp>(loc, ifOp->getResults());
assert(types.size() == yields.size());
// If not all slices are legit
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
builder.create<scf::YieldOp>(loc, yields);

// If all slices are legit, start the user generated code.
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
}

Value min;
// Finds the minimum coordinate
if (!needsUniv) {
for (auto [tid, dim] : llvm::zip(tids, dims)) {
if (isCompressedDLT(dimTypes[tid][dim]) ||
isSingletonDLT(dimTypes[tid][dim])) {
if (min) {
Value cmp = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ult, load, min);
min = builder.create<arith::SelectOp>(loc, cmp, load, min);
loc, arith::CmpIPredicate::ult, coord[tid][dim], min);
min = builder.create<arith::SelectOp>(loc, cmp, coord[tid][dim], min);
} else {
min = load;
min = coord[tid][dim];
}
}
}
}

if (needsUniv) {
} else {
assert(!min);
// Otherwise, universal index is the minimal pidx.
min = after->getArguments().back();
}

// Sets up the loop stack.
loopStack.emplace_back(tids, dims, whileOp, min, loopTag);
loopStack.emplace_back(tids, dims, whileOp, builder.getInsertionBlock(), min,
loopTag);
assert(loopStack.size() == loopSeqStack.size());

// Emits extra locals
Expand Down Expand Up @@ -642,6 +698,7 @@ void LoopEmitter::emitExtraLocalsForTensorsAtDenseDims(OpBuilder &builder,
void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
MutableArrayRef<Value> reduc) {
LoopLevelInfo &loopInfo = loopStack.back();
rewriter.setInsertionPointToEnd(loopInfo.userCodeBlock);
auto &dims = loopStack.back().dims;
auto &tids = loopStack.back().tids;
auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop);
Expand Down Expand Up @@ -722,12 +779,12 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,

void LoopEmitter::exitCoIterationLoop(OpBuilder &builder, Location loc,
MutableArrayRef<Value> reduc) {
auto whileOp = llvm::cast<scf::WhileOp>(loopStack.back().loop);
auto &dims = loopStack.back().dims;
auto &tids = loopStack.back().tids;
Value iv = loopStack.back().iv;
// Generation while loop induction at the end.
builder.setInsertionPointToEnd(&whileOp.getAfter().front());
const LoopLevelInfo &loopInfo = loopStack.back();
auto whileOp = llvm::cast<scf::WhileOp>(loopInfo.loop);
builder.setInsertionPointToEnd(loopInfo.userCodeBlock);
auto &dims = loopInfo.dims;
auto &tids = loopInfo.tids;
Value iv = loopInfo.iv;
// Finalize the induction. Note that the induction could be performed
// in the individual if-branches to avoid re-evaluating the conditions.
// However, that would result in a rather elaborate forest of yield
Expand Down
9 changes: 5 additions & 4 deletions mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
Expand Up @@ -170,8 +170,8 @@ class LoopEmitter {
private:
struct LoopLevelInfo {
LoopLevelInfo(ArrayRef<size_t> tids, ArrayRef<size_t> dims, Operation *loop,
Value iv, StringAttr loopTag)
: tids(tids), dims(dims), loop(loop), iv(iv) {
Block *userBlock, Value iv, StringAttr loopTag)
: tids(tids), dims(dims), loop(loop), userCodeBlock(userBlock), iv(iv) {
// Attached a special tag to loop emitter generated loop.
if (loopTag)
loop->setAttr(LoopEmitter::getLoopEmitterLoopAttrName(), loopTag);
Expand All @@ -181,8 +181,9 @@ class LoopEmitter {
const llvm::SmallVector<size_t> tids;
// The corresponding dims for the tensors
const llvm::SmallVector<size_t> dims;
const Operation *loop; // the loop operation
const Value iv; // the induction variable for the loop
const Operation *loop; // the loop operation
Block *const userCodeBlock; // the block holding users' generated code.
const Value iv; // the induction variable for the loop
};

/// Linearizes address for dense dimension (i.e., p = (i * d0) + j).
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Expand Up @@ -1066,6 +1066,11 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, unsigned idx,
if (env.isReduc() || env.isExpand() || env.getInsertionChain()) {
while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
builder.getInsertionBlock()->getParentOp())) {
// Break on IfOp for slicing filtering.
if (ifOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ==
StringAttr::get(ifOp->getContext(), "slice"))
break;

unsigned y = 0;
SmallVector<Value> yields;
if (env.isReduc()) {
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/SparseTensor/sparse_1d.mlir
Expand Up @@ -1300,11 +1300,11 @@ func.func @four_tensors_op(%arga: tensor<?xf64>,
// CHECK: scf.condition(%[[VAL_33]]) %[[VAL_25]], %[[VAL_26]], %[[VAL_27]], %[[VAL_28]] : index, index, index, f64
// CHECK: } do {
// CHECK: ^bb0(%[[VAL_34:.*]]: index, %[[VAL_35:.*]]: index, %[[VAL_36:.*]]: index, %[[VAL_37:.*]]: f64):
// CHECK: %[[VAL_38:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_34]]] : memref<?xindex>
// CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_35]]] : memref<?xindex>
// CHECK-DAG: %[[VAL_38:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_34]]] : memref<?xindex>
// CHECK-DAG: %[[VAL_39:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_35]]] : memref<?xindex>
// CHECK-DAG: %[[VAL_42:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_36]]] : memref<?xindex>
// CHECK: %[[VAL_40:.*]] = arith.cmpi ult, %[[VAL_39]], %[[VAL_38]] : index
// CHECK: %[[VAL_41:.*]] = arith.select %[[VAL_40]], %[[VAL_39]], %[[VAL_38]] : index
// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_36]]] : memref<?xindex>
// CHECK: %[[VAL_43:.*]] = arith.cmpi ult, %[[VAL_42]], %[[VAL_41]] : index
// CHECK: %[[VAL_44:.*]] = arith.select %[[VAL_43]], %[[VAL_42]], %[[VAL_41]] : index
// CHECK: %[[VAL_45:.*]] = arith.cmpi eq, %[[VAL_38]], %[[VAL_44]] : index
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/SparseTensor/sparse_2d.mlir
Expand Up @@ -1128,11 +1128,11 @@ func.func @sampled_dense_dense(%args: tensor<?x?xf32, #Tss>,
// CHECK: scf.condition(%[[VAL_56]]) %[[VAL_48]], %[[VAL_49]], %[[VAL_50]], %[[VAL_51]] : index, index, index, f32
// CHECK: } do {
// CHECK: ^bb0(%[[VAL_57:.*]]: index, %[[VAL_58:.*]]: index, %[[VAL_59:.*]]: index, %[[VAL_60:.*]]: f32):
// CHECK: %[[VAL_61:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_57]]] : memref<?xindex>
// CHECK: %[[VAL_62:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_58]]] : memref<?xindex>
// CHECK-DAG: %[[VAL_61:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_57]]] : memref<?xindex>
// CHECK-DAG: %[[VAL_62:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_58]]] : memref<?xindex>
// CHECK-DAG: %[[VAL_65:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_59]]] : memref<?xindex>
// CHECK: %[[VAL_63:.*]] = arith.cmpi ult, %[[VAL_62]], %[[VAL_61]] : index
// CHECK: %[[VAL_64:.*]] = arith.select %[[VAL_63]], %[[VAL_62]], %[[VAL_61]] : index
// CHECK: %[[VAL_65:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_59]]] : memref<?xindex>
// CHECK: %[[VAL_66:.*]] = arith.cmpi ult, %[[VAL_65]], %[[VAL_64]] : index
// CHECK: %[[VAL_67:.*]] = arith.select %[[VAL_66]], %[[VAL_65]], %[[VAL_64]] : index
// CHECK: %[[VAL_68:.*]] = arith.cmpi eq, %[[VAL_61]], %[[VAL_67]] : index
Expand Down

0 comments on commit e2e83f4

Please sign in to comment.