Skip to content

Commit

Permalink
[mlir][MemRef] Add patterns to extract address computations
Browse files Browse the repository at this point in the history
This patch adds patterns to rewrite memory accesses such that the resulting
accesses are only using a base pointer.
E.g.,
```mlir
memref.load %base[%off0, ...]
```

Will be rewritten in:
```mlir
%new_base = memref.subview %base[%off0,...][1,...][1,...]
memref.load %new_base[%c0,...]
```

The idea behind these patterns is to offer a way to more gradually lower
address computations.

These patterns are the exact opposite of FoldMemRefAliasOps.
I've implemented the support of only five operations in this patch:
- memref.load
- memref.store
- nvgpu.ldmatrix
- vector.transfer_read
- vector.transfer_write

Going forward we may want to provide an interface for these rewritings (and
the ones in FoldMemRefAliasOps.)
One step at a time!

Differential Revision: https://reviews.llvm.org/D146724
  • Loading branch information
qcolombet committed Mar 28, 2023
1 parent 86ce609 commit 54cda2e
Show file tree
Hide file tree
Showing 8 changed files with 830 additions and 0 deletions.
Expand Up @@ -49,4 +49,48 @@ def MemRefMultiBufferOp : Op<Transform_Dialect, "memref.multibuffer",
"$target attr-dict `:` functional-type(operands, results)";
}

def MemRefExtractAddressComputationsOp :
Op<Transform_Dialect, "memref.extract_address_computations",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformEachOpTrait, TransformOpInterface]> {
let summary = "Extract address computations from memory accesses";
let description = [{
Transformation that extracts address computations from instructions
with memory accesses such that these memory accesses use only a base
pointer.

For instance,
```mlir
memref.load %base[%off0, ...]
```

Will be rewritten in:
```mlir
%new_base = memref.subview %base[%off0,...][1,...][1,...]
memref.load %new_base[%c0,...]
```

Note: The current implementation requires that the input operation
is "isolated from above".

#### Return modes

This operation produces `definiteFailure` if the extraction fails for any
reason.
The operation always returns the handle to the target op that is expected
to be isolated from above.
}];

let arguments = (ins PDL_Operation:$target);
let results = (outs PDL_Operation:$transformed);

let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &transformResults,
::mlir::transform::TransformState &state);
}];
}
#endif // MEMREF_TRANSFORM_OPS
40 changes: 40 additions & 0 deletions mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -0,0 +1,40 @@
//===- Transforms.h - MemRef Dialect transformations ------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
/// This header declares functions that assit transformations in the MemRef
/// dialect.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H

namespace mlir {
class RewritePatternSet;

namespace memref {
/// Appends patterns for extracting address computations from the instructions
/// with memory accesses such that these memory accesses use only a base
/// pointer.
///
/// For instance,
/// ```mlir
/// memref.load %base[%off0, ...]
/// ```
///
/// Will be rewritten in:
/// ```mlir
/// %new_base = memref.subview %base[%off0,...][1,...][1,...]
/// memref.load %new_base[%c0,...]
/// ```
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);

} // namespace memref
} // namespace mlir

#endif
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt
Expand Up @@ -15,5 +15,7 @@ add_mlir_dialect_library(MLIRMemRefTransformOps
MLIRLoopLikeInterface
MLIRMemRefDialect
MLIRMemRefTransforms
MLIRNVGPUDialect
MLIRTransformDialect
MLIRVectorDialect
)
32 changes: 32 additions & 0 deletions mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
Expand Up @@ -11,10 +11,14 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"

using namespace mlir;
Expand Down Expand Up @@ -68,6 +72,31 @@ DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// MemRefExtractAddressComputationsOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::MemRefExtractAddressComputationsOp::applyToOne(
Operation *target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
auto diag = this->emitOpError("requires isolated-from-above targets");
diag.attachNote(target->getLoc()) << "non-isolated target";
return DiagnosedSilenceableFailure::definiteFailure();
}

MLIRContext *ctx = getContext();
RewritePatternSet patterns(ctx);
memref::populateExtractAddressComputationsPatterns(patterns);

if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
return emitDefaultDefiniteFailure(target);

results.push_back(target);
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
Expand All @@ -83,6 +112,9 @@ class MemRefTransformDialectExtension
declareDependentDialect<pdl::PDLDialect>();
declareGeneratedDialect<AffineDialect>();
declareGeneratedDialect<arith::ArithDialect>();
declareGeneratedDialect<memref::MemRefDialect>();
declareGeneratedDialect<nvgpu::NVGPUDialect>();
declareGeneratedDialect<vector::VectorDialect>();

registerTransformOps<
#define GET_OP_LIST
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
Expand Up @@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
ExpandOps.cpp
ExpandStridedMetadata.cpp
EmulateWideInt.cpp
ExtractAddressComputations.cpp
FoldMemRefAliasOps.cpp
MultiBuffer.cpp
NormalizeMemRefs.cpp
Expand All @@ -27,6 +28,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
MLIRInferTypeOpInterface
MLIRLoopLikeInterface
MLIRMemRefDialect
MLIRNVGPUDialect
MLIRPass
MLIRTensorDialect
MLIRTransforms
Expand Down

0 comments on commit 54cda2e

Please sign in to comment.