diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp index 58fb2e91b4f637..899b8c87d0df77 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" @@ -110,6 +111,10 @@ struct LinalgOpInterface ArrayRef opOperands) const { auto linalgOp = cast(op); + // Accesses into sparse data structures are not necessarily elementwise. + if (sparse_tensor::hasAnySparseOperand(linalgOp)) + return false; + // All loops must be parallel. if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops()) return false; diff --git a/mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir b/mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir index 6c2292be161a53..b769acdc7825ce 100644 --- a/mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir +++ b/mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir @@ -70,3 +70,39 @@ func.func @update_notinplace(%argb: tensor<10xf32>, %arga: tensor<10xf32, #SV>) } -> tensor<10xf32> return %0, %argb : tensor<10xf32>, tensor<10xf32> } + +#map = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> +#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64 }> + +// linalg.generic with sparse tensors does not necessarily bufferize to +// element-wise access into the underlying sparse data structures. + +// CHECK-LABEL: func @sparse_non_elementwise( +func.func @sparse_non_elementwise(%arg0: tensor<64x64xf32, #sparse>, %arg1: tensor<64x64xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> { + %cst = arith.constant 0.000000e+00 : f32 + // CHECK: %[[alloc0:.*]] = bufferization.alloc_tensor() + // CHECK: %[[alloc1:.*]] = bufferization.alloc_tensor() + %0 = bufferization.alloc_tensor() : tensor<64x64xf32> + // CHECK: %[[generic0:.*]] = linalg.generic {{.*}} outs(%[[alloc1]] : {{.*}}) + %1 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%0 : tensor<64x64xf32>) { + ^bb0(%out: f32): + linalg.yield %cst : f32 + } -> tensor<64x64xf32> + // CHECK: linalg.generic {{.*}} outs(%[[generic0]] : {{.*}}) + %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg2, %arg2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%1 : tensor<64x64xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %4 = arith.mulf %in, %in_0 : f32 + %5 = arith.addf %out, %4 : f32 + linalg.yield %5 : f32 + } -> tensor<64x64xf32> + // CHECK: linalg.generic {{.*}} outs(%[[alloc0]] : {{.*}}) + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %2 : tensor<64x64xf32, #sparse>, tensor<64x64xf32>) outs(%0 : tensor<64x64xf32>) attrs = {sorted = true} { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %4 = arith.mulf %in, %in_0 : f32 + linalg.yield %4 : f32 + } -> tensor<64x64xf32> + return %3 : tensor<64x64xf32> +}