Skip to content

Commit

Permalink
[mlir][sparse] add a csr x bsr matmul test case (#73012)
Browse files Browse the repository at this point in the history
  • Loading branch information
PeimingLiu committed Nov 21, 2023
1 parent 1caaec1 commit b52eb7c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1497,8 +1497,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
levelReducedDep[tid][lvl]--;
if (!resolved) {
// TODO: support coiterating multiple slices
assert(loopInfo.trivialTidLvls.empty() &&
loopInfo.sliceDrivenInfo.size() == 1);
assert(loopInfo.sliceDrivenInfo.size() == 1);
auto [nxNonEmpty, nxMinCrd, nxAbsOffset] =
genSliceNextInduction(builder, loc, tid, lvl);
// Update while loop induction operands.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
doc = "X(i,j) *= A(i,j) * B(j,i)"
}

#CSR = #sparse_tensor.encoding<{
map = ( i, j ) -> (i : dense, j : compressed)
}>


#BSR = #sparse_tensor.encoding<{
map = ( i, j ) ->
Expand Down Expand Up @@ -89,6 +93,20 @@ func.func @mul_24(%arg0: tensor<4x8xf64>,
return %0 : tensor<4x4xf64>
}

func.func @mul_csr_bsr(%arg0: tensor<4x8xf64, #CSR>,
%arg1: tensor<4x8xf64, #BSR>) -> tensor<4x4xf64> {
%out = arith.constant dense<0.0> : tensor<4x4xf64>
%0 = linalg.generic #trait_mul
ins(%arg0, %arg1: tensor<4x8xf64, #CSR>, tensor<4x8xf64, #BSR>)
outs(%out: tensor<4x4xf64>) {
^bb(%x: f64, %y : f64, %z : f64):
%1 = arith.mulf %x, %y : f64
%2 = arith.addf %1, %z : f64
linalg.yield %2 : f64
} -> tensor<4x4xf64>
return %0 : tensor<4x4xf64>
}

func.func @mul_dense(%arg0: tensor<4x8xf64>,
%arg1: tensor<4x8xf64>) -> tensor<4x4xf64> {
%out = arith.constant dense<0.0> : tensor<4x4xf64>
Expand Down Expand Up @@ -132,18 +150,22 @@ func.func @mul_dense(%arg0: tensor<4x8xf64>,

%2 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #BSR>
%3 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #NV_24>
%4 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #CSR>

%d = call @mul_dense(%td, %td)
: (tensor<4x8xf64>, tensor<4x8xf64>) -> tensor<4x4xf64>
%s = call @mul(%td, %2)
: (tensor<4x8xf64>, tensor<4x8xf64, #BSR>) -> tensor<4x4xf64>
%s24 = call @mul_24(%td, %3)
: (tensor<4x8xf64>, tensor<4x8xf64, #NV_24>) -> tensor<4x4xf64>
%scsr = call @mul_csr_bsr(%4, %2)
: (tensor<4x8xf64, #CSR>, tensor<4x8xf64, #BSR>) -> tensor<4x4xf64>

// CHECK-COUNT-3: ( ( 46, 115, 0, 0 ), ( 115, 306, 0, 0 ), ( 0, 0, 858, 1206 ), ( 0, 0, 1206, 1698 ) )
// CHECK-COUNT-4: ( ( 46, 115, 0, 0 ), ( 115, 306, 0, 0 ), ( 0, 0, 858, 1206 ), ( 0, 0, 1206, 1698 ) )
call @dumpf64(%d) : (tensor<4x4xf64>) -> ()
call @dumpf64(%s) : (tensor<4x4xf64>) -> ()
call @dumpf64(%s24) : (tensor<4x4xf64>) -> ()
call @dumpf64(%scsr) : (tensor<4x4xf64>) -> ()

return
}
Expand Down

0 comments on commit b52eb7c

Please sign in to comment.