Skip to content

Commit

Permalink
[mlir][sparse] support dynamic sparse tensor slices.
Browse files Browse the repository at this point in the history
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D141532
  • Loading branch information
PeimingLiu committed Mar 10, 2023
1 parent 8a712bf commit 6db397a
Show file tree
Hide file tree
Showing 14 changed files with 450 additions and 112 deletions.
22 changes: 9 additions & 13 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Expand Up @@ -570,7 +570,7 @@ Level mlir::sparse_tensor::toStoredDim(RankedTensorType type, Dimension d) {
/// We normalized sparse tensor encoding attribute by always using
/// ordered/unique DLT such that "compressed-nu-no" and "compressed-nu" (as well
/// as other variants) lead to the same storage specifier type, and stripping
/// irrelevant fields that does not alter the sparse tensor memory layout.
/// irrelevant fields that do not alter the sparse tensor memory layout.
static SparseTensorEncodingAttr
getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
SmallVector<DimLevelType> dlts;
Expand All @@ -582,13 +582,10 @@ getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
AffineMap(), // dimOrdering (irrelavant to storage speicifer)
AffineMap(), // highLvlOrdering (irrelavant to storage specifer)
// Always use `index` for memSize and lvlSize instead of reusing
// `getPosWidth`/`getCrdWidth`.
// It allows us to reuse the same SSA value for different bitwidth,
// It also avoids casting between index/integer (returned by DimOp)
0, 0,
// FIXME: we should keep the slice information, for now it is okay as only
// constant can be used for slice
ArrayRef<SparseTensorDimSliceAttr>{} /*enc.getDimSlices()*/);
// `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
// value for different bitwidth, it also avoids casting between index and
// integer (returned by DimOp)
0, 0, enc.getDimSlices());
}

StorageSpecifierType
Expand Down Expand Up @@ -620,11 +617,10 @@ static LogicalResult verifySparsifierGetterSetter(
const auto enc = md.getType().getEncoding();
const Level lvlRank = enc.getLvlRank();

// TODO:
// if (mdKind == StorageSpecifierKind::DimOffset ||
// mdKind == StorageSpecifierKind::DimStride)
// if (!enc.isSlice())
// return op->emitError("requested slice data on non-slice tensor");
if (mdKind == StorageSpecifierKind::DimOffset ||
mdKind == StorageSpecifierKind::DimStride)
if (!enc.isSlice())
return op->emitError("requested slice data on non-slice tensor");

if (mdKind != StorageSpecifierKind::ValMemSize) {
if (!lvl)
Expand Down
20 changes: 20 additions & 0 deletions mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
Expand Up @@ -694,3 +694,23 @@ Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc,
Value tensor) {
return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc);
}

Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc,
Value tensor, Dimension dim) {
auto enc = getSparseTensorEncoding(tensor.getType());
assert(enc && enc.isSlice());
std::optional<unsigned> offset = enc.getStaticDimSliceOffset(dim);
if (offset.has_value())
return constantIndex(builder, loc, *offset);
return builder.create<ToSliceOffsetOp>(loc, tensor, APInt(64, dim));
}

Value sparse_tensor::createOrFoldSliceStrideOp(OpBuilder &builder, Location loc,
Value tensor, Dimension dim) {
auto enc = getSparseTensorEncoding(tensor.getType());
assert(enc && enc.isSlice());
std::optional<unsigned> stride = enc.getStaticDimSliceStride(dim);
if (stride.has_value())
return constantIndex(builder, loc, *stride);
return builder.create<ToSliceStrideOp>(loc, tensor, APInt(64, dim));
}
9 changes: 9 additions & 0 deletions mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
Expand Up @@ -364,6 +364,15 @@ Value genToValues(OpBuilder &builder, Location loc, Value tensor);
/// Generates code to retrieve the values size for the sparse tensor.
Value genValMemSize(OpBuilder &builder, Location loc, Value tensor);

/// Generates code to retrieve the slice offset for the sparse tensor slice,
/// return a constant if the offset is statically known.
Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor,
Dimension dim);

/// Generates code to retrieve the slice slice for the sparse tensor slice,
/// return a constant if the offset is statically known.
Value createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, Value tensor,
Dimension dim);
} // namespace sparse_tensor
} // namespace mlir

