Skip to content

Commit

Permalink
Revert "Don't attempt to create vectors with complex element types."
Browse files Browse the repository at this point in the history
This reverts commit 91181db.
  • Loading branch information
jreiffers committed Jan 12, 2023
1 parent d987fe6 commit 0bd83f9
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 71 deletions.
86 changes: 30 additions & 56 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,6 @@ using namespace mlir::linalg;
static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
LinalgOp convOp);

/// Return the given vector type if `elementType` is valid.
static FailureOr<VectorType> getVectorType(ArrayRef<int64_t> shape,
Type elementType) {
if (!VectorType::isValidElementType(elementType)) {
return failure();
}
return VectorType::get(shape, elementType);
}

/// Cast the given type to a vector type if its element type is valid.
static FailureOr<VectorType> getVectorType(ShapedType type) {
return getVectorType(type.getShape(), type.getElementType());
}

/// Return the unique instance of OpType in `block` if it is indeed unique.
/// Return null if none or more than 1 instances exist.
template <typename OpType>
Expand Down Expand Up @@ -445,7 +431,8 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value,
vector::BroadcastableToResult::Success)
return value;
Location loc = b.getInsertionPoint()->getLoc();
return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType,
value);
}

/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
Expand Down Expand Up @@ -898,20 +885,18 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
}

auto readType =
getVectorType(readVecShape, getElementTypeOrSelf(opOperand->get()));
if (!succeeded(readType))
return failure();
VectorType::get(readVecShape, getElementTypeOrSelf(opOperand->get()));
SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);

Operation *read = rewriter.create<vector::TransferReadOp>(
loc, *readType, opOperand->get(), indices, readMap);
loc, readType, opOperand->get(), indices, readMap);
read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
Value readValue = read->getResult(0);

// 3.b. If masked, set in-bounds to true. Masking guarantees that the access
// will be in-bounds.
if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
SmallVector<bool> inBounds(readType->getRank(), true);
SmallVector<bool> inBounds(readType.getRank(), true);
cast<vector::TransferReadOp>(maskOp.getMaskableOp())
.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
}
Expand Down Expand Up @@ -1150,22 +1135,21 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
return failure();

auto readType = getVectorType(srcType);
auto writeType = getVectorType(dstType);
if (!(succeeded(readType) && succeeded(writeType)))
return failure();
auto readType =
VectorType::get(srcType.getShape(), getElementTypeOrSelf(srcType));
auto writeType =
VectorType::get(dstType.getShape(), getElementTypeOrSelf(dstType));

Location loc = copyOp->getLoc();
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> indices(srcType.getRank(), zero);

Value readValue = rewriter.create<vector::TransferReadOp>(
loc, *readType, copyOp.getSource(), indices,
loc, readType, copyOp.getSource(), indices,
rewriter.getMultiDimIdentityMap(srcType.getRank()));
if (readValue.getType().cast<VectorType>().getRank() == 0) {
readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
readValue =
rewriter.create<vector::BroadcastOp>(loc, *writeType, readValue);
readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
}
Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
loc, readValue, copyOp.getTarget(), indices,
Expand Down Expand Up @@ -1216,10 +1200,6 @@ struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
auto sourceType = padOp.getSourceType();
auto resultType = padOp.getResultType();

// Complex is not a valid vector element type.
if (!VectorType::isValidElementType(sourceType.getElementType()))
return failure();

// Copy cannot be vectorized if pad value is non-constant and source shape
// is dynamic. In case of a dynamic source shape, padding must be appended
// by TransferReadOp, but TransferReadOp supports only constant padding.
Expand Down Expand Up @@ -1568,17 +1548,15 @@ struct PadOpVectorizationWithInsertSlicePattern
if (insertOp.getDest() == padOp.getResult())
return failure();

auto vecType = getVectorType(padOp.getType());
if (!succeeded(vecType))
return failure();
unsigned vecRank = vecType->getRank();
auto vecType = VectorType::get(padOp.getType().getShape(),
padOp.getType().getElementType());
unsigned vecRank = vecType.getRank();
unsigned tensorRank = insertOp.getType().getRank();

