diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp index 75bb1757a55f5..35abdc027f194 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -64,6 +64,14 @@ FailureOr mlir::linalg::generalizeNamedOp(RewriterBase &rewriter, outputs, indexingMaps, iterators); rewriter.inlineRegionBefore(linalgOp->getRegion(0), genericOp.getRegion(), genericOp.getRegion().begin()); + + // Discardable attributes carry user-defined metadata (e.g., annotations for + // downstream passes). Generalization is a semantics-preserving + // transformation, so dropping this metadata would be unexpected. This is safe + // because discardable attributes are by definition independent of op + // semantics. + genericOp->setDiscardableAttrs(linalgOp->getDiscardableAttrDictionary()); + rewriter.replaceOp(linalgOp, genericOp->getResults()); return genericOp; } diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir index dcdd6c8db4b21..e346bee901f1d 100644 --- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir @@ -1261,3 +1261,21 @@ func.func @contract_matmul_bcast_a_b( outs(%C: memref<3x7xf32>) return } + +// ----- + +// Test that discardable (user-defined) attributes are preserved during +// generalization. + +func.func @preserve_discardable_attrs(%A : tensor<16x8xf32>, + %B : tensor<8x32xf32>, + %C : tensor<16x32xf32>) -> tensor<16x32xf32> { + %0 = linalg.matmul {my_custom_attr = "preserved", another_attr = 42 : i64} + ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>) + outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> + return %0: tensor<16x32xf32> +} + +// CHECK-LABEL: func @preserve_discardable_attrs +// CHECK: linalg.generic +// CHECK-SAME: attrs = {another_attr = 42 : i64, my_custom_attr = "preserved"}