Skip to content

Commit

Permalink
[mlir][sparse] Replace sparse_tensor.sort with sparse_tensor.sort_coo…
Browse files Browse the repository at this point in the history
… for sorting COO tensors.

Add codegen pattern for sparse_tensor.indices_buffer.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D140871
  • Loading branch information
bixia1 committed Jan 5, 2023
1 parent 47232be commit 81e3079
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 33 deletions.
26 changes: 23 additions & 3 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,26 @@ class SparseToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
}
};

/// Sparse codegen rule for accessing the linear indices buffer.
class SparseToIndicesBufferConverter
: public OpConversionPattern<ToIndicesBufferOp> {
public:
using OpAdaptor = typename ToIndicesBufferOp::Adaptor;
using OpConversionPattern<ToIndicesBufferOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ToIndicesBufferOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested pointer access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
SmallVector<Value> fields;
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
rewriter.replaceOp(op, desc.getAOSMemRef());

return success();
}
};

/// Sparse codegen rule for value accesses.
class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
public:
Expand Down Expand Up @@ -1005,9 +1025,9 @@ void mlir::populateSparseTensorCodegenPatterns(
SparseTensorLoadConverter, SparseExpandConverter,
SparseCompressConverter, SparseInsertConverter,
SparseToPointersConverter, SparseToIndicesConverter,
SparseToValuesConverter, SparseConvertConverter,
SparseNumberOfEntriesConverter>(typeConverter,
patterns.getContext());
SparseToIndicesBufferConverter, SparseToValuesConverter,
SparseConvertConverter, SparseNumberOfEntriesConverter>(
typeConverter, patterns.getContext());
patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
enableBufferInitialization);
}
65 changes: 41 additions & 24 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ static void sizesForTensor(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
// TODO: The dim level property of the COO type relies on input tensors, the
// shape relies on the output tensor
// Helpers to setup a COO type.
static RankedTensorType getUnorderedCOOFromType(RankedTensorType src) {
static RankedTensorType
getUnorderedCOOFromTypeWithOrdering(RankedTensorType src, AffineMap ordering) {
auto *ctx = src.getContext();
auto rank = src.getRank();
SmallVector<DimLevelType> dims;
Expand All @@ -176,12 +177,16 @@ static RankedTensorType getUnorderedCOOFromType(RankedTensorType src) {
// default value.
unsigned pointerBitWidth = encSrc ? encSrc.getPointerBitWidth() : 0;
unsigned indexBitWidth = encSrc ? encSrc.getIndexBitWidth() : 0;
auto enc = SparseTensorEncodingAttr::get(
ctx, dims, AffineMap::getMultiDimIdentityMap(rank, ctx), AffineMap(),
pointerBitWidth, indexBitWidth);
auto enc = SparseTensorEncodingAttr::get(ctx, dims, ordering, AffineMap(),
pointerBitWidth, indexBitWidth);
return RankedTensorType::get(src.getShape(), src.getElementType(), enc);
}

static RankedTensorType getUnorderedCOOFromType(RankedTensorType src) {
return getUnorderedCOOFromTypeWithOrdering(
src, AffineMap::getMultiDimIdentityMap(src.getRank(), src.getContext()));
}

/// Collects the dynamic dimension sizes for `tp` with the assumption that
/// `sizes` are the dimension sizes for the type. Stores the dynamic dimension
/// sizes to dynSizes.
Expand Down Expand Up @@ -771,6 +776,7 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
RankedTensorType srcTp = src.getType().cast<RankedTensorType>();
RankedTensorType dstTp = op.getType().cast<RankedTensorType>();
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
int64_t rank = dstTp.getRank();

SmallVector<Value> srcSizes;
sizesForTensor(rewriter, srcSizes, loc, srcTp, src);
Expand All @@ -788,16 +794,21 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
// the overhead types.
SmallVector<Value> dynSrcSizes;
getDynamicSizes(srcTp, srcSizes, dynSrcSizes);
srcTp = getUnorderedCOOFromType(srcTp);
srcTp =
getUnorderedCOOFromTypeWithOrdering(srcTp, encDst.getDimOrdering());
tmpCoo =
rewriter.create<AllocTensorOp>(loc, srcTp, dynSrcSizes).getResult();
auto foreachOp = rewriter.create<ForeachOp>(
loc, src, tmpCoo,
[&](OpBuilder &builder, Location loc, ValueRange args, Value v,
ValueRange reduc) {
// The resulting COO tensor has identity ordering.
auto t = builder.create<InsertOp>(loc, v, reduc.front(),
args.slice(0, srcTp.getRank()));
SmallVector<Value> dstIndices(srcTp.getRank(), Value());
for (int64_t i = 0; i < rank; i++) {
uint64_t dim = toStoredDim(encDst, i);
dstIndices[dim] = args[i];
}
auto t =
builder.create<InsertOp>(loc, v, reduc.front(), dstIndices);
builder.create<sparse_tensor::YieldOp>(loc, t);
});
src = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
Expand All @@ -806,29 +817,35 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
// Only need to sort if the srcTp is not already sorted (we faithfully take
// the guarantee from the sparse tensor encoding).
if (!isAllDimOrdered(srcTp)) {
// Sort the COO tensor so that its elements are ordered via increasing
// indices for the storage ordering of the dst tensor.
SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp);
uint64_t rank = dstTp.getRank();
uint64_t cooStart = getCOOStart(encSrc);
// Gather the indices-arrays in the dst tensor storage order.
SmallVector<Value> xs(rank, Value());
for (uint64_t i = 0; i < rank; i++) {
uint64_t orgDim = toOrigDim(encSrc, i);
xs[toStoredDim(encDst, orgDim)] =
genToIndices(rewriter, loc, src, i, cooStart);
}

// Retrieve NNZ.
Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src);
nnz = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
nnz);

// Retrieve the values-array.
Value y = genToValues(rewriter, loc, src);

// Sort the COO tensor.
rewriter.create<SortOp>(loc, nnz, xs, ValueRange{y});
SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp);
// Sort the COO tensor so that its elements are ordered via increasing
// indices for the storage ordering of the dst tensor. Use SortCoo if the
// COO tensor has the same dim ordering as the dst tensor.
if (rank > 1 && hasSameDimOrdering(srcTp, dstTp)) {
MemRefType indTp =
get1DMemRefType(getIndexOverheadType(rewriter, encSrc),
/*withLayout=*/false);
Value xs = rewriter.create<ToIndicesBufferOp>(loc, indTp, src);
rewriter.create<SortCooOp>(loc, nnz, xs, ValueRange{y},
rewriter.getIndexAttr(rank),
rewriter.getIndexAttr(0));
} else {
// Gather the indices-arrays in the dst tensor storage order.
SmallVector<Value> xs(rank, Value());
for (uint64_t i = 0; i < rank; i++) {
uint64_t orgDim = toOrigDim(encSrc, i);
xs[toStoredDim(encDst, orgDim)] =
genToIndices(rewriter, loc, src, i, /*cooStart=*/0);
}
rewriter.create<SortOp>(loc, nnz, xs, ValueRange{y});
}
}

