Skip to content

Commit

Permalink
[mlir][sparse] Improve quick sort by using a loop to sort the bigger …
Browse files Browse the repository at this point in the history
…partition.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D145440
  • Loading branch information
bixia1 committed Mar 11, 2023
1 parent 828cab5 commit f6424d1
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 68 deletions.
179 changes: 120 additions & 59 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
Expand Up @@ -918,30 +918,59 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
builder.create<func::ReturnOp>(loc);
}

static void createQuickSort(OpBuilder &builder, ModuleOp module,
func::FuncOp func, ValueRange args, uint64_t nx,
uint64_t ny, bool isCoo, uint32_t nTrailingP) {
/// A helper for generating code to perform quick sort. It partitions [lo, hi),
/// recursively calls quick sort to process the smaller partition and returns
/// the bigger partition to be processed by the enclosed while-loop.
static std::pair<Value, Value>
createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
ValueRange args, uint64_t nx, uint64_t ny, bool isCoo,
uint32_t nTrailingP) {
MLIRContext *context = module.getContext();
Location loc = func.getLoc();
Value lo = args[loIdx];
Value hi = args[hiIdx];
FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx,
ny, isCoo, args.drop_back(nTrailingP), createPartitionFunc);
auto p = builder.create<func::CallOp>(loc, partitionFunc,
TypeRange{IndexType::get(context)},
args.drop_back(nTrailingP));

SmallVector<Value> lowOperands{lo, p.getResult(0)};
lowOperands.append(args.begin() + xStartIdx, args.end());
builder.create<func::CallOp>(loc, func, lowOperands);

SmallVector<Value> highOperands{
builder.create<arith::AddIOp>(loc, p.getResult(0),
constantIndex(builder, loc, 1)),
hi};
highOperands.append(args.begin() + xStartIdx, args.end());
builder.create<func::CallOp>(loc, func, highOperands);
Value p = builder
.create<func::CallOp>(loc, partitionFunc,
TypeRange{IndexType::get(context)},
args.drop_back(nTrailingP))
.getResult(0);
Value pP1 =
builder.create<arith::AddIOp>(loc, p, constantIndex(builder, loc, 1));
Value lenLow = builder.create<arith::SubIOp>(loc, p, lo);
Value lenHigh = builder.create<arith::SubIOp>(loc, hi, p);
Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
lenLow, lenHigh);

SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);

Value c0 = constantIndex(builder, loc, 0);
auto mayRecursion = [&](Value low, Value high, Value len) {
Value cond =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, len, c0);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
SmallVector<Value> operands{low, high};
operands.append(args.begin() + xStartIdx, args.end());
builder.create<func::CallOp>(loc, func, operands);
builder.setInsertionPointAfter(ifOp);
};

// Recursively call quickSort to process the smaller partition and return
// the bigger partition to be processed by the enclosed while-loop.
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
mayRecursion(lo, p, lenLow);
builder.create<scf::YieldOp>(loc, ValueRange{pP1, hi});

builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
mayRecursion(pP1, hi, lenHigh);
builder.create<scf::YieldOp>(loc, ValueRange{lo, p});

builder.setInsertionPointAfter(ifOp);
return std::make_pair(ifOp.getResult(0), ifOp.getResult(1));
}

/// Creates a function to perform insertion sort on the values in the range of
Expand Down Expand Up @@ -1036,16 +1065,21 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
//
// When nTrailingP == 0, the generated IR corresponds to this C like algorithm:
// void quickSort(lo, hi, data) {
// if (lo + 1 < hi) {
// while (lo + 1 < hi) {
// p = partition(low, high, data);
// quickSort(lo, p, data);
// quickSort(p + 1, hi, data);
// if (len(lo, p) < len(p+1, hi)) {
// quickSort(lo, p, data);
// lo = p+1;
// } else {
// quickSort(p + 1, hi, data);
// hi = p;
// }
// }
// }
//
// When nTrailingP == 1, the generated IR corresponds to this C like algorithm:
// void hybridQuickSort(lo, hi, data, depthLimit) {
// if (lo + 1 < hi) {
// while (lo + 1 < hi) {
// len = hi - lo;
// if (len <= limit) {
// insertionSort(lo, hi, data);
Expand All @@ -1055,10 +1089,14 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
// heapSort(lo, hi, data);
// } else {
// p = partition(low, high, data);
// quickSort(lo, p, data);
// quickSort(p + 1, hi, data);
// if (len(lo, p) < len(p+1, hi)) {
// quickSort(lo, p, data, depthLimit);
// lo = p+1;
// } else {
// quickSort(p + 1, hi, data, depthLimit);
// hi = p;
// }
// }
// depthLimit ++;
// }
// }
// }
Expand All @@ -1073,70 +1111,98 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
builder.setInsertionPointToStart(entryBlock);

