Skip to content

Commit

Permalink
[mlir][sparse] refine heuristic for iteration graph topsort
Browse files Browse the repository at this point in the history
The sparse index order must always be satisfied, but this
may give a choice in topsorts for several cases. We broke
ties in favor of any dense index order, since this gives
good locality. However, breaking ties in favor of pushing
unrelated indices into sparse iteration spaces gives better
asymptotic complexity. This revision improves the heuristic.

Note that in the long run, we are really interested in using
ML for ML to find the best loop ordering as a replacement for
such heuristics.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D109100
  • Loading branch information
aartbik committed Sep 3, 2021
1 parent 36895cd commit b6d1a31
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 21 deletions.
27 changes: 22 additions & 5 deletions mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Expand Up @@ -30,6 +30,9 @@ using namespace mlir::sparse_tensor;

namespace {

// Iteration graph sorting.
enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 };

// Code generation.
struct CodeGen {
CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops)
Expand Down Expand Up @@ -141,7 +144,7 @@ static bool topSortDFS(unsigned i, std::vector<unsigned> &visit,
/// order yields innermost unit-stride access with better spatial locality.
static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
std::vector<unsigned> &topSort,
bool sparseOnly) {
unsigned mask) {
// Set up an n x n from/to adjacency matrix of the iteration graph
// for the implicit loop indices i_0 .. i_n-1.
unsigned n = op.getNumLoops();
Expand All @@ -152,8 +155,8 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
auto map = op.getTiedIndexingMap(t);
auto enc = getSparseTensorEncoding(t->get().getType());
assert(map.getNumDims() == n);
// Skip dense tensor constraints when sparse only is requested.
if (sparseOnly && !enc)
// Skip dense tensor constraints when not requested.
if (!(mask & SortMask::kIncludeDense) && !enc)
continue;
// Each tensor expression and optional dimension ordering (row-major
// by default) puts an ordering constraint on the loop indices. For
Expand All @@ -164,6 +167,16 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
unsigned t = map.getDimPosition(perm(enc, d));
adjM[f][t] = true;
}
// Push unrelated loops into sparse iteration space, so these
// will be skipped more often.
if (mask & SortMask::kIncludeUndef) {
unsigned tensor = t->getOperandNumber();
for (unsigned i = 0; i < n; i++)
if (merger.isDim(tensor, i, Dim::kSparse))
for (unsigned j = 0; j < n; j++)
if (merger.isDim(tensor, j, Dim::kUndef))
adjM[i][j] = true;
}
}

// Topologically sort the iteration graph to determine loop order.
Expand Down Expand Up @@ -1134,8 +1147,12 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
// This assumes that higher-level passes have already put the
// tensors in each tensor expression in a feasible order.
std::vector<unsigned> topSort;
if (!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/false) &&
!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true))
if (!computeIterationGraph(merger, op, topSort,
SortMask::kIncludeUndef |
SortMask::kIncludeDense) &&
!computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) &&
!computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) &&
!computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly))
return failure();

// Builds the tensor expression for the Linalg operation in SSA form.
Expand Down
33 changes: 17 additions & 16 deletions mlir/test/Dialect/SparseTensor/sparse_2d.mlir
Expand Up @@ -1043,25 +1043,26 @@ func @scale(%arga: tensor<?x?xf64, #Tds>, %argx: tensor<?x?xf64>) -> tensor<?x?x
// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
// CHECK: scf.for %[[VAL_20:.*]] = %[[VAL_18]] to %[[VAL_19]] step %[[VAL_5]] {
// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_20]]] : memref<?xindex>
// CHECK: scf.for %[[VAL_22:.*]] = %[[VAL_4]] to %[[VAL_12]] step %[[VAL_5]] {
// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]], %[[VAL_22]]] : memref<?x?xf32>
// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref<?xindex>
// CHECK: %[[VAL_25:.*]] = addi %[[VAL_20]], %[[VAL_5]] : index
// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_25]]] : memref<?xindex>
// CHECK: scf.for %[[VAL_27:.*]] = %[[VAL_24]] to %[[VAL_26]] step %[[VAL_5]] {
// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_27]]] : memref<?xindex>
// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_28]]] : memref<?x?xf32>
// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_27]]] : memref<?xf32>
// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_22]], %[[VAL_28]]] : memref<?x?xf32>
// CHECK: %[[VAL_32:.*]] = mulf %[[VAL_23]], %[[VAL_31]] : f32
// CHECK: %[[VAL_33:.*]] = mulf %[[VAL_30]], %[[VAL_32]] : f32
// CHECK: %[[VAL_34:.*]] = addf %[[VAL_29]], %[[VAL_33]] : f32
// CHECK: memref.store %[[VAL_34]], %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_28]]] : memref<?x?xf32>
// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref<?xindex>
// CHECK: %[[VAL_23:.*]] = addi %[[VAL_20]], %[[VAL_5]] : index
// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_23]]] : memref<?xindex>
// CHECK: scf.for %[[VAL_25:.*]] = %[[VAL_22]] to %[[VAL_24]] step %[[VAL_5]] {
// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_25]]] : memref<?xindex>
// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref<?xf32>
// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_26]]] : memref<?x?xf32>
// CHECK: %[[VAL_29:.*]] = scf.for %[[VAL_30:.*]] = %[[VAL_4]] to %[[VAL_12]] step %[[VAL_5]] iter_args(%[[VAL_31:.*]] = %[[VAL_28]]) -> (f32) {
// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]], %[[VAL_30]]] : memref<?x?xf32>
// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_30]], %[[VAL_26]]] : memref<?x?xf32>
// CHECK: %[[VAL_34:.*]] = mulf %[[VAL_32]], %[[VAL_33]] : f32
// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_27]], %[[VAL_34]] : f32
// CHECK: %[[VAL_36:.*]] = addf %[[VAL_31]], %[[VAL_35]] : f32
// CHECK: scf.yield %[[VAL_36]] : f32
// CHECK: }
// CHECK: memref.store %[[VAL_37:.*]], %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_26]]] : memref<?x?xf32>
// CHECK: }
// CHECK: }
// CHECK: %[[VAL_35:.*]] = memref.tensor_load %[[VAL_17]] : memref<?x?xf32>
// CHECK: return %[[VAL_35]] : tensor<?x?xf32>
// CHECK: %[[VAL_38:.*]] = memref.tensor_load %[[VAL_17]] : memref<?x?xf32>
// CHECK: return %[[VAL_38]] : tensor<?x?xf32>
// CHECK: }
func @sampled_dense_dense(%args: tensor<?x?xf32, #Tss>,
%arga: tensor<?x?xf32>,
Expand Down

0 comments on commit b6d1a31

Please sign in to comment.