Skip to content

Commit

Permalink
[mlir][Linalg] Add a transform.structured.lower_pack op
Browse files Browse the repository at this point in the history
This revision introduces `transform.structured.lower_pack` which allows
rewriting a `tensor.pack` to `tensor.pad` + `tensor.expand_shape` + `linalg.transpose`.

The implementation is currently limited to static pack ops that do not have outer_dims permutations.

Differential Revision: https://reviews.llvm.org/D142881
  • Loading branch information
nicolasvasilache committed Jan 31, 2023
1 parent 1625530 commit 4ca52c6
Show file tree
Hide file tree
Showing 10 changed files with 351 additions and 89 deletions.
Expand Up @@ -17,11 +17,16 @@
namespace mlir {
class TilingInterface;
class RewriterBase;

namespace linalg {
class GenericOp;
class LinalgOp;
} // namespace linalg

namespace tensor {
class PackOp;
} // namespace tensor

namespace transform {
class TransformHandleTypeInterface;
// Types needed for builders.
Expand Down
Expand Up @@ -215,6 +215,43 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
}];
}

//===----------------------------------------------------------------------===//
// LowerPackOp
//===----------------------------------------------------------------------===//
def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
TransformEachOpTrait,
TransformOpInterface]> {
let description = [{
Rewrite a tensor.pack into tensor.pad + tensor.expand_shape + linalg.transpose.

#### Return modes

This operation ignores non-pack ops and drops them in the return.
This operation produces a silenceableFailure if the padding fails for any
reason.
If all the operations referred to by the `target` are rewritten, the
transform succeeds.
Return handles to the newly produced pad, expand_shape and transpose ops.
}];

let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target);
let results = (outs Transform_ConcreteOpType<"tensor.pad">:$pad_op,
Transform_ConcreteOpType<"tensor.expand_shape">:$expand_shape_op,
Transform_ConcreteOpType<"linalg.transpose">:$transpose_op);
let assemblyFormat = [{
$target attr-dict `:` functional-type(operands, results)
}];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::tensor::PackOp target,
::mlir::transform::ApplyToEachResultList &transformResults,
::mlir::transform::TransformState &state);
}];
}

//===----------------------------------------------------------------------===//
// MatchOp
//===----------------------------------------------------------------------===//
Expand Down
11 changes: 8 additions & 3 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Expand Up @@ -1142,12 +1142,17 @@ FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
RewriterBase &rewriter);

/// Struct to hold the result of a `pack` call.
struct PackResult {
SmallVector<tensor::PackOp> packOps;
linalg::LinalgOp packedLinalgOp;
SmallVector<tensor::UnPackOp> unPackOps;
};
/// Implement packing of a single LinalgOp by `packedSizes`.
/// There must be one packedSizes entry per `linalgOp` iterator.
/// Return the packed Linalg op on success, failure otherwise.
FailureOr<linalg::LinalgOp> pack(RewriterBase &rewriter,
linalg::LinalgOp linalgOp,
ArrayRef<OpFoldResult> packedSizes);
FailureOr<PackResult> pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
ArrayRef<OpFoldResult> packedSizes);

/// Struct to hold the result of a `packTranspose` call.
struct PackTransposeResult {
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
Expand Up @@ -115,7 +115,7 @@ FailureOr<int64_t> getConstantUpperBoundForIndex(Value value);
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
Value source, Value pad, bool nofold);

/// Returns a GenericOp that tansposes `inputTensor` into `outputTensor` using
/// Returns a GenericOp that transposes `inputTensor` into `outputTensor` using
/// `transposeVector` to permute the `inputTensor` dimensions.
GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
Value outputTensor,
Expand Down
8 changes: 7 additions & 1 deletion mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
Expand Up @@ -1113,7 +1113,13 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
}]>
];

let extraClassDeclaration = commonExtraClassDeclaration;
let extraClassDeclaration = commonExtraClassDeclaration # [{
static RankedTensorType
inferCollapsedType(RankedTensorType type, ArrayRef<AffineMap> reassociation);
static RankedTensorType
inferCollapsedType(RankedTensorType type,
SmallVector<ReassociationIndices> reassociation);
}];
let hasVerifier = 1;
}

Expand Down

0 comments on commit 4ca52c6

Please sign in to comment.