// Check if sizes match: Insert the entire tensor into most minor dims.
// (No permutations allowed.)
SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
expectedSizes.append(vecType->getShape().begin(),
vecType->getShape().end());
expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
if (!llvm::all_of(
llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
Expand All @@ -1594,7 +1572,7 @@ struct PadOpVectorizationWithInsertSlicePattern
SmallVector<Value> readIndices(
vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
auto read = rewriter.create<vector::TransferReadOp>(
padOp.getLoc(), *vecType, padOp.getSource(), readIndices, padValue);
padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);

// Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
// specified offsets. Write is fully in-bounds because a InsertSliceOp's
Expand Down Expand Up @@ -1902,7 +1880,7 @@ struct Conv1DGenerator
auto rhsRank = rhsShapedType.getRank();
switch (oper) {
case Conv:
if (rhsRank != 2 && rhsRank != 3)
if (rhsRank != 2 && rhsRank!= 3)
return;
break;
case Pool:
Expand Down Expand Up @@ -2004,24 +1982,22 @@ struct Conv1DGenerator
Type lhsEltType = lhsShapedType.getElementType();
Type rhsEltType = rhsShapedType.getElementType();
Type resEltType = resShapedType.getElementType();
auto lhsType = getVectorType(lhsShape, lhsEltType);
auto rhsType = getVectorType(rhsShape, rhsEltType);
auto resType = getVectorType(resShape, resEltType);
if (!(succeeded(lhsType) && succeeded(rhsType) && succeeded(resType)))
return failure();
auto lhsType = VectorType::get(lhsShape, lhsEltType);
auto rhsType = VectorType::get(rhsShape, rhsEltType);
auto resType = VectorType::get(resShape, resEltType);
// Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0,
// 0].
Value lhs = rewriter.create<vector::TransferReadOp>(
loc, *lhsType, lhsShaped, ValueRange{zero, zero, zero});
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
// Read rhs slice of size {kw, c, f} @ [0, 0, 0].
// This is needed only for Conv.
Value rhs = nullptr;
if (oper == Conv)
rhs = rewriter.create<vector::TransferReadOp>(
loc, *rhsType, rhsShaped, ValueRange{zero, zero, zero});
loc, rhsType, rhsShaped, ValueRange{zero, zero, zero});
// Read res slice of size {n, w, f} @ [0, 0, 0].
Value res = rewriter.create<vector::TransferReadOp>(
loc, *resType, resShaped, ValueRange{zero, zero, zero});
loc, resType, resShaped, ValueRange{zero, zero, zero});

// The base vectorization case is input: {n,w,c}, weight: {kw,c,f}, output:
// {n,w,f}. To reuse the base pattern vectorization case, we do pre
Expand Down Expand Up @@ -2188,28 +2164,26 @@ struct Conv1DGenerator
Type lhsEltType = lhsShapedType.getElementType();
Type rhsEltType = rhsShapedType.getElementType();
Type resEltType = resShapedType.getElementType();
auto lhsType = getVectorType(
VectorType lhsType = VectorType::get(
{nSize,
// iw = ow * sw + kw * dw - 1
// (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
cSize},
lhsEltType);
auto rhsType = getVectorType({kwSize, cSize}, rhsEltType);
auto resType = getVectorType({nSize, wSize, cSize}, resEltType);
if (!(succeeded(lhsType) && succeeded(rhsType) && succeeded(resType)))
return failure();
VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);

// Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
// 0].
Value lhs = rewriter.create<vector::TransferReadOp>(
loc, *lhsType, lhsShaped, ValueRange{zero, zero, zero});
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
// Read rhs slice of size {kw, c} @ [0, 0].
Value rhs = rewriter.create<vector::TransferReadOp>(
loc, *rhsType, rhsShaped, ValueRange{zero, zero});
Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
ValueRange{zero, zero});
// Read res slice of size {n, w, c} @ [0, 0, 0].
Value res = rewriter.create<vector::TransferReadOp>(
loc, *resType, resShaped, ValueRange{zero, zero, zero});
loc, resType, resShaped, ValueRange{zero, zero, zero});

//===------------------------------------------------------------------===//
// Begin vector-only rewrite part
Expand Down
15 changes: 0 additions & 15 deletions mlir/test/Dialect/Linalg/vectorize-convolution.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -850,18 +850,3 @@ func.func @pooling_ncw_sum_memref_2_3_2_1(%input: memref<4x2x5xf32>, %filter: me
// CHECK: %[[V7:.+]] = arith.addf %[[V5]], %[[V6]] : vector<4x3x2xf32>
// CHECK: %[[V8:.+]] = vector.transpose %[[V7]], [0, 2, 1] : vector<4x3x2xf32> to vector<4x2x3xf32>
// CHECK: vector.transfer_write %[[V8:.+]], %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32>


// -----

func.func @pooling_ncw_sum_memref_complex(%input: memref<4x2x5xcomplex<f32>>,
%filter: memref<2xcomplex<f32>>, %output: memref<4x2x3xcomplex<f32>>) {
linalg.pooling_ncw_sum
{dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
ins(%input, %filter : memref<4x2x5xcomplex<f32>>, memref<2xcomplex<f32>>)
outs(%output : memref<4x2x3xcomplex<f32>>)
return
}

// Regression test: just check that this lowers successfully
// CHECK-LABEL: @pooling_ncw_sum_memref_complex

0 comments on commit 0bd83f9

Please sign in to comment.