Skip to content

Commit

Permalink
[Relay][VM] Fix compilation of If-Elses (apache#5040)
Browse files Browse the repository at this point in the history
  • Loading branch information
wweic authored and zhiics committed Apr 17, 2020
1 parent 84b01c3 commit 7916100
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,9 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
this->Emit(Instruction::If(test_register, target_register, 0, 0));
this->VisitExpr(if_node->true_branch);

size_t true_register = last_register_;
// It saves the result of If-Else expression.
auto merge_register = NewRegister();
Emit(Instruction::Move(last_register_, merge_register));
Emit(Instruction::Goto(0));

// Finally store how many instructions there are in the
Expand All @@ -378,7 +380,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
size_t false_register = last_register_;

// In else-branch, override the then-branch register
Emit(Instruction::Move(false_register, true_register));
Emit(Instruction::Move(false_register, merge_register));
// Compute the total number of instructions
// after generating false.
auto after_false = this->instructions_.size();
Expand All @@ -397,7 +399,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
// Patch the Goto.
this->instructions_[after_true - 1].pc_offset = (after_false - after_true) + 1;

this->last_register_ = true_register;
this->last_register_ = merge_register;
}

void EmitShapeFunc(Function func, Array<Expr> inputs, Array<Expr> outputs) {
Expand Down
19 changes: 19 additions & 0 deletions tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,25 @@ def test_simple_if():
# diff
check_result([x_data, y_data], y_data, mod=mod)

def test_multiple_ifs():
mod = tvm.IRModule({})
b = relay.var('b')
v0 = relay.var('v0')
v1 = relay.var('v1')
v2 = relay.var('v2')
v3 = relay.var('v3')
out = relay.Tuple([v2, v3])
out = relay.Let(v3, relay.If(b, v1, v0), out)
out = relay.Let(v2, relay.If(b, v0, v1), out)
out = relay.Let(v1, relay.Tuple([relay.const(1)]), out)
out = relay.Let(v0, relay.Tuple([relay.const(0)]), out)
fn = relay.Function([b], out)
mod['main'] = fn
ctx = tvm.runtime.ndarray.context('llvm', 0)
vm = relay.create_executor(ctx=ctx, mod=mod, kind='vm')
res = vmobj_to_list(vm.evaluate()(False))
assert(res == [1, 0])

def test_simple_call():
mod = tvm.IRModule({})
sum_up = relay.GlobalVar('sum_up')
Expand Down

0 comments on commit 7916100

Please sign in to comment.