@@ -563,11 +563,14 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
563563// p = (lo+hi)/2 // pivot index
564564// i = lo
565565// j = hi-1
566- // while (i < j ) do {
566+ // while (true ) do {
567567// while (xs[i] < xs[p]) i ++;
568568// i_eq = (xs[i] == xs[p]);
569569// while (xs[j] > xs[p]) j --;
570570// j_eq = (xs[j] == xs[p]);
571+ //
572+ // if (i >= j) return j + 1;
573+ //
571574// if (i < j) {
572575// swap(xs[i], xs[j])
573576// if (i == p) {
@@ -581,8 +584,7 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
581584// }
582585// }
583586// }
584- // return p
585- // }
587+ // }
586588static void createPartitionFunc (OpBuilder &builder, ModuleOp module ,
587589 func::FuncOp func, uint64_t nx, uint64_t ny,
588590 bool isCoo, uint32_t nTrailingP = 0 ) {
@@ -605,22 +607,22 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
605607 Value i = lo;
606608 Value j = builder.create <arith::SubIOp>(loc, hi, c1);
607609 createChoosePivot (builder, module , func, nx, ny, isCoo, i, j, p, args);
608- SmallVector<Value, 3 > operands{i, j, p}; // Exactly three values.
609- SmallVector<Type, 3 > types{i.getType (), j.getType (), p.getType ()};
610+ Value trueVal = constantI1 (builder, loc, true ); // The value for while (true)
611+ SmallVector<Value, 4 > operands{i, j, p, trueVal}; // Exactly four values.
612+ SmallVector<Type, 4 > types{i.getType (), j.getType (), p.getType (),
613+ trueVal.getType ()};
610614 scf::WhileOp whileOp = builder.create <scf::WhileOp>(loc, types, operands);
611615
612616 // The before-region of the WhileOp.
613- Block *before =
614- builder. createBlock (&whileOp. getBefore (), {}, types, { loc, loc, loc});
617+ Block *before = builder. createBlock (&whileOp. getBefore (), {}, types,
618+ {loc, loc, loc, loc});
615619 builder.setInsertionPointToEnd (before);
616- Value cond = builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
617- before->getArgument (0 ),
618- before->getArgument (1 ));
619- builder.create <scf::ConditionOp>(loc, cond, before->getArguments ());
620+ builder.create <scf::ConditionOp>(loc, before->getArgument (3 ),
621+ before->getArguments ());
620622
621623 // The after-region of the WhileOp.
622624 Block *after =
623- builder.createBlock (&whileOp.getAfter (), {}, types, {loc, loc, loc});
625+ builder.createBlock (&whileOp.getAfter (), {}, types, {loc, loc, loc, loc });
624626 builder.setInsertionPointToEnd (after);
625627 i = after->getArgument (0 );
626628 j = after->getArgument (1 );
@@ -637,7 +639,8 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
637639 j = jresult;
638640
639641 // If i < j:
640- cond = builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
642+ Value cond =
643+ builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
641644 scf::IfOp ifOp = builder.create <scf::IfOp>(loc, types, cond, /* else=*/ true );
642645 builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
643646 SmallVector<Value> swapOperands{i, j};
@@ -675,11 +678,15 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
675678 builder.setInsertionPointAfter (ifOp2);
676679 builder.create <scf::YieldOp>(
677680 loc,
678- ValueRange{ifOp2.getResult (0 ), ifOp2.getResult (1 ), ifOpI.getResult (0 )});
681+ ValueRange{ifOp2.getResult (0 ), ifOp2.getResult (1 ), ifOpI.getResult (0 ),
682+ /* cont=*/ constantI1 (builder, loc, true )});
679683
680- // False branch for if i < j:
684+ // False branch for if i < j (i.e., i >= j) :
681685 builder.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
682- builder.create <scf::YieldOp>(loc, ValueRange{i, j, p});
686+ p = builder.create <arith::AddIOp>(loc, j,
687+ constantOne (builder, loc, j.getType ()));
688+ builder.create <scf::YieldOp>(
689+ loc, ValueRange{i, j, p, /* cont=*/ constantI1 (builder, loc, false )});
683690
684691 // Return for the whileOp.
685692 builder.setInsertionPointAfter (ifOp);
@@ -927,6 +934,8 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
927934 Location loc = func.getLoc ();
928935 Value lo = args[loIdx];
929936 Value hi = args[hiIdx];
937+ SmallVector<Type, 2 > types (2 , lo.getType ()); // Only two types.
938+
930939 FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc (
931940 builder, func, {IndexType::get (context)}, kPartitionFuncNamePrefix , nx,
932941 ny, isCoo, args.drop_back (nTrailingP), createPartitionFunc);
@@ -935,14 +944,25 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
935944 TypeRange{IndexType::get (context)},
936945 args.drop_back (nTrailingP))
937946 .getResult (0 );
938- Value pP1 =
939- builder.create <arith::AddIOp>(loc, p, constantIndex (builder, loc, 1 ));
947+
940948 Value lenLow = builder.create <arith::SubIOp>(loc, p, lo);
941949 Value lenHigh = builder.create <arith::SubIOp>(loc, hi, p);
950+ // Partition already sorts array with len <= 2
951+ Value c2 = constantIndex (builder, loc, 2 );
952+ Value len = builder.create <arith::SubIOp>(loc, hi, lo);
953+ Value lenGtTwo =
954+ builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, len, c2);
955+ scf::IfOp ifLenGtTwo =
956+ builder.create <scf::IfOp>(loc, types, lenGtTwo, /* else=*/ true );
957+ builder.setInsertionPointToStart (&ifLenGtTwo.getElseRegion ().front ());
958+ // Returns an empty range to mark the entire region is fully sorted.
959+ builder.create <scf::YieldOp>(loc, ValueRange{lo, lo});
960+
961+ // Else len > 2, need recursion.
962+ builder.setInsertionPointToStart (&ifLenGtTwo.getThenRegion ().front ());
942963 Value cond = builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
943964 lenLow, lenHigh);
944965
945- SmallVector<Type, 2 > types (2 , lo.getType ()); // Only two types.
946966 scf::IfOp ifOp = builder.create <scf::IfOp>(loc, types, cond, /* else=*/ true );
947967
948968 Value c0 = constantIndex (builder, loc, 0 );
@@ -961,14 +981,17 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
961981 // the bigger partition to be processed by the enclosed while-loop.
962982 builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
963983 mayRecursion (lo, p, lenLow);
964- builder.create <scf::YieldOp>(loc, ValueRange{pP1 , hi});
984+ builder.create <scf::YieldOp>(loc, ValueRange{p , hi});
965985
966986 builder.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
967- mayRecursion (pP1 , hi, lenHigh);
987+ mayRecursion (p , hi, lenHigh);
968988 builder.create <scf::YieldOp>(loc, ValueRange{lo, p});
969989
970990 builder.setInsertionPointAfter (ifOp);
971- return std::make_pair (ifOp.getResult (0 ), ifOp.getResult (1 ));
991+ builder.create <scf::YieldOp>(loc, ifOp.getResults ());
992+
993+ builder.setInsertionPointAfter (ifLenGtTwo);
994+ return std::make_pair (ifLenGtTwo.getResult (0 ), ifLenGtTwo.getResult (1 ));
972995}
973996
974997// / Creates a function to perform insertion sort on the values in the range of
0 commit comments