diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index ebf20acab4171..794fe97ead57e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1374,10 +1374,29 @@ struct Conv1DNwcGenerator : public StructuredGenerator { maybeKind = getCombinerOpKind(reduceOp); if (!maybeKind || *maybeKind != vector::CombiningKind::ADD) return; - maybeKind = getCombinerOpKind(&(linalgOp->getRegion(0).front().front())); - if (!maybeKind || *maybeKind != vector::CombiningKind::MUL) + // Check for single `mul` predecessor. The `mul` operands must be block + // arguments or extension of block arguments. + Operation *mulOp = nullptr; + for (Value operand : reduceOp->getOperands()) { + if (operand.isa()) + continue; + if (mulOp) + return; + mulOp = operand.getDefiningOp(); + if (!mulOp || !isa(mulOp)) + return; + } + if (!mulOp) return; - + for (Value operand : mulOp->getOperands()) { + if (Operation *def = operand.getDefiningOp()) { + if (!isa(def)) + return; + operand = def->getOperand(0); + } + if (!operand.isa()) + return; + } // The op is now known to be valid. valid = true; } diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir index a4eb9d26e9c8e..7e1f39cbda3e9 100644 --- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -224,3 +224,29 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filt // Write the result back in one shot. // CHECK: vector.transfer_write %[[FMA_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] + + +// ----- + +func.func @conv_1d_nwc_wcf_mixed_type_memref(%input: memref<1x2x3xf16>, %filter: memref<1x3x2xf16>, %output: memref<1x2x2xf32>) { + linalg.conv_1d_nwc_wcf + {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} + ins(%input, %filter : memref<1x2x3xf16>, memref<1x3x2xf16>) + outs(%output : memref<1x2x2xf32>) + return +} + +// CHECK: func @conv_1d_nwc_wcf_mixed_type_memref +// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: memref<1x2x3xf16>, %[[FILTER:[0-9a-z]+]]: memref<1x3x2xf16>, %[[OUTPUT:[0-9a-z]+]]: memref<1x2x2xf32>) + +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 + +/// Read the whole data in one shot. +// CHECK: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<1x3x2xf16> +// CHECK: %[[CONT:.*]] = vector.contract +// {{.*}} %[[V_INPUT_R]], %[[V_FILTER_1]], %[[V_OUTPUT_R]] : vector<1x2x3xf16>, vector<3x2xf16> into vector<1x2x2xf32> +// CHECK: vector.transfer_write %[[CONT]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]