Location loc = func.getLoc();
ValueRange args = entryBlock->getArguments();
SmallVector<Value> args;
args.append(entryBlock->getArguments().begin(),
entryBlock->getArguments().end());
Value lo = args[loIdx];
Value hi = args[hiIdx];
Value loCmp =
SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
scf::WhileOp whileOp =
builder.create<scf::WhileOp>(loc, types, SmallVector<Value, 2>{lo, hi});

// The before-region of the WhileOp.
Block *before =
builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
builder.setInsertionPointToEnd(before);
lo = before->getArgument(0);
hi = before->getArgument(1);
Value loP1 =
builder.create<arith::AddIOp>(loc, lo, constantIndex(builder, loc, 1));
Value cond =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loCmp, hi);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
Value needSort =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loP1, hi);
builder.create<scf::ConditionOp>(loc, needSort, before->getArguments());

// The if-stmt true branch.
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
Value pDepthLimit;
Value savedDepthLimit;
scf::IfOp depthIf;
// The after-region of the WhileOp.
Block *after =
builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
builder.setInsertionPointToEnd(after);
lo = after->getArgument(0);
hi = after->getArgument(1);
args[0] = lo;
args[1] = hi;

if (isHybrid) {
Value len = builder.create<arith::SubIOp>(loc, hi, lo);
Value lenLimit = constantIndex(builder, loc, 30);
Value lenCond = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ule, len, lenLimit);
scf::IfOp lenIf = builder.create<scf::IfOp>(loc, lenCond, /*else=*/true);
scf::IfOp lenIf =
builder.create<scf::IfOp>(loc, types, lenCond, /*else=*/true);

// When len <= limit.
builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc(
builder, func, TypeRange(), kSortStableFuncNamePrefix, nx, ny, isCoo,
args.drop_back(nTrailingP), createSortStableFunc);
ValueRange(args).drop_back(nTrailingP), createSortStableFunc);
builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(),
ValueRange(args.drop_back(nTrailingP)));
ValueRange(args).drop_back(nTrailingP));
builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});

// When len > limit.
builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
pDepthLimit = args.back();
savedDepthLimit = builder.create<memref::LoadOp>(loc, pDepthLimit);
Value depthLimit = builder.create<arith::SubIOp>(
loc, savedDepthLimit, constantI64(builder, loc, 1));
builder.create<memref::StoreOp>(loc, depthLimit, pDepthLimit);
Value depthLimit = args.back();
depthLimit = builder.create<arith::SubIOp>(loc, depthLimit,
constantI64(builder, loc, 1));
Value depthCond =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
depthLimit, constantI64(builder, loc, 0));
depthIf = builder.create<scf::IfOp>(loc, depthCond, /*else=*/true);
scf::IfOp depthIf =
builder.create<scf::IfOp>(loc, types, depthCond, /*else=*/true);

// When depth exceeds limit.
builder.setInsertionPointToStart(&depthIf.getThenRegion().front());
FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc(
builder, func, TypeRange(), kHeapSortFuncNamePrefix, nx, ny, isCoo,
args.drop_back(nTrailingP), createHeapSortFunc);
ValueRange(args).drop_back(nTrailingP), createHeapSortFunc);
builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(),
ValueRange(args.drop_back(nTrailingP)));
ValueRange(args).drop_back(nTrailingP));
builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});

// When depth doesn't exceed limit.
builder.setInsertionPointToStart(&depthIf.getElseRegion().front());
}

createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP);
args.back() = depthLimit;
std::tie(lo, hi) =
createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP);
builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});

if (isHybrid) {
// Restore depthLimit.
builder.setInsertionPointAfter(depthIf);
builder.create<memref::StoreOp>(loc, savedDepthLimit, pDepthLimit);
lo = depthIf.getResult(0);
hi = depthIf.getResult(1);
builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});

builder.setInsertionPointAfter(lenIf);
lo = lenIf.getResult(0);
hi = lenIf.getResult(1);
} else {
std::tie(lo, hi) =
createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP);
}

// After the if-stmt.
builder.setInsertionPointAfter(ifOp);
// New [lo, hi) for the next while-loop iteration.
builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});

// After the while-loop.
builder.setInsertionPointAfter(whileOp);
builder.create<func::ReturnOp>(loc);
}