// For each element in the COO tensor, insert the element to the dst tensor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,13 @@ class MutSparseTensorDescriptor : public SparseTensorDescriptorImpl<true> {
return layout.getFieldIndexAndStride(SparseTensorFieldKind::IdxMemRef,
idxDim);
}

Value getAOSMemRef() const {
auto enc = getSparseTensorEncoding(rType);
unsigned cooStart = getCOOStart(enc);
assert(cooStart < enc.getDimLevelType().size());
return getIdxMemRef(cooStart);
}
};

class SparseTensorDescriptor : public SparseTensorDescriptorImpl<false> {
Expand Down
13 changes: 13 additions & 0 deletions mlir/test/Dialect/SparseTensor/codegen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,19 @@ func.func @sparse_indices_coo(%arg0: tensor<?x?x?xf64, #ccoo>) -> memref<?xindex
return %0 : memref<?xindex, strided<[?], offset: ?>>
}

// CHECK-LABEL: func.func @sparse_indices_buffer_coo(
// CHECK-SAME: %[[A0:.*0]]: memref<?xindex>,
// CHECK-SAME: %[[A1:.*1]]: memref<?xindex>,
// CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
// CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
// CHECK: return %[[A3]] : memref<?xindex>
func.func @sparse_indices_buffer_coo(%arg0: tensor<?x?x?xf64, #ccoo>) -> memref<?xindex> {
%0 = sparse_tensor.indices_buffer %arg0 : tensor<?x?x?xf64, #ccoo> to memref<?xindex>
return %0 : memref<?xindex>
}

// CHECK-LABEL: func @sparse_noe(
// CHECK-SAME: %[[A0:.*]]: memref<?xi32>,
// CHECK-SAME: %[[A1:.*]]: memref<?xi64>,
Expand Down
10 changes: 4 additions & 6 deletions mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,10 @@ func.func @sparse_convert_complex(%arg0: tensor<100xcomplex<f64>>) -> tensor<100
// CHECK-RWT: sparse_tensor.yield %[[IFR]]
// CHECK-RWT: }
// CHECK-RWT: %[[COO:.*]] = sparse_tensor.load %[[T2]] hasInserts
// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index}
// CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[COO]] {dimension = 1 : index}
// CHECK-RWT: %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[COO]]
// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[COO]]
// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]], %[[I1]] jointly %[[V]]
// CHECK-RWT: %[[I:.*]] = sparse_tensor.indices_buffer %[[COO]]
// CHECK-RWT: sparse_tensor.sort_coo %[[NNZ]], %[[I]] jointly %[[V]] {nx = 2 : index, ny = 0 : index}
// CHECK-RWT: %[[T3:.*]] = bufferization.alloc_tensor()
// CHECK-RWT: %[[T4:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[T3]])
// CHECK-RWT: ^bb0(%[[L1I0:.*]]: index, %[[L1I1:.*]]: index, %[[L1V:.*]]: f64, %[[L1T:.*]]: tensor
Expand Down Expand Up @@ -182,11 +181,10 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> {
// CHECK-RWT: sparse_tensor.yield %[[L0T2]]
// CHECK-RWT: }
// CHECK-RWT: %[[COO:.*]] = sparse_tensor.load %[[T1]] hasInserts
// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index}
// CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[COO]] {dimension = 1 : index}
// CHECK-RWT: %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[COO]]
// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[COO]]
// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]], %[[I1]] jointly %[[V]]
// CHECK-RWT: %[[I:.*]] = sparse_tensor.indices_buffer %[[COO]]
// CHECK-RWT: sparse_tensor.sort_coo %[[NNZ]], %[[I]] jointly %[[V]] {nx = 2 : index, ny = 0 : index}
// CHECK-RWT: %[[T3:.*]] = bufferization.alloc_tensor()
// CHECK-RWT: %[[T4:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[T3]])
// CHECK-RWT: ^bb0(%[[L1I0:.*]]: index, %[[L1I1:.*]]: index, %[[L1V:.*]]: f32, %[[L1T:.*]]: tensor
Expand Down

0 comments on commit 81e3079

Please sign in to comment.