Skip to content

Commit

Permalink
[mlir][sparse] Improve sort operation by generating inlined code to c…
Browse files Browse the repository at this point in the history
…ompare values.

Previously, we generate function calls to compare values for sorting. It turns
out that the compiler doesn't inline those function calls. We now directly
generate inlined code. Also, modify the code for comparing values to use less
number of branches.

This improves all sort implementation in general. For arabic-2005.mtx CSR, the
improvement is around 25%.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D145442
  • Loading branch information
bixia1 committed Mar 14, 2023
1 parent c1125ae commit 2ef4162
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 197 deletions.
240 changes: 98 additions & 142 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
Expand Up @@ -34,8 +34,6 @@ static constexpr uint64_t loIdx = 0;
static constexpr uint64_t hiIdx = 1;
static constexpr uint64_t xStartIdx = 2;

static constexpr const char kLessThanFuncNamePrefix[] = "_sparse_less_than_";
static constexpr const char kCompareEqFuncNamePrefix[] = "_sparse_compare_eq_";
static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
static constexpr const char kBinarySearchFuncNamePrefix[] =
"_sparse_binary_search_";
Expand Down Expand Up @@ -181,66 +179,69 @@ static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
forEachIJPairInAllBuffers(builder, loc, args, nx, ny, isCoo, swapOnePair);
}

/// Creates a function to compare all the (xs[i], xs[j]) pairs. The method to
/// compare each pair is create via `compareBuilder`.
static void createCompareFuncImplementation(
OpBuilder &builder, ModuleOp unused, func::FuncOp func, uint64_t nx,
uint64_t ny, bool isCoo,
function_ref<scf::IfOp(OpBuilder &, Location, Value, Value, Value, bool)>
/// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare
/// each pair is create via `compareBuilder`.
static Value createInlinedCompareImplementation(
OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
bool isCoo,
function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)>
compareBuilder) {
OpBuilder::InsertionGuard insertionGuard(builder);

Block *entryBlock = func.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);
Location loc = func.getLoc();
ValueRange args = entryBlock->getArguments();

scf::IfOp topIfOp;
Value result;
auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
scf::IfOp ifOp = compareBuilder(builder, loc, i, j, buffer, (k == nx - 1));
if (k == 0) {
topIfOp = ifOp;
} else {
bool isFirstDim = (k == 0);
bool isLastDim = (k == nx - 1);
Value val =
compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim);
if (isFirstDim) {
result = val;
} else if (!isLastDim) {
OpBuilder::InsertionGuard insertionGuard(builder);
auto ifOp = cast<scf::IfOp>(val.getDefiningOp());
builder.setInsertionPointAfter(ifOp);
builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
}
};

forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder);

builder.setInsertionPointAfter(topIfOp);
builder.create<func::ReturnOp>(loc, topIfOp.getResult(0));
builder.setInsertionPointAfterValue(result);
return result;
}

/// Generates an if-statement to compare whether x[i] is equal to x[j].
static scf::IfOp createEqCompare(OpBuilder &builder, Location loc, Value i,
Value j, Value x, bool isLastDim) {
Value f = constantI1(builder, loc, false);
Value t = constantI1(builder, loc, true);
/// Generates code to compare whether x[i] is equal to x[j] and returns the
/// result of the comparison.
static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j,
Value x, bool isFirstDim, bool isLastDim) {
Value vi = builder.create<memref::LoadOp>(loc, x, i);
Value vj = builder.create<memref::LoadOp>(loc, x, j);

Value cond =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj);
scf::IfOp ifOp =
builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);

// x[1] != x[j]:
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
builder.create<scf::YieldOp>(loc, f);
Value res;
if (isLastDim) {
res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj);
// For 1D, we create a compare without any control flow. Otherwise, we
// create YieldOp to return the result in the nested if-stmt.
if (!isFirstDim)
builder.create<scf::YieldOp>(loc, res);
} else {
Value ne =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1),
ne, /*else=*/true);
// If (x[i] != x[j]).
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
Value f = constantI1(builder, loc, false);
builder.create<scf::YieldOp>(loc, f);

// x[i] == x[j]:
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
if (isLastDim == 1) {
// Finish checking all dimensions.
builder.create<scf::YieldOp>(loc, t);
// If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
// checks the remaining dimensions.
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
res = ifOp.getResult(0);
}

