Skip to content

Commit

Permalink
[mlir][sparse] support sparsifying 2:4 block sparsity
Browse files Browse the repository at this point in the history
  • Loading branch information
PeimingLiu committed Nov 9, 2023
1 parent 30e4b09 commit b6c7492
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 26 deletions.
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,8 @@ class Merger {
bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const {
if (isLvlWithNonTrivialIdxExp(b)) {
auto dlt = getLoopDependentLevelType(b);
return isCompressedDLT(dlt) || isSingletonDLT(dlt);
return isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt);
}
return false;
}
Expand Down
17 changes: 13 additions & 4 deletions mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ void LoopEmitter::initializeLoopEmit(
positionsBuffers[t][l] = genToPositions(builder, loc, tensor, l);
coordinatesBuffers[t][l] =
genToCoordinates(builder, loc, tensor, l, cooStart);
} else if (isSingletonDLT(lvlTp)) {
} else if (isSingletonDLT(lvlTp) || is2OutOf4DLT(lvlTp)) {
// Singleton level, fetch coordinates.
coordinatesBuffers[t][l] =
genToCoordinates(builder, loc, tensor, l, cooStart);
Expand Down Expand Up @@ -540,7 +540,8 @@ void LoopEmitter::categorizeLoopCondition(
auto lvlType = lvlTypes[t][l];
// Must be a recognizable DLT.
assert(isDenseDLT(lvlType) || isCompressedDLT(lvlType) ||
isLooseCompressedDLT(lvlType) || isSingletonDLT(lvlType));
isLooseCompressedDLT(lvlType) || isSingletonDLT(lvlType) ||
is2OutOf4DLT(lvlType));

bool isSparse = !isDenseDLT(lvlType);
bool isSlice = isSparseSlices[t];
Expand Down Expand Up @@ -637,6 +638,7 @@ std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(
Value hi, MutableArrayRef<Value> reduc, bool isParallel) {
bool isSparseCond = isCompressedDLT(lvlTypes[tid][lvl]) ||
isLooseCompressedDLT(lvlTypes[tid][lvl]) ||
is2OutOf4DLT(lvlTypes[tid][lvl]) ||
isSingletonDLT(lvlTypes[tid][lvl]);
// TODO: support dynamic slices.
// Uses the first dimension here to build the loop bound (which is also the
Expand Down Expand Up @@ -1240,6 +1242,7 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,

const Value c0 = C_IDX(0);
const Value c1 = C_IDX(1);
const Value c2 = C_IDX(2);
// Either the first level, or the previous level has been set.
/// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
assert(lvl == 0 || posits[tid][lvl - 1]);
Expand All @@ -1248,7 +1251,7 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,

Value pLo = lvl == 0 ? c0 : posits[tid][lvl - 1];
if (isLooseCompressedDLT(lvlTp))
pLo = builder.create<arith::MulIOp>(loc, pLo, C_IDX(2));
pLo = builder.create<arith::MulIOp>(loc, pLo, c2);
posits[tid][lvl] = genIndexLoad(builder, loc, mem, pLo);

const Value pHi = ADDI(pLo, c1);
Expand All @@ -1271,7 +1274,13 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
: ADDI(pLo, c1);
return;
}

if (is2OutOf4DLT(lvlTp)) {
const Value pLo = lvl == 0 ? c0 : posits[tid][lvl - 1];
// Each 2:4 block has exactly two specified elements.
posits[tid][lvl] = MULI(pLo, c2);
highs[tid][lvl] = ADDI(posits[tid][lvl], c2);
return;
}
llvm_unreachable("Unrecognized level-type!");
}

Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
for (LoopId i = 0; i < numLoops; i++) {
const auto dltI = env.dlt(tid, i);
if (isCompressedDLT(dltI) || isLooseCompressedDLT(dltI) ||
isSingletonDLT(dltI)) {
isSingletonDLT(dltI) || is2OutOf4DLT(dltI)) {
for (LoopId j = 0; j < numLoops; j++)
if (isUndefDLT(env.dlt(tid, j))) {
addIterOrdering(i, j, adjM, inDegree);
Expand Down Expand Up @@ -1508,7 +1508,7 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
assert(ldx == env.merger().loop(b));
Value clause;
if (isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
isLooseCompressedDLT(dlt)) {
isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt)) {
assert(lvl.has_value());
const Value crd = env.emitter().getCoords()[tid][*lvl];
const Value lvar = env.getLoopVar(ldx);
Expand Down Expand Up @@ -1593,7 +1593,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
needsUniv = true;
}
if (isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
isLooseCompressedDLT(dlt) || isIdxReduc) {
isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt) || isIdxReduc) {
// Only when this is a index reduction loop, can the dlt be undefined.
assert(!isUndefDLT(dlt) || isIdxReduc);
// sparse/singleton levels, or a dense/sparse index reduction loop.
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
const auto dlt = getLvlType(b);
if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt) &&
!isLooseCompressedDLT(dlt)) {
!isLooseCompressedDLT(dlt) && !is2OutOf4DLT(dlt)) {
if (reset)
simple.reset(b);
reset = true;
Expand Down Expand Up @@ -671,7 +671,7 @@ bool Merger::hasAnySparse(const BitVector &bits) const {
for (TensorLoopId b : bits.set_bits()) {
const auto dlt = getLvlType(b);
if (isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
isLooseCompressedDLT(dlt))
isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt))
return true;
}
return hasSparseIdxReduction(bits);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,41 @@
#BSR = #sparse_tensor.encoding<{
map = ( i, j ) ->
( i floordiv 2 : dense,
j floordiv 3 : compressed,
j floordiv 2 : compressed,
i mod 2 : dense,
j mod 3 : dense
j mod 2 : dense
)
}>

