Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[Hexagon] Support template-free meta schedule tuning (apache#12854)
Browse files Browse the repository at this point in the history
* [Metaschedule] Support template-free tuning on Hexagon

* enable multi threading

* update tests

* black
  • Loading branch information
masahi authored and xinetzone committed Nov 25, 2022
1 parent 6a17d65 commit fa23ae6
Show file tree
Hide file tree
Showing 3 changed files with 291 additions and 7 deletions.
57 changes: 55 additions & 2 deletions python/tvm/meta_schedule/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,12 @@ def schedule_rules( # pylint: disable=redefined-outer-name
return sch_rules()
if sch_rules is not None:
raise TypeError(f"Expected `sch_rules` to be None or callable, but gets: {sch_rules}")
if target.kind.name in ["llvm", "hexagon"]:
if target.kind.name == "llvm":
return _DefaultLLVM.schedule_rules()
if target.kind.name in ["cuda", "rocm", "vulkan"]:
return _DefaultCUDA.schedule_rules()
if target.kind.name == "hexagon":
return _DefaultHexagon.schedule_rules()
raise ValueError(f"Unsupported target: {target}")


Expand All @@ -190,10 +192,12 @@ def postproc( # pylint: disable=redefined-outer-name
return postproc()
if postproc is not None:
raise TypeError(f"Expected `postproc` to be None or callable, but gets: {postproc}")
if target.kind.name in ["llvm", "hexagon"]:
if target.kind.name == "llvm":
return _DefaultLLVM.postprocs()
if target.kind.name in ["cuda", "rocm", "vulkan"]:
return _DefaultCUDA.postprocs()
if target.kind.name == "hexagon":
return _DefaultHexagon.postprocs()
raise ValueError(f"Unsupported target: {target}")


Expand Down Expand Up @@ -277,6 +281,55 @@ def mutator_probs() -> Dict[Mutator, float]:
}


class _DefaultHexagon:
"""Default tuning configuration for Hexagon."""

@staticmethod
def schedule_rules() -> List[ScheduleRule]:
from tvm.meta_schedule import schedule_rule as M

return [
M.AutoInline(
into_producer=False,
into_consumer=True,
inline_const_tensor=True,
disallow_if_then_else=True,
require_injective=True,
require_ordered=True,
disallow_op=["tir.exp"],
),
M.MultiLevelTilingWideVector(
structure="SRSRS",
vector_length_in_bits=1024,
max_innermost_factor=128,
reuse_read=None,
reuse_write=M.ReuseType(
req="may",
levels=[1, 2],
scope="global",
),
),
M.ParallelizeVectorizeUnroll(
max_jobs_per_core=16,
max_vectorize_extent=128,
unroll_max_steps=[0, 16, 64, 512],
unroll_explicit=True,
),
]

@staticmethod
def postprocs() -> List[Postproc]:
from tvm.meta_schedule import postproc as M

return [
M.DisallowDynamicLoop(),
M.RewriteParallelVectorizeUnroll(),
M.RewriteReductionBlock(),
# TODO(masahi): Fix RewriteLayout for link-params=True case
# M.RewriteLayout(),
]


class _DefaultCUDA:
"""Default tuning configuration for CUDA."""

Expand Down
29 changes: 26 additions & 3 deletions python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ def tune_relay(
postprocs: Optional[FnPostproc] = None,
mutator_probs: Optional[FnMutatorProb] = None,
num_threads: Optional[int] = None,
executor=None,
) -> Union[Module, vm.Executable]:
"""Tune a Relay IRModule with a given target.
Expand Down Expand Up @@ -581,6 +582,9 @@ def tune_relay(
The callbacks used during tuning.
backend : str = "graph"
The backend to use for relay compilation(graph / vm).
executor : relay.backend.Executor
The executor to be passed to relay.build(...). In particular, its link-params
attribute affects task extration and workload database look up.
Returns
-------
Expand All @@ -596,8 +600,23 @@ def tune_relay(
target = default_config.target(target)
# pylint: enable=protected-access,
# parse the tuning contexts

if executor is None:
executor = relay.backend.Executor("graph")

if "link-params" in executor.attrs:
link_params = executor.attrs["link-params"]
else:
link_params = False

with Profiler.timeit("TaskExtraction"):
extracted_tasks = extract_task_from_relay(mod, target, params)
pass_config = {
"relay.FuseOps.link_params": link_params,
"relay.backend.use_meta_schedule": True,
"relay.backend.tir_converter": "default",
}
extracted_tasks = extract_task_from_relay(mod, target, params, pass_config=pass_config)

database = tune_extracted_tasks(
extracted_tasks,
config,
Expand All @@ -613,7 +632,7 @@ def tune_relay(
mutator_probs=mutator_probs,
num_threads=num_threads,
)
relay_build = {"graph": relay.build, "vm": relay.vm.compile}[backend]

with Profiler.timeit("PostTuningCompilation"):
with target, autotvm_silencer(), database:
with PassContext(
Expand All @@ -624,4 +643,8 @@ def tune_relay(
"relay.backend.tir_converter": "default",
},
):
return relay_build(mod, target=target, params=params)
if backend == "graph":
return relay.build(mod, target=target, params=params, executor=executor)

# Executor is not supported by VM
return relay.vm.compile(mod, target=target, params=params)
212 changes: 210 additions & 2 deletions tests/python/contrib/test_hexagon/test_meta_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,20 @@
import tempfile

import tvm.testing
from tvm import te
import tvm.topi.testing
from tvm import te, relay
from tvm import meta_schedule as ms
from tvm.meta_schedule.arg_info import TensorInfo
from tvm.meta_schedule.builder import BuilderInput
from tvm.meta_schedule import postproc, schedule_rule
from tvm.script import tir as T
from tvm.tir import FloatImm
from tvm.tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN
from tvm.meta_schedule.runner import RunnerInput
from tvm.contrib.hexagon.meta_schedule import get_hexagon_local_builder, get_hexagon_rpc_runner
from tvm.relay.backend import Executor
from tvm.topi.utils import get_const_tuple
from tvm.meta_schedule.testing import te_workload

MATMUL_N = 16
MATMUL_M = 32
Expand Down Expand Up @@ -166,7 +171,6 @@ def verify_dense(sch, target, M, N, K, hexagon_session):
print("%f ms, %f GOPS" % (time_ms, gflops / (time_ms / 1e3)))


@pytest.mark.skip(reason="xgboost not installed on CI")
@tvm.testing.requires_hexagon
def test_vrmpy_dense(hexagon_launcher):
if hexagon_launcher._serial_number == "simulator":
Expand Down Expand Up @@ -209,3 +213,207 @@ def schedule_dense_for_tune(sch):

with hexagon_launcher.start_session() as session:
verify_dense(sch, target, M, N, K, session)


# This is an example of a schedule found by vrmpy auto tensorization.
# It gets 440 GFLOPS on SD888.
@tvm.script.ir_module
class Module_vrmpy_auto_tensorize:
@T.prim_func
def main(
X: T.Buffer[(128, 768), "uint8"],
packedW: T.Buffer[(24, 192, 32, 4), "uint8"],
compute: T.Buffer[(128, 768), "int32"],
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i0_0_i1_0_0_fused in T.parallel(
512, annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}
):
for i0_1_init, i1_0_1_init, i0_2_init, i1_0_2_init in T.grid(2, 3, 1, 1):
with T.block("compute_o_init"):
i = T.axis.spatial(128, i0_0_i1_0_0_fused // 8 * 2 + i0_1_init + i0_2_init)
j_o = T.axis.spatial(24, i1_0_2_init + i0_0_i1_0_0_fused % 8 * 3 + i1_0_1_init)
T.reads()
T.writes(compute[i, j_o * 32 : j_o * 32 + 32])
for i1_1 in T.vectorized(32):
with T.block("compute_init"):
j_i_init = T.axis.spatial(32, i1_1)
T.reads()
T.writes(compute[i, j_o * 32 + j_i_init])
compute[i, j_o * 32 + j_i_init] = 0
for i2_0_0, i0_1, i1_0_1, i2_0_1, i0_2, i1_0_2 in T.grid(32, 2, 3, 6, 1, 1):
with T.block("compute_o_update"):
i = T.axis.spatial(128, i0_0_i1_0_0_fused // 8 * 2 + i0_1 + i0_2)
j_o = T.axis.spatial(24, i1_0_2 + i0_0_i1_0_0_fused % 8 * 3 + i1_0_1)
k_o = T.axis.reduce(192, i2_0_0 * 6 + i2_0_1)
T.reads(
compute[i, j_o * 32 : j_o * 32 + 32],
X[i, k_o * 4 : k_o * 4 + 4],
packedW[j_o, k_o, 0:32, 0:4],
)
T.writes(compute[i, j_o * 32 : j_o * 32 + 32])
A = T.match_buffer(
X[i, k_o * 4 : k_o * 4 + 4], [4], dtype="uint8", offset_factor=1
)
B = T.match_buffer(
packedW[j_o, k_o, 0:32, 0:4], [32, 4], dtype="uint8", offset_factor=1
)
C = T.match_buffer(
compute[i, j_o * 32 : j_o * 32 + 32], [32], dtype="int32", offset_factor=1
)
A_u8x4: T.uint8x4 = A[0:4]
A_i32: T.int32 = T.reinterpret(A_u8x4, dtype="int32")
B_i32x32: T.int32x32 = T.reinterpret(B[0, 0:128], dtype="int32x32")
C[0:32] = T.call_llvm_pure_intrin(
4390, T.uint32(3), C[0:32], B_i32x32, A_i32, dtype="int32x32"
)


@tvm.testing.requires_hexagon
def test_vrmpy_dense_auto_tensorize(hexagon_launcher):
if hexagon_launcher._serial_number == "simulator":
pytest.skip(msg="Tuning on simulator not supported.")

target_hexagon = tvm.target.hexagon("v68")
target = tvm.target.Target(target_hexagon, host=target_hexagon)

M, N, K = 128, 768, 768
workload = te.create_prim_func(dense(M, N, K))

sch_rules = [
schedule_rule.MultiLevelTilingWithIntrin(
VRMPY_u8u8i32_INTRIN,
structure="SRSRS",
tile_binds=None,
max_innermost_factor=64,
vector_load_lens=None,
reuse_read=None,
reuse_write=schedule_rule.ReuseType(
req="may",
levels=[1, 2],
scope="global",
),
),
schedule_rule.ParallelizeVectorizeUnroll(
max_jobs_per_core=16,
max_vectorize_extent=128,
unroll_max_steps=[0, 16, 64, 512],
unroll_explicit=True,
),
]

postprocs = [
postproc.RewriteParallelVectorizeUnroll(),
postproc.RewriteReductionBlock(),
postproc.RewriteTensorize(vectorize_init_loop=True),
]

if True:
with tempfile.TemporaryDirectory() as work_dir:
config = ms.TuneConfig(
strategy="replay_trace",
num_trials_per_iter=8,
max_trials_per_task=8,
max_trials_global=8,
)

sch = ms.tune_tir(
mod=workload,
target=target,
config=config,
work_dir=work_dir,
sch_rules=lambda: sch_rules,
postprocs=lambda: postprocs,
builder=get_hexagon_local_builder(),
runner=get_hexagon_rpc_runner(hexagon_launcher, number=10),
)
else:
sch = tvm.tir.Schedule(Module_vrmpy_auto_tensorize, debug_mask="all")

with hexagon_launcher.start_session() as session:
verify_dense(sch, target, M, N, K, session)


@tvm.testing.requires_hexagon
def test_conv2d_relay_auto_schedule(hexagon_launcher):
if hexagon_launcher._serial_number == "simulator":
pytest.skip(msg="Tuning on simulator not supported.")

target_hexagon = tvm.target.hexagon("v69")
target = tvm.target.Target(target_hexagon, host=target_hexagon)
I, O, H, W = 64, 64, 56, 56
kH = kW = 3

strides = (1, 1)
padding = (1, 1)

d_shape = (1, H, W, I)
w_shape = (kH, kW, I, O)
bias_shape = (1, 1, 1, w_shape[3])
out_channel = w_shape[3]

data = relay.var("data", shape=d_shape, dtype="float16")
weight = relay.var("weight", shape=w_shape, dtype="float16")
bias = relay.var("bias", shape=bias_shape, dtype="float16")
conv2d = relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=(kH, kW),
channels=out_channel,
padding=padding,
strides=strides,
out_dtype="float16",
data_layout="NHWC",
kernel_layout="HWIO",
)
mod = tvm.IRModule.from_expr(conv2d + bias)

data_np = np.random.randn(*d_shape).astype("float16")
weight_np = np.random.randn(*w_shape).astype("float16")
bias_np = np.random.randn(*bias_shape).astype("float16")
params = {"weight": weight_np, "bias": bias_np}

target_llvm = tvm.target.Target("llvm")

with tvm.transform.PassContext(
opt_level=3,
):
lib_ref = relay.build(mod, target=target_llvm, params=params)

rt_mod_ref = tvm.contrib.graph_executor.GraphModule(lib_ref["default"](tvm.cpu(0)))

rt_mod_ref.set_input("data", data_np)

rt_mod_ref.run()

ref = rt_mod_ref.get_output(0).numpy()

config = ms.TuneConfig(
strategy="replay_trace",
num_trials_per_iter=8,
max_trials_per_task=8,
max_trials_global=8,
)

with tempfile.TemporaryDirectory() as work_dir:
executor = Executor("graph", {"link-params": True})
lib = ms.tune_relay(
mod=mod,
params=params,
target=target,
config=config,
work_dir=work_dir,
builder=get_hexagon_local_builder(),
runner=get_hexagon_rpc_runner(hexagon_launcher, number=20),
executor=executor,
)

with hexagon_launcher.start_session() as session:
rt_mod = session.get_executor_from_factory(lib)

rt_mod.set_input("data", data_np)

rt_mod.run()

out = rt_mod.get_output(0).numpy()
print(np.max(np.abs(ref - out)), np.mean(np.abs(ref - out)))

0 comments on commit fa23ae6

Please sign in to comment.