Skip to content

Commit

Permalink
[mlir][sparse] Enhancing sparse=>sparse conversion.
Browse files Browse the repository at this point in the history
Fixes: #51652

Depends On D122060

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D122061
  • Loading branch information
wrengr committed May 16, 2022
1 parent e0c3b94 commit 8cb3324
Show file tree
Hide file tree
Showing 5 changed files with 539 additions and 28 deletions.
76 changes: 61 additions & 15 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
Expand Up @@ -355,6 +355,32 @@ static void insertScalarIntoDenseTensor(OpBuilder &builder, Location loc,
builder.create<memref::StoreOp>(loc, elemV, tensor, ivs);
}

/// Determine if the runtime library supports direct conversion to the
/// given target `dimTypes`.
static bool canUseDirectConversion(
ArrayRef<SparseTensorEncodingAttr::DimLevelType> dimTypes) {
bool alreadyCompressed = false;
for (uint64_t rank = dimTypes.size(), r = 0; r < rank; r++) {
switch (dimTypes[r]) {
case SparseTensorEncodingAttr::DimLevelType::Compressed:
if (alreadyCompressed)
return false; // Multiple compressed dimensions not yet supported.
alreadyCompressed = true;
break;
case SparseTensorEncodingAttr::DimLevelType::Dense:
if (alreadyCompressed)
return false; // Dense after Compressed not yet supported.
break;
case SparseTensorEncodingAttr::DimLevelType::Singleton:
// Although Singleton isn't generally supported yet, the direct
// conversion method doesn't have any particular problems with
// singleton after compressed.
break;
}
}
return true;
}

//===----------------------------------------------------------------------===//
// Conversion rules.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -492,21 +518,41 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
SmallVector<Value, 8> params;
ShapedType stp = srcType.cast<ShapedType>();
sizesFromPtr(rewriter, sizes, op, encSrc, stp, src);
// Set up encoding with right mix of src and dst so that the two
// method calls can share most parameters, while still providing
// the correct sparsity information to either of them.
auto enc = SparseTensorEncodingAttr::get(
op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(),
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
newParams(rewriter, params, op, stp, enc, Action::kToCOO, sizes, src);
Value coo = genNewCall(rewriter, op, params);
params[3] = constantPointerTypeEncoding(rewriter, loc, encDst);
params[4] = constantIndexTypeEncoding(rewriter, loc, encDst);
params[6] = constantAction(rewriter, loc, Action::kFromCOO);
params[7] = coo;
Value dst = genNewCall(rewriter, op, params);
genDelCOOCall(rewriter, op, stp.getElementType(), coo);
rewriter.replaceOp(op, dst);
bool useDirectConversion;
switch (options.sparseToSparseStrategy) {
case SparseToSparseConversionStrategy::kViaCOO:
useDirectConversion = false;
break;
case SparseToSparseConversionStrategy::kDirect:
useDirectConversion = true;
assert(canUseDirectConversion(encDst.getDimLevelType()) &&
"Unsupported target for direct sparse-to-sparse conversion");
break;
case SparseToSparseConversionStrategy::kAuto:
useDirectConversion = canUseDirectConversion(encDst.getDimLevelType());
break;
}
if (useDirectConversion) {
newParams(rewriter, params, op, stp, encDst, Action::kSparseToSparse,
sizes, src);
rewriter.replaceOp(op, genNewCall(rewriter, op, params));
} else { // use via-COO conversion.
// Set up encoding with right mix of src and dst so that the two
// method calls can share most parameters, while still providing
// the correct sparsity information to either of them.
auto enc = SparseTensorEncodingAttr::get(
op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(),
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
newParams(rewriter, params, op, stp, enc, Action::kToCOO, sizes, src);
Value coo = genNewCall(rewriter, op, params);
params[3] = constantPointerTypeEncoding(rewriter, loc, encDst);
params[4] = constantIndexTypeEncoding(rewriter, loc, encDst);
params[6] = constantAction(rewriter, loc, Action::kFromCOO);
params[7] = coo;
Value dst = genNewCall(rewriter, op, params);
genDelCOOCall(rewriter, op, stp.getElementType(), coo);
rewriter.replaceOp(op, dst);
}
return success();
}
if (!encDst && encSrc) {
Expand Down

0 comments on commit 8cb3324

Please sign in to comment.