Expand Down
139 changes: 79 additions & 60 deletions mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
Expand Up @@ -43,29 +43,25 @@ static Value genIndexLoad(OpBuilder &builder, Location loc, Value mem,
return load;
}

// TODO: Support dynamic sized slice.
static Value getSliceOffset(OpBuilder &builder, Location loc,
SparseTensorEncodingAttr enc, unsigned lvl) {
return constantIndex(builder, loc, *enc.getStaticLvlSliceOffset(lvl));
static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor,
unsigned lvl) {
auto enc = getSparseTensorEncoding(tensor.getType());
// FIXME: `toOrigDim` is deprecated
return createOrFoldSliceOffsetOp(builder, loc, tensor, toOrigDim(enc, lvl));
}

static Value getSliceSize(OpBuilder &builder, Location loc,
SparseTensorEncodingAttr enc, unsigned lvl) {
return constantIndex(builder, loc, *enc.getStaticLvlSliceSize(lvl));
}

static Value getSliceStride(OpBuilder &builder, Location loc,
SparseTensorEncodingAttr enc, unsigned lvl) {
return constantIndex(builder, loc, *enc.getStaticLvlSliceStride(lvl));
static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
unsigned lvl) {
auto enc = getSparseTensorEncoding(tensor.getType());
// FIXME: `toOrigDim` is deprecated
return createOrFoldSliceStrideOp(builder, loc, tensor, toOrigDim(enc, lvl));
}

// Converts a coordinate relative to the slice to the coordinate relative
// to the underlying tensor.
static Value toSliceCoord(OpBuilder &builder, Location loc, Value v,
SparseTensorEncodingAttr enc, unsigned lvl) {

Value stride = getSliceStride(builder, loc, enc, lvl);
Value offset = getSliceOffset(builder, loc, enc, lvl);
Value offset, Value stride, Value tensor,
unsigned lvl) {
// iv = iv * stride + offset
v = builder.create<arith::MulIOp>(loc, v, stride);
v = builder.create<arith::AddIOp>(loc, v, offset);
Expand All @@ -75,40 +71,58 @@ static Value toSliceCoord(OpBuilder &builder, Location loc, Value v,
// Converts a coordinate relative to the underlying tensor to the coordinate
// relative to the slice, returns a extra reminder value
static std::pair<Value, Value> fromSliceCrd(OpBuilder &builder, Location loc,
Value v,
SparseTensorEncodingAttr enc,
Value iv, Value offset,
Value stride, Value tensor,
unsigned lvl) {
Value stride = getSliceStride(builder, loc, enc, lvl);
Value offset = getSliceOffset(builder, loc, enc, lvl);
// iv = (iv - offset) / stride
v = builder.create<arith::SubIOp>(loc, v, offset);
Value rem = builder.create<arith::RemUIOp>(loc, v, stride);
v = builder.create<arith::DivUIOp>(loc, v, stride);
return std::make_pair(v, rem);
iv = builder.create<arith::SubIOp>(loc, iv, offset);
Value rem = builder.create<arith::RemUIOp>(loc, iv, stride);
iv = builder.create<arith::DivUIOp>(loc, iv, stride);
return std::make_pair(iv, rem);
}

static std::pair<Value, Value>
genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd,
SparseTensorEncodingAttr enc, unsigned lvl) {
std::pair<Value, Value> trans = fromSliceCrd(builder, loc, crd, enc, lvl);
// First, crd >= offset (TODO: seems unsigned >= 0 won't be folded, skip
// the check if the offset is zero).
auto geOffset =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, crd,
getSliceOffset(builder, loc, enc, lvl));
std::pair<Value, Value>
LoopEmitter::genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd,
unsigned tid, unsigned lvl) {
assert(isSparseSlices[tid]);
Value slice = tensors[tid];
Value offset = sliceOffsets[tid][lvl];
Value stride = sliceStrides[tid][lvl];
auto enc = getSparseTensorEncoding(slice.getType());

std::pair<Value, Value> transformedCrd =
fromSliceCrd(builder, loc, crd, offset, stride, slice, lvl);

SmallVector<Value, 3> conds; // at most 3 conditions

// First, coord >= offset (skip the check if offset is known to be 0).
if (auto staticOffset = enc.getStaticLvlSliceOffset(lvl);
!(staticOffset.has_value() && *staticOffset == 0)) {
auto geOffset = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::uge, crd, offset);
conds.push_back(geOffset);
}

