Skip to content

Commit

Permalink
[VM] Move param bind to OptimizeModule (apache#7451)
Browse files Browse the repository at this point in the history
* [VM] Move param bind to OptimizeModule

* add test to verify the number of free vars after opt

* remove const from OptimizeModule
  • Loading branch information
masahi authored and trevor-m committed Mar 2, 2021
1 parent 4182823 commit d7a72d2
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
20 changes: 10 additions & 10 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -892,15 +892,6 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) {
}

void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) {
if (params_.size()) {
BaseFunc base_func = mod->Lookup("main");
ICHECK(base_func->IsInstance<FunctionNode>())
<< "VM compiler expects to compile relay::Function";
auto f = relay::backend::BindParamsByName(Downcast<Function>(base_func), params_);
auto gvar = mod->GetGlobalVar("main");
mod->Add(gvar, f);
}

exec_ = make_object<Executable>();
targets_ = targets;
target_host_ = target_host;
Expand Down Expand Up @@ -1005,8 +996,17 @@ transform::Sequential MemoryOpt(tvm::Target host_target, TargetsMap targets) {
return transform::Sequential(pass_seqs);
}

IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets,
IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetsMap& targets,
const Target& target_host) {
if (params_.size()) {
BaseFunc base_func = mod->Lookup("main");
ICHECK(base_func->IsInstance<FunctionNode>())
<< "VM compiler expects to compile relay::Function";
auto f = relay::backend::BindParamsByName(Downcast<Function>(base_func), params_);
auto gvar = mod->GetGlobalVar("main");
mod->Add(gvar, f);
}

Array<Pass> pass_seqs;
Array<runtime::String> entry_functions{"main"};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
Expand Down
3 changes: 1 addition & 2 deletions src/relay/backend/vm/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,7 @@ class VMCompiler : public runtime::ModuleNode {
*
* \return The optimized IRModule.
*/
IRModule OptimizeModule(const IRModule& mod, const TargetsMap& targets,
const Target& target_host);
IRModule OptimizeModule(IRModule mod, const TargetsMap& targets, const Target& target_host);

/*!
* \brief Populate the global function names in a map where the value is used
Expand Down
4 changes: 4 additions & 0 deletions tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,10 @@ def test_vm_optimize():
comp = relay.vm.VMCompiler()
opt_mod, _ = comp.optimize(mod, target="llvm", params=params)

free_vars = relay.analysis.free_vars(opt_mod["main"].body)
# Paremeters should all be bound, so the only free var is data
assert len(free_vars) == 1


@tvm.testing.uses_gpu
def test_loop_free_var():
Expand Down

0 comments on commit d7a72d2

Please sign in to comment.