Skip to content

Commit

Permalink
Revert "[mlir][Vector] Re-define masking semantics in vector.transfer…
Browse files Browse the repository at this point in the history
… ops"

This reverts commit 6c59c5c.
  • Loading branch information
dcaballe committed Nov 18, 2022
1 parent f51c915 commit 847b5f8
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 99 deletions.
14 changes: 3 additions & 11 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1203,11 +1203,9 @@ def Vector_TransferReadOp :
provided to specify a fallback value in the case of out-of-bounds accesses
and/or masking.

An optional SSA value `mask` may be specified to mask out elements read from
the MemRef/Tensor. The `mask` type is an `i1` vector with a shape that
matches how elements are read from the MemRef/Tensor, *before* any
permutation or broadcasting. Elements whose corresponding mask element is
`0` are masked out and replaced with `padding`.
An optional SSA value `mask` of the same shape as the vector type may be
specified to mask out elements. Such elements will be replaces with
`padding`. Elements whose corresponding mask element is `0` are masked out.

An optional boolean array attribute `in_bounds` specifies for every vector
dimension if the transfer is guaranteed to be within the source bounds.
Expand Down Expand Up @@ -1417,12 +1415,6 @@ def Vector_TransferWriteOp :

The size of the slice is specified by the size of the vector.

An optional SSA value `mask` may be specified to mask out elements written
to the MemRef/Tensor. The `mask` type is an `i1` vector with a shape that
matches how elements are written into the MemRef/Tensor, *after* applying
any permutation. Elements whose corresponding mask element is `0` are
masked out.

An optional SSA value `mask` of the same shape as the vector type may be
specified to mask out elements. Elements whose corresponding mask element
is `0` are masked out.
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Interfaces/VectorInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"

namespace mlir {
namespace vector {
namespace detail {

/// Given the vector type and the permutation map of a vector transfer op,
/// compute the expected mask type.
VectorType transferMaskType(VectorType vecType, AffineMap map);

} // namespace detail
} // namespace vector
} // namespace mlir

/// Include the generated interface declarations.
#include "mlir/Interfaces/VectorInterfaces.h.inc"