// Second, coord_in_slice < length
auto ltLength =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, trans.first,
getSliceSize(builder, loc, enc, lvl));

// Third, rem == 0; confirmed that (a % 1) will be folded to 0
auto fitStride =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, trans.second,
constantIndex(builder, loc, 0));

auto pred = builder.create<arith::AndIOp>(loc, geOffset, ltLength);
pred = builder.create<arith::AndIOp>(loc, pred, fitStride);
return {trans.first, pred};
auto ltLength = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ult, transformedCrd.first, lvlSizes[tid][lvl]);
conds.push_back(ltLength);

// Third, rem == 0 (skip the check if stride is known to be 1).
if (auto staticStride = enc.getStaticLvlSliceStride(lvl);
!(staticStride.has_value() && *staticStride == 1)) {
auto fitStride = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, transformedCrd.second,
constantIndex(builder, loc, 0));
conds.push_back(fitStride);
}

// Must meet all condition to be a valid coordinate in slice.
auto pred = conds.front();
for (auto cond : ValueRange(conds).drop_front())
pred = builder.create<arith::AndIOp>(loc, pred, cond);

return {transformedCrd.first, pred};
}

//===----------------------------------------------------------------------===//
Expand All @@ -119,10 +133,9 @@ Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, size_t tid,
size_t dim, Value iv) {
Value p = dim == 0 ? constantIndex(builder, loc, 0) : pidxs[tid][dim - 1];
Value mul = builder.create<arith::MulIOp>(loc, highs[tid][dim], p);
if (isSparseSlices[tid]) {
auto enc = getSparseTensorEncoding(tensors[tid].getType());
iv = toSliceCoord(builder, loc, iv, enc, dim);
}
if (isSparseSlices[tid])
iv = toSliceCoord(builder, loc, iv, sliceOffsets[tid][dim],
sliceStrides[tid][dim], tensors[tid], dim);
Value add = builder.create<arith::AddIOp>(loc, mul, iv);
return add;
}
Expand Down Expand Up @@ -204,6 +217,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
this->isSparseOut = isSparseOut;
this->tensors.assign(ts.begin(), ts.end());
this->isSparseSlices.assign(tensors.size(), false);
this->sliceOffsets.assign(tensors.size(), std::vector<Value>());
this->sliceStrides.assign(tensors.size(), std::vector<Value>());
this->dimTypes.assign(tensors.size(), std::vector<DimLevelType>());
this->pidxs.assign(tensors.size(), std::vector<Value>());
this->segHi.assign(tensors.size(), std::vector<Value>());
Expand Down Expand Up @@ -246,6 +261,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
dimTypes[tid].assign(rank, DimLevelType::Dense);

// Initialize using empty value.
sliceOffsets[tid].assign(rank, Value());
sliceStrides[tid].assign(rank, Value());
pidxs[tid].assign(rank, Value());
segHi[tid].assign(rank, Value());
coord[tid].assign(rank, Value());
Expand Down Expand Up @@ -300,11 +317,17 @@ void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc,
assert(isDenseDLT(dlt));
}

// Find upper bound in current dimension.
// FIXME: `toOrigDim` is deprecated
const Dimension d = toOrigDim(enc, l);
lvlSizes[t][l] = highs[t][l] =
mlir::linalg::createOrFoldDimOp(builder, loc, tensor, d);
// Since we do not have HigherOrdering now, we can always rely on the 1:1
// mapping from level to dimension to retrieve the level size.
Value lvlSz = mlir::linalg::createOrFoldDimOp(builder, loc, tensor,
toOrigDim(enc, l));
// Find upper bound in current dimension.
highs[t][l] = lvlSizes[t][l] = lvlSz;
if (isSparseSlices[t]) {
sliceOffsets[t][l] = genSliceOffset(builder, loc, tensors[t], l);
sliceStrides[t][l] = genSliceStride(builder, loc, tensors[t], l);
}
}

