From 2a0f59cc1f5442241b5ea15dd48bc154ca504a28 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 10 Nov 2024 16:10:46 +0000 Subject: [PATCH 1/3] relax transform update --- bitblas/__init__.py | 2 - bitblas/base/__init__.py | 1 - bitblas/relax/__init__.py | 6 +- bitblas/relax/transform/__init__.py | 7 +- .../transform/apply_fast_tuning.py} | 6 +- examples/.gitignore | 1 + examples/relax_end2end.py | 229 ++++++++++++++++++ 7 files changed, 243 insertions(+), 9 deletions(-) rename bitblas/{base/transform.py => relax/transform/apply_fast_tuning.py} (97%) create mode 100644 examples/.gitignore create mode 100644 examples/relax_end2end.py diff --git a/bitblas/__init__.py b/bitblas/__init__.py index 661556c56..4fecc93d7 100644 --- a/bitblas/__init__.py +++ b/bitblas/__init__.py @@ -137,8 +137,6 @@ def remove_tvm_path(path): from .base import ( TileDevice, # noqa: F401 fast_tune, # noqa: F401 - ApplyDefaultSchedule, # noqa: F401 - ApplyFastTuning, # noqa: F401 BlockInfo, # noqa: F401 IterInfo, # noqa: F401 ScheduleRule, # noqa: F401 diff --git a/bitblas/base/__init__.py b/bitblas/base/__init__.py index c6235ea42..0bee489a8 100644 --- a/bitblas/base/__init__.py +++ b/bitblas/base/__init__.py @@ -12,7 +12,6 @@ ) # noqa: F401 from .common_schedules import get_block, get_output_blocks, try_inline, try_inline_contiguous_spatial # noqa: F401 from .schedule_rule import ScheduleRule # noqa: F401 -from .transform import ApplyDefaultSchedule, ApplyFastTuning # noqa: F401 from .utils import fast_tune, fast_tune_with_dynamic_range # noqa: F401 from .roller import * from .arch import CUDA, CDNA # noqa: F401 diff --git a/bitblas/relax/__init__.py b/bitblas/relax/__init__.py index a7230fd9e..5d056b856 100644 --- a/bitblas/relax/__init__.py +++ b/bitblas/relax/__init__.py @@ -1,5 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .transform import AnnotateDecodeInformation, WeightOnlyLayoutPropagation # noqa: F401 +from .transform import ( + WeightOnlyLayoutPropagation, # noqa: F401 + ApplyDefaultSchedule, # noqa: F401 + ApplyFastTuning, # noqa: F401 +) from .op import tir_interleave_weight # noqa: F401 diff --git a/bitblas/relax/transform/__init__.py b/bitblas/relax/transform/__init__.py index b92f2c0b4..21bd9ba4b 100644 --- a/bitblas/relax/transform/__init__.py +++ b/bitblas/relax/transform/__init__.py @@ -1,5 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .annotate_decode_block import AnnotateDecodeInformation -from .weight_only_propagate import WeightOnlyLayoutPropagation +from .weight_only_propagate import WeightOnlyLayoutPropagation # noqa: F401 +from .apply_fast_tuning import ( + ApplyDefaultSchedule, # noqa: F401 + ApplyFastTuning, # noqa: F401 +) diff --git a/bitblas/base/transform.py b/bitblas/relax/transform/apply_fast_tuning.py similarity index 97% rename from bitblas/base/transform.py rename to bitblas/relax/transform/apply_fast_tuning.py index ec2cbc1e7..873cb6773 100644 --- a/bitblas/base/transform.py +++ b/bitblas/relax/transform/apply_fast_tuning.py @@ -15,9 +15,9 @@ from tvm.ir import IRModule from tvm.ir.transform import PassContext, module_pass from tvm.target import Target -from .schedule_rule import ScheduleRule -from ..base.analysis import check_func_with_dynamic -from .utils import fast_tune, fast_tune_with_dynamic_range +from bitblas.base.schedule_rule import ScheduleRule +from bitblas.base.analysis import check_func_with_dynamic +from bitblas.base.utils import fast_tune, fast_tune_with_dynamic_range import logging logger = logging.getLogger(__name__) diff --git a/examples/.gitignore b/examples/.gitignore new file mode 100644 index 000000000..1ed8a9f77 --- /dev/null +++ b/examples/.gitignore @@ -0,0 +1 @@ +progress/ \ No newline at end of file diff --git a/examples/relax_end2end.py b/examples/relax_end2end.py new file mode 100644 index 000000000..69a2417f2 --- /dev/null +++ b/examples/relax_end2end.py @@ -0,0 +1,229 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import os +from typing import Dict +import numpy as np # type: ignore +import time +import bitblas +from bitblas import tvm as tvm +import tvm +from tvm import relay, relax, runtime, transform +from tvm.ir.module import IRModule +from tvm.relax.testing import relay_translator, nn +from tvm.target.target import Target +from tvm import dlight as dl +from tvm import relay +import tvm.relay.testing +from tvm.ir.module import IRModule +from bitblas.relax import ApplyDefaultSchedule, ApplyFastTuning +fname = os.path.basename(__file__) +fname = os.path.splitext(fname)[0] +# get current file path +log_path = os.path.dirname(os.path.abspath(__file__)) + "/progress/" + fname + +count = 0 + +bitblas.set_log_level("Debug") + +def write_code(code, path, fname): + global count + fname = str(count) + "." + fname + count += 1 + if not os.path.exists(path): + os.makedirs(path) + fname = os.path.join(path, fname) + with open(fname, "w") as f: + f.write(code) + + +def write_sch(sch, path, fname): + py_fname = fname + ".py" + write_code(sch.mod["main"].script(), path, py_fname) + cu_fname = fname + ".cu" + write_code(sch.mod.astext(), path, cu_fname) + + +def write_mod(mod, path, fname): + py_fname = fname + ".py" + write_code(mod.script(show_meta=False), path, py_fname) + cu_fname = fname + ".cu" + write_code(mod.astext(show_meta_data=False), path, cu_fname) + + +def get_network(name, batch_size, layout="NHWC", dtype="float32"): + """Get the symbol definition and random weight of a network""" + + # auto-scheduler prefers NHWC layout + if layout == "NHWC": + image_shape = (224, 224, 3) + elif layout == "NCHW": + image_shape = (3, 224, 224) + else: + raise ValueError("Invalid layout: " + layout) + + input_shape = (batch_size,) + image_shape + output_shape = (batch_size, 1000) + + if name.startswith("resnet-"): + n_layer = int(name.split("-")[1]) + mod, params = relay.testing.resnet.get_workload( + num_layers=n_layer, + batch_size=batch_size, + layout=layout, + dtype=dtype, + image_shape=image_shape, + ) + elif name.startswith("resnet3d-"): + n_layer = int(name.split("-")[1]) + mod, params = relay.testing.resnet.get_workload( + num_layers=n_layer, + batch_size=batch_size, + layout=layout, + dtype=dtype, + image_shape=image_shape, + ) + elif name == "mobilenet": + mod, params = relay.testing.mobilenet.get_workload( + batch_size=batch_size, layout=layout, dtype=dtype, image_shape=image_shape + ) + elif name == "squeezenet_v1.1": + assert layout == "NCHW", "squeezenet_v1.1 only supports NCHW layout" + mod, params = relay.testing.squeezenet.get_workload( + version="1.1", + batch_size=batch_size, + dtype=dtype, + image_shape=image_shape, + ) + elif name == "inception_v3": + input_shape = (batch_size, 3, 299, 299) if layout == "NCHW" else (batch_size, 299, 299, 3) + mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype) + elif name == "mlp": + mod, params = relay.testing.mlp.get_workload( + batch_size=batch_size, image_shape=image_shape, dtype=dtype + ) + + return mod, params, input_shape, output_shape + + +# Define the neural network and compilation target. +network = "mlp" +# network = "resnet-18" +batch_size = 128 +layout = "NHWC" +# Path to cross compiler +target = tvm.target.Target("nvidia/nvidia-a100") +dtype = "float32" + +relay_mod, params, input_shape, output_shape = get_network(network, batch_size, layout, dtype=dtype) + + +def apply_opt_before_tuning( + relay_mod: IRModule, params: Dict[str, runtime.NDArray], target: Target +): + with transform.PassContext(opt_level=3): + main_func = relay_mod["main"] + bind_main_func = relay.build_module.bind_params_by_name(main_func, params) + relay_mod = IRModule.from_expr(bind_main_func) + write_mod(relay_mod, log_path, "create_mod") + relay_mod = relay.transform.SimplifyInference()(relay_mod) + write_mod(relay_mod, log_path, "SimplifyInference") + relay_mod = relay.transform.ConvertLayout({"nn.conv2d": ["NHWC", "default"]})(relay_mod) + write_mod(relay_mod, log_path, "ConvertLayout") + relay_mod = relay.transform.FoldConstant()(relay_mod) + write_mod(relay_mod, log_path, "FoldConstant") + relay_mod = relay.transform.FoldScaleAxis()(relay_mod) + write_mod(relay_mod, log_path, "FoldScaleAxis") + relay_mod = relay.transform.CanonicalizeOps()(relay_mod) + write_mod(relay_mod, log_path, "CanonicalizeOps") + relay_mod = relay.transform.AlterOpLayout()(relay_mod) + write_mod(relay_mod, log_path, "AlterOpLayout") + relay_mod = relay.transform.FoldConstant()(relay_mod) + write_mod(relay_mod, log_path, "FoldConstant") + + # opt_level=2 and select_impl_strategy are required for avoiding winograd lowering + relax_mod = relay_translator.from_relay(relay_mod["main"], opt_level=2, target=target, append_op_attrs=True, select_impl_strategy="first") + write_mod(relax_mod, log_path, "relay_translator_relax") + relax_mod = relax.transform.AnnotateTIROpPattern()(relax_mod) + write_mod(relax_mod, log_path, "AnnotateTIROpPattern") + relax_mod = relax.transform.FuseOps()(relax_mod) + write_mod(relax_mod, log_path, "FuseOps") + relax_mod = relax.transform.FuseTIR()(relax_mod) + write_mod(relax_mod, log_path, "FuseTIR") + return relax_mod + + +relax_mod = apply_opt_before_tuning(relay_mod, params, target) +start_tune_time = time.time() +relax_mod = ApplyFastTuning(topk=20, target=target, parallel_build=True)(relax_mod) +end_tune_time = time.time() + +write_mod(relax_mod, log_path, "ApplyFastTuning") +print("Time cost of Fast Dlight tuniing: {:.3f} s".format((end_tune_time - start_tune_time))) + +with target: + schedule_rules = [ + bitblas.gpu.Matmul(), + bitblas.gpu.GEMV(), + bitblas.gpu.Reduction(), + bitblas.gpu.GeneralReduction(), + bitblas.gpu.Fallback(), + ] + for rule in schedule_rules: + relax_mod = ApplyDefaultSchedule(rule)(relax_mod) + +write_mod(relax_mod, log_path, "ApplyFastTuning") + +relax_mod = relax.transform.RunCodegen()(relax_mod) + +write_mod(relax_mod, log_path, "run_codegen") + +relax_mod = tvm.tir.transform.MakePackedAPI()(relax_mod) +write_mod(relax_mod, log_path, "make_packed_api") + +ex = relax.build(relax_mod, target) +write_code(ex.mod.imported_modules[0].imported_modules[0].get_source(), log_path, "tmp.cu") + + +device = tvm.cuda(0) +vm = relax.VirtualMachine(ex, device) + +# init parameters +params = nn.init_params(relax_mod) + +input_args = [] + +input_args.append(tvm.nd.array(np.random.uniform(-1, 1, size=input_shape).astype(dtype), device)) + +res = vm["main"](*input_args) + +print(res) + +device.sync() + +start = time.time() + +for i in range(10): + vm["main"](*input_args) + + +device.sync() + +end = time.time() + +print("Time cost is: ", (end - start) * 100, "ms") From b4754073b711e980820d13d3ceb985e519f880c7 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 11 Nov 2024 09:28:47 +0000 Subject: [PATCH 2/3] End2end Fix --- 3rdparty/tvm | 2 +- bitblas/__init__.py | 5 ++++- examples/relax_end2end.py | 40 ++++++++++++++------------------------- 3 files changed, 19 insertions(+), 28 deletions(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 8847ba9a6..7b325acd5 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 8847ba9a6562b08b77d0223a33601f34d8100404 +Subproject commit 7b325acd51b8e1a9ed102e4065f7ba206b88b84a diff --git a/bitblas/__init__.py b/bitblas/__init__.py index 4fecc93d7..ef4986419 100644 --- a/bitblas/__init__.py +++ b/bitblas/__init__.py @@ -144,7 +144,10 @@ def remove_tvm_path(path): try_inline, # noqa: F401 try_inline_contiguous_spatial, # noqa: F401 ) - +from .relax import ( + ApplyDefaultSchedule, # noqa: F401 + ApplyFastTuning, # noqa: F401 +) from . import testing # noqa: F401 from .utils import auto_detect_nvidia_target, apply_transform_on_input # noqa: F401 from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401 diff --git a/examples/relax_end2end.py b/examples/relax_end2end.py index 69a2417f2..dc89b63c9 100644 --- a/examples/relax_end2end.py +++ b/examples/relax_end2end.py @@ -18,20 +18,16 @@ import numpy as np import os from typing import Dict -import numpy as np # type: ignore import time import bitblas from bitblas import tvm as tvm -import tvm from tvm import relay, relax, runtime, transform -from tvm.ir.module import IRModule from tvm.relax.testing import relay_translator, nn from tvm.target.target import Target -from tvm import dlight as dl -from tvm import relay import tvm.relay.testing from tvm.ir.module import IRModule from bitblas.relax import ApplyDefaultSchedule, ApplyFastTuning + fname = os.path.basename(__file__) fname = os.path.splitext(fname)[0] # get current file path @@ -41,6 +37,7 @@ bitblas.set_log_level("Debug") + def write_code(code, path, fname): global count fname = str(count) + "." + fname @@ -80,16 +77,7 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): input_shape = (batch_size,) + image_shape output_shape = (batch_size, 1000) - if name.startswith("resnet-"): - n_layer = int(name.split("-")[1]) - mod, params = relay.testing.resnet.get_workload( - num_layers=n_layer, - batch_size=batch_size, - layout=layout, - dtype=dtype, - image_shape=image_shape, - ) - elif name.startswith("resnet3d-"): + if name.startswith("resnet-") or name.startswith("resnet3d-"): n_layer = int(name.split("-")[1]) mod, params = relay.testing.resnet.get_workload( num_layers=n_layer, @@ -100,8 +88,7 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): ) elif name == "mobilenet": mod, params = relay.testing.mobilenet.get_workload( - batch_size=batch_size, layout=layout, dtype=dtype, image_shape=image_shape - ) + batch_size=batch_size, layout=layout, dtype=dtype, image_shape=image_shape) elif name == "squeezenet_v1.1": assert layout == "NCHW", "squeezenet_v1.1 only supports NCHW layout" mod, params = relay.testing.squeezenet.get_workload( @@ -115,8 +102,7 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype) elif name == "mlp": mod, params = relay.testing.mlp.get_workload( - batch_size=batch_size, image_shape=image_shape, dtype=dtype - ) + batch_size=batch_size, image_shape=image_shape, dtype=dtype) return mod, params, input_shape, output_shape @@ -133,9 +119,8 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): relay_mod, params, input_shape, output_shape = get_network(network, batch_size, layout, dtype=dtype) -def apply_opt_before_tuning( - relay_mod: IRModule, params: Dict[str, runtime.NDArray], target: Target -): +def apply_opt_before_tuning(relay_mod: IRModule, params: Dict[str, runtime.NDArray], + target: Target): with transform.PassContext(opt_level=3): main_func = relay_mod["main"] bind_main_func = relay.build_module.bind_params_by_name(main_func, params) @@ -157,7 +142,12 @@ def apply_opt_before_tuning( write_mod(relay_mod, log_path, "FoldConstant") # opt_level=2 and select_impl_strategy are required for avoiding winograd lowering - relax_mod = relay_translator.from_relay(relay_mod["main"], opt_level=2, target=target, append_op_attrs=True, select_impl_strategy="first") + relax_mod = relay_translator.from_relay( + relay_mod["main"], + opt_level=2, + target=target, + append_op_attrs=True, + select_impl_strategy="first") write_mod(relax_mod, log_path, "relay_translator_relax") relax_mod = relax.transform.AnnotateTIROpPattern()(relax_mod) write_mod(relax_mod, log_path, "AnnotateTIROpPattern") @@ -199,7 +189,6 @@ def apply_opt_before_tuning( ex = relax.build(relax_mod, target) write_code(ex.mod.imported_modules[0].imported_modules[0].get_source(), log_path, "tmp.cu") - device = tvm.cuda(0) vm = relax.VirtualMachine(ex, device) @@ -218,10 +207,9 @@ def apply_opt_before_tuning( start = time.time() -for i in range(10): +for _ in range(10): vm["main"](*input_args) - device.sync() end = time.time() From f23a2ecd934de40e6aa9fde1ebe26b63abdc5db1 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 11 Nov 2024 09:30:29 +0000 Subject: [PATCH 3/3] lint fix --- examples/relax_end2end.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/relax_end2end.py b/examples/relax_end2end.py index 9ac4fba86..dc89b63c9 100644 --- a/examples/relax_end2end.py +++ b/examples/relax_end2end.py @@ -37,6 +37,7 @@ bitblas.set_log_level("Debug") + def write_code(code, path, fname): global count fname = str(count) + "." + fname