Skip to content

Commit

Permalink
[TVMC] use target_host when it is set (apache#6855)
Browse files Browse the repository at this point in the history
* [TVMC] add cl support in tvmc runner

* [TVMC] use target_host when it is set

* Cleanup comment and asssert device type in else case

* add a test for tvmc compiler

* remove unused func
  • Loading branch information
euntaik authored and trevor-m committed Dec 4, 2020
1 parent df19f77 commit 6ce2997
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/tvm/driver/tvmc/compiler.py
Expand Up @@ -178,18 +178,18 @@ def compile_model(
mod = common.convert_graph_layout(mod, alter_layout)

tvm_target = common.target_from_cli(target)
target_host = target_host or ""
target_host = tvm_target if not target_host else target_host

if tuning_records and os.path.exists(tuning_records):
logger.debug("tuning records file provided: %s", tuning_records)
with autotvm.apply_history_best(tuning_records):
with tvm.transform.PassContext(opt_level=3):
logger.debug("building relay graph with tuning records")
graph_module = relay.build(mod, tvm_target, params=params, target_host=tvm_target)
graph_module = relay.build(mod, tvm_target, params=params, target_host=target_host)
else:
with tvm.transform.PassContext(opt_level=3):
logger.debug("building relay graph (no tuning records provided)")
graph_module = relay.build(mod, tvm_target, params=params, target_host=tvm_target)
graph_module = relay.build(mod, tvm_target, params=params, target_host=target_host)

# Generate output dump files with sources
dump_code = dump_code or []
Expand Down
13 changes: 13 additions & 0 deletions tests/python/driver/tvmc/conftest.py
Expand Up @@ -148,3 +148,16 @@ def imagenet_cat(tmpdir_factory):
np.savez(cat_file_full_path, input=image_data)

return cat_file_full_path


@pytest.fixture(scope="session")
def tflite_mobilenet_v1_0_25_128(tmpdir_factory):
base_url = "https://storage.googleapis.com/download.tensorflow.org/models"
model_url = "mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz"
model_file = download_and_untar(
"{}/{}".format(base_url, model_url),
"mobilenet_v1_0.25_128.tflite",
temp_dir=tmpdir_factory.mktemp("data"),
)

return model_file
18 changes: 18 additions & 0 deletions tests/python/driver/tvmc/test_compiler.py
Expand Up @@ -150,3 +150,21 @@ def test_cross_compile_aarch64_onnx_module(onnx_resnet50):
assert type(params) is dict
assert type(dumps) is dict
assert "asm" in dumps.keys()


@tvm.testing.requires_opencl
def test_compile_opencl(tflite_mobilenet_v1_0_25_128):
pytest.importorskip("tflite")

graph, lib, params, dumps = tvmc.compiler.compile_model(
tflite_mobilenet_v1_0_25_128,
target="opencl",
target_host="llvm",
alter_layout="NCHW",
)

# check for output types
assert type(graph) is str
assert type(lib) is tvm.runtime.module.Module
assert type(params) is dict
assert type(dumps) is dict

0 comments on commit 6ce2997

Please sign in to comment.