Skip to content

Commit

Permalink
[mlir][sparse] support sparse tensor element type conversion in codeg…
Browse files Browse the repository at this point in the history
…en path

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D144578
  • Loading branch information
PeimingLiu committed Feb 23, 2023
1 parent 230e616 commit 85dbb3f
Show file tree
Hide file tree
Showing 8 changed files with 265 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,10 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
/// reset to the default/identity.
SparseTensorEncodingAttr withoutOrdering() const;

/// Constructs a new encoding with the pointer and index bitwidth
/// reset to the default.
SparseTensorEncodingAttr withoutBitWidths() const;

/// Returns true if every level is dense. Also returns true for
/// the null encoding (since dense-tensors are always all-dense).
bool isAllDense() const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure]>,
}

def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
[Pure, SameOperandsAndResultElementType]>,
[Pure]>,
Arguments<(ins AnyTensor:$source)>,
Results<(outs AnyTensor:$dest)> {
string summary = "Converts between different tensor types";
Expand Down
6 changes: 6 additions & 0 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutOrdering() const {
getPointerBitWidth(), getIndexBitWidth());
}

SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
return SparseTensorEncodingAttr::get(getContext(), getDimLevelType(),
getDimOrdering(), getHigherOrdering(), 0,
0);
}

bool SparseTensorEncodingAttr::isAllDense() const {
return !getImpl() || llvm::all_of(getDimLevelType(), isDenseDLT);
}
Expand Down
68 changes: 65 additions & 3 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1030,11 +1030,73 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
SparseTensorEncodingAttr encSrc =
getSparseTensorEncoding(op.getSource().getType());
if (encDst != encSrc) {
// This should be handled by rewriting before codegen.
// Different encoding (except for different bitwidth) should be handled by
// rewriting.
if (encDst.withoutBitWidths() != encSrc.withoutBitWidths()) {
return failure();
}
rewriter.replaceOp(op, adaptor.getSource());

Type retElemTp = op.getResult().getType().getElementType();
Type srcElemTp = op.getSource().getType().getElementType();
// Fold the trivial cases.
if (retElemTp == srcElemTp && encDst == encSrc) {
rewriter.replaceOp(op, adaptor.getSource());
return success();
}
//
// Do element-wise type conversion without using InsertOp.
//
// for each memref in srcTensor:
// dst = memref.alloc
// if srcMemRefType != dstMemRefType:
// for every dst[i] = cast(src[i])
// else:
// dst = memref.copy(src)
Location loc = op.getLoc();
auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource());
SmallVector<Value> fields;
foreachFieldAndTypeInSparseTensor(
SparseTensorType(op.getResult().getType().cast<RankedTensorType>()),
[&rewriter, &fields, srcDesc,
loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl,
DimLevelType /*dlt*/) -> bool {
// Simply reuses the storage specifier as it is an SSA value.
if (fKind == SparseTensorFieldKind::StorageSpec) {
fields.push_back(srcDesc.getSpecifier());
} else {
// Allocates new memrefs
Value srcMem = srcDesc.getMemRefField(fIdx);
// TODO: We can instead use the actual memSize in specifier, that
// would require a subViewOp to avoid overflow when copying
// values.
Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
auto dstMem = rewriter.create<memref::AllocOp>(
loc, fTp.cast<MemRefType>(), sz);
if (fTp != srcMem.getType()) {
// Converts elements type.
scf::buildLoopNest(
rewriter, loc, constantIndex(rewriter, loc, 0), sz,
constantIndex(rewriter, loc, 1),
[srcMem, &dstMem](OpBuilder &builder, Location loc,
ValueRange ivs) {
Value v = builder.create<memref::LoadOp>(loc, srcMem, ivs);
Value casted = genCast(builder, loc, v,
dstMem.getType().getElementType());
builder.create<memref::StoreOp>(loc, casted, dstMem, ivs);
});
} else {
// TODO: We can even reuse the same memref for the new tensor,
// but that requires a `ref-counting` based memory management
// for shared memrefs between multiple sparse tensors.
rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
}
fields.push_back(dstMem);
}
return true;
});

