70 changes: 35 additions & 35 deletions mlir/docs/BufferDeallocationInternals.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ Example for breaking the invariant:
```mlir
func @condBranch(%arg0: i1, %arg1: memref<2xf32>) {
%0 = memref.alloc() : memref<2xf32>
cond_br %arg0, ^bb1, ^bb2
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
br ^bb3()
cf.br ^bb3()
^bb2:
partial_write(%0, %0)
br ^bb3()
cf.br ^bb3()
^bb3():
test.copy(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> ()
return
Expand Down Expand Up @@ -74,13 +74,13 @@ untracked allocations are mixed:
func @mixedAllocation(%arg0: i1) {
%0 = memref.alloca() : memref<2xf32> // aliases: %2
%1 = memref.alloc() : memref<2xf32> // aliases: %2
cond_br %arg0, ^bb1, ^bb2
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
use(%0)
br ^bb3(%0 : memref<2xf32>)
cf.br ^bb3(%0 : memref<2xf32>)
^bb2:
use(%1)
br ^bb3(%1 : memref<2xf32>)
cf.br ^bb3(%1 : memref<2xf32>)
^bb3(%2: memref<2xf32>):
...
}
Expand Down Expand Up @@ -129,13 +129,13 @@ BufferHoisting pass:

```mlir
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
br ^bb3(%arg1 : memref<2xf32>)
cf.br ^bb3(%arg1 : memref<2xf32>)
^bb2:
%0 = memref.alloc() : memref<2xf32> // aliases: %1
use(%0)
br ^bb3(%0 : memref<2xf32>)
cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>): // %1 could be %0 or %arg1
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
return
Expand All @@ -150,12 +150,12 @@ of code:
```mlir
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
%0 = memref.alloc() : memref<2xf32> // moved to bb0
cond_br %arg0, ^bb1, ^bb2
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
br ^bb3(%arg1 : memref<2xf32>)
cf.br ^bb3(%arg1 : memref<2xf32>)
^bb2:
use(%0)
br ^bb3(%0 : memref<2xf32>)
cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>):
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
return
Expand All @@ -175,14 +175,14 @@ func @condBranchDynamicType(
%arg1: memref<?xf32>,
%arg2: memref<?xf32>,
%arg3: index) {
cond_br %arg0, ^bb1, ^bb2(%arg3: index)
cf.cond_br %arg0, ^bb1, ^bb2(%arg3: index)
^bb1:
br ^bb3(%arg1 : memref<?xf32>)
cf.br ^bb3(%arg1 : memref<?xf32>)
^bb2(%0: index):
%1 = memref.alloc(%0) : memref<?xf32> // cannot be moved upwards to the data
// dependency to %0
use(%1)
br ^bb3(%1 : memref<?xf32>)
cf.br ^bb3(%1 : memref<?xf32>)
^bb3(%2: memref<?xf32>):
test.copy(%2, %arg2) : (memref<?xf32>, memref<?xf32>) -> ()
return
Expand All @@ -201,14 +201,14 @@ allocations have already been placed:
```mlir
func @branch(%arg0: i1) {
%0 = memref.alloc() : memref<2xf32> // aliases: %2
cond_br %arg0, ^bb1, ^bb2
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
%1 = memref.alloc() : memref<2xf32> // resides here for demonstration purposes
// aliases: %2
br ^bb3(%1 : memref<2xf32>)
cf.br ^bb3(%1 : memref<2xf32>)
^bb2:
use(%0)
br ^bb3(%0 : memref<2xf32>)
cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%2: memref<2xf32>):
…
return
Expand All @@ -233,16 +233,16 @@ result:
```mlir
func @branch(%arg0: i1) {
%0 = memref.alloc() : memref<2xf32>
cond_br %arg0, ^bb1, ^bb2
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
%1 = memref.alloc() : memref<2xf32>
%3 = bufferization.clone %1 : (memref<2xf32>) -> (memref<2xf32>)
memref.dealloc %1 : memref<2xf32> // %1 can be safely freed here
br ^bb3(%3 : memref<2xf32>)
cf.br ^bb3(%3 : memref<2xf32>)
^bb2:
use(%0)
%4 = bufferization.clone %0 : (memref<2xf32>) -> (memref<2xf32>)
br ^bb3(%4 : memref<2xf32>)
cf.br ^bb3(%4 : memref<2xf32>)
^bb3(%2: memref<2xf32>):
…
memref.dealloc %2 : memref<2xf32> // free temp buffer %2
Expand Down Expand Up @@ -273,23 +273,23 @@ func @condBranchDynamicTypeNested(
%arg1: memref<?xf32>, // aliases: %3, %4
%arg2: memref<?xf32>,
%arg3: index) {
cond_br %arg0, ^bb1, ^bb2(%arg3: index)
cf.cond_br %arg0, ^bb1, ^bb2(%arg3: index)
^bb1:
br ^bb6(%arg1 : memref<?xf32>)
cf.br ^bb6(%arg1 : memref<?xf32>)
^bb2(%0: index):
%1 = memref.alloc(%0) : memref<?xf32> // cannot be moved upwards due to the data
// dependency to %0
// aliases: %2, %3, %4
use(%1)
cond_br %arg0, ^bb3, ^bb4
cf.cond_br %arg0, ^bb3, ^bb4
^bb3:
br ^bb5(%1 : memref<?xf32>)
cf.br ^bb5(%1 : memref<?xf32>)
^bb4:
br ^bb5(%1 : memref<?xf32>)
cf.br ^bb5(%1 : memref<?xf32>)
^bb5(%2: memref<?xf32>): // non-crit. alias of %1, since %1 dominates %2
br ^bb6(%2 : memref<?xf32>)
cf.br ^bb6(%2 : memref<?xf32>)
^bb6(%3: memref<?xf32>): // crit. alias of %arg1 and %2 (in other words %1)
br ^bb7(%3 : memref<?xf32>)
cf.br ^bb7(%3 : memref<?xf32>)
^bb7(%4: memref<?xf32>): // non-crit. alias of %3, since %3 dominates %4
test.copy(%4, %arg2) : (memref<?xf32>, memref<?xf32>) -> ()
return
Expand All @@ -306,25 +306,25 @@ func @condBranchDynamicTypeNested(
%arg1: memref<?xf32>,
%arg2: memref<?xf32>,
%arg3: index) {
cond_br %arg0, ^bb1, ^bb2(%arg3 : index)
cf.cond_br %arg0, ^bb1, ^bb2(%arg3 : index)
^bb1:
// temp buffer required due to alias %3
%5 = bufferization.clone %arg1 : (memref<?xf32>) -> (memref<?xf32>)
br ^bb6(%5 : memref<?xf32>)
cf.br ^bb6(%5 : memref<?xf32>)
^bb2(%0: index):
%1 = memref.alloc(%0) : memref<?xf32>
use(%1)
cond_br %arg0, ^bb3, ^bb4
cf.cond_br %arg0, ^bb3, ^bb4
^bb3:
br ^bb5(%1 : memref<?xf32>)
cf.br ^bb5(%1 : memref<?xf32>)
^bb4:
br ^bb5(%1 : memref<?xf32>)
cf.br ^bb5(%1 : memref<?xf32>)
^bb5(%2: memref<?xf32>):
%6 = bufferization.clone %1 : (memref<?xf32>) -> (memref<?xf32>)
memref.dealloc %1 : memref<?xf32>
br ^bb6(%6 : memref<?xf32>)
cf.br ^bb6(%6 : memref<?xf32>)
^bb6(%3: memref<?xf32>):
br ^bb7(%3 : memref<?xf32>)
cf.br ^bb7(%3 : memref<?xf32>)
^bb7(%4: memref<?xf32>):
test.copy(%4, %arg2) : (memref<?xf32>, memref<?xf32>) -> ()
memref.dealloc %3 : memref<?xf32> // free %3, since %4 is a non-crit. alias of %3
Expand Down
2 changes: 1 addition & 1 deletion mlir/docs/Diagnostics.md
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ A few examples are shown below:
```mlir
// Expect an error on the same line.
func @bad_branch() {
br ^missing // expected-error {{reference to an undefined block}}
cf.br ^missing // expected-error {{reference to an undefined block}}
}
// Expect an error on an adjacent line.
Expand Down
4 changes: 2 additions & 2 deletions mlir/docs/DialectConversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ struct MyTarget : public ConversionTarget {
/// All operations within the GPU dialect are illegal.
addIllegalDialect<GPUDialect>();

/// Mark `std.br` and `std.cond_br` as illegal.
addIllegalOp<BranchOp, CondBranchOp>();
/// Mark `cf.br` and `cf.cond_br` as illegal.
addIllegalOp<cf::BranchOp, cf::CondBranchOp>();
}

/// Implement the default legalization handler to handle operations marked as
Expand Down
5 changes: 3 additions & 2 deletions mlir/docs/Dialects/emitc.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ argument `-declare-variables-at-top`.
Besides operations part of the EmitC dialect, the Cpp targets supports
translating the following operations:

