Skip to content

Commit

Permalink
[mlir][sparse] Implement insertion sort for the stable sort operator.
Browse files Browse the repository at this point in the history
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D135182
  • Loading branch information
bixia1 committed Oct 6, 2022
1 parent 6ebc3ab commit 9409bbb
Show file tree
Hide file tree
Showing 3 changed files with 242 additions and 28 deletions.
187 changes: 173 additions & 14 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,19 @@ using namespace mlir::sparse_tensor;
// Helper methods for the actual rewriting rules.
//===---------------------------------------------------------------------===//

constexpr uint64_t loIdx = 0;
constexpr uint64_t hiIdx = 1;
constexpr uint64_t xStartIdx = 2;
static constexpr uint64_t loIdx = 0;
static constexpr uint64_t hiIdx = 1;
static constexpr uint64_t xStartIdx = 2;

static constexpr const char kMaySwapFuncNamePrefix[] = "_sparse_may_swap_";
static constexpr const char kLessThanFuncNamePrefix[] = "_sparse_less_than_";
static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
static constexpr const char kBinarySearchFuncNamePrefix[] =
"_sparse_binary_search_";
static constexpr const char kSortNonstableFuncNamePrefix[] =
"_sparse_sort_nonstable_";
static constexpr const char kSortStableFuncNamePrefix[] =
"_sparse_sort_stable_";

typedef function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, size_t)>
FuncGeneratorType;
Expand Down Expand Up @@ -201,6 +211,79 @@ static void createLessThanFunc(OpBuilder &builder, ModuleOp unused,
builder.create<func::ReturnOp>(loc, topIfOp.getResult(0));
}

/// Creates a function to use a binary search to find the insertion point for
/// inserting xs[hi] to the sorted values xs[lo..hi).
//
// The generate IR corresponds to this C like algorithm:
// p = hi
// while (lo < hi)
// mid = (lo + hi) >> 1
// if (xs[p] < xs[mid])
// hi = mid
// else
// lo = mid - 1
// return lo;
//
static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, size_t dim) {
OpBuilder::InsertionGuard insertionGuard(builder);
Block *entryBlock = func.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);

Location loc = func.getLoc();
ValueRange args = entryBlock->getArguments();
Value p = args[hiIdx];
SmallVector<Type, 2> types(2, p.getType());
scf::WhileOp whileOp = builder.create<scf::WhileOp>(
loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]});

// The before-region of the WhileOp.
Block *before =
builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
builder.setInsertionPointToEnd(before);
Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
before->getArgument(0),
before->getArgument(1));
builder.create<scf::ConditionOp>(loc, cond1, before->getArguments());

// The after-region of the WhileOp.
Block *after =
builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
builder.setInsertionPointToEnd(after);
Value lo = after->getArgument(0);
Value hi = after->getArgument(1);
// Compute mid = (lo + hi) >> 1.
Value c1 = constantIndex(builder, loc, 1);
Value mid = builder.create<arith::ShRUIOp>(
loc, builder.create<arith::AddIOp>(loc, lo, hi), c1);
Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1);

// Compare xs[p] < xs[mid].
SmallVector<Value, 6> compareOperands{p, mid};
compareOperands.append(args.begin() + xStartIdx,
args.begin() + xStartIdx + dim);
Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
FlatSymbolRefAttr lessThanFunc =
getMangledSortHelperFunc(builder, func, {i1Type}, kLessThanFuncNamePrefix,
dim, compareOperands, createLessThanFunc);
Value cond2 = builder
.create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
compareOperands)
.getResult(0);

// Update lo and hi for the WhileOp as follows:
// if (xs[p] < xs[mid]))
// hi = mid;
// else
// lo = mid + 1;
Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1);
Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi);
builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi});

builder.setInsertionPointAfter(whileOp);
builder.create<func::ReturnOp>(loc, whileOp.getResult(0));
}

