From 3a10bf394b4ba578b34812f34c5271bda52088f6 Mon Sep 17 00:00:00 2001 From: Ahmed Taei Date: Tue, 22 Mar 2022 10:24:21 -0700 Subject: [PATCH] [NFC] Split BackendTypeConversion -> (BackendTypeConversion, BackendTypeConversionPasses) --- .../Transforms/BackendTypeConversion.cpp | 146 ---------------- .../BackendTypeConversionPasses.cpp | 161 ++++++++++++++++++ .../TorchConversion/Transforms/CMakeLists.txt | 1 + 3 files changed, 162 insertions(+), 146 deletions(-) create mode 100644 lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp index bc408f1ef71b..299415bfdb3f 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp @@ -7,17 +7,8 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Func/Transforms/FuncConversions.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" -#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" using namespace mlir; using namespace mlir::torch; @@ -175,140 +166,3 @@ void mlir::torch::TorchConversion::setupBackendTypeConversion( setupTorchFloatToF64Conversion(target, typeConverter); setupTorchGeneratorToI64Conversion(target, typeConverter); } - -//===----------------------------------------------------------------------===// -// FuncBackendTypeConversionPass -//===----------------------------------------------------------------------===// - -namespace { -struct FuncBackendTypeConversionPass - : public FuncBackendTypeConversionBase { - using FuncBackendTypeConversionBase< - FuncBackendTypeConversionPass>::FuncBackendTypeConversionBase; - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - void runOnOperation() override { - auto module = getOperation(); - auto *context = &getContext(); - - TypeConverter typeConverter; - RewritePatternSet patterns(context); - ConversionTarget target(*context); - typeConverter.addConversion([](Type type) { return type; }); - TorchConversion::setupBackendTypeConversion(target, typeConverter); - - populateFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter); - target.addDynamicallyLegalOp([&](FuncOp op) { - return typeConverter.isSignatureLegal(op.getType()) && - typeConverter.isLegal(&op.getBody()); - }); - populateCallOpTypeConversionPattern(patterns, typeConverter); - target.addDynamicallyLegalOp( - [&](func::CallOp op) { return typeConverter.isLegal(op); }); - - populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); - populateReturnOpTypeConversionPattern(patterns, typeConverter); - target.addLegalOp(); - - target.markUnknownOpDynamicallyLegal([&](Operation *op) { - return isNotBranchOpInterfaceOrReturnLikeOp(op) || - isLegalForBranchOpInterfaceTypeConversionPattern(op, - typeConverter) || - isLegalForReturnOpTypeConversionPattern(op, typeConverter); - }); - - if (failed(applyFullConversion(module, target, std::move(patterns)))) - signalPassFailure(); - } -}; -} // namespace - -std::unique_ptr> -mlir::torch::TorchConversion::createFuncBackendTypeConversionPass() { - return std::make_unique(); -} - -//===----------------------------------------------------------------------===// -// FinalizingBackendTypeConversionPass -//===----------------------------------------------------------------------===// - -namespace { -// In a finalizing conversion, we know that all of the source types have been -// converted to the destination types, so the materialization becomes an -// identity. -template -class FinalizeMaterialization : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - using OpAdaptor = typename OpTy::Adaptor; - LogicalResult - matchAndRewrite(OpTy op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOp(op, adaptor.getOperands()[0]); - return success(); - } -}; -} // namespace - -template -static void setupFinalization(ConversionTarget &target, - RewritePatternSet &patterns, - TypeConverter &typeConverter) { - target.addIllegalOp(); - patterns.add>(typeConverter, - patterns.getContext()); -} - -template -static void setupFinalization(ConversionTarget &target, - RewritePatternSet &patterns, - TypeConverter &typeConverter) { - setupFinalization(target, patterns, typeConverter); - setupFinalization(target, patterns, typeConverter); -} - -namespace { -struct FinalizingBackendTypeConversionPass - : public FinalizingBackendTypeConversionBase< - FinalizingBackendTypeConversionPass> { - using FinalizingBackendTypeConversionBase< - FinalizingBackendTypeConversionPass>::FinalizingBackendTypeConversionBase; - - void runOnOperation() override { - auto func = getOperation(); - auto *context = &getContext(); - - TypeConverter typeConverter; - RewritePatternSet patterns(context); - ConversionTarget target(*context); - - typeConverter.addConversion([](Type type) { return type; }); - TorchConversion::setupBackendTypeConversion(target, typeConverter); - - // Mark materializations as illegal in this pass (since we are finalizing) - // and add patterns that eliminate them. - setupFinalization(target, patterns, typeConverter); - - // If all result types are legal, and all block arguments are legal, then - // all types in the program are legal. - // - // We also check that the operand types are legal to avoid creating invalid - // IR. For example, this prevents the patterns from updating - // the types of the operands to a return op without updating the enclosing - // function. - target.markUnknownOpDynamicallyLegal( - [&](Operation *op) { return typeConverter.isLegal(op); }); - - if (failed(applyFullConversion(func, target, std::move(patterns)))) - signalPassFailure(); - } -}; -} // namespace - -std::unique_ptr> -mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() { - return std::make_unique(); -} diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp new file mode 100644 index 000000000000..e428e76240f8 --- /dev/null +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp @@ -0,0 +1,161 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::TorchConversion; + +//===----------------------------------------------------------------------===// +// FuncBackendTypeConversionPass +//===----------------------------------------------------------------------===// + +namespace { +struct FuncBackendTypeConversionPass + : public FuncBackendTypeConversionBase { + using FuncBackendTypeConversionBase< + FuncBackendTypeConversionPass>::FuncBackendTypeConversionBase; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { + auto module = getOperation(); + auto *context = &getContext(); + + TypeConverter typeConverter; + RewritePatternSet patterns(context); + ConversionTarget target(*context); + typeConverter.addConversion([](Type type) { return type; }); + TorchConversion::setupBackendTypeConversion(target, typeConverter); + + populateFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter); + target.addDynamicallyLegalOp([&](FuncOp op) { + return typeConverter.isSignatureLegal(op.getType()) && + typeConverter.isLegal(&op.getBody()); + }); + populateCallOpTypeConversionPattern(patterns, typeConverter); + target.addDynamicallyLegalOp( + [&](func::CallOp op) { return typeConverter.isLegal(op); }); + + populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); + target.addLegalOp(); + + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + return isNotBranchOpInterfaceOrReturnLikeOp(op) || + isLegalForBranchOpInterfaceTypeConversionPattern(op, + typeConverter) || + isLegalForReturnOpTypeConversionPattern(op, typeConverter); + }); + + if (failed(applyFullConversion(module, target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::TorchConversion::createFuncBackendTypeConversionPass() { + return std::make_unique(); +} + +//===----------------------------------------------------------------------===// +// FinalizingBackendTypeConversionPass +//===----------------------------------------------------------------------===// + +namespace { +// In a finalizing conversion, we know that all of the source types have been +// converted to the destination types, so the materialization becomes an +// identity. +template +class FinalizeMaterialization : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename OpTy::Adaptor; + LogicalResult + matchAndRewrite(OpTy op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(op, adaptor.getOperands()[0]); + return success(); + } +}; +} // namespace + +template +static void setupFinalization(ConversionTarget &target, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + target.addIllegalOp(); + patterns.add>(typeConverter, + patterns.getContext()); +} + +template +static void setupFinalization(ConversionTarget &target, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + setupFinalization(target, patterns, typeConverter); + setupFinalization(target, patterns, typeConverter); +} + +namespace { +struct FinalizingBackendTypeConversionPass + : public FinalizingBackendTypeConversionBase< + FinalizingBackendTypeConversionPass> { + using FinalizingBackendTypeConversionBase< + FinalizingBackendTypeConversionPass>::FinalizingBackendTypeConversionBase; + + void runOnOperation() override { + auto func = getOperation(); + auto *context = &getContext(); + + TypeConverter typeConverter; + RewritePatternSet patterns(context); + ConversionTarget target(*context); + + typeConverter.addConversion([](Type type) { return type; }); + TorchConversion::setupBackendTypeConversion(target, typeConverter); + + // Mark materializations as illegal in this pass (since we are finalizing) + // and add patterns that eliminate them. + setupFinalization(target, patterns, typeConverter); + + // If all result types are legal, and all block arguments are legal, then + // all types in the program are legal. + // + // We also check that the operand types are legal to avoid creating invalid + // IR. For example, this prevents the patterns from updating + // the types of the operands to a return op without updating the enclosing + // function. + target.markUnknownOpDynamicallyLegal( + [&](Operation *op) { return typeConverter.isLegal(op); }); + + if (failed(applyFullConversion(func, target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt index 782dc3ac0e3f..77f58e53cd4f 100644 --- a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt +++ b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_library(TorchMLIRTorchConversionPasses BackendTypeConversion.cpp + BackendTypeConversionPasses.cpp Passes.cpp VerifyInvariantsBeforeBackendLowering.cpp VerifyLinalgOnTensorsBackendContract.cpp