Skip to content

Commit

Permalink
[mlir][sparse] enable more sparse convolution kernels.
Browse files Browse the repository at this point in the history
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D147670
  • Loading branch information
PeimingLiu committed Apr 17, 2023
1 parent 7ea597e commit 6a148c5
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 135 deletions.
150 changes: 88 additions & 62 deletions mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
Expand Up @@ -1521,65 +1521,69 @@ std::pair<Operation *, ValueRange> LoopEmitter::genSliceLvlTraverseLoop(
// }
ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
OpBuilder &builder, Location loc, TensorId tid,
ArrayRef<const SliceInfo *> unResLvls, ValueRange userReduc,
ArrayRef<const SliceInfo *> unResLvls,
std::optional<std::pair<TensorId, Level>> firstResLvl, ValueRange userReduc,
LoopBodyBuilder bodyBuilder) {
// assert(unResLvls.size() == 1 && "TODO");
Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2);

const SliceInfo &frontSlice = *unResLvls.back();
Level firstLvl = *frontSlice.slicedOnLvl;
assert(!lvlFullyResolved(tid, firstLvl) && "TODO");

// FIXME: it is not zero when the first level is fully resolved.
Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2);
Value pos = c0;
OpBuilder::InsertPoint ip;
SmallVector<Value> innerArgs(userReduc.begin(), userReduc.end());
scf::ForOp outerMost = nullptr;
if (!lvlFullyResolved(tid, firstLvl)) {
if (isCompressedDLT(lvlTypes[tid][firstLvl])) {
unsigned depth = frontSlice.depth - 1;
Value offset = frontSlice.offset;
Value sPtrBuf = slicePosBuffer[tid][firstLvl][depth];
Value mSz = genIndexLoad(builder, loc, sPtrBuf, c0); // memSize
outerMost = builder.create<scf::ForOp>(
loc, c2, mSz, c2, innerArgs,
[this, c1, tid, firstLvl, offset, sPtrBuf, &ip, &pos, &innerArgs](
OpBuilder &builder, Location loc, Value iv, ValueRange iterArgs) {
// generate traversal for each level.
Value loopLo = genIndexLoad(builder, loc, sPtrBuf, iv);
Value loopHi = genIndexLoad(builder, loc, sPtrBuf, ADDI(iv, c1));
ValueRange itArgs =
genSliceLvlTraverseLoop(
builder, loc, loopLo, loopHi, offset,
sliceSizes[tid][firstLvl].back(), tid, firstLvl, iterArgs,
false,
[&](OpBuilder &builder, Location, Value iv,
MutableArrayRef<Value> reduc) {
ip = builder.saveInsertionPoint();
pos = iv;
innerArgs.assign(reduc.begin(), reduc.end());
})
.second;
YIELD(itArgs);
});
} else if (isDenseDLT(lvlTypes[tid][firstLvl])) {
assert(firstLvl == 0); // This must be the first level.
Value lb = frontSlice.offset;
Value sliceSz =
sliceSizes[tid][*frontSlice.slicedOnLvl][frontSlice.depth - 1];
Value ub = ADDI(lb, sliceSz);
outerMost = builder.create<scf::ForOp>(
loc, lb, ub, c1, innerArgs,
[&](OpBuilder &builder, Location loc, Value iv, ValueRange iterArgs) {
ip = builder.saveInsertionPoint();
pos = iv;
innerArgs.assign(iterArgs.begin(), iterArgs.end());
});
scf::ForOp outerMost = nullptr; // the outtermost loop.
if (firstResLvl.has_value()) {
// Overwrite position when the first level is fully resolved.
pos = posits[firstResLvl->first][firstResLvl->second];
ip = builder.saveInsertionPoint();
} else {
const SliceInfo &frontSlice = *unResLvls.back();
Level firstLvl = *frontSlice.slicedOnLvl;
if (!lvlFullyResolved(tid, firstLvl)) {
if (isCompressedDLT(lvlTypes[tid][firstLvl])) {
unsigned depth = frontSlice.depth - 1;
Value offset = frontSlice.offset;
Value sPtrBuf = slicePosBuffer[tid][firstLvl][depth];
Value mSz = genIndexLoad(builder, loc, sPtrBuf, c0); // memSize
outerMost = builder.create<scf::ForOp>(
loc, c2, mSz, c2, innerArgs,
[this, c1, tid, firstLvl, offset, sPtrBuf, &ip, &pos,
&innerArgs](OpBuilder &builder, Location loc, Value iv,
ValueRange iterArgs) {
// generate traversal for each level.
Value loopLo = genIndexLoad(builder, loc, sPtrBuf, iv);
Value loopHi = genIndexLoad(builder, loc, sPtrBuf, ADDI(iv, c1));
ValueRange itArgs =
genSliceLvlTraverseLoop(
builder, loc, loopLo, loopHi, offset,
sliceSizes[tid][firstLvl].back(), tid, firstLvl, iterArgs,
false,
[&](OpBuilder &builder, Location, Value iv,
MutableArrayRef<Value> reduc) {
ip = builder.saveInsertionPoint();
pos = iv;
innerArgs.assign(reduc.begin(), reduc.end());
})
.second;
YIELD(itArgs);
});
} else if (isDenseDLT(lvlTypes[tid][firstLvl])) {
assert(firstLvl == 0); // This must be the first level.
Value lb = frontSlice.offset;
Value sliceSz =
sliceSizes[tid][*frontSlice.slicedOnLvl][frontSlice.depth - 1];
Value ub = ADDI(lb, sliceSz);
outerMost = builder.create<scf::ForOp>(
loc, lb, ub, c1, innerArgs,
[&](OpBuilder &builder, Location loc, Value iv,
ValueRange iterArgs) {
ip = builder.saveInsertionPoint();
pos = iv;
innerArgs.assign(iterArgs.begin(), iterArgs.end());
});
}
// We generated the loop for the first slice above, now remove it.
unResLvls = unResLvls.drop_back();
}
// We generated the loop for the first slice above, now remove it.
unResLvls = unResLvls.drop_back();
}

// Reset the insertion point into the loop body.
builder.restoreInsertionPoint(ip);
if (!unResLvls.empty()) {
Expand Down Expand Up @@ -1611,20 +1615,28 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
bodyBuilder(builder, loc, pos, innerArgs);
return innerArgs;
});
YIELD(denseNest.results);

if (!outerMost) {
// If the outermost loop has not been set, this is the outermost loop.
outerMost = denseNest.loops.front();
} else {
// Otherwise we need to generate yield operations to link the SSA chain.
YIELD(denseNest.results);
}
} else {
assert(outerMost);
// Generates user request loop body.
bodyBuilder(builder, loc, pos, innerArgs);
YIELD(innerArgs);
}
assert(outerMost);
// Insert after current while operation.
builder.setInsertionPointAfter(outerMost);
return outerMost.getResults();
}

