From 7b6b7318c8af2a0afa13a4cdbd7717485197c891 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Sat, 7 Nov 2020 11:10:25 -0800 Subject: [PATCH] [AutoScheduler] Make SearchTask and ComputeDAG serializable (#6842) * serialize task and dag * fix test * more tests * format * format * format * trigger ci --- python/tvm/auto_scheduler/compute_dag.py | 23 ++++++++++---- python/tvm/auto_scheduler/search_task.py | 29 ++++++++++++++++++ .../unittest/test_auto_scheduler_common.py | 4 +-- .../test_auto_scheduler_compute_dag.py | 30 ++++++++++++++++++- 4 files changed, 77 insertions(+), 9 deletions(-) diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index d50ff395b679..9390a9c4589a 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -21,14 +21,14 @@ import tvm._ffi from tvm.runtime import Object -from tvm.te import PlaceholderOp, ComputeOp +from tvm.runtime._ffi_node_api import LoadJSON, SaveJSON +from tvm.te import ComputeOp, PlaceholderOp +from . import _ffi_api from .loop_state import State, StateObject from .utils import get_const_tuple from .workload_registry import workload_key_to_tensors -from . import _ffi_api - @tvm._ffi.register_object("auto_scheduler.ComputeDAG") class ComputeDAG(Object): @@ -63,7 +63,10 @@ def __init__(self, compute_or_sche): elif isinstance(compute_or_sche, list): for item in compute_or_sche: if not isinstance(item, tvm.te.Tensor): - raise ValueError("The input of ComputeDAG should be a list of Tensor") + raise ValueError( + "The input of ComputeDAG should be a list of Tensor, but got %s" + % type(item) + ) compute = compute_or_sche sche = None elif isinstance(compute_or_sche, tvm.te.Schedule): @@ -72,8 +75,10 @@ def __init__(self, compute_or_sche): else: raise ValueError( "Invalid compute type: %s. ComputeDAG expects string, list of Tensor, or Schedule" - % type(compute) + % type(compute_or_sche) ) + self.compute = compute + self.sche = sche self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute, sche) def get_init_state(self): @@ -182,3 +187,11 @@ def hash_key(self): str_key = str_key.encode(encoding="utf-8") return hashlib.md5(str_key).hexdigest() + + def __getstate__(self): + return {"compute": SaveJSON(self.compute), "sche": SaveJSON(self.sche)} + + def __setstate__(self, state): + self.compute = LoadJSON(state["compute"]) # pylint: disable=assignment-from-no-return + self.sche = LoadJSON(state["sche"]) # pylint: disable=assignment-from-no-return + self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, self.compute, self.sche) diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index 92c4f48bf371..7c5021b3f9b7 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -42,6 +42,35 @@ class SearchTask(Object): """ def __init__(self, dag, workload_key, target, target_host=None, hardware_params=None): + self.dag = dag + self.workload_key = workload_key + self.target = target + self.target_host = target_host + self.hardware_params = hardware_params self.__init_handle_by_constructor__( _ffi_api.SearchTask, dag, workload_key, target, target_host, hardware_params ) + + def __getstate__(self): + return { + "dag": self.dag, + "workload_key": self.workload_key, + "target": self.target, + "target_host": self.target_host, + "hardware_params": self.hardware_params, + } + + def __setstate__(self, state): + self.dag = state["dag"] + self.workload_key = state["workload_key"] + self.target = state["target"] + self.target_host = state["target_host"] + self.hardware_params = state["hardware_params"] + self.__init_handle_by_constructor__( + _ffi_api.SearchTask, + self.dag, + self.workload_key, + self.target, + self.target_host, + self.hardware_params, + ) diff --git a/tests/python/unittest/test_auto_scheduler_common.py b/tests/python/unittest/test_auto_scheduler_common.py index 6a3fe4e82c99..5b7add9733de 100644 --- a/tests/python/unittest/test_auto_scheduler_common.py +++ b/tests/python/unittest/test_auto_scheduler_common.py @@ -161,14 +161,12 @@ def conv2d_winograd_nhwc_auto_scheduler_test( r = KW m = tile_size alpha = m + r - 1 - A, B, G = winograd_transform_matrices(m, r, "float32") + A, B, _ = winograd_transform_matrices(m, r, "float32") H = (H + 2 * HPAD - KH) // HSTR + 1 W = (W + 2 * WPAD - KW) // WSTR + 1 nH, nW = (H + m - 1) // m, (W + m - 1) // m P = N * nH * nW - r_kh = te.reduce_axis((0, KH), name="r_kh") - r_kw = te.reduce_axis((0, KW), name="r_kw") kshape = (alpha, alpha, CI, CO) kernel_pack = te.placeholder(kshape, inputs.dtype, name="weight") diff --git a/tests/python/unittest/test_auto_scheduler_compute_dag.py b/tests/python/unittest/test_auto_scheduler_compute_dag.py index 2ccedef9e2de..e7774753796c 100644 --- a/tests/python/unittest/test_auto_scheduler_compute_dag.py +++ b/tests/python/unittest/test_auto_scheduler_compute_dag.py @@ -16,6 +16,7 @@ # under the License. """Test ComputeDAG (replay, infer bound)""" +import pickle import tvm from tvm import topi @@ -32,7 +33,7 @@ def test_apply_steps(): dag, s = get_tiled_matmul() dag.print_python_code_from_state(s) sch, tensors = dag.apply_steps_from_state(s) - stmt = tvm.lower(sch, tensors, simple_mode=True) + tvm.lower(sch, tensors, simple_mode=True) def test_infer_bound(): @@ -61,6 +62,7 @@ def test_estimate_flop(): def test_stage_order(): + """Test if the stage order is preserved when recovering a DAG.""" N = 512 A, B, C, D, E = parallel_matmul_auto_scheduler_test(N) sch = te.create_schedule([D.op, E.op]) @@ -87,6 +89,11 @@ def test_stage_order(): elif op.name in ["B", "C"]: assert stage_ops_1[idx + 1].name == "%s.shared" % op.name + # Serialize and deserialize the ComputeDAG constructed by a schedule. + loaded_dag = pickle.loads(pickle.dumps(dag)) + assert str(loaded_dag.get_init_state()) == str(dag.get_init_state()) + assert len(loaded_dag.get_init_state().stage_ops) == len(dag.get_init_state().stage_ops) + # Apply the same schedule to Ansor state and it should have the same stage order dag = auto_scheduler.ComputeDAG([A, B, C, D, E]) state = dag.get_init_state() @@ -105,6 +112,27 @@ def test_stage_order(): for op1, op2 in zip(stage_ops_1, stage_ops_2): assert op1.name == op2.name + # Serialize and deserialize the ComputeDAG constructed by a list of tensor ops. + loaded_dag = pickle.loads(pickle.dumps(dag)) + assert str(loaded_dag.get_init_state()) == str(dag.get_init_state()) + assert len(loaded_dag.get_init_state().stage_ops) == len(dag.get_init_state().stage_ops) + + # Serialize and deserialize the search task. + task = auto_scheduler.SearchTask( + dag, + "test1", + tvm.target.Target("llvm"), + hardware_params=auto_scheduler.HardwareParams(100000, 16, 64), + ) + task2 = pickle.loads(pickle.dumps(task)) + assert str(task.dag.get_init_state()) == str(task2.dag.get_init_state()) + assert len(task.dag.get_init_state().stage_ops) == len(task2.dag.get_init_state().stage_ops) + assert task.workload_key == task2.workload_key + assert str(task.target) == str(task2.target) + assert task.hardware_params.num_cores == task2.hardware_params.num_cores + assert task.hardware_params.vector_unit_bytes == task2.hardware_params.vector_unit_bytes + assert task.hardware_params.cache_line_bytes == task2.hardware_params.cache_line_bytes + if __name__ == "__main__": test_apply_steps()