New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][sparse] support sparsifying 2:4 block sparsity #71749
Conversation
@llvm/pr-subscribers-mlir-sparse @llvm/pr-subscribers-mlir Author: Peiming Liu (PeimingLiu) ChangesFull diff: https://github.com/llvm/llvm-project/pull/71749.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 215920f8b4607b2..cde6b2d13e58217 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -540,7 +540,8 @@ class Merger {
bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const {
if (isLvlWithNonTrivialIdxExp(b)) {
auto dlt = getLoopDependentLevelType(b);
- return isCompressedDLT(dlt) || isSingletonDLT(dlt);
+ return isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
+ isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt);
}
return false;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index bb3c6fb56f692d9..6facc87d1b5a029 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -448,7 +448,7 @@ void LoopEmitter::initializeLoopEmit(
positionsBuffers[t][l] = genToPositions(builder, loc, tensor, l);
coordinatesBuffers[t][l] =
genToCoordinates(builder, loc, tensor, l, cooStart);
- } else if (isSingletonDLT(lvlTp)) {
+ } else if (isSingletonDLT(lvlTp) || is2OutOf4DLT(lvlTp)) {
// Singleton level, fetch coordinates.
coordinatesBuffers[t][l] =
genToCoordinates(builder, loc, tensor, l, cooStart);
@@ -540,7 +540,8 @@ void LoopEmitter::categorizeLoopCondition(
auto lvlType = lvlTypes[t][l];
// Must be a recognizable DLT.
assert(isDenseDLT(lvlType) || isCompressedDLT(lvlType) ||
- isLooseCompressedDLT(lvlType) || isSingletonDLT(lvlType));
+ isLooseCompressedDLT(lvlType) || isSingletonDLT(lvlType) ||
+ is2OutOf4DLT(lvlType));
bool isSparse = !isDenseDLT(lvlType);
bool isSlice = isSparseSlices[t];
@@ -637,6 +638,7 @@ std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(
Value hi, MutableArrayRef<Value> reduc, bool isParallel) {
bool isSparseCond = isCompressedDLT(lvlTypes[tid][lvl]) ||
isLooseCompressedDLT(lvlTypes[tid][lvl]) ||
+ is2OutOf4DLT(lvlTypes[tid][lvl]) ||
isSingletonDLT(lvlTypes[tid][lvl]);
// TODO: support dynamic slices.
// Uses the first dimension here to build the loop bound (which is also the
@@ -1240,6 +1242,7 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
const Value c0 = C_IDX(0);
const Value c1 = C_IDX(1);
+ const Value c2 = C_IDX(2);
// Either the first level, or the previous level has been set.
/// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
assert(lvl == 0 || posits[tid][lvl - 1]);
@@ -1248,7 +1251,7 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
Value pLo = lvl == 0 ? c0 : posits[tid][lvl - 1];
if (isLooseCompressedDLT(lvlTp))
- pLo = builder.create<arith::MulIOp>(loc, pLo, C_IDX(2));
+ pLo = builder.create<arith::MulIOp>(loc, pLo, c2);
posits[tid][lvl] = genIndexLoad(builder, loc, mem, pLo);
const Value pHi = ADDI(pLo, c1);
@@ -1271,7 +1274,13 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
: ADDI(pLo, c1);
return;
}
-
+ if (is2OutOf4DLT(lvlTp)) {
+ const Value pLo = lvl == 0 ? c0 : posits[tid][lvl - 1];
+ // Each 2:4 block has exactly two specified elements.
+ posits[tid][lvl] = MULI(pLo, c2);
+ highs[tid][lvl] = ADDI(posits[tid][lvl], c2);
+ return;
+ }
llvm_unreachable("Unrecognized level-type!");
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 85d6a6ddabf9eb6..dd121cb05c2184d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -816,7 +816,7 @@ static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
for (LoopId i = 0; i < numLoops; i++) {
const auto dltI = env.dlt(tid, i);
if (isCompressedDLT(dltI) || isLooseCompressedDLT(dltI) ||
- isSingletonDLT(dltI)) {
+ isSingletonDLT(dltI) || is2OutOf4DLT(dltI)) {
for (LoopId j = 0; j < numLoops; j++)
if (isUndefDLT(env.dlt(tid, j))) {
addIterOrdering(i, j, adjM, inDegree);
@@ -1508,7 +1508,7 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
assert(ldx == env.merger().loop(b));
Value clause;
if (isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
- isLooseCompressedDLT(dlt)) {
+ isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt)) {
assert(lvl.has_value());
const Value crd = env.emitter().getCoords()[tid][*lvl];
const Value lvar = env.getLoopVar(ldx);
@@ -1593,7 +1593,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
needsUniv = true;
}
if (isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
- isLooseCompressedDLT(dlt) || isIdxReduc) {
+ isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt) || isIdxReduc) {
// Only when this is a index reduction loop, can the dlt be undefined.
assert(!isUndefDLT(dlt) || isIdxReduc);
// sparse/singleton levels, or a dense/sparse index reduction loop.
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 18ebd608608bdcb..033b61fc872a312 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -490,7 +490,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
const auto dlt = getLvlType(b);
if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt) &&
- !isLooseCompressedDLT(dlt)) {
+ !isLooseCompressedDLT(dlt) && !is2OutOf4DLT(dlt)) {
if (reset)
simple.reset(b);
reset = true;
@@ -671,7 +671,7 @@ bool Merger::hasAnySparse(const BitVector &bits) const {
for (TensorLoopId b : bits.set_bits()) {
const auto dlt = getLvlType(b);
if (isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
- isLooseCompressedDLT(dlt))
+ isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt))
return true;
}
return hasSparseIdxReduction(bits);
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
index 7e9606b1caedee9..0c420f2fd426fb2 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
@@ -44,19 +44,41 @@
#BSR = #sparse_tensor.encoding<{
map = ( i, j ) ->
( i floordiv 2 : dense,
- j floordiv 3 : compressed,
+ j floordiv 2 : compressed,
i mod 2 : dense,
- j mod 3 : dense
+ j mod 2 : dense
)
}>
+#NV_24 = #sparse_tensor.encoding<{
+ map = ( i, j ) ->
+ ( i : dense,
+ j floordiv 4 : dense,
+ j mod 4 : block2_4
+ ),
+}>
+
module {
-func.func @mul(%arg0: tensor<4x6xf64>,
- %arg1: tensor<4x6xf64, #BSR>) -> tensor<4x4xf64> {
- %out = tensor.empty() : tensor<4x4xf64>
+func.func @mul(%arg0: tensor<4x8xf64>,
+ %arg1: tensor<4x8xf64, #BSR>) -> tensor<4x4xf64> {
+ %out = arith.constant dense<0.0> : tensor<4x4xf64>
+ %0 = linalg.generic #trait_mul
+ ins(%arg0, %arg1: tensor<4x8xf64>, tensor<4x8xf64, #BSR>)
+ outs(%out: tensor<4x4xf64>) {
+ ^bb(%x: f64, %y : f64, %z : f64):
+ %1 = arith.mulf %x, %y : f64
+ %2 = arith.addf %1, %z : f64
+ linalg.yield %2 : f64
+ } -> tensor<4x4xf64>
+ return %0 : tensor<4x4xf64>
+}
+
+func.func @mul_24(%arg0: tensor<4x8xf64>,
+ %arg1: tensor<4x8xf64, #NV_24>) -> tensor<4x4xf64> {
+ %out = arith.constant dense<0.0> : tensor<4x4xf64>
%0 = linalg.generic #trait_mul
- ins(%arg0, %arg1: tensor<4x6xf64>, tensor<4x6xf64, #BSR>)
+ ins(%arg0, %arg1: tensor<4x8xf64>, tensor<4x8xf64, #NV_24>)
outs(%out: tensor<4x4xf64>) {
^bb(%x: f64, %y : f64, %z : f64):
%1 = arith.mulf %x, %y : f64
@@ -66,11 +88,11 @@ func.func @mul(%arg0: tensor<4x6xf64>,
return %0 : tensor<4x4xf64>
}
-func.func @mul_dense(%arg0: tensor<4x6xf64>,
- %arg1: tensor<4x6xf64>) -> tensor<4x4xf64> {
- %out = tensor.empty() : tensor<4x4xf64>
+func.func @mul_dense(%arg0: tensor<4x8xf64>,
+ %arg1: tensor<4x8xf64>) -> tensor<4x4xf64> {
+ %out = arith.constant dense<0.0> : tensor<4x4xf64>
%0 = linalg.generic #trait_mul
- ins(%arg0, %arg1: tensor<4x6xf64>, tensor<4x6xf64>)
+ ins(%arg0, %arg1: tensor<4x8xf64>, tensor<4x8xf64>)
outs(%out: tensor<4x4xf64>) {
^bb(%x: f64, %y : f64, %z : f64):
%1 = arith.mulf %x, %y : f64
@@ -101,22 +123,26 @@ func.func @mul_dense(%arg0: tensor<4x6xf64>,
%c2 = arith.constant 2 : index
- %td = arith.constant dense<[[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
- [ 6.0, 7.0, 8.0, 9.0, 10.0, 11.0],
- [12.0, 13.0, 14.0, 15.0, 16.0, 17.0],
- [18.0, 19.0, 20.0, 21.0, 22.0, 23.0]]> : tensor<4x6xf64>
+ %td = arith.constant dense<[[ 1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 4.0, 5.0],
+ [ 6.0, 7.0, 0.0, 0.0, 0.0, 0.0, 10.0, 11.0],
+ [ 0.0, 0.0, 12.0, 13.0, 16.0, 17.0, 0.0, 0.0],
+ [ 0.0, 0.0, 18.0, 19.0, 22.0, 23.0, 0.0, 0.0]]> : tensor<4x8xf64>
- %2 = sparse_tensor.convert %td : tensor<4x6xf64> to tensor<4x6xf64, #BSR>
+ %2 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #BSR>
+ %3 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #NV_24>
%d = call @mul_dense(%td, %td)
- : (tensor<4x6xf64>, tensor<4x6xf64>) -> tensor<4x4xf64>
+ : (tensor<4x8xf64>, tensor<4x8xf64>) -> tensor<4x4xf64>
%s = call @mul(%td, %2)
- : (tensor<4x6xf64>, tensor<4x6xf64, #BSR>) -> tensor<4x4xf64>
+ : (tensor<4x8xf64>, tensor<4x8xf64, #BSR>) -> tensor<4x4xf64>
+ %s24 = call @mul_24(%td, %3)
+ : (tensor<4x8xf64>, tensor<4x8xf64, #NV_24>) -> tensor<4x4xf64>
- // CHECK-COUNT-2: ( ( 55, 145, 235, 325 ), ( 145, 451, 757, 1063 ), ( 235, 757, 1279, 1801 ), ( 325, 1063, 1801, 2539 ) )
+ // CHECK-COUNT-3: ( ( 46, 115, 0, 0 ), ( 115, 306, 0, 0 ), ( 0, 0, 858, 1206 ), ( 0, 0, 1206, 1698 ) )
call @dumpf64(%d) : (tensor<4x4xf64>) -> ()
call @dumpf64(%s) : (tensor<4x4xf64>) -> ()
+ call @dumpf64(%s24) : (tensor<4x4xf64>) -> ()
return
}
|
Hi, recently I have been very interested in MLIR's support for NVIDIA's sparse tensor core. I saw your recent commits and wanted to ask——have MLIR only added support for the NV_2:4 sparse format so far, or have implemented an end-to-end flow that can generate calls to the sparse tensor core via mma.sp PTX instructions? |
Thanks for your interest in this work! We have an end-to-end flow that maps linalg.matmul to 2:4 using our "CUDA libgen" path, i.e. we change the code into a call into the cusparseLt library for 2:4 operations. We do have a sample example that uses gpu dialect to actually map to the mma instruction (through nvgpu.mma.sp.sync), but that is not part of any codegen... yet! See https://discourse.llvm.org/t/sparse-compiler-and-gpu-code-generation/69786/ for a more in-depth discussions of what I wrote above. |
This comment was marked as resolved.
This comment was marked as resolved.
For sure! The current NV_2:4 syntax is just a stepping stone towards a more general N:M syntax with N and M flexible integers. So, expect this to improve soon in MLIR! |
No description provided.