return ifOp;
return res;
}

/// Creates a function to compare whether xs[i] is equal to xs[j].
/// Creates code to compare whether xs[i] is equal to xs[j].
//
// The generate IR corresponds to this C like algorithm:
// if (x0[i] != x0[j])
Expand All @@ -250,77 +251,68 @@ static scf::IfOp createEqCompare(OpBuilder &builder, Location loc, Value i,
// return false;
// else if (x2[2] != x2[j]))
// and so on ...
static void createEqCompareFunc(OpBuilder &builder, ModuleOp unused,
func::FuncOp func, uint64_t nx, uint64_t ny,
bool isCoo, uint32_t nTrailingP = 0) {
static Value createInlinedEqCompare(OpBuilder &builder, Location loc,
ValueRange args, uint64_t nx, uint64_t ny,
bool isCoo, uint32_t nTrailingP = 0) {
// Compare functions don't use trailing parameters.
(void)nTrailingP;
assert(nTrailingP == 0);
createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo,
createEqCompare);
return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo,
createEqCompare);
}

/// Generates an if-statement to compare whether x[i] is less than x[j].
static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc,
Value i, Value j, Value x,
bool isLastDim) {
Value f = constantI1(builder, loc, false);
Value t = constantI1(builder, loc, true);
/// Generates code to compare whether x[i] is less than x[j] and returns the
/// result of the comparison.
static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i,
Value j, Value x, bool isFirstDim,
bool isLastDim) {
Value vi = builder.create<memref::LoadOp>(loc, x, i);
Value vj = builder.create<memref::LoadOp>(loc, x, j);

Value cond =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
scf::IfOp ifOp =
builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
// If (x[i] < x[j]).
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
builder.create<scf::YieldOp>(loc, t);

builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
if (isLastDim == 1) {
// Finish checking all dimensions.
builder.create<scf::YieldOp>(loc, f);
Value res;
if (isLastDim) {
res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
// For 1D, we create a compare without any control flow. Otherwise, we
// create YieldOp to return the result in the nested if-stmt.
if (!isFirstDim)
builder.create<scf::YieldOp>(loc, res);
} else {
cond =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vj, vi);
scf::IfOp ifOp2 =
builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
// Otherwise if (x[j] < x[i]).
builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
builder.create<scf::YieldOp>(loc, f);

// Otherwise check the remaining dimensions.
builder.setInsertionPointAfter(ifOp2);
builder.create<scf::YieldOp>(loc, ifOp2.getResult(0));
// Set up the insertion point for the nested if-stmt that checks the
// remaining dimensions.
builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
Value ne =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1),
ne, /*else=*/true);
// If (x[i] != x[j]).
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
Value lt =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
builder.create<scf::YieldOp>(loc, lt);

// If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
// checks the remaining dimensions.
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
res = ifOp.getResult(0);
}

return ifOp;
return res;
}

/// Creates a function to compare whether xs[i] is less than xs[j].
/// Creates code to compare whether xs[i] is less than xs[j].
//
// The generate IR corresponds to this C like algorithm:
// if (x0[i] < x0[j])
// return true;
// else if (x0[j] < x0[i])
// return false;
// if (x0[i] != x0[j])
// return x0[i] < x0[j];
// else if (x1[j] != x1[i])
// return x1[i] < x1[j];
// else
// if (x1[i] < x1[j])
// return true;
// else if (x1[j] < x1[i]))
// and so on ...
static void createLessThanFunc(OpBuilder &builder, ModuleOp unused,
func::FuncOp func, uint64_t nx, uint64_t ny,
bool isCoo, uint32_t nTrailingP = 0) {
static Value createInlinedLessThan(OpBuilder &builder, Location loc,
ValueRange args, uint64_t nx, uint64_t ny,
bool isCoo, uint32_t nTrailingP = 0) {
// Compare functions don't use trailing parameters.
(void)nTrailingP;
assert(nTrailingP == 0);
createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo,
createLessThanCompare);
return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo,
createLessThanCompare);
}

