Skip to content

Commit

Permalink
[AutoScheduler] Make SearchTask and ComputeDAG serializable (apache#6842
Browse files Browse the repository at this point in the history
)

* serialize task and dag

* fix test

* more tests

* format

* format

* format

* trigger ci
  • Loading branch information
comaniac authored and trevor-m committed Dec 4, 2020
1 parent 13a8e42 commit 7b6b731
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 9 deletions.
23 changes: 18 additions & 5 deletions python/tvm/auto_scheduler/compute_dag.py
Expand Up @@ -21,14 +21,14 @@

import tvm._ffi
from tvm.runtime import Object
from tvm.te import PlaceholderOp, ComputeOp
from tvm.runtime._ffi_node_api import LoadJSON, SaveJSON
from tvm.te import ComputeOp, PlaceholderOp

from . import _ffi_api
from .loop_state import State, StateObject
from .utils import get_const_tuple
from .workload_registry import workload_key_to_tensors

from . import _ffi_api


@tvm._ffi.register_object("auto_scheduler.ComputeDAG")
class ComputeDAG(Object):
Expand Down Expand Up @@ -63,7 +63,10 @@ def __init__(self, compute_or_sche):
elif isinstance(compute_or_sche, list):
for item in compute_or_sche:
if not isinstance(item, tvm.te.Tensor):
raise ValueError("The input of ComputeDAG should be a list of Tensor")
raise ValueError(
"The input of ComputeDAG should be a list of Tensor, but got %s"
% type(item)
)
compute = compute_or_sche
sche = None
elif isinstance(compute_or_sche, tvm.te.Schedule):
Expand All @@ -72,8 +75,10 @@ def __init__(self, compute_or_sche):
else:
raise ValueError(
"Invalid compute type: %s. ComputeDAG expects string, list of Tensor, or Schedule"
% type(compute)
% type(compute_or_sche)
)
self.compute = compute
self.sche = sche
self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute, sche)

def get_init_state(self):
Expand Down Expand Up @@ -182,3 +187,11 @@ def hash_key(self):

str_key = str_key.encode(encoding="utf-8")
return hashlib.md5(str_key).hexdigest()

def __getstate__(self):
return {"compute": SaveJSON(self.compute), "sche": SaveJSON(self.sche)}

def __setstate__(self, state):
self.compute = LoadJSON(state["compute"]) # pylint: disable=assignment-from-no-return
self.sche = LoadJSON(state["sche"]) # pylint: disable=assignment-from-no-return
self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, self.compute, self.sche)
29 changes: 29 additions & 0 deletions python/tvm/auto_scheduler/search_task.py
Expand Up @@ -42,6 +42,35 @@ class SearchTask(Object):
"""

def __init__(self, dag, workload_key, target, target_host=None, hardware_params=None):
self.dag = dag
self.workload_key = workload_key
self.target = target
self.target_host = target_host
self.hardware_params = hardware_params
self.__init_handle_by_constructor__(
_ffi_api.SearchTask, dag, workload_key, target, target_host, hardware_params
)

def __getstate__(self):
return {
"dag": self.dag,
"workload_key": self.workload_key,
"target": self.target,
"target_host": self.target_host,
"hardware_params": self.hardware_params,
}

def __setstate__(self, state):
self.dag = state["dag"]
self.workload_key = state["workload_key"]
self.target = state["target"]
self.target_host = state["target_host"]
self.hardware_params = state["hardware_params"]
self.__init_handle_by_constructor__(
_ffi_api.SearchTask,
self.dag,
self.workload_key,
self.target,
self.target_host,
self.hardware_params,
)
4 changes: 1 addition & 3 deletions tests/python/unittest/test_auto_scheduler_common.py
Expand Up @@ -161,14 +161,12 @@ def conv2d_winograd_nhwc_auto_scheduler_test(
r = KW
m = tile_size
alpha = m + r - 1
A, B, G = winograd_transform_matrices(m, r, "float32")
A, B, _ = winograd_transform_matrices(m, r, "float32")

H = (H + 2 * HPAD - KH) // HSTR + 1
W = (W + 2 * WPAD - KW) // WSTR + 1
nH, nW = (H + m - 1) // m, (W + m - 1) // m
P = N * nH * nW
r_kh = te.reduce_axis((0, KH), name="r_kh")
r_kw = te.reduce_axis((0, KW), name="r_kw")
kshape = (alpha, alpha, CI, CO)
kernel_pack = te.placeholder(kshape, inputs.dtype, name="weight")

Expand Down
30 changes: 29 additions & 1 deletion tests/python/unittest/test_auto_scheduler_compute_dag.py
Expand Up @@ -16,6 +16,7 @@
# under the License.

"""Test ComputeDAG (replay, infer bound)"""
import pickle

import tvm
from tvm import topi
Expand All @@ -32,7 +33,7 @@ def test_apply_steps():
dag, s = get_tiled_matmul()
dag.print_python_code_from_state(s)
sch, tensors = dag.apply_steps_from_state(s)
stmt = tvm.lower(sch, tensors, simple_mode=True)
tvm.lower(sch, tensors, simple_mode=True)


def test_infer_bound():
Expand Down Expand Up @@ -61,6 +62,7 @@ def test_estimate_flop():


def test_stage_order():
"""Test if the stage order is preserved when recovering a DAG."""
N = 512
A, B, C, D, E = parallel_matmul_auto_scheduler_test(N)
sch = te.create_schedule([D.op, E.op])
Expand All @@ -87,6 +89,11 @@ def test_stage_order():
elif op.name in ["B", "C"]:
assert stage_ops_1[idx + 1].name == "%s.shared" % op.name

# Serialize and deserialize the ComputeDAG constructed by a schedule.
loaded_dag = pickle.loads(pickle.dumps(dag))
assert str(loaded_dag.get_init_state()) == str(dag.get_init_state())
assert len(loaded_dag.get_init_state().stage_ops) == len(dag.get_init_state().stage_ops)

# Apply the same schedule to Ansor state and it should have the same stage order
dag = auto_scheduler.ComputeDAG([A, B, C, D, E])
state = dag.get_init_state()
Expand All @@ -105,6 +112,27 @@ def test_stage_order():
for op1, op2 in zip(stage_ops_1, stage_ops_2):
assert op1.name == op2.name

# Serialize and deserialize the ComputeDAG constructed by a list of tensor ops.
loaded_dag = pickle.loads(pickle.dumps(dag))
assert str(loaded_dag.get_init_state()) == str(dag.get_init_state())
assert len(loaded_dag.get_init_state().stage_ops) == len(dag.get_init_state().stage_ops)

# Serialize and deserialize the search task.
task = auto_scheduler.SearchTask(
dag,
"test1",
tvm.target.Target("llvm"),
hardware_params=auto_scheduler.HardwareParams(100000, 16, 64),
)
task2 = pickle.loads(pickle.dumps(task))
assert str(task.dag.get_init_state()) == str(task2.dag.get_init_state())
assert len(task.dag.get_init_state().stage_ops) == len(task2.dag.get_init_state().stage_ops)
assert task.workload_key == task2.workload_key
assert str(task.target) == str(task2.target)
assert task.hardware_params.num_cores == task2.hardware_params.num_cores
assert task.hardware_params.vector_unit_bytes == task2.hardware_params.vector_unit_bytes
assert task.hardware_params.cache_line_bytes == task2.hardware_params.cache_line_bytes


if __name__ == "__main__":
test_apply_steps()
Expand Down

0 comments on commit 7b6b731

Please sign in to comment.