Skip to content

Commit

Permalink
Remove PromoteTensorLoads pass, convert ExtractOp in TensorToFlow. (#…
Browse files Browse the repository at this point in the history
…6852)

Fixes #6756 (the [`tosa if.mlir`](https://github.com/google/iree/blob/main/iree/test/e2e/tosa_ops/if.mlir) file compiles successfully using `-iree-flow-enable-linalg-detensorize` with this change)

The `PromoteTensorLoads` pass was converting `i1` loads to `i8` loads using `ZeroExtendIOp` and `TruncateIOp`. That was producing weird cycles during compilation when detensoring was applied, and `flow` ops should be fine with i1 types. We still need to handle `i1` types when going to the HAL (since storage is incompatible) on the outside (external interface) and inside (codegen).
  • Loading branch information
ScottTodd authored Aug 25, 2021
1 parent 7fa8c20 commit d90f0fc
Show file tree
Hide file tree
Showing 17 changed files with 40 additions and 143 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,22 @@ struct ConvertTensorExtractSlicePattern
}
};

struct ConvertTensorExtractPattern
: public OpRewritePattern<tensor::ExtractOp> {
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::ExtractOp op,
PatternRewriter &rewriter) const override {
if (op->getParentOfType<Flow::DispatchWorkgroupsOp>()) {
return failure();
}

rewriter.replaceOpWithNewOp<IREE::Flow::TensorLoadOp>(
op, op.getResult().getType(), op.tensor(), op.indices());
return success();
}
};

struct ConvertTensorCastPattern : public OpRewritePattern<tensor::CastOp> {
using OpRewritePattern<tensor::CastOp>::OpRewritePattern;

Expand Down Expand Up @@ -307,14 +323,19 @@ struct ConvertTensorFromElementsPattern

} // namespace

void populateTensorToFlowPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
void populateTensorToFlowPatternsBeforeDispatchFormation(
MLIRContext *context, OwningRewritePatternList &patterns) {
patterns
.insert<ConvertTensorInsertSlicePattern, ConvertTensorExtractSlicePattern,
ConvertTensorCastPattern, ConvertTensorFromElementsPattern>(
context);
}

