diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 7ac759f635f87..0477815f329bf 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -2110,7 +2110,7 @@ vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv) { - if (isa(op.getOperation())) + if (isaConvolutionOpInterface(op)) return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv); if (hasReductionIterator(op)) diff --git a/mlir/test/Dialect/Linalg/vectorization/unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization/unsupported.mlir index 271d6169609e9..59d51e432b743 100644 --- a/mlir/test/Dialect/Linalg/vectorization/unsupported.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/unsupported.mlir @@ -112,6 +112,36 @@ module attributes {transform.with_named_sequence} { // ----- +/// Dynamic spatial dims for non-depthwise conv is not supported. This is already +/// being tested for named ops and the following lit test checks that the same is +/// applicable to linalg.generic conv ops as well. +func.func @generic_conv1d_ncw_fcw_dyn_spatial(%input: tensor<1x8x?xf16>, %filter: tensor<4x8x1xf16>, %output: tensor<1x4x?xf16>) -> tensor<1x4x?xf16> { + // expected-error @+1 {{Attempted to vectorize, but failed}} + %0 = linalg.generic + {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} + ins(%input, %filter : tensor<1x8x?xf16>, tensor<4x8x1xf16>) + outs(%output : tensor<1x4x?xf16>) { + ^bb0(%in: f16, %filt: f16, %out: f16): + %mul = arith.mulf %in, %filt : f16 + %add = arith.addf %out, %mul : f16 + linalg.yield %add : f16 + } -> tensor<1x4x?xf16> + return %0 : tensor<1x4x?xf16> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 : !transform.any_op + transform.yield + } +} + +// ----- + func.func @conv2d_nchw_fchw(%input: tensor<1x5x8x8xf32>, %filter: tensor<4x5x3x3xf32>, %output: tensor<1x4x6x6xf32>) { // expected-error @+1 {{Attempted to vectorize, but failed}} linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%input, %filter : tensor<1x5x8x8xf32>, tensor<4x5x3x3xf32>) outs(%output : tensor<1x4x6x6xf32>) -> tensor<1x4x6x6xf32>