rewriter.replaceOp(
op, genTuple(rewriter, loc, op.getResult().getType(), fields));
return success();
}
};
Expand Down
15 changes: 10 additions & 5 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,12 +616,17 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
PatternRewriter &rewriter) const override {
auto encDst = getSparseTensorEncoding(op.getType());
auto encSrc = getSparseTensorEncoding(op.getSource().getType());
if (encDst && encSrc) {
// Trivial tensor conversion is handled in codegen.
if (encSrc == encDst)
return failure();
return sparse2SparseRewrite(op, rewriter);
if (encDst && encSrc &&
encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {
// Trivial tensor conversion and simple element type conversion is handled
// in codegen.
return failure();
}
// TODO: Add a cast before generating InsertOp.
assert(op.getSource().getType().getElementType() ==
op.getDest().getType().getElementType());
if (encSrc && encDst)
return sparse2SparseRewrite(op, rewriter);
if (encSrc && !encDst)
return sparse2DenseRewrite(op, rewriter);
if (!encSrc && encDst)
Expand Down
13 changes: 0 additions & 13 deletions mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,6 @@ func.func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor
// CHECK-AUTO: %[[T:.*]] = call @newSparseTensor(%[[DimSizesP]], %[[LvlSizesP]], %[[LvlTypesP]], %[[IotaP]], %[[IotaP]], %{{.*}}, %{{.*}}, %{{.*}}, %[[SparseToSparse]], %[[A]])
// CHECK-AUTO: return %[[T]] : !llvm.ptr<i8>

// CHECK-RWT-LABEL: func.func @sparse_convert(
// CHECK-RWT-SAME: %[[A:.*]]: tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 64, indexBitWidth = 64 }>>)
// CHECK-RWT-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-RWT: %[[D:.*]] = tensor.dim %[[A]], %[[C0]]
// CHECK-RWT: %[[DST:.*]] = bufferization.alloc_tensor(%[[D]])
// CHECK-RWT: %[[RET:.*]] = sparse_tensor.foreach in %[[A]] init(%[[DST]])
// CHECK-RWT: ^bb0(%[[FI2:.*]]: index, %[[FV2:.*]]: f32, %[[T:.*]]: tensor<?xf32,
// CHECK-RWT: %[[I:.*]] = sparse_tensor.insert %[[FV2]] into %[[T]]{{\[}}%[[FI2]]]
// CHECK-RWT: sparse_tensor.yield %[[I]]
// CHECK-RWT: }
// CHECK-RWT: %[[T:.*]] = sparse_tensor.load %[[RET]] hasInserts
// CHECK-RWT: %[[R:.*]] = sparse_tensor.convert %[[T]]
// CHECK-RWT: return %[[R]] : tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 32, indexBitWidth = 32 }>>
func.func @sparse_convert(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf32, #SparseVector32> {
%0 = sparse_tensor.convert %arg0 : tensor<?xf32, #SparseVector64> to tensor<?xf32, #SparseVector32>
return %0 : tensor<?xf32, #SparseVector32>
Expand Down
72 changes: 72 additions & 0 deletions mlir/test/Dialect/SparseTensor/convert_sparse2sparse_element.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// RUN: mlir-opt %s --sparse-tensor-codegen --canonicalize --cse | FileCheck %s

#SparseVector64 = #sparse_tensor.encoding<{
dimLevelType = ["compressed"],
pointerBitWidth = 64,
indexBitWidth = 64
}>

#SparseVector32 = #sparse_tensor.encoding<{
dimLevelType = ["compressed"],
pointerBitWidth = 32,
indexBitWidth = 32
}>


// CHECK-LABEL: func.func @sparse_convert(
// CHECK-SAME: %[[VAL_0:.*0]]: memref<?xi64>,
// CHECK-SAME: %[[VAL_1:.*1]]: memref<?xi64>,
// CHECK-SAME: %[[VAL_2:.*2]]: memref<?xf32>,
// CHECK-SAME: %[[VAL_3:.*3]]: !sparse_tensor.storage_specifier
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_6:.*]] = memref.dim %[[VAL_0]], %[[VAL_5]] : memref<?xi64>
// CHECK: %[[VAL_7:.*]] = memref.alloc(%[[VAL_6]]) : memref<?xi32>
// CHECK: scf.for %[[VAL_8:.*]] = %[[VAL_5]] to %[[VAL_6]] step %[[VAL_4]] {
// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_8]]] : memref<?xi64>
// CHECK: %[[VAL_10:.*]] = arith.trunci %[[VAL_9]] : i64 to i32
// CHECK: memref.store %[[VAL_10]], %[[VAL_7]]{{\[}}%[[VAL_8]]] : memref<?xi32>
// CHECK: }
// CHECK: %[[VAL_11:.*]] = memref.dim %[[VAL_1]], %[[VAL_5]] : memref<?xi64>
// CHECK: %[[VAL_12:.*]] = memref.alloc(%[[VAL_11]]) : memref<?xi32>
// CHECK: scf.for %[[VAL_13:.*]] = %[[VAL_5]] to %[[VAL_11]] step %[[VAL_4]] {
// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_13]]] : memref<?xi64>
// CHECK: %[[VAL_15:.*]] = arith.trunci %[[VAL_14]] : i64 to i32
// CHECK: memref.store %[[VAL_15]], %[[VAL_12]]{{\[}}%[[VAL_13]]] : memref<?xi32>
// CHECK: }
// CHECK: %[[VAL_16:.*]] = memref.dim %[[VAL_2]], %[[VAL_5]] : memref<?xf32>
// CHECK: %[[VAL_17:.*]] = memref.alloc(%[[VAL_16]]) : memref<?xf32>
// CHECK: memref.copy %[[VAL_2]], %[[VAL_17]] : memref<?xf32> to memref<?xf32>
// CHECK: return %[[VAL_7]], %[[VAL_12]], %[[VAL_17]], %[[VAL_3]] : memref<?xi32>, memref<?xi32>, memref<?xf32>, !sparse_tensor.storage_specifier
// CHECK: }
func.func @sparse_convert(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf32, #SparseVector32> {
%0 = sparse_tensor.convert %arg0 : tensor<?xf32, #SparseVector64> to tensor<?xf32, #SparseVector32>
return %0 : tensor<?xf32, #SparseVector32>
}

