Skip to content

[mlir][linalg] Preserve cast semantics during generic to matmul#174757

Merged
meshtag merged 1 commit into
llvm:mainfrom
meshtag:fix_linalg_matmul_cast
Jan 23, 2026
Merged

[mlir][linalg] Preserve cast semantics during generic to matmul#174757
meshtag merged 1 commit into
llvm:mainfrom
meshtag:fix_linalg_matmul_cast

Conversation

@meshtag
Copy link
Copy Markdown
Member

@meshtag meshtag commented Jan 7, 2026

Infer signed/unsigned cast intent from cast ops in linalg.generic bodies and
propagate it via the matmul cast attribute. This could otherwise lead to
silent overflow/underflow errors in e2e execution.

TODO: Extend this to other named ops that support cast attribute.

Fixes a functional bug in the attached issue.

Fixes #174517

@meshtag meshtag force-pushed the fix_linalg_matmul_cast branch 2 times, most recently from 6845a17 to d1fe09f Compare January 8, 2026 05:44
@meshtag meshtag changed the title [mlir][linalg] Preserve cast semantics during linalg.generic to matmul [mlir][linalg] Preserve cast semantics during generic to matmul Jan 8, 2026
@meshtag meshtag force-pushed the fix_linalg_matmul_cast branch 2 times, most recently from 1b9891b to 0fb0238 Compare January 8, 2026 09:04
@meshtag meshtag marked this pull request as ready for review January 8, 2026 09:05
@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Jan 8, 2026

@llvm/pr-subscribers-mlir-linalg

Author: Prathamesh Tagore (meshtag)

Changes

Infer signed/unsigned cast intent from cast ops in linalg.generic bodies and
propagate it via the matmul cast attribute. This could otherwise lead to
silent overflow/underflow errors in e2e execution.

Fixes a functional bug in #174517.


Full diff: https://github.com/llvm/llvm-project/pull/174757.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp (+71-4)
  • (modified) mlir/test/Dialect/Linalg/specialize-generic-ops.mlir (+123)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 0c7b998ffcab9..6be1ca981bfd5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -11,6 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
@@ -134,14 +135,72 @@ static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
 //  All the variants expressed as pseudo regular expression:
 //      `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
 //  have same number of ins/out, so its easy to stamp different versions.
+// `castTy` is an optional type function that indicates whether (and which) cast
+// attribute is needed for the named matmul op.
 template <typename NamedOpTy>
-static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
+static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
+                                         std::optional<TypeFn> castTy) {
+  SmallVector<NamedAttribute> castAttrVec;
+  // Only explicitly specify the cast attribute if the cast type exists and is
+  // pointing to unsigned cast (the default is signed cast for
+  // linalg.matmul/linalg.batch_matmul).
+  if (castTy.has_value() && *castTy == TypeFn::cast_unsigned)
+    castAttrVec = {rewriter.getNamedAttr(
+        "cast", TypeFnAttr::get(rewriter.getContext(), *castTy))};
+
   LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
       op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
-      ValueRange{op.getDpsInits()[0]});
+      ValueRange{op.getDpsInits()[0]}, castAttrVec);
   return namedOp;
 }
 
