Skip to content

Commit

Permalink
[mlir][sparse] avoid using mutable descriptor when unnecessary (NFC)
Browse files Browse the repository at this point in the history
Use SparseTensorDescriptor whenever not calling setters, to avoid needing to create a temporal buffer for simple query purposes.

Reviewed By: bixia, wrengr

Differential Revision: https://reviews.llvm.org/D141953
  • Loading branch information
Peiming Liu committed Jan 17, 2023
1 parent bf1ba6b commit 83a5083
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 87 deletions.
6 changes: 2 additions & 4 deletions mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,5 @@ Value sparse_tensor::genToValues(OpBuilder &builder, Location loc,

Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc,
Value tensor) {
SmallVector<Value> fields;
auto desc = getMutDescriptorFromTensorTuple(tensor, fields);
return desc.getValMemSize(builder, loc);
}
return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc);
}
46 changes: 16 additions & 30 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,9 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
}

/// Gets the dimension size for the given sparse tensor at the given
/// original dimension 'dim'. Returns std::nullopt if no sparse encoding is
/// attached to the given tensor type.
static std::optional<Value>
sizeFromTensorAtDim(OpBuilder &builder, Location loc,
const SparseTensorDescriptor &desc, unsigned dim) {
/// original dimension 'dim'.
static Value sizeFromTensorAtDim(OpBuilder &builder, Location loc,
SparseTensorDescriptor desc, unsigned dim) {
RankedTensorType rtp = desc.getTensorType();
// Access into static dimension can query original type directly.
// Note that this is typically already done by DimOp's folding.
Expand All @@ -119,17 +117,12 @@ sizeFromTensorAtDim(OpBuilder &builder, Location loc,
return desc.getDimSize(builder, loc, toStoredDim(rtp, dim));
}

// Gets the dimension size at the given stored dimension 'd', either as a
// Gets the dimension size at the given stored level 'lvl', either as a
// constant for a static size, or otherwise dynamically through memSizes.
Value sizeAtStoredDim(OpBuilder &builder, Location loc,
MutSparseTensorDescriptor desc, unsigned d) {
RankedTensorType rtp = desc.getTensorType();
unsigned dim = toOrigDim(rtp, d);
auto shape = rtp.getShape();
if (!ShapedType::isDynamic(shape[dim]))
return constantIndex(builder, loc, shape[dim]);

return desc.getDimSize(builder, loc, d);
static Value sizeFromTensorAtLvl(OpBuilder &builder, Location loc,
SparseTensorDescriptor desc, unsigned lvl) {
return sizeFromTensorAtDim(builder, loc, desc,
toOrigDim(desc.getTensorType(), lvl));
}

static void createPushback(OpBuilder &builder, Location loc,
Expand Down Expand Up @@ -174,7 +167,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
// at this level. We will eventually reach a compressed level or
// otherwise the values array for the from-here "all-dense" case.
assert(isDenseDim(rtp, r));
Value size = sizeAtStoredDim(builder, loc, desc, r);
Value size = sizeFromTensorAtLvl(builder, loc, desc, r);
linear = builder.create<arith::MulIOp>(loc, linear, size);
}
// Reached values array so prepare for an insertion.
Expand Down Expand Up @@ -436,7 +429,7 @@ static void genInsertBody(OpBuilder &builder, ModuleOp module,
// Construct the new position as:
// pos[d] = size * pos[d-1] + i[d]
// <insert @ pos[d] at next dimension d + 1>
Value size = sizeAtStoredDim(builder, loc, desc, d);
Value size = sizeFromTensorAtLvl(builder, loc, desc, d);
Value mult = builder.create<arith::MulIOp>(loc, size, pos);
pos = builder.create<arith::AddIOp>(loc, mult, indices[d]);
}
Expand Down Expand Up @@ -517,7 +510,7 @@ static void genInsertionCallHelper(OpBuilder &builder,

/// Generations insertion finalization code.
static void genEndInsert(OpBuilder &builder, Location loc,
MutSparseTensorDescriptor desc) {
SparseTensorDescriptor desc) {
RankedTensorType rtp = desc.getTensorType();
unsigned rank = rtp.getShape().size();
for (unsigned d = 0; d < rank; d++) {
Expand Down Expand Up @@ -654,10 +647,7 @@ class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *index);

if (!sz)
return failure();

rewriter.replaceOp(op, *sz);
rewriter.replaceOp(op, sz);
return success();
}
};
Expand Down Expand Up @@ -727,8 +717,7 @@ class SparseTensorDeallocConverter

// Replace the sparse tensor deallocation with field deallocations.
Location loc = op.getLoc();
SmallVector<Value> fields;
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
for (auto input : desc.getMemRefFields())
// Deallocate every buffer used to store the sparse tensor handler.
rewriter.create<memref::DeallocOp>(loc, input);
Expand All @@ -746,8 +735,7 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
matchAndRewrite(LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Prepare descriptor.
SmallVector<Value> fields;
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
// Generate optional insertion finalization code.
if (op.getHasInserts())
genEndInsert(rewriter, op.getLoc(), desc);
Expand Down Expand Up @@ -780,11 +768,10 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
// recursively rewrite the new DimOp on the **original** tensor.
unsigned innerDim = toOrigDim(srcType, srcType.getRank() - 1);
auto sz = sizeFromTensorAtDim(rewriter, loc, desc, innerDim);
assert(sz); // This for sure is a sparse tensor
// Generate a memref for `sz` elements of type `t`.
auto genAlloc = [&](Type t) {
auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{*sz});
return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
};
// Allocate temporary buffers for values/filled-switch and added.
// We do not use stack buffers for this, since the expanded size may
Expand Down Expand Up @@ -957,8 +944,7 @@ class SparseToIndicesBufferConverter
// Replace the requested pointer access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
SmallVector<Value> fields;
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
rewriter.replaceOp(op, desc.getAOSMemRef());

return success();
Expand Down
110 changes: 57 additions & 53 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,29 +202,18 @@ class SparseTensorSpecifier {
/// field in a consistent way.
/// Users should not make assumption on how a sparse tensor is laid out but
/// instead relies on this class to access the right value for the right field.
template <bool mut>
template <typename ValueArrayRef>
class SparseTensorDescriptorImpl {
protected:
// Uses ValueRange for immuatable descriptors; uses SmallVectorImpl<Value> &
// for mutable descriptors.
// Using SmallVector for mutable descriptor allows users to reuse it as a tmp
// buffers to append value for some special cases, though users should be
// responsible to restore the buffer to legal states after their use. It is
// probably not a clean way, but it is the most efficient way to avoid copying
// the fields into another SmallVector. If a more clear way is wanted, we
// should change it to MutableArrayRef instead.
using ValueArrayRef = typename std::conditional<mut, SmallVectorImpl<Value> &,
ValueRange>::type;

SparseTensorDescriptorImpl(Type tp, ValueArrayRef fields)
: rType(tp.cast<RankedTensorType>()), fields(fields) {
assert(getSparseTensorEncoding(tp) &&
getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) ==
fields.size());
// We should make sure the class is trivially copyable (and should be small
// enough) such that we can pass it by value.
static_assert(
std::is_trivially_copyable_v<SparseTensorDescriptorImpl<mut>>);
static_assert(std::is_trivially_copyable_v<
SparseTensorDescriptorImpl<ValueArrayRef>>);
}

public:
Expand Down Expand Up @@ -262,12 +251,12 @@ class SparseTensorDescriptorImpl {

Value getMemRefField(SparseTensorFieldKind kind,
std::optional<unsigned> dim) const {
return fields[getMemRefFieldIndex(kind, dim)];
return getField(getMemRefFieldIndex(kind, dim));
}

Value getMemRefField(unsigned fidx) const {
assert(fidx < fields.size() - 1);
return fields[fidx];
return getField(fidx);
}

Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
Expand All @@ -293,6 +282,31 @@ class SparseTensorDescriptorImpl {
.getElementType();
}

Value getField(unsigned fidx) const {
assert(fidx < fields.size());
return fields[fidx];
}

ValueRange getMemRefFields() const {
ValueRange ret = fields;
// Drop the last metadata fields.
return ret.slice(0, fields.size() - 1);
}

std::pair<unsigned, unsigned>
getIdxMemRefIndexAndStride(unsigned idxDim) const {
StorageLayout layout(getSparseTensorEncoding(rType));
return layout.getFieldIndexAndStride(SparseTensorFieldKind::IdxMemRef,
idxDim);
}

Value getAOSMemRef() const {
auto enc = getSparseTensorEncoding(rType);
unsigned cooStart = getCOOStart(enc);
assert(cooStart < enc.getDimLevelType().size());
return getMemRefField(SparseTensorFieldKind::IdxMemRef, cooStart);
}

RankedTensorType getTensorType() const { return rType; }
ValueArrayRef getFields() const { return fields; }

Expand All @@ -301,25 +315,38 @@ class SparseTensorDescriptorImpl {
ValueArrayRef fields;
};

class MutSparseTensorDescriptor : public SparseTensorDescriptorImpl<true> {
/// Uses ValueRange for immuatable descriptors;
class SparseTensorDescriptor : public SparseTensorDescriptorImpl<ValueRange> {
public:
MutSparseTensorDescriptor(Type tp, ValueArrayRef buffers)
: SparseTensorDescriptorImpl<true>(tp, buffers) {}
SparseTensorDescriptor(Type tp, ValueRange buffers)
: SparseTensorDescriptorImpl<ValueRange>(tp, buffers) {}

Value getField(unsigned fidx) const {
assert(fidx < fields.size());
return fields[fidx];
}
Value getIdxMemRefOrView(OpBuilder &builder, Location loc,
unsigned idxDim) const;
};

ValueRange getMemRefFields() const {
ValueRange ret = fields;
// Drop the last metadata fields.
return ret.slice(0, fields.size() - 1);
/// Uses SmallVectorImpl<Value> & for mutable descriptors.
/// Using SmallVector for mutable descriptor allows users to reuse it as a
/// tmp buffers to append value for some special cases, though users should
/// be responsible to restore the buffer to legal states after their use. It
/// is probably not a clean way, but it is the most efficient way to avoid
/// copying the fields into another SmallVector. If a more clear way is
/// wanted, we should change it to MutableArrayRef instead.
class MutSparseTensorDescriptor
: public SparseTensorDescriptorImpl<SmallVectorImpl<Value> &> {
public:
MutSparseTensorDescriptor(Type tp, SmallVectorImpl<Value> &buffers)
: SparseTensorDescriptorImpl<SmallVectorImpl<Value> &>(tp, buffers) {}

// Allow implicit type conversion from mutable descriptors to immutable ones
// (but not vice versa).
/*implicit*/ operator SparseTensorDescriptor() const {
return SparseTensorDescriptor(rType, fields);
}

///
/// Setters: update the value for required field (only enabled for
/// MutSparseTensorDescriptor).
/// Adds additional setters for mutable descriptor, update the value for
/// required field.
///

void setMemRefField(SparseTensorFieldKind kind, std::optional<unsigned> dim,
Expand Down Expand Up @@ -348,29 +375,6 @@ class MutSparseTensorDescriptor : public SparseTensorDescriptorImpl<true> {
void setDimSize(OpBuilder &builder, Location loc, unsigned dim, Value v) {
setSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim, v);
}

std::pair<unsigned, unsigned>
getIdxMemRefIndexAndStride(unsigned idxDim) const {
StorageLayout layout(getSparseTensorEncoding(rType));
return layout.getFieldIndexAndStride(SparseTensorFieldKind::IdxMemRef,
idxDim);
}

Value getAOSMemRef() const {
auto enc = getSparseTensorEncoding(rType);
unsigned cooStart = getCOOStart(enc);
assert(cooStart < enc.getDimLevelType().size());
return getMemRefField(SparseTensorFieldKind::IdxMemRef, cooStart);
}
};

class SparseTensorDescriptor : public SparseTensorDescriptorImpl<false> {
public:
SparseTensorDescriptor(Type tp, ValueArrayRef buffers)
: SparseTensorDescriptorImpl<false>(tp, buffers) {}

Value getIdxMemRefOrView(OpBuilder &builder, Location loc,
unsigned idxDim) const;
};

/// Returns the "tuple" value of the adapted tensor.
Expand All @@ -386,7 +390,7 @@ inline Value genTuple(OpBuilder &builder, Location loc, Type tp,
}

inline Value genTuple(OpBuilder &builder, Location loc,
MutSparseTensorDescriptor desc) {
SparseTensorDescriptor desc) {
return genTuple(builder, loc, desc.getTensorType(), desc.getFields());
}

Expand Down

0 comments on commit 83a5083

Please sign in to comment.