Skip to content

Commit

Permalink
[spirv] Allow return ops to be in control flow ops
Browse files Browse the repository at this point in the history
Use `getParentOfType<FunctionOp>()` instead of `cast<FuncOp>(getParentOp())`
to avoid crash when return ops are used inside spv.selection/spv.loop.

PiperOrigin-RevId: 273006041
  • Loading branch information
antiagainst authored and tensorflower-gardener committed Oct 5, 2019
1 parent 58e2ead commit c020480
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 2 deletions.
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
Expand Up @@ -1745,7 +1745,7 @@ static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
//===----------------------------------------------------------------------===//

static LogicalResult verify(spirv::ReturnOp returnOp) {
auto funcOp = cast<FuncOp>(returnOp.getParentOp());
auto funcOp = returnOp.getParentOfType<FuncOp>();
auto numOutputs = funcOp.getType().getNumResults();
if (numOutputs != 0)
return returnOp.emitOpError("cannot be used in functions returning value")
Expand Down Expand Up @@ -1774,7 +1774,7 @@ static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter &printer) {
}

static LogicalResult verify(spirv::ReturnValueOp retValOp) {
auto funcOp = cast<FuncOp>(retValOp.getParentOp());
auto funcOp = retValOp.getParentOfType<FuncOp>();
auto numFnResults = funcOp.getType().getNumResults();
if (numFnResults != 1)
return retValOp.emitOpError(
Expand Down
66 changes: 66 additions & 0 deletions mlir/test/Dialect/SPIRV/control-flow-ops.mlir
Expand Up @@ -459,6 +459,38 @@ func @only_allowed_in_last_block() -> () {
// spv.Return
//===----------------------------------------------------------------------===//

// CHECK-LABEL: func @in_selection
func @in_selection(%cond : i1) -> () {
spv.selection {
spv.BranchConditional %cond, ^then, ^merge
^then:
// CHECK: spv.Return
spv.Return
^merge:
spv._merge
}
spv.Return
}

// CHECK-LABEL: func @in_loop
func @in_loop(%cond : i1) -> () {
spv.loop {
spv.Branch ^header
^header:
spv.BranchConditional %cond, ^body, ^merge
^body:
// CHECK: spv.Return
spv.Return
^continue:
spv.Branch ^header
^merge:
spv._merge
}
spv.Return
}

// -----

"foo.function"() ({
// expected-error @+1 {{op must appear in a 'func' block}}
spv.Return
Expand Down Expand Up @@ -486,6 +518,40 @@ func @ret_val() -> (i32) {
spv.ReturnValue %0 : i32
}

// CHECK-LABEL: func @in_selection
func @in_selection(%cond : i1) -> (i32) {
spv.selection {
spv.BranchConditional %cond, ^then, ^merge
^then:
%zero = spv.constant 0 : i32
// CHECK: spv.ReturnValue
spv.ReturnValue %zero : i32
^merge:
spv._merge
}
%one = spv.constant 1 : i32
spv.ReturnValue %one : i32
}

// CHECK-LABEL: func @in_loop
func @in_loop(%cond : i1) -> (i32) {
spv.loop {
spv.Branch ^header
^header:
spv.BranchConditional %cond, ^body, ^merge
^body:
%zero = spv.constant 0 : i32
// CHECK: spv.ReturnValue
spv.ReturnValue %zero : i32
^continue:
spv.Branch ^header
^merge:
spv._merge
}
%one = spv.constant 1 : i32
spv.ReturnValue %one : i32
}

// -----

"foo.function"() ({
Expand Down

0 comments on commit c020480

Please sign in to comment.