Expand Down Expand Up @@ -1171,9 +1237,6 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
funcName = kHybridQuickSortFuncNamePrefix;
funcGenerator = createQuickSortFunc;
nTrailingP = 1;
Value pDepthLimit = rewriter.create<memref::AllocaOp>(
loc, MemRefType::get({}, rewriter.getI64Type()));
operands.push_back(pDepthLimit);
// As a heuristics, set depthLimit = 2 * log2(n).
Value lo = operands[loIdx];
Value hi = operands[hiIdx];
Expand All @@ -1183,9 +1246,7 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
Value depthLimit = rewriter.create<arith::SubIOp>(
loc, constantI64(rewriter, loc, 64),
rewriter.create<math::CountLeadingZerosOp>(loc, len));
depthLimit = rewriter.create<arith::ShLIOp>(loc, depthLimit,
constantI64(rewriter, loc, 1));
rewriter.create<memref::StoreOp>(loc, depthLimit, pDepthLimit);
operands.push_back(depthLimit);
break;
}
case SparseTensorSortKind::QuickSort:
Expand Down
29 changes: 20 additions & 9 deletions mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
Expand Up @@ -132,13 +132,24 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref<?xf64>, %arg2: f
// CHECK-SAME: %[[Y0:.*]]: memref<?xf32>,
// CHECK-SAME: %[[Y1:.*]]: memref<?xindex>) {
// CHECK: %[[C1:.*]] = arith.constant 1
// CHECK: %[[Lb:.*]] = arith.addi %[[L]], %[[C1]]
// CHECK: %[[COND:.*]] = arith.cmpi ult, %[[Lb]], %[[H]]
// CHECK: scf.if %[[COND]] {
// CHECK: %[[P:.*]] = func.call @_sparse_partition_1_i8_f32_index(%[[L]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]])
// CHECK: func.call @_sparse_qsort_1_i8_f32_index(%[[L]], %[[P]], %[[X0]], %[[Y0]], %[[Y1]])
// CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]] : index
// CHECK: func.call @_sparse_qsort_1_i8_f32_index(%[[P2]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]])
// CHECK: scf.while (%[[L2:.*]] = %[[L]], %[[H2:.*]] = %[[H]])
// CHECK: %[[Lb:.*]] = arith.addi %[[L2]], %[[C1]]
// CHECK: %[[COND:.*]] = arith.cmpi ult, %[[Lb]], %[[H2]]
// CHECK: scf.condition(%[[COND]]) %[[L2]], %[[H2]]
// CHECK: } do {
// CHECK: ^bb0(%[[L3:.*]]: index, %[[H3:.*]]: index)
// CHECK: %[[P:.*]] = func.call @_sparse_partition_1_i8_f32_index(%[[L3]], %[[H3]], %[[X0]], %[[Y0]], %[[Y1]])
// CHECK: %[[PP1:.*]] = arith.addi %[[P]], %[[C1]] : index
// CHECK: %[[LenL:.*]] = arith.subi %[[P]], %[[L3]]
// CHECK: %[[LenH:.*]] = arith.subi %[[H3]], %[[P]]
// CHECK: %[[Cmp:.*]] = arith.cmpi ule, %[[LenL]], %[[LenH]]
// CHECK: %[[L4:.*]] = arith.select %[[Cmp]], %[[PP1]], %[[L3]]
// CHECK: %[[H4:.*]] = arith.select %[[Cmp]], %[[H3]], %[[P]]
// CHECK: scf.if %[[Cmp]]
// CHECK: func.call @_sparse_qsort_1_i8_f32_index(%[[L3]], %[[P]], %[[X0]], %[[Y0]], %[[Y1]])
// CHECK: else
// CHECK: func.call @_sparse_qsort_1_i8_f32_index(%[[PP1]], %[[H3]], %[[X0]], %[[Y0]], %[[Y1]])
// CHECK: scf.yield %[[L4]], %[[H4]]
// CHECK: }
// CHECK: return
// CHECK: }
Expand Down Expand Up @@ -187,7 +198,7 @@ func.func @sparse_sort_3d_quick(%arg0: index, %arg1: memref<10xindex>, %arg2: me
// CHECK-DAG: func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
// CHECK-DAG: func.func private @_sparse_compare_eq_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
// CHECK-DAG: func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
// CHECK-DAG: func.func private @_sparse_hybrid_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: memref<i64>) {
// CHECK-DAG: func.func private @_sparse_hybrid_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: i64) {
// CHECK-LABEL: func.func @sparse_sort_3d_hybrid
func.func @sparse_sort_3d_hybrid(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
Expand Down Expand Up @@ -249,7 +260,7 @@ func.func @sparse_sort_coo_quick(%arg0: index, %arg1: memref<100xindex>, %arg2:
// CHECK-DAG: func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
// CHECK-DAG: func.func private @_sparse_compare_eq_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref<?xindex>) -> i1 {
// CHECK-DAG: func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
// CHECK-DAG: func.func private @_sparse_hybrid_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: memref<i64>) {
// CHECK-DAG: func.func private @_sparse_hybrid_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: i64) {
// CHECK-LABEL: func.func @sparse_sort_coo_hybrid
func.func @sparse_sort_coo_hybrid(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
Expand Down

0 comments on commit f6424d1

Please sign in to comment.