Expand Down
19 changes: 5 additions & 14 deletions mlir/include/mlir/Interfaces/VectorInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -169,25 +169,16 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
}]
>,
InterfaceMethod<
/*desc=*/"Return the mask operand if the op has a mask. Otherwise, "
"return a empty value.",
/*retTy=*/"Value",
/*methodName=*/"getMask",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return $_op.getMask();
}]
>,
InterfaceMethod<
/*desc=*/"Return the mask type if the op has a mask. Otherwise, return "
"an empty VectorType.",
/*desc=*/"Return the mask type if the op has a mask.",
/*retTy=*/"::mlir::VectorType",
/*methodName=*/"getMaskType",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return $_op.getMask() ? $_op.getMask().getType() : ::mlir::VectorType();
return $_op.getMask()
? ::mlir::vector::detail::transferMaskType(
$_op.getVectorType(), $_op.getPermutationMap())
: ::mlir::VectorType();
}]
>,
InterfaceMethod<
Expand Down
85 changes: 25 additions & 60 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2873,8 +2873,7 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
static LogicalResult
verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
VectorType vectorType, VectorType maskType,
VectorType inferredMaskType, AffineMap permutationMap,
ArrayAttr inBounds) {
AffineMap permutationMap, ArrayAttr inBounds) {
if (op->hasAttr("masked")) {
return op->emitOpError("masked attribute has been removed. "
"Use in_bounds instead.");
Expand Down Expand Up @@ -2927,6 +2926,13 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
if (permutationMap.getNumResults() != vectorType.getRank())
return op->emitOpError("requires a permutation_map with result dims of "
"the same rank as the vector type");

VectorType expectedMaskType =
vector::detail::transferMaskType(vectorType, permutationMap);
if (maskType && expectedMaskType != maskType)
return op->emitOpError("expects mask type consistent with permutation "
"map: ")
<< maskType;
}

if (permutationMap.getNumSymbols() != 0)
Expand All @@ -2936,11 +2942,6 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
return op->emitOpError("requires a permutation_map with input dims of the "
"same rank as the source type");

if (maskType && maskType != inferredMaskType)
return op->emitOpError("inferred mask type (")
<< inferredMaskType << ") and mask operand type (" << maskType
<< ") don't match";

if (inBounds) {
if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
return op->emitOpError("expects the optional in_bounds attr of same rank "
Expand Down Expand Up @@ -2983,19 +2984,6 @@ void TransferReadOp::print(OpAsmPrinter &p) {
p << " : " << getShapedType() << ", " << getVectorType();
}

/// Infers the mask type for a transfer read given its vector type and
/// permutation map. The mask in a transfer read operation applies to the
/// tensor/buffer reading part of it and its type should match the shape read
/// *before* any permutation or broadcasting.
static VectorType inferTransferReadMaskType(VectorType vecType,
AffineMap permMap) {
auto i1Type = IntegerType::get(permMap.getContext(), 1);
AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
assert(invPermMap && "Inversed permutation map couldn't be computed");
SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
return VectorType::get(maskShape, i1Type);
}

ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
auto &builder = parser.getBuilder();
SMLoc typesLoc;
Expand Down Expand Up @@ -3026,14 +3014,13 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
VectorType vectorType = types[1].dyn_cast<VectorType>();
if (!vectorType)
return parser.emitError(typesLoc, "requires vector type");
auto permMapAttrName = TransferReadOp::getPermutationMapAttrStrName();
Attribute permMapAttr = result.attributes.get(permMapAttrName);
AffineMap permMap;
if (!permMapAttr) {
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
} else {
permMap = permMapAttr.cast<AffineMapAttr>().getValue();
auto permutationAttrName = TransferReadOp::getPermutationMapAttrStrName();
Attribute mapAttr = result.attributes.get(permutationAttrName);
if (!mapAttr) {
auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
// Update `mapAttr` that is used later to determine mask type.
mapAttr = AffineMapAttr::get(permMap);
result.attributes.set(permutationAttrName, mapAttr);
}
if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
parser.resolveOperands(indexInfo, indexType, result.operands) ||
Expand All @@ -3044,9 +3031,10 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
if (shapedType.getElementType().dyn_cast<VectorType>())
return parser.emitError(
maskInfo.location, "does not support masks with vector element type");
auto map = mapAttr.dyn_cast<AffineMapAttr>().getValue();
// Instead of adding the mask type as an op type, compute it based on the
// vector type and the permutation map (to keep the type signature small).
auto maskType = inferTransferReadMaskType(vectorType, permMap);
auto maskType = mlir::vector::detail::transferMaskType(vectorType, map);
if (parser.resolveOperand(maskInfo, maskType, result.operands))
return failure();
}
Expand All @@ -3064,17 +3052,13 @@ LogicalResult TransferReadOp::verify() {
VectorType maskType = getMaskType();
auto paddingType = getPadding().getType();
auto permutationMap = getPermutationMap();
VectorType inferredMaskType =
maskType ? inferTransferReadMaskType(vectorType, permutationMap)
: VectorType();
auto sourceElementType = shapedType.getElementType();

if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
return emitOpError("requires ") << shapedType.getRank() << " indices";

if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
shapedType, vectorType, maskType,
inferredMaskType, permutationMap,
shapedType, vectorType, maskType, permutationMap,
getInBounds() ? *getInBounds() : ArrayAttr())))
return failure();

Expand Down Expand Up @@ -3438,18 +3422,6 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
build(builder, result, vector, dest, indices, permutationMap, inBounds);
}

/// Infers the mask type for a transfer write given its vector type and
/// permutation map. The mask in a transfer read operation applies to the
/// tensor/buffer writing part of it and its type should match the shape written
/// *after* any permutation.
static VectorType inferTransferWriteMaskType(VectorType vecType,
AffineMap permMap) {
auto i1Type = IntegerType::get(permMap.getContext(), 1);
SmallVector<int64_t, 8> maskShape =
compressUnusedDims(permMap).compose(vecType.getShape());
return VectorType::get(maskShape, i1Type);
}

