Skip to content

Commit

Permalink
[mlir][linalg] Relax convolution vectorization to support mixed types
Browse files Browse the repository at this point in the history
Support the case where convolution does float extension of the inputs.

Differential Revision: https://reviews.llvm.org/D127925
  • Loading branch information
ThomasRaoux committed Jun 16, 2022
1 parent 6ed81ec commit 046ebeb
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 3 deletions.
25 changes: 22 additions & 3 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Expand Up @@ -1374,10 +1374,29 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
maybeKind = getCombinerOpKind(reduceOp);
if (!maybeKind || *maybeKind != vector::CombiningKind::ADD)
return;
maybeKind = getCombinerOpKind(&(linalgOp->getRegion(0).front().front()));
if (!maybeKind || *maybeKind != vector::CombiningKind::MUL)
// Check for single `mul` predecessor. The `mul` operands must be block
// arguments or extension of block arguments.
Operation *mulOp = nullptr;
for (Value operand : reduceOp->getOperands()) {
if (operand.isa<BlockArgument>())
continue;
if (mulOp)
return;
mulOp = operand.getDefiningOp();
if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
return;
}
if (!mulOp)
return;

for (Value operand : mulOp->getOperands()) {
if (Operation *def = operand.getDefiningOp()) {
if (!isa<arith::ExtFOp>(def))
return;
operand = def->getOperand(0);
}
if (!operand.isa<BlockArgument>())
return;
}
// The op is now known to be valid.
valid = true;
}
Expand Down
26 changes: 26 additions & 0 deletions mlir/test/Dialect/Linalg/vectorize-convolution.mlir
Expand Up @@ -224,3 +224,29 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filt

// Write the result back in one shot.
// CHECK: vector.transfer_write %[[FMA_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]


// -----

func.func @conv_1d_nwc_wcf_mixed_type_memref(%input: memref<1x2x3xf16>, %filter: memref<1x3x2xf16>, %output: memref<1x2x2xf32>) {
linalg.conv_1d_nwc_wcf
{dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>}
ins(%input, %filter : memref<1x2x3xf16>, memref<1x3x2xf16>)
outs(%output : memref<1x2x2xf32>)
return
}

// CHECK: func @conv_1d_nwc_wcf_mixed_type_memref
// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: memref<1x2x3xf16>, %[[FILTER:[0-9a-z]+]]: memref<1x3x2xf16>, %[[OUTPUT:[0-9a-z]+]]: memref<1x2x2xf32>)

// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32

/// Read the whole data in one shot.
// CHECK: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
// CHECK: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]]
// CHECK: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<1x3x2xf16>
// CHECK: %[[CONT:.*]] = vector.contract
// {{.*}} %[[V_INPUT_R]], %[[V_FILTER_1]], %[[V_OUTPUT_R]] : vector<1x2x3xf16>, vector<3x2xf16> into vector<1x2x2xf32>
// CHECK: vector.transfer_write %[[CONT]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]

0 comments on commit 046ebeb

Please sign in to comment.