* 'cf' Dialect
* `cf.br`
* `cf.cond_br`
* 'std' Dialect
* `std.br`
* `std.call`
* `std.cond_br`
* `std.constant`
* `std.return`
* 'scf' Dialect
Expand Down
12 changes: 6 additions & 6 deletions mlir/docs/LangRef.md
Original file line number Diff line number Diff line change
Expand Up @@ -391,21 +391,21 @@ arguments:
```mlir
func @simple(i64, i1) -> i64 {
^bb0(%a: i64, %cond: i1): // Code dominated by ^bb0 may refer to %a
cond_br %cond, ^bb1, ^bb2
cf.cond_br %cond, ^bb1, ^bb2
^bb1:
br ^bb3(%a: i64) // Branch passes %a as the argument
cf.br ^bb3(%a: i64) // Branch passes %a as the argument
^bb2:
%b = arith.addi %a, %a : i64
br ^bb3(%b: i64) // Branch passes %b as the argument
cf.br ^bb3(%b: i64) // Branch passes %b as the argument
// ^bb3 receives an argument, named %c, from predecessors
// and passes it on to bb4 along with %a. %a is referenced
// directly from its defining operation and is not passed through
// an argument of ^bb3.
^bb3(%c: i64):
br ^bb4(%c, %a : i64, i64)
cf.br ^bb4(%c, %a : i64, i64)
^bb4(%d : i64, %e : i64):
%0 = arith.addi %d, %e : i64
Expand Down Expand Up @@ -525,12 +525,12 @@ Example:
```mlir
func @accelerator_compute(i64, i1) -> i64 { // An SSACFG region
^bb0(%a: i64, %cond: i1): // Code dominated by ^bb0 may refer to %a
cond_br %cond, ^bb1, ^bb2
cf.cond_br %cond, ^bb1, ^bb2
^bb1:
// This def for %value does not dominate ^bb2
%value = "op.convert"(%a) : (i64) -> i64
br ^bb3(%a: i64) // Branch passes %a as the argument
cf.br ^bb3(%a: i64) // Branch passes %a as the argument
^bb2:
accelerator.launch() { // An SSACFG region
Expand Down
16 changes: 8 additions & 8 deletions mlir/docs/PatternRewriter.md
Original file line number Diff line number Diff line change
Expand Up @@ -356,24 +356,24 @@ Example output is shown below:
```
//===-------------------------------------------===//
Processing operation : 'std.cond_br'(0x60f000001120) {
"std.cond_br"(%arg0)[^bb2, ^bb2] {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (i1) -> ()
Processing operation : 'cf.cond_br'(0x60f000001120) {
"cf.cond_br"(%arg0)[^bb2, ^bb2] {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (i1) -> ()
* Pattern SimplifyConstCondBranchPred : 'std.cond_br -> ()' {
* Pattern SimplifyConstCondBranchPred : 'cf.cond_br -> ()' {
} -> failure : pattern failed to match
* Pattern SimplifyCondBranchIdenticalSuccessors : 'std.cond_br -> ()' {
** Insert : 'std.br'(0x60b000003690)
** Replace : 'std.cond_br'(0x60f000001120)
* Pattern SimplifyCondBranchIdenticalSuccessors : 'cf.cond_br -> ()' {
** Insert : 'cf.br'(0x60b000003690)
** Replace : 'cf.cond_br'(0x60f000001120)
} -> success : pattern applied successfully
} -> success : pattern matched
//===-------------------------------------------===//
```
This output is describing the processing of a `std.cond_br` operation. We first
This output is describing the processing of a `cf.cond_br` operation. We first
try to apply the `SimplifyConstCondBranchPred`, which fails. From there, another
pattern (`SimplifyCondBranchIdenticalSuccessors`) is applied that matches the
`std.cond_br` and replaces it with a `std.br`.
`cf.cond_br` and replaces it with a `cf.br`.
## Debugging
Expand Down
10 changes: 5 additions & 5 deletions mlir/docs/Rationale/Rationale.md
Original file line number Diff line number Diff line change
Expand Up @@ -560,24 +560,24 @@ func @search(%A: memref<?x?xi32>, %S: <?xi32>, %key : i32) {
func @search_body(%A: memref<?x?xi32>, %S: memref<?xi32>, %key: i32, %i : i32) {
%nj = memref.dim %A, 1 : memref<?x?xi32>
br ^bb1(0)
cf.br ^bb1(0)
^bb1(%j: i32)
%p1 = arith.cmpi "lt", %j, %nj : i32
cond_br %p1, ^bb2, ^bb5
cf.cond_br %p1, ^bb2, ^bb5
^bb2:
%v = affine.load %A[%i, %j] : memref<?x?xi32>
%p2 = arith.cmpi "eq", %v, %key : i32
cond_br %p2, ^bb3(%j), ^bb4
cf.cond_br %p2, ^bb3(%j), ^bb4
^bb3(%j: i32)
affine.store %j, %S[%i] : memref<?xi32>
br ^bb5
cf.br ^bb5
^bb4:
%jinc = arith.addi %j, 1 : i32
br ^bb1(%jinc)
cf.br ^bb1(%jinc)
^bb5:
return
Expand Down
5 changes: 3 additions & 2 deletions mlir/docs/Tutorials/Toy/Ch-6.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,11 @@ multiple stages by relying on
```c++
mlir::RewritePatternSet patterns(&getContext());
mlir::populateAffineToStdConversionPatterns(patterns, &getContext());
mlir::populateLoopToStdConversionPatterns(patterns, &getContext());
mlir::cf::populateSCFToControlFlowConversionPatterns(patterns, &getContext());
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
patterns);
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
mlir::cf::populateControlFlowToLLVMConversionPatterns(patterns, &getContext());
// The only remaining operation, to lower from the `toy` dialect, is the
// PrintOp.
Expand Down Expand Up @@ -207,7 +208,7 @@ define void @main() {
%109 = memref.load double, double* %108
%110 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double %109)
%111 = add i64 %100, 1
br label %99
cf.br label %99
...
Expand Down
2 changes: 1 addition & 1 deletion mlir/docs/includes/img/branch_example_post_move.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion mlir/docs/includes/img/branch_example_pre_move.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 4 additions & 2 deletions mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@

#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
Expand Down Expand Up @@ -200,10 +201,11 @@ void ToyToLLVMLoweringPass::runOnOperation() {
// set of legal ones.
RewritePatternSet patterns(&getContext());
populateAffineToStdConversionPatterns(patterns);
populateLoopToStdConversionPatterns(patterns);
populateSCFToControlFlowConversionPatterns(patterns);
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
patterns);
populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
populateStdToLLVMConversionPatterns(typeConverter, patterns);

// The only remaining operation to lower from the `toy` dialect, is the
Expand Down
6 changes: 4 additions & 2 deletions mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@

#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
Expand Down Expand Up @@ -200,10 +201,11 @@ void ToyToLLVMLoweringPass::runOnOperation() {
// set of legal ones.
RewritePatternSet patterns(&getContext());
populateAffineToStdConversionPatterns(patterns);
populateLoopToStdConversionPatterns(patterns);
populateSCFToControlFlowConversionPatterns(patterns);
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
patterns);
populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
populateStdToLLVMConversionPatterns(typeConverter, patterns);

// The only remaining operation to lower from the `toy` dialect, is the
Expand Down
35 changes: 35 additions & 0 deletions mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//===- ControlFlowToLLVM.h - ControlFlow to LLVM -----------*- C++ ------*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Define conversions from the ControlFlow dialect to the LLVM IR dialect.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_CONTROLFLOWTOLLVM_CONTROLFLOWTOLLVM_H
#define MLIR_CONVERSION_CONTROLFLOWTOLLVM_CONTROLFLOWTOLLVM_H

#include <memory>

namespace mlir {
class LLVMTypeConverter;
class RewritePatternSet;
class Pass;

namespace cf {
/// Collect the patterns to convert from the ControlFlow dialect to LLVM. The
/// conversion patterns capture the LLVMTypeConverter by reference meaning the
/// references have to remain alive during the entire pattern lifetime.
void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);

/// Creates a pass to convert the ControlFlow dialect into the LLVMIR dialect.
std::unique_ptr<Pass> createConvertControlFlowToLLVMPass();
} // namespace cf
} // namespace mlir

#endif // MLIR_CONVERSION_CONTROLFLOWTOLLVM_CONTROLFLOWTOLLVM_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//===- ControlFlowToSPIRV.h - CF to SPIR-V Patterns --------*- C++ ------*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Provides patterns to convert ControlFlow dialect to SPIR-V dialect.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRV_H
#define MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRV_H

namespace mlir {
class RewritePatternSet;
class SPIRVTypeConverter;

namespace cf {
/// Appends to a pattern list additional patterns for translating ControlFLow
/// ops to SPIR-V ops.
void populateControlFlowToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
} // namespace cf
} // namespace mlir

#endif // MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRV_H
4 changes: 3 additions & 1 deletion mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
Expand All @@ -35,10 +37,10 @@
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
Expand Down
46 changes: 34 additions & 12 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,28 @@ def ConvertComplexToStandard : Pass<"convert-complex-to-standard", "FuncOp"> {
let dependentDialects = ["math::MathDialect"];
}

//===----------------------------------------------------------------------===//
// ControlFlowToLLVM
//===----------------------------------------------------------------------===//

def ConvertControlFlowToLLVM : Pass<"convert-cf-to-llvm", "ModuleOp"> {
let summary = "Convert ControlFlow operations to the LLVM dialect";
let description = [{
Convert ControlFlow operations into LLVM IR dialect operations.

If other operations are present and their results are required by the LLVM
IR dialect operations, the pass will fail. Any LLVM IR operations or types
already present in the IR will be kept as is.
}];
let constructor = "mlir::cf::createConvertControlFlowToLLVMPass()";
let dependentDialects = ["LLVM::LLVMDialect"];
let options = [
Option<"indexBitwidth", "index-bitwidth", "unsigned",
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
"Bitwidth of the index type, 0 to use size of machine word">,
];
}

//===----------------------------------------------------------------------===//
// GPUCommon
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -460,6 +482,17 @@ def ReconcileUnrealizedCasts : Pass<"reconcile-unrealized-casts"> {
let constructor = "mlir::createReconcileUnrealizedCastsPass()";
}

//===----------------------------------------------------------------------===//
// SCFToControlFlow
//===----------------------------------------------------------------------===//

def SCFToControlFlow : Pass<"convert-scf-to-cf"> {
let summary = "Convert SCF dialect to ControlFlow dialect, replacing structured"
" control flow with a CFG";
let constructor = "mlir::createConvertSCFToCFPass()";
let dependentDialects = ["cf::ControlFlowDialect"];
}

//===----------------------------------------------------------------------===//
// SCFToOpenMP
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -488,17 +521,6 @@ def SCFToSPIRV : Pass<"convert-scf-to-spirv", "ModuleOp"> {
let dependentDialects = ["spirv::SPIRVDialect"];
}

//===----------------------------------------------------------------------===//
// SCFToStandard
//===----------------------------------------------------------------------===//

def SCFToStandard : Pass<"convert-scf-to-std"> {
let summary = "Convert SCF dialect to Standard dialect, replacing structured"
" control flow with a CFG";
let constructor = "mlir::createLowerToCFGPass()";
let dependentDialects = ["StandardOpsDialect"];
}

//===----------------------------------------------------------------------===//
// SCFToGPU
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -547,7 +569,7 @@ def ConvertShapeConstraints: Pass<"convert-shape-constraints", "FuncOp"> {
computation lowering.
}];
let constructor = "mlir::createConvertShapeConstraintsPass()";
let dependentDialects = ["StandardOpsDialect", "scf::SCFDialect"];
let dependentDialects = ["cf::ControlFlowDialect", "scf::SCFDialect"];
}

//===----------------------------------------------------------------------===//
Expand Down
28 changes: 28 additions & 0 deletions mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//===- ConvertSCFToControlFlow.h - Pass entrypoint --------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_SCFTOCONTROLFLOW_SCFTOCONTROLFLOW_H_
#define MLIR_CONVERSION_SCFTOCONTROLFLOW_SCFTOCONTROLFLOW_H_

#include <memory>

namespace mlir {
class Pass;
class RewritePatternSet;

/// Collect a set of patterns to convert SCF operations to CFG branch-based
/// operations within the ControlFlow dialect.
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns);

/// Creates a pass to convert SCF operations to CFG branch-based operation in
/// the ControlFlow dialect.
std::unique_ptr<Pass> createConvertSCFToCFPass();

} // namespace mlir

#endif // MLIR_CONVERSION_SCFTOCONTROLFLOW_SCFTOCONTROLFLOW_H_
31 changes: 0 additions & 31 deletions mlir/include/mlir/Conversion/SCFToStandard/SCFToStandard.h

This file was deleted.

12 changes: 6 additions & 6 deletions mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
#map0 = affine_map<(d0) -> (d0)>
module {
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
br ^bb3(%arg1 : memref<2xf32>)
cf.br ^bb3(%arg1 : memref<2xf32>)
^bb2:
%0 = memref.alloc() : memref<2xf32>
linalg.generic {
Expand All @@ -40,7 +40,7 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
}: memref<2xf32>, memref<2xf32>
br ^bb3(%0 : memref<2xf32>)
cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>):
"memref.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
return
Expand All @@ -55,11 +55,11 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
#map0 = affine_map<(d0) -> (d0)>
module {
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2
cf.cond_br %arg0, ^bb1, ^bb2
^bb1: // pred: ^bb0
%0 = memref.alloc() : memref<2xf32>
memref.copy(%arg1, %0) : memref<2xf32>, memref<2xf32>
br ^bb3(%0 : memref<2xf32>)
cf.br ^bb3(%0 : memref<2xf32>)
^bb2: // pred: ^bb0
%1 = memref.alloc() : memref<2xf32>
linalg.generic {
Expand All @@ -74,7 +74,7 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
%2 = memref.alloc() : memref<2xf32>
memref.copy(%1, %2) : memref<2xf32>, memref<2xf32>
dealloc %1 : memref<2xf32>
br ^bb3(%2 : memref<2xf32>)
cf.br ^bb3(%2 : memref<2xf32>)
^bb3(%3: memref<2xf32>): // 2 preds: ^bb1, ^bb2
memref.copy(%3, %arg2) : memref<2xf32>, memref<2xf32>
dealloc %3 : memref<2xf32>
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_subdirectory(ArmSVE)
add_subdirectory(AMX)
add_subdirectory(Bufferization)
add_subdirectory(Complex)
add_subdirectory(ControlFlow)
add_subdirectory(DLTI)
add_subdirectory(EmitC)
add_subdirectory(GPU)
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/ControlFlow/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(IR)
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/ControlFlow/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
add_mlir_dialect(ControlFlowOps cf ControlFlowOps)
add_mlir_doc(ControlFlowOps ControlFlowDialect Dialects/ -gen-dialect-doc)
21 changes: 21 additions & 0 deletions mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlow.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//===- ControlFlow.h - ControlFlow Dialect ----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the ControlFlow dialect.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOW_H
#define MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOW_H

#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/IR/Dialect.h"

#include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.h.inc"

#endif // MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOW_H
30 changes: 30 additions & 0 deletions mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//===- ControlFlowOps.h - ControlFlow Operations ----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the operations of the ControlFlow dialect.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOWOPS_H
#define MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOWOPS_H

#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

namespace mlir {
class PatternRewriter;
} // namespace mlir

#define GET_OP_CLASSES
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h.inc"

#endif // MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOWOPS_H
313 changes: 313 additions & 0 deletions mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
//===- ControlFlowOps.td - ControlFlow operations ----------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains definitions for the operations within the ControlFlow
// dialect.
//
//===----------------------------------------------------------------------===//

#ifndef STANDARD_OPS
#define STANDARD_OPS

include "mlir/IR/OpAsmInterface.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

def ControlFlow_Dialect : Dialect {
let name = "cf";
let cppNamespace = "::mlir::cf";
let dependentDialects = ["arith::ArithmeticDialect"];
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
let description = [{
This dialect contains low-level, i.e. non-region based, control flow
constructs. These constructs generally represent control flow directly
on SSA blocks of a control flow graph.
}];
}

class CF_Op<string mnemonic, list<Trait> traits = []> :
Op<ControlFlow_Dialect, mnemonic, traits>;

//===----------------------------------------------------------------------===//
// AssertOp
//===----------------------------------------------------------------------===//

def AssertOp : CF_Op<"assert"> {
let summary = "Assert operation with message attribute";
let description = [{
Assert operation with single boolean operand and an error message attribute.
If the argument is `true` this operation has no effect. Otherwise, the
program execution will abort. The provided error message may be used by a
runtime to propagate the error to the user.

Example:

```mlir
assert %b, "Expected ... to be true"
```
}];

let arguments = (ins I1:$arg, StrAttr:$msg);

let assemblyFormat = "$arg `,` $msg attr-dict";
let hasCanonicalizeMethod = 1;
}

//===----------------------------------------------------------------------===//
// BranchOp
//===----------------------------------------------------------------------===//

def BranchOp : CF_Op<"br", [
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
NoSideEffect, Terminator
]> {
let summary = "branch operation";
let description = [{
The `cf.br` operation represents a direct branch operation to a given
block. The operands of this operation are forwarded to the successor block,
and the number and type of the operands must match the arguments of the
target block.

Example:

```mlir
^bb2:
%2 = call @someFn()
cf.br ^bb3(%2 : tensor<*xf32>)
^bb3(%3: tensor<*xf32>):
```
}];

let arguments = (ins Variadic<AnyType>:$destOperands);
let successors = (successor AnySuccessor:$dest);

let builders = [
OpBuilder<(ins "Block *":$dest,
CArg<"ValueRange", "{}">:$destOperands), [{
$_state.addSuccessors(dest);
$_state.addOperands(destOperands);
}]>];

let extraClassDeclaration = [{
void setDest(Block *block);

/// Erase the operand at 'index' from the operand list.
void eraseOperand(unsigned index);
}];

let hasCanonicalizeMethod = 1;
let assemblyFormat = [{
$dest (`(` $destOperands^ `:` type($destOperands) `)`)? attr-dict
}];
}

//===----------------------------------------------------------------------===//
// CondBranchOp
//===----------------------------------------------------------------------===//

def CondBranchOp : CF_Op<"cond_br",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
NoSideEffect, Terminator]> {
let summary = "conditional branch operation";
let description = [{
The `cond_br` terminator operation represents a conditional branch on a
boolean (1-bit integer) value. If the bit is set, then the first destination
is jumped to; if it is false, the second destination is chosen. The count
and types of operands must align with the arguments in the corresponding
target blocks.

The MLIR conditional branch operation is not allowed to target the entry
block for a region. The two destinations of the conditional branch operation
are allowed to be the same.

The following example illustrates a function with a conditional branch
operation that targets the same block.

Example:

```mlir
func @select(%a: i32, %b: i32, %flag: i1) -> i32 {
// Both targets are the same, operands differ
cond_br %flag, ^bb1(%a : i32), ^bb1(%b : i32)

^bb1(%x : i32) :
return %x : i32
}
```
}];

let arguments = (ins I1:$condition,
Variadic<AnyType>:$trueDestOperands,
Variadic<AnyType>:$falseDestOperands);
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);

let builders = [
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
"ValueRange":$trueOperands, "Block *":$falseDest,
"ValueRange":$falseOperands), [{
build($_builder, $_state, condition, trueOperands, falseOperands, trueDest,
falseDest);
}]>,
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
"Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands), [{
build($_builder, $_state, condition, trueDest, ValueRange(), falseDest,
falseOperands);
}]>];

let extraClassDeclaration = [{
// These are the indices into the dests list.
enum { trueIndex = 0, falseIndex = 1 };

// Accessors for operands to the 'true' destination.
Value getTrueOperand(unsigned idx) {
assert(idx < getNumTrueOperands());
return getOperand(getTrueDestOperandIndex() + idx);
}

void setTrueOperand(unsigned idx, Value value) {
assert(idx < getNumTrueOperands());
setOperand(getTrueDestOperandIndex() + idx, value);
}

unsigned getNumTrueOperands() { return getTrueOperands().size(); }

/// Erase the operand at 'index' from the true operand list.
void eraseTrueOperand(unsigned index) {
getTrueDestOperandsMutable().erase(index);
}

// Accessors for operands to the 'false' destination.
Value getFalseOperand(unsigned idx) {
assert(idx < getNumFalseOperands());
return getOperand(getFalseDestOperandIndex() + idx);
}
void setFalseOperand(unsigned idx, Value value) {
assert(idx < getNumFalseOperands());
setOperand(getFalseDestOperandIndex() + idx, value);
}

operand_range getTrueOperands() { return getTrueDestOperands(); }
operand_range getFalseOperands() { return getFalseDestOperands(); }

unsigned getNumFalseOperands() { return getFalseOperands().size(); }

/// Erase the operand at 'index' from the false operand list.
void eraseFalseOperand(unsigned index) {
getFalseDestOperandsMutable().erase(index);
}

private:
/// Get the index of the first true destination operand.
unsigned getTrueDestOperandIndex() { return 1; }

/// Get the index of the first false destination operand.
unsigned getFalseDestOperandIndex() {
return getTrueDestOperandIndex() + getNumTrueOperands();
}
}];

let hasCanonicalizer = 1;
let assemblyFormat = [{
$condition `,`
$trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,`
$falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)?
attr-dict
}];
}

