Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 0 additions & 146 deletions lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,8 @@
//
//===----------------------------------------------------------------------===//

#include "PassDetail.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"

using namespace mlir;
using namespace mlir::torch;
Expand Down Expand Up @@ -175,140 +166,3 @@ void mlir::torch::TorchConversion::setupBackendTypeConversion(
setupTorchFloatToF64Conversion(target, typeConverter);
setupTorchGeneratorToI64Conversion(target, typeConverter);
}

//===----------------------------------------------------------------------===//
// FuncBackendTypeConversionPass
//===----------------------------------------------------------------------===//

namespace {
struct FuncBackendTypeConversionPass
: public FuncBackendTypeConversionBase<FuncBackendTypeConversionPass> {
using FuncBackendTypeConversionBase<
FuncBackendTypeConversionPass>::FuncBackendTypeConversionBase;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<TorchConversion::TorchConversionDialect>();
}
void runOnOperation() override {
auto module = getOperation();
auto *context = &getContext();

TypeConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);

populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns, typeConverter);
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
return typeConverter.isSignatureLegal(op.getType()) &&
typeConverter.isLegal(&op.getBody());
});
populateCallOpTypeConversionPattern(patterns, typeConverter);
target.addDynamicallyLegalOp<func::CallOp>(
[&](func::CallOp op) { return typeConverter.isLegal(op); });

populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
populateReturnOpTypeConversionPattern(patterns, typeConverter);
target.addLegalOp<ModuleOp>();

target.markUnknownOpDynamicallyLegal([&](Operation *op) {
return isNotBranchOpInterfaceOrReturnLikeOp(op) ||
isLegalForBranchOpInterfaceTypeConversionPattern(op,
typeConverter) ||
isLegalForReturnOpTypeConversionPattern(op, typeConverter);
});