/// Creates a function to perform quick sort partition on the values in the
/// range of index [lo, hi), assuming lo < hi.
//
Expand Down Expand Up @@ -243,7 +326,7 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
compareOperands.append(xs.begin(), xs.end());
Type i1Type = IntegerType::get(context, 1, IntegerType::Signless);
FlatSymbolRefAttr lessThanFunc =
getMangledSortHelperFunc(builder, func, {i1Type}, "_sparse_less_than_",
getMangledSortHelperFunc(builder, func, {i1Type}, kLessThanFuncNamePrefix,
dim, compareOperands, createLessThanFunc);
Value cond = builder
.create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
Expand All @@ -258,9 +341,9 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
builder.create<arith::AddIOp>(loc, forOp.getRegionIterArgs().front(), c1);
SmallVector<Value, 6> swapOperands{i1, j};
swapOperands.append(args.begin() + xStartIdx, args.end());
FlatSymbolRefAttr swapFunc =
getMangledSortHelperFunc(builder, func, TypeRange(), "_sparse_may_swap_",
dim, swapOperands, createMaySwapFunc);
FlatSymbolRefAttr swapFunc = getMangledSortHelperFunc(
builder, func, TypeRange(), kMaySwapFuncNamePrefix, dim, swapOperands,
createMaySwapFunc);
builder.create<func::CallOp>(loc, swapFunc, TypeRange(), swapOperands);
builder.create<scf::YieldOp>(loc, i1);

Expand Down Expand Up @@ -292,8 +375,8 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
// quickSort(p + 1, hi, data);
// }
// }
static void createSortFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, size_t dim) {
static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, size_t dim) {
OpBuilder::InsertionGuard insertionGuard(builder);
Block *entryBlock = func.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);
Expand All @@ -310,8 +393,8 @@ static void createSortFunc(OpBuilder &builder, ModuleOp module,
// The if-stmt true branch.
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
builder, func, {IndexType::get(context)}, "_sparse_partition_", dim, args,
createPartitionFunc);
builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, dim,
args, createPartitionFunc);
auto p = builder.create<func::CallOp>(
loc, partitionFunc, TypeRange{IndexType::get(context)}, ValueRange(args));

Expand All @@ -331,6 +414,78 @@ static void createSortFunc(OpBuilder &builder, ModuleOp module,
builder.create<func::ReturnOp>(loc);
}

/// Creates a function to perform insertion sort on the values in the range of
/// index [lo, hi).
//
// The generate IR corresponds to this C like algorithm:
// void insertionSort(lo, hi, data) {
// for (i = lo+1; i < hi; i++) {
// d = data[i];
// p = binarySearch(lo, i-1, data)
// for (j = 0; j > i - p; j++)
// data[i-j] = data[i-j-1]
// data[p] = d
// }
// }
static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, size_t dim) {
OpBuilder::InsertionGuard insertionGuard(builder);
Block *entryBlock = func.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);

MLIRContext *context = module.getContext();
Location loc = func.getLoc();
ValueRange args = entryBlock->getArguments();
Value c1 = constantIndex(builder, loc, 1);
Value lo = args[loIdx];
Value hi = args[hiIdx];
Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1);

// Start the outer for-stmt with induction variable i.
scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1);
builder.setInsertionPointToStart(forOpI.getBody());
Value i = forOpI.getInductionVar();

// Binary search to find the insertion point p.
SmallVector<Value, 6> operands{lo, i};
operands.append(args.begin() + xStartIdx, args.begin() + xStartIdx + dim);
FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix,
dim, operands, createBinarySearchFunc);
Value p = builder
.create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
operands)
.getResult(0);

// Move the value at data[i] to a temporary location.
ValueRange data = args.drop_front(xStartIdx);
SmallVector<Value, 6> d;
for (Value v : data)
d.push_back(builder.create<memref::LoadOp>(loc, v, i));

// Start the inner for-stmt with induction variable j, for moving data[p..i)
// to data[p+1..i+1).
Value imp = builder.create<arith::SubIOp>(loc, i, p);
Value c0 = constantIndex(builder, loc, 0);
scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1);
builder.setInsertionPointToStart(forOpJ.getBody());
Value j = forOpJ.getInductionVar();
Value imj = builder.create<arith::SubIOp>(loc, i, j);
Value imjm1 = builder.create<arith::SubIOp>(loc, imj, c1);
for (Value v : data) {
Value t = builder.create<memref::LoadOp>(loc, v, imjm1);
builder.create<memref::StoreOp>(loc, t, v, imj);
}

// Store the value at data[i] to data[p].
builder.setInsertionPointAfter(forOpJ);
for (auto it : llvm::zip(d, data))
builder.create<memref::StoreOp>(loc, std::get<0>(it), std::get<1>(it), p);

builder.setInsertionPointAfter(forOpI);
builder.create<func::ReturnOp>(loc);
}

//===---------------------------------------------------------------------===//
// The actual sparse buffer rewriting rules.
//===---------------------------------------------------------------------===//
Expand Down Expand Up @@ -425,9 +580,13 @@ struct SortRewriter : public OpRewritePattern<SortOp> {
addValues(xs);
addValues(op.getYs());
auto insertPoint = op->getParentOfType<func::FuncOp>();
FlatSymbolRefAttr func = getMangledSortHelperFunc(
rewriter, insertPoint, TypeRange(), "_sparse_sort_", xs.size(),
operands, createSortFunc);
SmallString<32> funcName(op.getStable() ? kSortStableFuncNamePrefix
: kSortNonstableFuncNamePrefix);
FuncGeneratorType funcGenerator =
op.getStable() ? createSortStableFunc : createSortNonstableFunc;
FlatSymbolRefAttr func =
getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName,
xs.size(), operands, funcGenerator);
rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
return success();
}
Expand Down
32 changes: 26 additions & 6 deletions mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s --sparse-buffer-rewrite --canonicalize --cse | FileCheck %s
// RUN: mlir-opt %s -split-input-file --sparse-buffer-rewrite --canonicalize --cse | FileCheck %s

