diff --git a/3rdparty/tvm b/3rdparty/tvm index 8847ba9a6..7b325acd5 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 8847ba9a6562b08b77d0223a33601f34d8100404 +Subproject commit 7b325acd51b8e1a9ed102e4065f7ba206b88b84a diff --git a/bitblas/__init__.py b/bitblas/__init__.py index 4fecc93d7..ef4986419 100644 --- a/bitblas/__init__.py +++ b/bitblas/__init__.py @@ -144,7 +144,10 @@ def remove_tvm_path(path): try_inline, # noqa: F401 try_inline_contiguous_spatial, # noqa: F401 ) - +from .relax import ( + ApplyDefaultSchedule, # noqa: F401 + ApplyFastTuning, # noqa: F401 +) from . import testing # noqa: F401 from .utils import auto_detect_nvidia_target, apply_transform_on_input # noqa: F401 from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401 diff --git a/examples/relax_end2end.py b/examples/relax_end2end.py index 69a2417f2..dc89b63c9 100644 --- a/examples/relax_end2end.py +++ b/examples/relax_end2end.py @@ -18,20 +18,16 @@ import numpy as np import os from typing import Dict -import numpy as np # type: ignore import time import bitblas from bitblas import tvm as tvm -import tvm from tvm import relay, relax, runtime, transform -from tvm.ir.module import IRModule from tvm.relax.testing import relay_translator, nn from tvm.target.target import Target -from tvm import dlight as dl -from tvm import relay import tvm.relay.testing from tvm.ir.module import IRModule from bitblas.relax import ApplyDefaultSchedule, ApplyFastTuning + fname = os.path.basename(__file__) fname = os.path.splitext(fname)[0] # get current file path @@ -41,6 +37,7 @@ bitblas.set_log_level("Debug") + def write_code(code, path, fname): global count fname = str(count) + "." + fname @@ -80,16 +77,7 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): input_shape = (batch_size,) + image_shape output_shape = (batch_size, 1000) - if name.startswith("resnet-"): - n_layer = int(name.split("-")[1]) - mod, params = relay.testing.resnet.get_workload( - num_layers=n_layer, - batch_size=batch_size, - layout=layout, - dtype=dtype, - image_shape=image_shape, - ) - elif name.startswith("resnet3d-"): + if name.startswith("resnet-") or name.startswith("resnet3d-"): n_layer = int(name.split("-")[1]) mod, params = relay.testing.resnet.get_workload( num_layers=n_layer, @@ -100,8 +88,7 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): ) elif name == "mobilenet": mod, params = relay.testing.mobilenet.get_workload( - batch_size=batch_size, layout=layout, dtype=dtype, image_shape=image_shape - ) + batch_size=batch_size, layout=layout, dtype=dtype, image_shape=image_shape) elif name == "squeezenet_v1.1": assert layout == "NCHW", "squeezenet_v1.1 only supports NCHW layout" mod, params = relay.testing.squeezenet.get_workload( @@ -115,8 +102,7 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype) elif name == "mlp": mod, params = relay.testing.mlp.get_workload( - batch_size=batch_size, image_shape=image_shape, dtype=dtype - ) + batch_size=batch_size, image_shape=image_shape, dtype=dtype) return mod, params, input_shape, output_shape @@ -133,9 +119,8 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): relay_mod, params, input_shape, output_shape = get_network(network, batch_size, layout, dtype=dtype) -def apply_opt_before_tuning( - relay_mod: IRModule, params: Dict[str, runtime.NDArray], target: Target -): +def apply_opt_before_tuning(relay_mod: IRModule, params: Dict[str, runtime.NDArray], + target: Target): with transform.PassContext(opt_level=3): main_func = relay_mod["main"] bind_main_func = relay.build_module.bind_params_by_name(main_func, params) @@ -157,7 +142,12 @@ def apply_opt_before_tuning( write_mod(relay_mod, log_path, "FoldConstant") # opt_level=2 and select_impl_strategy are required for avoiding winograd lowering - relax_mod = relay_translator.from_relay(relay_mod["main"], opt_level=2, target=target, append_op_attrs=True, select_impl_strategy="first") + relax_mod = relay_translator.from_relay( + relay_mod["main"], + opt_level=2, + target=target, + append_op_attrs=True, + select_impl_strategy="first") write_mod(relax_mod, log_path, "relay_translator_relax") relax_mod = relax.transform.AnnotateTIROpPattern()(relax_mod) write_mod(relax_mod, log_path, "AnnotateTIROpPattern") @@ -199,7 +189,6 @@ def apply_opt_before_tuning( ex = relax.build(relax_mod, target) write_code(ex.mod.imported_modules[0].imported_modules[0].get_source(), log_path, "tmp.cu") - device = tvm.cuda(0) vm = relax.VirtualMachine(ex, device) @@ -218,10 +207,9 @@ def apply_opt_before_tuning( start = time.time() -for i in range(10): +for _ in range(10): vm["main"](*input_args) - device.sync() end = time.time()