Skip to content

[mlir][linalg] Reject matmul specialization when generic uses bitcast#182705

Open
meshtag wants to merge 1 commit into
llvm:mainfrom
meshtag:fix_specialize_bitcast
Open

[mlir][linalg] Reject matmul specialization when generic uses bitcast#182705
meshtag wants to merge 1 commit into
llvm:mainfrom
meshtag:fix_specialize_bitcast

Conversation

@meshtag
Copy link
Copy Markdown
Member

@meshtag meshtag commented Feb 21, 2026

linalg-specialize-generic-ops currently allows matmul-like specialization even when the generic body contains arith.bitcast. The matmul cast attribute cannot represent bit-level reinterpretation semantics, so this can lose information across specialization/generalization.

Fixes #177593

linalg-specialize-generic-ops currently allows matmul-like specialization
even when the generic body contains arith.bitcast. The matmul cast attribute
cannot represent bit-level reinterpretation semantics, so this can lose
information across specialization/generalization.
@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Feb 21, 2026

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Prathamesh Tagore (meshtag)

Changes

linalg-specialize-generic-ops currently allows matmul-like specialization even when the generic body contains arith.bitcast. The matmul cast attribute cannot represent bit-level reinterpretation semantics, so this can lose information across specialization/generalization.

Fixes #177593


Full diff: https://github.com/llvm/llvm-project/pull/182705.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp (+10-1)
  • (modified) mlir/test/Dialect/Linalg/specialize-generic-ops.mlir (+5-7)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index a71f84dee3bb0..bdb33c833c829 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -157,9 +157,18 @@ static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
 // contains casts that cannot be represented (e.g. output casts or mixed
 // signedness), return std::nullopt.
 static std::optional<TypeFn> getCastTypeForMatmulLikeOp(GenericOp genericOp) {
+  // In addition to output casts, matmul-like named ops cannot represent bit
+  // level casts.
+  bool foundBitCastOp = false;
   bool foundCastForMatmulOutput = false;
   SmallVector<TypeFn> castTyFns;
   genericOp.getBody()->walk([&](CastOpInterface castOp) {
+    // Early return if we encounter a bitcast op.
+    if (isa<arith::BitcastOp>(castOp)) {
+      foundBitCastOp = true;
+      return WalkResult::interrupt();
+    }
+
     // Collect forward slice of the cast op to check if it is for the matmul
     // output.
     SetVector<Operation *> forwardSlice;
@@ -186,7 +195,7 @@ static std::optional<TypeFn> getCastTypeForMatmulLikeOp(GenericOp genericOp) {
     return WalkResult::advance();
   });
 
-  if (foundCastForMatmulOutput)
+  if (foundBitCastOp || foundCastForMatmulOutput)
     return std::nullopt;
 
   if (!castTyFns.empty()) {
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 6acf1ca0d4e30..da4b307f12fa7 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -171,11 +171,9 @@ func.func @negative_op_matmul_output_cast(%A: tensor<16x8xi32>, %B: tensor<8x32x
 // CHECK: linalg.generic
 // CHECK-NOT: linalg.matmul
 
-// Bitcasts are not modeled by the cast attribute, but should not block
-// specialization.
-// NOTE: Bitcasts are not preserved by the matmul named op during
-// roundtrip, so this is potentially loosing information here.
-// See #177593 for more details.
+// Bitcasts are not modeled by the cast attribute, and would lose information
+// when roundtripped through the matmul named op (sitofp will be emitted in
+// this case), so we do not allow them for specialization.
 func.func @op_matmul_bitcast_int_to_float(%A: tensor<16x8xi32>,
                                           %B: tensor<8x32xi32>,
                                           %Out: tensor<16x32xf32>) -> tensor<16x32xf32> {
@@ -193,8 +191,8 @@ func.func @op_matmul_bitcast_int_to_float(%A: tensor<16x8xi32>,
 }
 
 // CHECK-LABEL: op_matmul_bitcast_int_to_float
-// CHECK-NOT: linalg.generic
-// CHECK: linalg.matmul
+// CHECK:     linalg.generic
+// CHECK-NOT: linalg.matmul
 
 // Signed float casts only use sitofp, which defaults to signed semantics.
 func.func @op_matmul_signed_cast_float(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,

Copy link
Copy Markdown
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, this is a bad way of fixing the problem. This just fixes the one problem you had in the issue and ignores the actual semantics. We need to stop with the quick fixes in the form conversions and think about the whole process.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[mlir][linalg] bitcast is not preserved during matmul specialisation from generic

3 participants