diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp index ba763ec2137e7..5ec8b4ee1e945 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/IR/Matchers.h" @@ -28,6 +29,7 @@ namespace { /// `ConcreteType` should provide methods with signatures /// /// ```c++ +/// bool checkElementTypes(LinalgOp linalgOp) const; /// bool matchIndexingMaps(LinalgOp linalgOp) const; /// RegionComputationFn getRegionComputeFn(LinalgOp) const; /// ``` @@ -75,18 +77,18 @@ class FoldConstantBase : public OpInterfaceRewritePattern { })) return failure(); - // Make sure all element types are the same. - auto getOperandElementType = [](Value value) { - return cast(value.getType()).getElementType(); - }; - if (!llvm::all_equal( - llvm::map_range(linalgOp->getOperands(), getOperandElementType))) - return failure(); - // We can only handle the case where we have int/float elements. auto elementType = outputType.getElementType(); if (!elementType.isIntOrFloat()) return failure(); + for (Value input : linalgOp.getDpsInputs()) { + Type elemTy = cast(input.getType()).getElementType(); + if (!elemTy.isIntOrFloat()) + return failure(); + } + + if (!static_cast(this)->checkElementTypes(linalgOp)) + return failure(); // Require all indexing maps to be permutations for now. This is common and // it simplifies input/output access greatly: we can do the data shuffling @@ -267,6 +269,14 @@ struct FoldConstantTranspose : public FoldConstantBase { using FoldConstantBase::FoldConstantBase; + // Transpose requires all operand element types to match. + bool checkElementTypes(LinalgOp linalgOp) const { + auto getElem = [](Value v) { + return cast(v.getType()).getElementType(); + }; + return llvm::all_equal(llvm::map_range(linalgOp->getOperands(), getElem)); + } + bool matchIndexingMaps(LinalgOp linalgOp) const { // We should have one input and one output. return linalgOp.getIndexingMapsArray().size() == 2; @@ -300,10 +310,49 @@ struct FoldConstantTranspose : public FoldConstantBase { ControlFusionFn controlFn; }; + +// Folds a linalg.generic whose body is a single arith cast op on the input +// block arg, when the input is a constant. Only `arith.extsi` is supported for +// now. In the future arith ops like extui, trunci, sitofp, uitofp, extf, +// truncf, fptosi, fptoui could be added as well. +struct FoldConstantCast : public FoldConstantBase { + using FoldConstantBase::FoldConstantBase; + + // Allow differing input/output element types. + bool checkElementTypes(LinalgOp) const { return true; } + + bool matchIndexingMaps(LinalgOp linalgOp) const { + return linalgOp.getNumDpsInputs() == 1; + } + + RegionComputationFn getRegionComputeFn(LinalgOp linalgOp) const { + Block &body = linalgOp->getRegion(0).front(); + // Expect exactly two ops: the cast, then `linalg.yield`. + if (body.getOperations().size() != 2) + return nullptr; + auto yieldOp = dyn_cast(body.getTerminator()); + if (!yieldOp || yieldOp.getValues().size() != 1) + return nullptr; + + auto castOp = yieldOp.getValues().front().getDefiningOp(); + if (!castOp || castOp->getBlock() != &body) + return nullptr; + + // The cast must consume `bb0` arg 0. + auto inArg = dyn_cast(castOp.getIn()); + if (!inArg || inArg.getOwner() != &body || inArg.getArgNumber() != 0) + return nullptr; + + unsigned outBW = castOp.getResult().getType().getIntOrFloatBitWidth(); + return [outBW](const APIntOrFloatArray &inputs) { + return APIntOrFloat{inputs.apInts.front().sext(outBW), std::nullopt}; + }; + } +}; } // namespace void mlir::linalg::populateConstantFoldLinalgOperations( RewritePatternSet &patterns, const ControlFusionFn &controlFn) { MLIRContext *context = patterns.getContext(); - patterns.insert(context, controlFn); + patterns.insert(context, controlFn); } diff --git a/mlir/test/Dialect/Linalg/constant-fold.mlir b/mlir/test/Dialect/Linalg/constant-fold.mlir index 3929c26a3382f..ddb605a09623d 100644 --- a/mlir/test/Dialect/Linalg/constant-fold.mlir +++ b/mlir/test/Dialect/Linalg/constant-fold.mlir @@ -145,4 +145,72 @@ func.func @named_transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf3 // ----- +// CHECK-LABEL: @cast_fold_extsi_i32_to_i64 +func.func @cast_fold_extsi_i32_to_i64(%init: tensor<4xi64>) -> tensor<4xi64> { + %input = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> + // CHECK: %[[CST:.+]] = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi64> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%input : tensor<4xi32>) outs(%init : tensor<4xi64>) { + ^bb0(%arg1: i32, %arg2: i64): + %2 = arith.extsi %arg1 : i32 to i64 + linalg.yield %2 : i64 + } -> tensor<4xi64> + // CHECK: return %[[CST]] + return %1 : tensor<4xi64> +} + +// ----- + +// CHECK-LABEL: @cast_fold_extsi_negative +func.func @cast_fold_extsi_negative(%init: tensor<2xi64>) -> tensor<2xi64> { + %input = arith.constant dense<[-1, -2]> : tensor<2xi32> + // CHECK: %[[CST:.+]] = arith.constant dense<[-1, -2]> : tensor<2xi64> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%input : tensor<2xi32>) outs(%init : tensor<2xi64>) { + ^bb0(%arg1: i32, %arg2: i64): + %2 = arith.extsi %arg1 : i32 to i64 + linalg.yield %2 : i64 + } -> tensor<2xi64> + // CHECK: return %[[CST]] + return %1 : tensor<2xi64> +} + +// ----- + +// CHECK-LABEL: @cast_nofold_non_cst_input +func.func @cast_nofold_non_cst_input(%input: tensor<4xi32>, %init: tensor<4xi64>) -> tensor<4xi64> { + // CHECK: linalg.generic + %1 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%input : tensor<4xi32>) outs(%init : tensor<4xi64>) { + ^bb0(%arg1: i32, %arg2: i64): + %2 = arith.extsi %arg1 : i32 to i64 + linalg.yield %2 : i64 + } -> tensor<4xi64> + return %1 : tensor<4xi64> +} + +// ----- + +// CHECK-LABEL: @cast_nofold_multi_ops_in_region +func.func @cast_nofold_multi_ops_in_region(%init: tensor<4xi64>) -> tensor<4xi64> { + %input = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> + %two = arith.constant 2 : i64 + // CHECK: linalg.generic + %1 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%input : tensor<4xi32>) outs(%init : tensor<4xi64>) { + ^bb0(%arg1: i32, %arg2: i64): + %2 = arith.extsi %arg1 : i32 to i64 + %3 = arith.muli %2, %two : i64 + linalg.yield %3 : i64 + } -> tensor<4xi64> + return %1 : tensor<4xi64> +}