//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//

def SwitchOp : CF_Op<"switch",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
NoSideEffect, Terminator]> {
let summary = "switch operation";
let description = [{
The `switch` terminator operation represents a switch on a signless integer
value. If the flag matches one of the specified cases, then the
corresponding destination is jumped to. If the flag does not match any of
the cases, the default destination is jumped to. The count and types of
operands must align with the arguments in the corresponding target blocks.

Example:

```mlir
switch %flag : i32, [
default: ^bb1(%a : i32),
42: ^bb1(%b : i32),
43: ^bb3(%c : i32)
]
```
}];

let arguments = (ins
AnyInteger:$flag,
Variadic<AnyType>:$defaultOperands,
VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
OptionalAttr<AnyIntElementsAttr>:$case_values,
I32ElementsAttr:$case_operand_segments
);
let successors = (successor
AnySuccessor:$defaultDestination,
VariadicSuccessor<AnySuccessor>:$caseDestinations
);
let builders = [
OpBuilder<(ins "Value":$flag,
"Block *":$defaultDestination,
"ValueRange":$defaultOperands,
CArg<"ArrayRef<APInt>", "{}">:$caseValues,
CArg<"BlockRange", "{}">:$caseDestinations,
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
OpBuilder<(ins "Value":$flag,
"Block *":$defaultDestination,
"ValueRange":$defaultOperands,
CArg<"ArrayRef<int32_t>", "{}">:$caseValues,
CArg<"BlockRange", "{}">:$caseDestinations,
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
OpBuilder<(ins "Value":$flag,
"Block *":$defaultDestination,
"ValueRange":$defaultOperands,
CArg<"DenseIntElementsAttr", "{}">:$caseValues,
CArg<"BlockRange", "{}">:$caseDestinations,
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>
];

let assemblyFormat = [{
$flag `:` type($flag) `,` `[` `\n`
custom<SwitchOpCases>(ref(type($flag)),$defaultDestination,
$defaultOperands,
type($defaultOperands),
$case_values,
$caseDestinations,
$caseOperands,
type($caseOperands))
`]`
attr-dict
}];

let extraClassDeclaration = [{
/// Return the operands for the case destination block at the given index.
OperandRange getCaseOperands(unsigned index) {
return getCaseOperands()[index];
}

/// Return a mutable range of operands for the case destination block at the
/// given index.
MutableOperandRange getCaseOperandsMutable(unsigned index) {
return getCaseOperandsMutable()[index];
}
}];

let hasCanonicalizer = 1;
let hasVerifier = 1;
}

#endif // STANDARD_OPS
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/SCF/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ def ExecuteRegionOp : SCF_Op<"execute_region"> {
affine.for %i = 0 to 100 {
"foo"() : () -> ()
%v = scf.execute_region -> i64 {
cond_br %cond, ^bb1, ^bb2
cf.cond_br %cond, ^bb1, ^bb2

^bb1:
%c1 = arith.constant 1 : i64
br ^bb3(%c1 : i64)
cf.br ^bb3(%c1 : i64)

^bb2:
%c2 = arith.constant 2 : i64
br ^bb3(%c2 : i64)
cf.br ^bb3(%c2 : i64)

^bb3(%x : i64):
scf.yield %x : i64
Expand Down
3 changes: 1 addition & 2 deletions mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#ifndef MLIR_DIALECT_STANDARDOPS_IR_OPS_H
#define MLIR_DIALECT_STANDARDOPS_IR_OPS_H

#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
Expand All @@ -24,7 +24,6 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"

// Pull in all enum type definitions and utility function declarations.
#include "mlir/Dialect/StandardOps/IR/OpsEnums.h.inc"
Expand Down
279 changes: 1 addition & 278 deletions mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@ include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"

def StandardOps_Dialect : Dialect {
let name = "std";
let cppNamespace = "::mlir";
let dependentDialects = ["arith::ArithmeticDialect"];
let dependentDialects = ["cf::ControlFlowDialect"];
let hasConstantMaterializer = 1;
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
}
Expand All @@ -42,78 +41,6 @@ class Std_Op<string mnemonic, list<Trait> traits = []> :
let parser = [{ return ::parse$cppClass(parser, result); }];
}

//===----------------------------------------------------------------------===//
// AssertOp
//===----------------------------------------------------------------------===//

def AssertOp : Std_Op<"assert"> {
let summary = "Assert operation with message attribute";
let description = [{
Assert operation with single boolean operand and an error message attribute.
If the argument is `true` this operation has no effect. Otherwise, the
program execution will abort. The provided error message may be used by a
runtime to propagate the error to the user.

Example:

```mlir
assert %b, "Expected ... to be true"
```
}];

let arguments = (ins I1:$arg, StrAttr:$msg);

let assemblyFormat = "$arg `,` $msg attr-dict";
let hasCanonicalizeMethod = 1;
}

//===----------------------------------------------------------------------===//
// BranchOp
//===----------------------------------------------------------------------===//

def BranchOp : Std_Op<"br",
[DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
NoSideEffect, Terminator]> {
let summary = "branch operation";
let description = [{
The `br` operation represents a branch operation in a function.
The operation takes variable number of operands and produces no results.
The operand number and types for each successor must match the arguments of
the block successor.

Example:

```mlir
^bb2:
%2 = call @someFn()
br ^bb3(%2 : tensor<*xf32>)
^bb3(%3: tensor<*xf32>):
```
}];

let arguments = (ins Variadic<AnyType>:$destOperands);
let successors = (successor AnySuccessor:$dest);

let builders = [
OpBuilder<(ins "Block *":$dest,
CArg<"ValueRange", "{}">:$destOperands), [{
$_state.addSuccessors(dest);
$_state.addOperands(destOperands);
}]>];

let extraClassDeclaration = [{
void setDest(Block *block);

/// Erase the operand at 'index' from the operand list.
void eraseOperand(unsigned index);
}];

let hasCanonicalizeMethod = 1;
let assemblyFormat = [{
$dest (`(` $destOperands^ `:` type($destOperands) `)`)? attr-dict
}];
}

//===----------------------------------------------------------------------===//
// CallOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -246,121 +173,6 @@ def CallIndirectOp : Std_Op<"call_indirect", [
"$callee `(` $callee_operands `)` attr-dict `:` type($callee)";
}

//===----------------------------------------------------------------------===//
// CondBranchOp
//===----------------------------------------------------------------------===//

def CondBranchOp : Std_Op<"cond_br",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
NoSideEffect, Terminator]> {
let summary = "conditional branch operation";
let description = [{
The `cond_br` terminator operation represents a conditional branch on a
boolean (1-bit integer) value. If the bit is set, then the first destination
is jumped to; if it is false, the second destination is chosen. The count
and types of operands must align with the arguments in the corresponding
target blocks.

The MLIR conditional branch operation is not allowed to target the entry
block for a region. The two destinations of the conditional branch operation
are allowed to be the same.

The following example illustrates a function with a conditional branch
operation that targets the same block.

Example:

```mlir
func @select(%a: i32, %b: i32, %flag: i1) -> i32 {
// Both targets are the same, operands differ
cond_br %flag, ^bb1(%a : i32), ^bb1(%b : i32)

^bb1(%x : i32) :
return %x : i32
}
```
}];

let arguments = (ins I1:$condition,
Variadic<AnyType>:$trueDestOperands,
Variadic<AnyType>:$falseDestOperands);
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);

let builders = [
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
"ValueRange":$trueOperands, "Block *":$falseDest,
"ValueRange":$falseOperands), [{
build($_builder, $_state, condition, trueOperands, falseOperands, trueDest,
falseDest);
}]>,
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
"Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands), [{
build($_builder, $_state, condition, trueDest, ValueRange(), falseDest,
falseOperands);
}]>];

