| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| //===- ControlFlowToLLVM.h - ControlFlow to LLVM -----------*- C++ ------*-===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
| // | ||
| // Define conversions from the ControlFlow dialect to the LLVM IR dialect. | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #ifndef MLIR_CONVERSION_CONTROLFLOWTOLLVM_CONTROLFLOWTOLLVM_H | ||
| #define MLIR_CONVERSION_CONTROLFLOWTOLLVM_CONTROLFLOWTOLLVM_H | ||
|
|
||
| #include <memory> | ||
|
|
||
| namespace mlir { | ||
| class LLVMTypeConverter; | ||
| class RewritePatternSet; | ||
| class Pass; | ||
|
|
||
| namespace cf { | ||
| /// Collect the patterns to convert from the ControlFlow dialect to LLVM. The | ||
| /// conversion patterns capture the LLVMTypeConverter by reference meaning the | ||
| /// references have to remain alive during the entire pattern lifetime. | ||
| void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter, | ||
| RewritePatternSet &patterns); | ||
|
|
||
| /// Creates a pass to convert the ControlFlow dialect into the LLVMIR dialect. | ||
| std::unique_ptr<Pass> createConvertControlFlowToLLVMPass(); | ||
| } // namespace cf | ||
| } // namespace mlir | ||
|
|
||
| #endif // MLIR_CONVERSION_CONTROLFLOWTOLLVM_CONTROLFLOWTOLLVM_H |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| //===- ControlFlowToSPIRV.h - CF to SPIR-V Patterns --------*- C++ ------*-===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
| // | ||
| // Provides patterns to convert ControlFlow dialect to SPIR-V dialect. | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #ifndef MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRV_H | ||
| #define MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRV_H | ||
|
|
||
| namespace mlir { | ||
| class RewritePatternSet; | ||
| class SPIRVTypeConverter; | ||
|
|
||
| namespace cf { | ||
| /// Appends to a pattern list additional patterns for translating ControlFLow | ||
| /// ops to SPIR-V ops. | ||
| void populateControlFlowToSPIRVPatterns(SPIRVTypeConverter &typeConverter, | ||
| RewritePatternSet &patterns); | ||
| } // namespace cf | ||
| } // namespace mlir | ||
|
|
||
| #endif // MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRV_H |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| //===- ConvertSCFToControlFlow.h - Pass entrypoint --------------*- C++ -*-===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #ifndef MLIR_CONVERSION_SCFTOCONTROLFLOW_SCFTOCONTROLFLOW_H_ | ||
| #define MLIR_CONVERSION_SCFTOCONTROLFLOW_SCFTOCONTROLFLOW_H_ | ||
|
|
||
| #include <memory> | ||
|
|
||
| namespace mlir { | ||
| class Pass; | ||
| class RewritePatternSet; | ||
|
|
||
| /// Collect a set of patterns to convert SCF operations to CFG branch-based | ||
| /// operations within the ControlFlow dialect. | ||
| void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns); | ||
|
|
||
| /// Creates a pass to convert SCF operations to CFG branch-based operation in | ||
| /// the ControlFlow dialect. | ||
| std::unique_ptr<Pass> createConvertSCFToCFPass(); | ||
|
|
||
| } // namespace mlir | ||
|
|
||
| #endif // MLIR_CONVERSION_SCFTOCONTROLFLOW_SCFTOCONTROLFLOW_H_ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| add_subdirectory(IR) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| add_mlir_dialect(ControlFlowOps cf ControlFlowOps) | ||
| add_mlir_doc(ControlFlowOps ControlFlowDialect Dialects/ -gen-dialect-doc) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| //===- ControlFlow.h - ControlFlow Dialect ----------------*- C++ -*-===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
| // | ||
| // This file defines the ControlFlow dialect. | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #ifndef MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOW_H | ||
| #define MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOW_H | ||
|
|
||
| #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" | ||
| #include "mlir/IR/Dialect.h" | ||
|
|
||
| #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.h.inc" | ||
|
|
||
| #endif // MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOW_H |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| //===- ControlFlowOps.h - ControlFlow Operations ----------------*- C++ -*-===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
| // | ||
| // This file defines the operations of the ControlFlow dialect. | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #ifndef MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOWOPS_H | ||
| #define MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOWOPS_H | ||
|
|
||
| #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" | ||
| #include "mlir/IR/Builders.h" | ||
| #include "mlir/IR/BuiltinTypes.h" | ||
| #include "mlir/IR/OpImplementation.h" | ||
| #include "mlir/Interfaces/ControlFlowInterfaces.h" | ||
| #include "mlir/Interfaces/SideEffectInterfaces.h" | ||
|
|
||
| namespace mlir { | ||
| class PatternRewriter; | ||
| } // namespace mlir | ||
|
|
||
| #define GET_OP_CLASSES | ||
| #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h.inc" | ||
|
|
||
| #endif // MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOWOPS_H |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,313 @@ | ||
| //===- ControlFlowOps.td - ControlFlow operations ----------*- tablegen -*-===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
| // | ||
| // This file contains definitions for the operations within the ControlFlow | ||
| // dialect. | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #ifndef STANDARD_OPS | ||
| #define STANDARD_OPS | ||
|
|
||
| include "mlir/IR/OpAsmInterface.td" | ||
| include "mlir/Interfaces/ControlFlowInterfaces.td" | ||
| include "mlir/Interfaces/SideEffectInterfaces.td" | ||
|
|
||
| def ControlFlow_Dialect : Dialect { | ||
| let name = "cf"; | ||
| let cppNamespace = "::mlir::cf"; | ||
| let dependentDialects = ["arith::ArithmeticDialect"]; | ||
| let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; | ||
| let description = [{ | ||
| This dialect contains low-level, i.e. non-region based, control flow | ||
| constructs. These constructs generally represent control flow directly | ||
| on SSA blocks of a control flow graph. | ||
| }]; | ||
| } | ||
|
|
||
| class CF_Op<string mnemonic, list<Trait> traits = []> : | ||
| Op<ControlFlow_Dialect, mnemonic, traits>; | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // AssertOp | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| def AssertOp : CF_Op<"assert"> { | ||
| let summary = "Assert operation with message attribute"; | ||
| let description = [{ | ||
| Assert operation with single boolean operand and an error message attribute. | ||
| If the argument is `true` this operation has no effect. Otherwise, the | ||
| program execution will abort. The provided error message may be used by a | ||
| runtime to propagate the error to the user. | ||
|
|
||
| Example: | ||
|
|
||
| ```mlir | ||
| assert %b, "Expected ... to be true" | ||
| ``` | ||
| }]; | ||
|
|
||
| let arguments = (ins I1:$arg, StrAttr:$msg); | ||
|
|
||
| let assemblyFormat = "$arg `,` $msg attr-dict"; | ||
| let hasCanonicalizeMethod = 1; | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // BranchOp | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| def BranchOp : CF_Op<"br", [ | ||
| DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>, | ||
| NoSideEffect, Terminator | ||
| ]> { | ||
| let summary = "branch operation"; | ||
| let description = [{ | ||
| The `cf.br` operation represents a direct branch operation to a given | ||
| block. The operands of this operation are forwarded to the successor block, | ||
| and the number and type of the operands must match the arguments of the | ||
| target block. | ||
|
|
||
| Example: | ||
|
|
||
| ```mlir | ||
| ^bb2: | ||
| %2 = call @someFn() | ||
| cf.br ^bb3(%2 : tensor<*xf32>) | ||
| ^bb3(%3: tensor<*xf32>): | ||
| ``` | ||
| }]; | ||
|
|
||
| let arguments = (ins Variadic<AnyType>:$destOperands); | ||
| let successors = (successor AnySuccessor:$dest); | ||
|
|
||
| let builders = [ | ||
| OpBuilder<(ins "Block *":$dest, | ||
| CArg<"ValueRange", "{}">:$destOperands), [{ | ||
| $_state.addSuccessors(dest); | ||
| $_state.addOperands(destOperands); | ||
| }]>]; | ||
|
|
||
| let extraClassDeclaration = [{ | ||
| void setDest(Block *block); | ||
|
|
||
| /// Erase the operand at 'index' from the operand list. | ||
| void eraseOperand(unsigned index); | ||
| }]; | ||
|
|
||
| let hasCanonicalizeMethod = 1; | ||
| let assemblyFormat = [{ | ||
| $dest (`(` $destOperands^ `:` type($destOperands) `)`)? attr-dict | ||
| }]; | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // CondBranchOp | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| def CondBranchOp : CF_Op<"cond_br", | ||
| [AttrSizedOperandSegments, | ||
| DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>, | ||
| NoSideEffect, Terminator]> { | ||
| let summary = "conditional branch operation"; | ||
| let description = [{ | ||
| The `cond_br` terminator operation represents a conditional branch on a | ||
| boolean (1-bit integer) value. If the bit is set, then the first destination | ||
| is jumped to; if it is false, the second destination is chosen. The count | ||
| and types of operands must align with the arguments in the corresponding | ||
| target blocks. | ||
|
|
||
| The MLIR conditional branch operation is not allowed to target the entry | ||
| block for a region. The two destinations of the conditional branch operation | ||
| are allowed to be the same. | ||
|
|
||
| The following example illustrates a function with a conditional branch | ||
| operation that targets the same block. | ||
|
|
||
| Example: | ||
|
|
||
| ```mlir | ||
| func @select(%a: i32, %b: i32, %flag: i1) -> i32 { | ||
| // Both targets are the same, operands differ | ||
| cond_br %flag, ^bb1(%a : i32), ^bb1(%b : i32) | ||
|
|
||
| ^bb1(%x : i32) : | ||
| return %x : i32 | ||
| } | ||
| ``` | ||
| }]; | ||
|
|
||
| let arguments = (ins I1:$condition, | ||
| Variadic<AnyType>:$trueDestOperands, | ||
| Variadic<AnyType>:$falseDestOperands); | ||
| let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest); | ||
|
|
||
| let builders = [ | ||
| OpBuilder<(ins "Value":$condition, "Block *":$trueDest, | ||
| "ValueRange":$trueOperands, "Block *":$falseDest, | ||
| "ValueRange":$falseOperands), [{ | ||
| build($_builder, $_state, condition, trueOperands, falseOperands, trueDest, | ||
| falseDest); | ||
| }]>, | ||
| OpBuilder<(ins "Value":$condition, "Block *":$trueDest, | ||
| "Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands), [{ | ||
| build($_builder, $_state, condition, trueDest, ValueRange(), falseDest, | ||
| falseOperands); | ||
| }]>]; | ||
|
|
||
| let extraClassDeclaration = [{ | ||
| // These are the indices into the dests list. | ||
| enum { trueIndex = 0, falseIndex = 1 }; | ||
|
|
||
| // Accessors for operands to the 'true' destination. | ||
| Value getTrueOperand(unsigned idx) { | ||
| assert(idx < getNumTrueOperands()); | ||
| return getOperand(getTrueDestOperandIndex() + idx); | ||
| } | ||
|
|
||
| void setTrueOperand(unsigned idx, Value value) { | ||
| assert(idx < getNumTrueOperands()); | ||
| setOperand(getTrueDestOperandIndex() + idx, value); | ||
| } | ||
|
|
||
| unsigned getNumTrueOperands() { return getTrueOperands().size(); } | ||
|
|
||
| /// Erase the operand at 'index' from the true operand list. | ||
| void eraseTrueOperand(unsigned index) { | ||
| getTrueDestOperandsMutable().erase(index); | ||
| } | ||
|
|
||
| // Accessors for operands to the 'false' destination. | ||
| Value getFalseOperand(unsigned idx) { | ||
| assert(idx < getNumFalseOperands()); | ||
| return getOperand(getFalseDestOperandIndex() + idx); | ||
| } | ||
| void setFalseOperand(unsigned idx, Value value) { | ||
| assert(idx < getNumFalseOperands()); | ||
| setOperand(getFalseDestOperandIndex() + idx, value); | ||
| } | ||
|
|
||
| operand_range getTrueOperands() { return getTrueDestOperands(); } | ||
| operand_range getFalseOperands() { return getFalseDestOperands(); } | ||
|
|
||
| unsigned getNumFalseOperands() { return getFalseOperands().size(); } | ||
|
|
||
| /// Erase the operand at 'index' from the false operand list. | ||
| void eraseFalseOperand(unsigned index) { | ||
| getFalseDestOperandsMutable().erase(index); | ||
| } | ||
|
|
||
| private: | ||
| /// Get the index of the first true destination operand. | ||
| unsigned getTrueDestOperandIndex() { return 1; } | ||
|
|
||
| /// Get the index of the first false destination operand. | ||
| unsigned getFalseDestOperandIndex() { | ||
| return getTrueDestOperandIndex() + getNumTrueOperands(); | ||
| } | ||
| }]; | ||
|
|
||
| let hasCanonicalizer = 1; | ||
| let assemblyFormat = [{ | ||
| $condition `,` | ||
| $trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,` | ||
| $falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)? | ||
| attr-dict | ||
| }]; | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // SwitchOp | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| def SwitchOp : CF_Op<"switch", | ||
| [AttrSizedOperandSegments, | ||
| DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>, | ||
| NoSideEffect, Terminator]> { | ||
| let summary = "switch operation"; | ||
| let description = [{ | ||
| The `switch` terminator operation represents a switch on a signless integer | ||
| value. If the flag matches one of the specified cases, then the | ||
| corresponding destination is jumped to. If the flag does not match any of | ||
| the cases, the default destination is jumped to. The count and types of | ||
| operands must align with the arguments in the corresponding target blocks. | ||
|
|
||
| Example: | ||
|
|
||
| ```mlir | ||
| switch %flag : i32, [ | ||
| default: ^bb1(%a : i32), | ||
| 42: ^bb1(%b : i32), | ||
| 43: ^bb3(%c : i32) | ||
| ] | ||
| ``` | ||
| }]; | ||
|
|
||
| let arguments = (ins | ||
| AnyInteger:$flag, | ||
| Variadic<AnyType>:$defaultOperands, | ||
| VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands, | ||
| OptionalAttr<AnyIntElementsAttr>:$case_values, | ||
| I32ElementsAttr:$case_operand_segments | ||
| ); | ||
| let successors = (successor | ||
| AnySuccessor:$defaultDestination, | ||
| VariadicSuccessor<AnySuccessor>:$caseDestinations | ||
| ); | ||
| let builders = [ | ||
| OpBuilder<(ins "Value":$flag, | ||
| "Block *":$defaultDestination, | ||
| "ValueRange":$defaultOperands, | ||
| CArg<"ArrayRef<APInt>", "{}">:$caseValues, | ||
| CArg<"BlockRange", "{}">:$caseDestinations, | ||
| CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>, | ||
| OpBuilder<(ins "Value":$flag, | ||
| "Block *":$defaultDestination, | ||
| "ValueRange":$defaultOperands, | ||
| CArg<"ArrayRef<int32_t>", "{}">:$caseValues, | ||
| CArg<"BlockRange", "{}">:$caseDestinations, | ||
| CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>, | ||
| OpBuilder<(ins "Value":$flag, | ||
| "Block *":$defaultDestination, | ||
| "ValueRange":$defaultOperands, | ||
| CArg<"DenseIntElementsAttr", "{}">:$caseValues, | ||
| CArg<"BlockRange", "{}">:$caseDestinations, | ||
| CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)> | ||
| ]; | ||
|
|
||
| let assemblyFormat = [{ | ||
| $flag `:` type($flag) `,` `[` `\n` | ||
| custom<SwitchOpCases>(ref(type($flag)),$defaultDestination, | ||
| $defaultOperands, | ||
| type($defaultOperands), | ||
| $case_values, | ||
| $caseDestinations, | ||
| $caseOperands, | ||
| type($caseOperands)) | ||
| `]` | ||
| attr-dict | ||
| }]; | ||
|
|
||
| let extraClassDeclaration = [{ | ||
| /// Return the operands for the case destination block at the given index. | ||
| OperandRange getCaseOperands(unsigned index) { | ||
| return getCaseOperands()[index]; | ||
| } | ||
|
|
||
| /// Return a mutable range of operands for the case destination block at the | ||
| /// given index. | ||
| MutableOperandRange getCaseOperandsMutable(unsigned index) { | ||
| return getCaseOperandsMutable()[index]; | ||
| } | ||
| }]; | ||
|
|
||
| let hasCanonicalizer = 1; | ||
| let hasVerifier = 1; | ||
| } | ||
|
|
||
| #endif // STANDARD_OPS |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| add_mlir_conversion_library(MLIRControlFlowToLLVM | ||
| ControlFlowToLLVM.cpp | ||
|
|
||
| ADDITIONAL_HEADER_DIRS | ||
| ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ControlFlowToLLVM | ||
|
|
||
| DEPENDS | ||
| MLIRConversionPassIncGen | ||
| intrinsics_gen | ||
|
|
||
| LINK_COMPONENTS | ||
| Core | ||
|
|
||
| LINK_LIBS PUBLIC | ||
| MLIRAnalysis | ||
| MLIRControlFlow | ||
| MLIRLLVMCommonConversion | ||
| MLIRLLVMIR | ||
| MLIRPass | ||
| MLIRTransformUtils | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,148 @@ | ||
| //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
| // | ||
| // This file implements a pass to convert MLIR standard and builtin dialects | ||
| // into the LLVM IR dialect. | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" | ||
| #include "../PassDetail.h" | ||
| #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" | ||
| #include "mlir/Conversion/LLVMCommon/Pattern.h" | ||
| #include "mlir/Conversion/LLVMCommon/VectorPattern.h" | ||
| #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" | ||
| #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" | ||
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" | ||
| #include "mlir/IR/BuiltinOps.h" | ||
| #include "mlir/IR/PatternMatch.h" | ||
| #include "mlir/Transforms/DialectConversion.h" | ||
| #include <functional> | ||
|
|
||
| using namespace mlir; | ||
|
|
||
| #define PASS_NAME "convert-cf-to-llvm" | ||
|
|
||
| namespace { | ||
| /// Lower `std.assert`. The default lowering calls the `abort` function if the | ||
| /// assertion is violated and has no effect otherwise. The failure message is | ||
| /// ignored by the default lowering but should be propagated by any custom | ||
| /// lowering. | ||
| struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> { | ||
| using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern; | ||
|
|
||
| LogicalResult | ||
| matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
| auto loc = op.getLoc(); | ||
|
|
||
| // Insert the `abort` declaration if necessary. | ||
| auto module = op->getParentOfType<ModuleOp>(); | ||
| auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort"); | ||
| if (!abortFunc) { | ||
| OpBuilder::InsertionGuard guard(rewriter); | ||
| rewriter.setInsertionPointToStart(module.getBody()); | ||
| auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); | ||
| abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(), | ||
| "abort", abortFuncTy); | ||
| } | ||
|
|
||
| // Split block at `assert` operation. | ||
| Block *opBlock = rewriter.getInsertionBlock(); | ||
| auto opPosition = rewriter.getInsertionPoint(); | ||
| Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition); | ||
|
|
||
| // Generate IR to call `abort`. | ||
| Block *failureBlock = rewriter.createBlock(opBlock->getParent()); | ||
| rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None); | ||
| rewriter.create<LLVM::UnreachableOp>(loc); | ||
|
|
||
| // Generate assertion test. | ||
| rewriter.setInsertionPointToEnd(opBlock); | ||
| rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( | ||
| op, adaptor.getArg(), continuationBlock, failureBlock); | ||
|
|
||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| // Base class for LLVM IR lowering terminator operations with successors. | ||
| template <typename SourceOp, typename TargetOp> | ||
| struct OneToOneLLVMTerminatorLowering | ||
| : public ConvertOpToLLVMPattern<SourceOp> { | ||
| using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern; | ||
| using Base = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>; | ||
|
|
||
| LogicalResult | ||
| matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
| rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(), | ||
| op->getSuccessors(), op->getAttrs()); | ||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| // FIXME: this should be tablegen'ed as well. | ||
| struct BranchOpLowering | ||
| : public OneToOneLLVMTerminatorLowering<cf::BranchOp, LLVM::BrOp> { | ||
| using Base::Base; | ||
| }; | ||
| struct CondBranchOpLowering | ||
| : public OneToOneLLVMTerminatorLowering<cf::CondBranchOp, LLVM::CondBrOp> { | ||
| using Base::Base; | ||
| }; | ||
| struct SwitchOpLowering | ||
| : public OneToOneLLVMTerminatorLowering<cf::SwitchOp, LLVM::SwitchOp> { | ||
| using Base::Base; | ||
| }; | ||
|
|
||
| } // namespace | ||
|
|
||
| void mlir::cf::populateControlFlowToLLVMConversionPatterns( | ||
| LLVMTypeConverter &converter, RewritePatternSet &patterns) { | ||
| // clang-format off | ||
| patterns.add< | ||
| AssertOpLowering, | ||
| BranchOpLowering, | ||
| CondBranchOpLowering, | ||
| SwitchOpLowering>(converter); | ||
| // clang-format on | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // Pass Definition | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| namespace { | ||
| /// A pass converting MLIR operations into the LLVM IR dialect. | ||
| struct ConvertControlFlowToLLVM | ||
| : public ConvertControlFlowToLLVMBase<ConvertControlFlowToLLVM> { | ||
| ConvertControlFlowToLLVM() = default; | ||
|
|
||
| /// Run the dialect converter on the module. | ||
| void runOnOperation() override { | ||
| LLVMConversionTarget target(getContext()); | ||
| RewritePatternSet patterns(&getContext()); | ||
|
|
||
| LowerToLLVMOptions options(&getContext()); | ||
| if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) | ||
| options.overrideIndexBitwidth(indexBitwidth); | ||
|
|
||
| LLVMTypeConverter converter(&getContext(), options); | ||
| mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); | ||
|
|
||
| if (failed(applyPartialConversion(getOperation(), target, | ||
| std::move(patterns)))) | ||
| signalPassFailure(); | ||
| } | ||
| }; | ||
| } // namespace | ||
|
|
||
| std::unique_ptr<Pass> mlir::cf::createConvertControlFlowToLLVMPass() { | ||
| return std::make_unique<ConvertControlFlowToLLVM>(); | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| add_mlir_conversion_library(MLIRControlFlowToSPIRV | ||
| ControlFlowToSPIRV.cpp | ||
|
|
||
| ADDITIONAL_HEADER_DIRS | ||
| ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV | ||
| ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR | ||
|
|
||
| DEPENDS | ||
| MLIRConversionPassIncGen | ||
|
|
||
| LINK_LIBS PUBLIC | ||
| MLIRIR | ||
| MLIRControlFlow | ||
| MLIRPass | ||
| MLIRSPIRV | ||
| MLIRSPIRVConversion | ||
| MLIRSupport | ||
| MLIRTransformUtils | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| //===- ControlFlowToSPIRV.cpp - ControlFlow to SPIR-V Patterns ------------===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
| // | ||
| // This file implements patterns to convert standard dialect to SPIR-V dialect. | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h" | ||
| #include "../SPIRVCommon/Pattern.h" | ||
| #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" | ||
| #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" | ||
| #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" | ||
| #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" | ||
| #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" | ||
| #include "mlir/IR/AffineMap.h" | ||
| #include "mlir/Support/LogicalResult.h" | ||
| #include "llvm/ADT/SetVector.h" | ||
| #include "llvm/Support/Debug.h" | ||
|
|
||
| #define DEBUG_TYPE "cf-to-spirv-pattern" | ||
|
|
||
| using namespace mlir; | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // Operation conversion | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| namespace { | ||
|
|
||
| /// Converts cf.br to spv.Branch. | ||
| struct BranchOpPattern final : public OpConversionPattern<cf::BranchOp> { | ||
| using OpConversionPattern<cf::BranchOp>::OpConversionPattern; | ||
|
|
||
| LogicalResult | ||
| matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
| rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(), | ||
| adaptor.getDestOperands()); | ||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| /// Converts cf.cond_br to spv.BranchConditional. | ||
| struct CondBranchOpPattern final | ||
| : public OpConversionPattern<cf::CondBranchOp> { | ||
| using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern; | ||
|
|
||
| LogicalResult | ||
| matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
| rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>( | ||
| op, op.getCondition(), op.getTrueDest(), adaptor.getTrueDestOperands(), | ||
| op.getFalseDest(), adaptor.getFalseDestOperands()); | ||
| return success(); | ||
| } | ||
| }; | ||
| } // namespace | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // Pattern population | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| void mlir::cf::populateControlFlowToSPIRVPatterns( | ||
| SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { | ||
| MLIRContext *context = patterns.getContext(); | ||
|
|
||
| patterns.add<BranchOpPattern, CondBranchOpPattern>(typeConverter, context); | ||
| } |