/// Creates a function to use a binary search to find the insertion point for
Expand Down Expand Up @@ -379,15 +371,8 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
uint64_t numXBuffers = isCoo ? 1 : nx;
compareOperands.append(args.begin() + xStartIdx,
args.begin() + xStartIdx + numXBuffers);
Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo,
compareOperands, createLessThanFunc, nTrailingP);
Value cond2 = builder
.create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
compareOperands)
.getResult(0);

Value cond2 =
createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
// Update lo and hi for the WhileOp as follows:
// if (xs[p] < xs[mid]))
// hi = mid;
Expand Down Expand Up @@ -428,15 +413,8 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
compareOperands.push_back(before->getArgument(0));
}
compareOperands.append(xs.begin(), xs.end());
MLIRContext *context = module.getContext();
Type i1Type = IntegerType::get(context, 1, IntegerType::Signless);
FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo,
compareOperands, createLessThanFunc);
Value cond = builder
.create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
compareOperands)
.getResult(0);
Value cond =
createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
builder.create<scf::ConditionOp>(loc, cond, before->getArguments());

Block *after =
Expand All @@ -450,14 +428,8 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
builder.setInsertionPointAfter(whileOp);
compareOperands[0] = i;
compareOperands[1] = p;
FlatSymbolRefAttr compareEqFunc = getMangledSortHelperFunc(
builder, func, {i1Type}, kCompareEqFuncNamePrefix, nx, ny, isCoo,
compareOperands, createEqCompareFunc);
Value compareEq =
builder
.create<func::CallOp>(loc, compareEqFunc, TypeRange{i1Type},
compareOperands)
.getResult(0);
createInlinedEqCompare(builder, loc, compareOperands, nx, ny, isCoo);

return std::make_pair(whileOp.getResult(0), compareEq);
}
Expand Down Expand Up @@ -485,14 +457,10 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
args.begin() + xStartIdx + numXBuffers);
Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
SmallVector<Type, 1> cmpTypes{i1Type};
FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
builder, func, cmpTypes, kLessThanFuncNamePrefix, nx, ny, isCoo,
compareOperands, createLessThanFunc);
Location loc = func.getLoc();
// Compare data[mi] < data[lo].
Value cond1 =
builder.create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands)
.getResult(0);
createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
SmallVector<Type, 1> ifTypes{lo.getType()};
scf::IfOp ifOp1 =
builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true);
Expand All @@ -502,21 +470,17 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
auto createFindMedian = [&](Value a, Value b, Value c) -> scf::IfOp {
compareOperands[0] = c;
compareOperands[1] = a;
// Compare data[c]] < data[a].
// Compare data[c] < data[b].
Value cond2 =
builder
.create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands)
.getResult(0);
createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
scf::IfOp ifOp2 =
builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
compareOperands[0] = c;
compareOperands[1] = b;
// Compare data[c] < data[b].
Value cond3 =
builder
.create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands)
.getResult(0);
createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
builder.create<scf::YieldOp>(
loc, ValueRange{builder.create<arith::SelectOp>(loc, cond3, b, c)});
builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
Expand Down Expand Up @@ -758,10 +722,6 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
uint64_t numXBuffers = isCoo ? 1 : nx;
compareOperands.append(args.begin() + xStartIdx,
args.begin() + xStartIdx + numXBuffers);
Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo,
compareOperands, createLessThanFunc);

// Generate code to inspect the children of 'r' and return the larger child
// as follows:
Expand All @@ -784,10 +744,8 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
// Compare data[left] < data[right].
compareOperands[0] = lChildIdx;
compareOperands[1] = rChildIdx;
Value cond2 = builder
.create<func::CallOp>(loc, lessThanFunc,
TypeRange{i1Type}, compareOperands)
.getResult(0);
Value cond2 =
createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
scf::IfOp if2 =
builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
builder.setInsertionPointToStart(&if2.getThenRegion().front());
Expand Down Expand Up @@ -818,10 +776,8 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
childIdx = before->getArgument(2);
compareOperands[0] = start;
compareOperands[1] = childIdx;
Value cond = builder
.create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
compareOperands)
.getResult(0);
Value cond =
createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
builder.create<scf::ConditionOp>(loc, cond, before->getArguments());

// The after-region of the WhileOp.
Expand Down

0 comments on commit 2ef4162

Please sign in to comment.