let extraClassDeclaration = [{
// These are the indices into the dests list.
enum { trueIndex = 0, falseIndex = 1 };

// Accessors for operands to the 'true' destination.
Value getTrueOperand(unsigned idx) {
assert(idx < getNumTrueOperands());
return getOperand(getTrueDestOperandIndex() + idx);
}

void setTrueOperand(unsigned idx, Value value) {
assert(idx < getNumTrueOperands());
setOperand(getTrueDestOperandIndex() + idx, value);
}

unsigned getNumTrueOperands() { return getTrueOperands().size(); }

/// Erase the operand at 'index' from the true operand list.
void eraseTrueOperand(unsigned index) {
getTrueDestOperandsMutable().erase(index);
}

// Accessors for operands to the 'false' destination.
Value getFalseOperand(unsigned idx) {
assert(idx < getNumFalseOperands());
return getOperand(getFalseDestOperandIndex() + idx);
}
void setFalseOperand(unsigned idx, Value value) {
assert(idx < getNumFalseOperands());
setOperand(getFalseDestOperandIndex() + idx, value);
}

operand_range getTrueOperands() { return getTrueDestOperands(); }
operand_range getFalseOperands() { return getFalseDestOperands(); }

unsigned getNumFalseOperands() { return getFalseOperands().size(); }

/// Erase the operand at 'index' from the false operand list.
void eraseFalseOperand(unsigned index) {
getFalseDestOperandsMutable().erase(index);
}

private:
/// Get the index of the first true destination operand.
unsigned getTrueDestOperandIndex() { return 1; }

/// Get the index of the first false destination operand.
unsigned getFalseDestOperandIndex() {
return getTrueDestOperandIndex() + getNumTrueOperands();
}
}];