ParseResult TransferWriteOp::parse(OpAsmParser &parser,
OperationState &result) {
auto &builder = parser.getBuilder();
Expand Down Expand Up @@ -3477,14 +3449,11 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
ShapedType shapedType = types[1].dyn_cast<ShapedType>();
if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
return parser.emitError(typesLoc, "requires memref or ranked tensor type");
auto permMapAttrName = TransferWriteOp::getPermutationMapAttrStrName();
auto permMapAttr = result.attributes.get(permMapAttrName);
AffineMap permMap;
if (!permMapAttr) {
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
} else {
permMap = permMapAttr.cast<AffineMapAttr>().getValue();
auto permutationAttrName = TransferWriteOp::getPermutationMapAttrStrName();
auto attr = result.attributes.get(permutationAttrName);
if (!attr) {
auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
}
if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
Expand All @@ -3494,7 +3463,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
if (shapedType.getElementType().dyn_cast<VectorType>())
return parser.emitError(
maskInfo.location, "does not support masks with vector element type");
auto maskType = inferTransferWriteMaskType(vectorType, permMap);
auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type());
if (parser.resolveOperand(maskInfo, maskType, result.operands))
return failure();
}
Expand All @@ -3520,9 +3489,6 @@ LogicalResult TransferWriteOp::verify() {
VectorType vectorType = getVectorType();
VectorType maskType = getMaskType();
auto permutationMap = getPermutationMap();
VectorType inferredMaskType =
maskType ? inferTransferWriteMaskType(vectorType, permutationMap)
: VectorType();

if (llvm::size(getIndices()) != shapedType.getRank())
return emitOpError("requires ") << shapedType.getRank() << " indices";
Expand All @@ -3533,8 +3499,7 @@ LogicalResult TransferWriteOp::verify() {
return emitOpError("should not have broadcast dimensions");

if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
shapedType, vectorType, maskType,
inferredMaskType, permutationMap,
shapedType, vectorType, maskType, permutationMap,
getInBounds() ? *getInBounds() : ArrayAttr())))
return failure();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,26 @@ struct TransferReadPermutationLowering
newVectorShape[pos.value()] = originalShape[pos.index()];
}

// Transpose mask operand.
Value newMask;
if (op.getMask()) {
// Remove unused dims from the permutation map. E.g.:
// E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2)
// comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0)
auto comp = compressUnusedDims(map);
// Get positions of remaining result dims.
// E.g.: (d0, d1, d2) -> (d2, 0, d1, 0 d0)
// maskTransposeIndices = [ 2, 1, 0]
SmallVector<int64_t> maskTransposeIndices;
for (unsigned i = 0; i < comp.getNumResults(); ++i) {
if (auto expr = comp.getResult(i).dyn_cast<AffineDimExpr>())
maskTransposeIndices.push_back(expr.getPosition());
}

newMask = rewriter.create<vector::TransposeOp>(op.getLoc(), op.getMask(),
maskTransposeIndices);
}

// Transpose in_bounds attribute.
ArrayAttr newInBoundsAttr =
op.getInBounds() ? transposeInBoundsAttr(
Expand All @@ -94,8 +114,7 @@ struct TransferReadPermutationLowering
VectorType::get(newVectorShape, op.getVectorType().getElementType());
Value newRead = rewriter.create<vector::TransferReadOp>(
op.getLoc(), newReadType, op.getSource(), op.getIndices(),
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
newInBoundsAttr);
AffineMapAttr::get(newMap), op.getPadding(), newMask, newInBoundsAttr);

// Transpose result of transfer_read.
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
Expand Down Expand Up @@ -149,6 +168,11 @@ struct TransferWritePermutationLowering
return expr.dyn_cast<AffineDimExpr>().getPosition();
});

// Transpose mask operand.
Value newMask = op.getMask() ? rewriter.create<vector::TransposeOp>(
op.getLoc(), op.getMask(), indices)
: Value();