// CHECK-LABEL: func.func @sparse_convert_value(
// CHECK-SAME: %[[VAL_0:.*0]]: memref<?xi32>,
// CHECK-SAME: %[[VAL_1:.*1]]: memref<?xi32>,
// CHECK-SAME: %[[VAL_2:.*2]]: memref<?xf32>,
// CHECK-SAME: %[[VAL_3:.*]]: !sparse_tensor.storage_specifier
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_6:.*]] = memref.dim %[[VAL_0]], %[[VAL_5]] : memref<?xi32>
// CHECK: %[[VAL_7:.*]] = memref.alloc(%[[VAL_6]]) : memref<?xi32>
// CHECK: memref.copy %[[VAL_0]], %[[VAL_7]] : memref<?xi32> to memref<?xi32>
// CHECK: %[[VAL_8:.*]] = memref.dim %[[VAL_1]], %[[VAL_5]] : memref<?xi32>
// CHECK: %[[VAL_9:.*]] = memref.alloc(%[[VAL_8]]) : memref<?xi32>
// CHECK: memref.copy %[[VAL_1]], %[[VAL_9]] : memref<?xi32> to memref<?xi32>
// CHECK: %[[VAL_10:.*]] = memref.dim %[[VAL_2]], %[[VAL_5]] : memref<?xf32>
// CHECK: %[[VAL_11:.*]] = memref.alloc(%[[VAL_10]]) : memref<?xf64>
// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_5]] to %[[VAL_10]] step %[[VAL_4]] {
// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_12]]] : memref<?xf32>
// CHECK: %[[VAL_14:.*]] = arith.extf %[[VAL_13]] : f32 to f64
// CHECK: memref.store %[[VAL_14]], %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<?xf64>
// CHECK: }
// CHECK: return %[[VAL_7]], %[[VAL_9]], %[[VAL_11]], %[[VAL_3]] : memref<?xi32>, memref<?xi32>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK: }
func.func @sparse_convert_value(%arg0: tensor<?xf32, #SparseVector32>) -> tensor<?xf64, #SparseVector32> {
%0 = sparse_tensor.convert %arg0 : tensor<?xf32, #SparseVector32> to tensor<?xf64, #SparseVector32>
return %0 : tensor<?xf64, #SparseVector32>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// DEFINE: %{option} = "enable-runtime-library=false s2s-strategy=2"
// DEFINE: %{compile} = mlir-opt %s --sparse-compiler=%{option}
// DEFINE: %{run} = mlir-cpu-runner \
// DEFINE: -e entry -entry-point-result=void \
// DEFINE: -shared-libs=%mlir_c_runner_utils | \
// DEFINE: FileCheck %s
//
// RUN: %{compile} | %{run}
//
// Do the same run, but now with direct IR generation and vectorization.
// REDEFINE: %{option} = "enable-runtime-library=false s2s-strategy=2 vl=2 reassociate-fp-reductions=true enable-index-optimizations=true"
// RUN: %{compile} | %{run}
//
// Do the same run, but now with direct IR generation and, if available, VLA
// vectorization.
// REDEFINE: %{option} = "enable-runtime-library=false vl=4 enable-arm-sve=%ENABLE_VLA"
// REDEFINE: %{run} = %lli \
// REDEFINE: --entry-function=entry_lli \
// REDEFINE: --extra-module=%S/Inputs/main_for_lli.ll \
// REDEFINE: %VLA_ARCH_ATTR_OPTIONS \
// REDEFINE: --dlopen=%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext | \
// REDEFINE: FileCheck %s
// RUN: %{compile} | mlir-translate -mlir-to-llvmir | %{run}

#Tensor1 = #sparse_tensor.encoding<{
dimLevelType = [ "compressed-nu", "singleton-nu", "singleton" ]
}>

