-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][SPIR-V] Update the ConvertToSPIRV
pass to use dialect interfaces
#102046
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…aces. This patch updates the base implementation of `ConvertToSPIRV` to be more like the implementation of `ConvertToLLVM`. `ConvertToLLVM` relies on dialect interfaces for configuring the conversion, allowing out-of-tree dialects to participate in the pass if they implement the interface. This patch introduces the `ConvertToSPIRVPatternInterface` dialect interface, allowing the configuration of the conversion to SPIR-V on a dialect per dialect basis. Finally, this patch adds the dialect interfaces for all previously supported dialects in the previous implementation of the `ConvertToSPIRV` pass. Note: The convert SCF to SPIR-V was left inside the pass, as it depends on the `ScfToSPIRVContext`, a TODO for a future patch is removing this issue.
3130abb
to
a0d77ef
Compare
@llvm/pr-subscribers-mlir-complex @llvm/pr-subscribers-mlir-memref Author: Fabian Mora (fabianmcg) ChangesThis patch updates the base implementation of This patch introduces the Finally, this patch adds the dialect interfaces for all previously supported dialects in the previous implementation of the Note: Patch is 46.72 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/102046.diff 32 Files Affected:
diff --git a/mlir/include/mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h b/mlir/include/mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h
index bb30deb9dc10e..cadf0b2872bea 100644
--- a/mlir/include/mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h
+++ b/mlir/include/mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h
@@ -26,6 +26,10 @@ void populateArithToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
std::unique_ptr<OperationPass<>> createConvertArithToSPIRVPass();
+
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `arith`
+/// dialect.
+void registerConvertArithToSPIRVInterface(DialectRegistry ®istry);
} // namespace arith
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h b/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h
index 43578ffffae2d..276818973c3f8 100644
--- a/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h
+++ b/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h
@@ -14,6 +14,7 @@
#define MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRV_H
namespace mlir {
+class DialectRegistry;
class RewritePatternSet;
class SPIRVTypeConverter;
@@ -22,6 +23,10 @@ namespace cf {
/// ops to SPIR-V ops.
void populateControlFlowToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
+
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `cf`
+/// dialect.
+void registerConvertControlFlowToSPIRVInterface(DialectRegistry ®istry);
} // namespace cf
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h b/mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h
index 3852782247527..3062eb5464c53 100644
--- a/mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h
+++ b/mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h
@@ -11,12 +11,19 @@
#include <memory>
+#include "mlir/Pass/Pass.h"
+
namespace mlir {
class Pass;
+class DialectRegistry;
#define GEN_PASS_DECL_CONVERTTOSPIRVPASS
#include "mlir/Conversion/Passes.h.inc"
+/// Register the extension that will load dependent dialects for SPIR-V
+/// conversion. This is useful to implement a pass similar to
+/// "convert-to-spirv".
+void registerConvertToSPIRVDependentDialectLoading(DialectRegistry ®istry);
} // namespace mlir
#endif // MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H
diff --git a/mlir/include/mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h b/mlir/include/mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h
new file mode 100644
index 0000000000000..917b81dd237e2
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h
@@ -0,0 +1,55 @@
+//===- ToSPIRVInterface.h - Conversion to SPIRV iface -*- C++ -*-=============//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_CONVERTTOSPIRV_TOSPIRVINTERFACE_H
+#define MLIR_CONVERSION_CONVERTTOSPIRV_TOSPIRVINTERFACE_H
+
+#include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/MLIRContext.h"
+
+namespace mlir {
+class ConversionTarget;
+class SPIRVTypeConverter;
+class MLIRContext;
+class Operation;
+class RewritePatternSet;
+
+/// Base class for dialect interfaces providing translation to SPIR-V.
+/// Dialects that can be translated should provide an implementation of this
+/// interface for the supported operations. The interface may be implemented in
+/// a separate library to avoid the "main" dialect library depending on SPIR-V
+/// IR. The interface can be attached using the delayed registration mechanism
+/// available in DialectRegistry.
+class ConvertToSPIRVPatternInterface
+ : public DialectInterface::Base<ConvertToSPIRVPatternInterface> {
+public:
+ ConvertToSPIRVPatternInterface(Dialect *dialect) : Base(dialect) {}
+
+ /// Hook for derived dialect interface to load the dialects they
+ /// target. The SPIRVDialect is implicitly already loaded, but this
+ /// method allows to load other intermediate dialects used in the
+ /// conversion.
+ virtual void loadDependentDialects(MLIRContext *context) const {}
+
+ /// Hook for derived dialect interface to provide conversion patterns
+ /// and mark dialect legal for the conversion target.
+ virtual void populateConvertToSPIRVConversionPatterns(
+ ConversionTarget &target, SPIRVTypeConverter &typeConverter,
+ RewritePatternSet &patterns) const = 0;
+};
+
+/// Recursively walk the IR and collect all dialects implementing the interface,
+/// and populate the conversion patterns.
+void populateConversionTargetFromOperation(Operation *op,
+ ConversionTarget &target,
+ SPIRVTypeConverter &typeConverter,
+ RewritePatternSet &patterns);
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_CONVERTTOSPIRV_TOSPIRVINTERFACE_H
diff --git a/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h b/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h
index 2fa55f40dd970..42711fb6e4b51 100644
--- a/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h
+++ b/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h
@@ -24,6 +24,9 @@ class SPIRVTypeConverter;
void populateFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `func`
+/// dialect.
+void registerConvertFuncToSPIRVInterface(DialectRegistry ®istry);
} // namespace mlir
#endif // MLIR_CONVERSION_FUNCTOSPIRV_FUNCTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h b/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h
index 58a1c5246eef9..fad570591983c 100644
--- a/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h
+++ b/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h
@@ -24,6 +24,10 @@ namespace index {
void populateIndexToSPIRVPatterns(SPIRVTypeConverter &converter,
RewritePatternSet &patterns);
std::unique_ptr<OperationPass<>> createConvertIndexToSPIRVPass();
+
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `index`
+/// dialect.
+void registerConvertIndexToSPIRVInterface(DialectRegistry ®istry);
} // namespace index
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRV.h b/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRV.h
index 10090268a4663..9a9edc87f3446 100644
--- a/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRV.h
+++ b/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRV.h
@@ -23,6 +23,9 @@ class SPIRVTypeConverter;
void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `math`
+/// dialect.
+void registerConvertMathToSPIRVInterface(DialectRegistry ®istry);
} // namespace mlir
#endif // MLIR_CONVERSION_MATHTOSPIRV_MATHTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
index 54711c8ad727f..77f6cdd2935df 100644
--- a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
+++ b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
@@ -70,6 +70,9 @@ void convertMemRefTypesAndAttrs(
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `memref`
+/// dialect.
+void registerConvertMemRefToSPIRVInterface(DialectRegistry ®istry);
} // namespace mlir
#endif // MLIR_CONVERSION_MEMREFTOSPIRV_MEMREFTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b5bb2f42f2961..6a7d1434dd66d 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -45,6 +45,8 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
"vector::VectorDialect",
];
let options = [
+ ListOption<"filterDialects", "filter-dialects", "std::string",
+ "Test conversion patterns of only the specified dialects">,
Option<"runSignatureConversion", "run-signature-conversion", "bool",
/*default=*/"true",
"Run function signature conversion to convert vector types">,
diff --git a/mlir/include/mlir/Conversion/UBToSPIRV/UBToSPIRV.h b/mlir/include/mlir/Conversion/UBToSPIRV/UBToSPIRV.h
index 3843f2707a520..88cb58df4fc69 100644
--- a/mlir/include/mlir/Conversion/UBToSPIRV/UBToSPIRV.h
+++ b/mlir/include/mlir/Conversion/UBToSPIRV/UBToSPIRV.h
@@ -12,7 +12,7 @@
#include <memory>
namespace mlir {
-
+class DialectRegistry;
class SPIRVTypeConverter;
class RewritePatternSet;
class Pass;
@@ -23,6 +23,10 @@ class Pass;
namespace ub {
void populateUBToSPIRVConversionPatterns(SPIRVTypeConverter &converter,
RewritePatternSet &patterns);
+
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `ub`
+/// dialect.
+void registerConvertUBToSPIRVInterface(DialectRegistry ®istry);
} // namespace ub
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
index f8c02c54066b8..5184b82c33faf 100644
--- a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
+++ b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
@@ -32,6 +32,9 @@ void populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
void populateVectorReductionToSPIRVDotProductPatterns(
RewritePatternSet &patterns);
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `vector`
+/// dialect.
+void registerConvertVectorToSPIRVInterface(DialectRegistry ®istry);
} // namespace mlir
#endif // MLIR_CONVERSION_VECTORTOSPIRV_VECTORTOSPIRV_H
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 20a4ab6f18a28..d3aab3a0ff8df 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -15,14 +15,22 @@
#define MLIR_INITALLEXTENSIONS_H_
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
+#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
+#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
+#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
+#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
+#include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
+#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
+#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
+#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
@@ -66,6 +74,16 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
registerConvertNVVMToLLVMInterface(registry);
ub::registerConvertUBToLLVMInterface(registry);
+ // Register all conversions to SPIR-V extensions.
+ arith::registerConvertArithToSPIRVInterface(registry);
+ cf::registerConvertControlFlowToSPIRVInterface(registry);
+ registerConvertFuncToSPIRVInterface(registry);
+ index::registerConvertIndexToSPIRVInterface(registry);
+ registerConvertMathToSPIRVInterface(registry);
+ registerConvertMemRefToSPIRVInterface(registry);
+ ub::registerConvertUBToSPIRVInterface(registry);
+ registerConvertVectorToSPIRVInterface(registry);
+
// Register all transform dialect extensions.
affine::registerTransformDialectExtension(registry);
bufferization::registerTransformDialectExtension(registry);
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index e6c01f063e8b8..603d96462abb5 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -9,7 +9,9 @@
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "../SPIRVCommon/Pattern.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
@@ -1367,3 +1369,39 @@ struct ConvertArithToSPIRVPass
std::unique_ptr<OperationPass<>> mlir::arith::createConvertArithToSPIRVPass() {
return std::make_unique<ConvertArithToSPIRVPass>();
}
+
+//===----------------------------------------------------------------------===//
+// ConvertToSPIRVPatternInterface implementation
+//===----------------------------------------------------------------------===//
+namespace {
+/// Implement the interface to convert arith to SPIR-V.
+struct ToSPIRVDialectInterface : public ConvertToSPIRVPatternInterface {
+ using ConvertToSPIRVPatternInterface::ConvertToSPIRVPatternInterface;
+ void loadDependentDialects(MLIRContext *context) const final {
+ context->loadDialect<spirv::SPIRVDialect>();
+ }
+
+ /// Hook for derived dialect interface to provide conversion patterns
+ /// and mark dialect legal for the conversion target.
+ void populateConvertToSPIRVConversionPatterns(
+ ConversionTarget &target, SPIRVTypeConverter &typeConverter,
+ RewritePatternSet &patterns) const final {
+ arith::populateCeilFloorDivExpandOpsPatterns(patterns);
+ arith::populateArithToSPIRVPatterns(typeConverter, patterns);
+
+ // Use UnrealizedConversionCast as the bridge so that we don't need to pull
+ // in patterns for other dialects.
+ target.addLegalOp<UnrealizedConversionCastOp>();
+
+ // Fail hard when there are any remaining 'arith' ops.
+ target.addIllegalDialect<arith::ArithDialect>();
+ }
+};
+} // namespace
+
+void mlir::arith::registerConvertArithToSPIRVInterface(
+ DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
+ dialect->addInterfaces<ToSPIRVDialectInterface>();
+ });
+}
diff --git a/mlir/lib/Conversion/ArithToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ArithToSPIRV/CMakeLists.txt
index a5385d9cee6af..0ddb1700e4922 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/ArithToSPIRV/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRArithToSPIRV
LINK_LIBS PUBLIC
MLIRArithDialect
+ MLIRArithTransforms
MLIRFuncToSPIRV
MLIRSPIRVConversion
MLIRSPIRVDialect
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
index f96bfd6f788b9..1e701f729e1ea 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
@@ -12,6 +12,7 @@
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
#include "../SPIRVCommon/Pattern.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
@@ -114,3 +115,32 @@ void mlir::cf::populateControlFlowToSPIRVPatterns(
patterns.add<BranchOpPattern, CondBranchOpPattern>(typeConverter, context);
}
+
+//===----------------------------------------------------------------------===//
+// ConvertToSPIRVPatternInterface implementation
+//===----------------------------------------------------------------------===//
+namespace {
+/// Implement the interface to convert cf to SPIR-V.
+struct ToSPIRVDialectInterface : public ConvertToSPIRVPatternInterface {
+ using ConvertToSPIRVPatternInterface::ConvertToSPIRVPatternInterface;
+ void loadDependentDialects(MLIRContext *context) const final {
+ context->loadDialect<spirv::SPIRVDialect>();
+ }
+
+ /// Hook for derived dialect interface to provide conversion patterns
+ /// and mark dialect legal for the conversion target.
+ void populateConvertToSPIRVConversionPatterns(
+ ConversionTarget &target, SPIRVTypeConverter &typeConverter,
+ RewritePatternSet &patterns) const final {
+ // TODO: We should also take care of block argument type conversion.
+ cf::populateControlFlowToSPIRVPatterns(typeConverter, patterns);
+ }
+};
+} // namespace
+
+void mlir::cf::registerConvertControlFlowToSPIRVInterface(
+ DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) {
+ dialect->addInterfaces<ToSPIRVDialectInterface>();
+ });
+}
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
index c9d962d2de23f..6befb25234d3b 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
@@ -1,5 +1,6 @@
set(LLVM_OPTIONAL_SOURCES
ConvertToSPIRVPass.cpp
+ ToSPIRVInterface.cpp
)
add_mlir_conversion_library(MLIRConvertToSPIRVPass
@@ -12,10 +13,6 @@ add_mlir_conversion_library(MLIRConvertToSPIRVPass
MLIRConversionPassIncGen
LINK_LIBS PUBLIC
- MLIRArithToSPIRV
- MLIRArithTransforms
- MLIRFuncToSPIRV
- MLIRIndexToSPIRV
MLIRIR
MLIRPass
MLIRRewrite
@@ -24,10 +21,15 @@ add_mlir_conversion_library(MLIRConvertToSPIRVPass
MLIRSPIRVDialect
MLIRSPIRVTransforms
MLIRSupport
- MLIRTransforms
MLIRTransformUtils
- MLIRUBToSPIRV
- MLIRVectorDialect
- MLIRVectorToSPIRV
- MLIRVectorTransforms
)
+
+add_mlir_conversion_library(MLIRConvertToSPIRVInterface
+ ToSPIRVInterface.cpp
+
+ DEPENDS
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRSupport
+)
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 4694a147e1e94..9b5780fd95dd0 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -7,24 +7,15 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
-#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
-#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
-#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
-#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
-#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
-#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
-#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <memory>
#define DEBUG_TYPE "convert-to-spirv"
@@ -37,15 +28,91 @@ namespace mlir {
using namespace mlir;
namespace {
+/// This DialectExtension can be attached to the context, which will invoke the
+/// `apply()` method for every loaded dialect. If a dialect implements the
+/// `ConvertToSPIRVPatternInterface` interface, we load dependent dialects
+/// through the int...
[truncated]
|
@llvm/pr-subscribers-mlir-vector Author: Fabian Mora (fabianmcg) ChangesThis patch updates the base implementation of This patch introduces the Finally, this patch adds the dialect interfaces for all previously supported dialects in the previous implementation of the Note: Patch is 46.72 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/102046.diff 32 Files Affected:
diff --git a/mlir/include/mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h b/mlir/include/mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h
index bb30deb9dc10e..cadf0b2872bea 100644
--- a/mlir/include/mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h
+++ b/mlir/include/mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h
@@ -26,6 +26,10 @@ void populateArithToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
std::unique_ptr<OperationPass<>> createConvertArithToSPIRVPass();
+
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `arith`
+/// dialect.
+void registerConvertArithToSPIRVInterface(DialectRegistry ®istry);
} // namespace arith
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h b/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h
index 43578ffffae2d..276818973c3f8 100644
--- a/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h
+++ b/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h
@@ -14,6 +14,7 @@
#define MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRV_H
namespace mlir {
+class DialectRegistry;
class RewritePatternSet;
class SPIRVTypeConverter;
@@ -22,6 +23,10 @@ namespace cf {
/// ops to SPIR-V ops.
void populateControlFlowToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
+
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `cf`
+/// dialect.
+void registerConvertControlFlowToSPIRVInterface(DialectRegistry ®istry);
} // namespace cf
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h b/mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h
index 3852782247527..3062eb5464c53 100644
--- a/mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h
+++ b/mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h
@@ -11,12 +11,19 @@
#include <memory>
+#include "mlir/Pass/Pass.h"
+
namespace mlir {
class Pass;
+class DialectRegistry;
#define GEN_PASS_DECL_CONVERTTOSPIRVPASS
#include "mlir/Conversion/Passes.h.inc"
+/// Register the extension that will load dependent dialects for SPIR-V
+/// conversion. This is useful to implement a pass similar to
+/// "convert-to-spirv".
+void registerConvertToSPIRVDependentDialectLoading(DialectRegistry ®istry);
} // namespace mlir
#endif // MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H
diff --git a/mlir/include/mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h b/mlir/include/mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h
new file mode 100644
index 0000000000000..917b81dd237e2
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h
@@ -0,0 +1,55 @@
+//===- ToSPIRVInterface.h - Conversion to SPIRV iface -*- C++ -*-=============//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_CONVERTTOSPIRV_TOSPIRVINTERFACE_H
+#define MLIR_CONVERSION_CONVERTTOSPIRV_TOSPIRVINTERFACE_H
+
+#include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/MLIRContext.h"
+
+namespace mlir {
+class ConversionTarget;
+class SPIRVTypeConverter;
+class MLIRContext;
+class Operation;
+class RewritePatternSet;
+
+/// Base class for dialect interfaces providing translation to SPIR-V.
+/// Dialects that can be translated should provide an implementation of this
+/// interface for the supported operations. The interface may be implemented in
+/// a separate library to avoid the "main" dialect library depending on SPIR-V
+/// IR. The interface can be attached using the delayed registration mechanism
+/// available in DialectRegistry.
+class ConvertToSPIRVPatternInterface
+ : public DialectInterface::Base<ConvertToSPIRVPatternInterface> {
+public:
+ ConvertToSPIRVPatternInterface(Dialect *dialect) : Base(dialect) {}
+
+ /// Hook for derived dialect interface to load the dialects they
+ /// target. The SPIRVDialect is implicitly already loaded, but this
+ /// method allows to load other intermediate dialects used in the
+ /// conversion.
+ virtual void loadDependentDialects(MLIRContext *context) const {}
+
+ /// Hook for derived dialect interface to provide conversion patterns
+ /// and mark dialect legal for the conversion target.
+ virtual void populateConvertToSPIRVConversionPatterns(
+ ConversionTarget &target, SPIRVTypeConverter &typeConverter,
+ RewritePatternSet &patterns) const = 0;
+};
+
+/// Recursively walk the IR and collect all dialects implementing the interface,
+/// and populate the conversion patterns.
+void populateConversionTargetFromOperation(Operation *op,
+ ConversionTarget &target,
+ SPIRVTypeConverter &typeConverter,
+ RewritePatternSet &patterns);
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_CONVERTTOSPIRV_TOSPIRVINTERFACE_H
diff --git a/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h b/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h
index 2fa55f40dd970..42711fb6e4b51 100644
--- a/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h
+++ b/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h
@@ -24,6 +24,9 @@ class SPIRVTypeConverter;
void populateFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `func`
+/// dialect.
+void registerConvertFuncToSPIRVInterface(DialectRegistry ®istry);
} // namespace mlir
#endif // MLIR_CONVERSION_FUNCTOSPIRV_FUNCTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h b/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h
index 58a1c5246eef9..fad570591983c 100644
--- a/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h
+++ b/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h
@@ -24,6 +24,10 @@ namespace index {
void populateIndexToSPIRVPatterns(SPIRVTypeConverter &converter,
RewritePatternSet &patterns);
std::unique_ptr<OperationPass<>> createConvertIndexToSPIRVPass();
+
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `index`
+/// dialect.
+void registerConvertIndexToSPIRVInterface(DialectRegistry ®istry);
} // namespace index
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRV.h b/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRV.h
index 10090268a4663..9a9edc87f3446 100644
--- a/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRV.h
+++ b/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRV.h
@@ -23,6 +23,9 @@ class SPIRVTypeConverter;
void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `math`
+/// dialect.
+void registerConvertMathToSPIRVInterface(DialectRegistry ®istry);
} // namespace mlir
#endif // MLIR_CONVERSION_MATHTOSPIRV_MATHTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
index 54711c8ad727f..77f6cdd2935df 100644
--- a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
+++ b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
@@ -70,6 +70,9 @@ void convertMemRefTypesAndAttrs(
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `memref`
+/// dialect.
+void registerConvertMemRefToSPIRVInterface(DialectRegistry ®istry);
} // namespace mlir
#endif // MLIR_CONVERSION_MEMREFTOSPIRV_MEMREFTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b5bb2f42f2961..6a7d1434dd66d 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -45,6 +45,8 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
"vector::VectorDialect",
];
let options = [
+ ListOption<"filterDialects", "filter-dialects", "std::string",
+ "Test conversion patterns of only the specified dialects">,
Option<"runSignatureConversion", "run-signature-conversion", "bool",
/*default=*/"true",
"Run function signature conversion to convert vector types">,
diff --git a/mlir/include/mlir/Conversion/UBToSPIRV/UBToSPIRV.h b/mlir/include/mlir/Conversion/UBToSPIRV/UBToSPIRV.h
index 3843f2707a520..88cb58df4fc69 100644
--- a/mlir/include/mlir/Conversion/UBToSPIRV/UBToSPIRV.h
+++ b/mlir/include/mlir/Conversion/UBToSPIRV/UBToSPIRV.h
@@ -12,7 +12,7 @@
#include <memory>
namespace mlir {
-
+class DialectRegistry;
class SPIRVTypeConverter;
class RewritePatternSet;
class Pass;
@@ -23,6 +23,10 @@ class Pass;
namespace ub {
void populateUBToSPIRVConversionPatterns(SPIRVTypeConverter &converter,
RewritePatternSet &patterns);
+
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `ub`
+/// dialect.
+void registerConvertUBToSPIRVInterface(DialectRegistry ®istry);
} // namespace ub
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
index f8c02c54066b8..5184b82c33faf 100644
--- a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
+++ b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
@@ -32,6 +32,9 @@ void populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
void populateVectorReductionToSPIRVDotProductPatterns(
RewritePatternSet &patterns);
+/// Registers the `ConvertToSPIRVPatternInterface` interface in the `vector`
+/// dialect.
+void registerConvertVectorToSPIRVInterface(DialectRegistry ®istry);
} // namespace mlir
#endif // MLIR_CONVERSION_VECTORTOSPIRV_VECTORTOSPIRV_H
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 20a4ab6f18a28..d3aab3a0ff8df 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -15,14 +15,22 @@
#define MLIR_INITALLEXTENSIONS_H_
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
+#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
+#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
+#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
+#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
+#include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
+#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
+#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
+#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
@@ -66,6 +74,16 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
registerConvertNVVMToLLVMInterface(registry);
ub::registerConvertUBToLLVMInterface(registry);
+ // Register all conversions to SPIR-V extensions.
+ arith::registerConvertArithToSPIRVInterface(registry);
+ cf::registerConvertControlFlowToSPIRVInterface(registry);
+ registerConvertFuncToSPIRVInterface(registry);
+ index::registerConvertIndexToSPIRVInterface(registry);
+ registerConvertMathToSPIRVInterface(registry);
+ registerConvertMemRefToSPIRVInterface(registry);
+ ub::registerConvertUBToSPIRVInterface(registry);
+ registerConvertVectorToSPIRVInterface(registry);
+
// Register all transform dialect extensions.
affine::registerTransformDialectExtension(registry);
bufferization::registerTransformDialectExtension(registry);
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index e6c01f063e8b8..603d96462abb5 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -9,7 +9,9 @@
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "../SPIRVCommon/Pattern.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
@@ -1367,3 +1369,39 @@ struct ConvertArithToSPIRVPass
std::unique_ptr<OperationPass<>> mlir::arith::createConvertArithToSPIRVPass() {
return std::make_unique<ConvertArithToSPIRVPass>();
}
+
+//===----------------------------------------------------------------------===//
+// ConvertToSPIRVPatternInterface implementation
+//===----------------------------------------------------------------------===//
+namespace {
+/// Implement the interface to convert arith to SPIR-V.
+struct ToSPIRVDialectInterface : public ConvertToSPIRVPatternInterface {
+ using ConvertToSPIRVPatternInterface::ConvertToSPIRVPatternInterface;
+ void loadDependentDialects(MLIRContext *context) const final {
+ context->loadDialect<spirv::SPIRVDialect>();
+ }
+
+ /// Hook for derived dialect interface to provide conversion patterns
+ /// and mark dialect legal for the conversion target.
+ void populateConvertToSPIRVConversionPatterns(
+ ConversionTarget &target, SPIRVTypeConverter &typeConverter,
+ RewritePatternSet &patterns) const final {
+ arith::populateCeilFloorDivExpandOpsPatterns(patterns);
+ arith::populateArithToSPIRVPatterns(typeConverter, patterns);
+
+ // Use UnrealizedConversionCast as the bridge so that we don't need to pull
+ // in patterns for other dialects.
+ target.addLegalOp<UnrealizedConversionCastOp>();
+
+ // Fail hard when there are any remaining 'arith' ops.
+ target.addIllegalDialect<arith::ArithDialect>();
+ }
+};
+} // namespace
+
+void mlir::arith::registerConvertArithToSPIRVInterface(
+ DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
+ dialect->addInterfaces<ToSPIRVDialectInterface>();
+ });
+}
diff --git a/mlir/lib/Conversion/ArithToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ArithToSPIRV/CMakeLists.txt
index a5385d9cee6af..0ddb1700e4922 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/ArithToSPIRV/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRArithToSPIRV
LINK_LIBS PUBLIC
MLIRArithDialect
+ MLIRArithTransforms
MLIRFuncToSPIRV
MLIRSPIRVConversion
MLIRSPIRVDialect
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
index f96bfd6f788b9..1e701f729e1ea 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
@@ -12,6 +12,7 @@
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
#include "../SPIRVCommon/Pattern.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
@@ -114,3 +115,32 @@ void mlir::cf::populateControlFlowToSPIRVPatterns(
patterns.add<BranchOpPattern, CondBranchOpPattern>(typeConverter, context);
}
+
+//===----------------------------------------------------------------------===//
+// ConvertToSPIRVPatternInterface implementation
+//===----------------------------------------------------------------------===//
+namespace {
+/// Implement the interface to convert cf to SPIR-V.
+struct ToSPIRVDialectInterface : public ConvertToSPIRVPatternInterface {
+ using ConvertToSPIRVPatternInterface::ConvertToSPIRVPatternInterface;
+ void loadDependentDialects(MLIRContext *context) const final {
+ context->loadDialect<spirv::SPIRVDialect>();
+ }
+
+ /// Hook for derived dialect interface to provide conversion patterns
+ /// and mark dialect legal for the conversion target.
+ void populateConvertToSPIRVConversionPatterns(
+ ConversionTarget &target, SPIRVTypeConverter &typeConverter,
+ RewritePatternSet &patterns) const final {
+ // TODO: We should also take care of block argument type conversion.
+ cf::populateControlFlowToSPIRVPatterns(typeConverter, patterns);
+ }
+};
+} // namespace
+
+void mlir::cf::registerConvertControlFlowToSPIRVInterface(
+ DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) {
+ dialect->addInterfaces<ToSPIRVDialectInterface>();
+ });
+}
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
index c9d962d2de23f..6befb25234d3b 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
@@ -1,5 +1,6 @@
set(LLVM_OPTIONAL_SOURCES
ConvertToSPIRVPass.cpp
+ ToSPIRVInterface.cpp
)
add_mlir_conversion_library(MLIRConvertToSPIRVPass
@@ -12,10 +13,6 @@ add_mlir_conversion_library(MLIRConvertToSPIRVPass
MLIRConversionPassIncGen
LINK_LIBS PUBLIC
- MLIRArithToSPIRV
- MLIRArithTransforms
- MLIRFuncToSPIRV
- MLIRIndexToSPIRV
MLIRIR
MLIRPass
MLIRRewrite
@@ -24,10 +21,15 @@ add_mlir_conversion_library(MLIRConvertToSPIRVPass
MLIRSPIRVDialect
MLIRSPIRVTransforms
MLIRSupport
- MLIRTransforms
MLIRTransformUtils
- MLIRUBToSPIRV
- MLIRVectorDialect
- MLIRVectorToSPIRV
- MLIRVectorTransforms
)
+
+add_mlir_conversion_library(MLIRConvertToSPIRVInterface
+ ToSPIRVInterface.cpp
+
+ DEPENDS
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRSupport
+)
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 4694a147e1e94..9b5780fd95dd0 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -7,24 +7,15 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
-#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
-#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
-#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
+#include "mlir/Conversion/ConvertToSPIRV/ToSPIRVInterface.h"
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
-#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
-#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
-#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
-#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <memory>
#define DEBUG_TYPE "convert-to-spirv"
@@ -37,15 +28,91 @@ namespace mlir {
using namespace mlir;
namespace {
+/// This DialectExtension can be attached to the context, which will invoke the
+/// `apply()` method for every loaded dialect. If a dialect implements the
+/// `ConvertToSPIRVPatternInterface` interface, we load dependent dialects
+/// through the int...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is interesting! What's your motivation for this beyond what's explained in the PR description? Is this to better align with the llvm conversion or do you have some out of tree dialects that you would like to support?
The reason I'm asking is that converting to SPIR-V has a set of issues that are unique to it and make it harder (I think?) than conversion to llvm:
- It does not support arbitrary (1d) vector types
- It does not support arbitrary integer types
- I think arrays are also more limited
- Not all ops are converted directly and some require expansion / emulation before conversion (e.g., masked vector load / store).
Because of these complications, our plan immediate was to stick with handcrafted conversion that happens in a few phases.
It's both, I think it should align with
I noticed the vector issue with I think some of the issues you mention can be solved within the dialect conversion infrastructure, or with missing intermediate ops. We can talk further about this. |
After talking with @kuhar we decided it's best to close this PR as the conversion to SPIR-V still needs to mature more before productivizing this pass. |
This patch updates the base implementation of
ConvertToSPIRV
to be more like the implementation ofConvertToLLVM
.ConvertToLLVM
relies on dialect interfaces for configuring the conversion, allowing out-of-tree dialects to participate in the pass if they implement the interface.This patch introduces the
ConvertToSPIRVPatternInterface
dialect interface, allowing the configuration of the conversion to SPIR-V on a dialect per dialect basis.Finally, this patch adds the dialect interfaces for all previously supported dialects in the previous implementation of the
ConvertToSPIRV
pass.Note:
The convert SCF to SPIR-V was left inside the pass, as it depends on the
ScfToSPIRVContext
, a TODO for a future patch is removing this issue.