Skip to content

Commit

Permalink
[mlir][sparse] Adding wrappers for constantOverheadTypeEncoding
Browse files Browse the repository at this point in the history
Minor code cleanup

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D114392
  • Loading branch information
wrengr committed Nov 24, 2021
1 parent 17eb6b6 commit d7d7ffe
Showing 1 changed file with 20 additions and 8 deletions.
Expand Up @@ -85,6 +85,22 @@ static Value constantOverheadTypeEncoding(ConversionPatternRewriter &rewriter,
return constantI32(rewriter, loc, static_cast<uint32_t>(sec));
}

/// Generates a constant of the internal type encoding for pointer
/// overhead storage.
static Value constantPointerTypeEncoding(ConversionPatternRewriter &rewriter,
Location loc,
SparseTensorEncodingAttr &enc) {
return constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth());
}

/// Generates a constant of the internal type encoding for index overhead
/// storage.
static Value constantIndexTypeEncoding(ConversionPatternRewriter &rewriter,
Location loc,
SparseTensorEncodingAttr &enc) {
return constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth());
}

/// Generates a constant of the internal type encoding for primary storage.
static Value constantPrimaryTypeEncoding(ConversionPatternRewriter &rewriter,
Location loc, Type tp) {
Expand Down Expand Up @@ -277,10 +293,8 @@ static void newParams(ConversionPatternRewriter &rewriter,
params.push_back(genBuffer(rewriter, loc, rev));
// Secondary and primary types encoding.
ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
params.push_back(
constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth()));
params.push_back(
constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth()));
params.push_back(constantPointerTypeEncoding(rewriter, loc, enc));
params.push_back(constantIndexTypeEncoding(rewriter, loc, enc));
params.push_back(
constantPrimaryTypeEncoding(rewriter, loc, resType.getElementType()));
// User action and pointer.
Expand Down Expand Up @@ -598,10 +612,8 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
newParams(rewriter, params, op, enc, Action::kToCOO, sizes, src);
Value coo = genNewCall(rewriter, op, params);
params[3] = constantOverheadTypeEncoding(rewriter, loc,
encDst.getPointerBitWidth());
params[4] = constantOverheadTypeEncoding(rewriter, loc,
encDst.getIndexBitWidth());
params[3] = constantPointerTypeEncoding(rewriter, loc, encDst);
params[4] = constantIndexTypeEncoding(rewriter, loc, encDst);
params[6] = constantAction(rewriter, loc, Action::kFromCOO);
params[7] = coo;
rewriter.replaceOp(op, genNewCall(rewriter, op, params));
Expand Down

0 comments on commit d7d7ffe

Please sign in to comment.