#Tensor2 = #sparse_tensor.encoding<{
dimLevelType = [ "dense", "compressed", "dense" ]
}>

#Tensor3 = #sparse_tensor.encoding<{
dimLevelType = [ "dense", "dense", "compressed" ],
dimOrdering = affine_map<(i,j,k) -> (i,k,j)>
}>

module {
//
// Utility for output.
//
func.func @dump(%arg0: tensor<2x3x4xf32>) {
%c0 = arith.constant 0 : index
%d0 = arith.constant -1.0 : f32
%0 = vector.transfer_read %arg0[%c0, %c0, %c0], %d0: tensor<2x3x4xf32>, vector<2x3x4xf32>
vector.print %0 : vector<2x3x4xf32>
return
}

//
// The first test suite (for non-singleton DimLevelTypes).
//
func.func @entry() {
//
// Initialize a 3-dim dense tensor.
//
%src = arith.constant dense<[
[ [ 1.0, 2.0, 3.0, 4.0 ],
[ 5.0, 6.0, 7.0, 8.0 ],
[ 9.0, 10.0, 11.0, 12.0 ] ],
[ [ 13.0, 14.0, 15.0, 16.0 ],
[ 17.0, 18.0, 19.0, 20.0 ],
[ 21.0, 22.0, 23.0, 24.0 ] ]
]> : tensor<2x3x4xf64>

//
// Convert dense tensor directly to various sparse tensors.
//
%s1 = sparse_tensor.convert %src : tensor<2x3x4xf64> to tensor<2x3x4xf64, #Tensor1>
%s2 = sparse_tensor.convert %src : tensor<2x3x4xf64> to tensor<2x3x4xf64, #Tensor2>
%s3 = sparse_tensor.convert %src : tensor<2x3x4xf64> to tensor<2x3x4xf64, #Tensor3>

//
// Convert sparse tensor directly to another sparse format.
//
%t1 = sparse_tensor.convert %s1 : tensor<2x3x4xf64, #Tensor1> to tensor<2x3x4xf32, #Tensor1>
%t2 = sparse_tensor.convert %s2 : tensor<2x3x4xf64, #Tensor2> to tensor<2x3x4xf32, #Tensor2>
%t3 = sparse_tensor.convert %s3 : tensor<2x3x4xf64, #Tensor3> to tensor<2x3x4xf32, #Tensor3>

//
// Convert sparse tensor back to dense.
//
%d1 = sparse_tensor.convert %t1 : tensor<2x3x4xf32, #Tensor1> to tensor<2x3x4xf32>
%d2 = sparse_tensor.convert %t2 : tensor<2x3x4xf32, #Tensor2> to tensor<2x3x4xf32>
%d3 = sparse_tensor.convert %t3 : tensor<2x3x4xf32, #Tensor3> to tensor<2x3x4xf32>

//
// Check round-trip equality. And release dense tensors.
//
// CHECK-COUNT-3: ( ( ( 1, 2, 3, 4 ), ( 5, 6, 7, 8 ), ( 9, 10, 11, 12 ) ), ( ( 13, 14, 15, 16 ), ( 17, 18, 19, 20 ), ( 21, 22, 23, 24 ) ) )
call @dump(%d1) : (tensor<2x3x4xf32>) -> ()
call @dump(%d2) : (tensor<2x3x4xf32>) -> ()
call @dump(%d3) : (tensor<2x3x4xf32>) -> ()

//
// Release sparse tensors.
//
bufferization.dealloc_tensor %t1 : tensor<2x3x4xf32, #Tensor1>
bufferization.dealloc_tensor %t2 : tensor<2x3x4xf32, #Tensor2>
bufferization.dealloc_tensor %t3 : tensor<2x3x4xf32, #Tensor3>
bufferization.dealloc_tensor %s1 : tensor<2x3x4xf64, #Tensor1>
bufferization.dealloc_tensor %s2 : tensor<2x3x4xf64, #Tensor2>
bufferization.dealloc_tensor %s3 : tensor<2x3x4xf64, #Tensor3>

return
}
}

0 comments on commit 85dbb3f

Please sign in to comment.