+// Determines the required cast type for the specialized matmul op (if any)
+// which is expressed in the form of the input linalg.generic op. Also audits
+// that there are no invalid cast ops for matmul inputs/outputs which can't be
+// expressed using the specialized op.
+static bool
+getAndAuditMatmulCastTy(GenericOp genericOp,
+                        std::optional<TypeFn> &specializedOpCastTy) {
+  bool foundCastForMatmulOutput = false;
+  SmallVector<TypeFn> castTyFns;
+  genericOp.getBody()->walk([&](CastOpInterface castOp) {
+    // Collect forward slice of the cast op to check if it is for the matmul
+    // output.
+    SetVector<Operation *> forwardSlice;
+    getForwardSlice(castOp, &forwardSlice);
+
+    // If there is no multiplication op in the forward slice, then this cast
+    // op is for the matmul output. Cast ops on matmul output cannot be
+    // expressed by linalg.matmul and linalg.batch_matmul.
+    if (!llvm::any_of(forwardSlice, [](Operation *op) {
+          // We check explicitly for these multiplication ops in
+          // `specializeLinalgContractions()` to infer matmuls.
+          return isa<arith::MulIOp, arith::MulFOp, complex::MulOp>(op);
+        })) {
+      foundCastForMatmulOutput = true;
+      return WalkResult::interrupt();
+    }
+
+    // Determine the cast type.
+    if (isa<arith::ExtUIOp, arith::UIToFPOp, arith::FPToUIOp>(castOp))
+      castTyFns.push_back(TypeFn::cast_unsigned);
+    else if (isa<arith::ExtSIOp, arith::SIToFPOp, arith::FPToSIOp>(castOp))
+      castTyFns.push_back(TypeFn::cast_signed);
+
+    return WalkResult::advance();
+  });
+
+  if (!castTyFns.empty()) {
+    // If there were multiple different cast types found, then we can't express
+    // it correctly using linalg.matmul or linalg.batch_matmul ops. They only
+    // allow a single cast type for all inputs.
+    if (!llvm::all_equal(castTyFns))
+      return false;
+    specializedOpCastTy = castTyFns.front();
+  }
+  return !foundCastForMatmulOutput;
+}
+
 // Converts linalg.generic to named linalg.*matmul* where possible.
 static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
                                                         GenericOp genericOp) {
@@ -230,11 +289,19 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
       (a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
     return failure();
 
+  // Get the cast attribute for the named matmul op (if any).
+  std::optional<TypeFn> castTy;
+
+  // If there were invalid cast ops found for matmul, bail out. Else determine
+  // the cast type for the named matmul op (if any).
+  if (!getAndAuditMatmulCastTy(genericOp, castTy))
+    return failure();
+
   /// Codegen the different matmul variants.
   if (numOfBatchDims) {
-    return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
+    return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy);
   }
-  return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
+  return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy);
 }
 
 /// Utility to specialize a `genericOp` with a convolution op of type `ConvOpTy`
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index cf495a7d29b70..b1db1154fb357 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -124,3 +124,126 @@ func.func @op_matvec(%A: tensor<?x?xf32>, %B: tensor<?xf32>, %Out: tensor<?xf32>
 }
 // CHECK-LABEL: op_matvec
 // CHECK: linalg.generic
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_unsigned_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi64>,
+                                   %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi64>) outs(%Out : tensor<16x32xi32>) {
+   ^bb0(%in: i16, %in_0: i64, %out: i32):
+     %1 = arith.extui %in : i16 to i32
+     %2 = arith.trunci %in_0 : i64 to i32
+     %3 = arith.muli %1, %2 : i32
+     %4 = arith.addi %out, %3 : i32
+     linalg.yield %4 : i32
+   } -> tensor<16x32xi32>
+   return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_unsigned_cast
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_signed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
+                                 %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi16>) outs(%Out : tensor<16x32xi32>) {
+   ^bb0(%in: i16, %in_0: i16, %out: i32):
+     %1 = arith.extsi %in : i16 to i32
+     %2 = arith.extsi %in_0 : i16 to i32
+     %3 = arith.muli %1, %2 : i32
+     %4 = arith.addi %out, %3 : i32
+     linalg.yield %4 : i32
+   } -> tensor<16x32xi32>
+   return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_signed_cast
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_mixed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
+                                %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi16>) outs(%Out : tensor<16x32xi32>) {
+   ^bb0(%in: i16, %in_0: i16, %out: i32):
+     %1 = arith.extui %in : i16 to i32
+     %2 = arith.extsi %in_0 : i16 to i32
+     %3 = arith.muli %1, %2 : i32
+     %4 = arith.addi %out, %3 : i32
+     linalg.yield %4 : i32
+   } -> tensor<16x32xi32>
+   return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_mixed_cast
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.matmul
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_output_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
+                                 %Out: tensor<16x32xi64>) -> tensor<16x32xi64> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi16>) outs(%Out : tensor<16x32xi64>) {
+   ^bb0(%in: i16, %in_0: i16, %out: i64):
+     %1 = arith.extsi %in : i16 to i32
+     %2 = arith.extsi %in_0 : i16 to i32
+     %3 = arith.trunci %out : i64 to i32
+     %4 = arith.muli %1, %2 : i32
+     %5 = arith.addi %3, %4 : i32
+     %6 = arith.extsi %5 : i32 to i64
+     linalg.yield %6 : i64
+   } -> tensor<16x32xi64>
+   return %0 : tensor<16x32xi64>
+}
+
+// CHECK-LABEL: op_matmul_output_cast
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.matmul
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_bitcast_int_to_float(%A: tensor<16x8xi32>,
+                                          %B: tensor<8x32xi32>,
+                                          %Out: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi32>, tensor<8x32xi32>) outs(%Out : tensor<16x32xf32>) {
+   ^bb0(%in: i32, %in_0: i32, %out: f32):
+     %1 = arith.bitcast %in : i32 to f32
+     %2 = arith.bitcast %in_0 : i32 to f32
+     %3 = arith.mulf %1, %2 : f32
+     %4 = arith.addf %out, %3 : f32
+     linalg.yield %4 : f32
+   } -> tensor<16x32xf32>
+   return %0 : tensor<16x32xf32>
+}
+
+// CHECK-LABEL: op_matmul_bitcast_int_to_float
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Jan 8, 2026