void populateTensorToFlowPatternsAfterDispatchFormation(
MLIRContext *context, OwningRewritePatternList &patterns) {
patterns.insert<ConvertTensorExtractPattern>(context);
}

} // namespace Flow
} // namespace IREE
} // namespace iree_compiler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@ namespace iree_compiler {
namespace IREE {
namespace Flow {

// Populates rewrite patterns for Tensor->Flow.
void populateTensorToFlowPatterns(MLIRContext *context,
OwningRewritePatternList &patterns);
// Adds patterns for Tensor->Flow, for running before dispatch region formation.
void populateTensorToFlowPatternsBeforeDispatchFormation(
MLIRContext *context, OwningRewritePatternList &patterns);

// Adds patterns for Tensor->Flow, for running after dispatch region formation.
void populateTensorToFlowPatternsAfterDispatchFormation(
MLIRContext *context, OwningRewritePatternList &patterns);

} // namespace Flow
} // namespace IREE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ iree_lit_test_suite(
srcs = enforce_glob(
[
"cast.mlir",
"extract.mlir",
"extract_slice.mlir",
"from_elements.mlir",
"insert_slice.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ iree_lit_test_suite(
lit
SRCS
"cast.mlir"
"extract.mlir"
"extract_slice.mlir"
"from_elements.mlir"
"insert_slice.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
// RUN: iree-opt -split-input-file -iree-flow-promote-tensor-loads %s | IreeFileCheck %s
// RUN: iree-opt -split-input-file -iree-flow-convert-to-flow-after-dispatch-formation %s | IreeFileCheck %s

func @tensor_extract(%arg0 : tensor<1xi32>, %arg1 : index) -> i32 {
// CHECK: %[[RESULT:.*]] = flow.tensor.load %arg0[%arg1]
// CHECK: %[[RESULT:.*]] = flow.tensor.load %arg0[%arg1] : tensor<1xi32>
// CHECK: return %[[RESULT]]
%extract = tensor.extract %arg0[%arg1] : tensor<1xi32>
return %extract : i32
}

// -----

func @tensor_extract_i1(%arg0 : tensor<1xi1>, %arg1 : index) -> i1 {
// CHECK: %[[ZEXT:.*]] = zexti %arg0 : tensor<1xi1> to tensor<1xi8>
// CHECK: %[[LOADED:.*]] = flow.tensor.load %[[ZEXT]][%arg1] : tensor<1xi8>
// CHECK: %[[RESULT:.*]] = trunci %[[LOADED]] : i8 to i1
// CHECK: %[[RESULT:.*]] = flow.tensor.load %arg0[%arg1] : tensor<1xi1>
// CHECK: return %[[RESULT]]
%extract = tensor.extract %arg0[%arg1] : tensor<1xi1>
return %extract : i1
Expand Down
1 change: 0 additions & 1 deletion iree/compiler/Dialect/Flow/Transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ cc_library(
"PassDetail.h",
"Passes.cpp",
"PromoteI1ToI8Pass.cpp",
"PromoteTensorLoads.cpp",
"StripAndSplatConstantVariables.cpp",
"TypeConverter.cpp",
"VerifyInputLegality.cpp",
Expand Down
1 change: 0 additions & 1 deletion iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ iree_cc_library(
"PassDetail.h"
"Passes.cpp"
"PromoteI1ToI8Pass.cpp"
"PromoteTensorLoads.cpp"
"StripAndSplatConstantVariables.cpp"
"TypeConverter.cpp"
"VerifyInputLegality.cpp"
Expand Down
3 changes: 2 additions & 1 deletion iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ struct ConvertToFlowBeforeDispatchFormation
LinalgTensorReshapeToFlowTensorReshape<linalg::TensorCollapseShapeOp>,
LinalgTensorReshapeToFlowTensorReshape<linalg::TensorExpandShapeOp>>(
context);
populateTensorToFlowPatterns(context, patterns);
populateTensorToFlowPatternsBeforeDispatchFormation(context, patterns);
IREE::Flow::TensorReshapeOp::getCanonicalizationPatterns(patterns, context);

if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
Expand All @@ -131,6 +131,7 @@ struct ConvertToFlowAfterDispatchFormation
RewritePatternSet patterns(&getContext());

patterns.insert<LinalgFillToFlowTensorSplat>(context);
populateTensorToFlowPatternsAfterDispatchFormation(context, patterns);
IREE::Flow::TensorReshapeOp::getCanonicalizationPatterns(patterns, context);

if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
Expand Down
6 changes: 0 additions & 6 deletions iree/compiler/Dialect/Flow/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,6 @@ void buildFlowTransformPassPipeline(OpPassManager &passManager) {
// an argument if two executables differ only in that one dimension).
passManager.addPass(IREE::Flow::createDeduplicateExecutablesPass());

// TODO: Prune and rename this pass. This runs after sending everything
// possible to the device and then legalizes any remaining h<->d loads,
// typically coming from top level flow control.
passManager.addNestedPass<mlir::FuncOp>(
IREE::Flow::createPromoteTensorLoadsPass());

// Create one function per remaining flow.executable that can be used with
// iree-benchmark-module to benchmark each dispatch individually, as well as
// exporting all original model entry points.
Expand Down
7 changes: 0 additions & 7 deletions iree/compiler/Dialect/Flow/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,6 @@ createConvertToFlowAfterDispatchFormation();
// Promote I1 tensor constants to I8 tensors to match later operations.
std::unique_ptr<OperationPass<mlir::FuncOp>> createPromoteI1ToI8Pass();

// Converts standard ops which match to flow.tensor.load (typically causing a
// read-back).
// Note that there are typically very specific phase ordering issues with
// performing such a conversion, so even though it is of fine granularity,
// this is maintained separately.
std::unique_ptr<OperationPass<mlir::FuncOp>> createPromoteTensorLoadsPass();

// Expands dynamic !shapex.ranked_shape dimensions in variables.
std::unique_ptr<OperationPass<mlir::ModuleOp>>
createExpandGlobalDynamicDimsPass();
Expand Down
6 changes: 0 additions & 6 deletions iree/compiler/Dialect/Flow/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,6 @@ def PromoteI1ToI8 :
let constructor = "mlir::iree_compiler::IREE::Flow::createPromoteI1ToI8Pass()";
}

def PromoteTensorLoads :
Pass<"iree-flow-promote-tensor-loads", "mlir::FuncOp"> {
let summary = "Converts standard ops which match to flow.tensor.load (typically causing a read-back)";
let constructor = "mlir::iree_compiler::IREE::Flow::createPromoteTensorLoadsPass()";
}

def StripAndSplatConstantVariables :
Pass<"iree-flow-strip-and-splat-constant-variables", "mlir::ModuleOp"> {
let summary = "Strips constant util.globals and replaces them with splats.";
Expand Down
97 changes: 0 additions & 97 deletions iree/compiler/Dialect/Flow/Transforms/PromoteTensorLoads.cpp

This file was deleted.

1 change: 0 additions & 1 deletion iree/compiler/Dialect/Flow/Transforms/test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ iree_lit_test_suite(
"pad_linalg_ops.mlir",
"pad_tensor_to_tensor.mlir",
"promote_i1_to_i8.mlir",
"promote_tensor_loads.mlir",
"strip_and_splat_constant_variables.mlir",
"transformation.mlir",
"verify_input_ir.mlir",
Expand Down
1 change: 0 additions & 1 deletion iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ iree_lit_test_suite(
"pad_linalg_ops.mlir"
"pad_tensor_to_tensor.mlir"
"promote_i1_to_i8.mlir"
"promote_tensor_loads.mlir"
"strip_and_splat_constant_variables.mlir"
"transformation.mlir"
"verify_input_ir.mlir"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func @turn_fill_into_splat(%arg0: tensor<?x?xf32>, %arg1: tensor<f32>, %arg2: in
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK: %[[VAL:.+]] = tensor.extract %[[ARG1]][]
// CHECK: %[[VAL:.+]] = flow.tensor.load %[[ARG1]] : tensor<f32>
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[RD0:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[D0]]]
Expand Down
7 changes: 1 addition & 6 deletions iree/compiler/InputConversion/MHLO/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

#include "iree/compiler/InputConversion/MHLO/Passes.h"

#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/InputConversion/Common/Passes.h"
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassOptions.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"
Expand Down Expand Up @@ -41,11 +41,6 @@ void buildMHLOInputConversionPassPipeline(OpPassManager &passManager) {
passManager.addPass(createConvertShapeToStandardPass());
passManager.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());

// Now that control flow has been lowered, promote and extract_element
// to tensor loads. This will be done again later once everything that can
// be is lowered to device.
passManager.addNestedPass<FuncOp>(IREE::Flow::createPromoteTensorLoadsPass());

// We also don't handle calls well on the old codepath; until we remove the
// use of the CFG we can continue inlining.
passManager.addPass(mlir::createInlinerPass());
Expand Down
5 changes: 0 additions & 5 deletions iree/compiler/InputConversion/TOSA/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,6 @@ void buildTOSAInputConversionPassPipeline(OpPassManager &passManager) {
passManager.addNestedPass<FuncOp>(tosa::createTosaToSCF());
passManager.addNestedPass<FuncOp>(createTopLevelSCFToCFGPass());

// Now that control flow has been lowered, promote and extract_element
// to tensor loads. This will be done again later once everything that can
// be is lowered to device.
passManager.addNestedPass<FuncOp>(IREE::Flow::createPromoteTensorLoadsPass());

// We also don't handle calls well on the old codepath; until we remove the
// use of the CFG we can continue inlining.
passManager.addPass(mlir::createInlinerPass());
Expand Down

0 comments on commit d90f0fc

Please sign in to comment.