diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index caf9cdb3a3eb4..91165ddeb8887 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -281,6 +281,12 @@ bool mlir::linalg::detail::isContractionBody( Value yielded = getSourceSkipUnary(terminator->getOperand(0)); Operation *reductionOp = yielded.getDefiningOp(); + + if (!reductionOp) { + errs << "expected reduction op in body"; + return false; + } + if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) { errs << "expected reduction op to be binary"; return false; diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir index 542a7ed4a198b..3c5649fb63f62 100644 --- a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir +++ b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir @@ -14,3 +14,23 @@ func.func @transpose_and_broadcast(%arg0: tensor<7x8xf32>, %arg1: tensor<8x7x9xf } -> tensor<8x7x9xf32> return %0 : tensor<8x7x9xf32> } + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +// This tests checks that the pass does not crash when trying to specialize a +// contraction-like generic op with no reduction operation in its body. +// CHECK-LABEL: @test_fake_contraction +// CHECK: linalg.generic +func.func @test_fake_contraction(%arg0: tensor<4x4xi32>) -> tensor<4x4xi32> { + %0 = tensor.empty() : tensor<4x4xi32> + %1 = linalg.generic + {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} + ins(%arg0, %arg0 : tensor<4x4xi32>, tensor<4x4xi32>) outs(%0 : tensor<4x4xi32>) { + ^bb0(%in0: i32, %in1: i32, %out: i32): + linalg.yield %out : i32 + } -> tensor<4x4xi32> + return %1 : tensor<4x4xi32> +}