let hasCanonicalizer = 1;
let assemblyFormat = [{
$condition `,`
$trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,`
$falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)?
attr-dict
}];
}

//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -443,93 +255,4 @@ def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">,
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//

def SwitchOp : Std_Op<"switch",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
NoSideEffect, Terminator]> {
let summary = "switch operation";
let description = [{
The `switch` terminator operation represents a switch on a signless integer
value. If the flag matches one of the specified cases, then the
corresponding destination is jumped to. If the flag does not match any of
the cases, the default destination is jumped to. The count and types of
operands must align with the arguments in the corresponding target blocks.

Example:

```mlir
switch %flag : i32, [
default: ^bb1(%a : i32),
42: ^bb1(%b : i32),
43: ^bb3(%c : i32)
]
```
}];

let arguments = (ins
AnyInteger:$flag,
Variadic<AnyType>:$defaultOperands,
VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
OptionalAttr<AnyIntElementsAttr>:$case_values,
I32ElementsAttr:$case_operand_segments
);
let successors = (successor
AnySuccessor:$defaultDestination,
VariadicSuccessor<AnySuccessor>:$caseDestinations
);
let builders = [
OpBuilder<(ins "Value":$flag,
"Block *":$defaultDestination,
"ValueRange":$defaultOperands,
CArg<"ArrayRef<APInt>", "{}">:$caseValues,
CArg<"BlockRange", "{}">:$caseDestinations,
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
OpBuilder<(ins "Value":$flag,
"Block *":$defaultDestination,
"ValueRange":$defaultOperands,
CArg<"ArrayRef<int32_t>", "{}">:$caseValues,
CArg<"BlockRange", "{}">:$caseDestinations,
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
OpBuilder<(ins "Value":$flag,
"Block *":$defaultDestination,
"ValueRange":$defaultOperands,
CArg<"DenseIntElementsAttr", "{}">:$caseValues,
CArg<"BlockRange", "{}">:$caseDestinations,
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>
];

let assemblyFormat = [{
$flag `:` type($flag) `,` `[` `\n`
custom<SwitchOpCases>(ref(type($flag)),$defaultDestination,
$defaultOperands,
type($defaultOperands),
$case_values,
$caseDestinations,
$caseOperands,
type($caseOperands))
`]`
attr-dict
}];

let extraClassDeclaration = [{
/// Return the operands for the case destination block at the given index.
OperandRange getCaseOperands(unsigned index) {
return getCaseOperands()[index];
}

/// Return a mutable range of operands for the case destination block at the
/// given index.
MutableOperandRange getCaseOperandsMutable(unsigned index) {
return getCaseOperandsMutable()[index];
}
}];

let hasCanonicalizer = 1;
let hasVerifier = 1;
}

#endif // STANDARD_OPS
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
Expand Down Expand Up @@ -61,6 +62,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
arm_neon::ArmNeonDialect,
async::AsyncDialect,
bufferization::BufferizationDialect,
cf::ControlFlowDialect,
complex::ComplexDialect,
DLTIDialect,
emitc::EmitCDialect,
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ add_subdirectory(AsyncToLLVM)
add_subdirectory(BufferizationToMemRef)
add_subdirectory(ComplexToLLVM)
add_subdirectory(ComplexToStandard)
add_subdirectory(ControlFlowToLLVM)
add_subdirectory(ControlFlowToSPIRV)
add_subdirectory(GPUCommon)
add_subdirectory(GPUToNVVM)
add_subdirectory(GPUToROCDL)
Expand All @@ -25,10 +27,10 @@ add_subdirectory(OpenACCToSCF)
add_subdirectory(OpenMPToLLVM)
add_subdirectory(PDLToPDLInterp)
add_subdirectory(ReconcileUnrealizedCasts)
add_subdirectory(SCFToControlFlow)
add_subdirectory(SCFToGPU)
add_subdirectory(SCFToOpenMP)
add_subdirectory(SCFToSPIRV)
add_subdirectory(SCFToStandard)
add_subdirectory(ShapeToStandard)
add_subdirectory(SPIRVToLLVM)
add_subdirectory(StandardToLLVM)
Expand Down
21 changes: 21 additions & 0 deletions mlir/lib/Conversion/ControlFlowToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
add_mlir_conversion_library(MLIRControlFlowToLLVM
ControlFlowToLLVM.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ControlFlowToLLVM

DEPENDS
MLIRConversionPassIncGen
intrinsics_gen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIRAnalysis
MLIRControlFlow
MLIRLLVMCommonConversion
MLIRLLVMIR
MLIRPass
MLIRTransformUtils
)
148 changes: 148 additions & 0 deletions mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
//===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert MLIR standard and builtin dialects
// into the LLVM IR dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "../PassDetail.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include <functional>

using namespace mlir;

#define PASS_NAME "convert-cf-to-llvm"

namespace {
/// Lower `std.assert`. The default lowering calls the `abort` function if the
/// assertion is violated and has no effect otherwise. The failure message is
/// ignored by the default lowering but should be propagated by any custom
/// lowering.
struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();

// Insert the `abort` declaration if necessary.
auto module = op->getParentOfType<ModuleOp>();
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
if (!abortFunc) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
"abort", abortFuncTy);
}

// Split block at `assert` operation.
Block *opBlock = rewriter.getInsertionBlock();
auto opPosition = rewriter.getInsertionPoint();
Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);

// Generate IR to call `abort`.
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None);
rewriter.create<LLVM::UnreachableOp>(loc);

// Generate assertion test.
rewriter.setInsertionPointToEnd(opBlock);
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
op, adaptor.getArg(), continuationBlock, failureBlock);

return success();
}
};

// Base class for LLVM IR lowering terminator operations with successors.
template <typename SourceOp, typename TargetOp>
struct OneToOneLLVMTerminatorLowering
: public ConvertOpToLLVMPattern<SourceOp> {
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
using Base = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;

LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(),
op->getSuccessors(), op->getAttrs());
return success();
}
};

// FIXME: this should be tablegen'ed as well.
struct BranchOpLowering
: public OneToOneLLVMTerminatorLowering<cf::BranchOp, LLVM::BrOp> {
using Base::Base;
};
struct CondBranchOpLowering
: public OneToOneLLVMTerminatorLowering<cf::CondBranchOp, LLVM::CondBrOp> {
using Base::Base;
};
struct SwitchOpLowering
: public OneToOneLLVMTerminatorLowering<cf::SwitchOp, LLVM::SwitchOp> {
using Base::Base;
};

} // namespace

void mlir::cf::populateControlFlowToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
// clang-format off
patterns.add<
AssertOpLowering,
BranchOpLowering,
CondBranchOpLowering,
SwitchOpLowering>(converter);
// clang-format on
}

//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//

namespace {
/// A pass converting MLIR operations into the LLVM IR dialect.
struct ConvertControlFlowToLLVM
: public ConvertControlFlowToLLVMBase<ConvertControlFlowToLLVM> {
ConvertControlFlowToLLVM() = default;

/// Run the dialect converter on the module.
void runOnOperation() override {
LLVMConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());

LowerToLLVMOptions options(&getContext());
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
options.overrideIndexBitwidth(indexBitwidth);

LLVMTypeConverter converter(&getContext(), options);
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
} // namespace

std::unique_ptr<Pass> mlir::cf::createConvertControlFlowToLLVMPass() {
return std::make_unique<ConvertControlFlowToLLVM>();
}
19 changes: 19 additions & 0 deletions mlir/lib/Conversion/ControlFlowToSPIRV/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
add_mlir_conversion_library(MLIRControlFlowToSPIRV
ControlFlowToSPIRV.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR

DEPENDS
MLIRConversionPassIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRControlFlow
MLIRPass
MLIRSPIRV
MLIRSPIRVConversion
MLIRSupport
MLIRTransformUtils
)
73 changes: 73 additions & 0 deletions mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
//===- ControlFlowToSPIRV.cpp - ControlFlow to SPIR-V Patterns ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert standard dialect to SPIR-V dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
#include "../SPIRVCommon/Pattern.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "cf-to-spirv-pattern"

using namespace mlir;

//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//

namespace {

/// Converts cf.br to spv.Branch.
struct BranchOpPattern final : public OpConversionPattern<cf::BranchOp> {
using OpConversionPattern<cf::BranchOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
adaptor.getDestOperands());
return success();
}
};

/// Converts cf.cond_br to spv.BranchConditional.
struct CondBranchOpPattern final
: public OpConversionPattern<cf::CondBranchOp> {
using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
op, op.getCondition(), op.getTrueDest(), adaptor.getTrueDestOperands(),
op.getFalseDest(), adaptor.getFalseDestOperands());
return success();
}
};
} // namespace

//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//

void mlir::cf::populateControlFlowToSPIRVPatterns(
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();

patterns.add<BranchOpPattern, CondBranchOpPattern>(typeConverter, context);
}
5 changes: 3 additions & 2 deletions mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"

#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
Expand Down Expand Up @@ -172,8 +173,8 @@ struct LowerGpuOpsToNVVMOpsPass
populateGpuRewritePatterns(patterns);
(void)applyPatternsAndFoldGreedily(m, std::move(patterns));

mlir::arith::populateArithmeticToLLVMConversionPatterns(converter,
llvmPatterns);
arith::populateArithmeticToLLVMConversionPatterns(converter, llvmPatterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
populateStdToLLVMConversionPatterns(converter, llvmPatterns);
populateMemRefToLLVMConversionPatterns(converter, llvmPatterns);
populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/LinalgToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ add_mlir_conversion_library(MLIRLinalgToLLVM
MLIRLLVMCommonConversion
MLIRLLVMIR
MLIRMemRefToLLVM
MLIRSCFToStandard
MLIRSCFToControlFlow
MLIRTransforms
MLIRVectorToLLVM
MLIRVectorToSCF
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
/// +---------------------------------+
/// | <code before the AtomicRMWOp> |
/// | <compute initial %loaded> |
/// | br loop(%loaded) |
/// | cf.br loop(%loaded) |
/// +---------------------------------+
/// |
/// -------| |
Expand All @@ -444,7 +444,7 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
/// | | %pair = cmpxchg |
/// | | %ok = %pair[0] |
/// | | %new = %pair[1] |
/// | | cond_br %ok, end, loop(%new) |
/// | | cf.cond_br %ok, end, loop(%new) |
/// | +--------------------------------+
/// | | |
/// |----------- |
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "../PassDetail.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
Expand Down Expand Up @@ -66,7 +67,8 @@ void ConvertOpenMPToLLVMPass::runOnOperation() {
// Convert to OpenMP operations with LLVM IR dialect
RewritePatternSet patterns(&getContext());
LLVMTypeConverter converter(&getContext());
mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, patterns);
arith::populateArithmeticToLLVMConversionPatterns(converter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
populateMemRefToLLVMConversionPatterns(converter, patterns);
populateStdToLLVMConversionPatterns(converter, patterns);
populateOpenMPToLLVMConversionPatterns(converter, patterns);
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Conversion/PassDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ namespace arith {
class ArithmeticDialect;
} // namespace arith

namespace cf {
class ControlFlowDialect;
} // namespace cf

namespace complex {
class ComplexDialect;
} // namespace complex
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
add_mlir_conversion_library(MLIRSCFToStandard
SCFToStandard.cpp
add_mlir_conversion_library(MLIRSCFToControlFlow
SCFToControlFlow.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToStandard
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToControlFlow

DEPENDS
MLIRConversionPassIncGen
Expand All @@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRSCFToStandard

LINK_LIBS PUBLIC
MLIRArithmetic
MLIRControlFlow
MLIRSCF
MLIRTransforms
)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- SCFToStandard.cpp - ControlFlow to CFG conversion ------------------===//
//===- SCFToControlFlow.cpp - SCF to CF conversion ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -11,11 +11,11 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
Expand All @@ -29,7 +29,8 @@ using namespace mlir::scf;

namespace {

struct SCFToStandardPass : public SCFToStandardBase<SCFToStandardPass> {
struct SCFToControlFlowPass
: public SCFToControlFlowBase<SCFToControlFlowPass> {
void runOnOperation() override;
};

Expand Down Expand Up @@ -57,15 +58,15 @@ struct SCFToStandardPass : public SCFToStandardBase<SCFToStandardPass> {
// | <code before the ForOp> |
// | <definitions of %init...> |
// | <compute initial %iv value> |
// | br cond(%iv, %init...) |
// | cf.br cond(%iv, %init...) |
// +---------------------------------+
// |
// -------| |
// | v v
// | +--------------------------------+
// | | cond(%iv, %init...): |
// | | <compare %iv to upper bound> |
// | | cond_br %r, body, end |
// | | cf.cond_br %r, body, end |
// | +--------------------------------+
// | | |
// | | -------------|
Expand All @@ -83,7 +84,7 @@ struct SCFToStandardPass : public SCFToStandardBase<SCFToStandardPass> {
// | | <body contents> | |
// | | <operands of yield = %yields>| |
// | | %new_iv =<add step to %iv> | |
// | | br cond(%new_iv, %yields) | |
// | | cf.br cond(%new_iv, %yields) | |
// | +--------------------------------+ |
// | | |
// |----------- |--------------------
Expand Down Expand Up @@ -125,23 +126,23 @@ struct ForLowering : public OpRewritePattern<ForOp> {
//
// +--------------------------------+
// | <code before the IfOp> |
// | cond_br %cond, %then, %else |
// | cf.cond_br %cond, %then, %else |
// +--------------------------------+
// | |
// | --------------|
// v |
// +--------------------------------+ |
// | then: | |
// | <then contents> | |
// | br continue | |
// | cf.br continue | |
// +--------------------------------+ |
// | |
// |---------- |-------------
// | V
// | +--------------------------------+
// | | else: |
// | | <else contents> |
// | | br continue |
// | | cf.br continue |
// | +--------------------------------+
// | |
// ------| |
Expand All @@ -155,30 +156,30 @@ struct ForLowering : public OpRewritePattern<ForOp> {
//
// +--------------------------------+
// | <code before the IfOp> |
// | cond_br %cond, %then, %else |
// | cf.cond_br %cond, %then, %else |
// +--------------------------------+
// | |
// | --------------|
// v |
// +--------------------------------+ |
// | then: | |
// | <then contents> | |
// | br dom(%args...) | |
// | cf.br dom(%args...) | |
// +--------------------------------+ |
// | |
// |---------- |-------------
// | V
// | +--------------------------------+
// | | else: |
// | | <else contents> |
// | | br dom(%args...) |
// | | cf.br dom(%args...) |
// | +--------------------------------+
// | |
// ------| |
// v v
// +--------------------------------+
// | dom(%args...): |
// | br continue |
// | cf.br continue |
// +--------------------------------+
// |
// v
Expand Down Expand Up @@ -218,7 +219,7 @@ struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
///
/// +---------------------------------+
/// | <code before the WhileOp> |
/// | br ^before(%operands...) |
/// | cf.br ^before(%operands...) |
/// +---------------------------------+
/// |
/// -------| |
Expand All @@ -233,7 +234,7 @@ struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
/// | +--------------------------------+
/// | | ^before-last:
/// | | %cond = <compute condition> |
/// | | cond_br %cond, |
/// | | cf.cond_br %cond, |
/// | | ^after(%vals...), ^cont |
/// | +--------------------------------+
/// | | |
Expand All @@ -249,7 +250,7 @@ struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
/// | +--------------------------------+ |
/// | | ^after-last: | |
/// | | %yields... = <some payload> | |
/// | | br ^before(%yields...) | |
/// | | cf.br ^before(%yields...) | |
/// | +--------------------------------+ |
/// | | |
/// |----------- |--------------------
Expand Down Expand Up @@ -321,7 +322,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
SmallVector<Value, 8> loopCarried;
loopCarried.push_back(stepped);
loopCarried.append(terminator->operand_begin(), terminator->operand_end());
rewriter.create<BranchOp>(loc, conditionBlock, loopCarried);
rewriter.create<cf::BranchOp>(loc, conditionBlock, loopCarried);
rewriter.eraseOp(terminator);

// Compute loop bounds before branching to the condition.
Expand All @@ -337,15 +338,16 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
destOperands.push_back(lowerBound);
auto iterOperands = forOp.getIterOperands();
destOperands.append(iterOperands.begin(), iterOperands.end());
rewriter.create<BranchOp>(loc, conditionBlock, destOperands);
rewriter.create<cf::BranchOp>(loc, conditionBlock, destOperands);

// With the body block done, we can fill in the condition block.
rewriter.setInsertionPointToEnd(conditionBlock);
auto comparison = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, iv, upperBound);

rewriter.create<CondBranchOp>(loc, comparison, firstBodyBlock,
ArrayRef<Value>(), endBlock, ArrayRef<Value>());
rewriter.create<cf::CondBranchOp>(loc, comparison, firstBodyBlock,
ArrayRef<Value>(), endBlock,
ArrayRef<Value>());
// The result of the loop operation is the values of the condition block
// arguments except the induction variable on the last iteration.
rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front());
Expand All @@ -369,7 +371,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
continueBlock =
rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(),
SmallVector<Location>(ifOp.getNumResults(), loc));
rewriter.create<BranchOp>(loc, remainingOpsBlock);
rewriter.create<cf::BranchOp>(loc, remainingOpsBlock);
}

// Move blocks from the "then" region to the region containing 'scf.if',
Expand All @@ -379,7 +381,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
Operation *thenTerminator = thenRegion.back().getTerminator();
ValueRange thenTerminatorOperands = thenTerminator->getOperands();
rewriter.setInsertionPointToEnd(&thenRegion.back());
rewriter.create<BranchOp>(loc, continueBlock, thenTerminatorOperands);
rewriter.create<cf::BranchOp>(loc, continueBlock, thenTerminatorOperands);
rewriter.eraseOp(thenTerminator);
rewriter.inlineRegionBefore(thenRegion, continueBlock);

Expand All @@ -393,15 +395,15 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
Operation *elseTerminator = elseRegion.back().getTerminator();
ValueRange elseTerminatorOperands = elseTerminator->getOperands();
rewriter.setInsertionPointToEnd(&elseRegion.back());
rewriter.create<BranchOp>(loc, continueBlock, elseTerminatorOperands);
rewriter.create<cf::BranchOp>(loc, continueBlock, elseTerminatorOperands);
rewriter.eraseOp(elseTerminator);
rewriter.inlineRegionBefore(elseRegion, continueBlock);
}

rewriter.setInsertionPointToEnd(condBlock);
rewriter.create<CondBranchOp>(loc, ifOp.getCondition(), thenBlock,
/*trueArgs=*/ArrayRef<Value>(), elseBlock,
/*falseArgs=*/ArrayRef<Value>());
rewriter.create<cf::CondBranchOp>(loc, ifOp.getCondition(), thenBlock,
/*trueArgs=*/ArrayRef<Value>(), elseBlock,
/*falseArgs=*/ArrayRef<Value>());

// Ok, we're done!
rewriter.replaceOp(ifOp, continueBlock->getArguments());
Expand All @@ -419,13 +421,13 @@ ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,

auto &region = op.getRegion();
rewriter.setInsertionPointToEnd(condBlock);
rewriter.create<BranchOp>(loc, &region.front());
rewriter.create<cf::BranchOp>(loc, &region.front());

for (Block &block : region) {
if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
ValueRange terminatorOperands = terminator->getOperands();
rewriter.setInsertionPointToEnd(&block);
rewriter.create<BranchOp>(loc, remainingOpsBlock, terminatorOperands);
rewriter.create<cf::BranchOp>(loc, remainingOpsBlock, terminatorOperands);
rewriter.eraseOp(terminator);
}
}
Expand Down Expand Up @@ -538,20 +540,21 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,

// Branch to the "before" region.
rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<BranchOp>(loc, before, whileOp.getInits());
rewriter.create<cf::BranchOp>(loc, before, whileOp.getInits());

// Replace terminators with branches. Assuming bodies are SESE, which holds
// given only the patterns from this file, we only need to look at the last
// block. This should be reconsidered if we allow break/continue in SCF.
rewriter.setInsertionPointToEnd(beforeLast);
auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
rewriter.replaceOpWithNewOp<CondBranchOp>(condOp, condOp.getCondition(),
after, condOp.getArgs(),
continuation, ValueRange());
rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
after, condOp.getArgs(),
continuation, ValueRange());

rewriter.setInsertionPointToEnd(afterLast);
auto yieldOp = cast<scf::YieldOp>(afterLast->getTerminator());
rewriter.replaceOpWithNewOp<BranchOp>(yieldOp, before, yieldOp.getResults());
rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
yieldOp.getResults());

// Replace the op with values "yielded" from the "before" region, which are
// visible by dominance.
Expand Down Expand Up @@ -593,14 +596,14 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,

// Branch to the "before" region.
rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<BranchOp>(whileOp.getLoc(), before, whileOp.getInits());
rewriter.create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits());

// Loop around the "before" region based on condition.
rewriter.setInsertionPointToEnd(beforeLast);
auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
rewriter.replaceOpWithNewOp<CondBranchOp>(condOp, condOp.getCondition(),
before, condOp.getArgs(),
continuation, ValueRange());
rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
before, condOp.getArgs(),
continuation, ValueRange());

// Replace the op with values "yielded" from the "before" region, which are
// visible by dominance.
Expand All @@ -609,17 +612,18 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
return success();
}

void mlir::populateLoopToStdConversionPatterns(RewritePatternSet &patterns) {
void mlir::populateSCFToControlFlowConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
ExecuteRegionLowering>(patterns.getContext());
patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
}

void SCFToStandardPass::runOnOperation() {
void SCFToControlFlowPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateLoopToStdConversionPatterns(patterns);
// Configure conversion to lower out scf.for, scf.if, scf.parallel and
// scf.while. Anything else is fine.
populateSCFToControlFlowConversionPatterns(patterns);

// Configure conversion to lower out SCF operations.
ConversionTarget target(getContext());
target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp,
scf::ExecuteRegionOp>();
Expand All @@ -629,6 +633,6 @@ void SCFToStandardPass::runOnOperation() {
signalPassFailure();
}

std::unique_ptr<Pass> mlir::createLowerToCFGPass() {
return std::make_unique<SCFToStandardPass>();
std::unique_ptr<Pass> mlir::createConvertSCFToCFPass() {
return std::make_unique<SCFToControlFlowPass>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"

#include "../PassDetail.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
Expand All @@ -29,7 +30,7 @@ class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(shape::CstrRequireOp op,
PatternRewriter &rewriter) const override {
rewriter.create<AssertOp>(op.getLoc(), op.getPred(), op.getMsgAttr());
rewriter.create<cf::AssertOp>(op.getLoc(), op.getPred(), op.getMsgAttr());
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
return success();
}
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ add_mlir_conversion_library(MLIRStandardToLLVM
LINK_LIBS PUBLIC
MLIRAnalysis
MLIRArithmeticToLLVM
MLIRControlFlowToLLVM
MLIRDataLayoutInterfaces
MLIRLLVMCommonConversion
MLIRLLVMIR
Expand Down
81 changes: 3 additions & 78 deletions mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "../PassDetail.h"
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
Expand Down Expand Up @@ -387,48 +388,6 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
}
};

/// Lower `std.assert`. The default lowering calls the `abort` function if the
/// assertion is violated and has no effect otherwise. The failure message is
/// ignored by the default lowering but should be propagated by any custom
/// lowering.
struct AssertOpLowering : public ConvertOpToLLVMPattern<AssertOp> {
using ConvertOpToLLVMPattern<AssertOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();

// Insert the `abort` declaration if necessary.
auto module = op->getParentOfType<ModuleOp>();
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
if (!abortFunc) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
"abort", abortFuncTy);
}

// Split block at `assert` operation.
Block *opBlock = rewriter.getInsertionBlock();
auto opPosition = rewriter.getInsertionPoint();
Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);

// Generate IR to call `abort`.
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None);
rewriter.create<LLVM::UnreachableOp>(loc);

// Generate assertion test.
rewriter.setInsertionPointToEnd(opBlock);
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
op, adaptor.getArg(), continuationBlock, failureBlock);

return success();
}
};

struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
using ConvertOpToLLVMPattern<ConstantOp>::ConvertOpToLLVMPattern;

Expand Down Expand Up @@ -550,22 +509,6 @@ struct UnrealizedConversionCastOpLowering
}
};

// Base class for LLVM IR lowering terminator operations with successors.
template <typename SourceOp, typename TargetOp>
struct OneToOneLLVMTerminatorLowering
: public ConvertOpToLLVMPattern<SourceOp> {
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;

LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(),
op->getSuccessors(), op->getAttrs());
return success();
}
};

// Special lowering pattern for `ReturnOps`. Unlike all other operations,
// `ReturnOp` interacts with the function signature and must have as many
// operands as the function has return values. Because in LLVM IR, functions
Expand Down Expand Up @@ -633,21 +576,6 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
return success();
}
};

// FIXME: this should be tablegen'ed as well.
struct BranchOpLowering
: public OneToOneLLVMTerminatorLowering<BranchOp, LLVM::BrOp> {
using Super::Super;
};
struct CondBranchOpLowering
: public OneToOneLLVMTerminatorLowering<CondBranchOp, LLVM::CondBrOp> {
using Super::Super;
};
struct SwitchOpLowering
: public OneToOneLLVMTerminatorLowering<SwitchOp, LLVM::SwitchOp> {
using Super::Super;
};

} // namespace

void mlir::populateStdToLLVMFuncOpConversionPattern(
Expand All @@ -663,14 +591,10 @@ void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
populateStdToLLVMFuncOpConversionPattern(converter, patterns);
// clang-format off
patterns.add<
AssertOpLowering,
BranchOpLowering,
CallIndirectOpLowering,
CallOpLowering,
CondBranchOpLowering,
ConstantOpLowering,
ReturnOpLowering,
SwitchOpLowering>(converter);
ReturnOpLowering>(converter);
// clang-format on
}

Expand Down Expand Up @@ -721,6 +645,7 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
RewritePatternSet patterns(&getContext());
populateStdToLLVMConversionPatterns(typeConverter, patterns);
arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);

LLVMConversionTarget target(getContext());
if (failed(applyPartialConversion(m, target, std::move(patterns))))
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRStandardToSPIRV

LINK_LIBS PUBLIC
MLIRArithmeticToSPIRV
MLIRControlFlowToSPIRV
MLIRIR
MLIRMathToSPIRV
MLIRMemRef
Expand Down
46 changes: 1 addition & 45 deletions mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,6 @@ class ReturnOpPattern final : public OpConversionPattern<ReturnOp> {
ConversionPatternRewriter &rewriter) const override;
};

/// Converts std.br to spv.Branch.
struct BranchOpPattern final : public OpConversionPattern<BranchOp> {
using OpConversionPattern<BranchOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(BranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

/// Converts std.cond_br to spv.BranchConditional.
struct CondBranchOpPattern final : public OpConversionPattern<CondBranchOp> {
using OpConversionPattern<CondBranchOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(CondBranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

/// Converts tensor.extract into loading using access chains from SPIR-V local
/// variables.
class TensorExtractPattern final
Expand Down Expand Up @@ -146,31 +128,6 @@ ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, OpAdaptor adaptor,
return success();
}

//===----------------------------------------------------------------------===//
// BranchOpPattern
//===----------------------------------------------------------------------===//

LogicalResult
BranchOpPattern::matchAndRewrite(BranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
adaptor.getDestOperands());
return success();
}

//===----------------------------------------------------------------------===//
// CondBranchOpPattern
//===----------------------------------------------------------------------===//

LogicalResult CondBranchOpPattern::matchAndRewrite(
CondBranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
op, op.getCondition(), op.getTrueDest(), adaptor.getTrueDestOperands(),
op.getFalseDest(), adaptor.getFalseDestOperands());
return success();
}

//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//
Expand All @@ -189,8 +146,7 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>,
spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLSLUMinOp>,

ReturnOpPattern, BranchOpPattern, CondBranchOpPattern>(typeConverter,
context);
ReturnOpPattern>(typeConverter, context);
}

void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
Expand Down
5 changes: 4 additions & 1 deletion mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
#include "../PassDetail.h"
#include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h"
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
#include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h"
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
Expand Down Expand Up @@ -40,9 +41,11 @@ void ConvertStandardToSPIRVPass::runOnOperation() {
options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);

// TODO ArithmeticToSPIRV cannot be applied separately to StandardToSPIRV
// TODO ArithmeticToSPIRV/ControlFlowToSPIRV cannot be applied separately to
// StandardToSPIRV
RewritePatternSet patterns(context);
arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
cf::populateControlFlowToSPIRVPatterns(typeConverter, patterns);
populateMathToSPIRVPatterns(typeConverter, patterns);
populateStandardToSPIRVPatterns(typeConverter, patterns);
populateTensorToSPIRVPatterns(typeConverter, /*byteCountThreshold=*/64,
Expand Down
15 changes: 8 additions & 7 deletions mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Analysis/Liveness.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Async/Passes.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
Expand Down Expand Up @@ -169,11 +170,11 @@ class AsyncRuntimeRefCountingPass
///
/// ^entry:
/// %token = async.runtime.create : !async.token
/// cond_br %cond, ^bb1, ^bb2
/// cf.cond_br %cond, ^bb1, ^bb2
/// ^bb1:
/// async.runtime.await %token
/// async.runtime.drop_ref %token
/// br ^bb2
/// cf.br ^bb2
/// ^bb2:
/// return
///
Expand All @@ -185,14 +186,14 @@ class AsyncRuntimeRefCountingPass
///
/// ^entry:
/// %token = async.runtime.create : !async.token
/// cond_br %cond, ^bb1, ^reference_counting
/// cf.cond_br %cond, ^bb1, ^reference_counting
/// ^bb1:
/// async.runtime.await %token
/// async.runtime.drop_ref %token
/// br ^bb2
/// cf.br ^bb2
/// ^reference_counting:
/// async.runtime.drop_ref %token
/// br ^bb2
/// cf.br ^bb2
/// ^bb2:
/// return
///
Expand All @@ -208,7 +209,7 @@ class AsyncRuntimeRefCountingPass
/// async.coro.suspend %ret, ^suspend, ^resume, ^cleanup
/// ^resume:
/// %0 = async.runtime.load %value
/// br ^cleanup
/// cf.br ^cleanup
/// ^cleanup:
/// ...
/// ^suspend:
Expand Down Expand Up @@ -406,7 +407,7 @@ AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor(
refCountingBlock = &successor->getParent()->emplaceBlock();
refCountingBlock->moveBefore(successor);
OpBuilder builder = OpBuilder::atBlockEnd(refCountingBlock);
builder.create<BranchOp>(value.getLoc(), successor);
builder.create<cf::BranchOp>(value.getLoc(), successor);
}

OpBuilder builder = OpBuilder::atBlockBegin(refCountingBlock);
Expand Down
Loading