@llvm/pr-subscribers-mlir

Author: Prathamesh Tagore (meshtag)

Changes

Infer signed/unsigned cast intent from cast ops in linalg.generic bodies and
propagate it via the matmul cast attribute. This could otherwise lead to
silent overflow/underflow errors in e2e execution.

Fixes a functional bug in #174517.


Full diff: https://github.com/llvm/llvm-project/pull/174757.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp (+71-4)
  • (modified) mlir/test/Dialect/Linalg/specialize-generic-ops.mlir (+123)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 0c7b998ffcab9..6be1ca981bfd5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -11,6 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
@@ -134,14 +135,72 @@ static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
 //  All the variants expressed as pseudo regular expression:
 //      `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
 //  have same number of ins/out, so its easy to stamp different versions.
+// `castTy` is an optional type function that indicates whether (and which) cast
+// attribute is needed for the named matmul op.
 template <typename NamedOpTy>
-static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
+static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
+                                         std::optional<TypeFn> castTy) {
+  SmallVector<NamedAttribute> castAttrVec;
+  // Only explicitly specify the cast attribute if the cast type exists and is
+  // pointing to unsigned cast (the default is signed cast for
+  // linalg.matmul/linalg.batch_matmul).
+  if (castTy.has_value() && *castTy == TypeFn::cast_unsigned)
+    castAttrVec = {rewriter.getNamedAttr(
+        "cast", TypeFnAttr::get(rewriter.getContext(), *castTy))};
+
   LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
       op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
-      ValueRange{op.getDpsInits()[0]});
+      ValueRange{op.getDpsInits()[0]}, castAttrVec);
   return namedOp;
 }
 
