Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,20 @@ def ApplyDropUnitDimWithShapeCastPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}

def ApplyDropInnerMostUnitDimsFromXferOpsPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.drop_inner_most_unit_dims_from_xfer_ops",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Apply vector patterns to drop the inner most unit dims from
vector.transfer_read and vector.transfer_write Ops by taking a subview (via
memref.subview) of the original source/destination MemRef. Since it
requires the input/ouptu to be MemRefs, this Op is only helpful
past-bufferization.
}];

let assemblyFormat = "attr-dict";
}

def ApplyTransferPermutationPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.transfer_permutation_patterns",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ void transform::ApplyDropUnitDimWithShapeCastPatternsOp::populatePatterns(
vector::populateDropUnitDimWithShapeCastPatterns(patterns);
}

void transform::ApplyDropInnerMostUnitDimsFromXferOpsPatternsOp::
populatePatterns(RewritePatternSet &patterns) {
vector::populateDropInnerMostUnitDimsXferOpPatterns(patterns);
}

void transform::ApplyLowerBitCastPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorBitCastLoweringPatterns(patterns);
Expand Down
11 changes: 11 additions & 0 deletions mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module @transforms attributes { transform.with_named_sequence } {
transform.named_sequence @drop_unit_dims(%module: !transform.any_op {transform.readonly}) {

%func_op = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.drop_inner_most_unit_dims_from_xfer_ops
} : !transform.op<"func.func">

transform.yield
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// RUN: mlir-opt %s -test-vector-transfer-collapse-inner-most-dims -split-input-file | FileCheck %s
// RUN: mlir-opt -split-input-file \
// RUN: -transform-preload-library='transform-library-paths=%p/td/xfer-drop-unit-dims.mlir' \
// RUN: -transform-interpreter=entry-point=drop_unit_dims %s | FileCheck %s

//-----------------------------------------------------------------------------
// 1. vector.transfer_read
Expand Down
32 changes: 0 additions & 32 deletions mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,36 +344,6 @@ struct TestVectorTransferOpt
}
};

struct TestVectorTransferCollapseInnerMostContiguousDims
: public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorTransferCollapseInnerMostContiguousDims)

TestVectorTransferCollapseInnerMostContiguousDims() = default;
TestVectorTransferCollapseInnerMostContiguousDims(
const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect, affine::AffineDialect>();
}

StringRef getArgument() const final {
return "test-vector-transfer-collapse-inner-most-dims";
}

StringRef getDescription() const final {
return "Test lowering patterns that reduces the rank of the vector "
"transfer memory and vector operands.";
}

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateDropInnerMostUnitDimsXferOpPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

struct TestVectorSinkPatterns
: public PassWrapper<TestVectorSinkPatterns, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorSinkPatterns)
Expand Down Expand Up @@ -1079,8 +1049,6 @@ void registerTestVectorLowerings() {

PassRegistration<TestVectorTransferOpt>();

PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();

PassRegistration<TestVectorSinkPatterns>();

PassRegistration<TestVectorReduceToContractPatternsPatterns>();
Expand Down
Loading