[mlir][linalg] Preserve cast semantics during generic to matmul#174757
Conversation
6845a17 to
d1fe09f
Compare
1b9891b to
0fb0238
Compare
|
@llvm/pr-subscribers-mlir-linalg Author: Prathamesh Tagore (meshtag) ChangesInfer signed/unsigned cast intent from cast ops in linalg.generic bodies and Fixes a functional bug in #174517. Full diff: https://github.com/llvm/llvm-project/pull/174757.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 0c7b998ffcab9..6be1ca981bfd5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
@@ -134,14 +135,72 @@ static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
// All the variants expressed as pseudo regular expression:
// `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
// have same number of ins/out, so its easy to stamp different versions.
+// `castTy` is an optional type function that indicates whether (and which) cast
+// attribute is needed for the named matmul op.
template <typename NamedOpTy>
-static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
+static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
+ std::optional<TypeFn> castTy) {
+ SmallVector<NamedAttribute> castAttrVec;
+ // Only explicitly specify the cast attribute if the cast type exists and is
+ // pointing to unsigned cast (the default is signed cast for
+ // linalg.matmul/linalg.batch_matmul).
+ if (castTy.has_value() && *castTy == TypeFn::cast_unsigned)
+ castAttrVec = {rewriter.getNamedAttr(
+ "cast", TypeFnAttr::get(rewriter.getContext(), *castTy))};
+
LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
- ValueRange{op.getDpsInits()[0]});
+ ValueRange{op.getDpsInits()[0]}, castAttrVec);
return namedOp;
}
+// Determines the required cast type for the specialized matmul op (if any)
+// which is expressed in the form of the input linalg.generic op. Also audits
+// that there are no invalid cast ops for matmul inputs/outputs which can't be
+// expressed using the specialized op.
+static bool
+getAndAuditMatmulCastTy(GenericOp genericOp,
+ std::optional<TypeFn> &specializedOpCastTy) {
+ bool foundCastForMatmulOutput = false;
+ SmallVector<TypeFn> castTyFns;
+ genericOp.getBody()->walk([&](CastOpInterface castOp) {
+ // Collect forward slice of the cast op to check if it is for the matmul
+ // output.
+ SetVector<Operation *> forwardSlice;
+ getForwardSlice(castOp, &forwardSlice);
+
+ // If there is no multiplication op in the forward slice, then this cast
+ // op is for the matmul output. Cast ops on matmul output cannot be
+ // expressed by linalg.matmul and linalg.batch_matmul.
+ if (!llvm::any_of(forwardSlice, [](Operation *op) {
+ // We check explicitly for these multiplication ops in
+ // `specializeLinalgContractions()` to infer matmuls.
+ return isa<arith::MulIOp, arith::MulFOp, complex::MulOp>(op);
+ })) {
+ foundCastForMatmulOutput = true;
+ return WalkResult::interrupt();
+ }
+
+ // Determine the cast type.
+ if (isa<arith::ExtUIOp, arith::UIToFPOp, arith::FPToUIOp>(castOp))
+ castTyFns.push_back(TypeFn::cast_unsigned);
+ else if (isa<arith::ExtSIOp, arith::SIToFPOp, arith::FPToSIOp>(castOp))
+ castTyFns.push_back(TypeFn::cast_signed);
+
+ return WalkResult::advance();
+ });
+
+ if (!castTyFns.empty()) {
+ // If there were multiple different cast types found, then we can't express
+ // it correctly using linalg.matmul or linalg.batch_matmul ops. They only
+ // allow a single cast type for all inputs.
+ if (!llvm::all_equal(castTyFns))
+ return false;
+ specializedOpCastTy = castTyFns.front();
+ }
+ return !foundCastForMatmulOutput;
+}
+
// Converts linalg.generic to named linalg.*matmul* where possible.
static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
GenericOp genericOp) {
@@ -230,11 +289,19 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
(a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
return failure();
+ // Get the cast attribute for the named matmul op (if any).
+ std::optional<TypeFn> castTy;
+
+ // If there were invalid cast ops found for matmul, bail out. Else determine
+ // the cast type for the named matmul op (if any).
+ if (!getAndAuditMatmulCastTy(genericOp, castTy))
+ return failure();
+
/// Codegen the different matmul variants.
if (numOfBatchDims) {
- return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
+ return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy);
}
- return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
+ return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy);
}
/// Utility to specialize a `genericOp` with a convolution op of type `ConvOpTy`
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index cf495a7d29b70..b1db1154fb357 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -124,3 +124,126 @@ func.func @op_matvec(%A: tensor<?x?xf32>, %B: tensor<?xf32>, %Out: tensor<?xf32>
}
// CHECK-LABEL: op_matvec
// CHECK: linalg.generic
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_unsigned_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi64>,
+ %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi64>) outs(%Out : tensor<16x32xi32>) {
+ ^bb0(%in: i16, %in_0: i64, %out: i32):
+ %1 = arith.extui %in : i16 to i32
+ %2 = arith.trunci %in_0 : i64 to i32
+ %3 = arith.muli %1, %2 : i32
+ %4 = arith.addi %out, %3 : i32
+ linalg.yield %4 : i32
+ } -> tensor<16x32xi32>
+ return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_unsigned_cast
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_signed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
+ %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi16>) outs(%Out : tensor<16x32xi32>) {
+ ^bb0(%in: i16, %in_0: i16, %out: i32):
+ %1 = arith.extsi %in : i16 to i32
+ %2 = arith.extsi %in_0 : i16 to i32
+ %3 = arith.muli %1, %2 : i32
+ %4 = arith.addi %out, %3 : i32
+ linalg.yield %4 : i32
+ } -> tensor<16x32xi32>
+ return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_signed_cast
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_mixed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
+ %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi16>) outs(%Out : tensor<16x32xi32>) {
+ ^bb0(%in: i16, %in_0: i16, %out: i32):
+ %1 = arith.extui %in : i16 to i32
+ %2 = arith.extsi %in_0 : i16 to i32
+ %3 = arith.muli %1, %2 : i32
+ %4 = arith.addi %out, %3 : i32
+ linalg.yield %4 : i32
+ } -> tensor<16x32xi32>
+ return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_mixed_cast
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.matmul
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_output_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
+ %Out: tensor<16x32xi64>) -> tensor<16x32xi64> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi16>) outs(%Out : tensor<16x32xi64>) {
+ ^bb0(%in: i16, %in_0: i16, %out: i64):
+ %1 = arith.extsi %in : i16 to i32
+ %2 = arith.extsi %in_0 : i16 to i32
+ %3 = arith.trunci %out : i64 to i32
+ %4 = arith.muli %1, %2 : i32
+ %5 = arith.addi %3, %4 : i32
+ %6 = arith.extsi %5 : i32 to i64
+ linalg.yield %6 : i64
+ } -> tensor<16x32xi64>
+ return %0 : tensor<16x32xi64>
+}
+
+// CHECK-LABEL: op_matmul_output_cast
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.matmul
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_bitcast_int_to_float(%A: tensor<16x8xi32>,
+ %B: tensor<8x32xi32>,
+ %Out: tensor<16x32xf32>) -> tensor<16x32xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<16x8xi32>, tensor<8x32xi32>) outs(%Out : tensor<16x32xf32>) {
+ ^bb0(%in: i32, %in_0: i32, %out: f32):
+ %1 = arith.bitcast %in : i32 to f32
+ %2 = arith.bitcast %in_0 : i32 to f32
+ %3 = arith.mulf %1, %2 : f32
+ %4 = arith.addf %out, %3 : f32
+ linalg.yield %4 : f32
+ } -> tensor<16x32xf32>
+ return %0 : tensor<16x32xf32>
+}
+
+// CHECK-LABEL: op_matmul_bitcast_int_to_float
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul
|
|
@llvm/pr-subscribers-mlir Author: Prathamesh Tagore (meshtag) ChangesInfer signed/unsigned cast intent from cast ops in linalg.generic bodies and Fixes a functional bug in #174517. Full diff: https://github.com/llvm/llvm-project/pull/174757.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 0c7b998ffcab9..6be1ca981bfd5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
@@ -134,14 +135,72 @@ static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
// All the variants expressed as pseudo regular expression:
// `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
// have same number of ins/out, so its easy to stamp different versions.
+// `castTy` is an optional type function that indicates whether (and which) cast
+// attribute is needed for the named matmul op.
template <typename NamedOpTy>
-static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
+static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
+ std::optional<TypeFn> castTy) {
+ SmallVector<NamedAttribute> castAttrVec;
+ // Only explicitly specify the cast attribute if the cast type exists and is
+ // pointing to unsigned cast (the default is signed cast for
+ // linalg.matmul/linalg.batch_matmul).
+ if (castTy.has_value() && *castTy == TypeFn::cast_unsigned)
+ castAttrVec = {rewriter.getNamedAttr(
+ "cast", TypeFnAttr::get(rewriter.getContext(), *castTy))};
+
LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
- ValueRange{op.getDpsInits()[0]});
+ ValueRange{op.getDpsInits()[0]}, castAttrVec);
return namedOp;
}
+// Determines the required cast type for the specialized matmul op (if any)
+// which is expressed in the form of the input linalg.generic op. Also audits
+// that there are no invalid cast ops for matmul inputs/outputs which can't be
+// expressed using the specialized op.
+static bool
+getAndAuditMatmulCastTy(GenericOp genericOp,
+ std::optional<TypeFn> &specializedOpCastTy) {
+ bool foundCastForMatmulOutput = false;
+ SmallVector<TypeFn> castTyFns;
+ genericOp.getBody()->walk([&](CastOpInterface castOp) {
+ // Collect forward slice of the cast op to check if it is for the matmul
+ // output.
+ SetVector<Operation *> forwardSlice;
+ getForwardSlice(castOp, &forwardSlice);
+
+ // If there is no multiplication op in the forward slice, then this cast
+ // op is for the matmul output. Cast ops on matmul output cannot be
+ // expressed by linalg.matmul and linalg.batch_matmul.
+ if (!llvm::any_of(forwardSlice, [](Operation *op) {
+ // We check explicitly for these multiplication ops in
+ // `specializeLinalgContractions()` to infer matmuls.
+ return isa<arith::MulIOp, arith::MulFOp, complex::MulOp>(op);
+ })) {
+ foundCastForMatmulOutput = true;
+ return WalkResult::interrupt();
+ }
+
+ // Determine the cast type.
+ if (isa<arith::ExtUIOp, arith::UIToFPOp, arith::FPToUIOp>(castOp))
+ castTyFns.push_back(TypeFn::cast_unsigned);
+ else if (isa<arith::ExtSIOp, arith::SIToFPOp, arith::FPToSIOp>(castOp))
+ castTyFns.push_back(TypeFn::cast_signed);
+
+ return WalkResult::advance();
+ });
+
+ if (!castTyFns.empty()) {
+ // If there were multiple different cast types found, then we can't express
+ // it correctly using linalg.matmul or linalg.batch_matmul ops. They only
+ // allow a single cast type for all inputs.
+ if (!llvm::all_equal(castTyFns))
+ return false;
+ specializedOpCastTy = castTyFns.front();
+ }
+ return !foundCastForMatmulOutput;
+}
+
// Converts linalg.generic to named linalg.*matmul* where possible.
static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
GenericOp genericOp) {
@@ -230,11 +289,19 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
(a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
return failure();
+ // Get the cast attribute for the named matmul op (if any).
+ std::optional<TypeFn> castTy;
+
+ // If there were invalid cast ops found for matmul, bail out. Else determine
+ // the cast type for the named matmul op (if any).
+ if (!getAndAuditMatmulCastTy(genericOp, castTy))
+ return failure();
+
/// Codegen the different matmul variants.
if (numOfBatchDims) {
- return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
+ return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy);
}
- return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
+ return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy);
}
/// Utility to specialize a `genericOp` with a convolution op of type `ConvOpTy`
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index cf495a7d29b70..b1db1154fb357 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -124,3 +124,126 @@ func.func @op_matvec(%A: tensor<?x?xf32>, %B: tensor<?xf32>, %Out: tensor<?xf32>
}
// CHECK-LABEL: op_matvec
// CHECK: linalg.generic
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_unsigned_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi64>,
+ %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi64>) outs(%Out : tensor<16x32xi32>) {
+ ^bb0(%in: i16, %in_0: i64, %out: i32):
+ %1 = arith.extui %in : i16 to i32
+ %2 = arith.trunci %in_0 : i64 to i32
+ %3 = arith.muli %1, %2 : i32
+ %4 = arith.addi %out, %3 : i32
+ linalg.yield %4 : i32
+ } -> tensor<16x32xi32>
+ return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_unsigned_cast
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_signed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
+ %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi16>) outs(%Out : tensor<16x32xi32>) {
+ ^bb0(%in: i16, %in_0: i16, %out: i32):
+ %1 = arith.extsi %in : i16 to i32
+ %2 = arith.extsi %in_0 : i16 to i32
+ %3 = arith.muli %1, %2 : i32
+ %4 = arith.addi %out, %3 : i32
+ linalg.yield %4 : i32
+ } -> tensor<16x32xi32>
+ return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_signed_cast
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_mixed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
+ %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi16>) outs(%Out : tensor<16x32xi32>) {
+ ^bb0(%in: i16, %in_0: i16, %out: i32):
+ %1 = arith.extui %in : i16 to i32
+ %2 = arith.extsi %in_0 : i16 to i32
+ %3 = arith.muli %1, %2 : i32
+ %4 = arith.addi %out, %3 : i32
+ linalg.yield %4 : i32
+ } -> tensor<16x32xi32>
+ return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_mixed_cast
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.matmul
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_output_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
+ %Out: tensor<16x32xi64>) -> tensor<16x32xi64> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi16>) outs(%Out : tensor<16x32xi64>) {
+ ^bb0(%in: i16, %in_0: i16, %out: i64):
+ %1 = arith.extsi %in : i16 to i32
+ %2 = arith.extsi %in_0 : i16 to i32
+ %3 = arith.trunci %out : i64 to i32
+ %4 = arith.muli %1, %2 : i32
+ %5 = arith.addi %3, %4 : i32
+ %6 = arith.extsi %5 : i32 to i64
+ linalg.yield %6 : i64
+ } -> tensor<16x32xi64>
+ return %0 : tensor<16x32xi64>
+}
+
+// CHECK-LABEL: op_matmul_output_cast
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.matmul
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_bitcast_int_to_float(%A: tensor<16x8xi32>,
+ %B: tensor<8x32xi32>,
+ %Out: tensor<16x32xf32>) -> tensor<16x32xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<16x8xi32>, tensor<8x32xi32>) outs(%Out : tensor<16x32xf32>) {
+ ^bb0(%in: i32, %in_0: i32, %out: f32):
+ %1 = arith.bitcast %in : i32 to f32
+ %2 = arith.bitcast %in_0 : i32 to f32
+ %3 = arith.mulf %1, %2 : f32
+ %4 = arith.addf %out, %3 : f32
+ linalg.yield %4 : f32
+ } -> tensor<16x32xf32>
+ return %0 : tensor<16x32xf32>
+}
+
+// CHECK-LABEL: op_matmul_bitcast_int_to_float
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul
|
adam-smnk
left a comment
There was a problem hiding this comment.
Thanks for making the linalg morphism more robust 👍
Could you also add a related roundtrip test?
Maybe to roundtrip-linalg-named-ops.mlir.
|
Thanks for the fix! Aren't other contract-like ops also affected? Btw, would you mind clustering together all tests for linalg.matmul and other ops (bonus points for a block comment)? This way it will be much easier to see all the test variants for a particular Op. More on testing: https://mlir.llvm.org/getting_started/TestingGuide/#test-documentation-best-practices |
0fb0238 to
a7f302a
Compare
I haven't checked this. I'd prefer only dealing with matmul-like ops in this PR though. We can deal with other ops separately if that's fine.
I have removed input-file split indicators from between them, and am also re-using the affine maps wherever I can. Documentation and naming is also improved for them per the attached guide (thanks!). Please feel free to let me know if more things are required there. Thanks. |
ed22234 to
7a4d435
Compare
|
Thanks for the updates!
Sure, but please leave a TODO somewhere, eg the summary. Will you look into it? No worries if not, I can take a look myself.
Ultra nit - move the new tests next to other tests for matmul ops, e.g (ie cluster them together): Ultra nit - add a block comment, eg: This is just for a bit better testing hygiene :) Thanks! |
banach-space
left a comment
There was a problem hiding this comment.
I've left a few minor suggestions inline, thanks!
54b75cc to
c5d5f05
Compare
Sure, I can do this if it's not in the critical path. Updated the PR with the suggestions. EDIT: Created an issue to track the TODO here: #175885 |
c5d5f05 to
8239020
Compare
| // Bitcasts are not modeled by the cast attribute, but should not block | ||
| // specialization. |
There was a problem hiding this comment.
Very interesting example, thanks!
should not block specialization
Not sure about this 🤔 What is it specialised to?
There was a problem hiding this comment.
Leads to this op
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<16x8xi32>, tensor<8x32xi32>) outs(%arg2 : tensor<16x32xf32>) -> tensor<16x32xf32>.
The test checks that we do not block/modify anything for casts which are not specifying the signed/unsigned behaviour.
There was a problem hiding this comment.
Round-trip today will generate arith.sitofp instead of arith.bitcast.
There was a problem hiding this comment.
That looks like a bug, no?
There was a problem hiding this comment.
I'd think so. But unrelated to this PR.
Definitely sth to revisit.
| %1 = arith.extui %in : i16 to i32 | ||
| %2 = arith.trunci %in_0 : i64 to i32 |
There was a problem hiding this comment.
These are different casts - shouldn't this fail?
There was a problem hiding this comment.
We only consider casts as conflicting if they have different signedness behaviours, and then we do not specialise if they do conflict. Since this is not such a case, we do not block specialisation. Also the roundtrip lowering back to linalg.generic for such an op is expected to produce the same thing again, so we are not loosing information here.
For example:
%0 = linalg.generic
{indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi64>) outs(%Out : tensor<16x32xi32>) {
^bb0(%in: i16, %in_0: i64, %out: i32):
%1 = arith.extui %in : i16 to i32
%2 = arith.trunci %in_0 : i64 to i32
%3 = arith.muli %1, %2 : i32
%4 = arith.addi %out, %3 : i32
linalg.yield %4 : i32
} -> tensor<16x32xi32>
with --linalg-specialize-generic-ops becomes
%0 = linalg.matmul {cast = #linalg.type_fn<cast_unsigned>} ins(%arg0, %arg1 : tensor<16x8xi16>, tensor<8x32xi64>) outs(%arg2 : tensor<16x32xi32>) -> tensor<16x32xi32>
and applying -linalg-generalize-named-ops on the above gives
%0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<16x8xi16>, tensor<8x32xi64>) outs(%arg2 : tensor<16x32xi32>) {
^bb0(%in: i16, %in_0: i64, %out: i32):
%1 = arith.extui %in : i16 to i32
%2 = arith.trunci %in_0 : i64 to i32
%3 = arith.muli %1, %2 : i32
%4 = arith.addi %out, %3 : i32
linalg.yield %4 : i32
} -> tensor<16x32xi32>
So we did not loose any information here.
I can also add this in the test comment if that makes it more explicit.
There was a problem hiding this comment.
Thanks for the clarification - so truncations are excluded from that verification. Makes sense.
This makes me realise that we want the following tests:
@op_matmul_unsigned_cast(no truncation, onlyarith.extui)@op_matmul_signed_cast(no truncation, onlyarith.extsi)@negative_op_matmul_mixed_cast(no truncation, onlyarith.extui+arith.extsi)@op_matmul_bitcast_int_to_float(no truncation, onlyarith.bitcast)@op_matmul_signed_cast_float(no truncation, onlyarith.sitofp)@op_matmul_unsigned_cast_float(no truncation, onlyarith.uitofp)@op_matmul_unsigned_cast_and_truncate(arith.extui+arith.trunci)
Not sure about 4. TBH. That case should be rejected, no? @adam-smnk , WDYT?
There was a problem hiding this comment.
+1 to covering various combinations, linalg morphism test coverage clearly needs improvements.
Looks like case 4 should be rejected as round tripping doesn't preserve original semantics. AFAIK, there's no attribute to maintain them either.
I wouldn't block on this but more tests are always welcome.
There was a problem hiding this comment.
Added more tests here. I am not sure if rejecting case 4 should be done in this PR. I'd prefer handling it separately as it is not a bug introduced by this change, and I am not sure if special casing for bitcasts is the best solution forward. Would prefer taking another look and then adding tests for it separately - can also create a GitHub issue to track this.
Also, can someone please help me land this when we think it is good to go. Thanks.
There was a problem hiding this comment.
Thanks for adding more tests, LGTM
I'd prefer handling it separately as it is not a bug introduced by this change
That's fine with me, but could you create a GitHub issue and link it in the code? Thanks!
cd897b3 to
b7fcffd
Compare
Infer signed/unsigned cast intent from cast ops in linalg.generic bodies and propagate it via the matmul cast attribute. This could otherwise lead to silent overflow/underflow errors in e2e execution. TODO: Extend this to other named ops that support cast attribute.
b7fcffd to
2dda974
Compare
Infer signed/unsigned cast intent from cast ops in linalg.generic bodies and
propagate it via the matmul cast attribute. This could otherwise lead to
silent overflow/underflow errors in e2e execution.
TODO: Extend this to other named ops that support cast attribute.
Fixes a functional bug in the attached issue.
Fixes #174517