diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index 1366e920039bf4..a7855e6327b20e 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -300,7 +300,7 @@ def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>, Example: ```mlir - %1 = linalg.transpose %0 (i, j) -> (j, i) : memref + %1 = linalg.transpose %0 (i, j) -> (j, i) : memref to memref ``` }]; @@ -308,13 +308,7 @@ def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>, "OpBuilder &b, OperationState &result, Value view, " "AffineMapAttr permutation, ArrayRef attrs = {}">]; - let verifier = [{ - if (!permutation().isPermutation()) - return emitOpError("expected a permutation map"); - if (permutation().getNumDims() != getShapedType().getRank()) - return emitOpError("expected a permutation map of same rank as the view"); - return success(); - }]; + let verifier = [{ return ::verify(*this); }]; let extraClassDeclaration = [{ static StringRef getPermutationAttrName() { return "permutation"; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index fcead984dfe55c..77eb6448947791 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -846,13 +846,9 @@ Value SliceOp::getViewSource() { return view(); } //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// -void mlir::linalg::TransposeOp::build(OpBuilder &b, OperationState &result, - Value view, AffineMapAttr permutation, - ArrayRef attrs) { - auto permutationMap = permutation.getValue(); - assert(permutationMap); - auto memRefType = view.getType().cast(); +static MemRefType inferTransposeResultType(MemRefType memRefType, + AffineMap permutationMap) { auto rank = memRefType.getRank(); auto originalSizes = memRefType.getShape(); // Compute permuted sizes. @@ -867,11 +863,21 @@ void mlir::linalg::TransposeOp::build(OpBuilder &b, OperationState &result, auto res = getStridesAndOffset(memRefType, strides, offset); assert(succeeded(res) && strides.size() == static_cast(rank)); (void)res; - auto map = makeStridedLinearLayoutMap(strides, offset, b.getContext()); + auto map = + makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); map = permutationMap ? map.compose(permutationMap) : map; + return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map); +} + +void mlir::linalg::TransposeOp::build(OpBuilder &b, OperationState &result, + Value view, AffineMapAttr permutation, + ArrayRef attrs) { + auto permutationMap = permutation.getValue(); + assert(permutationMap); + + auto memRefType = view.getType().cast(); // Compute result type. - MemRefType resultType = - MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map); + MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); build(b, result, resultType, view, attrs); result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); @@ -881,19 +887,20 @@ static void print(OpAsmPrinter &p, TransposeOp op) { p << op.getOperationName() << " " << op.view() << " " << op.permutation(); p.printOptionalAttrDict(op.getAttrs(), {TransposeOp::getPermutationAttrName()}); - p << " : " << op.view().getType(); + p << " : " << op.view().getType() << " to " << op.getType(); } static ParseResult parseTransposeOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType view; AffineMap permutation; - MemRefType type; + MemRefType srcType, dstType; if (parser.parseOperand(view) || parser.parseAffineMap(permutation) || parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(type) || - parser.resolveOperand(view, type, result.operands) || - parser.addTypeToList(type, result.types)) + parser.parseColonType(srcType) || + parser.resolveOperand(view, srcType, result.operands) || + parser.parseKeywordType("to", dstType) || + parser.addTypeToList(dstType, result.types)) return failure(); result.addAttribute(TransposeOp::getPermutationAttrName(), @@ -901,6 +908,21 @@ static ParseResult parseTransposeOp(OpAsmParser &parser, return success(); } +static LogicalResult verify(TransposeOp op) { + if (!op.permutation().isPermutation()) + return op.emitOpError("expected a permutation map"); + if (op.permutation().getNumDims() != op.getShapedType().getRank()) + return op.emitOpError( + "expected a permutation map of same rank as the view"); + + auto srcType = op.view().getType().cast(); + auto dstType = op.getType().cast(); + if (dstType != inferTransposeResultType(srcType, op.permutation())) + return op.emitOpError("output type ") + << dstType << " does not match transposed input type " << srcType; + return success(); +} + //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index ca59ecd387ec3c..c631c47099b083 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -35,14 +35,21 @@ func @store_number_of_indices(%v : memref) { func @transpose_not_permutation(%v : memref(off + M * i + j)>>) { // expected-error @+1 {{expected a permutation map}} - linalg.transpose %v (i, j) -> (i, i) : memref(off + M * i + j)>> + linalg.transpose %v (i, j) -> (i, i) : memref(off + M * i + j)>> to memref(off + M * i + j)>> } // ----- func @transpose_bad_rank(%v : memref(off + M * i + j)>>) { // expected-error @+1 {{expected a permutation map of same rank as the view}} - linalg.transpose %v (i) -> (i) : memref(off + M * i + j)>> + linalg.transpose %v (i) -> (i) : memref(off + M * i + j)>> to memref(off + M * i + j)>> +} + +// ----- + +func @transpose_wrong_type(%v : memref(off + M * i + j)>>) { + // expected-error @+1 {{output type 'memref (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref (d0 * s1 + s0 + d1)>>'}} + linalg.transpose %v (i, j) -> (j, i) : memref(off + M * i + j)>> to memref(off + M * i + j)>> } // ----- diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir index 02693e5d1be464..c8031824d63073 100644 --- a/mlir/test/Dialect/Linalg/llvm.mlir +++ b/mlir/test/Dialect/Linalg/llvm.mlir @@ -70,7 +70,7 @@ func @slice_with_range_and_index(%arg0: memref, ptr, i64, array<1 x i64>, array<1 x i64>)> func @transpose(%arg0: memref) { - %0 = linalg.transpose %arg0 (i, j, k) -> (k, i, j) : memref + %0 = linalg.transpose %arg0 (i, j, k) -> (k, i, j) : memref to memref (d2 * s1 + s0 + d0 * s2 + d1)>> return } // CHECK-LABEL: func @transpose diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index 26966432469726..404c978fa61bb0 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -123,14 +123,15 @@ func @fill_view(%arg0: memref, %arg1: f32) { // ----- // CHECK-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)> +// CHECK-DAG: #[[$strided3DT:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)> func @transpose(%arg0: memref) { - %0 = linalg.transpose %arg0 (i, j, k) -> (k, j, i) : memref + %0 = linalg.transpose %arg0 (i, j, k) -> (k, j, i) : memref to memref (d2 * s1 + s0 + d1 * s2 + d0)>> return } // CHECK-LABEL: func @transpose // CHECK: linalg.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) : -// CHECK-SAME: memref +// CHECK-SAME: memref to memref // -----