#NV_24 = #sparse_tensor.encoding<{
map = ( i, j ) ->
( i : dense,
j floordiv 4 : dense,
j mod 4 : block2_4
),
}>

module {

func.func @mul(%arg0: tensor<4x6xf64>,
%arg1: tensor<4x6xf64, #BSR>) -> tensor<4x4xf64> {
func.func @mul(%arg0: tensor<4x8xf64>,
%arg1: tensor<4x8xf64, #BSR>) -> tensor<4x4xf64> {
%out = arith.constant dense<0.0> : tensor<4x4xf64>
%0 = linalg.generic #trait_mul
ins(%arg0, %arg1: tensor<4x8xf64>, tensor<4x8xf64, #BSR>)
outs(%out: tensor<4x4xf64>) {
^bb(%x: f64, %y : f64, %z : f64):
%1 = arith.mulf %x, %y : f64
%2 = arith.addf %1, %z : f64
linalg.yield %2 : f64
} -> tensor<4x4xf64>
return %0 : tensor<4x4xf64>
}

func.func @mul_24(%arg0: tensor<4x8xf64>,
%arg1: tensor<4x8xf64, #NV_24>) -> tensor<4x4xf64> {
%out = arith.constant dense<0.0> : tensor<4x4xf64>
%0 = linalg.generic #trait_mul
ins(%arg0, %arg1: tensor<4x6xf64>, tensor<4x6xf64, #BSR>)
ins(%arg0, %arg1: tensor<4x8xf64>, tensor<4x8xf64, #NV_24>)
outs(%out: tensor<4x4xf64>) {
^bb(%x: f64, %y : f64, %z : f64):
%1 = arith.mulf %x, %y : f64
Expand All @@ -69,11 +91,11 @@ func.func @mul(%arg0: tensor<4x6xf64>,
return %0 : tensor<4x4xf64>
}

func.func @mul_dense(%arg0: tensor<4x6xf64>,
%arg1: tensor<4x6xf64>) -> tensor<4x4xf64> {
func.func @mul_dense(%arg0: tensor<4x8xf64>,
%arg1: tensor<4x8xf64>) -> tensor<4x4xf64> {
%out = arith.constant dense<0.0> : tensor<4x4xf64>
%0 = linalg.generic #trait_mul
ins(%arg0, %arg1: tensor<4x6xf64>, tensor<4x6xf64>)
ins(%arg0, %arg1: tensor<4x8xf64>, tensor<4x8xf64>)
outs(%out: tensor<4x4xf64>) {
^bb(%x: f64, %y : f64, %z : f64):
%1 = arith.mulf %x, %y : f64
Expand Down Expand Up @@ -104,22 +126,26 @@ func.func @mul_dense(%arg0: tensor<4x6xf64>,
%c2 = arith.constant 2 : index


%td = arith.constant dense<[[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
[ 6.0, 7.0, 8.0, 9.0, 10.0, 11.0],
[12.0, 13.0, 14.0, 15.0, 16.0, 17.0],
[18.0, 19.0, 20.0, 21.0, 22.0, 23.0]]> : tensor<4x6xf64>
%td = arith.constant dense<[[ 1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 4.0, 5.0],
[ 6.0, 7.0, 0.0, 0.0, 0.0, 0.0, 10.0, 11.0],
[ 0.0, 0.0, 12.0, 13.0, 16.0, 17.0, 0.0, 0.0],
[ 0.0, 0.0, 18.0, 19.0, 22.0, 23.0, 0.0, 0.0]]> : tensor<4x8xf64>


%2 = sparse_tensor.convert %td : tensor<4x6xf64> to tensor<4x6xf64, #BSR>
%2 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #BSR>
%3 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #NV_24>

%d = call @mul_dense(%td, %td)
: (tensor<4x6xf64>, tensor<4x6xf64>) -> tensor<4x4xf64>
: (tensor<4x8xf64>, tensor<4x8xf64>) -> tensor<4x4xf64>
%s = call @mul(%td, %2)
: (tensor<4x6xf64>, tensor<4x6xf64, #BSR>) -> tensor<4x4xf64>
: (tensor<4x8xf64>, tensor<4x8xf64, #BSR>) -> tensor<4x4xf64>
%s24 = call @mul_24(%td, %3)
: (tensor<4x8xf64>, tensor<4x8xf64, #NV_24>) -> tensor<4x4xf64>

// CHECK-COUNT-2: ( ( 55, 145, 235, 325 ), ( 145, 451, 757, 1063 ), ( 235, 757, 1279, 1801 ), ( 325, 1063, 1801, 2539 ) )
// CHECK-COUNT-3: ( ( 46, 115, 0, 0 ), ( 115, 306, 0, 0 ), ( 0, 0, 858, 1206 ), ( 0, 0, 1206, 1698 ) )
call @dumpf64(%d) : (tensor<4x4xf64>) -> ()
call @dumpf64(%s) : (tensor<4x4xf64>) -> ()
call @dumpf64(%s24) : (tensor<4x4xf64>) -> ()

return
}
Expand Down

0 comments on commit b6c7492

Please sign in to comment.