// Perform the required bufferization. Dense inputs materialize
Expand Down Expand Up @@ -405,7 +428,6 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
isSparseInput = isSparseInput || isSparse;
}

auto enc = getSparseTensorEncoding(tensors[tid].getType());
const auto reassoc = getCollapseReassociation(tid, dim);
// TODO: support dynamic slices.
// Uses the first dimension here to build the loop bound (which is also the
Expand Down Expand Up @@ -468,7 +490,7 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
for (Value red : reduc)
types.push_back(red.getType());

auto [trans, pred] = genSliceLegitPredicate(builder, loc, crd, enc, dim);
auto [trans, pred] = genSliceLegitPredicate(builder, loc, crd, tid, dim);
bool hasReduc = !types.empty();
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, pred,
/*else*/ hasReduc);
Expand Down Expand Up @@ -660,11 +682,8 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
isSingletonDLT(dimTypes[tid][dim])) {
coord[tid][dim] = genSparseCrd(builder, loc, tid, dim);
if (isSparseSlices[tid]) {
Value load =
genIndexLoad(builder, loc, crdBuffer[tid][dim], pidxs[tid][dim]);
auto enc = getSparseTensorEncoding(tensors[tid].getType());
auto [trans, pred] =
genSliceLegitPredicate(builder, loc, load, enc, dim);
genSliceLegitPredicate(builder, loc, coord[tid][dim], tid, dim);
slicesPreds.emplace_back(pred, i);
// Updates to the relative coordinate to the slice.
coord[tid][dim] = trans;
Expand All @@ -679,7 +698,7 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
// Generates a list of if statments
// pidx = in_slice ? pidx : pidx + 1
// TODO: instead of always picking pidx + 1, we should set pidx = high to
// break to loop the coordinates is larger than the slice size.
// break to loop if the coordinates is larger than the slice size.
for (auto [pred, idx] : slicesPreds) {
Value nextPidx = builder.create<arith::AddIOp>(
loc, yields[idx], constantIndex(builder, loc, 1));
Expand Down
10 changes: 10 additions & 0 deletions mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
Expand Up @@ -202,6 +202,13 @@ class LoopEmitter {
Value genSparseCrd(OpBuilder &builder, Location loc, size_t tid,
size_t dstLvl);

/// Generates a predicate to determine whether the tranformed coordinates are
/// in the given slice.
/// Returns std::pair<Transformed coordinates, Predicate>
std::pair<Value, Value> genSliceLegitPredicate(OpBuilder &builder,
Location loc, Value crd,
unsigned tid, unsigned lvl);

bool isOutputTensor(size_t tid) {
return hasOutput && tid == tensors.size() - 1;
}
Expand Down Expand Up @@ -278,6 +285,9 @@ class LoopEmitter {

/// Whether the sparse input is a slice.
std::vector<bool> isSparseSlices;
/// Values related to slices.
std::vector<std::vector<Value>> sliceOffsets;
std::vector<std::vector<Value>> sliceStrides;

/// Loop Stack, stores the information of all the nested loops that are
/// alive.
Expand Down
Expand Up @@ -130,17 +130,18 @@ Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc,
/// Builds IR extracting the pos-th offset from the descriptor.
Value SpecifierStructBuilder::dimOffset(OpBuilder &builder, Location loc,
Dimension dim) const {
return builder.create<LLVM::ExtractValueOp>(
loc, value,
ArrayRef<int64_t>({kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)}));
return extractField(
builder, loc,
ArrayRef<int64_t>{kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)});
}

/// Builds IR inserting the pos-th offset into the descriptor.
void SpecifierStructBuilder::setDimOffset(OpBuilder &builder, Location loc,
Dimension dim, Value size) {
value = builder.create<LLVM::InsertValueOp>(
loc, value, size,
ArrayRef<int64_t>({kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)}));
insertField(
builder, loc,
ArrayRef<int64_t>{kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)},
size);
}

/// Builds IR extracting the `lvl`-th level-size from the descriptor.
Expand Down

0 comments on commit 6db397a

Please sign in to comment.