diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp index 40c182f9dbb37..0a8f29cd0b417 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp @@ -26,8 +26,10 @@ static void convTypes(bool &hasAnnotation, TypeRange types, SmallVectorImpl &convTypes, SmallVectorImpl *extraTypes, bool directOut) { for (auto type : types) { - // All "dense" data passes through unmodified. - if (!getSparseTensorEncoding(type)) { + // All "dense" data passes through unmodified. Note: getSparseTensorEncoding + // also returns non-null for StorageSpecifierType (which is not a + // RankedTensorType), so we must check isa as well. + if (!getSparseTensorEncoding(type) || !isa(type)) { convTypes.push_back(type); continue; } @@ -62,8 +64,10 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types, bool directOut) { unsigned idx = 0; for (auto type : types) { - // All "dense" data passes through unmodified. - if (!getSparseTensorEncoding(type)) { + // All "dense" data passes through unmodified. Note: getSparseTensorEncoding + // also returns non-null for StorageSpecifierType (which is not a + // RankedTensorType), so we must check isa as well. + if (!getSparseTensorEncoding(type) || !isa(type)) { toVals.push_back(fromVals[idx++]); continue; } diff --git a/mlir/test/Dialect/SparseTensor/external_after_codegen.mlir b/mlir/test/Dialect/SparseTensor/external_after_codegen.mlir new file mode 100644 index 0000000000000..b217900c64498 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/external_after_codegen.mlir @@ -0,0 +1,30 @@ +// RUN: mlir-opt %s --sparse-tensor-codegen --sparse-assembler | FileCheck %s + +// Regression test for https://github.com/llvm/llvm-project/issues/183776: +// Running --sparse-assembler after --sparse-tensor-codegen must not crash. +// After codegen, sparse tensor arguments are replaced by memrefs and +// \!sparse_tensor.storage_specifier types. getSparseTensorEncoding() returns +// non-null for StorageSpecifierType, but convTypes()/convVals() must not +// attempt cast on it. Instead, non-RankedTensorType types +// with a sparse encoding should pass through unchanged. + +#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0: dense, d1: compressed) }> + +// Storage_specifier types from codegen must pass through sparse-assembler +// unchanged (not be treated as sparse tensor arguments to wrap). +// CHECK-LABEL: func.func @storage_specifier_passthrough( +// CHECK-SAME: storage_specifier +// CHECK-SAME: storage_specifier +// CHECK: return %{{.*}} : tensor<32x32xf32> +func.func @storage_specifier_passthrough(%arg0: tensor<32x32xf32, #CSR>, + %arg1: tensor<32x32xf32, #CSR>) + -> tensor<32x32xf32> { + %cst = arith.constant 0.0 : f32 + %init = tensor.empty() : tensor<32x32xf32> + %out = linalg.fill ins(%cst : f32) outs(%init : tensor<32x32xf32>) + -> tensor<32x32xf32> + %3 = linalg.add + ins(%arg0, %arg1 : tensor<32x32xf32, #CSR>, tensor<32x32xf32, #CSR>) + outs(%out : tensor<32x32xf32>) -> tensor<32x32xf32> + return %3 : tensor<32x32xf32> +}