if (failed(applyFullConversion(module, target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace

std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::TorchConversion::createFuncBackendTypeConversionPass() {
return std::make_unique<FuncBackendTypeConversionPass>();
}

//===----------------------------------------------------------------------===//
// FinalizingBackendTypeConversionPass
//===----------------------------------------------------------------------===//

namespace {
// In a finalizing conversion, we know that all of the source types have been
// converted to the destination types, so the materialization becomes an
// identity.
template <typename OpTy>
class FinalizeMaterialization : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern<OpTy>::OpConversionPattern;
using OpAdaptor = typename OpTy::Adaptor;
LogicalResult
matchAndRewrite(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(op, adaptor.getOperands()[0]);
return success();
}
};
} // namespace

template <typename OpTy>
static void setupFinalization(ConversionTarget &target,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
target.addIllegalOp<OpTy>();
patterns.add<FinalizeMaterialization<OpTy>>(typeConverter,
patterns.getContext());
}

template <typename OpTy, typename OpTy2, typename... OpTys>
static void setupFinalization(ConversionTarget &target,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
setupFinalization<OpTy>(target, patterns, typeConverter);
setupFinalization<OpTy2, OpTys...>(target, patterns, typeConverter);
}

namespace {
struct FinalizingBackendTypeConversionPass
: public FinalizingBackendTypeConversionBase<
FinalizingBackendTypeConversionPass> {
using FinalizingBackendTypeConversionBase<
FinalizingBackendTypeConversionPass>::FinalizingBackendTypeConversionBase;

void runOnOperation() override {
auto func = getOperation();
auto *context = &getContext();

TypeConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);

typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);

// Mark materializations as illegal in this pass (since we are finalizing)
// and add patterns that eliminate them.
setupFinalization<ToBuiltinTensorOp, FromBuiltinTensorOp, FromI1Op, ToI1Op,
FromI64Op, ToI64Op, FromF64Op, ToF64Op, I64ToGeneratorOp,
GeneratorToI64Op>(target, patterns, typeConverter);

// If all result types are legal, and all block arguments are legal, then
// all types in the program are legal.
//
// We also check that the operand types are legal to avoid creating invalid
// IR. For example, this prevents the patterns from updating
// the types of the operands to a return op without updating the enclosing
// function.
target.markUnknownOpDynamicallyLegal(
[&](Operation *op) { return typeConverter.isLegal(op); });

if (failed(applyFullConversion(func, target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace

std::unique_ptr<OperationPass<FuncOp>>
mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() {
return std::make_unique<FinalizingBackendTypeConversionPass>();
}
161 changes: 161 additions & 0 deletions lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "PassDetail.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::TorchConversion;

//===----------------------------------------------------------------------===//
// FuncBackendTypeConversionPass
//===----------------------------------------------------------------------===//

namespace {
struct FuncBackendTypeConversionPass
: public FuncBackendTypeConversionBase<FuncBackendTypeConversionPass> {
using FuncBackendTypeConversionBase<
FuncBackendTypeConversionPass>::FuncBackendTypeConversionBase;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<TorchConversion::TorchConversionDialect>();
}
void runOnOperation() override {
auto module = getOperation();
auto *context = &getContext();

TypeConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);

populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns, typeConverter);
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
return typeConverter.isSignatureLegal(op.getType()) &&
typeConverter.isLegal(&op.getBody());
});
populateCallOpTypeConversionPattern(patterns, typeConverter);
target.addDynamicallyLegalOp<func::CallOp>(
[&](func::CallOp op) { return typeConverter.isLegal(op); });

populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
populateReturnOpTypeConversionPattern(patterns, typeConverter);
target.addLegalOp<ModuleOp>();

target.markUnknownOpDynamicallyLegal([&](Operation *op) {
return isNotBranchOpInterfaceOrReturnLikeOp(op) ||
isLegalForBranchOpInterfaceTypeConversionPattern(op,
typeConverter) ||
isLegalForReturnOpTypeConversionPattern(op, typeConverter);
});

if (failed(applyFullConversion(module, target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace

std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::TorchConversion::createFuncBackendTypeConversionPass() {
return std::make_unique<FuncBackendTypeConversionPass>();
}

//===----------------------------------------------------------------------===//
// FinalizingBackendTypeConversionPass
//===----------------------------------------------------------------------===//

namespace {
// In a finalizing conversion, we know that all of the source types have been
// converted to the destination types, so the materialization becomes an
// identity.
template <typename OpTy>
class FinalizeMaterialization : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern<OpTy>::OpConversionPattern;
using OpAdaptor = typename OpTy::Adaptor;
LogicalResult
matchAndRewrite(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(op, adaptor.getOperands()[0]);
return success();
}
};
} // namespace

template <typename OpTy>
static void setupFinalization(ConversionTarget &target,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
target.addIllegalOp<OpTy>();
patterns.add<FinalizeMaterialization<OpTy>>(typeConverter,
patterns.getContext());
}

template <typename OpTy, typename OpTy2, typename... OpTys>
static void setupFinalization(ConversionTarget &target,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
setupFinalization<OpTy>(target, patterns, typeConverter);
setupFinalization<OpTy2, OpTys...>(target, patterns, typeConverter);
}

namespace {
struct FinalizingBackendTypeConversionPass
: public FinalizingBackendTypeConversionBase<
FinalizingBackendTypeConversionPass> {
using FinalizingBackendTypeConversionBase<
FinalizingBackendTypeConversionPass>::FinalizingBackendTypeConversionBase;

void runOnOperation() override {
auto func = getOperation();
auto *context = &getContext();

TypeConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);

typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);

// Mark materializations as illegal in this pass (since we are finalizing)
// and add patterns that eliminate them.
setupFinalization<ToBuiltinTensorOp, FromBuiltinTensorOp, FromI1Op, ToI1Op,
FromI64Op, ToI64Op, FromF64Op, ToF64Op, I64ToGeneratorOp,
GeneratorToI64Op>(target, patterns, typeConverter);

// If all result types are legal, and all block arguments are legal, then
// all types in the program are legal.
//
// We also check that the operand types are legal to avoid creating invalid
// IR. For example, this prevents the patterns from updating
// the types of the operands to a return op without updating the enclosing
// function.
target.markUnknownOpDynamicallyLegal(
[&](Operation *op) { return typeConverter.isLegal(op); });

if (failed(applyFullConversion(func, target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace

std::unique_ptr<OperationPass<FuncOp>>
mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() {
return std::make_unique<FinalizingBackendTypeConversionPass>();
}
1 change: 1 addition & 0 deletions lib/Dialect/TorchConversion/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_mlir_library(TorchMLIRTorchConversionPasses
BackendTypeConversion.cpp
BackendTypeConversionPasses.cpp
Passes.cpp
VerifyInvariantsBeforeBackendLowering.cpp
VerifyLinalgOnTensorsBackendContract.cpp
Expand Down