Skip to content

Commit

Permalink
[mlir][sparse] code cleanup (#73047)
Browse files Browse the repository at this point in the history
removed two unused methods, removed obsoleted FIXME
  • Loading branch information
aartbik committed Nov 21, 2023
1 parent 2743b30 commit d2d2928
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -412,10 +412,8 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
::mlir::sparse_tensor::SparseTensorDimSliceAttr getDimSlice(::mlir::sparse_tensor::Dimension dim) const;

std::optional<uint64_t> getStaticDimSliceOffset(::mlir::sparse_tensor::Dimension dim) const;
std::optional<uint64_t> getStaticDimSliceSize(::mlir::sparse_tensor::Dimension dim) const;
std::optional<uint64_t> getStaticDimSliceStride(::mlir::sparse_tensor::Dimension dim) const;
std::optional<uint64_t> getStaticLvlSliceOffset(::mlir::sparse_tensor::Level lvl) const;
std::optional<uint64_t> getStaticLvlSliceSize(::mlir::sparse_tensor::Level lvl) const;
std::optional<uint64_t> getStaticLvlSliceStride(::mlir::sparse_tensor::Level lvl) const;

//
Expand Down
32 changes: 7 additions & 25 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,11 +368,6 @@ SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
return getDimSlice(dim).getStaticOffset();
}

std::optional<uint64_t>
SparseTensorEncodingAttr::getStaticDimSliceSize(Dimension dim) const {
return getDimSlice(dim).getStaticSize();
}

std::optional<uint64_t>
SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
return getDimSlice(dim).getStaticStride();
Expand All @@ -384,12 +379,6 @@ SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
return getStaticDimSliceOffset(toOrigDim(*this, lvl));
}

std::optional<uint64_t>
SparseTensorEncodingAttr::getStaticLvlSliceSize(Level lvl) const {
// FIXME: `toOrigDim` is deprecated.
return getStaticDimSliceSize(toOrigDim(*this, lvl));
}

std::optional<uint64_t>
SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
// FIXME: `toOrigDim` is deprecated.
Expand Down Expand Up @@ -1744,33 +1733,26 @@ LogicalResult SortOp::verify() {
if (!xPerm.isPermutation())
emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));

std::optional<int64_t> cn = getConstantIntValue(getN());
// We can't check the size of the buffers when n or buffer dimensions aren't
// compile-time constants.
std::optional<int64_t> cn = getConstantIntValue(getN());
if (!cn)
return success();

uint64_t n = cn.value();
uint64_t ny = 0;
if (auto nyAttr = getNyAttr()) {
ny = nyAttr.getInt();
}

// FIXME: update the types of variables used in expressions bassed as
// the `minSize` argument, to avoid implicit casting at the callsites
// of this lambda.
// Verify dimensions.
const auto checkDim = [&](Value v, Size minSize, const char *message) {
const Size sh = getMemRefType(v).getShape()[0];
if (!ShapedType::isDynamic(sh) && sh < minSize)
emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
};

uint64_t n = cn.value();
uint64_t ny = 0;
if (auto nyAttr = getNyAttr())
ny = nyAttr.getInt();
checkDim(getXy(), n * (nx + ny),
"Expected dimension(xy) >= n * (rank(perm_map) + ny)");

for (Value opnd : getYs()) {
for (Value opnd : getYs())
checkDim(opnd, n, "Expected dimension(y) >= n");
}

return success();
}
Expand Down

0 comments on commit d2d2928

Please sign in to comment.