diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h index d0a3f01afe871..43e48a6d34026 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -158,16 +158,14 @@ namespace sparse_tensor { /// Convenience method to abbreviate casting `getType()`. template inline RankedTensorType getRankedTensorType(T &&t) { - assert(static_cast(std::forward(t)) && - "getRankedTensorType got null argument"); + assert(static_cast(t) && "getRankedTensorType got null argument"); return dyn_cast(std::forward(t).getType()); } /// Convenience method to abbreviate casting `getType()`. template inline MemRefType getMemRefType(T &&t) { - assert(static_cast(std::forward(t)) && - "getMemRefType got null argument"); + assert(static_cast(t) && "getMemRefType got null argument"); return cast(std::forward(t).getType()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp index 73e0f3d2891d7..f53d2727c9b00 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp @@ -159,14 +159,22 @@ IterationGraphSorter::IterationGraphSorter( loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)), strategy(strategy) { // One map per tensor. - assert(loop2InsLvl.size() == ins.size()); + assert(this->loop2InsLvl.size() == this->ins.size()); // All the affine maps have the same number of dimensions (loops). assert(llvm::all_equal(llvm::map_range( - loop2InsLvl, [](AffineMap m) { return m.getNumDims(); }))); + this->loop2InsLvl, [](AffineMap m) { return m.getNumDims(); }))); // The number of results of the map should match the rank of the tensor. - assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](auto mvPair) { + assert(llvm::all_of(llvm::zip(this->loop2InsLvl, this->ins), [](auto mvPair) { auto [m, v] = mvPair; - return m.getNumResults() == cast(v.getType()).getRank(); + + // For ranked types the rank must match. + // Simply return true for UnrankedTensorType + if (auto shapedType = llvm::dyn_cast(v.getType())) { + return !shapedType.hasRank() || + (m.getNumResults() == shapedType.getRank()); + } + // Non-shaped (scalar) types behave like rank-0. + return m.getNumResults() == 0; })); itGraph.resize(getNumLoops(), std::vector(getNumLoops(), false));