diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp index b8cf62366d25a..321ef84bc5b75 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -34,8 +34,6 @@ static constexpr uint64_t loIdx = 0; static constexpr uint64_t hiIdx = 1; static constexpr uint64_t xStartIdx = 2; -static constexpr const char kLessThanFuncNamePrefix[] = "_sparse_less_than_"; -static constexpr const char kCompareEqFuncNamePrefix[] = "_sparse_compare_eq_"; static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_"; static constexpr const char kBinarySearchFuncNamePrefix[] = "_sparse_binary_search_"; @@ -181,27 +179,24 @@ static void createSwap(OpBuilder &builder, Location loc, ValueRange args, forEachIJPairInAllBuffers(builder, loc, args, nx, ny, isCoo, swapOnePair); } -/// Creates a function to compare all the (xs[i], xs[j]) pairs. The method to -/// compare each pair is create via `compareBuilder`. -static void createCompareFuncImplementation( - OpBuilder &builder, ModuleOp unused, func::FuncOp func, uint64_t nx, - uint64_t ny, bool isCoo, - function_ref +/// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare +/// each pair is create via `compareBuilder`. +static Value createInlinedCompareImplementation( + OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny, + bool isCoo, + function_ref compareBuilder) { - OpBuilder::InsertionGuard insertionGuard(builder); - - Block *entryBlock = func.addEntryBlock(); - builder.setInsertionPointToStart(entryBlock); - Location loc = func.getLoc(); - ValueRange args = entryBlock->getArguments(); - - scf::IfOp topIfOp; + Value result; auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) { - scf::IfOp ifOp = compareBuilder(builder, loc, i, j, buffer, (k == nx - 1)); - if (k == 0) { - topIfOp = ifOp; - } else { + bool isFirstDim = (k == 0); + bool isLastDim = (k == nx - 1); + Value val = + compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim); + if (isFirstDim) { + result = val; + } else if (!isLastDim) { OpBuilder::InsertionGuard insertionGuard(builder); + auto ifOp = cast(val.getDefiningOp()); builder.setInsertionPointAfter(ifOp); builder.create(loc, ifOp.getResult(0)); } @@ -209,38 +204,44 @@ static void createCompareFuncImplementation( forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder); - builder.setInsertionPointAfter(topIfOp); - builder.create(loc, topIfOp.getResult(0)); + builder.setInsertionPointAfterValue(result); + return result; } -/// Generates an if-statement to compare whether x[i] is equal to x[j]. -static scf::IfOp createEqCompare(OpBuilder &builder, Location loc, Value i, - Value j, Value x, bool isLastDim) { - Value f = constantI1(builder, loc, false); - Value t = constantI1(builder, loc, true); +/// Generates code to compare whether x[i] is equal to x[j] and returns the +/// result of the comparison. +static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j, + Value x, bool isFirstDim, bool isLastDim) { Value vi = builder.create(loc, x, i); Value vj = builder.create(loc, x, j); - Value cond = - builder.create(loc, arith::CmpIPredicate::eq, vi, vj); - scf::IfOp ifOp = - builder.create(loc, f.getType(), cond, /*else=*/true); - - // x[1] != x[j]: - builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - builder.create(loc, f); + Value res; + if (isLastDim) { + res = builder.create(loc, arith::CmpIPredicate::eq, vi, vj); + // For 1D, we create a compare without any control flow. Otherwise, we + // create YieldOp to return the result in the nested if-stmt. + if (!isFirstDim) + builder.create(loc, res); + } else { + Value ne = + builder.create(loc, arith::CmpIPredicate::ne, vi, vj); + scf::IfOp ifOp = builder.create(loc, builder.getIntegerType(1), + ne, /*else=*/true); + // If (x[i] != x[j]). + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + Value f = constantI1(builder, loc, false); + builder.create(loc, f); - // x[i] == x[j]: - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - if (isLastDim == 1) { - // Finish checking all dimensions. - builder.create(loc, t); + // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that + // checks the remaining dimensions. + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + res = ifOp.getResult(0); } - return ifOp; + return res; } -/// Creates a function to compare whether xs[i] is equal to xs[j]. +/// Creates code to compare whether xs[i] is equal to xs[j]. // // The generate IR corresponds to this C like algorithm: // if (x0[i] != x0[j]) @@ -250,77 +251,68 @@ static scf::IfOp createEqCompare(OpBuilder &builder, Location loc, Value i, // return false; // else if (x2[2] != x2[j])) // and so on ... -static void createEqCompareFunc(OpBuilder &builder, ModuleOp unused, - func::FuncOp func, uint64_t nx, uint64_t ny, - bool isCoo, uint32_t nTrailingP = 0) { +static Value createInlinedEqCompare(OpBuilder &builder, Location loc, + ValueRange args, uint64_t nx, uint64_t ny, + bool isCoo, uint32_t nTrailingP = 0) { // Compare functions don't use trailing parameters. (void)nTrailingP; assert(nTrailingP == 0); - createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo, - createEqCompare); + return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo, + createEqCompare); } -/// Generates an if-statement to compare whether x[i] is less than x[j]. -static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc, - Value i, Value j, Value x, - bool isLastDim) { - Value f = constantI1(builder, loc, false); - Value t = constantI1(builder, loc, true); +/// Generates code to compare whether x[i] is less than x[j] and returns the +/// result of the comparison. +static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i, + Value j, Value x, bool isFirstDim, + bool isLastDim) { Value vi = builder.create(loc, x, i); Value vj = builder.create(loc, x, j); - Value cond = - builder.create(loc, arith::CmpIPredicate::ult, vi, vj); - scf::IfOp ifOp = - builder.create(loc, f.getType(), cond, /*else=*/true); - // If (x[i] < x[j]). - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - builder.create(loc, t); - - builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - if (isLastDim == 1) { - // Finish checking all dimensions. - builder.create(loc, f); + Value res; + if (isLastDim) { + res = builder.create(loc, arith::CmpIPredicate::ult, vi, vj); + // For 1D, we create a compare without any control flow. Otherwise, we + // create YieldOp to return the result in the nested if-stmt. + if (!isFirstDim) + builder.create(loc, res); } else { - cond = - builder.create(loc, arith::CmpIPredicate::ult, vj, vi); - scf::IfOp ifOp2 = - builder.create(loc, f.getType(), cond, /*else=*/true); - // Otherwise if (x[j] < x[i]). - builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); - builder.create(loc, f); - - // Otherwise check the remaining dimensions. - builder.setInsertionPointAfter(ifOp2); - builder.create(loc, ifOp2.getResult(0)); - // Set up the insertion point for the nested if-stmt that checks the - // remaining dimensions. - builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); + Value ne = + builder.create(loc, arith::CmpIPredicate::ne, vi, vj); + scf::IfOp ifOp = builder.create(loc, builder.getIntegerType(1), + ne, /*else=*/true); + // If (x[i] != x[j]). + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + Value lt = + builder.create(loc, arith::CmpIPredicate::ult, vi, vj); + builder.create(loc, lt); + + // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that + // checks the remaining dimensions. + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + res = ifOp.getResult(0); } - return ifOp; + return res; } -/// Creates a function to compare whether xs[i] is less than xs[j]. +/// Creates code to compare whether xs[i] is less than xs[j]. // // The generate IR corresponds to this C like algorithm: -// if (x0[i] < x0[j]) -// return true; -// else if (x0[j] < x0[i]) -// return false; +// if (x0[i] != x0[j]) +// return x0[i] < x0[j]; +// else if (x1[j] != x1[i]) +// return x1[i] < x1[j]; // else -// if (x1[i] < x1[j]) -// return true; -// else if (x1[j] < x1[i])) // and so on ... -static void createLessThanFunc(OpBuilder &builder, ModuleOp unused, - func::FuncOp func, uint64_t nx, uint64_t ny, - bool isCoo, uint32_t nTrailingP = 0) { +static Value createInlinedLessThan(OpBuilder &builder, Location loc, + ValueRange args, uint64_t nx, uint64_t ny, + bool isCoo, uint32_t nTrailingP = 0) { // Compare functions don't use trailing parameters. (void)nTrailingP; assert(nTrailingP == 0); - createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo, - createLessThanCompare); + return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo, + createLessThanCompare); } /// Creates a function to use a binary search to find the insertion point for @@ -379,15 +371,8 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, uint64_t numXBuffers = isCoo ? 1 : nx; compareOperands.append(args.begin() + xStartIdx, args.begin() + xStartIdx + numXBuffers); - Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless); - FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( - builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo, - compareOperands, createLessThanFunc, nTrailingP); - Value cond2 = builder - .create(loc, lessThanFunc, TypeRange{i1Type}, - compareOperands) - .getResult(0); - + Value cond2 = + createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo); // Update lo and hi for the WhileOp as follows: // if (xs[p] < xs[mid])) // hi = mid; @@ -428,15 +413,8 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func, compareOperands.push_back(before->getArgument(0)); } compareOperands.append(xs.begin(), xs.end()); - MLIRContext *context = module.getContext(); - Type i1Type = IntegerType::get(context, 1, IntegerType::Signless); - FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( - builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo, - compareOperands, createLessThanFunc); - Value cond = builder - .create(loc, lessThanFunc, TypeRange{i1Type}, - compareOperands) - .getResult(0); + Value cond = + createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo); builder.create(loc, cond, before->getArguments()); Block *after = @@ -450,14 +428,8 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func, builder.setInsertionPointAfter(whileOp); compareOperands[0] = i; compareOperands[1] = p; - FlatSymbolRefAttr compareEqFunc = getMangledSortHelperFunc( - builder, func, {i1Type}, kCompareEqFuncNamePrefix, nx, ny, isCoo, - compareOperands, createEqCompareFunc); Value compareEq = - builder - .create(loc, compareEqFunc, TypeRange{i1Type}, - compareOperands) - .getResult(0); + createInlinedEqCompare(builder, loc, compareOperands, nx, ny, isCoo); return std::make_pair(whileOp.getResult(0), compareEq); } @@ -485,14 +457,10 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module, args.begin() + xStartIdx + numXBuffers); Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless); SmallVector cmpTypes{i1Type}; - FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( - builder, func, cmpTypes, kLessThanFuncNamePrefix, nx, ny, isCoo, - compareOperands, createLessThanFunc); Location loc = func.getLoc(); // Compare data[mi] < data[lo]. Value cond1 = - builder.create(loc, lessThanFunc, cmpTypes, compareOperands) - .getResult(0); + createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo); SmallVector ifTypes{lo.getType()}; scf::IfOp ifOp1 = builder.create(loc, ifTypes, cond1, /*else=*/true); @@ -502,11 +470,9 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module, auto createFindMedian = [&](Value a, Value b, Value c) -> scf::IfOp { compareOperands[0] = c; compareOperands[1] = a; - // Compare data[c]] < data[a]. + // Compare data[c] < data[b]. Value cond2 = - builder - .create(loc, lessThanFunc, cmpTypes, compareOperands) - .getResult(0); + createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo); scf::IfOp ifOp2 = builder.create(loc, ifTypes, cond2, /*else=*/true); builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); @@ -514,9 +480,7 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module, compareOperands[1] = b; // Compare data[c] < data[b]. Value cond3 = - builder - .create(loc, lessThanFunc, cmpTypes, compareOperands) - .getResult(0); + createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo); builder.create( loc, ValueRange{builder.create(loc, cond3, b, c)}); builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); @@ -758,10 +722,6 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, uint64_t numXBuffers = isCoo ? 1 : nx; compareOperands.append(args.begin() + xStartIdx, args.begin() + xStartIdx + numXBuffers); - Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless); - FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( - builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo, - compareOperands, createLessThanFunc); // Generate code to inspect the children of 'r' and return the larger child // as follows: @@ -784,10 +744,8 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, // Compare data[left] < data[right]. compareOperands[0] = lChildIdx; compareOperands[1] = rChildIdx; - Value cond2 = builder - .create(loc, lessThanFunc, - TypeRange{i1Type}, compareOperands) - .getResult(0); + Value cond2 = + createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo); scf::IfOp if2 = builder.create(loc, ifTypes, cond2, /*else=*/true); builder.setInsertionPointToStart(&if2.getThenRegion().front()); @@ -818,10 +776,8 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, childIdx = before->getArgument(2); compareOperands[0] = start; compareOperands[1] = childIdx; - Value cond = builder - .create(loc, lessThanFunc, TypeRange{i1Type}, - compareOperands) - .getResult(0); + Value cond = + createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo); builder.create(loc, cond, before->getArguments()); // The after-region of the WhileOp. diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir index 68e5c9b96b94f..84721f7479e65 100644 --- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir +++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir @@ -75,54 +75,132 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref, %arg2: f // ----- -// CHECK-LABEL: func.func private @_sparse_less_than_1_i8( -// CHECK-SAME: %[[I:arg0]]: index, -// CHECK-SAME: %[[J:.*]]: index, -// CHECK-SAME: %[[X0:.*]]: memref) -> i1 { -// CHECK: %[[VI:.*]] = memref.load %[[X0]]{{\[}}%[[I]]] -// CHECK: %[[VJ:.*]] = memref.load %[[X0]]{{\[}}%[[J]]] -// CHECK: %[[C:.*]] = arith.cmpi ult, %[[VI]], %[[VJ]] -// CHECK: return %[[C]] -// CHECK: } - -// CHECK-LABEL: func.func private @_sparse_compare_eq_1_i8( -// CHECK-SAME: %[[I:arg0]]: index, -// CHECK-SAME: %[[J:.*]]: index, -// CHECK-SAME: %[[X0:.*]]: memref) -> i1 { -// CHECK: %[[VI:.*]] = memref.load %[[X0]]{{\[}}%[[I]]] -// CHECK: %[[VJ:.*]] = memref.load %[[X0]]{{\[}}%[[J]]] -// CHECK: %[[C:.*]] = arith.cmpi eq, %[[VI]], %[[VJ]] -// CHECK: return %[[C]] -// CHECK: } - // CHECK-LABEL: func.func private @_sparse_partition_1_i8_f32_index( -// CHECK-SAME: %[[L:arg0]]: index, -// CHECK-SAME: %[[H:.*]]: index, -// CHECK-SAME: %[[X0:.*]]: memref, -// CHECK-SAME: %[[Y0:.*]]: memref, -// CHECK-SAME: %[[Y1:.*]]: memref) -> index { -// CHECK: %[[C1:.*]] = arith.constant 1 -// CHECK: %[[VAL_6:.*]] = arith.constant - -// CHECK: %[[SUM:.*]] = arith.addi %[[L]], %[[H]] -// CHECK: %[[P:.*]] = arith.shrui %[[SUM]], %[[C1]] -// CHECK: %[[J:.*]] = arith.subi %[[H]], %[[C1]] -// CHECK: %[[W:.*]]:3 = scf.while (%[[Ib:.*]] = %[[L]], %[[Jb:.*]] = %[[J]], %[[pb:.*]] = %[[P]]) : (index, index, index) -> (index, index, index) { -// CHECK: %[[Cn:.*]] = arith.cmpi ult, %[[Ib]], %[[Jb]] -// CHECK: scf.condition(%[[Cn]]) %[[Ib]], %[[Jb]], %[[pb]] +// CHECK-SAME: %[[VAL_0:.*0]]: index, +// CHECK-SAME: %[[VAL_1:.*1]]: index, +// CHECK-SAME: %[[VAL_2:.*2]]: memref, +// CHECK-SAME: %[[VAL_3:.*3]]: memref, +// CHECK-SAME: %[[VAL_4:.*4]]: memref) -> index { +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant -1 +// CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_0]], %[[VAL_1]] +// CHECK: %[[VAL_8:.*]] = arith.shrui %[[VAL_7]], %[[VAL_5]] +// CHECK: %[[VAL_9:.*]] = arith.subi %[[VAL_1]], %[[VAL_5]] +// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_8]]] +// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]] +// CHECK: %[[VAL_12:.*]] = arith.cmpi ult, %[[VAL_10]], %[[VAL_11]] +// CHECK: %[[VAL_13:.*]] = scf.if %[[VAL_12]] -> (index) { +// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] +// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]] +// CHECK: %[[VAL_16:.*]] = arith.cmpi ult, %[[VAL_14]], %[[VAL_15]] +// CHECK: %[[VAL_17:.*]] = scf.if %[[VAL_16]] -> (index) { +// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] +// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_8]]] +// CHECK: %[[VAL_20:.*]] = arith.cmpi ult, %[[VAL_18]], %[[VAL_19]] +// CHECK: %[[VAL_21:.*]] = arith.select %[[VAL_20]], %[[VAL_8]], %[[VAL_9]] +// CHECK: scf.yield %[[VAL_21]] +// CHECK: } else { +// CHECK: scf.yield %[[VAL_0]] +// CHECK: } +// CHECK: scf.yield %[[VAL_22:.*]] +// CHECK: } else { +// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] +// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_8]]] +// CHECK: %[[VAL_25:.*]] = arith.cmpi ult, %[[VAL_23]], %[[VAL_24]] +// CHECK: %[[VAL_26:.*]] = scf.if %[[VAL_25]] -> (index) { +// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] +// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]] +// CHECK: %[[VAL_29:.*]] = arith.cmpi ult, %[[VAL_27]], %[[VAL_28]] +// CHECK: %[[VAL_30:.*]] = arith.select %[[VAL_29]], %[[VAL_0]], %[[VAL_9]] +// CHECK: scf.yield %[[VAL_30]] +// CHECK: } else { +// CHECK: scf.yield %[[VAL_8]] +// CHECK: } +// CHECK: scf.yield %[[VAL_31:.*]] +// CHECK: } +// CHECK: %[[VAL_32:.*]] = arith.cmpi ne, %[[VAL_8]], %[[VAL_13:.*]] +// CHECK: scf.if %[[VAL_32]] { +// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_2]]{{\[}} +// CHECK: %[[VAL_35:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_8]]] +// CHECK: memref.store %[[VAL_35]], %[[VAL_2]] +// CHECK: memref.store %[[VAL_34]], %[[VAL_2]]{{\[}}%[[VAL_8]]] +// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_3]] +// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_8]]] +// CHECK: memref.store %[[VAL_37]], %[[VAL_3]] +// CHECK: memref.store %[[VAL_36]], %[[VAL_3]]{{\[}}%[[VAL_8]]] +// CHECK: %[[VAL_38:.*]] = memref.load %[[VAL_4]] +// CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_8]]] +// CHECK: memref.store %[[VAL_39]], %[[VAL_4]] +// CHECK: memref.store %[[VAL_38]], %[[VAL_4]]{{\[}}%[[VAL_8]]] +// CHECK: } +// CHECK: %[[VAL_40:.*]]:3 = scf.while (%[[VAL_41:.*]] = %[[VAL_0]], %[[VAL_42:.*]] = %[[VAL_9]], %[[VAL_43:.*]] = %[[VAL_8]]) +// CHECK: %[[VAL_44:.*]] = arith.cmpi ult, %[[VAL_41]], %[[VAL_42]] +// CHECK: scf.condition(%[[VAL_44]]) %[[VAL_41]], %[[VAL_42]], %[[VAL_43]] // CHECK: } do { -// CHECK: ^bb0(%[[Ia:.*]]: index, %[[Ja:.*]]: index, %[[Pa:.*]]: index): -// CHECK: %[[I2:.*]] = scf.while -// CHECK: %[[Ieq:.*]] = func.call @_sparse_compare_eq_1_i8(%[[I2:.*]], %[[Pa]], %[[X0]]) -// CHECK: %[[J2:.*]] = scf.while -// CHECK: %[[Jeq:.*]] = func.call @_sparse_compare_eq_1_i8(%[[J2:.*]], %[[Pa]], %[[X0]]) -// CHECK: %[[Cn2:.*]] = arith.cmpi ult, %[[I2]], %[[J2]] -// CHECK: %[[If:.*]]:3 = scf.if %[[Cn2]] -> (index, index, index) { +// CHECK: ^bb0(%[[VAL_45:.*]]: index, %[[VAL_46:.*]]: index, %[[VAL_47:.*]]: index): +// CHECK: %[[VAL_48:.*]] = scf.while (%[[VAL_49:.*]] = %[[VAL_45]]) : (index) -> index { +// CHECK: %[[VAL_50:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_49]]] +// CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_47]]] +// CHECK: %[[VAL_52:.*]] = arith.cmpi ult, %[[VAL_50]], %[[VAL_51]] +// CHECK: scf.condition(%[[VAL_52]]) %[[VAL_49]] +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_53:.*]]: index): +// CHECK: %[[VAL_54:.*]] = arith.addi %[[VAL_53]], %[[VAL_5]] +// CHECK: scf.yield %[[VAL_54]] +// CHECK: } +// CHECK: %[[VAL_55:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_56:.*]]] +// CHECK: %[[VAL_57:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_47]]] +// CHECK: %[[VAL_58:.*]] = arith.cmpi eq, %[[VAL_55]], %[[VAL_57]] +// CHECK: %[[VAL_59:.*]] = scf.while (%[[VAL_60:.*]] = %[[VAL_46]]) : (index) -> index { +// CHECK: %[[VAL_61:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_47]]] +// CHECK: %[[VAL_62:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_60]]] +// CHECK: %[[VAL_63:.*]] = arith.cmpi ult, %[[VAL_61]], %[[VAL_62]] +// CHECK: scf.condition(%[[VAL_63]]) %[[VAL_60]] +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_64:.*]]: index): +// CHECK: %[[VAL_65:.*]] = arith.addi %[[VAL_64]], %[[VAL_6]] +// CHECK: scf.yield %[[VAL_65]] +// CHECK: } +// CHECK: %[[VAL_66:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_67:.*]]] +// CHECK: %[[VAL_68:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_47]]] +// CHECK: %[[VAL_69:.*]] = arith.cmpi eq, %[[VAL_66]], %[[VAL_68]] +// CHECK: %[[VAL_70:.*]] = arith.cmpi ult, %[[VAL_56]], %[[VAL_67]] +// CHECK: %[[VAL_71:.*]]:3 = scf.if %[[VAL_70]] -> (index, index, index) { +// CHECK: %[[VAL_72:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_56]]] +// CHECK: %[[VAL_73:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_67]]] +// CHECK: memref.store %[[VAL_73]], %[[VAL_2]]{{\[}}%[[VAL_56]]] +// CHECK: memref.store %[[VAL_72]], %[[VAL_2]]{{\[}}%[[VAL_67]]] +// CHECK: %[[VAL_74:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_56]]] +// CHECK: %[[VAL_75:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_67]]] +// CHECK: memref.store %[[VAL_75]], %[[VAL_3]]{{\[}}%[[VAL_56]]] +// CHECK: memref.store %[[VAL_74]], %[[VAL_3]]{{\[}}%[[VAL_67]]] +// CHECK: %[[VAL_76:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_56]]] +// CHECK: %[[VAL_77:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_67]]] +// CHECK: memref.store %[[VAL_77]], %[[VAL_4]]{{\[}}%[[VAL_56]]] +// CHECK: memref.store %[[VAL_76]], %[[VAL_4]]{{\[}}%[[VAL_67]]] +// CHECK: %[[VAL_78:.*]] = arith.cmpi eq, %[[VAL_56]], %[[VAL_47]] +// CHECK: %[[VAL_79:.*]] = scf.if %[[VAL_78]] -> (index) { +// CHECK: scf.yield %[[VAL_67]] +// CHECK: } else { +// CHECK: %[[VAL_80:.*]] = arith.cmpi eq, %[[VAL_67]], %[[VAL_47]] +// CHECK: %[[VAL_81:.*]] = arith.select %[[VAL_80]], %[[VAL_56]], %[[VAL_47]] +// CHECK: scf.yield %[[VAL_81]] +// CHECK: } +// CHECK: %[[VAL_82:.*]] = arith.andi %[[VAL_58]], %[[VAL_69]] +// CHECK: %[[VAL_83:.*]]:2 = scf.if %[[VAL_82]] -> (index, index) { +// CHECK: %[[VAL_84:.*]] = arith.addi %[[VAL_56]], %[[VAL_5]] +// CHECK: %[[VAL_85:.*]] = arith.subi %[[VAL_67]], %[[VAL_5]] +// CHECK: scf.yield %[[VAL_84]], %[[VAL_85]] +// CHECK: } else { +// CHECK: scf.yield %[[VAL_56]], %[[VAL_67]] +// CHECK: } +// CHECK: scf.yield %[[VAL_86:.*]]#0, %[[VAL_86]]#1, %[[VAL_87:.*]] // CHECK: } else { -// CHECK: scf.yield %[[I2]], %[[J2]], %[[Pa]] +// CHECK: scf.yield %[[VAL_56]], %[[VAL_67]], %[[VAL_47]] // CHECK: } -// CHECK: scf.yield %[[If:.*]]#0, %[[If]]#1, %[[If]]#2 +// CHECK: scf.yield %[[VAL_88:.*]]#0, %[[VAL_88]]#1, %[[VAL_88]]#2 // CHECK: } -// CHECK: return %[[W:.*]]#2 +// CHECK: return %[[VAL_89:.*]]#2 // CHECK: } // CHECK-LABEL: func.func private @_sparse_qsort_1_i8_f32_index( @@ -176,8 +254,6 @@ func.func @sparse_sort_1d2v_quick(%arg0: index, %arg1: memref<10xi8>, %arg2: mem // Only check the generated supporting function now. We have integration test // to verify correctness of the generated code. // -// CHECK-DAG: func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> i1 { -// CHECK-DAG: func.func private @_sparse_compare_eq_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> i1 { // CHECK-DAG: func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { // CHECK-DAG: func.func private @_sparse_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { // CHECK-LABEL: func.func @sparse_sort_3d_quick @@ -191,12 +267,10 @@ func.func @sparse_sort_3d_quick(%arg0: index, %arg1: memref<10xindex>, %arg2: me // Only check the generated supporting function now. We have integration test // to verify correctness of the generated code. // -// CHECK-DAG: func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> i1 { // CHECK-DAG: func.func private @_sparse_binary_search_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { // CHECK-DAG: func.func private @_sparse_sort_stable_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { // CHECK-DAG: func.func private @_sparse_shift_down_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: index) { // CHECK-DAG: func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { -// CHECK-DAG: func.func private @_sparse_compare_eq_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> i1 { // CHECK-DAG: func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { // CHECK-DAG: func.func private @_sparse_hybrid_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: i64) { // CHECK-LABEL: func.func @sparse_sort_3d_hybrid @@ -210,7 +284,6 @@ func.func @sparse_sort_3d_hybrid(%arg0: index, %arg1: memref<10xindex>, %arg2: m // Only check the generated supporting functions. We have integration test to // verify correctness of the generated code. // -// CHECK-DAG: func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> i1 { // CHECK-DAG: func.func private @_sparse_binary_search_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { // CHECK-DAG: func.func private @_sparse_sort_stable_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { // CHECK-LABEL: func.func @sparse_sort_3d_stable @@ -224,7 +297,6 @@ func.func @sparse_sort_3d_stable(%arg0: index, %arg1: memref<10xindex>, %arg2: m // Only check the generated supporting functions. We have integration test to // verify correctness of the generated code. // -// CHECK-DAG: func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> i1 { // CHECK-DAG: func.func private @_sparse_shift_down_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: index) { // CHECK-DAG: func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { // CHECK-LABEL: func.func @sparse_sort_3d_heap @@ -238,8 +310,6 @@ func.func @sparse_sort_3d_heap(%arg0: index, %arg1: memref<10xindex>, %arg2: mem // Only check the generated supporting functions. We have integration test to // verify correctness of the generated code. // -// CHECK-DAG: func.func private @_sparse_less_than_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref) -> i1 { -// CHECK-DAG: func.func private @_sparse_compare_eq_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref) -> i1 { // CHECK-DAG: func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { // CHECK-DAG: func.func private @_sparse_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { // CHECK-LABEL: func.func @sparse_sort_coo_quick @@ -253,12 +323,10 @@ func.func @sparse_sort_coo_quick(%arg0: index, %arg1: memref<100xindex>, %arg2: // Only check the generated supporting functions. We have integration test to // verify correctness of the generated code. // -// CHECK-DAG: func.func private @_sparse_less_than_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref) -> i1 { // CHECK-DAG: func.func private @_sparse_binary_search_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { // CHECK-DAG: func.func private @_sparse_sort_stable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { // CHECK-DAG: func.func private @_sparse_shift_down_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: index) { // CHECK-DAG: func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { -// CHECK-DAG: func.func private @_sparse_compare_eq_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref) -> i1 { // CHECK-DAG: func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { // CHECK-DAG: func.func private @_sparse_hybrid_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: i64) { // CHECK-LABEL: func.func @sparse_sort_coo_hybrid @@ -272,7 +340,6 @@ func.func @sparse_sort_coo_hybrid(%arg0: index, %arg1: memref<100xindex>, %arg2: // Only check the generated supporting functions. We have integration test to // verify correctness of the generated code. // -// CHECK-DAG: func.func private @_sparse_less_than_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref) -> i1 { // CHECK-DAG: func.func private @_sparse_binary_search_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { // CHECK-DAG: func.func private @_sparse_sort_stable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { // CHECK-LABEL: func.func @sparse_sort_coo_stable @@ -286,7 +353,6 @@ func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<100xindex>, %arg2: // Only check the generated supporting functions. We have integration test to // verify correctness of the generated code. // -// CHECK-DAG: func.func private @_sparse_less_than_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref) -> i1 { // CHECK-DAG: func.func private @_sparse_shift_down_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: index) { // CHECK-DAG: func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { // CHECK-LABEL: func.func @sparse_sort_coo_heap