// CHECK-LABEL: func @sparse_push_back(
// CHECK-SAME: %[[A:.*]]: memref<?xindex>,
Expand Down Expand Up @@ -26,6 +26,8 @@ func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2:
return %0 : memref<?xf64>
}

// -----

// CHECK-LABEL: func @sparse_push_back_inbound(
// CHECK-SAME: %[[A:.*]]: memref<?xindex>,
// CHECK-SAME: %[[B:.*]]: memref<?xf64>,
Expand All @@ -42,6 +44,8 @@ func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>
return %0 : memref<?xf64>
}

// -----

// CHECK-LABEL: func.func private @_sparse_less_than_1_i8(
// CHECK-SAME: %[[I:arg0]]: index,
// CHECK-SAME: %[[J:.*]]: index,
Expand Down Expand Up @@ -101,7 +105,7 @@ func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>
// CHECK: return %[[I3p1]]
// CHECK: }

// CHECK-LABEL: func.func private @_sparse_sort_1_i8_f32_index(
// CHECK-LABEL: func.func private @_sparse_sort_nonstable_1_i8_f32_index(
// CHECK-SAME: %[[L:arg0]]: index,
// CHECK-SAME: %[[H:.*]]: index,
// CHECK-SAME: %[[X0:.*]]: memref<?xi8>,
Expand All @@ -111,9 +115,9 @@ func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>
// CHECK: %[[COND:.*]] = arith.cmpi ult, %[[L]], %[[H]]
// CHECK: scf.if %[[COND]] {
// CHECK: %[[P:.*]] = func.call @_sparse_partition_1_i8_f32_index(%[[L]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]])
// CHECK: func.call @_sparse_sort_1_i8_f32_index(%[[L]], %[[P]], %[[X0]], %[[Y0]], %[[Y1]])
// CHECK: func.call @_sparse_sort_nonstable_1_i8_f32_index(%[[L]], %[[P]], %[[X0]], %[[Y0]], %[[Y1]])
// CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]] : index
// CHECK: func.call @_sparse_sort_1_i8_f32_index(%[[P2]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]])
// CHECK: func.call @_sparse_sort_nonstable_1_i8_f32_index(%[[P2]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]])
// CHECK: }
// CHECK: return
// CHECK: }
Expand All @@ -126,7 +130,7 @@ func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>
// CHECK: %[[C0:.*]] = arith.constant 0
// CHECK: %[[DX0:.*]] = memref.cast %[[X0]] : memref<10xi8> to memref<?xi8>
// CHECK: %[[DY1:.*]] = memref.cast %[[Y1]] : memref<10xindex> to memref<?xindex>
// CHECK: call @_sparse_sort_1_i8_f32_index(%[[C0]], %[[N]], %[[DX0]], %[[Y0]], %[[DY1]])
// CHECK: call @_sparse_sort_nonstable_1_i8_f32_index(%[[C0]], %[[N]], %[[DX0]], %[[Y0]], %[[DY1]])
// CHECK: return %[[X0]], %[[Y0]], %[[Y1]]
// CHECK: }
func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<?xf32>, %arg3: memref<10xindex>)
Expand All @@ -135,15 +139,31 @@ func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<?x
return %arg1, %arg2, %arg3 : memref<10xi8>, memref<?xf32>, memref<10xindex>
}

// -----

// Only check the generated supporting function now. We have integration test
// to verify correctness of the generated code.
//
// CHECK-DAG: func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
// CHECK-DAG: func.func private @_sparse_may_swap_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
// CHECK-DAG: func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
// CHECK-DAG: func.func private @_sparse_sort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
// CHECK-DAG: func.func private @_sparse_sort_nonstable_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
// CHECK-LABEL: func.func @sparse_sort_3d
func.func @sparse_sort_3d(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
sparse_tensor.sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
}

// -----

// Only check the generated supporting functions. We have integration test to
// verify correctness of the generated code.
//
// CHECK-DAG: func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
// CHECK-DAG: func.func private @_sparse_binary_search_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
// CHECK-DAG: func.func private @_sparse_sort_stable_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
// CHECK-LABEL: func.func @sparse_sort_3d_stable
func.func @sparse_sort_3d_stable(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
sparse_tensor.sort stable %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
}
Loading

0 comments on commit 9409bbb

Please sign in to comment.