void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
TensorId tid, Level lvl) {
assert(lvl == 0 && "TODO: handle non-first level");
Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2), c3 = C_IDX(3),
c4 = C_IDX(4);
if (isDenseDLT(lvlTypes[tid][lvl])) {
Expand All @@ -1634,14 +1646,23 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
lvl, /*depth=*/1);
return;
}
Value size = sliceSizes[tid][0][0];
Value sPtrBuf = slicePosBuffer[tid][0][0];
Value pHi = genIndexLoad(builder, loc, positionsBuffers[tid][0], c1);
Value size = sliceSizes[tid][lvl][0];
Value sPtrBuf = slicePosBuffer[tid][lvl][0];
Value pHi, pLo;
if (lvl == 0) {
pLo = c0;
pHi = genIndexLoad(builder, loc, positionsBuffers[tid][0], c1);
} else {
pLo = genIndexLoad(builder, loc, positionsBuffers[tid][lvl],
posits[tid][lvl - 1]);
pHi = genIndexLoad(builder, loc, positionsBuffers[tid][lvl],
ADDI(posits[tid][lvl - 1], c1));
}
// Fills out pIdxBuffer[tid][lvl][0] with [/*memSize =*/4, 0, 0, pHi]
builder.create<memref::StoreOp>(loc, c4, sPtrBuf, c0); // memSize = 4
builder.create<memref::StoreOp>(loc, c0, sPtrBuf, c1); // index = 0
builder.create<memref::StoreOp>(loc, c0, sPtrBuf, c2); // pLo = 0;
builder.create<memref::StoreOp>(loc, pHi, sPtrBuf, c3); // loaded pHi.
builder.create<memref::StoreOp>(loc, pLo, sPtrBuf, c2); // pLo
builder.create<memref::StoreOp>(loc, pHi, sPtrBuf, c3); // pHi