+// Determines the required cast type for the specialized matmul op (if any)
+// which is expressed in the form of the input linalg.generic op. Also audits
+// that there are no invalid cast ops for matmul inputs/outputs which can't be
+// expressed using the specialized op.
+static bool
+getAndAuditMatmulCastTy(GenericOp genericOp,
+                        std::optional<TypeFn> &specializedOpCastTy) {
+  bool foundCastForMatmulOutput = false;
+  SmallVector<TypeFn> castTyFns;
+  genericOp.getBody()->walk([&](CastOpInterface castOp) {
+    // Collect forward slice of the cast op to check if it is for the matmul
+    // output.
+    SetVector<Operation *> forwardSlice;
+    getForwardSlice(castOp, &forwardSlice);
+
+    // If there is no multiplication op in the forward slice, then this cast
+    // op is for the matmul output. Cast ops on matmul output cannot be
+    // expressed by linalg.matmul and linalg.batch_matmul.
+    if (!llvm::any_of(forwardSlice, [](Operation *op) {
+          // We check explicitly for these multiplication ops in
+          // `specializeLinalgContractions()` to infer matmuls.
+          return isa<arith::MulIOp, arith::MulFOp, complex::MulOp>(op);
+        })) {
+      foundCastForMatmulOutput = true;
+      return WalkResult::interrupt();
+    }
+
+    // Determine the cast type.
+    if (isa<arith::ExtUIOp, arith::UIToFPOp, arith::FPToUIOp>(castOp))
+      castTyFns.push_back(TypeFn::cast_unsigned);
+    else if (isa<arith::ExtSIOp, arith::SIToFPOp, arith::FPToSIOp>(castOp))
+      castTyFns.push_back(TypeFn::cast_signed);
+
+    return WalkResult::advance();
+  });
+
+  if (!castTyFns.empty()) {
+    // If there were multiple different cast types found, then we can't express
+    // it correctly using linalg.matmul or linalg.batch_matmul ops. They only
+    // allow a single cast type for all inputs.
+    if (!llvm::all_equal(castTyFns))
+      return false;
+    specializedOpCastTy = castTyFns.front();
+  }
+  return !foundCastForMatmulOutput;
+}
+
 // Converts linalg.generic to named linalg.*matmul* where possible.
 static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
                                                         GenericOp genericOp) {
@@ -230,11 +289,19 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
       (a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
     return failure();
 
+  // Get the cast attribute for the named matmul op (if any).
+  std::optional<TypeFn> castTy;
+
+  // If there were invalid cast ops found for matmul, bail out. Else determine
+  // the cast type for the named matmul op (if any).
+  if (!getAndAuditMatmulCastTy(genericOp, castTy))
+    return failure();
+
   /// Codegen the different matmul variants.
   if (numOfBatchDims) {
-    return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
+    return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy);
   }
-  return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
+  return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy);
 }
 
 /// Utility to specialize a `genericOp` with a convolution op of type `ConvOpTy`
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index cf495a7d29b70..b1db1154fb357 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -124,3 +124,126 @@ func.func @op_matvec(%A: tensor<?x?xf32>, %B: tensor<?xf32>, %Out: tensor<?xf32>
 }
 // CHECK-LABEL: op_matvec
 // CHECK: linalg.generic
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_unsigned_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi64>,
+                                   %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi64>) outs(%Out : tensor<16x32xi32>) {
+   ^bb0(%in: i16, %in_0: i64, %out: i32):
+     %1 = arith.extui %in : i16 to i32
+     %2 = arith.trunci %in_0 : i64 to i32
+     %3 = arith.muli %1, %2 : i32
+     %4 = arith.addi %out, %3 : i32
+     linalg.yield %4 : i32
+   } -> tensor<16x32xi32>
+   return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_unsigned_cast
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_signed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
+                                 %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi16>) outs(%Out : tensor<16x32xi32>) {
+   ^bb0(%in: i16, %in_0: i16, %out: i32):
+     %1 = arith.extsi %in : i16 to i32
+     %2 = arith.extsi %in_0 : i16 to i32
+     %3 = arith.muli %1, %2 : i32
+     %4 = arith.addi %out, %3 : i32
+     linalg.yield %4 : i32
+   } -> tensor<16x32xi32>
+   return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_signed_cast
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_mixed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
+                                %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi16>) outs(%Out : tensor<16x32xi32>) {
+   ^bb0(%in: i16, %in_0: i16, %out: i32):
+     %1 = arith.extui %in : i16 to i32
+     %2 = arith.extsi %in_0 : i16 to i32
+     %3 = arith.muli %1, %2 : i32
+     %4 = arith.addi %out, %3 : i32
+     linalg.yield %4 : i32
+   } -> tensor<16x32xi32>
+   return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_mixed_cast
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.matmul
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_output_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
+                                 %Out: tensor<16x32xi64>) -> tensor<16x32xi64> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi16>) outs(%Out : tensor<16x32xi64>) {
+   ^bb0(%in: i16, %in_0: i16, %out: i64):
+     %1 = arith.extsi %in : i16 to i32
+     %2 = arith.extsi %in_0 : i16 to i32
+     %3 = arith.trunci %out : i64 to i32
+     %4 = arith.muli %1, %2 : i32
+     %5 = arith.addi %3, %4 : i32
+     %6 = arith.extsi %5 : i32 to i64
+     linalg.yield %6 : i64
+   } -> tensor<16x32xi64>
+   return %0 : tensor<16x32xi64>
+}
+
+// CHECK-LABEL: op_matmul_output_cast
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.matmul
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_bitcast_int_to_float(%A: tensor<16x8xi32>,
+                                          %B: tensor<8x32xi32>,
+                                          %Out: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi32>, tensor<8x32xi32>) outs(%Out : tensor<16x32xf32>) {
+   ^bb0(%in: i32, %in_0: i32, %out: f32):
+     %1 = arith.bitcast %in : i32 to f32
+     %2 = arith.bitcast %in_0 : i32 to f32
+     %3 = arith.mulf %1, %2 : f32
+     %4 = arith.addf %out, %3 : f32
+     linalg.yield %4 : f32
+   } -> tensor<16x32xf32>
+   return %0 : tensor<16x32xf32>
+}
+
+// CHECK-LABEL: op_matmul_bitcast_int_to_float
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul

@adam-smnk adam-smnk requested a review from CoTinker January 8, 2026 09:36
Copy link
Copy Markdown
Member

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the linalg morphism more robust 👍

Could you also add a related roundtrip test?
Maybe to roundtrip-linalg-named-ops.mlir.

Comment thread mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp Outdated
Comment thread mlir/test/Dialect/Linalg/specialize-generic-ops.mlir Outdated
Comment thread mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@banach-space
Copy link
Copy Markdown
Contributor

Thanks for the fix!

Aren't other contract-like ops also affected?

Btw, would you mind clustering together all tests for linalg.matmul and other ops (bonus points for a block comment)? This way it will be much easier to see all the test variants for a particular Op.

More on testing: https://mlir.llvm.org/getting_started/TestingGuide/#test-documentation-best-practices

@meshtag meshtag force-pushed the fix_linalg_matmul_cast branch from 0fb0238 to a7f302a Compare January 8, 2026 18:03
@meshtag
Copy link
Copy Markdown
Member Author

meshtag commented Jan 8, 2026

Aren't other contract-like ops also affected?

I haven't checked this. I'd prefer only dealing with matmul-like ops in this PR though. We can deal with other ops separately if that's fine.

Btw, would you mind clustering together all tests for linalg.matmul and other ops (bonus points for a block comment)? This way it will be much easier to see all the test variants for a particular Op.

I have removed input-file split indicators from between them, and am also re-using the affine maps wherever I can. Documentation and naming is also improved for them per the attached guide (thanks!). Please feel free to let me know if more things are required there. Thanks.

@meshtag meshtag force-pushed the fix_linalg_matmul_cast branch 2 times, most recently from ed22234 to 7a4d435 Compare January 8, 2026 18:39
@banach-space
Copy link
Copy Markdown
Contributor

Thanks for the updates!

I haven't checked this. I'd prefer only dealing with matmul-like ops in this PR though. We can deal with other ops separately if that's fine.

Sure, but please leave a TODO somewhere, eg the summary. Will you look into it? No worries if not, I can take a look myself.

Please feel free to let me know if more things are required there.

Ultra nit - move the new tests next to other tests for matmul ops, e.g (ie cluster them together):

func.func @op_matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {

Ultra nit - add a block comment, eg:

///----------------------------------------------------------------------------------------

This is just for a bit better testing hygiene :)

Thanks!

Copy link
Copy Markdown
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've left a few minor suggestions inline, thanks!

Comment thread mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp Outdated
Comment thread mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
Comment thread mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp Outdated
Comment thread mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp Outdated
@meshtag meshtag force-pushed the fix_linalg_matmul_cast branch 3 times, most recently from 54b75cc to c5d5f05 Compare January 10, 2026 08:14
@meshtag
Copy link
Copy Markdown
Member Author

meshtag commented Jan 10, 2026

Sure, but please leave a TODO somewhere, eg the summary. Will you look into it? No worries if not, I can take a look myself.

Sure, I can do this if it's not in the critical path. Updated the PR with the suggestions.

EDIT: Created an issue to track the TODO here: #175885

Comment on lines +149 to +150
// Bitcasts are not modeled by the cast attribute, but should not block
// specialization.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very interesting example, thanks!

should not block specialization

Not sure about this 🤔 What is it specialised to?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leads to this op
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<16x8xi32>, tensor<8x32xi32>) outs(%arg2 : tensor<16x32xf32>) -> tensor<16x32xf32>.

The test checks that we do not block/modify anything for casts which are not specifying the signed/unsigned behaviour.

Copy link
Copy Markdown
Member

@adam-smnk adam-smnk Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Round-trip today will generate arith.sitofp instead of arith.bitcast.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks like a bug, no?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd think so. But unrelated to this PR.
Definitely sth to revisit.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bug, or "canonicalization"?

Comment on lines +75 to +76
%1 = arith.extui %in : i16 to i32
%2 = arith.trunci %in_0 : i64 to i32
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are different casts - shouldn't this fail?

Copy link
Copy Markdown
Member Author