// Transpose in_bounds attribute.
ArrayAttr newInBoundsAttr =
op.getInBounds() ? transposeInBoundsAttr(
Expand All @@ -162,7 +186,7 @@ struct TransferWritePermutationLowering
map.getNumDims(), map.getNumResults(), rewriter.getContext());
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
op.getMask(), newInBoundsAttr);
newMask, newInBoundsAttr);

return success();
}
Expand Down
13 changes: 13 additions & 0 deletions mlir/lib/Interfaces/VectorInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@

using namespace mlir;

VectorType mlir::vector::detail::transferMaskType(VectorType vecType,
AffineMap map) {
auto i1Type = IntegerType::get(map.getContext(), 1);
SmallVector<int64_t, 8> shape;
for (int64_t i = 0; i < vecType.getRank(); ++i) {
// Only result dims have a corresponding dim in the mask.
if (map.getResult(i).template isa<AffineDimExpr>()) {
shape.push_back(vecType.getDimSize(i));
}
}
return VectorType::get(shape, i1Type);
}

//===----------------------------------------------------------------------===//
// VectorUnroll Interfaces
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

// CHECK-LABEL: func @transfer_read_2d_mask_transposed(
// CHECK-DAG: %[[PADDING:.*]] = arith.constant dense<-4.200000e+01> : vector<9xf32>
// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<{{.*}}> : vector<4x9xi1>
// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<{{.*}}> : vector<9x4xi1>
// CHECK: %[[MASK_MEM:.*]] = memref.alloca() : memref<vector<4x9xi1>>
// CHECK: memref.store %[[MASK]], %[[MASK_MEM]][] : memref<vector<4x9xi1>>
// CHECK: %[[MASK_T:.*]] = vector.transpose %[[MASK]], [1, 0] : vector<9x4xi1> to vector<4x9xi1>
// CHECK: memref.store %[[MASK_T]], %[[MASK_MEM]][] : memref<vector<4x9xi1>>
// CHECK: %[[MASK_CASTED:.*]] = vector.type_cast %[[MASK_MEM]] : memref<vector<4x9xi1>> to memref<4xvector<9xi1>>
// CHECK: scf.for {{.*}} {
// CHECK: scf.if {{.*}} {
Expand All @@ -24,10 +25,11 @@
func.func @transfer_read_2d_mask_transposed(
%A : memref<?x?xf32>, %base1: index, %base2: index) -> (vector<9x4xf32>) {
%fm42 = arith.constant -42.0: f32
%mask = arith.constant dense<[[1, 0, 1, 0, 1, 1, 1, 0, 1],
[0, 0, 1, 1, 1, 1, 1, 0, 1],
[1, 1, 1, 1, 1, 1, 1, 0, 1],
[0, 0, 1, 0, 1, 1, 1, 0, 1]]> : vector<4x9xi1>
%mask = arith.constant dense<[[1, 0, 1, 0], [0, 0, 1, 0],
[1, 1, 1, 1], [0, 1, 1, 0],
[1, 1, 1, 1], [1, 1, 1, 1],
[1, 1, 1, 1], [0, 0, 0, 0],
[1, 1, 1, 1]]> : vector<9x4xi1>
%f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
{permutation_map = affine_map<(d0, d1) -> (d1, d0)>} :
memref<?x?xf32>, vector<9x4xf32>
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Vector/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func.func @vector_transfer_ops(%arg0: memref<?x?xf32>,
%v0 = vector.splat %c0 : vector<4x3xi32>
%vi0 = vector.splat %i0 : vector<4x3xindex>
%m = arith.constant dense<[0, 0, 1, 0, 1]> : vector<5xi1>
%m2 = vector.splat %i1 : vector<4x5xi1>
%m2 = vector.splat %i1 : vector<5x4xi1>
//
// CHECK: vector.transfer_read
%0 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = affine_map<(d0, d1)->(d0)>} : memref<?x?xf32>, vector<128xf32>
Expand Down
Loading

0 comments on commit 847b5f8

Please sign in to comment.