Skip to content

Commit

Permalink
[mlir][sparse] Converting SparseTensorCOO to use standard C++-style i…
Browse files Browse the repository at this point in the history
…terators.

This differential comprises three related changes: (1) it gives SparseTensorCOO standard C++-style iterators; (2) it removes the old iterator stuff from SparseTensorCOO; and (3) it introduces SparseTensorIterator which behaves like the old SparseTensorCOO iterator stuff used to.

The SparseTensorIterator class is needed because the MLIR codegen cannot easily use the C++-style iterators (hence why SparseTensorCOO had the old iterator stuff).  Distinguishing SparseTensorIterator from SparseTensorCOO also helps improve API hygiene since these two classes are used for distinct purposes.  And having SparseTensorIterator as its own class enables changing the underlying implementation in the future, without needing to worry about updating all the codegen tests etc.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D135485
  • Loading branch information
wrengr committed Oct 11, 2022
1 parent 1079662 commit 90fd13b
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 49 deletions.
53 changes: 26 additions & 27 deletions mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,30 @@ using ElementConsumer =
/// an intermediate representation; e.g., for reading sparse tensors
/// from external formats into memory, or for certain conversions between
/// different `SparseTensorStorage` formats.
///
/// This class provides all the typedefs required by the "Container"
/// concept (<https://en.cppreference.com/w/cpp/named_req/Container>);
/// however, beware that it cannot fully implement that concept since
/// it cannot have a default ctor (because the `dimSizes` field is const).
/// Thus these typedefs are provided for familiarity reasons, rather
/// than as a proper implementation of the concept.
template <typename V>
class SparseTensorCOO final {
public:
using value_type = const Element<V>;
using reference = value_type &;
using const_reference = reference;
// The types associated with `std::vector` differ significantly between
// C++11/17 vs C++20; so we explicitly defer to whatever `std::vector`
// says the types should be.
using vector_type = std::vector<Element<V>>;
using iterator = typename vector_type::const_iterator;
using const_iterator = iterator;
using difference_type = typename vector_type::difference_type;
using size_type = typename vector_type::size_type;

SparseTensorCOO(const std::vector<uint64_t> &dimSizes, uint64_t capacity)
: dimSizes(dimSizes), isSorted(true), iteratorLocked(false),
iteratorPos(0) {
: dimSizes(dimSizes), isSorted(true) {
if (capacity) {
elements.reserve(capacity);
indices.reserve(capacity * getRank());
Expand Down Expand Up @@ -129,12 +147,12 @@ class SparseTensorCOO final {
/// Resolving such conflicts is left up to clients of the iterator
/// interface.
///
/// This method invalidates all iterators.
///
/// Asserts:
/// * is not in iterator mode
/// * the `ind` is valid for `rank`
/// * the elements of `ind` are valid for `dimSizes`.
void add(const std::vector<uint64_t> &ind, V val) {
assert(!iteratorLocked && "Attempt to add() after startIterator()");
const uint64_t *base = indices.data();
uint64_t size = indices.size();
uint64_t rank = getRank();
Expand All @@ -161,44 +179,25 @@ class SparseTensorCOO final {
elements.push_back(addedElem);
}

const_iterator begin() const { return elements.cbegin(); }
const_iterator end() const { return elements.cend(); }

/// Sorts elements lexicographically by index. If an index is mapped to
/// multiple values, then the relative order of those values is unspecified.
///
/// Asserts: is not in iterator mode.
/// This method invalidates all iterators.
void sort() {
assert(!iteratorLocked && "Attempt to sort() after startIterator()");
if (isSorted)
return;
std::sort(elements.begin(), elements.end(), getElementLT());
isSorted = true;
}

/// Switches into iterator mode. If already in iterator mode, then
/// resets the position to the first element.
void startIterator() {
iteratorLocked = true;
iteratorPos = 0;
}

/// Gets the next element. If there are no remaining elements, then
/// returns nullptr and switches out of iterator mode.
///
/// Asserts: is in iterator mode.
const Element<V> *getNext() {
assert(iteratorLocked && "Attempt to getNext() before startIterator()");
if (iteratorPos < elements.size())
return &(elements[iteratorPos++]);
iteratorLocked = false;
return nullptr;
}

private:
const std::vector<uint64_t> dimSizes; // per-dimension sizes
std::vector<Element<V>> elements; // all COO elements
std::vector<uint64_t> indices; // shared index pool
bool isSorted;
bool iteratorLocked;
unsigned iteratorPos;
};

} // namespace sparse_tensor
Expand Down
13 changes: 10 additions & 3 deletions mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ extern "C" {
/// This is the "swiss army knife" method for materializing sparse
/// tensors into the computation. The types of the `ptr` argument and
/// the result depend on the action, as explained in the following table
/// (where "STS" means a sparse-tensor-storage object, and "COO" means
/// a coordinate-scheme object).
/// (where "STS" means a sparse-tensor-storage object, "COO" means
/// a coordinate-scheme object, and "Iterator" means an iterator object).
///
/// Action: `ptr`: Returns:
/// kEmpty unused STS, empty
Expand All @@ -53,7 +53,8 @@ extern "C" {
/// kFromCOO COO STS, copied from the COO source
/// kToCOO STS COO, copied from the STS source
/// kSparseToSparse STS STS, copied from the STS source
/// kToIterator STS COO-Iterator, call @getNext to use
/// kToIterator STS Iterator, call @getNext to use and
/// @delSparseTensorIterator to free.
MLIR_CRUNNERUTILS_EXPORT void *
_mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
StridedMemRefType<index_type, 1> *sref,
Expand Down Expand Up @@ -150,6 +151,12 @@ MLIR_CRUNNERUTILS_EXPORT void delSparseTensor(void *tensor);
MLIR_SPARSETENSOR_FOREVERY_V(DECL_DELCOO)
#undef DECL_DELCOO

/// Releases the memory for an iterator object.
#define DECL_DELITER(VNAME, V) \
MLIR_CRUNNERUTILS_EXPORT void delSparseTensorIterator##VNAME(void *iter);
MLIR_SPARSETENSOR_FOREVERY_V(DECL_DELITER)
#undef DECL_DELITER

/// Helper function to read a sparse tensor filename from the environment,
/// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
MLIR_CRUNNERUTILS_EXPORT char *getTensorFilename(index_type id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,14 @@ static void genDelCOOCall(OpBuilder &builder, Location loc, Type elemTp,
createFuncCall(builder, loc, name, {}, coo, EmitCInterface::Off);
}

/// Generates a call to release/delete a `SparseTensorIterator`.
static void genDelIteratorCall(OpBuilder &builder, Location loc, Type elemTp,
Value iter) {
SmallString<26> name{"delSparseTensorIterator",
primaryTypeFunctionSuffix(elemTp)};
createFuncCall(builder, loc, name, {}, iter, EmitCInterface::Off);
}

/// Generates a call that adds one element to a coordinate scheme.
/// In particular, this generates code like the following:
/// val = a[i1,..,ik];
Expand All @@ -335,7 +343,7 @@ static void genAddEltCall(OpBuilder &builder, Location loc, Type eltType,
/// Generates a call to `iter->getNext()`. If there is a next element,
/// then it is copied into the out-parameters `ind` and `elemPtr`,
/// and the return value is true. If there isn't a next element, then
/// the memory for `iter` is freed and the return value is false.
/// the return value is false.
static Value genGetNextCall(OpBuilder &builder, Location loc, Value iter,
Value ind, Value elemPtr) {
Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType();
Expand Down Expand Up @@ -572,7 +580,7 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
params[7] = coo;
Value dst = genNewCall(rewriter, loc, params);
genDelCOOCall(rewriter, loc, elemTp, coo);
genDelCOOCall(rewriter, loc, elemTp, iter);
genDelIteratorCall(rewriter, loc, elemTp, iter);
rewriter.replaceOp(op, dst);
return success();
}
Expand All @@ -584,6 +592,7 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
// }
// TODO: It can be used by other operators (ReshapeOp, ConvertOP) conversion to
// reduce code repetition!
// TODO: rename to `genSparseIterationLoop`?
static void genSparseCOOIterationLoop(
ConversionPatternRewriter &rewriter, Location loc, Value t,
RankedTensorType tensorTp,
Expand Down Expand Up @@ -624,7 +633,7 @@ static void genSparseCOOIterationLoop(
rewriter.setInsertionPointAfter(whileOp);

// Free memory for iterator.
genDelCOOCall(rewriter, loc, elemTp, iter);
genDelIteratorCall(rewriter, loc, elemTp, iter);
}

// Generate loop that iterates over a dense tensor.
Expand Down Expand Up @@ -875,11 +884,11 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
if (!encDst && encSrc) {
// This is sparse => dense conversion, which is handled as follows:
// dst = new Tensor(0);
// iter = src->toCOO();
// iter->startIterator();
// iter = new SparseTensorIterator(src);
// while (elem = iter->getNext()) {
// dst[elem.indices] = elem.value;
// }
// delete iter;
RankedTensorType dstTensorTp = resType.cast<RankedTensorType>();
RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>();
unsigned rank = dstTensorTp.getRank();
Expand Down Expand Up @@ -918,7 +927,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, ivs);
rewriter.create<scf::YieldOp>(loc);
rewriter.setInsertionPointAfter(whileOp);
genDelCOOCall(rewriter, loc, elemTp, iter);
genDelIteratorCall(rewriter, loc, elemTp, iter);
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, dst);
// Deallocate the buffer.
if (bufferization::allocationDoesNotEscape(op->getOpResult(0))) {
Expand Down
53 changes: 49 additions & 4 deletions mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,44 @@ using namespace mlir::sparse_tensor;

namespace {

/// Wrapper class to avoid memory leakage issues. The `SparseTensorCOO<V>`
/// class provides a standard C++ iterator interface, where the iterator
/// is implemented as per `std::vector`'s iterator. However, for MLIR's
/// usage we need to have an iterator which also holds onto the underlying
/// `SparseTensorCOO<V>` so that it can be freed whenever the iterator
/// is freed.
//
// We name this `SparseTensorIterator` rather than `SparseTensorCOOIterator`
// for future-proofing, since the use of `SparseTensorCOO` is an
// implementation detail that we eventually want to change (e.g., to
// use `SparseTensorEnumerator` directly, rather than constructing the
// intermediate `SparseTensorCOO` at all).
template <typename V>
class SparseTensorIterator final {
public:
/// This ctor requires `coo` to be a non-null pointer to a dynamically
/// allocated object, and takes ownership of that object. Therefore,
/// callers must not free the underlying COO object, since the iterator's
/// dtor will do so.
explicit SparseTensorIterator(const SparseTensorCOO<V> *coo)
: coo(coo), it(coo->begin()), end(coo->end()) {}

~SparseTensorIterator() { delete coo; }

// Disable copy-ctor and copy-assignment, to prevent double-free.
SparseTensorIterator(const SparseTensorIterator<V> &) = delete;
SparseTensorIterator<V> &operator=(const SparseTensorIterator<V> &) = delete;

/// Gets the next element. If there are no remaining elements, then
/// returns nullptr.
const Element<V> *getNext() { return it < end ? &*it++ : nullptr; }

private:
const SparseTensorCOO<V> *const coo; // Owning pointer.
typename SparseTensorCOO<V>::const_iterator it;
const typename SparseTensorCOO<V>::const_iterator end;
};

/// Initializes sparse tensor from an external COO-flavored format.
/// Used by `IMPL_CONVERTTOMLIRSPARSETENSOR`.
// TODO: generalize beyond 64-bit indices.
Expand Down Expand Up @@ -194,7 +232,7 @@ extern "C" {
return SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm); \
coo = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm); \
if (action == Action::kToIterator) { \
coo->startIterator(); \
return new SparseTensorIterator<V>(coo); \
} else { \
assert(action == Action::kToCOO); \
} \
Expand Down Expand Up @@ -398,16 +436,16 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_ADDELT)
#undef IMPL_ADDELT

#define IMPL_GETNEXT(VNAME, V) \
bool _mlir_ciface_getNext##VNAME(void *coo, \
bool _mlir_ciface_getNext##VNAME(void *iter, \
StridedMemRefType<index_type, 1> *iref, \
StridedMemRefType<V, 0> *vref) { \
assert(coo &&iref &&vref); \
assert(iter &&iref &&vref); \
assert(iref->strides[0] == 1); \
index_type *indx = iref->data + iref->offset; \
V *value = vref->data + vref->offset; \
const uint64_t isize = iref->sizes[0]; \
const Element<V> *elem = \
static_cast<SparseTensorCOO<V> *>(coo)->getNext(); \
static_cast<SparseTensorIterator<V> *>(iter)->getNext(); \
if (elem == nullptr) \
return false; \
for (uint64_t r = 0; r < isize; r++) \
Expand Down Expand Up @@ -490,6 +528,13 @@ void delSparseTensor(void *tensor) {
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_DELCOO)
#undef IMPL_DELCOO

#define IMPL_DELITER(VNAME, V) \
void delSparseTensorIterator##VNAME(void *iter) { \
delete static_cast<SparseTensorIterator<V> *>(iter); \
}
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_DELITER)
#undef IMPL_DELITER

char *getTensorFilename(index_type id) {
char var[80];
sprintf(var, "TENSOR%" PRIu64, id);
Expand Down
10 changes: 5 additions & 5 deletions mlir/test/Dialect/SparseTensor/sparse_concat.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
// CHECK: memref.store %[[TMP_15]], %[[TMP_0]][%[[TMP_13]], %[[TMP_14]]] : memref<5x4xf64>
// CHECK: scf.yield
// CHECK: }
// CHECK: call @delSparseTensorCOOF64(%[[TMP_7]]) : (!llvm.ptr<i8>) -> ()
// CHECK: call @delSparseTensorIteratorF64(%[[TMP_7]]) : (!llvm.ptr<i8>) -> ()
// CHECK: %[[TMP_11:.*]] = bufferization.to_tensor %[[TMP_0]] : memref<5x4xf64>
// CHECK: return %[[TMP_11]] : tensor<5x4xf64>
// CHECK: }
Expand Down Expand Up @@ -141,7 +141,7 @@ func.func @concat_mix_dense(%arg0: tensor<2x4xf64>, %arg1: tensor<3x4xf64, #Spar
// CHECK: %[[TMP_25:.*]] = func.call @addEltF64(%[[TMP_7]], %[[TMP_20]], %[[TMP_10]], %[[TMP_5]]) : (!llvm.ptr<i8>, memref<f64>, memref<?xindex>, memref<?xindex>) -> !llvm.ptr<i8>
// CHECK: scf.yield
// CHECK: }
// CHECK: call @delSparseTensorCOOF64(%[[TMP_17]]) : (!llvm.ptr<i8>) -> ()
// CHECK: call @delSparseTensorIteratorF64(%[[TMP_17]]) : (!llvm.ptr<i8>) -> ()
// CHECK: %[[TMP_21:.*]] = call @newSparseTensor(%[[TMP_1]], %[[TMP_3]], %[[TMP_5]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c2_i32]], %[[TMP_7]]) : (memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK: call @delSparseTensorCOOF64(%[[TMP_7]]) : (!llvm.ptr<i8>) -> ()
// CHECK: return %[[TMP_21]] : !llvm.ptr<i8>
Expand Down Expand Up @@ -225,7 +225,7 @@ func.func @concat_mix_sparse(%arg0: tensor<2x4xf64>, %arg1: tensor<3x4xf64, #Spa
// CHECK: %[[TMP_25:.*]] = func.call @addEltF64(%[[TMP_7]], %[[TMP_20]], %[[TMP_10]], %[[TMP_5]]) : (!llvm.ptr<i8>, memref<f64>, memref<?xindex>, memref<?xindex>) -> !llvm.ptr<i8>
// CHECK: scf.yield
// CHECK: }
// CHECK: call @delSparseTensorCOOF64(%[[TMP_17]]) : (!llvm.ptr<i8>) -> ()
// CHECK: call @delSparseTensorIteratorF64(%[[TMP_17]]) : (!llvm.ptr<i8>) -> ()
// CHECK: %[[TMP_21:.*]] = call @newSparseTensor(%[[TMP_1]], %[[TMP_3]], %[[TMP_5]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c2_i32]], %[[TMP_7]]) : (memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK: call @delSparseTensorCOOF64(%[[TMP_7]]) : (!llvm.ptr<i8>) -> ()
// CHECK: return %[[TMP_21]] : !llvm.ptr<i8>
Expand Down Expand Up @@ -287,7 +287,7 @@ func.func @concat_mix_sparse_perm_dim1(%arg0: tensor<4x2xf64>, %arg1: tensor<4x3
// CHECK: memref.store %[[TMP_15]], %[[TMP_0]][%[[TMP_12]], %[[TMP_14]]] : memref<4x5xf64>
// CHECK: scf.yield
// CHECK: }
// CHECK: call @delSparseTensorCOOF64(%[[TMP_7]]) : (!llvm.ptr<i8>) -> ()
// CHECK: call @delSparseTensorIteratorF64(%[[TMP_7]]) : (!llvm.ptr<i8>) -> ()
// CHECK: %[[TMP_11:.*]] = bufferization.to_tensor %[[TMP_0]] : memref<4x5xf64>
// CHECK: return %[[TMP_11]] : tensor<4x5xf64>
// CHECK: }
Expand Down Expand Up @@ -348,7 +348,7 @@ func.func @concat_mix_dense_perm_dim1(%arg0: tensor<4x2xf64>, %arg1: tensor<4x3x
// CHECK: memref.store %[[TMP_16]], %[[TMP_0]][%[[TMP_13]], %[[TMP_15]]] : memref<3x5xf64>
// CHECK: scf.yield
// CHECK: }
// CHECK: call @delSparseTensorCOOF64(%[[TMP_8]]) : (!llvm.ptr<i8>) -> ()
// CHECK: call @delSparseTensorIteratorF64(%[[TMP_8]]) : (!llvm.ptr<i8>) -> ()
// CHECK: %[[TMP_12:.*]] = bufferization.to_tensor %[[TMP_1]] : memref<?x?xf64>
// CHECK: return %[[TMP_12]] : tensor<?x?xf64>
// CHECK: }
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
// CHECK-CONV: }
// CHECK-CONV: %[[N:.*]] = call @newSparseTensor
// CHECK-CONV: call @delSparseTensorCOOF64
// CHECK-CONV: call @delSparseTensorCOOF64
// CHECK-CONV: call @delSparseTensorIteratorF64
// CHECK-CONV: return %[[N]] : !llvm.ptr<i8>
//
// rewrite for codegen:
Expand Down Expand Up @@ -97,7 +97,7 @@ func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10x
// CHECK-CONV: }
// CHECK-CONV: %[[N:.*]] = call @newSparseTensor
// CHECK-CONV: call @delSparseTensorCOOF64
// CHECK-CONV: call @delSparseTensorCOOF64
// CHECK-CONV: call @delSparseTensorIteratorF64
// CHECK-CONV: return %[[N]] : !llvm.ptr<i8>
//
// rewrite for codegen:
Expand Down Expand Up @@ -172,7 +172,7 @@ func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<10
// CHECK-CONV: }
// CHECK-CONV: %[[N:.*]] = call @newSparseTensor
// CHECK-CONV: call @delSparseTensorCOOF64
// CHECK-CONV: call @delSparseTensorCOOF64
// CHECK-CONV: call @delSparseTensorIteratorF64
// CHECK-CONV: return %[[N]] : !llvm.ptr<i8>
//
// rewrite for codegen:
Expand Down Expand Up @@ -244,7 +244,7 @@ func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>) -> tensor<
// CHECK-CONV: }
// CHECK-CONV: %[[N:.*]] = call @newSparseTensor
// CHECK-CONV: call @delSparseTensorCOOF64
// CHECK-CONV: call @delSparseTensorCOOF64
// CHECK-CONV: call @delSparseTensorIteratorF64
// CHECK-CONV: return %[[N]] : !llvm.ptr<i8>
//
// rewrite for codegen:
Expand Down

0 comments on commit 90fd13b

Please sign in to comment.