@meshtag meshtag Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only consider casts as conflicting if they have different signedness behaviours, and then we do not specialise if they do conflict. Since this is not such a case, we do not block specialisation. Also the roundtrip lowering back to linalg.generic for such an op is expected to produce the same thing again, so we are not loosing information here.

For example:

%0 = linalg.generic
         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
         ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi64>) outs(%Out : tensor<16x32xi32>) {
   ^bb0(%in: i16, %in_0: i64, %out: i32):
     %1 = arith.extui %in : i16 to i32
     %2 = arith.trunci %in_0 : i64 to i32
     %3 = arith.muli %1, %2 : i32
     %4 = arith.addi %out, %3 : i32
     linalg.yield %4 : i32
   } -> tensor<16x32xi32>

with --linalg-specialize-generic-ops becomes

%0 = linalg.matmul {cast = #linalg.type_fn<cast_unsigned>} ins(%arg0, %arg1 : tensor<16x8xi16>, tensor<8x32xi64>) outs(%arg2 : tensor<16x32xi32>) -> tensor<16x32xi32>

and applying -linalg-generalize-named-ops on the above gives

%0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<16x8xi16>, tensor<8x32xi64>) outs(%arg2 : tensor<16x32xi32>) {
    ^bb0(%in: i16, %in_0: i64, %out: i32):
      %1 = arith.extui %in : i16 to i32
      %2 = arith.trunci %in_0 : i64 to i32
      %3 = arith.muli %1, %2 : i32
      %4 = arith.addi %out, %3 : i32
      linalg.yield %4 : i32
    } -> tensor<16x32xi32>

So we did not loose any information here.

I can also add this in the test comment if that makes it more explicit.

Copy link
Copy Markdown
Contributor

@banach-space banach-space Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification - so truncations are excluded from that verification. Makes sense.

This makes me realise that we want the following tests:

  1. @op_matmul_unsigned_cast (no truncation, only arith.extui)
  2. @op_matmul_signed_cast (no truncation, only arith.extsi)
  3. @negative_op_matmul_mixed_cast(no truncation, only arith.extui + arith.extsi)
  4. @op_matmul_bitcast_int_to_float (no truncation, only arith.bitcast)
  5. @op_matmul_signed_cast_float (no truncation, only arith.sitofp)
  6. @op_matmul_unsigned_cast_float (no truncation, only arith.uitofp)
  7. @op_matmul_unsigned_cast_and_truncate (arith.extui + arith.trunci)

Not sure about 4. TBH. That case should be rejected, no? @adam-smnk , WDYT?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to covering various combinations, linalg morphism test coverage clearly needs improvements.
Looks like case 4 should be rejected as round tripping doesn't preserve original semantics. AFAIK, there's no attribute to maintain them either.

I wouldn't block on this but more tests are always welcome.

Copy link
Copy Markdown
Member Author

@meshtag meshtag Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more tests here. I am not sure if rejecting case 4 should be done in this PR. I'd prefer handling it separately as it is not a bug introduced by this change, and I am not sure if special casing for bitcasts is the best solution forward. Would prefer taking another look and then adding tests for it separately - can also create a GitHub issue to track this.

Also, can someone please help me land this when we think it is good to go. Thanks.

Copy link
Copy Markdown
Contributor

@banach-space banach-space Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding more tests, LGTM

I'd prefer handling it separately as it is not a bug introduced by this change

That's fine with me, but could you create a GitHub issue and link it in the code? Thanks!

@meshtag meshtag force-pushed the fix_linalg_matmul_cast branch 3 times, most recently from cd897b3 to b7fcffd Compare January 23, 2026 14:20
Infer signed/unsigned cast intent from cast ops in linalg.generic bodies and
propagate it via the matmul cast attribute. This could otherwise lead to
silent overflow/underflow errors in e2e execution.

TODO: Extend this to other named ops that support cast attribute.
@meshtag meshtag force-pushed the fix_linalg_matmul_cast branch from b7fcffd to 2dda974 Compare January 23, 2026 14:22
@meshtag meshtag merged commit 0df9098 into llvm:main Jan 23, 2026
11 checks passed
@meshtag meshtag deleted the fix_linalg_matmul_cast branch January 23, 2026 14:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[MLIR] -linalg-specialize-generic-ops causes wrong-code for linalg.matmul

6 participants