// This is an non empty tensor if 0 < pHi.
Value isNonEmpty = CMPI(ult, c0, pHi);
Expand Down Expand Up @@ -1703,10 +1724,15 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
assert(slicePosBuffer[tid][lvl - 1].size() == sliceStack[tid].back().depth);

SmallVector<const SliceInfo *> unResSlices;
std::optional<std::pair<TensorId, Level>> firstResLvl;
for (Level curLvl = lvl; curLvl >= 1; curLvl--) {
Level prevLvl = curLvl - 1;
if (lvlFullyResolved(tid, prevLvl)) {
firstResLvl = std::make_pair(tid, prevLvl);
break;
}
unResSlices.push_back(&getMostRecentSliceOnLvl(tid, prevLvl));
if (!isDenseDLT(lvlTypes[tid][prevLvl]) || lvlFullyResolved(tid, prevLvl)) {
if (!isDenseDLT(lvlTypes[tid][prevLvl])) {
break;
}
}
Expand All @@ -1722,7 +1748,7 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
};

ValueRange result = genUnResolvedSliceTreeTraverse(
builder, loc, tid, unResSlices, reduc,
builder, loc, tid, unResSlices, firstResLvl, reduc,
[this, c1, c2, tid, lvl, sPtrBuf](OpBuilder &builder, Location loc,
Value iv,
MutableArrayRef<Value> reduc) {
Expand Down Expand Up @@ -1869,7 +1895,7 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
void LoopEmitter::invalidateSliceIterIdx(OpBuilder &builder, Location loc,
TensorId tid, Level lvl) {
for (unsigned i = 0; i <= lvl; i++) {
if (!isDenseDLT(lvlTypes[tid][i])) {
if (!isDenseDLT(lvlTypes[tid][i]) && !dependentLvlMap[tid][i].empty()) {
builder.create<memref::StoreOp>(loc, C_IDX(0),
slicePosBuffer[tid][i].back(), C_IDX(1));
}
Expand Down
10 changes: 5 additions & 5 deletions mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
Expand Up @@ -452,11 +452,11 @@ class LoopEmitter {

/// Generates a nested loop that iterates over tid on all the coordinates on
/// lvl.
ValueRange
genUnResolvedSliceTreeTraverse(OpBuilder &builder, Location loc, TensorId tid,
ArrayRef<const SliceInfo *> unResLvls,
ValueRange userReduc,
LoopBodyBuilder bodyBuilder);
ValueRange genUnResolvedSliceTreeTraverse(
OpBuilder &builder, Location loc, TensorId tid,
ArrayRef<const SliceInfo *> unResLvls,
std::optional<std::pair<TensorId, Level>> firstResLvl,
ValueRange userReduc, LoopBodyBuilder bodyBuilder);

/// Generates code to get the first non-empty slice of tid on lvl, when all
/// the previous level before `lvl` are resolved (or lvl is the first level).
Expand Down
@@ -1,9 +1,4 @@
// UNSUPPORTED: target={{.*}}
// FIXME: The test case is disabled (for now) because affine index on sparse tensor
// are not handled efficiently by sparse compiler, the test case will be re-enabled
// after new algorithm is implemented.

// DEFINE: %{option} = enable-runtime-library=true
// DEFINE: %{option} = "enable-runtime-library=true enable-index-reduction=true"
// DEFINE: %{compile} = mlir-opt %s --sparse-compiler=%{option}
// DEFINE: %{run} = mlir-cpu-runner \
// DEFINE: -e entry -entry-point-result=void \
Expand All @@ -13,16 +8,16 @@
// RUN: %{compile} | %{run}
//
// Do the same run, but now with direct IR generation.
// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true"
// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true enable-index-reduction=true"
// RUN: %{compile} | %{run}
//
// Do the same run, but now with direct IR generation and vectorization.
// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true vl=2 reassociate-fp-reductions=true enable-index-optimizations=true"
// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true vl=2 reassociate-fp-reductions=true enable-index-optimizations=true enable-index-reduction=true"
// RUN: %{compile} | %{run}

// Do the same run, but now with direct IR generation and, if available, VLA
// vectorization.
// REDEFINE: %{option} = "enable-runtime-library=false vl=4 enable-arm-sve=%ENABLE_VLA"
// REDEFINE: %{option} = "enable-runtime-library=false vl=4 enable-arm-sve=%ENABLE_VLA enable-index-reduction=true"
// REDEFINE: %{run} = %lli \
// REDEFINE: --entry-function=entry_lli \
// REDEFINE: --extra-module=%S/Inputs/main_for_lli.ll \
Expand Down Expand Up @@ -55,26 +50,26 @@ func.func @conv_1d_nwc_wcf(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %
return %ret : tensor<?x?x?xf32>
}

func.func @conv_1d_nwc_wcf_CCC(%arg0: tensor<?x?x?xf32, #CCC>, %arg1: tensor<?x?x?xf32, #CCC>) -> tensor<?x?x?xf32, #CCC> {
func.func @conv_1d_nwc_wcf_CCC(%arg0: tensor<?x?x?xf32, #CCC>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32, #CCC> {
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
%c6 = arith.constant 6 : index
%s = bufferization.alloc_tensor(%c3, %c6, %c1) : tensor<?x?x?xf32, #CCC>
%ret = linalg.conv_1d_nwc_wcf {dilations = dense<1> : tensor<1xi64>,
strides = dense<1> : tensor<1xi64>}
ins (%arg0, %arg1: tensor<?x?x?xf32, #CCC>, tensor<?x?x?xf32, #CCC>)
ins (%arg0, %arg1: tensor<?x?x?xf32, #CCC>, tensor<?x?x?xf32>)
outs (%s: tensor<?x?x?xf32, #CCC>) -> tensor<?x?x?xf32, #CCC>
return %ret : tensor<?x?x?xf32, #CCC>
}

func.func @conv_1d_nwc_wcf_CDC(%arg0: tensor<?x?x?xf32, #CDC>, %arg1: tensor<?x?x?xf32, #CDC>) -> tensor<?x?x?xf32, #CDC> {
func.func @conv_1d_nwc_wcf_CDC(%arg0: tensor<?x?x?xf32, #CDC>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32, #CDC> {
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
%c6 = arith.constant 6 : index
%s = bufferization.alloc_tensor(%c3, %c6, %c1) : tensor<?x?x?xf32, #CDC>
%ret = linalg.conv_1d_nwc_wcf {dilations = dense<1> : tensor<1xi64>,
strides = dense<1> : tensor<1xi64>}
ins (%arg0, %arg1: tensor<?x?x?xf32, #CDC>, tensor<?x?x?xf32, #CDC>)
ins (%arg0, %arg1: tensor<?x?x?xf32, #CDC>, tensor<?x?x?xf32>)
outs (%s: tensor<?x?x?xf32, #CDC>) -> tensor<?x?x?xf32, #CDC>
return %ret : tensor<?x?x?xf32, #CDC>
}
Expand All @@ -91,22 +86,18 @@ func.func @entry() {

%in1D_tmp = call @alloc_3d_filled_f32(%c3, %c8, %c1, %val) : (index, index, index, f32) -> (tensor<?x?x?xf32>)
%in1D_nwc = tensor.insert %f10 into %in1D_tmp[%c0, %c3, %c0] : tensor<?x?x?xf32>

%filter1D_nwc = call @alloc_3d_filled_f32(%c3, %c1, %c1, %val) : (index, index, index, f32) -> (tensor<?x?x?xf32>)
%out1D_nwc = call @alloc_3d_filled_f32(%c3, %c6, %c1, %zero) : (index, index, index, f32) -> (tensor<?x?x?xf32>)

%in1D_nwc_CCC = sparse_tensor.convert %in1D_nwc
: tensor<?x?x?xf32> to tensor<?x?x?xf32, #CCC>
%filter1D_nwc_CCC = sparse_tensor.convert %filter1D_nwc
: tensor<?x?x?xf32> to tensor<?x?x?xf32, #CCC>

%in1D_nwc_CDC = sparse_tensor.convert %in1D_nwc
: tensor<?x?x?xf32> to tensor<?x?x?xf32, #CDC>
%filter1D_nwc_CDC = sparse_tensor.convert %filter1D_nwc
: tensor<?x?x?xf32> to tensor<?x?x?xf32, #CDC>

%dense_ret = call @conv_1d_nwc_wcf(%in1D_nwc, %filter1D_nwc, %out1D_nwc) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>)
%CCC_ret = call @conv_1d_nwc_wcf_CCC(%in1D_nwc_CCC, %filter1D_nwc_CCC) : (tensor<?x?x?xf32, #CCC>, tensor<?x?x?xf32, #CCC>) -> (tensor<?x?x?xf32, #CCC>)
%CDC_ret = call @conv_1d_nwc_wcf_CDC(%in1D_nwc_CDC, %filter1D_nwc_CDC) : (tensor<?x?x?xf32, #CDC>, tensor<?x?x?xf32, #CDC>) -> (tensor<?x?x?xf32, #CDC>)
%CCC_ret = call @conv_1d_nwc_wcf_CCC(%in1D_nwc_CCC, %filter1D_nwc) : (tensor<?x?x?xf32, #CCC>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32, #CCC>)
%CDC_ret = call @conv_1d_nwc_wcf_CDC(%in1D_nwc_CDC, %filter1D_nwc) : (tensor<?x?x?xf32, #CDC>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32, #CDC>)

// CHECK: ( ( ( 12 ), ( 28 ), ( 28 ), ( 28 ), ( 12 ), ( 12 ) ),
// CHECK-SAME: ( ( 12 ), ( 12 ), ( 12 ), ( 12 ), ( 12 ), ( 12 ) ),
Expand Down Expand Up @@ -139,9 +130,7 @@ func.func @entry() {
bufferization.dealloc_tensor %out1D_nwc : tensor<?x?x?xf32>

bufferization.dealloc_tensor %in1D_nwc_CDC : tensor<?x?x?xf32, #CDC>
bufferization.dealloc_tensor %filter1D_nwc_CDC : tensor<?x?x?xf32, #CDC>
bufferization.dealloc_tensor %in1D_nwc_CCC : tensor<?x?x?xf32, #CCC>
bufferization.dealloc_tensor %filter1D_nwc_CCC : tensor<?x?x?xf32, #CCC>

bufferization.dealloc_tensor %CCC_ret : tensor<?x?x?xf32, #CCC>
bufferization.dealloc_tensor %CDC_ret : tensor<?x?x?xf32, #CDC>
Expand Down

0 comments on commit 6a148c5

Please sign in to comment.