Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] support sparsifying 2:4 block sparsity #71749

Merged
merged 1 commit into from Nov 10, 2023

Conversation

PeimingLiu
Copy link
Member

No description provided.

@llvmbot llvmbot added mlir:sparse Sparse compiler in MLIR mlir labels Nov 9, 2023
@llvmbot
Copy link
Collaborator

llvmbot commented Nov 9, 2023

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/71749.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h (+2-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp (+13-4)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+3-3)
  • (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+2-2)
  • (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir (+44-18)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 215920f8b4607b2..cde6b2d13e58217 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -540,7 +540,8 @@ class Merger {
   bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const {
     if (isLvlWithNonTrivialIdxExp(b)) {
       auto dlt = getLoopDependentLevelType(b);
-      return isCompressedDLT(dlt) || isSingletonDLT(dlt);
+      return isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
+             isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt);
     }
     return false;
   }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index bb3c6fb56f692d9..6facc87d1b5a029 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -448,7 +448,7 @@ void LoopEmitter::initializeLoopEmit(
         positionsBuffers[t][l] = genToPositions(builder, loc, tensor, l);
         coordinatesBuffers[t][l] =
             genToCoordinates(builder, loc, tensor, l, cooStart);
-      } else if (isSingletonDLT(lvlTp)) {
+      } else if (isSingletonDLT(lvlTp) || is2OutOf4DLT(lvlTp)) {
         // Singleton level, fetch coordinates.
         coordinatesBuffers[t][l] =
             genToCoordinates(builder, loc, tensor, l, cooStart);
@@ -540,7 +540,8 @@ void LoopEmitter::categorizeLoopCondition(
     auto lvlType = lvlTypes[t][l];
     // Must be a recognizable DLT.
     assert(isDenseDLT(lvlType) || isCompressedDLT(lvlType) ||
-           isLooseCompressedDLT(lvlType) || isSingletonDLT(lvlType));
+           isLooseCompressedDLT(lvlType) || isSingletonDLT(lvlType) ||
+           is2OutOf4DLT(lvlType));
 
     bool isSparse = !isDenseDLT(lvlType);
     bool isSlice = isSparseSlices[t];
@@ -637,6 +638,7 @@ std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(
     Value hi, MutableArrayRef<Value> reduc, bool isParallel) {
   bool isSparseCond = isCompressedDLT(lvlTypes[tid][lvl]) ||
                       isLooseCompressedDLT(lvlTypes[tid][lvl]) ||
+                      is2OutOf4DLT(lvlTypes[tid][lvl]) ||
                       isSingletonDLT(lvlTypes[tid][lvl]);
   // TODO: support dynamic slices.
   // Uses the first dimension here to build the loop bound (which is also the
@@ -1240,6 +1242,7 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
 
   const Value c0 = C_IDX(0);
   const Value c1 = C_IDX(1);
+  const Value c2 = C_IDX(2);
   // Either the first level, or the previous level has been set.
   /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
   assert(lvl == 0 || posits[tid][lvl - 1]);
@@ -1248,7 +1251,7 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
 
     Value pLo = lvl == 0 ? c0 : posits[tid][lvl - 1];
     if (isLooseCompressedDLT(lvlTp))
-      pLo = builder.create<arith::MulIOp>(loc, pLo, C_IDX(2));
+      pLo = builder.create<arith::MulIOp>(loc, pLo, c2);
     posits[tid][lvl] = genIndexLoad(builder, loc, mem, pLo);
 
     const Value pHi = ADDI(pLo, c1);
@@ -1271,7 +1274,13 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
                           : ADDI(pLo, c1);
     return;
   }
-
+  if (is2OutOf4DLT(lvlTp)) {
+    const Value pLo = lvl == 0 ? c0 : posits[tid][lvl - 1];
+    // Each 2:4 block has exactly two specified elements.
+    posits[tid][lvl] = MULI(pLo, c2);
+    highs[tid][lvl] = ADDI(posits[tid][lvl], c2);
+    return;
+  }
   llvm_unreachable("Unrecognized level-type!");
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 85d6a6ddabf9eb6..dd121cb05c2184d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -816,7 +816,7 @@ static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
       for (LoopId i = 0; i < numLoops; i++) {
         const auto dltI = env.dlt(tid, i);
         if (isCompressedDLT(dltI) || isLooseCompressedDLT(dltI) ||
-            isSingletonDLT(dltI)) {
+            isSingletonDLT(dltI) || is2OutOf4DLT(dltI)) {
           for (LoopId j = 0; j < numLoops; j++)
             if (isUndefDLT(env.dlt(tid, j))) {
               addIterOrdering(i, j, adjM, inDegree);
@@ -1508,7 +1508,7 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
         assert(ldx == env.merger().loop(b));
         Value clause;
         if (isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
-            isLooseCompressedDLT(dlt)) {
+            isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt)) {
           assert(lvl.has_value());
           const Value crd = env.emitter().getCoords()[tid][*lvl];
           const Value lvar = env.getLoopVar(ldx);
@@ -1593,7 +1593,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
       needsUniv = true;
     }
     if (isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
-        isLooseCompressedDLT(dlt) || isIdxReduc) {
+        isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt) || isIdxReduc) {
       // Only when this is a index reduction loop, can the dlt be undefined.
       assert(!isUndefDLT(dlt) || isIdxReduc);
       // sparse/singleton levels, or a dense/sparse index reduction loop.
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 18ebd608608bdcb..033b61fc872a312 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -490,7 +490,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
     if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
       const auto dlt = getLvlType(b);
       if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt) &&
-          !isLooseCompressedDLT(dlt)) {
+          !isLooseCompressedDLT(dlt) && !is2OutOf4DLT(dlt)) {
         if (reset)
           simple.reset(b);
         reset = true;
@@ -671,7 +671,7 @@ bool Merger::hasAnySparse(const BitVector &bits) const {
   for (TensorLoopId b : bits.set_bits()) {
     const auto dlt = getLvlType(b);
     if (isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
-        isLooseCompressedDLT(dlt))
+        isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt))
       return true;
   }
   return hasSparseIdxReduction(bits);
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
index 7e9606b1caedee9..0c420f2fd426fb2 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
@@ -44,19 +44,41 @@
 #BSR = #sparse_tensor.encoding<{
   map = ( i, j ) ->
   ( i floordiv 2 : dense,
-    j floordiv 3 : compressed,
+    j floordiv 2 : compressed,
     i mod 2      : dense,
-    j mod 3      : dense
+    j mod 2      : dense
   )
 }>
 
+#NV_24 = #sparse_tensor.encoding<{
+  map = ( i, j ) ->
+  ( i            : dense,
+    j floordiv 4 : dense,
+    j mod 4      : block2_4
+  ),
+}>
+
 module {
 
-func.func @mul(%arg0: tensor<4x6xf64>,
-               %arg1: tensor<4x6xf64, #BSR>) -> tensor<4x4xf64> {
-  %out = tensor.empty() : tensor<4x4xf64>
+func.func @mul(%arg0: tensor<4x8xf64>,
+               %arg1: tensor<4x8xf64, #BSR>) -> tensor<4x4xf64> {
+  %out = arith.constant dense<0.0> : tensor<4x4xf64>
+  %0 = linalg.generic #trait_mul
+    ins(%arg0, %arg1: tensor<4x8xf64>, tensor<4x8xf64, #BSR>)
+    outs(%out: tensor<4x4xf64>) {
+      ^bb(%x: f64, %y : f64, %z : f64):
+        %1 = arith.mulf %x, %y : f64
+        %2 = arith.addf %1, %z : f64
+        linalg.yield %2 : f64
+  } -> tensor<4x4xf64>
+  return %0 : tensor<4x4xf64>
+}
+
+func.func @mul_24(%arg0: tensor<4x8xf64>,
+                  %arg1: tensor<4x8xf64, #NV_24>) -> tensor<4x4xf64> {
+  %out = arith.constant dense<0.0> : tensor<4x4xf64>
   %0 = linalg.generic #trait_mul
-    ins(%arg0, %arg1: tensor<4x6xf64>, tensor<4x6xf64, #BSR>)
+    ins(%arg0, %arg1: tensor<4x8xf64>, tensor<4x8xf64, #NV_24>)
     outs(%out: tensor<4x4xf64>) {
       ^bb(%x: f64, %y : f64, %z : f64):
         %1 = arith.mulf %x, %y : f64
@@ -66,11 +88,11 @@ func.func @mul(%arg0: tensor<4x6xf64>,
   return %0 : tensor<4x4xf64>
 }
 
-func.func @mul_dense(%arg0: tensor<4x6xf64>,
-                     %arg1: tensor<4x6xf64>) -> tensor<4x4xf64> {
-  %out = tensor.empty() : tensor<4x4xf64>
+func.func @mul_dense(%arg0: tensor<4x8xf64>,
+                     %arg1: tensor<4x8xf64>) -> tensor<4x4xf64> {
+  %out = arith.constant dense<0.0> : tensor<4x4xf64>
   %0 = linalg.generic #trait_mul
-    ins(%arg0, %arg1: tensor<4x6xf64>, tensor<4x6xf64>)
+    ins(%arg0, %arg1: tensor<4x8xf64>, tensor<4x8xf64>)
     outs(%out: tensor<4x4xf64>) {
       ^bb(%x: f64, %y : f64, %z : f64):
         %1 = arith.mulf %x, %y : f64
@@ -101,22 +123,26 @@ func.func @mul_dense(%arg0: tensor<4x6xf64>,
     %c2 = arith.constant 2 : index
 
 
-    %td = arith.constant dense<[[ 0.0,  1.0,  2.0,  3.0,  4.0,  5.0],
-                                [ 6.0,  7.0,  8.0,  9.0, 10.0, 11.0],
-                                [12.0, 13.0, 14.0, 15.0, 16.0, 17.0],
-                                [18.0, 19.0, 20.0, 21.0, 22.0, 23.0]]> : tensor<4x6xf64>
+    %td = arith.constant dense<[[ 1.0, 2.0,  0.0,  0.0,  0.0,  0.0,  4.0,  5.0],
+                                [ 6.0, 7.0,  0.0,  0.0,  0.0,  0.0, 10.0, 11.0],
+                                [ 0.0, 0.0, 12.0, 13.0, 16.0, 17.0,  0.0,  0.0],
+                                [ 0.0, 0.0, 18.0, 19.0, 22.0, 23.0,  0.0,  0.0]]> : tensor<4x8xf64>
 
 
-    %2 = sparse_tensor.convert %td : tensor<4x6xf64> to tensor<4x6xf64, #BSR>
+    %2 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #BSR>
+    %3 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #NV_24>
 
     %d = call @mul_dense(%td, %td)
-         : (tensor<4x6xf64>, tensor<4x6xf64>) -> tensor<4x4xf64>
+         : (tensor<4x8xf64>, tensor<4x8xf64>) -> tensor<4x4xf64>
     %s = call @mul(%td, %2)
-         : (tensor<4x6xf64>, tensor<4x6xf64, #BSR>) -> tensor<4x4xf64>
+         : (tensor<4x8xf64>, tensor<4x8xf64, #BSR>) -> tensor<4x4xf64>
+    %s24 = call @mul_24(%td, %3)
+         : (tensor<4x8xf64>, tensor<4x8xf64, #NV_24>) -> tensor<4x4xf64>
 
-    // CHECK-COUNT-2: ( ( 55, 145, 235, 325 ), ( 145, 451, 757, 1063 ), ( 235, 757, 1279, 1801 ), ( 325, 1063, 1801, 2539 ) )
+    // CHECK-COUNT-3: ( ( 46, 115, 0, 0 ), ( 115, 306, 0, 0 ), ( 0, 0, 858, 1206 ), ( 0, 0, 1206, 1698 ) )
     call @dumpf64(%d) : (tensor<4x4xf64>) -> ()
     call @dumpf64(%s) : (tensor<4x4xf64>) -> ()
+    call @dumpf64(%s24) : (tensor<4x4xf64>) -> ()
 
     return
   }

@PeimingLiu PeimingLiu merged commit bfe08c0 into llvm:main Nov 10, 2023
3 checks passed
@PeimingLiu PeimingLiu deleted the gen_24 branch November 10, 2023 20:25
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
@Ag-Cu
Copy link

Ag-Cu commented Dec 13, 2023

Hi, recently I have been very interested in MLIR's support for NVIDIA's sparse tensor core. I saw your recent commits and wanted to ask——have MLIR only added support for the NV_2:4 sparse format so far, or have implemented an end-to-end flow that can generate calls to the sparse tensor core via mma.sp PTX instructions?
I would greatly appreciate any information you can provide.

@aartbik
Copy link
Contributor

aartbik commented Dec 13, 2023

Thanks for your interest in this work!

We have an end-to-end flow that maps linalg.matmul to 2:4 using our "CUDA libgen" path, i.e. we change the code into a call into the cusparseLt library for 2:4 operations. We do have a sample example that uses gpu dialect to actually map to the mma instruction (through nvgpu.mma.sp.sync), but that is not part of any codegen... yet!

See https://discourse.llvm.org/t/sparse-compiler-and-gpu-code-generation/69786/ for a more in-depth discussions of what I wrote above.

@Ag-Cu

This comment was marked as resolved.

@aartbik
Copy link
Contributor

aartbik commented Dec 14, 2023

For sure! The current NV_2:4 syntax is just a stepping stone towards a more general N:M syntax with N and M flexible integers. So, expect this to improve soon in MLIR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants