[mlir][linalg] Reject matmul specialization when generic uses bitcast#182705
Open
meshtag wants to merge 1 commit into
Open
[mlir][linalg] Reject matmul specialization when generic uses bitcast#182705meshtag wants to merge 1 commit into
meshtag wants to merge 1 commit into
Conversation
linalg-specialize-generic-ops currently allows matmul-like specialization even when the generic body contains arith.bitcast. The matmul cast attribute cannot represent bit-level reinterpretation semantics, so this can lose information across specialization/generalization.
Member
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Prathamesh Tagore (meshtag) Changeslinalg-specialize-generic-ops currently allows matmul-like specialization even when the generic body contains arith.bitcast. The matmul cast attribute cannot represent bit-level reinterpretation semantics, so this can lose information across specialization/generalization. Fixes #177593 Full diff: https://github.com/llvm/llvm-project/pull/182705.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index a71f84dee3bb0..bdb33c833c829 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -157,9 +157,18 @@ static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
// contains casts that cannot be represented (e.g. output casts or mixed
// signedness), return std::nullopt.
static std::optional<TypeFn> getCastTypeForMatmulLikeOp(GenericOp genericOp) {
+ // In addition to output casts, matmul-like named ops cannot represent bit
+ // level casts.
+ bool foundBitCastOp = false;
bool foundCastForMatmulOutput = false;
SmallVector<TypeFn> castTyFns;
genericOp.getBody()->walk([&](CastOpInterface castOp) {
+ // Early return if we encounter a bitcast op.
+ if (isa<arith::BitcastOp>(castOp)) {
+ foundBitCastOp = true;
+ return WalkResult::interrupt();
+ }
+
// Collect forward slice of the cast op to check if it is for the matmul
// output.
SetVector<Operation *> forwardSlice;
@@ -186,7 +195,7 @@ static std::optional<TypeFn> getCastTypeForMatmulLikeOp(GenericOp genericOp) {
return WalkResult::advance();
});
- if (foundCastForMatmulOutput)
+ if (foundBitCastOp || foundCastForMatmulOutput)
return std::nullopt;
if (!castTyFns.empty()) {
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 6acf1ca0d4e30..da4b307f12fa7 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -171,11 +171,9 @@ func.func @negative_op_matmul_output_cast(%A: tensor<16x8xi32>, %B: tensor<8x32x
// CHECK: linalg.generic
// CHECK-NOT: linalg.matmul
-// Bitcasts are not modeled by the cast attribute, but should not block
-// specialization.
-// NOTE: Bitcasts are not preserved by the matmul named op during
-// roundtrip, so this is potentially loosing information here.
-// See #177593 for more details.
+// Bitcasts are not modeled by the cast attribute, and would lose information
+// when roundtripped through the matmul named op (sitofp will be emitted in
+// this case), so we do not allow them for specialization.
func.func @op_matmul_bitcast_int_to_float(%A: tensor<16x8xi32>,
%B: tensor<8x32xi32>,
%Out: tensor<16x32xf32>) -> tensor<16x32xf32> {
@@ -193,8 +191,8 @@ func.func @op_matmul_bitcast_int_to_float(%A: tensor<16x8xi32>,
}
// CHECK-LABEL: op_matmul_bitcast_int_to_float
-// CHECK-NOT: linalg.generic
-// CHECK: linalg.matmul
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.matmul
// Signed float casts only use sitofp, which defaults to signed semantics.
func.func @op_matmul_signed_cast_float(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
|
rengolin
requested changes
Feb 22, 2026
Member
rengolin
left a comment
There was a problem hiding this comment.
Sorry, this is a bad way of fixing the problem. This just fixes the one problem you had in the issue and ignores the actual semantics. We need to stop with the quick fixes in the form conversions and think about the whole process.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
linalg-specialize-generic-ops currently allows matmul-like specialization even when the generic body contains arith.bitcast. The matmul cast attribute cannot represent bit-level reinterpretation semantics, so this can lose information across specialization/generalization.
Fixes #177593