Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove IREE usage of the Global Dialect Registry #3036

Merged
merged 6 commits into from Sep 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions bindings/python/pyiree/compiler/BUILD
Expand Up @@ -90,6 +90,7 @@ pybind_cc_library(
"//iree/tools:init_iree_passes_and_dialects",
"//iree/tools:init_mlir_passes_and_dialects",
"//iree/tools:init_targets",
"//iree/tools:init_xla_dialects",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:IR",
Expand Down
1 change: 1 addition & 0 deletions bindings/python/pyiree/compiler/CMakeLists.txt
Expand Up @@ -62,6 +62,7 @@ iree_pyext_library(
iree::tools::init_iree_passes_and_dialects
iree::tools::init_mlir_passes_and_dialects
iree::tools::init_targets
iree::tools::init_xla_dialects
LLVMSupport
MLIRIR
MLIRSCFTransforms
Expand Down
18 changes: 10 additions & 8 deletions bindings/python/pyiree/compiler/compiler.cc
Expand Up @@ -31,13 +31,15 @@
#include "iree/tools/init_mlir_dialects.h"
#include "iree/tools/init_mlir_passes.h"
#include "iree/tools/init_targets.h"
#include "iree/tools/init_xla_dialects.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/Signals.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
#include "mlir/Parser.h"
#include "mlir/Pass/PassManager.h"
Expand Down Expand Up @@ -73,13 +75,6 @@ bool LLVMOnceInit() {
llvm::sys::DefaultOneShotPipeSignalHandler);
llvm::sys::PrintStackTraceOnErrorSignal("pyiree");

mlir::enableGlobalDialectRegistry(true);
// Register built-in MLIR dialects.
mlir::registerMlirDialects();

// Register IREE dialects, compiler module dialects, and HAL target backends.
mlir::iree_compiler::registerIreeDialects();
mlir::iree_compiler::registerIreeCompilerModuleDialects();
mlir::iree_compiler::registerHALTargetBackends();
mlir::iree_compiler::registerVMTargets();

Expand All @@ -98,6 +93,13 @@ bool LLVMOnceInit() {
return true;
}

void registerDialects(DialectRegistry& registry) {
mlir::registerMlirDialects(registry);
mlir::registerXLADialects(registry);
mlir::iree_compiler::registerIreeDialects(registry);
mlir::iree_compiler::registerIreeCompilerModuleDialects(registry);
}

void SetupLLVMModule(pybind11::module m) {
m.def("print_help_message", []() { llvm::cl::PrintHelpMessage(); });
m.def(
Expand Down Expand Up @@ -286,7 +288,7 @@ void DiagnosticCapture::ClearDiagnostics() { diagnostics_.clear(); }

CompilerContextBundle::CompilerContextBundle()
: default_capture_(&mlir_context_, nullptr) {
mlir_context_.loadAllGloballyRegisteredDialects();
registerDialects(mlir_context_.getDialectRegistry());
}
CompilerContextBundle::~CompilerContextBundle() = default;

Expand Down
2 changes: 1 addition & 1 deletion experimental/ModelBuilder/test/TestMatMulVulkan.cpp
Expand Up @@ -72,7 +72,7 @@ void testMatMul() {
const int height = 4;
const int width = 4;
StringLiteral funcName = "kernel_matmul";
MLIRContext context;
MLIRContext context(/*loadAllDialects=*/false);
ModelBuilder modelBuilder;
auto typeA = modelBuilder.getMemRefType({width, height}, modelBuilder.f32);
auto typeB = modelBuilder.getMemRefType({width, height}, modelBuilder.f32);
Expand Down
2 changes: 1 addition & 1 deletion experimental/ModelBuilder/test/TestSimpleJITVulkan.cpp
Expand Up @@ -43,7 +43,7 @@ using namespace mlir; // NOLINT

template <unsigned vecSize>
void testVectorAdd1d() {
MLIRContext context;
MLIRContext context(/*loadAllDialects=*/false);
ModelBuilder modelBuilder;
constexpr int workgroupSize = 32;
auto typeA = modelBuilder.getMemRefType(vecSize, modelBuilder.f32);
Expand Down
2 changes: 1 addition & 1 deletion experimental/ModelBuilder/test/TestVectorToGPU.cpp
Expand Up @@ -89,7 +89,7 @@ void testVecAdd() {
// Simple test a single warp.
const int width = warpSize;
StringLiteral funcName = "kernel_vecadd";
MLIRContext context;
MLIRContext context(/*loadAllDialects=*/false);
ModelBuilder modelBuilder;
auto nVectorType = modelBuilder.getVectorType(width, modelBuilder.f32);
auto typeA = modelBuilder.getMemRefType({width}, modelBuilder.f32);
Expand Down
18 changes: 17 additions & 1 deletion integrations/tensorflow/compiler/BUILD
Expand Up @@ -56,9 +56,25 @@ cc_library(

cc_binary(
name = "iree-tf-opt",
srcs = ["tf_opt_main.cc"],
deps = [
":tensorflow",
"//iree/tools:iree_opt_main",
"//integrations/tensorflow/compiler/dialect/tf_strings/ir:dialect",
"//integrations/tensorflow/compiler/dialect/tf_tensorlist/ir:tf_tensorlist_dialect",
"//iree/compiler/Conversion:init_conversions",
"//iree/compiler/Conversion/HLOToLinalg",
"//iree/compiler/Dialect/HAL/Conversion:Passes",
"//iree/tools:init_compiler_modules",
"//iree/tools:init_iree_passes_and_dialects",
"//iree/tools:init_mlir_passes_and_dialects",
"//iree/tools:init_targets",
"//iree/tools:init_xla_dialects",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes",
Expand Down
Expand Up @@ -16,10 +16,12 @@
#include "iree/base/signature_mangle.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/RegionGraphTraits.h"
#include "mlir/IR/SymbolTable.h"
Expand Down Expand Up @@ -164,6 +166,10 @@ class TFSavedModelLowerGlobalTensors
: public PassWrapper<TFSavedModelLowerGlobalTensors,
OperationPass<ModuleOp>> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::Flow::FlowDialect, IREEDialect>();
}

void runOnOperation() override {
if (failed(importTfSavedModelGlobalTensorsToIREEFlow(getOperation()))) {
signalPassFailure();
Expand Down
Expand Up @@ -19,6 +19,7 @@
#include "iree/compiler/Dialect/HAL/Conversion/ConversionDialectInterface.h"
#include "iree/compiler/Dialect/Modules/Strings/IR/Dialect.h"
#include "iree/compiler/Dialect/Modules/Strings/IR/Types.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"

Expand All @@ -37,7 +38,11 @@ void populateTFStringsToHALPatterns(MLIRContext *ctx,
// use tensor types.
class TfStringsToHALConversionInterface : public HALConversionDialectInterface {
public:
using HALConversionDialectInterface::HALConversionDialectInterface;
TfStringsToHALConversionInterface(Dialect *dialect)
: HALConversionDialectInterface(dialect) {
dialect->getContext()->loadDialect<IREE::Strings::StringsDialect>();
}

void setupConversionTarget(ConversionTarget &target,
OwningRewritePatternList &patterns,
TypeConverter &typeConverter) const override {
Expand Down
Expand Up @@ -21,6 +21,7 @@
#include "integrations/tensorflow/compiler/dialect/tf_strings/ir/types.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
Expand Down Expand Up @@ -89,6 +90,10 @@ class LowerTensorflowToStringsPass
: public PassWrapper<LowerTensorflowToStringsPass,
OperationPass<ModuleOp>> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<TFStringsDialect>();
}

void runOnOperation() override {
if (failed(run())) {
signalPassFailure();
Expand Down
Expand Up @@ -19,6 +19,7 @@
#include "iree/compiler/Dialect/HAL/Conversion/ConversionDialectInterface.h"
#include "iree/compiler/Dialect/Modules/TensorList/IR/TensorListDialect.h"
#include "iree/compiler/Dialect/Modules/TensorList/IR/TensorListTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
Expand All @@ -36,7 +37,10 @@ void populateTensorListToHALPatterns(MLIRContext *context,
class TfTensorListToHALConversionInterface
: public HALConversionDialectInterface {
public:
using HALConversionDialectInterface::HALConversionDialectInterface;
TfTensorListToHALConversionInterface(Dialect *dialect)
: HALConversionDialectInterface(dialect) {
dialect->getContext()->loadDialect<IREE::TensorList::TensorListDialect>();
}

void setupConversionTarget(ConversionTarget &target,
OwningRewritePatternList &patterns,
Expand Down
Expand Up @@ -14,6 +14,7 @@

#include "integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/tf_tensorlist_dialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
Expand All @@ -27,6 +28,9 @@ namespace tf_tensorlist {
class ConvertTfToTfTensorList
: public PassWrapper<ConvertTfToTfTensorList, OperationPass<FuncOp>> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<TfTensorListDialect>();
}
void runOnOperation() override;
};

Expand Down
163 changes: 163 additions & 0 deletions integrations/tensorflow/compiler/tf_opt_main.cc
@@ -0,0 +1,163 @@
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Main entry function for iree-tf-opt and derived binaries.
//
// Based on iree-opt with the addition of TF dialects and passes

#include "integrations/tensorflow/compiler/dialect/tf_strings/ir/dialect.h"
#include "integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/tf_tensorlist_dialect.h"
#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
#include "iree/compiler/Conversion/init_conversions.h"
#include "iree/compiler/Dialect/HAL/Conversion/Passes.h"
#include "iree/tools/init_compiler_modules.h"
#include "iree/tools/init_iree_dialects.h"
#include "iree/tools/init_iree_passes.h"
#include "iree/tools/init_mlir_dialects.h"
#include "iree/tools/init_mlir_passes.h"
#include "iree/tools/init_targets.h"
#include "iree/tools/init_xla_dialects.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/MlirOptMain.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"

#ifdef IREE_HAVE_EMITC_DIALECT
#include "emitc/InitDialect.h"
#endif // IREE_HAVE_EMITC_DIALECT

static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
llvm::cl::desc("<input file>"),
llvm::cl::init("-"));

static llvm::cl::opt<std::string> outputFilename(
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
llvm::cl::init("-"));

static llvm::cl::opt<bool> splitInputFile(
"split-input-file",
llvm::cl::desc("Split the input file into pieces and process each "
"chunk independently"),
llvm::cl::init(false));

static llvm::cl::opt<bool> verifyDiagnostics(
"verify-diagnostics",
llvm::cl::desc("Check that emitted diagnostics match "
"expected-* lines on the corresponding line"),
llvm::cl::init(false));

static llvm::cl::opt<bool> verifyPasses(
"verify-each",
llvm::cl::desc("Run the verifier after each transformation pass"),
llvm::cl::init(true));

static llvm::cl::opt<bool> allowUnregisteredDialects(
"allow-unregistered-dialect",
llvm::cl::desc("Allow operation with no registered dialects"),
llvm::cl::init(true));

static llvm::cl::opt<bool> showDialects(
"show-dialects", llvm::cl::desc("Print the list of registered dialects"),
llvm::cl::init(false));

void registerTFDialects(mlir::DialectRegistry &registry) {
registry.insert<mlir::TF::TensorFlowDialect,
mlir::tf_executor::TensorFlowExecutorDialect,
mlir::tf_device::TensorFlowDeviceDialect,
mlir::tf_saved_model::TensorFlowSavedModelDialect>();
}

void registerExtensionDialects(mlir::DialectRegistry &registry) {
registry.insert<mlir::iree_compiler::tf_strings::TFStringsDialect,
mlir::tf_tensorlist::TfTensorListDialect>();
}

int main(int argc, char **argv) {
// TODO(#2958): There's a lot of duplication with iree-opt here. Factor out
// the common functionality.
llvm::InitLLVM y(argc, argv);

mlir::DialectRegistry registry;
mlir::registerMlirDialects(registry);
mlir::registerMlirPasses();
#ifdef IREE_HAVE_EMITC_DIALECT
mlir::registerEmitCDialect(registry);
#endif // IREE_HAVE_EMITC_DIALECT
mlir::registerXLADialects(registry);
mlir::iree_compiler::registerIreeDialects(registry);
mlir::iree_compiler::registerIreeCompilerModuleDialects(registry);
registerTFDialects(registry);
registerExtensionDialects(registry);

mlir::iree_compiler::registerAllIreePasses();
mlir::iree_compiler::registerHALConversionPasses();
mlir::iree_compiler::registerHALTargetBackends();
mlir::iree_compiler::registerLinalgToSPIRVPasses();
mlir::iree_compiler::registerHLOToLinalgPasses();
mlir::iree_compiler::registerLinalgToLLVMPasses();

// Register MLIRContext command-line options like
// -mlir-print-op-on-diagnostic.
mlir::registerMLIRContextCLOptions();
// Register assembly printer command-line options like
// -mlir-print-op-generic.
mlir::registerAsmPrinterCLOptions();
// Register pass manager command-line options like -print-ir-*.
mlir::registerPassManagerCLOptions();

mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run");

// Parse pass names in main to ensure static initialization completed.
llvm::cl::ParseCommandLineOptions(argc, argv,
"IREE modular optimizer driver\n");

if (showDialects) {
llvm::outs() << "Available Dialects:\n";
interleave(
registry, llvm::outs(),
[](auto &registryEntry) { llvm::outs() << registryEntry.first; }, "\n");
return 0;
}

// Set up the input file.
std::string errorMessage;
auto file = mlir::openInputFile(inputFilename, &errorMessage);
if (!file) {
llvm::errs() << errorMessage << "\n";
return 1;
}

auto output = mlir::openOutputFile(outputFilename, &errorMessage);
if (!output) {
llvm::errs() << errorMessage << "\n";
exit(1);
}

if (failed(mlir::MlirOptMain(output->os(), std::move(file), passPipeline,
registry, splitInputFile, verifyDiagnostics,
verifyPasses, allowUnregisteredDialects,
/*preloadDialectsInContext=*/false))) {
GMNGeoffrey marked this conversation as resolved.
Show resolved Hide resolved
return 1;
}
}