Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,16 +158,14 @@ namespace sparse_tensor {
/// Convenience method to abbreviate casting `getType()`.
template <typename T>
inline RankedTensorType getRankedTensorType(T &&t) {
assert(static_cast<bool>(std::forward<T>(t)) &&
"getRankedTensorType got null argument");
assert(static_cast<bool>(t) && "getRankedTensorType got null argument");
return dyn_cast<RankedTensorType>(std::forward<T>(t).getType());
}

/// Convenience method to abbreviate casting `getType()`.
template <typename T>
inline MemRefType getMemRefType(T &&t) {
assert(static_cast<bool>(std::forward<T>(t)) &&
"getMemRefType got null argument");
assert(static_cast<bool>(t) && "getMemRefType got null argument");
return cast<MemRefType>(std::forward<T>(t).getType());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,22 @@ IterationGraphSorter::IterationGraphSorter(
loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)),
strategy(strategy) {
// One map per tensor.
assert(loop2InsLvl.size() == ins.size());
assert(this->loop2InsLvl.size() == this->ins.size());
// All the affine maps have the same number of dimensions (loops).
assert(llvm::all_equal(llvm::map_range(
loop2InsLvl, [](AffineMap m) { return m.getNumDims(); })));
this->loop2InsLvl, [](AffineMap m) { return m.getNumDims(); })));
// The number of results of the map should match the rank of the tensor.
assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](auto mvPair) {
assert(llvm::all_of(llvm::zip(this->loop2InsLvl, this->ins), [](auto mvPair) {
auto [m, v] = mvPair;
return m.getNumResults() == cast<ShapedType>(v.getType()).getRank();

// For ranked types the rank must match.
// Simply return true for UnrankedTensorType
if (auto shapedType = llvm::dyn_cast<ShapedType>(v.getType())) {
return !shapedType.hasRank() ||
(m.getNumResults() == shapedType.getRank());
}
// Non-shaped (scalar) types behave like rank-0.
return m.getNumResults() == 0;
}));

itGraph.resize(getNumLoops(), std::vector<bool>(getNumLoops(), false));
Expand Down